diff --git a/cmd/guacrest/cmd/root.go b/cmd/guacrest/cmd/root.go index 44fafd9863..ddf3c5635d 100644 --- a/cmd/guacrest/cmd/root.go +++ b/cmd/guacrest/cmd/root.go @@ -33,14 +33,20 @@ var flags = struct { tlsCertFile string tlsKeyFile string + + dbDirectConnection bool + dbDriver string + dbAddress string }{} var rootCmd = &cobra.Command{ Use: "guacrest", Short: "Guac REST API Server", - Long: "The Guac REST API Server provides usable and analysis-focused endpoints. " + - "It is backed by the GraphQL API Server, which must be running for this server " + - "to work.", + Long: "The Guac REST API Server provides usable and analysis-focused endpoints.\n\n " + + "The default data backend is the GraphQL API Server, which must be " + + "running for this server to work. Some endpoints are optimized with a direct " + + "connection to the database that backs the GraphQL API. To enable this, " + + "set the db flags.", Version: version.Version, Run: func(command *cobra.Command, args []string) { flags.restAPIServerPort = viper.GetInt("rest-api-server-port") @@ -49,6 +55,10 @@ var rootCmd = &cobra.Command{ flags.tlsCertFile = viper.GetString("rest-api-tls-cert-file") flags.tlsKeyFile = viper.GetString("rest-api-tls-key-file") + flags.dbDriver = viper.GetString("db-driver") + flags.dbAddress = viper.GetString("db-address") + flags.dbDirectConnection = viper.GetBool("db-direct-connection") + startServer() }, } @@ -68,9 +78,14 @@ func init() { "rest-api-server-port", "rest-api-tls-cert-file", "rest-api-tls-key-file", + + // configuration of direct database connection + "db-direct-connection", + "db-driver", + "db-address", }) if err != nil { - fmt.Fprintf(os.Stderr, "failed to setup flag: %v", err) + fmt.Fprintf(os.Stderr, "failed to setup flags: %v", err) os.Exit(1) } diff --git a/cmd/guacrest/cmd/server.go b/cmd/guacrest/cmd/server.go index 664e6e374f..f7860fea0e 100644 --- a/cmd/guacrest/cmd/server.go +++ b/cmd/guacrest/cmd/server.go @@ -27,6 +27,8 @@ import ( "github.com/Khan/genqlient/graphql" "github.com/go-chi/chi" + "github.com/guacsec/guac/pkg/assembler/backends/ent" + "github.com/guacsec/guac/pkg/assembler/backends/ent/backend" "github.com/guacsec/guac/pkg/cli" gen "github.com/guacsec/guac/pkg/guacrest/generated" "github.com/guacsec/guac/pkg/guacrest/server" @@ -40,7 +42,7 @@ func startServer() { httpClient := &http.Client{Transport: cli.HTTPHeaderTransport(ctx, flags.headerFile, http.DefaultTransport)} gqlClient := getGraphqlServerClientOrExit(ctx, httpClient) - restApiHandler := gen.Handler(gen.NewStrictHandler(server.NewDefaultServer(gqlClient), nil)) + restApiHandler := gen.Handler(gen.NewStrictHandler(getRestApiHandlerOrExit(ctx, gqlClient), nil)) router := chi.NewRouter() router.Use(server.AddLoggerToCtxMiddleware, server.LogRequestsMiddleware) @@ -55,8 +57,8 @@ func startServer() { proto = "https" } - logger.Infof("Connect to the server at %s://0.0.0.0:%d/", proto, flags.restAPIServerPort) - logger.Info("Starting Server") + logger.Infof("connect to the server at %s://0.0.0.0:%d/", proto, flags.restAPIServerPort) + logger.Info("starting Server") go func() { var err error if proto == "https" { @@ -65,14 +67,14 @@ func startServer() { err = server.ListenAndServe() } if err != nil && err != http.ErrServerClosed { - logger.Errorf("Server finished with error: %s", err) + logger.Errorf("server finished with error: %s", err) } }() sigs := make(chan os.Signal, 1) signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) s := <-sigs - logger.Infof("Signal recieved: %s, shutting down gracefully\n", s.String()) + logger.Infof("signal recieved: %s, shutting down gracefully\n", s.String()) done := make(chan bool, 1) ctx, cf := context.WithCancel(ctx) @@ -90,6 +92,36 @@ func startServer() { cf() } +// get the service handler +// if an ent address is provided, get the handler backed by ent +func getRestApiHandlerOrExit(ctx context.Context, gqlClient graphql.Client) gen.StrictServerInterface { + logger := logging.FromContext(ctx) + if flags.dbDirectConnection { + logger.Infof("directly connecting to the Ent backend for optimized endpoint" + + "implementation. This is an experimental feature") + ent := getEntClientOrExit(ctx) + handler := server.NewEntConnectedServer(ent, gqlClient) + return handler + } + return server.NewDefaultServer(gqlClient) +} + +func getEntClientOrExit(ctx context.Context) *ent.Client { + logger := logging.FromContext(ctx) + client, err := backend.GetReadOnlyClient(ctx, &backend.BackendOptions{ + DriverName: flags.dbDriver, + Address: flags.dbAddress, + Debug: false, + // starting up the REST API shouldn't lead to a database migration, restart + // the graphql server instead + AutoMigrate: false, + }) + if err != nil { + logger.Fatalf("error getting the Ent client: %s", err) + } + return client +} + // get the graphql client and test the connection func getGraphqlServerClientOrExit(ctx context.Context, httpClient *http.Client) graphql.Client { logger := logging.FromContext(ctx) @@ -98,7 +130,7 @@ func getGraphqlServerClientOrExit(ctx context.Context, httpClient *http.Client) // expected here gqlBaseAddr, ok := strings.CutSuffix(flags.gqlServerAddress, "query") if !ok { - logger.Fatalf("Unexpected GraphQL server address. URL does not end in %q", "query") + logger.Fatalf("unexpected GraphQL server address. URL does not end in %q", "query") } gqlHealthzEndpoint := fmt.Sprintf("%s/healthz", gqlBaseAddr) @@ -111,6 +143,6 @@ func getGraphqlServerClientOrExit(ctx context.Context, httpClient *http.Client) gqlHealthzEndpoint, code) } - logger.Info("Successfully connected to Graphql Server") + logger.Info("successfully connected to Graphql Server") return graphql.NewClient(flags.gqlServerAddress, httpClient) } diff --git a/pkg/assembler/backends/ent/backend/migrations.go b/pkg/assembler/backends/ent/backend/migrations.go index 11bea823c4..d9ba65a551 100644 --- a/pkg/assembler/backends/ent/backend/migrations.go +++ b/pkg/assembler/backends/ent/backend/migrations.go @@ -22,6 +22,7 @@ import ( "entgo.io/ent/dialect" "github.com/guacsec/guac/pkg/assembler/backends/ent" + "github.com/guacsec/guac/pkg/assembler/backends/ent/hook" "github.com/guacsec/guac/pkg/assembler/backends/ent/migrate" "github.com/guacsec/guac/pkg/logging" @@ -35,6 +36,19 @@ type BackendOptions struct { AutoMigrate bool } +// GetReadOnlyClient sets up the ent backend and returns a read-only client. +func GetReadOnlyClient(ctx context.Context, options *BackendOptions) (*ent.Client, error) { + client, err := SetupBackend(ctx, options) + if err != nil { + return nil, err + } + // https://entgo.io/docs/hooks/#mutation + client.Use(hook.Reject( + ent.OpCreate | ent.OpUpdate | ent.OpUpdateOne | ent.OpDelete | ent.OpDeleteOne, + )) + return client, nil +} + // SetupBackend sets up the ent backend, preparing the database and returning a client func SetupBackend(ctx context.Context, options *BackendOptions) (*ent.Client, error) { logger := logging.FromContext(ctx) diff --git a/pkg/cli/store.go b/pkg/cli/store.go index 2fa32d2c28..360fde4495 100644 --- a/pkg/cli/store.go +++ b/pkg/cli/store.go @@ -85,6 +85,7 @@ func init() { set.String("rest-api-server-port", "8081", "port to serve the REST API from") set.String("rest-api-tls-cert-file", "", "path to the TLS certificate in PEM format for rest api server") set.String("rest-api-tls-key-file", "", "path to the TLS key in PEM format for rest api server") + set.Bool("db-direct-connection", false, "[experimental] connect directly to the database that backs the gql API for optimized endpoint implementations") set.String("verifier-key-path", "", "path to pem file to verify dsse") set.String("verifier-key-id", "", "ID of the key to be stored") diff --git a/pkg/guacrest/helpers/artifact_test.go b/pkg/guacrest/helpers/artifact_test.go index deaea9754b..4d50ac9982 100644 --- a/pkg/guacrest/helpers/artifact_test.go +++ b/pkg/guacrest/helpers/artifact_test.go @@ -42,10 +42,10 @@ func Test_FindArtifactWithDigest_ArtifactFound(t *testing.T) { Algorithm: "sha256", Digest: "abc", }} - _, err := gql.IngestArtifact(ctx, gqlClient, idOrArtifactSpec) - if err != nil { - t.Fatalf("Error ingesting test data") - } + _, err := gql.IngestArtifact(ctx, gqlClient, idOrArtifactSpec) + if err != nil { + t.Fatalf("Error ingesting test data") + } res, err := helpers.FindArtifactWithDigest(ctx, gqlClient, "abc") assert.NoError(t, err) diff --git a/pkg/guacrest/server/ent.go b/pkg/guacrest/server/ent.go new file mode 100644 index 0000000000..44cfcf97f9 --- /dev/null +++ b/pkg/guacrest/server/ent.go @@ -0,0 +1,40 @@ +// +// Copyright 2024 The GUAC Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "github.com/Khan/genqlient/graphql" + "github.com/guacsec/guac/pkg/assembler/backends/ent" +) + +// EntConnectedServer implements the REST API interface, using by default the +// GrapQL API Server as a backend, but also allows overriding the default +// handlers to ones that directly use the ENT backend. +// +// This is an experimental feature. +type EntConnectedServer struct { + ent *ent.Client + *DefaultServer +} + +func NewEntConnectedServer(ent *ent.Client, gqlClient graphql.Client) *EntConnectedServer { + return &EntConnectedServer{ + ent: ent, + DefaultServer: NewDefaultServer(gqlClient), + } +} + +// Override DefaultServer with methods here