diff --git a/go/cmd/api-devserver/main.go b/go/cmd/api-devserver/main.go index 81c48be9f07..b5efb889031 100644 --- a/go/cmd/api-devserver/main.go +++ b/go/cmd/api-devserver/main.go @@ -6,14 +6,19 @@ import ( "errors" "flag" "fmt" + "log/slog" "os" "os/exec" "os/signal" "path/filepath" "syscall" + "cloud.google.com/go/datastore" + "cloud.google.com/go/storage" "github.com/google/osv.dev/go/internal/api" + db "github.com/google/osv.dev/go/internal/database/datastore" "github.com/google/osv.dev/go/logger" + "github.com/google/osv.dev/go/osv/clients" ) const ( @@ -65,11 +70,7 @@ func run() error { if !*noBackend { logger.InfoContext(ctx, "Starting Go API backend natively", "port", *backendPort) - go func() { - if err := api.RunServer(ctx, *backendPort); err != nil { - logger.ErrorContext(ctx, "Go API server exited", "error", err) - } - }() + go runBackend(ctx, *backendPort) } logger.InfoContext(ctx, "Starting ESPv2 container", "port", *espPort, "backendPort", *backendPort) @@ -149,3 +150,41 @@ func runCmdAsync(cmd *exec.Cmd) <-chan error { return out } + +func runBackend(ctx context.Context, port int) { + project := os.Getenv("GOOGLE_CLOUD_PROJECT") + if project == "" { + logger.ErrorContext(ctx, "GOOGLE_CLOUD_PROJECT environment variable is not set") + return + } + datastoreID := os.Getenv("DATASTORE_DATABASE_ID") // empty string is the (default) database + dbClient, err := datastore.NewClientWithDatabase(ctx, project, datastoreID) + if err != nil { + logger.ErrorContext(ctx, "failed to create datastore client", "error", err) + return + } + defer dbClient.Close() + gcsClient, err := storage.NewClient(ctx) + if err != nil { + logger.ErrorContext(ctx, "Failed to create storage client", slog.Any("error", err)) + return + } + defer gcsClient.Close() + vulnBucket := os.Getenv("OSV_VULNERABILITIES_BUCKET") + if vulnBucket == "" { + logger.ErrorContext(ctx, "OSV_VULNERABILITIES_BUCKET environment variable is not set") + return + } + vulnStore := db.NewVulnerabilityStore(db.VulnStoreConfig{ + Client: dbClient, + GCS: clients.NewGCSClient(gcsClient, vulnBucket), + }) + relationsStore := db.NewRelationsStore(dbClient) + if err := api.RunServer(ctx, api.ServerOptions{ + Port: port, + VulnStore: vulnStore, + RelationsStore: relationsStore, + }); err != nil { + logger.ErrorContext(ctx, "Go API server exited", "error", err) + } +} diff --git a/go/cmd/api/main.go b/go/cmd/api/main.go index fcc2725203a..e324bb1895f 100644 --- a/go/cmd/api/main.go +++ b/go/cmd/api/main.go @@ -3,13 +3,19 @@ package main import ( "context" + "errors" "flag" + "log/slog" "os" "os/signal" "syscall" + "cloud.google.com/go/datastore" + "cloud.google.com/go/storage" "github.com/google/osv.dev/go/internal/api" + db "github.com/google/osv.dev/go/internal/database/datastore" "github.com/google/osv.dev/go/logger" + "github.com/google/osv.dev/go/osv/clients" ) func main() { @@ -28,5 +34,38 @@ func run() error { ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) defer stop() - return api.RunServer(ctx, *port) + project := os.Getenv("GOOGLE_CLOUD_PROJECT") + if project == "" { + logger.ErrorContext(ctx, "GOOGLE_CLOUD_PROJECT environment variable is not set") + return errors.New("GOOGLE_CLOUD_PROJECT environment variable is not set") + } + datastoreID := os.Getenv("DATASTORE_DATABASE_ID") // empty string is the (default) database + dbClient, err := datastore.NewClientWithDatabase(ctx, project, datastoreID) + if err != nil { + logger.ErrorContext(ctx, "Failed to create datastore client", slog.Any("error", err)) + return err + } + defer dbClient.Close() + gcsClient, err := storage.NewClient(ctx) + if err != nil { + logger.ErrorContext(ctx, "Failed to create storage client", slog.Any("error", err)) + return err + } + defer gcsClient.Close() + vulnBucket := os.Getenv("OSV_VULNERABILITIES_BUCKET") + if vulnBucket == "" { + logger.ErrorContext(ctx, "OSV_VULNERABILITIES_BUCKET environment variable is not set") + return errors.New("OSV_VULNERABILITIES_BUCKET environment variable is not set") + } + vulnStore := db.NewVulnerabilityStore(db.VulnStoreConfig{ + Client: dbClient, + GCS: clients.NewGCSClient(gcsClient, vulnBucket), + }) + relationsStore := db.NewRelationsStore(dbClient) + + return api.RunServer(ctx, api.ServerOptions{ + Port: *port, + VulnStore: vulnStore, + RelationsStore: relationsStore, + }) } diff --git a/go/internal/api/get_vuln_by_id.go b/go/internal/api/get_vuln_by_id.go new file mode 100644 index 00000000000..975bdec4b23 --- /dev/null +++ b/go/internal/api/get_vuln_by_id.go @@ -0,0 +1,59 @@ +package api + +import ( + "context" + "errors" + "log/slog" + "strings" + + "github.com/google/osv.dev/go/internal/models" + "github.com/google/osv.dev/go/logger" + "github.com/ossf/osv-schema/bindings/go/osvschema" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + pb "osv.dev/bindings/go/api" +) + +//nolint:revive // complains about 'Id' instead of 'ID', but that matches the API (the proto). +func (s *server) GetVulnById(ctx context.Context, params *pb.GetVulnByIdParameters) (*osvschema.Vulnerability, error) { + id := params.GetId() + if len(id) == 0 { + return nil, status.Error(codes.InvalidArgument, "ID is required") + } + // Datastore has a limit of how large indexed properties can be (1500 bytes). + // Vulnerability IDs are not going to be over 100 characters. + if len(id) > 100 { + return nil, status.Error(codes.InvalidArgument, "ID is too long") + } + vulnerability, err := s.vulnStore.Get(ctx, id) + if err == nil { + return vulnerability, nil + } + if !errors.Is(err, models.ErrNotFound) { + logger.ErrorContext(ctx, "failed to get vulnerability from store", + slog.String("id", id), + slog.Any("error", err), + ) + + return nil, status.Errorf(codes.Internal, "error getting vulnerability: %v", err) + } + + // Check for aliases + aliases, err := s.relationsStore.GetAliases(ctx, id) + if err != nil { + if errors.Is(err, models.ErrNotFound) { + return nil, status.Error(codes.NotFound, "Vulnerability not found") + } + + logger.ErrorContext(ctx, "failed to check aliases for vulnerability", + slog.String("id", id), + slog.Any("error", err), + ) + + return nil, status.Errorf(codes.Internal, "error getting vulnerability: %v", err) + } + + aliasStrs := strings.Join(aliases.Aliases, " ") + + return nil, status.Errorf(codes.NotFound, "Vulnerability not found, but the following aliases were: %s", aliasStrs) +} diff --git a/go/internal/api/get_vuln_by_id_test.go b/go/internal/api/get_vuln_by_id_test.go new file mode 100644 index 00000000000..079c6f4ceb9 --- /dev/null +++ b/go/internal/api/get_vuln_by_id_test.go @@ -0,0 +1,183 @@ +package api + +import ( + "context" + "errors" + "iter" + "strings" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/google/osv.dev/go/internal/models" + "github.com/ossf/osv-schema/bindings/go/osvschema" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/testing/protocmp" + + pb "osv.dev/bindings/go/api" +) + +type mockVulnerabilityStore struct { + vuln *osvschema.Vulnerability + err error +} + +func (m *mockVulnerabilityStore) Get(_ context.Context, _ string) (*osvschema.Vulnerability, error) { + if m.err != nil { + return nil, m.err + } + if m.vuln == nil { + return nil, models.ErrNotFound + } + + return m.vuln, nil +} + +func (m *mockVulnerabilityStore) ListBySource(_ context.Context, _ string, _ bool) iter.Seq2[*models.VulnSourceRef, error] { + panic("unimplemented") +} + +func (m *mockVulnerabilityStore) GetSourceModified(_ context.Context, _ string) (time.Time, error) { + panic("unimplemented") +} + +func (m *mockVulnerabilityStore) GetWithMetadata(_ context.Context, _ string) (*osvschema.Vulnerability, *models.VulnSourceRef, error) { + panic("unimplemented") +} + +func (m *mockVulnerabilityStore) Write(_ context.Context, _ models.WriteRequest) error { + panic("unimplemented") +} + +type mockRelationsStore struct { + aliases *models.GetAliasResult + err error +} + +func (m *mockRelationsStore) GetAliases(_ context.Context, _ string) (*models.GetAliasResult, error) { + if m.err != nil { + return nil, m.err + } + if m.aliases == nil { + return nil, models.ErrNotFound + } + + return m.aliases, nil +} + +func (m *mockRelationsStore) GetRelated(_ context.Context, _ string) (*models.GetRelatedResult, error) { + panic("unimplemented") +} + +func (m *mockRelationsStore) GetUpstream(_ context.Context, _ string) (*models.GetUpstreamResult, error) { + panic("unimplemented") +} + +func TestGetVulnById(t *testing.T) { + ctx := context.Background() + + testVuln := &osvschema.Vulnerability{ + Id: "TEST-1", + } + + tests := []struct { + name string + id string + mockVuln *osvschema.Vulnerability + mockVulnErr error + mockAliases *models.GetAliasResult + mockAliasesErr error + want *osvschema.Vulnerability + wantErrCode codes.Code + wantErrMsg string + }{ + { + name: "Success", + id: "TEST-1", + mockVuln: testVuln, + want: testVuln, + }, + { + name: "Empty ID", + id: "", + wantErrCode: codes.InvalidArgument, + wantErrMsg: "ID is required", + }, + { + name: "Too Long ID", + id: string(make([]byte, 101)), + wantErrCode: codes.InvalidArgument, + wantErrMsg: "ID is too long", + }, + { + name: "Not Found - No Aliases", + id: "TEST-1", + wantErrCode: codes.NotFound, + wantErrMsg: "Vulnerability not found", + }, + { + name: "Not Found - With Aliases", + id: "TEST-1", + mockAliases: &models.GetAliasResult{ + Aliases: []string{"ALIAS-1", "ALIAS-2"}, + }, + wantErrCode: codes.NotFound, + wantErrMsg: "Vulnerability not found, but the following aliases were: ALIAS-1 ALIAS-2", + }, + { + name: "VulnStore Error", + id: "TEST-1", + mockVulnErr: errors.New("internal GCS error"), + wantErrCode: codes.Internal, + wantErrMsg: "error getting vulnerability", + }, + { + name: "RelationsStore Error", + id: "TEST-1", + mockAliasesErr: errors.New("internal Datastore error"), + wantErrCode: codes.Internal, + wantErrMsg: "error getting vulnerability", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &server{ + vulnStore: &mockVulnerabilityStore{ + vuln: tt.mockVuln, + err: tt.mockVulnErr, + }, + relationsStore: &mockRelationsStore{ + aliases: tt.mockAliases, + err: tt.mockAliasesErr, + }, + } + + got, err := s.GetVulnById(ctx, &pb.GetVulnByIdParameters{Id: tt.id}) + + if tt.wantErrCode != codes.OK { + if err == nil { + t.Fatalf("GetVulnById() expected error, got nil") + } + st, ok := status.FromError(err) + if !ok { + t.Fatalf("GetVulnById() expected gRPC status error, got %v", err) + } + if st.Code() != tt.wantErrCode { + t.Errorf("GetVulnById() error code = %v, want %v", st.Code(), tt.wantErrCode) + } + if tt.wantErrMsg != "" && !strings.Contains(st.Message(), tt.wantErrMsg) { + t.Errorf("GetVulnById() error message = %q, want to contain %q", st.Message(), tt.wantErrMsg) + } + } else { + if err != nil { + t.Fatalf("GetVulnById() unexpected error: %v", err) + } + if diff := cmp.Diff(tt.want, got, protocmp.Transform()); diff != "" { + t.Errorf("GetVulnById() mismatch (-want +got):\n%s", diff) + } + } + }) + } +} diff --git a/go/internal/api/server.go b/go/internal/api/server.go index ffc703fcff7..354a4717f84 100644 --- a/go/internal/api/server.go +++ b/go/internal/api/server.go @@ -6,6 +6,7 @@ import ( "fmt" "net" + "github.com/google/osv.dev/go/internal/models" "github.com/google/osv.dev/go/logger" "google.golang.org/grpc" pb "osv.dev/bindings/go/api" @@ -13,20 +14,32 @@ import ( type server struct { pb.UnimplementedOSVServer + + vulnStore models.VulnerabilityStore + relationsStore models.RelationsStore +} + +type ServerOptions struct { + Port int + VulnStore models.VulnerabilityStore + RelationsStore models.RelationsStore } // RunServer starts the gRPC server and handles graceful shutdown. -func RunServer(ctx context.Context, port int) error { - lis, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) +func RunServer(ctx context.Context, opts ServerOptions) error { + lis, err := net.Listen("tcp", fmt.Sprintf(":%d", opts.Port)) if err != nil { logger.ErrorContext(ctx, "failed to listen", "error", err) return err } s := grpc.NewServer() - pb.RegisterOSVServer(s, &server{}) + pb.RegisterOSVServer(s, &server{ + vulnStore: opts.VulnStore, + relationsStore: opts.RelationsStore, + }) - logger.InfoContext(ctx, "server listening", "port", port) + logger.InfoContext(ctx, "server listening", "port", opts.Port) serveErr := make(chan error, 1) go func() {