From b9409827080e0781e3e0627c1f6ded962c71488d Mon Sep 17 00:00:00 2001 From: Thomas Stromberg Date: Tue, 28 Oct 2025 22:20:36 -0400 Subject: [PATCH 1/2] make mock race-proof, add tests, fix lint --- datastore.go | 74 ++++++++++++++++++++++------------ datastore_test.go | 4 +- ds9mock/mock.go | 17 ++++++-- ds9mock/mock_test.go | 94 ++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 159 insertions(+), 30 deletions(-) diff --git a/datastore.go b/datastore.go index 7564756..93765bd 100644 --- a/datastore.go +++ b/datastore.go @@ -3,6 +3,8 @@ // It uses only the Go standard library and makes direct REST API calls // to the Datastore API. Authentication is handled via the GCP metadata // server when running on GCP, or via Application Default Credentials. +// +//nolint:revive // Public structs required for API compatibility with cloud.google.com/go/datastore package ds9 import ( @@ -21,6 +23,7 @@ import ( "reflect" "strconv" "strings" + "sync/atomic" "testing" "time" @@ -40,8 +43,9 @@ var ( // ErrNoSuchEntity is returned when an entity is not found. ErrNoSuchEntity = errors.New("datastore: no such entity") - // Package-level variable for easier testing. - apiURL = "https://datastore.googleapis.com/v1" + // atomicAPIURL stores the API URL for thread-safe access. + // Use getAPIURL() to read and setAPIURL() to write. + atomicAPIURL atomic.Pointer[string] httpClient = &http.Client{ Timeout: defaultTimeout, @@ -62,6 +66,22 @@ var ( } ) +//nolint:gochecknoinits // Required for thread-safe initialization of atomic pointer +func init() { + defaultURL := "https://datastore.googleapis.com/v1" + atomicAPIURL.Store(&defaultURL) +} + +// getAPIURL returns the current API URL in a thread-safe manner. +func getAPIURL() string { + return *atomicAPIURL.Load() +} + +// setAPIURL sets the API URL in a thread-safe manner. +func setAPIURL(url string) { + atomicAPIURL.Store(&url) +} + // SetTestURLs configures custom metadata and API URLs for testing. // This is intended for use by testing packages like ds9mock. // Returns a function that restores the original URLs. @@ -74,11 +94,11 @@ var ( // defer restore() func SetTestURLs(metadata, api string) (restore func()) { // Auth package will log warning if called outside test environment - oldAPI := apiURL - apiURL = api + oldAPI := getAPIURL() + setAPIURL(api) restoreAuth := auth.SetMetadataURL(metadata) return func() { - apiURL = oldAPI + setAPIURL(oldAPI) restoreAuth() } } @@ -88,6 +108,7 @@ type Client struct { logger *slog.Logger projectID string databaseID string + baseURL string // API base URL, defaults to production but can be overridden for testing } // NewClient creates a new Datastore client. @@ -122,6 +143,7 @@ func NewClientWithDatabase(ctx context.Context, projID, dbID string) (*Client, e return &Client{ projectID: projID, databaseID: dbID, + baseURL: getAPIURL(), logger: logger, }, nil } @@ -288,6 +310,8 @@ func DecodeCursor(s string) (Cursor, error) { // Iterator is an iterator for query results. // API compatible with cloud.google.com/go/datastore. +// +//nolint:govet // Field alignment optimized for API compatibility over memory layout type Iterator struct { ctx context.Context //nolint:containedctx // Required for API compatibility with cloud.google.com/go/datastore client *Client @@ -375,7 +399,7 @@ func (it *Iterator) fetch() error { } // URL-encode project ID to prevent injection attacks - reqURL := fmt.Sprintf("%s/projects/%s:runQuery", apiURL, neturl.PathEscape(it.client.projectID)) + reqURL := fmt.Sprintf("%s/projects/%s:runQuery", it.client.baseURL, neturl.PathEscape(it.client.projectID)) body, err := doRequest(it.ctx, it.client.logger, reqURL, jsonData, token, it.client.projectID, it.client.databaseID) if err != nil { return err @@ -567,7 +591,7 @@ func (c *Client) Get(ctx context.Context, key *Key, dst any) error { } // URL-encode project ID to prevent injection attacks - reqURL := fmt.Sprintf("%s/projects/%s:lookup", apiURL, neturl.PathEscape(c.projectID)) + reqURL := fmt.Sprintf("%s/projects/%s:lookup", c.baseURL, neturl.PathEscape(c.projectID)) body, err := doRequest(ctx, c.logger, reqURL, jsonData, token, c.projectID, c.databaseID) if err != nil { c.logger.ErrorContext(ctx, "lookup request failed", "error", err, "kind", key.Kind) @@ -632,7 +656,7 @@ func (c *Client) Put(ctx context.Context, key *Key, src any) (*Key, error) { } // URL-encode project ID to prevent injection attacks - reqURL := fmt.Sprintf("%s/projects/%s:commit", apiURL, neturl.PathEscape(c.projectID)) + reqURL := fmt.Sprintf("%s/projects/%s:commit", c.baseURL, neturl.PathEscape(c.projectID)) if _, err := doRequest(ctx, c.logger, reqURL, jsonData, token, c.projectID, c.databaseID); err != nil { c.logger.ErrorContext(ctx, "commit request failed", "error", err, "kind", key.Kind) return nil, err @@ -672,7 +696,7 @@ func (c *Client) Delete(ctx context.Context, key *Key) error { } // URL-encode project ID to prevent injection attacks - reqURL := fmt.Sprintf("%s/projects/%s:commit", apiURL, neturl.PathEscape(c.projectID)) + reqURL := fmt.Sprintf("%s/projects/%s:commit", c.baseURL, neturl.PathEscape(c.projectID)) if _, err := doRequest(ctx, c.logger, reqURL, jsonData, token, c.projectID, c.databaseID); err != nil { c.logger.ErrorContext(ctx, "delete request failed", "error", err, "kind", key.Kind) return err @@ -724,7 +748,7 @@ func (c *Client) GetMulti(ctx context.Context, keys []*Key, dst any) error { } // URL-encode project ID to prevent injection attacks - reqURL := fmt.Sprintf("%s/projects/%s:lookup", apiURL, neturl.PathEscape(c.projectID)) + reqURL := fmt.Sprintf("%s/projects/%s:lookup", c.baseURL, neturl.PathEscape(c.projectID)) body, err := doRequest(ctx, c.logger, reqURL, jsonData, token, c.projectID, c.databaseID) if err != nil { c.logger.ErrorContext(ctx, "lookup request failed", "error", err) @@ -841,7 +865,7 @@ func (c *Client) PutMulti(ctx context.Context, keys []*Key, src any) ([]*Key, er } // URL-encode project ID to prevent injection attacks - reqURL := fmt.Sprintf("%s/projects/%s:commit", apiURL, neturl.PathEscape(c.projectID)) + reqURL := fmt.Sprintf("%s/projects/%s:commit", c.baseURL, neturl.PathEscape(c.projectID)) if _, err := doRequest(ctx, c.logger, reqURL, jsonData, token, c.projectID, c.databaseID); err != nil { c.logger.ErrorContext(ctx, "commit request failed", "error", err) return nil, err @@ -895,7 +919,7 @@ func (c *Client) DeleteMulti(ctx context.Context, keys []*Key) error { } // URL-encode project ID to prevent injection attacks - reqURL := fmt.Sprintf("%s/projects/%s:commit", apiURL, neturl.PathEscape(c.projectID)) + reqURL := fmt.Sprintf("%s/projects/%s:commit", c.baseURL, neturl.PathEscape(c.projectID)) if _, err := doRequest(ctx, c.logger, reqURL, jsonData, token, c.projectID, c.databaseID); err != nil { c.logger.ErrorContext(ctx, "delete request failed", "error", err) return err @@ -985,7 +1009,7 @@ func (c *Client) AllocateIDs(ctx context.Context, keys []*Key) ([]*Key, error) { } // URL-encode project ID to prevent injection attacks - reqURL := fmt.Sprintf("%s/projects/%s:allocateIds", apiURL, neturl.PathEscape(c.projectID)) + reqURL := fmt.Sprintf("%s/projects/%s:allocateIds", c.baseURL, neturl.PathEscape(c.projectID)) body, err := doRequest(ctx, c.logger, reqURL, jsonData, token, c.projectID, c.databaseID) if err != nil { c.logger.ErrorContext(ctx, "allocateIds request failed", "error", err) @@ -1172,18 +1196,18 @@ func encodeValue(v any) (any, error) { if rv.Kind() == reflect.Slice || rv.Kind() == reflect.Array { length := rv.Len() values := make([]map[string]any, length) - for i := 0; i < length; i++ { + for i := range length { elem := rv.Index(i).Interface() encodedElem, err := encodeValue(elem) if err != nil { return nil, fmt.Errorf("failed to encode array element %d: %w", i, err) } // encodedElem is already a map[string]any with the type wrapper - if m, ok := encodedElem.(map[string]any); ok { - values[i] = m - } else { + m, ok := encodedElem.(map[string]any) + if !ok { return nil, fmt.Errorf("unexpected encoded value type for element %d", i) } + values[i] = m } return map[string]any{"arrayValue": map[string]any{"values": values}}, nil } @@ -1691,7 +1715,7 @@ func (c *Client) AllKeys(ctx context.Context, q *Query) ([]*Key, error) { } // URL-encode project ID to prevent injection attacks - reqURL := fmt.Sprintf("%s/projects/%s:runQuery", apiURL, neturl.PathEscape(c.projectID)) + reqURL := fmt.Sprintf("%s/projects/%s:runQuery", c.baseURL, neturl.PathEscape(c.projectID)) body, err := doRequest(ctx, c.logger, reqURL, jsonData, token, c.projectID, c.databaseID) if err != nil { c.logger.ErrorContext(ctx, "query request failed", "error", err, "kind", q.kind) @@ -1752,7 +1776,7 @@ func (c *Client) GetAll(ctx context.Context, query *Query, dst any) ([]*Key, err } // URL-encode project ID to prevent injection attacks - reqURL := fmt.Sprintf("%s/projects/%s:runQuery", apiURL, neturl.PathEscape(c.projectID)) + reqURL := fmt.Sprintf("%s/projects/%s:runQuery", c.baseURL, neturl.PathEscape(c.projectID)) body, err := doRequest(ctx, c.logger, reqURL, jsonData, token, c.projectID, c.databaseID) if err != nil { c.logger.ErrorContext(ctx, "query request failed", "error", err, "kind", query.kind) @@ -1846,7 +1870,7 @@ func (c *Client) Count(ctx context.Context, q *Query) (int, error) { } // URL-encode project ID to prevent injection attacks - reqURL := fmt.Sprintf("%s/projects/%s:runAggregationQuery", apiURL, neturl.PathEscape(c.projectID)) + reqURL := fmt.Sprintf("%s/projects/%s:runAggregationQuery", c.baseURL, neturl.PathEscape(c.projectID)) body, err := doRequest(ctx, c.logger, reqURL, jsonData, token, c.projectID, c.databaseID) if err != nil { c.logger.ErrorContext(ctx, "count query failed", "error", err, "kind", q.kind) @@ -2052,7 +2076,7 @@ func (c *Client) Mutate(ctx context.Context, muts ...*Mutation) ([]*Key, error) } // URL-encode project ID to prevent injection attacks - reqURL := fmt.Sprintf("%s/projects/%s:commit", apiURL, neturl.PathEscape(c.projectID)) + reqURL := fmt.Sprintf("%s/projects/%s:commit", c.baseURL, neturl.PathEscape(c.projectID)) body, err := doRequest(ctx, c.logger, reqURL, jsonData, token, c.projectID, c.databaseID) if err != nil { c.logger.ErrorContext(ctx, "mutate request failed", "error", err) @@ -2227,7 +2251,7 @@ func (c *Client) NewTransaction(ctx context.Context, opts ...TransactionOption) } // URL-encode project ID to prevent injection attacks - reqURL := fmt.Sprintf("%s/projects/%s:beginTransaction", apiURL, neturl.PathEscape(c.projectID)) + reqURL := fmt.Sprintf("%s/projects/%s:beginTransaction", c.baseURL, neturl.PathEscape(c.projectID)) req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, bytes.NewReader(jsonData)) if err != nil { return nil, err @@ -2322,7 +2346,7 @@ func (c *Client) RunInTransaction(ctx context.Context, f func(*Transaction) erro } // URL-encode project ID to prevent injection attacks - reqURL := fmt.Sprintf("%s/projects/%s:beginTransaction", apiURL, neturl.PathEscape(c.projectID)) + reqURL := fmt.Sprintf("%s/projects/%s:beginTransaction", c.baseURL, neturl.PathEscape(c.projectID)) req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, bytes.NewReader(jsonData)) if err != nil { return nil, err @@ -2447,7 +2471,7 @@ func (tx *Transaction) Get(key *Key, dst any) error { } // URL-encode project ID to prevent injection attacks - reqURL := fmt.Sprintf("%s/projects/%s:lookup", apiURL, neturl.PathEscape(tx.client.projectID)) + reqURL := fmt.Sprintf("%s/projects/%s:lookup", tx.client.baseURL, neturl.PathEscape(tx.client.projectID)) req, err := http.NewRequestWithContext(tx.ctx, http.MethodPost, reqURL, bytes.NewReader(jsonData)) if err != nil { return err @@ -2734,7 +2758,7 @@ func (tx *Transaction) doCommit(ctx context.Context, token string) error { } // URL-encode project ID to prevent injection attacks - reqURL := fmt.Sprintf("%s/projects/%s:commit", apiURL, neturl.PathEscape(tx.client.projectID)) + reqURL := fmt.Sprintf("%s/projects/%s:commit", tx.client.baseURL, neturl.PathEscape(tx.client.projectID)) req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, bytes.NewReader(jsonData)) if err != nil { return err diff --git a/datastore_test.go b/datastore_test.go index fe300c7..7676a23 100644 --- a/datastore_test.go +++ b/datastore_test.go @@ -8016,8 +8016,8 @@ func TestMutate(t *testing.T) { t.Fatalf("Mutate with no mutations failed: %v", err) } - if keys != nil && len(keys) != 0 { - t.Errorf("Expected nil or empty keys, got %d", len(keys)) + if len(keys) != 0 { + t.Errorf("Expected empty keys, got %d", len(keys)) } }) } diff --git a/ds9mock/mock.go b/ds9mock/mock.go index cfa5fc4..9deab8d 100644 --- a/ds9mock/mock.go +++ b/ds9mock/mock.go @@ -32,6 +32,8 @@ import ( const metadataFlavor = "Google" // Store holds the in-memory entity storage. +// +//nolint:govet // Field alignment not optimized to maintain readability type Store struct { mu sync.RWMutex entities map[string]map[string]any @@ -231,6 +233,8 @@ func (s *Store) handleLookup(w http.ResponseWriter, r *http.Request) { } // handleCommit handles commit (put/delete) requests. +// +//nolint:gocognit,maintidx // Complex logic required for handling multiple mutation types func (s *Store) handleCommit(w http.ResponseWriter, r *http.Request) { var req struct { Mode string `json:"mode"` @@ -553,7 +557,7 @@ func handleBeginTransaction(w http.ResponseWriter, r *http.Request) { } // handleAllocateIDs handles :allocateIds requests. -func (s *Store) handleAllocateIDs(w http.ResponseWriter, r *http.Request) { +func (*Store) handleAllocateIDs(w http.ResponseWriter, r *http.Request) { var req struct { DatabaseID string `json:"databaseId"` Keys []map[string]any `json:"keys"` @@ -619,6 +623,8 @@ func (s *Store) handleAllocateIDs(w http.ResponseWriter, r *http.Request) { } // matchesFilter checks if an entity matches a filter. +// +//nolint:gocognit,nestif // Complex logic required for proper filter evaluation with multiple types and operators func matchesFilter(entity map[string]any, filterMap map[string]any) bool { // Handle propertyFilter if propFilter, ok := filterMap["propertyFilter"].(map[string]any); ok { @@ -702,6 +708,8 @@ func matchesFilter(entity map[string]any, filterMap map[string]any) bool { return ev <= fv } } + default: + return false } } @@ -716,7 +724,8 @@ func matchesFilter(entity map[string]any, filterMap map[string]any) bool { return true } - if op == "AND" { + switch op { + case "AND": for _, f := range filters { if fm, ok := f.(map[string]any); ok { if !matchesFilter(entity, fm) { @@ -725,7 +734,7 @@ func matchesFilter(entity map[string]any, filterMap map[string]any) bool { } } return true - } else if op == "OR" { + case "OR": for _, f := range filters { if fm, ok := f.(map[string]any); ok { if matchesFilter(entity, fm) { @@ -734,6 +743,8 @@ func matchesFilter(entity map[string]any, filterMap map[string]any) bool { } } return false + default: + return true } } diff --git a/ds9mock/mock_test.go b/ds9mock/mock_test.go index 1689b46..050825a 100644 --- a/ds9mock/mock_test.go +++ b/ds9mock/mock_test.go @@ -351,3 +351,97 @@ func TestMockDeleteNonExistent(t *testing.T) { t.Errorf("Delete of non-existent entity should not error, got: %v", err) } } + +func TestMockConcurrentAccess(t *testing.T) { + client, cleanup := NewClient(t) + defer cleanup() + + ctx := context.Background() + + type TestEntity struct { + Value int64 `datastore:"value"` + } + + // Run concurrent operations to stress-test locking + const goroutines = 50 + const operations = 100 + + done := make(chan bool) + + for g := range goroutines { + go func(id int) { + defer func() { done <- true }() + + for i := range operations { + key := ds9.NameKey("ConcurrentKind", string(rune('a'+id%10)), nil) + entity := &TestEntity{Value: int64(i)} + + // Mix of reads and writes + if i%3 == 0 { + // Write + _, err := client.Put(ctx, key, entity) + if err != nil { + t.Errorf("goroutine %d: Put failed: %v", id, err) + return + } + } else { + // Read - may fail if entity doesn't exist, which is expected + var result TestEntity + client.Get(ctx, key, &result) //nolint:errcheck // Expected to fail when entity doesn't exist + } + } + }(g) + } + + // Wait for all goroutines + for range goroutines { + <-done + } +} + +func TestMockConcurrentQuery(t *testing.T) { + client, cleanup := NewClient(t) + defer cleanup() + + ctx := context.Background() + + type TestEntity struct { + Name string `datastore:"name"` + } + + // Populate some data + for i := range 20 { + key := ds9.NameKey("QueryConcurrent", string(rune('a'+i)), nil) + entity := &TestEntity{Name: "test"} + _, err := client.Put(ctx, key, entity) + if err != nil { + t.Fatalf("Put failed: %v", err) + } + } + + // Run concurrent queries + const goroutines = 20 + done := make(chan bool) + + for range goroutines { + go func() { + defer func() { done <- true }() + + query := ds9.NewQuery("QueryConcurrent").KeysOnly() + keys, err := client.AllKeys(ctx, query) + if err != nil { + t.Errorf("AllKeys failed: %v", err) + return + } + + if len(keys) != 20 { + t.Errorf("expected 20 keys, got %d", len(keys)) + } + }() + } + + // Wait for all goroutines + for range goroutines { + <-done + } +} From df1d8f1f09367e157e766a98e5e7ce71caf87767 Mon Sep 17 00:00:00 2001 From: Thomas Stromberg Date: Wed, 29 Oct 2025 09:19:18 -0400 Subject: [PATCH 2/2] Reorganize and improve --- README.md | 36 +- datastore.go | 2799 ------ datastore_test.go | 8023 ----------------- example/main.go | 12 +- .../integration_test.go | 79 +- pkg/datastore/client.go | 142 + pkg/datastore/client_test.go | 207 + pkg/datastore/common_test.go | 23 + pkg/datastore/comprehensive_coverage_test.go | 336 + pkg/datastore/cursor.go | 21 + pkg/datastore/cursor_coverage_test.go | 138 + pkg/datastore/cursor_test.go | 72 + pkg/datastore/encode_coverage_test.go | 160 + pkg/datastore/entity.go | 313 + pkg/datastore/entity_coverage_test.go | 311 + pkg/datastore/entity_test.go | 816 ++ pkg/datastore/errors.go | 11 + pkg/datastore/http.go | 122 + pkg/datastore/http_test.go | 517 ++ pkg/datastore/iterator.go | 156 + pkg/datastore/iterator_coverage_test.go | 227 + pkg/datastore/iterator_test.go | 96 + pkg/datastore/key.go | 225 + key_test.go => pkg/datastore/key_test.go | 2 +- pkg/datastore/mock_client.go | 38 + pkg/datastore/mutation.go | 198 + pkg/datastore/mutation_test.go | 180 + pkg/datastore/operations.go | 498 + pkg/datastore/operations_coverage_test.go | 354 + pkg/datastore/operations_test.go | 4001 ++++++++ pkg/datastore/query.go | 557 ++ pkg/datastore/query_coverage_test.go | 386 + pkg/datastore/query_test.go | 559 ++ .../datastore/query_unit_test.go | 24 +- pkg/datastore/transaction.go | 655 ++ pkg/datastore/transaction_coverage_test.go | 284 + pkg/datastore/transaction_test.go | 1663 ++++ {ds9mock => pkg/mock}/mock.go | 113 +- {ds9mock => pkg/mock}/mock_test.go | 78 +- transaction_test.go | 402 - 40 files changed, 13472 insertions(+), 11362 deletions(-) delete mode 100644 datastore.go delete mode 100644 datastore_test.go rename integration_test.go => integration/integration_test.go (84%) create mode 100644 pkg/datastore/client.go create mode 100644 pkg/datastore/client_test.go create mode 100644 pkg/datastore/common_test.go create mode 100644 pkg/datastore/comprehensive_coverage_test.go create mode 100644 pkg/datastore/cursor.go create mode 100644 pkg/datastore/cursor_coverage_test.go create mode 100644 pkg/datastore/cursor_test.go create mode 100644 pkg/datastore/encode_coverage_test.go create mode 100644 pkg/datastore/entity.go create mode 100644 pkg/datastore/entity_coverage_test.go create mode 100644 pkg/datastore/entity_test.go create mode 100644 pkg/datastore/errors.go create mode 100644 pkg/datastore/http.go create mode 100644 pkg/datastore/http_test.go create mode 100644 pkg/datastore/iterator.go create mode 100644 pkg/datastore/iterator_coverage_test.go create mode 100644 pkg/datastore/iterator_test.go create mode 100644 pkg/datastore/key.go rename key_test.go => pkg/datastore/key_test.go (99%) create mode 100644 pkg/datastore/mock_client.go create mode 100644 pkg/datastore/mutation.go create mode 100644 pkg/datastore/mutation_test.go create mode 100644 pkg/datastore/operations.go create mode 100644 pkg/datastore/operations_coverage_test.go create mode 100644 pkg/datastore/operations_test.go create mode 100644 pkg/datastore/query.go create mode 100644 pkg/datastore/query_coverage_test.go create mode 100644 pkg/datastore/query_test.go rename query_test.go => pkg/datastore/query_unit_test.go (92%) create mode 100644 pkg/datastore/transaction.go create mode 100644 pkg/datastore/transaction_coverage_test.go create mode 100644 pkg/datastore/transaction_test.go rename {ds9mock => pkg/mock}/mock.go (88%) rename {ds9mock => pkg/mock}/mock_test.go (80%) delete mode 100644 transaction_test.go diff --git a/README.md b/README.md index bf2fce7..f8f5f2b 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # ds9 -Zero-dependency Google Cloud Datastore client for Go. Drop-in replacement for `cloud.google.com/go/datastore` basic operations. +Zero-dependency Google Cloud Datastore client for Go. Drop-in replacement for `cloud.google.com/go/datastore` basic operations. In-memory mock implementation. Comprehensive testing. **Why?** The official client has 50+ dependencies. `ds9` uses only Go stdlib—ideal for lightweight services and minimizing supply chain risk. @@ -12,16 +12,24 @@ go get github.com/codeGROOVE-dev/ds9 ## Quick Start +This isn't the API we would choose, but our primary goal was a drop-in replacement, so usage is exactly the same as the cloud.google.com/go/datastore library: + ```go -import "github.com/codeGROOVE-dev/ds9" +import "github.com/codeGROOVE-dev/ds9/pkg/datastore" -client, _ := ds9.NewClient(ctx, "my-project") -key := ds9.NameKey("Task", "task-1", nil) +client, _ := datastore.NewClient(ctx, "my-project") +key := datastore.NameKey("Task", "task-1", nil) client.Put(ctx, key, &task) client.Get(ctx, key, &task) ``` -**Supported:** +## Migrating from cloud.google.com/go/datastore + +Just switch the import path from `cloud.google.com/go/datastore` to `github.com/codeGROOVE-dev/ds9/pkg/datastore`. + +## Features + +**Supported Features** - **CRUD**: Get, Put, Delete, GetMulti, PutMulti, DeleteMulti - **Transactions**: RunInTransaction, NewTransaction, Commit, Rollback - **Queries**: Filter, Order, Limit, Offset, Ancestor, Project, Distinct, DistinctOn, Namespace, Run (iterator), Count @@ -30,18 +38,14 @@ client.Get(ctx, key, &task) - **Mutations**: NewInsert, NewUpdate, NewUpsert, NewDelete, Mutate - **Types**: string, int, int64, int32, bool, float64, time.Time, slices ([]string, []int64, []int, []float64, []bool) -## Migrating from Official Client - -Change the import—API is compatible: -```go -// import "cloud.google.com/go/datastore" -import "github.com/codeGROOVE-dev/ds9" -``` +**Unsupported Features** -Use `ds9mock` package for in-memory testing. See [TESTING.md](TESTING.md) for integration tests. +These features are unsupported just because we haven't found a use for the feature yet. PRs welcome: -## Limitations +* Embedded structs, nested slices, map types, some advanced query features (streaming aggregations, OR filters). -Not supported: embedded structs, nested slices, map types, some advanced query features (streaming aggregations, OR filters). +## Testing -See [example/](example/) for usage. Apache 2.0 licensed. +* Use `github.com/codeGROOVE-dev/ds9/pkg/mock` package for in-memory testing. It should work even if you choose not to use ds9. +* See [TESTING.md](TESTING.md) for integration tests. +* We aim to maintain 85% test coverage - please don't send PRs without tests. diff --git a/datastore.go b/datastore.go deleted file mode 100644 index 93765bd..0000000 --- a/datastore.go +++ /dev/null @@ -1,2799 +0,0 @@ -// Package ds9 provides a zero-dependency Google Cloud Datastore client. -// -// It uses only the Go standard library and makes direct REST API calls -// to the Datastore API. Authentication is handled via the GCP metadata -// server when running on GCP, or via Application Default Credentials. -// -//nolint:revive // Public structs required for API compatibility with cloud.google.com/go/datastore -package ds9 - -import ( - "bytes" - "context" - "encoding/base64" - "encoding/json" - "errors" - "fmt" - "io" - "log/slog" - "math" - "math/rand/v2" - "net/http" - neturl "net/url" - "reflect" - "strconv" - "strings" - "sync/atomic" - "testing" - "time" - - "github.com/codeGROOVE-dev/ds9/auth" -) - -const ( - maxRetries = 3 - maxBodySize = 10 * 1024 * 1024 // 10MB - defaultTimeout = 30 * time.Second - baseBackoffMS = 100 // Start with 100ms - maxBackoffMS = 2000 // Cap at 2 seconds - jitterFraction = 0.25 // 25% jitter -) - -var ( - // ErrNoSuchEntity is returned when an entity is not found. - ErrNoSuchEntity = errors.New("datastore: no such entity") - - // atomicAPIURL stores the API URL for thread-safe access. - // Use getAPIURL() to read and setAPIURL() to write. - atomicAPIURL atomic.Pointer[string] - - httpClient = &http.Client{ - Timeout: defaultTimeout, - Transport: &http.Transport{ - MaxIdleConns: 10, - IdleConnTimeout: 30 * time.Second, - MaxIdleConnsPerHost: 2, - }, - } - - // operatorMap converts shorthand operators to Datastore API operators. - operatorMap = map[string]string{ - "=": "EQUAL", - "<": "LESS_THAN", - "<=": "LESS_THAN_OR_EQUAL", - ">": "GREATER_THAN", - ">=": "GREATER_THAN_OR_EQUAL", - } -) - -//nolint:gochecknoinits // Required for thread-safe initialization of atomic pointer -func init() { - defaultURL := "https://datastore.googleapis.com/v1" - atomicAPIURL.Store(&defaultURL) -} - -// getAPIURL returns the current API URL in a thread-safe manner. -func getAPIURL() string { - return *atomicAPIURL.Load() -} - -// setAPIURL sets the API URL in a thread-safe manner. -func setAPIURL(url string) { - atomicAPIURL.Store(&url) -} - -// SetTestURLs configures custom metadata and API URLs for testing. -// This is intended for use by testing packages like ds9mock. -// Returns a function that restores the original URLs. -// WARNING: This function should only be called in test code. -// Set DS9_ALLOW_TEST_OVERRIDES=true to enable in non-test environments. -// -// Example: -// -// restore := ds9.SetTestURLs("http://localhost:8080", "http://localhost:9090") -// defer restore() -func SetTestURLs(metadata, api string) (restore func()) { - // Auth package will log warning if called outside test environment - oldAPI := getAPIURL() - setAPIURL(api) - restoreAuth := auth.SetMetadataURL(metadata) - return func() { - setAPIURL(oldAPI) - restoreAuth() - } -} - -// Client is a Google Cloud Datastore client. -type Client struct { - logger *slog.Logger - projectID string - databaseID string - baseURL string // API base URL, defaults to production but can be overridden for testing -} - -// NewClient creates a new Datastore client. -// If projectID is empty, it will be fetched from the GCP metadata server. -func NewClient(ctx context.Context, projectID string) (*Client, error) { - return NewClientWithDatabase(ctx, projectID, "") -} - -// NewClientWithDatabase creates a new Datastore client with a specific database. -func NewClientWithDatabase(ctx context.Context, projID, dbID string) (*Client, error) { - logger := slog.Default() - - if projID == "" { - if !testing.Testing() { - logger.InfoContext(ctx, "project ID not provided, fetching from metadata server") - } - pid, err := auth.ProjectID(ctx) - if err != nil { - logger.ErrorContext(ctx, "failed to get project ID from metadata server", "error", err) - return nil, fmt.Errorf("project ID required: %w", err) - } - projID = pid - if !testing.Testing() { - logger.InfoContext(ctx, "fetched project ID from metadata server", "project_id", projID) - } - } - - if !testing.Testing() { - logger.InfoContext(ctx, "creating datastore client", "project_id", projID, "database_id", dbID) - } - - return &Client{ - projectID: projID, - databaseID: dbID, - baseURL: getAPIURL(), - logger: logger, - }, nil -} - -// Close closes the client connection. -// This is a no-op for ds9 since it uses a shared HTTP client with connection pooling, -// but is provided for API compatibility with cloud.google.com/go/datastore. -func (*Client) Close() error { - return nil -} - -// Key represents a Datastore key. -type Key struct { - Parent *Key // Parent key for hierarchical keys - Kind string - Name string // For string keys - ID int64 // For numeric keys -} - -// NameKey creates a new key with a string name. -// The parent parameter can be nil for top-level keys. -// This matches the API of cloud.google.com/go/datastore. -func NameKey(kind, name string, parent *Key) *Key { - return &Key{ - Kind: kind, - Name: name, - Parent: parent, - } -} - -// IDKey creates a new key with a numeric ID. -// The parent parameter can be nil for top-level keys. -// This matches the API of cloud.google.com/go/datastore. -func IDKey(kind string, id int64, parent *Key) *Key { - return &Key{ - Kind: kind, - ID: id, - Parent: parent, - } -} - -// IncompleteKey creates a new incomplete key. -// The key will be completed (assigned an ID) when the entity is saved. -// API compatible with cloud.google.com/go/datastore. -func IncompleteKey(kind string, parent *Key) *Key { - return &Key{ - Kind: kind, - Parent: parent, - } -} - -// Incomplete returns true if the key does not have an ID or Name. -// API compatible with cloud.google.com/go/datastore. -func (k *Key) Incomplete() bool { - return k.ID == 0 && k.Name == "" -} - -// Equal returns true if this key is equal to the other key. -// API compatible with cloud.google.com/go/datastore. -func (k *Key) Equal(other *Key) bool { - if k == nil && other == nil { - return true - } - if k == nil || other == nil { - return false - } - if k.Kind != other.Kind || k.Name != other.Name || k.ID != other.ID { - return false - } - // Recursively check parent keys - return k.Parent.Equal(other.Parent) -} - -// String returns a human-readable string representation of the key. -// API compatible with cloud.google.com/go/datastore. -func (k *Key) String() string { - if k == nil { - return "" - } - - var parts []string - for curr := k; curr != nil; curr = curr.Parent { - var part string - switch { - case curr.Name != "": - part = fmt.Sprintf("%s,%q", curr.Kind, curr.Name) - case curr.ID != 0: - part = fmt.Sprintf("%s,%d", curr.Kind, curr.ID) - default: - part = fmt.Sprintf("%s,incomplete", curr.Kind) - } - // Prepend to maintain correct order (root to leaf) - parts = append([]string{part}, parts...) - } - - return "/" + strings.Join(parts, "/") -} - -// Encode returns an opaque representation of the key. -// API compatible with cloud.google.com/go/datastore. -func (k *Key) Encode() string { - if k == nil { - return "" - } - - // Convert key to JSON representation - keyJSON := keyToJSON(k) - - // Marshal to JSON bytes - jsonBytes, err := json.Marshal(keyJSON) - if err != nil { - return "" - } - - // Base64 encode - return base64.URLEncoding.EncodeToString(jsonBytes) -} - -// DecodeKey decodes a key from its opaque representation. -// API compatible with cloud.google.com/go/datastore. -func DecodeKey(encoded string) (*Key, error) { - if encoded == "" { - return nil, errors.New("empty encoded key") - } - - // Base64 decode - jsonBytes, err := base64.URLEncoding.DecodeString(encoded) - if err != nil { - return nil, fmt.Errorf("failed to decode base64: %w", err) - } - - // Unmarshal JSON - var keyData any - if err := json.Unmarshal(jsonBytes, &keyData); err != nil { - return nil, fmt.Errorf("failed to unmarshal JSON: %w", err) - } - - // Convert from JSON representation - key, err := keyFromJSON(keyData) - if err != nil { - return nil, fmt.Errorf("failed to parse key: %w", err) - } - - return key, nil -} - -// Cursor represents a query cursor for pagination. -// API compatible with cloud.google.com/go/datastore. -type Cursor string - -// String returns the cursor as a string. -func (c Cursor) String() string { - return string(c) -} - -// DecodeCursor decodes a cursor string. -// API compatible with cloud.google.com/go/datastore. -func DecodeCursor(s string) (Cursor, error) { - if s == "" { - return "", errors.New("empty cursor string") - } - return Cursor(s), nil -} - -// Iterator is an iterator for query results. -// API compatible with cloud.google.com/go/datastore. -// -//nolint:govet // Field alignment optimized for API compatibility over memory layout -type Iterator struct { - ctx context.Context //nolint:containedctx // Required for API compatibility with cloud.google.com/go/datastore - client *Client - query *Query - results []iteratorResult - index int - err error - cursor Cursor - fetchNext bool -} - -type iteratorResult struct { - key *Key - entity map[string]any - cursor Cursor -} - -// Next advances the iterator and returns the next key and destination. -// It returns Done when no more results are available. -// API compatible with cloud.google.com/go/datastore. -func (it *Iterator) Next(dst any) (*Key, error) { - // Check if we need to fetch more results - if it.index >= len(it.results) { - if it.err != nil { - return nil, it.err - } - if !it.fetchNext { - return nil, ErrDone - } - - // Fetch next batch - if err := it.fetch(); err != nil { - it.err = err - return nil, err - } - - if len(it.results) == 0 { - return nil, ErrDone - } - } - - result := it.results[it.index] - it.index++ - it.cursor = result.cursor - - // Decode entity into dst - if err := decodeEntity(result.entity, dst); err != nil { - return nil, err - } - - return result.key, nil -} - -// Cursor returns the cursor for the iterator's current position. -// API compatible with cloud.google.com/go/datastore. -func (it *Iterator) Cursor() (Cursor, error) { - if it.cursor == "" { - return "", errors.New("no cursor available") - } - return it.cursor, nil -} - -// fetch retrieves the next batch of results. -func (it *Iterator) fetch() error { - token, err := auth.AccessToken(it.ctx) - if err != nil { - return fmt.Errorf("failed to get access token: %w", err) - } - - // Build query with current cursor as start - q := *it.query - if it.cursor != "" { - q.startCursor = it.cursor - } - - queryObj := buildQueryMap(&q) - reqBody := map[string]any{"query": queryObj} - if it.client.databaseID != "" { - reqBody["databaseId"] = it.client.databaseID - } - - jsonData, err := json.Marshal(reqBody) - if err != nil { - return fmt.Errorf("failed to marshal request: %w", err) - } - - // URL-encode project ID to prevent injection attacks - reqURL := fmt.Sprintf("%s/projects/%s:runQuery", it.client.baseURL, neturl.PathEscape(it.client.projectID)) - body, err := doRequest(it.ctx, it.client.logger, reqURL, jsonData, token, it.client.projectID, it.client.databaseID) - if err != nil { - return err - } - - var result struct { - Batch struct { - EntityResults []struct { - Entity map[string]any `json:"entity"` - Cursor string `json:"cursor"` - } `json:"entityResults"` - MoreResults string `json:"moreResults"` - EndCursor string `json:"endCursor"` - SkippedResults int `json:"skippedResults"` - } `json:"batch"` - } - - if err := json.Unmarshal(body, &result); err != nil { - return fmt.Errorf("failed to parse response: %w", err) - } - - // Convert results to iterator format - it.results = make([]iteratorResult, 0, len(result.Batch.EntityResults)) - for _, er := range result.Batch.EntityResults { - key, err := keyFromJSON(er.Entity["key"]) - if err != nil { - return err - } - - it.results = append(it.results, iteratorResult{ - key: key, - entity: er.Entity, - cursor: Cursor(er.Cursor), - }) - } - - it.index = 0 - - // Check if there are more results - moreResults := result.Batch.MoreResults - it.fetchNext = moreResults == "NOT_FINISHED" || moreResults == "MORE_RESULTS_AFTER_LIMIT" || moreResults == "MORE_RESULTS_AFTER_CURSOR" - - if result.Batch.EndCursor != "" { - it.cursor = Cursor(result.Batch.EndCursor) - } - - return nil -} - -// ErrDone is returned by Iterator.Next when no more results are available. -var ErrDone = errors.New("datastore: no more results") - -// doRequest performs an HTTP request with exponential backoff retries. -// Returns an error if the status code is not 200 OK. -func doRequest(ctx context.Context, logger *slog.Logger, url string, jsonData []byte, token, projectID, databaseID string) ([]byte, error) { - var lastErr error - - for attempt := range maxRetries { - if attempt > 0 { - // Exponential backoff: 100ms, 200ms, 400ms... capped at maxBackoffMS - backoffMS := math.Min(float64(baseBackoffMS)*math.Pow(2, float64(attempt-1)), float64(maxBackoffMS)) - // Add jitter: ±25% randomness - jitter := backoffMS * jitterFraction * (2*rand.Float64() - 1) //nolint:gosec // Weak random is acceptable for jitter - sleepMS := backoffMS + jitter - sleepDuration := time.Duration(sleepMS) * time.Millisecond - - logger.DebugContext(ctx, "retrying request", - "attempt", attempt+1, - "max_attempts", maxRetries, - "backoff_ms", int(sleepMS), - "last_error", lastErr) - - select { - case <-time.After(sleepDuration): - case <-ctx.Done(): - return nil, ctx.Err() - } - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonData)) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) - } - - req.Header.Set("Authorization", "Bearer "+token) - req.Header.Set("Content-Type", "application/json") - - // Add routing header for named databases - if databaseID != "" { - // URL-encode values to prevent header injection attacks - routingHeader := fmt.Sprintf("project_id=%s&database_id=%s", neturl.QueryEscape(projectID), neturl.QueryEscape(databaseID)) - req.Header.Set("X-Goog-Request-Params", routingHeader) - } - - logger.DebugContext(ctx, "sending request", "url", url, "attempt", attempt+1) - - resp, err := httpClient.Do(req) - if err != nil { - lastErr = err - logger.WarnContext(ctx, "request failed", "error", err, "attempt", attempt+1) - if attempt == maxRetries-1 { - return nil, fmt.Errorf("request failed after %d attempts: %w", maxRetries, err) - } - continue - } - - // Always close response body - defer func() { //nolint:revive,gocritic // Defer in loop is intentional - loop exits after successful response - if closeErr := resp.Body.Close(); closeErr != nil { - logger.WarnContext(ctx, "failed to close response body", "error", closeErr) - } - }() - - body, err := io.ReadAll(io.LimitReader(resp.Body, maxBodySize)) - if err != nil { - lastErr = err - logger.WarnContext(ctx, "failed to read response body", "error", err, "attempt", attempt+1) - if attempt == maxRetries-1 { - return nil, fmt.Errorf("failed to read response after %d attempts: %w", maxRetries, err) - } - continue - } - - logger.DebugContext(ctx, "received response", - "status_code", resp.StatusCode, - "body_size", len(body), - "attempt", attempt+1) - - // Success - if resp.StatusCode == http.StatusOK { - return body, nil - } - - // Don't retry on 4xx errors (client errors) - if resp.StatusCode >= 400 && resp.StatusCode < 500 { - if resp.StatusCode == http.StatusNotFound { - logger.DebugContext(ctx, "entity not found", "status_code", resp.StatusCode) - } else { - logger.WarnContext(ctx, "client error", "status_code", resp.StatusCode, "body", string(body)) - } - return nil, fmt.Errorf("request failed with status %d: %s", resp.StatusCode, string(body)) - } - - // Unexpected 2xx/3xx status codes - if resp.StatusCode < 400 { - logger.WarnContext(ctx, "unexpected non-200 success status", "status_code", resp.StatusCode) - return nil, fmt.Errorf("unexpected status code %d: %s", resp.StatusCode, string(body)) - } - - // 5xx errors - retry - lastErr = fmt.Errorf("server error: status %d", resp.StatusCode) - logger.WarnContext(ctx, "server error, will retry", - "status_code", resp.StatusCode, - "attempt", attempt+1, - "body", string(body)) - } - - return nil, fmt.Errorf("all %d attempts failed: %w", maxRetries, lastErr) -} - -// Get retrieves an entity by key and stores it in dst. -// dst must be a pointer to a struct. -// Returns ErrNoSuchEntity if the key is not found. -func (c *Client) Get(ctx context.Context, key *Key, dst any) error { - if key == nil { - c.logger.WarnContext(ctx, "Get called with nil key") - return errors.New("key cannot be nil") - } - - c.logger.DebugContext(ctx, "getting entity", "kind", key.Kind, "name", key.Name, "id", key.ID) - - token, err := auth.AccessToken(ctx) - if err != nil { - c.logger.ErrorContext(ctx, "failed to get access token", "error", err) - return fmt.Errorf("failed to get access token: %w", err) - } - - reqBody := map[string]any{ - "keys": []map[string]any{keyToJSON(key)}, - } - if c.databaseID != "" { - reqBody["databaseId"] = c.databaseID - } - - jsonData, err := json.Marshal(reqBody) - if err != nil { - c.logger.ErrorContext(ctx, "failed to marshal request", "error", err) - return fmt.Errorf("failed to marshal request: %w", err) - } - - // URL-encode project ID to prevent injection attacks - reqURL := fmt.Sprintf("%s/projects/%s:lookup", c.baseURL, neturl.PathEscape(c.projectID)) - body, err := doRequest(ctx, c.logger, reqURL, jsonData, token, c.projectID, c.databaseID) - if err != nil { - c.logger.ErrorContext(ctx, "lookup request failed", "error", err, "kind", key.Kind) - return err - } - - var result struct { - Found []struct { - Entity map[string]any `json:"entity"` - } `json:"found"` - } - - if err := json.Unmarshal(body, &result); err != nil { - c.logger.ErrorContext(ctx, "failed to parse response", "error", err) - return fmt.Errorf("failed to parse response: %w", err) - } - - if len(result.Found) == 0 { - c.logger.DebugContext(ctx, "entity not found", "kind", key.Kind, "name", key.Name, "id", key.ID) - return ErrNoSuchEntity - } - - c.logger.DebugContext(ctx, "entity retrieved successfully", "kind", key.Kind) - return decodeEntity(result.Found[0].Entity, dst) -} - -// Put stores an entity with the given key. -// src must be a struct or pointer to struct. -// Returns the key (useful for auto-generated IDs in the future). -func (c *Client) Put(ctx context.Context, key *Key, src any) (*Key, error) { - if key == nil { - c.logger.WarnContext(ctx, "Put called with nil key") - return nil, errors.New("key cannot be nil") - } - - c.logger.DebugContext(ctx, "putting entity", "kind", key.Kind, "name", key.Name, "id", key.ID) - - token, err := auth.AccessToken(ctx) - if err != nil { - c.logger.ErrorContext(ctx, "failed to get access token", "error", err) - return nil, fmt.Errorf("failed to get access token: %w", err) - } - - entity, err := encodeEntity(key, src) - if err != nil { - c.logger.ErrorContext(ctx, "failed to encode entity", "error", err, "kind", key.Kind) - return nil, err - } - - reqBody := map[string]any{ - "mode": "NON_TRANSACTIONAL", - "mutations": []map[string]any{{"upsert": entity}}, - } - if c.databaseID != "" { - reqBody["databaseId"] = c.databaseID - } - - jsonData, err := json.Marshal(reqBody) - if err != nil { - c.logger.ErrorContext(ctx, "failed to marshal request", "error", err) - return nil, fmt.Errorf("failed to marshal request: %w", err) - } - - // URL-encode project ID to prevent injection attacks - reqURL := fmt.Sprintf("%s/projects/%s:commit", c.baseURL, neturl.PathEscape(c.projectID)) - if _, err := doRequest(ctx, c.logger, reqURL, jsonData, token, c.projectID, c.databaseID); err != nil { - c.logger.ErrorContext(ctx, "commit request failed", "error", err, "kind", key.Kind) - return nil, err - } - - c.logger.DebugContext(ctx, "entity stored successfully", "kind", key.Kind) - return key, nil -} - -// Delete deletes the entity with the given key. -func (c *Client) Delete(ctx context.Context, key *Key) error { - if key == nil { - c.logger.WarnContext(ctx, "Delete called with nil key") - return errors.New("key cannot be nil") - } - - c.logger.DebugContext(ctx, "deleting entity", "kind", key.Kind, "name", key.Name, "id", key.ID) - - token, err := auth.AccessToken(ctx) - if err != nil { - c.logger.ErrorContext(ctx, "failed to get access token", "error", err) - return fmt.Errorf("failed to get access token: %w", err) - } - - reqBody := map[string]any{ - "mode": "NON_TRANSACTIONAL", - "mutations": []map[string]any{{"delete": keyToJSON(key)}}, - } - if c.databaseID != "" { - reqBody["databaseId"] = c.databaseID - } - - jsonData, err := json.Marshal(reqBody) - if err != nil { - c.logger.ErrorContext(ctx, "failed to marshal request", "error", err) - return fmt.Errorf("failed to marshal request: %w", err) - } - - // URL-encode project ID to prevent injection attacks - reqURL := fmt.Sprintf("%s/projects/%s:commit", c.baseURL, neturl.PathEscape(c.projectID)) - if _, err := doRequest(ctx, c.logger, reqURL, jsonData, token, c.projectID, c.databaseID); err != nil { - c.logger.ErrorContext(ctx, "delete request failed", "error", err, "kind", key.Kind) - return err - } - - c.logger.DebugContext(ctx, "entity deleted successfully", "kind", key.Kind) - return nil -} - -// GetMulti retrieves multiple entities by their keys. -// dst must be a pointer to a slice of structs. -// Returns ErrNoSuchEntity if any key is not found. -// This matches the API of cloud.google.com/go/datastore. -func (c *Client) GetMulti(ctx context.Context, keys []*Key, dst any) error { - if len(keys) == 0 { - c.logger.WarnContext(ctx, "GetMulti called with no keys") - return errors.New("keys cannot be empty") - } - - c.logger.DebugContext(ctx, "getting multiple entities", "count", len(keys)) - - token, err := auth.AccessToken(ctx) - if err != nil { - c.logger.ErrorContext(ctx, "failed to get access token", "error", err) - return fmt.Errorf("failed to get access token: %w", err) - } - - // Build keys array - jsonKeys := make([]map[string]any, len(keys)) - for i, key := range keys { - if key == nil { - c.logger.WarnContext(ctx, "GetMulti called with nil key", "index", i) - return fmt.Errorf("key at index %d cannot be nil", i) - } - jsonKeys[i] = keyToJSON(key) - } - - reqBody := map[string]any{ - "keys": jsonKeys, - } - if c.databaseID != "" { - reqBody["databaseId"] = c.databaseID - } - - jsonData, err := json.Marshal(reqBody) - if err != nil { - c.logger.ErrorContext(ctx, "failed to marshal request", "error", err) - return fmt.Errorf("failed to marshal request: %w", err) - } - - // URL-encode project ID to prevent injection attacks - reqURL := fmt.Sprintf("%s/projects/%s:lookup", c.baseURL, neturl.PathEscape(c.projectID)) - body, err := doRequest(ctx, c.logger, reqURL, jsonData, token, c.projectID, c.databaseID) - if err != nil { - c.logger.ErrorContext(ctx, "lookup request failed", "error", err) - return err - } - - var result struct { - Found []struct { - Entity map[string]any `json:"entity"` - } `json:"found"` - Missing []struct { - Entity map[string]any `json:"entity"` - } `json:"missing"` - } - - if err := json.Unmarshal(body, &result); err != nil { - c.logger.ErrorContext(ctx, "failed to parse response", "error", err) - return fmt.Errorf("failed to parse response: %w", err) - } - - if len(result.Missing) > 0 { - c.logger.DebugContext(ctx, "some entities not found", "missing_count", len(result.Missing)) - return ErrNoSuchEntity - } - - // Decode into slice - v := reflect.ValueOf(dst) - if v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Slice { - return errors.New("dst must be a pointer to slice") - } - - sliceType := v.Elem().Type() - elemType := sliceType.Elem() - - // Create new slice of correct size - slice := reflect.MakeSlice(sliceType, 0, len(result.Found)) - - for _, found := range result.Found { - elem := reflect.New(elemType).Elem() - if err := decodeEntity(found.Entity, elem.Addr().Interface()); err != nil { - c.logger.ErrorContext(ctx, "failed to decode entity", "error", err) - return err - } - slice = reflect.Append(slice, elem) - } - - v.Elem().Set(slice) - c.logger.DebugContext(ctx, "entities retrieved successfully", "count", len(result.Found)) - return nil -} - -// PutMulti stores multiple entities with their keys. -// keys and src must have the same length. -// Returns the keys (same as input) and any error. -// This matches the API of cloud.google.com/go/datastore. -func (c *Client) PutMulti(ctx context.Context, keys []*Key, src any) ([]*Key, error) { - if len(keys) == 0 { - c.logger.WarnContext(ctx, "PutMulti called with no keys") - return nil, errors.New("keys cannot be empty") - } - - c.logger.DebugContext(ctx, "putting multiple entities", "count", len(keys)) - - // Verify src is a slice - v := reflect.ValueOf(src) - if v.Kind() == reflect.Ptr { - v = v.Elem() - } - if v.Kind() != reflect.Slice { - return nil, errors.New("src must be a slice") - } - - if v.Len() != len(keys) { - return nil, fmt.Errorf("keys and src length mismatch: %d != %d", len(keys), v.Len()) - } - - token, err := auth.AccessToken(ctx) - if err != nil { - c.logger.ErrorContext(ctx, "failed to get access token", "error", err) - return nil, fmt.Errorf("failed to get access token: %w", err) - } - - // Build mutations - mutations := make([]map[string]any, len(keys)) - for i, key := range keys { - if key == nil { - c.logger.WarnContext(ctx, "PutMulti called with nil key", "index", i) - return nil, fmt.Errorf("key at index %d cannot be nil", i) - } - - entity, err := encodeEntity(key, v.Index(i).Interface()) - if err != nil { - c.logger.ErrorContext(ctx, "failed to encode entity", "error", err, "index", i) - return nil, fmt.Errorf("failed to encode entity at index %d: %w", i, err) - } - - mutations[i] = map[string]any{ - "upsert": entity, - } - } - - reqBody := map[string]any{ - "mode": "NON_TRANSACTIONAL", - "mutations": mutations, - } - if c.databaseID != "" { - reqBody["databaseId"] = c.databaseID - } - - jsonData, err := json.Marshal(reqBody) - if err != nil { - c.logger.ErrorContext(ctx, "failed to marshal request", "error", err) - return nil, fmt.Errorf("failed to marshal request: %w", err) - } - - // URL-encode project ID to prevent injection attacks - reqURL := fmt.Sprintf("%s/projects/%s:commit", c.baseURL, neturl.PathEscape(c.projectID)) - if _, err := doRequest(ctx, c.logger, reqURL, jsonData, token, c.projectID, c.databaseID); err != nil { - c.logger.ErrorContext(ctx, "commit request failed", "error", err) - return nil, err - } - - c.logger.DebugContext(ctx, "entities stored successfully", "count", len(keys)) - return keys, nil -} - -// DeleteMulti deletes multiple entities with their keys. -// This matches the API of cloud.google.com/go/datastore. -func (c *Client) DeleteMulti(ctx context.Context, keys []*Key) error { - if len(keys) == 0 { - c.logger.WarnContext(ctx, "DeleteMulti called with no keys") - return errors.New("keys cannot be empty") - } - - c.logger.DebugContext(ctx, "deleting multiple entities", "count", len(keys)) - - token, err := auth.AccessToken(ctx) - if err != nil { - c.logger.ErrorContext(ctx, "failed to get access token", "error", err) - return fmt.Errorf("failed to get access token: %w", err) - } - - // Build mutations - mutations := make([]map[string]any, len(keys)) - for i, key := range keys { - if key == nil { - c.logger.WarnContext(ctx, "DeleteMulti called with nil key", "index", i) - return fmt.Errorf("key at index %d cannot be nil", i) - } - - mutations[i] = map[string]any{ - "delete": keyToJSON(key), - } - } - - reqBody := map[string]any{ - "mode": "NON_TRANSACTIONAL", - "mutations": mutations, - } - if c.databaseID != "" { - reqBody["databaseId"] = c.databaseID - } - - jsonData, err := json.Marshal(reqBody) - if err != nil { - c.logger.ErrorContext(ctx, "failed to marshal request", "error", err) - return fmt.Errorf("failed to marshal request: %w", err) - } - - // URL-encode project ID to prevent injection attacks - reqURL := fmt.Sprintf("%s/projects/%s:commit", c.baseURL, neturl.PathEscape(c.projectID)) - if _, err := doRequest(ctx, c.logger, reqURL, jsonData, token, c.projectID, c.databaseID); err != nil { - c.logger.ErrorContext(ctx, "delete request failed", "error", err) - return err - } - - c.logger.DebugContext(ctx, "entities deleted successfully", "count", len(keys)) - return nil -} - -// DeleteAllByKind deletes all entities of a given kind. -// This method queries for all keys and then deletes them in batches. -func (c *Client) DeleteAllByKind(ctx context.Context, kind string) error { - c.logger.InfoContext(ctx, "deleting all entities by kind", "kind", kind) - - // Query for all keys of this kind - q := NewQuery(kind).KeysOnly() - keys, err := c.AllKeys(ctx, q) - if err != nil { - c.logger.ErrorContext(ctx, "failed to query keys", "kind", kind, "error", err) - return fmt.Errorf("failed to query keys: %w", err) - } - - if len(keys) == 0 { - c.logger.InfoContext(ctx, "no entities found to delete", "kind", kind) - return nil - } - - // Delete all keys - if err := c.DeleteMulti(ctx, keys); err != nil { - c.logger.ErrorContext(ctx, "failed to delete entities", "kind", kind, "count", len(keys), "error", err) - return fmt.Errorf("failed to delete entities: %w", err) - } - - c.logger.InfoContext(ctx, "deleted all entities", "kind", kind, "count", len(keys)) - return nil -} - -// AllocateIDs allocates IDs for incomplete keys. -// Returns keys with IDs filled in. Complete keys are returned unchanged. -// API compatible with cloud.google.com/go/datastore. -func (c *Client) AllocateIDs(ctx context.Context, keys []*Key) ([]*Key, error) { - if len(keys) == 0 { - return keys, nil - } - - c.logger.DebugContext(ctx, "allocating IDs", "count", len(keys)) - - // Separate incomplete and complete keys - var incompleteKeys []*Key - var incompleteIndices []int - for i, key := range keys { - if key != nil && key.Incomplete() { - incompleteKeys = append(incompleteKeys, key) - incompleteIndices = append(incompleteIndices, i) - } - } - - // If no incomplete keys, return original slice - if len(incompleteKeys) == 0 { - c.logger.DebugContext(ctx, "no incomplete keys to allocate") - return keys, nil - } - - token, err := auth.AccessToken(ctx) - if err != nil { - c.logger.ErrorContext(ctx, "failed to get access token", "error", err) - return nil, fmt.Errorf("failed to get access token: %w", err) - } - - // Build request with incomplete keys - reqKeys := make([]map[string]any, len(incompleteKeys)) - for i, key := range incompleteKeys { - reqKeys[i] = keyToJSON(key) - } - - reqBody := map[string]any{ - "keys": reqKeys, - } - if c.databaseID != "" { - reqBody["databaseId"] = c.databaseID - } - - jsonData, err := json.Marshal(reqBody) - if err != nil { - c.logger.ErrorContext(ctx, "failed to marshal request", "error", err) - return nil, fmt.Errorf("failed to marshal request: %w", err) - } - - // URL-encode project ID to prevent injection attacks - reqURL := fmt.Sprintf("%s/projects/%s:allocateIds", c.baseURL, neturl.PathEscape(c.projectID)) - body, err := doRequest(ctx, c.logger, reqURL, jsonData, token, c.projectID, c.databaseID) - if err != nil { - c.logger.ErrorContext(ctx, "allocateIds request failed", "error", err) - return nil, err - } - - var resp struct { - Keys []map[string]any `json:"keys"` - } - if err := json.Unmarshal(body, &resp); err != nil { - c.logger.ErrorContext(ctx, "failed to parse response", "error", err) - return nil, fmt.Errorf("failed to parse allocateIds response: %w", err) - } - - // Parse allocated keys - allocatedKeys := make([]*Key, len(resp.Keys)) - for i, keyData := range resp.Keys { - key, err := keyFromJSON(keyData) - if err != nil { - c.logger.ErrorContext(ctx, "failed to parse allocated key", "index", i, "error", err) - return nil, fmt.Errorf("failed to parse allocated key at index %d: %w", i, err) - } - allocatedKeys[i] = key - } - - // Create result slice with allocated keys in correct positions - result := make([]*Key, len(keys)) - copy(result, keys) - for i, idx := range incompleteIndices { - result[idx] = allocatedKeys[i] - } - - c.logger.DebugContext(ctx, "IDs allocated successfully", "count", len(allocatedKeys)) - return result, nil -} - -// keyToJSON converts a Key to its JSON representation. -// Supports hierarchical keys with parent relationships. -func keyToJSON(key *Key) map[string]any { - // Build path from root to leaf (parent -> child) - var path []map[string]any - - // Collect all keys from root to leaf - keys := make([]*Key, 0) - for k := key; k != nil; k = k.Parent { - keys = append(keys, k) - } - - // Reverse to go from root to leaf - for i := len(keys) - 1; i >= 0; i-- { - k := keys[i] - elem := map[string]any{ - "kind": k.Kind, - } - - if k.Name != "" { - elem["name"] = k.Name - } else if k.ID != 0 { - elem["id"] = strconv.FormatInt(k.ID, 10) - } - - path = append(path, elem) - } - - return map[string]any{ - "path": path, - } -} - -// encodeEntity converts a Go struct to a Datastore entity. -func encodeEntity(key *Key, src any) (map[string]any, error) { - v := reflect.ValueOf(src) - if v.Kind() == reflect.Ptr { - v = v.Elem() - } - - if v.Kind() != reflect.Struct { - return nil, errors.New("src must be a struct or pointer to struct") - } - - t := v.Type() - properties := make(map[string]any) - - for i := range v.NumField() { - field := t.Field(i) - value := v.Field(i) - - // Skip unexported fields - if !field.IsExported() { - continue - } - - // Get field name from datastore tag or use field name - name := field.Name - noIndex := false - - if tag := field.Tag.Get("datastore"); tag != "" { - parts := strings.Split(tag, ",") - if parts[0] != "" && parts[0] != "-" { - name = parts[0] - } - if len(parts) > 1 && parts[1] == "noindex" { - noIndex = true - } - if parts[0] == "-" { - continue - } - } - - prop, err := encodeValue(value.Interface()) - if err != nil { - return nil, fmt.Errorf("field %s: %w", field.Name, err) - } - - if noIndex { - if m, ok := prop.(map[string]any); ok { - m["excludeFromIndexes"] = true - } - } - - properties[name] = prop - } - - return map[string]any{ - "key": keyToJSON(key), - "properties": properties, - }, nil -} - -// encodeValue converts a Go value to a Datastore property value. -func encodeValue(v any) (any, error) { - if v == nil { - return map[string]any{"nullValue": nil}, nil - } - - switch val := v.(type) { - case string: - return map[string]any{"stringValue": val}, nil - case int: - return map[string]any{"integerValue": strconv.Itoa(val)}, nil - case int64: - return map[string]any{"integerValue": strconv.FormatInt(val, 10)}, nil - case int32: - return map[string]any{"integerValue": strconv.Itoa(int(val))}, nil - case bool: - return map[string]any{"booleanValue": val}, nil - case float64: - return map[string]any{"doubleValue": val}, nil - case time.Time: - return map[string]any{"timestampValue": val.Format(time.RFC3339Nano)}, nil - case []string: - values := make([]map[string]any, len(val)) - for i, s := range val { - values[i] = map[string]any{"stringValue": s} - } - return map[string]any{"arrayValue": map[string]any{"values": values}}, nil - case []int64: - values := make([]map[string]any, len(val)) - for i, n := range val { - values[i] = map[string]any{"integerValue": strconv.FormatInt(n, 10)} - } - return map[string]any{"arrayValue": map[string]any{"values": values}}, nil - case []int: - values := make([]map[string]any, len(val)) - for i, n := range val { - values[i] = map[string]any{"integerValue": strconv.Itoa(n)} - } - return map[string]any{"arrayValue": map[string]any{"values": values}}, nil - case []float64: - values := make([]map[string]any, len(val)) - for i, f := range val { - values[i] = map[string]any{"doubleValue": f} - } - return map[string]any{"arrayValue": map[string]any{"values": values}}, nil - case []bool: - values := make([]map[string]any, len(val)) - for i, b := range val { - values[i] = map[string]any{"booleanValue": b} - } - return map[string]any{"arrayValue": map[string]any{"values": values}}, nil - default: - // Try to handle slices/arrays via reflection - rv := reflect.ValueOf(v) - if rv.Kind() == reflect.Slice || rv.Kind() == reflect.Array { - length := rv.Len() - values := make([]map[string]any, length) - for i := range length { - elem := rv.Index(i).Interface() - encodedElem, err := encodeValue(elem) - if err != nil { - return nil, fmt.Errorf("failed to encode array element %d: %w", i, err) - } - // encodedElem is already a map[string]any with the type wrapper - m, ok := encodedElem.(map[string]any) - if !ok { - return nil, fmt.Errorf("unexpected encoded value type for element %d", i) - } - values[i] = m - } - return map[string]any{"arrayValue": map[string]any{"values": values}}, nil - } - return nil, fmt.Errorf("unsupported type: %T", v) - } -} - -// decodeEntity converts a Datastore entity to a Go struct. -func decodeEntity(entity map[string]any, dst any) error { - v := reflect.ValueOf(dst) - if v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Struct { - return errors.New("dst must be a pointer to struct") - } - - v = v.Elem() - t := v.Type() - - properties, ok := entity["properties"].(map[string]any) - if !ok { - return errors.New("invalid entity format") - } - - for i := range v.NumField() { - field := t.Field(i) - value := v.Field(i) - - if !field.IsExported() { - continue - } - - // Get field name from datastore tag - name := field.Name - if tag := field.Tag.Get("datastore"); tag != "" { - parts := strings.Split(tag, ",") - if parts[0] != "" && parts[0] != "-" { - name = parts[0] - } - if parts[0] == "-" { - continue - } - } - - prop, ok := properties[name] - if !ok { - continue // Field not in entity - } - - propMap, ok := prop.(map[string]any) - if !ok { - continue - } - - if err := decodeValue(propMap, value); err != nil { - return fmt.Errorf("field %s: %w", field.Name, err) - } - } - - return nil -} - -// decodeValue decodes a Datastore property value into a Go reflect.Value. -func decodeValue(prop map[string]any, dst reflect.Value) error { - // Handle each type - if val, ok := prop["stringValue"]; ok { - if dst.Kind() == reflect.String { - if s, ok := val.(string); ok { - dst.SetString(s) - return nil - } - } - } - - if val, ok := prop["integerValue"]; ok { - var intVal int64 - switch v := val.(type) { - case string: - if _, err := fmt.Sscanf(v, "%d", &intVal); err != nil { - return fmt.Errorf("invalid integer format: %w", err) - } - case float64: - intVal = int64(v) - } - - switch dst.Kind() { - case reflect.Int, reflect.Int64, reflect.Int32: - dst.SetInt(intVal) - return nil - default: - return fmt.Errorf("unsupported integer type: %v", dst.Kind()) - } - } - - if val, ok := prop["booleanValue"]; ok { - if dst.Kind() == reflect.Bool { - if b, ok := val.(bool); ok { - dst.SetBool(b) - return nil - } - } - } - - if val, ok := prop["doubleValue"]; ok { - if dst.Kind() == reflect.Float64 { - if f, ok := val.(float64); ok { - dst.SetFloat(f) - return nil - } - } - } - - if val, ok := prop["timestampValue"]; ok { - if dst.Type() == reflect.TypeOf(time.Time{}) { - if s, ok := val.(string); ok { - t, err := time.Parse(time.RFC3339Nano, s) - if err != nil { - return err - } - dst.Set(reflect.ValueOf(t)) - return nil - } - } - } - - if val, ok := prop["arrayValue"]; ok { - if dst.Kind() != reflect.Slice { - return fmt.Errorf("cannot decode array into non-slice type: %s", dst.Type()) - } - - arrayMap, ok := val.(map[string]any) - if !ok { - return errors.New("invalid arrayValue format") - } - - valuesAny, ok := arrayMap["values"] - if !ok { - // Empty array - dst.Set(reflect.MakeSlice(dst.Type(), 0, 0)) - return nil - } - - values, ok := valuesAny.([]any) - if !ok { - return errors.New("invalid arrayValue.values format") - } - - // Create slice with appropriate capacity - slice := reflect.MakeSlice(dst.Type(), len(values), len(values)) - - // Decode each element - for i, elemAny := range values { - elemMap, ok := elemAny.(map[string]any) - if !ok { - return fmt.Errorf("invalid array element %d format", i) - } - - elemValue := slice.Index(i) - if err := decodeValue(elemMap, elemValue); err != nil { - return fmt.Errorf("failed to decode array element %d: %w", i, err) - } - } - - dst.Set(slice) - return nil - } - - if _, ok := prop["nullValue"]; ok { - // Set to zero value - dst.Set(reflect.Zero(dst.Type())) - return nil - } - - return fmt.Errorf("unsupported property type for %s", dst.Type()) -} - -// Query represents a Datastore query. -type Query struct { - ancestor *Key - kind string - filters []queryFilter - orders []queryOrder - projection []string - distinctOn []string - namespace string - startCursor Cursor - endCursor Cursor - limit int - offset int - keysOnly bool -} - -type queryFilter struct { - value any - property string - operator string -} - -type queryOrder struct { - property string - direction string // "ASCENDING" or "DESCENDING" -} - -// NewQuery creates a new query for the given kind. -func NewQuery(kind string) *Query { - return &Query{ - kind: kind, - } -} - -// KeysOnly configures the query to return only keys, not full entities. -func (q *Query) KeysOnly() *Query { - q.keysOnly = true - return q -} - -// Limit sets the maximum number of results to return. -func (q *Query) Limit(limit int) *Query { - q.limit = limit - return q -} - -// Offset sets the number of results to skip before returning. -// API compatible with cloud.google.com/go/datastore. -func (q *Query) Offset(offset int) *Query { - q.offset = offset - return q -} - -// Filter adds a property filter to the query. -// The filterStr should be in the format "Property Operator" (e.g., "Count >", "Name ="). -// Deprecated: Use FilterField instead. API compatible with cloud.google.com/go/datastore. -func (q *Query) Filter(filterStr string, value any) *Query { - // Parse the filter string to extract property and operator - parts := strings.Fields(filterStr) - if len(parts) != 2 { - // Invalid filter format, but we'll be lenient - return q - } - - property := parts[0] - op := parts[1] - - operator, ok := operatorMap[op] - if !ok { - operator = "EQUAL" - } - - q.filters = append(q.filters, queryFilter{ - property: property, - operator: operator, - value: value, - }) - - return q -} - -// FilterField adds a property filter to the query with explicit operator. -// API compatible with cloud.google.com/go/datastore. -func (q *Query) FilterField(fieldName, operator string, value any) *Query { - dsOperator, ok := operatorMap[operator] - if !ok { - dsOperator = operator // Use as-is if not in map (might already be EQUAL, etc.) - } - - q.filters = append(q.filters, queryFilter{ - property: fieldName, - operator: dsOperator, - value: value, - }) - - return q -} - -// Order sets the order in which results are returned. -// Prefix the property name with "-" for descending order (e.g., "-Created"). -// API compatible with cloud.google.com/go/datastore. -func (q *Query) Order(fieldName string) *Query { - direction := "ASCENDING" - property := fieldName - - if strings.HasPrefix(fieldName, "-") { - direction = "DESCENDING" - property = fieldName[1:] - } - - q.orders = append(q.orders, queryOrder{ - property: property, - direction: direction, - }) - - return q -} - -// Ancestor sets an ancestor filter for the query. -// Only entities with the given ancestor will be returned. -// API compatible with cloud.google.com/go/datastore. -func (q *Query) Ancestor(ancestor *Key) *Query { - q.ancestor = ancestor - return q -} - -// Project sets the fields to be projected (returned) in the query results. -// API compatible with cloud.google.com/go/datastore. -func (q *Query) Project(fieldNames ...string) *Query { - q.projection = fieldNames - return q -} - -// Distinct marks the query to return only distinct results. -// This is equivalent to DistinctOn with all projected fields. -// API compatible with cloud.google.com/go/datastore. -func (q *Query) Distinct() *Query { - // Distinct without fields means distinct on projection - // This will be handled in buildQueryMap - if len(q.projection) > 0 { - q.distinctOn = q.projection - } - return q -} - -// DistinctOn returns a query that removes duplicates based on the given field names. -// API compatible with cloud.google.com/go/datastore. -func (q *Query) DistinctOn(fieldNames ...string) *Query { - q.distinctOn = fieldNames - return q -} - -// Namespace sets the namespace for the query. -// API compatible with cloud.google.com/go/datastore. -func (q *Query) Namespace(ns string) *Query { - q.namespace = ns - return q -} - -// Start sets the starting cursor for the query results. -// API compatible with cloud.google.com/go/datastore. -func (q *Query) Start(c Cursor) *Query { - q.startCursor = c - return q -} - -// End sets the ending cursor for the query results. -// API compatible with cloud.google.com/go/datastore. -func (q *Query) End(c Cursor) *Query { - q.endCursor = c - return q -} - -// buildQueryMap creates a Datastore API query map from a Query object. -func buildQueryMap(query *Query) map[string]any { - queryMap := map[string]any{ - "kind": []map[string]any{{"name": query.kind}}, - } - - // Add namespace via partition ID if specified - if query.namespace != "" { - queryMap["partitionId"] = map[string]any{ - "namespaceId": query.namespace, - } - } - - // Add filters - if len(query.filters) > 0 { - var compositeFilters []map[string]any - for _, f := range query.filters { - encodedVal, err := encodeValue(f.value) - if err != nil { - // Skip invalid filters - continue - } - compositeFilters = append(compositeFilters, map[string]any{ - "propertyFilter": map[string]any{ - "property": map[string]string{"name": f.property}, - "op": f.operator, - "value": encodedVal, - }, - }) - } - - if len(compositeFilters) == 1 { - queryMap["filter"] = compositeFilters[0] - } else if len(compositeFilters) > 1 { - queryMap["filter"] = map[string]any{ - "compositeFilter": map[string]any{ - "op": "AND", - "filters": compositeFilters, - }, - } - } - } - - // Add ancestor filter - if query.ancestor != nil { - ancestorFilter := map[string]any{ - "propertyFilter": map[string]any{ - "property": map[string]string{"name": "__key__"}, - "op": "HAS_ANCESTOR", - "value": map[string]any{"keyValue": keyToJSON(query.ancestor)}, - }, - } - - // Combine with existing filters if present - if existingFilter, ok := queryMap["filter"]; ok { - existingMap, ok := existingFilter.(map[string]any) - if !ok { - // Skip if filter is invalid - queryMap["filter"] = ancestorFilter - } else { - queryMap["filter"] = map[string]any{ - "compositeFilter": map[string]any{ - "op": "AND", - "filters": []map[string]any{existingMap, ancestorFilter}, - }, - } - } - } else { - queryMap["filter"] = ancestorFilter - } - } - - // Add ordering - if len(query.orders) > 0 { - var orders []map[string]any - for _, o := range query.orders { - orders = append(orders, map[string]any{ - "property": map[string]string{"name": o.property}, - "direction": o.direction, - }) - } - queryMap["order"] = orders - } - - // Add projection - if len(query.projection) > 0 { - var projections []map[string]any - for _, field := range query.projection { - projections = append(projections, map[string]any{ - "property": map[string]string{"name": field}, - }) - } - queryMap["projection"] = projections - } else if query.keysOnly { - // Keys-only projection - queryMap["projection"] = []map[string]any{{"property": map[string]string{"name": "__key__"}}} - } - - // Add distinct on - if len(query.distinctOn) > 0 { - var distinctFields []map[string]any - for _, field := range query.distinctOn { - distinctFields = append(distinctFields, map[string]any{ - "property": map[string]string{"name": field}, - }) - } - queryMap["distinctOn"] = distinctFields - } - - // Add limit - if query.limit > 0 { - queryMap["limit"] = query.limit - } - - // Add offset - if query.offset > 0 { - queryMap["offset"] = query.offset - } - - // Add cursors - if query.startCursor != "" { - queryMap["startCursor"] = string(query.startCursor) - } - if query.endCursor != "" { - queryMap["endCursor"] = string(query.endCursor) - } - - return queryMap -} - -// AllKeys returns all keys matching the query. -// This is a convenience method for KeysOnly queries. -func (c *Client) AllKeys(ctx context.Context, q *Query) ([]*Key, error) { - if !q.keysOnly { - c.logger.WarnContext(ctx, "AllKeys called on non-KeysOnly query") - return nil, errors.New("AllKeys requires KeysOnly query") - } - - c.logger.DebugContext(ctx, "querying for keys", "kind", q.kind, "limit", q.limit) - - token, err := auth.AccessToken(ctx) - if err != nil { - c.logger.ErrorContext(ctx, "failed to get access token", "error", err) - return nil, fmt.Errorf("failed to get access token: %w", err) - } - - query := buildQueryMap(q) - - reqBody := map[string]any{"query": query} - if c.databaseID != "" { - reqBody["databaseId"] = c.databaseID - } - - jsonData, err := json.Marshal(reqBody) - if err != nil { - c.logger.ErrorContext(ctx, "failed to marshal request", "error", err) - return nil, fmt.Errorf("failed to marshal request: %w", err) - } - - // URL-encode project ID to prevent injection attacks - reqURL := fmt.Sprintf("%s/projects/%s:runQuery", c.baseURL, neturl.PathEscape(c.projectID)) - body, err := doRequest(ctx, c.logger, reqURL, jsonData, token, c.projectID, c.databaseID) - if err != nil { - c.logger.ErrorContext(ctx, "query request failed", "error", err, "kind", q.kind) - return nil, err - } - - var result struct { - Batch struct { - EntityResults []struct { - Entity map[string]any `json:"entity"` - } `json:"entityResults"` - } `json:"batch"` - } - - if err := json.Unmarshal(body, &result); err != nil { - c.logger.ErrorContext(ctx, "failed to parse response", "error", err) - return nil, fmt.Errorf("failed to parse response: %w", err) - } - - keys := make([]*Key, 0, len(result.Batch.EntityResults)) - for _, er := range result.Batch.EntityResults { - key, err := keyFromJSON(er.Entity["key"]) - if err != nil { - c.logger.ErrorContext(ctx, "failed to parse key from response", "error", err) - return nil, err - } - keys = append(keys, key) - } - - c.logger.DebugContext(ctx, "query completed successfully", "kind", q.kind, "keys_found", len(keys)) - return keys, nil -} - -// GetAll retrieves all entities matching the query and stores them in dst. -// dst must be a pointer to a slice of structs. -// Returns the keys of the retrieved entities and any error. -// This matches the API of cloud.google.com/go/datastore. -func (c *Client) GetAll(ctx context.Context, query *Query, dst any) ([]*Key, error) { - c.logger.DebugContext(ctx, "querying for entities", "kind", query.kind, "limit", query.limit) - - token, err := auth.AccessToken(ctx) - if err != nil { - c.logger.ErrorContext(ctx, "failed to get access token", "error", err) - return nil, fmt.Errorf("failed to get access token: %w", err) - } - - queryObj := buildQueryMap(query) - - reqBody := map[string]any{"query": queryObj} - if c.databaseID != "" { - reqBody["databaseId"] = c.databaseID - } - - jsonData, err := json.Marshal(reqBody) - if err != nil { - c.logger.ErrorContext(ctx, "failed to marshal request", "error", err) - return nil, fmt.Errorf("failed to marshal request: %w", err) - } - - // URL-encode project ID to prevent injection attacks - reqURL := fmt.Sprintf("%s/projects/%s:runQuery", c.baseURL, neturl.PathEscape(c.projectID)) - body, err := doRequest(ctx, c.logger, reqURL, jsonData, token, c.projectID, c.databaseID) - if err != nil { - c.logger.ErrorContext(ctx, "query request failed", "error", err, "kind", query.kind) - return nil, err - } - - var result struct { - Batch struct { - EntityResults []struct { - Entity map[string]any `json:"entity"` - } `json:"entityResults"` - } `json:"batch"` - } - - if err := json.Unmarshal(body, &result); err != nil { - c.logger.ErrorContext(ctx, "failed to parse response", "error", err) - return nil, fmt.Errorf("failed to parse response: %w", err) - } - - // Verify dst is a pointer to slice - v := reflect.ValueOf(dst) - if v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Slice { - return nil, errors.New("dst must be a pointer to slice") - } - - sliceType := v.Elem().Type() - elemType := sliceType.Elem() - - // Create new slice of correct size - slice := reflect.MakeSlice(sliceType, 0, len(result.Batch.EntityResults)) - keys := make([]*Key, 0, len(result.Batch.EntityResults)) - - for _, er := range result.Batch.EntityResults { - // Extract key - key, err := keyFromJSON(er.Entity["key"]) - if err != nil { - c.logger.ErrorContext(ctx, "failed to parse key from response", "error", err) - return nil, err - } - keys = append(keys, key) - - // Decode entity - elem := reflect.New(elemType).Elem() - if err := decodeEntity(er.Entity, elem.Addr().Interface()); err != nil { - c.logger.ErrorContext(ctx, "failed to decode entity", "error", err) - return nil, err - } - slice = reflect.Append(slice, elem) - } - - v.Elem().Set(slice) - c.logger.DebugContext(ctx, "query completed successfully", "kind", query.kind, "entities_found", len(keys)) - return keys, nil -} - -// Count returns the number of entities matching the query. -// Deprecated: Use aggregation queries with RunAggregationQuery instead. -// API compatible with cloud.google.com/go/datastore. -func (c *Client) Count(ctx context.Context, q *Query) (int, error) { - c.logger.DebugContext(ctx, "counting entities", "kind", q.kind) - - token, err := auth.AccessToken(ctx) - if err != nil { - c.logger.ErrorContext(ctx, "failed to get access token", "error", err) - return 0, fmt.Errorf("failed to get access token: %w", err) - } - - // Build aggregation query with COUNT - queryObj := buildQueryMap(q) - aggregationQuery := map[string]any{ - "aggregations": []map[string]any{ - { - "alias": "total", - "count": map[string]any{}, - }, - }, - "nestedQuery": queryObj, - } - - reqBody := map[string]any{ - "aggregationQuery": aggregationQuery, - } - if c.databaseID != "" { - reqBody["databaseId"] = c.databaseID - } - - jsonData, err := json.Marshal(reqBody) - if err != nil { - c.logger.ErrorContext(ctx, "failed to marshal request", "error", err) - return 0, fmt.Errorf("failed to marshal request: %w", err) - } - - // URL-encode project ID to prevent injection attacks - reqURL := fmt.Sprintf("%s/projects/%s:runAggregationQuery", c.baseURL, neturl.PathEscape(c.projectID)) - body, err := doRequest(ctx, c.logger, reqURL, jsonData, token, c.projectID, c.databaseID) - if err != nil { - c.logger.ErrorContext(ctx, "count query failed", "error", err, "kind", q.kind) - return 0, err - } - - var result struct { - Batch struct { - AggregationResults []struct { - AggregateProperties map[string]struct { - IntegerValue string `json:"integerValue"` - } `json:"aggregateProperties"` - } `json:"aggregationResults"` - } `json:"batch"` - } - - if err := json.Unmarshal(body, &result); err != nil { - c.logger.ErrorContext(ctx, "failed to parse response", "error", err) - return 0, fmt.Errorf("failed to parse count response: %w", err) - } - - if len(result.Batch.AggregationResults) == 0 { - c.logger.DebugContext(ctx, "no results returned", "kind", q.kind) - return 0, nil - } - - // Extract count from total aggregation - countVal, ok := result.Batch.AggregationResults[0].AggregateProperties["total"] - if !ok { - c.logger.ErrorContext(ctx, "count not found in response") - return 0, errors.New("count not found in aggregation response") - } - - count, err := strconv.Atoi(countVal.IntegerValue) - if err != nil { - c.logger.ErrorContext(ctx, "failed to parse count", "error", err, "value", countVal.IntegerValue) - return 0, fmt.Errorf("failed to parse count: %w", err) - } - - c.logger.DebugContext(ctx, "count completed successfully", "kind", q.kind, "count", count) - return count, nil -} - -// Run executes the query and returns an iterator for the results. -// API compatible with cloud.google.com/go/datastore. -func (c *Client) Run(ctx context.Context, q *Query) *Iterator { - return &Iterator{ - ctx: ctx, - client: c, - query: q, - fetchNext: true, - } -} - -// MutationOp represents the type of mutation operation. -type MutationOp string - -const ( - // MutationInsert represents an insert operation. - MutationInsert MutationOp = "insert" - // MutationUpdate represents an update operation. - MutationUpdate MutationOp = "update" - // MutationUpsert represents an upsert operation. - MutationUpsert MutationOp = "upsert" - // MutationDelete represents a delete operation. - MutationDelete MutationOp = "delete" -) - -// Mutation represents a pending datastore mutation. -type Mutation struct { - op MutationOp - key *Key - entity any -} - -// NewInsert creates an insert mutation. -// API compatible with cloud.google.com/go/datastore. -func NewInsert(k *Key, src any) *Mutation { - return &Mutation{ - op: MutationInsert, - key: k, - entity: src, - } -} - -// NewUpdate creates an update mutation. -// API compatible with cloud.google.com/go/datastore. -func NewUpdate(k *Key, src any) *Mutation { - return &Mutation{ - op: MutationUpdate, - key: k, - entity: src, - } -} - -// NewUpsert creates an upsert mutation. -// API compatible with cloud.google.com/go/datastore. -func NewUpsert(k *Key, src any) *Mutation { - return &Mutation{ - op: MutationUpsert, - key: k, - entity: src, - } -} - -// NewDelete creates a delete mutation. -// API compatible with cloud.google.com/go/datastore. -func NewDelete(k *Key) *Mutation { - return &Mutation{ - op: MutationDelete, - key: k, - } -} - -// Mutate applies one or more mutations atomically. -// API compatible with cloud.google.com/go/datastore. -func (c *Client) Mutate(ctx context.Context, muts ...*Mutation) ([]*Key, error) { - if len(muts) == 0 { - return nil, nil - } - - c.logger.DebugContext(ctx, "applying mutations", "count", len(muts)) - - token, err := auth.AccessToken(ctx) - if err != nil { - c.logger.ErrorContext(ctx, "failed to get access token", "error", err) - return nil, fmt.Errorf("failed to get access token: %w", err) - } - - // Build mutations array - mutations := make([]map[string]any, 0, len(muts)) - for i, mut := range muts { - if mut == nil { - c.logger.ErrorContext(ctx, "nil mutation", "index", i) - return nil, fmt.Errorf("mutation at index %d is nil", i) - } - if mut.key == nil { - c.logger.ErrorContext(ctx, "nil key in mutation", "index", i) - return nil, fmt.Errorf("mutation at index %d has nil key", i) - } - - mutMap := make(map[string]any) - - switch mut.op { - case MutationInsert: - if mut.entity == nil { - c.logger.ErrorContext(ctx, "nil entity for insert", "index", i) - return nil, fmt.Errorf("insert mutation at index %d has nil entity", i) - } - entity, err := encodeEntity(mut.key, mut.entity) - if err != nil { - c.logger.ErrorContext(ctx, "failed to encode entity", "index", i, "error", err) - return nil, fmt.Errorf("failed to encode entity at index %d: %w", i, err) - } - mutMap["insert"] = entity - - case MutationUpdate: - if mut.entity == nil { - c.logger.ErrorContext(ctx, "nil entity for update", "index", i) - return nil, fmt.Errorf("update mutation at index %d has nil entity", i) - } - entity, err := encodeEntity(mut.key, mut.entity) - if err != nil { - c.logger.ErrorContext(ctx, "failed to encode entity", "index", i, "error", err) - return nil, fmt.Errorf("failed to encode entity at index %d: %w", i, err) - } - mutMap["update"] = entity - - case MutationUpsert: - if mut.entity == nil { - c.logger.ErrorContext(ctx, "nil entity for upsert", "index", i) - return nil, fmt.Errorf("upsert mutation at index %d has nil entity", i) - } - entity, err := encodeEntity(mut.key, mut.entity) - if err != nil { - c.logger.ErrorContext(ctx, "failed to encode entity", "index", i, "error", err) - return nil, fmt.Errorf("failed to encode entity at index %d: %w", i, err) - } - mutMap["upsert"] = entity - - case MutationDelete: - mutMap["delete"] = keyToJSON(mut.key) - - default: - c.logger.ErrorContext(ctx, "unknown mutation operation", "index", i, "op", mut.op) - return nil, fmt.Errorf("unknown mutation operation at index %d: %s", i, mut.op) - } - - mutations = append(mutations, mutMap) - } - - reqBody := map[string]any{ - "mutations": mutations, - } - if c.databaseID != "" { - reqBody["databaseId"] = c.databaseID - } - - jsonData, err := json.Marshal(reqBody) - if err != nil { - c.logger.ErrorContext(ctx, "failed to marshal request", "error", err) - return nil, fmt.Errorf("failed to marshal request: %w", err) - } - - // URL-encode project ID to prevent injection attacks - reqURL := fmt.Sprintf("%s/projects/%s:commit", c.baseURL, neturl.PathEscape(c.projectID)) - body, err := doRequest(ctx, c.logger, reqURL, jsonData, token, c.projectID, c.databaseID) - if err != nil { - c.logger.ErrorContext(ctx, "mutate request failed", "error", err) - return nil, err - } - - var resp struct { - MutationResults []struct { - Key map[string]any `json:"key"` - } `json:"mutationResults"` - } - if err := json.Unmarshal(body, &resp); err != nil { - c.logger.ErrorContext(ctx, "failed to parse response", "error", err) - return nil, fmt.Errorf("failed to parse mutate response: %w", err) - } - - // Extract resulting keys - keys := make([]*Key, len(resp.MutationResults)) - for i, result := range resp.MutationResults { - if result.Key != nil { - key, err := keyFromJSON(result.Key) - if err != nil { - c.logger.ErrorContext(ctx, "failed to parse key", "index", i, "error", err) - return nil, fmt.Errorf("failed to parse key at index %d: %w", i, err) - } - keys[i] = key - } else { - // For deletes, use the original key - keys[i] = muts[i].key - } - } - - c.logger.DebugContext(ctx, "mutations applied successfully", "count", len(keys)) - return keys, nil -} - -// keyFromJSON converts a JSON key representation to a Key. -func keyFromJSON(keyData any) (*Key, error) { - keyMap, ok := keyData.(map[string]any) - if !ok { - return nil, errors.New("invalid key format") - } - - path, ok := keyMap["path"].([]any) - if !ok || len(path) == 0 { - return nil, errors.New("invalid key path") - } - - // Build key hierarchy from path elements - var key *Key - for _, elem := range path { - elemMap, ok := elem.(map[string]any) - if !ok { - return nil, errors.New("invalid path element") - } - - newKey := &Key{ - Parent: key, - } - - if kind, ok := elemMap["kind"].(string); ok { - newKey.Kind = kind - } - - if name, ok := elemMap["name"].(string); ok { - newKey.Name = name - } else if idVal, exists := elemMap["id"]; exists { - switch id := idVal.(type) { - case string: - if _, err := fmt.Sscanf(id, "%d", &newKey.ID); err != nil { - return nil, fmt.Errorf("invalid ID format: %w", err) - } - case float64: - newKey.ID = int64(id) - } - } - - key = newKey - } - - return key, nil -} - -// Commit represents the result of a committed transaction. -// This is provided for API compatibility with cloud.google.com/go/datastore. -type Commit struct{} - -// Transaction represents a Datastore transaction. -// Note: This struct stores context for API compatibility with Google's official -// cloud.google.com/go/datastore library, which uses the same pattern. -type Transaction struct { - ctx context.Context //nolint:containedctx // Required for API compatibility with cloud.google.com/go/datastore - client *Client - id string - mutations []map[string]any -} - -// TransactionOption configures transaction behavior. -type TransactionOption interface { - apply(*transactionSettings) -} - -type transactionSettings struct { - readTime time.Time - maxAttempts int -} - -type maxAttemptsOption int - -func (o maxAttemptsOption) apply(s *transactionSettings) { - s.maxAttempts = int(o) -} - -// MaxAttempts returns a TransactionOption that specifies the maximum number -// of times a transaction should be attempted before giving up. -func MaxAttempts(n int) TransactionOption { - return maxAttemptsOption(n) -} - -type readTimeOption struct { - t time.Time -} - -func (o readTimeOption) apply(s *transactionSettings) { - s.readTime = o.t -} - -// WithReadTime returns a TransactionOption that sets a specific timestamp -// at which to read data, enabling reading from a particular snapshot in time. -func WithReadTime(t time.Time) TransactionOption { - return readTimeOption{t: t} -} - -// NewTransaction creates a new transaction. -// The caller must call Commit or Rollback when done. -// API compatible with cloud.google.com/go/datastore. -func (c *Client) NewTransaction(ctx context.Context, opts ...TransactionOption) (*Transaction, error) { - settings := transactionSettings{ - maxAttempts: 3, // default (not used for NewTransaction, but kept for consistency) - } - for _, opt := range opts { - opt.apply(&settings) - } - - token, err := auth.AccessToken(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get access token: %w", err) - } - - // Begin transaction - reqBody := map[string]any{} - if c.databaseID != "" { - reqBody["databaseId"] = c.databaseID - } - - // Add transaction options if needed - if !settings.readTime.IsZero() { - reqBody["transactionOptions"] = map[string]any{ - "readOnly": map[string]any{ - "readTime": settings.readTime.Format(time.RFC3339Nano), - }, - } - } else { - reqBody["transactionOptions"] = map[string]any{ - "readWrite": map[string]any{}, - } - } - - jsonData, err := json.Marshal(reqBody) - if err != nil { - return nil, err - } - - // URL-encode project ID to prevent injection attacks - reqURL := fmt.Sprintf("%s/projects/%s:beginTransaction", c.baseURL, neturl.PathEscape(c.projectID)) - req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, bytes.NewReader(jsonData)) - if err != nil { - return nil, err - } - - req.Header.Set("Authorization", "Bearer "+token) - req.Header.Set("Content-Type", "application/json") - - // Add routing header for named databases - if c.databaseID != "" { - // URL-encode values to prevent header injection attacks - routingHeader := fmt.Sprintf("project_id=%s&database_id=%s", neturl.QueryEscape(c.projectID), neturl.QueryEscape(c.databaseID)) - req.Header.Set("X-Goog-Request-Params", routingHeader) - } - - resp, err := httpClient.Do(req) - if err != nil { - return nil, err - } - - body, err := io.ReadAll(io.LimitReader(resp.Body, maxBodySize)) - closeErr := resp.Body.Close() - if closeErr != nil { - c.logger.Warn("failed to close response body", "error", closeErr) - } - if err != nil { - return nil, err - } - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("begin transaction failed with status %d: %s", resp.StatusCode, string(body)) - } - - var txResp struct { - Transaction string `json:"transaction"` - } - - if err := json.Unmarshal(body, &txResp); err != nil { - return nil, fmt.Errorf("failed to parse transaction response: %w", err) - } - - tx := &Transaction{ - ctx: ctx, - client: c, - id: txResp.Transaction, - } - - return tx, nil -} - -// RunInTransaction runs a function in a transaction. -// The function should use the transaction's Get and Put methods. -// API compatible with cloud.google.com/go/datastore. -func (c *Client) RunInTransaction(ctx context.Context, f func(*Transaction) error, opts ...TransactionOption) (*Commit, error) { - settings := transactionSettings{ - maxAttempts: 3, // default - } - for _, opt := range opts { - opt.apply(&settings) - } - - var lastErr error - - for attempt := range settings.maxAttempts { - token, err := auth.AccessToken(ctx) - if err != nil { - return nil, fmt.Errorf("failed to get access token: %w", err) - } - - // Begin transaction - reqBody := map[string]any{} - if c.databaseID != "" { - reqBody["databaseId"] = c.databaseID - } - - // Add transaction options if needed - if !settings.readTime.IsZero() { - reqBody["transactionOptions"] = map[string]any{ - "readOnly": map[string]any{ - "readTime": settings.readTime.Format(time.RFC3339Nano), - }, - } - } else { - reqBody["transactionOptions"] = map[string]any{ - "readWrite": map[string]any{}, - } - } - - jsonData, err := json.Marshal(reqBody) - if err != nil { - return nil, err - } - - // URL-encode project ID to prevent injection attacks - reqURL := fmt.Sprintf("%s/projects/%s:beginTransaction", c.baseURL, neturl.PathEscape(c.projectID)) - req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, bytes.NewReader(jsonData)) - if err != nil { - return nil, err - } - - req.Header.Set("Authorization", "Bearer "+token) - req.Header.Set("Content-Type", "application/json") - - // Add routing header for named databases - if c.databaseID != "" { - // URL-encode values to prevent header injection attacks - routingHeader := fmt.Sprintf("project_id=%s&database_id=%s", neturl.QueryEscape(c.projectID), neturl.QueryEscape(c.databaseID)) - req.Header.Set("X-Goog-Request-Params", routingHeader) - } - - resp, err := httpClient.Do(req) - if err != nil { - return nil, err - } - - body, err := io.ReadAll(io.LimitReader(resp.Body, maxBodySize)) - closeErr := resp.Body.Close() - if closeErr != nil { - c.logger.Warn("failed to close response body", "error", closeErr) - } - if err != nil { - return nil, err - } - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("begin transaction failed with status %d: %s", resp.StatusCode, string(body)) - } - - var txResp struct { - Transaction string `json:"transaction"` - } - - if err := json.Unmarshal(body, &txResp); err != nil { - return nil, fmt.Errorf("failed to parse transaction response: %w", err) - } - - tx := &Transaction{ - ctx: ctx, - client: c, - id: txResp.Transaction, - } - - // Run the function - if err := f(tx); err != nil { - // Rollback is implicit if commit is not called - return nil, err - } - - // Commit the transaction - err = tx.doCommit(ctx, token) - if err == nil { - c.logger.Debug("transaction committed successfully", "attempt", attempt+1) - return &Commit{}, nil // Success - } - - c.logger.Warn("transaction commit failed", "attempt", attempt+1, "error", err) - - // Check if error contains 409 ABORTED - if so, retry - errStr := err.Error() - is409 := strings.Contains(errStr, "status 409") - isAborted := strings.Contains(errStr, "ABORTED") - - if is409 || isAborted { - lastErr = err - c.logger.Warn("transaction aborted, will retry", - "attempt", attempt+1, - "max_attempts", settings.maxAttempts, - "has_409", is409, - "has_ABORTED", isAborted, - "error", err) - - // Exponential backoff: 100ms, 200ms, 400ms - if attempt < settings.maxAttempts-1 { - backoffMS := 100 * (1 << attempt) - c.logger.Debug("sleeping before retry", "backoff_ms", backoffMS) - time.Sleep(time.Duration(backoffMS) * time.Millisecond) - } - continue - } - - // Non-retriable error - c.logger.Warn("non-retriable transaction error", "error", err) - return nil, err - } - - return nil, fmt.Errorf("transaction failed after %d attempts: %w", settings.maxAttempts, lastErr) -} - -// Get retrieves an entity within the transaction. -// API compatible with cloud.google.com/go/datastore. -func (tx *Transaction) Get(key *Key, dst any) error { - if key == nil { - return errors.New("key cannot be nil") - } - - token, err := auth.AccessToken(tx.ctx) - if err != nil { - return fmt.Errorf("failed to get access token: %w", err) - } - - reqBody := map[string]any{ - "keys": []map[string]any{ - keyToJSON(key), - }, - "readOptions": map[string]any{ - "transaction": tx.id, - }, - } - - if tx.client.databaseID != "" { - reqBody["databaseId"] = tx.client.databaseID - } - - jsonData, err := json.Marshal(reqBody) - if err != nil { - return err - } - - // URL-encode project ID to prevent injection attacks - reqURL := fmt.Sprintf("%s/projects/%s:lookup", tx.client.baseURL, neturl.PathEscape(tx.client.projectID)) - req, err := http.NewRequestWithContext(tx.ctx, http.MethodPost, reqURL, bytes.NewReader(jsonData)) - if err != nil { - return err - } - - req.Header.Set("Authorization", "Bearer "+token) - req.Header.Set("Content-Type", "application/json") - - // Add routing header for named databases - if tx.client.databaseID != "" { - // URL-encode values to prevent header injection attacks - routingHeader := fmt.Sprintf("project_id=%s&database_id=%s", - neturl.QueryEscape(tx.client.projectID), - neturl.QueryEscape(tx.client.databaseID)) - req.Header.Set("X-Goog-Request-Params", routingHeader) - } - - resp, err := httpClient.Do(req) - if err != nil { - return err - } - defer func() { - if closeErr := resp.Body.Close(); closeErr != nil { - tx.client.logger.Warn("failed to close response body", "error", closeErr) - } - }() - - body, err := io.ReadAll(io.LimitReader(resp.Body, maxBodySize)) - if err != nil { - return err - } - - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("transaction get failed with status %d: %s", resp.StatusCode, string(body)) - } - - var result struct { - Found []struct { - Entity map[string]any `json:"entity"` - } `json:"found"` - Missing []struct{} `json:"missing"` - } - - if err := json.Unmarshal(body, &result); err != nil { - return fmt.Errorf("failed to parse response: %w", err) - } - - if len(result.Found) == 0 { - return ErrNoSuchEntity - } - - return decodeEntity(result.Found[0].Entity, dst) -} - -// Put stores an entity within the transaction. -func (tx *Transaction) Put(key *Key, src any) (*Key, error) { - if key == nil { - return nil, errors.New("key cannot be nil") - } - - // Encode the entity - entity, err := encodeEntity(key, src) - if err != nil { - return nil, err - } - - // Create mutation - mutation := map[string]any{ - "upsert": entity, - } - - // Accumulate mutation for commit - tx.mutations = append(tx.mutations, mutation) - - return key, nil -} - -// Delete deletes an entity within the transaction. -// API compatible with cloud.google.com/go/datastore. -func (tx *Transaction) Delete(key *Key) error { - if key == nil { - return errors.New("key cannot be nil") - } - - // Create delete mutation - mutation := map[string]any{ - "delete": keyToJSON(key), - } - - // Accumulate mutation for commit - tx.mutations = append(tx.mutations, mutation) - - return nil -} - -// DeleteMulti deletes multiple entities within the transaction. -// API compatible with cloud.google.com/go/datastore. -func (tx *Transaction) DeleteMulti(keys []*Key) error { - for _, key := range keys { - if err := tx.Delete(key); err != nil { - return err - } - } - return nil -} - -// GetMulti retrieves multiple entities within the transaction. -// API compatible with cloud.google.com/go/datastore. -func (tx *Transaction) GetMulti(keys []*Key, dst any) error { - dstVal := reflect.ValueOf(dst) - if dstVal.Kind() != reflect.Ptr || dstVal.Elem().Kind() != reflect.Slice { - return errors.New("dst must be a pointer to a slice") - } - - slice := dstVal.Elem() - if len(keys) != slice.Len() { - return fmt.Errorf("keys and dst slices must have same length: %d vs %d", len(keys), slice.Len()) - } - - // Get each entity individually within the transaction - for i, key := range keys { - elem := slice.Index(i) - if elem.Kind() == reflect.Ptr { - // dst is []*Entity - if elem.IsNil() { - elem.Set(reflect.New(elem.Type().Elem())) - } - if err := tx.Get(key, elem.Interface()); err != nil { - return err - } - } else { - // dst is []Entity - if err := tx.Get(key, elem.Addr().Interface()); err != nil { - return err - } - } - } - - return nil -} - -// PutMulti stores multiple entities within the transaction. -// API compatible with cloud.google.com/go/datastore. -func (tx *Transaction) PutMulti(keys []*Key, src any) ([]*Key, error) { - srcVal := reflect.ValueOf(src) - if srcVal.Kind() != reflect.Slice { - return nil, errors.New("src must be a slice") - } - - if len(keys) != srcVal.Len() { - return nil, fmt.Errorf("keys and src slices must have same length: %d vs %d", len(keys), srcVal.Len()) - } - - // Put each entity individually within the transaction - for i, key := range keys { - elem := srcVal.Index(i) - var src any - if elem.Kind() == reflect.Ptr { - src = elem.Interface() - } else { - src = elem.Addr().Interface() - } - - if _, err := tx.Put(key, src); err != nil { - return nil, err - } - } - - return keys, nil -} - -// Commit applies the transaction's mutations. -// API compatible with cloud.google.com/go/datastore. -func (tx *Transaction) Commit() (*Commit, error) { - token, err := auth.AccessToken(tx.ctx) - if err != nil { - return nil, fmt.Errorf("failed to get access token: %w", err) - } - - if err := tx.doCommit(tx.ctx, token); err != nil { - return nil, err - } - - return &Commit{}, nil -} - -// Rollback abandons the transaction. -// API compatible with cloud.google.com/go/datastore. -func (tx *Transaction) Rollback() error { - // Datastore transactions are automatically rolled back if not committed - // So we just need to clear the mutations to prevent accidental commit - tx.mutations = nil - return nil -} - -// Mutate adds one or more mutations to the transaction. -// API compatible with cloud.google.com/go/datastore. -func (tx *Transaction) Mutate(muts ...*Mutation) ([]*PendingKey, error) { - if len(muts) == 0 { - return nil, nil - } - - // Build mutations array - pendingKeys := make([]*PendingKey, 0, len(muts)) - for i, mut := range muts { - if mut == nil { - return nil, fmt.Errorf("mutation at index %d is nil", i) - } - if mut.key == nil { - return nil, fmt.Errorf("mutation at index %d has nil key", i) - } - - mutMap := make(map[string]any) - - switch mut.op { - case MutationInsert: - if mut.entity == nil { - return nil, fmt.Errorf("insert mutation at index %d has nil entity", i) - } - entity, err := encodeEntity(mut.key, mut.entity) - if err != nil { - return nil, fmt.Errorf("failed to encode entity at index %d: %w", i, err) - } - mutMap["insert"] = entity - - case MutationUpdate: - if mut.entity == nil { - return nil, fmt.Errorf("update mutation at index %d has nil entity", i) - } - entity, err := encodeEntity(mut.key, mut.entity) - if err != nil { - return nil, fmt.Errorf("failed to encode entity at index %d: %w", i, err) - } - mutMap["update"] = entity - - case MutationUpsert: - if mut.entity == nil { - return nil, fmt.Errorf("upsert mutation at index %d has nil entity", i) - } - entity, err := encodeEntity(mut.key, mut.entity) - if err != nil { - return nil, fmt.Errorf("failed to encode entity at index %d: %w", i, err) - } - mutMap["upsert"] = entity - - case MutationDelete: - mutMap["delete"] = keyToJSON(mut.key) - - default: - return nil, fmt.Errorf("unknown mutation operation at index %d: %s", i, mut.op) - } - - tx.mutations = append(tx.mutations, mutMap) - - // Create a pending key for the result - pk := &PendingKey{key: mut.key} - pendingKeys = append(pendingKeys, pk) - } - - return pendingKeys, nil -} - -// PendingKey represents a key that will be resolved after a transaction commit. -// API compatible with cloud.google.com/go/datastore. -type PendingKey struct { - key *Key -} - -// commit commits the transaction. -func (tx *Transaction) doCommit(ctx context.Context, token string) error { - reqBody := map[string]any{ - "mode": "TRANSACTIONAL", - "transaction": tx.id, - "mutations": tx.mutations, - } - - if tx.client.databaseID != "" { - reqBody["databaseId"] = tx.client.databaseID - } - - jsonData, err := json.Marshal(reqBody) - if err != nil { - return err - } - - // URL-encode project ID to prevent injection attacks - reqURL := fmt.Sprintf("%s/projects/%s:commit", tx.client.baseURL, neturl.PathEscape(tx.client.projectID)) - req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, bytes.NewReader(jsonData)) - if err != nil { - return err - } - - req.Header.Set("Authorization", "Bearer "+token) - req.Header.Set("Content-Type", "application/json") - - // Add routing header for named databases - if tx.client.databaseID != "" { - // URL-encode values to prevent header injection attacks - routingHeader := fmt.Sprintf("project_id=%s&database_id=%s", - neturl.QueryEscape(tx.client.projectID), - neturl.QueryEscape(tx.client.databaseID)) - req.Header.Set("X-Goog-Request-Params", routingHeader) - } - - resp, err := httpClient.Do(req) - if err != nil { - return err - } - defer func() { - if closeErr := resp.Body.Close(); closeErr != nil { - tx.client.logger.Warn("failed to close response body", "error", closeErr) - } - }() - - body, err := io.ReadAll(io.LimitReader(resp.Body, maxBodySize)) - if err != nil { - return err - } - - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("commit failed with status %d: %s", resp.StatusCode, string(body)) - } - - return nil -} diff --git a/datastore_test.go b/datastore_test.go deleted file mode 100644 index 7676a23..0000000 --- a/datastore_test.go +++ /dev/null @@ -1,8023 +0,0 @@ -package ds9_test - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "net/http" - "net/http/httptest" - "strings" - "testing" - "time" - - "github.com/codeGROOVE-dev/ds9" - "github.com/codeGROOVE-dev/ds9/ds9mock" -) - -// testEntity represents a simple test entity. -type testEntity struct { - UpdatedAt time.Time `datastore:"updated_at"` - Name string `datastore:"name"` - Notes string `datastore:"notes,noindex"` - Count int64 `datastore:"count"` - Score float64 `datastore:"score"` - Active bool `datastore:"active"` -} - -func TestNewClient(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - // Just verify we got a valid client - if client == nil { - t.Fatal("expected non-nil client") - } -} - -func TestNewClientWithDatabase(t *testing.T) { - // Setup mock servers - metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Metadata-Flavor") != "Google" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.URL.Path == "/project/project-id" { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("test-project")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - if r.URL.Path == "/instance/service-accounts/default/token" { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "access_token": "test-token", - "expires_in": 3600, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer metadataServer.Close() - - apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - if err := json.NewEncoder(w).Encode(map[string]any{}); err != nil { - t.Logf("encode failed: %v", err) - } - })) - defer apiServer.Close() - - restore := ds9.SetTestURLs(metadataServer.URL, apiServer.URL) - defer restore() - - ctx := context.Background() - - // Test with explicit databaseID - client, err := ds9.NewClientWithDatabase(ctx, "test-project", "custom-db") - if err != nil { - t.Fatalf("NewClientWithDatabase failed: %v", err) - } - if client == nil { - t.Fatal("expected non-nil client") - } -} - -func TestPutAndGet(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - // Create test entity - now := time.Now().UTC().Truncate(time.Second) - entity := &testEntity{ - Name: "test-item", - Count: 42, - Active: true, - Score: 3.14, - UpdatedAt: now, - Notes: "This is a test note", - } - - // Put entity - key := ds9.NameKey("TestKind", "test-key", nil) - _, err := client.Put(ctx, key, entity) - if err != nil { - t.Fatalf("Put failed: %v", err) - } - - // Get entity - var retrieved testEntity - err = client.Get(ctx, key, &retrieved) - if err != nil { - t.Fatalf("Get failed: %v", err) - } - - // Verify fields - if retrieved.Name != entity.Name { - t.Errorf("Name: expected %q, got %q", entity.Name, retrieved.Name) - } - if retrieved.Count != entity.Count { - t.Errorf("Count: expected %d, got %d", entity.Count, retrieved.Count) - } - if retrieved.Active != entity.Active { - t.Errorf("Active: expected %v, got %v", entity.Active, retrieved.Active) - } - if retrieved.Score != entity.Score { - t.Errorf("Score: expected %f, got %f", entity.Score, retrieved.Score) - } - if !retrieved.UpdatedAt.Equal(entity.UpdatedAt) { - t.Errorf("UpdatedAt: expected %v, got %v", entity.UpdatedAt, retrieved.UpdatedAt) - } - if retrieved.Notes != entity.Notes { - t.Errorf("Notes: expected %q, got %q", entity.Notes, retrieved.Notes) - } -} - -func TestGetNotFound(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - key := ds9.NameKey("TestKind", "nonexistent", nil) - var entity testEntity - err := client.Get(ctx, key, &entity) - - if !errors.Is(err, ds9.ErrNoSuchEntity) { - t.Errorf("expected ErrNoSuchEntity, got %v", err) - } -} - -func TestDelete(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - // Put entity - entity := &testEntity{ - Name: "test-item", - Count: 42, - Active: true, - } - - key := ds9.NameKey("TestKind", "test-key", nil) - _, err := client.Put(ctx, key, entity) - if err != nil { - t.Fatalf("Put failed: %v", err) - } - - // Delete entity - err = client.Delete(ctx, key) - if err != nil { - t.Fatalf("Delete failed: %v", err) - } - - // Verify it's gone - var retrieved testEntity - err = client.Get(ctx, key, &retrieved) - if !errors.Is(err, ds9.ErrNoSuchEntity) { - t.Errorf("expected ErrNoSuchEntity after delete, got %v", err) - } -} - -func TestAllKeys(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - // Put multiple entities - for i := range 5 { - entity := &testEntity{ - Name: "test-item", - Count: int64(i), - } - key := ds9.NameKey("TestKind", string(rune('a'+i)), nil) - _, err := client.Put(ctx, key, entity) - if err != nil { - t.Fatalf("Put failed: %v", err) - } - } - - // Query for all keys - query := ds9.NewQuery("TestKind").KeysOnly() - keys, err := client.AllKeys(ctx, query) - if err != nil { - t.Fatalf("AllKeys failed: %v", err) - } - - if len(keys) != 5 { - t.Errorf("expected 5 keys, got %d", len(keys)) - } -} - -func TestAllKeysWithLimit(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - // Put multiple entities - for i := range 10 { - entity := &testEntity{ - Name: "test-item", - Count: int64(i), - } - key := ds9.NameKey("TestKind", string(rune('a'+i)), nil) - _, err := client.Put(ctx, key, entity) - if err != nil { - t.Fatalf("Put failed: %v", err) - } - } - - // Query with limit - query := ds9.NewQuery("TestKind").KeysOnly().Limit(3) - keys, err := client.AllKeys(ctx, query) - if err != nil { - t.Fatalf("AllKeys failed: %v", err) - } - - if len(keys) != 3 { - t.Errorf("expected 3 keys, got %d", len(keys)) - } -} - -func TestRunInTransaction(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - // Put initial entity - entity := &testEntity{ - Name: "counter", - Count: 0, - } - - key := ds9.NameKey("TestKind", "counter", nil) - _, err := client.Put(ctx, key, entity) - if err != nil { - t.Fatalf("Put failed: %v", err) - } - - // Run transaction to read and update - _, err = client.RunInTransaction(ctx, func(tx *ds9.Transaction) error { - var current testEntity - if err := tx.Get(key, ¤t); err != nil { - return err - } - - current.Count++ - _, err := tx.Put(key, ¤t) - return err - }) - if err != nil { - t.Fatalf("RunInTransaction failed: %v", err) - } - - // Verify the update - var updated testEntity - err = client.Get(ctx, key, &updated) - if err != nil { - t.Fatalf("Get failed: %v", err) - } - - if updated.Count != 1 { - t.Errorf("expected Count to be 1, got %d", updated.Count) - } -} - -func TestTransactionNotFound(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - key := ds9.NameKey("TestKind", "nonexistent", nil) - - _, err := client.RunInTransaction(ctx, func(tx *ds9.Transaction) error { - var entity testEntity - return tx.Get(key, &entity) - }) - - if !errors.Is(err, ds9.ErrNoSuchEntity) { - t.Errorf("expected ErrNoSuchEntity, got %v", err) - } -} - -func TestIDKey(t *testing.T) { - key := ds9.IDKey("TestKind", 12345, nil) - - if key.Kind != "TestKind" { - t.Errorf("expected Kind %q, got %q", "TestKind", key.Kind) - } - - if key.ID != 12345 { - t.Errorf("expected ID %d, got %d", 12345, key.ID) - } - - if key.Name != "" { - t.Errorf("expected empty Name, got %q", key.Name) - } -} - -func TestNameKey(t *testing.T) { - key := ds9.NameKey("TestKind", "test-name", nil) - - if key.Kind != "TestKind" { - t.Errorf("expected Kind %q, got %q", "TestKind", key.Kind) - } - - if key.Name != "test-name" { - t.Errorf("expected Name %q, got %q", "test-name", key.Name) - } - - if key.ID != 0 { - t.Errorf("expected ID 0, got %d", key.ID) - } -} - -func TestMultiPutAndMultiGet(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - // Create test entities - now := time.Now().UTC().Truncate(time.Second) - entities := []testEntity{ - { - Name: "item-1", - Count: 1, - Active: true, - Score: 1.1, - UpdatedAt: now, - }, - { - Name: "item-2", - Count: 2, - Active: false, - Score: 2.2, - UpdatedAt: now, - }, - { - Name: "item-3", - Count: 3, - Active: true, - Score: 3.3, - UpdatedAt: now, - }, - } - - keys := []*ds9.Key{ - ds9.NameKey("TestKind", "key-1", nil), - ds9.NameKey("TestKind", "key-2", nil), - ds9.NameKey("TestKind", "key-3", nil), - } - - // MultiPut - _, err := client.PutMulti(ctx, keys, entities) - if err != nil { - t.Fatalf("MultiPut failed: %v", err) - } - - // MultiGet - var retrieved []testEntity - err = client.GetMulti(ctx, keys, &retrieved) - if err != nil { - t.Fatalf("MultiGet failed: %v", err) - } - - if len(retrieved) != 3 { - t.Fatalf("expected 3 entities, got %d", len(retrieved)) - } - - // Verify entities - for i, entity := range retrieved { - if entity.Name != entities[i].Name { - t.Errorf("entity %d: Name mismatch: expected %q, got %q", i, entities[i].Name, entity.Name) - } - if entity.Count != entities[i].Count { - t.Errorf("entity %d: Count mismatch: expected %d, got %d", i, entities[i].Count, entity.Count) - } - if entity.Active != entities[i].Active { - t.Errorf("entity %d: Active mismatch: expected %v, got %v", i, entities[i].Active, entity.Active) - } - if entity.Score != entities[i].Score { - t.Errorf("entity %d: Score mismatch: expected %f, got %f", i, entities[i].Score, entity.Score) - } - } -} - -func TestMultiGetNotFound(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - // Put only one entity - entity := &testEntity{Name: "exists", Count: 1} - key1 := ds9.NameKey("TestKind", "exists", nil) - _, err := client.Put(ctx, key1, entity) - if err != nil { - t.Fatalf("Put failed: %v", err) - } - - // Try to get multiple, one missing - keys := []*ds9.Key{ - key1, - ds9.NameKey("TestKind", "missing", nil), - } - - var retrieved []testEntity - err = client.GetMulti(ctx, keys, &retrieved) - - if !errors.Is(err, ds9.ErrNoSuchEntity) { - t.Errorf("expected ErrNoSuchEntity when some keys missing, got %v", err) - } -} - -func TestMultiDelete(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - // Put multiple entities - entities := []testEntity{ - {Name: "item-1", Count: 1}, - {Name: "item-2", Count: 2}, - {Name: "item-3", Count: 3}, - } - - keys := []*ds9.Key{ - ds9.NameKey("TestKind", "key-1", nil), - ds9.NameKey("TestKind", "key-2", nil), - ds9.NameKey("TestKind", "key-3", nil), - } - - _, err := client.PutMulti(ctx, keys, entities) - if err != nil { - t.Fatalf("MultiPut failed: %v", err) - } - - // MultiDelete - err = client.DeleteMulti(ctx, keys) - if err != nil { - t.Fatalf("MultiDelete failed: %v", err) - } - - // Verify they're gone by trying to get them - var retrieved []testEntity - err = client.GetMulti(ctx, keys, &retrieved) - if !errors.Is(err, ds9.ErrNoSuchEntity) { - t.Errorf("expected ErrNoSuchEntity after delete, got %v", err) - } -} - -func TestMultiPutEmptyKeys(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - var entities []testEntity - var keys []*ds9.Key - - _, err := client.PutMulti(ctx, keys, entities) - if err == nil { - t.Error("expected error for empty keys, got nil") - } -} - -func TestMultiGetEmptyKeys(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - var keys []*ds9.Key - var retrieved []testEntity - - err := client.GetMulti(ctx, keys, &retrieved) - if err == nil { - t.Error("expected error for empty keys, got nil") - } -} - -func TestMultiDeleteEmptyKeys(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - var keys []*ds9.Key - - err := client.DeleteMulti(ctx, keys) - if err == nil { - t.Error("expected error for empty keys, got nil") - } -} - -func TestIDKeyOperations(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - // Test with ID key - entity := &testEntity{ - Name: "id-test", - Count: 123, - } - - key := ds9.IDKey("TestKind", 999, nil) - _, err := client.Put(ctx, key, entity) - if err != nil { - t.Fatalf("Put with ID key failed: %v", err) - } - - // Get with ID key - var retrieved testEntity - err = client.Get(ctx, key, &retrieved) - if err != nil { - t.Fatalf("Get with ID key failed: %v", err) - } - - if retrieved.Name != "id-test" { - t.Errorf("expected Name 'id-test', got %q", retrieved.Name) - } -} - -func TestPutWithNilKey(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - entity := &testEntity{Name: "test"} - _, err := client.Put(ctx, nil, entity) - if err == nil { - t.Error("expected error for nil key, got nil") - } -} - -func TestGetWithNilKey(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - var entity testEntity - err := client.Get(ctx, nil, &entity) - if err == nil { - t.Error("expected error for nil key, got nil") - } -} - -func TestDeleteWithNilKey(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - err := client.Delete(ctx, nil) - if err == nil { - t.Error("expected error for nil key, got nil") - } -} - -func TestMultiGetWithNilKey(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - keys := []*ds9.Key{ - ds9.NameKey("TestKind", "key-1", nil), - nil, - ds9.NameKey("TestKind", "key-2", nil), - } - - var entities []testEntity - err := client.GetMulti(ctx, keys, &entities) - if err == nil { - t.Error("expected error for nil key in slice, got nil") - } -} - -func TestMultiPutWithNilKey(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - entities := []testEntity{ - {Name: "item-1"}, - {Name: "item-2"}, - } - - keys := []*ds9.Key{ - ds9.NameKey("TestKind", "key-1", nil), - nil, - } - - _, err := client.PutMulti(ctx, keys, entities) - if err == nil { - t.Error("expected error for nil key in slice, got nil") - } -} - -func TestMultiDeleteWithNilKey(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - keys := []*ds9.Key{ - ds9.NameKey("TestKind", "key-1", nil), - nil, - } - - err := client.DeleteMulti(ctx, keys) - if err == nil { - t.Error("expected error for nil key in slice, got nil") - } -} - -func TestMultiPutMismatchedSlices(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - entities := []testEntity{ - {Name: "item-1"}, - {Name: "item-2"}, - } - - keys := []*ds9.Key{ - ds9.NameKey("TestKind", "key-1", nil), - } - - _, err := client.PutMulti(ctx, keys, entities) - if err == nil { - t.Error("expected error for mismatched slices, got nil") - } -} - -func TestAllKeysNonKeysOnlyQuery(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - // Create a query without KeysOnly - query := ds9.NewQuery("TestKind") - _, err := client.AllKeys(ctx, query) - if err == nil { - t.Error("expected error for non-KeysOnly query, got nil") - } -} - -func TestQueryOperations(t *testing.T) { - // Test query builder methods - query := ds9.NewQuery("TestKind") - - if query.KeysOnly().KeysOnly() == nil { - t.Error("KeysOnly() should be chainable") - } - - if query.Limit(10).Limit(20) == nil { - t.Error("Limit() should be chainable") - } -} - -func TestEntityWithAllTypes(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - type AllTypes struct { - TimeVal time.Time `datastore:"t"` - StringVal string `datastore:"str"` - NoIndex string `datastore:"noindex,noindex"` - Skip string `datastore:"-"` - Int64Val int64 `datastore:"i64"` - IntVal int `datastore:"i"` - Float64Val float64 `datastore:"f64"` - Int32Val int32 `datastore:"i32"` - BoolVal bool `datastore:"b"` - } - - now := time.Now().UTC().Truncate(time.Second) - entity := &AllTypes{ - StringVal: "test", - Int64Val: int64(123), - Int32Val: int32(456), - IntVal: 789, - BoolVal: true, - Float64Val: 3.14, - TimeVal: now, - NoIndex: "not indexed", - Skip: "should not be stored", - } - - key := ds9.NameKey("AllTypes", "test", nil) - _, err := client.Put(ctx, key, entity) - if err != nil { - t.Fatalf("Put failed: %v", err) - } - - var retrieved AllTypes - err = client.Get(ctx, key, &retrieved) - if err != nil { - t.Fatalf("Get failed: %v", err) - } - - if retrieved.StringVal != entity.StringVal { - t.Errorf("StringVal: expected %v, got %v", entity.StringVal, retrieved.StringVal) - } - if retrieved.Int64Val != entity.Int64Val { - t.Errorf("Int64Val: expected %v, got %v", entity.Int64Val, retrieved.Int64Val) - } - if retrieved.Int32Val != entity.Int32Val { - t.Errorf("Int32Val: expected %v, got %v", entity.Int32Val, retrieved.Int32Val) - } - if retrieved.IntVal != entity.IntVal { - t.Errorf("IntVal: expected %v, got %v", entity.IntVal, retrieved.IntVal) - } - if retrieved.BoolVal != entity.BoolVal { - t.Errorf("BoolVal: expected %v, got %v", entity.BoolVal, retrieved.BoolVal) - } - if retrieved.Float64Val != entity.Float64Val { - t.Errorf("Float64Val: expected %v, got %v", entity.Float64Val, retrieved.Float64Val) - } - if !retrieved.TimeVal.Equal(entity.TimeVal) { - t.Errorf("TimeVal: expected %v, got %v", entity.TimeVal, retrieved.TimeVal) - } - if retrieved.NoIndex != entity.NoIndex { - t.Errorf("NoIndex: expected %v, got %v", entity.NoIndex, retrieved.NoIndex) - } - if retrieved.Skip != "" { - t.Errorf("Skip field should be empty, got %q", retrieved.Skip) - } -} - -func TestSetTestURLs(t *testing.T) { - // Save original values - restore := ds9.SetTestURLs("http://test1", "http://test2") - - // Restore should work - restore() - - // Should be chainable - restore2 := ds9.SetTestURLs("http://test3", "http://test4") - restore2() -} - -func TestTransactionMultipleOperations(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - // Put initial entities - for i := range 3 { - entity := &testEntity{ - Name: "item", - Count: int64(i), - } - key := ds9.NameKey("TestKind", string(rune('a'+i)), nil) - _, err := client.Put(ctx, key, entity) - if err != nil { - t.Fatalf("Put failed: %v", err) - } - } - - // Run transaction that reads and updates multiple entities - _, err := client.RunInTransaction(ctx, func(tx *ds9.Transaction) error { - for i := range 3 { - key := ds9.NameKey("TestKind", string(rune('a'+i)), nil) - var current testEntity - if err := tx.Get(key, ¤t); err != nil { - return err - } - - current.Count += 10 - _, err := tx.Put(key, ¤t) - if err != nil { - return err - } - } - return nil - }) - if err != nil { - t.Fatalf("RunInTransaction failed: %v", err) - } - - // Verify updates - for i := range 3 { - key := ds9.NameKey("TestKind", string(rune('a'+i)), nil) - var retrieved testEntity - err = client.Get(ctx, key, &retrieved) - if err != nil { - t.Fatalf("Get failed: %v", err) - } - expectedCount := int64(i + 10) - if retrieved.Count != expectedCount { - t.Errorf("entity %d: expected Count %d, got %d", i, expectedCount, retrieved.Count) - } - } -} - -func TestMultiGetPartialResults(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - // Put some entities - entities := []testEntity{ - {Name: "item-1", Count: 1}, - {Name: "item-3", Count: 3}, - } - keys := []*ds9.Key{ - ds9.NameKey("TestKind", "key-1", nil), - ds9.NameKey("TestKind", "key-3", nil), - } - _, err := client.PutMulti(ctx, keys, entities) - if err != nil { - t.Fatalf("MultiPut failed: %v", err) - } - - // Try to get more keys than exist - getAllKeys := []*ds9.Key{ - ds9.NameKey("TestKind", "key-1", nil), - ds9.NameKey("TestKind", "key-2", nil), // doesn't exist - ds9.NameKey("TestKind", "key-3", nil), - } - - var retrieved []testEntity - err = client.GetMulti(ctx, getAllKeys, &retrieved) - if err == nil { - t.Error("expected error when some keys don't exist") - } -} - -func TestEmptyQuery(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - // Query for keys when no entities exist - query := ds9.NewQuery("NonExistent").KeysOnly() - keys, err := client.AllKeys(ctx, query) - if err != nil { - t.Fatalf("AllKeys failed: %v", err) - } - - if len(keys) != 0 { - t.Errorf("expected 0 keys, got %d", len(keys)) - } -} - -func TestKeyComparison(t *testing.T) { - nameKey1 := ds9.NameKey("Kind", "name", nil) - nameKey2 := ds9.NameKey("Kind", "name", nil) - - if nameKey1.Kind != nameKey2.Kind || nameKey1.Name != nameKey2.Name { - t.Error("identical name keys should have same values") - } - - idKey1 := ds9.IDKey("Kind", 123, nil) - idKey2 := ds9.IDKey("Kind", 123, nil) - - if idKey1.Kind != idKey2.Kind || idKey1.ID != idKey2.ID { - t.Error("identical ID keys should have same values") - } -} - -func TestLargeEntityBatch(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - // Create a larger batch - const batchSize = 50 - entities := make([]testEntity, batchSize) - keys := make([]*ds9.Key, batchSize) - - for i := range batchSize { - entities[i] = testEntity{ - Name: "batch-item", - Count: int64(i), - } - keys[i] = ds9.NameKey("BatchKind", string(rune('0'+i/10))+string(rune('0'+i%10)), nil) - } - - // MultiPut - _, err := client.PutMulti(ctx, keys, entities) - if err != nil { - t.Fatalf("MultiPut failed: %v", err) - } - - // MultiGet - var retrieved []testEntity - err = client.GetMulti(ctx, keys, &retrieved) - if err != nil { - t.Fatalf("MultiGet failed: %v", err) - } - - if len(retrieved) != batchSize { - t.Errorf("expected %d entities, got %d", batchSize, len(retrieved)) - } - - // MultiDelete - err = client.DeleteMulti(ctx, keys) - if err != nil { - t.Fatalf("MultiDelete failed: %v", err) - } - - // Verify deletion - var retrieved2 []testEntity - err = client.GetMulti(ctx, keys, &retrieved2) - if !errors.Is(err, ds9.ErrNoSuchEntity) { - t.Errorf("expected ErrNoSuchEntity after batch delete, got %v", err) - } -} - -func TestUnsupportedEncodeType(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - // Entity with unsupported type (map) - type BadEntity struct { - Name string - Data map[string]string // maps not supported - } - - key := ds9.NameKey("TestKind", "bad", nil) - entity := BadEntity{ - Name: "test", - Data: map[string]string{"key": "value"}, - } - - _, err := client.Put(ctx, key, &entity) - if err == nil { - t.Error("expected error for unsupported type, got nil") - } - if !strings.Contains(err.Error(), "unsupported type") { - t.Errorf("expected 'unsupported type' error, got: %v", err) - } -} - -func TestDecodeNonPointer(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - // Store entity - key := ds9.NameKey("TestKind", "test", nil) - entity := testEntity{Name: "test", Count: 42} - _, err := client.Put(ctx, key, &entity) - if err != nil { - t.Fatalf("Put failed: %v", err) - } - - // Try to decode into non-pointer - var notPtr testEntity - err = client.Get(ctx, key, notPtr) // Should be ¬Ptr - if err == nil { - t.Error("expected error for non-pointer dst, got nil") - } - if !strings.Contains(err.Error(), "pointer to struct") { - t.Errorf("expected 'pointer to struct' error, got: %v", err) - } -} - -func TestDecodePointerToNonStruct(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - // Store entity - key := ds9.NameKey("TestKind", "test", nil) - entity := testEntity{Name: "test", Count: 42} - _, err := client.Put(ctx, key, &entity) - if err != nil { - t.Fatalf("Put failed: %v", err) - } - - // Try to decode into pointer to string - var str string - err = client.Get(ctx, key, &str) - if err == nil { - t.Error("expected error for pointer to non-struct, got nil") - } - if !strings.Contains(err.Error(), "pointer to struct") { - t.Errorf("expected 'pointer to struct' error, got: %v", err) - } -} - -func TestEntityWithSkippedFields(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - type EntityWithSkip struct { - Name string `datastore:"name"` - Skipped string `datastore:"-"` - private string - Count int64 `datastore:"count"` - } - - key := ds9.NameKey("TestKind", "skip", nil) - entity := EntityWithSkip{ - Name: "test", - Count: 42, - Skipped: "should not store", - private: "also not stored", - } - - _, err := client.Put(ctx, key, &entity) - if err != nil { - t.Fatalf("Put failed: %v", err) - } - - var retrieved EntityWithSkip - err = client.Get(ctx, key, &retrieved) - if err != nil { - t.Fatalf("Get failed: %v", err) - } - - if retrieved.Name != entity.Name || retrieved.Count != entity.Count { - t.Errorf("wrong values: got %+v", retrieved) - } - - // Skipped field should be zero value - if retrieved.Skipped != "" { - t.Errorf("Skipped field should be empty, got %q", retrieved.Skipped) - } -} - -func TestZeroValueEntity(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - type ZeroEntity struct { - Name string - Count int64 - Active bool - Score float64 - } - - key := ds9.NameKey("TestKind", "zero", nil) - entity := ZeroEntity{} // All zero values - - _, err := client.Put(ctx, key, &entity) - if err != nil { - t.Fatalf("Put with zero values failed: %v", err) - } - - var retrieved ZeroEntity - err = client.Get(ctx, key, &retrieved) - if err != nil { - t.Fatalf("Get failed: %v", err) - } - - if retrieved.Name != "" || retrieved.Count != 0 || retrieved.Active != false || retrieved.Score != 0.0 { - t.Errorf("expected zero values, got %+v", retrieved) - } -} - -func TestQueryWithLimitZero(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - // Store some entities - for i := range 5 { - key := ds9.NameKey("LimitKind", string(rune('a'+i)), nil) - entity := testEntity{Name: "item", Count: int64(i)} - if _, err := client.Put(ctx, key, &entity); err != nil { - t.Fatalf("Put failed: %v", err) - } - } - - // Query with limit 0 (should return all) - query := ds9.NewQuery("LimitKind").KeysOnly().Limit(0) - keys, err := client.AllKeys(ctx, query) - if err != nil { - t.Fatalf("AllKeys failed: %v", err) - } - - if len(keys) == 0 { - t.Error("expected keys, got 0 (limit 0 should mean unlimited)") - } -} - -func TestQueryWithLimitLessThanResults(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - // Store 10 entities - for i := range 10 { - key := ds9.NameKey("LimitKind2", string(rune('a'+i)), nil) - entity := testEntity{Name: "item", Count: int64(i)} - if _, err := client.Put(ctx, key, &entity); err != nil { - t.Fatalf("Put failed: %v", err) - } - } - - // Query with limit 3 - query := ds9.NewQuery("LimitKind2").KeysOnly().Limit(3) - keys, err := client.AllKeys(ctx, query) - if err != nil { - t.Fatalf("AllKeys failed: %v", err) - } - - if len(keys) != 3 { - t.Errorf("expected 3 keys, got %d", len(keys)) - } -} - -func TestMultiGetEmptySlices(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - // Call MultiGet with empty slices - should return error - var entities []testEntity - err := client.GetMulti(ctx, []*ds9.Key{}, &entities) - if err == nil { - t.Error("expected error for MultiGet with empty keys, got nil") - } -} - -func TestMultiPutEmptySlices(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - // Call MultiPut with empty slices - should return error - _, err := client.PutMulti(ctx, []*ds9.Key{}, []testEntity{}) - if err == nil { - t.Error("expected error for MultiPut with empty keys, got nil") - } -} - -func TestMultiDeleteEmptySlice(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - // Call MultiDelete with empty slice - should return error - err := client.DeleteMulti(ctx, []*ds9.Key{}) - if err == nil { - t.Error("expected error for MultiDelete with empty keys, got nil") - } -} - -func TestNewClientWithDatabaseEmptyProjectID(t *testing.T) { - // Setup mock servers - metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Metadata-Flavor") != "Google" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.URL.Path == "/project/project-id" { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("auto-detected-project")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - if r.URL.Path == "/instance/service-accounts/default/token" { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "access_token": "test-token", - "expires_in": 3600, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer metadataServer.Close() - - apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - if err := json.NewEncoder(w).Encode(map[string]any{}); err != nil { - t.Logf("encode failed: %v", err) - } - })) - defer apiServer.Close() - - restore := ds9.SetTestURLs(metadataServer.URL, apiServer.URL) - defer restore() - - ctx := context.Background() - - // Test with empty projectID - should fetch from metadata - client, err := ds9.NewClientWithDatabase(ctx, "", "my-db") - if err != nil { - t.Fatalf("NewClientWithDatabase with empty projectID failed: %v", err) - } - if client == nil { - t.Fatal("expected non-nil client") - } -} - -func TestNewClientWithDatabaseProjectIDFetchFailure(t *testing.T) { - // Setup mock servers that fail to provide projectID - metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Metadata-Flavor") != "Google" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.URL.Path == "/project/project-id" { - // Return error instead of project ID - w.WriteHeader(http.StatusInternalServerError) - return - } - if r.URL.Path == "/instance/service-accounts/default/token" { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "access_token": "test-token", - "expires_in": 3600, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer metadataServer.Close() - - apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - if err := json.NewEncoder(w).Encode(map[string]any{}); err != nil { - t.Logf("encode failed: %v", err) - } - })) - defer apiServer.Close() - - restore := ds9.SetTestURLs(metadataServer.URL, apiServer.URL) - defer restore() - - ctx := context.Background() - - // Test with empty projectID and failing metadata server - client, err := ds9.NewClientWithDatabase(ctx, "", "my-db") - if err == nil { - t.Fatal("expected error when projectID fetch fails, got nil") - } - if client != nil { - t.Errorf("expected nil client on error, got %v", client) - } - if !strings.Contains(err.Error(), "project ID required") { - t.Errorf("expected 'project ID required' error, got: %v", err) - } -} - -func TestTransactionWithError(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - // Store initial entity - key := ds9.NameKey("TestKind", "tx-err", nil) - entity := testEntity{Name: "initial", Count: 1} - _, err := client.Put(ctx, key, &entity) - if err != nil { - t.Fatalf("Put failed: %v", err) - } - - // Run transaction that errors - _, err = client.RunInTransaction(ctx, func(tx *ds9.Transaction) error { - var current testEntity - if err := tx.Get(key, ¤t); err != nil { - return err - } - - current.Count = 999 - - if _, err := tx.Put(key, ¤t); err != nil { - return err - } - - // Return error to trigger rollback - return errors.New("intentional error") - }) - - if err == nil { - t.Fatal("expected transaction to fail, got nil error") - } - if !strings.Contains(err.Error(), "intentional error") { - t.Errorf("expected 'intentional error', got: %v", err) - } - - // Verify entity was not modified (transaction rolled back) - // Note: In a real implementation this would check rollback, but our mock doesn't support it - // This test at least exercises the error path -} - -func TestTransactionWithDatabaseID(t *testing.T) { - // Setup mock servers - metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Metadata-Flavor") != "Google" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.URL.Path == "/project/project-id" { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("test-project")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - if r.URL.Path == "/instance/service-accounts/default/token" { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "access_token": "test-token", - "expires_in": 3600, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer metadataServer.Close() - - txID := "test-tx-123" - apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Authorization") != "Bearer test-token" { - w.WriteHeader(http.StatusUnauthorized) - return - } - - var reqBody map[string]any - if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil { - w.WriteHeader(http.StatusBadRequest) - return - } - - // Check for databaseId in request - if dbID, ok := reqBody["databaseId"].(string); ok && dbID != "tx-db" { - t.Errorf("expected databaseId 'tx-db', got %v", dbID) - } - - w.Header().Set("Content-Type", "application/json") - - if r.URL.Path == "/projects/test-project:beginTransaction" { - w.WriteHeader(http.StatusOK) - if err := json.NewEncoder(w).Encode(map[string]any{ - "transaction": txID, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - - if r.URL.Path == "/projects/test-project:commit" { - w.WriteHeader(http.StatusOK) - if err := json.NewEncoder(w).Encode(map[string]any{ - "mutationResults": []any{}, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - - if r.URL.Path == "/projects/test-project:lookup" { - w.WriteHeader(http.StatusOK) - if err := json.NewEncoder(w).Encode(map[string]any{ - "found": []any{}, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - - w.WriteHeader(http.StatusNotFound) - })) - defer apiServer.Close() - - restore := ds9.SetTestURLs(metadataServer.URL, apiServer.URL) - defer restore() - - ctx := context.Background() - client, err := ds9.NewClientWithDatabase(ctx, "test-project", "tx-db") - if err != nil { - t.Fatalf("NewClientWithDatabase failed: %v", err) - } - - // Run transaction with databaseID - _, err = client.RunInTransaction(ctx, func(tx *ds9.Transaction) error { - key := ds9.NameKey("TestKind", "tx-test", nil) - entity := testEntity{Name: "in-tx", Count: 42} - _, err := tx.Put(key, &entity) - return err - }) - if err != nil { - t.Fatalf("Transaction with databaseID failed: %v", err) - } -} - -func TestDeleteWithDatabaseID(t *testing.T) { - // Setup with databaseID - metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Metadata-Flavor") != "Google" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.URL.Path == "/project/project-id" { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("test-project")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - if r.URL.Path == "/instance/service-accounts/default/token" { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "access_token": "test-token", - "expires_in": 3600, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer metadataServer.Close() - - apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Authorization") != "Bearer test-token" { - w.WriteHeader(http.StatusUnauthorized) - return - } - - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - if err := json.NewEncoder(w).Encode(map[string]any{ - "mutationResults": []any{}, - }); err != nil { - t.Logf("encode failed: %v", err) - } - })) - defer apiServer.Close() - - restore := ds9.SetTestURLs(metadataServer.URL, apiServer.URL) - defer restore() - - ctx := context.Background() - client, err := ds9.NewClientWithDatabase(ctx, "test-project", "del-db") - if err != nil { - t.Fatalf("NewClientWithDatabase failed: %v", err) - } - - // Delete with databaseID - key := ds9.NameKey("TestKind", "to-delete", nil) - err = client.Delete(ctx, key) - if err != nil { - t.Fatalf("Delete with databaseID failed: %v", err) - } -} - -func TestAllKeysWithDatabaseID(t *testing.T) { - // Setup with databaseID - metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Metadata-Flavor") != "Google" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.URL.Path == "/project/project-id" { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("test-project")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - if r.URL.Path == "/instance/service-accounts/default/token" { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "access_token": "test-token", - "expires_in": 3600, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer metadataServer.Close() - - apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Authorization") != "Bearer test-token" { - w.WriteHeader(http.StatusUnauthorized) - return - } - - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - if err := json.NewEncoder(w).Encode(map[string]any{ - "batch": map[string]any{ - "entityResults": []any{}, - }, - }); err != nil { - t.Logf("encode failed: %v", err) - } - })) - defer apiServer.Close() - - restore := ds9.SetTestURLs(metadataServer.URL, apiServer.URL) - defer restore() - - ctx := context.Background() - client, err := ds9.NewClientWithDatabase(ctx, "test-project", "query-db") - if err != nil { - t.Fatalf("NewClientWithDatabase failed: %v", err) - } - - // Query with databaseID - query := ds9.NewQuery("TestKind").KeysOnly() - keys, err := client.AllKeys(ctx, query) - if err != nil { - t.Fatalf("AllKeys with databaseID failed: %v", err) - } - - if len(keys) != 0 { - t.Errorf("expected 0 keys, got %d", len(keys)) - } -} - -func TestMultiGetWithDatabaseID(t *testing.T) { - // Setup with databaseID - metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Metadata-Flavor") != "Google" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.URL.Path == "/project/project-id" { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("test-project")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - if r.URL.Path == "/instance/service-accounts/default/token" { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "access_token": "test-token", - "expires_in": 3600, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer metadataServer.Close() - - apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Authorization") != "Bearer test-token" { - w.WriteHeader(http.StatusUnauthorized) - return - } - - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - // Return missing entities to trigger ErrNoSuchEntity - if err := json.NewEncoder(w).Encode(map[string]any{ - "found": []any{}, - "missing": []any{ - map[string]any{"entity": map[string]any{"key": map[string]any{}}}, - }, - }); err != nil { - t.Logf("encode failed: %v", err) - } - })) - defer apiServer.Close() - - restore := ds9.SetTestURLs(metadataServer.URL, apiServer.URL) - defer restore() - - ctx := context.Background() - client, err := ds9.NewClientWithDatabase(ctx, "test-project", "multiget-db") - if err != nil { - t.Fatalf("NewClientWithDatabase failed: %v", err) - } - - // MultiGet with databaseID - keys := []*ds9.Key{ - ds9.NameKey("TestKind", "key1", nil), - ds9.NameKey("TestKind", "key2", nil), - } - var entities []testEntity - err = client.GetMulti(ctx, keys, &entities) - // Expect error since entities don't exist - if !errors.Is(err, ds9.ErrNoSuchEntity) { - t.Errorf("expected ErrNoSuchEntity, got: %v", err) - } -} - -func TestMultiDeleteWithDatabaseID(t *testing.T) { - // Setup with databaseID - metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Metadata-Flavor") != "Google" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.URL.Path == "/project/project-id" { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("test-project")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - if r.URL.Path == "/instance/service-accounts/default/token" { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "access_token": "test-token", - "expires_in": 3600, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer metadataServer.Close() - - apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Authorization") != "Bearer test-token" { - w.WriteHeader(http.StatusUnauthorized) - return - } - - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - if err := json.NewEncoder(w).Encode(map[string]any{ - "mutationResults": []any{}, - }); err != nil { - t.Logf("encode failed: %v", err) - } - })) - defer apiServer.Close() - - restore := ds9.SetTestURLs(metadataServer.URL, apiServer.URL) - defer restore() - - ctx := context.Background() - client, err := ds9.NewClientWithDatabase(ctx, "test-project", "multidel-db") - if err != nil { - t.Fatalf("NewClientWithDatabase failed: %v", err) - } - - // MultiDelete with databaseID - keys := []*ds9.Key{ - ds9.NameKey("TestKind", "key1", nil), - ds9.NameKey("TestKind", "key2", nil), - } - err = client.DeleteMulti(ctx, keys) - if err != nil { - t.Fatalf("MultiDelete with databaseID failed: %v", err) - } -} - -func TestDeleteAllByKind(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - // Put multiple entities of the same kind - for i := range 5 { - entity := &testEntity{ - Name: "item", - Count: int64(i), - } - key := ds9.NameKey("DeleteKind", string(rune('a'+i)), nil) - _, err := client.Put(ctx, key, entity) - if err != nil { - t.Fatalf("Put failed: %v", err) - } - } - - // Delete all entities of this kind - err := client.DeleteAllByKind(ctx, "DeleteKind") - if err != nil { - t.Fatalf("DeleteAllByKind failed: %v", err) - } - - // Verify all deleted - query := ds9.NewQuery("DeleteKind").KeysOnly() - keys, err := client.AllKeys(ctx, query) - if err != nil { - t.Fatalf("AllKeys failed: %v", err) - } - - if len(keys) != 0 { - t.Errorf("expected 0 keys after DeleteAllByKind, got %d", len(keys)) - } -} - -func TestDeleteAllByKindEmpty(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - // Delete from non-existent kind - err := client.DeleteAllByKind(ctx, "NonExistentKind") - if err != nil { - t.Errorf("DeleteAllByKind on empty kind should not error, got: %v", err) - } -} - -func TestHierarchicalKeys(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - // Create parent key - parentKey := ds9.NameKey("Parent", "parent1", nil) - parentEntity := &testEntity{ - Name: "parent", - Count: 1, - } - _, err := client.Put(ctx, parentKey, parentEntity) - if err != nil { - t.Fatalf("Put parent failed: %v", err) - } - - // Create child key with parent - childKey := ds9.NameKey("Child", "child1", parentKey) - childEntity := &testEntity{ - Name: "child", - Count: 2, - } - _, err = client.Put(ctx, childKey, childEntity) - if err != nil { - t.Fatalf("Put child failed: %v", err) - } - - // Get child - var retrieved testEntity - err = client.Get(ctx, childKey, &retrieved) - if err != nil { - t.Fatalf("Get child failed: %v", err) - } - - if retrieved.Name != "child" { - t.Errorf("expected child name 'child', got %q", retrieved.Name) - } -} - -func TestHierarchicalKeysMultiLevel(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - // Create grandparent -> parent -> child hierarchy - grandparentKey := ds9.NameKey("Grandparent", "gp1", nil) - parentKey := ds9.NameKey("Parent", "p1", grandparentKey) - childKey := ds9.NameKey("Child", "c1", parentKey) - - entity := &testEntity{ - Name: "deep-child", - Count: 42, - } - - _, err := client.Put(ctx, childKey, entity) - if err != nil { - t.Fatalf("Put with multi-level hierarchy failed: %v", err) - } - - var retrieved testEntity - err = client.Get(ctx, childKey, &retrieved) - if err != nil { - t.Fatalf("Get with multi-level hierarchy failed: %v", err) - } - - if retrieved.Name != "deep-child" { - t.Errorf("expected name 'deep-child', got %q", retrieved.Name) - } -} - -func TestDoRequestRetryOn5xxError(t *testing.T) { - // Setup mock servers - metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Metadata-Flavor") != "Google" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.URL.Path == "/project/project-id" { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("test-project")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - if r.URL.Path == "/instance/service-accounts/default/token" { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "access_token": "test-token", - "expires_in": 3600, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer metadataServer.Close() - - attemptCount := 0 - apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - attemptCount++ - // Return 503 on first two attempts, then succeed - if attemptCount < 3 { - w.WriteHeader(http.StatusServiceUnavailable) - if _, err := w.Write([]byte(`{"error":"service unavailable"}`)); err != nil { - t.Logf("write failed: %v", err) - } - return - } - - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - if err := json.NewEncoder(w).Encode(map[string]any{ - "mutationResults": []any{ - map[string]any{"key": map[string]any{}}, - }, - }); err != nil { - t.Logf("encode failed: %v", err) - } - })) - defer apiServer.Close() - - restore := ds9.SetTestURLs(metadataServer.URL, apiServer.URL) - defer restore() - - ctx := context.Background() - client, err := ds9.NewClient(ctx, "test-project") - if err != nil { - t.Fatalf("NewClient failed: %v", err) - } - - // This should succeed after retries - key := ds9.NameKey("TestKind", "retry-test", nil) - entity := &testEntity{Name: "test", Count: 1} - _, err = client.Put(ctx, key, entity) - if err != nil { - t.Fatalf("Put should succeed after retries, got: %v", err) - } - - if attemptCount < 2 { - t.Errorf("expected at least 2 attempts, got %d", attemptCount) - } -} - -func TestDoRequestFailsOn4xxError(t *testing.T) { - metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Metadata-Flavor") != "Google" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.URL.Path == "/project/project-id" { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("test-project")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - if r.URL.Path == "/instance/service-accounts/default/token" { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "access_token": "test-token", - "expires_in": 3600, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer metadataServer.Close() - - attemptCount := 0 - apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - attemptCount++ - // Always return 400 Bad Request - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusBadRequest) - if _, err := w.Write([]byte(`{"error":"bad request"}`)); err != nil { - t.Logf("write failed: %v", err) - } - })) - defer apiServer.Close() - - restore := ds9.SetTestURLs(metadataServer.URL, apiServer.URL) - defer restore() - - ctx := context.Background() - client, err := ds9.NewClient(ctx, "test-project") - if err != nil { - t.Fatalf("NewClient failed: %v", err) - } - - // This should fail immediately without retry on 4xx - key := ds9.NameKey("TestKind", "bad-request", nil) - entity := &testEntity{Name: "test", Count: 1} - _, err = client.Put(ctx, key, entity) - if err == nil { - t.Fatal("expected error on 4xx response") - } - - if !strings.Contains(err.Error(), "400") { - t.Errorf("expected error to mention 400 status, got: %v", err) - } - - // Should only try once for 4xx errors (no retry) - if attemptCount != 1 { - t.Errorf("expected exactly 1 attempt for 4xx error, got %d", attemptCount) - } -} - -func TestDoRequestContextCancellation(t *testing.T) { - metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Metadata-Flavor") != "Google" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.URL.Path == "/project/project-id" { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("test-project")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - if r.URL.Path == "/instance/service-accounts/default/token" { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "access_token": "test-token", - "expires_in": 3600, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer metadataServer.Close() - - attemptCount := 0 - apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - attemptCount++ - // Always return 503 to force retry - w.WriteHeader(http.StatusServiceUnavailable) - if _, err := w.Write([]byte(`{"error":"unavailable"}`)); err != nil { - t.Logf("write failed: %v", err) - } - })) - defer apiServer.Close() - - restore := ds9.SetTestURLs(metadataServer.URL, apiServer.URL) - defer restore() - - client, err := ds9.NewClient(context.Background(), "test-project") - if err != nil { - t.Fatalf("NewClient failed: %v", err) - } - - // Create context that we'll cancel - ctx, cancel := context.WithCancel(context.Background()) - - // Cancel after a short delay - go func() { - time.Sleep(50 * time.Millisecond) - cancel() - }() - - key := ds9.NameKey("TestKind", "cancel-test", nil) - entity := &testEntity{Name: "test", Count: 1} - _, err = client.Put(ctx, key, entity) - - if err == nil { - t.Fatal("expected error when context is cancelled") - } - - if !errors.Is(err, context.Canceled) && !strings.Contains(err.Error(), "context canceled") { - t.Errorf("expected context cancellation error, got: %v", err) - } -} - -func TestTransactionRollback(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - // Put initial entity - key := ds9.NameKey("TestKind", "rollback-test", nil) - entity := &testEntity{Name: "original", Count: 1} - _, err := client.Put(ctx, key, entity) - if err != nil { - t.Fatalf("Put failed: %v", err) - } - - // Run transaction that will fail - _, err = client.RunInTransaction(ctx, func(tx *ds9.Transaction) error { - var current testEntity - if err := tx.Get(key, ¤t); err != nil { - return err - } - - current.Name = "modified" - current.Count = 999 - - _, err := tx.Put(key, ¤t) - if err != nil { - return err - } - - // Return error to cause rollback - return errors.New("force rollback") - }) - - if err == nil { - t.Fatal("expected transaction to fail") - } - - if !strings.Contains(err.Error(), "force rollback") { - t.Errorf("expected 'force rollback' error, got: %v", err) - } -} - -func TestPutWithInvalidEntity(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - type InvalidEntity struct { - Map map[string]string // maps not supported - } - - key := ds9.NameKey("TestKind", "invalid", nil) - entity := &InvalidEntity{ - Map: map[string]string{"key": "value"}, - } - - _, err := client.Put(ctx, key, entity) - if err == nil { - t.Error("expected error for unsupported entity type") - } -} - -func TestGetMultiWithMismatchedSliceSize(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - // Put one entity - key1 := ds9.NameKey("TestKind", "key1", nil) - entity := &testEntity{Name: "test", Count: 1} - _, err := client.Put(ctx, key1, entity) - if err != nil { - t.Fatalf("Put failed: %v", err) - } - - // Try to get with wrong slice type - keys := []*ds9.Key{key1} - var retrieved []testEntity - - // This should work - err = client.GetMulti(ctx, keys, &retrieved) - if err != nil { - t.Fatalf("GetMulti failed: %v", err) - } - - if len(retrieved) != 1 { - t.Errorf("expected 1 entity, got %d", len(retrieved)) - } -} - -func TestTransactionBeginFailure(t *testing.T) { - // Setup mock servers - metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Metadata-Flavor") != "Google" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.URL.Path == "/project/project-id" { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("test-project")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - if r.URL.Path == "/instance/service-accounts/default/token" { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "access_token": "test-token", - "expires_in": 3600, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer metadataServer.Close() - - apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Fail to begin transaction - w.WriteHeader(http.StatusInternalServerError) - if _, err := w.Write([]byte(`{"error":"internal error"}`)); err != nil { - t.Logf("write failed: %v", err) - } - })) - defer apiServer.Close() - - restore := ds9.SetTestURLs(metadataServer.URL, apiServer.URL) - defer restore() - - ctx := context.Background() - client, err := ds9.NewClient(ctx, "test-project") - if err != nil { - t.Fatalf("NewClient failed: %v", err) - } - - _, err = client.RunInTransaction(ctx, func(tx *ds9.Transaction) error { - return nil - }) - - if err == nil { - t.Fatal("expected transaction to fail on begin") - } - - if !strings.Contains(err.Error(), "500") { - t.Errorf("expected error to mention 500 status, got: %v", err) - } -} - -func TestTransactionCommitAbortedRetry(t *testing.T) { - // Setup mock servers - metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Metadata-Flavor") != "Google" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.URL.Path == "/project/project-id" { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("test-project")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - if r.URL.Path == "/instance/service-accounts/default/token" { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "access_token": "test-token", - "expires_in": 3600, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer metadataServer.Close() - - commitAttempt := 0 - apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if strings.Contains(r.URL.Path, "beginTransaction") { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - if err := json.NewEncoder(w).Encode(map[string]any{ - "transaction": "test-tx-123", - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - - if strings.Contains(r.URL.Path, "commit") { - commitAttempt++ - // Fail with 409 ABORTED on first two attempts, succeed on third - if commitAttempt < 3 { - w.WriteHeader(http.StatusConflict) - if _, err := w.Write([]byte(`{"error":"ABORTED: transaction aborted"}`)); err != nil { - t.Logf("write failed: %v", err) - } - return - } - - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - if err := json.NewEncoder(w).Encode(map[string]any{ - "mutationResults": []any{}, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - - if strings.Contains(r.URL.Path, "lookup") { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - if err := json.NewEncoder(w).Encode(map[string]any{ - "found": []any{}, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - - w.WriteHeader(http.StatusNotFound) - })) - defer apiServer.Close() - - restore := ds9.SetTestURLs(metadataServer.URL, apiServer.URL) - defer restore() - - ctx := context.Background() - client, err := ds9.NewClient(ctx, "test-project") - if err != nil { - t.Fatalf("NewClient failed: %v", err) - } - - // This should succeed after retries - key := ds9.NameKey("TestKind", "tx-retry", nil) - _, err = client.RunInTransaction(ctx, func(tx *ds9.Transaction) error { - _, err := tx.Put(key, &testEntity{Name: "test", Count: 1}) - return err - }) - if err != nil { - t.Fatalf("transaction should succeed after retries, got: %v", err) - } - - if commitAttempt < 2 { - t.Errorf("expected at least 2 commit attempts, got %d", commitAttempt) - } -} - -func TestTransactionMaxRetriesExceeded(t *testing.T) { - // Setup mock servers - metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Metadata-Flavor") != "Google" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.URL.Path == "/project/project-id" { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("test-project")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - if r.URL.Path == "/instance/service-accounts/default/token" { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "access_token": "test-token", - "expires_in": 3600, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer metadataServer.Close() - - commitAttempt := 0 - apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if strings.Contains(r.URL.Path, "beginTransaction") { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - if err := json.NewEncoder(w).Encode(map[string]any{ - "transaction": "test-tx-456", - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - - if strings.Contains(r.URL.Path, "commit") { - commitAttempt++ - // Always return 409 ABORTED - w.WriteHeader(http.StatusConflict) - if _, err := w.Write([]byte(`{"error":"status 409 ABORTED: transaction conflict"}`)); err != nil { - t.Logf("write failed: %v", err) - } - return - } - - if strings.Contains(r.URL.Path, "lookup") { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - if err := json.NewEncoder(w).Encode(map[string]any{ - "found": []any{}, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - - w.WriteHeader(http.StatusNotFound) - })) - defer apiServer.Close() - - restore := ds9.SetTestURLs(metadataServer.URL, apiServer.URL) - defer restore() - - ctx := context.Background() - client, err := ds9.NewClient(ctx, "test-project") - if err != nil { - t.Fatalf("NewClient failed: %v", err) - } - - // This should fail after max retries - key := ds9.NameKey("TestKind", "tx-max-retry", nil) - _, err = client.RunInTransaction(ctx, func(tx *ds9.Transaction) error { - _, err := tx.Put(key, &testEntity{Name: "test", Count: 1}) - return err - }) - - if err == nil { - t.Fatal("expected transaction to fail after max retries") - } - - if !strings.Contains(err.Error(), "failed after 3 attempts") { - t.Errorf("expected 'failed after 3 attempts' error, got: %v", err) - } - - if commitAttempt != 3 { - t.Errorf("expected exactly 3 commit attempts, got %d", commitAttempt) - } -} - -func TestKeyFromJSONEdgeCases(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - // Test with ID key using integer ID - idKey := ds9.IDKey("TestKind", 12345, nil) - entity := &testEntity{Name: "id-test", Count: 1} - _, err := client.Put(ctx, idKey, entity) - if err != nil { - t.Fatalf("Put with ID key failed: %v", err) - } - - var retrieved testEntity - err = client.Get(ctx, idKey, &retrieved) - if err != nil { - t.Fatalf("Get with ID key failed: %v", err) - } - - if retrieved.Name != "id-test" { - t.Errorf("expected name 'id-test', got %q", retrieved.Name) - } -} - -func TestDecodeValueEdgeCases(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - // Test with all basic types - type ComplexEntity struct { - Time time.Time `datastore:"t"` - String string `datastore:"s"` - NoIndex string `datastore:"n,noindex"` - Int int `datastore:"i"` - Int64 int64 `datastore:"i64"` - Float float64 `datastore:"f"` - Int32 int32 `datastore:"i32"` - Bool bool `datastore:"b"` - } - - now := time.Now().UTC().Truncate(time.Second) - key := ds9.NameKey("Complex", "test", nil) - entity := &ComplexEntity{ - String: "test", - Int: 42, - Int32: 32, - Int64: 64, - Float: 3.14, - Bool: true, - Time: now, - NoIndex: "not indexed", - } - - _, err := client.Put(ctx, key, entity) - if err != nil { - t.Fatalf("Put failed: %v", err) - } - - var retrieved ComplexEntity - err = client.Get(ctx, key, &retrieved) - if err != nil { - t.Fatalf("Get failed: %v", err) - } - - if retrieved.String != entity.String { - t.Errorf("String mismatch") - } - if retrieved.Int != entity.Int { - t.Errorf("Int mismatch") - } - if retrieved.Int32 != entity.Int32 { - t.Errorf("Int32 mismatch") - } - if retrieved.Int64 != entity.Int64 { - t.Errorf("Int64 mismatch") - } - if retrieved.Float != entity.Float { - t.Errorf("Float mismatch") - } - if retrieved.Bool != entity.Bool { - t.Errorf("Bool mismatch") - } - if !retrieved.Time.Equal(entity.Time) { - t.Errorf("Time mismatch") - } - if retrieved.NoIndex != entity.NoIndex { - t.Errorf("NoIndex mismatch") - } -} - -func TestGetMultiMixedResults(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - // Put some entities - key1 := ds9.NameKey("Mixed", "exists1", nil) - key2 := ds9.NameKey("Mixed", "exists2", nil) - key3 := ds9.NameKey("Mixed", "missing", nil) - - entities := []testEntity{ - {Name: "entity1", Count: 1}, - {Name: "entity2", Count: 2}, - } - - _, err := client.PutMulti(ctx, []*ds9.Key{key1, key2}, entities) - if err != nil { - t.Fatalf("PutMulti failed: %v", err) - } - - // Try to get mix of existing and non-existing - keys := []*ds9.Key{key1, key2, key3} - var retrieved []testEntity - - err = client.GetMulti(ctx, keys, &retrieved) - if !errors.Is(err, ds9.ErrNoSuchEntity) { - t.Errorf("expected ErrNoSuchEntity for mixed results, got: %v", err) - } -} - -func TestPutMultiLargeBatch(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - // Create large batch - const size = 100 - entities := make([]testEntity, size) - keys := make([]*ds9.Key, size) - - for i := range size { - entities[i] = testEntity{ - Name: "large-batch", - Count: int64(i), - } - keys[i] = ds9.NameKey("LargeBatch", fmt.Sprintf("key-%d", i), nil) - } - - _, err := client.PutMulti(ctx, keys, entities) - if err != nil { - t.Fatalf("PutMulti with large batch failed: %v", err) - } - - // Verify a few - var retrieved testEntity - err = client.Get(ctx, keys[0], &retrieved) - if err != nil { - t.Fatalf("Get failed: %v", err) - } - - if retrieved.Count != 0 { - t.Errorf("expected Count 0, got %d", retrieved.Count) - } -} - -func TestGetWithHTTPError(t *testing.T) { - metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Metadata-Flavor") != "Google" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.URL.Path == "/project/project-id" { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("test-project")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - if r.URL.Path == "/instance/service-accounts/default/token" { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "access_token": "test-token", - "expires_in": 3600, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer metadataServer.Close() - - apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Return 404 for lookup - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusNotFound) - if err := json.NewEncoder(w).Encode(map[string]any{ - "error": "not found", - }); err != nil { - t.Logf("encode failed: %v", err) - } - })) - defer apiServer.Close() - - restore := ds9.SetTestURLs(metadataServer.URL, apiServer.URL) - defer restore() - - ctx := context.Background() - client, err := ds9.NewClient(ctx, "test-project") - if err != nil { - t.Fatalf("NewClient failed: %v", err) - } - - key := ds9.NameKey("TestKind", "test", nil) - var entity testEntity - err = client.Get(ctx, key, &entity) - - if err == nil { - t.Fatal("expected error on 404") - } - - if !strings.Contains(err.Error(), "404") { - t.Errorf("expected error to mention 404, got: %v", err) - } -} - -func TestPutWithHTTPError(t *testing.T) { - metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Metadata-Flavor") != "Google" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.URL.Path == "/project/project-id" { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("test-project")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - if r.URL.Path == "/instance/service-accounts/default/token" { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "access_token": "test-token", - "expires_in": 3600, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer metadataServer.Close() - - apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Return 403 Forbidden - w.WriteHeader(http.StatusForbidden) - if _, err := w.Write([]byte(`{"error":"permission denied"}`)); err != nil { - t.Logf("write failed: %v", err) - } - })) - defer apiServer.Close() - - restore := ds9.SetTestURLs(metadataServer.URL, apiServer.URL) - defer restore() - - ctx := context.Background() - client, err := ds9.NewClient(ctx, "test-project") - if err != nil { - t.Fatalf("NewClient failed: %v", err) - } - - key := ds9.NameKey("TestKind", "test", nil) - entity := &testEntity{Name: "test", Count: 1} - _, err = client.Put(ctx, key, entity) - - if err == nil { - t.Fatal("expected error on 403") - } - - if !strings.Contains(err.Error(), "403") { - t.Errorf("expected error to mention 403, got: %v", err) - } -} - -func TestDeleteMultiWithErrors(t *testing.T) { - metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Metadata-Flavor") != "Google" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.URL.Path == "/project/project-id" { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("test-project")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - if r.URL.Path == "/instance/service-accounts/default/token" { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "access_token": "test-token", - "expires_in": 3600, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer metadataServer.Close() - - apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Return server error - w.WriteHeader(http.StatusInternalServerError) - if _, err := w.Write([]byte(`{"error":"internal error"}`)); err != nil { - t.Logf("write failed: %v", err) - } - })) - defer apiServer.Close() - - restore := ds9.SetTestURLs(metadataServer.URL, apiServer.URL) - defer restore() - - ctx := context.Background() - client, err := ds9.NewClient(ctx, "test-project") - if err != nil { - t.Fatalf("NewClient failed: %v", err) - } - - keys := []*ds9.Key{ - ds9.NameKey("TestKind", "key1", nil), - ds9.NameKey("TestKind", "key2", nil), - } - - err = client.DeleteMulti(ctx, keys) - if err == nil { - t.Fatal("expected error on server failure") - } -} - -func TestQueryNonKeysOnly(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - // Try to call AllKeys with non-KeysOnly query - query := ds9.NewQuery("TestKind") - _, err := client.AllKeys(ctx, query) - - if err == nil { - t.Error("expected error for non-KeysOnly query") - } - - if !strings.Contains(err.Error(), "KeysOnly") { - t.Errorf("expected error to mention KeysOnly, got: %v", err) - } -} - -func TestDoRequestAllRetriesFail(t *testing.T) { - metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Metadata-Flavor") != "Google" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.URL.Path == "/project/project-id" { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("test-project")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - if r.URL.Path == "/instance/service-accounts/default/token" { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "access_token": "test-token", - "expires_in": 3600, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer metadataServer.Close() - - attemptCount := 0 - apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - attemptCount++ - // Always fail with 500 - w.WriteHeader(http.StatusInternalServerError) - if _, err := w.Write([]byte(`{"error":"persistent failure"}`)); err != nil { - t.Logf("write failed: %v", err) - } - })) - defer apiServer.Close() - - restore := ds9.SetTestURLs(metadataServer.URL, apiServer.URL) - defer restore() - - ctx := context.Background() - client, err := ds9.NewClient(ctx, "test-project") - if err != nil { - t.Fatalf("NewClient failed: %v", err) - } - - key := ds9.NameKey("TestKind", "test", nil) - entity := &testEntity{Name: "test", Count: 1} - _, err = client.Put(ctx, key, entity) - - if err == nil { - t.Fatal("expected error after all retries") - } - - if !strings.Contains(err.Error(), "attempts failed") { - t.Errorf("expected 'attempts failed' error, got: %v", err) - } - - // Should have tried multiple times - if attemptCount < 3 { - t.Errorf("expected at least 3 attempts, got %d", attemptCount) - } -} - -func TestEntityWithPointerFields(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - // Entities with pointer fields - type EntityWithPointers struct { - Name *string `datastore:"name"` - Count *int64 `datastore:"count"` - } - - name := "test" - count := int64(42) - key := ds9.NameKey("Pointers", "test", nil) - entity := &EntityWithPointers{ - Name: &name, - Count: &count, - } - - // Note: The current implementation doesn't support pointer fields - // This test documents the expected behavior - _, err := client.Put(ctx, key, entity) - if err == nil { - // If it succeeds, that's fine (future enhancement) - var retrieved EntityWithPointers - err = client.Get(ctx, key, &retrieved) - if err != nil { - t.Logf("Get after Put with pointers failed: %v", err) - } - } else { - // Expected to fail with current implementation - t.Logf("Put with pointer fields failed as expected: %v", err) - } -} - -func TestKeyWithOnlyKind(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - // Key with neither name nor ID should work (incomplete key) - // This gets an ID assigned by the datastore - key := &ds9.Key{Kind: "TestKind"} - entity := &testEntity{Name: "test", Count: 1} - - returnedKey, err := client.Put(ctx, key, entity) - if err != nil { - t.Fatalf("Put with incomplete key failed: %v", err) - } - - // The returned key should have an ID - if returnedKey == nil { - t.Fatal("expected non-nil returned key") - } - - if returnedKey.Kind != "TestKind" { - t.Errorf("expected Kind 'TestKind', got %q", returnedKey.Kind) - } -} - -func TestTransactionGetNonExistent(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - key := ds9.NameKey("TestKind", "nonexistent", nil) - - _, err := client.RunInTransaction(ctx, func(tx *ds9.Transaction) error { - var entity testEntity - return tx.Get(key, &entity) - }) - - if !errors.Is(err, ds9.ErrNoSuchEntity) { - t.Errorf("expected ErrNoSuchEntity in transaction, got: %v", err) - } -} - -func TestGetMultiAllMissing(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - keys := []*ds9.Key{ - ds9.NameKey("Missing", "key1", nil), - ds9.NameKey("Missing", "key2", nil), - ds9.NameKey("Missing", "key3", nil), - } - - var entities []testEntity - err := client.GetMulti(ctx, keys, &entities) - - if !errors.Is(err, ds9.ErrNoSuchEntity) { - t.Errorf("expected ErrNoSuchEntity when all keys missing, got: %v", err) - } -} - -func TestGetMultiWithSliceMismatch(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - // Put entity - key := ds9.NameKey("Test", "key1", nil) - entity := &testEntity{Name: "test", Count: 1} - _, err := client.Put(ctx, key, entity) - if err != nil { - t.Fatalf("Put failed: %v", err) - } - - // GetMulti with destination not being a pointer to slice - var notSlice testEntity - err = client.GetMulti(ctx, []*ds9.Key{key}, notSlice) - if err == nil { - t.Error("expected error when dst is not pointer to slice") - } -} - -func TestPutMultiWithLengthMismatch(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - // Keys and entities with different lengths - keys := []*ds9.Key{ - ds9.NameKey("Test", "key1", nil), - ds9.NameKey("Test", "key2", nil), - } - entities := []testEntity{ - {Name: "only-one", Count: 1}, - } - - _, err := client.PutMulti(ctx, keys, entities) - if err == nil { - t.Error("expected error when keys and entities have different lengths") - } -} - -func TestDeleteWithNonexistentKey(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - // Delete non-existent key (should not error) - key := ds9.NameKey("Test", "nonexistent", nil) - err := client.Delete(ctx, key) - if err != nil { - t.Errorf("Delete of non-existent key should not error, got: %v", err) - } -} - -func TestAllKeysWithEmptyResult(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - // Query kind with no entities - query := ds9.NewQuery("EmptyKind").KeysOnly() - keys, err := client.AllKeys(ctx, query) - if err != nil { - t.Fatalf("AllKeys on empty kind failed: %v", err) - } - - if len(keys) != 0 { - t.Errorf("expected 0 keys, got %d", len(keys)) - } -} - -func TestAllKeysWithLargeResult(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - // Put many entities - for i := range 50 { - key := ds9.NameKey("LargeResult", fmt.Sprintf("key-%d", i), nil) - entity := &testEntity{Name: "test", Count: int64(i)} - _, err := client.Put(ctx, key, entity) - if err != nil { - t.Fatalf("Put failed: %v", err) - } - } - - // Query all - query := ds9.NewQuery("LargeResult").KeysOnly() - keys, err := client.AllKeys(ctx, query) - if err != nil { - t.Fatalf("AllKeys failed: %v", err) - } - - if len(keys) != 50 { - t.Errorf("expected 50 keys, got %d", len(keys)) - } -} - -func TestQueryWithZeroLimit(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - // Put entities - for i := range 5 { - key := ds9.NameKey("ZeroLimit", fmt.Sprintf("key-%d", i), nil) - entity := &testEntity{Name: "test", Count: int64(i)} - _, err := client.Put(ctx, key, entity) - if err != nil { - t.Fatalf("Put failed: %v", err) - } - } - - // Query with limit 0 (should return all) - query := ds9.NewQuery("ZeroLimit").KeysOnly().Limit(0) - keys, err := client.AllKeys(ctx, query) - if err != nil { - t.Fatalf("AllKeys with limit 0 failed: %v", err) - } - - // Limit 0 should mean unlimited - if len(keys) == 0 { - t.Error("expected results with limit 0 (unlimited), got 0") - } -} - -func TestPutMultiEmptySlice(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - // Empty slices - _, err := client.PutMulti(ctx, []*ds9.Key{}, []testEntity{}) - if err == nil { - t.Error("expected error for empty slices") - } -} - -func TestGetMultiEmptySlice(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - var entities []testEntity - err := client.GetMulti(ctx, []*ds9.Key{}, &entities) - if err == nil { - t.Error("expected error for empty keys") - } -} - -func TestDeleteMultiEmptySlice(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - err := client.DeleteMulti(ctx, []*ds9.Key{}) - if err == nil { - t.Error("expected error for empty keys") - } -} - -func TestTransactionPutWithNilKey(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - _, err := client.RunInTransaction(ctx, func(tx *ds9.Transaction) error { - entity := &testEntity{Name: "test", Count: 1} - _, err := tx.Put(nil, entity) - return err - }) - - if err == nil { - t.Error("expected error for nil key in transaction") - } -} - -func TestTransactionGetWithNilKey(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - _, err := client.RunInTransaction(ctx, func(tx *ds9.Transaction) error { - var entity testEntity - return tx.Get(nil, &entity) - }) - - if err == nil { - t.Error("expected error for nil key in transaction Get") - } -} - -func TestDeepHierarchicalKeys(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - // Create 4-level hierarchy - gp := ds9.NameKey("GP", "gp1", nil) - p := ds9.NameKey("P", "p1", gp) - c := ds9.NameKey("C", "c1", p) - gc := ds9.NameKey("GC", "gc1", c) - - entity := &testEntity{Name: "great-grandchild", Count: 42} - _, err := client.Put(ctx, gc, entity) - if err != nil { - t.Fatalf("Put with 4-level hierarchy failed: %v", err) - } - - var retrieved testEntity - err = client.Get(ctx, gc, &retrieved) - if err != nil { - t.Fatalf("Get with 4-level hierarchy failed: %v", err) - } - - if retrieved.Name != "great-grandchild" { - t.Errorf("expected name 'great-grandchild', got %q", retrieved.Name) - } -} - -func TestEntityWithEmptyStringFields(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - key := ds9.NameKey("Empty", "test", nil) - entity := &testEntity{ - Name: "", // empty string - Count: 0, // zero - Active: false, // false - Score: 0.0, // zero float - } - - _, err := client.Put(ctx, key, entity) - if err != nil { - t.Fatalf("Put with empty/zero values failed: %v", err) - } - - var retrieved testEntity - err = client.Get(ctx, key, &retrieved) - if err != nil { - t.Fatalf("Get failed: %v", err) - } - - if retrieved.Name != "" { - t.Errorf("expected empty string, got %q", retrieved.Name) - } - if retrieved.Count != 0 { - t.Errorf("expected 0, got %d", retrieved.Count) - } -} - -func TestGetWithNonPointerDst(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - // Put entity - key := ds9.NameKey("Test", "key", nil) - entity := &testEntity{Name: "test", Count: 1} - _, err := client.Put(ctx, key, entity) - if err != nil { - t.Fatalf("Put failed: %v", err) - } - - // Try to get into non-pointer - var notPointer testEntity - err = client.Get(ctx, key, notPointer) // Should be ¬Pointer - if err == nil { - t.Error("expected error when dst is not a pointer") - } -} - -func TestPutWithNonPointerEntity(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - key := ds9.NameKey("Test", "key", nil) - entity := testEntity{Name: "test", Count: 1} // not a pointer - - // The mock implementation may accept non-pointers, but test with the real client - // For now, just test that it works (real Datastore would require pointer) - _, err := client.Put(ctx, key, entity) - if err != nil { - t.Logf("Put with non-pointer entity failed (expected with real client): %v", err) - } -} - -func TestDeleteAllByKindWithNoEntities(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - // Delete from kind with no entities - err := client.DeleteAllByKind(ctx, "NonExistentKind") - if err != nil { - t.Errorf("DeleteAllByKind on empty kind should not error, got: %v", err) - } -} - -func TestDeleteAllByKindWithManyEntities(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - // Put many entities - for i := range 25 { - key := ds9.NameKey("ManyDelete", fmt.Sprintf("key-%d", i), nil) - entity := &testEntity{Name: "test", Count: int64(i)} - _, err := client.Put(ctx, key, entity) - if err != nil { - t.Fatalf("Put failed: %v", err) - } - } - - // Delete all - err := client.DeleteAllByKind(ctx, "ManyDelete") - if err != nil { - t.Fatalf("DeleteAllByKind failed: %v", err) - } - - // Verify all deleted - query := ds9.NewQuery("ManyDelete").KeysOnly() - keys, err := client.AllKeys(ctx, query) - if err != nil { - t.Fatalf("AllKeys failed: %v", err) - } - - if len(keys) != 0 { - t.Errorf("expected 0 keys after DeleteAllByKind, got %d", len(keys)) - } -} - -func TestTransactionWithMultiplePuts(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - _, err := client.RunInTransaction(ctx, func(tx *ds9.Transaction) error { - for i := range 5 { - key := ds9.NameKey("TxMulti", fmt.Sprintf("key-%d", i), nil) - entity := &testEntity{Name: "test", Count: int64(i)} - _, err := tx.Put(key, entity) - if err != nil { - return err - } - } - return nil - }) - if err != nil { - t.Fatalf("Transaction with multiple puts failed: %v", err) - } - - // Verify all entities were created - for i := range 5 { - key := ds9.NameKey("TxMulti", fmt.Sprintf("key-%d", i), nil) - var retrieved testEntity - err = client.Get(ctx, key, &retrieved) - if err != nil { - t.Errorf("Get for entity %d failed: %v", i, err) - } - if retrieved.Count != int64(i) { - t.Errorf("entity %d: expected Count %d, got %d", i, i, retrieved.Count) - } - } -} - -func TestIDKeyWithZeroID(t *testing.T) { - // Zero ID is valid - key := ds9.IDKey("Test", 0, nil) - if key.ID != 0 { - t.Errorf("expected ID 0, got %d", key.ID) - } - if key.Name != "" { - t.Errorf("expected empty Name, got %q", key.Name) - } -} - -func TestNameKeyWithEmptyName(t *testing.T) { - // Empty name is technically valid - key := ds9.NameKey("Test", "", nil) - if key.Name != "" { - t.Errorf("expected empty Name, got %q", key.Name) - } - if key.ID != 0 { - t.Errorf("expected ID 0, got %d", key.ID) - } -} - -func TestDoRequestUnexpectedSuccess(t *testing.T) { - metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Metadata-Flavor") != "Google" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.URL.Path == "/project/project-id" { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("test-project")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - if r.URL.Path == "/instance/service-accounts/default/token" { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "access_token": "test-token", - "expires_in": 3600, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer metadataServer.Close() - - apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Return unexpected 2xx status (not 200) - w.WriteHeader(http.StatusAccepted) // 202 - if _, err := w.Write([]byte(`{"message":"accepted"}`)); err != nil { - t.Logf("write failed: %v", err) - } - })) - defer apiServer.Close() - - restore := ds9.SetTestURLs(metadataServer.URL, apiServer.URL) - defer restore() - - ctx := context.Background() - client, err := ds9.NewClient(ctx, "test-project") - if err != nil { - t.Fatalf("NewClient failed: %v", err) - } - - key := ds9.NameKey("Test", "key", nil) - entity := &testEntity{Name: "test", Count: 1} - _, err = client.Put(ctx, key, entity) - - if err == nil { - t.Error("expected error for unexpected 2xx status") - } - - if !strings.Contains(err.Error(), "202") { - t.Errorf("expected error to mention 202 status, got: %v", err) - } -} - -func TestGetMultiWithNonSliceDst(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - keys := []*ds9.Key{ - ds9.NameKey("Test", "key1", nil), - } - - // Pass a non-slice as destination - var notSlice string - err := client.GetMulti(ctx, keys, ¬Slice) - - if err == nil { - t.Error("expected error when dst is not a slice") - } -} - -func TestPutMultiWithNonSliceSrc(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - keys := []*ds9.Key{ - ds9.NameKey("Test", "key1", nil), - } - - // Pass a non-slice as source - notSlice := "not a slice" - _, err := client.PutMulti(ctx, keys, notSlice) - - if err == nil { - t.Error("expected error when src is not a slice") - } -} - -func TestAllKeysQueryWithoutKeysOnly(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - // Create query without KeysOnly - query := ds9.NewQuery("Test") - - _, err := client.AllKeys(ctx, query) - - if err == nil { - t.Error("expected error for query without KeysOnly") - } - - if !strings.Contains(err.Error(), "KeysOnly") { - t.Errorf("expected error to mention KeysOnly, got: %v", err) - } -} - -func TestDeleteAllByKindQueryFailure(t *testing.T) { - metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Metadata-Flavor") != "Google" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.URL.Path == "/project/project-id" { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("test-project")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - if r.URL.Path == "/instance/service-accounts/default/token" { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "access_token": "test-token", - "expires_in": 3600, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer metadataServer.Close() - - apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Fail on query request - if strings.Contains(r.URL.Path, "runQuery") { - w.WriteHeader(http.StatusInternalServerError) - if _, err := w.Write([]byte(`{"error":"query failed"}`)); err != nil { - t.Logf("write failed: %v", err) - } - return - } - w.WriteHeader(http.StatusOK) - })) - defer apiServer.Close() - - restore := ds9.SetTestURLs(metadataServer.URL, apiServer.URL) - defer restore() - - ctx := context.Background() - client, err := ds9.NewClient(ctx, "test-project") - if err != nil { - t.Fatalf("NewClient failed: %v", err) - } - - err = client.DeleteAllByKind(ctx, "TestKind") - - if err == nil { - t.Error("expected error when query fails") - } -} - -func TestTransactionGetWithInvalidResponse(t *testing.T) { - metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Metadata-Flavor") != "Google" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.URL.Path == "/project/project-id" { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("test-project")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - if r.URL.Path == "/instance/service-accounts/default/token" { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "access_token": "test-token", - "expires_in": 3600, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer metadataServer.Close() - - apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if strings.Contains(r.URL.Path, "beginTransaction") { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - if err := json.NewEncoder(w).Encode(map[string]any{ - "transaction": "test-tx", - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - - if strings.Contains(r.URL.Path, "lookup") { - // Return invalid JSON structure - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte(`{"invalid":"structure"}`)); err != nil { - t.Logf("write failed: %v", err) - } - return - } - - if strings.Contains(r.URL.Path, "commit") { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - if err := json.NewEncoder(w).Encode(map[string]any{ - "mutationResults": []any{}, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - - w.WriteHeader(http.StatusOK) - })) - defer apiServer.Close() - - restore := ds9.SetTestURLs(metadataServer.URL, apiServer.URL) - defer restore() - - ctx := context.Background() - client, err := ds9.NewClient(ctx, "test-project") - if err != nil { - t.Fatalf("NewClient failed: %v", err) - } - - key := ds9.NameKey("Test", "key", nil) - _, err = client.RunInTransaction(ctx, func(tx *ds9.Transaction) error { - var entity testEntity - return tx.Get(key, &entity) - }) - - // Should handle the invalid response gracefully - if err == nil { - t.Log("Transaction succeeded despite invalid lookup response") - } -} - -func TestGetWithInvalidJSONResponse(t *testing.T) { - metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Metadata-Flavor") != "Google" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.URL.Path == "/project/project-id" { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("test-project")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - if r.URL.Path == "/instance/service-accounts/default/token" { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "access_token": "test-token", - "expires_in": 3600, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer metadataServer.Close() - - apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Return invalid JSON - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte(`{invalid json`)); err != nil { - t.Logf("write failed: %v", err) - } - })) - defer apiServer.Close() - - restore := ds9.SetTestURLs(metadataServer.URL, apiServer.URL) - defer restore() - - ctx := context.Background() - client, err := ds9.NewClient(ctx, "test-project") - if err != nil { - t.Fatalf("NewClient failed: %v", err) - } - - key := ds9.NameKey("Test", "key", nil) - var entity testEntity - err = client.Get(ctx, key, &entity) - - if err == nil { - t.Error("expected error for invalid JSON response") - } -} - -func TestPutWithInvalidEntityStructure(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - // Entity with channel (unsupported type) - type BadEntity struct { - Ch chan int - Name string - } - - key := ds9.NameKey("Test", "bad", nil) - entity := &BadEntity{ - Name: "test", - Ch: make(chan int), - } - - _, err := client.Put(ctx, key, entity) - - if err == nil { - t.Error("expected error for unsupported entity type") - } -} - -func TestGetMultiWithNilInResults(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - // Put one entity - key1 := ds9.NameKey("Test", "exists", nil) - entity := &testEntity{Name: "test", Count: 1} - _, err := client.Put(ctx, key1, entity) - if err != nil { - t.Fatalf("Put failed: %v", err) - } - - // Try to get multiple with one missing - keys := []*ds9.Key{ - key1, - ds9.NameKey("Test", "missing", nil), - ds9.NameKey("Test", "missing2", nil), - } - - var entities []testEntity - err = client.GetMulti(ctx, keys, &entities) - - if !errors.Is(err, ds9.ErrNoSuchEntity) { - t.Errorf("expected ErrNoSuchEntity when some keys missing, got: %v", err) - } -} - -func TestDeleteMultiPartialSuccess(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - // Put some entities - keys := []*ds9.Key{ - ds9.NameKey("Test", "key1", nil), - ds9.NameKey("Test", "key2", nil), - } - - entities := []testEntity{ - {Name: "entity1", Count: 1}, - {Name: "entity2", Count: 2}, - } - - _, err := client.PutMulti(ctx, keys, entities) - if err != nil { - t.Fatalf("PutMulti failed: %v", err) - } - - // Delete them (should succeed) - err = client.DeleteMulti(ctx, keys) - if err != nil { - t.Fatalf("DeleteMulti failed: %v", err) - } - - // Verify deletion - var retrieved []testEntity - err = client.GetMulti(ctx, keys, &retrieved) - if !errors.Is(err, ds9.ErrNoSuchEntity) { - t.Errorf("expected ErrNoSuchEntity after delete, got: %v", err) - } -} - -func TestQueryWithVeryLargeLimit(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - // Put a few entities - for i := range 3 { - key := ds9.NameKey("LargeLimit", fmt.Sprintf("key-%d", i), nil) - entity := &testEntity{Name: "test", Count: int64(i)} - _, err := client.Put(ctx, key, entity) - if err != nil { - t.Fatalf("Put failed: %v", err) - } - } - - // Query with very large limit - query := ds9.NewQuery("LargeLimit").KeysOnly().Limit(10000) - keys, err := client.AllKeys(ctx, query) - if err != nil { - t.Fatalf("AllKeys with large limit failed: %v", err) - } - - // Should return all 3 - if len(keys) != 3 { - t.Errorf("expected 3 keys, got %d", len(keys)) - } -} - -func TestDeleteWithServerError(t *testing.T) { - metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Metadata-Flavor") != "Google" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.URL.Path == "/project/project-id" { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("test-project")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - if r.URL.Path == "/instance/service-accounts/default/token" { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "access_token": "test-token", - "expires_in": 3600, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer metadataServer.Close() - - attemptCount := 0 - apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - attemptCount++ - // Always return 503 - w.WriteHeader(http.StatusServiceUnavailable) - if _, err := w.Write([]byte(`{"error":"unavailable"}`)); err != nil { - t.Logf("write failed: %v", err) - } - })) - defer apiServer.Close() - - restore := ds9.SetTestURLs(metadataServer.URL, apiServer.URL) - defer restore() - - ctx := context.Background() - client, err := ds9.NewClient(ctx, "test-project") - if err != nil { - t.Fatalf("NewClient failed: %v", err) - } - - key := ds9.NameKey("Test", "key", nil) - err = client.Delete(ctx, key) - - if err == nil { - t.Error("expected error on persistent server failure") - } - - // Should have retried - if attemptCount < 2 { - t.Errorf("expected multiple attempts, got %d", attemptCount) - } -} - -func TestPutMultiWithServerError(t *testing.T) { - metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Metadata-Flavor") != "Google" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.URL.Path == "/project/project-id" { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("test-project")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - if r.URL.Path == "/instance/service-accounts/default/token" { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "access_token": "test-token", - "expires_in": 3600, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer metadataServer.Close() - - apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusBadRequest) - if _, err := w.Write([]byte(`{"error":"bad request"}`)); err != nil { - t.Logf("write failed: %v", err) - } - })) - defer apiServer.Close() - - restore := ds9.SetTestURLs(metadataServer.URL, apiServer.URL) - defer restore() - - ctx := context.Background() - client, err := ds9.NewClient(ctx, "test-project") - if err != nil { - t.Fatalf("NewClient failed: %v", err) - } - - keys := []*ds9.Key{ - ds9.NameKey("Test", "key1", nil), - ds9.NameKey("Test", "key2", nil), - } - - entities := []testEntity{ - {Name: "entity1", Count: 1}, - {Name: "entity2", Count: 2}, - } - - _, err = client.PutMulti(ctx, keys, entities) - - if err == nil { - t.Error("expected error on server failure") - } -} - -func TestGetMultiWithServerError(t *testing.T) { - metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Metadata-Flavor") != "Google" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.URL.Path == "/project/project-id" { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("test-project")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - if r.URL.Path == "/instance/service-accounts/default/token" { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "access_token": "test-token", - "expires_in": 3600, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer metadataServer.Close() - - apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusUnauthorized) - if _, err := w.Write([]byte(`{"error":"unauthorized"}`)); err != nil { - t.Logf("write failed: %v", err) - } - })) - defer apiServer.Close() - - restore := ds9.SetTestURLs(metadataServer.URL, apiServer.URL) - defer restore() - - ctx := context.Background() - client, err := ds9.NewClient(ctx, "test-project") - if err != nil { - t.Fatalf("NewClient failed: %v", err) - } - - keys := []*ds9.Key{ - ds9.NameKey("Test", "key1", nil), - } - - var entities []testEntity - err = client.GetMulti(ctx, keys, &entities) - - if err == nil { - t.Error("expected error on unauthorized") - } - - if !strings.Contains(err.Error(), "401") { - t.Errorf("expected 401 error, got: %v", err) - } -} - -func TestAllKeysWithInvalidResponse(t *testing.T) { - metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Metadata-Flavor") != "Google" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.URL.Path == "/project/project-id" { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("test-project")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - if r.URL.Path == "/instance/service-accounts/default/token" { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "access_token": "test-token", - "expires_in": 3600, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer metadataServer.Close() - - apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Return invalid JSON - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte(`{malformed`)); err != nil { - t.Logf("write failed: %v", err) - } - })) - defer apiServer.Close() - - restore := ds9.SetTestURLs(metadataServer.URL, apiServer.URL) - defer restore() - - ctx := context.Background() - client, err := ds9.NewClient(ctx, "test-project") - if err != nil { - t.Fatalf("NewClient failed: %v", err) - } - - query := ds9.NewQuery("Test").KeysOnly() - _, err = client.AllKeys(ctx, query) - - if err == nil { - t.Error("expected error for invalid JSON") - } -} - -func TestTransactionWithNonRetriableError(t *testing.T) { - metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Metadata-Flavor") != "Google" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.URL.Path == "/project/project-id" { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("test-project")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - if r.URL.Path == "/instance/service-accounts/default/token" { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "access_token": "test-token", - "expires_in": 3600, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer metadataServer.Close() - - commitAttempts := 0 - apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if strings.Contains(r.URL.Path, "beginTransaction") { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - if err := json.NewEncoder(w).Encode(map[string]any{ - "transaction": "test-tx", - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - - if strings.Contains(r.URL.Path, "commit") { - commitAttempts++ - // Return non-retriable error (not 409 ABORTED) - w.WriteHeader(http.StatusBadRequest) - if _, err := w.Write([]byte(`{"error":"INVALID_ARGUMENT"}`)); err != nil { - t.Logf("write failed: %v", err) - } - return - } - - if strings.Contains(r.URL.Path, "lookup") { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - if err := json.NewEncoder(w).Encode(map[string]any{ - "found": []any{}, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - - w.WriteHeader(http.StatusOK) - })) - defer apiServer.Close() - - restore := ds9.SetTestURLs(metadataServer.URL, apiServer.URL) - defer restore() - - ctx := context.Background() - client, err := ds9.NewClient(ctx, "test-project") - if err != nil { - t.Fatalf("NewClient failed: %v", err) - } - - key := ds9.NameKey("Test", "key", nil) - _, err = client.RunInTransaction(ctx, func(tx *ds9.Transaction) error { - _, err := tx.Put(key, &testEntity{Name: "test", Count: 1}) - return err - }) - - if err == nil { - t.Error("expected error on non-retriable failure") - } - - // Should NOT retry on non-409 errors - if commitAttempts != 1 { - t.Errorf("expected exactly 1 commit attempt for non-retriable error, got %d", commitAttempts) - } -} - -func TestTransactionWithInvalidTxResponse(t *testing.T) { - metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Metadata-Flavor") != "Google" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.URL.Path == "/project/project-id" { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("test-project")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - if r.URL.Path == "/instance/service-accounts/default/token" { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "access_token": "test-token", - "expires_in": 3600, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer metadataServer.Close() - - apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if strings.Contains(r.URL.Path, "beginTransaction") { - // Return invalid JSON - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte(`{bad json`)); err != nil { - t.Logf("write failed: %v", err) - } - return - } - w.WriteHeader(http.StatusOK) - })) - defer apiServer.Close() - - restore := ds9.SetTestURLs(metadataServer.URL, apiServer.URL) - defer restore() - - ctx := context.Background() - client, err := ds9.NewClient(ctx, "test-project") - if err != nil { - t.Fatalf("NewClient failed: %v", err) - } - - _, err = client.RunInTransaction(ctx, func(tx *ds9.Transaction) error { - return nil - }) - - if err == nil { - t.Error("expected error for invalid transaction response") - } -} - -func TestTransactionGetWithDecodeError(t *testing.T) { - metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Metadata-Flavor") != "Google" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.URL.Path == "/project/project-id" { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("test-project")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - if r.URL.Path == "/instance/service-accounts/default/token" { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "access_token": "test-token", - "expires_in": 3600, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer metadataServer.Close() - - apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if strings.Contains(r.URL.Path, "beginTransaction") { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - if err := json.NewEncoder(w).Encode(map[string]any{ - "transaction": "test-tx", - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - - if strings.Contains(r.URL.Path, "lookup") { - // Return entity with malformed data - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - if err := json.NewEncoder(w).Encode(map[string]any{ - "found": []any{ - map[string]any{ - "entity": map[string]any{ - "key": map[string]any{ - "path": []any{ - map[string]any{ - "kind": "Test", - "name": "key", - }, - }, - }, - "properties": map[string]any{ - "name": map[string]any{ - "stringValue": 12345, // Wrong type - }, - }, - }, - }, - }, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - - if strings.Contains(r.URL.Path, "commit") { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - if err := json.NewEncoder(w).Encode(map[string]any{ - "mutationResults": []any{}, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - - w.WriteHeader(http.StatusOK) - })) - defer apiServer.Close() - - restore := ds9.SetTestURLs(metadataServer.URL, apiServer.URL) - defer restore() - - ctx := context.Background() - client, err := ds9.NewClient(ctx, "test-project") - if err != nil { - t.Fatalf("NewClient failed: %v", err) - } - - key := ds9.NameKey("Test", "key", nil) - _, err = client.RunInTransaction(ctx, func(tx *ds9.Transaction) error { - var entity testEntity - return tx.Get(key, &entity) - }) - // May succeed or fail depending on how decoding handles type mismatches - if err != nil { - t.Logf("Transaction Get with decode error: %v", err) - } -} - -func TestDoRequestWithReadBodyError(t *testing.T) { - metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Metadata-Flavor") != "Google" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.URL.Path == "/project/project-id" { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("test-project")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - if r.URL.Path == "/instance/service-accounts/default/token" { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "access_token": "test-token", - "expires_in": 3600, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer metadataServer.Close() - - apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Set content length but don't write enough data - w.Header().Set("Content-Length", "1000000") - w.WriteHeader(http.StatusOK) - // Write partial data then close connection - })) - defer apiServer.Close() - - restore := ds9.SetTestURLs(metadataServer.URL, apiServer.URL) - defer restore() - - ctx := context.Background() - client, err := ds9.NewClient(ctx, "test-project") - if err != nil { - t.Fatalf("NewClient failed: %v", err) - } - - key := ds9.NameKey("Test", "key", nil) - entity := &testEntity{Name: "test", Count: 1} - _, err = client.Put(ctx, key, entity) - // Should get an error related to response parsing - if err != nil { - t.Logf("Got expected error with incomplete response: %v", err) - } -} - -func TestPutMultiWithPartialEncode(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - // Mix of valid and invalid entities - type MixedEntity struct { - Data any - Name string - } - - keys := []*ds9.Key{ - ds9.NameKey("Test", "key1", nil), - ds9.NameKey("Test", "key2", nil), - } - - entities := []MixedEntity{ - {Name: "valid", Data: "string"}, - {Name: "maybe-invalid", Data: make(chan int)}, // channels unsupported - } - - _, err := client.PutMulti(ctx, keys, entities) - - if err == nil { - t.Log("PutMulti with mixed entities succeeded (mock may not validate types)") - } else { - t.Logf("PutMulti with mixed entities failed as expected: %v", err) - } -} - -func TestDeleteWithContextCancellation(t *testing.T) { - metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Metadata-Flavor") != "Google" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.URL.Path == "/project/project-id" { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("test-project")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - if r.URL.Path == "/instance/service-accounts/default/token" { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "access_token": "test-token", - "expires_in": 3600, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer metadataServer.Close() - - apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Slow response - time.Sleep(100 * time.Millisecond) - w.WriteHeader(http.StatusServiceUnavailable) - })) - defer apiServer.Close() - - restore := ds9.SetTestURLs(metadataServer.URL, apiServer.URL) - defer restore() - - client, err := ds9.NewClient(context.Background(), "test-project") - if err != nil { - t.Fatalf("NewClient failed: %v", err) - } - - ctx, cancel := context.WithCancel(context.Background()) - cancel() // Cancel immediately - - key := ds9.NameKey("Test", "key", nil) - err = client.Delete(ctx, key) - - if err == nil { - t.Error("expected error when context is cancelled") - } -} - -// Tests for keyFromJSON with invalid path elements -func TestKeyFromJSONInvalidPathElement(t *testing.T) { - // Test with non-map path element - keyData := map[string]any{ - "path": []any{ - "invalid-string-instead-of-map", - }, - } - - metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Metadata-Flavor") != "Google" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.URL.Path == "/project/project-id" { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("test-project")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - if r.URL.Path == "/instance/service-accounts/default/token" { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "access_token": "test-token", - "expires_in": 3600, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer metadataServer.Close() - - apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if strings.Contains(r.URL.Path, ":commit") { - // Return response with invalid key in mutation result - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "mutationResults": []map[string]any{ - { - "key": keyData, - }, - }, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer apiServer.Close() - - restore := ds9.SetTestURLs(metadataServer.URL, apiServer.URL) - defer restore() - - realClient, err := ds9.NewClient(context.Background(), "test-project") - if err != nil { - t.Fatalf("NewClient failed: %v", err) - } - - ctx := context.Background() - key := ds9.NameKey("Test", "key", nil) - entity := &testEntity{Name: "test"} - - // Try Put which will parse the returned key - _, err = realClient.Put(ctx, key, entity) - if err == nil { - t.Log("Put succeeded despite invalid path element (API may handle gracefully)") - } else { - t.Logf("Put failed as expected: %v", err) - } -} - -// Test keyFromJSON with ID as string that fails parsing -func TestKeyFromJSONInvalidIDString(t *testing.T) { - keyData := map[string]any{ - "path": []any{ - map[string]any{ - "kind": "Test", - "id": "not-a-number", - }, - }, - } - - metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Metadata-Flavor") != "Google" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.URL.Path == "/project/project-id" { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("test-project")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - if r.URL.Path == "/instance/service-accounts/default/token" { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "access_token": "test-token", - "expires_in": 3600, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer metadataServer.Close() - - apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if strings.Contains(r.URL.Path, ":commit") { - // Return response with invalid ID string in mutation result - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "mutationResults": []map[string]any{ - { - "key": keyData, - }, - }, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer apiServer.Close() - - restore := ds9.SetTestURLs(metadataServer.URL, apiServer.URL) - defer restore() - - realClient, err := ds9.NewClient(context.Background(), "test-project") - if err != nil { - t.Fatalf("NewClient failed: %v", err) - } - - ctx := context.Background() - key := ds9.NameKey("Test", "key", nil) - entity := &testEntity{Name: "test"} - - // Try Put which will parse the returned key - _, err = realClient.Put(ctx, key, entity) - if err == nil { - t.Log("Put succeeded despite invalid ID string (API may handle gracefully)") - } else { - t.Logf("Put failed as expected: %v", err) - } -} - -// Test keyFromJSON with ID as float64 -func TestKeyFromJSONIDAsFloat(t *testing.T) { - keyData := map[string]any{ - "path": []any{ - map[string]any{ - "kind": "Test", - "id": float64(12345), - }, - }, - } - - metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Metadata-Flavor") != "Google" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.URL.Path == "/project/project-id" { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("test-project")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - if r.URL.Path == "/instance/service-accounts/default/token" { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "access_token": "test-token", - "expires_in": 3600, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer metadataServer.Close() - - apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if strings.Contains(r.URL.Path, ":lookup") { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "found": []map[string]any{ - { - "entity": map[string]any{ - "key": keyData, - "properties": map[string]any{ - "name": map[string]any{"stringValue": "test"}, - }, - }, - }, - }, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer apiServer.Close() - - restore := ds9.SetTestURLs(metadataServer.URL, apiServer.URL) - defer restore() - - realClient, err := ds9.NewClient(context.Background(), "test-project") - if err != nil { - t.Fatalf("NewClient failed: %v", err) - } - - ctx := context.Background() - key := ds9.NameKey("Test", "key", nil) - var entity testEntity - - err = realClient.Get(ctx, key, &entity) - if err != nil { - t.Errorf("unexpected error with float64 ID: %v", err) - } -} - -// Test Transaction.Get with missing entity -func TestTransactionGetMissingEntity(t *testing.T) { - metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Metadata-Flavor") != "Google" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.URL.Path == "/project/project-id" { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("test-project")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - if r.URL.Path == "/instance/service-accounts/default/token" { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "access_token": "test-token", - "expires_in": 3600, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer metadataServer.Close() - - apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if strings.Contains(r.URL.Path, ":beginTransaction") { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "transaction": "test-tx-id", - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - if strings.Contains(r.URL.Path, ":lookup") { - // Return empty found array (entity not found) - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "found": []any{}, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - if strings.Contains(r.URL.Path, ":commit") { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "mutationResults": []any{}, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer apiServer.Close() - - restore := ds9.SetTestURLs(metadataServer.URL, apiServer.URL) - defer restore() - - client, err := ds9.NewClient(context.Background(), "test-project") - if err != nil { - t.Fatalf("NewClient failed: %v", err) - } - - ctx := context.Background() - key := ds9.NameKey("Test", "nonexistent", nil) - - _, err = client.RunInTransaction(ctx, func(tx *ds9.Transaction) error { - var entity testEntity - err := tx.Get(key, &entity) - if err == nil { - return errors.New("expected error for missing entity") - } - return nil - }) - if err != nil { - t.Errorf("transaction should succeed even with get error: %v", err) - } -} - -// Test Transaction.Get with decode error -func TestTransactionGetDecodeError(t *testing.T) { - metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Metadata-Flavor") != "Google" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.URL.Path == "/project/project-id" { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("test-project")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - if r.URL.Path == "/instance/service-accounts/default/token" { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "access_token": "test-token", - "expires_in": 3600, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer metadataServer.Close() - - apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if strings.Contains(r.URL.Path, ":beginTransaction") { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "transaction": "test-tx-id", - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - if strings.Contains(r.URL.Path, ":lookup") { - // Return malformed entity - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "found": []map[string]any{ - { - "entity": "invalid-not-a-map", - }, - }, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - if strings.Contains(r.URL.Path, ":commit") { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "mutationResults": []any{}, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer apiServer.Close() - - restore := ds9.SetTestURLs(metadataServer.URL, apiServer.URL) - defer restore() - - client, err := ds9.NewClient(context.Background(), "test-project") - if err != nil { - t.Fatalf("NewClient failed: %v", err) - } - - ctx := context.Background() - key := ds9.NameKey("Test", "key", nil) - - _, err = client.RunInTransaction(ctx, func(tx *ds9.Transaction) error { - var entity testEntity - err := tx.Get(key, &entity) - if err == nil { - return errors.New("expected decode error") - } - return nil - }) - if err != nil { - t.Errorf("transaction should succeed: %v", err) - } -} - -// Test Delete with multiple retries exhausted -func TestDeleteAllRetriesFail(t *testing.T) { - metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Metadata-Flavor") != "Google" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.URL.Path == "/project/project-id" { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("test-project")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - if r.URL.Path == "/instance/service-accounts/default/token" { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "access_token": "test-token", - "expires_in": 3600, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer metadataServer.Close() - - requestCount := 0 - apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - requestCount++ - // Always return 503 to force retries - w.WriteHeader(http.StatusServiceUnavailable) - })) - defer apiServer.Close() - - restore := ds9.SetTestURLs(metadataServer.URL, apiServer.URL) - defer restore() - - client, err := ds9.NewClient(context.Background(), "test-project") - if err != nil { - t.Fatalf("NewClient failed: %v", err) - } - - ctx := context.Background() - key := ds9.NameKey("Test", "key", nil) - - err = client.Delete(ctx, key) - if err == nil { - t.Error("expected error after all retries exhausted") - } - - if !strings.Contains(err.Error(), "attempts") { - t.Errorf("expected error message about attempts, got: %v", err) - } - - if requestCount != 3 { - t.Errorf("expected 3 retry attempts, got %d", requestCount) - } -} - -// Test Client.Get with decode error -func TestGetWithDecodeError(t *testing.T) { - metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Metadata-Flavor") != "Google" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.URL.Path == "/project/project-id" { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("test-project")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - if r.URL.Path == "/instance/service-accounts/default/token" { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "access_token": "test-token", - "expires_in": 3600, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer metadataServer.Close() - - apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if strings.Contains(r.URL.Path, ":lookup") { - // Return entity with missing properties field - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "found": []map[string]any{ - { - "entity": map[string]any{ - "key": map[string]any{ - "path": []any{ - map[string]any{ - "kind": "Test", - "name": "key", - }, - }, - }, - // Missing properties field - }, - }, - }, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer apiServer.Close() - - restore := ds9.SetTestURLs(metadataServer.URL, apiServer.URL) - defer restore() - - client, err := ds9.NewClient(context.Background(), "test-project") - if err != nil { - t.Fatalf("NewClient failed: %v", err) - } - - ctx := context.Background() - key := ds9.NameKey("Test", "key", nil) - var entity testEntity - - err = client.Get(ctx, key, &entity) - if err == nil { - t.Error("expected error with missing properties") - } -} - -// Test Put with invalid entity causing encode error -func TestPutWithEncodeError(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - // Create entity with unsupported type - type BadEntity struct { - Channel chan int `datastore:"channel"` - } - - key := ds9.NameKey("Test", "key", nil) - entity := &BadEntity{Channel: make(chan int)} - - _, err := client.Put(ctx, key, entity) - if err == nil { - t.Log("Put with unsupported type succeeded (mock may not validate types)") - } else { - t.Logf("Put with unsupported type failed as expected: %v", err) - } -} - -// Test GetMulti with some entities not found -func TestGetMultiPartialNotFound(t *testing.T) { - metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Metadata-Flavor") != "Google" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.URL.Path == "/project/project-id" { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("test-project")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - if r.URL.Path == "/instance/service-accounts/default/token" { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "access_token": "test-token", - "expires_in": 3600, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer metadataServer.Close() - - apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if strings.Contains(r.URL.Path, ":lookup") { - // Return one found, one missing - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "found": []map[string]any{ - { - "entity": map[string]any{ - "key": map[string]any{ - "path": []any{ - map[string]any{ - "kind": "Test", - "name": "key1", - }, - }, - }, - "properties": map[string]any{ - "name": map[string]any{"stringValue": "test1"}, - }, - }, - }, - }, - "missing": []map[string]any{ - { - "key": map[string]any{ - "path": []any{ - map[string]any{ - "kind": "Test", - "name": "key2", - }, - }, - }, - }, - }, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer apiServer.Close() - - restore := ds9.SetTestURLs(metadataServer.URL, apiServer.URL) - defer restore() - - client, err := ds9.NewClient(context.Background(), "test-project") - if err != nil { - t.Fatalf("NewClient failed: %v", err) - } - - ctx := context.Background() - keys := []*ds9.Key{ - ds9.NameKey("Test", "key1", nil), - ds9.NameKey("Test", "key2", nil), - } - - var entities []testEntity - err = client.GetMulti(ctx, keys, &entities) - if err == nil { - t.Error("expected error when some entities are missing") - } else { - t.Logf("GetMulti with missing entities failed as expected: %v", err) - } -} - -// Test AllKeys with invalid JSON response -func TestAllKeysInvalidJSON(t *testing.T) { - metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Metadata-Flavor") != "Google" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.URL.Path == "/project/project-id" { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("test-project")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - if r.URL.Path == "/instance/service-accounts/default/token" { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "access_token": "test-token", - "expires_in": 3600, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer metadataServer.Close() - - apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if strings.Contains(r.URL.Path, ":runQuery") { - // Return invalid JSON - w.Header().Set("Content-Type", "application/json") - if _, err := w.Write([]byte("{")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer apiServer.Close() - - restore := ds9.SetTestURLs(metadataServer.URL, apiServer.URL) - defer restore() - - client, err := ds9.NewClient(context.Background(), "test-project") - if err != nil { - t.Fatalf("NewClient failed: %v", err) - } - - ctx := context.Background() - query := ds9.NewQuery("Test").KeysOnly() - - _, err = client.AllKeys(ctx, query) - if err == nil { - t.Error("expected error with invalid JSON") - } -} - -// Test Transaction commit with invalid response -func TestTransactionCommitInvalidResponse(t *testing.T) { - metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Metadata-Flavor") != "Google" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.URL.Path == "/project/project-id" { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("test-project")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - if r.URL.Path == "/instance/service-accounts/default/token" { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "access_token": "test-token", - "expires_in": 3600, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer metadataServer.Close() - - apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if strings.Contains(r.URL.Path, ":beginTransaction") { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "transaction": "test-tx-id", - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - if strings.Contains(r.URL.Path, ":commit") { - // Return invalid JSON (missing mutationResults) - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - // Missing mutationResults field - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer apiServer.Close() - - restore := ds9.SetTestURLs(metadataServer.URL, apiServer.URL) - defer restore() - - client, err := ds9.NewClient(context.Background(), "test-project") - if err != nil { - t.Fatalf("NewClient failed: %v", err) - } - - ctx := context.Background() - key := ds9.NameKey("Test", "key", nil) - entity := &testEntity{Name: "test"} - - _, err = client.RunInTransaction(ctx, func(tx *ds9.Transaction) error { - _, err := tx.Put(key, entity) - return err - }) - if err != nil { - t.Logf("Transaction with invalid commit response failed: %v", err) - } -} - -// Test PutMulti with encode errors in entities -func TestPutMultiWithInvalidEntities(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - type InvalidEntity struct { - Func func() `datastore:"func"` - } - - keys := []*ds9.Key{ - ds9.NameKey("Test", "key1", nil), - } - - entities := []InvalidEntity{ - {Func: func() {}}, - } - - _, err := client.PutMulti(ctx, keys, entities) - if err == nil { - t.Log("PutMulti with func field succeeded (mock may not validate types)") - } else { - t.Logf("PutMulti with func field failed as expected: %v", err) - } -} - -// Test decodeValue with invalid integer format -func TestDecodeValueInvalidInteger(t *testing.T) { - metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Metadata-Flavor") != "Google" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.URL.Path == "/project/project-id" { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("test-project")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - if r.URL.Path == "/instance/service-accounts/default/token" { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "access_token": "test-token", - "expires_in": 3600, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer metadataServer.Close() - - apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if strings.Contains(r.URL.Path, ":lookup") { - // Return entity with invalid integer format - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "found": []map[string]any{ - { - "entity": map[string]any{ - "key": map[string]any{ - "path": []any{ - map[string]any{ - "kind": "Test", - "name": "key", - }, - }, - }, - "properties": map[string]any{ - "count": map[string]any{"integerValue": "not-an-integer"}, - }, - }, - }, - }, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer apiServer.Close() - - restore := ds9.SetTestURLs(metadataServer.URL, apiServer.URL) - defer restore() - - client, err := ds9.NewClient(context.Background(), "test-project") - if err != nil { - t.Fatalf("NewClient failed: %v", err) - } - - ctx := context.Background() - key := ds9.NameKey("Test", "key", nil) - var entity testEntity - - err = client.Get(ctx, key, &entity) - if err == nil { - t.Error("expected error with invalid integer format") - } else { - t.Logf("Got expected error: %v", err) - } -} - -// Test decodeValue with wrong type for integer -func TestDecodeValueWrongTypeForInteger(t *testing.T) { - metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Metadata-Flavor") != "Google" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.URL.Path == "/project/project-id" { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("test-project")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - if r.URL.Path == "/instance/service-accounts/default/token" { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "access_token": "test-token", - "expires_in": 3600, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer metadataServer.Close() - - apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if strings.Contains(r.URL.Path, ":lookup") { - // Return entity with integer value but string field type - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "found": []map[string]any{ - { - "entity": map[string]any{ - "key": map[string]any{ - "path": []any{ - map[string]any{ - "kind": "Test", - "name": "key", - }, - }, - }, - "properties": map[string]any{ - "name": map[string]any{"integerValue": "12345"}, // integer for string field - }, - }, - }, - }, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer apiServer.Close() - - restore := ds9.SetTestURLs(metadataServer.URL, apiServer.URL) - defer restore() - - client, err := ds9.NewClient(context.Background(), "test-project") - if err != nil { - t.Fatalf("NewClient failed: %v", err) - } - - ctx := context.Background() - key := ds9.NameKey("Test", "key", nil) - var entity testEntity - - err = client.Get(ctx, key, &entity) - if err == nil { - t.Error("expected error with wrong type for integer") - } else { - t.Logf("Got expected error: %v", err) - } -} - -// Test decodeValue with invalid timestamp format -func TestDecodeValueInvalidTimestamp(t *testing.T) { - metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Metadata-Flavor") != "Google" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.URL.Path == "/project/project-id" { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("test-project")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - if r.URL.Path == "/instance/service-accounts/default/token" { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "access_token": "test-token", - "expires_in": 3600, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer metadataServer.Close() - - apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if strings.Contains(r.URL.Path, ":lookup") { - // Return entity with invalid timestamp format - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "found": []map[string]any{ - { - "entity": map[string]any{ - "key": map[string]any{ - "path": []any{ - map[string]any{ - "kind": "Test", - "name": "key", - }, - }, - }, - "properties": map[string]any{ - "updated_at": map[string]any{"timestampValue": "invalid-timestamp"}, - }, - }, - }, - }, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer apiServer.Close() - - restore := ds9.SetTestURLs(metadataServer.URL, apiServer.URL) - defer restore() - - client, err := ds9.NewClient(context.Background(), "test-project") - if err != nil { - t.Fatalf("NewClient failed: %v", err) - } - - ctx := context.Background() - key := ds9.NameKey("Test", "key", nil) - var entity testEntity - - err = client.Get(ctx, key, &entity) - if err == nil { - t.Error("expected error with invalid timestamp format") - } else { - t.Logf("Got expected error: %v", err) - } -} - -// Test Client.Get with non-pointer destination -func TestGetWithNonPointer(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - key := ds9.NameKey("Test", "key", nil) - var entity testEntity // non-pointer - - err := client.Get(ctx, key, entity) // Pass by value - if err == nil { - t.Error("expected error when dst is not a pointer") - } -} - -// Test Client.Put with non-struct -func TestPutWithNonStruct(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - key := ds9.NameKey("Test", "key", nil) - entity := "not a struct" - - _, err := client.Put(ctx, key, entity) - if err == nil { - t.Error("expected error when entity is not a struct") - } -} - -// Test AllKeys with non-KeysOnly query error handling -func TestAllKeysNotKeysOnlyError(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - query := ds9.NewQuery("Test") // Not KeysOnly - - _, err := client.AllKeys(ctx, query) - if err == nil { - t.Error("expected error when query is not KeysOnly") - } -} - -// Test GetMulti with mismatched keys and entities length -func TestGetMultiMismatchedLength(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - keys := []*ds9.Key{ - ds9.NameKey("Test", "key1", nil), - ds9.NameKey("Test", "key2", nil), - } - - var entities []testEntity // Empty slice - - err := client.GetMulti(ctx, keys, &entities) - // This should work - GetMulti should populate the slice - if err != nil { - t.Logf("GetMulti with empty slice: %v", err) - } -} - -// Test PutMulti with mismatched keys and entities length -func TestPutMultiMismatchedLength(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - keys := []*ds9.Key{ - ds9.NameKey("Test", "key1", nil), - ds9.NameKey("Test", "key2", nil), - } - - entities := []testEntity{ - {Name: "test1"}, - // Missing second entity - } - - _, err := client.PutMulti(ctx, keys, entities) - if err == nil { - t.Error("expected error with mismatched lengths") - } -} - -// Test DeleteMulti with empty keys slice -func TestDeleteMultiWithEmptyKeysSlice(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - var keys []*ds9.Key // Empty - - err := client.DeleteMulti(ctx, keys) - // Mock may behave differently - log the result - if err != nil { - t.Logf("DeleteMulti with empty keys: %v", err) - } -} - -// Test Client.Get with JSON unmarshal error for found entities -func TestGetWithJSONUnmarshalError(t *testing.T) { - metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Metadata-Flavor") != "Google" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.URL.Path == "/project/project-id" { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("test-project")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - if r.URL.Path == "/instance/service-accounts/default/token" { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "access_token": "test-token", - "expires_in": 3600, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer metadataServer.Close() - - apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if strings.Contains(r.URL.Path, ":lookup") { - // Return invalid JSON - w.Header().Set("Content-Type", "application/json") - if _, err := w.Write([]byte(`{"found": [{"entity": "not-an-object"}]}`)); err != nil { - t.Logf("write failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer apiServer.Close() - - restore := ds9.SetTestURLs(metadataServer.URL, apiServer.URL) - defer restore() - - client, err := ds9.NewClient(context.Background(), "test-project") - if err != nil { - t.Fatalf("NewClient failed: %v", err) - } - - ctx := context.Background() - key := ds9.NameKey("Test", "key", nil) - var entity testEntity - - err = client.Get(ctx, key, &entity) - if err == nil { - t.Error("expected error with invalid entity format") - } -} - -// Test Client.Put with access token error -func TestPutWithAccessTokenError(t *testing.T) { - metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Metadata-Flavor") != "Google" { - w.WriteHeader(http.StatusForbidden) - return - } - // Always return error for token - w.WriteHeader(http.StatusInternalServerError) - })) - defer metadataServer.Close() - - restore := ds9.SetTestURLs(metadataServer.URL, "http://unused") - defer restore() - - client, err := ds9.NewClient(context.Background(), "test-project") - if err != nil { - t.Fatalf("NewClient failed: %v", err) - } - - ctx := context.Background() - key := ds9.NameKey("Test", "key", nil) - entity := &testEntity{Name: "test"} - - _, err = client.Put(ctx, key, entity) - if err == nil { - t.Error("expected error when access token fails") - } -} - -// Test Client.Delete with JSON marshal error -func TestDeleteWithJSONMarshalError(t *testing.T) { - // This is hard to trigger since we control the JSON structure - // But we can test with a context that gets cancelled - metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Metadata-Flavor") != "Google" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.URL.Path == "/project/project-id" { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("test-project")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - if r.URL.Path == "/instance/service-accounts/default/token" { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "access_token": "test-token", - "expires_in": 3600, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer metadataServer.Close() - - apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - if err := json.NewEncoder(w).Encode(map[string]any{}); err != nil { - t.Logf("encode failed: %v", err) - } - })) - defer apiServer.Close() - - restore := ds9.SetTestURLs(metadataServer.URL, apiServer.URL) - defer restore() - - client, err := ds9.NewClient(context.Background(), "test-project") - if err != nil { - t.Fatalf("NewClient failed: %v", err) - } - - ctx := context.Background() - key := ds9.NameKey("Test", "key", nil) - - err = client.Delete(ctx, key) - if err != nil { - t.Logf("Delete completed with: %v", err) - } -} - -// Test GetMulti with decode error for specific entity -func TestGetMultiDecodeError(t *testing.T) { - metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Metadata-Flavor") != "Google" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.URL.Path == "/project/project-id" { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("test-project")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - if r.URL.Path == "/instance/service-accounts/default/token" { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "access_token": "test-token", - "expires_in": 3600, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer metadataServer.Close() - - apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if strings.Contains(r.URL.Path, ":lookup") { - // Return one good entity and one with decode error - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "found": []map[string]any{ - { - "entity": map[string]any{ - "key": map[string]any{ - "path": []any{ - map[string]any{ - "kind": "Test", - "name": "key1", - }, - }, - }, - "properties": map[string]any{ - "name": map[string]any{"stringValue": "test"}, - }, - }, - }, - { - "entity": "invalid", // This will cause decode error - }, - }, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer apiServer.Close() - - restore := ds9.SetTestURLs(metadataServer.URL, apiServer.URL) - defer restore() - - client, err := ds9.NewClient(context.Background(), "test-project") - if err != nil { - t.Fatalf("NewClient failed: %v", err) - } - - ctx := context.Background() - keys := []*ds9.Key{ - ds9.NameKey("Test", "key1", nil), - ds9.NameKey("Test", "key2", nil), - } - - var entities []testEntity - err = client.GetMulti(ctx, keys, &entities) - if err == nil { - t.Error("expected error when one entity has decode error") - } -} - -// Test AllKeys with batch batching (many results) -func TestAllKeysWithBatching(t *testing.T) { - metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Metadata-Flavor") != "Google" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.URL.Path == "/project/project-id" { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("test-project")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - if r.URL.Path == "/instance/service-accounts/default/token" { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "access_token": "test-token", - "expires_in": 3600, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer metadataServer.Close() - - apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if strings.Contains(r.URL.Path, ":runQuery") { - // Return multiple key results - results := make([]map[string]any, 50) - for i := range 50 { - results[i] = map[string]any{ - "entity": map[string]any{ - "key": map[string]any{ - "path": []any{ - map[string]any{ - "kind": "Test", - "name": fmt.Sprintf("key%d", i), - }, - }, - }, - }, - } - } - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "batch": map[string]any{ - "entityResults": results, - }, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer apiServer.Close() - - restore := ds9.SetTestURLs(metadataServer.URL, apiServer.URL) - defer restore() - - client, err := ds9.NewClient(context.Background(), "test-project") - if err != nil { - t.Fatalf("NewClient failed: %v", err) - } - - ctx := context.Background() - query := ds9.NewQuery("Test").KeysOnly() - - keys, err := client.AllKeys(ctx, query) - if err != nil { - t.Logf("AllKeys with many results: %v", err) - } else if len(keys) != 50 { - t.Logf("Expected 50 keys, got %d", len(keys)) - } -} - -// Test AllKeys with keyFromJSON error -func TestAllKeysKeyFromJSONError(t *testing.T) { - metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Metadata-Flavor") != "Google" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.URL.Path == "/project/project-id" { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("test-project")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - if r.URL.Path == "/instance/service-accounts/default/token" { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "access_token": "test-token", - "expires_in": 3600, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer metadataServer.Close() - - apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if strings.Contains(r.URL.Path, ":runQuery") { - // Return result with invalid key format - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "batch": map[string]any{ - "entityResults": []map[string]any{ - { - "entity": map[string]any{ - "key": "not-a-map", // Invalid key format - }, - }, - }, - }, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer apiServer.Close() - - restore := ds9.SetTestURLs(metadataServer.URL, apiServer.URL) - defer restore() - - client, err := ds9.NewClient(context.Background(), "test-project") - if err != nil { - t.Fatalf("NewClient failed: %v", err) - } - - ctx := context.Background() - query := ds9.NewQuery("Test").KeysOnly() - - _, err = client.AllKeys(ctx, query) - if err == nil { - t.Error("expected error with invalid key format") - } -} - -// Test PutMulti with JSON marshal error for request body -func TestPutMultiRequestMarshalError(t *testing.T) { - // This is hard to trigger directly, but we can test with encoding errors - metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Metadata-Flavor") != "Google" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.URL.Path == "/project/project-id" { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("test-project")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - if r.URL.Path == "/instance/service-accounts/default/token" { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "access_token": "test-token", - "expires_in": 3600, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer metadataServer.Close() - - apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - if err := json.NewEncoder(w).Encode(map[string]any{ - "mutationResults": []any{}, - }); err != nil { - t.Logf("encode failed: %v", err) - } - })) - defer apiServer.Close() - - restore := ds9.SetTestURLs(metadataServer.URL, apiServer.URL) - defer restore() - - client, err := ds9.NewClient(context.Background(), "test-project") - if err != nil { - t.Fatalf("NewClient failed: %v", err) - } - - ctx := context.Background() - - // Test with valid entities to exercise the code path - keys := []*ds9.Key{ - ds9.NameKey("Test", "key1", nil), - } - - entities := []testEntity{ - {Name: "test1", Count: 123}, - } - - _, err = client.PutMulti(ctx, keys, entities) - if err != nil { - t.Logf("PutMulti completed with: %v", err) - } -} - -// Test Transaction commit with JSON unmarshal error -func TestTransactionCommitUnmarshalError(t *testing.T) { - metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Metadata-Flavor") != "Google" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.URL.Path == "/project/project-id" { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("test-project")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - if r.URL.Path == "/instance/service-accounts/default/token" { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "access_token": "test-token", - "expires_in": 3600, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer metadataServer.Close() - - apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if strings.Contains(r.URL.Path, ":beginTransaction") { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "transaction": "test-tx-id", - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - if strings.Contains(r.URL.Path, ":commit") { - // Return malformed mutation results - w.Header().Set("Content-Type", "application/json") - if _, err := w.Write([]byte(`{"mutationResults": "not-an-array"}`)); err != nil { - t.Logf("write failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer apiServer.Close() - - restore := ds9.SetTestURLs(metadataServer.URL, apiServer.URL) - defer restore() - - client, err := ds9.NewClient(context.Background(), "test-project") - if err != nil { - t.Fatalf("NewClient failed: %v", err) - } - - ctx := context.Background() - key := ds9.NameKey("Test", "key", nil) - entity := &testEntity{Name: "test"} - - _, err = client.RunInTransaction(ctx, func(tx *ds9.Transaction) error { - _, err := tx.Put(key, entity) - return err - }) - // May or may not error depending on JSON parsing behavior - if err != nil { - t.Logf("Transaction with malformed mutation results failed: %v", err) - } -} - -// Test DeleteAllByKind with empty batch response -func TestDeleteAllByKindEmptyBatch(t *testing.T) { - metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Metadata-Flavor") != "Google" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.URL.Path == "/project/project-id" { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("test-project")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - if r.URL.Path == "/instance/service-accounts/default/token" { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "access_token": "test-token", - "expires_in": 3600, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer metadataServer.Close() - - apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if strings.Contains(r.URL.Path, ":runQuery") { - // Return empty batch - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "batch": map[string]any{}, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer apiServer.Close() - - restore := ds9.SetTestURLs(metadataServer.URL, apiServer.URL) - defer restore() - - client, err := ds9.NewClient(context.Background(), "test-project") - if err != nil { - t.Fatalf("NewClient failed: %v", err) - } - - ctx := context.Background() - err = client.DeleteAllByKind(ctx, "EmptyKind") - if err != nil { - t.Logf("DeleteAllByKind with empty batch: %v", err) - } -} - -// Test AllKeys with empty path in key -func TestAllKeysEmptyPathInKey(t *testing.T) { - metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Metadata-Flavor") != "Google" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.URL.Path == "/project/project-id" { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("test-project")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - if r.URL.Path == "/instance/service-accounts/default/token" { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "access_token": "test-token", - "expires_in": 3600, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer metadataServer.Close() - - apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if strings.Contains(r.URL.Path, ":runQuery") { - w.Header().Set("Content-Type", "application/json") - // Return key with empty path array - if err := json.NewEncoder(w).Encode(map[string]any{ - "batch": map[string]any{ - "entityResults": []map[string]any{ - { - "entity": map[string]any{ - "key": map[string]any{ - "path": []any{}, // Empty path - }, - }, - }, - }, - }, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer apiServer.Close() - - restore := ds9.SetTestURLs(metadataServer.URL, apiServer.URL) - defer restore() - - client, err := ds9.NewClient(context.Background(), "test-project") - if err != nil { - t.Fatalf("NewClient failed: %v", err) - } - - ctx := context.Background() - query := ds9.NewQuery("TestKind").KeysOnly() - _, err = client.AllKeys(ctx, query) - if err == nil { - t.Error("expected error with empty path in key") - } -} - -// Test AllKeys with invalid path element (not a map) -func TestAllKeysInvalidPathElement(t *testing.T) { - metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Metadata-Flavor") != "Google" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.URL.Path == "/project/project-id" { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("test-project")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - if r.URL.Path == "/instance/service-accounts/default/token" { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "access_token": "test-token", - "expires_in": 3600, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer metadataServer.Close() - - apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if strings.Contains(r.URL.Path, ":runQuery") { - w.Header().Set("Content-Type", "application/json") - // Return key with invalid path element (string instead of map) - if err := json.NewEncoder(w).Encode(map[string]any{ - "batch": map[string]any{ - "entityResults": []map[string]any{ - { - "entity": map[string]any{ - "key": map[string]any{ - "path": []any{"invalid-element"}, // String instead of map - }, - }, - }, - }, - }, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer apiServer.Close() - - restore := ds9.SetTestURLs(metadataServer.URL, apiServer.URL) - defer restore() - - client, err := ds9.NewClient(context.Background(), "test-project") - if err != nil { - t.Fatalf("NewClient failed: %v", err) - } - - ctx := context.Background() - query := ds9.NewQuery("TestKind").KeysOnly() - _, err = client.AllKeys(ctx, query) - if err == nil { - t.Error("expected error with invalid path element") - } -} - -// Test Get with ID key as string -func TestGetWithStringIDKey(t *testing.T) { - metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Metadata-Flavor") != "Google" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.URL.Path == "/project/project-id" { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("test-project")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - if r.URL.Path == "/instance/service-accounts/default/token" { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "access_token": "test-token", - "expires_in": 3600, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer metadataServer.Close() - - apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if strings.Contains(r.URL.Path, ":lookup") { - w.Header().Set("Content-Type", "application/json") - // Return entity with ID as string - if err := json.NewEncoder(w).Encode(map[string]any{ - "found": []map[string]any{ - { - "entity": map[string]any{ - "key": map[string]any{ - "path": []any{ - map[string]any{ - "kind": "TestKind", - "id": "12345", // ID as string - }, - }, - }, - "properties": map[string]any{ - "name": map[string]any{"stringValue": "test"}, - }, - }, - }, - }, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer apiServer.Close() - - restore := ds9.SetTestURLs(metadataServer.URL, apiServer.URL) - defer restore() - - client, err := ds9.NewClient(context.Background(), "test-project") - if err != nil { - t.Fatalf("NewClient failed: %v", err) - } - - type TestEntity struct { - Name string `datastore:"name"` - } - - ctx := context.Background() - key := ds9.IDKey("TestKind", 12345, nil) - var entity TestEntity - err = client.Get(ctx, key, &entity) - if err != nil { - t.Fatalf("Get with string ID key failed: %v", err) - } - - if entity.Name != "test" { - t.Errorf("expected name 'test', got %q", entity.Name) - } -} - -// Test Get with ID key as float64 -func TestGetWithFloat64IDKey(t *testing.T) { - metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Metadata-Flavor") != "Google" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.URL.Path == "/project/project-id" { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("test-project")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - if r.URL.Path == "/instance/service-accounts/default/token" { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "access_token": "test-token", - "expires_in": 3600, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer metadataServer.Close() - - apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if strings.Contains(r.URL.Path, ":lookup") { - w.Header().Set("Content-Type", "application/json") - // Return entity with ID as float64 - if err := json.NewEncoder(w).Encode(map[string]any{ - "found": []map[string]any{ - { - "entity": map[string]any{ - "key": map[string]any{ - "path": []any{ - map[string]any{ - "kind": "TestKind", - "id": float64(67890), // ID as float64 - }, - }, - }, - "properties": map[string]any{ - "value": map[string]any{"integerValue": "42"}, - }, - }, - }, - }, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer apiServer.Close() - - restore := ds9.SetTestURLs(metadataServer.URL, apiServer.URL) - defer restore() - - client, err := ds9.NewClient(context.Background(), "test-project") - if err != nil { - t.Fatalf("NewClient failed: %v", err) - } - - type TestEntity struct { - Value int64 `datastore:"value"` - } - - ctx := context.Background() - key := ds9.IDKey("TestKind", 67890, nil) - var entity TestEntity - err = client.Get(ctx, key, &entity) - if err != nil { - t.Fatalf("Get with float64 ID key failed: %v", err) - } - - if entity.Value != 42 { - t.Errorf("expected value 42, got %d", entity.Value) - } -} - -// Test Get with invalid string ID format in response -func TestGetWithInvalidStringIDFormat(t *testing.T) { - metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Metadata-Flavor") != "Google" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.URL.Path == "/project/project-id" { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("test-project")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - if r.URL.Path == "/instance/service-accounts/default/token" { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "access_token": "test-token", - "expires_in": 3600, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer metadataServer.Close() - - apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if strings.Contains(r.URL.Path, ":lookup") { - w.Header().Set("Content-Type", "application/json") - // Return entity with invalid ID string format - if err := json.NewEncoder(w).Encode(map[string]any{ - "found": []map[string]any{ - { - "entity": map[string]any{ - "key": map[string]any{ - "path": []any{ - map[string]any{ - "kind": "TestKind", - "id": "not-a-number", // Invalid ID format - }, - }, - }, - "properties": map[string]any{ - "name": map[string]any{"stringValue": "test"}, - }, - }, - }, - }, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer apiServer.Close() - - restore := ds9.SetTestURLs(metadataServer.URL, apiServer.URL) - defer restore() - - client, err := ds9.NewClient(context.Background(), "test-project") - if err != nil { - t.Fatalf("NewClient failed: %v", err) - } - - type TestEntity struct { - Name string `datastore:"name"` - } - - ctx := context.Background() - key := ds9.IDKey("TestKind", 12345, nil) - var entity TestEntity - err = client.Get(ctx, key, &entity) - // May or may not error depending on parsing behavior - if err != nil { - t.Logf("Get with invalid string ID format failed: %v", err) - } else { - t.Logf("Get with invalid string ID format succeeded unexpectedly") - } -} - -// Test Transaction.Get with no entity found -func TestTransactionGetNotFound(t *testing.T) { - metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Metadata-Flavor") != "Google" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.URL.Path == "/project/project-id" { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("test-project")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - if r.URL.Path == "/instance/service-accounts/default/token" { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "access_token": "test-token", - "expires_in": 3600, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer metadataServer.Close() - - apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if strings.Contains(r.URL.Path, ":beginTransaction") { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "transaction": "test-tx-id", - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - if strings.Contains(r.URL.Path, ":lookup") { - // Return empty found array - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "found": []any{}, - "missing": []any{}, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - if strings.Contains(r.URL.Path, ":commit") { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "mutationResults": []any{}, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer apiServer.Close() - - restore := ds9.SetTestURLs(metadataServer.URL, apiServer.URL) - defer restore() - - client, err := ds9.NewClient(context.Background(), "test-project") - if err != nil { - t.Fatalf("NewClient failed: %v", err) - } - - ctx := context.Background() - key := ds9.NameKey("Test", "nonexistent", nil) - - _, err = client.RunInTransaction(ctx, func(tx *ds9.Transaction) error { - var entity testEntity - err := tx.Get(key, &entity) - if err == nil { - t.Error("expected error with empty found array") - } - return nil - }) - if err != nil { - t.Logf("Transaction completed: %v", err) - } -} - -// Test Transaction.Get with access token error -func TestTransactionGetAccessTokenError(t *testing.T) { - metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Metadata-Flavor") != "Google" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.URL.Path == "/project/project-id" { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("test-project")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - if r.URL.Path == "/instance/service-accounts/default/token" { - // Return error for token request - w.WriteHeader(http.StatusInternalServerError) - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer metadataServer.Close() - - apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if strings.Contains(r.URL.Path, ":beginTransaction") { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "transaction": "test-tx-id", - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer apiServer.Close() - - restore := ds9.SetTestURLs(metadataServer.URL, apiServer.URL) - defer restore() - - client, err := ds9.NewClient(context.Background(), "test-project") - if err != nil { - t.Fatalf("NewClient failed: %v", err) - } - - ctx := context.Background() - key := ds9.NameKey("Test", "test-key", nil) - - _, err = client.RunInTransaction(ctx, func(tx *ds9.Transaction) error { - var entity testEntity - err := tx.Get(key, &entity) - if err == nil { - t.Error("expected error with token failure") - } - return err - }) - - if err == nil { - t.Error("expected transaction to fail with token error") - } -} - -// Test Transaction.Get with non-OK status -func TestTransactionGetNonOKStatus(t *testing.T) { - metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Metadata-Flavor") != "Google" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.URL.Path == "/project/project-id" { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("test-project")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - if r.URL.Path == "/instance/service-accounts/default/token" { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "access_token": "test-token", - "expires_in": 3600, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer metadataServer.Close() - - apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if strings.Contains(r.URL.Path, ":beginTransaction") { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "transaction": "test-tx-id", - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - if strings.Contains(r.URL.Path, ":lookup") { - // Return non-OK status - w.WriteHeader(http.StatusBadRequest) - if _, err := w.Write([]byte("bad request")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - if strings.Contains(r.URL.Path, ":commit") { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "mutationResults": []any{}, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer apiServer.Close() - - restore := ds9.SetTestURLs(metadataServer.URL, apiServer.URL) - defer restore() - - client, err := ds9.NewClient(context.Background(), "test-project") - if err != nil { - t.Fatalf("NewClient failed: %v", err) - } - - ctx := context.Background() - key := ds9.NameKey("Test", "test-key", nil) - - _, err = client.RunInTransaction(ctx, func(tx *ds9.Transaction) error { - var entity testEntity - return tx.Get(key, &entity) - }) - - if err == nil { - t.Error("expected error with non-OK status") - } -} - -// Test Client.Get with JSON unmarshal error -func TestGetJSONUnmarshalError(t *testing.T) { - metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Metadata-Flavor") != "Google" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.URL.Path == "/project/project-id" { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("test-project")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - if r.URL.Path == "/instance/service-accounts/default/token" { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "access_token": "test-token", - "expires_in": 3600, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer metadataServer.Close() - - apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if strings.Contains(r.URL.Path, ":lookup") { - // Return malformed JSON - w.Header().Set("Content-Type", "application/json") - if _, err := w.Write([]byte("not valid json")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer apiServer.Close() - - restore := ds9.SetTestURLs(metadataServer.URL, apiServer.URL) - defer restore() - - client, err := ds9.NewClient(context.Background(), "test-project") - if err != nil { - t.Fatalf("NewClient failed: %v", err) - } - - ctx := context.Background() - key := ds9.NameKey("Test", "test-key", nil) - var entity testEntity - - err = client.Get(ctx, key, &entity) - if err == nil { - t.Error("expected error with malformed JSON") - } -} - -// Test PutMulti with length mismatch -func TestPutMultiLengthValidation(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - keys := []*ds9.Key{ds9.NameKey("Test", "key1", nil)} - entities := []testEntity{{Name: "test1"}, {Name: "test2"}} - - _, err := client.PutMulti(ctx, keys, entities) - if err == nil { - t.Error("expected error with mismatched lengths") - } -} - -// Test DeleteMulti with partial success -func TestDeleteMultiMixedResults(t *testing.T) { - metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("Metadata-Flavor") != "Google" { - w.WriteHeader(http.StatusForbidden) - return - } - if r.URL.Path == "/project/project-id" { - w.WriteHeader(http.StatusOK) - if _, err := w.Write([]byte("test-project")); err != nil { - t.Logf("write failed: %v", err) - } - return - } - if r.URL.Path == "/instance/service-accounts/default/token" { - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "access_token": "test-token", - "expires_in": 3600, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer metadataServer.Close() - - apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if strings.Contains(r.URL.Path, ":commit") { - w.Header().Set("Content-Type", "application/json") - // Return empty mutation results - if err := json.NewEncoder(w).Encode(map[string]any{ - "mutationResults": []any{}, - }); err != nil { - t.Logf("encode failed: %v", err) - } - return - } - w.WriteHeader(http.StatusNotFound) - })) - defer apiServer.Close() - - restore := ds9.SetTestURLs(metadataServer.URL, apiServer.URL) - defer restore() - - client, err := ds9.NewClient(context.Background(), "test-project") - if err != nil { - t.Fatalf("NewClient failed: %v", err) - } - - ctx := context.Background() - keys := []*ds9.Key{ - ds9.NameKey("Test", "key1", nil), - ds9.NameKey("Test", "key2", nil), - } - - err = client.DeleteMulti(ctx, keys) - // May or may not error depending on implementation - if err != nil { - t.Logf("DeleteMulti with mismatched results: %v", err) - } -} - -// TestBackwardsCompatibility tests the API compatibility with cloud.google.com/go/datastore. -// This ensures that ds9 can be used as a drop-in replacement. -func TestBackwardsCompatibility(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - // Test 1: Close() method exists and can be called (even though it's a no-op) - t.Run("Close", func(t *testing.T) { - err := client.Close() - if err != nil { - t.Errorf("Close() returned error: %v", err) - } - }) - - // Test 2: RunInTransaction returns (*Commit, error) - t.Run("RunInTransactionSignature", func(t *testing.T) { - key := ds9.NameKey("TestKind", "test-tx-compat", nil) - - commit, err := client.RunInTransaction(ctx, func(tx *ds9.Transaction) error { - entity := &testEntity{ - Name: "transaction test", - Count: 100, - Active: true, - Score: 99.9, - UpdatedAt: time.Now().UTC().Truncate(time.Microsecond), - } - _, err := tx.Put(key, entity) - return err - }) - if err != nil { - t.Fatalf("RunInTransaction failed: %v", err) - } - - if commit == nil { - t.Error("Expected non-nil Commit, got nil") - } - }) - - // Test 3: GetAll() method retrieves entities and returns keys - t.Run("GetAll", func(t *testing.T) { - // Setup: Create some test entities - entities := []testEntity{ - {Name: "entity1", Count: 1, Active: true, Score: 1.1, UpdatedAt: time.Now().UTC().Truncate(time.Microsecond)}, - {Name: "entity2", Count: 2, Active: false, Score: 2.2, UpdatedAt: time.Now().UTC().Truncate(time.Microsecond)}, - {Name: "entity3", Count: 3, Active: true, Score: 3.3, UpdatedAt: time.Now().UTC().Truncate(time.Microsecond)}, - } - - keys := []*ds9.Key{ - ds9.NameKey("GetAllTest", "key1", nil), - ds9.NameKey("GetAllTest", "key2", nil), - ds9.NameKey("GetAllTest", "key3", nil), - } - - _, err := client.PutMulti(ctx, keys, entities) - if err != nil { - t.Fatalf("PutMulti failed: %v", err) - } - - // Test GetAll - query := ds9.NewQuery("GetAllTest") - var results []testEntity - returnedKeys, err := client.GetAll(ctx, query, &results) - if err != nil { - t.Fatalf("GetAll failed: %v", err) - } - - if len(results) != 3 { - t.Errorf("Expected 3 entities, got %d", len(results)) - } - - if len(returnedKeys) != 3 { - t.Errorf("Expected 3 keys, got %d", len(returnedKeys)) - } - - // Verify entities were properly decoded - foundNames := make(map[string]bool) - for _, entity := range results { - foundNames[entity.Name] = true - } - - for _, expectedName := range []string{"entity1", "entity2", "entity3"} { - if !foundNames[expectedName] { - t.Errorf("Expected to find entity %s, but didn't", expectedName) - } - } - - // Verify keys match entities - for i, key := range returnedKeys { - if key.Kind != "GetAllTest" { - t.Errorf("Key %d has wrong kind: %s", i, key.Kind) - } - } - }) - - // Test 4: GetAll with limit - t.Run("GetAllWithLimit", func(t *testing.T) { - query := ds9.NewQuery("GetAllTest").Limit(2) - var results []testEntity - returnedKeys, err := client.GetAll(ctx, query, &results) - if err != nil { - t.Fatalf("GetAll with limit failed: %v", err) - } - - if len(results) != 2 { - t.Errorf("Expected 2 entities with limit, got %d", len(results)) - } - - if len(returnedKeys) != 2 { - t.Errorf("Expected 2 keys with limit, got %d", len(returnedKeys)) - } - }) -} - -// TestClose tests that the Close() method exists and returns no error. -func TestClose(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - err := client.Close() - if err != nil { - t.Errorf("Close() returned unexpected error: %v", err) - } - - // Should be idempotent - can call multiple times - err = client.Close() - if err != nil { - t.Errorf("Second Close() returned unexpected error: %v", err) - } -} - -// TestGetAllEmpty tests GetAll with no results. -func TestGetAllEmpty(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - query := ds9.NewQuery("NonExistentKind") - var results []testEntity - - keys, err := client.GetAll(ctx, query, &results) - if err != nil { - t.Fatalf("GetAll failed: %v", err) - } - - if len(results) != 0 { - t.Errorf("Expected 0 entities, got %d", len(results)) - } - - if len(keys) != 0 { - t.Errorf("Expected 0 keys, got %d", len(keys)) - } -} - -// TestGetAllInvalidDst tests GetAll with invalid destination. -func TestGetAllInvalidDst(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - query := ds9.NewQuery("TestKind") - - tests := []struct { - name string - dst any - }{ - {"not a pointer", []testEntity{}}, - {"not a slice", new(testEntity)}, - {"nil", nil}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - _, err := client.GetAll(ctx, query, tt.dst) - if err == nil { - t.Error("Expected error for invalid dst, got nil") - } - }) - } -} - -// TestGetAllSingleEntity tests GetAll retrieving a single entity. -func TestGetAllSingleEntity(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - // Create entity - key := ds9.NameKey("SingleGetAll", "single1", nil) - entity := testEntity{ - Name: "single", - Count: 42, - Active: true, - Score: 3.14, - UpdatedAt: time.Now().UTC().Truncate(time.Microsecond), - Notes: "test notes", - } - - _, err := client.Put(ctx, key, &entity) - if err != nil { - t.Fatalf("Put failed: %v", err) - } - - // Test GetAll - query := ds9.NewQuery("SingleGetAll") - var results []testEntity - keys, err := client.GetAll(ctx, query, &results) - if err != nil { - t.Fatalf("GetAll failed: %v", err) - } - - if len(results) != 1 { - t.Fatalf("Expected 1 entity, got %d", len(results)) - } - - if len(keys) != 1 { - t.Fatalf("Expected 1 key, got %d", len(keys)) - } - - // Verify entity content - if results[0].Name != "single" { - t.Errorf("Expected name 'single', got '%s'", results[0].Name) - } - if results[0].Count != 42 { - t.Errorf("Expected count 42, got %d", results[0].Count) - } - if !results[0].Active { - t.Error("Expected active=true") - } - if results[0].Score != 3.14 { - t.Errorf("Expected score 3.14, got %f", results[0].Score) - } - - // Verify key - if keys[0].Kind != "SingleGetAll" { - t.Errorf("Expected kind 'SingleGetAll', got '%s'", keys[0].Kind) - } - if keys[0].Name != "single1" { - t.Errorf("Expected key name 'single1', got '%s'", keys[0].Name) - } -} - -// TestGetAllMultipleEntities tests GetAll retrieving multiple entities. -func TestGetAllMultipleEntities(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - // Create multiple entities - count := 5 - keys := make([]*ds9.Key, count) - entities := make([]testEntity, count) - - for i := range count { - keys[i] = ds9.NameKey("MultiGetAll", fmt.Sprintf("entity%d", i), nil) - entities[i] = testEntity{ - Name: fmt.Sprintf("entity%d", i), - Count: int64(i * 10), - Active: i%2 == 0, - Score: float64(i) * 1.5, - UpdatedAt: time.Now().UTC().Truncate(time.Microsecond), - } - } - - _, err := client.PutMulti(ctx, keys, entities) - if err != nil { - t.Fatalf("PutMulti failed: %v", err) - } - - // Test GetAll - query := ds9.NewQuery("MultiGetAll") - var results []testEntity - returnedKeys, err := client.GetAll(ctx, query, &results) - if err != nil { - t.Fatalf("GetAll failed: %v", err) - } - - if len(results) != count { - t.Fatalf("Expected %d entities, got %d", count, len(results)) - } - - if len(returnedKeys) != count { - t.Fatalf("Expected %d keys, got %d", count, len(returnedKeys)) - } - - // Verify we got all entities - foundNames := make(map[string]bool) - for _, entity := range results { - foundNames[entity.Name] = true - } - - for i := range count { - expectedName := fmt.Sprintf("entity%d", i) - if !foundNames[expectedName] { - t.Errorf("Missing entity: %s", expectedName) - } - } -} - -// TestGetAllWithLimitVariations tests GetAll with various limit values. -func TestGetAllWithLimitVariations(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - // Setup: Create 10 entities - keys := make([]*ds9.Key, 10) - entities := make([]testEntity, 10) - for i := range 10 { - keys[i] = ds9.NameKey("LimitGetAll", fmt.Sprintf("key%d", i), nil) - entities[i] = testEntity{ - Name: fmt.Sprintf("entity%d", i), - Count: int64(i), - Active: true, - UpdatedAt: time.Now().UTC().Truncate(time.Microsecond), - } - } - - _, err := client.PutMulti(ctx, keys, entities) - if err != nil { - t.Fatalf("PutMulti failed: %v", err) - } - - tests := []struct { - name string - limit int - expected int - }{ - {"Limit 1", 1, 1}, - {"Limit 3", 3, 3}, - {"Limit 5", 5, 5}, - {"Limit 10", 10, 10}, - {"Limit 20 (more than available)", 20, 10}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - query := ds9.NewQuery("LimitGetAll").Limit(tt.limit) - var results []testEntity - keys, err := client.GetAll(ctx, query, &results) - if err != nil { - t.Fatalf("GetAll failed: %v", err) - } - - if len(results) != tt.expected { - t.Errorf("Expected %d entities, got %d", tt.expected, len(results)) - } - - if len(keys) != tt.expected { - t.Errorf("Expected %d keys, got %d", tt.expected, len(keys)) - } - }) - } -} - -// TestRunInTransactionReturnsCommit tests that RunInTransaction returns a Commit object. -func TestRunInTransactionReturnsCommit(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - key := ds9.NameKey("CommitTest", "test1", nil) - - commit, err := client.RunInTransaction(ctx, func(tx *ds9.Transaction) error { - entity := &testEntity{ - Name: "commit test", - Count: 1, - Active: true, - UpdatedAt: time.Now().UTC().Truncate(time.Microsecond), - } - _, err := tx.Put(key, entity) - return err - }) - if err != nil { - t.Fatalf("RunInTransaction failed: %v", err) - } - - if commit == nil { - t.Fatal("Expected non-nil Commit, got nil") - } - - // Commit should be a valid *Commit type - _ = commit -} - -// TestRunInTransactionErrorReturnsNilCommit tests that RunInTransaction returns nil Commit on error. -func TestRunInTransactionErrorReturnsNilCommit(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - expectedErr := errors.New("intentional error") - commit, err := client.RunInTransaction(ctx, func(tx *ds9.Transaction) error { - return expectedErr - }) - - if err == nil { - t.Fatal("Expected error, got nil") - } - - if !errors.Is(err, expectedErr) { - t.Errorf("Expected error to be %v, got %v", expectedErr, err) - } - - if commit != nil { - t.Errorf("Expected nil Commit on error, got %v", commit) - } -} - -func TestTransactionOptions(t *testing.T) { - t.Run("MaxAttempts", func(t *testing.T) { - // Test that MaxAttempts option is accepted and sets the retry limit - // We can verify this by checking the error message mentions the right attempt count - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - key := ds9.NameKey("TestKind", "test", nil) - - // This test verifies that the MaxAttempts option is parsed correctly - // The actual retry behavior is tested in TestTransactionMaxRetriesExceeded - _, err := client.RunInTransaction(ctx, func(tx *ds9.Transaction) error { - entity := testEntity{Name: "test", Count: 42} - _, err := tx.Put(key, &entity) - return err - }, ds9.MaxAttempts(5)) - // With mock client, this should succeed - if err != nil { - t.Fatalf("Transaction failed: %v", err) - } - }) - - t.Run("WithReadTime", func(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - key := ds9.NameKey("TestKind", "test", nil) - - // First, put an entity - entity := testEntity{Name: "test", Count: 42} - _, err := client.Put(ctx, key, &entity) - if err != nil { - t.Fatalf("Put failed: %v", err) - } - - // Run a read-only transaction with readTime - readTime := time.Now().UTC() - _, err = client.RunInTransaction(ctx, func(tx *ds9.Transaction) error { - var result testEntity - return tx.Get(key, &result) - }, ds9.WithReadTime(readTime)) - // Note: ds9mock doesn't actually enforce read-only semantics, - // but we're testing that the option is accepted and doesn't cause errors - if err != nil { - t.Fatalf("Transaction with WithReadTime failed: %v", err) - } - }) - - t.Run("CombinedOptions", func(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - key := ds9.NameKey("TestKind", "test", nil) - - // Test that multiple options can be combined - _, err := client.RunInTransaction(ctx, func(tx *ds9.Transaction) error { - entity := testEntity{Name: "test", Count: 42} - _, err := tx.Put(key, &entity) - return err - }, ds9.MaxAttempts(2), ds9.WithReadTime(time.Now().UTC())) - // With mock client, this should succeed - if err != nil { - t.Fatalf("Transaction with combined options failed: %v", err) - } - }) -} - -// Test entity with arrays for array/slice tests -type arrayEntity struct { - Strings []string `datastore:"strings"` - Ints []int64 `datastore:"ints"` - Floats []float64 `datastore:"floats"` - Bools []bool `datastore:"bools"` -} - -func TestArraySliceSupport(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - t.Run("StringSlice", func(t *testing.T) { - key := ds9.NameKey("ArrayTest", "strings", nil) - entity := &arrayEntity{ - Strings: []string{"hello", "world", "test"}, - } - - _, err := client.Put(ctx, key, entity) - if err != nil { - t.Fatalf("Put with string slice failed: %v", err) - } - - var result arrayEntity - if err := client.Get(ctx, key, &result); err != nil { - t.Fatalf("Get failed: %v", err) - } - - if len(result.Strings) != 3 { - t.Errorf("Expected 3 strings, got %d", len(result.Strings)) - } - if result.Strings[0] != "hello" || result.Strings[1] != "world" || result.Strings[2] != "test" { - t.Errorf("String slice values incorrect: %v", result.Strings) - } - }) - - t.Run("Int64Slice", func(t *testing.T) { - key := ds9.NameKey("ArrayTest", "ints", nil) - entity := &arrayEntity{ - Ints: []int64{1, 2, 3, 42, 100}, - } - - _, err := client.Put(ctx, key, entity) - if err != nil { - t.Fatalf("Put with int64 slice failed: %v", err) - } - - var result arrayEntity - if err := client.Get(ctx, key, &result); err != nil { - t.Fatalf("Get failed: %v", err) - } - - if len(result.Ints) != 5 { - t.Errorf("Expected 5 ints, got %d", len(result.Ints)) - } - if result.Ints[3] != 42 { - t.Errorf("Expected Ints[3] = 42, got %d", result.Ints[3]) - } - }) - - t.Run("Float64Slice", func(t *testing.T) { - key := ds9.NameKey("ArrayTest", "floats", nil) - entity := &arrayEntity{ - Floats: []float64{1.1, 2.2, 3.3}, - } - - _, err := client.Put(ctx, key, entity) - if err != nil { - t.Fatalf("Put with float64 slice failed: %v", err) - } - - var result arrayEntity - if err := client.Get(ctx, key, &result); err != nil { - t.Fatalf("Get failed: %v", err) - } - - if len(result.Floats) != 3 { - t.Errorf("Expected 3 floats, got %d", len(result.Floats)) - } - if result.Floats[0] != 1.1 { - t.Errorf("Expected Floats[0] = 1.1, got %f", result.Floats[0]) - } - }) - - t.Run("BoolSlice", func(t *testing.T) { - key := ds9.NameKey("ArrayTest", "bools", nil) - entity := &arrayEntity{ - Bools: []bool{true, false, true}, - } - - _, err := client.Put(ctx, key, entity) - if err != nil { - t.Fatalf("Put with bool slice failed: %v", err) - } - - var result arrayEntity - if err := client.Get(ctx, key, &result); err != nil { - t.Fatalf("Get failed: %v", err) - } - - if len(result.Bools) != 3 { - t.Errorf("Expected 3 bools, got %d", len(result.Bools)) - } - if result.Bools[0] != true || result.Bools[1] != false { - t.Errorf("Bool slice values incorrect: %v", result.Bools) - } - }) - - t.Run("EmptySlices", func(t *testing.T) { - key := ds9.NameKey("ArrayTest", "empty", nil) - entity := &arrayEntity{ - Strings: []string{}, - Ints: []int64{}, - } - - _, err := client.Put(ctx, key, entity) - if err != nil { - t.Fatalf("Put with empty slices failed: %v", err) - } - - var result arrayEntity - if err := client.Get(ctx, key, &result); err != nil { - t.Fatalf("Get failed: %v", err) - } - - if result.Strings == nil || len(result.Strings) != 0 { - t.Errorf("Expected empty string slice, got %v", result.Strings) - } - if result.Ints == nil || len(result.Ints) != 0 { - t.Errorf("Expected empty int slice, got %v", result.Ints) - } - }) - - t.Run("MixedArrays", func(t *testing.T) { - key := ds9.NameKey("ArrayTest", "mixed", nil) - entity := &arrayEntity{ - Strings: []string{"a", "b"}, - Ints: []int64{10, 20, 30}, - Floats: []float64{1.5}, - Bools: []bool{true, false, true, false}, - } - - _, err := client.Put(ctx, key, entity) - if err != nil { - t.Fatalf("Put with mixed arrays failed: %v", err) - } - - var result arrayEntity - if err := client.Get(ctx, key, &result); err != nil { - t.Fatalf("Get failed: %v", err) - } - - if len(result.Strings) != 2 || len(result.Ints) != 3 || len(result.Floats) != 1 || len(result.Bools) != 4 { - t.Errorf("Mixed array lengths incorrect: strings=%d, ints=%d, floats=%d, bools=%d", - len(result.Strings), len(result.Ints), len(result.Floats), len(result.Bools)) - } - }) -} - -func TestAllocateIDs(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - t.Run("AllocateIncompleteKeys", func(t *testing.T) { - keys := []*ds9.Key{ - ds9.IncompleteKey("Task", nil), - ds9.IncompleteKey("Task", nil), - ds9.IncompleteKey("Task", nil), - } - - allocated, err := client.AllocateIDs(ctx, keys) - if err != nil { - t.Fatalf("AllocateIDs failed: %v", err) - } - - if len(allocated) != 3 { - t.Errorf("Expected 3 allocated keys, got %d", len(allocated)) - } - - for i, key := range allocated { - if key.Incomplete() { - t.Errorf("Key %d is still incomplete", i) - } - if key.ID == 0 { - t.Errorf("Key %d has zero ID", i) - } - } - }) - - t.Run("AllocateMixedKeys", func(t *testing.T) { - keys := []*ds9.Key{ - ds9.NameKey("Task", "complete", nil), - ds9.IncompleteKey("Task", nil), - ds9.IDKey("Task", 123, nil), - ds9.IncompleteKey("Task", nil), - } - - allocated, err := client.AllocateIDs(ctx, keys) - if err != nil { - t.Fatalf("AllocateIDs with mixed keys failed: %v", err) - } - - if len(allocated) != 4 { - t.Errorf("Expected 4 keys, got %d", len(allocated)) - } - - // First key should still be the named key - if allocated[0].Name != "complete" { - t.Errorf("First key should be unchanged") - } - - // Second key should now have an ID - if allocated[1].Incomplete() { - t.Errorf("Second key should be allocated") - } - - // Third key should be unchanged - if allocated[2].ID != 123 { - t.Errorf("Third key should be unchanged") - } - - // Fourth key should now have an ID - if allocated[3].Incomplete() { - t.Errorf("Fourth key should be allocated") - } - }) - - t.Run("AllocateEmptySlice", func(t *testing.T) { - keys := []*ds9.Key{} - - allocated, err := client.AllocateIDs(ctx, keys) - if err != nil { - t.Fatalf("AllocateIDs with empty slice failed: %v", err) - } - - if len(allocated) != 0 { - t.Errorf("Expected empty slice, got %d keys", len(allocated)) - } - }) - - t.Run("AllocateAllCompleteKeys", func(t *testing.T) { - keys := []*ds9.Key{ - ds9.NameKey("Task", "key1", nil), - ds9.IDKey("Task", 100, nil), - } - - allocated, err := client.AllocateIDs(ctx, keys) - if err != nil { - t.Fatalf("AllocateIDs with complete keys failed: %v", err) - } - - if len(allocated) != 2 { - t.Errorf("Expected 2 keys, got %d", len(allocated)) - } - - // Keys should be unchanged - if allocated[0].Name != "key1" || allocated[1].ID != 100 { - t.Errorf("Complete keys should be unchanged") - } - }) -} - -func TestCount(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - t.Run("CountEmptyKind", func(t *testing.T) { - q := ds9.NewQuery("NonExistent") - count, err := client.Count(ctx, q) - if err != nil { - t.Fatalf("Count failed: %v", err) - } - - if count != 0 { - t.Errorf("Expected count 0, got %d", count) - } - }) - - t.Run("CountWithEntities", func(t *testing.T) { - // Create some entities - for i := range 5 { - key := ds9.IDKey("CountTest", int64(i+1), nil) - entity := &testEntity{ - Name: fmt.Sprintf("entity-%d", i), - Count: int64(i), - } - if _, err := client.Put(ctx, key, entity); err != nil { - t.Fatalf("Put failed: %v", err) - } - } - - q := ds9.NewQuery("CountTest") - count, err := client.Count(ctx, q) - if err != nil { - t.Fatalf("Count failed: %v", err) - } - - if count != 5 { - t.Errorf("Expected count 5, got %d", count) - } - }) - - t.Run("CountWithFilter", func(t *testing.T) { - // Create entities with different counts - for i := range 10 { - key := ds9.IDKey("FilterCount", int64(i+1), nil) - entity := &testEntity{ - Name: fmt.Sprintf("entity-%d", i), - Count: int64(i), - } - if _, err := client.Put(ctx, key, entity); err != nil { - t.Fatalf("Put failed: %v", err) - } - } - - // Count entities where count >= 5 - q := ds9.NewQuery("FilterCount").Filter("count >=", int64(5)) - count, err := client.Count(ctx, q) - if err != nil { - t.Fatalf("Count with filter failed: %v", err) - } - - // Should return entities with count 5,6,7,8,9 = 5 entities - if count != 5 { - t.Errorf("Expected count 5, got %d", count) - } - }) - - t.Run("CountWithLimit", func(t *testing.T) { - // Create entities - for i := range 10 { - key := ds9.IDKey("LimitCount", int64(i+1), nil) - entity := &testEntity{ - Name: fmt.Sprintf("entity-%d", i), - Count: int64(i), - } - if _, err := client.Put(ctx, key, entity); err != nil { - t.Fatalf("Put failed: %v", err) - } - } - - // Count with limit - note: count should respect limit - q := ds9.NewQuery("LimitCount").Limit(3) - count, err := client.Count(ctx, q) - if err != nil { - t.Fatalf("Count with limit failed: %v", err) - } - - // Mock implementation may return full count, but limit is respected - if count > 10 { - t.Errorf("Count should not exceed actual entities: %d", count) - } - }) -} - -func TestQueryNamespace(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - t.Run("NamespaceFilter", func(t *testing.T) { - // Note: ds9mock may not fully support namespaces, but we test the API - q := ds9.NewQuery("Task").Namespace("custom-namespace") - - var entities []testEntity - _, err := client.GetAll(ctx, q, &entities) - // Should not error even if namespace is not supported by mock - if err != nil { - t.Logf("GetAll with namespace: %v", err) - } - }) - - t.Run("EmptyNamespace", func(t *testing.T) { - q := ds9.NewQuery("Task").Namespace("") - - var entities []testEntity - _, err := client.GetAll(ctx, q, &entities) - if err != nil { - t.Logf("GetAll with empty namespace: %v", err) - } - }) -} - -func TestQueryDistinct(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - t.Run("Distinct", func(t *testing.T) { - // Create duplicate entities - for i := range 3 { - key := ds9.IDKey("DistinctTest", int64(i+1), nil) - entity := &testEntity{ - Name: "same-name", // Same name for all - Count: int64(i), - } - if _, err := client.Put(ctx, key, entity); err != nil { - t.Fatalf("Put failed: %v", err) - } - } - - // Query with distinct on Name field - q := ds9.NewQuery("DistinctTest").Project("name").Distinct() - - var entities []testEntity - _, err := client.GetAll(ctx, q, &entities) - if err != nil { - t.Logf("GetAll with Distinct: %v", err) - } - }) - - t.Run("DistinctOn", func(t *testing.T) { - q := ds9.NewQuery("Task").DistinctOn("name", "count") - - var entities []testEntity - _, err := client.GetAll(ctx, q, &entities) - if err != nil { - t.Logf("GetAll with DistinctOn: %v", err) - } - }) -} - -func TestIterator(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - t.Run("IterateAll", func(t *testing.T) { - // Create test entities - for i := range 5 { - key := ds9.IDKey("IterTest", int64(i+1), nil) - entity := &testEntity{ - Name: fmt.Sprintf("entity-%d", i), - Count: int64(i), - } - if _, err := client.Put(ctx, key, entity); err != nil { - t.Fatalf("Put failed: %v", err) - } - } - - q := ds9.NewQuery("IterTest") - it := client.Run(ctx, q) - - count := 0 - for { - var entity testEntity - key, err := it.Next(&entity) - if errors.Is(err, ds9.ErrDone) { - break - } - if err != nil { - t.Fatalf("Iterator.Next failed: %v", err) - } - if key == nil { - t.Errorf("Expected non-nil key") - } - count++ - } - - if count != 5 { - t.Errorf("Expected to iterate over 5 entities, got %d", count) - } - }) - - t.Run("IteratorCursor", func(t *testing.T) { - // Create test entities - for i := range 3 { - key := ds9.IDKey("CursorTest", int64(i+1), nil) - entity := &testEntity{ - Name: fmt.Sprintf("entity-%d", i), - Count: int64(i), - } - if _, err := client.Put(ctx, key, entity); err != nil { - t.Fatalf("Put failed: %v", err) - } - } - - q := ds9.NewQuery("CursorTest") - it := client.Run(ctx, q) - - var entity testEntity - _, err := it.Next(&entity) - if err != nil { - t.Fatalf("Iterator.Next failed: %v", err) - } - - // Get cursor after first entity - cursor, err := it.Cursor() - if err != nil { - t.Logf("Cursor not available: %v", err) - } else if cursor == "" { - t.Logf("Empty cursor returned") - } - }) - - t.Run("EmptyIterator", func(t *testing.T) { - q := ds9.NewQuery("NonExistent") - it := client.Run(ctx, q) - - var entity testEntity - _, err := it.Next(&entity) - if !errors.Is(err, ds9.ErrDone) { - t.Errorf("Expected ErrDone, got %v", err) - } - }) -} - -func TestMutate(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - t.Run("MutateInsert", func(t *testing.T) { - key := ds9.NameKey("MutateTest", "insert", nil) - entity := &testEntity{ - Name: "inserted", - Count: 42, - } - - mut := ds9.NewInsert(key, entity) - keys, err := client.Mutate(ctx, mut) - if err != nil { - t.Fatalf("Mutate insert failed: %v", err) - } - - if len(keys) != 1 { - t.Errorf("Expected 1 key, got %d", len(keys)) - } - - // Verify entity was created - var result testEntity - if err := client.Get(ctx, key, &result); err != nil { - t.Fatalf("Get after insert failed: %v", err) - } - if result.Name != "inserted" { - t.Errorf("Expected Name 'inserted', got '%s'", result.Name) - } - }) - - t.Run("MutateUpdate", func(t *testing.T) { - key := ds9.NameKey("MutateTest", "update", nil) - entity := &testEntity{Name: "original", Count: 1} - - // Create entity first - if _, err := client.Put(ctx, key, entity); err != nil { - t.Fatalf("Put failed: %v", err) - } - - // Update via mutation - updated := &testEntity{Name: "updated", Count: 2} - mut := ds9.NewUpdate(key, updated) - _, err := client.Mutate(ctx, mut) - if err != nil { - t.Fatalf("Mutate update failed: %v", err) - } - - // Verify entity was updated - var result testEntity - if err := client.Get(ctx, key, &result); err != nil { - t.Fatalf("Get after update failed: %v", err) - } - if result.Name != "updated" { - t.Errorf("Expected Name 'updated', got '%s'", result.Name) - } - }) - - t.Run("MutateUpsert", func(t *testing.T) { - key := ds9.NameKey("MutateTest", "upsert", nil) - entity := &testEntity{Name: "upserted", Count: 100} - - mut := ds9.NewUpsert(key, entity) - keys, err := client.Mutate(ctx, mut) - if err != nil { - t.Fatalf("Mutate upsert failed: %v", err) - } - - if len(keys) != 1 { - t.Errorf("Expected 1 key, got %d", len(keys)) - } - - // Verify entity exists - var result testEntity - if err := client.Get(ctx, key, &result); err != nil { - t.Fatalf("Get after upsert failed: %v", err) - } - if result.Name != "upserted" { - t.Errorf("Expected Name 'upserted', got '%s'", result.Name) - } - }) - - t.Run("MutateDelete", func(t *testing.T) { - key := ds9.NameKey("MutateTest", "delete", nil) - entity := &testEntity{Name: "to-delete", Count: 1} - - // Create entity first - if _, err := client.Put(ctx, key, entity); err != nil { - t.Fatalf("Put failed: %v", err) - } - - // Delete via mutation - mut := ds9.NewDelete(key) - keys, err := client.Mutate(ctx, mut) - if err != nil { - t.Fatalf("Mutate delete failed: %v", err) - } - - if len(keys) != 1 { - t.Errorf("Expected 1 key, got %d", len(keys)) - } - - // Verify entity was deleted - var result testEntity - err = client.Get(ctx, key, &result) - if !errors.Is(err, ds9.ErrNoSuchEntity) { - t.Errorf("Expected ErrNoSuchEntity after delete, got %v", err) - } - }) - - t.Run("MutateMultiple", func(t *testing.T) { - key1 := ds9.NameKey("MutateTest", "multi1", nil) - key2 := ds9.NameKey("MutateTest", "multi2", nil) - key3 := ds9.NameKey("MutateTest", "multi3", nil) - - entity1 := &testEntity{Name: "first", Count: 1} - entity2 := &testEntity{Name: "second", Count: 2} - entity3 := &testEntity{Name: "third", Count: 3} - - // Pre-create entity3 for update - if _, err := client.Put(ctx, key3, &testEntity{Name: "old", Count: 0}); err != nil { - t.Fatalf("Put failed: %v", err) - } - - // Apply multiple mutations - muts := []*ds9.Mutation{ - ds9.NewInsert(key1, entity1), - ds9.NewUpsert(key2, entity2), - ds9.NewUpdate(key3, entity3), - } - - keys, err := client.Mutate(ctx, muts...) - if err != nil { - t.Fatalf("Mutate multiple failed: %v", err) - } - - if len(keys) != 3 { - t.Errorf("Expected 3 keys, got %d", len(keys)) - } - - // Verify all mutations applied - var result1, result2, result3 testEntity - if err := client.Get(ctx, key1, &result1); err != nil { - t.Errorf("Get key1 failed: %v", err) - } - if err := client.Get(ctx, key2, &result2); err != nil { - t.Errorf("Get key2 failed: %v", err) - } - if err := client.Get(ctx, key3, &result3); err != nil { - t.Errorf("Get key3 failed: %v", err) - } - - if result1.Name != "first" || result2.Name != "second" || result3.Name != "third" { - t.Errorf("Mutation results incorrect") - } - }) - - t.Run("MutateEmpty", func(t *testing.T) { - keys, err := client.Mutate(ctx) - if err != nil { - t.Fatalf("Mutate with no mutations failed: %v", err) - } - - if len(keys) != 0 { - t.Errorf("Expected empty keys, got %d", len(keys)) - } - }) -} diff --git a/example/main.go b/example/main.go index 56f65c3..8ff066e 100644 --- a/example/main.go +++ b/example/main.go @@ -8,7 +8,7 @@ import ( "log" "time" - "github.com/codeGROOVE-dev/ds9" + "github.com/codeGROOVE-dev/ds9/pkg/datastore" ) // Task represents a simple task entity. @@ -25,7 +25,7 @@ func main() { // Create a new Datastore client // The project ID will be automatically detected from the GCP metadata server - client, err := ds9.NewClient(ctx, "my-project") + client, err := datastore.NewClient(ctx, "my-project") if err != nil { log.Fatalf("Failed to create client: %v", err) } @@ -39,7 +39,7 @@ func main() { CreatedAt: time.Now(), } - key := ds9.NameKey("Task", "task-1", nil) + key := datastore.NameKey("Task", "task-1", nil) _, err = client.Put(ctx, key, task) if err != nil { log.Fatalf("Failed to put task: %v", err) @@ -63,7 +63,7 @@ func main() { fmt.Println("Task updated successfully") // Example 4: Query for all task keys - query := ds9.NewQuery("Task").KeysOnly().Limit(100) + query := datastore.NewQuery("Task").KeysOnly().Limit(100) keys, err := client.AllKeys(ctx, query) if err != nil { log.Fatalf("Failed to query tasks: %v", err) @@ -71,7 +71,7 @@ func main() { fmt.Printf("Found %d tasks\n", len(keys)) // Example 5: Use a transaction - _, err = client.RunInTransaction(ctx, func(tx *ds9.Transaction) error { + _, err = client.RunInTransaction(ctx, func(tx *datastore.Transaction) error { var current Task if err := tx.Get(key, ¤t); err != nil { return err @@ -97,7 +97,7 @@ func main() { // Example 7: Check that entity is gone err = client.Get(ctx, key, &retrieved) - if errors.Is(err, ds9.ErrNoSuchEntity) { + if errors.Is(err, datastore.ErrNoSuchEntity) { fmt.Println("Confirmed: task no longer exists") } else if err != nil { log.Fatalf("Unexpected error: %v", err) diff --git a/integration_test.go b/integration/integration_test.go similarity index 84% rename from integration_test.go rename to integration/integration_test.go index 50edfab..1c9ecb3 100644 --- a/integration_test.go +++ b/integration/integration_test.go @@ -7,8 +7,7 @@ import ( "testing" "time" - "github.com/codeGROOVE-dev/ds9" - "github.com/codeGROOVE-dev/ds9/ds9mock" + "github.com/codeGROOVE-dev/ds9/pkg/datastore" ) const ( @@ -26,13 +25,13 @@ func testProject() string { // integrationClient returns either a real GCP client or a mock client // based on whether DS9_TEST_PROJECT is set. -func integrationClient(t *testing.T) (client *ds9.Client, cleanup func()) { +func integrationClient(t *testing.T) (client *datastore.Client, cleanup func()) { t.Helper() if os.Getenv("DS9_TEST_PROJECT") != "" { // Real GCP integration test ctx := context.Background() - client, err := ds9.NewClientWithDatabase(ctx, testProject(), testDatabaseID) + client, err := datastore.NewClientWithDatabase(ctx, testProject(), testDatabaseID) if err != nil { t.Fatalf("Failed to create GCP client: %v", err) } @@ -40,7 +39,7 @@ func integrationClient(t *testing.T) (client *ds9.Client, cleanup func()) { } // Mock client for unit testing - return ds9mock.NewClient(t) + return datastore.NewMockClient(t) } func TestIntegrationBasicOperations(t *testing.T) { @@ -51,7 +50,7 @@ func TestIntegrationBasicOperations(t *testing.T) { // Generate unique key for this test run testID := t.Name() + "-" + time.Now().Format("20060102-150405.000000") - key := ds9.NameKey(testKind, testID, nil) + key := datastore.NameKey(testKind, testID, nil) // Cleanup at the end defer func() { @@ -120,8 +119,8 @@ func TestIntegrationBasicOperations(t *testing.T) { var entity integrationEntity err = client.Get(ctx, key, &entity) - if !errors.Is(err, ds9.ErrNoSuchEntity) { - t.Errorf("expected ErrNoSuchEntity after delete, got %v", err) + if !errors.Is(err, datastore.ErrNoSuchEntity) { + t.Errorf("expected datastore.ErrNoSuchEntity after delete, got %v", err) } }) } @@ -134,10 +133,10 @@ func TestIntegrationBatchOperations(t *testing.T) { // Generate unique keys for this test run testID := t.Name() + "-" + time.Now().Format("20060102-150405.000000") - keys := []*ds9.Key{ - ds9.NameKey(testKind, testID+"-1", nil), - ds9.NameKey(testKind, testID+"-2", nil), - ds9.NameKey(testKind, testID+"-3", nil), + keys := []*datastore.Key{ + datastore.NameKey(testKind, testID+"-1", nil), + datastore.NameKey(testKind, testID+"-2", nil), + datastore.NameKey(testKind, testID+"-3", nil), } // Cleanup at the end @@ -187,8 +186,8 @@ func TestIntegrationBatchOperations(t *testing.T) { var retrieved []integrationEntity err = client.GetMulti(ctx, keys, &retrieved) - if !errors.Is(err, ds9.ErrNoSuchEntity) { - t.Errorf("expected ErrNoSuchEntity after DeleteMulti, got %v", err) + if !errors.Is(err, datastore.ErrNoSuchEntity) { + t.Errorf("expected datastore.ErrNoSuchEntity after DeleteMulti, got %v", err) } }) } @@ -200,7 +199,7 @@ func TestIntegrationTransaction(t *testing.T) { ctx := context.Background() testID := t.Name() + "-" + time.Now().Format("20060102-150405.000000") - key := ds9.NameKey(testKind, testID, nil) + key := datastore.NameKey(testKind, testID, nil) defer func() { if err := client.Delete(ctx, key); err != nil { @@ -210,7 +209,7 @@ func TestIntegrationTransaction(t *testing.T) { t.Run("Transaction", func(t *testing.T) { // Create entity inside transaction to avoid contention with non-transactional operations - _, err := client.RunInTransaction(ctx, func(tx *ds9.Transaction) error { + _, err := client.RunInTransaction(ctx, func(tx *datastore.Transaction) error { // Create new entity inside transaction initial := &integrationEntity{ Name: "counter", @@ -225,7 +224,7 @@ func TestIntegrationTransaction(t *testing.T) { } // Now run another transaction to update it - _, err = client.RunInTransaction(ctx, func(tx *ds9.Transaction) error { + _, err = client.RunInTransaction(ctx, func(tx *datastore.Transaction) error { var entity integrationEntity if err := tx.Get(key, &entity); err != nil { return err @@ -260,10 +259,10 @@ func TestIntegrationQuery(t *testing.T) { // Create test entities testID := t.Name() + "-" + time.Now().Format("20060102-150405.000000") - keys := make([]*ds9.Key, 5) + keys := make([]*datastore.Key, 5) entities := make([]integrationEntity, 5) for i := range 5 { - keys[i] = ds9.NameKey(testKind, testID+"-"+string(rune('a'+i)), nil) + keys[i] = datastore.NameKey(testKind, testID+"-"+string(rune('a'+i)), nil) entities[i] = integrationEntity{ Name: "query-test", Count: int64(i), @@ -283,10 +282,10 @@ func TestIntegrationQuery(t *testing.T) { }() t.Run("AllKeys", func(t *testing.T) { - query := ds9.NewQuery(testKind).KeysOnly().Limit(10) + query := datastore.NewQuery(testKind).KeysOnly().Limit(10) resultKeys, err := client.AllKeys(ctx, query) if err != nil { - t.Fatalf("AllKeys failed: %v", err) + t.Fatalf("datastore.AllKeys failed: %v", err) } // We expect at least our 5 keys (there might be more from other tests) @@ -319,9 +318,9 @@ func TestIntegrationCleanup(t *testing.T) { // First create some test entities testID := t.Name() + "-" + time.Now().Format("20060102-150405.000000") - keys := []*ds9.Key{ - ds9.NameKey(testKind, testID+"-1", nil), - ds9.NameKey(testKind, testID+"-2", nil), + keys := []*datastore.Key{ + datastore.NameKey(testKind, testID+"-1", nil), + datastore.NameKey(testKind, testID+"-2", nil), } entities := []integrationEntity{ {Name: "cleanup-1", Count: 1, Timestamp: time.Now().UTC().Truncate(time.Microsecond)}, @@ -341,7 +340,7 @@ func TestIntegrationCleanup(t *testing.T) { } // Verify all entities are deleted - q := ds9.NewQuery(testKind).KeysOnly() + q := datastore.NewQuery(testKind).KeysOnly() keys, err := client.AllKeys(ctx, q) if err != nil { t.Fatalf("Failed to query after cleanup: %v", err) @@ -371,11 +370,11 @@ func TestIntegrationGetAll(t *testing.T) { // Setup: Create test entities kind := "DS9GetAllTest" count := 5 - keys := make([]*ds9.Key, count) + keys := make([]*datastore.Key, count) entities := make([]integrationEntity, count) for i := range count { - keys[i] = ds9.IDKey(kind, int64(i+1000), nil) // Use IDs to avoid conflicts + keys[i] = datastore.IDKey(kind, int64(i+1000), nil) // Use IDs to avoid conflicts entities[i] = integrationEntity{ Name: "getall-entity-" + string(rune('A'+i)), Count: int64(i * 100), @@ -390,11 +389,11 @@ func TestIntegrationGetAll(t *testing.T) { } // Test GetAll - query := ds9.NewQuery(kind) + query := datastore.NewQuery(kind) var results []integrationEntity returnedKeys, err := client.GetAll(ctx, query, &results) if err != nil { - t.Fatalf("GetAll failed: %v", err) + t.Fatalf("datastore.GetAll failed: %v", err) } if len(results) < count { @@ -427,11 +426,11 @@ func TestIntegrationGetAll(t *testing.T) { t.Run("GetAllWithLimit", func(t *testing.T) { kind := "DS9GetAllLimitTest" // Create 10 entities - keys := make([]*ds9.Key, 10) + keys := make([]*datastore.Key, 10) entities := make([]integrationEntity, 10) for i := range 10 { - keys[i] = ds9.IDKey(kind, int64(i+2000), nil) + keys[i] = datastore.IDKey(kind, int64(i+2000), nil) entities[i] = integrationEntity{ Name: "limit-test-" + string(rune('0'+i)), Count: int64(i), @@ -445,11 +444,11 @@ func TestIntegrationGetAll(t *testing.T) { } // Test GetAll with limit - query := ds9.NewQuery(kind).Limit(3) + query := datastore.NewQuery(kind).Limit(3) var results []integrationEntity returnedKeys, err := client.GetAll(ctx, query, &results) if err != nil { - t.Fatalf("GetAll with limit failed: %v", err) + t.Fatalf("datastore.GetAll with limit failed: %v", err) } // Should get at most 3 results @@ -470,12 +469,12 @@ func TestIntegrationGetAll(t *testing.T) { t.Run("GetAllEmpty", func(t *testing.T) { kind := "DS9NonExistentKind" - query := ds9.NewQuery(kind) + query := datastore.NewQuery(kind) var results []integrationEntity keys, err := client.GetAll(ctx, query, &results) if err != nil { - t.Fatalf("GetAll on empty kind failed: %v", err) + t.Fatalf("datastore.GetAll on empty kind failed: %v", err) } if len(results) != 0 { @@ -506,15 +505,15 @@ func TestIntegrationClose(t *testing.T) { } } -// TestIntegrationCommitReturn tests that RunInTransaction returns a Commit. +// TestIntegrationCommitReturn tests that datastore.RunInTransaction returns a datastore.Commit. func TestIntegrationCommitReturn(t *testing.T) { client, cleanup := integrationClient(t) defer cleanup() ctx := context.Background() - key := ds9.IDKey("DS9CommitTest", 9999, nil) + key := datastore.IDKey("DS9CommitTest", 9999, nil) - commit, err := client.RunInTransaction(ctx, func(tx *ds9.Transaction) error { + commit, err := client.RunInTransaction(ctx, func(tx *datastore.Transaction) error { entity := &integrationEntity{ Name: "commit-test", Count: 42, @@ -524,11 +523,11 @@ func TestIntegrationCommitReturn(t *testing.T) { return err }) if err != nil { - t.Fatalf("RunInTransaction failed: %v", err) + t.Fatalf("datastore.RunInTransaction failed: %v", err) } if commit == nil { - t.Fatal("Expected non-nil Commit, got nil") + t.Fatal("Expected non-nil datastore.Commit, got nil") } // Verify entity was created diff --git a/pkg/datastore/client.go b/pkg/datastore/client.go new file mode 100644 index 0000000..9b0e5e0 --- /dev/null +++ b/pkg/datastore/client.go @@ -0,0 +1,142 @@ +// Package datastore provides a zero-dependency Google Cloud Datastore client. +// +// It uses only the Go standard library and makes direct REST API calls +// to the Datastore API. Authentication is handled via the GCP metadata +// server when running on GCP, or via Application Default Credentials. +// +//nolint:revive // Public structs required for API compatibility with cloud.google.com/go/datastore +package datastore + +import ( + "context" + "fmt" + "log/slog" + "net/http" + "sync/atomic" + "testing" + "time" + + "github.com/codeGROOVE-dev/ds9/auth" +) + +const ( + maxRetries = 3 + maxBodySize = 10 * 1024 * 1024 // 10MB + defaultTimeout = 30 * time.Second + baseBackoffMS = 100 // Start with 100ms + maxBackoffMS = 2000 // Cap at 2 seconds + jitterFraction = 0.25 // 25% jitter +) + +var ( + // atomicAPIURL stores the API URL for thread-safe access. + // Use getAPIURL() to read and setAPIURL() to write. + atomicAPIURL atomic.Pointer[string] + + httpClient = &http.Client{ + Timeout: defaultTimeout, + Transport: &http.Transport{ + MaxIdleConns: 10, + IdleConnTimeout: 30 * time.Second, + MaxIdleConnsPerHost: 2, + }, + } + + // operatorMap converts shorthand operators to Datastore API operators. + operatorMap = map[string]string{ + "=": "EQUAL", + "<": "LESS_THAN", + "<=": "LESS_THAN_OR_EQUAL", + ">": "GREATER_THAN", + ">=": "GREATER_THAN_OR_EQUAL", + } +) + +//nolint:gochecknoinits // Required for thread-safe initialization of atomic pointer +func init() { + defaultURL := "https://datastore.googleapis.com/v1" + atomicAPIURL.Store(&defaultURL) +} + +// getAPIURL returns the current API URL in a thread-safe manner. +func getAPIURL() string { + return *atomicAPIURL.Load() +} + +// setAPIURL sets the API URL in a thread-safe manner. +func setAPIURL(url string) { + atomicAPIURL.Store(&url) +} + +// SetTestURLs configures custom metadata and API URLs for testing. +// This is intended for use by testing packages like ds9mock. +// Returns a function that restores the original URLs. +// WARNING: This function should only be called in test code. +// Set DS9_ALLOW_TEST_OVERRIDES=true to enable in non-test environments. +// +// Example: +// +// restore := ds9.SetTestURLs("http://localhost:8080", "http://localhost:9090") +// defer restore() +func SetTestURLs(metadata, api string) (restore func()) { + // Auth package will log warning if called outside test environment + oldAPI := getAPIURL() + setAPIURL(api) + restoreAuth := auth.SetMetadataURL(metadata) + return func() { + setAPIURL(oldAPI) + restoreAuth() + } +} + +// Client is a Google Cloud Datastore client. +type Client struct { + logger *slog.Logger + projectID string + databaseID string + baseURL string // API base URL, defaults to production but can be overridden for testing +} + +// NewClient creates a new Datastore client. +// If projectID is empty, it will be fetched from the GCP metadata server. +func NewClient(ctx context.Context, projectID string) (*Client, error) { + return NewClientWithDatabase(ctx, projectID, "") +} + +// NewClientWithDatabase creates a new Datastore client with a specific database. +func NewClientWithDatabase(ctx context.Context, projID, dbID string) (*Client, error) { + logger := slog.Default() + + if projID == "" { + if !testing.Testing() { + logger.InfoContext(ctx, "project ID not provided, fetching from metadata server") + } + pid, err := auth.ProjectID(ctx) + if err != nil { + logger.ErrorContext(ctx, "failed to get project ID from metadata server", "error", err) + return nil, fmt.Errorf("project ID required: %w", err) + } + projID = pid + if !testing.Testing() { + logger.InfoContext(ctx, "fetched project ID from metadata server", "project_id", projID) + } + } + + if !testing.Testing() { + logger.InfoContext(ctx, "creating datastore client", "project_id", projID, "database_id", dbID) + } + + return &Client{ + projectID: projID, + databaseID: dbID, + baseURL: getAPIURL(), + logger: logger, + }, nil +} + +// Close closes the client connection. +// This is a no-op for ds9 since it uses a shared HTTP client with connection pooling, +// but is provided for API compatibility with cloud.google.com/go/datastore. +func (*Client) Close() error { + return nil +} diff --git a/pkg/datastore/client_test.go b/pkg/datastore/client_test.go new file mode 100644 index 0000000..1feaf81 --- /dev/null +++ b/pkg/datastore/client_test.go @@ -0,0 +1,207 @@ +package datastore_test + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/codeGROOVE-dev/ds9/pkg/datastore" +) + +func TestNewClient(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + // Just verify we got a valid client + if client == nil { + t.Fatal("expected non-nil client") + } +} + +func TestNewClientWithDatabase(t *testing.T) { + // Setup mock servers + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Metadata-Flavor") != "Google" { + w.WriteHeader(http.StatusForbidden) + return + } + if r.URL.Path == "/project/project-id" { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("test-project")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + if r.URL.Path == "/instance/service-accounts/default/token" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-token", + "expires_in": 3600, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer metadataServer.Close() + + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(map[string]any{}); err != nil { + t.Logf("encode failed: %v", err) + } + })) + defer apiServer.Close() + + restore := datastore.SetTestURLs(metadataServer.URL, apiServer.URL) + defer restore() + + ctx := context.Background() + + // Test with explicit databaseID + client, err := datastore.NewClientWithDatabase(ctx, "test-project", "custom-db") + if err != nil { + t.Fatalf("NewClientWithDatabase failed: %v", err) + } + if client == nil { + t.Fatal("expected non-nil client") + } +} + +func TestSetTestURLs(t *testing.T) { + // Save original values + restore := datastore.SetTestURLs("http://test1", "http://test2") + + // Restore should work + restore() + + // Should be chainable + restore2 := datastore.SetTestURLs("http://test3", "http://test4") + restore2() +} + +func TestNewClientWithDatabaseEmptyProjectID(t *testing.T) { + // Setup mock servers + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Metadata-Flavor") != "Google" { + w.WriteHeader(http.StatusForbidden) + return + } + if r.URL.Path == "/project/project-id" { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("auto-detected-project")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + if r.URL.Path == "/instance/service-accounts/default/token" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-token", + "expires_in": 3600, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer metadataServer.Close() + + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(map[string]any{}); err != nil { + t.Logf("encode failed: %v", err) + } + })) + defer apiServer.Close() + + restore := datastore.SetTestURLs(metadataServer.URL, apiServer.URL) + defer restore() + + ctx := context.Background() + + // Test with empty projectID - should fetch from metadata + client, err := datastore.NewClientWithDatabase(ctx, "", "my-db") + if err != nil { + t.Fatalf("NewClientWithDatabase with empty projectID failed: %v", err) + } + if client == nil { + t.Fatal("expected non-nil client") + } +} + +func TestNewClientWithDatabaseProjectIDFetchFailure(t *testing.T) { + // Setup mock servers that fail to provide projectID + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Metadata-Flavor") != "Google" { + w.WriteHeader(http.StatusForbidden) + return + } + if r.URL.Path == "/project/project-id" { + // Return error instead of project ID + w.WriteHeader(http.StatusInternalServerError) + return + } + if r.URL.Path == "/instance/service-accounts/default/token" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-token", + "expires_in": 3600, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer metadataServer.Close() + + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(map[string]any{}); err != nil { + t.Logf("encode failed: %v", err) + } + })) + defer apiServer.Close() + + restore := datastore.SetTestURLs(metadataServer.URL, apiServer.URL) + defer restore() + + ctx := context.Background() + + // Test with empty projectID and failing metadata server + client, err := datastore.NewClientWithDatabase(ctx, "", "my-db") + if err == nil { + t.Fatal("expected error when projectID fetch fails, got nil") + } + if client != nil { + t.Errorf("expected nil client on error, got %v", client) + } + if !strings.Contains(err.Error(), "project ID required") { + t.Errorf("expected 'project ID required' error, got: %v", err) + } +} + +func TestClose(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + err := client.Close() + if err != nil { + t.Errorf("Close() returned unexpected error: %v", err) + } + + // Should be idempotent - can call multiple times + err = client.Close() + if err != nil { + t.Errorf("Second Close() returned unexpected error: %v", err) + } +} diff --git a/pkg/datastore/common_test.go b/pkg/datastore/common_test.go new file mode 100644 index 0000000..9f6e012 --- /dev/null +++ b/pkg/datastore/common_test.go @@ -0,0 +1,23 @@ +package datastore_test + +import ( + "time" +) + +// testEntity represents a simple test entity used across multiple test files. +type testEntity struct { + UpdatedAt time.Time `datastore:"updated_at"` + Name string `datastore:"name"` + Notes string `datastore:"notes,noindex"` + Count int64 `datastore:"count"` + Score float64 `datastore:"score"` + Active bool `datastore:"active"` +} + +// arrayEntity is used for testing slice/array fields. +type arrayEntity struct { + Strings []string `datastore:"strings,omitempty"` + Ints []int64 `datastore:"ints,omitempty"` + Floats []float64 `datastore:"floats,omitempty"` + Bools []bool `datastore:"bools,omitempty"` +} diff --git a/pkg/datastore/comprehensive_coverage_test.go b/pkg/datastore/comprehensive_coverage_test.go new file mode 100644 index 0000000..7d34116 --- /dev/null +++ b/pkg/datastore/comprehensive_coverage_test.go @@ -0,0 +1,336 @@ +package datastore_test + +import ( + "context" + "errors" + "testing" + + "github.com/codeGROOVE-dev/ds9/pkg/datastore" +) + +// TestCountComprehensive tests Count with various scenarios +func TestCountComprehensive(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + t.Run("CountWithFilter", func(t *testing.T) { + // Create test entities with varying counts + for i := range 5 { + key := datastore.IDKey("CountFilterTest", int64(i+1), nil) + entity := &testEntity{ + Name: "test", + Count: int64(i), + } + if _, err := client.Put(ctx, key, entity); err != nil { + t.Fatalf("Put failed: %v", err) + } + } + + // Count with filter + q := datastore.NewQuery("CountFilterTest").Filter("count >", 2) + count, err := client.Count(ctx, q) + if err != nil { + t.Fatalf("Count with filter failed: %v", err) + } + + // Should count entities with count > 2 (3, 4 = 2 entities) + if count != 2 { + t.Errorf("Expected count 2 with filter, got %d", count) + } + }) + + t.Run("CountZero", func(t *testing.T) { + // Count non-existent kind + q := datastore.NewQuery("NonExistentKind") + count, err := client.Count(ctx, q) + if err != nil { + t.Fatalf("Count of empty kind failed: %v", err) + } + + if count != 0 { + t.Errorf("Expected count 0 for empty kind, got %d", count) + } + }) + + t.Run("CountAll", func(t *testing.T) { + // Create entities + for i := range 3 { + key := datastore.IDKey("CountAllTest", int64(i+1), nil) + entity := &testEntity{ + Name: "test", + Count: int64(i), + } + if _, err := client.Put(ctx, key, entity); err != nil { + t.Fatalf("Put failed: %v", err) + } + } + + // Count all without filter + q := datastore.NewQuery("CountAllTest") + count, err := client.Count(ctx, q) + if err != nil { + t.Fatalf("Count all failed: %v", err) + } + + if count != 3 { + t.Errorf("Expected count 3, got %d", count) + } + }) + + t.Run("CountWithMultipleFilters", func(t *testing.T) { + // Create entities with different values + for i := range 5 { + key := datastore.IDKey("CountMultiFilterTest", int64(i+1), nil) + entity := &testEntity{ + Name: "test", + Count: int64(i), + } + if _, err := client.Put(ctx, key, entity); err != nil { + t.Fatalf("Put failed: %v", err) + } + } + + // Count with multiple filters (if supported) + q := datastore.NewQuery("CountMultiFilterTest"). + Filter("count >=", 2). + Filter("count <", 4) + count, err := client.Count(ctx, q) + if err != nil { + t.Logf("Count with multiple filters: %v (may not be supported)", err) + } else { + // Should count 2, 3 = 2 entities + if count != 2 { + t.Logf("Expected count 2 with multiple filters, got %d (composite filters may not be supported)", count) + } + } + }) +} + +// TestGetAllComprehensive tests GetAll with various edge cases +func TestGetAllComprehensive(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + t.Run("GetAllOrdered", func(t *testing.T) { + // Create entities + for i := range 5 { + key := datastore.IDKey("GetAllOrderTest", int64(i+1), nil) + entity := &testEntity{ + Name: "test", + Count: int64(5 - i), // Reverse order + } + if _, err := client.Put(ctx, key, entity); err != nil { + t.Fatalf("Put failed: %v", err) + } + } + + // Get all with order + q := datastore.NewQuery("GetAllOrderTest").Order("count") + var entities []testEntity + keys, err := client.GetAll(ctx, q, &entities) + if err != nil { + t.Logf("GetAll with order: %v (ordering may not be implemented)", err) + } else { + if len(keys) != 5 { + t.Errorf("Expected 5 keys, got %d", len(keys)) + } + } + }) + + t.Run("GetAllWithOffset", func(t *testing.T) { + // Create entities + for i := range 5 { + key := datastore.IDKey("GetAllOffsetTest", int64(i+1), nil) + entity := &testEntity{ + Name: "test", + Count: int64(i), + } + if _, err := client.Put(ctx, key, entity); err != nil { + t.Fatalf("Put failed: %v", err) + } + } + + // Get all with offset + q := datastore.NewQuery("GetAllOffsetTest").Offset(2).Limit(2) + var entities []testEntity + keys, err := client.GetAll(ctx, q, &entities) + if err != nil { + t.Logf("GetAll with offset: %v (offset may not be implemented)", err) + } else { + if len(keys) > 5 { + t.Errorf("Got too many keys: %d", len(keys)) + } + } + }) + + t.Run("GetAllKeysOnly", func(t *testing.T) { + // Create entities + for i := range 3 { + key := datastore.IDKey("GetAllKeysOnlyTest", int64(i+1), nil) + entity := &testEntity{ + Name: "test", + Count: int64(i), + } + if _, err := client.Put(ctx, key, entity); err != nil { + t.Fatalf("Put failed: %v", err) + } + } + + // Get keys only using AllKeys + q := datastore.NewQuery("GetAllKeysOnlyTest").KeysOnly() + keys, err := client.AllKeys(ctx, q) + if err != nil { + t.Fatalf("AllKeys failed: %v", err) + } + + if len(keys) != 3 { + t.Errorf("Expected 3 keys, got %d", len(keys)) + } + + // Verify keys are complete + for i, key := range keys { + if key.Incomplete() { + t.Errorf("Key %d is incomplete", i) + } + } + }) + + t.Run("GetAllEmptyResult", func(t *testing.T) { + // Query non-existent kind + q := datastore.NewQuery("EmptyResultTest") + var entities []testEntity + keys, err := client.GetAll(ctx, q, &entities) + if err != nil { + t.Fatalf("GetAll on empty kind failed: %v", err) + } + + if len(keys) != 0 { + t.Errorf("Expected 0 keys for empty result, got %d", len(keys)) + } + if len(entities) != 0 { + t.Errorf("Expected 0 entities for empty result, got %d", len(entities)) + } + }) +} + +// TestMutateComprehensive tests Mutate with various combinations +func TestMutateComprehensive(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + t.Run("MutateBatch", func(t *testing.T) { + // Create batch of mutations + mutations := []*datastore.Mutation{ + datastore.NewInsert(datastore.NameKey("MutateBatchTest", "insert1", nil), &testEntity{Name: "insert", Count: 1}), + datastore.NewInsert(datastore.NameKey("MutateBatchTest", "insert2", nil), &testEntity{Name: "insert", Count: 2}), + datastore.NewUpsert(datastore.NameKey("MutateBatchTest", "upsert1", nil), &testEntity{Name: "upsert", Count: 3}), + } + + keys, err := client.Mutate(ctx, mutations...) + if err != nil { + t.Fatalf("Mutate batch failed: %v", err) + } + + if len(keys) != 3 { + t.Errorf("Expected 3 result keys, got %d", len(keys)) + } + }) + + t.Run("MutateUpdateThenDelete", func(t *testing.T) { + // First insert + key := datastore.NameKey("MutateUpdateDeleteTest", "test", nil) + entity := &testEntity{Name: "original", Count: 1} + if _, err := client.Put(ctx, key, entity); err != nil { + t.Fatalf("Put failed: %v", err) + } + + // Update via mutation + entity.Count = 2 + updateMut := datastore.NewUpdate(key, entity) + _, err := client.Mutate(ctx, updateMut) + if err != nil { + t.Fatalf("Update mutation failed: %v", err) + } + + // Verify update + var retrieved testEntity + if err := client.Get(ctx, key, &retrieved); err != nil { + t.Fatalf("Get after update failed: %v", err) + } + if retrieved.Count != 2 { + t.Errorf("Expected count 2 after update, got %d", retrieved.Count) + } + + // Delete via mutation + deleteMut := datastore.NewDelete(key) + _, err = client.Mutate(ctx, deleteMut) + if err != nil { + t.Fatalf("Delete mutation failed: %v", err) + } + + // Verify delete + err = client.Get(ctx, key, &retrieved) + if !errors.Is(err, datastore.ErrNoSuchEntity) { + t.Errorf("Expected ErrNoSuchEntity after delete, got %v", err) + } + }) + + t.Run("MutateWithNilInBatch", func(t *testing.T) { + // Try to mutate with nil in batch + mutations := []*datastore.Mutation{ + datastore.NewInsert(datastore.NameKey("MutateNilTest", "valid", nil), &testEntity{Name: "valid", Count: 1}), + nil, // This should cause an error + } + + _, err := client.Mutate(ctx, mutations...) + if err == nil { + t.Error("Expected error for nil mutation in batch") + } + }) + + t.Run("MutateLargeBatch", func(t *testing.T) { + // Create batch + var mutations []*datastore.Mutation + for i := range 5 { + key := datastore.IDKey("MutateLargeBatchTest", int64(i+1), nil) + entity := &testEntity{Name: "batch", Count: int64(i)} + mutations = append(mutations, datastore.NewInsert(key, entity)) + } + + keys, err := client.Mutate(ctx, mutations...) + if err != nil { + t.Fatalf("Batch mutate failed: %v", err) + } + + if len(keys) != 5 { + t.Errorf("Expected 5 result keys, got %d", len(keys)) + } + }) + + t.Run("MutateUpsertNew", func(t *testing.T) { + // Upsert a new entity (insert behavior) + key := datastore.NameKey("MutateUpsertNewTest", "new", nil) + entity := &testEntity{Name: "new", Count: 99} + upsertMut := datastore.NewUpsert(key, entity) + + _, err := client.Mutate(ctx, upsertMut) + if err != nil { + t.Fatalf("Upsert new entity failed: %v", err) + } + + // Verify it was created + var retrieved testEntity + if err := client.Get(ctx, key, &retrieved); err != nil { + t.Fatalf("Get after upsert failed: %v", err) + } + if retrieved.Count != 99 { + t.Errorf("Expected count 99, got %d", retrieved.Count) + } + }) +} diff --git a/pkg/datastore/cursor.go b/pkg/datastore/cursor.go new file mode 100644 index 0000000..7acd7fc --- /dev/null +++ b/pkg/datastore/cursor.go @@ -0,0 +1,21 @@ +package datastore + +import "errors" + +// Cursor represents a query cursor for pagination. +// API compatible with cloud.google.com/go/datastore. +type Cursor string + +// String returns the cursor as a string. +func (c Cursor) String() string { + return string(c) +} + +// DecodeCursor decodes a cursor string. +// API compatible with cloud.google.com/go/datastore. +func DecodeCursor(s string) (Cursor, error) { + if s == "" { + return "", errors.New("empty cursor string") + } + return Cursor(s), nil +} diff --git a/pkg/datastore/cursor_coverage_test.go b/pkg/datastore/cursor_coverage_test.go new file mode 100644 index 0000000..06c146d --- /dev/null +++ b/pkg/datastore/cursor_coverage_test.go @@ -0,0 +1,138 @@ +package datastore_test + +import ( + "context" + "errors" + "testing" + + "github.com/codeGROOVE-dev/ds9/pkg/datastore" +) + +// TestCursorWithPagination tests the Cursor() method with actual cursor from query +func TestCursorWithPagination(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Create enough entities to trigger pagination + for i := range 3 { + key := datastore.IDKey("CursorTest", int64(i+1), nil) + entity := &testEntity{ + Name: "test", + Count: int64(i), + } + if _, err := client.Put(ctx, key, entity); err != nil { + t.Fatalf("Put failed: %v", err) + } + } + + // Query with limit to trigger cursor generation + q := datastore.NewQuery("CursorTest").Limit(2) + it := client.Run(ctx, q) + + // Fetch first result + var entity testEntity + _, err := it.Next(&entity) + if err != nil { + t.Fatalf("First Next failed: %v", err) + } + + // Now try to get cursor - should be available after fetching results + cursor, err := it.Cursor() + if err != nil { + t.Logf("Cursor() returned error (mock implementation): %v", err) + // Mock might not support cursors yet, that's OK + } else { + // If cursor is available, verify it's not empty + if cursor == "" { + t.Error("Expected non-empty cursor after fetching with limit") + } else { + t.Logf("Successfully got cursor: %s", cursor) + + // Verify we can convert cursor to string + cursorStr := cursor.String() + if cursorStr == "" { + t.Error("Cursor.String() returned empty string") + } + } + } +} + +// TestCursorBeforeFetch tests Cursor() before any results are fetched +func TestCursorBeforeFetch(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Create test entity + key := datastore.NameKey("CursorTest2", "test", nil) + entity := &testEntity{Name: "test", Count: 1} + if _, err := client.Put(ctx, key, entity); err != nil { + t.Fatalf("Put failed: %v", err) + } + + // Create iterator but don't fetch any results + q := datastore.NewQuery("CursorTest2") + it := client.Run(ctx, q) + + // Try to get cursor before fetching - should fail + cursor, err := it.Cursor() + if err == nil { + t.Error("Expected error when getting cursor before fetching results") + } + if cursor != "" { + t.Errorf("Expected empty cursor before fetching, got: %s", cursor) + } +} + +// TestCursorWithLimitedResults tests cursor behavior with pagination +func TestCursorWithLimitedResults(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Create multiple entities + for i := range 3 { + key := datastore.IDKey("CursorPaginationTest", int64(i+1), nil) + entity := &testEntity{ + Name: "test", + Count: int64(i), + } + if _, err := client.Put(ctx, key, entity); err != nil { + t.Fatalf("Put failed: %v", err) + } + } + + // Query with limit smaller than total entities + q := datastore.NewQuery("CursorPaginationTest").Limit(2) + it := client.Run(ctx, q) + + // Fetch all limited results + count := 0 + for { + var entity testEntity + _, err := it.Next(&entity) + if errors.Is(err, datastore.ErrDone) { + break + } + if err != nil { + t.Fatalf("Iterator Next failed: %v", err) + } + count++ + + // Try to get cursor after each fetch + cursor, err := it.Cursor() + if err != nil { + t.Logf("Cursor not available at position %d: %v", count, err) + } else if cursor != "" { + t.Logf("Got cursor at position %d: %s", count, cursor) + } + } + + if count != 2 { + t.Errorf("Expected 2 results with limit, got %d", count) + } +} diff --git a/pkg/datastore/cursor_test.go b/pkg/datastore/cursor_test.go new file mode 100644 index 0000000..fb8bba1 --- /dev/null +++ b/pkg/datastore/cursor_test.go @@ -0,0 +1,72 @@ +package datastore + +import ( + "testing" +) + +func TestCursorString(t *testing.T) { + tests := []struct { + name string + cursor Cursor + expected string + }{ + { + name: "non-empty cursor", + cursor: Cursor("test-cursor-123"), + expected: "test-cursor-123", + }, + { + name: "empty cursor", + cursor: Cursor(""), + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.cursor.String() + if result != tt.expected { + t.Errorf("Expected %q, got %q", tt.expected, result) + } + }) + } +} + +func TestDecodeCursor(t *testing.T) { + tests := []struct { + name string + cursorStr string + expectError bool + expected Cursor + }{ + { + name: "valid cursor", + cursorStr: "valid-cursor-string", + expectError: false, + expected: Cursor("valid-cursor-string"), + }, + { + name: "empty cursor string", + cursorStr: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cursor, err := DecodeCursor(tt.cursorStr) + if tt.expectError { + if err == nil { + t.Error("Expected error, got nil") + } + } else { + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if cursor != tt.expected { + t.Errorf("Expected cursor %q, got %q", tt.expected, cursor) + } + } + }) + } +} diff --git a/pkg/datastore/encode_coverage_test.go b/pkg/datastore/encode_coverage_test.go new file mode 100644 index 0000000..0292c1d --- /dev/null +++ b/pkg/datastore/encode_coverage_test.go @@ -0,0 +1,160 @@ +package datastore + +import ( + "testing" + "time" +) + +// Test encodeValue with reflection-based slice handling +func TestEncodeValue_ReflectionSlices(t *testing.T) { + tests := []struct { + name string + value any + }{ + { + "array of strings", + [3]string{"a", "b", "c"}, + }, + { + "array of ints", + [2]int{1, 2}, + }, + { + "array of int64", + [2]int64{100, 200}, + }, + { + "nested time slice", + []time.Time{ + time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(2025, 1, 2, 0, 0, 0, 0, time.UTC), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := encodeValue(tt.value) + if err != nil { + t.Errorf("encodeValue(%v) failed: %v", tt.value, err) + } + if result == nil { + t.Error("Expected non-nil result") + } + }) + } +} + +// Test encodeValue error paths +func TestEncodeValue_Errors(t *testing.T) { + tests := []struct { + name string + value any + }{ + { + "map type", + map[string]int{"key": 1}, + }, + { + "function type", + func() {}, + }, + { + "channel type", + make(chan int), + }, + { + "struct type", + struct{ Name string }{Name: "test"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := encodeValue(tt.value) + if err == nil { + t.Errorf("encodeValue(%T) should have returned an error", tt.value) + } + }) + } +} + +// Test encodeValue with slice of time.Time (uses reflection path) +func TestEncodeValue_TimeSlice(t *testing.T) { + now := time.Now().UTC().Truncate(time.Microsecond) + later := now.Add(time.Hour) + + timeSlice := []time.Time{now, later} + + result, err := encodeValue(timeSlice) + if err != nil { + t.Fatalf("encodeValue failed for time slice: %v", err) + } + + // Verify it's wrapped in arrayValue + resultMap, ok := result.(map[string]any) + if !ok { + t.Fatalf("Expected map[string]any, got %T", result) + } + + arrayValue, ok := resultMap["arrayValue"] + if !ok { + t.Error("Expected arrayValue key in result") + } + + if arrayValue == nil { + t.Error("arrayValue should not be nil") + } +} + +// Test encodeValue with empty slices +func TestEncodeValue_EmptySlices(t *testing.T) { + tests := []struct { + name string + value any + }{ + {"empty string slice", []string{}}, + {"empty int slice", []int{}}, + {"empty int64 slice", []int64{}}, + {"empty float64 slice", []float64{}}, + {"empty bool slice", []bool{}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := encodeValue(tt.value) + if err != nil { + t.Errorf("encodeValue failed: %v", err) + } + if result == nil { + t.Error("Expected non-nil result for empty slice") + } + }) + } +} + +// Test encodeValue with single element slices +func TestEncodeValue_SingleElementSlices(t *testing.T) { + tests := []struct { + name string + value any + }{ + {"single string", []string{"only"}}, + {"single int", []int{42}}, + {"single int64", []int64{42}}, + {"single float64", []float64{3.14}}, + {"single bool", []bool{true}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := encodeValue(tt.value) + if err != nil { + t.Errorf("encodeValue failed: %v", err) + } + if result == nil { + t.Error("Expected non-nil result") + } + }) + } +} diff --git a/pkg/datastore/entity.go b/pkg/datastore/entity.go new file mode 100644 index 0000000..edda13d --- /dev/null +++ b/pkg/datastore/entity.go @@ -0,0 +1,313 @@ +package datastore + +import ( + "errors" + "fmt" + "reflect" + "strconv" + "strings" + "time" +) + +// encodeEntity converts a Go struct to a Datastore entity. +func encodeEntity(key *Key, src any) (map[string]any, error) { + v := reflect.ValueOf(src) + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + + if v.Kind() != reflect.Struct { + return nil, errors.New("src must be a struct or pointer to struct") + } + + t := v.Type() + properties := make(map[string]any) + + for i := range v.NumField() { + field := t.Field(i) + value := v.Field(i) + + // Skip unexported fields + if !field.IsExported() { + continue + } + + // Get field name from datastore tag or use field name + name := field.Name + noIndex := false + + if tag := field.Tag.Get("datastore"); tag != "" { + parts := strings.Split(tag, ",") + if parts[0] != "" && parts[0] != "-" { + name = parts[0] + } + if len(parts) > 1 && parts[1] == "noindex" { + noIndex = true + } + if parts[0] == "-" { + continue + } + } + + prop, err := encodeValue(value.Interface()) + if err != nil { + return nil, fmt.Errorf("field %s: %w", field.Name, err) + } + + if noIndex { + if m, ok := prop.(map[string]any); ok { + m["excludeFromIndexes"] = true + } + } + + properties[name] = prop + } + + return map[string]any{ + "key": keyToJSON(key), + "properties": properties, + }, nil +} + +// encodeValue converts a Go value to a Datastore property value. +func encodeValue(v any) (any, error) { + if v == nil { + return map[string]any{"nullValue": nil}, nil + } + + switch val := v.(type) { + case string: + return map[string]any{"stringValue": val}, nil + case int: + return map[string]any{"integerValue": strconv.Itoa(val)}, nil + case int64: + return map[string]any{"integerValue": strconv.FormatInt(val, 10)}, nil + case int32: + return map[string]any{"integerValue": strconv.Itoa(int(val))}, nil + case bool: + return map[string]any{"booleanValue": val}, nil + case float64: + return map[string]any{"doubleValue": val}, nil + case time.Time: + return map[string]any{"timestampValue": val.Format(time.RFC3339Nano)}, nil + case []string: + values := make([]map[string]any, len(val)) + for i, s := range val { + values[i] = map[string]any{"stringValue": s} + } + return map[string]any{"arrayValue": map[string]any{"values": values}}, nil + case []int64: + values := make([]map[string]any, len(val)) + for i, n := range val { + values[i] = map[string]any{"integerValue": strconv.FormatInt(n, 10)} + } + return map[string]any{"arrayValue": map[string]any{"values": values}}, nil + case []int: + values := make([]map[string]any, len(val)) + for i, n := range val { + values[i] = map[string]any{"integerValue": strconv.Itoa(n)} + } + return map[string]any{"arrayValue": map[string]any{"values": values}}, nil + case []float64: + values := make([]map[string]any, len(val)) + for i, f := range val { + values[i] = map[string]any{"doubleValue": f} + } + return map[string]any{"arrayValue": map[string]any{"values": values}}, nil + case []bool: + values := make([]map[string]any, len(val)) + for i, b := range val { + values[i] = map[string]any{"booleanValue": b} + } + return map[string]any{"arrayValue": map[string]any{"values": values}}, nil + default: + // Try to handle slices/arrays via reflection + rv := reflect.ValueOf(v) + if rv.Kind() == reflect.Slice || rv.Kind() == reflect.Array { + length := rv.Len() + values := make([]map[string]any, length) + for i := range length { + elem := rv.Index(i).Interface() + encodedElem, err := encodeValue(elem) + if err != nil { + return nil, fmt.Errorf("failed to encode array element %d: %w", i, err) + } + // encodedElem is already a map[string]any with the type wrapper + m, ok := encodedElem.(map[string]any) + if !ok { + return nil, fmt.Errorf("unexpected encoded value type for element %d", i) + } + values[i] = m + } + return map[string]any{"arrayValue": map[string]any{"values": values}}, nil + } + return nil, fmt.Errorf("unsupported type: %T", v) + } +} + +// decodeEntity converts a Datastore entity to a Go struct. +func decodeEntity(entity map[string]any, dst any) error { + v := reflect.ValueOf(dst) + if v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Struct { + return errors.New("dst must be a pointer to struct") + } + + v = v.Elem() + t := v.Type() + + properties, ok := entity["properties"].(map[string]any) + if !ok { + return errors.New("invalid entity format") + } + + for i := range v.NumField() { + field := t.Field(i) + value := v.Field(i) + + if !field.IsExported() { + continue + } + + // Get field name from datastore tag + name := field.Name + if tag := field.Tag.Get("datastore"); tag != "" { + parts := strings.Split(tag, ",") + if parts[0] != "" && parts[0] != "-" { + name = parts[0] + } + if parts[0] == "-" { + continue + } + } + + prop, ok := properties[name] + if !ok { + continue // Field not in entity + } + + propMap, ok := prop.(map[string]any) + if !ok { + continue + } + + if err := decodeValue(propMap, value); err != nil { + return fmt.Errorf("field %s: %w", field.Name, err) + } + } + + return nil +} + +// decodeValue decodes a Datastore property value into a Go reflect.Value. +func decodeValue(prop map[string]any, dst reflect.Value) error { + // Handle each type + if val, ok := prop["stringValue"]; ok { + if dst.Kind() == reflect.String { + if s, ok := val.(string); ok { + dst.SetString(s) + return nil + } + } + } + + if val, ok := prop["integerValue"]; ok { + var intVal int64 + switch v := val.(type) { + case string: + if _, err := fmt.Sscanf(v, "%d", &intVal); err != nil { + return fmt.Errorf("invalid integer format: %w", err) + } + case float64: + intVal = int64(v) + } + + switch dst.Kind() { + case reflect.Int, reflect.Int64, reflect.Int32: + dst.SetInt(intVal) + return nil + default: + return fmt.Errorf("unsupported integer type: %v", dst.Kind()) + } + } + + if val, ok := prop["booleanValue"]; ok { + if dst.Kind() == reflect.Bool { + if b, ok := val.(bool); ok { + dst.SetBool(b) + return nil + } + } + } + + if val, ok := prop["doubleValue"]; ok { + if dst.Kind() == reflect.Float64 { + if f, ok := val.(float64); ok { + dst.SetFloat(f) + return nil + } + } + } + + if val, ok := prop["timestampValue"]; ok { + if dst.Type() == reflect.TypeOf(time.Time{}) { + if s, ok := val.(string); ok { + t, err := time.Parse(time.RFC3339Nano, s) + if err != nil { + return err + } + dst.Set(reflect.ValueOf(t)) + return nil + } + } + } + + if val, ok := prop["arrayValue"]; ok { + if dst.Kind() != reflect.Slice { + return fmt.Errorf("cannot decode array into non-slice type: %s", dst.Type()) + } + + arrayMap, ok := val.(map[string]any) + if !ok { + return errors.New("invalid arrayValue format") + } + + valuesAny, ok := arrayMap["values"] + if !ok { + // Empty array + dst.Set(reflect.MakeSlice(dst.Type(), 0, 0)) + return nil + } + + values, ok := valuesAny.([]any) + if !ok { + return errors.New("invalid arrayValue.values format") + } + + // Create slice with appropriate capacity + slice := reflect.MakeSlice(dst.Type(), len(values), len(values)) + + // Decode each element + for i, elemAny := range values { + elemMap, ok := elemAny.(map[string]any) + if !ok { + return fmt.Errorf("invalid array element %d format", i) + } + + elemValue := slice.Index(i) + if err := decodeValue(elemMap, elemValue); err != nil { + return fmt.Errorf("failed to decode array element %d: %w", i, err) + } + } + + dst.Set(slice) + return nil + } + + if _, ok := prop["nullValue"]; ok { + // Set to zero value + dst.Set(reflect.Zero(dst.Type())) + return nil + } + + return fmt.Errorf("unsupported property type for %s", dst.Type()) +} diff --git a/pkg/datastore/entity_coverage_test.go b/pkg/datastore/entity_coverage_test.go new file mode 100644 index 0000000..3d8017e --- /dev/null +++ b/pkg/datastore/entity_coverage_test.go @@ -0,0 +1,311 @@ +package datastore + +import ( + "context" + "errors" + "testing" + "time" +) + +// Additional tests to improve coverage for entity encoding/decoding + +func TestEncodeValue_AllTypes(t *testing.T) { + tests := []struct { + name string + value any + }{ + {"nil", nil}, + {"string", "test"}, + {"int", int(42)}, + {"int64", int64(42)}, + {"int32", int32(42)}, + {"bool", true}, + {"float64", float64(3.14)}, + {"time", time.Now().UTC()}, + {"string slice", []string{"a", "b"}}, + {"int64 slice", []int64{1, 2}}, + {"int slice", []int{1, 2}}, + {"float64 slice", []float64{1.1, 2.2}}, + {"bool slice", []bool{true, false}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := encodeValue(tt.value) + if err != nil { + t.Errorf("encodeValue failed: %v", err) + } + }) + } +} + +func TestEncodeValue_UnsupportedType(t *testing.T) { + _, err := encodeValue(map[string]string{"key": "value"}) + if err == nil { + t.Error("Expected error for unsupported type, got nil") + } +} + +func TestMutate_Coverage(t *testing.T) { + client, cleanup := NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + key := NameKey("MutateTest", "test1", nil) + entity := &struct { + Name string + Count int64 + Time time.Time + }{ + Name: "test", + Count: 42, + Time: time.Now().UTC().Truncate(time.Microsecond), + } + + t.Run("InsertMutation", func(t *testing.T) { + insertMut := NewInsert(key, entity) + _, err := client.Mutate(ctx, insertMut) + if err != nil { + t.Fatalf("Insert mutation failed: %v", err) + } + }) + + t.Run("UpdateMutation", func(t *testing.T) { + entity.Count = 100 + updateMut := NewUpdate(key, entity) + _, err := client.Mutate(ctx, updateMut) + if err != nil { + t.Fatalf("Update mutation failed: %v", err) + } + }) + + t.Run("UpsertMutation", func(t *testing.T) { + entity.Count = 200 + upsertMut := NewUpsert(key, entity) + _, err := client.Mutate(ctx, upsertMut) + if err != nil { + t.Fatalf("Upsert mutation failed: %v", err) + } + }) + + t.Run("DeleteMutation", func(t *testing.T) { + deleteMut := NewDelete(key) + _, err := client.Mutate(ctx, deleteMut) + if err != nil { + t.Fatalf("Delete mutation failed: %v", err) + } + }) + + t.Run("EmptyMutations", func(t *testing.T) { + // Test with no mutations + keys, err := client.Mutate(ctx) + if err != nil { + t.Errorf("Empty Mutate should not error: %v", err) + } + if keys != nil { + t.Errorf("Expected nil keys for empty mutate, got %v", keys) + } + }) + + t.Run("NilMutation", func(t *testing.T) { + // Test with nil mutation + _, err := client.Mutate(ctx, nil) + if err == nil { + t.Error("Expected error for nil mutation") + } + }) + + t.Run("NilKey", func(t *testing.T) { + // Test mutation with nil key + mut := &Mutation{ + key: nil, + op: MutationInsert, + entity: entity, + } + _, err := client.Mutate(ctx, mut) + if err == nil { + t.Error("Expected error for mutation with nil key") + } + }) + + t.Run("NilEntityInsert", func(t *testing.T) { + // Test insert with nil entity + mut := NewInsert(key, nil) + _, err := client.Mutate(ctx, mut) + if err == nil { + t.Error("Expected error for insert with nil entity") + } + }) + + t.Run("NilEntityUpdate", func(t *testing.T) { + // Test update with nil entity + mut := NewUpdate(key, nil) + _, err := client.Mutate(ctx, mut) + if err == nil { + t.Error("Expected error for update with nil entity") + } + }) + + t.Run("NilEntityUpsert", func(t *testing.T) { + // Test upsert with nil entity + mut := NewUpsert(key, nil) + _, err := client.Mutate(ctx, mut) + if err == nil { + t.Error("Expected error for upsert with nil entity") + } + }) + + t.Run("MultipleMutations", func(t *testing.T) { + // Test multiple mutations at once + key1 := NameKey("MutateTest", "batch1", nil) + key2 := NameKey("MutateTest", "batch2", nil) + key3 := NameKey("MutateTest", "batch3", nil) + + muts := []*Mutation{ + NewInsert(key1, entity), + NewUpsert(key2, entity), + NewDelete(key3), + } + + _, err := client.Mutate(ctx, muts...) + if err != nil { + t.Fatalf("Multiple mutations failed: %v", err) + } + }) +} + +func TestIterator_Coverage(t *testing.T) { + client, cleanup := NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Create test entities + for i := range 5 { + key := IDKey("IteratorTest", int64(i+1), nil) + entity := &struct { + Name string + Index int64 + Time time.Time + }{ + Name: "test", + Index: int64(i), + Time: time.Now().UTC().Truncate(time.Microsecond), + } + if _, err := client.Put(ctx, key, entity); err != nil { + t.Fatalf("Put failed: %v", err) + } + } + + // Test iterator with cursor + query := NewQuery("IteratorTest").Limit(2) + it := client.Run(ctx, query) + + count := 0 + for { + var entity struct { + Name string + Index int64 + Time time.Time + } + _, err := it.Next(&entity) + if errors.Is(err, ErrDone) { + break + } + if err != nil { + t.Fatalf("Iterator Next failed: %v", err) + } + count++ + + // Test Cursor() method + cursor, err := it.Cursor() + if err != nil { + t.Logf("Cursor() error (expected for some backends): %v", err) + } else if cursor.String() != "" { + t.Logf("Got cursor: %s", cursor.String()) + } + } + + if count == 0 { + t.Error("Expected at least some entities from iterator") + } +} + +func TestAllocateIDs_Coverage(t *testing.T) { + client, cleanup := NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Create incomplete keys + keys := []*Key{ + IncompleteKey("AllocateTest", nil), + IncompleteKey("AllocateTest", nil), + IncompleteKey("AllocateTest", nil), + } + + // Allocate IDs + allocatedKeys, err := client.AllocateIDs(ctx, keys) + if err != nil { + t.Fatalf("AllocateIDs failed: %v", err) + } + + if len(allocatedKeys) != len(keys) { + t.Errorf("Expected %d allocated keys, got %d", len(keys), len(allocatedKeys)) + } + + // Verify keys have IDs + for i, key := range allocatedKeys { + if key.Incomplete() { + t.Errorf("Key %d is still incomplete after allocation", i) + } + } +} + +func TestGet_ErrorCases(t *testing.T) { + client, cleanup := NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + t.Run("NonExistent", func(t *testing.T) { + key := NameKey("NonExistent", "test", nil) + var entity struct { + Name string + } + err := client.Get(ctx, key, &entity) + if !errors.Is(err, ErrNoSuchEntity) { + t.Errorf("Expected ErrNoSuchEntity, got %v", err) + } + }) + + t.Run("NilKey", func(t *testing.T) { + var entity struct { + Name string + } + err := client.Get(ctx, nil, &entity) + if err == nil { + t.Error("Expected error for nil key, got nil") + } + }) +} + +func TestNewClientWithDatabase_Coverage(t *testing.T) { + SetTestURLs("http://localhost:8080/datastore", "http://localhost:8080/token") + + ctx := context.Background() + + // Test with database ID + _, err := NewClientWithDatabase(ctx, "test-project", "test-db") + if err != nil { + // Expected to fail without real backend, but we exercise the code path + t.Logf("NewClientWithDatabase failed as expected: %v", err) + } + + // Test with empty project (error case) + _, err = NewClientWithDatabase(ctx, "", "test-db") + if err == nil { + t.Error("Expected error for empty project ID, got nil") + } +} diff --git a/pkg/datastore/entity_test.go b/pkg/datastore/entity_test.go new file mode 100644 index 0000000..fa07fcc --- /dev/null +++ b/pkg/datastore/entity_test.go @@ -0,0 +1,816 @@ +package datastore_test + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/codeGROOVE-dev/ds9/pkg/datastore" +) + +func TestEntityWithAllTypes(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + type AllTypes struct { + TimeVal time.Time `datastore:"t"` + StringVal string `datastore:"str"` + NoIndex string `datastore:"noindex,noindex"` + Skip string `datastore:"-"` + Int64Val int64 `datastore:"i64"` + IntVal int `datastore:"i"` + Float64Val float64 `datastore:"f64"` + Int32Val int32 `datastore:"i32"` + BoolVal bool `datastore:"b"` + } + + now := time.Now().UTC().Truncate(time.Second) + entity := &AllTypes{ + StringVal: "test", + Int64Val: int64(123), + Int32Val: int32(456), + IntVal: 789, + BoolVal: true, + Float64Val: 3.14, + TimeVal: now, + NoIndex: "not indexed", + Skip: "should not be stored", + } + + key := datastore.NameKey("AllTypes", "test", nil) + _, err := client.Put(ctx, key, entity) + if err != nil { + t.Fatalf("Put failed: %v", err) + } + + var retrieved AllTypes + err = client.Get(ctx, key, &retrieved) + if err != nil { + t.Fatalf("Get failed: %v", err) + } + + if retrieved.StringVal != entity.StringVal { + t.Errorf("StringVal: expected %v, got %v", entity.StringVal, retrieved.StringVal) + } + if retrieved.Int64Val != entity.Int64Val { + t.Errorf("Int64Val: expected %v, got %v", entity.Int64Val, retrieved.Int64Val) + } + if retrieved.Int32Val != entity.Int32Val { + t.Errorf("Int32Val: expected %v, got %v", entity.Int32Val, retrieved.Int32Val) + } + if retrieved.IntVal != entity.IntVal { + t.Errorf("IntVal: expected %v, got %v", entity.IntVal, retrieved.IntVal) + } + if retrieved.BoolVal != entity.BoolVal { + t.Errorf("BoolVal: expected %v, got %v", entity.BoolVal, retrieved.BoolVal) + } + if retrieved.Float64Val != entity.Float64Val { + t.Errorf("Float64Val: expected %v, got %v", entity.Float64Val, retrieved.Float64Val) + } + if !retrieved.TimeVal.Equal(entity.TimeVal) { + t.Errorf("TimeVal: expected %v, got %v", entity.TimeVal, retrieved.TimeVal) + } + if retrieved.NoIndex != entity.NoIndex { + t.Errorf("NoIndex: expected %v, got %v", entity.NoIndex, retrieved.NoIndex) + } + if retrieved.Skip != "" { + t.Errorf("Skip field should be empty, got %q", retrieved.Skip) + } +} + +func TestUnsupportedEncodeType(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Entity with unsupported type (map) + type BadEntity struct { + Name string + Data map[string]string // maps not supported + } + + key := datastore.NameKey("TestKind", "bad", nil) + entity := BadEntity{ + Name: "test", + Data: map[string]string{"key": "value"}, + } + + _, err := client.Put(ctx, key, &entity) + if err == nil { + t.Error("expected error for unsupported type, got nil") + } + if !strings.Contains(err.Error(), "unsupported type") { + t.Errorf("expected 'unsupported type' error, got: %v", err) + } +} + +func TestDecodeNonPointer(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Store entity + key := datastore.NameKey("TestKind", "test", nil) + entity := testEntity{Name: "test", Count: 42} + _, err := client.Put(ctx, key, &entity) + if err != nil { + t.Fatalf("Put failed: %v", err) + } + + // Try to decode into non-pointer + var notPtr testEntity + err = client.Get(ctx, key, notPtr) // Should be ¬Ptr + if err == nil { + t.Error("expected error for non-pointer dst, got nil") + } + if !strings.Contains(err.Error(), "pointer to struct") { + t.Errorf("expected 'pointer to struct' error, got: %v", err) + } +} + +func TestDecodePointerToNonStruct(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Store entity + key := datastore.NameKey("TestKind", "test", nil) + entity := testEntity{Name: "test", Count: 42} + _, err := client.Put(ctx, key, &entity) + if err != nil { + t.Fatalf("Put failed: %v", err) + } + + // Try to decode into pointer to string + var str string + err = client.Get(ctx, key, &str) + if err == nil { + t.Error("expected error for pointer to non-struct, got nil") + } + if !strings.Contains(err.Error(), "pointer to struct") { + t.Errorf("expected 'pointer to struct' error, got: %v", err) + } +} + +func TestEntityWithSkippedFields(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + type EntityWithSkip struct { + Name string `datastore:"name"` + Skipped string `datastore:"-"` + private string + Count int64 `datastore:"count"` + } + + key := datastore.NameKey("TestKind", "skip", nil) + entity := EntityWithSkip{ + Name: "test", + Count: 42, + Skipped: "should not store", + private: "also not stored", + } + + _, err := client.Put(ctx, key, &entity) + if err != nil { + t.Fatalf("Put failed: %v", err) + } + + var retrieved EntityWithSkip + err = client.Get(ctx, key, &retrieved) + if err != nil { + t.Fatalf("Get failed: %v", err) + } + + if retrieved.Name != entity.Name || retrieved.Count != entity.Count { + t.Errorf("wrong values: got %+v", retrieved) + } + + // Skipped field should be zero value + if retrieved.Skipped != "" { + t.Errorf("Skipped field should be empty, got %q", retrieved.Skipped) + } +} + +func TestZeroValueEntity(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + type ZeroEntity struct { + Name string + Count int64 + Active bool + Score float64 + } + + key := datastore.NameKey("TestKind", "zero", nil) + entity := ZeroEntity{} // All zero values + + _, err := client.Put(ctx, key, &entity) + if err != nil { + t.Fatalf("Put with zero values failed: %v", err) + } + + var retrieved ZeroEntity + err = client.Get(ctx, key, &retrieved) + if err != nil { + t.Fatalf("Get failed: %v", err) + } + + if retrieved.Name != "" || retrieved.Count != 0 || retrieved.Active != false || retrieved.Score != 0.0 { + t.Errorf("expected zero values, got %+v", retrieved) + } +} + +func TestDecodeValueEdgeCases(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Test with all basic types + type ComplexEntity struct { + Time time.Time `datastore:"t"` + String string `datastore:"s"` + NoIndex string `datastore:"n,noindex"` + Int int `datastore:"i"` + Int64 int64 `datastore:"i64"` + Float float64 `datastore:"f"` + Int32 int32 `datastore:"i32"` + Bool bool `datastore:"b"` + } + + now := time.Now().UTC().Truncate(time.Second) + key := datastore.NameKey("Complex", "test", nil) + entity := &ComplexEntity{ + String: "test", + Int: 42, + Int32: 32, + Int64: 64, + Float: 3.14, + Bool: true, + Time: now, + NoIndex: "not indexed", + } + + _, err := client.Put(ctx, key, entity) + if err != nil { + t.Fatalf("Put failed: %v", err) + } + + var retrieved ComplexEntity + err = client.Get(ctx, key, &retrieved) + if err != nil { + t.Fatalf("Get failed: %v", err) + } + + if retrieved.String != entity.String { + t.Errorf("String mismatch") + } + if retrieved.Int != entity.Int { + t.Errorf("Int mismatch") + } + if retrieved.Int32 != entity.Int32 { + t.Errorf("Int32 mismatch") + } + if retrieved.Int64 != entity.Int64 { + t.Errorf("Int64 mismatch") + } + if retrieved.Float != entity.Float { + t.Errorf("Float mismatch") + } + if retrieved.Bool != entity.Bool { + t.Errorf("Bool mismatch") + } + if !retrieved.Time.Equal(entity.Time) { + t.Errorf("Time mismatch") + } + if retrieved.NoIndex != entity.NoIndex { + t.Errorf("NoIndex mismatch") + } +} + +func TestEntityWithPointerFields(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Entities with pointer fields + type EntityWithPointers struct { + Name *string `datastore:"name"` + Count *int64 `datastore:"count"` + } + + name := "test" + count := int64(42) + key := datastore.NameKey("Pointers", "test", nil) + entity := &EntityWithPointers{ + Name: &name, + Count: &count, + } + + // Note: The current implementation doesn't support pointer fields + // This test documents the expected behavior + _, err := client.Put(ctx, key, entity) + if err == nil { + // If it succeeds, that's fine (future enhancement) + var retrieved EntityWithPointers + err = client.Get(ctx, key, &retrieved) + if err != nil { + t.Logf("Get after Put with pointers failed: %v", err) + } + } else { + // Expected to fail with current implementation + t.Logf("Put with pointer fields failed as expected: %v", err) + } +} + +func TestEntityWithEmptyStringFields(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + key := datastore.NameKey("Empty", "test", nil) + entity := &testEntity{ + Name: "", // empty string + Count: 0, // zero + Active: false, // false + Score: 0.0, // zero float + } + + _, err := client.Put(ctx, key, entity) + if err != nil { + t.Fatalf("Put with empty/zero values failed: %v", err) + } + + var retrieved testEntity + err = client.Get(ctx, key, &retrieved) + if err != nil { + t.Fatalf("Get failed: %v", err) + } + + if retrieved.Name != "" { + t.Errorf("expected empty string, got %q", retrieved.Name) + } + if retrieved.Count != 0 { + t.Errorf("expected 0, got %d", retrieved.Count) + } +} + +func TestPutMultiWithPartialEncode(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Mix of valid and invalid entities + type MixedEntity struct { + Data any + Name string + } + + keys := []*datastore.Key{ + datastore.NameKey("Test", "key1", nil), + datastore.NameKey("Test", "key2", nil), + } + + entities := []MixedEntity{ + {Name: "valid", Data: "string"}, + {Name: "maybe-invalid", Data: make(chan int)}, // channels unsupported + } + + _, err := client.PutMulti(ctx, keys, entities) + + if err == nil { + t.Log("PutMulti with mixed entities succeeded (mock may not validate types)") + } else { + t.Logf("PutMulti with mixed entities failed as expected: %v", err) + } +} + +func TestGetWithDecodeError(t *testing.T) { + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Metadata-Flavor") != "Google" { + w.WriteHeader(http.StatusForbidden) + return + } + if r.URL.Path == "/project/project-id" { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("test-project")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + if r.URL.Path == "/instance/service-accounts/default/token" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-token", + "expires_in": 3600, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer metadataServer.Close() + + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, ":lookup") { + // Return entity with missing properties field + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "found": []map[string]any{ + { + "entity": map[string]any{ + "key": map[string]any{ + "path": []any{ + map[string]any{ + "kind": "Test", + "name": "key", + }, + }, + }, + // Missing properties field + }, + }, + }, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer apiServer.Close() + + restore := datastore.SetTestURLs(metadataServer.URL, apiServer.URL) + defer restore() + + client, err := datastore.NewClient(context.Background(), "test-project") + if err != nil { + t.Fatalf("NewClient failed: %v", err) + } + + ctx := context.Background() + key := datastore.NameKey("Test", "key", nil) + var entity testEntity + + err = client.Get(ctx, key, &entity) + if err == nil { + t.Error("expected error with missing properties") + } +} + +func TestPutWithEncodeError(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Create entity with unsupported type + type BadEntity struct { + Channel chan int `datastore:"channel"` + } + + key := datastore.NameKey("Test", "key", nil) + entity := &BadEntity{Channel: make(chan int)} + + _, err := client.Put(ctx, key, entity) + if err == nil { + t.Log("Put with unsupported type succeeded (mock may not validate types)") + } else { + t.Logf("Put with unsupported type failed as expected: %v", err) + } +} + +func TestDecodeValueInvalidInteger(t *testing.T) { + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Metadata-Flavor") != "Google" { + w.WriteHeader(http.StatusForbidden) + return + } + if r.URL.Path == "/project/project-id" { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("test-project")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + if r.URL.Path == "/instance/service-accounts/default/token" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-token", + "expires_in": 3600, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer metadataServer.Close() + + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, ":lookup") { + // Return entity with invalid integer format + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "found": []map[string]any{ + { + "entity": map[string]any{ + "key": map[string]any{ + "path": []any{ + map[string]any{ + "kind": "Test", + "name": "key", + }, + }, + }, + "properties": map[string]any{ + "count": map[string]any{"integerValue": "not-an-integer"}, + }, + }, + }, + }, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer apiServer.Close() + + restore := datastore.SetTestURLs(metadataServer.URL, apiServer.URL) + defer restore() + + client, err := datastore.NewClient(context.Background(), "test-project") + if err != nil { + t.Fatalf("NewClient failed: %v", err) + } + + ctx := context.Background() + key := datastore.NameKey("Test", "key", nil) + var entity testEntity + + err = client.Get(ctx, key, &entity) + if err == nil { + t.Error("expected error with invalid integer format") + } else { + t.Logf("Got expected error: %v", err) + } +} + +func TestDecodeValueWrongTypeForInteger(t *testing.T) { + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Metadata-Flavor") != "Google" { + w.WriteHeader(http.StatusForbidden) + return + } + if r.URL.Path == "/project/project-id" { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("test-project")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + if r.URL.Path == "/instance/service-accounts/default/token" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-token", + "expires_in": 3600, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer metadataServer.Close() + + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, ":lookup") { + // Return entity with integer value but string field type + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "found": []map[string]any{ + { + "entity": map[string]any{ + "key": map[string]any{ + "path": []any{ + map[string]any{ + "kind": "Test", + "name": "key", + }, + }, + }, + "properties": map[string]any{ + "name": map[string]any{"integerValue": "12345"}, // integer for string field + }, + }, + }, + }, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer apiServer.Close() + + restore := datastore.SetTestURLs(metadataServer.URL, apiServer.URL) + defer restore() + + client, err := datastore.NewClient(context.Background(), "test-project") + if err != nil { + t.Fatalf("NewClient failed: %v", err) + } + + ctx := context.Background() + key := datastore.NameKey("Test", "key", nil) + var entity testEntity + + err = client.Get(ctx, key, &entity) + if err == nil { + t.Error("expected error with wrong type for integer") + } else { + t.Logf("Got expected error: %v", err) + } +} + +func TestDecodeValueInvalidTimestamp(t *testing.T) { + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Metadata-Flavor") != "Google" { + w.WriteHeader(http.StatusForbidden) + return + } + if r.URL.Path == "/project/project-id" { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("test-project")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + if r.URL.Path == "/instance/service-accounts/default/token" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-token", + "expires_in": 3600, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer metadataServer.Close() + + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, ":lookup") { + // Return entity with invalid timestamp format + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "found": []map[string]any{ + { + "entity": map[string]any{ + "key": map[string]any{ + "path": []any{ + map[string]any{ + "kind": "Test", + "name": "key", + }, + }, + }, + "properties": map[string]any{ + "updated_at": map[string]any{"timestampValue": "invalid-timestamp"}, + }, + }, + }, + }, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer apiServer.Close() + + restore := datastore.SetTestURLs(metadataServer.URL, apiServer.URL) + defer restore() + + client, err := datastore.NewClient(context.Background(), "test-project") + if err != nil { + t.Fatalf("NewClient failed: %v", err) + } + + ctx := context.Background() + key := datastore.NameKey("Test", "key", nil) + var entity testEntity + + err = client.Get(ctx, key, &entity) + if err == nil { + t.Error("expected error with invalid timestamp format") + } else { + t.Logf("Got expected error: %v", err) + } +} + +func TestGetMultiDecodeError(t *testing.T) { + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Metadata-Flavor") != "Google" { + w.WriteHeader(http.StatusForbidden) + return + } + if r.URL.Path == "/project/project-id" { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("test-project")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + if r.URL.Path == "/instance/service-accounts/default/token" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-token", + "expires_in": 3600, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer metadataServer.Close() + + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, ":lookup") { + // Return one good entity and one with decode error + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "found": []map[string]any{ + { + "entity": map[string]any{ + "key": map[string]any{ + "path": []any{ + map[string]any{ + "kind": "Test", + "name": "key1", + }, + }, + }, + "properties": map[string]any{ + "name": map[string]any{"stringValue": "test"}, + }, + }, + }, + { + "entity": "invalid", // This will cause decode error + }, + }, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer apiServer.Close() + + restore := datastore.SetTestURLs(metadataServer.URL, apiServer.URL) + defer restore() + + client, err := datastore.NewClient(context.Background(), "test-project") + if err != nil { + t.Fatalf("NewClient failed: %v", err) + } + + ctx := context.Background() + keys := []*datastore.Key{ + datastore.NameKey("Test", "key1", nil), + datastore.NameKey("Test", "key2", nil), + } + + var entities []testEntity + err = client.GetMulti(ctx, keys, &entities) + if err == nil { + t.Error("expected error when one entity has decode error") + } +} diff --git a/pkg/datastore/errors.go b/pkg/datastore/errors.go new file mode 100644 index 0000000..d7b09cd --- /dev/null +++ b/pkg/datastore/errors.go @@ -0,0 +1,11 @@ +package datastore + +import "errors" + +var ( + // ErrNoSuchEntity is returned when an entity is not found. + ErrNoSuchEntity = errors.New("datastore: no such entity") + + // ErrDone is returned by Iterator.Next when no more results are available. + ErrDone = errors.New("datastore: no more results") +) diff --git a/pkg/datastore/http.go b/pkg/datastore/http.go new file mode 100644 index 0000000..33dd802 --- /dev/null +++ b/pkg/datastore/http.go @@ -0,0 +1,122 @@ +package datastore + +import ( + "bytes" + "context" + "fmt" + "io" + "log/slog" + "math" + "math/rand/v2" + "net/http" + neturl "net/url" + "time" +) + +// doRequest performs an HTTP request with exponential backoff retries. +// Returns an error if the status code is not 200 OK. +func doRequest(ctx context.Context, logger *slog.Logger, url string, jsonData []byte, token, projectID, databaseID string) ([]byte, error) { + var lastErr error + + for attempt := range maxRetries { + if attempt > 0 { + // Exponential backoff: 100ms, 200ms, 400ms... capped at maxBackoffMS + backoffMS := math.Min(float64(baseBackoffMS)*math.Pow(2, float64(attempt-1)), float64(maxBackoffMS)) + // Add jitter: ±25% randomness + jitter := backoffMS * jitterFraction * (2*rand.Float64() - 1) //nolint:gosec // Weak random is acceptable for jitter + sleepMS := backoffMS + jitter + sleepDuration := time.Duration(sleepMS) * time.Millisecond + + logger.DebugContext(ctx, "retrying request", + "attempt", attempt+1, + "max_attempts", maxRetries, + "backoff_ms", int(sleepMS), + "last_error", lastErr) + + select { + case <-time.After(sleepDuration): + case <-ctx.Done(): + return nil, ctx.Err() + } + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Authorization", "Bearer "+token) + req.Header.Set("Content-Type", "application/json") + + // Add routing header for named databases + if databaseID != "" { + // URL-encode values to prevent header injection attacks + routingHeader := fmt.Sprintf("project_id=%s&database_id=%s", neturl.QueryEscape(projectID), neturl.QueryEscape(databaseID)) + req.Header.Set("X-Goog-Request-Params", routingHeader) + } + + logger.DebugContext(ctx, "sending request", "url", url, "attempt", attempt+1) + + resp, err := httpClient.Do(req) + if err != nil { + lastErr = err + logger.WarnContext(ctx, "request failed", "error", err, "attempt", attempt+1) + if attempt == maxRetries-1 { + return nil, fmt.Errorf("request failed after %d attempts: %w", maxRetries, err) + } + continue + } + + // Always close response body + defer func() { //nolint:revive,gocritic // Defer in loop is intentional - loop exits after successful response + if closeErr := resp.Body.Close(); closeErr != nil { + logger.WarnContext(ctx, "failed to close response body", "error", closeErr) + } + }() + + body, err := io.ReadAll(io.LimitReader(resp.Body, maxBodySize)) + if err != nil { + lastErr = err + logger.WarnContext(ctx, "failed to read response body", "error", err, "attempt", attempt+1) + if attempt == maxRetries-1 { + return nil, fmt.Errorf("failed to read response after %d attempts: %w", maxRetries, err) + } + continue + } + + logger.DebugContext(ctx, "received response", + "status_code", resp.StatusCode, + "body_size", len(body), + "attempt", attempt+1) + + // Success + if resp.StatusCode == http.StatusOK { + return body, nil + } + + // Don't retry on 4xx errors (client errors) + if resp.StatusCode >= 400 && resp.StatusCode < 500 { + if resp.StatusCode == http.StatusNotFound { + logger.DebugContext(ctx, "entity not found", "status_code", resp.StatusCode) + } else { + logger.WarnContext(ctx, "client error", "status_code", resp.StatusCode, "body", string(body)) + } + return nil, fmt.Errorf("request failed with status %d: %s", resp.StatusCode, string(body)) + } + + // Unexpected 2xx/3xx status codes + if resp.StatusCode < 400 { + logger.WarnContext(ctx, "unexpected non-200 success status", "status_code", resp.StatusCode) + return nil, fmt.Errorf("unexpected status code %d: %s", resp.StatusCode, string(body)) + } + + // 5xx errors - retry + lastErr = fmt.Errorf("server error: status %d", resp.StatusCode) + logger.WarnContext(ctx, "server error, will retry", + "status_code", resp.StatusCode, + "attempt", attempt+1, + "body", string(body)) + } + + return nil, fmt.Errorf("all %d attempts failed: %w", maxRetries, lastErr) +} diff --git a/pkg/datastore/http_test.go b/pkg/datastore/http_test.go new file mode 100644 index 0000000..58cbbfc --- /dev/null +++ b/pkg/datastore/http_test.go @@ -0,0 +1,517 @@ +package datastore_test + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/codeGROOVE-dev/ds9/pkg/datastore" +) + +func TestDoRequestRetryOn5xxError(t *testing.T) { + // Setup mock servers + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Metadata-Flavor") != "Google" { + w.WriteHeader(http.StatusForbidden) + return + } + if r.URL.Path == "/project/project-id" { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("test-project")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + if r.URL.Path == "/instance/service-accounts/default/token" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-token", + "expires_in": 3600, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer metadataServer.Close() + + attemptCount := 0 + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attemptCount++ + // Return 503 on first two attempts, then succeed + if attemptCount < 3 { + w.WriteHeader(http.StatusServiceUnavailable) + if _, err := w.Write([]byte(`{"error":"service unavailable"}`)); err != nil { + t.Logf("write failed: %v", err) + } + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(map[string]any{ + "mutationResults": []any{ + map[string]any{"key": map[string]any{}}, + }, + }); err != nil { + t.Logf("encode failed: %v", err) + } + })) + defer apiServer.Close() + + restore := datastore.SetTestURLs(metadataServer.URL, apiServer.URL) + defer restore() + + ctx := context.Background() + client, err := datastore.NewClient(ctx, "test-project") + if err != nil { + t.Fatalf("NewClient failed: %v", err) + } + + // This should succeed after retries + key := datastore.NameKey("TestKind", "retry-test", nil) + entity := &testEntity{Name: "test", Count: 1} + _, err = client.Put(ctx, key, entity) + if err != nil { + t.Fatalf("Put should succeed after retries, got: %v", err) + } + + if attemptCount < 2 { + t.Errorf("expected at least 2 attempts, got %d", attemptCount) + } +} + +func TestDoRequestFailsOn4xxError(t *testing.T) { + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Metadata-Flavor") != "Google" { + w.WriteHeader(http.StatusForbidden) + return + } + if r.URL.Path == "/project/project-id" { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("test-project")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + if r.URL.Path == "/instance/service-accounts/default/token" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-token", + "expires_in": 3600, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer metadataServer.Close() + + attemptCount := 0 + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attemptCount++ + // Always return 400 Bad Request + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + if _, err := w.Write([]byte(`{"error":"bad request"}`)); err != nil { + t.Logf("write failed: %v", err) + } + })) + defer apiServer.Close() + + restore := datastore.SetTestURLs(metadataServer.URL, apiServer.URL) + defer restore() + + ctx := context.Background() + client, err := datastore.NewClient(ctx, "test-project") + if err != nil { + t.Fatalf("NewClient failed: %v", err) + } + + // This should fail immediately without retry on 4xx + key := datastore.NameKey("TestKind", "bad-request", nil) + entity := &testEntity{Name: "test", Count: 1} + _, err = client.Put(ctx, key, entity) + if err == nil { + t.Fatal("expected error on 4xx response") + } + + if !strings.Contains(err.Error(), "400") { + t.Errorf("expected error to mention 400 status, got: %v", err) + } + + // Should only try once for 4xx errors (no retry) + if attemptCount != 1 { + t.Errorf("expected exactly 1 attempt for 4xx error, got %d", attemptCount) + } +} + +func TestDoRequestContextCancellation(t *testing.T) { + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Metadata-Flavor") != "Google" { + w.WriteHeader(http.StatusForbidden) + return + } + if r.URL.Path == "/project/project-id" { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("test-project")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + if r.URL.Path == "/instance/service-accounts/default/token" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-token", + "expires_in": 3600, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer metadataServer.Close() + + attemptCount := 0 + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attemptCount++ + // Always return 503 to force retry + w.WriteHeader(http.StatusServiceUnavailable) + if _, err := w.Write([]byte(`{"error":"unavailable"}`)); err != nil { + t.Logf("write failed: %v", err) + } + })) + defer apiServer.Close() + + restore := datastore.SetTestURLs(metadataServer.URL, apiServer.URL) + defer restore() + + client, err := datastore.NewClient(context.Background(), "test-project") + if err != nil { + t.Fatalf("NewClient failed: %v", err) + } + + // Create context that we'll cancel + ctx, cancel := context.WithCancel(context.Background()) + + // Cancel after a short delay + go func() { + time.Sleep(50 * time.Millisecond) + cancel() + }() + + key := datastore.NameKey("TestKind", "cancel-test", nil) + entity := &testEntity{Name: "test", Count: 1} + _, err = client.Put(ctx, key, entity) + + if err == nil { + t.Fatal("expected error when context is cancelled") + } + + if !errors.Is(err, context.Canceled) && !strings.Contains(err.Error(), "context canceled") { + t.Errorf("expected context cancellation error, got: %v", err) + } +} + +func TestGetWithHTTPError(t *testing.T) { + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Metadata-Flavor") != "Google" { + w.WriteHeader(http.StatusForbidden) + return + } + if r.URL.Path == "/project/project-id" { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("test-project")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + if r.URL.Path == "/instance/service-accounts/default/token" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-token", + "expires_in": 3600, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer metadataServer.Close() + + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Return 404 for lookup + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusNotFound) + if err := json.NewEncoder(w).Encode(map[string]any{ + "error": "not found", + }); err != nil { + t.Logf("encode failed: %v", err) + } + })) + defer apiServer.Close() + + restore := datastore.SetTestURLs(metadataServer.URL, apiServer.URL) + defer restore() + + ctx := context.Background() + client, err := datastore.NewClient(ctx, "test-project") + if err != nil { + t.Fatalf("NewClient failed: %v", err) + } + + key := datastore.NameKey("TestKind", "test", nil) + var entity testEntity + err = client.Get(ctx, key, &entity) + + if err == nil { + t.Fatal("expected error on 404") + } + + if !strings.Contains(err.Error(), "404") { + t.Errorf("expected error to mention 404, got: %v", err) + } +} + +func TestPutWithHTTPError(t *testing.T) { + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Metadata-Flavor") != "Google" { + w.WriteHeader(http.StatusForbidden) + return + } + if r.URL.Path == "/project/project-id" { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("test-project")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + if r.URL.Path == "/instance/service-accounts/default/token" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-token", + "expires_in": 3600, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer metadataServer.Close() + + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Return 403 Forbidden + w.WriteHeader(http.StatusForbidden) + if _, err := w.Write([]byte(`{"error":"permission denied"}`)); err != nil { + t.Logf("write failed: %v", err) + } + })) + defer apiServer.Close() + + restore := datastore.SetTestURLs(metadataServer.URL, apiServer.URL) + defer restore() + + ctx := context.Background() + client, err := datastore.NewClient(ctx, "test-project") + if err != nil { + t.Fatalf("NewClient failed: %v", err) + } + + key := datastore.NameKey("TestKind", "test", nil) + entity := &testEntity{Name: "test", Count: 1} + _, err = client.Put(ctx, key, entity) + + if err == nil { + t.Fatal("expected error on 403") + } + + if !strings.Contains(err.Error(), "403") { + t.Errorf("expected error to mention 403, got: %v", err) + } +} + +func TestDoRequestAllRetriesFail(t *testing.T) { + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Metadata-Flavor") != "Google" { + w.WriteHeader(http.StatusForbidden) + return + } + if r.URL.Path == "/project/project-id" { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("test-project")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + if r.URL.Path == "/instance/service-accounts/default/token" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-token", + "expires_in": 3600, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer metadataServer.Close() + + attemptCount := 0 + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attemptCount++ + // Always fail with 500 + w.WriteHeader(http.StatusInternalServerError) + if _, err := w.Write([]byte(`{"error":"persistent failure"}`)); err != nil { + t.Logf("write failed: %v", err) + } + })) + defer apiServer.Close() + + restore := datastore.SetTestURLs(metadataServer.URL, apiServer.URL) + defer restore() + + ctx := context.Background() + client, err := datastore.NewClient(ctx, "test-project") + if err != nil { + t.Fatalf("NewClient failed: %v", err) + } + + key := datastore.NameKey("TestKind", "test", nil) + entity := &testEntity{Name: "test", Count: 1} + _, err = client.Put(ctx, key, entity) + + if err == nil { + t.Fatal("expected error after all retries") + } + + if !strings.Contains(err.Error(), "attempts failed") { + t.Errorf("expected 'attempts failed' error, got: %v", err) + } + + // Should have tried multiple times + if attemptCount < 3 { + t.Errorf("expected at least 3 attempts, got %d", attemptCount) + } +} + +func TestDoRequestUnexpectedSuccess(t *testing.T) { + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Metadata-Flavor") != "Google" { + w.WriteHeader(http.StatusForbidden) + return + } + if r.URL.Path == "/project/project-id" { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("test-project")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + if r.URL.Path == "/instance/service-accounts/default/token" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-token", + "expires_in": 3600, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer metadataServer.Close() + + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Return unexpected 2xx status (not 200) + w.WriteHeader(http.StatusAccepted) // 202 + if _, err := w.Write([]byte(`{"message":"accepted"}`)); err != nil { + t.Logf("write failed: %v", err) + } + })) + defer apiServer.Close() + + restore := datastore.SetTestURLs(metadataServer.URL, apiServer.URL) + defer restore() + + ctx := context.Background() + client, err := datastore.NewClient(ctx, "test-project") + if err != nil { + t.Fatalf("NewClient failed: %v", err) + } + + key := datastore.NameKey("Test", "key", nil) + entity := &testEntity{Name: "test", Count: 1} + _, err = client.Put(ctx, key, entity) + + if err == nil { + t.Error("expected error for unexpected 2xx status") + } + + if !strings.Contains(err.Error(), "202") { + t.Errorf("expected error to mention 202 status, got: %v", err) + } +} + +func TestDoRequestWithReadBodyError(t *testing.T) { + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Metadata-Flavor") != "Google" { + w.WriteHeader(http.StatusForbidden) + return + } + if r.URL.Path == "/project/project-id" { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("test-project")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + if r.URL.Path == "/instance/service-accounts/default/token" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-token", + "expires_in": 3600, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer metadataServer.Close() + + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Set content length but don't write enough data + w.Header().Set("Content-Length", "1000000") + w.WriteHeader(http.StatusOK) + // Write partial data then close connection + })) + defer apiServer.Close() + + restore := datastore.SetTestURLs(metadataServer.URL, apiServer.URL) + defer restore() + + ctx := context.Background() + client, err := datastore.NewClient(ctx, "test-project") + if err != nil { + t.Fatalf("NewClient failed: %v", err) + } + + key := datastore.NameKey("Test", "key", nil) + entity := &testEntity{Name: "test", Count: 1} + _, err = client.Put(ctx, key, entity) + // Should get an error related to response parsing + if err != nil { + t.Logf("Got expected error with incomplete response: %v", err) + } +} diff --git a/pkg/datastore/iterator.go b/pkg/datastore/iterator.go new file mode 100644 index 0000000..f33262b --- /dev/null +++ b/pkg/datastore/iterator.go @@ -0,0 +1,156 @@ +package datastore + +import ( + "context" + "encoding/json" + "errors" + "fmt" + neturl "net/url" + + "github.com/codeGROOVE-dev/ds9/auth" +) + +// Iterator is an iterator for query results. +// API compatible with cloud.google.com/go/datastore. +// +//nolint:govet // Field alignment optimized for API compatibility over memory layout +type Iterator struct { + ctx context.Context //nolint:containedctx // Required for API compatibility with cloud.google.com/go/datastore + client *Client + query *Query + results []iteratorResult + index int + err error + cursor Cursor + fetchNext bool +} + +type iteratorResult struct { + key *Key + entity map[string]any + cursor Cursor +} + +// Next advances the iterator and returns the next key and destination. +// It returns Done when no more results are available. +// API compatible with cloud.google.com/go/datastore. +func (it *Iterator) Next(dst any) (*Key, error) { + // Check if we need to fetch more results + if it.index >= len(it.results) { + if it.err != nil { + return nil, it.err + } + if !it.fetchNext { + return nil, ErrDone + } + + // Fetch next batch + if err := it.fetch(); err != nil { + it.err = err + return nil, err + } + + if len(it.results) == 0 { + return nil, ErrDone + } + } + + result := it.results[it.index] + it.index++ + + // Only update cursor if the result has one + if result.cursor != "" { + it.cursor = result.cursor + } + + // Decode entity into dst + if err := decodeEntity(result.entity, dst); err != nil { + return nil, err + } + + return result.key, nil +} + +// Cursor returns the cursor for the iterator's current position. +// API compatible with cloud.google.com/go/datastore. +func (it *Iterator) Cursor() (Cursor, error) { + if it.cursor == "" { + return "", errors.New("no cursor available") + } + return it.cursor, nil +} + +// fetch retrieves the next batch of results. +func (it *Iterator) fetch() error { + token, err := auth.AccessToken(it.ctx) + if err != nil { + return fmt.Errorf("failed to get access token: %w", err) + } + + // Build query with current cursor as start + q := *it.query + if it.cursor != "" { + q.startCursor = it.cursor + } + + queryObj := buildQueryMap(&q) + reqBody := map[string]any{"query": queryObj} + if it.client.databaseID != "" { + reqBody["databaseId"] = it.client.databaseID + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return fmt.Errorf("failed to marshal request: %w", err) + } + + // URL-encode project ID to prevent injection attacks + reqURL := fmt.Sprintf("%s/projects/%s:runQuery", it.client.baseURL, neturl.PathEscape(it.client.projectID)) + body, err := doRequest(it.ctx, it.client.logger, reqURL, jsonData, token, it.client.projectID, it.client.databaseID) + if err != nil { + return err + } + + var result struct { + Batch struct { + EntityResults []struct { + Entity map[string]any `json:"entity"` + Cursor string `json:"cursor"` + } `json:"entityResults"` + MoreResults string `json:"moreResults"` + EndCursor string `json:"endCursor"` + SkippedResults int `json:"skippedResults"` + } `json:"batch"` + } + + if err := json.Unmarshal(body, &result); err != nil { + return fmt.Errorf("failed to parse response: %w", err) + } + + // Convert results to iterator format + it.results = make([]iteratorResult, 0, len(result.Batch.EntityResults)) + for _, er := range result.Batch.EntityResults { + key, err := keyFromJSON(er.Entity["key"]) + if err != nil { + return err + } + + it.results = append(it.results, iteratorResult{ + key: key, + entity: er.Entity, + cursor: Cursor(er.Cursor), + }) + } + + it.index = 0 + + // Check if there are more results + moreResults := result.Batch.MoreResults + it.fetchNext = moreResults == "NOT_FINISHED" || moreResults == "MORE_RESULTS_AFTER_LIMIT" || moreResults == "MORE_RESULTS_AFTER_CURSOR" + + if result.Batch.EndCursor != "" { + it.cursor = Cursor(result.Batch.EndCursor) + } + + return nil +} diff --git a/pkg/datastore/iterator_coverage_test.go b/pkg/datastore/iterator_coverage_test.go new file mode 100644 index 0000000..f915dd4 --- /dev/null +++ b/pkg/datastore/iterator_coverage_test.go @@ -0,0 +1,227 @@ +package datastore + +import ( + "context" + "errors" + "testing" + "time" +) + +// Test Iterator.Cursor() when no cursor is available +func TestIteratorCursorNoCursor(t *testing.T) { + client, cleanup := NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Run query on empty kind + q := NewQuery("EmptyKind") + it := client.Run(ctx, q) + + // Try to get cursor before iterating + cursor, err := it.Cursor() + if err == nil { + t.Error("Expected error when no cursor available") + } + if cursor != "" { + t.Errorf("Expected empty cursor, got %s", cursor) + } +} + +// Test Iterator with multiple fetches (pagination) +func TestIteratorMultipleFetches(t *testing.T) { + client, cleanup := NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Create many entities to force multiple fetches + for i := range 25 { + key := IDKey("FetchTest", int64(i+1), nil) + entity := &struct { + Name string + Index int64 + Time time.Time + }{ + Name: "test", + Index: int64(i), + Time: time.Now().UTC().Truncate(time.Microsecond), + } + if _, err := client.Put(ctx, key, entity); err != nil { + t.Fatalf("Put failed: %v", err) + } + } + + // Query with small limit to trigger multiple fetches + query := NewQuery("FetchTest").Limit(10) + it := client.Run(ctx, query) + + count := 0 + for { + var entity struct { + Name string + Index int64 + Time time.Time + } + _, err := it.Next(&entity) + if errors.Is(err, ErrDone) { + break + } + if err != nil { + t.Fatalf("Iterator Next failed: %v", err) + } + count++ + } + + if count == 0 { + t.Error("Expected to iterate over some entities") + } +} + +// Test Iterator.Next with nil destination +func TestIteratorNextNilDst(t *testing.T) { + client, cleanup := NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Create test entity + key := NameKey("NilDstTest", "test", nil) + entity := &struct { + Name string + }{ + Name: "test", + } + if _, err := client.Put(ctx, key, entity); err != nil { + t.Fatalf("Put failed: %v", err) + } + + // Query + q := NewQuery("NilDstTest") + it := client.Run(ctx, q) + + // Try to iterate with nil dst + _, err := it.Next(nil) + if err == nil { + t.Error("Expected error when dst is nil") + } +} + +// Test Iterator.Next with non-pointer destination +func TestIteratorNextNonPointerDst(t *testing.T) { + client, cleanup := NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Create test entity + key := NameKey("NonPtrTest", "test", nil) + entity := &struct { + Name string + }{ + Name: "test", + } + if _, err := client.Put(ctx, key, entity); err != nil { + t.Fatalf("Put failed: %v", err) + } + + // Query + q := NewQuery("NonPtrTest") + it := client.Run(ctx, q) + + // Try to iterate with non-pointer dst + var dst struct { + Name string + } + _, err := it.Next(dst) // Pass by value instead of pointer + if err == nil { + t.Error("Expected error when dst is not a pointer") + } +} + +// Test fetch() error path via context cancellation +func TestIteratorFetchContextCancel(t *testing.T) { + client, cleanup := NewMockClient(t) + defer cleanup() + + // Create a cancelled context + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Cancel immediately + + // Create test entities first with valid context + validCtx := context.Background() + for i := range 5 { + key := IDKey("CancelTest", int64(i+1), nil) + entity := &struct { + Name string + }{ + Name: "test", + } + if _, err := client.Put(validCtx, key, entity); err != nil { + t.Fatalf("Put failed: %v", err) + } + } + + // Query with cancelled context + q := NewQuery("CancelTest") + it := client.Run(ctx, q) + + var dst struct { + Name string + } + _, err := it.Next(&dst) + // Should get error due to cancelled context + if err == nil { + t.Log("Expected error with cancelled context (mock may not respect context)") + } +} + +// Test iterator with keys-only query +func TestIteratorKeysOnly(t *testing.T) { + client, cleanup := NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Create test entities + for i := range 5 { + key := IDKey("KeysOnlyTest", int64(i+1), nil) + entity := &struct { + Name string + Data string + }{ + Name: "test", + Data: "lots of data here", + } + if _, err := client.Put(ctx, key, entity); err != nil { + t.Fatalf("Put failed: %v", err) + } + } + + // Query keys only + q := NewQuery("KeysOnlyTest").KeysOnly() + it := client.Run(ctx, q) + + count := 0 + for { + var entity struct { + Name string + Data string + } + key, err := it.Next(&entity) + if err == ErrDone { + break + } + if err != nil { + t.Fatalf("Iterator Next failed: %v", err) + } + if key == nil { + t.Error("Expected non-nil key") + } + count++ + } + + if count == 0 { + t.Error("Expected to iterate over some keys") + } +} diff --git a/pkg/datastore/iterator_test.go b/pkg/datastore/iterator_test.go new file mode 100644 index 0000000..f77f179 --- /dev/null +++ b/pkg/datastore/iterator_test.go @@ -0,0 +1,96 @@ +package datastore_test + +import ( + "context" + "errors" + "fmt" + "testing" + + "github.com/codeGROOVE-dev/ds9/pkg/datastore" +) + +func TestIterator(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + t.Run("IterateAll", func(t *testing.T) { + // Create test entities + for i := range 5 { + key := datastore.IDKey("IterTest", int64(i+1), nil) + entity := &testEntity{ + Name: fmt.Sprintf("entity-%d", i), + Count: int64(i), + } + if _, err := client.Put(ctx, key, entity); err != nil { + t.Fatalf("Put failed: %v", err) + } + } + + q := datastore.NewQuery("IterTest") + it := client.Run(ctx, q) + + count := 0 + for { + var entity testEntity + key, err := it.Next(&entity) + if errors.Is(err, datastore.ErrDone) { + break + } + if err != nil { + t.Fatalf("Iterator.Next failed: %v", err) + } + if key == nil { + t.Errorf("Expected non-nil key") + } + count++ + } + + if count != 5 { + t.Errorf("Expected to iterate over 5 entities, got %d", count) + } + }) + + t.Run("IteratorCursor", func(t *testing.T) { + // Create test entities + for i := range 3 { + key := datastore.IDKey("CursorTest", int64(i+1), nil) + entity := &testEntity{ + Name: fmt.Sprintf("entity-%d", i), + Count: int64(i), + } + if _, err := client.Put(ctx, key, entity); err != nil { + t.Fatalf("Put failed: %v", err) + } + } + + q := datastore.NewQuery("CursorTest") + it := client.Run(ctx, q) + + var entity testEntity + _, err := it.Next(&entity) + if err != nil { + t.Fatalf("Iterator.Next failed: %v", err) + } + + // Get cursor after first entity + cursor, err := it.Cursor() + if err != nil { + t.Logf("Cursor not available: %v", err) + } else if cursor == "" { + t.Logf("Empty cursor returned") + } + }) + + t.Run("EmptyIterator", func(t *testing.T) { + q := datastore.NewQuery("NonExistent") + it := client.Run(ctx, q) + + var entity testEntity + _, err := it.Next(&entity) + if !errors.Is(err, datastore.ErrDone) { + t.Errorf("Expected datastore.ErrDone, got %v", err) + } + }) +} diff --git a/pkg/datastore/key.go b/pkg/datastore/key.go new file mode 100644 index 0000000..844aa35 --- /dev/null +++ b/pkg/datastore/key.go @@ -0,0 +1,225 @@ +package datastore + +import ( + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "strconv" + "strings" +) + +// Key represents a Datastore key. +type Key struct { + Parent *Key // Parent key for hierarchical keys + Kind string + Name string // For string keys + ID int64 // For numeric keys +} + +// NameKey creates a new key with a string name. +// The parent parameter can be nil for top-level keys. +// This matches the API of cloud.google.com/go/datastore. +func NameKey(kind, name string, parent *Key) *Key { + return &Key{ + Kind: kind, + Name: name, + Parent: parent, + } +} + +// IDKey creates a new key with a numeric ID. +// The parent parameter can be nil for top-level keys. +// This matches the API of cloud.google.com/go/datastore. +func IDKey(kind string, id int64, parent *Key) *Key { + return &Key{ + Kind: kind, + ID: id, + Parent: parent, + } +} + +// IncompleteKey creates a new incomplete key. +// The key will be completed (assigned an ID) when the entity is saved. +// API compatible with cloud.google.com/go/datastore. +func IncompleteKey(kind string, parent *Key) *Key { + return &Key{ + Kind: kind, + Parent: parent, + } +} + +// Incomplete returns true if the key does not have an ID or Name. +// API compatible with cloud.google.com/go/datastore. +func (k *Key) Incomplete() bool { + return k.ID == 0 && k.Name == "" +} + +// Equal returns true if this key is equal to the other key. +// API compatible with cloud.google.com/go/datastore. +func (k *Key) Equal(other *Key) bool { + if k == nil && other == nil { + return true + } + if k == nil || other == nil { + return false + } + if k.Kind != other.Kind || k.Name != other.Name || k.ID != other.ID { + return false + } + // Recursively check parent keys + return k.Parent.Equal(other.Parent) +} + +// String returns a human-readable string representation of the key. +// API compatible with cloud.google.com/go/datastore. +func (k *Key) String() string { + if k == nil { + return "" + } + + var parts []string + for curr := k; curr != nil; curr = curr.Parent { + var part string + switch { + case curr.Name != "": + part = fmt.Sprintf("%s,%q", curr.Kind, curr.Name) + case curr.ID != 0: + part = fmt.Sprintf("%s,%d", curr.Kind, curr.ID) + default: + part = fmt.Sprintf("%s,incomplete", curr.Kind) + } + // Prepend to maintain correct order (root to leaf) + parts = append([]string{part}, parts...) + } + + return "/" + strings.Join(parts, "/") +} + +// Encode returns an opaque representation of the key. +// API compatible with cloud.google.com/go/datastore. +func (k *Key) Encode() string { + if k == nil { + return "" + } + + // Convert key to JSON representation + keyJSON := keyToJSON(k) + + // Marshal to JSON bytes + jsonBytes, err := json.Marshal(keyJSON) + if err != nil { + return "" + } + + // Base64 encode + return base64.URLEncoding.EncodeToString(jsonBytes) +} + +// DecodeKey decodes a key from its opaque representation. +// API compatible with cloud.google.com/go/datastore. +func DecodeKey(encoded string) (*Key, error) { + if encoded == "" { + return nil, errors.New("empty encoded key") + } + + // Base64 decode + jsonBytes, err := base64.URLEncoding.DecodeString(encoded) + if err != nil { + return nil, fmt.Errorf("failed to decode base64: %w", err) + } + + // Unmarshal JSON + var keyData any + if err := json.Unmarshal(jsonBytes, &keyData); err != nil { + return nil, fmt.Errorf("failed to unmarshal JSON: %w", err) + } + + // Convert from JSON representation + key, err := keyFromJSON(keyData) + if err != nil { + return nil, fmt.Errorf("failed to parse key: %w", err) + } + + return key, nil +} + +// keyToJSON converts a Key to its JSON representation. +// Supports hierarchical keys with parent relationships. +func keyToJSON(key *Key) map[string]any { + // Build path from root to leaf (parent -> child) + var path []map[string]any + + // Collect all keys from root to leaf + keys := make([]*Key, 0) + for k := key; k != nil; k = k.Parent { + keys = append(keys, k) + } + + // Reverse to go from root to leaf + for i := len(keys) - 1; i >= 0; i-- { + k := keys[i] + elem := map[string]any{ + "kind": k.Kind, + } + + if k.Name != "" { + elem["name"] = k.Name + } else if k.ID != 0 { + elem["id"] = strconv.FormatInt(k.ID, 10) + } + + path = append(path, elem) + } + + return map[string]any{ + "path": path, + } +} + +// keyFromJSON converts a JSON key representation to a Key. +func keyFromJSON(keyData any) (*Key, error) { + keyMap, ok := keyData.(map[string]any) + if !ok { + return nil, errors.New("invalid key format") + } + + path, ok := keyMap["path"].([]any) + if !ok || len(path) == 0 { + return nil, errors.New("invalid key path") + } + + // Build key hierarchy from path elements + var key *Key + for _, elem := range path { + elemMap, ok := elem.(map[string]any) + if !ok { + return nil, errors.New("invalid path element") + } + + newKey := &Key{ + Parent: key, + } + + if kind, ok := elemMap["kind"].(string); ok { + newKey.Kind = kind + } + + if name, ok := elemMap["name"].(string); ok { + newKey.Name = name + } else if idVal, exists := elemMap["id"]; exists { + switch id := idVal.(type) { + case string: + if _, err := fmt.Sscanf(id, "%d", &newKey.ID); err != nil { + return nil, fmt.Errorf("invalid ID format: %w", err) + } + case float64: + newKey.ID = int64(id) + } + } + + key = newKey + } + + return key, nil +} diff --git a/key_test.go b/pkg/datastore/key_test.go similarity index 99% rename from key_test.go rename to pkg/datastore/key_test.go index 0e7b25c..11f9358 100644 --- a/key_test.go +++ b/pkg/datastore/key_test.go @@ -1,4 +1,4 @@ -package ds9 +package datastore import ( "testing" diff --git a/pkg/datastore/mock_client.go b/pkg/datastore/mock_client.go new file mode 100644 index 0000000..f728318 --- /dev/null +++ b/pkg/datastore/mock_client.go @@ -0,0 +1,38 @@ +package datastore + +import ( + "context" + "testing" + + "github.com/codeGROOVE-dev/ds9/pkg/mock" +) + +// NewMockClient creates a datastore client connected to mock servers with in-memory storage. +// This is a convenience wrapper that avoids import cycles when writing tests in package datastore. +// Returns the client and a cleanup function that should be deferred. +func NewMockClient(t *testing.T) (client *Client, cleanup func()) { + t.Helper() + + // Create mock servers + metadataURL, apiURL, cleanup := mock.NewMockServers(t) + + // Set test URLs + restore := SetTestURLs(metadataURL, apiURL) + + // Create client + ctx := context.Background() + var err error + client, err = NewClient(ctx, "test-project") + if err != nil { + t.Fatalf("failed to create mock client: %v", err) + } + + // Wrap cleanup to restore URLs + originalCleanup := cleanup + cleanup = func() { + restore() + originalCleanup() + } + + return client, cleanup +} diff --git a/pkg/datastore/mutation.go b/pkg/datastore/mutation.go new file mode 100644 index 0000000..6fb7874 --- /dev/null +++ b/pkg/datastore/mutation.go @@ -0,0 +1,198 @@ +package datastore + +import ( + "context" + "encoding/json" + "fmt" + neturl "net/url" + + "github.com/codeGROOVE-dev/ds9/auth" +) + +// MutationOp represents the type of mutation operation. +type MutationOp string + +const ( + // MutationInsert represents an insert operation. + MutationInsert MutationOp = "insert" + // MutationUpdate represents an update operation. + MutationUpdate MutationOp = "update" + // MutationUpsert represents an upsert operation. + MutationUpsert MutationOp = "upsert" + // MutationDelete represents a delete operation. + MutationDelete MutationOp = "delete" +) + +// Mutation represents a pending datastore mutation. +type Mutation struct { + op MutationOp + key *Key + entity any +} + +// NewInsert creates an insert mutation. +// API compatible with cloud.google.com/go/datastore. +func NewInsert(k *Key, src any) *Mutation { + return &Mutation{ + op: MutationInsert, + key: k, + entity: src, + } +} + +// NewUpdate creates an update mutation. +// API compatible with cloud.google.com/go/datastore. +func NewUpdate(k *Key, src any) *Mutation { + return &Mutation{ + op: MutationUpdate, + key: k, + entity: src, + } +} + +// NewUpsert creates an upsert mutation. +// API compatible with cloud.google.com/go/datastore. +func NewUpsert(k *Key, src any) *Mutation { + return &Mutation{ + op: MutationUpsert, + key: k, + entity: src, + } +} + +// NewDelete creates a delete mutation. +// API compatible with cloud.google.com/go/datastore. +func NewDelete(k *Key) *Mutation { + return &Mutation{ + op: MutationDelete, + key: k, + } +} + +// Mutate applies one or more mutations atomically. +// API compatible with cloud.google.com/go/datastore. +func (c *Client) Mutate(ctx context.Context, muts ...*Mutation) ([]*Key, error) { + if len(muts) == 0 { + return nil, nil + } + + c.logger.DebugContext(ctx, "applying mutations", "count", len(muts)) + + token, err := auth.AccessToken(ctx) + if err != nil { + c.logger.ErrorContext(ctx, "failed to get access token", "error", err) + return nil, fmt.Errorf("failed to get access token: %w", err) + } + + // Build mutations array + mutations := make([]map[string]any, 0, len(muts)) + for i, mut := range muts { + if mut == nil { + c.logger.ErrorContext(ctx, "nil mutation", "index", i) + return nil, fmt.Errorf("mutation at index %d is nil", i) + } + if mut.key == nil { + c.logger.ErrorContext(ctx, "nil key in mutation", "index", i) + return nil, fmt.Errorf("mutation at index %d has nil key", i) + } + + mutMap := make(map[string]any) + + switch mut.op { + case MutationInsert: + if mut.entity == nil { + c.logger.ErrorContext(ctx, "nil entity for insert", "index", i) + return nil, fmt.Errorf("insert mutation at index %d has nil entity", i) + } + entity, err := encodeEntity(mut.key, mut.entity) + if err != nil { + c.logger.ErrorContext(ctx, "failed to encode entity", "index", i, "error", err) + return nil, fmt.Errorf("failed to encode entity at index %d: %w", i, err) + } + mutMap["insert"] = entity + + case MutationUpdate: + if mut.entity == nil { + c.logger.ErrorContext(ctx, "nil entity for update", "index", i) + return nil, fmt.Errorf("update mutation at index %d has nil entity", i) + } + entity, err := encodeEntity(mut.key, mut.entity) + if err != nil { + c.logger.ErrorContext(ctx, "failed to encode entity", "index", i, "error", err) + return nil, fmt.Errorf("failed to encode entity at index %d: %w", i, err) + } + mutMap["update"] = entity + + case MutationUpsert: + if mut.entity == nil { + c.logger.ErrorContext(ctx, "nil entity for upsert", "index", i) + return nil, fmt.Errorf("upsert mutation at index %d has nil entity", i) + } + entity, err := encodeEntity(mut.key, mut.entity) + if err != nil { + c.logger.ErrorContext(ctx, "failed to encode entity", "index", i, "error", err) + return nil, fmt.Errorf("failed to encode entity at index %d: %w", i, err) + } + mutMap["upsert"] = entity + + case MutationDelete: + mutMap["delete"] = keyToJSON(mut.key) + + default: + c.logger.ErrorContext(ctx, "unknown mutation operation", "index", i, "op", mut.op) + return nil, fmt.Errorf("unknown mutation operation at index %d: %s", i, mut.op) + } + + mutations = append(mutations, mutMap) + } + + reqBody := map[string]any{ + "mutations": mutations, + } + if c.databaseID != "" { + reqBody["databaseId"] = c.databaseID + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + c.logger.ErrorContext(ctx, "failed to marshal request", "error", err) + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + // URL-encode project ID to prevent injection attacks + reqURL := fmt.Sprintf("%s/projects/%s:commit", c.baseURL, neturl.PathEscape(c.projectID)) + body, err := doRequest(ctx, c.logger, reqURL, jsonData, token, c.projectID, c.databaseID) + if err != nil { + c.logger.ErrorContext(ctx, "mutate request failed", "error", err) + return nil, err + } + + var resp struct { + MutationResults []struct { + Key map[string]any `json:"key"` + } `json:"mutationResults"` + } + if err := json.Unmarshal(body, &resp); err != nil { + c.logger.ErrorContext(ctx, "failed to parse response", "error", err) + return nil, fmt.Errorf("failed to parse mutate response: %w", err) + } + + // Extract resulting keys + keys := make([]*Key, len(resp.MutationResults)) + for i, result := range resp.MutationResults { + if result.Key != nil { + key, err := keyFromJSON(result.Key) + if err != nil { + c.logger.ErrorContext(ctx, "failed to parse key", "index", i, "error", err) + return nil, fmt.Errorf("failed to parse key at index %d: %w", i, err) + } + keys[i] = key + } else { + // For deletes, use the original key + keys[i] = muts[i].key + } + } + + c.logger.DebugContext(ctx, "mutations applied successfully", "count", len(keys)) + return keys, nil +} diff --git a/pkg/datastore/mutation_test.go b/pkg/datastore/mutation_test.go new file mode 100644 index 0000000..addd6c0 --- /dev/null +++ b/pkg/datastore/mutation_test.go @@ -0,0 +1,180 @@ +package datastore_test + +import ( + "context" + "errors" + "testing" + + "github.com/codeGROOVE-dev/ds9/pkg/datastore" +) + +func TestMutate(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + t.Run("MutateInsert", func(t *testing.T) { + key := datastore.NameKey("MutateTest", "insert", nil) + entity := &testEntity{ + Name: "inserted", + Count: 42, + } + + mut := datastore.NewInsert(key, entity) + keys, err := client.Mutate(ctx, mut) + if err != nil { + t.Fatalf("Mutate insert failed: %v", err) + } + + if len(keys) != 1 { + t.Errorf("Expected 1 key, got %d", len(keys)) + } + + // Verify entity was created + var result testEntity + if err := client.Get(ctx, key, &result); err != nil { + t.Fatalf("Get after insert failed: %v", err) + } + if result.Name != "inserted" { + t.Errorf("Expected Name 'inserted', got '%s'", result.Name) + } + }) + + t.Run("MutateUpdate", func(t *testing.T) { + key := datastore.NameKey("MutateTest", "update", nil) + entity := &testEntity{Name: "original", Count: 1} + + // Create entity first + if _, err := client.Put(ctx, key, entity); err != nil { + t.Fatalf("Put failed: %v", err) + } + + // Update via mutation + updated := &testEntity{Name: "updated", Count: 2} + mut := datastore.NewUpdate(key, updated) + _, err := client.Mutate(ctx, mut) + if err != nil { + t.Fatalf("Mutate update failed: %v", err) + } + + // Verify entity was updated + var result testEntity + if err := client.Get(ctx, key, &result); err != nil { + t.Fatalf("Get after update failed: %v", err) + } + if result.Name != "updated" { + t.Errorf("Expected Name 'updated', got '%s'", result.Name) + } + }) + + t.Run("MutateUpsert", func(t *testing.T) { + key := datastore.NameKey("MutateTest", "upsert", nil) + entity := &testEntity{Name: "upserted", Count: 100} + + mut := datastore.NewUpsert(key, entity) + keys, err := client.Mutate(ctx, mut) + if err != nil { + t.Fatalf("Mutate upsert failed: %v", err) + } + + if len(keys) != 1 { + t.Errorf("Expected 1 key, got %d", len(keys)) + } + + // Verify entity exists + var result testEntity + if err := client.Get(ctx, key, &result); err != nil { + t.Fatalf("Get after upsert failed: %v", err) + } + if result.Name != "upserted" { + t.Errorf("Expected Name 'upserted', got '%s'", result.Name) + } + }) + + t.Run("MutateDelete", func(t *testing.T) { + key := datastore.NameKey("MutateTest", "delete", nil) + entity := &testEntity{Name: "to-delete", Count: 1} + + // Create entity first + if _, err := client.Put(ctx, key, entity); err != nil { + t.Fatalf("Put failed: %v", err) + } + + // Delete via mutation + mut := datastore.NewDelete(key) + keys, err := client.Mutate(ctx, mut) + if err != nil { + t.Fatalf("Mutate delete failed: %v", err) + } + + if len(keys) != 1 { + t.Errorf("Expected 1 key, got %d", len(keys)) + } + + // Verify entity was deleted + var result testEntity + err = client.Get(ctx, key, &result) + if !errors.Is(err, datastore.ErrNoSuchEntity) { + t.Errorf("Expected datastore.ErrNoSuchEntity after delete, got %v", err) + } + }) + + t.Run("MutateMultiple", func(t *testing.T) { + key1 := datastore.NameKey("MutateTest", "multi1", nil) + key2 := datastore.NameKey("MutateTest", "multi2", nil) + key3 := datastore.NameKey("MutateTest", "multi3", nil) + + entity1 := &testEntity{Name: "first", Count: 1} + entity2 := &testEntity{Name: "second", Count: 2} + entity3 := &testEntity{Name: "third", Count: 3} + + // Pre-create entity3 for update + if _, err := client.Put(ctx, key3, &testEntity{Name: "old", Count: 0}); err != nil { + t.Fatalf("Put failed: %v", err) + } + + // Apply multiple mutations + muts := []*datastore.Mutation{ + datastore.NewInsert(key1, entity1), + datastore.NewUpsert(key2, entity2), + datastore.NewUpdate(key3, entity3), + } + + keys, err := client.Mutate(ctx, muts...) + if err != nil { + t.Fatalf("Mutate multiple failed: %v", err) + } + + if len(keys) != 3 { + t.Errorf("Expected 3 keys, got %d", len(keys)) + } + + // Verify all mutations applied + var result1, result2, result3 testEntity + if err := client.Get(ctx, key1, &result1); err != nil { + t.Errorf("Get key1 failed: %v", err) + } + if err := client.Get(ctx, key2, &result2); err != nil { + t.Errorf("Get key2 failed: %v", err) + } + if err := client.Get(ctx, key3, &result3); err != nil { + t.Errorf("Get key3 failed: %v", err) + } + + if result1.Name != "first" || result2.Name != "second" || result3.Name != "third" { + t.Errorf("Mutation results incorrect") + } + }) + + t.Run("MutateEmpty", func(t *testing.T) { + keys, err := client.Mutate(ctx) + if err != nil { + t.Fatalf("Mutate with no mutations failed: %v", err) + } + + if len(keys) != 0 { + t.Errorf("Expected empty keys, got %d", len(keys)) + } + }) +} diff --git a/pkg/datastore/operations.go b/pkg/datastore/operations.go new file mode 100644 index 0000000..d8eae0a --- /dev/null +++ b/pkg/datastore/operations.go @@ -0,0 +1,498 @@ +package datastore + +import ( + "context" + "encoding/json" + "errors" + "fmt" + neturl "net/url" + "reflect" + + "github.com/codeGROOVE-dev/ds9/auth" +) + +// Get retrieves an entity by key and stores it in dst. +// dst must be a pointer to a struct. +// Returns ErrNoSuchEntity if the key is not found. +func (c *Client) Get(ctx context.Context, key *Key, dst any) error { + if key == nil { + c.logger.WarnContext(ctx, "Get called with nil key") + return errors.New("key cannot be nil") + } + + c.logger.DebugContext(ctx, "getting entity", "kind", key.Kind, "name", key.Name, "id", key.ID) + + token, err := auth.AccessToken(ctx) + if err != nil { + c.logger.ErrorContext(ctx, "failed to get access token", "error", err) + return fmt.Errorf("failed to get access token: %w", err) + } + + reqBody := map[string]any{ + "keys": []map[string]any{keyToJSON(key)}, + } + if c.databaseID != "" { + reqBody["databaseId"] = c.databaseID + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + c.logger.ErrorContext(ctx, "failed to marshal request", "error", err) + return fmt.Errorf("failed to marshal request: %w", err) + } + + // URL-encode project ID to prevent injection attacks + reqURL := fmt.Sprintf("%s/projects/%s:lookup", c.baseURL, neturl.PathEscape(c.projectID)) + body, err := doRequest(ctx, c.logger, reqURL, jsonData, token, c.projectID, c.databaseID) + if err != nil { + c.logger.ErrorContext(ctx, "lookup request failed", "error", err, "kind", key.Kind) + return err + } + + var result struct { + Found []struct { + Entity map[string]any `json:"entity"` + } `json:"found"` + } + + if err := json.Unmarshal(body, &result); err != nil { + c.logger.ErrorContext(ctx, "failed to parse response", "error", err) + return fmt.Errorf("failed to parse response: %w", err) + } + + if len(result.Found) == 0 { + c.logger.DebugContext(ctx, "entity not found", "kind", key.Kind, "name", key.Name, "id", key.ID) + return ErrNoSuchEntity + } + + c.logger.DebugContext(ctx, "entity retrieved successfully", "kind", key.Kind) + return decodeEntity(result.Found[0].Entity, dst) +} + +// Put stores an entity with the given key. +// src must be a struct or pointer to struct. +// Returns the key (useful for auto-generated IDs in the future). +func (c *Client) Put(ctx context.Context, key *Key, src any) (*Key, error) { + if key == nil { + c.logger.WarnContext(ctx, "Put called with nil key") + return nil, errors.New("key cannot be nil") + } + + c.logger.DebugContext(ctx, "putting entity", "kind", key.Kind, "name", key.Name, "id", key.ID) + + token, err := auth.AccessToken(ctx) + if err != nil { + c.logger.ErrorContext(ctx, "failed to get access token", "error", err) + return nil, fmt.Errorf("failed to get access token: %w", err) + } + + entity, err := encodeEntity(key, src) + if err != nil { + c.logger.ErrorContext(ctx, "failed to encode entity", "error", err, "kind", key.Kind) + return nil, err + } + + reqBody := map[string]any{ + "mode": "NON_TRANSACTIONAL", + "mutations": []map[string]any{{"upsert": entity}}, + } + if c.databaseID != "" { + reqBody["databaseId"] = c.databaseID + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + c.logger.ErrorContext(ctx, "failed to marshal request", "error", err) + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + // URL-encode project ID to prevent injection attacks + reqURL := fmt.Sprintf("%s/projects/%s:commit", c.baseURL, neturl.PathEscape(c.projectID)) + if _, err := doRequest(ctx, c.logger, reqURL, jsonData, token, c.projectID, c.databaseID); err != nil { + c.logger.ErrorContext(ctx, "commit request failed", "error", err, "kind", key.Kind) + return nil, err + } + + c.logger.DebugContext(ctx, "entity stored successfully", "kind", key.Kind) + return key, nil +} + +// Delete deletes the entity with the given key. +func (c *Client) Delete(ctx context.Context, key *Key) error { + if key == nil { + c.logger.WarnContext(ctx, "Delete called with nil key") + return errors.New("key cannot be nil") + } + + c.logger.DebugContext(ctx, "deleting entity", "kind", key.Kind, "name", key.Name, "id", key.ID) + + token, err := auth.AccessToken(ctx) + if err != nil { + c.logger.ErrorContext(ctx, "failed to get access token", "error", err) + return fmt.Errorf("failed to get access token: %w", err) + } + + reqBody := map[string]any{ + "mode": "NON_TRANSACTIONAL", + "mutations": []map[string]any{{"delete": keyToJSON(key)}}, + } + if c.databaseID != "" { + reqBody["databaseId"] = c.databaseID + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + c.logger.ErrorContext(ctx, "failed to marshal request", "error", err) + return fmt.Errorf("failed to marshal request: %w", err) + } + + // URL-encode project ID to prevent injection attacks + reqURL := fmt.Sprintf("%s/projects/%s:commit", c.baseURL, neturl.PathEscape(c.projectID)) + if _, err := doRequest(ctx, c.logger, reqURL, jsonData, token, c.projectID, c.databaseID); err != nil { + c.logger.ErrorContext(ctx, "delete request failed", "error", err, "kind", key.Kind) + return err + } + + c.logger.DebugContext(ctx, "entity deleted successfully", "kind", key.Kind) + return nil +} + +// GetMulti retrieves multiple entities by their keys. +// dst must be a pointer to a slice of structs. +// Returns ErrNoSuchEntity if any key is not found. +// This matches the API of cloud.google.com/go/datastore. +func (c *Client) GetMulti(ctx context.Context, keys []*Key, dst any) error { + if len(keys) == 0 { + c.logger.WarnContext(ctx, "GetMulti called with no keys") + return errors.New("keys cannot be empty") + } + + c.logger.DebugContext(ctx, "getting multiple entities", "count", len(keys)) + + token, err := auth.AccessToken(ctx) + if err != nil { + c.logger.ErrorContext(ctx, "failed to get access token", "error", err) + return fmt.Errorf("failed to get access token: %w", err) + } + + // Build keys array + jsonKeys := make([]map[string]any, len(keys)) + for i, key := range keys { + if key == nil { + c.logger.WarnContext(ctx, "GetMulti called with nil key", "index", i) + return fmt.Errorf("key at index %d cannot be nil", i) + } + jsonKeys[i] = keyToJSON(key) + } + + reqBody := map[string]any{ + "keys": jsonKeys, + } + if c.databaseID != "" { + reqBody["databaseId"] = c.databaseID + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + c.logger.ErrorContext(ctx, "failed to marshal request", "error", err) + return fmt.Errorf("failed to marshal request: %w", err) + } + + // URL-encode project ID to prevent injection attacks + reqURL := fmt.Sprintf("%s/projects/%s:lookup", c.baseURL, neturl.PathEscape(c.projectID)) + body, err := doRequest(ctx, c.logger, reqURL, jsonData, token, c.projectID, c.databaseID) + if err != nil { + c.logger.ErrorContext(ctx, "lookup request failed", "error", err) + return err + } + + var result struct { + Found []struct { + Entity map[string]any `json:"entity"` + } `json:"found"` + Missing []struct { + Entity map[string]any `json:"entity"` + } `json:"missing"` + } + + if err := json.Unmarshal(body, &result); err != nil { + c.logger.ErrorContext(ctx, "failed to parse response", "error", err) + return fmt.Errorf("failed to parse response: %w", err) + } + + if len(result.Missing) > 0 { + c.logger.DebugContext(ctx, "some entities not found", "missing_count", len(result.Missing)) + return ErrNoSuchEntity + } + + // Decode into slice + v := reflect.ValueOf(dst) + if v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Slice { + return errors.New("dst must be a pointer to slice") + } + + sliceType := v.Elem().Type() + elemType := sliceType.Elem() + + // Create new slice of correct size + slice := reflect.MakeSlice(sliceType, 0, len(result.Found)) + + for _, found := range result.Found { + elem := reflect.New(elemType).Elem() + if err := decodeEntity(found.Entity, elem.Addr().Interface()); err != nil { + c.logger.ErrorContext(ctx, "failed to decode entity", "error", err) + return err + } + slice = reflect.Append(slice, elem) + } + + v.Elem().Set(slice) + c.logger.DebugContext(ctx, "entities retrieved successfully", "count", len(result.Found)) + return nil +} + +// PutMulti stores multiple entities with their keys. +// keys and src must have the same length. +// Returns the keys (same as input) and any error. +// This matches the API of cloud.google.com/go/datastore. +func (c *Client) PutMulti(ctx context.Context, keys []*Key, src any) ([]*Key, error) { + if len(keys) == 0 { + c.logger.WarnContext(ctx, "PutMulti called with no keys") + return nil, errors.New("keys cannot be empty") + } + + c.logger.DebugContext(ctx, "putting multiple entities", "count", len(keys)) + + // Verify src is a slice + v := reflect.ValueOf(src) + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + if v.Kind() != reflect.Slice { + return nil, errors.New("src must be a slice") + } + + if v.Len() != len(keys) { + return nil, fmt.Errorf("keys and src length mismatch: %d != %d", len(keys), v.Len()) + } + + token, err := auth.AccessToken(ctx) + if err != nil { + c.logger.ErrorContext(ctx, "failed to get access token", "error", err) + return nil, fmt.Errorf("failed to get access token: %w", err) + } + + // Build mutations + mutations := make([]map[string]any, len(keys)) + for i, key := range keys { + if key == nil { + c.logger.WarnContext(ctx, "PutMulti called with nil key", "index", i) + return nil, fmt.Errorf("key at index %d cannot be nil", i) + } + + entity, err := encodeEntity(key, v.Index(i).Interface()) + if err != nil { + c.logger.ErrorContext(ctx, "failed to encode entity", "error", err, "index", i) + return nil, fmt.Errorf("failed to encode entity at index %d: %w", i, err) + } + + mutations[i] = map[string]any{ + "upsert": entity, + } + } + + reqBody := map[string]any{ + "mode": "NON_TRANSACTIONAL", + "mutations": mutations, + } + if c.databaseID != "" { + reqBody["databaseId"] = c.databaseID + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + c.logger.ErrorContext(ctx, "failed to marshal request", "error", err) + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + // URL-encode project ID to prevent injection attacks + reqURL := fmt.Sprintf("%s/projects/%s:commit", c.baseURL, neturl.PathEscape(c.projectID)) + if _, err := doRequest(ctx, c.logger, reqURL, jsonData, token, c.projectID, c.databaseID); err != nil { + c.logger.ErrorContext(ctx, "commit request failed", "error", err) + return nil, err + } + + c.logger.DebugContext(ctx, "entities stored successfully", "count", len(keys)) + return keys, nil +} + +// DeleteMulti deletes multiple entities with their keys. +// This matches the API of cloud.google.com/go/datastore. +func (c *Client) DeleteMulti(ctx context.Context, keys []*Key) error { + if len(keys) == 0 { + c.logger.WarnContext(ctx, "DeleteMulti called with no keys") + return errors.New("keys cannot be empty") + } + + c.logger.DebugContext(ctx, "deleting multiple entities", "count", len(keys)) + + token, err := auth.AccessToken(ctx) + if err != nil { + c.logger.ErrorContext(ctx, "failed to get access token", "error", err) + return fmt.Errorf("failed to get access token: %w", err) + } + + // Build mutations + mutations := make([]map[string]any, len(keys)) + for i, key := range keys { + if key == nil { + c.logger.WarnContext(ctx, "DeleteMulti called with nil key", "index", i) + return fmt.Errorf("key at index %d cannot be nil", i) + } + + mutations[i] = map[string]any{ + "delete": keyToJSON(key), + } + } + + reqBody := map[string]any{ + "mode": "NON_TRANSACTIONAL", + "mutations": mutations, + } + if c.databaseID != "" { + reqBody["databaseId"] = c.databaseID + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + c.logger.ErrorContext(ctx, "failed to marshal request", "error", err) + return fmt.Errorf("failed to marshal request: %w", err) + } + + // URL-encode project ID to prevent injection attacks + reqURL := fmt.Sprintf("%s/projects/%s:commit", c.baseURL, neturl.PathEscape(c.projectID)) + if _, err := doRequest(ctx, c.logger, reqURL, jsonData, token, c.projectID, c.databaseID); err != nil { + c.logger.ErrorContext(ctx, "delete request failed", "error", err) + return err + } + + c.logger.DebugContext(ctx, "entities deleted successfully", "count", len(keys)) + return nil +} + +// DeleteAllByKind deletes all entities of a given kind. +// This method queries for all keys and then deletes them in batches. +func (c *Client) DeleteAllByKind(ctx context.Context, kind string) error { + c.logger.InfoContext(ctx, "deleting all entities by kind", "kind", kind) + + // Query for all keys of this kind + q := NewQuery(kind).KeysOnly() + keys, err := c.AllKeys(ctx, q) + if err != nil { + c.logger.ErrorContext(ctx, "failed to query keys", "kind", kind, "error", err) + return fmt.Errorf("failed to query keys: %w", err) + } + + if len(keys) == 0 { + c.logger.InfoContext(ctx, "no entities found to delete", "kind", kind) + return nil + } + + // Delete all keys + if err := c.DeleteMulti(ctx, keys); err != nil { + c.logger.ErrorContext(ctx, "failed to delete entities", "kind", kind, "count", len(keys), "error", err) + return fmt.Errorf("failed to delete entities: %w", err) + } + + c.logger.InfoContext(ctx, "deleted all entities", "kind", kind, "count", len(keys)) + return nil +} + +// AllocateIDs allocates IDs for incomplete keys. +// Returns keys with IDs filled in. Complete keys are returned unchanged. +// API compatible with cloud.google.com/go/datastore. +func (c *Client) AllocateIDs(ctx context.Context, keys []*Key) ([]*Key, error) { + if len(keys) == 0 { + return keys, nil + } + + c.logger.DebugContext(ctx, "allocating IDs", "count", len(keys)) + + // Separate incomplete and complete keys + var incompleteKeys []*Key + var incompleteIndices []int + for i, key := range keys { + if key != nil && key.Incomplete() { + incompleteKeys = append(incompleteKeys, key) + incompleteIndices = append(incompleteIndices, i) + } + } + + // If no incomplete keys, return original slice + if len(incompleteKeys) == 0 { + c.logger.DebugContext(ctx, "no incomplete keys to allocate") + return keys, nil + } + + token, err := auth.AccessToken(ctx) + if err != nil { + c.logger.ErrorContext(ctx, "failed to get access token", "error", err) + return nil, fmt.Errorf("failed to get access token: %w", err) + } + + // Build request with incomplete keys + reqKeys := make([]map[string]any, len(incompleteKeys)) + for i, key := range incompleteKeys { + reqKeys[i] = keyToJSON(key) + } + + reqBody := map[string]any{ + "keys": reqKeys, + } + if c.databaseID != "" { + reqBody["databaseId"] = c.databaseID + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + c.logger.ErrorContext(ctx, "failed to marshal request", "error", err) + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + // URL-encode project ID to prevent injection attacks + reqURL := fmt.Sprintf("%s/projects/%s:allocateIds", c.baseURL, neturl.PathEscape(c.projectID)) + body, err := doRequest(ctx, c.logger, reqURL, jsonData, token, c.projectID, c.databaseID) + if err != nil { + c.logger.ErrorContext(ctx, "allocateIds request failed", "error", err) + return nil, err + } + + var resp struct { + Keys []map[string]any `json:"keys"` + } + if err := json.Unmarshal(body, &resp); err != nil { + c.logger.ErrorContext(ctx, "failed to parse response", "error", err) + return nil, fmt.Errorf("failed to parse allocateIds response: %w", err) + } + + // Parse allocated keys + allocatedKeys := make([]*Key, len(resp.Keys)) + for i, keyData := range resp.Keys { + key, err := keyFromJSON(keyData) + if err != nil { + c.logger.ErrorContext(ctx, "failed to parse allocated key", "index", i, "error", err) + return nil, fmt.Errorf("failed to parse allocated key at index %d: %w", i, err) + } + allocatedKeys[i] = key + } + + // Create result slice with allocated keys in correct positions + result := make([]*Key, len(keys)) + copy(result, keys) + for i, idx := range incompleteIndices { + result[idx] = allocatedKeys[i] + } + + c.logger.DebugContext(ctx, "IDs allocated successfully", "count", len(allocatedKeys)) + return result, nil +} diff --git a/pkg/datastore/operations_coverage_test.go b/pkg/datastore/operations_coverage_test.go new file mode 100644 index 0000000..b2929b1 --- /dev/null +++ b/pkg/datastore/operations_coverage_test.go @@ -0,0 +1,354 @@ +package datastore_test + +import ( + "context" + "errors" + "testing" + + "github.com/codeGROOVE-dev/ds9/pkg/datastore" +) + +func TestGet_CoverageEdgeCases(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + t.Run("GetWithStructPointer", func(t *testing.T) { + // Create entity + key := datastore.NameKey("GetTest", "test1", nil) + entity := &testEntity{ + Name: "test", + Count: 42, + } + if _, err := client.Put(ctx, key, entity); err != nil { + t.Fatalf("Put failed: %v", err) + } + + // Get with pointer to struct + var retrieved testEntity + err := client.Get(ctx, key, &retrieved) + if err != nil { + t.Fatalf("Get failed: %v", err) + } + + if retrieved.Name != "test" { + t.Errorf("Expected name 'test', got '%s'", retrieved.Name) + } + }) +} + +func TestPut_CoverageEdgeCases(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + t.Run("PutWithCompleteKey", func(t *testing.T) { + key := datastore.NameKey("PutTest", "complete", nil) + entity := &testEntity{ + Name: "complete", + Count: 1, + } + + resultKey, err := client.Put(ctx, key, entity) + if err != nil { + t.Fatalf("Put failed: %v", err) + } + + if resultKey.Name != "complete" { + t.Errorf("Expected key name 'complete', got '%s'", resultKey.Name) + } + }) + + t.Run("PutWithIDKey", func(t *testing.T) { + key := datastore.IDKey("PutTest", 12345, nil) + entity := &testEntity{ + Name: "id-key", + Count: 1, + } + + resultKey, err := client.Put(ctx, key, entity) + if err != nil { + t.Fatalf("Put failed: %v", err) + } + + if resultKey.ID != 12345 { + t.Errorf("Expected key ID 12345, got %d", resultKey.ID) + } + }) + + t.Run("PutOverwrite", func(t *testing.T) { + // Put entity + key := datastore.NameKey("PutTest", "overwrite", nil) + entity := &testEntity{ + Name: "original", + Count: 1, + } + if _, err := client.Put(ctx, key, entity); err != nil { + t.Fatalf("First Put failed: %v", err) + } + + // Overwrite with new data + entity.Name = "updated" + entity.Count = 2 + if _, err := client.Put(ctx, key, entity); err != nil { + t.Fatalf("Second Put failed: %v", err) + } + + // Verify updated + var retrieved testEntity + if err := client.Get(ctx, key, &retrieved); err != nil { + t.Fatalf("Get failed: %v", err) + } + + if retrieved.Name != "updated" { + t.Errorf("Expected name 'updated', got '%s'", retrieved.Name) + } + if retrieved.Count != 2 { + t.Errorf("Expected count 2, got %d", retrieved.Count) + } + }) +} + +func TestDelete_CoverageEdgeCases(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + t.Run("DeleteExisting", func(t *testing.T) { + // Create entity + key := datastore.NameKey("DeleteTest", "existing", nil) + entity := &testEntity{Name: "test", Count: 1} + if _, err := client.Put(ctx, key, entity); err != nil { + t.Fatalf("Put failed: %v", err) + } + + // Delete it + err := client.Delete(ctx, key) + if err != nil { + t.Fatalf("Delete failed: %v", err) + } + + // Verify deleted + var retrieved testEntity + err = client.Get(ctx, key, &retrieved) + if !errors.Is(err, datastore.ErrNoSuchEntity) { + t.Errorf("Expected ErrNoSuchEntity, got %v", err) + } + }) + + t.Run("DeleteNonExistent", func(t *testing.T) { + // Delete non-existent key (should not error) + key := datastore.NameKey("DeleteTest", "nonexistent", nil) + err := client.Delete(ctx, key) + if err != nil { + t.Logf("Delete of non-existent key returned: %v", err) + } + }) +} + +func TestAllocateIDs_EdgeCases(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + t.Run("AllocateSingleID", func(t *testing.T) { + keys := []*datastore.Key{ + datastore.IncompleteKey("AllocTest", nil), + } + + allocated, err := client.AllocateIDs(ctx, keys) + if err != nil { + t.Fatalf("AllocateIDs failed: %v", err) + } + + if len(allocated) != 1 { + t.Errorf("Expected 1 allocated key, got %d", len(allocated)) + } + + if allocated[0].Incomplete() { + t.Error("Allocated key is still incomplete") + } + }) + + t.Run("AllocateMultipleIDs", func(t *testing.T) { + keys := []*datastore.Key{ + datastore.IncompleteKey("AllocTest", nil), + datastore.IncompleteKey("AllocTest", nil), + datastore.IncompleteKey("AllocTest", nil), + } + + allocated, err := client.AllocateIDs(ctx, keys) + if err != nil { + t.Fatalf("AllocateIDs failed: %v", err) + } + + if len(allocated) != 3 { + t.Errorf("Expected 3 allocated keys, got %d", len(allocated)) + } + + // Verify all are complete (mock may not guarantee unique IDs) + for i, key := range allocated { + if key.Incomplete() { + t.Errorf("Key %d is still incomplete", i) + } + } + }) + + t.Run("AllocateWithParentKey", func(t *testing.T) { + parent := datastore.NameKey("Parent", "parent1", nil) + keys := []*datastore.Key{ + datastore.IncompleteKey("Child", parent), + } + + allocated, err := client.AllocateIDs(ctx, keys) + if err != nil { + t.Fatalf("AllocateIDs with parent failed: %v", err) + } + + if len(allocated) != 1 { + t.Errorf("Expected 1 allocated key, got %d", len(allocated)) + } + + if allocated[0].Incomplete() { + t.Error("Allocated key is still incomplete") + } + + if allocated[0].Parent == nil { + t.Error("Parent key was lost during allocation") + } + }) +} + +func TestGetMulti_EdgeCases(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + t.Run("GetMultiAllExist", func(t *testing.T) { + // Create multiple entities + keys := []*datastore.Key{ + datastore.NameKey("MultiTest", "key1", nil), + datastore.NameKey("MultiTest", "key2", nil), + datastore.NameKey("MultiTest", "key3", nil), + } + + for i, key := range keys { + entity := &testEntity{Name: "test", Count: int64(i)} + if _, err := client.Put(ctx, key, entity); err != nil { + t.Fatalf("Put failed: %v", err) + } + } + + // Get all + entities := make([]testEntity, len(keys)) + err := client.GetMulti(ctx, keys, &entities) + if err != nil { + t.Fatalf("GetMulti failed: %v", err) + } + + for i, entity := range entities { + if entity.Count != int64(i) { + t.Errorf("Entity %d: expected count %d, got %d", i, i, entity.Count) + } + } + }) + + t.Run("GetMultiSomeMissing", func(t *testing.T) { + // Create one entity, leave others missing + keys := []*datastore.Key{ + datastore.NameKey("MultiTest2", "exists", nil), + datastore.NameKey("MultiTest2", "missing1", nil), + datastore.NameKey("MultiTest2", "missing2", nil), + } + + entity := &testEntity{Name: "exists", Count: 1} + if _, err := client.Put(ctx, keys[0], entity); err != nil { + t.Fatalf("Put failed: %v", err) + } + + // Get all - should get MultiError + entities := make([]testEntity, len(keys)) + err := client.GetMulti(ctx, keys, &entities) + if err == nil { + t.Log("GetMulti returned nil error (mock may not report missing entities)") + } + }) +} + +func TestPutMulti_EdgeCases(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + t.Run("PutMultiAllComplete", func(t *testing.T) { + keys := []*datastore.Key{ + datastore.NameKey("PutMultiTest2", "key1", nil), + datastore.NameKey("PutMultiTest2", "key2", nil), + } + + entities := []testEntity{ + {Name: "entity1", Count: 1}, + {Name: "entity2", Count: 2}, + } + + resultKeys, err := client.PutMulti(ctx, keys, entities) + if err != nil { + t.Fatalf("PutMulti failed: %v", err) + } + + if len(resultKeys) != len(keys) { + t.Errorf("Expected %d keys, got %d", len(keys), len(resultKeys)) + } + + // Verify all were stored + for i, key := range resultKeys { + var retrieved testEntity + if err := client.Get(ctx, key, &retrieved); err != nil { + t.Errorf("Failed to retrieve entity %d: %v", i, err) + } + } + }) +} + +func TestDeleteMulti_EdgeCases(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + t.Run("DeleteMultiAllExist", func(t *testing.T) { + // Create entities + keys := []*datastore.Key{ + datastore.NameKey("DelMultiTest", "key1", nil), + datastore.NameKey("DelMultiTest", "key2", nil), + } + + for _, key := range keys { + entity := &testEntity{Name: "test", Count: 1} + if _, err := client.Put(ctx, key, entity); err != nil { + t.Fatalf("Put failed: %v", err) + } + } + + // Delete all + err := client.DeleteMulti(ctx, keys) + if err != nil { + t.Fatalf("DeleteMulti failed: %v", err) + } + + // Verify all deleted + for _, key := range keys { + var retrieved testEntity + err := client.Get(ctx, key, &retrieved) + if !errors.Is(err, datastore.ErrNoSuchEntity) { + t.Errorf("Expected ErrNoSuchEntity for key %v, got %v", key, err) + } + } + }) +} diff --git a/pkg/datastore/operations_test.go b/pkg/datastore/operations_test.go new file mode 100644 index 0000000..0612287 --- /dev/null +++ b/pkg/datastore/operations_test.go @@ -0,0 +1,4001 @@ +package datastore_test + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/codeGROOVE-dev/ds9/pkg/datastore" +) + +func TestPutAndGet(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Create test entity + now := time.Now().UTC().Truncate(time.Second) + entity := &testEntity{ + Name: "test-item", + Count: 42, + Active: true, + Score: 3.14, + UpdatedAt: now, + Notes: "This is a test note", + } + + // Put entity + key := datastore.NameKey("TestKind", "test-key", nil) + _, err := client.Put(ctx, key, entity) + if err != nil { + t.Fatalf("Put failed: %v", err) + } + + // Get entity + var retrieved testEntity + err = client.Get(ctx, key, &retrieved) + if err != nil { + t.Fatalf("Get failed: %v", err) + } + + // Verify fields + if retrieved.Name != entity.Name { + t.Errorf("Name: expected %q, got %q", entity.Name, retrieved.Name) + } + if retrieved.Count != entity.Count { + t.Errorf("Count: expected %d, got %d", entity.Count, retrieved.Count) + } + if retrieved.Active != entity.Active { + t.Errorf("Active: expected %v, got %v", entity.Active, retrieved.Active) + } + if retrieved.Score != entity.Score { + t.Errorf("Score: expected %f, got %f", entity.Score, retrieved.Score) + } + if !retrieved.UpdatedAt.Equal(entity.UpdatedAt) { + t.Errorf("UpdatedAt: expected %v, got %v", entity.UpdatedAt, retrieved.UpdatedAt) + } + if retrieved.Notes != entity.Notes { + t.Errorf("Notes: expected %q, got %q", entity.Notes, retrieved.Notes) + } +} + +func TestGetNotFound(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + key := datastore.NameKey("TestKind", "nonexistent", nil) + var entity testEntity + err := client.Get(ctx, key, &entity) + + if !errors.Is(err, datastore.ErrNoSuchEntity) { + t.Errorf("expected datastore.ErrNoSuchEntity, got %v", err) + } +} + +func TestDelete(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Put entity + entity := &testEntity{ + Name: "test-item", + Count: 42, + Active: true, + } + + key := datastore.NameKey("TestKind", "test-key", nil) + _, err := client.Put(ctx, key, entity) + if err != nil { + t.Fatalf("Put failed: %v", err) + } + + // Delete entity + err = client.Delete(ctx, key) + if err != nil { + t.Fatalf("Delete failed: %v", err) + } + + // Verify it's gone + var retrieved testEntity + err = client.Get(ctx, key, &retrieved) + if !errors.Is(err, datastore.ErrNoSuchEntity) { + t.Errorf("expected datastore.ErrNoSuchEntity after delete, got %v", err) + } +} + +func TestAllKeys(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Put multiple entities + for i := range 5 { + entity := &testEntity{ + Name: "test-item", + Count: int64(i), + } + key := datastore.NameKey("TestKind", string(rune('a'+i)), nil) + _, err := client.Put(ctx, key, entity) + if err != nil { + t.Fatalf("Put failed: %v", err) + } + } + + // Query for all keys + query := datastore.NewQuery("TestKind").KeysOnly() + keys, err := client.AllKeys(ctx, query) + if err != nil { + t.Fatalf("AllKeys failed: %v", err) + } + + if len(keys) != 5 { + t.Errorf("expected 5 keys, got %d", len(keys)) + } +} + +func TestAllKeysWithLimit(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Put multiple entities + for i := range 10 { + entity := &testEntity{ + Name: "test-item", + Count: int64(i), + } + key := datastore.NameKey("TestKind", string(rune('a'+i)), nil) + _, err := client.Put(ctx, key, entity) + if err != nil { + t.Fatalf("Put failed: %v", err) + } + } + + // Query with limit + query := datastore.NewQuery("TestKind").KeysOnly().Limit(3) + keys, err := client.AllKeys(ctx, query) + if err != nil { + t.Fatalf("AllKeys failed: %v", err) + } + + if len(keys) != 3 { + t.Errorf("expected 3 keys, got %d", len(keys)) + } +} + +func TestIDKey(t *testing.T) { + key := datastore.IDKey("TestKind", 12345, nil) + + if key.Kind != "TestKind" { + t.Errorf("expected Kind %q, got %q", "TestKind", key.Kind) + } + + if key.ID != 12345 { + t.Errorf("expected ID %d, got %d", 12345, key.ID) + } + + if key.Name != "" { + t.Errorf("expected empty Name, got %q", key.Name) + } +} + +func TestNameKey(t *testing.T) { + key := datastore.NameKey("TestKind", "test-name", nil) + + if key.Kind != "TestKind" { + t.Errorf("expected Kind %q, got %q", "TestKind", key.Kind) + } + + if key.Name != "test-name" { + t.Errorf("expected Name %q, got %q", "test-name", key.Name) + } + + if key.ID != 0 { + t.Errorf("expected ID 0, got %d", key.ID) + } +} + +func TestMultiPutAndMultiGet(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Create test entities + now := time.Now().UTC().Truncate(time.Second) + entities := []testEntity{ + { + Name: "item-1", + Count: 1, + Active: true, + Score: 1.1, + UpdatedAt: now, + }, + { + Name: "item-2", + Count: 2, + Active: false, + Score: 2.2, + UpdatedAt: now, + }, + { + Name: "item-3", + Count: 3, + Active: true, + Score: 3.3, + UpdatedAt: now, + }, + } + + keys := []*datastore.Key{ + datastore.NameKey("TestKind", "key-1", nil), + datastore.NameKey("TestKind", "key-2", nil), + datastore.NameKey("TestKind", "key-3", nil), + } + + // MultiPut + _, err := client.PutMulti(ctx, keys, entities) + if err != nil { + t.Fatalf("MultiPut failed: %v", err) + } + + // MultiGet + var retrieved []testEntity + err = client.GetMulti(ctx, keys, &retrieved) + if err != nil { + t.Fatalf("MultiGet failed: %v", err) + } + + if len(retrieved) != 3 { + t.Fatalf("expected 3 entities, got %d", len(retrieved)) + } + + // Verify entities + for i, entity := range retrieved { + if entity.Name != entities[i].Name { + t.Errorf("entity %d: Name mismatch: expected %q, got %q", i, entities[i].Name, entity.Name) + } + if entity.Count != entities[i].Count { + t.Errorf("entity %d: Count mismatch: expected %d, got %d", i, entities[i].Count, entity.Count) + } + if entity.Active != entities[i].Active { + t.Errorf("entity %d: Active mismatch: expected %v, got %v", i, entities[i].Active, entity.Active) + } + if entity.Score != entities[i].Score { + t.Errorf("entity %d: Score mismatch: expected %f, got %f", i, entities[i].Score, entity.Score) + } + } +} + +func TestMultiGetNotFound(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Put only one entity + entity := &testEntity{Name: "exists", Count: 1} + key1 := datastore.NameKey("TestKind", "exists", nil) + _, err := client.Put(ctx, key1, entity) + if err != nil { + t.Fatalf("Put failed: %v", err) + } + + // Try to get multiple, one missing + keys := []*datastore.Key{ + key1, + datastore.NameKey("TestKind", "missing", nil), + } + + var retrieved []testEntity + err = client.GetMulti(ctx, keys, &retrieved) + + if !errors.Is(err, datastore.ErrNoSuchEntity) { + t.Errorf("expected datastore.ErrNoSuchEntity when some keys missing, got %v", err) + } +} + +func TestMultiDelete(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Put multiple entities + entities := []testEntity{ + {Name: "item-1", Count: 1}, + {Name: "item-2", Count: 2}, + {Name: "item-3", Count: 3}, + } + + keys := []*datastore.Key{ + datastore.NameKey("TestKind", "key-1", nil), + datastore.NameKey("TestKind", "key-2", nil), + datastore.NameKey("TestKind", "key-3", nil), + } + + _, err := client.PutMulti(ctx, keys, entities) + if err != nil { + t.Fatalf("MultiPut failed: %v", err) + } + + // MultiDelete + err = client.DeleteMulti(ctx, keys) + if err != nil { + t.Fatalf("MultiDelete failed: %v", err) + } + + // Verify they're gone by trying to get them + var retrieved []testEntity + err = client.GetMulti(ctx, keys, &retrieved) + if !errors.Is(err, datastore.ErrNoSuchEntity) { + t.Errorf("expected datastore.ErrNoSuchEntity after delete, got %v", err) + } +} + +func TestMultiPutEmptyKeys(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + var entities []testEntity + var keys []*datastore.Key + + _, err := client.PutMulti(ctx, keys, entities) + if err == nil { + t.Error("expected error for empty keys, got nil") + } +} + +func TestMultiGetEmptyKeys(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + var keys []*datastore.Key + var retrieved []testEntity + + err := client.GetMulti(ctx, keys, &retrieved) + if err == nil { + t.Error("expected error for empty keys, got nil") + } +} + +func TestMultiDeleteEmptyKeys(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + var keys []*datastore.Key + + err := client.DeleteMulti(ctx, keys) + if err == nil { + t.Error("expected error for empty keys, got nil") + } +} + +func TestIDKeyOperations(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Test with ID key + entity := &testEntity{ + Name: "id-test", + Count: 123, + } + + key := datastore.IDKey("TestKind", 999, nil) + _, err := client.Put(ctx, key, entity) + if err != nil { + t.Fatalf("Put with ID key failed: %v", err) + } + + // Get with ID key + var retrieved testEntity + err = client.Get(ctx, key, &retrieved) + if err != nil { + t.Fatalf("Get with ID key failed: %v", err) + } + + if retrieved.Name != "id-test" { + t.Errorf("expected Name 'id-test', got %q", retrieved.Name) + } +} + +func TestPutWithNilKey(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + entity := &testEntity{Name: "test"} + _, err := client.Put(ctx, nil, entity) + if err == nil { + t.Error("expected error for nil key, got nil") + } +} + +func TestGetWithNilKey(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + var entity testEntity + err := client.Get(ctx, nil, &entity) + if err == nil { + t.Error("expected error for nil key, got nil") + } +} + +func TestDeleteWithNilKey(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + err := client.Delete(ctx, nil) + if err == nil { + t.Error("expected error for nil key, got nil") + } +} + +func TestMultiGetWithNilKey(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + keys := []*datastore.Key{ + datastore.NameKey("TestKind", "key-1", nil), + nil, + datastore.NameKey("TestKind", "key-2", nil), + } + + var entities []testEntity + err := client.GetMulti(ctx, keys, &entities) + if err == nil { + t.Error("expected error for nil key in slice, got nil") + } +} + +func TestMultiPutWithNilKey(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + entities := []testEntity{ + {Name: "item-1"}, + {Name: "item-2"}, + } + + keys := []*datastore.Key{ + datastore.NameKey("TestKind", "key-1", nil), + nil, + } + + _, err := client.PutMulti(ctx, keys, entities) + if err == nil { + t.Error("expected error for nil key in slice, got nil") + } +} + +func TestMultiDeleteWithNilKey(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + keys := []*datastore.Key{ + datastore.NameKey("TestKind", "key-1", nil), + nil, + } + + err := client.DeleteMulti(ctx, keys) + if err == nil { + t.Error("expected error for nil key in slice, got nil") + } +} + +func TestMultiPutMismatchedSlices(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + entities := []testEntity{ + {Name: "item-1"}, + {Name: "item-2"}, + } + + keys := []*datastore.Key{ + datastore.NameKey("TestKind", "key-1", nil), + } + + _, err := client.PutMulti(ctx, keys, entities) + if err == nil { + t.Error("expected error for mismatched slices, got nil") + } +} + +func TestAllKeysNonKeysOnlyQuery(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Create a query without KeysOnly + query := datastore.NewQuery("TestKind") + _, err := client.AllKeys(ctx, query) + if err == nil { + t.Error("expected error for non-KeysOnly query, got nil") + } +} + +func TestMultiGetPartialResults(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Put some entities + entities := []testEntity{ + {Name: "item-1", Count: 1}, + {Name: "item-3", Count: 3}, + } + keys := []*datastore.Key{ + datastore.NameKey("TestKind", "key-1", nil), + datastore.NameKey("TestKind", "key-3", nil), + } + _, err := client.PutMulti(ctx, keys, entities) + if err != nil { + t.Fatalf("MultiPut failed: %v", err) + } + + // Try to get more keys than exist + getAllKeys := []*datastore.Key{ + datastore.NameKey("TestKind", "key-1", nil), + datastore.NameKey("TestKind", "key-2", nil), // doesn't exist + datastore.NameKey("TestKind", "key-3", nil), + } + + var retrieved []testEntity + err = client.GetMulti(ctx, getAllKeys, &retrieved) + if err == nil { + t.Error("expected error when some keys don't exist") + } +} + +func TestKeyComparison(t *testing.T) { + nameKey1 := datastore.NameKey("Kind", "name", nil) + nameKey2 := datastore.NameKey("Kind", "name", nil) + + if nameKey1.Kind != nameKey2.Kind || nameKey1.Name != nameKey2.Name { + t.Error("identical name keys should have same values") + } + + idKey1 := datastore.IDKey("Kind", 123, nil) + idKey2 := datastore.IDKey("Kind", 123, nil) + + if idKey1.Kind != idKey2.Kind || idKey1.ID != idKey2.ID { + t.Error("identical ID keys should have same values") + } +} + +func TestLargeEntityBatch(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Create a larger batch + const batchSize = 50 + entities := make([]testEntity, batchSize) + keys := make([]*datastore.Key, batchSize) + + for i := range batchSize { + entities[i] = testEntity{ + Name: "batch-item", + Count: int64(i), + } + keys[i] = datastore.NameKey("BatchKind", string(rune('0'+i/10))+string(rune('0'+i%10)), nil) + } + + // MultiPut + _, err := client.PutMulti(ctx, keys, entities) + if err != nil { + t.Fatalf("MultiPut failed: %v", err) + } + + // MultiGet + var retrieved []testEntity + err = client.GetMulti(ctx, keys, &retrieved) + if err != nil { + t.Fatalf("MultiGet failed: %v", err) + } + + if len(retrieved) != batchSize { + t.Errorf("expected %d entities, got %d", batchSize, len(retrieved)) + } + + // MultiDelete + err = client.DeleteMulti(ctx, keys) + if err != nil { + t.Fatalf("MultiDelete failed: %v", err) + } + + // Verify deletion + var retrieved2 []testEntity + err = client.GetMulti(ctx, keys, &retrieved2) + if !errors.Is(err, datastore.ErrNoSuchEntity) { + t.Errorf("expected datastore.ErrNoSuchEntity after batch delete, got %v", err) + } +} + +func TestMultiGetEmptySlices(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Call MultiGet with empty slices - should return error + var entities []testEntity + err := client.GetMulti(ctx, []*datastore.Key{}, &entities) + if err == nil { + t.Error("expected error for MultiGet with empty keys, got nil") + } +} + +func TestMultiPutEmptySlices(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Call MultiPut with empty slices - should return error + _, err := client.PutMulti(ctx, []*datastore.Key{}, []testEntity{}) + if err == nil { + t.Error("expected error for MultiPut with empty keys, got nil") + } +} + +func TestMultiDeleteEmptySlice(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Call MultiDelete with empty slice - should return error + err := client.DeleteMulti(ctx, []*datastore.Key{}) + if err == nil { + t.Error("expected error for MultiDelete with empty keys, got nil") + } +} + +func TestDeleteWithDatabaseID(t *testing.T) { + // Setup with databaseID + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Metadata-Flavor") != "Google" { + w.WriteHeader(http.StatusForbidden) + return + } + if r.URL.Path == "/project/project-id" { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("test-project")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + if r.URL.Path == "/instance/service-accounts/default/token" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-token", + "expires_in": 3600, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer metadataServer.Close() + + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Authorization") != "Bearer test-token" { + w.WriteHeader(http.StatusUnauthorized) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(map[string]any{ + "mutationResults": []any{}, + }); err != nil { + t.Logf("encode failed: %v", err) + } + })) + defer apiServer.Close() + + restore := datastore.SetTestURLs(metadataServer.URL, apiServer.URL) + defer restore() + + ctx := context.Background() + client, err := datastore.NewClientWithDatabase(ctx, "test-project", "del-db") + if err != nil { + t.Fatalf("NewClientWithDatabase failed: %v", err) + } + + // Delete with databaseID + key := datastore.NameKey("TestKind", "to-delete", nil) + err = client.Delete(ctx, key) + if err != nil { + t.Fatalf("Delete with databaseID failed: %v", err) + } +} + +func TestAllKeysWithDatabaseID(t *testing.T) { + // Setup with databaseID + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Metadata-Flavor") != "Google" { + w.WriteHeader(http.StatusForbidden) + return + } + if r.URL.Path == "/project/project-id" { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("test-project")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + if r.URL.Path == "/instance/service-accounts/default/token" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-token", + "expires_in": 3600, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer metadataServer.Close() + + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Authorization") != "Bearer test-token" { + w.WriteHeader(http.StatusUnauthorized) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(map[string]any{ + "batch": map[string]any{ + "entityResults": []any{}, + }, + }); err != nil { + t.Logf("encode failed: %v", err) + } + })) + defer apiServer.Close() + + restore := datastore.SetTestURLs(metadataServer.URL, apiServer.URL) + defer restore() + + ctx := context.Background() + client, err := datastore.NewClientWithDatabase(ctx, "test-project", "query-db") + if err != nil { + t.Fatalf("NewClientWithDatabase failed: %v", err) + } + + // Query with databaseID + query := datastore.NewQuery("TestKind").KeysOnly() + keys, err := client.AllKeys(ctx, query) + if err != nil { + t.Fatalf("AllKeys with databaseID failed: %v", err) + } + + if len(keys) != 0 { + t.Errorf("expected 0 keys, got %d", len(keys)) + } +} + +func TestMultiGetWithDatabaseID(t *testing.T) { + // Setup with databaseID + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Metadata-Flavor") != "Google" { + w.WriteHeader(http.StatusForbidden) + return + } + if r.URL.Path == "/project/project-id" { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("test-project")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + if r.URL.Path == "/instance/service-accounts/default/token" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-token", + "expires_in": 3600, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer metadataServer.Close() + + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Authorization") != "Bearer test-token" { + w.WriteHeader(http.StatusUnauthorized) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + // Return missing entities to trigger datastore.ErrNoSuchEntity + if err := json.NewEncoder(w).Encode(map[string]any{ + "found": []any{}, + "missing": []any{ + map[string]any{"entity": map[string]any{"key": map[string]any{}}}, + }, + }); err != nil { + t.Logf("encode failed: %v", err) + } + })) + defer apiServer.Close() + + restore := datastore.SetTestURLs(metadataServer.URL, apiServer.URL) + defer restore() + + ctx := context.Background() + client, err := datastore.NewClientWithDatabase(ctx, "test-project", "multiget-db") + if err != nil { + t.Fatalf("NewClientWithDatabase failed: %v", err) + } + + // MultiGet with databaseID + keys := []*datastore.Key{ + datastore.NameKey("TestKind", "key1", nil), + datastore.NameKey("TestKind", "key2", nil), + } + var entities []testEntity + err = client.GetMulti(ctx, keys, &entities) + // Expect error since entities don't exist + if !errors.Is(err, datastore.ErrNoSuchEntity) { + t.Errorf("expected datastore.ErrNoSuchEntity, got: %v", err) + } +} + +func TestMultiDeleteWithDatabaseID(t *testing.T) { + // Setup with databaseID + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Metadata-Flavor") != "Google" { + w.WriteHeader(http.StatusForbidden) + return + } + if r.URL.Path == "/project/project-id" { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("test-project")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + if r.URL.Path == "/instance/service-accounts/default/token" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-token", + "expires_in": 3600, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer metadataServer.Close() + + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Authorization") != "Bearer test-token" { + w.WriteHeader(http.StatusUnauthorized) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(map[string]any{ + "mutationResults": []any{}, + }); err != nil { + t.Logf("encode failed: %v", err) + } + })) + defer apiServer.Close() + + restore := datastore.SetTestURLs(metadataServer.URL, apiServer.URL) + defer restore() + + ctx := context.Background() + client, err := datastore.NewClientWithDatabase(ctx, "test-project", "multidel-db") + if err != nil { + t.Fatalf("NewClientWithDatabase failed: %v", err) + } + + // MultiDelete with databaseID + keys := []*datastore.Key{ + datastore.NameKey("TestKind", "key1", nil), + datastore.NameKey("TestKind", "key2", nil), + } + err = client.DeleteMulti(ctx, keys) + if err != nil { + t.Fatalf("MultiDelete with databaseID failed: %v", err) + } +} + +func TestDeleteAllByKind(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Put multiple entities of the same kind + for i := range 5 { + entity := &testEntity{ + Name: "item", + Count: int64(i), + } + key := datastore.NameKey("DeleteKind", string(rune('a'+i)), nil) + _, err := client.Put(ctx, key, entity) + if err != nil { + t.Fatalf("Put failed: %v", err) + } + } + + // Delete all entities of this kind + err := client.DeleteAllByKind(ctx, "DeleteKind") + if err != nil { + t.Fatalf("DeleteAllByKind failed: %v", err) + } + + // Verify all deleted + query := datastore.NewQuery("DeleteKind").KeysOnly() + keys, err := client.AllKeys(ctx, query) + if err != nil { + t.Fatalf("AllKeys failed: %v", err) + } + + if len(keys) != 0 { + t.Errorf("expected 0 keys after DeleteAllByKind, got %d", len(keys)) + } +} + +func TestDeleteAllByKindEmpty(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Delete from non-existent kind + err := client.DeleteAllByKind(ctx, "NonExistentKind") + if err != nil { + t.Errorf("DeleteAllByKind on empty kind should not error, got: %v", err) + } +} + +func TestHierarchicalKeys(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Create parent key + parentKey := datastore.NameKey("Parent", "parent1", nil) + parentEntity := &testEntity{ + Name: "parent", + Count: 1, + } + _, err := client.Put(ctx, parentKey, parentEntity) + if err != nil { + t.Fatalf("Put parent failed: %v", err) + } + + // Create child key with parent + childKey := datastore.NameKey("Child", "child1", parentKey) + childEntity := &testEntity{ + Name: "child", + Count: 2, + } + _, err = client.Put(ctx, childKey, childEntity) + if err != nil { + t.Fatalf("Put child failed: %v", err) + } + + // Get child + var retrieved testEntity + err = client.Get(ctx, childKey, &retrieved) + if err != nil { + t.Fatalf("Get child failed: %v", err) + } + + if retrieved.Name != "child" { + t.Errorf("expected child name 'child', got %q", retrieved.Name) + } +} + +func TestHierarchicalKeysMultiLevel(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Create grandparent -> parent -> child hierarchy + grandparentKey := datastore.NameKey("Grandparent", "gp1", nil) + parentKey := datastore.NameKey("Parent", "p1", grandparentKey) + childKey := datastore.NameKey("Child", "c1", parentKey) + + entity := &testEntity{ + Name: "deep-child", + Count: 42, + } + + _, err := client.Put(ctx, childKey, entity) + if err != nil { + t.Fatalf("Put with multi-level hierarchy failed: %v", err) + } + + var retrieved testEntity + err = client.Get(ctx, childKey, &retrieved) + if err != nil { + t.Fatalf("Get with multi-level hierarchy failed: %v", err) + } + + if retrieved.Name != "deep-child" { + t.Errorf("expected name 'deep-child', got %q", retrieved.Name) + } +} + +func TestPutWithInvalidEntity(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + type InvalidEntity struct { + Map map[string]string // maps not supported + } + + key := datastore.NameKey("TestKind", "invalid", nil) + entity := &InvalidEntity{ + Map: map[string]string{"key": "value"}, + } + + _, err := client.Put(ctx, key, entity) + if err == nil { + t.Error("expected error for unsupported entity type") + } +} + +func TestGetMultiWithMismatchedSliceSize(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Put one entity + key1 := datastore.NameKey("TestKind", "key1", nil) + entity := &testEntity{Name: "test", Count: 1} + _, err := client.Put(ctx, key1, entity) + if err != nil { + t.Fatalf("Put failed: %v", err) + } + + // Try to get with wrong slice type + keys := []*datastore.Key{key1} + var retrieved []testEntity + + // This should work + err = client.GetMulti(ctx, keys, &retrieved) + if err != nil { + t.Fatalf("GetMulti failed: %v", err) + } + + if len(retrieved) != 1 { + t.Errorf("expected 1 entity, got %d", len(retrieved)) + } +} + +func TestKeyFromJSONEdgeCases(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Test with ID key using integer ID + idKey := datastore.IDKey("TestKind", 12345, nil) + entity := &testEntity{Name: "id-test", Count: 1} + _, err := client.Put(ctx, idKey, entity) + if err != nil { + t.Fatalf("Put with ID key failed: %v", err) + } + + var retrieved testEntity + err = client.Get(ctx, idKey, &retrieved) + if err != nil { + t.Fatalf("Get with ID key failed: %v", err) + } + + if retrieved.Name != "id-test" { + t.Errorf("expected name 'id-test', got %q", retrieved.Name) + } +} + +func TestGetMultiMixedResults(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Put some entities + key1 := datastore.NameKey("Mixed", "exists1", nil) + key2 := datastore.NameKey("Mixed", "exists2", nil) + key3 := datastore.NameKey("Mixed", "missing", nil) + + entities := []testEntity{ + {Name: "entity1", Count: 1}, + {Name: "entity2", Count: 2}, + } + + _, err := client.PutMulti(ctx, []*datastore.Key{key1, key2}, entities) + if err != nil { + t.Fatalf("PutMulti failed: %v", err) + } + + // Try to get mix of existing and non-existing + keys := []*datastore.Key{key1, key2, key3} + var retrieved []testEntity + + err = client.GetMulti(ctx, keys, &retrieved) + if !errors.Is(err, datastore.ErrNoSuchEntity) { + t.Errorf("expected datastore.ErrNoSuchEntity for mixed results, got: %v", err) + } +} + +func TestPutMultiLargeBatch(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Create large batch + const size = 100 + entities := make([]testEntity, size) + keys := make([]*datastore.Key, size) + + for i := range size { + entities[i] = testEntity{ + Name: "large-batch", + Count: int64(i), + } + keys[i] = datastore.NameKey("LargeBatch", fmt.Sprintf("key-%d", i), nil) + } + + _, err := client.PutMulti(ctx, keys, entities) + if err != nil { + t.Fatalf("PutMulti with large batch failed: %v", err) + } + + // Verify a few + var retrieved testEntity + err = client.Get(ctx, keys[0], &retrieved) + if err != nil { + t.Fatalf("Get failed: %v", err) + } + + if retrieved.Count != 0 { + t.Errorf("expected Count 0, got %d", retrieved.Count) + } +} + +func TestDeleteMultiWithErrors(t *testing.T) { + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Metadata-Flavor") != "Google" { + w.WriteHeader(http.StatusForbidden) + return + } + if r.URL.Path == "/project/project-id" { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("test-project")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + if r.URL.Path == "/instance/service-accounts/default/token" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-token", + "expires_in": 3600, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer metadataServer.Close() + + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Return server error + w.WriteHeader(http.StatusInternalServerError) + if _, err := w.Write([]byte(`{"error":"internal error"}`)); err != nil { + t.Logf("write failed: %v", err) + } + })) + defer apiServer.Close() + + restore := datastore.SetTestURLs(metadataServer.URL, apiServer.URL) + defer restore() + + ctx := context.Background() + client, err := datastore.NewClient(ctx, "test-project") + if err != nil { + t.Fatalf("NewClient failed: %v", err) + } + + keys := []*datastore.Key{ + datastore.NameKey("TestKind", "key1", nil), + datastore.NameKey("TestKind", "key2", nil), + } + + err = client.DeleteMulti(ctx, keys) + if err == nil { + t.Fatal("expected error on server failure") + } +} + +func TestKeyWithOnlyKind(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Key with neither name nor ID should work (incomplete key) + // This gets an ID assigned by the datastore + key := &datastore.Key{Kind: "TestKind"} + entity := &testEntity{Name: "test", Count: 1} + + returnedKey, err := client.Put(ctx, key, entity) + if err != nil { + t.Fatalf("Put with incomplete key failed: %v", err) + } + + // The returned key should have an ID + if returnedKey == nil { + t.Fatal("expected non-nil returned key") + } + + if returnedKey.Kind != "TestKind" { + t.Errorf("expected Kind 'TestKind', got %q", returnedKey.Kind) + } +} + +func TestGetMultiAllMissing(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + keys := []*datastore.Key{ + datastore.NameKey("Missing", "key1", nil), + datastore.NameKey("Missing", "key2", nil), + datastore.NameKey("Missing", "key3", nil), + } + + var entities []testEntity + err := client.GetMulti(ctx, keys, &entities) + + if !errors.Is(err, datastore.ErrNoSuchEntity) { + t.Errorf("expected datastore.ErrNoSuchEntity when all keys missing, got: %v", err) + } +} + +func TestGetMultiWithSliceMismatch(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Put entity + key := datastore.NameKey("Test", "key1", nil) + entity := &testEntity{Name: "test", Count: 1} + _, err := client.Put(ctx, key, entity) + if err != nil { + t.Fatalf("Put failed: %v", err) + } + + // GetMulti with destination not being a pointer to slice + var notSlice testEntity + err = client.GetMulti(ctx, []*datastore.Key{key}, notSlice) + if err == nil { + t.Error("expected error when dst is not pointer to slice") + } +} + +func TestPutMultiWithLengthMismatch(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Keys and entities with different lengths + keys := []*datastore.Key{ + datastore.NameKey("Test", "key1", nil), + datastore.NameKey("Test", "key2", nil), + } + entities := []testEntity{ + {Name: "only-one", Count: 1}, + } + + _, err := client.PutMulti(ctx, keys, entities) + if err == nil { + t.Error("expected error when keys and entities have different lengths") + } +} + +func TestDeleteWithNonexistentKey(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Delete non-existent key (should not error) + key := datastore.NameKey("Test", "nonexistent", nil) + err := client.Delete(ctx, key) + if err != nil { + t.Errorf("Delete of non-existent key should not error, got: %v", err) + } +} + +func TestAllKeysWithEmptyResult(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Query kind with no entities + query := datastore.NewQuery("EmptyKind").KeysOnly() + keys, err := client.AllKeys(ctx, query) + if err != nil { + t.Fatalf("AllKeys on empty kind failed: %v", err) + } + + if len(keys) != 0 { + t.Errorf("expected 0 keys, got %d", len(keys)) + } +} + +func TestAllKeysWithLargeResult(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Put many entities + for i := range 50 { + key := datastore.NameKey("LargeResult", fmt.Sprintf("key-%d", i), nil) + entity := &testEntity{Name: "test", Count: int64(i)} + _, err := client.Put(ctx, key, entity) + if err != nil { + t.Fatalf("Put failed: %v", err) + } + } + + // Query all + query := datastore.NewQuery("LargeResult").KeysOnly() + keys, err := client.AllKeys(ctx, query) + if err != nil { + t.Fatalf("AllKeys failed: %v", err) + } + + if len(keys) != 50 { + t.Errorf("expected 50 keys, got %d", len(keys)) + } +} + +func TestPutMultiEmptySlice(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Empty slices + _, err := client.PutMulti(ctx, []*datastore.Key{}, []testEntity{}) + if err == nil { + t.Error("expected error for empty slices") + } +} + +func TestGetMultiEmptySlice(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + var entities []testEntity + err := client.GetMulti(ctx, []*datastore.Key{}, &entities) + if err == nil { + t.Error("expected error for empty keys") + } +} + +func TestDeleteMultiEmptySlice(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + err := client.DeleteMulti(ctx, []*datastore.Key{}) + if err == nil { + t.Error("expected error for empty keys") + } +} + +func TestDeepHierarchicalKeys(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Create 4-level hierarchy + gp := datastore.NameKey("GP", "gp1", nil) + p := datastore.NameKey("P", "p1", gp) + c := datastore.NameKey("C", "c1", p) + gc := datastore.NameKey("GC", "gc1", c) + + entity := &testEntity{Name: "great-grandchild", Count: 42} + _, err := client.Put(ctx, gc, entity) + if err != nil { + t.Fatalf("Put with 4-level hierarchy failed: %v", err) + } + + var retrieved testEntity + err = client.Get(ctx, gc, &retrieved) + if err != nil { + t.Fatalf("Get with 4-level hierarchy failed: %v", err) + } + + if retrieved.Name != "great-grandchild" { + t.Errorf("expected name 'great-grandchild', got %q", retrieved.Name) + } +} + +func TestGetWithNonPointerDst(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Put entity + key := datastore.NameKey("Test", "key", nil) + entity := &testEntity{Name: "test", Count: 1} + _, err := client.Put(ctx, key, entity) + if err != nil { + t.Fatalf("Put failed: %v", err) + } + + // Try to get into non-pointer + var notPointer testEntity + err = client.Get(ctx, key, notPointer) // Should be ¬Pointer + if err == nil { + t.Error("expected error when dst is not a pointer") + } +} + +func TestPutWithNonPointerEntity(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + key := datastore.NameKey("Test", "key", nil) + entity := testEntity{Name: "test", Count: 1} // not a pointer + + // The mock implementation may accept non-pointers, but test with the real client + // For now, just test that it works (real Datastore would require pointer) + _, err := client.Put(ctx, key, entity) + if err != nil { + t.Logf("Put with non-pointer entity failed (expected with real client): %v", err) + } +} + +func TestDeleteAllByKindWithNoEntities(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Delete from kind with no entities + err := client.DeleteAllByKind(ctx, "NonExistentKind") + if err != nil { + t.Errorf("DeleteAllByKind on empty kind should not error, got: %v", err) + } +} + +func TestDeleteAllByKindWithManyEntities(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Put many entities + for i := range 25 { + key := datastore.NameKey("ManyDelete", fmt.Sprintf("key-%d", i), nil) + entity := &testEntity{Name: "test", Count: int64(i)} + _, err := client.Put(ctx, key, entity) + if err != nil { + t.Fatalf("Put failed: %v", err) + } + } + + // Delete all + err := client.DeleteAllByKind(ctx, "ManyDelete") + if err != nil { + t.Fatalf("DeleteAllByKind failed: %v", err) + } + + // Verify all deleted + query := datastore.NewQuery("ManyDelete").KeysOnly() + keys, err := client.AllKeys(ctx, query) + if err != nil { + t.Fatalf("AllKeys failed: %v", err) + } + + if len(keys) != 0 { + t.Errorf("expected 0 keys after DeleteAllByKind, got %d", len(keys)) + } +} + +func TestIDKeyWithZeroID(t *testing.T) { + // Zero ID is valid + key := datastore.IDKey("Test", 0, nil) + if key.ID != 0 { + t.Errorf("expected ID 0, got %d", key.ID) + } + if key.Name != "" { + t.Errorf("expected empty Name, got %q", key.Name) + } +} + +func TestNameKeyWithEmptyName(t *testing.T) { + // Empty name is technically valid + key := datastore.NameKey("Test", "", nil) + if key.Name != "" { + t.Errorf("expected empty Name, got %q", key.Name) + } + if key.ID != 0 { + t.Errorf("expected ID 0, got %d", key.ID) + } +} + +func TestGetMultiWithNonSliceDst(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + keys := []*datastore.Key{ + datastore.NameKey("Test", "key1", nil), + } + + // Pass a non-slice as destination + var notSlice string + err := client.GetMulti(ctx, keys, ¬Slice) + + if err == nil { + t.Error("expected error when dst is not a slice") + } +} + +func TestPutMultiWithNonSliceSrc(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + keys := []*datastore.Key{ + datastore.NameKey("Test", "key1", nil), + } + + // Pass a non-slice as source + notSlice := "not a slice" + _, err := client.PutMulti(ctx, keys, notSlice) + + if err == nil { + t.Error("expected error when src is not a slice") + } +} + +func TestAllKeysQueryWithoutKeysOnly(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Create query without KeysOnly + query := datastore.NewQuery("Test") + + _, err := client.AllKeys(ctx, query) + + if err == nil { + t.Error("expected error for query without KeysOnly") + } + + if !strings.Contains(err.Error(), "KeysOnly") { + t.Errorf("expected error to mention KeysOnly, got: %v", err) + } +} + +func TestDeleteAllByKindQueryFailure(t *testing.T) { + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Metadata-Flavor") != "Google" { + w.WriteHeader(http.StatusForbidden) + return + } + if r.URL.Path == "/project/project-id" { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("test-project")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + if r.URL.Path == "/instance/service-accounts/default/token" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-token", + "expires_in": 3600, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer metadataServer.Close() + + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Fail on query request + if strings.Contains(r.URL.Path, "runQuery") { + w.WriteHeader(http.StatusInternalServerError) + if _, err := w.Write([]byte(`{"error":"query failed"}`)); err != nil { + t.Logf("write failed: %v", err) + } + return + } + w.WriteHeader(http.StatusOK) + })) + defer apiServer.Close() + + restore := datastore.SetTestURLs(metadataServer.URL, apiServer.URL) + defer restore() + + ctx := context.Background() + client, err := datastore.NewClient(ctx, "test-project") + if err != nil { + t.Fatalf("NewClient failed: %v", err) + } + + err = client.DeleteAllByKind(ctx, "TestKind") + + if err == nil { + t.Error("expected error when query fails") + } +} + +func TestGetWithInvalidJSONResponse(t *testing.T) { + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Metadata-Flavor") != "Google" { + w.WriteHeader(http.StatusForbidden) + return + } + if r.URL.Path == "/project/project-id" { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("test-project")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + if r.URL.Path == "/instance/service-accounts/default/token" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-token", + "expires_in": 3600, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer metadataServer.Close() + + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Return invalid JSON + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte(`{invalid json`)); err != nil { + t.Logf("write failed: %v", err) + } + })) + defer apiServer.Close() + + restore := datastore.SetTestURLs(metadataServer.URL, apiServer.URL) + defer restore() + + ctx := context.Background() + client, err := datastore.NewClient(ctx, "test-project") + if err != nil { + t.Fatalf("NewClient failed: %v", err) + } + + key := datastore.NameKey("Test", "key", nil) + var entity testEntity + err = client.Get(ctx, key, &entity) + + if err == nil { + t.Error("expected error for invalid JSON response") + } +} + +func TestPutWithInvalidEntityStructure(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Entity with channel (unsupported type) + type BadEntity struct { + Ch chan int + Name string + } + + key := datastore.NameKey("Test", "bad", nil) + entity := &BadEntity{ + Name: "test", + Ch: make(chan int), + } + + _, err := client.Put(ctx, key, entity) + + if err == nil { + t.Error("expected error for unsupported entity type") + } +} + +func TestGetMultiWithNilInResults(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Put one entity + key1 := datastore.NameKey("Test", "exists", nil) + entity := &testEntity{Name: "test", Count: 1} + _, err := client.Put(ctx, key1, entity) + if err != nil { + t.Fatalf("Put failed: %v", err) + } + + // Try to get multiple with one missing + keys := []*datastore.Key{ + key1, + datastore.NameKey("Test", "missing", nil), + datastore.NameKey("Test", "missing2", nil), + } + + var entities []testEntity + err = client.GetMulti(ctx, keys, &entities) + + if !errors.Is(err, datastore.ErrNoSuchEntity) { + t.Errorf("expected datastore.ErrNoSuchEntity when some keys missing, got: %v", err) + } +} + +func TestDeleteMultiPartialSuccess(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Put some entities + keys := []*datastore.Key{ + datastore.NameKey("Test", "key1", nil), + datastore.NameKey("Test", "key2", nil), + } + + entities := []testEntity{ + {Name: "entity1", Count: 1}, + {Name: "entity2", Count: 2}, + } + + _, err := client.PutMulti(ctx, keys, entities) + if err != nil { + t.Fatalf("PutMulti failed: %v", err) + } + + // Delete them (should succeed) + err = client.DeleteMulti(ctx, keys) + if err != nil { + t.Fatalf("DeleteMulti failed: %v", err) + } + + // Verify deletion + var retrieved []testEntity + err = client.GetMulti(ctx, keys, &retrieved) + if !errors.Is(err, datastore.ErrNoSuchEntity) { + t.Errorf("expected datastore.ErrNoSuchEntity after delete, got: %v", err) + } +} + +func TestDeleteWithServerError(t *testing.T) { + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Metadata-Flavor") != "Google" { + w.WriteHeader(http.StatusForbidden) + return + } + if r.URL.Path == "/project/project-id" { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("test-project")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + if r.URL.Path == "/instance/service-accounts/default/token" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-token", + "expires_in": 3600, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer metadataServer.Close() + + attemptCount := 0 + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attemptCount++ + // Always return 503 + w.WriteHeader(http.StatusServiceUnavailable) + if _, err := w.Write([]byte(`{"error":"unavailable"}`)); err != nil { + t.Logf("write failed: %v", err) + } + })) + defer apiServer.Close() + + restore := datastore.SetTestURLs(metadataServer.URL, apiServer.URL) + defer restore() + + ctx := context.Background() + client, err := datastore.NewClient(ctx, "test-project") + if err != nil { + t.Fatalf("NewClient failed: %v", err) + } + + key := datastore.NameKey("Test", "key", nil) + err = client.Delete(ctx, key) + + if err == nil { + t.Error("expected error on persistent server failure") + } + + // Should have retried + if attemptCount < 2 { + t.Errorf("expected multiple attempts, got %d", attemptCount) + } +} + +func TestPutMultiWithServerError(t *testing.T) { + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Metadata-Flavor") != "Google" { + w.WriteHeader(http.StatusForbidden) + return + } + if r.URL.Path == "/project/project-id" { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("test-project")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + if r.URL.Path == "/instance/service-accounts/default/token" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-token", + "expires_in": 3600, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer metadataServer.Close() + + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + if _, err := w.Write([]byte(`{"error":"bad request"}`)); err != nil { + t.Logf("write failed: %v", err) + } + })) + defer apiServer.Close() + + restore := datastore.SetTestURLs(metadataServer.URL, apiServer.URL) + defer restore() + + ctx := context.Background() + client, err := datastore.NewClient(ctx, "test-project") + if err != nil { + t.Fatalf("NewClient failed: %v", err) + } + + keys := []*datastore.Key{ + datastore.NameKey("Test", "key1", nil), + datastore.NameKey("Test", "key2", nil), + } + + entities := []testEntity{ + {Name: "entity1", Count: 1}, + {Name: "entity2", Count: 2}, + } + + _, err = client.PutMulti(ctx, keys, entities) + + if err == nil { + t.Error("expected error on server failure") + } +} + +func TestGetMultiWithServerError(t *testing.T) { + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Metadata-Flavor") != "Google" { + w.WriteHeader(http.StatusForbidden) + return + } + if r.URL.Path == "/project/project-id" { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("test-project")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + if r.URL.Path == "/instance/service-accounts/default/token" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-token", + "expires_in": 3600, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer metadataServer.Close() + + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + if _, err := w.Write([]byte(`{"error":"unauthorized"}`)); err != nil { + t.Logf("write failed: %v", err) + } + })) + defer apiServer.Close() + + restore := datastore.SetTestURLs(metadataServer.URL, apiServer.URL) + defer restore() + + ctx := context.Background() + client, err := datastore.NewClient(ctx, "test-project") + if err != nil { + t.Fatalf("NewClient failed: %v", err) + } + + keys := []*datastore.Key{ + datastore.NameKey("Test", "key1", nil), + } + + var entities []testEntity + err = client.GetMulti(ctx, keys, &entities) + + if err == nil { + t.Error("expected error on unauthorized") + } + + if !strings.Contains(err.Error(), "401") { + t.Errorf("expected 401 error, got: %v", err) + } +} + +func TestAllKeysWithInvalidResponse(t *testing.T) { + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Metadata-Flavor") != "Google" { + w.WriteHeader(http.StatusForbidden) + return + } + if r.URL.Path == "/project/project-id" { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("test-project")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + if r.URL.Path == "/instance/service-accounts/default/token" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-token", + "expires_in": 3600, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer metadataServer.Close() + + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Return invalid JSON + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte(`{malformed`)); err != nil { + t.Logf("write failed: %v", err) + } + })) + defer apiServer.Close() + + restore := datastore.SetTestURLs(metadataServer.URL, apiServer.URL) + defer restore() + + ctx := context.Background() + client, err := datastore.NewClient(ctx, "test-project") + if err != nil { + t.Fatalf("NewClient failed: %v", err) + } + + query := datastore.NewQuery("Test").KeysOnly() + _, err = client.AllKeys(ctx, query) + + if err == nil { + t.Error("expected error for invalid JSON") + } +} + +func TestDeleteWithContextCancellation(t *testing.T) { + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Metadata-Flavor") != "Google" { + w.WriteHeader(http.StatusForbidden) + return + } + if r.URL.Path == "/project/project-id" { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("test-project")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + if r.URL.Path == "/instance/service-accounts/default/token" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-token", + "expires_in": 3600, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer metadataServer.Close() + + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Slow response + time.Sleep(100 * time.Millisecond) + w.WriteHeader(http.StatusServiceUnavailable) + })) + defer apiServer.Close() + + restore := datastore.SetTestURLs(metadataServer.URL, apiServer.URL) + defer restore() + + client, err := datastore.NewClient(context.Background(), "test-project") + if err != nil { + t.Fatalf("NewClient failed: %v", err) + } + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Cancel immediately + + key := datastore.NameKey("Test", "key", nil) + err = client.Delete(ctx, key) + + if err == nil { + t.Error("expected error when context is cancelled") + } +} + +func TestKeyFromJSONInvalidPathElement(t *testing.T) { + // Test with non-map path element + keyData := map[string]any{ + "path": []any{ + "invalid-string-instead-of-map", + }, + } + + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Metadata-Flavor") != "Google" { + w.WriteHeader(http.StatusForbidden) + return + } + if r.URL.Path == "/project/project-id" { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("test-project")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + if r.URL.Path == "/instance/service-accounts/default/token" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-token", + "expires_in": 3600, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer metadataServer.Close() + + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, ":commit") { + // Return response with invalid key in mutation result + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "mutationResults": []map[string]any{ + { + "key": keyData, + }, + }, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer apiServer.Close() + + restore := datastore.SetTestURLs(metadataServer.URL, apiServer.URL) + defer restore() + + realClient, err := datastore.NewClient(context.Background(), "test-project") + if err != nil { + t.Fatalf("NewClient failed: %v", err) + } + + ctx := context.Background() + key := datastore.NameKey("Test", "key", nil) + entity := &testEntity{Name: "test"} + + // Try Put which will parse the returned key + _, err = realClient.Put(ctx, key, entity) + if err == nil { + t.Log("Put succeeded despite invalid path element (API may handle gracefully)") + } else { + t.Logf("Put failed as expected: %v", err) + } +} + +func TestKeyFromJSONInvalidIDString(t *testing.T) { + keyData := map[string]any{ + "path": []any{ + map[string]any{ + "kind": "Test", + "id": "not-a-number", + }, + }, + } + + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Metadata-Flavor") != "Google" { + w.WriteHeader(http.StatusForbidden) + return + } + if r.URL.Path == "/project/project-id" { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("test-project")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + if r.URL.Path == "/instance/service-accounts/default/token" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-token", + "expires_in": 3600, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer metadataServer.Close() + + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, ":commit") { + // Return response with invalid ID string in mutation result + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "mutationResults": []map[string]any{ + { + "key": keyData, + }, + }, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer apiServer.Close() + + restore := datastore.SetTestURLs(metadataServer.URL, apiServer.URL) + defer restore() + + realClient, err := datastore.NewClient(context.Background(), "test-project") + if err != nil { + t.Fatalf("NewClient failed: %v", err) + } + + ctx := context.Background() + key := datastore.NameKey("Test", "key", nil) + entity := &testEntity{Name: "test"} + + // Try Put which will parse the returned key + _, err = realClient.Put(ctx, key, entity) + if err == nil { + t.Log("Put succeeded despite invalid ID string (API may handle gracefully)") + } else { + t.Logf("Put failed as expected: %v", err) + } +} + +func TestKeyFromJSONIDAsFloat(t *testing.T) { + keyData := map[string]any{ + "path": []any{ + map[string]any{ + "kind": "Test", + "id": float64(12345), + }, + }, + } + + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Metadata-Flavor") != "Google" { + w.WriteHeader(http.StatusForbidden) + return + } + if r.URL.Path == "/project/project-id" { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("test-project")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + if r.URL.Path == "/instance/service-accounts/default/token" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-token", + "expires_in": 3600, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer metadataServer.Close() + + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, ":lookup") { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "found": []map[string]any{ + { + "entity": map[string]any{ + "key": keyData, + "properties": map[string]any{ + "name": map[string]any{"stringValue": "test"}, + }, + }, + }, + }, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer apiServer.Close() + + restore := datastore.SetTestURLs(metadataServer.URL, apiServer.URL) + defer restore() + + realClient, err := datastore.NewClient(context.Background(), "test-project") + if err != nil { + t.Fatalf("NewClient failed: %v", err) + } + + ctx := context.Background() + key := datastore.NameKey("Test", "key", nil) + var entity testEntity + + err = realClient.Get(ctx, key, &entity) + if err != nil { + t.Errorf("unexpected error with float64 ID: %v", err) + } +} + +func TestDeleteAllRetriesFail(t *testing.T) { + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Metadata-Flavor") != "Google" { + w.WriteHeader(http.StatusForbidden) + return + } + if r.URL.Path == "/project/project-id" { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("test-project")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + if r.URL.Path == "/instance/service-accounts/default/token" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-token", + "expires_in": 3600, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer metadataServer.Close() + + requestCount := 0 + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestCount++ + // Always return 503 to force retries + w.WriteHeader(http.StatusServiceUnavailable) + })) + defer apiServer.Close() + + restore := datastore.SetTestURLs(metadataServer.URL, apiServer.URL) + defer restore() + + client, err := datastore.NewClient(context.Background(), "test-project") + if err != nil { + t.Fatalf("NewClient failed: %v", err) + } + + ctx := context.Background() + key := datastore.NameKey("Test", "key", nil) + + err = client.Delete(ctx, key) + if err == nil { + t.Error("expected error after all retries exhausted") + } + + if !strings.Contains(err.Error(), "attempts") { + t.Errorf("expected error message about attempts, got: %v", err) + } + + if requestCount != 3 { + t.Errorf("expected 3 retry attempts, got %d", requestCount) + } +} + +func TestGetMultiPartialNotFound(t *testing.T) { + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Metadata-Flavor") != "Google" { + w.WriteHeader(http.StatusForbidden) + return + } + if r.URL.Path == "/project/project-id" { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("test-project")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + if r.URL.Path == "/instance/service-accounts/default/token" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-token", + "expires_in": 3600, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer metadataServer.Close() + + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, ":lookup") { + // Return one found, one missing + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "found": []map[string]any{ + { + "entity": map[string]any{ + "key": map[string]any{ + "path": []any{ + map[string]any{ + "kind": "Test", + "name": "key1", + }, + }, + }, + "properties": map[string]any{ + "name": map[string]any{"stringValue": "test1"}, + }, + }, + }, + }, + "missing": []map[string]any{ + { + "key": map[string]any{ + "path": []any{ + map[string]any{ + "kind": "Test", + "name": "key2", + }, + }, + }, + }, + }, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer apiServer.Close() + + restore := datastore.SetTestURLs(metadataServer.URL, apiServer.URL) + defer restore() + + client, err := datastore.NewClient(context.Background(), "test-project") + if err != nil { + t.Fatalf("NewClient failed: %v", err) + } + + ctx := context.Background() + keys := []*datastore.Key{ + datastore.NameKey("Test", "key1", nil), + datastore.NameKey("Test", "key2", nil), + } + + var entities []testEntity + err = client.GetMulti(ctx, keys, &entities) + if err == nil { + t.Error("expected error when some entities are missing") + } else { + t.Logf("GetMulti with missing entities failed as expected: %v", err) + } +} + +func TestAllKeysInvalidJSON(t *testing.T) { + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Metadata-Flavor") != "Google" { + w.WriteHeader(http.StatusForbidden) + return + } + if r.URL.Path == "/project/project-id" { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("test-project")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + if r.URL.Path == "/instance/service-accounts/default/token" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-token", + "expires_in": 3600, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer metadataServer.Close() + + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, ":runQuery") { + // Return invalid JSON + w.Header().Set("Content-Type", "application/json") + if _, err := w.Write([]byte("{")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer apiServer.Close() + + restore := datastore.SetTestURLs(metadataServer.URL, apiServer.URL) + defer restore() + + client, err := datastore.NewClient(context.Background(), "test-project") + if err != nil { + t.Fatalf("NewClient failed: %v", err) + } + + ctx := context.Background() + query := datastore.NewQuery("Test").KeysOnly() + + _, err = client.AllKeys(ctx, query) + if err == nil { + t.Error("expected error with invalid JSON") + } +} + +// Test Transaction commit with invalid response + +func TestPutMultiWithInvalidEntities(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + type InvalidEntity struct { + Func func() `datastore:"func"` + } + + keys := []*datastore.Key{ + datastore.NameKey("Test", "key1", nil), + } + + entities := []InvalidEntity{ + {Func: func() {}}, + } + + _, err := client.PutMulti(ctx, keys, entities) + if err == nil { + t.Log("PutMulti with func field succeeded (mock may not validate types)") + } else { + t.Logf("PutMulti with func field failed as expected: %v", err) + } +} + +func TestGetWithNonPointer(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + key := datastore.NameKey("Test", "key", nil) + var entity testEntity // non-pointer + + err := client.Get(ctx, key, entity) // Pass by value + if err == nil { + t.Error("expected error when dst is not a pointer") + } +} + +func TestPutWithNonStruct(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + key := datastore.NameKey("Test", "key", nil) + entity := "not a struct" + + _, err := client.Put(ctx, key, entity) + if err == nil { + t.Error("expected error when entity is not a struct") + } +} + +func TestAllKeysNotKeysOnlyError(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + query := datastore.NewQuery("Test") // Not KeysOnly + + _, err := client.AllKeys(ctx, query) + if err == nil { + t.Error("expected error when query is not KeysOnly") + } +} + +func TestGetMultiMismatchedLength(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + keys := []*datastore.Key{ + datastore.NameKey("Test", "key1", nil), + datastore.NameKey("Test", "key2", nil), + } + + var entities []testEntity // Empty slice + + err := client.GetMulti(ctx, keys, &entities) + // This should work - GetMulti should populate the slice + if err != nil { + t.Logf("GetMulti with empty slice: %v", err) + } +} + +func TestPutMultiMismatchedLength(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + keys := []*datastore.Key{ + datastore.NameKey("Test", "key1", nil), + datastore.NameKey("Test", "key2", nil), + } + + entities := []testEntity{ + {Name: "test1"}, + // Missing second entity + } + + _, err := client.PutMulti(ctx, keys, entities) + if err == nil { + t.Error("expected error with mismatched lengths") + } +} + +func TestDeleteMultiWithEmptyKeysSlice(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + var keys []*datastore.Key // Empty + + err := client.DeleteMulti(ctx, keys) + // Mock may behave differently - log the result + if err != nil { + t.Logf("DeleteMulti with empty keys: %v", err) + } +} + +func TestGetWithJSONUnmarshalError(t *testing.T) { + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Metadata-Flavor") != "Google" { + w.WriteHeader(http.StatusForbidden) + return + } + if r.URL.Path == "/project/project-id" { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("test-project")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + if r.URL.Path == "/instance/service-accounts/default/token" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-token", + "expires_in": 3600, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer metadataServer.Close() + + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, ":lookup") { + // Return invalid JSON + w.Header().Set("Content-Type", "application/json") + if _, err := w.Write([]byte(`{"found": [{"entity": "not-an-object"}]}`)); err != nil { + t.Logf("write failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer apiServer.Close() + + restore := datastore.SetTestURLs(metadataServer.URL, apiServer.URL) + defer restore() + + client, err := datastore.NewClient(context.Background(), "test-project") + if err != nil { + t.Fatalf("NewClient failed: %v", err) + } + + ctx := context.Background() + key := datastore.NameKey("Test", "key", nil) + var entity testEntity + + err = client.Get(ctx, key, &entity) + if err == nil { + t.Error("expected error with invalid entity format") + } +} + +func TestPutWithAccessTokenError(t *testing.T) { + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Metadata-Flavor") != "Google" { + w.WriteHeader(http.StatusForbidden) + return + } + // Always return error for token + w.WriteHeader(http.StatusInternalServerError) + })) + defer metadataServer.Close() + + restore := datastore.SetTestURLs(metadataServer.URL, "http://unused") + defer restore() + + client, err := datastore.NewClient(context.Background(), "test-project") + if err != nil { + t.Fatalf("NewClient failed: %v", err) + } + + ctx := context.Background() + key := datastore.NameKey("Test", "key", nil) + entity := &testEntity{Name: "test"} + + _, err = client.Put(ctx, key, entity) + if err == nil { + t.Error("expected error when access token fails") + } +} + +func TestDeleteWithJSONMarshalError(t *testing.T) { + // This is hard to trigger since we control the JSON structure + // But we can test with a context that gets cancelled + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Metadata-Flavor") != "Google" { + w.WriteHeader(http.StatusForbidden) + return + } + if r.URL.Path == "/project/project-id" { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("test-project")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + if r.URL.Path == "/instance/service-accounts/default/token" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-token", + "expires_in": 3600, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer metadataServer.Close() + + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(map[string]any{}); err != nil { + t.Logf("encode failed: %v", err) + } + })) + defer apiServer.Close() + + restore := datastore.SetTestURLs(metadataServer.URL, apiServer.URL) + defer restore() + + client, err := datastore.NewClient(context.Background(), "test-project") + if err != nil { + t.Fatalf("NewClient failed: %v", err) + } + + ctx := context.Background() + key := datastore.NameKey("Test", "key", nil) + + err = client.Delete(ctx, key) + if err != nil { + t.Logf("Delete completed with: %v", err) + } +} + +func TestAllKeysWithBatching(t *testing.T) { + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Metadata-Flavor") != "Google" { + w.WriteHeader(http.StatusForbidden) + return + } + if r.URL.Path == "/project/project-id" { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("test-project")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + if r.URL.Path == "/instance/service-accounts/default/token" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-token", + "expires_in": 3600, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer metadataServer.Close() + + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, ":runQuery") { + // Return multiple key results + results := make([]map[string]any, 50) + for i := range 50 { + results[i] = map[string]any{ + "entity": map[string]any{ + "key": map[string]any{ + "path": []any{ + map[string]any{ + "kind": "Test", + "name": fmt.Sprintf("key%d", i), + }, + }, + }, + }, + } + } + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "batch": map[string]any{ + "entityResults": results, + }, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer apiServer.Close() + + restore := datastore.SetTestURLs(metadataServer.URL, apiServer.URL) + defer restore() + + client, err := datastore.NewClient(context.Background(), "test-project") + if err != nil { + t.Fatalf("NewClient failed: %v", err) + } + + ctx := context.Background() + query := datastore.NewQuery("Test").KeysOnly() + + keys, err := client.AllKeys(ctx, query) + if err != nil { + t.Logf("AllKeys with many results: %v", err) + } else if len(keys) != 50 { + t.Logf("Expected 50 keys, got %d", len(keys)) + } +} + +func TestAllKeysKeyFromJSONError(t *testing.T) { + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Metadata-Flavor") != "Google" { + w.WriteHeader(http.StatusForbidden) + return + } + if r.URL.Path == "/project/project-id" { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("test-project")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + if r.URL.Path == "/instance/service-accounts/default/token" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-token", + "expires_in": 3600, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer metadataServer.Close() + + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, ":runQuery") { + // Return result with invalid key format + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "batch": map[string]any{ + "entityResults": []map[string]any{ + { + "entity": map[string]any{ + "key": "not-a-map", // Invalid key format + }, + }, + }, + }, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer apiServer.Close() + + restore := datastore.SetTestURLs(metadataServer.URL, apiServer.URL) + defer restore() + + client, err := datastore.NewClient(context.Background(), "test-project") + if err != nil { + t.Fatalf("NewClient failed: %v", err) + } + + ctx := context.Background() + query := datastore.NewQuery("Test").KeysOnly() + + _, err = client.AllKeys(ctx, query) + if err == nil { + t.Error("expected error with invalid key format") + } +} + +func TestPutMultiRequestMarshalError(t *testing.T) { + // This is hard to trigger directly, but we can test with encoding errors + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Metadata-Flavor") != "Google" { + w.WriteHeader(http.StatusForbidden) + return + } + if r.URL.Path == "/project/project-id" { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("test-project")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + if r.URL.Path == "/instance/service-accounts/default/token" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-token", + "expires_in": 3600, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer metadataServer.Close() + + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(map[string]any{ + "mutationResults": []any{}, + }); err != nil { + t.Logf("encode failed: %v", err) + } + })) + defer apiServer.Close() + + restore := datastore.SetTestURLs(metadataServer.URL, apiServer.URL) + defer restore() + + client, err := datastore.NewClient(context.Background(), "test-project") + if err != nil { + t.Fatalf("NewClient failed: %v", err) + } + + ctx := context.Background() + + // Test with valid entities to exercise the code path + keys := []*datastore.Key{ + datastore.NameKey("Test", "key1", nil), + } + + entities := []testEntity{ + {Name: "test1", Count: 123}, + } + + _, err = client.PutMulti(ctx, keys, entities) + if err != nil { + t.Logf("PutMulti completed with: %v", err) + } +} + +func TestDeleteAllByKindEmptyBatch(t *testing.T) { + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Metadata-Flavor") != "Google" { + w.WriteHeader(http.StatusForbidden) + return + } + if r.URL.Path == "/project/project-id" { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("test-project")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + if r.URL.Path == "/instance/service-accounts/default/token" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-token", + "expires_in": 3600, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer metadataServer.Close() + + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, ":runQuery") { + // Return empty batch + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "batch": map[string]any{}, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer apiServer.Close() + + restore := datastore.SetTestURLs(metadataServer.URL, apiServer.URL) + defer restore() + + client, err := datastore.NewClient(context.Background(), "test-project") + if err != nil { + t.Fatalf("NewClient failed: %v", err) + } + + ctx := context.Background() + err = client.DeleteAllByKind(ctx, "EmptyKind") + if err != nil { + t.Logf("DeleteAllByKind with empty batch: %v", err) + } +} + +func TestAllKeysEmptyPathInKey(t *testing.T) { + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Metadata-Flavor") != "Google" { + w.WriteHeader(http.StatusForbidden) + return + } + if r.URL.Path == "/project/project-id" { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("test-project")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + if r.URL.Path == "/instance/service-accounts/default/token" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-token", + "expires_in": 3600, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer metadataServer.Close() + + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, ":runQuery") { + w.Header().Set("Content-Type", "application/json") + // Return key with empty path array + if err := json.NewEncoder(w).Encode(map[string]any{ + "batch": map[string]any{ + "entityResults": []map[string]any{ + { + "entity": map[string]any{ + "key": map[string]any{ + "path": []any{}, // Empty path + }, + }, + }, + }, + }, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer apiServer.Close() + + restore := datastore.SetTestURLs(metadataServer.URL, apiServer.URL) + defer restore() + + client, err := datastore.NewClient(context.Background(), "test-project") + if err != nil { + t.Fatalf("NewClient failed: %v", err) + } + + ctx := context.Background() + query := datastore.NewQuery("TestKind").KeysOnly() + _, err = client.AllKeys(ctx, query) + if err == nil { + t.Error("expected error with empty path in key") + } +} + +func TestAllKeysInvalidPathElement(t *testing.T) { + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Metadata-Flavor") != "Google" { + w.WriteHeader(http.StatusForbidden) + return + } + if r.URL.Path == "/project/project-id" { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("test-project")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + if r.URL.Path == "/instance/service-accounts/default/token" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-token", + "expires_in": 3600, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer metadataServer.Close() + + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, ":runQuery") { + w.Header().Set("Content-Type", "application/json") + // Return key with invalid path element (string instead of map) + if err := json.NewEncoder(w).Encode(map[string]any{ + "batch": map[string]any{ + "entityResults": []map[string]any{ + { + "entity": map[string]any{ + "key": map[string]any{ + "path": []any{"invalid-element"}, // String instead of map + }, + }, + }, + }, + }, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer apiServer.Close() + + restore := datastore.SetTestURLs(metadataServer.URL, apiServer.URL) + defer restore() + + client, err := datastore.NewClient(context.Background(), "test-project") + if err != nil { + t.Fatalf("NewClient failed: %v", err) + } + + ctx := context.Background() + query := datastore.NewQuery("TestKind").KeysOnly() + _, err = client.AllKeys(ctx, query) + if err == nil { + t.Error("expected error with invalid path element") + } +} + +func TestGetWithStringIDKey(t *testing.T) { + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Metadata-Flavor") != "Google" { + w.WriteHeader(http.StatusForbidden) + return + } + if r.URL.Path == "/project/project-id" { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("test-project")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + if r.URL.Path == "/instance/service-accounts/default/token" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-token", + "expires_in": 3600, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer metadataServer.Close() + + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, ":lookup") { + w.Header().Set("Content-Type", "application/json") + // Return entity with ID as string + if err := json.NewEncoder(w).Encode(map[string]any{ + "found": []map[string]any{ + { + "entity": map[string]any{ + "key": map[string]any{ + "path": []any{ + map[string]any{ + "kind": "TestKind", + "id": "12345", // ID as string + }, + }, + }, + "properties": map[string]any{ + "name": map[string]any{"stringValue": "test"}, + }, + }, + }, + }, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer apiServer.Close() + + restore := datastore.SetTestURLs(metadataServer.URL, apiServer.URL) + defer restore() + + client, err := datastore.NewClient(context.Background(), "test-project") + if err != nil { + t.Fatalf("NewClient failed: %v", err) + } + + type TestEntity struct { + Name string `datastore:"name"` + } + + ctx := context.Background() + key := datastore.IDKey("TestKind", 12345, nil) + var entity TestEntity + err = client.Get(ctx, key, &entity) + if err != nil { + t.Fatalf("Get with string ID key failed: %v", err) + } + + if entity.Name != "test" { + t.Errorf("expected name 'test', got %q", entity.Name) + } +} + +func TestGetWithFloat64IDKey(t *testing.T) { + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Metadata-Flavor") != "Google" { + w.WriteHeader(http.StatusForbidden) + return + } + if r.URL.Path == "/project/project-id" { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("test-project")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + if r.URL.Path == "/instance/service-accounts/default/token" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-token", + "expires_in": 3600, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer metadataServer.Close() + + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, ":lookup") { + w.Header().Set("Content-Type", "application/json") + // Return entity with ID as float64 + if err := json.NewEncoder(w).Encode(map[string]any{ + "found": []map[string]any{ + { + "entity": map[string]any{ + "key": map[string]any{ + "path": []any{ + map[string]any{ + "kind": "TestKind", + "id": float64(67890), // ID as float64 + }, + }, + }, + "properties": map[string]any{ + "value": map[string]any{"integerValue": "42"}, + }, + }, + }, + }, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer apiServer.Close() + + restore := datastore.SetTestURLs(metadataServer.URL, apiServer.URL) + defer restore() + + client, err := datastore.NewClient(context.Background(), "test-project") + if err != nil { + t.Fatalf("NewClient failed: %v", err) + } + + type TestEntity struct { + Value int64 `datastore:"value"` + } + + ctx := context.Background() + key := datastore.IDKey("TestKind", 67890, nil) + var entity TestEntity + err = client.Get(ctx, key, &entity) + if err != nil { + t.Fatalf("Get with float64 ID key failed: %v", err) + } + + if entity.Value != 42 { + t.Errorf("expected value 42, got %d", entity.Value) + } +} + +func TestGetWithInvalidStringIDFormat(t *testing.T) { + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Metadata-Flavor") != "Google" { + w.WriteHeader(http.StatusForbidden) + return + } + if r.URL.Path == "/project/project-id" { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("test-project")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + if r.URL.Path == "/instance/service-accounts/default/token" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-token", + "expires_in": 3600, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer metadataServer.Close() + + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, ":lookup") { + w.Header().Set("Content-Type", "application/json") + // Return entity with invalid ID string format + if err := json.NewEncoder(w).Encode(map[string]any{ + "found": []map[string]any{ + { + "entity": map[string]any{ + "key": map[string]any{ + "path": []any{ + map[string]any{ + "kind": "TestKind", + "id": "not-a-number", // Invalid ID format + }, + }, + }, + "properties": map[string]any{ + "name": map[string]any{"stringValue": "test"}, + }, + }, + }, + }, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer apiServer.Close() + + restore := datastore.SetTestURLs(metadataServer.URL, apiServer.URL) + defer restore() + + client, err := datastore.NewClient(context.Background(), "test-project") + if err != nil { + t.Fatalf("NewClient failed: %v", err) + } + + type TestEntity struct { + Name string `datastore:"name"` + } + + ctx := context.Background() + key := datastore.IDKey("TestKind", 12345, nil) + var entity TestEntity + err = client.Get(ctx, key, &entity) + // May or may not error depending on parsing behavior + if err != nil { + t.Logf("Get with invalid string ID format failed: %v", err) + } else { + t.Logf("Get with invalid string ID format succeeded unexpectedly") + } +} + +func TestGetJSONUnmarshalError(t *testing.T) { + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Metadata-Flavor") != "Google" { + w.WriteHeader(http.StatusForbidden) + return + } + if r.URL.Path == "/project/project-id" { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("test-project")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + if r.URL.Path == "/instance/service-accounts/default/token" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-token", + "expires_in": 3600, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer metadataServer.Close() + + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, ":lookup") { + // Return malformed JSON + w.Header().Set("Content-Type", "application/json") + if _, err := w.Write([]byte("not valid json")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer apiServer.Close() + + restore := datastore.SetTestURLs(metadataServer.URL, apiServer.URL) + defer restore() + + client, err := datastore.NewClient(context.Background(), "test-project") + if err != nil { + t.Fatalf("NewClient failed: %v", err) + } + + ctx := context.Background() + key := datastore.NameKey("Test", "test-key", nil) + var entity testEntity + + err = client.Get(ctx, key, &entity) + if err == nil { + t.Error("expected error with malformed JSON") + } +} + +func TestPutMultiLengthValidation(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + keys := []*datastore.Key{datastore.NameKey("Test", "key1", nil)} + entities := []testEntity{{Name: "test1"}, {Name: "test2"}} + + _, err := client.PutMulti(ctx, keys, entities) + if err == nil { + t.Error("expected error with mismatched lengths") + } +} + +func TestDeleteMultiMixedResults(t *testing.T) { + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Metadata-Flavor") != "Google" { + w.WriteHeader(http.StatusForbidden) + return + } + if r.URL.Path == "/project/project-id" { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("test-project")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + if r.URL.Path == "/instance/service-accounts/default/token" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-token", + "expires_in": 3600, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer metadataServer.Close() + + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, ":commit") { + w.Header().Set("Content-Type", "application/json") + // Return empty mutation results + if err := json.NewEncoder(w).Encode(map[string]any{ + "mutationResults": []any{}, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer apiServer.Close() + + restore := datastore.SetTestURLs(metadataServer.URL, apiServer.URL) + defer restore() + + client, err := datastore.NewClient(context.Background(), "test-project") + if err != nil { + t.Fatalf("NewClient failed: %v", err) + } + + ctx := context.Background() + keys := []*datastore.Key{ + datastore.NameKey("Test", "key1", nil), + datastore.NameKey("Test", "key2", nil), + } + + err = client.DeleteMulti(ctx, keys) + // May or may not error depending on implementation + if err != nil { + t.Logf("DeleteMulti with mismatched results: %v", err) + } +} + +func TestBackwardsCompatibility(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Test 1: Close() method exists and can be called (even though it's a no-op) + t.Run("Close", func(t *testing.T) { + err := client.Close() + if err != nil { + t.Errorf("Close() returned error: %v", err) + } + }) + + // Test 2: RunInTransaction returns (*Commit, error) + t.Run("RunInTransactionSignature", func(t *testing.T) { + key := datastore.NameKey("TestKind", "test-tx-compat", nil) + + commit, err := client.RunInTransaction(ctx, func(tx *datastore.Transaction) error { + entity := &testEntity{ + Name: "transaction test", + Count: 100, + Active: true, + Score: 99.9, + UpdatedAt: time.Now().UTC().Truncate(time.Microsecond), + } + _, err := tx.Put(key, entity) + return err + }) + if err != nil { + t.Fatalf("RunInTransaction failed: %v", err) + } + + if commit == nil { + t.Error("Expected non-nil Commit, got nil") + } + }) + + // Test 3: GetAll() method retrieves entities and returns keys + t.Run("GetAll", func(t *testing.T) { + // Setup: Create some test entities + entities := []testEntity{ + {Name: "entity1", Count: 1, Active: true, Score: 1.1, UpdatedAt: time.Now().UTC().Truncate(time.Microsecond)}, + {Name: "entity2", Count: 2, Active: false, Score: 2.2, UpdatedAt: time.Now().UTC().Truncate(time.Microsecond)}, + {Name: "entity3", Count: 3, Active: true, Score: 3.3, UpdatedAt: time.Now().UTC().Truncate(time.Microsecond)}, + } + + keys := []*datastore.Key{ + datastore.NameKey("GetAllTest", "key1", nil), + datastore.NameKey("GetAllTest", "key2", nil), + datastore.NameKey("GetAllTest", "key3", nil), + } + + _, err := client.PutMulti(ctx, keys, entities) + if err != nil { + t.Fatalf("PutMulti failed: %v", err) + } + + // Test GetAll + query := datastore.NewQuery("GetAllTest") + var results []testEntity + returnedKeys, err := client.GetAll(ctx, query, &results) + if err != nil { + t.Fatalf("GetAll failed: %v", err) + } + + if len(results) != 3 { + t.Errorf("Expected 3 entities, got %d", len(results)) + } + + if len(returnedKeys) != 3 { + t.Errorf("Expected 3 keys, got %d", len(returnedKeys)) + } + + // Verify entities were properly decoded + foundNames := make(map[string]bool) + for _, entity := range results { + foundNames[entity.Name] = true + } + + for _, expectedName := range []string{"entity1", "entity2", "entity3"} { + if !foundNames[expectedName] { + t.Errorf("Expected to find entity %s, but didn't", expectedName) + } + } + + // Verify keys match entities + for i, key := range returnedKeys { + if key.Kind != "GetAllTest" { + t.Errorf("Key %d has wrong kind: %s", i, key.Kind) + } + } + }) + + // Test 4: GetAll with limit + t.Run("GetAllWithLimit", func(t *testing.T) { + query := datastore.NewQuery("GetAllTest").Limit(2) + var results []testEntity + returnedKeys, err := client.GetAll(ctx, query, &results) + if err != nil { + t.Fatalf("GetAll with limit failed: %v", err) + } + + if len(results) != 2 { + t.Errorf("Expected 2 entities with limit, got %d", len(results)) + } + + if len(returnedKeys) != 2 { + t.Errorf("Expected 2 keys with limit, got %d", len(returnedKeys)) + } + }) +} + +func TestArraySliceSupport(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + t.Run("StringSlice", func(t *testing.T) { + key := datastore.NameKey("ArrayTest", "strings", nil) + entity := &arrayEntity{ + Strings: []string{"hello", "world", "test"}, + } + + _, err := client.Put(ctx, key, entity) + if err != nil { + t.Fatalf("Put with string slice failed: %v", err) + } + + var result arrayEntity + if err := client.Get(ctx, key, &result); err != nil { + t.Fatalf("Get failed: %v", err) + } + + if len(result.Strings) != 3 { + t.Errorf("Expected 3 strings, got %d", len(result.Strings)) + } + if result.Strings[0] != "hello" || result.Strings[1] != "world" || result.Strings[2] != "test" { + t.Errorf("String slice values incorrect: %v", result.Strings) + } + }) + + t.Run("Int64Slice", func(t *testing.T) { + key := datastore.NameKey("ArrayTest", "ints", nil) + entity := &arrayEntity{ + Ints: []int64{1, 2, 3, 42, 100}, + } + + _, err := client.Put(ctx, key, entity) + if err != nil { + t.Fatalf("Put with int64 slice failed: %v", err) + } + + var result arrayEntity + if err := client.Get(ctx, key, &result); err != nil { + t.Fatalf("Get failed: %v", err) + } + + if len(result.Ints) != 5 { + t.Errorf("Expected 5 ints, got %d", len(result.Ints)) + } + if result.Ints[3] != 42 { + t.Errorf("Expected Ints[3] = 42, got %d", result.Ints[3]) + } + }) + + t.Run("Float64Slice", func(t *testing.T) { + key := datastore.NameKey("ArrayTest", "floats", nil) + entity := &arrayEntity{ + Floats: []float64{1.1, 2.2, 3.3}, + } + + _, err := client.Put(ctx, key, entity) + if err != nil { + t.Fatalf("Put with float64 slice failed: %v", err) + } + + var result arrayEntity + if err := client.Get(ctx, key, &result); err != nil { + t.Fatalf("Get failed: %v", err) + } + + if len(result.Floats) != 3 { + t.Errorf("Expected 3 floats, got %d", len(result.Floats)) + } + if result.Floats[0] != 1.1 { + t.Errorf("Expected Floats[0] = 1.1, got %f", result.Floats[0]) + } + }) + + t.Run("BoolSlice", func(t *testing.T) { + key := datastore.NameKey("ArrayTest", "bools", nil) + entity := &arrayEntity{ + Bools: []bool{true, false, true}, + } + + _, err := client.Put(ctx, key, entity) + if err != nil { + t.Fatalf("Put with bool slice failed: %v", err) + } + + var result arrayEntity + if err := client.Get(ctx, key, &result); err != nil { + t.Fatalf("Get failed: %v", err) + } + + if len(result.Bools) != 3 { + t.Errorf("Expected 3 bools, got %d", len(result.Bools)) + } + if result.Bools[0] != true || result.Bools[1] != false { + t.Errorf("Bool slice values incorrect: %v", result.Bools) + } + }) + + t.Run("EmptySlices", func(t *testing.T) { + key := datastore.NameKey("ArrayTest", "empty", nil) + entity := &arrayEntity{ + Strings: []string{}, + Ints: []int64{}, + } + + _, err := client.Put(ctx, key, entity) + if err != nil { + t.Fatalf("Put with empty slices failed: %v", err) + } + + var result arrayEntity + if err := client.Get(ctx, key, &result); err != nil { + t.Fatalf("Get failed: %v", err) + } + + if result.Strings == nil || len(result.Strings) != 0 { + t.Errorf("Expected empty string slice, got %v", result.Strings) + } + if result.Ints == nil || len(result.Ints) != 0 { + t.Errorf("Expected empty int slice, got %v", result.Ints) + } + }) + + t.Run("MixedArrays", func(t *testing.T) { + key := datastore.NameKey("ArrayTest", "mixed", nil) + entity := &arrayEntity{ + Strings: []string{"a", "b"}, + Ints: []int64{10, 20, 30}, + Floats: []float64{1.5}, + Bools: []bool{true, false, true, false}, + } + + _, err := client.Put(ctx, key, entity) + if err != nil { + t.Fatalf("Put with mixed arrays failed: %v", err) + } + + var result arrayEntity + if err := client.Get(ctx, key, &result); err != nil { + t.Fatalf("Get failed: %v", err) + } + + if len(result.Strings) != 2 || len(result.Ints) != 3 || len(result.Floats) != 1 || len(result.Bools) != 4 { + t.Errorf("Mixed array lengths incorrect: strings=%d, ints=%d, floats=%d, bools=%d", + len(result.Strings), len(result.Ints), len(result.Floats), len(result.Bools)) + } + }) +} + +func TestAllocateIDs(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + t.Run("AllocateIncompleteKeys", func(t *testing.T) { + keys := []*datastore.Key{ + datastore.IncompleteKey("Task", nil), + datastore.IncompleteKey("Task", nil), + datastore.IncompleteKey("Task", nil), + } + + allocated, err := client.AllocateIDs(ctx, keys) + if err != nil { + t.Fatalf("AllocateIDs failed: %v", err) + } + + if len(allocated) != 3 { + t.Errorf("Expected 3 allocated keys, got %d", len(allocated)) + } + + for i, key := range allocated { + if key.Incomplete() { + t.Errorf("Key %d is still incomplete", i) + } + if key.ID == 0 { + t.Errorf("Key %d has zero ID", i) + } + } + }) + + t.Run("AllocateMixedKeys", func(t *testing.T) { + keys := []*datastore.Key{ + datastore.NameKey("Task", "complete", nil), + datastore.IncompleteKey("Task", nil), + datastore.IDKey("Task", 123, nil), + datastore.IncompleteKey("Task", nil), + } + + allocated, err := client.AllocateIDs(ctx, keys) + if err != nil { + t.Fatalf("AllocateIDs with mixed keys failed: %v", err) + } + + if len(allocated) != 4 { + t.Errorf("Expected 4 keys, got %d", len(allocated)) + } + + // First key should still be the named key + if allocated[0].Name != "complete" { + t.Errorf("First key should be unchanged") + } + + // Second key should now have an ID + if allocated[1].Incomplete() { + t.Errorf("Second key should be allocated") + } + + // Third key should be unchanged + if allocated[2].ID != 123 { + t.Errorf("Third key should be unchanged") + } + + // Fourth key should now have an ID + if allocated[3].Incomplete() { + t.Errorf("Fourth key should be allocated") + } + }) + + t.Run("AllocateEmptySlice", func(t *testing.T) { + keys := []*datastore.Key{} + + allocated, err := client.AllocateIDs(ctx, keys) + if err != nil { + t.Fatalf("AllocateIDs with empty slice failed: %v", err) + } + + if len(allocated) != 0 { + t.Errorf("Expected empty slice, got %d keys", len(allocated)) + } + }) + + t.Run("AllocateAllCompleteKeys", func(t *testing.T) { + keys := []*datastore.Key{ + datastore.NameKey("Task", "key1", nil), + datastore.IDKey("Task", 100, nil), + } + + allocated, err := client.AllocateIDs(ctx, keys) + if err != nil { + t.Fatalf("AllocateIDs with complete keys failed: %v", err) + } + + if len(allocated) != 2 { + t.Errorf("Expected 2 keys, got %d", len(allocated)) + } + + // Keys should be unchanged + if allocated[0].Name != "key1" || allocated[1].ID != 100 { + t.Errorf("Complete keys should be unchanged") + } + }) +} diff --git a/pkg/datastore/query.go b/pkg/datastore/query.go new file mode 100644 index 0000000..b45cebf --- /dev/null +++ b/pkg/datastore/query.go @@ -0,0 +1,557 @@ +package datastore + +import ( + "context" + "encoding/json" + "errors" + "fmt" + neturl "net/url" + "reflect" + "strconv" + "strings" + + "github.com/codeGROOVE-dev/ds9/auth" +) + +// Query represents a Datastore query. +type Query struct { + ancestor *Key + kind string + filters []queryFilter + orders []queryOrder + projection []string + distinctOn []string + namespace string + startCursor Cursor + endCursor Cursor + limit int + offset int + keysOnly bool +} + +type queryFilter struct { + value any + property string + operator string +} + +type queryOrder struct { + property string + direction string // "ASCENDING" or "DESCENDING" +} + +// NewQuery creates a new query for the given kind. +func NewQuery(kind string) *Query { + return &Query{ + kind: kind, + } +} + +// KeysOnly configures the query to return only keys, not full entities. +func (q *Query) KeysOnly() *Query { + q.keysOnly = true + return q +} + +// Limit sets the maximum number of results to return. +func (q *Query) Limit(limit int) *Query { + q.limit = limit + return q +} + +// Offset sets the number of results to skip before returning. +// API compatible with cloud.google.com/go/datastore. +func (q *Query) Offset(offset int) *Query { + q.offset = offset + return q +} + +// Filter adds a property filter to the query. +// The filterStr should be in the format "Property Operator" (e.g., "Count >", "Name ="). +// Deprecated: Use FilterField instead. API compatible with cloud.google.com/go/datastore. +func (q *Query) Filter(filterStr string, value any) *Query { + // Parse the filter string to extract property and operator + parts := strings.Fields(filterStr) + if len(parts) != 2 { + // Invalid filter format, but we'll be lenient + return q + } + + property := parts[0] + op := parts[1] + + operator, ok := operatorMap[op] + if !ok { + operator = "EQUAL" + } + + q.filters = append(q.filters, queryFilter{ + property: property, + operator: operator, + value: value, + }) + + return q +} + +// FilterField adds a property filter to the query with explicit operator. +// API compatible with cloud.google.com/go/datastore. +func (q *Query) FilterField(fieldName, operator string, value any) *Query { + dsOperator, ok := operatorMap[operator] + if !ok { + dsOperator = operator // Use as-is if not in map (might already be EQUAL, etc.) + } + + q.filters = append(q.filters, queryFilter{ + property: fieldName, + operator: dsOperator, + value: value, + }) + + return q +} + +// Order sets the order in which results are returned. +// Prefix the property name with "-" for descending order (e.g., "-Created"). +// API compatible with cloud.google.com/go/datastore. +func (q *Query) Order(fieldName string) *Query { + direction := "ASCENDING" + property := fieldName + + if strings.HasPrefix(fieldName, "-") { + direction = "DESCENDING" + property = fieldName[1:] + } + + q.orders = append(q.orders, queryOrder{ + property: property, + direction: direction, + }) + + return q +} + +// Ancestor sets an ancestor filter for the query. +// Only entities with the given ancestor will be returned. +// API compatible with cloud.google.com/go/datastore. +func (q *Query) Ancestor(ancestor *Key) *Query { + q.ancestor = ancestor + return q +} + +// Project sets the fields to be projected (returned) in the query results. +// API compatible with cloud.google.com/go/datastore. +func (q *Query) Project(fieldNames ...string) *Query { + q.projection = fieldNames + return q +} + +// Distinct marks the query to return only distinct results. +// This is equivalent to DistinctOn with all projected fields. +// API compatible with cloud.google.com/go/datastore. +func (q *Query) Distinct() *Query { + // Distinct without fields means distinct on projection + // This will be handled in buildQueryMap + if len(q.projection) > 0 { + q.distinctOn = q.projection + } + return q +} + +// DistinctOn returns a query that removes duplicates based on the given field names. +// API compatible with cloud.google.com/go/datastore. +func (q *Query) DistinctOn(fieldNames ...string) *Query { + q.distinctOn = fieldNames + return q +} + +// Namespace sets the namespace for the query. +// API compatible with cloud.google.com/go/datastore. +func (q *Query) Namespace(ns string) *Query { + q.namespace = ns + return q +} + +// Start sets the starting cursor for the query results. +// API compatible with cloud.google.com/go/datastore. +func (q *Query) Start(c Cursor) *Query { + q.startCursor = c + return q +} + +// End sets the ending cursor for the query results. +// API compatible with cloud.google.com/go/datastore. +func (q *Query) End(c Cursor) *Query { + q.endCursor = c + return q +} + +// buildQueryMap creates a Datastore API query map from a Query object. +func buildQueryMap(query *Query) map[string]any { + queryMap := map[string]any{ + "kind": []map[string]any{{"name": query.kind}}, + } + + // Add namespace via partition ID if specified + if query.namespace != "" { + queryMap["partitionId"] = map[string]any{ + "namespaceId": query.namespace, + } + } + + // Add filters + if len(query.filters) > 0 { + var compositeFilters []map[string]any + for _, f := range query.filters { + encodedVal, err := encodeValue(f.value) + if err != nil { + // Skip invalid filters + continue + } + compositeFilters = append(compositeFilters, map[string]any{ + "propertyFilter": map[string]any{ + "property": map[string]string{"name": f.property}, + "op": f.operator, + "value": encodedVal, + }, + }) + } + + if len(compositeFilters) == 1 { + queryMap["filter"] = compositeFilters[0] + } else if len(compositeFilters) > 1 { + queryMap["filter"] = map[string]any{ + "compositeFilter": map[string]any{ + "op": "AND", + "filters": compositeFilters, + }, + } + } + } + + // Add ancestor filter + if query.ancestor != nil { + ancestorFilter := map[string]any{ + "propertyFilter": map[string]any{ + "property": map[string]string{"name": "__key__"}, + "op": "HAS_ANCESTOR", + "value": map[string]any{"keyValue": keyToJSON(query.ancestor)}, + }, + } + + // Combine with existing filters if present + if existingFilter, ok := queryMap["filter"]; ok { + existingMap, ok := existingFilter.(map[string]any) + if !ok { + // Skip if filter is invalid + queryMap["filter"] = ancestorFilter + } else { + queryMap["filter"] = map[string]any{ + "compositeFilter": map[string]any{ + "op": "AND", + "filters": []map[string]any{existingMap, ancestorFilter}, + }, + } + } + } else { + queryMap["filter"] = ancestorFilter + } + } + + // Add ordering + if len(query.orders) > 0 { + var orders []map[string]any + for _, o := range query.orders { + orders = append(orders, map[string]any{ + "property": map[string]string{"name": o.property}, + "direction": o.direction, + }) + } + queryMap["order"] = orders + } + + // Add projection + if len(query.projection) > 0 { + var projections []map[string]any + for _, field := range query.projection { + projections = append(projections, map[string]any{ + "property": map[string]string{"name": field}, + }) + } + queryMap["projection"] = projections + } else if query.keysOnly { + // Keys-only projection + queryMap["projection"] = []map[string]any{{"property": map[string]string{"name": "__key__"}}} + } + + // Add distinct on + if len(query.distinctOn) > 0 { + var distinctFields []map[string]any + for _, field := range query.distinctOn { + distinctFields = append(distinctFields, map[string]any{ + "property": map[string]string{"name": field}, + }) + } + queryMap["distinctOn"] = distinctFields + } + + // Add limit + if query.limit > 0 { + queryMap["limit"] = query.limit + } + + // Add offset + if query.offset > 0 { + queryMap["offset"] = query.offset + } + + // Add cursors + if query.startCursor != "" { + queryMap["startCursor"] = string(query.startCursor) + } + if query.endCursor != "" { + queryMap["endCursor"] = string(query.endCursor) + } + + return queryMap +} + +// AllKeys returns all keys matching the query. +// This is a convenience method for KeysOnly queries. +func (c *Client) AllKeys(ctx context.Context, q *Query) ([]*Key, error) { + if !q.keysOnly { + c.logger.WarnContext(ctx, "AllKeys called on non-KeysOnly query") + return nil, errors.New("AllKeys requires KeysOnly query") + } + + c.logger.DebugContext(ctx, "querying for keys", "kind", q.kind, "limit", q.limit) + + token, err := auth.AccessToken(ctx) + if err != nil { + c.logger.ErrorContext(ctx, "failed to get access token", "error", err) + return nil, fmt.Errorf("failed to get access token: %w", err) + } + + query := buildQueryMap(q) + + reqBody := map[string]any{"query": query} + if c.databaseID != "" { + reqBody["databaseId"] = c.databaseID + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + c.logger.ErrorContext(ctx, "failed to marshal request", "error", err) + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + // URL-encode project ID to prevent injection attacks + reqURL := fmt.Sprintf("%s/projects/%s:runQuery", c.baseURL, neturl.PathEscape(c.projectID)) + body, err := doRequest(ctx, c.logger, reqURL, jsonData, token, c.projectID, c.databaseID) + if err != nil { + c.logger.ErrorContext(ctx, "query request failed", "error", err, "kind", q.kind) + return nil, err + } + + var result struct { + Batch struct { + EntityResults []struct { + Entity map[string]any `json:"entity"` + } `json:"entityResults"` + } `json:"batch"` + } + + if err := json.Unmarshal(body, &result); err != nil { + c.logger.ErrorContext(ctx, "failed to parse response", "error", err) + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + keys := make([]*Key, 0, len(result.Batch.EntityResults)) + for _, er := range result.Batch.EntityResults { + key, err := keyFromJSON(er.Entity["key"]) + if err != nil { + c.logger.ErrorContext(ctx, "failed to parse key from response", "error", err) + return nil, err + } + keys = append(keys, key) + } + + c.logger.DebugContext(ctx, "query completed successfully", "kind", q.kind, "keys_found", len(keys)) + return keys, nil +} + +// GetAll retrieves all entities matching the query and stores them in dst. +// dst must be a pointer to a slice of structs. +// Returns the keys of the retrieved entities and any error. +// This matches the API of cloud.google.com/go/datastore. +func (c *Client) GetAll(ctx context.Context, query *Query, dst any) ([]*Key, error) { + c.logger.DebugContext(ctx, "querying for entities", "kind", query.kind, "limit", query.limit) + + token, err := auth.AccessToken(ctx) + if err != nil { + c.logger.ErrorContext(ctx, "failed to get access token", "error", err) + return nil, fmt.Errorf("failed to get access token: %w", err) + } + + queryObj := buildQueryMap(query) + + reqBody := map[string]any{"query": queryObj} + if c.databaseID != "" { + reqBody["databaseId"] = c.databaseID + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + c.logger.ErrorContext(ctx, "failed to marshal request", "error", err) + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + // URL-encode project ID to prevent injection attacks + reqURL := fmt.Sprintf("%s/projects/%s:runQuery", c.baseURL, neturl.PathEscape(c.projectID)) + body, err := doRequest(ctx, c.logger, reqURL, jsonData, token, c.projectID, c.databaseID) + if err != nil { + c.logger.ErrorContext(ctx, "query request failed", "error", err, "kind", query.kind) + return nil, err + } + + var result struct { + Batch struct { + EntityResults []struct { + Entity map[string]any `json:"entity"` + } `json:"entityResults"` + } `json:"batch"` + } + + if err := json.Unmarshal(body, &result); err != nil { + c.logger.ErrorContext(ctx, "failed to parse response", "error", err) + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + // Verify dst is a pointer to slice + v := reflect.ValueOf(dst) + if v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Slice { + return nil, errors.New("dst must be a pointer to slice") + } + + sliceType := v.Elem().Type() + elemType := sliceType.Elem() + + // Create new slice of correct size + slice := reflect.MakeSlice(sliceType, 0, len(result.Batch.EntityResults)) + keys := make([]*Key, 0, len(result.Batch.EntityResults)) + + for _, er := range result.Batch.EntityResults { + // Extract key + key, err := keyFromJSON(er.Entity["key"]) + if err != nil { + c.logger.ErrorContext(ctx, "failed to parse key from response", "error", err) + return nil, err + } + keys = append(keys, key) + + // Decode entity + elem := reflect.New(elemType).Elem() + if err := decodeEntity(er.Entity, elem.Addr().Interface()); err != nil { + c.logger.ErrorContext(ctx, "failed to decode entity", "error", err) + return nil, err + } + slice = reflect.Append(slice, elem) + } + + v.Elem().Set(slice) + c.logger.DebugContext(ctx, "query completed successfully", "kind", query.kind, "entities_found", len(keys)) + return keys, nil +} + +// Count returns the number of entities matching the query. +// Deprecated: Use aggregation queries with RunAggregationQuery instead. +// API compatible with cloud.google.com/go/datastore. +func (c *Client) Count(ctx context.Context, q *Query) (int, error) { + c.logger.DebugContext(ctx, "counting entities", "kind", q.kind) + + token, err := auth.AccessToken(ctx) + if err != nil { + c.logger.ErrorContext(ctx, "failed to get access token", "error", err) + return 0, fmt.Errorf("failed to get access token: %w", err) + } + + // Build aggregation query with COUNT + queryObj := buildQueryMap(q) + aggregationQuery := map[string]any{ + "aggregations": []map[string]any{ + { + "alias": "total", + "count": map[string]any{}, + }, + }, + "nestedQuery": queryObj, + } + + reqBody := map[string]any{ + "aggregationQuery": aggregationQuery, + } + if c.databaseID != "" { + reqBody["databaseId"] = c.databaseID + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + c.logger.ErrorContext(ctx, "failed to marshal request", "error", err) + return 0, fmt.Errorf("failed to marshal request: %w", err) + } + + // URL-encode project ID to prevent injection attacks + reqURL := fmt.Sprintf("%s/projects/%s:runAggregationQuery", c.baseURL, neturl.PathEscape(c.projectID)) + body, err := doRequest(ctx, c.logger, reqURL, jsonData, token, c.projectID, c.databaseID) + if err != nil { + c.logger.ErrorContext(ctx, "count query failed", "error", err, "kind", q.kind) + return 0, err + } + + var result struct { + Batch struct { + AggregationResults []struct { + AggregateProperties map[string]struct { + IntegerValue string `json:"integerValue"` + } `json:"aggregateProperties"` + } `json:"aggregationResults"` + } `json:"batch"` + } + + if err := json.Unmarshal(body, &result); err != nil { + c.logger.ErrorContext(ctx, "failed to parse response", "error", err) + return 0, fmt.Errorf("failed to parse count response: %w", err) + } + + if len(result.Batch.AggregationResults) == 0 { + c.logger.DebugContext(ctx, "no results returned", "kind", q.kind) + return 0, nil + } + + // Extract count from total aggregation + countVal, ok := result.Batch.AggregationResults[0].AggregateProperties["total"] + if !ok { + c.logger.ErrorContext(ctx, "count not found in response") + return 0, errors.New("count not found in aggregation response") + } + + count, err := strconv.Atoi(countVal.IntegerValue) + if err != nil { + c.logger.ErrorContext(ctx, "failed to parse count", "error", err, "value", countVal.IntegerValue) + return 0, fmt.Errorf("failed to parse count: %w", err) + } + + c.logger.DebugContext(ctx, "count completed successfully", "kind", q.kind, "count", count) + return count, nil +} + +// Run executes the query and returns an iterator for the results. +// API compatible with cloud.google.com/go/datastore. +func (c *Client) Run(ctx context.Context, q *Query) *Iterator { + return &Iterator{ + ctx: ctx, + client: c, + query: q, + fetchNext: true, + } +} diff --git a/pkg/datastore/query_coverage_test.go b/pkg/datastore/query_coverage_test.go new file mode 100644 index 0000000..139111e --- /dev/null +++ b/pkg/datastore/query_coverage_test.go @@ -0,0 +1,386 @@ +package datastore_test + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/codeGROOVE-dev/ds9/pkg/datastore" +) + +func TestCount_Coverage(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + t.Run("CountEmptyKind", func(t *testing.T) { + // Test counting entities in an empty kind + q := datastore.NewQuery("CountTest") + count, err := client.Count(ctx, q) + if err != nil { + t.Fatalf("Count failed: %v", err) + } + if count != 0 { + t.Errorf("Expected count 0 for empty kind, got %d", count) + } + }) + + t.Run("CountWithEntities", func(t *testing.T) { + // Create test entities + for i := range 5 { + key := datastore.IDKey("CountTest2", int64(i+1), nil) + entity := &testEntity{ + Name: "test", + Count: int64(i), + } + if _, err := client.Put(ctx, key, entity); err != nil { + t.Fatalf("Put failed: %v", err) + } + } + + // Count entities + q := datastore.NewQuery("CountTest2") + count, err := client.Count(ctx, q) + if err != nil { + t.Fatalf("Count failed: %v", err) + } + if count != 5 { + t.Errorf("Expected count 5, got %d", count) + } + }) + + t.Run("CountWithFilter", func(t *testing.T) { + // Create test entities with different counts + for i := range 5 { + key := datastore.IDKey("CountTest3", int64(i+1), nil) + entity := &testEntity{ + Name: "test", + Count: int64(i), + } + if _, err := client.Put(ctx, key, entity); err != nil { + t.Fatalf("Put failed: %v", err) + } + } + + // Count entities with filter + q := datastore.NewQuery("CountTest3").Filter("count >", 2) + count, err := client.Count(ctx, q) + if err != nil { + t.Fatalf("Count with filter failed: %v", err) + } + if count != 2 { + t.Errorf("Expected count 2 for filtered query, got %d", count) + } + }) +} + +func TestGetAll_Coverage(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + t.Run("GetAllEmpty", func(t *testing.T) { + // Test GetAll on empty kind + q := datastore.NewQuery("GetAllTest") + var entities []testEntity + keys, err := client.GetAll(ctx, q, &entities) + if err != nil { + t.Fatalf("GetAll failed: %v", err) + } + if len(keys) != 0 { + t.Errorf("Expected 0 keys, got %d", len(keys)) + } + if len(entities) != 0 { + t.Errorf("Expected 0 entities, got %d", len(entities)) + } + }) + + t.Run("GetAllWithEntities", func(t *testing.T) { + // Create test entities + expectedCount := 7 + for i := range expectedCount { + key := datastore.IDKey("GetAllTest2", int64(i+1), nil) + entity := &testEntity{ + Name: "test", + Count: int64(i), + UpdatedAt: time.Now().UTC().Truncate(time.Microsecond), + } + if _, err := client.Put(ctx, key, entity); err != nil { + t.Fatalf("Put failed: %v", err) + } + } + + // Get all entities + q := datastore.NewQuery("GetAllTest2") + var entities []testEntity + keys, err := client.GetAll(ctx, q, &entities) + if err != nil { + t.Fatalf("GetAll failed: %v", err) + } + if len(keys) != expectedCount { + t.Errorf("Expected %d keys, got %d", expectedCount, len(keys)) + } + if len(entities) != expectedCount { + t.Errorf("Expected %d entities, got %d", expectedCount, len(entities)) + } + }) + + t.Run("GetAllWithLimit", func(t *testing.T) { + // Create test entities + for i := range 5 { + key := datastore.IDKey("GetAllTest3", int64(i+1), nil) + entity := &testEntity{ + Name: "test", + Count: int64(i), + } + if _, err := client.Put(ctx, key, entity); err != nil { + t.Fatalf("Put failed: %v", err) + } + } + + // Get all entities with limit + q := datastore.NewQuery("GetAllTest3").Limit(3) + var entities []testEntity + keys, err := client.GetAll(ctx, q, &entities) + if err != nil { + t.Fatalf("GetAll with limit failed: %v", err) + } + if len(keys) != 3 { + t.Errorf("Expected 3 keys, got %d", len(keys)) + } + if len(entities) != 3 { + t.Errorf("Expected 3 entities, got %d", len(entities)) + } + }) + + t.Run("GetAllWithFilter", func(t *testing.T) { + // Create test entities + for i := range 5 { + key := datastore.IDKey("GetAllTest4", int64(i+1), nil) + entity := &testEntity{ + Name: "test", + Count: int64(i), + } + if _, err := client.Put(ctx, key, entity); err != nil { + t.Fatalf("Put failed: %v", err) + } + } + + // Get all entities with filter - mock now supports filtering + q := datastore.NewQuery("GetAllTest4").Filter("count >=", 3) + var entities []testEntity + keys, err := client.GetAll(ctx, q, &entities) + if err != nil { + t.Fatalf("GetAll with filter failed: %v", err) + } + // Should get entities with count >= 3 (3, 4) + if len(keys) != 2 { + t.Errorf("Expected 2 keys with filter, got %d", len(keys)) + } + if len(entities) != 2 { + t.Errorf("Expected 2 entities with filter, got %d", len(entities)) + } + }) + + t.Run("GetAllErrorInvalidDst", func(t *testing.T) { + // Test error case: dst is not a pointer to slice + q := datastore.NewQuery("GetAllTest5") + var entity testEntity + _, err := client.GetAll(ctx, q, &entity) // Pass pointer to struct instead of slice + if err == nil { + t.Error("Expected error for invalid dst, got nil") + } + if !errors.Is(err, errors.New("dst must be a pointer to slice")) && err.Error() != "dst must be a pointer to slice" { + t.Errorf("Expected 'dst must be a pointer to slice' error, got: %v", err) + } + }) + + t.Run("GetAllErrorNotPointer", func(t *testing.T) { + // Test error case: dst is not a pointer + q := datastore.NewQuery("GetAllTest6") + var entities []testEntity + _, err := client.GetAll(ctx, q, entities) // Pass slice instead of pointer to slice + if err == nil { + t.Error("Expected error for non-pointer dst, got nil") + } + }) +} + +func TestAllKeys_Coverage(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + t.Run("AllKeysEmpty", func(t *testing.T) { + // Test AllKeys on empty kind + q := datastore.NewQuery("AllKeysTest").KeysOnly() + keys, err := client.AllKeys(ctx, q) + if err != nil { + t.Fatalf("AllKeys failed: %v", err) + } + if len(keys) != 0 { + t.Errorf("Expected 0 keys, got %d", len(keys)) + } + }) + + t.Run("AllKeysWithEntities", func(t *testing.T) { + // Create test entities + expectedCount := 6 + for i := range expectedCount { + key := datastore.IDKey("AllKeysTest2", int64(i+1), nil) + entity := &testEntity{ + Name: "test", + Count: int64(i), + } + if _, err := client.Put(ctx, key, entity); err != nil { + t.Fatalf("Put failed: %v", err) + } + } + + // Get all keys + q := datastore.NewQuery("AllKeysTest2").KeysOnly() + keys, err := client.AllKeys(ctx, q) + if err != nil { + t.Fatalf("AllKeys failed: %v", err) + } + if len(keys) != expectedCount { + t.Errorf("Expected %d keys, got %d", expectedCount, len(keys)) + } + + // Verify keys are valid + for i, key := range keys { + if key.Kind != "AllKeysTest2" { + t.Errorf("Key %d: expected kind 'AllKeysTest2', got '%s'", i, key.Kind) + } + if key.Incomplete() { + t.Errorf("Key %d is incomplete", i) + } + } + }) +} + +func TestDeleteMulti_Coverage(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + t.Run("DeleteMultiSingle", func(t *testing.T) { + // Create a test entity + key := datastore.NameKey("DeleteMultiTest", "test1", nil) + entity := &testEntity{Name: "test", Count: 42} + if _, err := client.Put(ctx, key, entity); err != nil { + t.Fatalf("Put failed: %v", err) + } + + // Delete using DeleteMulti + err := client.DeleteMulti(ctx, []*datastore.Key{key}) + if err != nil { + t.Fatalf("DeleteMulti failed: %v", err) + } + + // Verify entity is deleted + var result testEntity + err = client.Get(ctx, key, &result) + if !errors.Is(err, datastore.ErrNoSuchEntity) { + t.Errorf("Expected ErrNoSuchEntity, got %v", err) + } + }) + + t.Run("DeleteMultiMultiple", func(t *testing.T) { + // Create multiple test entities + keys := []*datastore.Key{ + datastore.NameKey("DeleteMultiTest2", "test1", nil), + datastore.NameKey("DeleteMultiTest2", "test2", nil), + datastore.NameKey("DeleteMultiTest2", "test3", nil), + } + for _, key := range keys { + entity := &testEntity{Name: "test", Count: 42} + if _, err := client.Put(ctx, key, entity); err != nil { + t.Fatalf("Put failed: %v", err) + } + } + + // Delete using DeleteMulti + err := client.DeleteMulti(ctx, keys) + if err != nil { + t.Fatalf("DeleteMulti failed: %v", err) + } + + // Verify all entities are deleted + for _, key := range keys { + var result testEntity + err = client.Get(ctx, key, &result) + if !errors.Is(err, datastore.ErrNoSuchEntity) { + t.Errorf("Expected ErrNoSuchEntity for key %v, got %v", key, err) + } + } + }) +} + +func TestPutMulti_Coverage(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + t.Run("PutMultiMultiple", func(t *testing.T) { + // Test PutMulti with multiple complete keys + keys := []*datastore.Key{ + datastore.NameKey("PutMultiTest", "test1", nil), + datastore.NameKey("PutMultiTest", "test2", nil), + } + entities := []testEntity{ + {Name: "test1", Count: 1}, + {Name: "test2", Count: 2}, + } + + resultKeys, err := client.PutMulti(ctx, keys, entities) + if err != nil { + t.Fatalf("PutMulti failed: %v", err) + } + + if len(resultKeys) != len(keys) { + t.Errorf("Expected %d result keys, got %d", len(keys), len(resultKeys)) + } + + // Verify entities were stored + for i, key := range resultKeys { + var retrieved testEntity + if err := client.Get(ctx, key, &retrieved); err != nil { + t.Errorf("Failed to retrieve entity %d: %v", i, err) + } + } + }) +} + +func TestGetMulti_Coverage(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + t.Run("GetMultiMixedResults", func(t *testing.T) { + // Create one entity, leave another missing + key1 := datastore.NameKey("GetMultiTest", "exists", nil) + key2 := datastore.NameKey("GetMultiTest", "missing", nil) + + entity1 := &testEntity{Name: "test1", Count: 42} + if _, err := client.Put(ctx, key1, entity1); err != nil { + t.Fatalf("Put failed: %v", err) + } + + // Get both keys + entities := []testEntity{{}, {}} + err := client.GetMulti(ctx, []*datastore.Key{key1, key2}, entities) + + // Should get MultiError with one ErrNoSuchEntity + if err == nil { + t.Error("Expected MultiError, got nil") + } + }) +} diff --git a/pkg/datastore/query_test.go b/pkg/datastore/query_test.go new file mode 100644 index 0000000..1c8abcc --- /dev/null +++ b/pkg/datastore/query_test.go @@ -0,0 +1,559 @@ +package datastore_test + +import ( + "context" + "fmt" + "strings" + "testing" + "time" + + "github.com/codeGROOVE-dev/ds9/pkg/datastore" +) + +func TestQueryOperations(t *testing.T) { + // Test query builder methods + query := datastore.NewQuery("TestKind") + + if query.KeysOnly().KeysOnly() == nil { + t.Error("KeysOnly() should be chainable") + } + + if query.Limit(10).Limit(20) == nil { + t.Error("Limit() should be chainable") + } +} + +func TestEmptyQuery(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Query for keys when no entities exist + query := datastore.NewQuery("NonExistent").KeysOnly() + keys, err := client.AllKeys(ctx, query) + if err != nil { + t.Fatalf("AllKeys failed: %v", err) + } + + if len(keys) != 0 { + t.Errorf("expected 0 keys, got %d", len(keys)) + } +} + +func TestQueryWithLimitZero(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Store some entities + for i := range 5 { + key := datastore.NameKey("LimitKind", string(rune('a'+i)), nil) + entity := testEntity{Name: "item", Count: int64(i)} + if _, err := client.Put(ctx, key, &entity); err != nil { + t.Fatalf("Put failed: %v", err) + } + } + + // Query with limit 0 (should return all) + query := datastore.NewQuery("LimitKind").KeysOnly().Limit(0) + keys, err := client.AllKeys(ctx, query) + if err != nil { + t.Fatalf("AllKeys failed: %v", err) + } + + if len(keys) == 0 { + t.Error("expected keys, got 0 (limit 0 should mean unlimited)") + } +} + +func TestQueryWithLimitLessThanResults(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Store 10 entities + for i := range 10 { + key := datastore.NameKey("LimitKind2", string(rune('a'+i)), nil) + entity := testEntity{Name: "item", Count: int64(i)} + if _, err := client.Put(ctx, key, &entity); err != nil { + t.Fatalf("Put failed: %v", err) + } + } + + // Query with limit 3 + query := datastore.NewQuery("LimitKind2").KeysOnly().Limit(3) + keys, err := client.AllKeys(ctx, query) + if err != nil { + t.Fatalf("AllKeys failed: %v", err) + } + + if len(keys) != 3 { + t.Errorf("expected 3 keys, got %d", len(keys)) + } +} + +func TestQueryNonKeysOnly(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Try to call AllKeys with non-KeysOnly query + query := datastore.NewQuery("TestKind") + _, err := client.AllKeys(ctx, query) + + if err == nil { + t.Error("expected error for non-KeysOnly query") + } + + if !strings.Contains(err.Error(), "KeysOnly") { + t.Errorf("expected error to mention KeysOnly, got: %v", err) + } +} + +func TestQueryWithZeroLimit(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Put entities + for i := range 5 { + key := datastore.NameKey("ZeroLimit", fmt.Sprintf("key-%d", i), nil) + entity := &testEntity{Name: "test", Count: int64(i)} + _, err := client.Put(ctx, key, entity) + if err != nil { + t.Fatalf("Put failed: %v", err) + } + } + + // Query with limit 0 (should return all) + query := datastore.NewQuery("ZeroLimit").KeysOnly().Limit(0) + keys, err := client.AllKeys(ctx, query) + if err != nil { + t.Fatalf("AllKeys with limit 0 failed: %v", err) + } + + // Limit 0 should mean unlimited + if len(keys) == 0 { + t.Error("expected results with limit 0 (unlimited), got 0") + } +} + +func TestQueryWithVeryLargeLimit(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Put a few entities + for i := range 3 { + key := datastore.NameKey("LargeLimit", fmt.Sprintf("key-%d", i), nil) + entity := &testEntity{Name: "test", Count: int64(i)} + _, err := client.Put(ctx, key, entity) + if err != nil { + t.Fatalf("Put failed: %v", err) + } + } + + // Query with very large limit + query := datastore.NewQuery("LargeLimit").KeysOnly().Limit(10000) + keys, err := client.AllKeys(ctx, query) + if err != nil { + t.Fatalf("AllKeys with large limit failed: %v", err) + } + + // Should return all 3 + if len(keys) != 3 { + t.Errorf("expected 3 keys, got %d", len(keys)) + } +} + +func TestGetAllEmpty(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + query := datastore.NewQuery("NonExistentKind") + var results []testEntity + + keys, err := client.GetAll(ctx, query, &results) + if err != nil { + t.Fatalf("GetAll failed: %v", err) + } + + if len(results) != 0 { + t.Errorf("Expected 0 entities, got %d", len(results)) + } + + if len(keys) != 0 { + t.Errorf("Expected 0 keys, got %d", len(keys)) + } +} + +func TestGetAllInvalidDst(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + query := datastore.NewQuery("TestKind") + + tests := []struct { + name string + dst any + }{ + {"not a pointer", []testEntity{}}, + {"not a slice", new(testEntity)}, + {"nil", nil}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := client.GetAll(ctx, query, tt.dst) + if err == nil { + t.Error("Expected error for invalid dst, got nil") + } + }) + } +} + +func TestGetAllSingleEntity(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Create entity + key := datastore.NameKey("SingleGetAll", "single1", nil) + entity := testEntity{ + Name: "single", + Count: 42, + Active: true, + Score: 3.14, + UpdatedAt: time.Now().UTC().Truncate(time.Microsecond), + Notes: "test notes", + } + + _, err := client.Put(ctx, key, &entity) + if err != nil { + t.Fatalf("Put failed: %v", err) + } + + // Test GetAll + query := datastore.NewQuery("SingleGetAll") + var results []testEntity + keys, err := client.GetAll(ctx, query, &results) + if err != nil { + t.Fatalf("GetAll failed: %v", err) + } + + if len(results) != 1 { + t.Fatalf("Expected 1 entity, got %d", len(results)) + } + + if len(keys) != 1 { + t.Fatalf("Expected 1 key, got %d", len(keys)) + } + + // Verify entity content + if results[0].Name != "single" { + t.Errorf("Expected name 'single', got '%s'", results[0].Name) + } + if results[0].Count != 42 { + t.Errorf("Expected count 42, got %d", results[0].Count) + } + if !results[0].Active { + t.Error("Expected active=true") + } + if results[0].Score != 3.14 { + t.Errorf("Expected score 3.14, got %f", results[0].Score) + } + + // Verify key + if keys[0].Kind != "SingleGetAll" { + t.Errorf("Expected kind 'SingleGetAll', got '%s'", keys[0].Kind) + } + if keys[0].Name != "single1" { + t.Errorf("Expected key name 'single1', got '%s'", keys[0].Name) + } +} + +func TestGetAllMultipleEntities(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Create multiple entities + count := 5 + keys := make([]*datastore.Key, count) + entities := make([]testEntity, count) + + for i := range count { + keys[i] = datastore.NameKey("MultiGetAll", fmt.Sprintf("entity%d", i), nil) + entities[i] = testEntity{ + Name: fmt.Sprintf("entity%d", i), + Count: int64(i * 10), + Active: i%2 == 0, + Score: float64(i) * 1.5, + UpdatedAt: time.Now().UTC().Truncate(time.Microsecond), + } + } + + _, err := client.PutMulti(ctx, keys, entities) + if err != nil { + t.Fatalf("PutMulti failed: %v", err) + } + + // Test GetAll + query := datastore.NewQuery("MultiGetAll") + var results []testEntity + returnedKeys, err := client.GetAll(ctx, query, &results) + if err != nil { + t.Fatalf("GetAll failed: %v", err) + } + + if len(results) != count { + t.Fatalf("Expected %d entities, got %d", count, len(results)) + } + + if len(returnedKeys) != count { + t.Fatalf("Expected %d keys, got %d", count, len(returnedKeys)) + } + + // Verify we got all entities + foundNames := make(map[string]bool) + for _, entity := range results { + foundNames[entity.Name] = true + } + + for i := range count { + expectedName := fmt.Sprintf("entity%d", i) + if !foundNames[expectedName] { + t.Errorf("Missing entity: %s", expectedName) + } + } +} + +func TestGetAllWithLimitVariations(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Setup: Create 10 entities + keys := make([]*datastore.Key, 10) + entities := make([]testEntity, 10) + for i := range 10 { + keys[i] = datastore.NameKey("LimitGetAll", fmt.Sprintf("key%d", i), nil) + entities[i] = testEntity{ + Name: fmt.Sprintf("entity%d", i), + Count: int64(i), + Active: true, + UpdatedAt: time.Now().UTC().Truncate(time.Microsecond), + } + } + + _, err := client.PutMulti(ctx, keys, entities) + if err != nil { + t.Fatalf("PutMulti failed: %v", err) + } + + tests := []struct { + name string + limit int + expected int + }{ + {"Limit 1", 1, 1}, + {"Limit 3", 3, 3}, + {"Limit 5", 5, 5}, + {"Limit 10", 10, 10}, + {"Limit 20 (more than available)", 20, 10}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + query := datastore.NewQuery("LimitGetAll").Limit(tt.limit) + var results []testEntity + keys, err := client.GetAll(ctx, query, &results) + if err != nil { + t.Fatalf("GetAll failed: %v", err) + } + + if len(results) != tt.expected { + t.Errorf("Expected %d entities, got %d", tt.expected, len(results)) + } + + if len(keys) != tt.expected { + t.Errorf("Expected %d keys, got %d", tt.expected, len(keys)) + } + }) + } +} + +func TestCount(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + t.Run("CountEmptyKind", func(t *testing.T) { + q := datastore.NewQuery("NonExistent") + count, err := client.Count(ctx, q) + if err != nil { + t.Fatalf("Count failed: %v", err) + } + + if count != 0 { + t.Errorf("Expected count 0, got %d", count) + } + }) + + t.Run("CountWithEntities", func(t *testing.T) { + // Create some entities + for i := range 5 { + key := datastore.IDKey("CountTest", int64(i+1), nil) + entity := &testEntity{ + Name: fmt.Sprintf("entity-%d", i), + Count: int64(i), + } + if _, err := client.Put(ctx, key, entity); err != nil { + t.Fatalf("Put failed: %v", err) + } + } + + q := datastore.NewQuery("CountTest") + count, err := client.Count(ctx, q) + if err != nil { + t.Fatalf("Count failed: %v", err) + } + + if count != 5 { + t.Errorf("Expected count 5, got %d", count) + } + }) + + t.Run("CountWithFilter", func(t *testing.T) { + // Create entities with different counts + for i := range 10 { + key := datastore.IDKey("FilterCount", int64(i+1), nil) + entity := &testEntity{ + Name: fmt.Sprintf("entity-%d", i), + Count: int64(i), + } + if _, err := client.Put(ctx, key, entity); err != nil { + t.Fatalf("Put failed: %v", err) + } + } + + // Count entities where count >= 5 + q := datastore.NewQuery("FilterCount").Filter("count >=", int64(5)) + count, err := client.Count(ctx, q) + if err != nil { + t.Fatalf("Count with filter failed: %v", err) + } + + // Should return entities with count 5,6,7,8,9 = 5 entities + if count != 5 { + t.Errorf("Expected count 5, got %d", count) + } + }) + + t.Run("CountWithLimit", func(t *testing.T) { + // Create entities + for i := range 10 { + key := datastore.IDKey("LimitCount", int64(i+1), nil) + entity := &testEntity{ + Name: fmt.Sprintf("entity-%d", i), + Count: int64(i), + } + if _, err := client.Put(ctx, key, entity); err != nil { + t.Fatalf("Put failed: %v", err) + } + } + + // Count with limit - note: count should respect limit + q := datastore.NewQuery("LimitCount").Limit(3) + count, err := client.Count(ctx, q) + if err != nil { + t.Fatalf("Count with limit failed: %v", err) + } + + // Mock implementation may return full count, but limit is respected + if count > 10 { + t.Errorf("Count should not exceed actual entities: %d", count) + } + }) +} + +func TestQueryNamespace(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + t.Run("NamespaceFilter", func(t *testing.T) { + // Note: ds9mock may not fully support namespaces, but we test the API + q := datastore.NewQuery("Task").Namespace("custom-namespace") + + var entities []testEntity + _, err := client.GetAll(ctx, q, &entities) + // Should not error even if namespace is not supported by mock + if err != nil { + t.Logf("GetAll with namespace: %v", err) + } + }) + + t.Run("EmptyNamespace", func(t *testing.T) { + q := datastore.NewQuery("Task").Namespace("") + + var entities []testEntity + _, err := client.GetAll(ctx, q, &entities) + if err != nil { + t.Logf("GetAll with empty namespace: %v", err) + } + }) +} + +func TestQueryDistinct(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + t.Run("Distinct", func(t *testing.T) { + // Create duplicate entities + for i := range 3 { + key := datastore.IDKey("DistinctTest", int64(i+1), nil) + entity := &testEntity{ + Name: "same-name", // Same name for all + Count: int64(i), + } + if _, err := client.Put(ctx, key, entity); err != nil { + t.Fatalf("Put failed: %v", err) + } + } + + // Query with distinct on Name field + q := datastore.NewQuery("DistinctTest").Project("name").Distinct() + + var entities []testEntity + _, err := client.GetAll(ctx, q, &entities) + if err != nil { + t.Logf("GetAll with Distinct: %v", err) + } + }) + + t.Run("DistinctOn", func(t *testing.T) { + q := datastore.NewQuery("Task").DistinctOn("name", "count") + + var entities []testEntity + _, err := client.GetAll(ctx, q, &entities) + if err != nil { + t.Logf("GetAll with DistinctOn: %v", err) + } + }) +} diff --git a/query_test.go b/pkg/datastore/query_unit_test.go similarity index 92% rename from query_test.go rename to pkg/datastore/query_unit_test.go index a2c17c5..78c2a9b 100644 --- a/query_test.go +++ b/pkg/datastore/query_unit_test.go @@ -1,4 +1,4 @@ -package ds9 +package datastore import ( "testing" @@ -288,3 +288,25 @@ func TestBuildQueryMapKeysOnly(t *testing.T) { t.Fatal("Expected projection in query map for keys-only") } } + +func TestQueryStart(t *testing.T) { + cursor := Cursor("test-start-cursor") + q := NewQuery("TestKind").Start(cursor) + + // Start returns a new Query with the cursor set + // We can verify it's chainable and doesn't panic + if q == nil { + t.Error("Expected non-nil query") + } +} + +func TestQueryEnd(t *testing.T) { + cursor := Cursor("test-end-cursor") + q := NewQuery("TestKind").End(cursor) + + // End returns a new Query with the cursor set + // We can verify it's chainable and doesn't panic + if q == nil { + t.Error("Expected non-nil query") + } +} diff --git a/pkg/datastore/transaction.go b/pkg/datastore/transaction.go new file mode 100644 index 0000000..d8c56ee --- /dev/null +++ b/pkg/datastore/transaction.go @@ -0,0 +1,655 @@ +package datastore + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + neturl "net/url" + "reflect" + "strings" + "time" + + "github.com/codeGROOVE-dev/ds9/auth" +) + +// Commit represents the result of a committed transaction. +// This is provided for API compatibility with cloud.google.com/go/datastore. +type Commit struct{} + +// Transaction represents a Datastore transaction. +// Note: This struct stores context for API compatibility with Google's official +// cloud.google.com/go/datastore library, which uses the same pattern. +type Transaction struct { + ctx context.Context //nolint:containedctx // Required for API compatibility with cloud.google.com/go/datastore + client *Client + id string + mutations []map[string]any +} + +// TransactionOption configures transaction behavior. +type TransactionOption interface { + apply(*transactionSettings) +} + +type transactionSettings struct { + readTime time.Time + maxAttempts int +} + +type maxAttemptsOption int + +func (o maxAttemptsOption) apply(s *transactionSettings) { + s.maxAttempts = int(o) +} + +// MaxAttempts returns a TransactionOption that specifies the maximum number +// of times a transaction should be attempted before giving up. +func MaxAttempts(n int) TransactionOption { + return maxAttemptsOption(n) +} + +type readTimeOption struct { + t time.Time +} + +func (o readTimeOption) apply(s *transactionSettings) { + s.readTime = o.t +} + +// WithReadTime returns a TransactionOption that sets a specific timestamp +// at which to read data, enabling reading from a particular snapshot in time. +func WithReadTime(t time.Time) TransactionOption { + return readTimeOption{t: t} +} + +// NewTransaction creates a new transaction. +// The caller must call Commit or Rollback when done. +// API compatible with cloud.google.com/go/datastore. +func (c *Client) NewTransaction(ctx context.Context, opts ...TransactionOption) (*Transaction, error) { + settings := transactionSettings{ + maxAttempts: 3, // default (not used for NewTransaction, but kept for consistency) + } + for _, opt := range opts { + opt.apply(&settings) + } + + token, err := auth.AccessToken(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get access token: %w", err) + } + + // Begin transaction + reqBody := map[string]any{} + if c.databaseID != "" { + reqBody["databaseId"] = c.databaseID + } + + // Add transaction options if needed + if !settings.readTime.IsZero() { + reqBody["transactionOptions"] = map[string]any{ + "readOnly": map[string]any{ + "readTime": settings.readTime.Format(time.RFC3339Nano), + }, + } + } else { + reqBody["transactionOptions"] = map[string]any{ + "readWrite": map[string]any{}, + } + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, err + } + + // URL-encode project ID to prevent injection attacks + reqURL := fmt.Sprintf("%s/projects/%s:beginTransaction", c.baseURL, neturl.PathEscape(c.projectID)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, bytes.NewReader(jsonData)) + if err != nil { + return nil, err + } + + req.Header.Set("Authorization", "Bearer "+token) + req.Header.Set("Content-Type", "application/json") + + // Add routing header for named databases + if c.databaseID != "" { + // URL-encode values to prevent header injection attacks + routingHeader := fmt.Sprintf("project_id=%s&database_id=%s", neturl.QueryEscape(c.projectID), neturl.QueryEscape(c.databaseID)) + req.Header.Set("X-Goog-Request-Params", routingHeader) + } + + resp, err := httpClient.Do(req) + if err != nil { + return nil, err + } + + body, err := io.ReadAll(io.LimitReader(resp.Body, maxBodySize)) + closeErr := resp.Body.Close() + if closeErr != nil { + c.logger.Warn("failed to close response body", "error", closeErr) + } + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("begin transaction failed with status %d: %s", resp.StatusCode, string(body)) + } + + var txResp struct { + Transaction string `json:"transaction"` + } + + if err := json.Unmarshal(body, &txResp); err != nil { + return nil, fmt.Errorf("failed to parse transaction response: %w", err) + } + + tx := &Transaction{ + ctx: ctx, + client: c, + id: txResp.Transaction, + } + + return tx, nil +} + +// RunInTransaction runs a function in a transaction. +// The function should use the transaction's Get and Put methods. +// API compatible with cloud.google.com/go/datastore. +func (c *Client) RunInTransaction(ctx context.Context, f func(*Transaction) error, opts ...TransactionOption) (*Commit, error) { + settings := transactionSettings{ + maxAttempts: 3, // default + } + for _, opt := range opts { + opt.apply(&settings) + } + + var lastErr error + + for attempt := range settings.maxAttempts { + token, err := auth.AccessToken(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get access token: %w", err) + } + + // Begin transaction + reqBody := map[string]any{} + if c.databaseID != "" { + reqBody["databaseId"] = c.databaseID + } + + // Add transaction options if needed + if !settings.readTime.IsZero() { + reqBody["transactionOptions"] = map[string]any{ + "readOnly": map[string]any{ + "readTime": settings.readTime.Format(time.RFC3339Nano), + }, + } + } else { + reqBody["transactionOptions"] = map[string]any{ + "readWrite": map[string]any{}, + } + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, err + } + + // URL-encode project ID to prevent injection attacks + reqURL := fmt.Sprintf("%s/projects/%s:beginTransaction", c.baseURL, neturl.PathEscape(c.projectID)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, bytes.NewReader(jsonData)) + if err != nil { + return nil, err + } + + req.Header.Set("Authorization", "Bearer "+token) + req.Header.Set("Content-Type", "application/json") + + // Add routing header for named databases + if c.databaseID != "" { + // URL-encode values to prevent header injection attacks + routingHeader := fmt.Sprintf("project_id=%s&database_id=%s", neturl.QueryEscape(c.projectID), neturl.QueryEscape(c.databaseID)) + req.Header.Set("X-Goog-Request-Params", routingHeader) + } + + resp, err := httpClient.Do(req) + if err != nil { + return nil, err + } + + body, err := io.ReadAll(io.LimitReader(resp.Body, maxBodySize)) + closeErr := resp.Body.Close() + if closeErr != nil { + c.logger.Warn("failed to close response body", "error", closeErr) + } + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("begin transaction failed with status %d: %s", resp.StatusCode, string(body)) + } + + var txResp struct { + Transaction string `json:"transaction"` + } + + if err := json.Unmarshal(body, &txResp); err != nil { + return nil, fmt.Errorf("failed to parse transaction response: %w", err) + } + + tx := &Transaction{ + ctx: ctx, + client: c, + id: txResp.Transaction, + } + + // Run the function + if err := f(tx); err != nil { + // Rollback is implicit if commit is not called + return nil, err + } + + // Commit the transaction + err = tx.doCommit(ctx, token) + if err == nil { + c.logger.Debug("transaction committed successfully", "attempt", attempt+1) + return &Commit{}, nil // Success + } + + c.logger.Warn("transaction commit failed", "attempt", attempt+1, "error", err) + + // Check if error contains 409 ABORTED - if so, retry + errStr := err.Error() + is409 := strings.Contains(errStr, "status 409") + isAborted := strings.Contains(errStr, "ABORTED") + + if is409 || isAborted { + lastErr = err + c.logger.Warn("transaction aborted, will retry", + "attempt", attempt+1, + "max_attempts", settings.maxAttempts, + "has_409", is409, + "has_ABORTED", isAborted, + "error", err) + + // Exponential backoff: 100ms, 200ms, 400ms + if attempt < settings.maxAttempts-1 { + backoffMS := 100 * (1 << attempt) + c.logger.Debug("sleeping before retry", "backoff_ms", backoffMS) + time.Sleep(time.Duration(backoffMS) * time.Millisecond) + } + continue + } + + // Non-retriable error + c.logger.Warn("non-retriable transaction error", "error", err) + return nil, err + } + + return nil, fmt.Errorf("transaction failed after %d attempts: %w", settings.maxAttempts, lastErr) +} + +// Get retrieves an entity within the transaction. +// API compatible with cloud.google.com/go/datastore. +func (tx *Transaction) Get(key *Key, dst any) error { + if key == nil { + return errors.New("key cannot be nil") + } + + token, err := auth.AccessToken(tx.ctx) + if err != nil { + return fmt.Errorf("failed to get access token: %w", err) + } + + reqBody := map[string]any{ + "keys": []map[string]any{ + keyToJSON(key), + }, + "readOptions": map[string]any{ + "transaction": tx.id, + }, + } + + if tx.client.databaseID != "" { + reqBody["databaseId"] = tx.client.databaseID + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return err + } + + // URL-encode project ID to prevent injection attacks + reqURL := fmt.Sprintf("%s/projects/%s:lookup", tx.client.baseURL, neturl.PathEscape(tx.client.projectID)) + req, err := http.NewRequestWithContext(tx.ctx, http.MethodPost, reqURL, bytes.NewReader(jsonData)) + if err != nil { + return err + } + + req.Header.Set("Authorization", "Bearer "+token) + req.Header.Set("Content-Type", "application/json") + + // Add routing header for named databases + if tx.client.databaseID != "" { + // URL-encode values to prevent header injection attacks + routingHeader := fmt.Sprintf("project_id=%s&database_id=%s", + neturl.QueryEscape(tx.client.projectID), + neturl.QueryEscape(tx.client.databaseID)) + req.Header.Set("X-Goog-Request-Params", routingHeader) + } + + resp, err := httpClient.Do(req) + if err != nil { + return err + } + defer func() { + if closeErr := resp.Body.Close(); closeErr != nil { + tx.client.logger.Warn("failed to close response body", "error", closeErr) + } + }() + + body, err := io.ReadAll(io.LimitReader(resp.Body, maxBodySize)) + if err != nil { + return err + } + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("transaction get failed with status %d: %s", resp.StatusCode, string(body)) + } + + var result struct { + Found []struct { + Entity map[string]any `json:"entity"` + } `json:"found"` + Missing []struct{} `json:"missing"` + } + + if err := json.Unmarshal(body, &result); err != nil { + return fmt.Errorf("failed to parse response: %w", err) + } + + if len(result.Found) == 0 { + return ErrNoSuchEntity + } + + return decodeEntity(result.Found[0].Entity, dst) +} + +// Put stores an entity within the transaction. +func (tx *Transaction) Put(key *Key, src any) (*Key, error) { + if key == nil { + return nil, errors.New("key cannot be nil") + } + + // Encode the entity + entity, err := encodeEntity(key, src) + if err != nil { + return nil, err + } + + // Create mutation + mutation := map[string]any{ + "upsert": entity, + } + + // Accumulate mutation for commit + tx.mutations = append(tx.mutations, mutation) + + return key, nil +} + +// Delete deletes an entity within the transaction. +// API compatible with cloud.google.com/go/datastore. +func (tx *Transaction) Delete(key *Key) error { + if key == nil { + return errors.New("key cannot be nil") + } + + // Create delete mutation + mutation := map[string]any{ + "delete": keyToJSON(key), + } + + // Accumulate mutation for commit + tx.mutations = append(tx.mutations, mutation) + + return nil +} + +// DeleteMulti deletes multiple entities within the transaction. +// API compatible with cloud.google.com/go/datastore. +func (tx *Transaction) DeleteMulti(keys []*Key) error { + for _, key := range keys { + if err := tx.Delete(key); err != nil { + return err + } + } + return nil +} + +// GetMulti retrieves multiple entities within the transaction. +// API compatible with cloud.google.com/go/datastore. +func (tx *Transaction) GetMulti(keys []*Key, dst any) error { + dstVal := reflect.ValueOf(dst) + if dstVal.Kind() != reflect.Ptr || dstVal.Elem().Kind() != reflect.Slice { + return errors.New("dst must be a pointer to a slice") + } + + slice := dstVal.Elem() + if len(keys) != slice.Len() { + return fmt.Errorf("keys and dst slices must have same length: %d vs %d", len(keys), slice.Len()) + } + + // Get each entity individually within the transaction + for i, key := range keys { + elem := slice.Index(i) + if elem.Kind() == reflect.Ptr { + // dst is []*Entity + if elem.IsNil() { + elem.Set(reflect.New(elem.Type().Elem())) + } + if err := tx.Get(key, elem.Interface()); err != nil { + return err + } + } else { + // dst is []Entity + if err := tx.Get(key, elem.Addr().Interface()); err != nil { + return err + } + } + } + + return nil +} + +// PutMulti stores multiple entities within the transaction. +// API compatible with cloud.google.com/go/datastore. +func (tx *Transaction) PutMulti(keys []*Key, src any) ([]*Key, error) { + srcVal := reflect.ValueOf(src) + if srcVal.Kind() != reflect.Slice { + return nil, errors.New("src must be a slice") + } + + if len(keys) != srcVal.Len() { + return nil, fmt.Errorf("keys and src slices must have same length: %d vs %d", len(keys), srcVal.Len()) + } + + // Put each entity individually within the transaction + for i, key := range keys { + elem := srcVal.Index(i) + var src any + if elem.Kind() == reflect.Ptr { + src = elem.Interface() + } else { + src = elem.Addr().Interface() + } + + if _, err := tx.Put(key, src); err != nil { + return nil, err + } + } + + return keys, nil +} + +// Commit applies the transaction's mutations. +// API compatible with cloud.google.com/go/datastore. +func (tx *Transaction) Commit() (*Commit, error) { + token, err := auth.AccessToken(tx.ctx) + if err != nil { + return nil, fmt.Errorf("failed to get access token: %w", err) + } + + if err := tx.doCommit(tx.ctx, token); err != nil { + return nil, err + } + + return &Commit{}, nil +} + +// Rollback abandons the transaction. +// API compatible with cloud.google.com/go/datastore. +func (tx *Transaction) Rollback() error { + // Datastore transactions are automatically rolled back if not committed + // So we just need to clear the mutations to prevent accidental commit + tx.mutations = nil + return nil +} + +// Mutate adds one or more mutations to the transaction. +// API compatible with cloud.google.com/go/datastore. +func (tx *Transaction) Mutate(muts ...*Mutation) ([]*PendingKey, error) { + if len(muts) == 0 { + return nil, nil + } + + // Build mutations array + pendingKeys := make([]*PendingKey, 0, len(muts)) + for i, mut := range muts { + if mut == nil { + return nil, fmt.Errorf("mutation at index %d is nil", i) + } + if mut.key == nil { + return nil, fmt.Errorf("mutation at index %d has nil key", i) + } + + mutMap := make(map[string]any) + + switch mut.op { + case MutationInsert: + if mut.entity == nil { + return nil, fmt.Errorf("insert mutation at index %d has nil entity", i) + } + entity, err := encodeEntity(mut.key, mut.entity) + if err != nil { + return nil, fmt.Errorf("failed to encode entity at index %d: %w", i, err) + } + mutMap["insert"] = entity + + case MutationUpdate: + if mut.entity == nil { + return nil, fmt.Errorf("update mutation at index %d has nil entity", i) + } + entity, err := encodeEntity(mut.key, mut.entity) + if err != nil { + return nil, fmt.Errorf("failed to encode entity at index %d: %w", i, err) + } + mutMap["update"] = entity + + case MutationUpsert: + if mut.entity == nil { + return nil, fmt.Errorf("upsert mutation at index %d has nil entity", i) + } + entity, err := encodeEntity(mut.key, mut.entity) + if err != nil { + return nil, fmt.Errorf("failed to encode entity at index %d: %w", i, err) + } + mutMap["upsert"] = entity + + case MutationDelete: + mutMap["delete"] = keyToJSON(mut.key) + + default: + return nil, fmt.Errorf("unknown mutation operation at index %d: %s", i, mut.op) + } + + tx.mutations = append(tx.mutations, mutMap) + + // Create a pending key for the result + pk := &PendingKey{key: mut.key} + pendingKeys = append(pendingKeys, pk) + } + + return pendingKeys, nil +} + +// PendingKey represents a key that will be resolved after a transaction commit. +// API compatible with cloud.google.com/go/datastore. +type PendingKey struct { + key *Key +} + +// commit commits the transaction. +func (tx *Transaction) doCommit(ctx context.Context, token string) error { + reqBody := map[string]any{ + "mode": "TRANSACTIONAL", + "transaction": tx.id, + "mutations": tx.mutations, + } + + if tx.client.databaseID != "" { + reqBody["databaseId"] = tx.client.databaseID + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return err + } + + // URL-encode project ID to prevent injection attacks + reqURL := fmt.Sprintf("%s/projects/%s:commit", tx.client.baseURL, neturl.PathEscape(tx.client.projectID)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, bytes.NewReader(jsonData)) + if err != nil { + return err + } + + req.Header.Set("Authorization", "Bearer "+token) + req.Header.Set("Content-Type", "application/json") + + // Add routing header for named databases + if tx.client.databaseID != "" { + // URL-encode values to prevent header injection attacks + routingHeader := fmt.Sprintf("project_id=%s&database_id=%s", + neturl.QueryEscape(tx.client.projectID), + neturl.QueryEscape(tx.client.databaseID)) + req.Header.Set("X-Goog-Request-Params", routingHeader) + } + + resp, err := httpClient.Do(req) + if err != nil { + return err + } + defer func() { + if closeErr := resp.Body.Close(); closeErr != nil { + tx.client.logger.Warn("failed to close response body", "error", closeErr) + } + }() + + body, err := io.ReadAll(io.LimitReader(resp.Body, maxBodySize)) + if err != nil { + return err + } + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("commit failed with status %d: %s", resp.StatusCode, string(body)) + } + + return nil +} diff --git a/pkg/datastore/transaction_coverage_test.go b/pkg/datastore/transaction_coverage_test.go new file mode 100644 index 0000000..54209c6 --- /dev/null +++ b/pkg/datastore/transaction_coverage_test.go @@ -0,0 +1,284 @@ +package datastore_test + +import ( + "context" + "errors" + "testing" + + "github.com/codeGROOVE-dev/ds9/pkg/datastore" +) + +// Test manual transaction API (NewTransaction, Get, Put, Delete, Commit, Rollback) +func TestManualTransaction(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + t.Run("NewTransactionAndCommit", func(t *testing.T) { + // Begin transaction + tx, err := client.NewTransaction(ctx) + if err != nil { + t.Fatalf("NewTransaction failed: %v", err) + } + + // Perform operations + key := datastore.NameKey("TxTest", "manual1", nil) + entity := &testEntity{Name: "manual", Count: 1} + + // Put in transaction + if _, err := tx.Put(key, entity); err != nil { + t.Fatalf("Transaction Put failed: %v", err) + } + + // Commit transaction + if _, err := tx.Commit(); err != nil { + t.Fatalf("Transaction Commit failed: %v", err) + } + + // Verify entity was saved + var retrieved testEntity + if err := client.Get(ctx, key, &retrieved); err != nil { + t.Fatalf("Get after commit failed: %v", err) + } + if retrieved.Name != "manual" { + t.Errorf("Expected name 'manual', got '%s'", retrieved.Name) + } + }) + + t.Run("NewTransactionAndRollback", func(t *testing.T) { + // Begin transaction + tx, err := client.NewTransaction(ctx) + if err != nil { + t.Fatalf("NewTransaction failed: %v", err) + } + + // Perform operations + key := datastore.NameKey("TxTest", "rollback1", nil) + entity := &testEntity{Name: "rollback", Count: 1} + + // Put in transaction + if _, err := tx.Put(key, entity); err != nil { + t.Fatalf("Transaction Put failed: %v", err) + } + + // Rollback transaction + if err := tx.Rollback(); err != nil { + t.Fatalf("Transaction Rollback failed: %v", err) + } + + // Verify entity was NOT saved + var retrieved testEntity + err = client.Get(ctx, key, &retrieved) + if !errors.Is(err, datastore.ErrNoSuchEntity) { + t.Errorf("Expected ErrNoSuchEntity after rollback, got %v", err) + } + }) + + t.Run("TransactionGet", func(t *testing.T) { + // Create entity first + key := datastore.NameKey("TxTest", "get1", nil) + entity := &testEntity{Name: "existing", Count: 42} + if _, err := client.Put(ctx, key, entity); err != nil { + t.Fatalf("Put failed: %v", err) + } + + // Begin transaction and read + tx, err := client.NewTransaction(ctx) + if err != nil { + t.Fatalf("NewTransaction failed: %v", err) + } + + var retrieved testEntity + if err := tx.Get(key, &retrieved); err != nil { + t.Fatalf("Transaction Get failed: %v", err) + } + + if retrieved.Count != 42 { + t.Errorf("Expected count 42, got %d", retrieved.Count) + } + + // Rollback since we're just reading + if err := tx.Rollback(); err != nil { + t.Logf("Rollback returned: %v", err) + } + }) + + t.Run("TransactionDelete", func(t *testing.T) { + // Create entity first + key := datastore.NameKey("TxTest", "delete1", nil) + entity := &testEntity{Name: "to-delete", Count: 1} + if _, err := client.Put(ctx, key, entity); err != nil { + t.Fatalf("Put failed: %v", err) + } + + // Begin transaction and delete + tx, err := client.NewTransaction(ctx) + if err != nil { + t.Fatalf("NewTransaction failed: %v", err) + } + + if err := tx.Delete(key); err != nil { + t.Fatalf("Transaction Delete failed: %v", err) + } + + // Commit + if _, err := tx.Commit(); err != nil { + t.Fatalf("Commit failed: %v", err) + } + + // Verify entity was deleted + var retrieved testEntity + err = client.Get(ctx, key, &retrieved) + if !errors.Is(err, datastore.ErrNoSuchEntity) { + t.Errorf("Expected ErrNoSuchEntity after delete, got %v", err) + } + }) + + t.Run("TransactionPutMulti", func(t *testing.T) { + // Begin transaction + tx, err := client.NewTransaction(ctx) + if err != nil { + t.Fatalf("NewTransaction failed: %v", err) + } + + keys := []*datastore.Key{ + datastore.NameKey("TxTest", "multi1", nil), + datastore.NameKey("TxTest", "multi2", nil), + } + + entities := []testEntity{ + {Name: "multi1", Count: 1}, + {Name: "multi2", Count: 2}, + } + + if _, err := tx.PutMulti(keys, entities); err != nil { + t.Fatalf("Transaction PutMulti failed: %v", err) + } + + // Commit + if _, err := tx.Commit(); err != nil { + t.Fatalf("Commit failed: %v", err) + } + + // Verify entities were saved + for i, key := range keys { + var retrieved testEntity + if err := client.Get(ctx, key, &retrieved); err != nil { + t.Errorf("Get for key %d failed: %v", i, err) + } + } + }) + + t.Run("TransactionGetMulti", func(t *testing.T) { + // Create entities first + keys := []*datastore.Key{ + datastore.NameKey("TxTest", "getmulti1", nil), + datastore.NameKey("TxTest", "getmulti2", nil), + } + + entities := []testEntity{ + {Name: "getmulti1", Count: 1}, + {Name: "getmulti2", Count: 2}, + } + + for i, key := range keys { + if _, err := client.Put(ctx, key, &entities[i]); err != nil { + t.Fatalf("Put failed: %v", err) + } + } + + // Begin transaction and read multiple + tx, err := client.NewTransaction(ctx) + if err != nil { + t.Fatalf("NewTransaction failed: %v", err) + } + + retrieved := make([]testEntity, len(keys)) + if err := tx.GetMulti(keys, &retrieved); err != nil { + t.Fatalf("Transaction GetMulti failed: %v", err) + } + + for i, entity := range retrieved { + if entity.Count != int64(i+1) { + t.Errorf("Entity %d: expected count %d, got %d", i, i+1, entity.Count) + } + } + + // Rollback since we're just reading + if err := tx.Rollback(); err != nil { + t.Logf("Rollback returned: %v", err) + } + }) + + t.Run("TransactionDeleteMulti", func(t *testing.T) { + // Create entities first + keys := []*datastore.Key{ + datastore.NameKey("TxTest", "delmulti1", nil), + datastore.NameKey("TxTest", "delmulti2", nil), + } + + for _, key := range keys { + entity := &testEntity{Name: "to-delete", Count: 1} + if _, err := client.Put(ctx, key, entity); err != nil { + t.Fatalf("Put failed: %v", err) + } + } + + // Begin transaction and delete multiple + tx, err := client.NewTransaction(ctx) + if err != nil { + t.Fatalf("NewTransaction failed: %v", err) + } + + if err := tx.DeleteMulti(keys); err != nil { + t.Fatalf("Transaction DeleteMulti failed: %v", err) + } + + // Commit + if _, err := tx.Commit(); err != nil { + t.Fatalf("Commit failed: %v", err) + } + + // Verify entities were deleted + for _, key := range keys { + var retrieved testEntity + err := client.Get(ctx, key, &retrieved) + if !errors.Is(err, datastore.ErrNoSuchEntity) { + t.Errorf("Expected ErrNoSuchEntity for key %v, got %v", key, err) + } + } + }) + + t.Run("TransactionMutate", func(t *testing.T) { + // Begin transaction + tx, err := client.NewTransaction(ctx) + if err != nil { + t.Fatalf("NewTransaction failed: %v", err) + } + + key := datastore.NameKey("TxTest", "mutate1", nil) + entity := &testEntity{Name: "mutate", Count: 1} + + // Create mutation + mut := datastore.NewInsert(key, entity) + + if _, err := tx.Mutate(mut); err != nil { + t.Fatalf("Transaction Mutate failed: %v", err) + } + + // Commit + if _, err := tx.Commit(); err != nil { + t.Fatalf("Commit failed: %v", err) + } + + // Verify entity was saved + var retrieved testEntity + if err := client.Get(ctx, key, &retrieved); err != nil { + t.Fatalf("Get after mutate failed: %v", err) + } + if retrieved.Name != "mutate" { + t.Errorf("Expected name 'mutate', got '%s'", retrieved.Name) + } + }) +} diff --git a/pkg/datastore/transaction_test.go b/pkg/datastore/transaction_test.go new file mode 100644 index 0000000..c4a06fa --- /dev/null +++ b/pkg/datastore/transaction_test.go @@ -0,0 +1,1663 @@ +package datastore_test + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/codeGROOVE-dev/ds9/pkg/datastore" +) + +func TestRunInTransaction(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Put initial entity + entity := &testEntity{ + Name: "counter", + Count: 0, + } + + key := datastore.NameKey("TestKind", "counter", nil) + _, err := client.Put(ctx, key, entity) + if err != nil { + t.Fatalf("Put failed: %v", err) + } + + // Run transaction to read and update + _, err = client.RunInTransaction(ctx, func(tx *datastore.Transaction) error { + var current testEntity + if err := tx.Get(key, ¤t); err != nil { + return err + } + + current.Count++ + _, err := tx.Put(key, ¤t) + return err + }) + if err != nil { + t.Fatalf("RunInTransaction failed: %v", err) + } + + // Verify the update + var updated testEntity + err = client.Get(ctx, key, &updated) + if err != nil { + t.Fatalf("Get failed: %v", err) + } + + if updated.Count != 1 { + t.Errorf("expected Count to be 1, got %d", updated.Count) + } +} + +func TestTransactionNotFound(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + key := datastore.NameKey("TestKind", "nonexistent", nil) + + _, err := client.RunInTransaction(ctx, func(tx *datastore.Transaction) error { + var entity testEntity + return tx.Get(key, &entity) + }) + + if !errors.Is(err, datastore.ErrNoSuchEntity) { + t.Errorf("expected datastore.ErrNoSuchEntity, got %v", err) + } +} + +func TestTransactionMultipleOperations(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Put initial entities + for i := range 3 { + entity := &testEntity{ + Name: "item", + Count: int64(i), + } + key := datastore.NameKey("TestKind", string(rune('a'+i)), nil) + _, err := client.Put(ctx, key, entity) + if err != nil { + t.Fatalf("Put failed: %v", err) + } + } + + // Run transaction that reads and updates multiple entities + _, err := client.RunInTransaction(ctx, func(tx *datastore.Transaction) error { + for i := range 3 { + key := datastore.NameKey("TestKind", string(rune('a'+i)), nil) + var current testEntity + if err := tx.Get(key, ¤t); err != nil { + return err + } + + current.Count += 10 + _, err := tx.Put(key, ¤t) + if err != nil { + return err + } + } + return nil + }) + if err != nil { + t.Fatalf("RunInTransaction failed: %v", err) + } + + // Verify updates + for i := range 3 { + key := datastore.NameKey("TestKind", string(rune('a'+i)), nil) + var retrieved testEntity + err = client.Get(ctx, key, &retrieved) + if err != nil { + t.Fatalf("Get failed: %v", err) + } + expectedCount := int64(i + 10) + if retrieved.Count != expectedCount { + t.Errorf("entity %d: expected Count %d, got %d", i, expectedCount, retrieved.Count) + } + } +} + +func TestTransactionWithError(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Store initial entity + key := datastore.NameKey("TestKind", "tx-err", nil) + entity := testEntity{Name: "initial", Count: 1} + _, err := client.Put(ctx, key, &entity) + if err != nil { + t.Fatalf("Put failed: %v", err) + } + + // Run transaction that errors + _, err = client.RunInTransaction(ctx, func(tx *datastore.Transaction) error { + var current testEntity + if err := tx.Get(key, ¤t); err != nil { + return err + } + + current.Count = 999 + + if _, err := tx.Put(key, ¤t); err != nil { + return err + } + + // Return error to trigger rollback + return errors.New("intentional error") + }) + + if err == nil { + t.Fatal("expected transaction to fail, got nil error") + } + if !strings.Contains(err.Error(), "intentional error") { + t.Errorf("expected 'intentional error', got: %v", err) + } + + // Verify entity was not modified (transaction rolled back) + // Note: In a real implementation this would check rollback, but our mock doesn't support it + // This test at least exercises the error path +} + +func TestTransactionWithDatabaseID(t *testing.T) { + // Setup mock servers + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Metadata-Flavor") != "Google" { + w.WriteHeader(http.StatusForbidden) + return + } + if r.URL.Path == "/project/project-id" { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("test-project")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + if r.URL.Path == "/instance/service-accounts/default/token" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-token", + "expires_in": 3600, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer metadataServer.Close() + + txID := "test-tx-123" + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Authorization") != "Bearer test-token" { + w.WriteHeader(http.StatusUnauthorized) + return + } + + var reqBody map[string]any + if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + + // Check for databaseId in request + if dbID, ok := reqBody["databaseId"].(string); ok && dbID != "tx-db" { + t.Errorf("expected databaseId 'tx-db', got %v", dbID) + } + + w.Header().Set("Content-Type", "application/json") + + if r.URL.Path == "/projects/test-project:beginTransaction" { + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(map[string]any{ + "transaction": txID, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + + if r.URL.Path == "/projects/test-project:commit" { + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(map[string]any{ + "mutationResults": []any{}, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + + if r.URL.Path == "/projects/test-project:lookup" { + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(map[string]any{ + "found": []any{}, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + + w.WriteHeader(http.StatusNotFound) + })) + defer apiServer.Close() + + restore := datastore.SetTestURLs(metadataServer.URL, apiServer.URL) + defer restore() + + ctx := context.Background() + client, err := datastore.NewClientWithDatabase(ctx, "test-project", "tx-db") + if err != nil { + t.Fatalf("NewClientWithDatabase failed: %v", err) + } + + // Run transaction with databaseID + _, err = client.RunInTransaction(ctx, func(tx *datastore.Transaction) error { + key := datastore.NameKey("TestKind", "tx-test", nil) + entity := testEntity{Name: "in-tx", Count: 42} + _, err := tx.Put(key, &entity) + return err + }) + if err != nil { + t.Fatalf("Transaction with databaseID failed: %v", err) + } +} + +func TestTransactionRollback(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + // Put initial entity + key := datastore.NameKey("TestKind", "rollback-test", nil) + entity := &testEntity{Name: "original", Count: 1} + _, err := client.Put(ctx, key, entity) + if err != nil { + t.Fatalf("Put failed: %v", err) + } + + // Run transaction that will fail + _, err = client.RunInTransaction(ctx, func(tx *datastore.Transaction) error { + var current testEntity + if err := tx.Get(key, ¤t); err != nil { + return err + } + + current.Name = "modified" + current.Count = 999 + + _, err := tx.Put(key, ¤t) + if err != nil { + return err + } + + // Return error to cause rollback + return errors.New("force rollback") + }) + + if err == nil { + t.Fatal("expected transaction to fail") + } + + if !strings.Contains(err.Error(), "force rollback") { + t.Errorf("expected 'force rollback' error, got: %v", err) + } +} + +func TestTransactionBeginFailure(t *testing.T) { + // Setup mock servers + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Metadata-Flavor") != "Google" { + w.WriteHeader(http.StatusForbidden) + return + } + if r.URL.Path == "/project/project-id" { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("test-project")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + if r.URL.Path == "/instance/service-accounts/default/token" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-token", + "expires_in": 3600, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer metadataServer.Close() + + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Fail to begin transaction + w.WriteHeader(http.StatusInternalServerError) + if _, err := w.Write([]byte(`{"error":"internal error"}`)); err != nil { + t.Logf("write failed: %v", err) + } + })) + defer apiServer.Close() + + restore := datastore.SetTestURLs(metadataServer.URL, apiServer.URL) + defer restore() + + ctx := context.Background() + client, err := datastore.NewClient(ctx, "test-project") + if err != nil { + t.Fatalf("NewClient failed: %v", err) + } + + _, err = client.RunInTransaction(ctx, func(tx *datastore.Transaction) error { + return nil + }) + + if err == nil { + t.Fatal("expected transaction to fail on begin") + } + + if !strings.Contains(err.Error(), "500") { + t.Errorf("expected error to mention 500 status, got: %v", err) + } +} + +func TestTransactionCommitAbortedRetry(t *testing.T) { + // Setup mock servers + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Metadata-Flavor") != "Google" { + w.WriteHeader(http.StatusForbidden) + return + } + if r.URL.Path == "/project/project-id" { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("test-project")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + if r.URL.Path == "/instance/service-accounts/default/token" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-token", + "expires_in": 3600, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer metadataServer.Close() + + commitAttempt := 0 + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, "beginTransaction") { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(map[string]any{ + "transaction": "test-tx-123", + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + + if strings.Contains(r.URL.Path, "commit") { + commitAttempt++ + // Fail with 409 ABORTED on first two attempts, succeed on third + if commitAttempt < 3 { + w.WriteHeader(http.StatusConflict) + if _, err := w.Write([]byte(`{"error":"ABORTED: transaction aborted"}`)); err != nil { + t.Logf("write failed: %v", err) + } + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(map[string]any{ + "mutationResults": []any{}, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + + if strings.Contains(r.URL.Path, "lookup") { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(map[string]any{ + "found": []any{}, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + + w.WriteHeader(http.StatusNotFound) + })) + defer apiServer.Close() + + restore := datastore.SetTestURLs(metadataServer.URL, apiServer.URL) + defer restore() + + ctx := context.Background() + client, err := datastore.NewClient(ctx, "test-project") + if err != nil { + t.Fatalf("NewClient failed: %v", err) + } + + // This should succeed after retries + key := datastore.NameKey("TestKind", "tx-retry", nil) + _, err = client.RunInTransaction(ctx, func(tx *datastore.Transaction) error { + _, err := tx.Put(key, &testEntity{Name: "test", Count: 1}) + return err + }) + if err != nil { + t.Fatalf("transaction should succeed after retries, got: %v", err) + } + + if commitAttempt < 2 { + t.Errorf("expected at least 2 commit attempts, got %d", commitAttempt) + } +} + +func TestTransactionMaxRetriesExceeded(t *testing.T) { + // Setup mock servers + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Metadata-Flavor") != "Google" { + w.WriteHeader(http.StatusForbidden) + return + } + if r.URL.Path == "/project/project-id" { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("test-project")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + if r.URL.Path == "/instance/service-accounts/default/token" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-token", + "expires_in": 3600, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer metadataServer.Close() + + commitAttempt := 0 + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, "beginTransaction") { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(map[string]any{ + "transaction": "test-tx-456", + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + + if strings.Contains(r.URL.Path, "commit") { + commitAttempt++ + // Always return 409 ABORTED + w.WriteHeader(http.StatusConflict) + if _, err := w.Write([]byte(`{"error":"status 409 ABORTED: transaction conflict"}`)); err != nil { + t.Logf("write failed: %v", err) + } + return + } + + if strings.Contains(r.URL.Path, "lookup") { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(map[string]any{ + "found": []any{}, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + + w.WriteHeader(http.StatusNotFound) + })) + defer apiServer.Close() + + restore := datastore.SetTestURLs(metadataServer.URL, apiServer.URL) + defer restore() + + ctx := context.Background() + client, err := datastore.NewClient(ctx, "test-project") + if err != nil { + t.Fatalf("NewClient failed: %v", err) + } + + // This should fail after max retries + key := datastore.NameKey("TestKind", "tx-max-retry", nil) + _, err = client.RunInTransaction(ctx, func(tx *datastore.Transaction) error { + _, err := tx.Put(key, &testEntity{Name: "test", Count: 1}) + return err + }) + + if err == nil { + t.Fatal("expected transaction to fail after max retries") + } + + if !strings.Contains(err.Error(), "failed after 3 attempts") { + t.Errorf("expected 'failed after 3 attempts' error, got: %v", err) + } + + if commitAttempt != 3 { + t.Errorf("expected exactly 3 commit attempts, got %d", commitAttempt) + } +} + +func TestTransactionGetNonExistent(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + key := datastore.NameKey("TestKind", "nonexistent", nil) + + _, err := client.RunInTransaction(ctx, func(tx *datastore.Transaction) error { + var entity testEntity + return tx.Get(key, &entity) + }) + + if !errors.Is(err, datastore.ErrNoSuchEntity) { + t.Errorf("expected datastore.ErrNoSuchEntity in transaction, got: %v", err) + } +} + +func TestTransactionPutWithNilKey(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + _, err := client.RunInTransaction(ctx, func(tx *datastore.Transaction) error { + entity := &testEntity{Name: "test", Count: 1} + _, err := tx.Put(nil, entity) + return err + }) + + if err == nil { + t.Error("expected error for nil key in transaction") + } +} + +func TestTransactionGetWithNilKey(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + _, err := client.RunInTransaction(ctx, func(tx *datastore.Transaction) error { + var entity testEntity + return tx.Get(nil, &entity) + }) + + if err == nil { + t.Error("expected error for nil key in transaction Get") + } +} + +func TestTransactionWithMultiplePuts(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + _, err := client.RunInTransaction(ctx, func(tx *datastore.Transaction) error { + for i := range 5 { + key := datastore.NameKey("TxMulti", fmt.Sprintf("key-%d", i), nil) + entity := &testEntity{Name: "test", Count: int64(i)} + _, err := tx.Put(key, entity) + if err != nil { + return err + } + } + return nil + }) + if err != nil { + t.Fatalf("Transaction with multiple puts failed: %v", err) + } + + // Verify all entities were created + for i := range 5 { + key := datastore.NameKey("TxMulti", fmt.Sprintf("key-%d", i), nil) + var retrieved testEntity + err = client.Get(ctx, key, &retrieved) + if err != nil { + t.Errorf("Get for entity %d failed: %v", i, err) + } + if retrieved.Count != int64(i) { + t.Errorf("entity %d: expected Count %d, got %d", i, i, retrieved.Count) + } + } +} + +func TestTransactionGetWithInvalidResponse(t *testing.T) { + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Metadata-Flavor") != "Google" { + w.WriteHeader(http.StatusForbidden) + return + } + if r.URL.Path == "/project/project-id" { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("test-project")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + if r.URL.Path == "/instance/service-accounts/default/token" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-token", + "expires_in": 3600, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer metadataServer.Close() + + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, "beginTransaction") { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(map[string]any{ + "transaction": "test-tx", + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + + if strings.Contains(r.URL.Path, "lookup") { + // Return invalid JSON structure + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte(`{"invalid":"structure"}`)); err != nil { + t.Logf("write failed: %v", err) + } + return + } + + if strings.Contains(r.URL.Path, "commit") { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(map[string]any{ + "mutationResults": []any{}, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + + w.WriteHeader(http.StatusOK) + })) + defer apiServer.Close() + + restore := datastore.SetTestURLs(metadataServer.URL, apiServer.URL) + defer restore() + + ctx := context.Background() + client, err := datastore.NewClient(ctx, "test-project") + if err != nil { + t.Fatalf("NewClient failed: %v", err) + } + + key := datastore.NameKey("Test", "key", nil) + _, err = client.RunInTransaction(ctx, func(tx *datastore.Transaction) error { + var entity testEntity + return tx.Get(key, &entity) + }) + + // Should handle the invalid response gracefully + if err == nil { + t.Log("Transaction succeeded despite invalid lookup response") + } +} + +func TestTransactionWithNonRetriableError(t *testing.T) { + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Metadata-Flavor") != "Google" { + w.WriteHeader(http.StatusForbidden) + return + } + if r.URL.Path == "/project/project-id" { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("test-project")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + if r.URL.Path == "/instance/service-accounts/default/token" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-token", + "expires_in": 3600, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer metadataServer.Close() + + commitAttempts := 0 + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, "beginTransaction") { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(map[string]any{ + "transaction": "test-tx", + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + + if strings.Contains(r.URL.Path, "commit") { + commitAttempts++ + // Return non-retriable error (not 409 ABORTED) + w.WriteHeader(http.StatusBadRequest) + if _, err := w.Write([]byte(`{"error":"INVALID_ARGUMENT"}`)); err != nil { + t.Logf("write failed: %v", err) + } + return + } + + if strings.Contains(r.URL.Path, "lookup") { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(map[string]any{ + "found": []any{}, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + + w.WriteHeader(http.StatusOK) + })) + defer apiServer.Close() + + restore := datastore.SetTestURLs(metadataServer.URL, apiServer.URL) + defer restore() + + ctx := context.Background() + client, err := datastore.NewClient(ctx, "test-project") + if err != nil { + t.Fatalf("NewClient failed: %v", err) + } + + key := datastore.NameKey("Test", "key", nil) + _, err = client.RunInTransaction(ctx, func(tx *datastore.Transaction) error { + _, err := tx.Put(key, &testEntity{Name: "test", Count: 1}) + return err + }) + + if err == nil { + t.Error("expected error on non-retriable failure") + } + + // Should NOT retry on non-409 errors + if commitAttempts != 1 { + t.Errorf("expected exactly 1 commit attempt for non-retriable error, got %d", commitAttempts) + } +} + +func TestTransactionWithInvalidTxResponse(t *testing.T) { + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Metadata-Flavor") != "Google" { + w.WriteHeader(http.StatusForbidden) + return + } + if r.URL.Path == "/project/project-id" { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("test-project")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + if r.URL.Path == "/instance/service-accounts/default/token" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-token", + "expires_in": 3600, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer metadataServer.Close() + + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, "beginTransaction") { + // Return invalid JSON + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte(`{bad json`)); err != nil { + t.Logf("write failed: %v", err) + } + return + } + w.WriteHeader(http.StatusOK) + })) + defer apiServer.Close() + + restore := datastore.SetTestURLs(metadataServer.URL, apiServer.URL) + defer restore() + + ctx := context.Background() + client, err := datastore.NewClient(ctx, "test-project") + if err != nil { + t.Fatalf("NewClient failed: %v", err) + } + + _, err = client.RunInTransaction(ctx, func(tx *datastore.Transaction) error { + return nil + }) + + if err == nil { + t.Error("expected error for invalid transaction response") + } +} + +func TestTransactionGetWithDecodeError(t *testing.T) { + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Metadata-Flavor") != "Google" { + w.WriteHeader(http.StatusForbidden) + return + } + if r.URL.Path == "/project/project-id" { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("test-project")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + if r.URL.Path == "/instance/service-accounts/default/token" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-token", + "expires_in": 3600, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer metadataServer.Close() + + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, "beginTransaction") { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(map[string]any{ + "transaction": "test-tx", + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + + if strings.Contains(r.URL.Path, "lookup") { + // Return entity with malformed data + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(map[string]any{ + "found": []any{ + map[string]any{ + "entity": map[string]any{ + "key": map[string]any{ + "path": []any{ + map[string]any{ + "kind": "Test", + "name": "key", + }, + }, + }, + "properties": map[string]any{ + "name": map[string]any{ + "stringValue": 12345, // Wrong type + }, + }, + }, + }, + }, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + + if strings.Contains(r.URL.Path, "commit") { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + if err := json.NewEncoder(w).Encode(map[string]any{ + "mutationResults": []any{}, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + + w.WriteHeader(http.StatusOK) + })) + defer apiServer.Close() + + restore := datastore.SetTestURLs(metadataServer.URL, apiServer.URL) + defer restore() + + ctx := context.Background() + client, err := datastore.NewClient(ctx, "test-project") + if err != nil { + t.Fatalf("NewClient failed: %v", err) + } + + key := datastore.NameKey("Test", "key", nil) + _, err = client.RunInTransaction(ctx, func(tx *datastore.Transaction) error { + var entity testEntity + return tx.Get(key, &entity) + }) + // May succeed or fail depending on how decoding handles type mismatches + if err != nil { + t.Logf("Transaction Get with decode error: %v", err) + } +} + +func TestTransactionGetMissingEntity(t *testing.T) { + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Metadata-Flavor") != "Google" { + w.WriteHeader(http.StatusForbidden) + return + } + if r.URL.Path == "/project/project-id" { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("test-project")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + if r.URL.Path == "/instance/service-accounts/default/token" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-token", + "expires_in": 3600, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer metadataServer.Close() + + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, ":beginTransaction") { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "transaction": "test-tx-id", + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + if strings.Contains(r.URL.Path, ":lookup") { + // Return empty found array (entity not found) + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "found": []any{}, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + if strings.Contains(r.URL.Path, ":commit") { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "mutationResults": []any{}, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer apiServer.Close() + + restore := datastore.SetTestURLs(metadataServer.URL, apiServer.URL) + defer restore() + + client, err := datastore.NewClient(context.Background(), "test-project") + if err != nil { + t.Fatalf("NewClient failed: %v", err) + } + + ctx := context.Background() + key := datastore.NameKey("Test", "nonexistent", nil) + + _, err = client.RunInTransaction(ctx, func(tx *datastore.Transaction) error { + var entity testEntity + err := tx.Get(key, &entity) + if err == nil { + return errors.New("expected error for missing entity") + } + return nil + }) + if err != nil { + t.Errorf("transaction should succeed even with get error: %v", err) + } +} + +func TestTransactionGetDecodeError(t *testing.T) { + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Metadata-Flavor") != "Google" { + w.WriteHeader(http.StatusForbidden) + return + } + if r.URL.Path == "/project/project-id" { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("test-project")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + if r.URL.Path == "/instance/service-accounts/default/token" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-token", + "expires_in": 3600, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer metadataServer.Close() + + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, ":beginTransaction") { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "transaction": "test-tx-id", + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + if strings.Contains(r.URL.Path, ":lookup") { + // Return malformed entity + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "found": []map[string]any{ + { + "entity": "invalid-not-a-map", + }, + }, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + if strings.Contains(r.URL.Path, ":commit") { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "mutationResults": []any{}, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer apiServer.Close() + + restore := datastore.SetTestURLs(metadataServer.URL, apiServer.URL) + defer restore() + + client, err := datastore.NewClient(context.Background(), "test-project") + if err != nil { + t.Fatalf("NewClient failed: %v", err) + } + + ctx := context.Background() + key := datastore.NameKey("Test", "key", nil) + + _, err = client.RunInTransaction(ctx, func(tx *datastore.Transaction) error { + var entity testEntity + err := tx.Get(key, &entity) + if err == nil { + return errors.New("expected decode error") + } + return nil + }) + if err != nil { + t.Errorf("transaction should succeed: %v", err) + } +} + +func TestTransactionCommitInvalidResponse(t *testing.T) { + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Metadata-Flavor") != "Google" { + w.WriteHeader(http.StatusForbidden) + return + } + if r.URL.Path == "/project/project-id" { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("test-project")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + if r.URL.Path == "/instance/service-accounts/default/token" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-token", + "expires_in": 3600, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer metadataServer.Close() + + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, ":beginTransaction") { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "transaction": "test-tx-id", + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + if strings.Contains(r.URL.Path, ":commit") { + // Return invalid JSON (missing mutationResults) + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + // Missing mutationResults field + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer apiServer.Close() + + restore := datastore.SetTestURLs(metadataServer.URL, apiServer.URL) + defer restore() + + client, err := datastore.NewClient(context.Background(), "test-project") + if err != nil { + t.Fatalf("NewClient failed: %v", err) + } + + ctx := context.Background() + key := datastore.NameKey("Test", "key", nil) + entity := &testEntity{Name: "test"} + + _, err = client.RunInTransaction(ctx, func(tx *datastore.Transaction) error { + _, err := tx.Put(key, entity) + return err + }) + if err != nil { + t.Logf("Transaction with invalid commit response failed: %v", err) + } +} + +func TestTransactionCommitUnmarshalError(t *testing.T) { + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Metadata-Flavor") != "Google" { + w.WriteHeader(http.StatusForbidden) + return + } + if r.URL.Path == "/project/project-id" { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("test-project")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + if r.URL.Path == "/instance/service-accounts/default/token" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-token", + "expires_in": 3600, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer metadataServer.Close() + + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, ":beginTransaction") { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "transaction": "test-tx-id", + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + if strings.Contains(r.URL.Path, ":commit") { + // Return malformed mutation results + w.Header().Set("Content-Type", "application/json") + if _, err := w.Write([]byte(`{"mutationResults": "not-an-array"}`)); err != nil { + t.Logf("write failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer apiServer.Close() + + restore := datastore.SetTestURLs(metadataServer.URL, apiServer.URL) + defer restore() + + client, err := datastore.NewClient(context.Background(), "test-project") + if err != nil { + t.Fatalf("NewClient failed: %v", err) + } + + ctx := context.Background() + key := datastore.NameKey("Test", "key", nil) + entity := &testEntity{Name: "test"} + + _, err = client.RunInTransaction(ctx, func(tx *datastore.Transaction) error { + _, err := tx.Put(key, entity) + return err + }) + // May or may not error depending on JSON parsing behavior + if err != nil { + t.Logf("Transaction with malformed mutation results failed: %v", err) + } +} + +func TestTransactionGetNotFound(t *testing.T) { + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Metadata-Flavor") != "Google" { + w.WriteHeader(http.StatusForbidden) + return + } + if r.URL.Path == "/project/project-id" { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("test-project")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + if r.URL.Path == "/instance/service-accounts/default/token" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-token", + "expires_in": 3600, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer metadataServer.Close() + + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, ":beginTransaction") { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "transaction": "test-tx-id", + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + if strings.Contains(r.URL.Path, ":lookup") { + // Return empty found array + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "found": []any{}, + "missing": []any{}, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + if strings.Contains(r.URL.Path, ":commit") { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "mutationResults": []any{}, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer apiServer.Close() + + restore := datastore.SetTestURLs(metadataServer.URL, apiServer.URL) + defer restore() + + client, err := datastore.NewClient(context.Background(), "test-project") + if err != nil { + t.Fatalf("NewClient failed: %v", err) + } + + ctx := context.Background() + key := datastore.NameKey("Test", "nonexistent", nil) + + _, err = client.RunInTransaction(ctx, func(tx *datastore.Transaction) error { + var entity testEntity + err := tx.Get(key, &entity) + if err == nil { + t.Error("expected error with empty found array") + } + return nil + }) + if err != nil { + t.Logf("Transaction completed: %v", err) + } +} + +func TestTransactionGetAccessTokenError(t *testing.T) { + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Metadata-Flavor") != "Google" { + w.WriteHeader(http.StatusForbidden) + return + } + if r.URL.Path == "/project/project-id" { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("test-project")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + if r.URL.Path == "/instance/service-accounts/default/token" { + // Return error for token request + w.WriteHeader(http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer metadataServer.Close() + + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, ":beginTransaction") { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "transaction": "test-tx-id", + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer apiServer.Close() + + restore := datastore.SetTestURLs(metadataServer.URL, apiServer.URL) + defer restore() + + client, err := datastore.NewClient(context.Background(), "test-project") + if err != nil { + t.Fatalf("NewClient failed: %v", err) + } + + ctx := context.Background() + key := datastore.NameKey("Test", "test-key", nil) + + _, err = client.RunInTransaction(ctx, func(tx *datastore.Transaction) error { + var entity testEntity + err := tx.Get(key, &entity) + if err == nil { + t.Error("expected error with token failure") + } + return err + }) + + if err == nil { + t.Error("expected transaction to fail with token error") + } +} + +func TestTransactionGetNonOKStatus(t *testing.T) { + metadataServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Metadata-Flavor") != "Google" { + w.WriteHeader(http.StatusForbidden) + return + } + if r.URL.Path == "/project/project-id" { + w.WriteHeader(http.StatusOK) + if _, err := w.Write([]byte("test-project")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + if r.URL.Path == "/instance/service-accounts/default/token" { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "access_token": "test-token", + "expires_in": 3600, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer metadataServer.Close() + + apiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, ":beginTransaction") { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "transaction": "test-tx-id", + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + if strings.Contains(r.URL.Path, ":lookup") { + // Return non-OK status + w.WriteHeader(http.StatusBadRequest) + if _, err := w.Write([]byte("bad request")); err != nil { + t.Logf("write failed: %v", err) + } + return + } + if strings.Contains(r.URL.Path, ":commit") { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "mutationResults": []any{}, + }); err != nil { + t.Logf("encode failed: %v", err) + } + return + } + w.WriteHeader(http.StatusNotFound) + })) + defer apiServer.Close() + + restore := datastore.SetTestURLs(metadataServer.URL, apiServer.URL) + defer restore() + + client, err := datastore.NewClient(context.Background(), "test-project") + if err != nil { + t.Fatalf("NewClient failed: %v", err) + } + + ctx := context.Background() + key := datastore.NameKey("Test", "test-key", nil) + + _, err = client.RunInTransaction(ctx, func(tx *datastore.Transaction) error { + var entity testEntity + return tx.Get(key, &entity) + }) + + if err == nil { + t.Error("expected error with non-OK status") + } +} + +func TestRunInTransactionReturnsCommit(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + key := datastore.NameKey("CommitTest", "test1", nil) + + commit, err := client.RunInTransaction(ctx, func(tx *datastore.Transaction) error { + entity := &testEntity{ + Name: "commit test", + Count: 1, + Active: true, + UpdatedAt: time.Now().UTC().Truncate(time.Microsecond), + } + _, err := tx.Put(key, entity) + return err + }) + if err != nil { + t.Fatalf("RunInTransaction failed: %v", err) + } + + if commit == nil { + t.Fatal("Expected non-nil Commit, got nil") + } + + // Commit should be a valid *Commit type + _ = commit +} + +func TestRunInTransactionErrorReturnsNilCommit(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + + expectedErr := errors.New("intentional error") + commit, err := client.RunInTransaction(ctx, func(tx *datastore.Transaction) error { + return expectedErr + }) + + if err == nil { + t.Fatal("Expected error, got nil") + } + + if !errors.Is(err, expectedErr) { + t.Errorf("Expected error to be %v, got %v", expectedErr, err) + } + + if commit != nil { + t.Errorf("Expected nil Commit on error, got %v", commit) + } +} + +func TestTransactionOptions(t *testing.T) { + t.Run("MaxAttempts", func(t *testing.T) { + // Test that MaxAttempts option is accepted and sets the retry limit + // We can verify this by checking the error message mentions the right attempt count + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + key := datastore.NameKey("TestKind", "test", nil) + + // This test verifies that the MaxAttempts option is parsed correctly + // The actual retry behavior is tested in TestTransactionMaxRetriesExceeded + _, err := client.RunInTransaction(ctx, func(tx *datastore.Transaction) error { + entity := testEntity{Name: "test", Count: 42} + _, err := tx.Put(key, &entity) + return err + }, datastore.MaxAttempts(5)) + // With mock client, this should succeed + if err != nil { + t.Fatalf("Transaction failed: %v", err) + } + }) + + t.Run("WithReadTime", func(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + key := datastore.NameKey("TestKind", "test", nil) + + // First, put an entity + entity := testEntity{Name: "test", Count: 42} + _, err := client.Put(ctx, key, &entity) + if err != nil { + t.Fatalf("Put failed: %v", err) + } + + // Run a read-only transaction with readTime + readTime := time.Now().UTC() + _, err = client.RunInTransaction(ctx, func(tx *datastore.Transaction) error { + var result testEntity + return tx.Get(key, &result) + }, datastore.WithReadTime(readTime)) + // Note: ds9mock doesn't actually enforce read-only semantics, + // but we're testing that the option is accepted and doesn't cause errors + if err != nil { + t.Fatalf("Transaction with WithReadTime failed: %v", err) + } + }) + + t.Run("CombinedOptions", func(t *testing.T) { + client, cleanup := datastore.NewMockClient(t) + defer cleanup() + + ctx := context.Background() + key := datastore.NameKey("TestKind", "test", nil) + + // Test that multiple options can be combined + _, err := client.RunInTransaction(ctx, func(tx *datastore.Transaction) error { + entity := testEntity{Name: "test", Count: 42} + _, err := tx.Put(key, &entity) + return err + }, datastore.MaxAttempts(2), datastore.WithReadTime(time.Now().UTC())) + // With mock client, this should succeed + if err != nil { + t.Fatalf("Transaction with combined options failed: %v", err) + } + }) +} diff --git a/ds9mock/mock.go b/pkg/mock/mock.go similarity index 88% rename from ds9mock/mock.go rename to pkg/mock/mock.go index 9deab8d..4f9ceb6 100644 --- a/ds9mock/mock.go +++ b/pkg/mock/mock.go @@ -1,22 +1,23 @@ -// Package ds9mock provides an in-memory mock Datastore server for testing. +// Package mock provides an in-memory mock Datastore server for testing. // // This package can be used by both ds9 internal tests and by end-users who want // to test their code that depends on ds9 without hitting real Datastore APIs. // // Example usage: // +// import "github.com/codeGROOVE-dev/ds9/pkg/datastore" +// // func TestMyCode(t *testing.T) { -// client, cleanup := ds9mock.NewClient(t) +// client, cleanup := datastore.NewMockClient(t) // defer cleanup() // // // Use client in your tests -// key := ds9.NameKey("Task", "task-1", nil) +// key := datastore.NameKey("Task", "task-1", nil) // _, err := client.Put(ctx, key, &myTask) // } -package ds9mock +package mock import ( - "context" "encoding/json" "fmt" "log" @@ -25,8 +26,6 @@ import ( "strconv" "sync" "testing" - - "github.com/codeGROOVE-dev/ds9" ) const metadataFlavor = "Google" @@ -37,18 +36,21 @@ const metadataFlavor = "Google" type Store struct { mu sync.RWMutex entities map[string]map[string]any + nextID int64 // Counter for allocating unique IDs } // NewStore creates a new in-memory store. func NewStore() *Store { return &Store{ entities: make(map[string]map[string]any), + nextID: 1000, // Start IDs at 1000 } } -// NewClient creates a ds9 client connected to mock servers with in-memory storage. -// Returns the client and a cleanup function that should be deferred. -func NewClient(t *testing.T) (client *ds9.Client, cleanup func()) { +// NewMockServers creates mock metadata and API servers for testing. +// Returns the metadata URL, API URL, and a cleanup function. +// This function doesn't import datastore to avoid import cycles. +func NewMockServers(t *testing.T) (metadataURL, apiURL string, cleanup func()) { t.Helper() store := NewStore() @@ -125,24 +127,12 @@ func NewClient(t *testing.T) (client *ds9.Client, cleanup func()) { w.WriteHeader(http.StatusNotFound) })) - // Set test URLs in ds9 - restore := ds9.SetTestURLs(metadataServer.URL, apiServer.URL) - - // Create client - ctx := context.Background() - var err error - client, err = ds9.NewClient(ctx, "test-project") - if err != nil { - t.Fatalf("failed to create mock client: %v", err) - } - cleanup = func() { - restore() metadataServer.Close() apiServer.Close() } - return client, cleanup + return metadataServer.URL, apiServer.URL, cleanup } // handleLookup handles lookup (get) requests. @@ -472,11 +462,37 @@ func (s *Store) handleRunQuery(w http.ResponseWriter, r *http.Request) { limit = int(l) } + // Check for startCursor - if present, we've already returned results + // For simplicity in the mock, return empty results when cursor is used + var startCursor string + if sc, ok := query["startCursor"].(string); ok { + startCursor = sc + } + // Find all entities of this kind s.mu.RLock() defer s.mu.RUnlock() var results []any + + // If there's a start cursor, we simulate pagination by returning no more results + // This is a simplified mock behavior - a real implementation would track position + if startCursor != "" { + // Return empty results to indicate end of pagination + response := map[string]any{ + "batch": map[string]any{ + "entityResults": []any{}, + "moreResults": "NO_MORE_RESULTS", + }, + } + w.WriteHeader(http.StatusOK) + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(response); err != nil { + log.Printf("failed to encode query response: %v", err) + } + return + } + for _, entity := range s.entities { keyData, ok := entity["key"].(map[string]any) if !ok { @@ -496,6 +512,13 @@ func (s *Store) handleRunQuery(w http.ResponseWriter, r *http.Request) { } if entityKind == kind { + // Apply filters if present + if filterMap, hasFilter := query["filter"].(map[string]any); hasFilter { + if !matchesFilter(entity, filterMap) { + continue + } + } + results = append(results, map[string]any{ "entity": entity, }) @@ -506,13 +529,33 @@ func (s *Store) handleRunQuery(w http.ResponseWriter, r *http.Request) { } } + // Add cursor if there are more results (for pagination testing) + var endCursor string + if limit > 0 && len(results) == limit { + // Generate a simple cursor to indicate more results might exist + endCursor = fmt.Sprintf("cursor-after-%d", limit) + } + + // Build response + batch := map[string]any{ + "entityResults": results, + } + + // Add cursor if available + if endCursor != "" { + batch["endCursor"] = endCursor + batch["moreResults"] = "MORE_RESULTS_AFTER_LIMIT" + } else { + batch["moreResults"] = "NO_MORE_RESULTS" + } + + response := map[string]any{ + "batch": batch, + } + w.WriteHeader(http.StatusOK) w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(map[string]any{ - "batch": map[string]any{ - "entityResults": results, - }, - }); err != nil { + if err := json.NewEncoder(w).Encode(response); err != nil { log.Printf("failed to encode query response: %v", err) } } @@ -557,7 +600,7 @@ func handleBeginTransaction(w http.ResponseWriter, r *http.Request) { } // handleAllocateIDs handles :allocateIds requests. -func (*Store) handleAllocateIDs(w http.ResponseWriter, r *http.Request) { +func (s *Store) handleAllocateIDs(w http.ResponseWriter, r *http.Request) { var req struct { DatabaseID string `json:"databaseId"` Keys []map[string]any `json:"keys"` @@ -587,7 +630,10 @@ func (*Store) handleAllocateIDs(w http.ResponseWriter, r *http.Request) { } } - // Allocate IDs for incomplete keys + // Allocate unique IDs for incomplete keys + s.mu.Lock() + defer s.mu.Unlock() + allocatedKeys := make([]map[string]any, 0, len(req.Keys)) for _, keyData := range req.Keys { // Parse path to check if incomplete @@ -602,12 +648,13 @@ func (*Store) handleAllocateIDs(w http.ResponseWriter, r *http.Request) { continue } - // If it has no name or id, allocate an ID + // If it has no name or id, allocate a unique ID _, hasName := lastElem["name"] _, hasID := lastElem["id"] if !hasName && !hasID { - // Allocate a simple sequential ID - lastElem["id"] = "1001" // Simple mock ID + // Allocate a unique sequential ID + s.nextID++ + lastElem["id"] = strconv.FormatInt(s.nextID, 10) } allocatedKeys = append(allocatedKeys, keyData) diff --git a/ds9mock/mock_test.go b/pkg/mock/mock_test.go similarity index 80% rename from ds9mock/mock_test.go rename to pkg/mock/mock_test.go index 050825a..8403778 100644 --- a/ds9mock/mock_test.go +++ b/pkg/mock/mock_test.go @@ -1,29 +1,25 @@ -package ds9mock +package mock_test import ( "context" "testing" - "github.com/codeGROOVE-dev/ds9" + "github.com/codeGROOVE-dev/ds9/pkg/datastore" + "github.com/codeGROOVE-dev/ds9/pkg/mock" ) func TestNewStore(t *testing.T) { - store := NewStore() + store := mock.NewStore() if store == nil { t.Fatal("expected non-nil store") } - if store.entities == nil { - t.Error("expected initialized entities map") - } - - if len(store.entities) != 0 { - t.Errorf("expected empty store, got %d entities", len(store.entities)) - } + // Store entities are not directly accessible from outside the package + // but we can verify the store is functional through NewMockServers } func TestNewClient(t *testing.T) { - client, cleanup := NewClient(t) + client, cleanup := datastore.NewMockClient(t) defer cleanup() if client == nil { @@ -32,7 +28,7 @@ func TestNewClient(t *testing.T) { } func TestMockBasicOperations(t *testing.T) { - client, cleanup := NewClient(t) + client, cleanup := datastore.NewMockClient(t) defer cleanup() ctx := context.Background() @@ -43,7 +39,7 @@ func TestMockBasicOperations(t *testing.T) { } // Test Put - key := ds9.NameKey("TestKind", "test-key", nil) + key := datastore.NameKey("TestKind", "test-key", nil) entity := &TestEntity{ Name: "test", Value: 42, @@ -82,7 +78,7 @@ func TestMockBasicOperations(t *testing.T) { } func TestMockMultiOperations(t *testing.T) { - client, cleanup := NewClient(t) + client, cleanup := datastore.NewMockClient(t) defer cleanup() ctx := context.Background() @@ -93,10 +89,10 @@ func TestMockMultiOperations(t *testing.T) { } // Test PutMulti - keys := []*ds9.Key{ - ds9.NameKey("Multi", "key1", nil), - ds9.NameKey("Multi", "key2", nil), - ds9.NameKey("Multi", "key3", nil), + keys := []*datastore.Key{ + datastore.NameKey("Multi", "key1", nil), + datastore.NameKey("Multi", "key2", nil), + datastore.NameKey("Multi", "key3", nil), } entities := []TestEntity{ @@ -144,7 +140,7 @@ func TestMockMultiOperations(t *testing.T) { } func TestMockQuery(t *testing.T) { - client, cleanup := NewClient(t) + client, cleanup := datastore.NewMockClient(t) defer cleanup() ctx := context.Background() @@ -155,7 +151,7 @@ func TestMockQuery(t *testing.T) { // Put some entities for i := range 5 { - key := ds9.NameKey("QueryKind", string(rune('a'+i)), nil) + key := datastore.NameKey("QueryKind", string(rune('a'+i)), nil) entity := &TestEntity{Name: "test"} _, err := client.Put(ctx, key, entity) if err != nil { @@ -164,7 +160,7 @@ func TestMockQuery(t *testing.T) { } // Query for keys - query := ds9.NewQuery("QueryKind").KeysOnly() + query := datastore.NewQuery("QueryKind").KeysOnly() keys, err := client.AllKeys(ctx, query) if err != nil { t.Fatalf("AllKeys failed: %v", err) @@ -176,7 +172,7 @@ func TestMockQuery(t *testing.T) { } func TestMockQueryWithLimit(t *testing.T) { - client, cleanup := NewClient(t) + client, cleanup := datastore.NewMockClient(t) defer cleanup() ctx := context.Background() @@ -187,7 +183,7 @@ func TestMockQueryWithLimit(t *testing.T) { // Put entities for i := range 10 { - key := ds9.NameKey("LimitKind", string(rune('a'+i)), nil) + key := datastore.NameKey("LimitKind", string(rune('a'+i)), nil) entity := &TestEntity{Name: "test"} _, err := client.Put(ctx, key, entity) if err != nil { @@ -196,7 +192,7 @@ func TestMockQueryWithLimit(t *testing.T) { } // Query with limit - query := ds9.NewQuery("LimitKind").KeysOnly().Limit(3) + query := datastore.NewQuery("LimitKind").KeysOnly().Limit(3) keys, err := client.AllKeys(ctx, query) if err != nil { t.Fatalf("AllKeys with limit failed: %v", err) @@ -208,7 +204,7 @@ func TestMockQueryWithLimit(t *testing.T) { } func TestMockTransaction(t *testing.T) { - client, cleanup := NewClient(t) + client, cleanup := datastore.NewMockClient(t) defer cleanup() ctx := context.Background() @@ -218,7 +214,7 @@ func TestMockTransaction(t *testing.T) { } // Put initial entity - key := ds9.NameKey("TxKind", "counter", nil) + key := datastore.NameKey("TxKind", "counter", nil) entity := &TestEntity{Counter: 0} _, err := client.Put(ctx, key, entity) if err != nil { @@ -226,7 +222,7 @@ func TestMockTransaction(t *testing.T) { } // Run transaction - _, err = client.RunInTransaction(ctx, func(tx *ds9.Transaction) error { + _, err = client.RunInTransaction(ctx, func(tx *datastore.Transaction) error { var current TestEntity if err := tx.Get(key, ¤t); err != nil { return err @@ -253,7 +249,7 @@ func TestMockTransaction(t *testing.T) { } func TestMockHierarchicalKeys(t *testing.T) { - client, cleanup := NewClient(t) + client, cleanup := datastore.NewMockClient(t) defer cleanup() ctx := context.Background() @@ -263,7 +259,7 @@ func TestMockHierarchicalKeys(t *testing.T) { } // Create parent key - parentKey := ds9.NameKey("Parent", "p1", nil) + parentKey := datastore.NameKey("Parent", "p1", nil) parentEntity := &TestEntity{Name: "parent"} _, err := client.Put(ctx, parentKey, parentEntity) if err != nil { @@ -271,7 +267,7 @@ func TestMockHierarchicalKeys(t *testing.T) { } // Create child key - childKey := ds9.NameKey("Child", "c1", parentKey) + childKey := datastore.NameKey("Child", "c1", parentKey) childEntity := &TestEntity{Name: "child"} _, err = client.Put(ctx, childKey, childEntity) if err != nil { @@ -291,7 +287,7 @@ func TestMockHierarchicalKeys(t *testing.T) { } func TestMockIDKeys(t *testing.T) { - client, cleanup := NewClient(t) + client, cleanup := datastore.NewMockClient(t) defer cleanup() ctx := context.Background() @@ -301,7 +297,7 @@ func TestMockIDKeys(t *testing.T) { } // Use ID key - key := ds9.IDKey("IDKind", 12345, nil) + key := datastore.IDKey("IDKind", 12345, nil) entity := &TestEntity{Value: 99} _, err := client.Put(ctx, key, entity) if err != nil { @@ -321,13 +317,13 @@ func TestMockIDKeys(t *testing.T) { } func TestMockEmptyQuery(t *testing.T) { - client, cleanup := NewClient(t) + client, cleanup := datastore.NewMockClient(t) defer cleanup() ctx := context.Background() // Query non-existent kind - query := ds9.NewQuery("NonExistent").KeysOnly() + query := datastore.NewQuery("NonExistent").KeysOnly() keys, err := client.AllKeys(ctx, query) if err != nil { t.Fatalf("AllKeys on empty kind failed: %v", err) @@ -339,13 +335,13 @@ func TestMockEmptyQuery(t *testing.T) { } func TestMockDeleteNonExistent(t *testing.T) { - client, cleanup := NewClient(t) + client, cleanup := datastore.NewMockClient(t) defer cleanup() ctx := context.Background() // Try to delete non-existent entity (should not error) - key := ds9.NameKey("Test", "nonexistent", nil) + key := datastore.NameKey("Test", "nonexistent", nil) err := client.Delete(ctx, key) if err != nil { t.Errorf("Delete of non-existent entity should not error, got: %v", err) @@ -353,7 +349,7 @@ func TestMockDeleteNonExistent(t *testing.T) { } func TestMockConcurrentAccess(t *testing.T) { - client, cleanup := NewClient(t) + client, cleanup := datastore.NewMockClient(t) defer cleanup() ctx := context.Background() @@ -373,7 +369,7 @@ func TestMockConcurrentAccess(t *testing.T) { defer func() { done <- true }() for i := range operations { - key := ds9.NameKey("ConcurrentKind", string(rune('a'+id%10)), nil) + key := datastore.NameKey("ConcurrentKind", string(rune('a'+id%10)), nil) entity := &TestEntity{Value: int64(i)} // Mix of reads and writes @@ -400,7 +396,7 @@ func TestMockConcurrentAccess(t *testing.T) { } func TestMockConcurrentQuery(t *testing.T) { - client, cleanup := NewClient(t) + client, cleanup := datastore.NewMockClient(t) defer cleanup() ctx := context.Background() @@ -411,7 +407,7 @@ func TestMockConcurrentQuery(t *testing.T) { // Populate some data for i := range 20 { - key := ds9.NameKey("QueryConcurrent", string(rune('a'+i)), nil) + key := datastore.NameKey("QueryConcurrent", string(rune('a'+i)), nil) entity := &TestEntity{Name: "test"} _, err := client.Put(ctx, key, entity) if err != nil { @@ -427,7 +423,7 @@ func TestMockConcurrentQuery(t *testing.T) { go func() { defer func() { done <- true }() - query := ds9.NewQuery("QueryConcurrent").KeysOnly() + query := datastore.NewQuery("QueryConcurrent").KeysOnly() keys, err := client.AllKeys(ctx, query) if err != nil { t.Errorf("AllKeys failed: %v", err) diff --git a/transaction_test.go b/transaction_test.go deleted file mode 100644 index 5c094ad..0000000 --- a/transaction_test.go +++ /dev/null @@ -1,402 +0,0 @@ -package ds9_test - -import ( - "context" - "errors" - "testing" - "time" - - "github.com/codeGROOVE-dev/ds9" - "github.com/codeGROOVE-dev/ds9/ds9mock" -) - -type txTestEntity struct { - Time time.Time - Name string - Count int64 -} - -func TestTransactionDelete(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - key := ds9.NameKey("TxDeleteTest", "test", nil) - - // Create an entity - entity := &txTestEntity{ - Name: "test", - Count: 42, - Time: time.Now().UTC().Truncate(time.Microsecond), - } - if _, err := client.Put(ctx, key, entity); err != nil { - t.Fatalf("Put failed: %v", err) - } - - // Delete in transaction - _, err := client.RunInTransaction(ctx, func(tx *ds9.Transaction) error { - return tx.Delete(key) - }) - if err != nil { - t.Fatalf("Transaction failed: %v", err) - } - - // Verify deletion - var result txTestEntity - err = client.Get(ctx, key, &result) - if !errors.Is(err, ds9.ErrNoSuchEntity) { - t.Errorf("Expected ds9.ErrNoSuchEntity, got %v", err) - } -} - -func TestTransactionDeleteMulti(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - // Create multiple entities - keys := []*ds9.Key{ - ds9.NameKey("TxDeleteMultiTest", "test1", nil), - ds9.NameKey("TxDeleteMultiTest", "test2", nil), - ds9.NameKey("TxDeleteMultiTest", "test3", nil), - } - - entities := []txTestEntity{ - {Name: "test1", Count: 1, Time: time.Now().UTC().Truncate(time.Microsecond)}, - {Name: "test2", Count: 2, Time: time.Now().UTC().Truncate(time.Microsecond)}, - {Name: "test3", Count: 3, Time: time.Now().UTC().Truncate(time.Microsecond)}, - } - - if _, err := client.PutMulti(ctx, keys, entities); err != nil { - t.Fatalf("PutMulti failed: %v", err) - } - - // Delete in transaction - _, err := client.RunInTransaction(ctx, func(tx *ds9.Transaction) error { - return tx.DeleteMulti(keys) - }) - if err != nil { - t.Fatalf("Transaction failed: %v", err) - } -} - -func TestTransactionGetMulti(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - // Create multiple entities - keys := []*ds9.Key{ - ds9.NameKey("TxGetMultiTest", "test1", nil), - ds9.NameKey("TxGetMultiTest", "test2", nil), - } - - entities := []txTestEntity{ - {Name: "test1", Count: 1, Time: time.Now().UTC().Truncate(time.Microsecond)}, - {Name: "test2", Count: 2, Time: time.Now().UTC().Truncate(time.Microsecond)}, - } - - if _, err := client.PutMulti(ctx, keys, entities); err != nil { - t.Fatalf("PutMulti failed: %v", err) - } - - // Get in transaction - _, err := client.RunInTransaction(ctx, func(tx *ds9.Transaction) error { - results := make([]txTestEntity, 2) - if err := tx.GetMulti(keys, &results); err != nil { - return err - } - - if results[0].Name != "test1" { - t.Errorf("Expected Name 'test1', got '%s'", results[0].Name) - } - if results[1].Name != "test2" { - t.Errorf("Expected Name 'test2', got '%s'", results[1].Name) - } - - return nil - }) - if err != nil { - t.Fatalf("Transaction failed: %v", err) - } -} - -func TestTransactionPutMulti(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - keys := []*ds9.Key{ - ds9.NameKey("TxPutMultiTest", "test1", nil), - ds9.NameKey("TxPutMultiTest", "test2", nil), - } - - entities := []txTestEntity{ - {Name: "test1", Count: 1, Time: time.Now().UTC().Truncate(time.Microsecond)}, - {Name: "test2", Count: 2, Time: time.Now().UTC().Truncate(time.Microsecond)}, - } - - // Put in transaction - _, err := client.RunInTransaction(ctx, func(tx *ds9.Transaction) error { - _, err := tx.PutMulti(keys, entities) - return err - }) - if err != nil { - t.Fatalf("Transaction failed: %v", err) - } - - // Verify entities were created - results := make([]txTestEntity, 2) - if err := client.GetMulti(ctx, keys, &results); err != nil { - t.Fatalf("GetMulti failed: %v", err) - } - - if results[0].Name != "test1" { - t.Errorf("Expected Name 'test1', got '%s'", results[0].Name) - } - if results[1].Name != "test2" { - t.Errorf("Expected Name 'test2', got '%s'", results[1].Name) - } -} - -func TestTransactionMixedOperations(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - key1 := ds9.NameKey("TxMixedTest", "read", nil) - key2 := ds9.NameKey("TxMixedTest", "write", nil) - key3 := ds9.NameKey("TxMixedTest", "delete", nil) - - // Create initial entities - if _, err := client.Put(ctx, key1, &txTestEntity{Name: "read", Count: 1, Time: time.Now().UTC().Truncate(time.Microsecond)}); err != nil { - t.Fatalf("Put failed: %v", err) - } - if _, err := client.Put(ctx, key3, &txTestEntity{Name: "delete", Count: 3, Time: time.Now().UTC().Truncate(time.Microsecond)}); err != nil { - t.Fatalf("Put failed: %v", err) - } - - // Transaction with mixed operations - _, err := client.RunInTransaction(ctx, func(tx *ds9.Transaction) error { - // Read - var entity txTestEntity - if err := tx.Get(key1, &entity); err != nil { - return err - } - - // Write - if _, err := tx.Put(key2, &txTestEntity{Name: "write", Count: 2, Time: time.Now().UTC().Truncate(time.Microsecond)}); err != nil { - return err - } - - // Delete - if err := tx.Delete(key3); err != nil { - return err - } - - return nil - }) - if err != nil { - t.Fatalf("Transaction failed: %v", err) - } - - // Verify write succeeded - var entity txTestEntity - if err := client.Get(ctx, key2, &entity); err != nil { - t.Fatalf("Get after transaction failed: %v", err) - } - if entity.Name != "write" { - t.Errorf("Expected Name 'write', got '%s'", entity.Name) - } - - // Verify delete succeeded - err = client.Get(ctx, key3, &entity) - if !errors.Is(err, ds9.ErrNoSuchEntity) { - t.Errorf("Expected ds9.ErrNoSuchEntity for deleted entity, got %v", err) - } -} - -func TestNewTransactionCommit(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - tx, err := client.NewTransaction(ctx) - if err != nil { - t.Fatalf("NewTransaction failed: %v", err) - } - - key := ds9.NameKey("TxCommitTest", "test", nil) - entity := &txTestEntity{ - Name: "test", - Count: 42, - Time: time.Now().UTC().Truncate(time.Microsecond), - } - - if _, err := tx.Put(key, entity); err != nil { - t.Fatalf("tx.Put failed: %v", err) - } - - commit, err := tx.Commit() - if err != nil { - t.Fatalf("tx.Commit failed: %v", err) - } - - if commit == nil { - t.Error("Expected non-nil Commit") - } - - // Verify entity was created - var result txTestEntity - if err := client.Get(ctx, key, &result); err != nil { - t.Fatalf("Get after commit failed: %v", err) - } - - if result.Name != "test" { - t.Errorf("Expected Name 'test', got '%s'", result.Name) - } -} - -func TestNewTransactionRollback(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - tx, err := client.NewTransaction(ctx) - if err != nil { - t.Fatalf("NewTransaction failed: %v", err) - } - - key := ds9.NameKey("TxRollbackTest", "test", nil) - entity := &txTestEntity{ - Name: "test", - Count: 42, - Time: time.Now().UTC().Truncate(time.Microsecond), - } - - if _, err := tx.Put(key, entity); err != nil { - t.Fatalf("tx.Put failed: %v", err) - } - - if err := tx.Rollback(); err != nil { - t.Fatalf("tx.Rollback failed: %v", err) - } - - // After rollback, transaction should not commit (but we can't verify internal state) -} - -func TestTransactionDeleteNilKey(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - _, err := client.RunInTransaction(ctx, func(tx *ds9.Transaction) error { - return tx.Delete(nil) - }) - - if err == nil { - t.Error("Expected error for nil key") - } -} - -func TestTransactionGetMultiLengthMismatch(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - keys := []*ds9.Key{ - ds9.NameKey("Test", "test1", nil), - ds9.NameKey("Test", "test2", nil), - } - - _, err := client.RunInTransaction(ctx, func(tx *ds9.Transaction) error { - results := make([]txTestEntity, 1) // Wrong length - return tx.GetMulti(keys, &results) - }) - - if err == nil { - t.Error("Expected error for length mismatch") - } -} - -func TestTransactionPutMultiLengthMismatch(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - keys := []*ds9.Key{ - ds9.NameKey("Test", "test1", nil), - } - - entities := []txTestEntity{ - {Name: "test1", Count: 1, Time: time.Now().UTC()}, - {Name: "test2", Count: 2, Time: time.Now().UTC()}, - } - - _, err := client.RunInTransaction(ctx, func(tx *ds9.Transaction) error { - _, err := tx.PutMulti(keys, entities) - return err - }) - - if err == nil { - t.Error("Expected error for length mismatch") - } -} - -func TestTransactionWithOptions(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - // Test that transaction options are accepted - _, err := client.NewTransaction(ctx, ds9.MaxAttempts(5)) - if err != nil { - t.Fatalf("NewTransaction with options failed: %v", err) - } - - // Test with read time option - _, err = client.NewTransaction(ctx, ds9.WithReadTime(time.Now().UTC())) - if err != nil { - t.Fatalf("NewTransaction with ds9.WithReadTime failed: %v", err) - } -} - -func TestNewTransactionMultipleOperations(t *testing.T) { - client, cleanup := ds9mock.NewClient(t) - defer cleanup() - - ctx := context.Background() - - tx, err := client.NewTransaction(ctx) - if err != nil { - t.Fatalf("NewTransaction failed: %v", err) - } - - // Multiple puts - for i := range 3 { - key := ds9.IDKey("TxMultiOp", int64(i+1), nil) - entity := &txTestEntity{ - Name: "test", - Count: int64(i), - Time: time.Now().UTC().Truncate(time.Microsecond), - } - if _, err := tx.Put(key, entity); err != nil { - t.Fatalf("tx.Put failed: %v", err) - } - } - - // Commit the transaction - if _, err := tx.Commit(); err != nil { - t.Fatalf("tx.Commit failed: %v", err) - } -}