diff --git a/schema/table.go b/schema/table.go index 58b4bf3493..a0a0940f23 100644 --- a/schema/table.go +++ b/schema/table.go @@ -7,6 +7,7 @@ import ( "github.com/apache/arrow/go/v13/arrow" "github.com/cloudquery/plugin-sdk/v3/internal/glob" + "golang.org/x/exp/slices" ) // TableResolver is the main entry point when a table is sync is called. @@ -173,9 +174,10 @@ func (tt Tables) FilterDfsFunc(include, exclude func(*Table) bool, skipDependent } func (tt Tables) ToArrowSchemas() Schemas { - schemas := make(Schemas, 0, len(tt.FlattenTables())) - for _, t := range tt.FlattenTables() { - schemas = append(schemas, t.ToArrowSchema()) + flattened := tt.FlattenTables() + schemas := make(Schemas, len(flattened)) + for i, t := range flattened { + schemas[i] = t.ToArrowSchema() } return schemas } @@ -228,19 +230,22 @@ func (tt Tables) FilterDfs(tables, skipTables []string, skipDependentTables bool func (tt Tables) FlattenTables() Tables { tables := make(Tables, 0, len(tt)) for _, t := range tt { - tables = append(tables, t) + table := *t + table.Relations = nil + tables = append(tables, &table) tables = append(tables, t.Relations.FlattenTables()...) } - tableNames := make(map[string]bool) - dedupedTables := make(Tables, 0, len(tables)) + + seen := make(map[string]struct{}) + deduped := make(Tables, 0, len(tables)) for _, t := range tables { - if _, found := tableNames[t.Name]; !found { - dedupedTables = append(dedupedTables, t) - tableNames[t.Name] = true + if _, found := seen[t.Name]; !found { + deduped = append(deduped, t) + seen[t.Name] = struct{}{} } } - return dedupedTables + return slices.Clip(deduped) } func (tt Tables) TableNames() []string { diff --git a/schema/table_test.go b/schema/table_test.go index cc3aadb1d0..c70c2b3f2b 100644 --- a/schema/table_test.go +++ b/schema/table_test.go @@ -5,6 +5,7 @@ import ( "github.com/apache/arrow/go/v13/arrow" "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/require" ) var testTable = &Table{ @@ -19,11 +20,25 @@ var testTable = &Table{ } func TestTablesFlatten(t *testing.T) { - tables := Tables{testTable}.FlattenTables() - if len(tables) != 2 { - t.Fatal("expected 2 tables") + srcTables := Tables{testTable} + tables := srcTables.FlattenTables() + require.Equal(t, 1, len(srcTables)) // verify that the source Tables were left untouched + require.Equal(t, 1, len(testTable.Relations)) + require.Equal(t, 2, len(tables)) + for _, table := range tables { + require.Nil(t, table.Relations) } - tables = Tables{testTable}.FlattenTables() + + srcTables = Tables{testTable, testTable} + tables = srcTables.FlattenTables() + require.Equal(t, 2, len(srcTables)) // verify that the source Tables were left untouched + require.Equal(t, 1, len(testTable.Relations)) + require.Equal(t, 2, len(tables)) + for _, table := range tables { + require.Nil(t, table.Relations) + } + + tables = tables.FlattenTables() if len(tables) != 2 { t.Fatal("expected 2 tables") }