diff --git a/schema/table.go b/schema/table.go index de984eecd1..6aa49fad76 100644 --- a/schema/table.go +++ b/schema/table.go @@ -275,6 +275,31 @@ func (tt Tables) FlattenTables() Tables { return slices.Clip(deduped) } +// UnflattenTables returns a new Tables copy with the relations unflattened. This is the +// opposite operation of FlattenTables. +func (tt Tables) UnflattenTables() (Tables, error) { + tables := make(Tables, 0, len(tt)) + for _, t := range tt { + table := *t + tables = append(tables, &table) + } + topLevel := make([]*Table, 0, len(tt)) + // build relations + for _, table := range tables { + if table.Parent == nil { + topLevel = append(topLevel, table) + continue + } + parent := tables.Get(table.Parent.Name) + if parent == nil { + return nil, fmt.Errorf("parent table %s not found", table.Parent.Name) + } + table.Parent = parent + parent.Relations = append(parent.Relations, table) + } + return slices.Clip(topLevel), nil +} + func (tt Tables) TableNames() []string { ret := []string{} for _, t := range tt { diff --git a/schema/table_test.go b/schema/table_test.go index 25d14deb22..26b10d483c 100644 --- a/schema/table_test.go +++ b/schema/table_test.go @@ -16,6 +16,7 @@ var testTable = &Table{ { Name: "test2", Columns: []Column{}, + Parent: &Table{Name: "test"}, }, }, } @@ -45,6 +46,15 @@ func TestTablesFlatten(t *testing.T) { } } +func TestTablesUnflatten(t *testing.T) { + srcTables := Tables{testTable} + tables, err := srcTables.FlattenTables().UnflattenTables() + require.NoError(t, err) + require.Equal(t, 1, len(srcTables)) // verify that the source Tables were left untouched + require.Equal(t, 1, len(tables)) // verify that the tables are equal to what we started with + require.Equal(t, 1, len(tables[0].Relations)) +} + func TestTablesFilterDFS(t *testing.T) { tests := []struct { name string