Skip to content
Permalink
Browse files Browse the repository at this point in the history
Use context to store and populate origin
  • Loading branch information
easyCZ authored and corneliusludmann committed Feb 23, 2023
1 parent 76896ac commit 673ab68
Show file tree
Hide file tree
Showing 10 changed files with 191 additions and 20 deletions.
9 changes: 3 additions & 6 deletions components/public-api-server/pkg/auth/context.go
Expand Up @@ -25,8 +25,6 @@ const (
type Token struct {
Type TokenType
Value string
// Only relevant for CookieTokenType
OriginHeader string
}

func NewAccessToken(token string) Token {
Expand All @@ -36,11 +34,10 @@ func NewAccessToken(token string) Token {
}
}

func NewCookieToken(cookie string, origin string) Token {
func NewCookieToken(cookie string) Token {
return Token{
Type: CookieTokenType,
Value: cookie,
OriginHeader: origin,
Type: CookieTokenType,
Value: cookie,
}
}

Expand Down
2 changes: 1 addition & 1 deletion components/public-api-server/pkg/auth/context_test.go
Expand Up @@ -20,7 +20,7 @@ func TestTokenToAndFromContext_AccessToken(t *testing.T) {
}

func TestTokenToAndFromContext_CookieToken(t *testing.T) {
token := NewCookieToken("my_token", "gitpod.io")
token := NewCookieToken("my_token")

extracted, err := TokenFromContext(TokenToContext(context.Background(), token))
require.NoError(t, err)
Expand Down
3 changes: 1 addition & 2 deletions components/public-api-server/pkg/auth/middleware.go
Expand Up @@ -41,9 +41,8 @@ func tokenFromRequest(ctx context.Context, req connect.AnyRequest) (Token, error
}

cookie := req.Header().Get("Cookie")
origin := req.Header().Get("Origin")
if cookie != "" {
return NewCookieToken(cookie, origin), nil
return NewCookieToken(cookie), nil
}

return Token{}, connect.NewError(connect.CodeUnauthenticated, fmt.Errorf("No access token or cookie credentials available on request."))
Expand Down
27 changes: 27 additions & 0 deletions components/public-api-server/pkg/origin/context.go
@@ -0,0 +1,27 @@
// Copyright (c) 2023 Gitpod GmbH. All rights reserved.
// Licensed under the GNU Affero General Public License (AGPL).
// See License.AGPL.txt in the project root for license information.

package origin

import (
"context"
)

type contextKey int

const (
originContextKey contextKey = iota
)

func ToContext(ctx context.Context, origin string) context.Context {
return context.WithValue(ctx, originContextKey, origin)
}

func FromContext(ctx context.Context) string {
if val, ok := ctx.Value(originContextKey).(string); ok {
return val
}

return ""
}
17 changes: 17 additions & 0 deletions components/public-api-server/pkg/origin/context_test.go
@@ -0,0 +1,17 @@
// Copyright (c) 2023 Gitpod GmbH. All rights reserved.
// Licensed under the GNU Affero General Public License (AGPL).
// See License.AGPL.txt in the project root for license information.

package origin

import (
"context"
"testing"

"github.com/stretchr/testify/require"
)

func TestToFromContext(t *testing.T) {
require.Equal(t, "some-origin", FromContext(ToContext(context.Background(), "some-origin")), "origin stored on context is extracted")
require.Equal(t, "", FromContext(context.Background()), "context without origin value returns empty")
}
48 changes: 48 additions & 0 deletions components/public-api-server/pkg/origin/middleware.go
@@ -0,0 +1,48 @@
// Copyright (c) 2023 Gitpod GmbH. All rights reserved.
// Licensed under the GNU Affero General Public License (AGPL).
// See License.AGPL.txt in the project root for license information.

package origin

import (
"context"

"github.com/bufbuild/connect-go"
)

func NewInterceptor() *Interceptor {
return &Interceptor{}
}

type Interceptor struct{}

func (i *Interceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
return connect.UnaryFunc(func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
if req.Spec().IsClient {
req.Header().Add("Origin", FromContext(ctx))
} else {
origin := req.Header().Get("Origin")
ctx = ToContext(ctx, origin)
}

return next(ctx, req)
})
}

func (a *Interceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc {
return func(ctx context.Context, s connect.Spec) connect.StreamingClientConn {
conn := next(ctx, s)
conn.RequestHeader().Add("Origin", FromContext(ctx))

return conn
}
}

func (a *Interceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
return func(ctx context.Context, conn connect.StreamingHandlerConn) error {
origin := conn.RequestHeader().Get("Origin")
ctx = ToContext(ctx, origin)

return next(ctx, conn)
}
}
36 changes: 36 additions & 0 deletions components/public-api-server/pkg/origin/middleware_test.go
@@ -0,0 +1,36 @@
// Copyright (c) 2023 Gitpod GmbH. All rights reserved.
// Licensed under the GNU Affero General Public License (AGPL).
// See License.AGPL.txt in the project root for license information.

package origin

import (
"context"
"testing"

"github.com/bufbuild/connect-go"
"github.com/stretchr/testify/require"
)

func TestInterceptor_Unary(t *testing.T) {
requestPaylaod := "request"
origin := "my-origin"

type response struct {
origin string
}

handler := connect.UnaryFunc(func(ctx context.Context, ar connect.AnyRequest) (connect.AnyResponse, error) {
origin := FromContext(ctx)
return connect.NewResponse(&response{origin: origin}), nil
})

ctx := context.Background()
request := connect.NewRequest(&requestPaylaod)
request.Header().Add("Origin", origin)

interceptor := NewInterceptor()
resp, err := interceptor.WrapUnary(handler)(ctx, request)
require.NoError(t, err)
require.Equal(t, &response{origin: origin}, resp.Any())
}
30 changes: 23 additions & 7 deletions components/public-api-server/pkg/proxy/conn.go
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/gitpod-io/gitpod/common-go/log"
gitpod "github.com/gitpod-io/gitpod/gitpod-protocol"
"github.com/gitpod-io/gitpod/public-api-server/pkg/auth"
"github.com/gitpod-io/gitpod/public-api-server/pkg/origin"
"github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus"

lru "github.com/hashicorp/golang-lru"
Expand Down Expand Up @@ -41,14 +42,14 @@ func (p *NoConnectionPool) Get(ctx context.Context, token auth.Token) (gitpod.AP
opts := gitpod.ConnectToServerOpts{
Context: ctx,
Log: logger,
Origin: origin.FromContext(ctx),
}

switch token.Type {
case auth.AccessTokenType:
opts.Token = token.Value
case auth.CookieTokenType:
opts.Cookie = token.Value
opts.Origin = token.OriginHeader
default:
return nil, errors.New("unknown token type")
}
Expand Down Expand Up @@ -83,11 +84,12 @@ func NewConnectionPool(address *url.URL, poolSize int) (*ConnectionPool, error)

return &ConnectionPool{
cache: cache,
connConstructor: func(token auth.Token) (gitpod.APIInterface, error) {
connConstructor: func(ctx context.Context, token auth.Token) (gitpod.APIInterface, error) {
opts := gitpod.ConnectToServerOpts{
// We're using Background context as we want the connection to persist beyond the lifecycle of a single request
Context: context.Background(),
Log: log.Log,
Origin: origin.FromContext(ctx),
CloseHandler: func(_ error) {
cache.Remove(token)
connectionPoolSize.Dec()
Expand All @@ -99,7 +101,6 @@ func NewConnectionPool(address *url.URL, poolSize int) (*ConnectionPool, error)
opts.Token = token.Value
case auth.CookieTokenType:
opts.Cookie = token.Value
opts.Origin = token.OriginHeader
default:
return nil, errors.New("unknown token type")
}
Expand All @@ -120,15 +121,23 @@ func NewConnectionPool(address *url.URL, poolSize int) (*ConnectionPool, error)

}

type conenctionPoolCacheKey struct {
token auth.Token
origin string
}

type ConnectionPool struct {
connConstructor func(token auth.Token) (gitpod.APIInterface, error)
connConstructor func(context.Context, auth.Token) (gitpod.APIInterface, error)

// cache stores token to connection mapping
cache *lru.Cache
}

func (p *ConnectionPool) Get(ctx context.Context, token auth.Token) (gitpod.APIInterface, error) {
cached, found := p.cache.Get(token)
origin := origin.FromContext(ctx)

cacheKey := p.cacheKey(token, origin)
cached, found := p.cache.Get(cacheKey)
reportCacheOutcome(found)
if found {
conn, ok := cached.(*gitpod.APIoverJSONRPC)
Expand All @@ -137,17 +146,24 @@ func (p *ConnectionPool) Get(ctx context.Context, token auth.Token) (gitpod.APII
}
}

conn, err := p.connConstructor(token)
conn, err := p.connConstructor(ctx, token)
if err != nil {
return nil, fmt.Errorf("failed to create new connection to server: %w", err)
}

p.cache.Add(token, conn)
p.cache.Add(cacheKey, conn)
connectionPoolSize.Inc()

return conn, nil
}

func (p *ConnectionPool) cacheKey(token auth.Token, origin string) conenctionPoolCacheKey {
return conenctionPoolCacheKey{
token: token,
origin: origin,
}
}

func getEndpointBasedOnToken(t auth.Token, u *url.URL) (string, error) {
switch t.Type {
case auth.AccessTokenType:
Expand Down
37 changes: 33 additions & 4 deletions components/public-api-server/pkg/proxy/conn_test.go
Expand Up @@ -11,6 +11,7 @@ import (

gitpod "github.com/gitpod-io/gitpod/gitpod-protocol"
"github.com/gitpod-io/gitpod/public-api-server/pkg/auth"
"github.com/gitpod-io/gitpod/public-api-server/pkg/origin"
"github.com/golang/mock/gomock"
lru "github.com/hashicorp/golang-lru"
"github.com/stretchr/testify/require"
Expand All @@ -25,7 +26,7 @@ func TestConnectionPool(t *testing.T) {
require.NoError(t, err)
pool := &ConnectionPool{
cache: cache,
connConstructor: func(token auth.Token) (gitpod.APIInterface, error) {
connConstructor: func(ctx context.Context, token auth.Token) (gitpod.APIInterface, error) {
return srv, nil
},
}
Expand All @@ -45,8 +46,36 @@ func TestConnectionPool(t *testing.T) {
_, err = pool.Get(context.Background(), bazToken)
require.NoError(t, err)
require.Equal(t, 2, pool.cache.Len(), "must keep only last two connectons")
require.True(t, pool.cache.Contains(barToken))
require.True(t, pool.cache.Contains(bazToken))
require.True(t, pool.cache.Contains(pool.cacheKey(barToken, "")))
require.True(t, pool.cache.Contains(pool.cacheKey(bazToken, "")))
}

func TestConnectionPool_ByDistinctOrigins(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
srv := gitpod.NewMockAPIInterface(ctrl)

cache, err := lru.New(2)
require.NoError(t, err)
pool := &ConnectionPool{
cache: cache,
connConstructor: func(ctx context.Context, token auth.Token) (gitpod.APIInterface, error) {
return srv, nil
},
}

token := auth.NewAccessToken("foo")

ctxWithOriginA := origin.ToContext(context.Background(), "originA")
ctxWithOriginB := origin.ToContext(context.Background(), "originB")

_, err = pool.Get(ctxWithOriginA, token)
require.NoError(t, err)
require.Equal(t, 1, pool.cache.Len())

_, err = pool.Get(ctxWithOriginB, token)
require.NoError(t, err)
require.Equal(t, 2, pool.cache.Len())
}

func TestEndpointBasedOnToken(t *testing.T) {
Expand All @@ -57,7 +86,7 @@ func TestEndpointBasedOnToken(t *testing.T) {
require.NoError(t, err)
require.Equal(t, "wss://gitpod.io/api/v1", endpointForAccessToken)

endpointForCookie, err := getEndpointBasedOnToken(auth.NewCookieToken("foo", "server"), u)
endpointForCookie, err := getEndpointBasedOnToken(auth.NewCookieToken("foo"), u)
require.NoError(t, err)
require.Equal(t, "wss://gitpod.io/api/gitpod", endpointForCookie)
}
2 changes: 2 additions & 0 deletions components/public-api-server/pkg/server/server.go
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/gitpod-io/gitpod/public-api-server/pkg/apiv1"
"github.com/gitpod-io/gitpod/public-api-server/pkg/auth"
"github.com/gitpod-io/gitpod/public-api-server/pkg/billingservice"
"github.com/gitpod-io/gitpod/public-api-server/pkg/origin"
"github.com/gitpod-io/gitpod/public-api-server/pkg/proxy"
"github.com/gitpod-io/gitpod/public-api-server/pkg/webhooks"
"github.com/sirupsen/logrus"
Expand Down Expand Up @@ -112,6 +113,7 @@ func register(srv *baseserver.Server, connPool proxy.ServerConnectionPool, expCl
NewMetricsInterceptor(connectMetrics),
NewLogInterceptor(log.Log),
auth.NewServerInterceptor(),
origin.NewInterceptor(),
),
}

Expand Down

0 comments on commit 673ab68

Please sign in to comment.