Skip to content

Commit

Permalink
Fix circular dependencies (#1796)
Browse files Browse the repository at this point in the history
  • Loading branch information
alishakawaguchi committed Apr 20, 2024
1 parent 1e543fd commit c9ef2f9
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 30 deletions.
30 changes: 23 additions & 7 deletions backend/pkg/table-dependency/table-dependency.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,13 @@ type ConstraintColumns struct {
}

func GetRunConfigs(dependencies dbschemas.TableDependency, tables []string, subsets map[string]string) []*RunConfig {
depsMap := map[string][]string{}
filteredDepsMap := map[string][]string{} // only include tables that are in tables arg list
foreignKeyMap := map[string]map[string]string{} // map: table -> foreign key table -> foreign key column
configs := []*RunConfig{}

for table, constraints := range dependencies {
foreignKeyMap[table] = map[string]string{}
for _, constraint := range constraints.Constraints {
depsMap[table] = append(depsMap[table], constraint.ForeignKey.Table)
foreignKeyMap[table][constraint.ForeignKey.Table] = constraint.ForeignKey.Column
if slices.Contains(tables, table) && slices.Contains(tables, constraint.ForeignKey.Table) {
filteredDepsMap[table] = append(filteredDepsMap[table], constraint.ForeignKey.Table)
Expand Down Expand Up @@ -340,10 +338,16 @@ func getMultiTableCircularDependencies(dependencyMap map[string][]string) [][]st
return multiTableCycles
}

func GetTablesOrderedByDependency(dependencyMap map[string][]string) ([]string, error) {
type OrderedTablesResult struct {
OrderedTables []string
HasCycles bool
}

func GetTablesOrderedByDependency(dependencyMap map[string][]string) (*OrderedTablesResult, error) {
hasCycles := false
cycles := getMultiTableCircularDependencies(dependencyMap)
if len(cycles) > 0 {
return nil, fmt.Errorf("unable to handle circular dependencies: %+v", cycles)
hasCycles = true
}

tableMap := map[string]struct{}{}
Expand All @@ -370,19 +374,31 @@ func GetTablesOrderedByDependency(dependencyMap map[string][]string) ([]string,
prevTableLen = len(tableMap)
for table := range tableMap {
deps := dependencyMap[table]
if isReady(seenTables, deps, table) {
if isReady(seenTables, deps, table, cycles) {
orderedTables = append(orderedTables, table)
seenTables[table] = struct{}{}
delete(tableMap, table)
}
}
}

return orderedTables, nil
return &OrderedTablesResult{OrderedTables: orderedTables, HasCycles: hasCycles}, nil
}

func isReady(seen map[string]struct{}, deps []string, table string) bool {
func isReady(seen map[string]struct{}, deps []string, table string, cycles [][]string) bool {
// allow circular dependencies
circularDeps := getTableCirularDependencies(table, cycles)
circularDepsMap := map[string]struct{}{}
for _, cycle := range circularDeps {
for _, t := range cycle {
circularDepsMap[t] = struct{}{}
}
}
for _, d := range deps {
_, cdOk := circularDepsMap[d]
if cdOk {
return true
}
_, ok := seen[d]
// allow self dependencies
if !ok && d != table {
Expand Down
22 changes: 15 additions & 7 deletions backend/pkg/table-dependency/table-dependency_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -694,8 +694,12 @@ func Test_GetTablesOrderedByDependency_CircularDependency(t *testing.T) {
"c": {"a"},
}

_, err := GetTablesOrderedByDependency(dependencies)
assert.Error(t, err)
resp, err := GetTablesOrderedByDependency(dependencies)
assert.NoError(t, err)
assert.Equal(t, resp.HasCycles, true)
for _, e := range resp.OrderedTables {
assert.Contains(t, []string{"a", "b", "c"}, e)
}
}

func Test_GetTablesOrderedByDependency_Dependencies(t *testing.T) {
Expand All @@ -712,7 +716,9 @@ func Test_GetTablesOrderedByDependency_Dependencies(t *testing.T) {

actual, err := GetTablesOrderedByDependency(dependencies)
assert.NoError(t, err)
for idx, table := range actual {
assert.Equal(t, actual.HasCycles, false)

for idx, table := range actual.OrderedTables {
assert.Contains(t, expected[idx], table)
}
}
Expand All @@ -728,11 +734,12 @@ func Test_GetTablesOrderedByDependency_Mixed(t *testing.T) {
expected := []string{"countries", "regions", "jobs", "locations"}
actual, err := GetTablesOrderedByDependency(dependencies)
assert.NoError(t, err)
assert.Len(t, actual, len(expected))
for _, table := range actual {
assert.Equal(t, actual.HasCycles, false)
assert.Len(t, actual.OrderedTables, len(expected))
for _, table := range actual.OrderedTables {
assert.Contains(t, expected, table)
}
assert.Equal(t, "locations", actual[len(actual)-1])
assert.Equal(t, "locations", actual.OrderedTables[len(actual.OrderedTables)-1])
}

func Test_GetTablesOrderedByDependency_BrokenDependencies_NoLoop(t *testing.T) {
Expand All @@ -758,7 +765,8 @@ func Test_GetTablesOrderedByDependency_NestedDependencies(t *testing.T) {
expected := []string{"d", "c", "b", "a"}
actual, err := GetTablesOrderedByDependency(dependencies)
assert.NoError(t, err)
assert.Equal(t, expected[0], actual[0])
assert.Equal(t, expected[0], actual.OrderedTables[0])
assert.Equal(t, actual.HasCycles, false)
}

func TestCycleKey(t *testing.T) {
Expand Down
18 changes: 12 additions & 6 deletions cli/internal/cmds/neosync/sync/sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -560,12 +560,15 @@ func runDestinationInitStatements(ctx context.Context, cmd *cmdConfig, syncConfi
}
defer pool.Close()
if cmd.Destination.InitSchema {
orderedTables, err := tabledependency.GetTablesOrderedByDependency(dependencyMap)
orderedTablesResp, err := tabledependency.GetTablesOrderedByDependency(dependencyMap)
if err != nil {
return err
}
if orderedTablesResp.HasCycles {
return errors.New("init schema: unable to handle circular dependencies")
}
orderedInitStatements := []string{}
for _, t := range orderedTables {
for _, t := range orderedTablesResp.OrderedTables {
orderedInitStatements = append(orderedInitStatements, schemaConfig.InitTableStatementsMap[t])
}
err = dbschemas_postgres.BatchExecStmts(ctx, pool, batchSize, orderedInitStatements)
Expand All @@ -589,25 +592,28 @@ func runDestinationInitStatements(ctx context.Context, cmd *cmdConfig, syncConfi
return err
}
} else if cmd.Destination.TruncateBeforeInsert {
orderedTables, err := tabledependency.GetTablesOrderedByDependency(dependencyMap)
orderedTablesResp, err := tabledependency.GetTablesOrderedByDependency(dependencyMap)
if err != nil {
return err
}
orderedTruncateStatement := dbschemas_postgres.BuildTruncateStatement(orderedTables)
orderedTruncateStatement := dbschemas_postgres.BuildTruncateStatement(orderedTablesResp.OrderedTables)
err = dbschemas_postgres.BatchExecStmts(ctx, pool, batchSize, []string{orderedTruncateStatement})
if err != nil {
fmt.Println("Error truncating tables:", err) //nolint:forbidigo
return err
}
}
} else if cmd.Destination.Driver == mysqlDriver {
orderedTables, err := tabledependency.GetTablesOrderedByDependency(dependencyMap)
orderedTablesResp, err := tabledependency.GetTablesOrderedByDependency(dependencyMap)
if err != nil {
return err
}
if cmd.Destination.InitSchema && orderedTablesResp.HasCycles {
return errors.New("init schema: unable to handle circular dependencies")
}
orderedInitStatements := []string{}
orderedTableTruncateStatements := []string{}
for _, t := range orderedTables {
for _, t := range orderedTablesResp.OrderedTables {
orderedInitStatements = append(orderedInitStatements, schemaConfig.InitTableStatementsMap[t])
orderedTableTruncateStatements = append(orderedTableTruncateStatements, schemaConfig.TruncateTableStatementsMap[t])
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,8 @@ func createSqlUpdateBenthosConfig(
newResp.metriclabels = append(newResp.metriclabels, metrics.NewEqLabel(metrics.IsUpdateConfigLabel, "true"))
output := buildOutputArgs(newResp, tm)
newResp.Columns = output.Columns
// add self to dependency so that insert always runs before update
newResp.DependsOn = append(newResp.DependsOn, &tabledependency.DependsOn{Table: insertConfig.updateConfig.Table, Columns: output.WhereCols})
if newResp.Config.Input.SqlSelect != nil {
newResp.Config.Input.SqlSelect.Where = insertConfig.Config.Input.SqlSelect.Where // keep the where clause the same as insert
} else if newResp.Config.Input.PooledSqlRaw != nil {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1999,7 +1999,7 @@ output:
updateConfig := getBenthosConfigByName(resp.BenthosConfigs, "public.users.update")
assert.NotNil(t, updateConfig)
assert.Equal(t, updateConfig.Name, "public.users.update")
assert.Equal(t, updateConfig.DependsOn, []*tabledependency.DependsOn{{Table: "public.user_account_associations", Columns: []string{"id"}}})
assert.Equal(t, updateConfig.DependsOn, []*tabledependency.DependsOn{{Table: "public.user_account_associations", Columns: []string{"id"}}, {Table: "public.users", Columns: []string{"id"}}})
out1, err := yaml.Marshal(updateConfig.Config)
assert.NoError(t, err)
assert.Equal(
Expand Down Expand Up @@ -2427,7 +2427,7 @@ output:
updateConfig := getBenthosConfigByName(resp.BenthosConfigs, "public.users.update")
assert.NotNil(t, updateConfig)
assert.Equal(t, updateConfig.Name, "public.users.update")
assert.Equal(t, updateConfig.DependsOn, []*tabledependency.DependsOn{{Table: "public.user_account_associations", Columns: []string{"id"}}})
assert.Equal(t, updateConfig.DependsOn, []*tabledependency.DependsOn{{Table: "public.user_account_associations", Columns: []string{"id"}}, {Table: "public.users", Columns: []string{"id"}}})
out1, err := yaml.Marshal(updateConfig.Config)

assert.NoError(t, err)
Expand Down Expand Up @@ -3069,7 +3069,7 @@ func Test_BenthosBuilder_GenerateBenthosConfigs_Basic_Mysql_Mysql_With_Circular_
updateConfig := getBenthosConfigByName(resp.BenthosConfigs, "public.users.update")
assert.NotNil(t, updateConfig)
assert.Equal(t, updateConfig.Name, "public.users.update")
assert.Equal(t, updateConfig.DependsOn, []*tabledependency.DependsOn{{Table: "public.user_account_associations", Columns: []string{"id"}}})
assert.Equal(t, updateConfig.DependsOn, []*tabledependency.DependsOn{{Table: "public.user_account_associations", Columns: []string{"id"}}, {Table: "public.users", Columns: []string{"id"}}})
out1, err := yaml.Marshal(updateConfig.Config)
assert.NoError(t, err)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,12 +216,16 @@ func (b *initStatementBuilder) RunSqlInitTableStatements(
// create statements
if initSchema {
tableForeignDependencyMap := getFilteredForeignToPrimaryTableMap(tableDependencies, uniqueTables)
orderedTables, err := tabledependency.GetTablesOrderedByDependency(tableForeignDependencyMap)
orderedTablesResp, err := tabledependency.GetTablesOrderedByDependency(tableForeignDependencyMap)
if err != nil {
return nil, err
}
if orderedTablesResp.HasCycles {
return nil, errors.New("init schema: unable to handle circular dependencies")
}

tableCreateStmts := []string{}
for _, table := range orderedTables {
for _, table := range orderedTablesResp.OrderedTables {
split := strings.Split(table, ".")
// todo: make this more efficient to reduce amount of times we have to connect to the source database
initStmt, err := b.getCreateStatementFromPostgres(
Expand Down Expand Up @@ -259,13 +263,13 @@ func (b *initStatementBuilder) RunSqlInitTableStatements(
}
} else if truncateBeforeInsert {
tablePrimaryDependencyMap := getFilteredForeignToPrimaryTableMap(tableDependencies, uniqueTables)
orderedTables, err := tabledependency.GetTablesOrderedByDependency(tablePrimaryDependencyMap)
orderedTablesResp, err := tabledependency.GetTablesOrderedByDependency(tablePrimaryDependencyMap)
if err != nil {
return nil, err
}

orderedTableTruncate := []string{}
for _, table := range orderedTables {
for _, table := range orderedTablesResp.OrderedTables {
split := strings.Split(table, ".")
orderedTableTruncate = append(orderedTableTruncate, fmt.Sprintf(`%q.%q`, split[0], split[1]))
}
Expand Down Expand Up @@ -314,13 +318,16 @@ func (b *initStatementBuilder) RunSqlInitTableStatements(
// create statements
if initSchema {
tableForeignDependencyMap := getFilteredForeignToPrimaryTableMap(tableDependencies, uniqueTables)
orderedTables, err := tabledependency.GetTablesOrderedByDependency(tableForeignDependencyMap)
orderedTablesResp, err := tabledependency.GetTablesOrderedByDependency(tableForeignDependencyMap)
if err != nil {
return nil, err
}
if orderedTablesResp.HasCycles {
return nil, errors.New("init schema: unable to handle circular dependencies")
}
// todo: make this more efficient to reduce amount of times we have to connect to the source database
tableCreateStmts := []string{}
for _, table := range orderedTables {
for _, table := range orderedTablesResp.OrderedTables {
split := strings.Split(table, ".")
initStmt, err := b.getCreateStatementFromMysql(
ctx,
Expand Down
3 changes: 2 additions & 1 deletion worker/pkg/workflows/datasync/workflow/workflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ func Workflow(wfctx workflow.Context, req *WorkflowRequest) (*WorkflowResponse,
if err != nil {
return nil, err
}
logger.Info("completed RetrieveActivityOptions.")

ctx = workflow.WithActivityOptions(wfctx, *actOptResp.SyncActivityOptions)
logger.Info("scheduling RunSqlInitTableStatements for execution.")
Expand All @@ -89,6 +90,7 @@ func Workflow(wfctx workflow.Context, req *WorkflowRequest) (*WorkflowResponse,
if err != nil {
return nil, err
}
logger.Info("completed RunSqlInitTableStatements.")

started := sync.Map{}
completed := sync.Map{}
Expand Down Expand Up @@ -272,7 +274,6 @@ func invokeSync(
settable.SetError(fmt.Errorf("unable to marshal benthos config: %w", err))
return
}

logger.Info("scheduling Sync for execution.")

var result sync_activity.SyncResponse
Expand Down

0 comments on commit c9ef2f9

Please sign in to comment.