From 3dda7b27ec536637d8ebaa20937fc8019c930481 Mon Sep 17 00:00:00 2001 From: go_vargo <44902466+govargo@users.noreply.github.com> Date: Wed, 23 Aug 2023 19:52:53 +0900 Subject: [PATCH] feat(spanner/spannertest): support INSERT DML (#7820) Co-authored-by: rahul2393 --- spanner/spannertest/README.md | 1 - spanner/spannertest/db.go | 58 ++++++++++++++++++ spanner/spannertest/integration_test.go | 80 +++++++++++++++++++++++-- 3 files changed, 133 insertions(+), 6 deletions(-) diff --git a/spanner/spannertest/README.md b/spanner/spannertest/README.md index e737bd6810f2..0714328af0a1 100644 --- a/spanner/spannertest/README.md +++ b/spanner/spannertest/README.md @@ -33,7 +33,6 @@ by ascending esotericism: - case insensitivity of table and column names and query aliases - transaction simulation - FOREIGN KEY and CHECK constraints -- INSERT DML statements - set operations (UNION, INTERSECT, EXCEPT) - STRUCT types - partition support diff --git a/spanner/spannertest/db.go b/spanner/spannertest/db.go index 6babaf93e04a..31cebd9cebbd 100644 --- a/spanner/spannertest/db.go +++ b/spanner/spannertest/db.go @@ -1262,6 +1262,64 @@ func (d *database) Execute(stmt spansql.DMLStmt, params queryParams) (int, error } } return n, nil + case *spansql.Insert: + t, err := d.table(stmt.Table) + if err != nil { + return 0, err + } + + t.mu.Lock() + defer t.mu.Unlock() + + ec := evalContext{ + cols: t.cols, + params: params, + } + + values := make(row, len(t.cols)) + input := stmt.Input.(spansql.Values) + if len(input) > 0 { + for i := 0; i < len(input); i++ { + val := input[i] + for k, v := range val { + switch v := v.(type) { + // if spanner.Statement.Params is not empty, scratch row with ec.parameters + case spansql.Param: + values[k] = ec.params[t.cols[k].Name.SQL()].Value + // if nil is included in parameters, pass nil + case spansql.ID: + cutset := `""` + str := strings.Trim(v.SQL(), cutset) + if str == "nil" { + values[k] = nil + } else { + expr, err := ec.evalExpr(v) + if err != nil { + return 0, status.Errorf(codes.InvalidArgument, "invalid parameter format") + } + values[k] = expr + } + // if parameter is embedded in SQL as string, not in statement.Params, analyze parameters + default: + expr, err := ec.evalExpr(v) + if err != nil { + return 0, status.Errorf(codes.InvalidArgument, "invalid parameter format") + } + values[k] = expr + } + } + } + } + + // pk check if the primary key already exists + pk := values[:t.pkCols] + rowNum, found := t.rowForPK(pk) + if found { + return 0, status.Errorf(codes.AlreadyExists, "row already in table") + } + t.insertRow(rowNum, values) + + return 1, nil } } diff --git a/spanner/spannertest/integration_test.go b/spanner/spannertest/integration_test.go index cfe75c171e97..5dc3c0958aa7 100644 --- a/spanner/spannertest/integration_test.go +++ b/spanner/spannertest/integration_test.go @@ -708,15 +708,83 @@ func TestIntegration_ReadsAndQueries(t *testing.T) { spanner.Insert("SomeStrings", []string{"i", "str"}, []interface{}{1, "abar"}), spanner.Insert("SomeStrings", []string{"i", "str"}, []interface{}{2, nil}), spanner.Insert("SomeStrings", []string{"i", "str"}, []interface{}{3, "bbar"}), - - spanner.Insert("Updateable", []string{"id", "first", "last"}, []interface{}{0, "joe", nil}), - spanner.Insert("Updateable", []string{"id", "first", "last"}, []interface{}{1, "doe", "joan"}), - spanner.Insert("Updateable", []string{"id", "first", "last"}, []interface{}{2, "wong", "wong"}), }) if err != nil { t.Fatalf("Inserting sample data: %v", err) } + // Perform INSERT DML; the results are checked later on. + n = 0 + _, err = client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *spanner.ReadWriteTransaction) error { + for _, u := range []string{ + `INSERT INTO Updateable (id, first, last) VALUES (0, "joe", nil)`, + `INSERT INTO Updateable (id, first, last) VALUES (1, "doe", "joan")`, + `INSERT INTO Updateable (id, first, last) VALUES (2, "wong", "wong")`, + } { + nr, err := tx.Update(ctx, spanner.NewStatement(u)) + if err != nil { + return err + } + n += nr + } + return nil + }) + if err != nil { + t.Fatalf("Inserting with DML: %v", err) + } + if n != 3 { + t.Errorf("Inserting with DML affected %d rows, want 3", n) + } + + // Perform INSERT DML with statement.Params; the results are checked later on. + n = 0 + _, err = client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *spanner.ReadWriteTransaction) error { + stmt := spanner.Statement{ + SQL: "INSERT INTO Updateable (id, first, last) VALUES (@id, @first, @last)", + Params: map[string]interface{}{ + "id": 3, + "first": "tom", + "last": "jerry", + }, + } + nr, err := tx.Update(ctx, stmt) + if err != nil { + return err + } + n += nr + return nil + }) + if err != nil { + t.Fatalf("Inserting with DML: %v", err) + } + if n != 1 { + t.Errorf("Inserting with DML affected %d rows, want 1", n) + } + + // Perform INSERT DML with statement.Params and inline parameter; the results are checked later on. + n = 0 + _, err = client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *spanner.ReadWriteTransaction) error { + stmt := spanner.Statement{ + SQL: `INSERT INTO Updateable (id, first, last) VALUES (@id, "jim", @last)`, + Params: map[string]interface{}{ + "id": 4, + "last": nil, + }, + } + nr, err := tx.Update(ctx, stmt) + if err != nil { + return err + } + n += nr + return nil + }) + if err != nil { + t.Fatalf("Inserting with DML: %v", err) + } + if n != 1 { + t.Errorf("Inserting with DML affected %d rows, want 1", n) + } + // Perform UPDATE DML; the results are checked later on. n = 0 _, err = client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *spanner.ReadWriteTransaction) error { @@ -724,7 +792,7 @@ func TestIntegration_ReadsAndQueries(t *testing.T) { `UPDATE Updateable SET last = "bloggs" WHERE id = 0`, `UPDATE Updateable SET first = last, last = first WHERE id = 1`, `UPDATE Updateable SET last = DEFAULT WHERE id = 2`, - `UPDATE Updateable SET first = "noname" WHERE id = 3`, // no id=3 + `UPDATE Updateable SET first = "noname" WHERE id = 5`, // no id=5 } { nr, err := tx.Update(ctx, spanner.NewStatement(u)) if err != nil { @@ -1156,6 +1224,8 @@ func TestIntegration_ReadsAndQueries(t *testing.T) { {int64(0), "joe", "bloggs"}, {int64(1), "joan", "doe"}, {int64(2), "wong", nil}, + {int64(3), "tom", "jerry"}, + {int64(4), "jim", nil}, }, }, // Regression test for aggregating no rows; it used to return an empty row.