Skip to content

Commit

Permalink
feat: add private endpoint and gRPC test cases (#306)
Browse files Browse the repository at this point in the history
Because

- add private API for internal usage

This commit

- add private API implementation 
- add gRPC test case for private APIs
  • Loading branch information
Phelan164 committed Mar 21, 2023
1 parent 6e32f6c commit bb3c193
Show file tree
Hide file tree
Showing 93 changed files with 8,954 additions and 1,868 deletions.
79 changes: 59 additions & 20 deletions cmd/main/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,17 +117,20 @@ func main() {
grpcServerOpts = append(grpcServerOpts, grpc.Creds(creds))
}

grpcS := grpc.NewServer(grpcServerOpts...)
reflection.Register(grpcS)
privateGrpcS := grpc.NewServer(grpcServerOpts...)
reflection.Register(privateGrpcS)

publicGrpcS := grpc.NewServer(grpcServerOpts...)
reflection.Register(publicGrpcS)

triton := triton.NewTriton()
defer triton.Close()

mgmtPrivateServiceClient, mgmtPrivateServiceClientConn := external.InitMgmtPrivateServiceClient()
defer mgmtPrivateServiceClientConn.Close()

pipelineServiceClient, pipelineServiceClientConn := external.InitPipelinePublicServiceClient()
defer pipelineServiceClientConn.Close()
pipelinePublicServiceClient, pipelinePublicServiceClientConn := external.InitPipelinePublicServiceClient()
defer pipelinePublicServiceClientConn.Close()

redisClient := redis.NewClient(&config.Config.Cache.Redis.RedisOptions)
defer redisClient.Close()
Expand All @@ -146,16 +149,30 @@ func main() {

repository := repository.NewRepository(db)

service := service.NewService(repository, triton, pipelinePublicServiceClient, redisClient, temporalClient)

modelPB.RegisterModelPublicServiceServer(
grpcS,
handler.NewHandler(
service.NewService(repository, triton, pipelineServiceClient, redisClient, temporalClient),
triton))
publicGrpcS,
handler.NewPublicHandler(service, triton))

modelPB.RegisterModelPrivateServiceServer(
privateGrpcS,
handler.NewPrivateHandler(service, triton))

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

gwS := runtime.NewServeMux(
privateGwS := runtime.NewServeMux(
runtime.WithForwardResponseOption(httpResponseModifier),
runtime.WithErrorHandler(errorHandler),
runtime.WithIncomingHeaderMatcher(customMatcher),
runtime.WithMarshalerOption(runtime.MIMEWildcard, &runtime.JSONPb{
MarshalOptions: util.MarshalOptions,
UnmarshalOptions: util.UnmarshalOptions,
}),
)

publicGwS := runtime.NewServeMux(
runtime.WithForwardResponseOption(httpResponseModifier),
runtime.WithErrorHandler(errorHandler),
runtime.WithIncomingHeaderMatcher(customMatcher),
Expand All @@ -166,17 +183,17 @@ func main() {
)

// Register custom route for POST /v1alpha/models/{name=models/*/instances/*}/test-multipart which makes model inference for REST multiple-part form-data
if err := gwS.HandlePath("POST", "/v1alpha/{name=models/*/instances/*}/test-multipart", appendCustomHeaderMiddleware(handler.HandleTestModelInstanceByUpload)); err != nil {
if err := publicGwS.HandlePath("POST", "/v1alpha/{name=models/*/instances/*}/test-multipart", appendCustomHeaderMiddleware(handler.HandleTestModelInstanceByUpload)); err != nil {
panic(err)
}

// Register custom route for POST /v1alpha/models/{name=models/*/instances/*}/trigger-multipart which makes model inference for REST multiple-part form-data
if err := gwS.HandlePath("POST", "/v1alpha/{name=models/*/instances/*}/trigger-multipart", appendCustomHeaderMiddleware(handler.HandleTriggerModelInstanceByUpload)); err != nil {
if err := publicGwS.HandlePath("POST", "/v1alpha/{name=models/*/instances/*}/trigger-multipart", appendCustomHeaderMiddleware(handler.HandleTriggerModelInstanceByUpload)); err != nil {
panic(err)
}

// Register custom route for POST /models/multipart which uploads model for REST multiple-part form-data
if err := gwS.HandlePath("POST", "/v1alpha/models/multipart", appendCustomHeaderMiddleware(handler.HandleCreateModelByMultiPartFormData)); err != nil {
if err := publicGwS.HandlePath("POST", "/v1alpha/models/multipart", appendCustomHeaderMiddleware(handler.HandleCreateModelByMultiPartFormData)); err != nil {
panic(err)
}

Expand All @@ -197,27 +214,48 @@ func main() {
} else {
dialOpts = []grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())}
}

if err := modelPB.RegisterModelPublicServiceHandlerFromEndpoint(ctx, gwS, fmt.Sprintf(":%v", config.Config.Server.Port), dialOpts); err != nil {
if err := modelPB.RegisterModelPrivateServiceHandlerFromEndpoint(ctx, privateGwS, fmt.Sprintf(":%v", config.Config.Server.PrivatePort), dialOpts); err != nil {
logger.Fatal(err.Error())
}
httpServer := &http.Server{
Addr: fmt.Sprintf(":%v", config.Config.Server.Port),
Handler: grpcHandlerFunc(grpcS, gwS, config.Config.Server.CORSOrigins),
if err := modelPB.RegisterModelPublicServiceHandlerFromEndpoint(ctx, publicGwS, fmt.Sprintf(":%v", config.Config.Server.PublicPort), dialOpts); err != nil {
logger.Fatal(err.Error())
}

privateHttpServer := &http.Server{
Addr: fmt.Sprintf(":%v", config.Config.Server.PrivatePort),
Handler: grpcHandlerFunc(privateGrpcS, privateGwS, config.Config.Server.CORSOrigins),
}

publicHttpServer := &http.Server{
Addr: fmt.Sprintf(":%v", config.Config.Server.PublicPort),
Handler: grpcHandlerFunc(publicGrpcS, publicGwS, config.Config.Server.CORSOrigins),
}

// Wait for interrupt signal to gracefully shutdown the server with a timeout of 5 seconds.
quitSig := make(chan os.Signal, 1)
errSig := make(chan error)
if config.Config.Server.HTTPS.Cert != "" && config.Config.Server.HTTPS.Key != "" {
go func() {
if err := httpServer.ListenAndServeTLS(config.Config.Server.HTTPS.Cert, config.Config.Server.HTTPS.Key); err != nil {
if err := privateHttpServer.ListenAndServeTLS(config.Config.Server.HTTPS.Cert, config.Config.Server.HTTPS.Key); err != nil {
errSig <- err
}
}()
} else {
go func() {
if err := privateHttpServer.ListenAndServe(); err != nil {
errSig <- err
}
}()
}
if config.Config.Server.HTTPS.Cert != "" && config.Config.Server.HTTPS.Key != "" {
go func() {
if err := publicHttpServer.ListenAndServeTLS(config.Config.Server.HTTPS.Cert, config.Config.Server.HTTPS.Key); err != nil {
errSig <- err
}
}()
} else {
go func() {
if err := httpServer.ListenAndServe(); err != nil {
if err := publicHttpServer.ListenAndServe(); err != nil {
errSig <- err
}
}()
Expand All @@ -237,7 +275,8 @@ func main() {
usg.TriggerSingleReporter(ctx)
}
logger.Info("Shutting down server...")
grpcS.GracefulStop()
privateGrpcS.GracefulStop()
publicGrpcS.GracefulStop()
}

}
17 changes: 9 additions & 8 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ import (

// ServerConfig defines HTTP server configurations
type ServerConfig struct {
Port int `koanf:"port"`
HTTPS struct {
PrivatePort int `koanf:"privateport"`
PublicPort int `koanf:"publicport"`
HTTPS struct {
Cert string `koanf:"cert"`
Key string `koanf:"key"`
}
Expand Down Expand Up @@ -54,9 +55,9 @@ type TritonServerConfig struct {

// MgmtBackendConfig related to mgmt-backend
type MgmtBackendConfig struct {
Host string `koanf:"host"`
AdminPort int `koanf:"adminport"`
HTTPS struct {
Host string `koanf:"host"`
PublicPort int `koanf:"publicport"`
HTTPS struct {
Cert string `koanf:"cert"`
Key string `koanf:"key"`
}
Expand All @@ -78,9 +79,9 @@ type UsageServerConfig struct {

// PipelineBackendConfig related to pipeline-backend
type PipelineBackendConfig struct {
Host string `koanf:"host"`
Port int `koanf:"port"`
HTTPS struct {
Host string `koanf:"host"`
PublicPort int `koanf:"publicport"`
HTTPS struct {
Cert string `koanf:"cert"`
Key string `koanf:"key"`
}
Expand Down
7 changes: 4 additions & 3 deletions config/config.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
server:
port: 8083
privateport: 3083
publicport: 8083
https:
cert:
key:
Expand All @@ -26,7 +27,7 @@ tritonserver:
modelstore: /model-repository
mgmtbackend:
host: mgmt-backend
adminport: 3084
publicport: 3084
https:
cert:
key:
Expand All @@ -40,7 +41,7 @@ usageserver:
port: 443
pipelinebackend:
host: pipeline-backend
port: 8081
publicport: 8081
https:
cert:
key:
Expand Down
18 changes: 12 additions & 6 deletions integration-test/const.js
Original file line number Diff line number Diff line change
@@ -1,24 +1,30 @@
let proto, host, port
let proto, host, publicPort, privatePort

if (__ENV.MODE == "api-gateway") {
// api-gateway mode
proto = "http"
host = "api-gateway"
port = 8080
publicPort = 8080
privatePort = 3083
} else if (__ENV.MODE == "localhost") {
// localhost mode for GitHub Actions
proto = "http"
host = "localhost"
port = 8080
publicPort = 8080
privatePort = 3083
} else {
// direct microservice mode
proto = "http"
host = "model-backend"
port = 8083
publicPort = 8083
privatePort = 3083
}

export const gRPCHost = `${host}:${port}`
export const apiHost = `${proto}://${host}:${port}`
export const gRPCPrivateHost = `${host}:${privatePort}`
export const apiPrivateHost = `${proto}://${host}:${privatePort}`

export const gRPCPublicHost = `${host}:${publicPort}`
export const apiPublicHost = `${proto}://${host}:${publicPort}`

export const cls_model = open(`${__ENV.TEST_FOLDER_ABS_PATH}/integration-test/data/dummy-cls-model.zip`, "b");
export const det_model = open(`${__ENV.TEST_FOLDER_ABS_PATH}/integration-test//data/dummy-det-model.zip`, "b");
Expand Down
23 changes: 17 additions & 6 deletions integration-test/grpc.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import {
import * as createModel from "./grpc_create_model.js"
import * as updateModel from "./grpc_update_model.js"
import * as queryModel from "./grpc_query_model.js"
import * as queryModelPrivate from "./grpc_query_model_private.js"
import * as deployModel from "./grpc_deploy_model.js"
import * as inferModel from "./grpc_infer_model.js"
import * as publishModel from "./grpc_publish_model.js"
Expand All @@ -25,28 +26,38 @@ export const options = {
};

const client = new grpc.Client();
client.load(['proto'], 'model_definition.proto');
client.load(['proto'], 'model.proto');
client.load(['proto'], 'model_public_service.proto');
client.load(['proto'], 'healthcheck.proto');
client.load(['proto/vdp/model/v1alpha'], 'model_definition.proto');
client.load(['proto/vdp/model/v1alpha'], 'model.proto');
client.load(['proto/vdp/model/v1alpha'], 'model_private_service.proto');
client.load(['proto/vdp/model/v1alpha'], 'model_public_service.proto');
client.load(['proto/vdp/model/v1alpha'], 'healthcheck.proto');

export function setup() { }

export default () => {
// Liveness check
{
group("Model API: Liveness", () => {
client.connect(constant.gRPCHost, {
client.connect(constant.gRPCPublicHost, {
plaintext: true
});
const response = client.invoke('vdp.model.v1alpha.ModelPublicService/Liveness', {});
console.log(response.message);
check(response, {
'Status is OK': (r) => r && r.status === grpc.StatusOK,
'Response status is SERVING_STATUS_SERVING': (r) => r && r.message.healthCheckResponse.status === "SERVING_STATUS_SERVING",
});
client.close()
});
}

// Private API
if (__ENV.MODE != "api-gateway" && __ENV.MODE != "localhost") {
queryModelPrivate.GetModel()
queryModelPrivate.ListModels()
queryModelPrivate.LookUpModel()
}

// Create model API
createModel.CreateModel()

Expand Down Expand Up @@ -82,7 +93,7 @@ export default () => {
};

export function teardown() {
client.connect(constant.gRPCHost, {
client.connect(constant.gRPCPublicHost, {
plaintext: true
});
group("Model API: Delete all models created by this test", () => {
Expand Down
10 changes: 5 additions & 5 deletions integration-test/grpc_create_model.js
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ import {
} from "https://jslib.k6.io/k6-utils/1.1.0/index.js";

const client = new grpc.Client();
client.load(['proto'], 'model_definition.proto');
client.load(['proto'], 'model.proto');
client.load(['proto'], 'model_public_service.proto');
client.load(['proto/vdp/model/v1alpha'], 'model_definition.proto');
client.load(['proto/vdp/model/v1alpha'], 'model.proto');
client.load(['proto/vdp/model/v1alpha'], 'model_public_service.proto');

import * as constant from "./const.js"

Expand All @@ -20,7 +20,7 @@ const model_def_name = "model-definitions/github"
export function CreateModel() {
// CreateModelBinaryFileUpload check
group("Model API: CreateModelBinaryFileUpload", () => {
client.connect(constant.gRPCHost, {
client.connect(constant.gRPCPublicHost, {
plaintext: true
});
check(client.invoke('vdp.model.v1alpha.ModelPublicService/CreateModelBinaryFileUpload', {}), {
Expand All @@ -33,7 +33,7 @@ export function CreateModel() {

// CreateModel check
group("Model API: CreateModel with GitHub", () => {
client.connect(constant.gRPCHost, {
client.connect(constant.gRPCPublicHost, {
plaintext: true
});
let model_id = randomString(10)
Expand Down
10 changes: 5 additions & 5 deletions integration-test/grpc_deploy_model.js
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@ import {
import * as constant from "./const.js"

const client = new grpc.Client();
client.load(['proto'], 'model_definition.proto');
client.load(['proto'], 'model.proto');
client.load(['proto'], 'model_public_service.proto');
client.load(['proto/vdp/model/v1alpha'], 'model_definition.proto');
client.load(['proto/vdp/model/v1alpha'], 'model.proto');
client.load(['proto/vdp/model/v1alpha'], 'model_public_service.proto');

const model_def_name = "model-definitions/local"

export function DeployUndeployModel() {
// Deploy ModelInstance check
group("Model API: Deploy ModelInstance", () => {
client.connect(constant.gRPCHost, {
client.connect(constant.gRPCPublicHost, {
plaintext: true
});

Expand All @@ -39,7 +39,7 @@ export function DeployUndeployModel() {
fd_cls.append("description", model_description);
fd_cls.append("model_definition", model_def_name);
fd_cls.append("content", http.file(constant.cls_model, "dummy-cls-model.zip"));
let createClsModelRes = http.request("POST", `${constant.apiHost}/v1alpha/models/multipart`, fd_cls.body(), {
let createClsModelRes = http.request("POST", `${constant.apiPublicHost}/v1alpha/models/multipart`, fd_cls.body(), {
headers: genHeader(`multipart/form-data; boundary=${fd_cls.boundary}`),
})
check(createClsModelRes, {
Expand Down

0 comments on commit bb3c193

Please sign in to comment.