Skip to content

Commit

Permalink
support direct connections to ent from the rest api (#1932)
Browse files Browse the repository at this point in the history
Signed-off-by: Marco Deicas <mdeicas@google.com>
  • Loading branch information
mdeicas committed Jun 12, 2024
1 parent bf65123 commit e2486e1
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 15 deletions.
23 changes: 19 additions & 4 deletions cmd/guacrest/cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,20 @@ var flags = struct {

tlsCertFile string
tlsKeyFile string

dbDirectConnection bool
dbDriver string
dbAddress string
}{}

var rootCmd = &cobra.Command{
Use: "guacrest",
Short: "Guac REST API Server",
Long: "The Guac REST API Server provides usable and analysis-focused endpoints. " +
"It is backed by the GraphQL API Server, which must be running for this server " +
"to work.",
Long: "The Guac REST API Server provides usable and analysis-focused endpoints.\n\n " +
"The default data backend is the GraphQL API Server, which must be " +
"running for this server to work. Some endpoints are optimized with a direct " +
"connection to the database that backs the GraphQL API. To enable this, " +
"set the db flags.",
Version: version.Version,
Run: func(command *cobra.Command, args []string) {
flags.restAPIServerPort = viper.GetInt("rest-api-server-port")
Expand All @@ -49,6 +55,10 @@ var rootCmd = &cobra.Command{
flags.tlsCertFile = viper.GetString("rest-api-tls-cert-file")
flags.tlsKeyFile = viper.GetString("rest-api-tls-key-file")

flags.dbDriver = viper.GetString("db-driver")
flags.dbAddress = viper.GetString("db-address")
flags.dbDirectConnection = viper.GetBool("db-direct-connection")

startServer()
},
}
Expand All @@ -68,9 +78,14 @@ func init() {
"rest-api-server-port",
"rest-api-tls-cert-file",
"rest-api-tls-key-file",

// configuration of direct database connection
"db-direct-connection",
"db-driver",
"db-address",
})
if err != nil {
fmt.Fprintf(os.Stderr, "failed to setup flag: %v", err)
fmt.Fprintf(os.Stderr, "failed to setup flags: %v", err)
os.Exit(1)
}

Expand Down
46 changes: 39 additions & 7 deletions cmd/guacrest/cmd/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ import (

"github.com/Khan/genqlient/graphql"
"github.com/go-chi/chi"
"github.com/guacsec/guac/pkg/assembler/backends/ent"
"github.com/guacsec/guac/pkg/assembler/backends/ent/backend"
"github.com/guacsec/guac/pkg/cli"
gen "github.com/guacsec/guac/pkg/guacrest/generated"
"github.com/guacsec/guac/pkg/guacrest/server"
Expand All @@ -40,7 +42,7 @@ func startServer() {
httpClient := &http.Client{Transport: cli.HTTPHeaderTransport(ctx, flags.headerFile, http.DefaultTransport)}
gqlClient := getGraphqlServerClientOrExit(ctx, httpClient)

restApiHandler := gen.Handler(gen.NewStrictHandler(server.NewDefaultServer(gqlClient), nil))
restApiHandler := gen.Handler(gen.NewStrictHandler(getRestApiHandlerOrExit(ctx, gqlClient), nil))

router := chi.NewRouter()
router.Use(server.AddLoggerToCtxMiddleware, server.LogRequestsMiddleware)
Expand All @@ -55,8 +57,8 @@ func startServer() {
proto = "https"
}

logger.Infof("Connect to the server at %s://0.0.0.0:%d/", proto, flags.restAPIServerPort)
logger.Info("Starting Server")
logger.Infof("connect to the server at %s://0.0.0.0:%d/", proto, flags.restAPIServerPort)
logger.Info("starting Server")
go func() {
var err error
if proto == "https" {
Expand All @@ -65,14 +67,14 @@ func startServer() {
err = server.ListenAndServe()
}
if err != nil && err != http.ErrServerClosed {
logger.Errorf("Server finished with error: %s", err)
logger.Errorf("server finished with error: %s", err)
}
}()

sigs := make(chan os.Signal, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
s := <-sigs
logger.Infof("Signal recieved: %s, shutting down gracefully\n", s.String())
logger.Infof("signal recieved: %s, shutting down gracefully\n", s.String())

done := make(chan bool, 1)
ctx, cf := context.WithCancel(ctx)
Expand All @@ -90,6 +92,36 @@ func startServer() {
cf()
}

// get the service handler
// if an ent address is provided, get the handler backed by ent
func getRestApiHandlerOrExit(ctx context.Context, gqlClient graphql.Client) gen.StrictServerInterface {
logger := logging.FromContext(ctx)
if flags.dbDirectConnection {
logger.Infof("directly connecting to the Ent backend for optimized endpoint" +
"implementation. This is an experimental feature")
ent := getEntClientOrExit(ctx)
handler := server.NewEntConnectedServer(ent, gqlClient)
return handler
}
return server.NewDefaultServer(gqlClient)
}

func getEntClientOrExit(ctx context.Context) *ent.Client {
logger := logging.FromContext(ctx)
client, err := backend.GetReadOnlyClient(ctx, &backend.BackendOptions{
DriverName: flags.dbDriver,
Address: flags.dbAddress,
Debug: false,
// starting up the REST API shouldn't lead to a database migration, restart
// the graphql server instead
AutoMigrate: false,
})
if err != nil {
logger.Fatalf("error getting the Ent client: %s", err)
}
return client
}

// get the graphql client and test the connection
func getGraphqlServerClientOrExit(ctx context.Context, httpClient *http.Client) graphql.Client {
logger := logging.FromContext(ctx)
Expand All @@ -98,7 +130,7 @@ func getGraphqlServerClientOrExit(ctx context.Context, httpClient *http.Client)
// expected here
gqlBaseAddr, ok := strings.CutSuffix(flags.gqlServerAddress, "query")
if !ok {
logger.Fatalf("Unexpected GraphQL server address. URL does not end in %q", "query")
logger.Fatalf("unexpected GraphQL server address. URL does not end in %q", "query")
}

gqlHealthzEndpoint := fmt.Sprintf("%s/healthz", gqlBaseAddr)
Expand All @@ -111,6 +143,6 @@ func getGraphqlServerClientOrExit(ctx context.Context, httpClient *http.Client)
gqlHealthzEndpoint, code)
}

logger.Info("Successfully connected to Graphql Server")
logger.Info("successfully connected to Graphql Server")
return graphql.NewClient(flags.gqlServerAddress, httpClient)
}
14 changes: 14 additions & 0 deletions pkg/assembler/backends/ent/backend/migrations.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (

"entgo.io/ent/dialect"
"github.com/guacsec/guac/pkg/assembler/backends/ent"
"github.com/guacsec/guac/pkg/assembler/backends/ent/hook"
"github.com/guacsec/guac/pkg/assembler/backends/ent/migrate"
"github.com/guacsec/guac/pkg/logging"

Expand All @@ -35,6 +36,19 @@ type BackendOptions struct {
AutoMigrate bool
}

// GetReadOnlyClient sets up the ent backend and returns a read-only client.
func GetReadOnlyClient(ctx context.Context, options *BackendOptions) (*ent.Client, error) {
client, err := SetupBackend(ctx, options)
if err != nil {
return nil, err
}
// https://entgo.io/docs/hooks/#mutation
client.Use(hook.Reject(
ent.OpCreate | ent.OpUpdate | ent.OpUpdateOne | ent.OpDelete | ent.OpDeleteOne,
))
return client, nil
}

// SetupBackend sets up the ent backend, preparing the database and returning a client
func SetupBackend(ctx context.Context, options *BackendOptions) (*ent.Client, error) {
logger := logging.FromContext(ctx)
Expand Down
1 change: 1 addition & 0 deletions pkg/cli/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ func init() {
set.String("rest-api-server-port", "8081", "port to serve the REST API from")
set.String("rest-api-tls-cert-file", "", "path to the TLS certificate in PEM format for rest api server")
set.String("rest-api-tls-key-file", "", "path to the TLS key in PEM format for rest api server")
set.Bool("db-direct-connection", false, "[experimental] connect directly to the database that backs the gql API for optimized endpoint implementations")

set.String("verifier-key-path", "", "path to pem file to verify dsse")
set.String("verifier-key-id", "", "ID of the key to be stored")
Expand Down
8 changes: 4 additions & 4 deletions pkg/guacrest/helpers/artifact_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ func Test_FindArtifactWithDigest_ArtifactFound(t *testing.T) {
Algorithm: "sha256",
Digest: "abc",
}}
_, err := gql.IngestArtifact(ctx, gqlClient, idOrArtifactSpec)
if err != nil {
t.Fatalf("Error ingesting test data")
}
_, err := gql.IngestArtifact(ctx, gqlClient, idOrArtifactSpec)
if err != nil {
t.Fatalf("Error ingesting test data")
}

res, err := helpers.FindArtifactWithDigest(ctx, gqlClient, "abc")
assert.NoError(t, err)
Expand Down
40 changes: 40 additions & 0 deletions pkg/guacrest/server/ent.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
//
// Copyright 2024 The GUAC Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package server

import (
"github.com/Khan/genqlient/graphql"
"github.com/guacsec/guac/pkg/assembler/backends/ent"
)

// EntConnectedServer implements the REST API interface, using by default the
// GrapQL API Server as a backend, but also allows overriding the default
// handlers to ones that directly use the ENT backend.
//
// This is an experimental feature.
type EntConnectedServer struct {
ent *ent.Client
*DefaultServer
}

func NewEntConnectedServer(ent *ent.Client, gqlClient graphql.Client) *EntConnectedServer {
return &EntConnectedServer{
ent: ent,
DefaultServer: NewDefaultServer(gqlClient),
}
}

// Override DefaultServer with methods here

0 comments on commit e2486e1

Please sign in to comment.