Skip to content

Commit

Permalink
feat(gateway): add api key support (#10)
Browse files Browse the repository at this point in the history
Signed-off-by: Praveen Yadav <pyadav9678@gmail.com>
  • Loading branch information
pyadav committed Feb 29, 2024
1 parent 4c95111 commit d3fc6cd
Show file tree
Hide file tree
Showing 25 changed files with 2,136 additions and 351 deletions.
5 changes: 5 additions & 0 deletions gateway/cmd/serve.go
Expand Up @@ -10,6 +10,7 @@ import (

_ "github.com/jackc/pgx/v5/stdlib"
"github.com/missingstudio/studio/backend/config"
"github.com/missingstudio/studio/backend/core/apikey"
"github.com/missingstudio/studio/backend/core/connection"
"github.com/missingstudio/studio/backend/core/prompt"
"github.com/missingstudio/studio/backend/internal/api"
Expand Down Expand Up @@ -65,6 +66,9 @@ func Serve(cfg *config.Config) error {
connectionRepository := postgres.NewConnectionRepository(dbc)
connectionService := connection.NewService(connectionRepository)

apikeyRepository := postgres.NewAPIKeyRepository(dbc)
apikeyService := apikey.NewService(apikeyRepository)

promptRepository := postgres.NewPromptRepository(dbc)
promptService := prompt.NewService(promptRepository)

Expand All @@ -74,6 +78,7 @@ func Serve(cfg *config.Config) error {
providerService,
connectionService,
promptService,
apikeyService,
)

if err := server.Serve(ctx, logger, cfg.App, deps); err != nil {
Expand Down
16 changes: 16 additions & 0 deletions gateway/core/apikey/apikey.go
@@ -0,0 +1,16 @@
package apikey

import (
"time"

"github.com/google/uuid"
)

type APIKey struct {
Id uuid.UUID `json:"id"`
Name string `json:"name"`
Value string `json:"value"`
MaskedValue string `json:"masked_value"`
CreatedAt time.Time `json:"created_at"`
LastUsedAt time.Time `json:"last_used_at"`
}
50 changes: 50 additions & 0 deletions gateway/core/apikey/service.go
@@ -0,0 +1,50 @@
package apikey

import (
"context"
)

type Repository interface {
GetAll(context.Context) ([]APIKey, error)
Create(context.Context, APIKey) (APIKey, error)
Get(context.Context, string) (APIKey, error)
GetByToken(context.Context, string) (APIKey, error)
Update(context.Context, APIKey) (APIKey, error)
DeleteByID(context.Context, string) error
}

var _ Repository = &Service{}

type Service struct {
apikeyRepo Repository
}

func NewService(apikeyRepo Repository) *Service {
return &Service{
apikeyRepo: apikeyRepo,
}
}

func (s *Service) GetAll(ctx context.Context) ([]APIKey, error) {
return s.apikeyRepo.GetAll(ctx)
}

func (s *Service) Create(ctx context.Context, api APIKey) (APIKey, error) {
return s.apikeyRepo.Create(ctx, api)
}

func (s *Service) Get(ctx context.Context, id string) (APIKey, error) {
return s.apikeyRepo.Get(ctx, id)
}

func (s *Service) GetByToken(ctx context.Context, id string) (APIKey, error) {
return s.apikeyRepo.GetByToken(ctx, id)
}

func (s *Service) Update(ctx context.Context, api APIKey) (APIKey, error) {
return s.apikeyRepo.Update(ctx, api)
}

func (s *Service) DeleteByID(ctx context.Context, id string) error {
return s.apikeyRepo.DeleteByID(ctx, id)
}
4 changes: 4 additions & 0 deletions gateway/internal/api/deps.go
Expand Up @@ -3,6 +3,7 @@ package api
import (
"log/slog"

"github.com/missingstudio/studio/backend/core/apikey"
"github.com/missingstudio/studio/backend/core/connection"
"github.com/missingstudio/studio/backend/core/prompt"
"github.com/missingstudio/studio/backend/internal/ingester"
Expand All @@ -17,6 +18,7 @@ type Deps struct {
ProviderService *providers.Service
ConnectionService *connection.Service
PromptService *prompt.Service
APIKeyService *apikey.Service
}

func NewDeps(
Expand All @@ -26,6 +28,7 @@ func NewDeps(
ps *providers.Service,
cs *connection.Service,
pms *prompt.Service,
aks *apikey.Service,
) *Deps {
return &Deps{
Logger: logger,
Expand All @@ -34,5 +37,6 @@ func NewDeps(
ProviderService: ps,
ConnectionService: cs,
PromptService: pms,
APIKeyService: aks,
}
}
118 changes: 118 additions & 0 deletions gateway/internal/api/v1/apikeys.go
@@ -0,0 +1,118 @@
package v1

import (
"context"

"connectrpc.com/connect"
"github.com/google/uuid"
"github.com/missingstudio/studio/backend/core/apikey"
"github.com/missingstudio/studio/backend/pkg/utils"
"github.com/missingstudio/studio/common/errors"
llmv1 "github.com/missingstudio/studio/protos/pkg/llm/v1"
"google.golang.org/protobuf/types/known/emptypb"
"google.golang.org/protobuf/types/known/timestamppb"
)

func (s *V1Handler) ListAPIKeys(ctx context.Context, req *connect.Request[emptypb.Empty]) (*connect.Response[llmv1.ListAPIKeysResponse], error) {
keys, err := s.apikeyService.GetAll(ctx)
if err != nil {
return nil, errors.NewInternalError(err.Error())
}

data := []*llmv1.APIKey{}
for _, k := range keys {
data = append(data, &llmv1.APIKey{
Id: k.Id.String(),
Name: k.Name,
MaskedValue: utils.MaskString(k.Value),
CreatedAt: timestamppb.New(k.CreatedAt),
LastUsedAt: timestamppb.New(k.LastUsedAt),
})
}

return connect.NewResponse(&llmv1.ListAPIKeysResponse{
Keys: data,
}), nil
}

func (s *V1Handler) CreateAPIKey(ctx context.Context, req *connect.Request[llmv1.CreateAPIKeyRequest]) (*connect.Response[llmv1.CreateAPIKeyResponse], error) {
securekey, err := utils.GenerateSecureAPIKey()
if err != nil {
return nil, errors.New(err)
}

key := apikey.APIKey{
Name: req.Msg.Name,
Value: securekey,
}

newkey, err := s.apikeyService.Create(ctx, key)
if err != nil {
return nil, errors.New(err)
}

return connect.NewResponse(&llmv1.CreateAPIKeyResponse{
Key: &llmv1.APIKey{
Id: newkey.Id.String(),
Name: newkey.Name,
Value: newkey.Value,
CreatedAt: timestamppb.New(newkey.CreatedAt),
LastUsedAt: timestamppb.New(newkey.LastUsedAt),
},
}), nil
}

func (s *V1Handler) GetAPIKey(ctx context.Context, req *connect.Request[llmv1.GetAPIKeyRequest]) (*connect.Response[llmv1.GetAPIKeyResponse], error) {
key, err := s.apikeyService.Get(ctx, req.Msg.Id)
if err != nil {
return nil, errors.NewNotFound(err.Error())
}

k := &llmv1.APIKey{
Id: key.Id.String(),
Name: key.Name,
MaskedValue: utils.MaskString(key.Value),
CreatedAt: timestamppb.New(key.CreatedAt),
LastUsedAt: timestamppb.New(key.LastUsedAt),
}

return connect.NewResponse(&llmv1.GetAPIKeyResponse{
Key: k,
}), nil
}

func (s *V1Handler) UpdateAPIKey(ctx context.Context, req *connect.Request[llmv1.UpdateAPIKeyRequest]) (*connect.Response[llmv1.UpdateAPIKeyResponse], error) {
parsedUUID, err := uuid.Parse(req.Msg.Id)
if err != nil {
return nil, errors.New(err)
}

key := apikey.APIKey{
Id: parsedUUID,
Name: req.Msg.Name,
}

updatedkey, err := s.apikeyService.Update(ctx, key)
if err != nil {
return nil, errors.New(err)
}

return connect.NewResponse(&llmv1.UpdateAPIKeyResponse{
Key: &llmv1.APIKey{
Id: updatedkey.Id.String(),
Name: updatedkey.Name,
MaskedValue: utils.MaskString(updatedkey.Value),
CreatedAt: timestamppb.New(updatedkey.CreatedAt),
LastUsedAt: timestamppb.New(updatedkey.LastUsedAt),
},
}), nil
}

func (s *V1Handler) DeleteAPIKey(ctx context.Context, req *connect.Request[llmv1.DeleteAPIKeyRequest]) (*connect.Response[emptypb.Empty], error) {
err := s.apikeyService.DeleteByID(ctx, req.Msg.Id)
if err != nil {
return nil, errors.New(err)
}

return connect.NewResponse(&emptypb.Empty{}), nil
}
4 changes: 2 additions & 2 deletions gateway/internal/api/v1/providers.go
Expand Up @@ -14,7 +14,7 @@ import (
"google.golang.org/protobuf/types/known/structpb"
)

func (s *V1Handler) ListProviders(ctx context.Context, req *connect.Request[emptypb.Empty]) (*connect.Response[llmv1.ProvidersResponse], error) {
func (s *V1Handler) ListProviders(ctx context.Context, req *connect.Request[emptypb.Empty]) (*connect.Response[llmv1.ListProvidersResponse], error) {
providers := s.providerService.GetProviders()

data := []*llmv1.Provider{}
Expand All @@ -27,7 +27,7 @@ func (s *V1Handler) ListProviders(ctx context.Context, req *connect.Request[empt
})
}

return connect.NewResponse(&llmv1.ProvidersResponse{
return connect.NewResponse(&llmv1.ListProvidersResponse{
Providers: data,
}), nil
}
Expand Down
4 changes: 4 additions & 0 deletions gateway/internal/api/v1/v1.go
Expand Up @@ -8,6 +8,7 @@ import (
"connectrpc.com/otelconnect"
"connectrpc.com/validate"
"connectrpc.com/vanguard"
"github.com/missingstudio/studio/backend/core/apikey"
"github.com/missingstudio/studio/backend/core/connection"
"github.com/missingstudio/studio/backend/core/prompt"
"github.com/missingstudio/studio/backend/internal/api"
Expand All @@ -24,6 +25,7 @@ type V1Handler struct {
ingester ingester.Ingester
providerService *providers.Service
connectionService *connection.Service
apikeyService *apikey.Service
promptService *prompt.Service
}

Expand All @@ -33,6 +35,7 @@ func NewHandlerV1(d *api.Deps) *V1Handler {
providerService: d.ProviderService,
connectionService: d.ConnectionService,
promptService: d.PromptService,
apikeyService: d.APIKeyService,
}
}

Expand All @@ -52,6 +55,7 @@ func Register(d *api.Deps) (http.Handler, error) {
stdInterceptors := []connect.Interceptor{
validateInterceptor,
otelconnectInterceptor,
interceptor.NewAPIKeyInterceptor(d.Logger, d.APIKeyService),
interceptor.HeadersInterceptor(),
interceptor.RateLimiterInterceptor(d.RateLimiter),
interceptor.RetryInterceptor(),
Expand Down
21 changes: 20 additions & 1 deletion gateway/internal/interceptor/auth.go
Expand Up @@ -5,14 +5,19 @@ import (
"log/slog"

"connectrpc.com/connect"
"github.com/missingstudio/studio/backend/core/apikey"
"github.com/missingstudio/studio/backend/internal/constants"
"github.com/missingstudio/studio/backend/internal/errors"
)

// NewAPIKeyInterceptor returns interceptor which is checking if api key exits
func NewAPIKeyInterceptor(logger *slog.Logger) connect.UnaryInterceptorFunc {
func NewAPIKeyInterceptor(logger *slog.Logger, aks *apikey.Service) connect.UnaryInterceptorFunc {
return connect.UnaryInterceptorFunc(func(next connect.UnaryFunc) connect.UnaryFunc {
return connect.UnaryFunc(func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
if authenticationSkipList[req.Spec().Procedure] {
return next(ctx, req)
}

apiHeader := req.Header().Get(constants.XMSAPIKey)
if apiHeader == "" {
logger.Info("request without api key",
Expand All @@ -22,7 +27,21 @@ func NewAPIKeyInterceptor(logger *slog.Logger) connect.UnaryInterceptorFunc {
return nil, errors.ErrUnauthenticated
}

apikey := req.Header().Get(constants.XMSAPIKey)
if _, err := aks.GetByToken(context.Background(), apikey); err != nil {
return nil, errors.ErrUnauthenticated
}

return next(ctx, req)
})
})
}

// authenticationSkipList stores path to skip authentication, by default its enabled for all requests
var authenticationSkipList = map[string]bool{
"/llm.v1.LLMService/ListModels": true,
"/llm.v1.LLMService/ListProviders": true,
"/llm.v1.LLMService/GetProviderConfig": true,
"/llm.v1.LLMService/ListAPIKeys": true,
"/llm.v1.LLMService/CreateAPIKey": true,
}
28 changes: 28 additions & 0 deletions gateway/internal/storage/postgres/apikey.go
@@ -0,0 +1,28 @@
package postgres

import (
"database/sql"
"time"

"github.com/google/uuid"
"github.com/missingstudio/studio/backend/core/apikey"
)

type APIKey struct {
ID uuid.UUID `db:"id"`
Name string `db:"name"`
Value []byte `db:"value"`
LastUsedAt sql.NullTime `db:"last_used_at"`
UpdatedAt time.Time `db:"updated_at"`
CreatedAt time.Time `db:"created_at"`
}

func (c APIKey) ToAPIKey() (apikey.APIKey, error) {
return apikey.APIKey{
Id: c.ID,
Name: c.Name,
Value: string(c.Value),
LastUsedAt: c.LastUsedAt.Time,
CreatedAt: c.CreatedAt,
}, nil
}

0 comments on commit d3fc6cd

Please sign in to comment.