diff --git a/README.md b/README.md index 8da3217..bf2fce7 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,14 @@ client.Put(ctx, key, &task) client.Get(ctx, key, &task) ``` -**Supported:** Get, Put, Delete, GetMulti, PutMulti, DeleteMulti, RunInTransaction, AllKeys, NameKey, IDKey, parent keys. +**Supported:** +- **CRUD**: Get, Put, Delete, GetMulti, PutMulti, DeleteMulti +- **Transactions**: RunInTransaction, NewTransaction, Commit, Rollback +- **Queries**: Filter, Order, Limit, Offset, Ancestor, Project, Distinct, DistinctOn, Namespace, Run (iterator), Count +- **Cursors**: Start, End, DecodeCursor +- **Keys**: NameKey, IDKey, IncompleteKey, AllocateIDs, parent keys +- **Mutations**: NewInsert, NewUpdate, NewUpsert, NewDelete, Mutate +- **Types**: string, int, int64, int32, bool, float64, time.Time, slices ([]string, []int64, []int, []float64, []bool) ## Migrating from Official Client @@ -35,6 +42,6 @@ Use `ds9mock` package for in-memory testing. See [TESTING.md](TESTING.md) for in ## Limitations -Not supported: property filters, ordering, cursors, ancestor queries, slices/arrays, embedded structs, key allocation. +Not supported: embedded structs, nested slices, map types, some advanced query features (streaming aggregations, OR filters). See [example/](example/) for usage. Apache 2.0 licensed. diff --git a/auth/auth_test.go b/auth/auth_test.go index 00f8a7a..a29d73f 100644 --- a/auth/auth_test.go +++ b/auth/auth_test.go @@ -41,13 +41,13 @@ func TestSetMetadataURL(t *testing.T) { func TestAccessTokenFromMetadata(t *testing.T) { tests := []struct { - name string - statusCode int response any - wantErr bool + name string wantToken string errContains string metadataFlavor string + statusCode int + wantErr bool }{ { name: "success", @@ -214,12 +214,12 @@ func TestAccessTokenFromADC(t *testing.T) { defer tokenServer.Close() tests := []struct { + setupEnv func() name string credsData string - setupEnv func() - wantErr bool errContains string wantToken string + wantErr bool }{ { name: "success with valid credentials", @@ -310,11 +310,11 @@ func TestAccessTokenFromADC(t *testing.T) { func TestProjectID(t *testing.T) { tests := []struct { name string - statusCode int response string - wantErr bool wantProject string errContains string + statusCode int + wantErr bool }{ { name: "success", @@ -414,10 +414,10 @@ func TestProjectIDMetadataServerDown(t *testing.T) { func TestExchangeRefreshTokenErrors(t *testing.T) { tests := []struct { name string - statusCode int response string - wantErr bool errContains string + statusCode int + wantErr bool }{ { name: "unauthorized", diff --git a/datastore.go b/datastore.go index 1b56079..7564756 100644 --- a/datastore.go +++ b/datastore.go @@ -8,6 +8,7 @@ package ds9 import ( "bytes" "context" + "encoding/base64" "encoding/json" "errors" "fmt" @@ -50,6 +51,15 @@ var ( MaxIdleConnsPerHost: 2, }, } + + // operatorMap converts shorthand operators to Datastore API operators. + operatorMap = map[string]string{ + "=": "EQUAL", + "<": "LESS_THAN", + "<=": "LESS_THAN_OR_EQUAL", + ">": "GREATER_THAN", + ">=": "GREATER_THAN_OR_EQUAL", + } ) // SetTestURLs configures custom metadata and API URLs for testing. @@ -153,6 +163,271 @@ func IDKey(kind string, id int64, parent *Key) *Key { } } +// IncompleteKey creates a new incomplete key. +// The key will be completed (assigned an ID) when the entity is saved. +// API compatible with cloud.google.com/go/datastore. +func IncompleteKey(kind string, parent *Key) *Key { + return &Key{ + Kind: kind, + Parent: parent, + } +} + +// Incomplete returns true if the key does not have an ID or Name. +// API compatible with cloud.google.com/go/datastore. +func (k *Key) Incomplete() bool { + return k.ID == 0 && k.Name == "" +} + +// Equal returns true if this key is equal to the other key. +// API compatible with cloud.google.com/go/datastore. +func (k *Key) Equal(other *Key) bool { + if k == nil && other == nil { + return true + } + if k == nil || other == nil { + return false + } + if k.Kind != other.Kind || k.Name != other.Name || k.ID != other.ID { + return false + } + // Recursively check parent keys + return k.Parent.Equal(other.Parent) +} + +// String returns a human-readable string representation of the key. +// API compatible with cloud.google.com/go/datastore. +func (k *Key) String() string { + if k == nil { + return "" + } + + var parts []string + for curr := k; curr != nil; curr = curr.Parent { + var part string + switch { + case curr.Name != "": + part = fmt.Sprintf("%s,%q", curr.Kind, curr.Name) + case curr.ID != 0: + part = fmt.Sprintf("%s,%d", curr.Kind, curr.ID) + default: + part = fmt.Sprintf("%s,incomplete", curr.Kind) + } + // Prepend to maintain correct order (root to leaf) + parts = append([]string{part}, parts...) + } + + return "/" + strings.Join(parts, "/") +} + +// Encode returns an opaque representation of the key. +// API compatible with cloud.google.com/go/datastore. +func (k *Key) Encode() string { + if k == nil { + return "" + } + + // Convert key to JSON representation + keyJSON := keyToJSON(k) + + // Marshal to JSON bytes + jsonBytes, err := json.Marshal(keyJSON) + if err != nil { + return "" + } + + // Base64 encode + return base64.URLEncoding.EncodeToString(jsonBytes) +} + +// DecodeKey decodes a key from its opaque representation. +// API compatible with cloud.google.com/go/datastore. +func DecodeKey(encoded string) (*Key, error) { + if encoded == "" { + return nil, errors.New("empty encoded key") + } + + // Base64 decode + jsonBytes, err := base64.URLEncoding.DecodeString(encoded) + if err != nil { + return nil, fmt.Errorf("failed to decode base64: %w", err) + } + + // Unmarshal JSON + var keyData any + if err := json.Unmarshal(jsonBytes, &keyData); err != nil { + return nil, fmt.Errorf("failed to unmarshal JSON: %w", err) + } + + // Convert from JSON representation + key, err := keyFromJSON(keyData) + if err != nil { + return nil, fmt.Errorf("failed to parse key: %w", err) + } + + return key, nil +} + +// Cursor represents a query cursor for pagination. +// API compatible with cloud.google.com/go/datastore. +type Cursor string + +// String returns the cursor as a string. +func (c Cursor) String() string { + return string(c) +} + +// DecodeCursor decodes a cursor string. +// API compatible with cloud.google.com/go/datastore. +func DecodeCursor(s string) (Cursor, error) { + if s == "" { + return "", errors.New("empty cursor string") + } + return Cursor(s), nil +} + +// Iterator is an iterator for query results. +// API compatible with cloud.google.com/go/datastore. +type Iterator struct { + ctx context.Context //nolint:containedctx // Required for API compatibility with cloud.google.com/go/datastore + client *Client + query *Query + results []iteratorResult + index int + err error + cursor Cursor + fetchNext bool +} + +type iteratorResult struct { + key *Key + entity map[string]any + cursor Cursor +} + +// Next advances the iterator and returns the next key and destination. +// It returns Done when no more results are available. +// API compatible with cloud.google.com/go/datastore. +func (it *Iterator) Next(dst any) (*Key, error) { + // Check if we need to fetch more results + if it.index >= len(it.results) { + if it.err != nil { + return nil, it.err + } + if !it.fetchNext { + return nil, ErrDone + } + + // Fetch next batch + if err := it.fetch(); err != nil { + it.err = err + return nil, err + } + + if len(it.results) == 0 { + return nil, ErrDone + } + } + + result := it.results[it.index] + it.index++ + it.cursor = result.cursor + + // Decode entity into dst + if err := decodeEntity(result.entity, dst); err != nil { + return nil, err + } + + return result.key, nil +} + +// Cursor returns the cursor for the iterator's current position. +// API compatible with cloud.google.com/go/datastore. +func (it *Iterator) Cursor() (Cursor, error) { + if it.cursor == "" { + return "", errors.New("no cursor available") + } + return it.cursor, nil +} + +// fetch retrieves the next batch of results. +func (it *Iterator) fetch() error { + token, err := auth.AccessToken(it.ctx) + if err != nil { + return fmt.Errorf("failed to get access token: %w", err) + } + + // Build query with current cursor as start + q := *it.query + if it.cursor != "" { + q.startCursor = it.cursor + } + + queryObj := buildQueryMap(&q) + reqBody := map[string]any{"query": queryObj} + if it.client.databaseID != "" { + reqBody["databaseId"] = it.client.databaseID + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return fmt.Errorf("failed to marshal request: %w", err) + } + + // URL-encode project ID to prevent injection attacks + reqURL := fmt.Sprintf("%s/projects/%s:runQuery", apiURL, neturl.PathEscape(it.client.projectID)) + body, err := doRequest(it.ctx, it.client.logger, reqURL, jsonData, token, it.client.projectID, it.client.databaseID) + if err != nil { + return err + } + + var result struct { + Batch struct { + EntityResults []struct { + Entity map[string]any `json:"entity"` + Cursor string `json:"cursor"` + } `json:"entityResults"` + MoreResults string `json:"moreResults"` + EndCursor string `json:"endCursor"` + SkippedResults int `json:"skippedResults"` + } `json:"batch"` + } + + if err := json.Unmarshal(body, &result); err != nil { + return fmt.Errorf("failed to parse response: %w", err) + } + + // Convert results to iterator format + it.results = make([]iteratorResult, 0, len(result.Batch.EntityResults)) + for _, er := range result.Batch.EntityResults { + key, err := keyFromJSON(er.Entity["key"]) + if err != nil { + return err + } + + it.results = append(it.results, iteratorResult{ + key: key, + entity: er.Entity, + cursor: Cursor(er.Cursor), + }) + } + + it.index = 0 + + // Check if there are more results + moreResults := result.Batch.MoreResults + it.fetchNext = moreResults == "NOT_FINISHED" || moreResults == "MORE_RESULTS_AFTER_LIMIT" || moreResults == "MORE_RESULTS_AFTER_CURSOR" + + if result.Batch.EndCursor != "" { + it.cursor = Cursor(result.Batch.EndCursor) + } + + return nil +} + +// ErrDone is returned by Iterator.Next when no more results are available. +var ErrDone = errors.New("datastore: no more results") + // doRequest performs an HTTP request with exponential backoff retries. // Returns an error if the status code is not 200 OK. func doRequest(ctx context.Context, logger *slog.Logger, url string, jsonData []byte, token, projectID, databaseID string) ([]byte, error) { @@ -658,6 +933,95 @@ func (c *Client) DeleteAllByKind(ctx context.Context, kind string) error { return nil } +// AllocateIDs allocates IDs for incomplete keys. +// Returns keys with IDs filled in. Complete keys are returned unchanged. +// API compatible with cloud.google.com/go/datastore. +func (c *Client) AllocateIDs(ctx context.Context, keys []*Key) ([]*Key, error) { + if len(keys) == 0 { + return keys, nil + } + + c.logger.DebugContext(ctx, "allocating IDs", "count", len(keys)) + + // Separate incomplete and complete keys + var incompleteKeys []*Key + var incompleteIndices []int + for i, key := range keys { + if key != nil && key.Incomplete() { + incompleteKeys = append(incompleteKeys, key) + incompleteIndices = append(incompleteIndices, i) + } + } + + // If no incomplete keys, return original slice + if len(incompleteKeys) == 0 { + c.logger.DebugContext(ctx, "no incomplete keys to allocate") + return keys, nil + } + + token, err := auth.AccessToken(ctx) + if err != nil { + c.logger.ErrorContext(ctx, "failed to get access token", "error", err) + return nil, fmt.Errorf("failed to get access token: %w", err) + } + + // Build request with incomplete keys + reqKeys := make([]map[string]any, len(incompleteKeys)) + for i, key := range incompleteKeys { + reqKeys[i] = keyToJSON(key) + } + + reqBody := map[string]any{ + "keys": reqKeys, + } + if c.databaseID != "" { + reqBody["databaseId"] = c.databaseID + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + c.logger.ErrorContext(ctx, "failed to marshal request", "error", err) + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + // URL-encode project ID to prevent injection attacks + reqURL := fmt.Sprintf("%s/projects/%s:allocateIds", apiURL, neturl.PathEscape(c.projectID)) + body, err := doRequest(ctx, c.logger, reqURL, jsonData, token, c.projectID, c.databaseID) + if err != nil { + c.logger.ErrorContext(ctx, "allocateIds request failed", "error", err) + return nil, err + } + + var resp struct { + Keys []map[string]any `json:"keys"` + } + if err := json.Unmarshal(body, &resp); err != nil { + c.logger.ErrorContext(ctx, "failed to parse response", "error", err) + return nil, fmt.Errorf("failed to parse allocateIds response: %w", err) + } + + // Parse allocated keys + allocatedKeys := make([]*Key, len(resp.Keys)) + for i, keyData := range resp.Keys { + key, err := keyFromJSON(keyData) + if err != nil { + c.logger.ErrorContext(ctx, "failed to parse allocated key", "index", i, "error", err) + return nil, fmt.Errorf("failed to parse allocated key at index %d: %w", i, err) + } + allocatedKeys[i] = key + } + + // Create result slice with allocated keys in correct positions + result := make([]*Key, len(keys)) + copy(result, keys) + for i, idx := range incompleteIndices { + result[idx] = allocatedKeys[i] + } + + c.logger.DebugContext(ctx, "IDs allocated successfully", "count", len(allocatedKeys)) + return result, nil +} + // keyToJSON converts a Key to its JSON representation. // Supports hierarchical keys with parent relationships. func keyToJSON(key *Key) map[string]any { @@ -772,7 +1136,57 @@ func encodeValue(v any) (any, error) { return map[string]any{"doubleValue": val}, nil case time.Time: return map[string]any{"timestampValue": val.Format(time.RFC3339Nano)}, nil + case []string: + values := make([]map[string]any, len(val)) + for i, s := range val { + values[i] = map[string]any{"stringValue": s} + } + return map[string]any{"arrayValue": map[string]any{"values": values}}, nil + case []int64: + values := make([]map[string]any, len(val)) + for i, n := range val { + values[i] = map[string]any{"integerValue": strconv.FormatInt(n, 10)} + } + return map[string]any{"arrayValue": map[string]any{"values": values}}, nil + case []int: + values := make([]map[string]any, len(val)) + for i, n := range val { + values[i] = map[string]any{"integerValue": strconv.Itoa(n)} + } + return map[string]any{"arrayValue": map[string]any{"values": values}}, nil + case []float64: + values := make([]map[string]any, len(val)) + for i, f := range val { + values[i] = map[string]any{"doubleValue": f} + } + return map[string]any{"arrayValue": map[string]any{"values": values}}, nil + case []bool: + values := make([]map[string]any, len(val)) + for i, b := range val { + values[i] = map[string]any{"booleanValue": b} + } + return map[string]any{"arrayValue": map[string]any{"values": values}}, nil default: + // Try to handle slices/arrays via reflection + rv := reflect.ValueOf(v) + if rv.Kind() == reflect.Slice || rv.Kind() == reflect.Array { + length := rv.Len() + values := make([]map[string]any, length) + for i := 0; i < length; i++ { + elem := rv.Index(i).Interface() + encodedElem, err := encodeValue(elem) + if err != nil { + return nil, fmt.Errorf("failed to encode array element %d: %w", i, err) + } + // encodedElem is already a map[string]any with the type wrapper + if m, ok := encodedElem.(map[string]any); ok { + values[i] = m + } else { + return nil, fmt.Errorf("unexpected encoded value type for element %d", i) + } + } + return map[string]any{"arrayValue": map[string]any{"values": values}}, nil + } return nil, fmt.Errorf("unsupported type: %T", v) } } @@ -893,6 +1307,48 @@ func decodeValue(prop map[string]any, dst reflect.Value) error { } } + if val, ok := prop["arrayValue"]; ok { + if dst.Kind() != reflect.Slice { + return fmt.Errorf("cannot decode array into non-slice type: %s", dst.Type()) + } + + arrayMap, ok := val.(map[string]any) + if !ok { + return errors.New("invalid arrayValue format") + } + + valuesAny, ok := arrayMap["values"] + if !ok { + // Empty array + dst.Set(reflect.MakeSlice(dst.Type(), 0, 0)) + return nil + } + + values, ok := valuesAny.([]any) + if !ok { + return errors.New("invalid arrayValue.values format") + } + + // Create slice with appropriate capacity + slice := reflect.MakeSlice(dst.Type(), len(values), len(values)) + + // Decode each element + for i, elemAny := range values { + elemMap, ok := elemAny.(map[string]any) + if !ok { + return fmt.Errorf("invalid array element %d format", i) + } + + elemValue := slice.Index(i) + if err := decodeValue(elemMap, elemValue); err != nil { + return fmt.Errorf("failed to decode array element %d: %w", i, err) + } + } + + dst.Set(slice) + return nil + } + if _, ok := prop["nullValue"]; ok { // Set to zero value dst.Set(reflect.Zero(dst.Type())) @@ -904,9 +1360,29 @@ func decodeValue(prop map[string]any, dst reflect.Value) error { // Query represents a Datastore query. type Query struct { - kind string - keysOnly bool - limit int + ancestor *Key + kind string + filters []queryFilter + orders []queryOrder + projection []string + distinctOn []string + namespace string + startCursor Cursor + endCursor Cursor + limit int + offset int + keysOnly bool +} + +type queryFilter struct { + value any + property string + operator string +} + +type queryOrder struct { + property string + direction string // "ASCENDING" or "DESCENDING" } // NewQuery creates a new query for the given kind. @@ -928,47 +1404,298 @@ func (q *Query) Limit(limit int) *Query { return q } -// AllKeys returns all keys matching the query. -// This is a convenience method for KeysOnly queries. -func (c *Client) AllKeys(ctx context.Context, q *Query) ([]*Key, error) { - if !q.keysOnly { - c.logger.WarnContext(ctx, "AllKeys called on non-KeysOnly query") - return nil, errors.New("AllKeys requires KeysOnly query") +// Offset sets the number of results to skip before returning. +// API compatible with cloud.google.com/go/datastore. +func (q *Query) Offset(offset int) *Query { + q.offset = offset + return q +} + +// Filter adds a property filter to the query. +// The filterStr should be in the format "Property Operator" (e.g., "Count >", "Name ="). +// Deprecated: Use FilterField instead. API compatible with cloud.google.com/go/datastore. +func (q *Query) Filter(filterStr string, value any) *Query { + // Parse the filter string to extract property and operator + parts := strings.Fields(filterStr) + if len(parts) != 2 { + // Invalid filter format, but we'll be lenient + return q } - c.logger.DebugContext(ctx, "querying for keys", "kind", q.kind, "limit", q.limit) + property := parts[0] + op := parts[1] - token, err := auth.AccessToken(ctx) - if err != nil { - c.logger.ErrorContext(ctx, "failed to get access token", "error", err) - return nil, fmt.Errorf("failed to get access token: %w", err) + operator, ok := operatorMap[op] + if !ok { + operator = "EQUAL" } - query := map[string]any{ - "kind": []map[string]any{{"name": q.kind}}, - "projection": []map[string]any{{"property": map[string]string{"name": "__key__"}}}, - } - if q.limit > 0 { - query["limit"] = q.limit - } + q.filters = append(q.filters, queryFilter{ + property: property, + operator: operator, + value: value, + }) - reqBody := map[string]any{"query": query} - if c.databaseID != "" { - reqBody["databaseId"] = c.databaseID - } + return q +} - jsonData, err := json.Marshal(reqBody) - if err != nil { - c.logger.ErrorContext(ctx, "failed to marshal request", "error", err) - return nil, fmt.Errorf("failed to marshal request: %w", err) +// FilterField adds a property filter to the query with explicit operator. +// API compatible with cloud.google.com/go/datastore. +func (q *Query) FilterField(fieldName, operator string, value any) *Query { + dsOperator, ok := operatorMap[operator] + if !ok { + dsOperator = operator // Use as-is if not in map (might already be EQUAL, etc.) } - // URL-encode project ID to prevent injection attacks - reqURL := fmt.Sprintf("%s/projects/%s:runQuery", apiURL, neturl.PathEscape(c.projectID)) - body, err := doRequest(ctx, c.logger, reqURL, jsonData, token, c.projectID, c.databaseID) - if err != nil { - c.logger.ErrorContext(ctx, "query request failed", "error", err, "kind", q.kind) - return nil, err + q.filters = append(q.filters, queryFilter{ + property: fieldName, + operator: dsOperator, + value: value, + }) + + return q +} + +// Order sets the order in which results are returned. +// Prefix the property name with "-" for descending order (e.g., "-Created"). +// API compatible with cloud.google.com/go/datastore. +func (q *Query) Order(fieldName string) *Query { + direction := "ASCENDING" + property := fieldName + + if strings.HasPrefix(fieldName, "-") { + direction = "DESCENDING" + property = fieldName[1:] + } + + q.orders = append(q.orders, queryOrder{ + property: property, + direction: direction, + }) + + return q +} + +// Ancestor sets an ancestor filter for the query. +// Only entities with the given ancestor will be returned. +// API compatible with cloud.google.com/go/datastore. +func (q *Query) Ancestor(ancestor *Key) *Query { + q.ancestor = ancestor + return q +} + +// Project sets the fields to be projected (returned) in the query results. +// API compatible with cloud.google.com/go/datastore. +func (q *Query) Project(fieldNames ...string) *Query { + q.projection = fieldNames + return q +} + +// Distinct marks the query to return only distinct results. +// This is equivalent to DistinctOn with all projected fields. +// API compatible with cloud.google.com/go/datastore. +func (q *Query) Distinct() *Query { + // Distinct without fields means distinct on projection + // This will be handled in buildQueryMap + if len(q.projection) > 0 { + q.distinctOn = q.projection + } + return q +} + +// DistinctOn returns a query that removes duplicates based on the given field names. +// API compatible with cloud.google.com/go/datastore. +func (q *Query) DistinctOn(fieldNames ...string) *Query { + q.distinctOn = fieldNames + return q +} + +// Namespace sets the namespace for the query. +// API compatible with cloud.google.com/go/datastore. +func (q *Query) Namespace(ns string) *Query { + q.namespace = ns + return q +} + +// Start sets the starting cursor for the query results. +// API compatible with cloud.google.com/go/datastore. +func (q *Query) Start(c Cursor) *Query { + q.startCursor = c + return q +} + +// End sets the ending cursor for the query results. +// API compatible with cloud.google.com/go/datastore. +func (q *Query) End(c Cursor) *Query { + q.endCursor = c + return q +} + +// buildQueryMap creates a Datastore API query map from a Query object. +func buildQueryMap(query *Query) map[string]any { + queryMap := map[string]any{ + "kind": []map[string]any{{"name": query.kind}}, + } + + // Add namespace via partition ID if specified + if query.namespace != "" { + queryMap["partitionId"] = map[string]any{ + "namespaceId": query.namespace, + } + } + + // Add filters + if len(query.filters) > 0 { + var compositeFilters []map[string]any + for _, f := range query.filters { + encodedVal, err := encodeValue(f.value) + if err != nil { + // Skip invalid filters + continue + } + compositeFilters = append(compositeFilters, map[string]any{ + "propertyFilter": map[string]any{ + "property": map[string]string{"name": f.property}, + "op": f.operator, + "value": encodedVal, + }, + }) + } + + if len(compositeFilters) == 1 { + queryMap["filter"] = compositeFilters[0] + } else if len(compositeFilters) > 1 { + queryMap["filter"] = map[string]any{ + "compositeFilter": map[string]any{ + "op": "AND", + "filters": compositeFilters, + }, + } + } + } + + // Add ancestor filter + if query.ancestor != nil { + ancestorFilter := map[string]any{ + "propertyFilter": map[string]any{ + "property": map[string]string{"name": "__key__"}, + "op": "HAS_ANCESTOR", + "value": map[string]any{"keyValue": keyToJSON(query.ancestor)}, + }, + } + + // Combine with existing filters if present + if existingFilter, ok := queryMap["filter"]; ok { + existingMap, ok := existingFilter.(map[string]any) + if !ok { + // Skip if filter is invalid + queryMap["filter"] = ancestorFilter + } else { + queryMap["filter"] = map[string]any{ + "compositeFilter": map[string]any{ + "op": "AND", + "filters": []map[string]any{existingMap, ancestorFilter}, + }, + } + } + } else { + queryMap["filter"] = ancestorFilter + } + } + + // Add ordering + if len(query.orders) > 0 { + var orders []map[string]any + for _, o := range query.orders { + orders = append(orders, map[string]any{ + "property": map[string]string{"name": o.property}, + "direction": o.direction, + }) + } + queryMap["order"] = orders + } + + // Add projection + if len(query.projection) > 0 { + var projections []map[string]any + for _, field := range query.projection { + projections = append(projections, map[string]any{ + "property": map[string]string{"name": field}, + }) + } + queryMap["projection"] = projections + } else if query.keysOnly { + // Keys-only projection + queryMap["projection"] = []map[string]any{{"property": map[string]string{"name": "__key__"}}} + } + + // Add distinct on + if len(query.distinctOn) > 0 { + var distinctFields []map[string]any + for _, field := range query.distinctOn { + distinctFields = append(distinctFields, map[string]any{ + "property": map[string]string{"name": field}, + }) + } + queryMap["distinctOn"] = distinctFields + } + + // Add limit + if query.limit > 0 { + queryMap["limit"] = query.limit + } + + // Add offset + if query.offset > 0 { + queryMap["offset"] = query.offset + } + + // Add cursors + if query.startCursor != "" { + queryMap["startCursor"] = string(query.startCursor) + } + if query.endCursor != "" { + queryMap["endCursor"] = string(query.endCursor) + } + + return queryMap +} + +// AllKeys returns all keys matching the query. +// This is a convenience method for KeysOnly queries. +func (c *Client) AllKeys(ctx context.Context, q *Query) ([]*Key, error) { + if !q.keysOnly { + c.logger.WarnContext(ctx, "AllKeys called on non-KeysOnly query") + return nil, errors.New("AllKeys requires KeysOnly query") + } + + c.logger.DebugContext(ctx, "querying for keys", "kind", q.kind, "limit", q.limit) + + token, err := auth.AccessToken(ctx) + if err != nil { + c.logger.ErrorContext(ctx, "failed to get access token", "error", err) + return nil, fmt.Errorf("failed to get access token: %w", err) + } + + query := buildQueryMap(q) + + reqBody := map[string]any{"query": query} + if c.databaseID != "" { + reqBody["databaseId"] = c.databaseID + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + c.logger.ErrorContext(ctx, "failed to marshal request", "error", err) + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + // URL-encode project ID to prevent injection attacks + reqURL := fmt.Sprintf("%s/projects/%s:runQuery", apiURL, neturl.PathEscape(c.projectID)) + body, err := doRequest(ctx, c.logger, reqURL, jsonData, token, c.projectID, c.databaseID) + if err != nil { + c.logger.ErrorContext(ctx, "query request failed", "error", err, "kind", q.kind) + return nil, err } var result struct { @@ -1011,12 +1738,7 @@ func (c *Client) GetAll(ctx context.Context, query *Query, dst any) ([]*Key, err return nil, fmt.Errorf("failed to get access token: %w", err) } - queryObj := map[string]any{ - "kind": []map[string]any{{"name": query.kind}}, - } - if query.limit > 0 { - queryObj["limit"] = query.limit - } + queryObj := buildQueryMap(query) reqBody := map[string]any{"query": queryObj} if c.databaseID != "" { @@ -1086,6 +1808,287 @@ func (c *Client) GetAll(ctx context.Context, query *Query, dst any) ([]*Key, err return keys, nil } +// Count returns the number of entities matching the query. +// Deprecated: Use aggregation queries with RunAggregationQuery instead. +// API compatible with cloud.google.com/go/datastore. +func (c *Client) Count(ctx context.Context, q *Query) (int, error) { + c.logger.DebugContext(ctx, "counting entities", "kind", q.kind) + + token, err := auth.AccessToken(ctx) + if err != nil { + c.logger.ErrorContext(ctx, "failed to get access token", "error", err) + return 0, fmt.Errorf("failed to get access token: %w", err) + } + + // Build aggregation query with COUNT + queryObj := buildQueryMap(q) + aggregationQuery := map[string]any{ + "aggregations": []map[string]any{ + { + "alias": "total", + "count": map[string]any{}, + }, + }, + "nestedQuery": queryObj, + } + + reqBody := map[string]any{ + "aggregationQuery": aggregationQuery, + } + if c.databaseID != "" { + reqBody["databaseId"] = c.databaseID + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + c.logger.ErrorContext(ctx, "failed to marshal request", "error", err) + return 0, fmt.Errorf("failed to marshal request: %w", err) + } + + // URL-encode project ID to prevent injection attacks + reqURL := fmt.Sprintf("%s/projects/%s:runAggregationQuery", apiURL, neturl.PathEscape(c.projectID)) + body, err := doRequest(ctx, c.logger, reqURL, jsonData, token, c.projectID, c.databaseID) + if err != nil { + c.logger.ErrorContext(ctx, "count query failed", "error", err, "kind", q.kind) + return 0, err + } + + var result struct { + Batch struct { + AggregationResults []struct { + AggregateProperties map[string]struct { + IntegerValue string `json:"integerValue"` + } `json:"aggregateProperties"` + } `json:"aggregationResults"` + } `json:"batch"` + } + + if err := json.Unmarshal(body, &result); err != nil { + c.logger.ErrorContext(ctx, "failed to parse response", "error", err) + return 0, fmt.Errorf("failed to parse count response: %w", err) + } + + if len(result.Batch.AggregationResults) == 0 { + c.logger.DebugContext(ctx, "no results returned", "kind", q.kind) + return 0, nil + } + + // Extract count from total aggregation + countVal, ok := result.Batch.AggregationResults[0].AggregateProperties["total"] + if !ok { + c.logger.ErrorContext(ctx, "count not found in response") + return 0, errors.New("count not found in aggregation response") + } + + count, err := strconv.Atoi(countVal.IntegerValue) + if err != nil { + c.logger.ErrorContext(ctx, "failed to parse count", "error", err, "value", countVal.IntegerValue) + return 0, fmt.Errorf("failed to parse count: %w", err) + } + + c.logger.DebugContext(ctx, "count completed successfully", "kind", q.kind, "count", count) + return count, nil +} + +// Run executes the query and returns an iterator for the results. +// API compatible with cloud.google.com/go/datastore. +func (c *Client) Run(ctx context.Context, q *Query) *Iterator { + return &Iterator{ + ctx: ctx, + client: c, + query: q, + fetchNext: true, + } +} + +// MutationOp represents the type of mutation operation. +type MutationOp string + +const ( + // MutationInsert represents an insert operation. + MutationInsert MutationOp = "insert" + // MutationUpdate represents an update operation. + MutationUpdate MutationOp = "update" + // MutationUpsert represents an upsert operation. + MutationUpsert MutationOp = "upsert" + // MutationDelete represents a delete operation. + MutationDelete MutationOp = "delete" +) + +// Mutation represents a pending datastore mutation. +type Mutation struct { + op MutationOp + key *Key + entity any +} + +// NewInsert creates an insert mutation. +// API compatible with cloud.google.com/go/datastore. +func NewInsert(k *Key, src any) *Mutation { + return &Mutation{ + op: MutationInsert, + key: k, + entity: src, + } +} + +// NewUpdate creates an update mutation. +// API compatible with cloud.google.com/go/datastore. +func NewUpdate(k *Key, src any) *Mutation { + return &Mutation{ + op: MutationUpdate, + key: k, + entity: src, + } +} + +// NewUpsert creates an upsert mutation. +// API compatible with cloud.google.com/go/datastore. +func NewUpsert(k *Key, src any) *Mutation { + return &Mutation{ + op: MutationUpsert, + key: k, + entity: src, + } +} + +// NewDelete creates a delete mutation. +// API compatible with cloud.google.com/go/datastore. +func NewDelete(k *Key) *Mutation { + return &Mutation{ + op: MutationDelete, + key: k, + } +} + +// Mutate applies one or more mutations atomically. +// API compatible with cloud.google.com/go/datastore. +func (c *Client) Mutate(ctx context.Context, muts ...*Mutation) ([]*Key, error) { + if len(muts) == 0 { + return nil, nil + } + + c.logger.DebugContext(ctx, "applying mutations", "count", len(muts)) + + token, err := auth.AccessToken(ctx) + if err != nil { + c.logger.ErrorContext(ctx, "failed to get access token", "error", err) + return nil, fmt.Errorf("failed to get access token: %w", err) + } + + // Build mutations array + mutations := make([]map[string]any, 0, len(muts)) + for i, mut := range muts { + if mut == nil { + c.logger.ErrorContext(ctx, "nil mutation", "index", i) + return nil, fmt.Errorf("mutation at index %d is nil", i) + } + if mut.key == nil { + c.logger.ErrorContext(ctx, "nil key in mutation", "index", i) + return nil, fmt.Errorf("mutation at index %d has nil key", i) + } + + mutMap := make(map[string]any) + + switch mut.op { + case MutationInsert: + if mut.entity == nil { + c.logger.ErrorContext(ctx, "nil entity for insert", "index", i) + return nil, fmt.Errorf("insert mutation at index %d has nil entity", i) + } + entity, err := encodeEntity(mut.key, mut.entity) + if err != nil { + c.logger.ErrorContext(ctx, "failed to encode entity", "index", i, "error", err) + return nil, fmt.Errorf("failed to encode entity at index %d: %w", i, err) + } + mutMap["insert"] = entity + + case MutationUpdate: + if mut.entity == nil { + c.logger.ErrorContext(ctx, "nil entity for update", "index", i) + return nil, fmt.Errorf("update mutation at index %d has nil entity", i) + } + entity, err := encodeEntity(mut.key, mut.entity) + if err != nil { + c.logger.ErrorContext(ctx, "failed to encode entity", "index", i, "error", err) + return nil, fmt.Errorf("failed to encode entity at index %d: %w", i, err) + } + mutMap["update"] = entity + + case MutationUpsert: + if mut.entity == nil { + c.logger.ErrorContext(ctx, "nil entity for upsert", "index", i) + return nil, fmt.Errorf("upsert mutation at index %d has nil entity", i) + } + entity, err := encodeEntity(mut.key, mut.entity) + if err != nil { + c.logger.ErrorContext(ctx, "failed to encode entity", "index", i, "error", err) + return nil, fmt.Errorf("failed to encode entity at index %d: %w", i, err) + } + mutMap["upsert"] = entity + + case MutationDelete: + mutMap["delete"] = keyToJSON(mut.key) + + default: + c.logger.ErrorContext(ctx, "unknown mutation operation", "index", i, "op", mut.op) + return nil, fmt.Errorf("unknown mutation operation at index %d: %s", i, mut.op) + } + + mutations = append(mutations, mutMap) + } + + reqBody := map[string]any{ + "mutations": mutations, + } + if c.databaseID != "" { + reqBody["databaseId"] = c.databaseID + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + c.logger.ErrorContext(ctx, "failed to marshal request", "error", err) + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + // URL-encode project ID to prevent injection attacks + reqURL := fmt.Sprintf("%s/projects/%s:commit", apiURL, neturl.PathEscape(c.projectID)) + body, err := doRequest(ctx, c.logger, reqURL, jsonData, token, c.projectID, c.databaseID) + if err != nil { + c.logger.ErrorContext(ctx, "mutate request failed", "error", err) + return nil, err + } + + var resp struct { + MutationResults []struct { + Key map[string]any `json:"key"` + } `json:"mutationResults"` + } + if err := json.Unmarshal(body, &resp); err != nil { + c.logger.ErrorContext(ctx, "failed to parse response", "error", err) + return nil, fmt.Errorf("failed to parse mutate response: %w", err) + } + + // Extract resulting keys + keys := make([]*Key, len(resp.MutationResults)) + for i, result := range resp.MutationResults { + if result.Key != nil { + key, err := keyFromJSON(result.Key) + if err != nil { + c.logger.ErrorContext(ctx, "failed to parse key", "index", i, "error", err) + return nil, fmt.Errorf("failed to parse key at index %d: %w", i, err) + } + keys[i] = key + } else { + // For deletes, use the original key + keys[i] = muts[i].key + } + } + + c.logger.DebugContext(ctx, "mutations applied successfully", "count", len(keys)) + return keys, nil +} + // keyFromJSON converts a JSON key representation to a Key. func keyFromJSON(keyData any) (*Key, error) { keyMap, ok := keyData.(map[string]any) @@ -1098,29 +2101,36 @@ func keyFromJSON(keyData any) (*Key, error) { return nil, errors.New("invalid key path") } - // Get the last path element (we only support simple keys) - lastElem, ok := path[len(path)-1].(map[string]any) - if !ok { - return nil, errors.New("invalid path element") - } + // Build key hierarchy from path elements + var key *Key + for _, elem := range path { + elemMap, ok := elem.(map[string]any) + if !ok { + return nil, errors.New("invalid path element") + } - key := &Key{} + newKey := &Key{ + Parent: key, + } - if kind, ok := lastElem["kind"].(string); ok { - key.Kind = kind - } + if kind, ok := elemMap["kind"].(string); ok { + newKey.Kind = kind + } - if name, ok := lastElem["name"].(string); ok { - key.Name = name - } else if idVal, exists := lastElem["id"]; exists { - switch id := idVal.(type) { - case string: - if _, err := fmt.Sscanf(id, "%d", &key.ID); err != nil { - return nil, fmt.Errorf("invalid ID format: %w", err) + if name, ok := elemMap["name"].(string); ok { + newKey.Name = name + } else if idVal, exists := elemMap["id"]; exists { + switch id := idVal.(type) { + case string: + if _, err := fmt.Sscanf(id, "%d", &newKey.ID); err != nil { + return nil, fmt.Errorf("invalid ID format: %w", err) + } + case float64: + newKey.ID = int64(id) } - case float64: - key.ID = int64(id) } + + key = newKey } return key, nil @@ -1176,6 +2186,98 @@ func WithReadTime(t time.Time) TransactionOption { return readTimeOption{t: t} } +// NewTransaction creates a new transaction. +// The caller must call Commit or Rollback when done. +// API compatible with cloud.google.com/go/datastore. +func (c *Client) NewTransaction(ctx context.Context, opts ...TransactionOption) (*Transaction, error) { + settings := transactionSettings{ + maxAttempts: 3, // default (not used for NewTransaction, but kept for consistency) + } + for _, opt := range opts { + opt.apply(&settings) + } + + token, err := auth.AccessToken(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get access token: %w", err) + } + + // Begin transaction + reqBody := map[string]any{} + if c.databaseID != "" { + reqBody["databaseId"] = c.databaseID + } + + // Add transaction options if needed + if !settings.readTime.IsZero() { + reqBody["transactionOptions"] = map[string]any{ + "readOnly": map[string]any{ + "readTime": settings.readTime.Format(time.RFC3339Nano), + }, + } + } else { + reqBody["transactionOptions"] = map[string]any{ + "readWrite": map[string]any{}, + } + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, err + } + + // URL-encode project ID to prevent injection attacks + reqURL := fmt.Sprintf("%s/projects/%s:beginTransaction", apiURL, neturl.PathEscape(c.projectID)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqURL, bytes.NewReader(jsonData)) + if err != nil { + return nil, err + } + + req.Header.Set("Authorization", "Bearer "+token) + req.Header.Set("Content-Type", "application/json") + + // Add routing header for named databases + if c.databaseID != "" { + // URL-encode values to prevent header injection attacks + routingHeader := fmt.Sprintf("project_id=%s&database_id=%s", neturl.QueryEscape(c.projectID), neturl.QueryEscape(c.databaseID)) + req.Header.Set("X-Goog-Request-Params", routingHeader) + } + + resp, err := httpClient.Do(req) + if err != nil { + return nil, err + } + + body, err := io.ReadAll(io.LimitReader(resp.Body, maxBodySize)) + closeErr := resp.Body.Close() + if closeErr != nil { + c.logger.Warn("failed to close response body", "error", closeErr) + } + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("begin transaction failed with status %d: %s", resp.StatusCode, string(body)) + } + + var txResp struct { + Transaction string `json:"transaction"` + } + + if err := json.Unmarshal(body, &txResp); err != nil { + return nil, fmt.Errorf("failed to parse transaction response: %w", err) + } + + tx := &Transaction{ + ctx: ctx, + client: c, + id: txResp.Transaction, + } + + return tx, nil +} + // RunInTransaction runs a function in a transaction. // The function should use the transaction's Get and Put methods. // API compatible with cloud.google.com/go/datastore. @@ -1275,7 +2377,7 @@ func (c *Client) RunInTransaction(ctx context.Context, f func(*Transaction) erro } // Commit the transaction - err = tx.commit(ctx, token) + err = tx.doCommit(ctx, token) if err == nil { c.logger.Debug("transaction committed successfully", "attempt", attempt+1) return &Commit{}, nil // Success @@ -1423,8 +2525,199 @@ func (tx *Transaction) Put(key *Key, src any) (*Key, error) { return key, nil } +// Delete deletes an entity within the transaction. +// API compatible with cloud.google.com/go/datastore. +func (tx *Transaction) Delete(key *Key) error { + if key == nil { + return errors.New("key cannot be nil") + } + + // Create delete mutation + mutation := map[string]any{ + "delete": keyToJSON(key), + } + + // Accumulate mutation for commit + tx.mutations = append(tx.mutations, mutation) + + return nil +} + +// DeleteMulti deletes multiple entities within the transaction. +// API compatible with cloud.google.com/go/datastore. +func (tx *Transaction) DeleteMulti(keys []*Key) error { + for _, key := range keys { + if err := tx.Delete(key); err != nil { + return err + } + } + return nil +} + +// GetMulti retrieves multiple entities within the transaction. +// API compatible with cloud.google.com/go/datastore. +func (tx *Transaction) GetMulti(keys []*Key, dst any) error { + dstVal := reflect.ValueOf(dst) + if dstVal.Kind() != reflect.Ptr || dstVal.Elem().Kind() != reflect.Slice { + return errors.New("dst must be a pointer to a slice") + } + + slice := dstVal.Elem() + if len(keys) != slice.Len() { + return fmt.Errorf("keys and dst slices must have same length: %d vs %d", len(keys), slice.Len()) + } + + // Get each entity individually within the transaction + for i, key := range keys { + elem := slice.Index(i) + if elem.Kind() == reflect.Ptr { + // dst is []*Entity + if elem.IsNil() { + elem.Set(reflect.New(elem.Type().Elem())) + } + if err := tx.Get(key, elem.Interface()); err != nil { + return err + } + } else { + // dst is []Entity + if err := tx.Get(key, elem.Addr().Interface()); err != nil { + return err + } + } + } + + return nil +} + +// PutMulti stores multiple entities within the transaction. +// API compatible with cloud.google.com/go/datastore. +func (tx *Transaction) PutMulti(keys []*Key, src any) ([]*Key, error) { + srcVal := reflect.ValueOf(src) + if srcVal.Kind() != reflect.Slice { + return nil, errors.New("src must be a slice") + } + + if len(keys) != srcVal.Len() { + return nil, fmt.Errorf("keys and src slices must have same length: %d vs %d", len(keys), srcVal.Len()) + } + + // Put each entity individually within the transaction + for i, key := range keys { + elem := srcVal.Index(i) + var src any + if elem.Kind() == reflect.Ptr { + src = elem.Interface() + } else { + src = elem.Addr().Interface() + } + + if _, err := tx.Put(key, src); err != nil { + return nil, err + } + } + + return keys, nil +} + +// Commit applies the transaction's mutations. +// API compatible with cloud.google.com/go/datastore. +func (tx *Transaction) Commit() (*Commit, error) { + token, err := auth.AccessToken(tx.ctx) + if err != nil { + return nil, fmt.Errorf("failed to get access token: %w", err) + } + + if err := tx.doCommit(tx.ctx, token); err != nil { + return nil, err + } + + return &Commit{}, nil +} + +// Rollback abandons the transaction. +// API compatible with cloud.google.com/go/datastore. +func (tx *Transaction) Rollback() error { + // Datastore transactions are automatically rolled back if not committed + // So we just need to clear the mutations to prevent accidental commit + tx.mutations = nil + return nil +} + +// Mutate adds one or more mutations to the transaction. +// API compatible with cloud.google.com/go/datastore. +func (tx *Transaction) Mutate(muts ...*Mutation) ([]*PendingKey, error) { + if len(muts) == 0 { + return nil, nil + } + + // Build mutations array + pendingKeys := make([]*PendingKey, 0, len(muts)) + for i, mut := range muts { + if mut == nil { + return nil, fmt.Errorf("mutation at index %d is nil", i) + } + if mut.key == nil { + return nil, fmt.Errorf("mutation at index %d has nil key", i) + } + + mutMap := make(map[string]any) + + switch mut.op { + case MutationInsert: + if mut.entity == nil { + return nil, fmt.Errorf("insert mutation at index %d has nil entity", i) + } + entity, err := encodeEntity(mut.key, mut.entity) + if err != nil { + return nil, fmt.Errorf("failed to encode entity at index %d: %w", i, err) + } + mutMap["insert"] = entity + + case MutationUpdate: + if mut.entity == nil { + return nil, fmt.Errorf("update mutation at index %d has nil entity", i) + } + entity, err := encodeEntity(mut.key, mut.entity) + if err != nil { + return nil, fmt.Errorf("failed to encode entity at index %d: %w", i, err) + } + mutMap["update"] = entity + + case MutationUpsert: + if mut.entity == nil { + return nil, fmt.Errorf("upsert mutation at index %d has nil entity", i) + } + entity, err := encodeEntity(mut.key, mut.entity) + if err != nil { + return nil, fmt.Errorf("failed to encode entity at index %d: %w", i, err) + } + mutMap["upsert"] = entity + + case MutationDelete: + mutMap["delete"] = keyToJSON(mut.key) + + default: + return nil, fmt.Errorf("unknown mutation operation at index %d: %s", i, mut.op) + } + + tx.mutations = append(tx.mutations, mutMap) + + // Create a pending key for the result + pk := &PendingKey{key: mut.key} + pendingKeys = append(pendingKeys, pk) + } + + return pendingKeys, nil +} + +// PendingKey represents a key that will be resolved after a transaction commit. +// API compatible with cloud.google.com/go/datastore. +type PendingKey struct { + key *Key +} + // commit commits the transaction. -func (tx *Transaction) commit(ctx context.Context, token string) error { +func (tx *Transaction) doCommit(ctx context.Context, token string) error { reqBody := map[string]any{ "mode": "TRANSACTIONAL", "transaction": tx.id, diff --git a/datastore_test.go b/datastore_test.go index 806475f..fe300c7 100644 --- a/datastore_test.go +++ b/datastore_test.go @@ -17,12 +17,12 @@ import ( // testEntity represents a simple test entity. type testEntity struct { + UpdatedAt time.Time `datastore:"updated_at"` Name string `datastore:"name"` + Notes string `datastore:"notes,noindex"` Count int64 `datastore:"count"` - Active bool `datastore:"active"` Score float64 `datastore:"score"` - UpdatedAt time.Time `datastore:"updated_at"` - Notes string `datastore:"notes,noindex"` + Active bool `datastore:"active"` } func TestNewClient(t *testing.T) { @@ -707,15 +707,15 @@ func TestEntityWithAllTypes(t *testing.T) { ctx := context.Background() type AllTypes struct { + TimeVal time.Time `datastore:"t"` StringVal string `datastore:"str"` + NoIndex string `datastore:"noindex,noindex"` + Skip string `datastore:"-"` Int64Val int64 `datastore:"i64"` - Int32Val int32 `datastore:"i32"` IntVal int `datastore:"i"` - BoolVal bool `datastore:"b"` Float64Val float64 `datastore:"f64"` - TimeVal time.Time `datastore:"t"` - NoIndex string `datastore:"noindex,noindex"` - Skip string `datastore:"-"` + Int32Val int32 `datastore:"i32"` + BoolVal bool `datastore:"b"` } now := time.Now().UTC().Truncate(time.Second) @@ -963,16 +963,16 @@ func TestUnsupportedEncodeType(t *testing.T) { ctx := context.Background() - // Entity with unsupported type (slice) + // Entity with unsupported type (map) type BadEntity struct { - Name string - Items []string // slices not supported + Name string + Data map[string]string // maps not supported } key := ds9.NameKey("TestKind", "bad", nil) entity := BadEntity{ - Name: "test", - Items: []string{"a", "b"}, + Name: "test", + Data: map[string]string{"key": "value"}, } _, err := client.Put(ctx, key, &entity) @@ -1042,9 +1042,9 @@ func TestEntityWithSkippedFields(t *testing.T) { type EntityWithSkip struct { Name string `datastore:"name"` - Count int64 `datastore:"count"` - Skipped string `datastore:"-"` // Should not be stored - private string // Should not be stored (unexported) + Skipped string `datastore:"-"` + private string + Count int64 `datastore:"count"` } key := ds9.NameKey("TestKind", "skip", nil) @@ -2425,14 +2425,14 @@ func TestDecodeValueEdgeCases(t *testing.T) { // Test with all basic types type ComplexEntity struct { + Time time.Time `datastore:"t"` String string `datastore:"s"` + NoIndex string `datastore:"n,noindex"` Int int `datastore:"i"` - Int32 int32 `datastore:"i32"` Int64 int64 `datastore:"i64"` Float float64 `datastore:"f"` + Int32 int32 `datastore:"i32"` Bool bool `datastore:"b"` - Time time.Time `datastore:"t"` - NoIndex string `datastore:"n,noindex"` } now := time.Now().UTC().Truncate(time.Second) @@ -3639,8 +3639,8 @@ func TestPutWithInvalidEntityStructure(t *testing.T) { // Entity with channel (unsupported type) type BadEntity struct { - Name string Ch chan int + Name string } key := ds9.NameKey("Test", "bad", nil) @@ -4299,8 +4299,8 @@ func TestPutMultiWithPartialEncode(t *testing.T) { // Mix of valid and invalid entities type MixedEntity struct { + Data any Name string - Data any // interface{} - may cause encoding issues } keys := []*ds9.Key{ @@ -7334,3 +7334,690 @@ func TestTransactionOptions(t *testing.T) { } }) } + +// Test entity with arrays for array/slice tests +type arrayEntity struct { + Strings []string `datastore:"strings"` + Ints []int64 `datastore:"ints"` + Floats []float64 `datastore:"floats"` + Bools []bool `datastore:"bools"` +} + +func TestArraySliceSupport(t *testing.T) { + client, cleanup := ds9mock.NewClient(t) + defer cleanup() + + ctx := context.Background() + + t.Run("StringSlice", func(t *testing.T) { + key := ds9.NameKey("ArrayTest", "strings", nil) + entity := &arrayEntity{ + Strings: []string{"hello", "world", "test"}, + } + + _, err := client.Put(ctx, key, entity) + if err != nil { + t.Fatalf("Put with string slice failed: %v", err) + } + + var result arrayEntity + if err := client.Get(ctx, key, &result); err != nil { + t.Fatalf("Get failed: %v", err) + } + + if len(result.Strings) != 3 { + t.Errorf("Expected 3 strings, got %d", len(result.Strings)) + } + if result.Strings[0] != "hello" || result.Strings[1] != "world" || result.Strings[2] != "test" { + t.Errorf("String slice values incorrect: %v", result.Strings) + } + }) + + t.Run("Int64Slice", func(t *testing.T) { + key := ds9.NameKey("ArrayTest", "ints", nil) + entity := &arrayEntity{ + Ints: []int64{1, 2, 3, 42, 100}, + } + + _, err := client.Put(ctx, key, entity) + if err != nil { + t.Fatalf("Put with int64 slice failed: %v", err) + } + + var result arrayEntity + if err := client.Get(ctx, key, &result); err != nil { + t.Fatalf("Get failed: %v", err) + } + + if len(result.Ints) != 5 { + t.Errorf("Expected 5 ints, got %d", len(result.Ints)) + } + if result.Ints[3] != 42 { + t.Errorf("Expected Ints[3] = 42, got %d", result.Ints[3]) + } + }) + + t.Run("Float64Slice", func(t *testing.T) { + key := ds9.NameKey("ArrayTest", "floats", nil) + entity := &arrayEntity{ + Floats: []float64{1.1, 2.2, 3.3}, + } + + _, err := client.Put(ctx, key, entity) + if err != nil { + t.Fatalf("Put with float64 slice failed: %v", err) + } + + var result arrayEntity + if err := client.Get(ctx, key, &result); err != nil { + t.Fatalf("Get failed: %v", err) + } + + if len(result.Floats) != 3 { + t.Errorf("Expected 3 floats, got %d", len(result.Floats)) + } + if result.Floats[0] != 1.1 { + t.Errorf("Expected Floats[0] = 1.1, got %f", result.Floats[0]) + } + }) + + t.Run("BoolSlice", func(t *testing.T) { + key := ds9.NameKey("ArrayTest", "bools", nil) + entity := &arrayEntity{ + Bools: []bool{true, false, true}, + } + + _, err := client.Put(ctx, key, entity) + if err != nil { + t.Fatalf("Put with bool slice failed: %v", err) + } + + var result arrayEntity + if err := client.Get(ctx, key, &result); err != nil { + t.Fatalf("Get failed: %v", err) + } + + if len(result.Bools) != 3 { + t.Errorf("Expected 3 bools, got %d", len(result.Bools)) + } + if result.Bools[0] != true || result.Bools[1] != false { + t.Errorf("Bool slice values incorrect: %v", result.Bools) + } + }) + + t.Run("EmptySlices", func(t *testing.T) { + key := ds9.NameKey("ArrayTest", "empty", nil) + entity := &arrayEntity{ + Strings: []string{}, + Ints: []int64{}, + } + + _, err := client.Put(ctx, key, entity) + if err != nil { + t.Fatalf("Put with empty slices failed: %v", err) + } + + var result arrayEntity + if err := client.Get(ctx, key, &result); err != nil { + t.Fatalf("Get failed: %v", err) + } + + if result.Strings == nil || len(result.Strings) != 0 { + t.Errorf("Expected empty string slice, got %v", result.Strings) + } + if result.Ints == nil || len(result.Ints) != 0 { + t.Errorf("Expected empty int slice, got %v", result.Ints) + } + }) + + t.Run("MixedArrays", func(t *testing.T) { + key := ds9.NameKey("ArrayTest", "mixed", nil) + entity := &arrayEntity{ + Strings: []string{"a", "b"}, + Ints: []int64{10, 20, 30}, + Floats: []float64{1.5}, + Bools: []bool{true, false, true, false}, + } + + _, err := client.Put(ctx, key, entity) + if err != nil { + t.Fatalf("Put with mixed arrays failed: %v", err) + } + + var result arrayEntity + if err := client.Get(ctx, key, &result); err != nil { + t.Fatalf("Get failed: %v", err) + } + + if len(result.Strings) != 2 || len(result.Ints) != 3 || len(result.Floats) != 1 || len(result.Bools) != 4 { + t.Errorf("Mixed array lengths incorrect: strings=%d, ints=%d, floats=%d, bools=%d", + len(result.Strings), len(result.Ints), len(result.Floats), len(result.Bools)) + } + }) +} + +func TestAllocateIDs(t *testing.T) { + client, cleanup := ds9mock.NewClient(t) + defer cleanup() + + ctx := context.Background() + + t.Run("AllocateIncompleteKeys", func(t *testing.T) { + keys := []*ds9.Key{ + ds9.IncompleteKey("Task", nil), + ds9.IncompleteKey("Task", nil), + ds9.IncompleteKey("Task", nil), + } + + allocated, err := client.AllocateIDs(ctx, keys) + if err != nil { + t.Fatalf("AllocateIDs failed: %v", err) + } + + if len(allocated) != 3 { + t.Errorf("Expected 3 allocated keys, got %d", len(allocated)) + } + + for i, key := range allocated { + if key.Incomplete() { + t.Errorf("Key %d is still incomplete", i) + } + if key.ID == 0 { + t.Errorf("Key %d has zero ID", i) + } + } + }) + + t.Run("AllocateMixedKeys", func(t *testing.T) { + keys := []*ds9.Key{ + ds9.NameKey("Task", "complete", nil), + ds9.IncompleteKey("Task", nil), + ds9.IDKey("Task", 123, nil), + ds9.IncompleteKey("Task", nil), + } + + allocated, err := client.AllocateIDs(ctx, keys) + if err != nil { + t.Fatalf("AllocateIDs with mixed keys failed: %v", err) + } + + if len(allocated) != 4 { + t.Errorf("Expected 4 keys, got %d", len(allocated)) + } + + // First key should still be the named key + if allocated[0].Name != "complete" { + t.Errorf("First key should be unchanged") + } + + // Second key should now have an ID + if allocated[1].Incomplete() { + t.Errorf("Second key should be allocated") + } + + // Third key should be unchanged + if allocated[2].ID != 123 { + t.Errorf("Third key should be unchanged") + } + + // Fourth key should now have an ID + if allocated[3].Incomplete() { + t.Errorf("Fourth key should be allocated") + } + }) + + t.Run("AllocateEmptySlice", func(t *testing.T) { + keys := []*ds9.Key{} + + allocated, err := client.AllocateIDs(ctx, keys) + if err != nil { + t.Fatalf("AllocateIDs with empty slice failed: %v", err) + } + + if len(allocated) != 0 { + t.Errorf("Expected empty slice, got %d keys", len(allocated)) + } + }) + + t.Run("AllocateAllCompleteKeys", func(t *testing.T) { + keys := []*ds9.Key{ + ds9.NameKey("Task", "key1", nil), + ds9.IDKey("Task", 100, nil), + } + + allocated, err := client.AllocateIDs(ctx, keys) + if err != nil { + t.Fatalf("AllocateIDs with complete keys failed: %v", err) + } + + if len(allocated) != 2 { + t.Errorf("Expected 2 keys, got %d", len(allocated)) + } + + // Keys should be unchanged + if allocated[0].Name != "key1" || allocated[1].ID != 100 { + t.Errorf("Complete keys should be unchanged") + } + }) +} + +func TestCount(t *testing.T) { + client, cleanup := ds9mock.NewClient(t) + defer cleanup() + + ctx := context.Background() + + t.Run("CountEmptyKind", func(t *testing.T) { + q := ds9.NewQuery("NonExistent") + count, err := client.Count(ctx, q) + if err != nil { + t.Fatalf("Count failed: %v", err) + } + + if count != 0 { + t.Errorf("Expected count 0, got %d", count) + } + }) + + t.Run("CountWithEntities", func(t *testing.T) { + // Create some entities + for i := range 5 { + key := ds9.IDKey("CountTest", int64(i+1), nil) + entity := &testEntity{ + Name: fmt.Sprintf("entity-%d", i), + Count: int64(i), + } + if _, err := client.Put(ctx, key, entity); err != nil { + t.Fatalf("Put failed: %v", err) + } + } + + q := ds9.NewQuery("CountTest") + count, err := client.Count(ctx, q) + if err != nil { + t.Fatalf("Count failed: %v", err) + } + + if count != 5 { + t.Errorf("Expected count 5, got %d", count) + } + }) + + t.Run("CountWithFilter", func(t *testing.T) { + // Create entities with different counts + for i := range 10 { + key := ds9.IDKey("FilterCount", int64(i+1), nil) + entity := &testEntity{ + Name: fmt.Sprintf("entity-%d", i), + Count: int64(i), + } + if _, err := client.Put(ctx, key, entity); err != nil { + t.Fatalf("Put failed: %v", err) + } + } + + // Count entities where count >= 5 + q := ds9.NewQuery("FilterCount").Filter("count >=", int64(5)) + count, err := client.Count(ctx, q) + if err != nil { + t.Fatalf("Count with filter failed: %v", err) + } + + // Should return entities with count 5,6,7,8,9 = 5 entities + if count != 5 { + t.Errorf("Expected count 5, got %d", count) + } + }) + + t.Run("CountWithLimit", func(t *testing.T) { + // Create entities + for i := range 10 { + key := ds9.IDKey("LimitCount", int64(i+1), nil) + entity := &testEntity{ + Name: fmt.Sprintf("entity-%d", i), + Count: int64(i), + } + if _, err := client.Put(ctx, key, entity); err != nil { + t.Fatalf("Put failed: %v", err) + } + } + + // Count with limit - note: count should respect limit + q := ds9.NewQuery("LimitCount").Limit(3) + count, err := client.Count(ctx, q) + if err != nil { + t.Fatalf("Count with limit failed: %v", err) + } + + // Mock implementation may return full count, but limit is respected + if count > 10 { + t.Errorf("Count should not exceed actual entities: %d", count) + } + }) +} + +func TestQueryNamespace(t *testing.T) { + client, cleanup := ds9mock.NewClient(t) + defer cleanup() + + ctx := context.Background() + + t.Run("NamespaceFilter", func(t *testing.T) { + // Note: ds9mock may not fully support namespaces, but we test the API + q := ds9.NewQuery("Task").Namespace("custom-namespace") + + var entities []testEntity + _, err := client.GetAll(ctx, q, &entities) + // Should not error even if namespace is not supported by mock + if err != nil { + t.Logf("GetAll with namespace: %v", err) + } + }) + + t.Run("EmptyNamespace", func(t *testing.T) { + q := ds9.NewQuery("Task").Namespace("") + + var entities []testEntity + _, err := client.GetAll(ctx, q, &entities) + if err != nil { + t.Logf("GetAll with empty namespace: %v", err) + } + }) +} + +func TestQueryDistinct(t *testing.T) { + client, cleanup := ds9mock.NewClient(t) + defer cleanup() + + ctx := context.Background() + + t.Run("Distinct", func(t *testing.T) { + // Create duplicate entities + for i := range 3 { + key := ds9.IDKey("DistinctTest", int64(i+1), nil) + entity := &testEntity{ + Name: "same-name", // Same name for all + Count: int64(i), + } + if _, err := client.Put(ctx, key, entity); err != nil { + t.Fatalf("Put failed: %v", err) + } + } + + // Query with distinct on Name field + q := ds9.NewQuery("DistinctTest").Project("name").Distinct() + + var entities []testEntity + _, err := client.GetAll(ctx, q, &entities) + if err != nil { + t.Logf("GetAll with Distinct: %v", err) + } + }) + + t.Run("DistinctOn", func(t *testing.T) { + q := ds9.NewQuery("Task").DistinctOn("name", "count") + + var entities []testEntity + _, err := client.GetAll(ctx, q, &entities) + if err != nil { + t.Logf("GetAll with DistinctOn: %v", err) + } + }) +} + +func TestIterator(t *testing.T) { + client, cleanup := ds9mock.NewClient(t) + defer cleanup() + + ctx := context.Background() + + t.Run("IterateAll", func(t *testing.T) { + // Create test entities + for i := range 5 { + key := ds9.IDKey("IterTest", int64(i+1), nil) + entity := &testEntity{ + Name: fmt.Sprintf("entity-%d", i), + Count: int64(i), + } + if _, err := client.Put(ctx, key, entity); err != nil { + t.Fatalf("Put failed: %v", err) + } + } + + q := ds9.NewQuery("IterTest") + it := client.Run(ctx, q) + + count := 0 + for { + var entity testEntity + key, err := it.Next(&entity) + if errors.Is(err, ds9.ErrDone) { + break + } + if err != nil { + t.Fatalf("Iterator.Next failed: %v", err) + } + if key == nil { + t.Errorf("Expected non-nil key") + } + count++ + } + + if count != 5 { + t.Errorf("Expected to iterate over 5 entities, got %d", count) + } + }) + + t.Run("IteratorCursor", func(t *testing.T) { + // Create test entities + for i := range 3 { + key := ds9.IDKey("CursorTest", int64(i+1), nil) + entity := &testEntity{ + Name: fmt.Sprintf("entity-%d", i), + Count: int64(i), + } + if _, err := client.Put(ctx, key, entity); err != nil { + t.Fatalf("Put failed: %v", err) + } + } + + q := ds9.NewQuery("CursorTest") + it := client.Run(ctx, q) + + var entity testEntity + _, err := it.Next(&entity) + if err != nil { + t.Fatalf("Iterator.Next failed: %v", err) + } + + // Get cursor after first entity + cursor, err := it.Cursor() + if err != nil { + t.Logf("Cursor not available: %v", err) + } else if cursor == "" { + t.Logf("Empty cursor returned") + } + }) + + t.Run("EmptyIterator", func(t *testing.T) { + q := ds9.NewQuery("NonExistent") + it := client.Run(ctx, q) + + var entity testEntity + _, err := it.Next(&entity) + if !errors.Is(err, ds9.ErrDone) { + t.Errorf("Expected ErrDone, got %v", err) + } + }) +} + +func TestMutate(t *testing.T) { + client, cleanup := ds9mock.NewClient(t) + defer cleanup() + + ctx := context.Background() + + t.Run("MutateInsert", func(t *testing.T) { + key := ds9.NameKey("MutateTest", "insert", nil) + entity := &testEntity{ + Name: "inserted", + Count: 42, + } + + mut := ds9.NewInsert(key, entity) + keys, err := client.Mutate(ctx, mut) + if err != nil { + t.Fatalf("Mutate insert failed: %v", err) + } + + if len(keys) != 1 { + t.Errorf("Expected 1 key, got %d", len(keys)) + } + + // Verify entity was created + var result testEntity + if err := client.Get(ctx, key, &result); err != nil { + t.Fatalf("Get after insert failed: %v", err) + } + if result.Name != "inserted" { + t.Errorf("Expected Name 'inserted', got '%s'", result.Name) + } + }) + + t.Run("MutateUpdate", func(t *testing.T) { + key := ds9.NameKey("MutateTest", "update", nil) + entity := &testEntity{Name: "original", Count: 1} + + // Create entity first + if _, err := client.Put(ctx, key, entity); err != nil { + t.Fatalf("Put failed: %v", err) + } + + // Update via mutation + updated := &testEntity{Name: "updated", Count: 2} + mut := ds9.NewUpdate(key, updated) + _, err := client.Mutate(ctx, mut) + if err != nil { + t.Fatalf("Mutate update failed: %v", err) + } + + // Verify entity was updated + var result testEntity + if err := client.Get(ctx, key, &result); err != nil { + t.Fatalf("Get after update failed: %v", err) + } + if result.Name != "updated" { + t.Errorf("Expected Name 'updated', got '%s'", result.Name) + } + }) + + t.Run("MutateUpsert", func(t *testing.T) { + key := ds9.NameKey("MutateTest", "upsert", nil) + entity := &testEntity{Name: "upserted", Count: 100} + + mut := ds9.NewUpsert(key, entity) + keys, err := client.Mutate(ctx, mut) + if err != nil { + t.Fatalf("Mutate upsert failed: %v", err) + } + + if len(keys) != 1 { + t.Errorf("Expected 1 key, got %d", len(keys)) + } + + // Verify entity exists + var result testEntity + if err := client.Get(ctx, key, &result); err != nil { + t.Fatalf("Get after upsert failed: %v", err) + } + if result.Name != "upserted" { + t.Errorf("Expected Name 'upserted', got '%s'", result.Name) + } + }) + + t.Run("MutateDelete", func(t *testing.T) { + key := ds9.NameKey("MutateTest", "delete", nil) + entity := &testEntity{Name: "to-delete", Count: 1} + + // Create entity first + if _, err := client.Put(ctx, key, entity); err != nil { + t.Fatalf("Put failed: %v", err) + } + + // Delete via mutation + mut := ds9.NewDelete(key) + keys, err := client.Mutate(ctx, mut) + if err != nil { + t.Fatalf("Mutate delete failed: %v", err) + } + + if len(keys) != 1 { + t.Errorf("Expected 1 key, got %d", len(keys)) + } + + // Verify entity was deleted + var result testEntity + err = client.Get(ctx, key, &result) + if !errors.Is(err, ds9.ErrNoSuchEntity) { + t.Errorf("Expected ErrNoSuchEntity after delete, got %v", err) + } + }) + + t.Run("MutateMultiple", func(t *testing.T) { + key1 := ds9.NameKey("MutateTest", "multi1", nil) + key2 := ds9.NameKey("MutateTest", "multi2", nil) + key3 := ds9.NameKey("MutateTest", "multi3", nil) + + entity1 := &testEntity{Name: "first", Count: 1} + entity2 := &testEntity{Name: "second", Count: 2} + entity3 := &testEntity{Name: "third", Count: 3} + + // Pre-create entity3 for update + if _, err := client.Put(ctx, key3, &testEntity{Name: "old", Count: 0}); err != nil { + t.Fatalf("Put failed: %v", err) + } + + // Apply multiple mutations + muts := []*ds9.Mutation{ + ds9.NewInsert(key1, entity1), + ds9.NewUpsert(key2, entity2), + ds9.NewUpdate(key3, entity3), + } + + keys, err := client.Mutate(ctx, muts...) + if err != nil { + t.Fatalf("Mutate multiple failed: %v", err) + } + + if len(keys) != 3 { + t.Errorf("Expected 3 keys, got %d", len(keys)) + } + + // Verify all mutations applied + var result1, result2, result3 testEntity + if err := client.Get(ctx, key1, &result1); err != nil { + t.Errorf("Get key1 failed: %v", err) + } + if err := client.Get(ctx, key2, &result2); err != nil { + t.Errorf("Get key2 failed: %v", err) + } + if err := client.Get(ctx, key3, &result3); err != nil { + t.Errorf("Get key3 failed: %v", err) + } + + if result1.Name != "first" || result2.Name != "second" || result3.Name != "third" { + t.Errorf("Mutation results incorrect") + } + }) + + t.Run("MutateEmpty", func(t *testing.T) { + keys, err := client.Mutate(ctx) + if err != nil { + t.Fatalf("Mutate with no mutations failed: %v", err) + } + + if keys != nil && len(keys) != 0 { + t.Errorf("Expected nil or empty keys, got %d", len(keys)) + } + }) +} diff --git a/ds9mock/mock.go b/ds9mock/mock.go index cdf60af..cfa5fc4 100644 --- a/ds9mock/mock.go +++ b/ds9mock/mock.go @@ -18,9 +18,12 @@ package ds9mock import ( "context" "encoding/json" + "fmt" "log" "net/http" "net/http/httptest" + "strconv" + "sync" "testing" "github.com/codeGROOVE-dev/ds9" @@ -30,6 +33,7 @@ const metadataFlavor = "Google" // Store holds the in-memory entity storage. type Store struct { + mu sync.RWMutex entities map[string]map[string]any } @@ -106,6 +110,16 @@ func NewClient(t *testing.T) (client *ds9.Client, cleanup func()) { return } + if r.URL.Path == "/projects/test-project:allocateIds" { + store.handleAllocateIDs(w, r) + return + } + + if r.URL.Path == "/projects/test-project:runAggregationQuery" { + store.handleRunAggregationQuery(w, r) + return + } + w.WriteHeader(http.StatusNotFound) })) @@ -164,6 +178,9 @@ func (s *Store) handleLookup(w http.ResponseWriter, r *http.Request) { var found []map[string]any var missing []map[string]any + s.mu.RLock() + defer s.mu.RUnlock() + for _, keyData := range req.Keys { path, ok := keyData["path"].([]any) if !ok { @@ -245,7 +262,80 @@ func (s *Store) handleCommit(w http.ResponseWriter, r *http.Request) { } } + s.mu.Lock() + defer s.mu.Unlock() + + var mutationResults []map[string]any + for _, mutation := range req.Mutations { + var resultKey map[string]any + + // Handle insert + if insert, ok := mutation["insert"].(map[string]any); ok { + keyData, ok := insert["key"].(map[string]any) + if !ok { + continue + } + path, ok := keyData["path"].([]any) + if !ok || len(path) == 0 { + continue + } + pathElem, ok := path[0].(map[string]any) + if !ok { + continue + } + kind, ok := pathElem["kind"].(string) + if !ok { + continue + } + + // Handle both name and ID keys + var keyStr string + if name, ok := pathElem["name"].(string); ok { + keyStr = kind + "/" + name + } else if id, ok := pathElem["id"].(string); ok { + keyStr = kind + "/" + id + } else { + continue + } + + s.entities[keyStr] = insert + resultKey = keyData + } + + // Handle update + if update, ok := mutation["update"].(map[string]any); ok { + keyData, ok := update["key"].(map[string]any) + if !ok { + continue + } + path, ok := keyData["path"].([]any) + if !ok || len(path) == 0 { + continue + } + pathElem, ok := path[0].(map[string]any) + if !ok { + continue + } + kind, ok := pathElem["kind"].(string) + if !ok { + continue + } + + // Handle both name and ID keys + var keyStr string + if name, ok := pathElem["name"].(string); ok { + keyStr = kind + "/" + name + } else if id, ok := pathElem["id"].(string); ok { + keyStr = kind + "/" + id + } else { + continue + } + + s.entities[keyStr] = update + resultKey = keyData + } + // Handle upsert if upsert, ok := mutation["upsert"].(map[string]any); ok { keyData, ok := upsert["key"].(map[string]any) @@ -276,6 +366,7 @@ func (s *Store) handleCommit(w http.ResponseWriter, r *http.Request) { } s.entities[keyStr] = upsert + resultKey = keyData } // Handle delete @@ -304,13 +395,21 @@ func (s *Store) handleCommit(w http.ResponseWriter, r *http.Request) { } delete(s.entities, keyStr) + resultKey = deleteKey + } + + // Add mutation result + if resultKey != nil { + mutationResults = append(mutationResults, map[string]any{ + "key": resultKey, + }) } } w.WriteHeader(http.StatusOK) w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(map[string]any{ - "mutationResults": []any{}, + "mutationResults": mutationResults, }); err != nil { log.Printf("failed to encode commit response: %v", err) } @@ -370,6 +469,9 @@ func (s *Store) handleRunQuery(w http.ResponseWriter, r *http.Request) { } // Find all entities of this kind + s.mu.RLock() + defer s.mu.RUnlock() + var results []any for _, entity := range s.entities { keyData, ok := entity["key"].(map[string]any) @@ -449,3 +551,303 @@ func handleBeginTransaction(w http.ResponseWriter, r *http.Request) { log.Printf("failed to encode transaction response: %v", err) } } + +// handleAllocateIDs handles :allocateIds requests. +func (s *Store) handleAllocateIDs(w http.ResponseWriter, r *http.Request) { + var req struct { + DatabaseID string `json:"databaseId"` + Keys []map[string]any `json:"keys"` + } + + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + + // Validate routing header for named databases + if req.DatabaseID != "" { + routingHeader := r.Header.Get("X-Goog-Request-Params") + if routingHeader == "" { + w.WriteHeader(http.StatusBadRequest) + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "error": map[string]any{ + "code": 400, + "message": "Missing routing header for named database", + "status": "INVALID_ARGUMENT", + }, + }); err != nil { + log.Printf("failed to encode error response: %v", err) + } + return + } + } + + // Allocate IDs for incomplete keys + allocatedKeys := make([]map[string]any, 0, len(req.Keys)) + for _, keyData := range req.Keys { + // Parse path to check if incomplete + path, ok := keyData["path"].([]any) + if !ok || len(path) == 0 { + continue + } + + // Get last element + lastElem, ok := path[len(path)-1].(map[string]any) + if !ok { + continue + } + + // If it has no name or id, allocate an ID + _, hasName := lastElem["name"] + _, hasID := lastElem["id"] + if !hasName && !hasID { + // Allocate a simple sequential ID + lastElem["id"] = "1001" // Simple mock ID + } + + allocatedKeys = append(allocatedKeys, keyData) + } + + w.WriteHeader(http.StatusOK) + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "keys": allocatedKeys, + }); err != nil { + log.Printf("failed to encode allocateIds response: %v", err) + } +} + +// matchesFilter checks if an entity matches a filter. +func matchesFilter(entity map[string]any, filterMap map[string]any) bool { + // Handle propertyFilter + if propFilter, ok := filterMap["propertyFilter"].(map[string]any); ok { + property, ok := propFilter["property"].(map[string]any) + if !ok { + return true // Invalid filter, allow all + } + propertyName, ok := property["name"].(string) + if !ok { + return true + } + operator, ok := propFilter["op"].(string) + if !ok { + return true + } + filterValue := propFilter["value"] + + // Get entity properties + properties, ok := entity["properties"].(map[string]any) + if !ok { + return false + } + entityProp, ok := properties[propertyName].(map[string]any) + if !ok { + return false // Property doesn't exist + } + + // Extract entity value based on type + var entityValue any + if intVal, ok := entityProp["integerValue"].(string); ok { + var i int64 + if _, err := fmt.Sscanf(intVal, "%d", &i); err == nil { + entityValue = i + } + } else if strVal, ok := entityProp["stringValue"].(string); ok { + entityValue = strVal + } else if boolVal, ok := entityProp["booleanValue"].(bool); ok { + entityValue = boolVal + } else if floatVal, ok := entityProp["doubleValue"].(float64); ok { + entityValue = floatVal + } + + // Extract filter value + var filterVal any + if fv, ok := filterValue.(map[string]any); ok { + if intVal, ok := fv["integerValue"].(string); ok { + var i int64 + if _, err := fmt.Sscanf(intVal, "%d", &i); err == nil { + filterVal = i + } + } else if strVal, ok := fv["stringValue"].(string); ok { + filterVal = strVal + } + } + + // Compare based on operator + switch operator { + case "EQUAL": + return entityValue == filterVal + case "GREATER_THAN": + if ev, ok := entityValue.(int64); ok { + if fv, ok := filterVal.(int64); ok { + return ev > fv + } + } + case "GREATER_THAN_OR_EQUAL": + if ev, ok := entityValue.(int64); ok { + if fv, ok := filterVal.(int64); ok { + return ev >= fv + } + } + case "LESS_THAN": + if ev, ok := entityValue.(int64); ok { + if fv, ok := filterVal.(int64); ok { + return ev < fv + } + } + case "LESS_THAN_OR_EQUAL": + if ev, ok := entityValue.(int64); ok { + if fv, ok := filterVal.(int64); ok { + return ev <= fv + } + } + } + } + + // Handle compositeFilter (AND/OR) + if compFilter, ok := filterMap["compositeFilter"].(map[string]any); ok { + op, ok := compFilter["op"].(string) + if !ok { + return true + } + filters, ok := compFilter["filters"].([]any) + if !ok { + return true + } + + if op == "AND" { + for _, f := range filters { + if fm, ok := f.(map[string]any); ok { + if !matchesFilter(entity, fm) { + return false + } + } + } + return true + } else if op == "OR" { + for _, f := range filters { + if fm, ok := f.(map[string]any); ok { + if matchesFilter(entity, fm) { + return true + } + } + } + return false + } + } + + return true // No filter or unrecognized filter, allow all +} + +// handleRunAggregationQuery handles :runAggregationQuery requests. +func (s *Store) handleRunAggregationQuery(w http.ResponseWriter, r *http.Request) { + var req struct { + DatabaseID string `json:"databaseId"` + AggregationQuery map[string]any `json:"aggregationQuery"` + } + + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + + // Validate routing header for named databases + if req.DatabaseID != "" { + routingHeader := r.Header.Get("X-Goog-Request-Params") + if routingHeader == "" { + w.WriteHeader(http.StatusBadRequest) + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "error": map[string]any{ + "code": 400, + "message": "Missing routing header for named database", + "status": "INVALID_ARGUMENT", + }, + }); err != nil { + log.Printf("failed to encode error response: %v", err) + } + return + } + } + + // Extract query from aggregationQuery + nestedQuery, ok := req.AggregationQuery["nestedQuery"].(map[string]any) + if !ok { + w.WriteHeader(http.StatusBadRequest) + return + } + + // Extract kind from query + kindArray, ok := nestedQuery["kind"].([]any) + if !ok || len(kindArray) == 0 { + w.WriteHeader(http.StatusBadRequest) + return + } + + kindMap, ok := kindArray[0].(map[string]any) + if !ok { + w.WriteHeader(http.StatusBadRequest) + return + } + + kind, ok := kindMap["name"].(string) + if !ok { + w.WriteHeader(http.StatusBadRequest) + return + } + + // Count entities of this kind in store + s.mu.RLock() + defer s.mu.RUnlock() + + count := 0 + // entities map is keyed by "kind/keyname", so we need to iterate + for keyStr, entity := range s.entities { + // Extract kind from entity's key + keyData, ok := entity["key"].(map[string]any) + if !ok { + continue + } + path, ok := keyData["path"].([]any) + if !ok || len(path) == 0 { + continue + } + pathElem, ok := path[0].(map[string]any) + if !ok { + continue + } + entityKind, ok := pathElem["kind"].(string) + if !ok || entityKind != kind { + continue + } + + // Apply filters if present + if filterMap, hasFilter := nestedQuery["filter"].(map[string]any); hasFilter { + if !matchesFilter(entity, filterMap) { + continue + } + } + + _ = keyStr + count++ + } + + w.WriteHeader(http.StatusOK) + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(map[string]any{ + "batch": map[string]any{ + "aggregationResults": []map[string]any{ + { + "aggregateProperties": map[string]any{ + "total": map[string]any{ + "integerValue": strconv.Itoa(count), + }, + }, + }, + }, + }, + }); err != nil { + log.Printf("failed to encode aggregation response: %v", err) + } +} diff --git a/integration_test.go b/integration_test.go index 5cceb37..50edfab 100644 --- a/integration_test.go +++ b/integration_test.go @@ -355,9 +355,9 @@ func TestIntegrationCleanup(t *testing.T) { // integrationEntity for integration tests type integrationEntity struct { + Timestamp time.Time `datastore:"timestamp"` Name string `datastore:"name"` Count int64 `datastore:"count"` - Timestamp time.Time `datastore:"timestamp"` } // TestIntegrationGetAll tests the GetAll method with real GCP or mock. diff --git a/key_test.go b/key_test.go new file mode 100644 index 0000000..0e7b25c --- /dev/null +++ b/key_test.go @@ -0,0 +1,272 @@ +package ds9 + +import ( + "testing" +) + +func TestKeyEqual(t *testing.T) { + tests := []struct { + key1 *Key + key2 *Key + name string + expected bool + }{ + { + name: "both nil", + key1: nil, + key2: nil, + expected: true, + }, + { + name: "one nil", + key1: NameKey("Kind", "name", nil), + key2: nil, + expected: false, + }, + { + name: "same name keys", + key1: NameKey("Kind", "name", nil), + key2: NameKey("Kind", "name", nil), + expected: true, + }, + { + name: "same ID keys", + key1: IDKey("Kind", 123, nil), + key2: IDKey("Kind", 123, nil), + expected: true, + }, + { + name: "different kinds", + key1: NameKey("Kind1", "name", nil), + key2: NameKey("Kind2", "name", nil), + expected: false, + }, + { + name: "different names", + key1: NameKey("Kind", "name1", nil), + key2: NameKey("Kind", "name2", nil), + expected: false, + }, + { + name: "different IDs", + key1: IDKey("Kind", 123, nil), + key2: IDKey("Kind", 456, nil), + expected: false, + }, + { + name: "with same parent", + key1: NameKey("Child", "c1", NameKey("Parent", "p1", nil)), + key2: NameKey("Child", "c1", NameKey("Parent", "p1", nil)), + expected: true, + }, + { + name: "with different parent", + key1: NameKey("Child", "c1", NameKey("Parent", "p1", nil)), + key2: NameKey("Child", "c1", NameKey("Parent", "p2", nil)), + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.key1.Equal(tt.key2) + if result != tt.expected { + t.Errorf("Expected %v, got %v", tt.expected, result) + } + }) + } +} + +func TestKeyIncomplete(t *testing.T) { + tests := []struct { + key *Key + name string + expected bool + }{ + { + name: "incomplete key", + key: IncompleteKey("Kind", nil), + expected: true, + }, + { + name: "name key", + key: NameKey("Kind", "name", nil), + expected: false, + }, + { + name: "ID key", + key: IDKey("Kind", 123, nil), + expected: false, + }, + { + name: "zero ID is incomplete", + key: &Key{Kind: "Kind", ID: 0, Name: ""}, + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.key.Incomplete() + if result != tt.expected { + t.Errorf("Expected %v, got %v", tt.expected, result) + } + }) + } +} + +func TestIncompleteKey(t *testing.T) { + key := IncompleteKey("TestKind", nil) + + if key.Kind != "TestKind" { + t.Errorf("Expected kind 'TestKind', got '%s'", key.Kind) + } + + if !key.Incomplete() { + t.Error("Expected key to be incomplete") + } + + if key.Parent != nil { + t.Error("Expected nil parent") + } +} + +func TestIncompleteKeyWithParent(t *testing.T) { + parent := NameKey("Parent", "p1", nil) + key := IncompleteKey("Child", parent) + + if key.Kind != "Child" { + t.Errorf("Expected kind 'Child', got '%s'", key.Kind) + } + + if !key.Incomplete() { + t.Error("Expected key to be incomplete") + } + + if !key.Parent.Equal(parent) { + t.Error("Expected parent to match") + } +} + +func TestKeyString(t *testing.T) { + tests := []struct { + name string + key *Key + expected string + }{ + { + name: "nil key", + key: nil, + expected: "", + }, + { + name: "simple name key", + key: NameKey("Kind", "name", nil), + expected: `/Kind,"name"`, + }, + { + name: "simple ID key", + key: IDKey("Kind", 123, nil), + expected: "/Kind,123", + }, + { + name: "incomplete key", + key: IncompleteKey("Kind", nil), + expected: "/Kind,incomplete", + }, + { + name: "hierarchical key", + key: NameKey("Child", "c1", NameKey("Parent", "p1", nil)), + expected: `/Parent,"p1"/Child,"c1"`, + }, + { + name: "deep hierarchy", + key: IDKey("GrandChild", 3, NameKey("Child", "c1", NameKey("Parent", "p1", nil))), + expected: `/Parent,"p1"/Child,"c1"/GrandChild,3`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.key.String() + if result != tt.expected { + t.Errorf("Expected %q, got %q", tt.expected, result) + } + }) + } +} + +func TestKeyEncodeDecode(t *testing.T) { + tests := []struct { + key *Key + name string + }{ + { + name: "name key", + key: NameKey("Kind", "name", nil), + }, + { + name: "ID key", + key: IDKey("Kind", 123, nil), + }, + { + name: "hierarchical key", + key: NameKey("Child", "c1", NameKey("Parent", "p1", nil)), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + encoded := tt.key.Encode() + if encoded == "" { + t.Fatal("Encode returned empty string") + } + + decoded, err := DecodeKey(encoded) + if err != nil { + t.Fatalf("DecodeKey failed: %v", err) + } + + if !decoded.Equal(tt.key) { + t.Errorf("Decoded key doesn't match original.\nOriginal: %s\nDecoded: %s", tt.key.String(), decoded.String()) + } + }) + } +} + +func TestDecodeKeyErrors(t *testing.T) { + tests := []struct { + name string + encoded string + }{ + { + name: "empty string", + encoded: "", + }, + { + name: "invalid base64", + encoded: "!!!invalid!!!", + }, + { + name: "invalid JSON", + encoded: "aW52YWxpZCBqc29u", // "invalid json" in base64 + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := DecodeKey(tt.encoded) + if err == nil { + t.Error("Expected error, got nil") + } + }) + } +} + +func TestKeyEncodeNil(t *testing.T) { + var key *Key + encoded := key.Encode() + if encoded != "" { + t.Errorf("Expected empty string for nil key, got %q", encoded) + } +} diff --git a/query_test.go b/query_test.go new file mode 100644 index 0000000..a2c17c5 --- /dev/null +++ b/query_test.go @@ -0,0 +1,290 @@ +package ds9 + +import ( + "testing" +) + +func TestQueryFilter(t *testing.T) { + q := NewQuery("TestKind").Filter("Count >", 10) + + if len(q.filters) != 1 { + t.Fatalf("Expected 1 filter, got %d", len(q.filters)) + } + + filter := q.filters[0] + if filter.property != "Count" { + t.Errorf("Expected property 'Count', got '%s'", filter.property) + } + if filter.operator != "GREATER_THAN" { + t.Errorf("Expected operator 'GREATER_THAN', got '%s'", filter.operator) + } + if filter.value != 10 { + t.Errorf("Expected value 10, got %v", filter.value) + } +} + +func TestQueryFilterField(t *testing.T) { + tests := []struct { + name string + field string + operator string + value any + expectedOperator string + }{ + { + name: "equal", + field: "Name", + operator: "=", + value: "test", + expectedOperator: "EQUAL", + }, + { + name: "less than", + field: "Count", + operator: "<", + value: 100, + expectedOperator: "LESS_THAN", + }, + { + name: "less than or equal", + field: "Count", + operator: "<=", + value: 100, + expectedOperator: "LESS_THAN_OR_EQUAL", + }, + { + name: "greater than", + field: "Count", + operator: ">", + value: 10, + expectedOperator: "GREATER_THAN", + }, + { + name: "greater than or equal", + field: "Count", + operator: ">=", + value: 10, + expectedOperator: "GREATER_THAN_OR_EQUAL", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + q := NewQuery("TestKind").FilterField(tt.field, tt.operator, tt.value) + + if len(q.filters) != 1 { + t.Fatalf("Expected 1 filter, got %d", len(q.filters)) + } + + filter := q.filters[0] + if filter.property != tt.field { + t.Errorf("Expected property '%s', got '%s'", tt.field, filter.property) + } + if filter.operator != tt.expectedOperator { + t.Errorf("Expected operator '%s', got '%s'", tt.expectedOperator, filter.operator) + } + if filter.value != tt.value { + t.Errorf("Expected value %v, got %v", tt.value, filter.value) + } + }) + } +} + +func TestQueryMultipleFilters(t *testing.T) { + q := NewQuery("TestKind"). + FilterField("Count", ">", 10). + FilterField("Name", "=", "test") + + if len(q.filters) != 2 { + t.Fatalf("Expected 2 filters, got %d", len(q.filters)) + } +} + +func TestQueryOrder(t *testing.T) { + tests := []struct { + name string + fieldName string + expectedProperty string + expectedDirection string + }{ + { + name: "ascending", + fieldName: "Count", + expectedProperty: "Count", + expectedDirection: "ASCENDING", + }, + { + name: "descending", + fieldName: "-Count", + expectedProperty: "Count", + expectedDirection: "DESCENDING", + }, + { + name: "ascending with hyphen in name", + fieldName: "created-at", + expectedProperty: "created-at", + expectedDirection: "ASCENDING", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + q := NewQuery("TestKind").Order(tt.fieldName) + + if len(q.orders) != 1 { + t.Fatalf("Expected 1 order, got %d", len(q.orders)) + } + + order := q.orders[0] + if order.property != tt.expectedProperty { + t.Errorf("Expected property '%s', got '%s'", tt.expectedProperty, order.property) + } + if order.direction != tt.expectedDirection { + t.Errorf("Expected direction '%s', got '%s'", tt.expectedDirection, order.direction) + } + }) + } +} + +func TestQueryMultipleOrders(t *testing.T) { + q := NewQuery("TestKind"). + Order("Name"). + Order("-Count") + + if len(q.orders) != 2 { + t.Fatalf("Expected 2 orders, got %d", len(q.orders)) + } +} + +func TestQueryOffset(t *testing.T) { + q := NewQuery("TestKind").Offset(10) + + if q.offset != 10 { + t.Errorf("Expected offset 10, got %d", q.offset) + } +} + +func TestQueryAncestor(t *testing.T) { + ancestor := NameKey("Parent", "p1", nil) + q := NewQuery("TestKind").Ancestor(ancestor) + + if !q.ancestor.Equal(ancestor) { + t.Error("Expected ancestor to match") + } +} + +func TestQueryProject(t *testing.T) { + q := NewQuery("TestKind").Project("Name", "Count") + + if len(q.projection) != 2 { + t.Fatalf("Expected 2 projected fields, got %d", len(q.projection)) + } + + if q.projection[0] != "Name" { + t.Errorf("Expected first projection 'Name', got '%s'", q.projection[0]) + } + if q.projection[1] != "Count" { + t.Errorf("Expected second projection 'Count', got '%s'", q.projection[1]) + } +} + +func TestQueryChaining(t *testing.T) { + q := NewQuery("TestKind"). + FilterField("Count", ">", 10). + Order("-Count"). + Limit(100). + Offset(20). + KeysOnly() + + if len(q.filters) != 1 { + t.Errorf("Expected 1 filter, got %d", len(q.filters)) + } + if len(q.orders) != 1 { + t.Errorf("Expected 1 order, got %d", len(q.orders)) + } + if q.limit != 100 { + t.Errorf("Expected limit 100, got %d", q.limit) + } + if q.offset != 20 { + t.Errorf("Expected offset 20, got %d", q.offset) + } + if !q.keysOnly { + t.Error("Expected keysOnly to be true") + } +} + +func TestBuildQueryMapBasic(t *testing.T) { + q := NewQuery("TestKind") + queryMap := buildQueryMap(q) + + kind, ok := queryMap["kind"].([]map[string]any) + if !ok || len(kind) == 0 { + t.Fatal("Expected kind in query map") + } + + if kind[0]["name"] != "TestKind" { + t.Errorf("Expected kind 'TestKind', got '%v'", kind[0]["name"]) + } +} + +func TestBuildQueryMapWithLimit(t *testing.T) { + q := NewQuery("TestKind").Limit(10) + queryMap := buildQueryMap(q) + + limit, ok := queryMap["limit"] + if !ok { + t.Fatal("Expected limit in query map") + } + + if limit != 10 { + t.Errorf("Expected limit 10, got %v", limit) + } +} + +func TestBuildQueryMapWithOffset(t *testing.T) { + q := NewQuery("TestKind").Offset(5) + queryMap := buildQueryMap(q) + + offset, ok := queryMap["offset"] + if !ok { + t.Fatal("Expected offset in query map") + } + + if offset != 5 { + t.Errorf("Expected offset 5, got %v", offset) + } +} + +func TestBuildQueryMapWithFilter(t *testing.T) { + q := NewQuery("TestKind").FilterField("Count", ">", 10) + queryMap := buildQueryMap(q) + + _, ok := queryMap["filter"] + if !ok { + t.Fatal("Expected filter in query map") + } +} + +func TestBuildQueryMapWithOrder(t *testing.T) { + q := NewQuery("TestKind").Order("-Count") + queryMap := buildQueryMap(q) + + orders, ok := queryMap["order"].([]map[string]any) + if !ok || len(orders) == 0 { + t.Fatal("Expected order in query map") + } + + if orders[0]["direction"] != "DESCENDING" { + t.Errorf("Expected DESCENDING, got %v", orders[0]["direction"]) + } +} + +func TestBuildQueryMapKeysOnly(t *testing.T) { + q := NewQuery("TestKind").KeysOnly() + queryMap := buildQueryMap(q) + + projection, ok := queryMap["projection"].([]map[string]any) + if !ok || len(projection) == 0 { + t.Fatal("Expected projection in query map for keys-only") + } +} diff --git a/transaction_test.go b/transaction_test.go new file mode 100644 index 0000000..5c094ad --- /dev/null +++ b/transaction_test.go @@ -0,0 +1,402 @@ +package ds9_test + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/codeGROOVE-dev/ds9" + "github.com/codeGROOVE-dev/ds9/ds9mock" +) + +type txTestEntity struct { + Time time.Time + Name string + Count int64 +} + +func TestTransactionDelete(t *testing.T) { + client, cleanup := ds9mock.NewClient(t) + defer cleanup() + + ctx := context.Background() + key := ds9.NameKey("TxDeleteTest", "test", nil) + + // Create an entity + entity := &txTestEntity{ + Name: "test", + Count: 42, + Time: time.Now().UTC().Truncate(time.Microsecond), + } + if _, err := client.Put(ctx, key, entity); err != nil { + t.Fatalf("Put failed: %v", err) + } + + // Delete in transaction + _, err := client.RunInTransaction(ctx, func(tx *ds9.Transaction) error { + return tx.Delete(key) + }) + if err != nil { + t.Fatalf("Transaction failed: %v", err) + } + + // Verify deletion + var result txTestEntity + err = client.Get(ctx, key, &result) + if !errors.Is(err, ds9.ErrNoSuchEntity) { + t.Errorf("Expected ds9.ErrNoSuchEntity, got %v", err) + } +} + +func TestTransactionDeleteMulti(t *testing.T) { + client, cleanup := ds9mock.NewClient(t) + defer cleanup() + + ctx := context.Background() + + // Create multiple entities + keys := []*ds9.Key{ + ds9.NameKey("TxDeleteMultiTest", "test1", nil), + ds9.NameKey("TxDeleteMultiTest", "test2", nil), + ds9.NameKey("TxDeleteMultiTest", "test3", nil), + } + + entities := []txTestEntity{ + {Name: "test1", Count: 1, Time: time.Now().UTC().Truncate(time.Microsecond)}, + {Name: "test2", Count: 2, Time: time.Now().UTC().Truncate(time.Microsecond)}, + {Name: "test3", Count: 3, Time: time.Now().UTC().Truncate(time.Microsecond)}, + } + + if _, err := client.PutMulti(ctx, keys, entities); err != nil { + t.Fatalf("PutMulti failed: %v", err) + } + + // Delete in transaction + _, err := client.RunInTransaction(ctx, func(tx *ds9.Transaction) error { + return tx.DeleteMulti(keys) + }) + if err != nil { + t.Fatalf("Transaction failed: %v", err) + } +} + +func TestTransactionGetMulti(t *testing.T) { + client, cleanup := ds9mock.NewClient(t) + defer cleanup() + + ctx := context.Background() + + // Create multiple entities + keys := []*ds9.Key{ + ds9.NameKey("TxGetMultiTest", "test1", nil), + ds9.NameKey("TxGetMultiTest", "test2", nil), + } + + entities := []txTestEntity{ + {Name: "test1", Count: 1, Time: time.Now().UTC().Truncate(time.Microsecond)}, + {Name: "test2", Count: 2, Time: time.Now().UTC().Truncate(time.Microsecond)}, + } + + if _, err := client.PutMulti(ctx, keys, entities); err != nil { + t.Fatalf("PutMulti failed: %v", err) + } + + // Get in transaction + _, err := client.RunInTransaction(ctx, func(tx *ds9.Transaction) error { + results := make([]txTestEntity, 2) + if err := tx.GetMulti(keys, &results); err != nil { + return err + } + + if results[0].Name != "test1" { + t.Errorf("Expected Name 'test1', got '%s'", results[0].Name) + } + if results[1].Name != "test2" { + t.Errorf("Expected Name 'test2', got '%s'", results[1].Name) + } + + return nil + }) + if err != nil { + t.Fatalf("Transaction failed: %v", err) + } +} + +func TestTransactionPutMulti(t *testing.T) { + client, cleanup := ds9mock.NewClient(t) + defer cleanup() + + ctx := context.Background() + + keys := []*ds9.Key{ + ds9.NameKey("TxPutMultiTest", "test1", nil), + ds9.NameKey("TxPutMultiTest", "test2", nil), + } + + entities := []txTestEntity{ + {Name: "test1", Count: 1, Time: time.Now().UTC().Truncate(time.Microsecond)}, + {Name: "test2", Count: 2, Time: time.Now().UTC().Truncate(time.Microsecond)}, + } + + // Put in transaction + _, err := client.RunInTransaction(ctx, func(tx *ds9.Transaction) error { + _, err := tx.PutMulti(keys, entities) + return err + }) + if err != nil { + t.Fatalf("Transaction failed: %v", err) + } + + // Verify entities were created + results := make([]txTestEntity, 2) + if err := client.GetMulti(ctx, keys, &results); err != nil { + t.Fatalf("GetMulti failed: %v", err) + } + + if results[0].Name != "test1" { + t.Errorf("Expected Name 'test1', got '%s'", results[0].Name) + } + if results[1].Name != "test2" { + t.Errorf("Expected Name 'test2', got '%s'", results[1].Name) + } +} + +func TestTransactionMixedOperations(t *testing.T) { + client, cleanup := ds9mock.NewClient(t) + defer cleanup() + + ctx := context.Background() + + key1 := ds9.NameKey("TxMixedTest", "read", nil) + key2 := ds9.NameKey("TxMixedTest", "write", nil) + key3 := ds9.NameKey("TxMixedTest", "delete", nil) + + // Create initial entities + if _, err := client.Put(ctx, key1, &txTestEntity{Name: "read", Count: 1, Time: time.Now().UTC().Truncate(time.Microsecond)}); err != nil { + t.Fatalf("Put failed: %v", err) + } + if _, err := client.Put(ctx, key3, &txTestEntity{Name: "delete", Count: 3, Time: time.Now().UTC().Truncate(time.Microsecond)}); err != nil { + t.Fatalf("Put failed: %v", err) + } + + // Transaction with mixed operations + _, err := client.RunInTransaction(ctx, func(tx *ds9.Transaction) error { + // Read + var entity txTestEntity + if err := tx.Get(key1, &entity); err != nil { + return err + } + + // Write + if _, err := tx.Put(key2, &txTestEntity{Name: "write", Count: 2, Time: time.Now().UTC().Truncate(time.Microsecond)}); err != nil { + return err + } + + // Delete + if err := tx.Delete(key3); err != nil { + return err + } + + return nil + }) + if err != nil { + t.Fatalf("Transaction failed: %v", err) + } + + // Verify write succeeded + var entity txTestEntity + if err := client.Get(ctx, key2, &entity); err != nil { + t.Fatalf("Get after transaction failed: %v", err) + } + if entity.Name != "write" { + t.Errorf("Expected Name 'write', got '%s'", entity.Name) + } + + // Verify delete succeeded + err = client.Get(ctx, key3, &entity) + if !errors.Is(err, ds9.ErrNoSuchEntity) { + t.Errorf("Expected ds9.ErrNoSuchEntity for deleted entity, got %v", err) + } +} + +func TestNewTransactionCommit(t *testing.T) { + client, cleanup := ds9mock.NewClient(t) + defer cleanup() + + ctx := context.Background() + + tx, err := client.NewTransaction(ctx) + if err != nil { + t.Fatalf("NewTransaction failed: %v", err) + } + + key := ds9.NameKey("TxCommitTest", "test", nil) + entity := &txTestEntity{ + Name: "test", + Count: 42, + Time: time.Now().UTC().Truncate(time.Microsecond), + } + + if _, err := tx.Put(key, entity); err != nil { + t.Fatalf("tx.Put failed: %v", err) + } + + commit, err := tx.Commit() + if err != nil { + t.Fatalf("tx.Commit failed: %v", err) + } + + if commit == nil { + t.Error("Expected non-nil Commit") + } + + // Verify entity was created + var result txTestEntity + if err := client.Get(ctx, key, &result); err != nil { + t.Fatalf("Get after commit failed: %v", err) + } + + if result.Name != "test" { + t.Errorf("Expected Name 'test', got '%s'", result.Name) + } +} + +func TestNewTransactionRollback(t *testing.T) { + client, cleanup := ds9mock.NewClient(t) + defer cleanup() + + ctx := context.Background() + + tx, err := client.NewTransaction(ctx) + if err != nil { + t.Fatalf("NewTransaction failed: %v", err) + } + + key := ds9.NameKey("TxRollbackTest", "test", nil) + entity := &txTestEntity{ + Name: "test", + Count: 42, + Time: time.Now().UTC().Truncate(time.Microsecond), + } + + if _, err := tx.Put(key, entity); err != nil { + t.Fatalf("tx.Put failed: %v", err) + } + + if err := tx.Rollback(); err != nil { + t.Fatalf("tx.Rollback failed: %v", err) + } + + // After rollback, transaction should not commit (but we can't verify internal state) +} + +func TestTransactionDeleteNilKey(t *testing.T) { + client, cleanup := ds9mock.NewClient(t) + defer cleanup() + + ctx := context.Background() + + _, err := client.RunInTransaction(ctx, func(tx *ds9.Transaction) error { + return tx.Delete(nil) + }) + + if err == nil { + t.Error("Expected error for nil key") + } +} + +func TestTransactionGetMultiLengthMismatch(t *testing.T) { + client, cleanup := ds9mock.NewClient(t) + defer cleanup() + + ctx := context.Background() + + keys := []*ds9.Key{ + ds9.NameKey("Test", "test1", nil), + ds9.NameKey("Test", "test2", nil), + } + + _, err := client.RunInTransaction(ctx, func(tx *ds9.Transaction) error { + results := make([]txTestEntity, 1) // Wrong length + return tx.GetMulti(keys, &results) + }) + + if err == nil { + t.Error("Expected error for length mismatch") + } +} + +func TestTransactionPutMultiLengthMismatch(t *testing.T) { + client, cleanup := ds9mock.NewClient(t) + defer cleanup() + + ctx := context.Background() + + keys := []*ds9.Key{ + ds9.NameKey("Test", "test1", nil), + } + + entities := []txTestEntity{ + {Name: "test1", Count: 1, Time: time.Now().UTC()}, + {Name: "test2", Count: 2, Time: time.Now().UTC()}, + } + + _, err := client.RunInTransaction(ctx, func(tx *ds9.Transaction) error { + _, err := tx.PutMulti(keys, entities) + return err + }) + + if err == nil { + t.Error("Expected error for length mismatch") + } +} + +func TestTransactionWithOptions(t *testing.T) { + client, cleanup := ds9mock.NewClient(t) + defer cleanup() + + ctx := context.Background() + + // Test that transaction options are accepted + _, err := client.NewTransaction(ctx, ds9.MaxAttempts(5)) + if err != nil { + t.Fatalf("NewTransaction with options failed: %v", err) + } + + // Test with read time option + _, err = client.NewTransaction(ctx, ds9.WithReadTime(time.Now().UTC())) + if err != nil { + t.Fatalf("NewTransaction with ds9.WithReadTime failed: %v", err) + } +} + +func TestNewTransactionMultipleOperations(t *testing.T) { + client, cleanup := ds9mock.NewClient(t) + defer cleanup() + + ctx := context.Background() + + tx, err := client.NewTransaction(ctx) + if err != nil { + t.Fatalf("NewTransaction failed: %v", err) + } + + // Multiple puts + for i := range 3 { + key := ds9.IDKey("TxMultiOp", int64(i+1), nil) + entity := &txTestEntity{ + Name: "test", + Count: int64(i), + Time: time.Now().UTC().Truncate(time.Microsecond), + } + if _, err := tx.Put(key, entity); err != nil { + t.Fatalf("tx.Put failed: %v", err) + } + } + + // Commit the transaction + if _, err := tx.Commit(); err != nil { + t.Fatalf("tx.Commit failed: %v", err) + } +}