Skip to content

Commit

Permalink
Updates logging tags, transform float/int to use relative min/max (#1798
Browse files Browse the repository at this point in the history
)
  • Loading branch information
nickzelei committed Apr 20, 2024
1 parent 6d010c4 commit 6ceed2c
Show file tree
Hide file tree
Showing 28 changed files with 399 additions and 231 deletions.
5 changes: 3 additions & 2 deletions backend/internal/auth/jwt/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,10 @@ func getCombinedScopesAndPermissions(scope string, permissions []string) []strin
}

func GetTokenDataFromCtx(ctx context.Context) (*TokenContextData, error) {
data, ok := ctx.Value(TokenContextKey{}).(*TokenContextData)
val := ctx.Value(TokenContextKey{})
data, ok := val.(*TokenContextData)
if !ok {
return nil, nucleuserrors.NewUnauthenticated("ctx does not contain TokenContextData or unable to cast struct")
return nil, nucleuserrors.NewUnauthenticated(fmt.Sprintf("ctx does not contain TokenContextData or unable to cast struct: %T", val))
}
return data, nil
}
6 changes: 5 additions & 1 deletion backend/internal/cmds/mgmt/serve/connect/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
awsmanager "github.com/nucleuscloud/neosync/backend/internal/aws"
up_cmd "github.com/nucleuscloud/neosync/backend/internal/cmds/mgmt/migrate/up"
auth_interceptor "github.com/nucleuscloud/neosync/backend/internal/connect/interceptors/auth"
authlogging_interceptor "github.com/nucleuscloud/neosync/backend/internal/connect/interceptors/auth_logging"
logger_interceptor "github.com/nucleuscloud/neosync/backend/internal/connect/interceptors/logger"
logging_interceptor "github.com/nucleuscloud/neosync/backend/internal/connect/interceptors/logging"
neosynclogger "github.com/nucleuscloud/neosync/backend/internal/logger"
Expand Down Expand Up @@ -158,7 +159,7 @@ func serve(ctx context.Context) error {
return err
}
loggerInterceptor := logger_interceptor.NewInterceptor(logger)
loggingInterceptor := logging_interceptor.NewInterceptor(logger)
loggingInterceptor := logging_interceptor.NewInterceptor()

stdInterceptors := []connect.Interceptor{
otelInterceptor,
Expand Down Expand Up @@ -202,12 +203,14 @@ func serve(ctx context.Context) error {
apikeyClient,
).InjectTokenCtx,
),
authlogging_interceptor.NewInterceptor(db),
)
jwtOnlyAuthInterceptors = append(
jwtOnlyAuthInterceptors,
auth_interceptor.NewInterceptor(
jwtclient.InjectTokenCtx,
),
authlogging_interceptor.NewInterceptor(db),
)
authSvcInterceptors = append(
authSvcInterceptors,
Expand All @@ -221,6 +224,7 @@ func serve(ctx context.Context) error {
mgmtv1alpha1connect.AuthServiceRefreshCliProcedure,
},
),
authlogging_interceptor.NewInterceptor(db),
)
}

Expand Down
69 changes: 69 additions & 0 deletions backend/internal/connect/interceptors/auth_logging/interceptor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package authlogging_interceptor

import (
"context"

"connectrpc.com/connect"
"github.com/nucleuscloud/neosync/backend/internal/auth/tokenctx"
logger_interceptor "github.com/nucleuscloud/neosync/backend/internal/connect/interceptors/logger"
"github.com/nucleuscloud/neosync/backend/internal/nucleusdb"
)

type Interceptor struct {
db *nucleusdb.NucleusDb
}

func NewInterceptor(db *nucleusdb.NucleusDb) connect.Interceptor {
return &Interceptor{db: db}
}

func (i *Interceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
return func(ctx context.Context, request connect.AnyRequest) (connect.AnyResponse, error) {
return next(setAuthValues(ctx, i.db), request)
}
}

func (i *Interceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc {
return func(ctx context.Context, spec connect.Spec) connect.StreamingClientConn {
return next(ctx, spec)
}
}

func (i *Interceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
return func(ctx context.Context, conn connect.StreamingHandlerConn) error {
return next(setAuthValues(ctx, i.db), conn)
}
}

func setAuthValues(ctx context.Context, db *nucleusdb.NucleusDb) context.Context {
vals := getAuthValues(ctx, db)
logger := logger_interceptor.GetLoggerFromContextOrDefault(ctx).With(vals...)
return logger_interceptor.SetLoggerContext(ctx, logger)
}

func getAuthValues(ctx context.Context, db *nucleusdb.NucleusDb) []any {
tokenCtxResp, err := tokenctx.GetTokenCtx(ctx)
if err != nil {
return []any{}
}
output := []any{}

if tokenCtxResp.JwtContextData != nil {
output = append(output, "authUserId", tokenCtxResp.JwtContextData.AuthUserId)

user, err := db.Q.GetUserByProviderSub(ctx, db.Db, tokenCtxResp.JwtContextData.AuthUserId)
if err == nil {
output = append(output, "userId", nucleusdb.UUIDString(user.ID))
}
} else if tokenCtxResp.ApiKeyContextData != nil {
output = append(output, "apiKeyType", tokenCtxResp.ApiKeyContextData.ApiKeyType)
if tokenCtxResp.ApiKeyContextData.ApiKey != nil {
output = append(output,
"apiKeyId", nucleusdb.UUIDString(tokenCtxResp.ApiKeyContextData.ApiKey.ID),
"accountId", nucleusdb.UUIDString(tokenCtxResp.ApiKeyContextData.ApiKey.AccountID),
"userId", nucleusdb.UUIDString(tokenCtxResp.ApiKeyContextData.ApiKey.UserID),
)
}
}
return output
}
198 changes: 198 additions & 0 deletions backend/internal/connect/interceptors/auth_logging/interceptor_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
package authlogging_interceptor

import (
"context"
"errors"
"log/slog"
"net/http"
"net/http/httptest"
"os"
"testing"

"connectrpc.com/connect"
"github.com/google/uuid"
db_queries "github.com/nucleuscloud/neosync/backend/gen/go/db"
mgmtv1alpha1 "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1"
"github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1/mgmtv1alpha1connect"
"github.com/nucleuscloud/neosync/backend/internal/apikey"
auth_apikey "github.com/nucleuscloud/neosync/backend/internal/auth/apikey"
auth_jwt "github.com/nucleuscloud/neosync/backend/internal/auth/jwt"
logger_interceptor "github.com/nucleuscloud/neosync/backend/internal/connect/interceptors/logger"
"github.com/nucleuscloud/neosync/backend/internal/nucleusdb"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)

func Test_Interceptor_WrapUnary_JwtContextData_ValidUser(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))

mockDbtx := nucleusdb.NewMockDBTX(t)
mockQuerier := db_queries.NewMockQuerier(t)

genuuid, _ := nucleusdb.ToUuid(uuid.NewString())
mockQuerier.On("GetUserByProviderSub", mock.Anything, mock.Anything, "auth-user-id").
Return(db_queries.NeosyncApiUser{ID: genuuid}, nil)

mux := http.NewServeMux()
mux.Handle(mgmtv1alpha1connect.UserAccountServiceGetUserProcedure, connect.NewUnaryHandler(
mgmtv1alpha1connect.UserAccountServiceGetUserProcedure,
func(ctx context.Context, r *connect.Request[mgmtv1alpha1.GetUserRequest]) (*connect.Response[mgmtv1alpha1.GetUserResponse], error) {
return connect.NewResponse(&mgmtv1alpha1.GetUserResponse{UserId: "123"}), nil
},
connect.WithInterceptors(
logger_interceptor.NewInterceptor(logger),
&mockAuthInterceptor{data: &auth_jwt.TokenContextData{AuthUserId: "auth-user-id"}},
NewInterceptor(nucleusdb.New(mockDbtx, mockQuerier)),
),
))

srv := startHTTPServer(t, mux)
client := mgmtv1alpha1connect.NewUserAccountServiceClient(srv.Client(), srv.URL)
_, err := client.GetUser(context.Background(), connect.NewRequest(&mgmtv1alpha1.GetUserRequest{}))
require.NoError(t, err)
}

func Test_Interceptor_WrapUnary_JwtContextData_NoUser_NoFail(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))

mockDbtx := nucleusdb.NewMockDBTX(t)
mockQuerier := db_queries.NewMockQuerier(t)

mux := http.NewServeMux()
mux.Handle(mgmtv1alpha1connect.UserAccountServiceGetUserProcedure, connect.NewUnaryHandler(
mgmtv1alpha1connect.UserAccountServiceGetUserProcedure,
func(ctx context.Context, r *connect.Request[mgmtv1alpha1.GetUserRequest]) (*connect.Response[mgmtv1alpha1.GetUserResponse], error) {
return connect.NewResponse(&mgmtv1alpha1.GetUserResponse{UserId: "123"}), nil
},
connect.WithInterceptors(
logger_interceptor.NewInterceptor(logger),
NewInterceptor(nucleusdb.New(mockDbtx, mockQuerier)),
),
))

srv := startHTTPServer(t, mux)
client := mgmtv1alpha1connect.NewUserAccountServiceClient(srv.Client(), srv.URL)
_, err := client.GetUser(context.Background(), connect.NewRequest(&mgmtv1alpha1.GetUserRequest{}))
require.NoError(t, err)
}

type mockAuthInterceptor struct {
data *auth_jwt.TokenContextData
}

func (i *mockAuthInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
return func(ctx context.Context, request connect.AnyRequest) (connect.AnyResponse, error) {
return next(context.WithValue(ctx, auth_jwt.TokenContextKey{}, i.data), request)
}
}

func (i *mockAuthInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc {
return func(ctx context.Context, spec connect.Spec) connect.StreamingClientConn {
return next(ctx, spec)
}
}

func (i *mockAuthInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
return func(ctx context.Context, conn connect.StreamingHandlerConn) error {
return next(ctx, conn)
}
}

func startHTTPServer(tb testing.TB, h http.Handler) *httptest.Server {
tb.Helper()
srv := httptest.NewUnstartedServer(h)
srv.EnableHTTP2 = true
srv.Start()
tb.Cleanup(srv.Close)
return srv
}

func Test_getAuthValues_NoTokenCtx(t *testing.T) {
vals := getAuthValues(context.Background(), &nucleusdb.NucleusDb{})
require.Empty(t, vals)
}

func Test_getAuthValues_Valid_Jwt(t *testing.T) {
mockDbtx := nucleusdb.NewMockDBTX(t)
mockQuerier := db_queries.NewMockQuerier(t)

uuidstr := uuid.NewString()
genuuid, _ := nucleusdb.ToUuid(uuidstr)
mockQuerier.On("GetUserByProviderSub", mock.Anything, mock.Anything, "auth-user-id").
Return(db_queries.NeosyncApiUser{ID: genuuid}, nil)

ctx := context.WithValue(context.Background(), auth_jwt.TokenContextKey{}, &auth_jwt.TokenContextData{
AuthUserId: "auth-user-id",
})

vals := getAuthValues(ctx, nucleusdb.New(mockDbtx, mockQuerier))
require.Equal(
t,
[]any{"authUserId", "auth-user-id", "userId", uuidstr},
vals,
)
}

func Test_getAuthValues_Valid_Jwt_No_User(t *testing.T) {
mockDbtx := nucleusdb.NewMockDBTX(t)
mockQuerier := db_queries.NewMockQuerier(t)

mockQuerier.On("GetUserByProviderSub", mock.Anything, mock.Anything, "auth-user-id").
Return(db_queries.NeosyncApiUser{}, errors.New("test err"))

ctx := context.WithValue(context.Background(), auth_jwt.TokenContextKey{}, &auth_jwt.TokenContextData{
AuthUserId: "auth-user-id",
})

vals := getAuthValues(ctx, nucleusdb.New(mockDbtx, mockQuerier))
require.Equal(
t,
[]any{"authUserId", "auth-user-id"},
vals,
)
}

func Test_getAuthValues_Valid_ApiKey(t *testing.T) {
mockDbtx := nucleusdb.NewMockDBTX(t)
mockQuerier := db_queries.NewMockQuerier(t)

apikeyid := uuid.NewString()
accountid := uuid.NewString()
userid := uuid.NewString()

apikeyuuid, _ := nucleusdb.ToUuid(apikeyid)
accountiduuid, _ := nucleusdb.ToUuid(accountid)
useriduuid, _ := nucleusdb.ToUuid(userid)

ctx := context.WithValue(context.Background(), auth_apikey.TokenContextKey{}, &auth_apikey.TokenContextData{
ApiKeyType: apikey.AccountApiKey,
ApiKey: &db_queries.NeosyncApiAccountApiKey{
ID: apikeyuuid,
AccountID: accountiduuid,
UserID: useriduuid,
},
})

vals := getAuthValues(ctx, nucleusdb.New(mockDbtx, mockQuerier))
require.Equal(
t,
[]any{"apiKeyType", apikey.AccountApiKey, "apiKeyId", apikeyid, "accountId", accountid, "userId", userid},
vals,
)
}

func Test_getAuthValues_Valid_ApiKey_No_Apikey(t *testing.T) {
mockDbtx := nucleusdb.NewMockDBTX(t)
mockQuerier := db_queries.NewMockQuerier(t)

ctx := context.WithValue(context.Background(), auth_apikey.TokenContextKey{}, &auth_apikey.TokenContextData{
ApiKeyType: apikey.AccountApiKey,
})

vals := getAuthValues(ctx, nucleusdb.New(mockDbtx, mockQuerier))
require.Equal(
t,
[]any{"apiKeyType", apikey.AccountApiKey},
vals,
)
}
4 changes: 2 additions & 2 deletions backend/internal/connect/interceptors/logger/interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func NewInterceptor(logger *slog.Logger) connect.Interceptor {

func (i *Interceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
return func(ctx context.Context, request connect.AnyRequest) (connect.AnyResponse, error) {
newCtx := setLoggerContext(ctx, clonelogger(i.logger))
newCtx := SetLoggerContext(ctx, clonelogger(i.logger))
return next(newCtx, request)
}
}
Expand All @@ -32,7 +32,7 @@ func (i *Interceptor) WrapStreamingClient(next connect.StreamingClientFunc) conn

func (i *Interceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
return func(ctx context.Context, conn connect.StreamingHandlerConn) error {
newCtx := setLoggerContext(ctx, clonelogger(i.logger))
newCtx := SetLoggerContext(ctx, clonelogger(i.logger))
return next(newCtx, conn)
}
}
Expand Down
2 changes: 1 addition & 1 deletion backend/internal/connect/interceptors/logger/logger-ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,6 @@ func GetLoggerFromContextOrDefault(ctx context.Context) *slog.Logger {
return data.GetLogger()
}

func setLoggerContext(ctx context.Context, logger *slog.Logger) context.Context {
func SetLoggerContext(ctx context.Context, logger *slog.Logger) context.Context {
return context.WithValue(ctx, loggerContextKey{}, &loggerContextData{logger: logger})
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func Test_GetLoggerFromContextOrDefault(t *testing.T) {

func Test_GetLoggerFromContextOrDefault_NonDefault(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := setLoggerContext(context.Background(), logger)
ctx := SetLoggerContext(context.Background(), logger)
ctxlogger := GetLoggerFromContextOrDefault(ctx)
assert.Equal(t, logger, ctxlogger)
}
Loading

0 comments on commit 6ceed2c

Please sign in to comment.