Skip to content

Commit

Permalink
refactor: create api handler
Browse files Browse the repository at this point in the history
  • Loading branch information
pyadav committed Jan 28, 2024
1 parent 239d947 commit 4c980bb
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 49 deletions.
33 changes: 33 additions & 0 deletions mobius/internal/api/v1/chatcompletions.go
@@ -0,0 +1,33 @@
package v1

import (
"context"
"errors"

"connectrpc.com/connect"
"github.com/missingstudio/studio/backend/internal/providers"
"github.com/missingstudio/studio/backend/internal/providers/base"
llmv1 "github.com/missingstudio/studio/protos/pkg/llm"
)

func (s *V1Handler) ChatCompletions(
ctx context.Context,
req *connect.Request[llmv1.CompletionRequest],
) (*connect.Response[llmv1.CompletionResponse], error) {
provider, err := providers.GetProvider(ctx)
if err != nil {
return nil, connect.NewError(connect.CodeInternal, err)
}

completionProvider, ok := provider.(base.ChatCompilationInterface)
if !ok {
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("method not implemented"))
}

data, err := completionProvider.ChatCompilation(ctx, req.Msg)
if err != nil {
return nil, connect.NewError(connect.CodeInternal, err)
}

return connect.NewResponse(data), nil
}
37 changes: 37 additions & 0 deletions mobius/internal/api/v1/v1.go
@@ -0,0 +1,37 @@
package v1

import (
"fmt"
"net/http"

"connectrpc.com/connect"
"connectrpc.com/validate"
"connectrpc.com/vanguard"
"github.com/missingstudio/studio/backend/pkg/utils"
"github.com/missingstudio/studio/protos/pkg/llm/llmv1connect"
)

type V1Handler struct {
llmv1connect.UnimplementedLLMServiceHandler
}

func Register() (http.Handler, error) {
validateInterceptor, err := validate.NewInterceptor()
if err != nil {
return nil, fmt.Errorf("validate interceptor not created: %w", err)
}

compress1KB := connect.WithCompressMinBytes(1024)
services := []*vanguard.Service{
vanguard.NewService(llmv1connect.NewLLMServiceHandler(
&V1Handler{},
compress1KB,
connect.WithInterceptors(validateInterceptor, utils.ProviderInterceptor()),
)),
}
transcoderOptions := []vanguard.TranscoderOption{
vanguard.WithUnknownHandler(utils.Custom404handler()),
}

return vanguard.NewTranscoder(services, transcoderOptions...)
}
53 changes: 4 additions & 49 deletions mobius/internal/connectrpc/mux.go
@@ -1,20 +1,13 @@
package connectrpc

import (
"context"
"errors"
"fmt"
"net/http"

"connectrpc.com/connect"
"connectrpc.com/grpchealth"
"connectrpc.com/grpcreflect"
"connectrpc.com/validate"
"connectrpc.com/vanguard"
"github.com/missingstudio/studio/backend/internal/providers"
"github.com/missingstudio/studio/backend/internal/providers/base"
"github.com/missingstudio/studio/backend/pkg/utils"
llmv1 "github.com/missingstudio/studio/protos/pkg/llm"
v1 "github.com/missingstudio/studio/backend/internal/api/v1"
"github.com/missingstudio/studio/protos/pkg/llm/llmv1connect"
)

Expand All @@ -23,33 +16,17 @@ type Deps struct{}
func NewConnectMux(d Deps) (*http.ServeMux, error) {
mux := http.NewServeMux()

validateInterceptor, err := validate.NewInterceptor()
if err != nil {
return nil, fmt.Errorf("validate interceptor not created: %w", err)
}

compress1KB := connect.WithCompressMinBytes(1024)
services := []*vanguard.Service{
vanguard.NewService(llmv1connect.NewLLMServiceHandler(
&LLMServer{},
compress1KB,
connect.WithInterceptors(validateInterceptor, utils.ProviderInterceptor()),
)),
}
transcoderOptions := []vanguard.TranscoderOption{
vanguard.WithUnknownHandler(Custom404handler()),
}

transcoder, err := vanguard.NewTranscoder(services, transcoderOptions...)
v1Handler, err := v1.Register()
if err != nil {
return nil, fmt.Errorf("failed to create transcoder: %w", err)
return nil, fmt.Errorf("failed to create handler: %w", err)
}

mux.Handle("/", v1Handler)
mux.Handle("/ping", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "Pong\n")
}))

mux.Handle("/", transcoder)
mux.Handle(grpchealth.NewHandler(
grpchealth.NewStaticChecker(llmv1connect.LLMServiceName),
compress1KB,
Expand All @@ -65,25 +42,3 @@ func NewConnectMux(d Deps) (*http.ServeMux, error) {
type LLMServer struct {
llmv1connect.UnimplementedLLMServiceHandler
}

func (s *LLMServer) ChatCompletions(
ctx context.Context,
req *connect.Request[llmv1.CompletionRequest],
) (*connect.Response[llmv1.CompletionResponse], error) {
provider, err := providers.GetProvider(ctx)
if err != nil {
return nil, connect.NewError(connect.CodeInternal, err)
}

completionProvider, ok := provider.(base.ChatCompilationInterface)
if !ok {
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("method not implemented"))
}

data, err := completionProvider.ChatCompilation(ctx, req.Msg)
if err != nil {
return nil, connect.NewError(connect.CodeInternal, err)
}

return connect.NewResponse(data), nil
}
2 changes: 2 additions & 0 deletions mobius/pkg/server/server.go
Expand Up @@ -25,6 +25,8 @@ func Serve(ctx context.Context) error {
return connectsrv.Shutdown()
},
})

slog.Info("server started")
<-wait

slog.Info("graceful shutdown complete")
Expand Down
9 changes: 9 additions & 0 deletions mobius/pkg/utils/middleware.go
@@ -0,0 +1,9 @@
package utils

import "net/http"

func Custom404handler() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
http.Error(w, "custom 404 error", http.StatusNotFound)
})
}

0 comments on commit 4c980bb

Please sign in to comment.