diff --git a/docs/json.go b/docs/json.go index 8972a86b8c..922e733be1 100644 --- a/docs/json.go +++ b/docs/json.go @@ -18,10 +18,11 @@ type jsonTable struct { } type jsonColumn struct { - Name string `json:"name"` - Type string `json:"type"` - IsPrimaryKey bool `json:"is_primary_key,omitempty"` - IsIncrementalKey bool `json:"is_incremental_key,omitempty"` + Name string `json:"name"` + Type string `json:"type"` + IsPrimaryKey bool `json:"is_primary_key,omitempty"` + IsPrimaryKeyComponent bool `json:"is_primary_key_component,omitempty"` + IsIncrementalKey bool `json:"is_incremental_key,omitempty"` } func (g *Generator) renderTablesAsJSON(dir string) error { @@ -44,10 +45,11 @@ func (g *Generator) jsonifyTables(tables schema.Tables) []jsonTable { jsonColumns := make([]jsonColumn, len(table.Columns)) for c, col := range table.Columns { jsonColumns[c] = jsonColumn{ - Name: col.Name, - Type: col.Type.String(), - IsPrimaryKey: col.PrimaryKey, - IsIncrementalKey: col.IncrementalKey, + Name: col.Name, + Type: col.Type.String(), + IsPrimaryKey: col.PrimaryKey, + IsPrimaryKeyComponent: col.PrimaryKeyComponent, + IsIncrementalKey: col.IncrementalKey, } } jsonTables[i] = jsonTable{ diff --git a/scheduler/scheduler_test.go b/scheduler/scheduler_test.go index 254c6b7d6f..2db5e5e3b8 100644 --- a/scheduler/scheduler_test.go +++ b/scheduler/scheduler_test.go @@ -101,6 +101,23 @@ func testTableSuccessWithCQIDPK() *schema.Table { } } +func testTableSuccessWithPKComponents() *schema.Table { + cqID := schema.CqIDColumn + cqID.PrimaryKey = true + return &schema.Table{ + Name: "test_table_succes_vpk__cq_id", + Resolver: testResolverSuccess, + Columns: []schema.Column{ + cqID, + { + Name: "test_column", + Type: arrow.PrimitiveTypes.Int64, + PrimaryKeyComponent: true, + }, + }, + } +} + func testTableResolverPanic() *schema.Table { return &schema.Table{ Name: "test_table_resolver_panic", @@ -270,6 +287,16 @@ var syncTestCases = []syncTestCase{ }, deterministicCQID: false, }, + { + table: testTableSuccessWithPKComponents(), + data: []scalar.Vector{ + { + // This value will not be validated as it will be randomly set by the scheduler + &scalar.UUID{}, + &scalar.Int{Value: 3, Valid: true}, + }, + }, + }, } func TestScheduler(t *testing.T) { @@ -289,77 +316,6 @@ func TestScheduler(t *testing.T) { } } -func TestScheduler_Cancellation(t *testing.T) { - data := make([]any, 100) - - tests := []struct { - name string - data []any - cancel bool - messageCount int - }{ - { - name: "should consume all message", - data: data, - cancel: false, - messageCount: len(data) + 1, // 9 data + 1 migration message - }, - { - name: "should not consume all message on cancel", - data: data, - cancel: true, - messageCount: len(data) + 1, // 9 data + 1 migration message - }, - } - - for _, strategy := range AllStrategies { - strategy := strategy - for _, tc := range tests { - tc := tc - t.Run(fmt.Sprintf("%s_%s", tc.name, strategy.String()), func(t *testing.T) { - logger := zerolog.New(zerolog.NewTestWriter(t)) - if tc.cancel { - logger = zerolog.Nop() // FIXME without this, zerolog usage causes a race condition when tests are run with `-race -count=100` - } - sc := NewScheduler(WithLogger(logger), WithStrategy(strategy)) - - messages := make(chan message.SyncMessage) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - go func() { - err := sc.Sync( - ctx, - &testExecutionClient{}, - []*schema.Table{testTableSuccessWithData(tc.data)}, - messages, - ) - if tc.cancel { - assert.Equal(t, err, context.Canceled) - } else { - require.NoError(t, err) - } - close(messages) - }() - - messageConsumed := 0 - for range messages { - if tc.cancel { - cancel() - } - messageConsumed++ - } - - if tc.cancel { - assert.NotEqual(t, tc.messageCount, messageConsumed) - } else { - assert.Equal(t, tc.messageCount, messageConsumed) - } - }) - } - } -} - // nolint:revive func testSyncTable(t *testing.T, tc syncTestCase, strategy Strategy, deterministicCQID bool) { ctx := context.Background() @@ -411,7 +367,7 @@ func testSyncTable(t *testing.T, tc syncTestCase, strategy Strategy, determinist initialTable := tables.Get(v.Table.Name) pks := migratedTable.PrimaryKeys() - if deterministicCQID && initialTable.Columns.Get(schema.CqIDColumn.Name) != nil { + if (deterministicCQID || len(migratedTable.PrimaryKeyComponents()) > 0) && initialTable.Columns.Get(schema.CqIDColumn.Name) != nil { if len(pks) != 1 { t.Fatalf("expected 1 pk. got %d", len(pks)) } @@ -433,3 +389,74 @@ func testSyncTable(t *testing.T, tc syncTestCase, strategy Strategy, determinist t.Fatalf("expected %d resources. got %d", len(tc.data), i) } } + +func TestScheduler_Cancellation(t *testing.T) { + data := make([]any, 100) + + tests := []struct { + name string + data []any + cancel bool + messageCount int + }{ + { + name: "should consume all message", + data: data, + cancel: false, + messageCount: len(data) + 1, // 9 data + 1 migration message + }, + { + name: "should not consume all message on cancel", + data: data, + cancel: true, + messageCount: len(data) + 1, // 9 data + 1 migration message + }, + } + + for _, strategy := range AllStrategies { + strategy := strategy + for _, tc := range tests { + tc := tc + t.Run(fmt.Sprintf("%s_%s", tc.name, strategy.String()), func(t *testing.T) { + logger := zerolog.New(zerolog.NewTestWriter(t)) + if tc.cancel { + logger = zerolog.Nop() // FIXME without this, zerolog usage causes a race condition when tests are run with `-race -count=100` + } + sc := NewScheduler(WithLogger(logger), WithStrategy(strategy)) + + messages := make(chan message.SyncMessage) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go func() { + err := sc.Sync( + ctx, + &testExecutionClient{}, + []*schema.Table{testTableSuccessWithData(tc.data)}, + messages, + ) + if tc.cancel { + assert.Equal(t, err, context.Canceled) + } else { + require.NoError(t, err) + } + close(messages) + }() + + messageConsumed := 0 + for range messages { + if tc.cancel { + cancel() + } + messageConsumed++ + } + + if tc.cancel { + assert.NotEqual(t, tc.messageCount, messageConsumed) + } else { + assert.Equal(t, tc.messageCount, messageConsumed) + } + }) + } + } +} diff --git a/schema/arrow.go b/schema/arrow.go index e882625334..f732f3adb2 100644 --- a/schema/arrow.go +++ b/schema/arrow.go @@ -5,10 +5,11 @@ import ( ) const ( - MetadataUnique = "cq:extension:unique" - MetadataPrimaryKey = "cq:extension:primary_key" - MetadataConstraintName = "cq:extension:constraint_name" - MetadataIncremental = "cq:extension:incremental" + MetadataUnique = "cq:extension:unique" + MetadataPrimaryKey = "cq:extension:primary_key" + MetadataPrimaryKeyComponent = "cq:extension:primary_key_component" + MetadataConstraintName = "cq:extension:constraint_name" + MetadataIncremental = "cq:extension:incremental" MetadataTrue = "true" MetadataFalse = "false" diff --git a/schema/column.go b/schema/column.go index 3fd93e1f47..6c4474cca2 100644 --- a/schema/column.go +++ b/schema/column.go @@ -43,6 +43,9 @@ type Column struct { IncrementalKey bool `json:"incremental_key"` // Unique requires the destinations supporting this to mark this column as unique Unique bool `json:"unique"` + + // PrimaryKeyComponent is a flag that indicates if the column is used as part of the input to calculate the value of `_cq_id`. + PrimaryKeyComponent bool `json:"primary_key_component"` } // NewColumnFromArrowField creates a new Column from an arrow.Field @@ -64,14 +67,18 @@ func NewColumnFromArrowField(f arrow.Field) Column { v, ok = f.Metadata.GetValue(MetadataIncremental) column.IncrementalKey = ok && v == MetadataTrue + v, ok = f.Metadata.GetValue(MetadataPrimaryKeyComponent) + column.PrimaryKeyComponent = ok && v == MetadataTrue + return column } func (c Column) ToArrowField() arrow.Field { mdKV := map[string]string{ - MetadataPrimaryKey: MetadataFalse, - MetadataUnique: MetadataFalse, - MetadataIncremental: MetadataFalse, + MetadataPrimaryKey: MetadataFalse, + MetadataUnique: MetadataFalse, + MetadataIncremental: MetadataFalse, + MetadataPrimaryKeyComponent: MetadataFalse, } if c.PrimaryKey { mdKV[MetadataPrimaryKey] = MetadataTrue @@ -82,6 +89,9 @@ func (c Column) ToArrowField() arrow.Field { if c.IncrementalKey { mdKV[MetadataIncremental] = MetadataTrue } + if c.PrimaryKeyComponent { + mdKV[MetadataPrimaryKeyComponent] = MetadataTrue + } return arrow.Field{ Name: c.Name, @@ -93,13 +103,14 @@ func (c Column) ToArrowField() arrow.Field { func (c Column) MarshalJSON() ([]byte, error) { type Alias struct { - Name string `json:"name"` - Type string `json:"type"` - Description string `json:"description"` - PrimaryKey bool `json:"primary_key"` - NotNull bool `json:"not_null"` - Unique bool `json:"unique"` - IncrementalKey bool `json:"incremental_key"` + Name string `json:"name"` + Type string `json:"type"` + Description string `json:"description"` + PrimaryKey bool `json:"primary_key"` + NotNull bool `json:"not_null"` + Unique bool `json:"unique"` + IncrementalKey bool `json:"incremental_key"` + PrimaryKeyComponent bool `json:"primary_key_component"` } var alias Alias alias.Name = c.Name @@ -109,6 +120,7 @@ func (c Column) MarshalJSON() ([]byte, error) { alias.NotNull = c.NotNull alias.Unique = c.Unique alias.IncrementalKey = c.IncrementalKey + alias.PrimaryKeyComponent = c.PrimaryKeyComponent return json.Marshal(alias) } @@ -130,6 +142,10 @@ func (c Column) String() string { if c.IncrementalKey { sb.WriteString(":IncrementalKey") } + + if c.PrimaryKeyComponent { + sb.WriteString(":PrimaryKeyComponent") + } return sb.String() } diff --git a/schema/doc.go b/schema/doc.go deleted file mode 100644 index 1a1354872c..0000000000 --- a/schema/doc.go +++ /dev/null @@ -1,2 +0,0 @@ -// Package schema defines types supported by tables in source plugins -package schema diff --git a/schema/resource.go b/schema/resource.go index fa17c8ada7..3e826ba8ee 100644 --- a/schema/resource.go +++ b/schema/resource.go @@ -3,6 +3,7 @@ package schema import ( "crypto/sha256" "fmt" + "hash" "slices" "github.com/cloudquery/plugin-sdk/v4/scalar" @@ -79,21 +80,34 @@ func (r *Resource) GetValues() scalar.Vector { //nolint:revive func (r *Resource) CalculateCQID(deterministicCQID bool) error { + // if `PrimaryKeyComponent` is set, we calculate the CQID based on those components + pkComponents := r.Table.PrimaryKeyComponents() + if len(pkComponents) > 0 { + return r.storeCQID(uuid.NewSHA1(uuid.UUID{}, calculateCqIDValue(r, pkComponents).Sum(nil))) + } + + // If deterministicCQID is false, we generate a random CQID if !deterministicCQID { return r.storeCQID(uuid.New()) } names := r.Table.PrimaryKeys() + // If there are no primary keys or if CQID is the only PK, we generate a random CQID if len(names) == 0 || (len(names) == 1 && names[0] == CqIDColumn.Name) { return r.storeCQID(uuid.New()) } - slices.Sort(names) + + return r.storeCQID(uuid.NewSHA1(uuid.UUID{}, calculateCqIDValue(r, names).Sum(nil))) +} + +func calculateCqIDValue(r *Resource, cols []string) hash.Hash { h := sha256.New() - for _, name := range names { + slices.Sort(cols) + for _, col := range cols { // We need to include the column name in the hash because the same value can be present in multiple columns and therefore lead to the same hash - h.Write([]byte(name)) - h.Write([]byte(r.Get(name).String())) + h.Write([]byte(col)) + h.Write([]byte(r.Get(col).String())) } - return r.storeCQID(uuid.NewSHA1(uuid.UUID{}, h.Sum(nil))) + return h } func (r *Resource) storeCQID(value uuid.UUID) error { diff --git a/schema/table.go b/schema/table.go index ab058d7219..10e2237454 100644 --- a/schema/table.go +++ b/schema/table.go @@ -589,6 +589,17 @@ func (t *Table) IncrementalKeys() []string { return incrementalKeys } +func (t *Table) PrimaryKeyComponents() []string { + var primaryKeyComponents []string + for _, c := range t.Columns { + if c.PrimaryKeyComponent { + primaryKeyComponents = append(primaryKeyComponents, c.Name) + } + } + + return primaryKeyComponents +} + func (t *Table) TableNames() []string { ret := []string{t.Name} for _, rel := range t.Relations { diff --git a/transformers/struct.go b/transformers/struct.go index 0148cb714d..105924ed6a 100644 --- a/transformers/struct.go +++ b/transformers/struct.go @@ -26,6 +26,8 @@ type structTransformer struct { structFieldsToUnwrap []string pkFields []string pkFieldsFound []string + pkComponentFields []string + pkComponentFieldsFound []string } type NameTransformer func(reflect.StructField) (string, error) @@ -117,6 +119,13 @@ func WithPrimaryKeys(fields ...string) StructTransformerOption { } } +// WithPrimaryKeyComponents allows to specify what struct fields should be used as primary key components +func WithPrimaryKeyComponents(fields ...string) StructTransformerOption { + return func(t *structTransformer) { + t.pkComponentFields = fields + } +} + func TransformWithStruct(st any, opts ...StructTransformerOption) schema.Transform { t := &structTransformer{ nameTransformer: DefaultNameTransformer, @@ -159,6 +168,10 @@ func TransformWithStruct(st any, opts ...StructTransformerOption) schema.Transfo if diff := funk.SubtractString(t.pkFields, t.pkFieldsFound); len(diff) > 0 { return fmt.Errorf("failed to create all of the desired primary keys: %v", diff) } + + if diff := funk.SubtractString(t.pkComponentFields, t.pkComponentFieldsFound); len(diff) > 0 { + return fmt.Errorf("failed to find all of the desired primary key components: %v", diff) + } return nil } } @@ -286,6 +299,16 @@ func (t *structTransformer) addColumnFromField(field reflect.StructField, parent } } + for _, pk := range t.pkComponentFields { + if pk == path { + // use path to allow the following + // 1. Don't duplicate the PK fields if the unwrapped struct contains a fields with the same name + // 2. Allow specifying the nested unwrapped field as part of the PK. + column.PrimaryKeyComponent = true + t.pkComponentFieldsFound = append(t.pkComponentFieldsFound, pk) + } + } + t.table.Columns = append(t.table.Columns, column) return nil