diff --git a/spanner/integration_test.go b/spanner/integration_test.go index dfa7a94c321b..8bf7d76d4cf2 100644 --- a/spanner/integration_test.go +++ b/spanner/integration_test.go @@ -114,6 +114,16 @@ var ( PRIMARY KEY(AccountId) )`, `CREATE INDEX AccountByNickname ON Accounts(Nickname)`, + `CREATE TABLE Types ( + RowID BIGINT PRIMARY KEY, + String VARCHAR, + Bytes BYTEA, + Int64a BIGINT, + Bool BOOL, + Float64 DOUBLE PRECISION, + Numeric NUMERIC, + JSONB jsonb + )`, } singerDBStatements = []string{ @@ -3453,6 +3463,111 @@ func TestIntegration_PGNumeric(t *testing.T) { } } +func TestIntegration_PGJSONB(t *testing.T) { + onlyRunForPGTest(t) + skipEmulatorTest(t) + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + client, _, cleanup := prepareIntegrationTestForPG(ctx, t, DefaultSessionPoolConfig, singerDBPGStatements) + defer cleanup() + + type Message struct { + Name string + Body string + Time int64 + } + msg := Message{"Alice", "Hello", 1294706395881547000} + jsonStr := `{"Name":"Alice","Body":"Hello","Time":1294706395881547000}` + var unmarshalledJSONstruct interface{} + json.Unmarshal([]byte(jsonStr), &unmarshalledJSONstruct) + + tests := []struct { + col string + val interface{} + want interface{} + }{ + {col: "JSONB", val: PGJsonB{Value: msg, Valid: true}, want: PGJsonB{Value: unmarshalledJSONstruct, Valid: true}}, + {col: "JSONB", val: PGJsonB{Value: msg, Valid: false}, want: PGJsonB{}}, + } + + // Write rows into table first using DML. + statements := make([]Statement, 0) + for i, test := range tests { + stmt := NewStatement(fmt.Sprintf("INSERT INTO Types (RowId, %s) VALUES ($1, $2)", test.col)) + // Note: We are not setting the parameter type here to ensure that it + // can be automatically recognized when it is actually needed. + stmt.Params["p1"] = i + stmt.Params["p2"] = test.val + statements = append(statements, stmt) + } + _, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error { + rowCounts, err := tx.BatchUpdate(ctx, statements) + if err != nil { + return err + } + if len(rowCounts) != len(tests) { + return fmt.Errorf("rowCounts length mismatch\nGot: %v\nWant: %v", len(rowCounts), len(tests)) + } + for i, c := range rowCounts { + if c != 1 { + return fmt.Errorf("row count mismatch for row %v:\nGot: %v\nWant: %v", i, c, 1) + } + } + return nil + }) + if err != nil { + t.Fatalf("failed to insert values using DML: %v", err) + } + // Delete all the rows so we can insert them using mutations as well. + _, err = client.Apply(ctx, []*Mutation{Delete("Types", AllKeys())}) + if err != nil { + t.Fatalf("failed to delete all rows: %v", err) + } + + // Verify that we can insert the rows using mutations. + var muts []*Mutation + for i, test := range tests { + muts = append(muts, InsertOrUpdate("Types", []string{"RowID", test.col}, []interface{}{i, test.val})) + } + if _, err := client.Apply(ctx, muts, ApplyAtLeastOnce()); err != nil { + t.Fatal(err) + } + + for i, test := range tests { + row, err := client.Single().ReadRow(ctx, "Types", []interface{}{i}, []string{test.col}) + if err != nil { + t.Fatalf("Unable to fetch row %v: %v", i, err) + } + verifyDirectPathRemoteAddress(t) + // Create new instance of type of test.want. + want := test.want + if want == nil { + want = test.val + } + gotp := reflect.New(reflect.TypeOf(want)) + if err := row.Column(0, gotp.Interface()); err != nil { + t.Errorf("%d: col:%v val:%#v, %v", i, test.col, test.val, err) + continue + } + got := reflect.Indirect(gotp).Interface() + + // One of the test cases is checking NaN handling. Given + // NaN!=NaN, we can't use reflect to test for it. + if isNaN(got) && isNaN(want) { + continue + } + + // Check non-NaN cases. + if !testEqual(got, want) { + t.Errorf("%d: col:%v val:%#v, got %#v, want %#v", i, test.col, test.val, got, want) + continue + } + } + +} + func readPGSingerTable(iter *RowIterator) ([][]interface{}, error) { defer iter.Stop() var vals [][]interface{} diff --git a/spanner/protoutils.go b/spanner/protoutils.go index ebe03c1b0cf7..6f4d3ac04bb5 100644 --- a/spanner/protoutils.go +++ b/spanner/protoutils.go @@ -81,6 +81,10 @@ func jsonType() *sppb.Type { return &sppb.Type{Code: sppb.TypeCode_JSON} } +func pgJsonbType() *sppb.Type { + return &sppb.Type{Code: sppb.TypeCode_JSON, TypeAnnotation: sppb.TypeAnnotationCode_PG_JSONB} +} + func bytesProto(b []byte) *proto3.Value { return &proto3.Value{Kind: &proto3.Value_StringValue{StringValue: base64.StdEncoding.EncodeToString(b)}} } diff --git a/spanner/value.go b/spanner/value.go index 1927c560e2e5..5d8cd8de943e 100644 --- a/spanner/value.go +++ b/spanner/value.go @@ -884,6 +884,59 @@ type NullRow struct { Valid bool // Valid is true if Row is not NULL. } +// PGJsonB represents a Cloud Spanner PGJsonB that may be NULL. +type PGJsonB struct { + Value interface{} // Val contains the value when it is non-NULL, and nil when NULL. + Valid bool // Valid is true if PGJsonB is not NULL. + // This is here to support customer wrappers around PGJsonB type, this will help during getDecodableSpannerType + // to differentiate between PGJsonB and NullJSON types. + _ bool +} + +// IsNull implements NullableValue.IsNull for PGJsonB. +func (n PGJsonB) IsNull() bool { + return !n.Valid +} + +// String implements Stringer.String for PGJsonB. +func (n PGJsonB) String() string { + if !n.Valid { + return nullString + } + b, err := json.Marshal(n.Value) + if err != nil { + return fmt.Sprintf("error: %v", err) + } + return fmt.Sprintf("%v", string(b)) +} + +// MarshalJSON implements json.Marshaler.MarshalJSON for PGJsonB. +func (n PGJsonB) MarshalJSON() ([]byte, error) { + if n.Valid { + return json.Marshal(n.Value) + } + return jsonNullBytes, nil +} + +// UnmarshalJSON implements json.Unmarshaler.UnmarshalJSON for PGJsonB. +func (n *PGJsonB) UnmarshalJSON(payload []byte) error { + if payload == nil { + return fmt.Errorf("payload should not be nil") + } + if bytes.Equal(payload, jsonNullBytes) { + n.Valid = false + return nil + } + var v interface{} + err := json.Unmarshal(payload, &v) + if err != nil { + return fmt.Errorf("payload cannot be converted to a struct: got %v, err: %w", string(payload), err) + } + n.Value = v + n.Valid = true + return nil +} + // GenericColumnValue represents the generic encoded value and type of the // column. See google.spanner.v1.ResultSet proto for details. This can be // useful for proxying query results when the result types are not known in @@ -1638,6 +1691,44 @@ func decodeValue(v *proto3.Value, t *sppb.Type, ptr interface{}, opts ...decodeO return err } *p = y + case *PGJsonB: + if p == nil { + return errNilDst(p) + } + if code != sppb.TypeCode_JSON || typeAnnotation != sppb.TypeAnnotationCode_PG_JSONB { + return errTypeMismatch(code, acode, ptr) + } + if isNull { + *p = PGJsonB{} + break + } + x := v.GetStringValue() + var y interface{} + err := json.Unmarshal([]byte(x), &y) + if err != nil { + return err + } + *p = PGJsonB{Value: y, Valid: true} + case *[]PGJsonB: + if p == nil { + return errNilDst(p) + } + if acode != sppb.TypeCode_JSON || typeAnnotation != sppb.TypeAnnotationCode_PG_JSONB { + return errTypeMismatch(code, acode, ptr) + } + if isNull { + *p = nil + break + } + x, err := getListValue(v) + if err != nil { + return err + } + y, err := decodePGJsonBArray(x) + if err != nil { + return err + } + *p = y case *time.Time: var nt NullTime if isNull { @@ -1928,6 +2019,7 @@ const ( spannerTypeNullNumeric spannerTypeNullJSON spannerTypePGNumeric + spannerTypePGJsonB spannerTypeArrayOfNonNullString spannerTypeArrayOfByteArray spannerTypeArrayOfNonNullInt64 @@ -1945,6 +2037,7 @@ const ( spannerTypeArrayOfNullTime spannerTypeArrayOfNullDate spannerTypeArrayOfPGNumeric + spannerTypeArrayOfPGJsonB ) // supportsNull returns true for the Go types that can hold a null value from @@ -1977,6 +2070,7 @@ var typeOfNullDate = reflect.TypeOf(NullDate{}) var typeOfNullNumeric = reflect.TypeOf(NullNumeric{}) var typeOfNullJSON = reflect.TypeOf(NullJSON{}) var typeOfPGNumeric = reflect.TypeOf(PGNumeric{}) +var typeOfPGJsonB = reflect.TypeOf(PGJsonB{}) // getDecodableSpannerType returns the corresponding decodableSpannerType of // the given pointer. @@ -2011,6 +2105,9 @@ func getDecodableSpannerType(ptr interface{}, isPtr bool) decodableSpannerType { if t.ConvertibleTo(typeOfNullJSON) { return spannerTypeNullJSON } + if t.ConvertibleTo(typeOfPGJsonB) { + return spannerTypePGJsonB + } case reflect.Struct: t := val.Type() if t.ConvertibleTo(typeOfNonNullNumeric) { @@ -2049,6 +2146,9 @@ func getDecodableSpannerType(ptr interface{}, isPtr bool) decodableSpannerType { if t.ConvertibleTo(typeOfPGNumeric) { return spannerTypePGNumeric } + if t.ConvertibleTo(typeOfPGJsonB) { + return spannerTypePGJsonB + } case reflect.Slice: kind := val.Type().Elem().Kind() switch kind { @@ -2107,6 +2207,9 @@ func getDecodableSpannerType(ptr interface{}, isPtr bool) decodableSpannerType { if t.ConvertibleTo(typeOfPGNumeric) { return spannerTypeArrayOfPGNumeric } + if t.ConvertibleTo(typeOfPGJsonB) { + return spannerTypeArrayOfPGJsonB + } case reflect.Slice: // The only array-of-array type that is supported is [][]byte. kind := val.Type().Elem().Elem().Kind() @@ -2267,6 +2370,21 @@ func (dsc decodableSpannerType) decodeValueToCustomType(v *proto3.Value, t *sppb return err } result = &NullJSON{y, true} + case spannerTypePGJsonB: + if code != sppb.TypeCode_JSON || typeAnnotation != sppb.TypeAnnotationCode_PG_JSONB { + return errTypeMismatch(code, acode, ptr) + } + if isNull { + result = &PGJsonB{} + break + } + x := v.GetStringValue() + var y interface{} + err := json.Unmarshal([]byte(x), &y) + if err != nil { + return err + } + result = &PGJsonB{Value: y, Valid: true} case spannerTypeNonNullTime, spannerTypeNullTime: var nt NullTime err := parseNullTime(v, &nt, code, isNull) @@ -2435,6 +2553,23 @@ func (dsc decodableSpannerType) decodeValueToCustomType(v *proto3.Value, t *sppb return err } result = y + case spannerTypeArrayOfPGJsonB: + if acode != sppb.TypeCode_JSON || atypeAnnotation != sppb.TypeAnnotationCode_PG_JSONB { + return errTypeMismatch(code, acode, ptr) + } + if isNull { + ptr = nil + return nil + } + x, err := getListValue(v) + if err != nil { + return err + } + y, err := decodeGenericArray(reflect.TypeOf(ptr).Elem(), x, pgJsonbType(), "PGJSONB") + if err != nil { + return err + } + result = y case spannerTypeArrayOfNonNullTime, spannerTypeArrayOfNullTime: if acode != sppb.TypeCode_TIMESTAMP { return errTypeMismatch(code, acode, ptr) @@ -2812,6 +2947,20 @@ func decodeNullJSONArray(pb *proto3.ListValue) ([]NullJSON, error) { return a, nil } +// decodeJsonBArray decodes proto3.ListValue pb into a JsonB slice. +func decodePGJsonBArray(pb *proto3.ListValue) ([]PGJsonB, error) { + if pb == nil { + return nil, errNilListValue("PGJSONB") + } + a := make([]PGJsonB, len(pb.Values)) + for i, v := range pb.Values { + if err := decodeValue(v, pgJsonbType(), &a[i]); err != nil { + return nil, errDecodeArrayElement(i, v, "PGJSONB", err) + } + } + return a, nil +} + // decodeNullJSONArray decodes proto3.ListValue pb into a NullJSON pointer. func decodeNullJSONArrayToNullJSON(pb *proto3.ListValue) (*NullJSON, error) { if pb == nil { @@ -3457,6 +3606,23 @@ func encodeValue(v interface{}) (*proto3.Value, *sppb.Type, error) { } } pt = listType(jsonType()) + case PGJsonB: + if v.Valid { + b, err := json.Marshal(v.Value) + if err != nil { + return nil, nil, err + } + pb.Kind = stringKind(string(b)) + } + return pb, pgJsonbType(), nil + case []PGJsonB: + if v != nil { + pb, err = encodeArray(len(v), func(i int) interface{} { return v[i] }) + if err != nil { + return nil, nil, err + } + } + pt = listType(pgJsonbType()) case *big.Rat: switch LossOfPrecisionHandling { case NumericError: @@ -3648,6 +3814,8 @@ func convertCustomTypeValue(sourceType decodableSpannerType, v interface{}) (int destination = reflect.Indirect(reflect.New(reflect.TypeOf(NullNumeric{}))) case spannerTypeNullJSON: destination = reflect.Indirect(reflect.New(reflect.TypeOf(NullJSON{}))) + case spannerTypePGJsonB: + destination = reflect.Indirect(reflect.New(reflect.TypeOf(PGJsonB{}))) case spannerTypePGNumeric: destination = reflect.Indirect(reflect.New(reflect.TypeOf(PGNumeric{}))) case spannerTypeArrayOfNonNullString: @@ -3730,6 +3898,11 @@ func convertCustomTypeValue(sourceType decodableSpannerType, v interface{}) (int return []NullJSON(nil), nil } destination = reflect.MakeSlice(reflect.TypeOf([]NullJSON{}), reflect.ValueOf(v).Len(), reflect.ValueOf(v).Cap()) + case spannerTypeArrayOfPGJsonB: + if reflect.ValueOf(v).IsNil() { + return []PGJsonB(nil), nil + } + destination = reflect.MakeSlice(reflect.TypeOf([]PGJsonB{}), reflect.ValueOf(v).Len(), reflect.ValueOf(v).Cap()) case spannerTypeArrayOfPGNumeric: if reflect.ValueOf(v).IsNil() { return []PGNumeric(nil), nil diff --git a/spanner/value_test.go b/spanner/value_test.go index 01dd88eb3c0a..f8f3cd0a8a66 100644 --- a/spanner/value_test.go +++ b/spanner/value_test.go @@ -205,6 +205,7 @@ func TestEncodeValue(t *testing.T) { type CustomDate civil.Date type CustomNumeric big.Rat type CustomPGNumeric PGNumeric + type CustomPGJSONB PGJsonB type CustomNullString NullString type CustomNullInt64 NullInt64 @@ -258,6 +259,7 @@ func TestEncodeValue(t *testing.T) { tNumeric = numericType() tJSON = jsonType() tPGNumeric = pgNumericType() + tPGJsonb = pgJsonbType() ) for i, test := range []struct { in interface{} @@ -333,6 +335,13 @@ func TestEncodeValue(t *testing.T) { {[]NullJSON{{msg, true}, {msg, false}}, listProto(stringProto(jsonStr), nullProto()), listType(tJSON), "[]NullJSON"}, {NullJSON{[]Message{}, true}, stringProto(emptyArrayJSONStr), tJSON, "a json string with empty array to NullJSON"}, {NullJSON{ptrMsg, true}, stringProto(nullValueJSONStr), tJSON, "a json string with null value to NullJSON"}, + // PG JSONB + {PGJsonB{Value: msg, Valid: true}, stringProto(jsonStr), tPGJsonb, "PGJsonB with value"}, + {PGJsonB{Value: msg, Valid: false}, nullProto(), tPGJsonb, "PGJsonB with null"}, + {[]PGJsonB(nil), nullProto(), listType(tPGJsonb), "null []PGJsonB"}, + {[]PGJsonB{{Value: msg, Valid: true}, {Value: msg, Valid: false}}, listProto(stringProto(jsonStr), nullProto()), listType(tPGJsonb), "[]PGJsonB"}, + {PGJsonB{Value: []Message{}, Valid: true}, stringProto(emptyArrayJSONStr), tPGJsonb, "a json string with empty array to PGJsonB"}, + {PGJsonB{Value: ptrMsg, Valid: true}, stringProto(nullValueJSONStr), tPGJsonb, "a json string with null value to PGJsonB"}, // PG NUMERIC {PGNumeric{"123.456", true}, stringProto("123.456"), tPGNumeric, "PG Numeric"}, {PGNumeric{Valid: false}, nullProto(), tPGNumeric, "PG Numeric with a null value"}, @@ -459,6 +468,11 @@ func TestEncodeValue(t *testing.T) { {CustomPGNumeric{Valid: false}, nullProto(), tPGNumeric, "PG Numeric with a null value"}, {[]CustomPGNumeric(nil), nullProto(), listType(tPGNumeric), "null []PGNumeric"}, {[]CustomPGNumeric{{"123.456", true}, {Valid: false}}, listProto(stringProto("123.456"), nullProto()), listType(tPGNumeric), "[]PGNumeric"}, + // CUSTOM PG JSONB + {CustomPGJSONB{Value: msg, Valid: true}, stringProto(jsonStr), tPGJsonb, "CustomPGJSONB with value"}, + {CustomPGJSONB{Value: msg, Valid: false}, nullProto(), tPGJsonb, "CustomPGJSONB with null"}, + {[]CustomPGJSONB(nil), nullProto(), listType(tPGJsonb), "null []CustomPGJSONB"}, + {[]CustomPGJSONB{{Value: msg, Valid: true}, {Value: msg, Valid: false}}, listProto(stringProto(jsonStr), nullProto()), listType(tPGJsonb), "[]CustomPGJSONB"}, } { got, gotType, err := encodeValue(test.in) if err != nil {