From 227444282636c04edac883d0771ba695f4f0c86e Mon Sep 17 00:00:00 2001 From: Nick Zelei <2420177+nickzelei@users.noreply.github.com> Date: Wed, 24 Apr 2024 13:58:49 -0700 Subject: [PATCH] Nick/neos 1034 update postgres constraint queries to not use information schema (#1829) --- .../db/dbschemas/postgresql/mock_Querier.go | 164 +--- .../gen/go/db/dbschemas/postgresql/querier.go | 4 +- .../go/db/dbschemas/postgresql/system.sql.go | 277 +++--- backend/pkg/dbschemas/postgres/postgres.go | 256 +++-- .../pkg/dbschemas/postgres/postgres_test.go | 70 +- .../sql/postgresql/queries/system.sql | 152 ++- backend/pkg/sqlmanager/postgres-manager.go | 49 +- .../connection-data.go | 29 +- .../connection-data_test.go | 156 +-- backend/sqlc.yaml | 98 +- cli/go.mod | 1 + cli/internal/cmds/neosync/sync/sync.go | 12 +- .../gen-benthos-configs/benthos-builder.go | 2 +- .../benthos-builder_test.go | 900 +++++++++--------- .../init-statement-builder.go | 8 +- .../init-statement-builder_test.go | 44 +- 16 files changed, 1054 insertions(+), 1168 deletions(-) diff --git a/backend/gen/go/db/dbschemas/postgresql/mock_Querier.go b/backend/gen/go/db/dbschemas/postgresql/mock_Querier.go index c00b106b9..696089c24 100644 --- a/backend/gen/go/db/dbschemas/postgresql/mock_Querier.go +++ b/backend/gen/go/db/dbschemas/postgresql/mock_Querier.go @@ -140,66 +140,6 @@ func (_c *MockQuerier_GetDatabaseTableSchema_Call) RunAndReturn(run func(context return _c } -// GetForeignKeyConstraints provides a mock function with given fields: ctx, db, tableschema -func (_m *MockQuerier) GetForeignKeyConstraints(ctx context.Context, db DBTX, tableschema string) ([]*GetForeignKeyConstraintsRow, error) { - ret := _m.Called(ctx, db, tableschema) - - if len(ret) == 0 { - panic("no return value specified for GetForeignKeyConstraints") - } - - var r0 []*GetForeignKeyConstraintsRow - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, DBTX, string) ([]*GetForeignKeyConstraintsRow, error)); ok { - return rf(ctx, db, tableschema) - } - if rf, ok := ret.Get(0).(func(context.Context, DBTX, string) []*GetForeignKeyConstraintsRow); ok { - r0 = rf(ctx, db, tableschema) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]*GetForeignKeyConstraintsRow) - } - } - - if rf, ok := ret.Get(1).(func(context.Context, DBTX, string) error); ok { - r1 = rf(ctx, db, tableschema) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// MockQuerier_GetForeignKeyConstraints_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetForeignKeyConstraints' -type MockQuerier_GetForeignKeyConstraints_Call struct { - *mock.Call -} - -// GetForeignKeyConstraints is a helper method to define mock.On call -// - ctx context.Context -// - db DBTX -// - tableschema string -func (_e *MockQuerier_Expecter) GetForeignKeyConstraints(ctx interface{}, db interface{}, tableschema interface{}) *MockQuerier_GetForeignKeyConstraints_Call { - return &MockQuerier_GetForeignKeyConstraints_Call{Call: _e.mock.On("GetForeignKeyConstraints", ctx, db, tableschema)} -} - -func (_c *MockQuerier_GetForeignKeyConstraints_Call) Run(run func(ctx context.Context, db DBTX, tableschema string)) *MockQuerier_GetForeignKeyConstraints_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(DBTX), args[2].(string)) - }) - return _c -} - -func (_c *MockQuerier_GetForeignKeyConstraints_Call) Return(_a0 []*GetForeignKeyConstraintsRow, _a1 error) *MockQuerier_GetForeignKeyConstraints_Call { - _c.Call.Return(_a0, _a1) - return _c -} - -func (_c *MockQuerier_GetForeignKeyConstraints_Call) RunAndReturn(run func(context.Context, DBTX, string) ([]*GetForeignKeyConstraintsRow, error)) *MockQuerier_GetForeignKeyConstraints_Call { - _c.Call.Return(run) - return _c -} - // GetPostgresRolePermissions provides a mock function with given fields: ctx, db, role func (_m *MockQuerier) GetPostgresRolePermissions(ctx context.Context, db DBTX, role interface{}) ([]*GetPostgresRolePermissionsRow, error) { ret := _m.Called(ctx, db, role) @@ -260,66 +200,6 @@ func (_c *MockQuerier_GetPostgresRolePermissions_Call) RunAndReturn(run func(con return _c } -// GetPrimaryKeyConstraints provides a mock function with given fields: ctx, db, tableschema -func (_m *MockQuerier) GetPrimaryKeyConstraints(ctx context.Context, db DBTX, tableschema string) ([]*GetPrimaryKeyConstraintsRow, error) { - ret := _m.Called(ctx, db, tableschema) - - if len(ret) == 0 { - panic("no return value specified for GetPrimaryKeyConstraints") - } - - var r0 []*GetPrimaryKeyConstraintsRow - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, DBTX, string) ([]*GetPrimaryKeyConstraintsRow, error)); ok { - return rf(ctx, db, tableschema) - } - if rf, ok := ret.Get(0).(func(context.Context, DBTX, string) []*GetPrimaryKeyConstraintsRow); ok { - r0 = rf(ctx, db, tableschema) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]*GetPrimaryKeyConstraintsRow) - } - } - - if rf, ok := ret.Get(1).(func(context.Context, DBTX, string) error); ok { - r1 = rf(ctx, db, tableschema) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// MockQuerier_GetPrimaryKeyConstraints_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetPrimaryKeyConstraints' -type MockQuerier_GetPrimaryKeyConstraints_Call struct { - *mock.Call -} - -// GetPrimaryKeyConstraints is a helper method to define mock.On call -// - ctx context.Context -// - db DBTX -// - tableschema string -func (_e *MockQuerier_Expecter) GetPrimaryKeyConstraints(ctx interface{}, db interface{}, tableschema interface{}) *MockQuerier_GetPrimaryKeyConstraints_Call { - return &MockQuerier_GetPrimaryKeyConstraints_Call{Call: _e.mock.On("GetPrimaryKeyConstraints", ctx, db, tableschema)} -} - -func (_c *MockQuerier_GetPrimaryKeyConstraints_Call) Run(run func(ctx context.Context, db DBTX, tableschema string)) *MockQuerier_GetPrimaryKeyConstraints_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(DBTX), args[2].(string)) - }) - return _c -} - -func (_c *MockQuerier_GetPrimaryKeyConstraints_Call) Return(_a0 []*GetPrimaryKeyConstraintsRow, _a1 error) *MockQuerier_GetPrimaryKeyConstraints_Call { - _c.Call.Return(_a0, _a1) - return _c -} - -func (_c *MockQuerier_GetPrimaryKeyConstraints_Call) RunAndReturn(run func(context.Context, DBTX, string) ([]*GetPrimaryKeyConstraintsRow, error)) *MockQuerier_GetPrimaryKeyConstraints_Call { - _c.Call.Return(run) - return _c -} - // GetTableConstraints provides a mock function with given fields: ctx, db, arg func (_m *MockQuerier) GetTableConstraints(ctx context.Context, db DBTX, arg *GetTableConstraintsParams) ([]*GetTableConstraintsRow, error) { ret := _m.Called(ctx, db, arg) @@ -380,29 +260,29 @@ func (_c *MockQuerier_GetTableConstraints_Call) RunAndReturn(run func(context.Co return _c } -// GetUniqueConstraints provides a mock function with given fields: ctx, db, tableschema -func (_m *MockQuerier) GetUniqueConstraints(ctx context.Context, db DBTX, tableschema string) ([]*GetUniqueConstraintsRow, error) { - ret := _m.Called(ctx, db, tableschema) +// GetTableConstraintsBySchema provides a mock function with given fields: ctx, db, schema +func (_m *MockQuerier) GetTableConstraintsBySchema(ctx context.Context, db DBTX, schema []string) ([]*GetTableConstraintsBySchemaRow, error) { + ret := _m.Called(ctx, db, schema) if len(ret) == 0 { - panic("no return value specified for GetUniqueConstraints") + panic("no return value specified for GetTableConstraintsBySchema") } - var r0 []*GetUniqueConstraintsRow + var r0 []*GetTableConstraintsBySchemaRow var r1 error - if rf, ok := ret.Get(0).(func(context.Context, DBTX, string) ([]*GetUniqueConstraintsRow, error)); ok { - return rf(ctx, db, tableschema) + if rf, ok := ret.Get(0).(func(context.Context, DBTX, []string) ([]*GetTableConstraintsBySchemaRow, error)); ok { + return rf(ctx, db, schema) } - if rf, ok := ret.Get(0).(func(context.Context, DBTX, string) []*GetUniqueConstraintsRow); ok { - r0 = rf(ctx, db, tableschema) + if rf, ok := ret.Get(0).(func(context.Context, DBTX, []string) []*GetTableConstraintsBySchemaRow); ok { + r0 = rf(ctx, db, schema) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).([]*GetUniqueConstraintsRow) + r0 = ret.Get(0).([]*GetTableConstraintsBySchemaRow) } } - if rf, ok := ret.Get(1).(func(context.Context, DBTX, string) error); ok { - r1 = rf(ctx, db, tableschema) + if rf, ok := ret.Get(1).(func(context.Context, DBTX, []string) error); ok { + r1 = rf(ctx, db, schema) } else { r1 = ret.Error(1) } @@ -410,32 +290,32 @@ func (_m *MockQuerier) GetUniqueConstraints(ctx context.Context, db DBTX, tables return r0, r1 } -// MockQuerier_GetUniqueConstraints_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetUniqueConstraints' -type MockQuerier_GetUniqueConstraints_Call struct { +// MockQuerier_GetTableConstraintsBySchema_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetTableConstraintsBySchema' +type MockQuerier_GetTableConstraintsBySchema_Call struct { *mock.Call } -// GetUniqueConstraints is a helper method to define mock.On call +// GetTableConstraintsBySchema is a helper method to define mock.On call // - ctx context.Context // - db DBTX -// - tableschema string -func (_e *MockQuerier_Expecter) GetUniqueConstraints(ctx interface{}, db interface{}, tableschema interface{}) *MockQuerier_GetUniqueConstraints_Call { - return &MockQuerier_GetUniqueConstraints_Call{Call: _e.mock.On("GetUniqueConstraints", ctx, db, tableschema)} +// - schema []string +func (_e *MockQuerier_Expecter) GetTableConstraintsBySchema(ctx interface{}, db interface{}, schema interface{}) *MockQuerier_GetTableConstraintsBySchema_Call { + return &MockQuerier_GetTableConstraintsBySchema_Call{Call: _e.mock.On("GetTableConstraintsBySchema", ctx, db, schema)} } -func (_c *MockQuerier_GetUniqueConstraints_Call) Run(run func(ctx context.Context, db DBTX, tableschema string)) *MockQuerier_GetUniqueConstraints_Call { +func (_c *MockQuerier_GetTableConstraintsBySchema_Call) Run(run func(ctx context.Context, db DBTX, schema []string)) *MockQuerier_GetTableConstraintsBySchema_Call { _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(DBTX), args[2].(string)) + run(args[0].(context.Context), args[1].(DBTX), args[2].([]string)) }) return _c } -func (_c *MockQuerier_GetUniqueConstraints_Call) Return(_a0 []*GetUniqueConstraintsRow, _a1 error) *MockQuerier_GetUniqueConstraints_Call { +func (_c *MockQuerier_GetTableConstraintsBySchema_Call) Return(_a0 []*GetTableConstraintsBySchemaRow, _a1 error) *MockQuerier_GetTableConstraintsBySchema_Call { _c.Call.Return(_a0, _a1) return _c } -func (_c *MockQuerier_GetUniqueConstraints_Call) RunAndReturn(run func(context.Context, DBTX, string) ([]*GetUniqueConstraintsRow, error)) *MockQuerier_GetUniqueConstraints_Call { +func (_c *MockQuerier_GetTableConstraintsBySchema_Call) RunAndReturn(run func(context.Context, DBTX, []string) ([]*GetTableConstraintsBySchemaRow, error)) *MockQuerier_GetTableConstraintsBySchema_Call { _c.Call.Return(run) return _c } diff --git a/backend/gen/go/db/dbschemas/postgresql/querier.go b/backend/gen/go/db/dbschemas/postgresql/querier.go index 0f9413eaf..fe4589b42 100644 --- a/backend/gen/go/db/dbschemas/postgresql/querier.go +++ b/backend/gen/go/db/dbschemas/postgresql/querier.go @@ -11,11 +11,9 @@ import ( type Querier interface { GetDatabaseSchema(ctx context.Context, db DBTX) ([]*GetDatabaseSchemaRow, error) GetDatabaseTableSchema(ctx context.Context, db DBTX, arg *GetDatabaseTableSchemaParams) ([]*GetDatabaseTableSchemaRow, error) - GetForeignKeyConstraints(ctx context.Context, db DBTX, tableschema string) ([]*GetForeignKeyConstraintsRow, error) GetPostgresRolePermissions(ctx context.Context, db DBTX, role interface{}) ([]*GetPostgresRolePermissionsRow, error) - GetPrimaryKeyConstraints(ctx context.Context, db DBTX, tableschema string) ([]*GetPrimaryKeyConstraintsRow, error) GetTableConstraints(ctx context.Context, db DBTX, arg *GetTableConstraintsParams) ([]*GetTableConstraintsRow, error) - GetUniqueConstraints(ctx context.Context, db DBTX, tableschema string) ([]*GetUniqueConstraintsRow, error) + GetTableConstraintsBySchema(ctx context.Context, db DBTX, schema []string) ([]*GetTableConstraintsBySchemaRow, error) } var _ Querier = (*Queries)(nil) diff --git a/backend/gen/go/db/dbschemas/postgresql/system.sql.go b/backend/gen/go/db/dbschemas/postgresql/system.sql.go index 074e4614a..2cb6af229 100644 --- a/backend/gen/go/db/dbschemas/postgresql/system.sql.go +++ b/backend/gen/go/db/dbschemas/postgresql/system.sql.go @@ -231,89 +231,18 @@ func (q *Queries) GetDatabaseTableSchema(ctx context.Context, db DBTX, arg *GetD return items, nil } -const getForeignKeyConstraints = `-- name: GetForeignKeyConstraints :many -SELECT - rc.constraint_name, - rc.constraint_schema AS schema_name, - fk.table_name, - fk.column_name, - c.is_nullable, - pk.table_schema AS foreign_schema_name, - pk.table_name AS foreign_table_name, - pk.column_name AS foreign_column_name -FROM - information_schema.referential_constraints rc -JOIN information_schema.key_column_usage fk ON - fk.constraint_catalog = rc.constraint_catalog AND - fk.constraint_schema = rc.constraint_schema AND - fk.constraint_name = rc.constraint_name -JOIN information_schema.key_column_usage pk ON - pk.constraint_catalog = rc.unique_constraint_catalog AND - pk.constraint_schema = rc.unique_constraint_schema AND - pk.constraint_name = rc.unique_constraint_name -JOIN information_schema.columns c ON - c.table_schema = fk.table_schema AND - c.table_name = fk.table_name AND - c.column_name = fk.column_name -WHERE - rc.constraint_schema = $1 -ORDER BY - rc.constraint_name, - fk.ordinal_position -` - -type GetForeignKeyConstraintsRow struct { - ConstraintName string - SchemaName string - TableName string - ColumnName string - IsNullable string - ForeignSchemaName string - ForeignTableName string - ForeignColumnName string -} - -func (q *Queries) GetForeignKeyConstraints(ctx context.Context, db DBTX, tableschema string) ([]*GetForeignKeyConstraintsRow, error) { - rows, err := db.Query(ctx, getForeignKeyConstraints, tableschema) - if err != nil { - return nil, err - } - defer rows.Close() - var items []*GetForeignKeyConstraintsRow - for rows.Next() { - var i GetForeignKeyConstraintsRow - if err := rows.Scan( - &i.ConstraintName, - &i.SchemaName, - &i.TableName, - &i.ColumnName, - &i.IsNullable, - &i.ForeignSchemaName, - &i.ForeignTableName, - &i.ForeignColumnName, - ); err != nil { - return nil, err - } - items = append(items, &i) - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - const getPostgresRolePermissions = `-- name: GetPostgresRolePermissions :many SELECT - rtg.table_schema as table_schema, - rtg.table_name as table_name, + rtg.table_schema as table_schema, + rtg.table_name as table_name, rtg.privilege_type as privilege_type -FROM +FROM information_schema.role_table_grants as rtg -WHERE - table_schema NOT IN ('pg_catalog', 'information_schema') +WHERE + table_schema NOT IN ('pg_catalog', 'information_schema') AND grantee = $1 -ORDER BY - table_schema, +ORDER BY + table_schema, table_name ` @@ -343,73 +272,42 @@ func (q *Queries) GetPostgresRolePermissions(ctx context.Context, db DBTX, role return items, nil } -const getPrimaryKeyConstraints = `-- name: GetPrimaryKeyConstraints :many -SELECT - tc.table_schema AS schema_name, - tc.table_name as table_name, - tc.constraint_name as constraint_name, - kcu.column_name as column_name -FROM - information_schema.table_constraints AS tc -JOIN information_schema.key_column_usage AS kcu - ON tc.constraint_name = kcu.constraint_name - AND tc.table_schema = kcu.table_schema -WHERE - tc.table_schema = $1 - AND tc.constraint_type = 'PRIMARY KEY' -ORDER BY - tc.table_name, - kcu.column_name -` - -type GetPrimaryKeyConstraintsRow struct { - SchemaName string - TableName string - ConstraintName string - ColumnName string -} - -func (q *Queries) GetPrimaryKeyConstraints(ctx context.Context, db DBTX, tableschema string) ([]*GetPrimaryKeyConstraintsRow, error) { - rows, err := db.Query(ctx, getPrimaryKeyConstraints, tableschema) - if err != nil { - return nil, err - } - defer rows.Close() - var items []*GetPrimaryKeyConstraintsRow - for rows.Next() { - var i GetPrimaryKeyConstraintsRow - if err := rows.Scan( - &i.SchemaName, - &i.TableName, - &i.ConstraintName, - &i.ColumnName, - ); err != nil { - return nil, err - } - items = append(items, &i) - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - const getTableConstraints = `-- name: GetTableConstraints :many SELECT - nsp.nspname AS db_schema, - rel.relname AS table_name, con.conname AS constraint_name, - pg_get_constraintdef(con.oid) AS constraint_definition + con.contype::text AS constraint_type, + con.connamespace::regnamespace::text AS schema_name, + con.conrelid::regclass::text AS table_name, + CASE + WHEN con.contype IN ('f', 'p', 'u') THEN array_agg(att.attname) + ELSE NULL + END::text[] AS constraint_columns, + array_agg(att.attnotnull)::bool[] AS notnullable, + CASE + WHEN con.contype = 'f' THEN fn_cl.relnamespace::regnamespace::text + ELSE '' + END AS foreign_schema_name, + CASE + WHEN con.contype = 'f' THEN con.confrelid::regclass::text + ELSE '' + END AS foreign_table_name, + CASE + WHEN con.contype = 'f' THEN array_agg(fk_att.attname)::text[] + ELSE NULL::text[] + END AS foreign_column_names, + pg_get_constraintdef(con.oid)::text AS constraint_definition FROM pg_catalog.pg_constraint con -INNER JOIN pg_catalog.pg_class rel - ON - rel.oid = con.conrelid -INNER JOIN pg_catalog.pg_namespace nsp - ON - nsp.oid = connamespace +LEFT JOIN + pg_catalog.pg_attribute fk_att ON fk_att.attrelid = con.confrelid AND fk_att.attnum = ANY(con.confkey) +LEFT JOIN + pg_catalog.pg_attribute att ON att.attrelid = con.conrelid AND att.attnum = ANY(con.conkey) +LEFT JOIN + pg_catalog.pg_class fn_cl ON fn_cl.oid = con.confrelid WHERE - nsp.nspname = $1 AND rel.relname = $2 + con.connamespace::regnamespace::text = $1 AND con.conrelid::regclass::text = $2 +GROUP BY + con.oid, con.conname, con.conrelid, fn_cl.relnamespace, con.confrelid, con.contype ` type GetTableConstraintsParams struct { @@ -418,9 +316,15 @@ type GetTableConstraintsParams struct { } type GetTableConstraintsRow struct { - DbSchema string - TableName string ConstraintName string + ConstraintType string + SchemaName string + TableName string + ConstraintColumns []string + Notnullable []bool + ForeignSchemaName string + ForeignTableName string + ForeignColumnNames []string ConstraintDefinition string } @@ -434,9 +338,15 @@ func (q *Queries) GetTableConstraints(ctx context.Context, db DBTX, arg *GetTabl for rows.Next() { var i GetTableConstraintsRow if err := rows.Scan( - &i.DbSchema, - &i.TableName, &i.ConstraintName, + &i.ConstraintType, + &i.SchemaName, + &i.TableName, + &i.ConstraintColumns, + &i.Notnullable, + &i.ForeignSchemaName, + &i.ForeignTableName, + &i.ForeignColumnNames, &i.ConstraintDefinition, ); err != nil { return nil, err @@ -449,46 +359,77 @@ func (q *Queries) GetTableConstraints(ctx context.Context, db DBTX, arg *GetTabl return items, nil } -const getUniqueConstraints = `-- name: GetUniqueConstraints :many +const getTableConstraintsBySchema = `-- name: GetTableConstraintsBySchema :many SELECT - tc.table_schema AS schema_name, - tc.table_name AS table_name, - tc.constraint_name AS constraint_name, - kcu.column_name AS column_name + con.conname AS constraint_name, + con.contype::text AS constraint_type, + con.connamespace::regnamespace::text AS schema_name, + con.conrelid::regclass::text AS table_name, + CASE + WHEN con.contype IN ('f', 'p', 'u') THEN array_agg(att.attname) + ELSE NULL + END::text[] AS constraint_columns, + array_agg(att.attnotnull)::bool[] AS notnullable, + CASE + WHEN con.contype = 'f' THEN fn_cl.relnamespace::regnamespace::text + ELSE '' + END AS foreign_schema_name, + CASE + WHEN con.contype = 'f' THEN con.confrelid::regclass::text + ELSE '' + END AS foreign_table_name, + CASE + WHEN con.contype = 'f' THEN array_agg(fk_att.attname)::text[] + ELSE NULL::text[] + END AS foreign_column_names, + pg_get_constraintdef(con.oid)::text AS constraint_definition FROM - information_schema.table_constraints AS tc -JOIN information_schema.key_column_usage AS kcu - ON tc.constraint_name = kcu.constraint_name - AND tc.table_schema = kcu.table_schema + pg_catalog.pg_constraint con +LEFT JOIN + pg_catalog.pg_attribute fk_att ON fk_att.attrelid = con.confrelid AND fk_att.attnum = ANY(con.confkey) +LEFT JOIN + pg_catalog.pg_attribute att ON att.attrelid = con.conrelid AND att.attnum = ANY(con.conkey) +LEFT JOIN + pg_catalog.pg_class fn_cl ON fn_cl.oid = con.confrelid WHERE - tc.table_schema = $1 - AND tc.constraint_type = 'UNIQUE' -ORDER BY - tc.table_name, - kcu.column_name + con.connamespace::regnamespace::text = ANY($1::text[]) +GROUP BY + con.oid, con.conname, con.conrelid, fn_cl.relnamespace, con.confrelid, con.contype ` -type GetUniqueConstraintsRow struct { - SchemaName string - TableName string - ConstraintName string - ColumnName string +type GetTableConstraintsBySchemaRow struct { + ConstraintName string + ConstraintType string + SchemaName string + TableName string + ConstraintColumns []string + Notnullable []bool + ForeignSchemaName string + ForeignTableName string + ForeignColumnNames []string + ConstraintDefinition string } -func (q *Queries) GetUniqueConstraints(ctx context.Context, db DBTX, tableschema string) ([]*GetUniqueConstraintsRow, error) { - rows, err := db.Query(ctx, getUniqueConstraints, tableschema) +func (q *Queries) GetTableConstraintsBySchema(ctx context.Context, db DBTX, schema []string) ([]*GetTableConstraintsBySchemaRow, error) { + rows, err := db.Query(ctx, getTableConstraintsBySchema, schema) if err != nil { return nil, err } defer rows.Close() - var items []*GetUniqueConstraintsRow + var items []*GetTableConstraintsBySchemaRow for rows.Next() { - var i GetUniqueConstraintsRow + var i GetTableConstraintsBySchemaRow if err := rows.Scan( + &i.ConstraintName, + &i.ConstraintType, &i.SchemaName, &i.TableName, - &i.ConstraintName, - &i.ColumnName, + &i.ConstraintColumns, + &i.Notnullable, + &i.ForeignSchemaName, + &i.ForeignTableName, + &i.ForeignColumnNames, + &i.ConstraintDefinition, ); err != nil { return nil, err } diff --git a/backend/pkg/dbschemas/postgres/postgres.go b/backend/pkg/dbschemas/postgres/postgres.go index b7e881fef..71f3163d0 100644 --- a/backend/pkg/dbschemas/postgres/postgres.go +++ b/backend/pkg/dbschemas/postgres/postgres.go @@ -6,6 +6,7 @@ import ( "strings" pg_queries "github.com/nucleuscloud/neosync/backend/gen/go/db/dbschemas/postgresql" + "github.com/nucleuscloud/neosync/backend/internal/nucleusdb" dbschemas "github.com/nucleuscloud/neosync/backend/pkg/dbschemas" "golang.org/x/sync/errgroup" @@ -113,32 +114,40 @@ func buildNullableText(record *pg_queries.GetDatabaseTableSchemaRow) string { // Key is schema.table value is list of tables that key depends on func GetPostgresTableDependencies( - constraints []*pg_queries.GetForeignKeyConstraintsRow, -) dbschemas.TableDependency { + constraintRows []*pg_queries.GetTableConstraintsBySchemaRow, +) (dbschemas.TableDependency, error) { tableConstraints := map[string]*dbschemas.TableConstraints{} - for _, c := range constraints { - tableName := dbschemas.BuildTable(c.SchemaName, c.TableName) - - constraint, ok := tableConstraints[tableName] - if !ok { - tableConstraints[tableName] = &dbschemas.TableConstraints{ - Constraints: []*dbschemas.ForeignConstraint{ - {Column: c.ColumnName, IsNullable: dbschemas.ConvertIsNullableToBool(c.IsNullable), ForeignKey: &dbschemas.ForeignKey{ - Table: dbschemas.BuildTable(c.ForeignSchemaName, c.ForeignTableName), - Column: c.ForeignColumnName, - }}, + for _, row := range constraintRows { + if len(row.ConstraintColumns) != len(row.ForeignColumnNames) { + return nil, fmt.Errorf("length of columns was not equal to length of foreign key cols: %d %d", len(row.ConstraintColumns), len(row.ForeignColumnNames)) + } + if len(row.ConstraintColumns) != len(row.Notnullable) { + return nil, fmt.Errorf("length of columns was not equal to length of not nullable cols: %d %d", len(row.ConstraintColumns), len(row.Notnullable)) + } + + tableName := dbschemas.BuildTable(row.SchemaName, row.TableName) + for idx, colname := range row.ConstraintColumns { + fkcol := row.ForeignColumnNames[idx] + notnullable := row.Notnullable[idx] + + constraints, ok := tableConstraints[tableName] + constraint := &dbschemas.ForeignConstraint{ + Column: colname, + IsNullable: !notnullable, ForeignKey: &dbschemas.ForeignKey{ + Table: dbschemas.BuildTable(row.ForeignSchemaName, row.ForeignTableName), + Column: fkcol, }, } - } else { - constraint.Constraints = append(constraint.Constraints, &dbschemas.ForeignConstraint{ - Column: c.ColumnName, IsNullable: dbschemas.ConvertIsNullableToBool(c.IsNullable), ForeignKey: &dbschemas.ForeignKey{ - Table: dbschemas.BuildTable(c.ForeignSchemaName, c.ForeignTableName), - Column: c.ForeignColumnName, - }, - }) + if ok { + constraints.Constraints = append(constraints.Constraints, constraint) + } else { + tableConstraints[tableName] = &dbschemas.TableConstraints{ + Constraints: []*dbschemas.ForeignConstraint{constraint}, + } + } } } - return tableConstraints + return tableConstraints, nil } func GetUniqueSchemaColMappings( @@ -181,95 +190,91 @@ func ptr[T any](val T) *T { return &val } -func GetAllPostgresFkConstraints( - pgquerier pg_queries.Querier, +func GetAllPostgresForeignKeyConstraints( ctx context.Context, conn pg_queries.DBTX, - uniqueSchemas []string, -) ([]*pg_queries.GetForeignKeyConstraintsRow, error) { - holder := make([][]*pg_queries.GetForeignKeyConstraintsRow, len(uniqueSchemas)) - errgrp, errctx := errgroup.WithContext(ctx) - for idx := range uniqueSchemas { - idx := idx - schema := uniqueSchemas[idx] - errgrp.Go(func() error { - constraints, err := pgquerier.GetForeignKeyConstraints(errctx, conn, schema) - if err != nil { - return err - } - holder[idx] = constraints - return nil - }) + pgquerier pg_queries.Querier, + schemas []string, +) ([]*pg_queries.GetTableConstraintsBySchemaRow, error) { + if len(schemas) == 0 { + return []*pg_queries.GetTableConstraintsBySchemaRow{}, nil } - - if err := errgrp.Wait(); err != nil { + rows, err := pgquerier.GetTableConstraintsBySchema(ctx, conn, schemas) + if err != nil && !nucleusdb.IsNoRows(err) { return nil, err + } else if err != nil && nucleusdb.IsNoRows(err) { + return []*pg_queries.GetTableConstraintsBySchemaRow{}, nil } - output := []*pg_queries.GetForeignKeyConstraintsRow{} - for _, schemas := range holder { - output = append(output, schemas...) + output := []*pg_queries.GetTableConstraintsBySchemaRow{} + for _, row := range rows { + if row.ConstraintType != "f" { + continue + } + output = append(output, row) } return output, nil } -func GetAllPostgresPkConstraints( - pgquerier pg_queries.Querier, +func GetAllPostgresPrimaryKeyConstraints( ctx context.Context, conn pg_queries.DBTX, - uniqueSchemas []string, -) ([]*pg_queries.GetPrimaryKeyConstraintsRow, error) { - holder := make([][]*pg_queries.GetPrimaryKeyConstraintsRow, len(uniqueSchemas)) - errgrp, errctx := errgroup.WithContext(ctx) - for idx := range uniqueSchemas { - idx := idx - schema := uniqueSchemas[idx] - errgrp.Go(func() error { - constraints, err := pgquerier.GetPrimaryKeyConstraints(errctx, conn, schema) - if err != nil { - return err - } - holder[idx] = constraints - return nil - }) + pgquerier pg_queries.Querier, + schemas []string, +) ([]*pg_queries.GetTableConstraintsBySchemaRow, error) { + if len(schemas) == 0 { + return []*pg_queries.GetTableConstraintsBySchemaRow{}, nil } - - if err := errgrp.Wait(); err != nil { + rows, err := pgquerier.GetTableConstraintsBySchema(ctx, conn, schemas) + if err != nil && !nucleusdb.IsNoRows(err) { return nil, err + } else if err != nil && nucleusdb.IsNoRows(err) { + return []*pg_queries.GetTableConstraintsBySchemaRow{}, nil } - output := []*pg_queries.GetPrimaryKeyConstraintsRow{} - for _, schemas := range holder { - output = append(output, schemas...) + output := []*pg_queries.GetTableConstraintsBySchemaRow{} + for _, row := range rows { + if row.ConstraintType != "p" { + continue + } + output = append(output, row) } return output, nil } -func GetPostgresTablePrimaryKeys( - primaryKeyConstraints []*pg_queries.GetPrimaryKeyConstraintsRow, -) map[string][]string { - pkConstraintMap := map[string][]*pg_queries.GetPrimaryKeyConstraintsRow{} - for _, c := range primaryKeyConstraints { - _, ok := pkConstraintMap[c.ConstraintName] - if ok { - pkConstraintMap[c.ConstraintName] = append(pkConstraintMap[c.ConstraintName], c) +func GetAllPostgresPrimaryKeyConstraintsByTableCols( + ctx context.Context, + conn pg_queries.DBTX, + pgquerier pg_queries.Querier, + schemas []string, +) (map[string][]string, error) { + if len(schemas) == 0 { + return map[string][]string{}, nil + } + rows, err := pgquerier.GetTableConstraintsBySchema(ctx, conn, schemas) + if err != nil && !nucleusdb.IsNoRows(err) { + return nil, err + } else if err != nil && nucleusdb.IsNoRows(err) { + return map[string][]string{}, nil + } + + output := map[string][]string{} + for _, row := range rows { + if row.ConstraintType != "p" { + continue + } + key := dbschemas.BuildTable(row.SchemaName, row.TableName) + if _, ok := output[key]; ok { + output[key] = append(output[key], row.ConstraintColumns...) } else { - pkConstraintMap[c.ConstraintName] = []*pg_queries.GetPrimaryKeyConstraintsRow{c} + output[key] = append([]string{}, row.ConstraintColumns...) } } - pkMap := map[string][]string{} - for _, constraints := range pkConstraintMap { - for _, c := range constraints { - key := dbschemas.BuildTable(c.SchemaName, c.TableName) - _, ok := pkMap[key] - if ok { - pkMap[key] = append(pkMap[key], c.ColumnName) - } else { - pkMap[key] = []string{c.ColumnName} - } - } + + for key, val := range output { + output[key] = dedupeSlice(val) } - return pkMap + return output, nil } func BuildTruncateStatement( @@ -317,63 +322,52 @@ func EscapePgColumn(col string) string { return fmt.Sprintf("%q", col) } -func GetAllPostgresUniqueConstraints( - pgquerier pg_queries.Querier, +// Returns a map by table name and lists all columns that are a part of a unique constraint +func GetAllPostgresUniqueConstraintsByTableCols( ctx context.Context, conn pg_queries.DBTX, - uniqueSchemas []string, -) ([]*pg_queries.GetUniqueConstraintsRow, error) { - holder := make([][]*pg_queries.GetUniqueConstraintsRow, len(uniqueSchemas)) - errgrp, errctx := errgroup.WithContext(ctx) - for idx := range uniqueSchemas { - idx := idx - schema := uniqueSchemas[idx] - errgrp.Go(func() error { - constraints, err := pgquerier.GetUniqueConstraints(errctx, conn, schema) - if err != nil { - return err - } - holder[idx] = constraints - return nil - }) + pgquerier pg_queries.Querier, + schemas []string, +) (map[string][]string, error) { + if len(schemas) == 0 { + return map[string][]string{}, nil } - - if err := errgrp.Wait(); err != nil { + rows, err := pgquerier.GetTableConstraintsBySchema(ctx, conn, schemas) + if err != nil && !nucleusdb.IsNoRows(err) { return nil, err + } else if err != nil && nucleusdb.IsNoRows(err) { + return map[string][]string{}, nil + } + + output := map[string][]string{} + for _, row := range rows { + if row.ConstraintType != "u" { + continue + } + key := dbschemas.BuildTable(row.SchemaName, row.TableName) + if _, ok := output[key]; ok { + output[key] = append(output[key], row.ConstraintColumns...) + } else { + output[key] = append([]string{}, row.ConstraintColumns...) + } } - output := []*pg_queries.GetUniqueConstraintsRow{} - for _, schemas := range holder { - output = append(output, schemas...) + for key, val := range output { + output[key] = dedupeSlice(val) } return output, nil } -func GetPostgresTableUniqueConstraints( - uniqueConstraints []*pg_queries.GetUniqueConstraintsRow, -) map[string][]string { - uniqueConstraintMap := map[string][]*pg_queries.GetUniqueConstraintsRow{} - for _, c := range uniqueConstraints { - _, ok := uniqueConstraintMap[c.ConstraintName] - if ok { - uniqueConstraintMap[c.ConstraintName] = append(uniqueConstraintMap[c.ConstraintName], c) - } else { - uniqueConstraintMap[c.ConstraintName] = []*pg_queries.GetUniqueConstraintsRow{c} - } +func dedupeSlice(input []string) []string { + set := map[string]any{} + for _, i := range input { + set[i] = struct{}{} } - pkMap := map[string][]string{} - for _, constraints := range uniqueConstraintMap { - for _, c := range constraints { - key := dbschemas.BuildTable(c.SchemaName, c.TableName) - _, ok := pkMap[key] - if ok { - pkMap[key] = append(pkMap[key], c.ColumnName) - } else { - pkMap[key] = []string{c.ColumnName} - } - } + output := make([]string, 0, len(set)) + for key := range set { + output = append(output, key) } - return pkMap + return output } func GetPostgresRolePermissions(pgquerier pg_queries.Querier, diff --git a/backend/pkg/dbschemas/postgres/postgres_test.go b/backend/pkg/dbschemas/postgres/postgres_test.go index 7361ee0f0..c0ef855c1 100644 --- a/backend/pkg/dbschemas/postgres/postgres_test.go +++ b/backend/pkg/dbschemas/postgres/postgres_test.go @@ -7,28 +7,29 @@ import ( "github.com/jackc/pgx/v5/pgconn" pg_queries "github.com/nucleuscloud/neosync/backend/gen/go/db/dbschemas/postgresql" dbschemas "github.com/nucleuscloud/neosync/backend/pkg/dbschemas" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" ) func Test_GetPostgresTableDependencies(t *testing.T) { - constraints := []*pg_queries.GetForeignKeyConstraintsRow{ - {ConstraintName: "fk_account_user_associations_account_id", SchemaName: "neosync_api", TableName: "account_user_associations", ColumnName: "account_id", ForeignSchemaName: "neosync_api", ForeignTableName: "accounts", ForeignColumnName: "id", IsNullable: "NO"}, - {ConstraintName: "fk_account_user_associations_user_id", SchemaName: "neosync_api", TableName: "account_user_associations", ColumnName: "user_id", ForeignSchemaName: "neosync_api", ForeignTableName: "users", ForeignColumnName: "id", IsNullable: "NO"}, - {ConstraintName: "fk_connections_accounts_id", SchemaName: "neosync_api", TableName: "connections", ColumnName: "account_id", ForeignSchemaName: "neosync_api", ForeignTableName: "accounts", ForeignColumnName: "id", IsNullable: "NO"}, - {ConstraintName: "fk_connections_created_by_users_id", SchemaName: "neosync_api", TableName: "connections", ColumnName: "created_by_id", ForeignSchemaName: "neosync_api", ForeignTableName: "users", ForeignColumnName: "id", IsNullable: "YES"}, - {ConstraintName: "fk_connections_updated_by_users_id", SchemaName: "neosync_api", TableName: "connections", ColumnName: "updated_by_id", ForeignSchemaName: "neosync_api", ForeignTableName: "users", ForeignColumnName: "id", IsNullable: "NO"}, - {ConstraintName: "fk_jobdstconassoc_conn_id_conn_id", SchemaName: "neosync_api", TableName: "job_destination_connection_associations", ColumnName: "connection_id", ForeignSchemaName: "neosync_api", ForeignTableName: "connections", ForeignColumnName: "id", IsNullable: "NO"}, - {ConstraintName: "fk_jobdstconassoc_job_id_jobs_id", SchemaName: "neosync_api", TableName: "job_destination_connection_associations", ColumnName: "job_id", ForeignSchemaName: "neosync_api", ForeignTableName: "jobs", ForeignColumnName: "id", IsNullable: "NO"}, - {ConstraintName: "fk_jobs_accounts_id", SchemaName: "neosync_api", TableName: "jobs", ColumnName: "account_id", ForeignSchemaName: "neosync_api", ForeignTableName: "accounts", ForeignColumnName: "id", IsNullable: "NO"}, - {ConstraintName: "fk_jobs_accounts_id", SchemaName: "neosync_api", TableName: "jobs", ColumnName: "connection_source_id", ForeignSchemaName: "neosync_api", ForeignTableName: "connections", ForeignColumnName: "id", IsNullable: "NO"}, - {ConstraintName: "fk_jobs_created_by_users_id", SchemaName: "neosync_api", TableName: "jobs", ColumnName: "created_by_id", ForeignSchemaName: "neosync_api", ForeignTableName: "users", ForeignColumnName: "id", IsNullable: "YES"}, - {ConstraintName: "fk_jobs_updated_by_users_id", SchemaName: "neosync_api", TableName: "jobs", ColumnName: "updated_by_id", ForeignSchemaName: "neosync_api", ForeignTableName: "users", ForeignColumnName: "id", IsNullable: "NO"}, - {ConstraintName: "fk_user_identity_provider_user_id", SchemaName: "neosync_api", TableName: "user_identity_provider_associations", ColumnName: "user_id", ForeignSchemaName: "neosync_api", ForeignTableName: "users", ForeignColumnName: "id", IsNullable: "NO"}, + constraints := []*pg_queries.GetTableConstraintsBySchemaRow{ + {ConstraintName: "fk_account_user_associations_account_id", SchemaName: "neosync_api", TableName: "account_user_associations", ConstraintColumns: []string{"account_id"}, ForeignSchemaName: "neosync_api", ForeignTableName: "accounts", ForeignColumnNames: []string{"id"}, Notnullable: []bool{true}, ConstraintType: "f"}, + {ConstraintName: "fk_account_user_associations_user_id", SchemaName: "neosync_api", TableName: "account_user_associations", ConstraintColumns: []string{"user_id"}, ForeignSchemaName: "neosync_api", ForeignTableName: "users", ForeignColumnNames: []string{"id"}, Notnullable: []bool{true}, ConstraintType: "f"}, + {ConstraintName: "fk_connections_accounts_id", SchemaName: "neosync_api", TableName: "connections", ConstraintColumns: []string{"account_id"}, ForeignSchemaName: "neosync_api", ForeignTableName: "accounts", ForeignColumnNames: []string{"id"}, Notnullable: []bool{true}, ConstraintType: "f"}, + {ConstraintName: "fk_connections_created_by_users_id", SchemaName: "neosync_api", TableName: "connections", ConstraintColumns: []string{"created_by_id"}, ForeignSchemaName: "neosync_api", ForeignTableName: "users", ForeignColumnNames: []string{"id"}, Notnullable: []bool{false}, ConstraintType: "f"}, + {ConstraintName: "fk_connections_updated_by_users_id", SchemaName: "neosync_api", TableName: "connections", ConstraintColumns: []string{"updated_by_id"}, ForeignSchemaName: "neosync_api", ForeignTableName: "users", ForeignColumnNames: []string{"id"}, Notnullable: []bool{true}, ConstraintType: "f"}, + {ConstraintName: "fk_jobdstconassoc_conn_id_conn_id", SchemaName: "neosync_api", TableName: "job_destination_connection_associations", ConstraintColumns: []string{"connection_id"}, ForeignSchemaName: "neosync_api", ForeignTableName: "connections", ForeignColumnNames: []string{"id"}, Notnullable: []bool{true}, ConstraintType: "f"}, + {ConstraintName: "fk_jobdstconassoc_job_id_jobs_id", SchemaName: "neosync_api", TableName: "job_destination_connection_associations", ConstraintColumns: []string{"job_id"}, ForeignSchemaName: "neosync_api", ForeignTableName: "jobs", ForeignColumnNames: []string{"id"}, Notnullable: []bool{true}, ConstraintType: "f"}, + {ConstraintName: "fk_jobs_accounts_id", SchemaName: "neosync_api", TableName: "jobs", ConstraintColumns: []string{"account_id"}, ForeignSchemaName: "neosync_api", ForeignTableName: "accounts", ForeignColumnNames: []string{"id"}, Notnullable: []bool{true}, ConstraintType: "f"}, + {ConstraintName: "fk_jobs_accounts_id", SchemaName: "neosync_api", TableName: "jobs", ConstraintColumns: []string{"connection_source_id"}, ForeignSchemaName: "neosync_api", ForeignTableName: "connections", ForeignColumnNames: []string{"id"}, Notnullable: []bool{true}, ConstraintType: "f"}, + {ConstraintName: "fk_jobs_created_by_users_id", SchemaName: "neosync_api", TableName: "jobs", ConstraintColumns: []string{"created_by_id"}, ForeignSchemaName: "neosync_api", ForeignTableName: "users", ForeignColumnNames: []string{"id"}, Notnullable: []bool{false}, ConstraintType: "f"}, + {ConstraintName: "fk_jobs_updated_by_users_id", SchemaName: "neosync_api", TableName: "jobs", ConstraintColumns: []string{"updated_by_id"}, ForeignSchemaName: "neosync_api", ForeignTableName: "users", ForeignColumnNames: []string{"id"}, Notnullable: []bool{true}, ConstraintType: "f"}, + {ConstraintName: "fk_user_identity_provider_user_id", SchemaName: "neosync_api", TableName: "user_identity_provider_associations", ConstraintColumns: []string{"user_id"}, ForeignSchemaName: "neosync_api", ForeignTableName: "users", ForeignColumnNames: []string{"id"}, Notnullable: []bool{true}, ConstraintType: "f"}, } - td := GetPostgresTableDependencies(constraints) - assert.Equal(t, td, dbschemas.TableDependency{ + td, err := GetPostgresTableDependencies(constraints) + require.NoError(t, err) + require.Equal(t, td, dbschemas.TableDependency{ "neosync_api.account_user_associations": {Constraints: []*dbschemas.ForeignConstraint{ {Column: "account_id", IsNullable: false, ForeignKey: &dbschemas.ForeignKey{Table: "neosync_api.accounts", Column: "id"}}, {Column: "user_id", IsNullable: false, ForeignKey: &dbschemas.ForeignKey{Table: "neosync_api.users", Column: "id"}}, @@ -55,16 +56,17 @@ func Test_GetPostgresTableDependencies(t *testing.T) { } func Test_GetPostgresTableDependenciesExtraEdgeCases(t *testing.T) { - constraints := []*pg_queries.GetForeignKeyConstraintsRow{ - {ConstraintName: "t1_b_c_fkey", SchemaName: "neosync_api", TableName: "t1", ColumnName: "b", ForeignSchemaName: "neosync_api", ForeignTableName: "account_user_associations", ForeignColumnName: "account_id", IsNullable: "NO"}, - {ConstraintName: "t1_b_c_fkey", SchemaName: "neosync_api", TableName: "t1", ColumnName: "c", ForeignSchemaName: "neosync_api", ForeignTableName: "account_user_associations", ForeignColumnName: "user_id", IsNullable: "NO"}, - {ConstraintName: "t2_b_fkey", SchemaName: "neosync_api", TableName: "t2", ColumnName: "b", ForeignSchemaName: "neosync_api", ForeignTableName: "t2", ForeignColumnName: "a", IsNullable: "NO"}, - {ConstraintName: "t3_b_fkey", SchemaName: "neosync_api", TableName: "t3", ColumnName: "b", ForeignSchemaName: "neosync_api", ForeignTableName: "t4", ForeignColumnName: "a", IsNullable: "NO"}, - {ConstraintName: "t4_b_fkey", SchemaName: "neosync_api", TableName: "t4", ColumnName: "b", ForeignSchemaName: "neosync_api", ForeignTableName: "t3", ForeignColumnName: "a", IsNullable: "NO"}, + constraints := []*pg_queries.GetTableConstraintsBySchemaRow{ + {ConstraintName: "t1_b_c_fkey", SchemaName: "neosync_api", TableName: "t1", ConstraintColumns: []string{"b"}, ForeignSchemaName: "neosync_api", ForeignTableName: "account_user_associations", ForeignColumnNames: []string{"account_id"}, Notnullable: []bool{true}, ConstraintType: "f"}, + {ConstraintName: "t1_b_c_fkey", SchemaName: "neosync_api", TableName: "t1", ConstraintColumns: []string{"c"}, ForeignSchemaName: "neosync_api", ForeignTableName: "account_user_associations", ForeignColumnNames: []string{"user_id"}, Notnullable: []bool{true}, ConstraintType: "f"}, + {ConstraintName: "t2_b_fkey", SchemaName: "neosync_api", TableName: "t2", ConstraintColumns: []string{"b"}, ForeignSchemaName: "neosync_api", ForeignTableName: "t2", ForeignColumnNames: []string{"a"}, Notnullable: []bool{true}, ConstraintType: "f"}, + {ConstraintName: "t3_b_fkey", SchemaName: "neosync_api", TableName: "t3", ConstraintColumns: []string{"b"}, ForeignSchemaName: "neosync_api", ForeignTableName: "t4", ForeignColumnNames: []string{"a"}, Notnullable: []bool{true}, ConstraintType: "f"}, + {ConstraintName: "t4_b_fkey", SchemaName: "neosync_api", TableName: "t4", ConstraintColumns: []string{"b"}, ForeignSchemaName: "neosync_api", ForeignTableName: "t3", ForeignColumnNames: []string{"a"}, Notnullable: []bool{true}, ConstraintType: "f"}, } - td := GetPostgresTableDependencies(constraints) - assert.Equal(t, td, dbschemas.TableDependency{ + td, err := GetPostgresTableDependencies(constraints) + require.NoError(t, err) + require.Equal(t, td, dbschemas.TableDependency{ "neosync_api.t1": {Constraints: []*dbschemas.ForeignConstraint{ {Column: "b", IsNullable: false, ForeignKey: &dbschemas.ForeignKey{Table: "neosync_api.account_user_associations", Column: "account_id"}}, {Column: "c", IsNullable: false, ForeignKey: &dbschemas.ForeignKey{Table: "neosync_api.account_user_associations", Column: "user_id"}}, @@ -171,7 +173,7 @@ func Test_GenerateCreateTableStatement(t *testing.T) { for _, testcase := range cases { t.Run(t.Name(), func(t *testing.T) { actual := generateCreateTableStatement(testcase.schema, testcase.table, testcase.rows, testcase.constraints) - assert.Equal(t, testcase.expected, actual) + require.Equal(t, testcase.expected, actual) }) } } @@ -186,12 +188,12 @@ func Test_GetUniqueSchemaColMappings(t *testing.T) { {TableSchema: "neosync_api", TableName: "accounts", ColumnName: "id"}, }, ) - assert.Contains(t, mappings, "public.users", "job mappings are a subset of the present database schemas") - assert.Contains(t, mappings, "neosync_api.accounts", "job mappings are a subset of the present database schemas") - assert.Contains(t, mappings["public.users"], "id", "") - assert.Contains(t, mappings["public.users"], "created_by", "") - assert.Contains(t, mappings["public.users"], "updated_by", "") - assert.Contains(t, mappings["neosync_api.accounts"], "id", "") + require.Contains(t, mappings, "public.users", "job mappings are a subset of the present database schemas") + require.Contains(t, mappings, "neosync_api.accounts", "job mappings are a subset of the present database schemas") + require.Contains(t, mappings["public.users"], "id", "") + require.Contains(t, mappings["public.users"], "created_by", "") + require.Contains(t, mappings["public.users"], "updated_by", "") + require.Contains(t, mappings["neosync_api.accounts"], "id", "") } func Test_BatchExecStmts(t *testing.T) { @@ -226,14 +228,14 @@ func Test_BatchExecStmts(t *testing.T) { } err := BatchExecStmts(ctx, dbtx, tt.batchSize, tt.statements) - assert.NoError(t, err) + require.NoError(t, err) }) } } func Test_EscapePgColumns(t *testing.T) { - assert.Empty(t, EscapePgColumns(nil)) - assert.Equal( + require.Empty(t, EscapePgColumns(nil)) + require.Equal( t, EscapePgColumns([]string{"foo", "bar", "baz"}), []string{`"foo"`, `"bar"`, `"baz"`}, diff --git a/backend/pkg/dbschemas/sql/postgresql/queries/system.sql b/backend/pkg/dbschemas/sql/postgresql/queries/system.sql index b38de45d8..b999713c1 100644 --- a/backend/pkg/dbschemas/sql/postgresql/queries/system.sql +++ b/backend/pkg/dbschemas/sql/postgresql/queries/system.sql @@ -127,98 +127,88 @@ ORDER BY -- name: GetTableConstraints :many SELECT - nsp.nspname AS db_schema, - rel.relname AS table_name, con.conname AS constraint_name, - pg_get_constraintdef(con.oid) AS constraint_definition + con.contype::text AS constraint_type, + con.connamespace::regnamespace::text AS schema_name, + con.conrelid::regclass::text AS table_name, + CASE + WHEN con.contype IN ('f', 'p', 'u') THEN array_agg(att.attname) + ELSE NULL + END::text[] AS constraint_columns, + array_agg(att.attnotnull)::bool[] AS notnullable, + CASE + WHEN con.contype = 'f' THEN fn_cl.relnamespace::regnamespace::text + ELSE '' + END AS foreign_schema_name, + CASE + WHEN con.contype = 'f' THEN con.confrelid::regclass::text + ELSE '' + END AS foreign_table_name, + CASE + WHEN con.contype = 'f' THEN array_agg(fk_att.attname)::text[] + ELSE NULL::text[] + END AS foreign_column_names, + pg_get_constraintdef(con.oid)::text AS constraint_definition FROM pg_catalog.pg_constraint con -INNER JOIN pg_catalog.pg_class rel - ON - rel.oid = con.conrelid -INNER JOIN pg_catalog.pg_namespace nsp - ON - nsp.oid = connamespace +LEFT JOIN + pg_catalog.pg_attribute fk_att ON fk_att.attrelid = con.confrelid AND fk_att.attnum = ANY(con.confkey) +LEFT JOIN + pg_catalog.pg_attribute att ON att.attrelid = con.conrelid AND att.attnum = ANY(con.conkey) +LEFT JOIN + pg_catalog.pg_class fn_cl ON fn_cl.oid = con.confrelid WHERE - nsp.nspname = sqlc.arg('schema') AND rel.relname = sqlc.arg('table'); + con.connamespace::regnamespace::text = sqlc.arg('schema') AND con.conrelid::regclass::text = sqlc.arg('table') +GROUP BY + con.oid, con.conname, con.conrelid, fn_cl.relnamespace, con.confrelid, con.contype; --- name: GetForeignKeyConstraints :many +-- name: GetTableConstraintsBySchema :many SELECT - rc.constraint_name, - rc.constraint_schema AS schema_name, - fk.table_name, - fk.column_name, - c.is_nullable, - pk.table_schema AS foreign_schema_name, - pk.table_name AS foreign_table_name, - pk.column_name AS foreign_column_name -FROM - information_schema.referential_constraints rc -JOIN information_schema.key_column_usage fk ON - fk.constraint_catalog = rc.constraint_catalog AND - fk.constraint_schema = rc.constraint_schema AND - fk.constraint_name = rc.constraint_name -JOIN information_schema.key_column_usage pk ON - pk.constraint_catalog = rc.unique_constraint_catalog AND - pk.constraint_schema = rc.unique_constraint_schema AND - pk.constraint_name = rc.unique_constraint_name -JOIN information_schema.columns c ON - c.table_schema = fk.table_schema AND - c.table_name = fk.table_name AND - c.column_name = fk.column_name -WHERE - rc.constraint_schema = sqlc.arg('tableSchema') -ORDER BY - rc.constraint_name, - fk.ordinal_position; - --- name: GetPrimaryKeyConstraints :many -SELECT - tc.table_schema AS schema_name, - tc.table_name as table_name, - tc.constraint_name as constraint_name, - kcu.column_name as column_name -FROM - information_schema.table_constraints AS tc -JOIN information_schema.key_column_usage AS kcu - ON tc.constraint_name = kcu.constraint_name - AND tc.table_schema = kcu.table_schema -WHERE - tc.table_schema = sqlc.arg('tableSchema') - AND tc.constraint_type = 'PRIMARY KEY' -ORDER BY - tc.table_name, - kcu.column_name; - - --- name: GetUniqueConstraints :many -SELECT - tc.table_schema AS schema_name, - tc.table_name AS table_name, - tc.constraint_name AS constraint_name, - kcu.column_name AS column_name + con.conname AS constraint_name, + con.contype::text AS constraint_type, + con.connamespace::regnamespace::text AS schema_name, + con.conrelid::regclass::text AS table_name, + CASE + WHEN con.contype IN ('f', 'p', 'u') THEN array_agg(att.attname) + ELSE NULL + END::text[] AS constraint_columns, + array_agg(att.attnotnull)::bool[] AS notnullable, + CASE + WHEN con.contype = 'f' THEN fn_cl.relnamespace::regnamespace::text + ELSE '' + END AS foreign_schema_name, + CASE + WHEN con.contype = 'f' THEN con.confrelid::regclass::text + ELSE '' + END AS foreign_table_name, + CASE + WHEN con.contype = 'f' THEN array_agg(fk_att.attname)::text[] + ELSE NULL::text[] + END AS foreign_column_names, + pg_get_constraintdef(con.oid)::text AS constraint_definition FROM - information_schema.table_constraints AS tc -JOIN information_schema.key_column_usage AS kcu - ON tc.constraint_name = kcu.constraint_name - AND tc.table_schema = kcu.table_schema + pg_catalog.pg_constraint con +LEFT JOIN + pg_catalog.pg_attribute fk_att ON fk_att.attrelid = con.confrelid AND fk_att.attnum = ANY(con.confkey) +LEFT JOIN + pg_catalog.pg_attribute att ON att.attrelid = con.conrelid AND att.attnum = ANY(con.conkey) +LEFT JOIN + pg_catalog.pg_class fn_cl ON fn_cl.oid = con.confrelid WHERE - tc.table_schema = sqlc.arg('tableSchema') - AND tc.constraint_type = 'UNIQUE' -ORDER BY - tc.table_name, - kcu.column_name; + con.connamespace::regnamespace::text = ANY(sqlc.arg('schema')::text[]) +GROUP BY + con.oid, con.conname, con.conrelid, fn_cl.relnamespace, con.confrelid, con.contype; -- name: GetPostgresRolePermissions :many SELECT - rtg.table_schema as table_schema, - rtg.table_name as table_name, + rtg.table_schema as table_schema, + rtg.table_name as table_name, rtg.privilege_type as privilege_type -FROM +FROM information_schema.role_table_grants as rtg -WHERE - table_schema NOT IN ('pg_catalog', 'information_schema') +WHERE + table_schema NOT IN ('pg_catalog', 'information_schema') AND grantee = sqlc.arg('role') -ORDER BY - table_schema, - table_name; \ No newline at end of file +ORDER BY + table_schema, + table_name; diff --git a/backend/pkg/sqlmanager/postgres-manager.go b/backend/pkg/sqlmanager/postgres-manager.go index 34aa6478c..d89e8dbcc 100644 --- a/backend/pkg/sqlmanager/postgres-manager.go +++ b/backend/pkg/sqlmanager/postgres-manager.go @@ -45,37 +45,56 @@ func (p *PostgresManager) GetDatabaseSchema(ctx context.Context) ([]*DatabaseSch } func (p *PostgresManager) GetAllForeignKeyConstraints(ctx context.Context, schemas []string) ([]*ForeignKeyConstraintsRow, error) { - fkConstraints, err := dbschemas_postgres.GetAllPostgresFkConstraints(p.querier, ctx, p.pool, schemas) + constraints, err := dbschemas_postgres.GetAllPostgresForeignKeyConstraints(ctx, p.pool, p.querier, schemas) if err != nil { return nil, fmt.Errorf("unable to get database foreign keys for postgres connection: %w", err) } result := []*ForeignKeyConstraintsRow{} - for _, row := range fkConstraints { - result = append(result, &ForeignKeyConstraintsRow{ - SchemaName: row.SchemaName, - TableName: row.TableName, - ColumnName: row.ColumnName, - IsNullable: row.IsNullable, - ConstraintName: row.ConstraintName, - ForeignSchemaName: row.ForeignSchemaName, - ForeignTableName: row.ForeignTableName, - ForeignColumnName: row.ForeignColumnName, - }) + for _, row := range constraints { + if len(row.ConstraintColumns) != len(row.ForeignColumnNames) { + return nil, fmt.Errorf("length of columns was not equal to length of foreign key cols: %d %d", len(row.ConstraintColumns), len(row.ForeignColumnNames)) + } + if len(row.ConstraintColumns) != len(row.Notnullable) { + return nil, fmt.Errorf("length of columns was not equal to length of not nullable cols: %d %d", len(row.ConstraintColumns), len(row.Notnullable)) + } + + for idx, colname := range row.ConstraintColumns { + fkcol := row.ForeignColumnNames[idx] + notnullable := row.Notnullable[idx] + + result = append(result, &ForeignKeyConstraintsRow{ + SchemaName: row.SchemaName, + TableName: row.TableName, + ColumnName: colname, + IsNullable: convertNotNullableToNullableText(notnullable), + ConstraintName: row.ConstraintName, + ForeignSchemaName: row.ForeignSchemaName, + ForeignTableName: row.ForeignTableName, + ForeignColumnName: fkcol, + }) + } } return result, nil } +func convertNotNullableToNullableText(notnullable bool) string { + if notnullable { + return "NO" + } + return "YES" +} + func (p *PostgresManager) GetAllPrimaryKeyConstraints(ctx context.Context, schemas []string) ([]*PrimaryKeyConstraintsRow, error) { - fkConstraints, err := dbschemas_postgres.GetAllPostgresPkConstraints(p.querier, ctx, p.pool, schemas) + constraints, err := dbschemas_postgres.GetAllPostgresPrimaryKeyConstraints(ctx, p.pool, p.querier, schemas) if err != nil { return nil, fmt.Errorf("unable to get database primary keys for postgres connection: %w", err) } result := []*PrimaryKeyConstraintsRow{} - for _, row := range fkConstraints { + for _, row := range constraints { result = append(result, &PrimaryKeyConstraintsRow{ SchemaName: row.SchemaName, TableName: row.TableName, - ColumnName: row.ColumnName, + ColumnName: row.ConstraintColumns[0], // todo: hack, this should be fixed to support primary keys ConstraintName: row.ConstraintName, }) } diff --git a/backend/services/mgmt/v1alpha1/connection-data-service/connection-data.go b/backend/services/mgmt/v1alpha1/connection-data-service/connection-data.go index 297b8b58a..d1211fea9 100644 --- a/backend/services/mgmt/v1alpha1/connection-data-service/connection-data.go +++ b/backend/services/mgmt/v1alpha1/connection-data-service/connection-data.go @@ -602,7 +602,7 @@ func (s *Service) GetConnectionForeignConstraints( } defer conn.Close() - allConstraints, err := dbschemas_postgres.GetAllPostgresFkConstraints(s.pgquerier, ctx, db, schemas) + allConstraints, err := dbschemas_postgres.GetAllPostgresForeignKeyConstraints(ctx, db, s.pgquerier, schemas) if err != nil && !nucleusdb.IsNoRows(err) { return nil, err } else if err != nil && nucleusdb.IsNoRows(err) { @@ -610,7 +610,11 @@ func (s *Service) GetConnectionForeignConstraints( TableConstraints: map[string]*mgmtv1alpha1.ForeignConstraintTables{}, }), nil } - td = dbschemas_postgres.GetPostgresTableDependencies(allConstraints) + tdeps, err := dbschemas_postgres.GetPostgresTableDependencies(allConstraints) + if err != nil { + return nil, err + } + td = tdeps default: return nil, errors.New("unsupported fk connection") @@ -703,16 +707,11 @@ func (s *Service) GetConnectionPrimaryConstraints( } defer conn.Close() - allConstraints, err := dbschemas_postgres.GetAllPostgresPkConstraints(s.pgquerier, ctx, db, schemas) - if err != nil && !nucleusdb.IsNoRows(err) { + pcon, err := dbschemas_postgres.GetAllPostgresPrimaryKeyConstraintsByTableCols(ctx, db, s.pgquerier, schemas) + if err != nil { return nil, err - } else if err != nil && nucleusdb.IsNoRows(err) { - return connect.NewResponse(&mgmtv1alpha1.GetConnectionPrimaryConstraintsResponse{ - TableConstraints: map[string]*mgmtv1alpha1.PrimaryConstraint{}, - }), nil } - pc = dbschemas_postgres.GetPostgresTablePrimaryKeys(allConstraints) - + pc = pcon default: return nil, errors.New("unsupported fk connection") } @@ -1010,15 +1009,11 @@ func (s *Service) GetConnectionUniqueConstraints( } defer conn.Close() - allConstraints, err := dbschemas_postgres.GetAllPostgresUniqueConstraints(s.pgquerier, ctx, db, schemas) - if err != nil && !nucleusdb.IsNoRows(err) { + ucon, err := dbschemas_postgres.GetAllPostgresUniqueConstraintsByTableCols(ctx, db, s.pgquerier, schemas) + if err != nil { return nil, err - } else if err != nil && nucleusdb.IsNoRows(err) { - return connect.NewResponse(&mgmtv1alpha1.GetConnectionUniqueConstraintsResponse{ - TableConstraints: map[string]*mgmtv1alpha1.UniqueConstraint{}, - }), nil } - uc = dbschemas_postgres.GetPostgresTableUniqueConstraints(allConstraints) + uc = ucon default: return nil, errors.New("unsupported unique constraint connection") diff --git a/backend/services/mgmt/v1alpha1/connection-data-service/connection-data_test.go b/backend/services/mgmt/v1alpha1/connection-data-service/connection-data_test.go index cef1bf264..12e7ea438 100644 --- a/backend/services/mgmt/v1alpha1/connection-data-service/connection-data_test.go +++ b/backend/services/mgmt/v1alpha1/connection-data-service/connection-data_test.go @@ -27,8 +27,8 @@ import ( awsmanager "github.com/nucleuscloud/neosync/backend/internal/aws" "github.com/nucleuscloud/neosync/backend/internal/nucleusdb" "github.com/nucleuscloud/neosync/backend/pkg/sqlconnect" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" ) const ( @@ -106,10 +106,10 @@ func Test_GetConnectionSchema_AwsS3(t *testing.T) { {Schema: "public", Table: "regions", Column: "region_name"}, } - assert.NoError(t, err) - assert.NotNil(t, resp) - assert.Equal(t, 2, len(resp.Msg.GetSchemas())) - assert.ElementsMatch(t, expected, resp.Msg.Schemas) + require.NoError(t, err) + require.NotNil(t, resp) + require.Equal(t, 2, len(resp.Msg.GetSchemas())) + require.ElementsMatch(t, expected, resp.Msg.Schemas) } func Test_GetConnectionSchema_Postgres(t *testing.T) { @@ -158,10 +158,10 @@ func Test_GetConnectionSchema_Postgres(t *testing.T) { }) } - assert.NoError(t, err) - assert.NotNil(t, resp) - assert.Equal(t, 2, len(resp.Msg.GetSchemas())) - assert.ElementsMatch(t, expected, resp.Msg.Schemas) + require.NoError(t, err) + require.NotNil(t, resp) + require.Equal(t, 2, len(resp.Msg.GetSchemas())) + require.ElementsMatch(t, expected, resp.Msg.Schemas) } func Test_GetConnectionSchema_Mysql(t *testing.T) { @@ -201,9 +201,9 @@ func Test_GetConnectionSchema_Mysql(t *testing.T) { }, }) - assert.NoError(t, err) - assert.NotNil(t, resp) - assert.Equal(t, 2, len(resp.Msg.GetSchemas())) + require.NoError(t, err) + require.NotNil(t, resp) + require.Equal(t, 2, len(resp.Msg.GetSchemas())) if err := m.SqlMock.ExpectationsWereMet(); err != nil { t.Errorf("there were unfulfilled expectations: %s", err) } @@ -230,10 +230,10 @@ func Test_GetConnectionSchema_NoRows(t *testing.T) { }, }) - assert.NoError(t, err) - assert.NotNil(t, resp) - assert.Equal(t, 0, len(resp.Msg.GetSchemas())) - assert.ElementsMatch(t, []*mgmtv1alpha1.DatabaseColumn{}, resp.Msg.Schemas) + require.NoError(t, err) + require.NotNil(t, resp) + require.Equal(t, 0, len(resp.Msg.GetSchemas())) + require.ElementsMatch(t, []*mgmtv1alpha1.DatabaseColumn{}, resp.Msg.Schemas) if err := m.SqlMock.ExpectationsWereMet(); err != nil { t.Errorf("there were unfulfilled expectations: %s", err) } @@ -260,8 +260,8 @@ func Test_GetConnectionSchema_Error(t *testing.T) { }, }) - assert.Error(t, err) - assert.Nil(t, resp) + require.Error(t, err) + require.Nil(t, resp) if err := m.SqlMock.ExpectationsWereMet(); err != nil { t.Errorf("there were unfulfilled expectations: %s", err) } @@ -314,9 +314,9 @@ func Test_GetConnectionForeignConstraints_Mysql(t *testing.T) { }, }) - assert.Nil(t, err) - assert.Len(t, resp.Msg.TableConstraints, 1) - assert.EqualValues(t, map[string]*mgmtv1alpha1.ForeignConstraintTables{ + require.Nil(t, err) + require.Len(t, resp.Msg.TableConstraints, 1) + require.EqualValues(t, map[string]*mgmtv1alpha1.ForeignConstraintTables{ "public.user_account_associations": {Constraints: []*mgmtv1alpha1.ForeignConstraint{ {Column: "user_id", IsNullable: false, ForeignKey: &mgmtv1alpha1.ForeignKey{Table: "public.users", Column: "id"}}, }}, @@ -350,17 +350,18 @@ func Test_GetConnectionForeignConstraints_Postgres(t *testing.T) { ColumnName: "name", }, }, nil) - m.PgQueierMock.On("GetForeignKeyConstraints", mock.Anything, mock.Anything, mock.Anything). - Return([]*pg_queries.GetForeignKeyConstraintsRow{ + m.PgQueierMock.On("GetTableConstraintsBySchema", mock.Anything, mock.Anything, mock.Anything). + Return([]*pg_queries.GetTableConstraintsBySchemaRow{ { - ConstraintName: "fk_user_account_associations_user_id_users_id", - SchemaName: "public", - TableName: "user_account_associations", - ColumnName: "user_id", - IsNullable: "NO", - ForeignSchemaName: "public", - ForeignTableName: "users", - ForeignColumnName: "id", + ConstraintName: "fk_user_account_associations_user_id_users_id", + SchemaName: "public", + TableName: "user_account_associations", + ConstraintColumns: []string{"user_id"}, + Notnullable: []bool{true}, + ForeignSchemaName: "public", + ForeignTableName: "users", + ForeignColumnNames: []string{"id"}, + ConstraintType: "f", }, }, nil) @@ -370,9 +371,10 @@ func Test_GetConnectionForeignConstraints_Postgres(t *testing.T) { }, }) - assert.Nil(t, err) - assert.Len(t, resp.Msg.TableConstraints, 1) - assert.EqualValues(t, map[string]*mgmtv1alpha1.ForeignConstraintTables{ + require.NoError(t, err) + require.NotNil(t, resp) + require.Len(t, resp.Msg.TableConstraints, 1) + require.EqualValues(t, map[string]*mgmtv1alpha1.ForeignConstraintTables{ "public.user_account_associations": {Constraints: []*mgmtv1alpha1.ForeignConstraint{ {Column: "user_id", IsNullable: false, ForeignKey: &mgmtv1alpha1.ForeignKey{Table: "public.users", Column: "id"}}, }}, @@ -422,9 +424,9 @@ func Test_GetConnectionPrimaryConstraints_Mysql(t *testing.T) { }, }) - assert.Nil(t, err) - assert.Len(t, resp.Msg.TableConstraints, 1) - assert.EqualValues(t, map[string]*mgmtv1alpha1.PrimaryConstraint{ + require.Nil(t, err) + require.Len(t, resp.Msg.TableConstraints, 1) + require.EqualValues(t, map[string]*mgmtv1alpha1.PrimaryConstraint{ "public.users": {Columns: []string{"id"}}, }, resp.Msg.TableConstraints) } @@ -456,13 +458,14 @@ func Test_GetConnectionPrimaryConstraints_Postgres(t *testing.T) { ColumnName: "name", }, }, nil) - m.PgQueierMock.On("GetPrimaryKeyConstraints", mock.Anything, mock.Anything, mock.Anything). - Return([]*pg_queries.GetPrimaryKeyConstraintsRow{ + m.PgQueierMock.On("GetTableConstraintsBySchema", mock.Anything, mock.Anything, mock.Anything). + Return([]*pg_queries.GetTableConstraintsBySchemaRow{ { - ConstraintName: "pk_users_id", - SchemaName: "public", - TableName: "users", - ColumnName: "id", + ConstraintName: "pk_users_id", + SchemaName: "public", + TableName: "users", + ConstraintColumns: []string{"id"}, + ConstraintType: "p", }, }, nil) @@ -472,9 +475,9 @@ func Test_GetConnectionPrimaryConstraints_Postgres(t *testing.T) { }, }) - assert.Nil(t, err) - assert.Len(t, resp.Msg.TableConstraints, 1) - assert.EqualValues(t, map[string]*mgmtv1alpha1.PrimaryConstraint{ + require.Nil(t, err) + require.Len(t, resp.Msg.TableConstraints, 1) + require.EqualValues(t, map[string]*mgmtv1alpha1.PrimaryConstraint{ "public.users": {Columns: []string{"id"}}, }, resp.Msg.TableConstraints) } @@ -519,10 +522,10 @@ func Test_GetConnectionInitStatements_Mysql_Create(t *testing.T) { }) expectedInit := "CREATE TABLE IF NOT EXISTS public.users;" - assert.Nil(t, err) - assert.Len(t, resp.Msg.TableInitStatements, 1) - assert.Len(t, resp.Msg.TableTruncateStatements, 0) - assert.Equal(t, expectedInit, resp.Msg.TableInitStatements["public.users"]) + require.Nil(t, err) + require.Len(t, resp.Msg.TableInitStatements, 1) + require.Len(t, resp.Msg.TableTruncateStatements, 0) + require.Equal(t, expectedInit, resp.Msg.TableInitStatements["public.users"]) } func Test_GetConnectionInitStatements_Mysql_Truncate(t *testing.T) { @@ -562,10 +565,10 @@ func Test_GetConnectionInitStatements_Mysql_Truncate(t *testing.T) { }) expectedTruncate := "TRUNCATE TABLE `public`.`users`;" - assert.Nil(t, err) - assert.Len(t, resp.Msg.TableInitStatements, 0) - assert.Len(t, resp.Msg.TableTruncateStatements, 1) - assert.Equal(t, expectedTruncate, resp.Msg.TableTruncateStatements["public.users"]) + require.Nil(t, err) + require.Len(t, resp.Msg.TableInitStatements, 0) + require.Len(t, resp.Msg.TableTruncateStatements, 1) + require.Equal(t, expectedTruncate, resp.Msg.TableTruncateStatements["public.users"]) } func Test_GetConnectionInitStatements_Postgres_Create(t *testing.T) { @@ -636,10 +639,10 @@ func Test_GetConnectionInitStatements_Postgres_Create(t *testing.T) { }) expectedInit := "CREATE TABLE IF NOT EXISTS \"public\".\"users\" (\"id\" uuid NOT NULL DEFAULT gen_random_uuid(), \"name\" varchar(40) NULL, CONSTRAINT users_pkey PRIMARY KEY (id));" - assert.Nil(t, err) - assert.Len(t, resp.Msg.TableInitStatements, 1) - assert.Len(t, resp.Msg.TableTruncateStatements, 0) - assert.Equal(t, expectedInit, resp.Msg.TableInitStatements["public.users"]) + require.Nil(t, err) + require.Len(t, resp.Msg.TableInitStatements, 1) + require.Len(t, resp.Msg.TableTruncateStatements, 0) + require.Equal(t, expectedInit, resp.Msg.TableInitStatements["public.users"]) } func Test_GetConnectionInitStatements_Postgres_Truncate(t *testing.T) { @@ -681,10 +684,10 @@ func Test_GetConnectionInitStatements_Postgres_Truncate(t *testing.T) { }) expectedTruncate := "TRUNCATE TABLE \"public\".\"users\" CASCADE;" - assert.Nil(t, err) - assert.Len(t, resp.Msg.TableInitStatements, 0) - assert.Len(t, resp.Msg.TableTruncateStatements, 1) - assert.Equal(t, expectedTruncate, resp.Msg.TableTruncateStatements["public.users"]) + require.Nil(t, err) + require.Len(t, resp.Msg.TableInitStatements, 0) + require.Len(t, resp.Msg.TableTruncateStatements, 1) + require.Equal(t, expectedTruncate, resp.Msg.TableTruncateStatements["public.users"]) } type serviceMocks struct { @@ -867,7 +870,7 @@ func Test_isValidTable(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { actual := isValidTable(tt.table, tt.columns) - assert.Equal(t, tt.expected, actual) + require.Equal(t, tt.expected, actual) }) } } @@ -920,7 +923,7 @@ func Test_isValidSchema(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { actual := isValidSchema(tt.schema, tt.columns) - assert.Equal(t, tt.expected, actual) + require.Equal(t, tt.expected, actual) }) } } @@ -968,9 +971,9 @@ func Test_GetConnectionUniqueConstraints_Mysql(t *testing.T) { }, }) - assert.Nil(t, err) - assert.Len(t, resp.Msg.TableConstraints, 1) - assert.EqualValues(t, map[string]*mgmtv1alpha1.UniqueConstraint{ + require.Nil(t, err) + require.Len(t, resp.Msg.TableConstraints, 1) + require.EqualValues(t, map[string]*mgmtv1alpha1.UniqueConstraint{ "public.users": {Columns: []string{"id"}}, }, resp.Msg.TableConstraints) } @@ -1002,13 +1005,14 @@ func Test_GetConnectionUniqueConstraints_Postgres(t *testing.T) { ColumnName: "name", }, }, nil) - m.PgQueierMock.On("GetUniqueConstraints", mock.Anything, mock.Anything, mock.Anything). - Return([]*pg_queries.GetUniqueConstraintsRow{ + m.PgQueierMock.On("GetTableConstraintsBySchema", mock.Anything, mock.Anything, mock.Anything). + Return([]*pg_queries.GetTableConstraintsBySchemaRow{ { - ConstraintName: "id", - SchemaName: "public", - TableName: "users", - ColumnName: "id", + ConstraintName: "id", + SchemaName: "public", + TableName: "users", + ConstraintColumns: []string{"id"}, + ConstraintType: "u", }, }, nil) @@ -1018,9 +1022,9 @@ func Test_GetConnectionUniqueConstraints_Postgres(t *testing.T) { }, }) - assert.Nil(t, err) - assert.Len(t, resp.Msg.TableConstraints, 1) - assert.EqualValues(t, map[string]*mgmtv1alpha1.UniqueConstraint{ + require.Nil(t, err) + require.Len(t, resp.Msg.TableConstraints, 1) + require.EqualValues(t, map[string]*mgmtv1alpha1.UniqueConstraint{ "public.users": {Columns: []string{"id"}}, }, resp.Msg.TableConstraints) } diff --git a/backend/sqlc.yaml b/backend/sqlc.yaml index a997da64c..b1061fc5e 100644 --- a/backend/sqlc.yaml +++ b/backend/sqlc.yaml @@ -93,56 +93,60 @@ sql: emit_params_struct_pointers: true emit_pointers_for_null_types: true overrides: - - column: information_schema.columns.table_schema - go_type: string - - column: information_schema.columns.table_name - go_type: string - - column: information_schema.columns.column_name - go_type: string - - column: information_schema.columns.ordinal_position - go_type: int - - column: information_schema.columns.column_default - go_type: "string" - nullable: true # this only appears to work on models - - column: information_schema.columns.is_nullable - go_type: string - - column: information_schema.columns.data_type - go_type: string - - column: information_schema.columns.character_maximum_length - go_type: int - - column: information_schema.columns.numeric_precision - go_type: int - - column: information_schema.columns.numeric_scale - go_type: int - - column: information_schema.tables.table_schema - go_type: string - - column: information_schema.tables.table_name - go_type: string - - column: information_schema.referential_constraints.constraint_name - go_type: string - - column: information_schema.referential_constraints.constraint_schema - go_type: string - - column: information_schema.key_column_usage.table_schema - go_type: string - - column: information_schema.key_column_usage.table_name - go_type: string - - column: information_schema.key_column_usage.column_name - go_type: string - - column: information_schema.table_constraints.table_schema - go_type: string - - column: information_schema.table_constraints.table_name - go_type: string - - column: information_schema.table_constraints.constraint_name - go_type: string - - column: information_schema.constraint_column_usage.table_schema - go_type: string - - column: information_schema.constraint_column_usage.table_name - go_type: string - - column: information_schema.constraint_column_usage.column_name - go_type: string + # - column: information_schema.columns.table_schema + # go_type: string + # - column: information_schema.columns.table_name + # go_type: string + # - column: information_schema.columns.column_name + # go_type: string + # - column: information_schema.columns.ordinal_position + # go_type: int + # - column: information_schema.columns.column_default + # go_type: "string" + # nullable: true # this only appears to work on models + # - column: information_schema.columns.is_nullable + # go_type: string + # - column: information_schema.columns.data_type + # go_type: string + # - column: information_schema.columns.character_maximum_length + # go_type: int + # - column: information_schema.columns.numeric_precision + # go_type: int + # - column: information_schema.columns.numeric_scale + # go_type: int + # - column: information_schema.tables.table_schema + # go_type: string + # - column: information_schema.tables.table_name + # go_type: string + # - column: information_schema.referential_constraints.constraint_name + # go_type: string + # - column: information_schema.referential_constraints.constraint_schema + # go_type: string + # - column: information_schema.key_column_usage.table_schema + # go_type: string + # - column: information_schema.key_column_usage.table_name + # go_type: string + # - column: information_schema.key_column_usage.column_name + # go_type: string + # - column: information_schema.table_constraints.table_schema + # go_type: string + # - column: information_schema.table_constraints.table_name + # go_type: string + # - column: information_schema.table_constraints.constraint_name + # go_type: string + # - column: information_schema.constraint_column_usage.table_schema + # go_type: string + # - column: information_schema.constraint_column_usage.table_name + # go_type: string + # - column: information_schema.constraint_column_usage.column_name + # go_type: string - column: information_schema.role_table_grants.table_schema go_type: string - column: information_schema.role_table_grants.table_name go_type: string - column: information_schema.role_table_grants.privilege_type go_type: string + - column: pg_catalog.pg_constraint.connamespace # this is normally an int but we are casting it in queries to the friendly name + go_type: string + - column: pg_catalog.pg_constraint.conrelid # this is normally an int but we are casting it in queries to the friendly name + go_type: string diff --git a/cli/go.mod b/cli/go.mod index 2ddfd030d..691ceb26b 100644 --- a/cli/go.mod +++ b/cli/go.mod @@ -253,6 +253,7 @@ require ( golang.org/x/tools v0.18.0 // indirect golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028 // indirect google.golang.org/appengine v1.6.8 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20240401170217-c3f982113cda // indirect google.golang.org/protobuf v1.33.0 // indirect gopkg.in/ini.v1 v1.67.0 // indirect gopkg.in/jcmturner/aescts.v1 v1.0.1 // indirect diff --git a/cli/internal/cmds/neosync/sync/sync.go b/cli/internal/cmds/neosync/sync/sync.go index e3282b953..39ad76a9d 100644 --- a/cli/internal/cmds/neosync/sync/sync.go +++ b/cli/internal/cmds/neosync/sync/sync.go @@ -1199,11 +1199,15 @@ func getDestinationForeignConstraints(ctx context.Context, connectionDriver Driv } cctx, cancel := context.WithDeadline(ctx, time.Now().Add(5*time.Second)) defer cancel() - allConstraints, err := dbschemas_postgres.GetAllPostgresFkConstraints(pgquerier, cctx, pool, schemas) + allConstraints, err := dbschemas_postgres.GetAllPostgresForeignKeyConstraints(cctx, pool, pgquerier, schemas) if err != nil { return nil, err } - constraints = dbschemas_postgres.GetPostgresTableDependencies(allConstraints) + tableDeps, err := dbschemas_postgres.GetPostgresTableDependencies(allConstraints) + if err != nil { + return nil, err + } + constraints = tableDeps case mysqlDriver: mysqlquerier := mysql_queries.New() conn, err := sql.Open(string(connectionDriver), connectionUrl) @@ -1240,11 +1244,11 @@ func getDestinationPrimaryKeyConstraints(ctx context.Context, connectionDriver D } cctx, cancel := context.WithDeadline(ctx, time.Now().Add(5*time.Second)) defer cancel() - allConstraints, err := dbschemas_postgres.GetAllPostgresPkConstraints(pgquerier, cctx, pool, schemas) + pcon, err := dbschemas_postgres.GetAllPostgresPrimaryKeyConstraintsByTableCols(cctx, pool, pgquerier, schemas) if err != nil { return nil, err } - pc = dbschemas_postgres.GetPostgresTablePrimaryKeys(allConstraints) + pc = pcon case mysqlDriver: mysqlquerier := mysql_queries.New() conn, err := sql.Open(string(connectionDriver), connectionUrl) diff --git a/worker/pkg/workflows/datasync/activities/gen-benthos-configs/benthos-builder.go b/worker/pkg/workflows/datasync/activities/gen-benthos-configs/benthos-builder.go index 6afa4e6d6..0b8574554 100644 --- a/worker/pkg/workflows/datasync/activities/gen-benthos-configs/benthos-builder.go +++ b/worker/pkg/workflows/datasync/activities/gen-benthos-configs/benthos-builder.go @@ -146,7 +146,6 @@ func (b *benthosBuilder) GenerateBenthosConfigs( } groupedColInfoMap = groupedSchemas - allConstraints, err := db.GetAllForeignKeyConstraints(ctx, uniqueSchemas) if err != nil { return nil, fmt.Errorf("unable to retrieve database foreign key constraints: %w", err) @@ -158,6 +157,7 @@ func (b *benthosBuilder) GenerateBenthosConfigs( if err != nil { return nil, fmt.Errorf("unable to get all primary key constraints: %w", err) } + slogger.Info(fmt.Sprintf("found %d primary key constraints for database", len(primaryKeys))) primaryKeyMap := sql_manager.GetTablePrimaryKeysMap(primaryKeys) tables := filterNullTables(groupedMappings) diff --git a/worker/pkg/workflows/datasync/activities/gen-benthos-configs/benthos-builder_test.go b/worker/pkg/workflows/datasync/activities/gen-benthos-configs/benthos-builder_test.go index e613395c0..8ef9d5155 100644 --- a/worker/pkg/workflows/datasync/activities/gen-benthos-configs/benthos-builder_test.go +++ b/worker/pkg/workflows/datasync/activities/gen-benthos-configs/benthos-builder_test.go @@ -22,8 +22,8 @@ import ( tabledependency "github.com/nucleuscloud/neosync/backend/pkg/table-dependency" pg_models "github.com/nucleuscloud/neosync/backend/sql/postgresql/models" "github.com/nucleuscloud/neosync/worker/pkg/workflows/datasync/activities/shared" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" "gopkg.in/yaml.v3" _ "github.com/benthosdev/benthos/v4/public/components/aws" @@ -164,15 +164,15 @@ func Test_BenthosBuilder_GenerateBenthosConfigs_Basic_Generate_Pg(t *testing.T) &GenerateBenthosConfigsRequest{JobId: "123", WorkflowId: "123"}, slog.Default(), ) - assert.Nil(t, err) - assert.NotEmpty(t, resp.BenthosConfigs) - assert.Len(t, resp.BenthosConfigs, 1) + require.Nil(t, err) + require.NotEmpty(t, resp.BenthosConfigs) + require.Len(t, resp.BenthosConfigs, 1) bc := resp.BenthosConfigs[0] - assert.Equal(t, bc.Name, "public.users") - assert.Empty(t, bc.DependsOn) + require.Equal(t, bc.Name, "public.users") + require.Empty(t, bc.DependsOn) out, err := yaml.Marshal(bc.Config) - assert.NoError(t, err) - assert.Equal( + require.NoError(t, err) + require.Equal( t, strings.TrimSpace(` input: @@ -234,18 +234,18 @@ output: // create a new streambuilder instance so we can access the SetYaml method benthosenv := service.NewEnvironment() err = neosync_benthos_sql.RegisterPooledSqlInsertOutput(benthosenv, nil, false) - assert.NoError(t, err) + require.NoError(t, err) err = neosync_benthos_sql.RegisterPooledSqlRawInput(benthosenv, nil, nil) - assert.NoError(t, err) + require.NoError(t, err) err = neosync_benthos_error.RegisterErrorProcessor(benthosenv, nil) - assert.NoError(t, err) + require.NoError(t, err) err = neosync_benthos_error.RegisterErrorOutput(benthosenv, nil) - assert.NoError(t, err) + require.NoError(t, err) newSB := benthosenv.NewStreamBuilder() // SetYAML parses a full Benthos config and uses it to configure the builder. err = newSB.SetYAML(string(out)) - assert.NoError(t, err) + require.NoError(t, err) } func Test_BenthosBuilder_GenerateBenthosConfigs_Metrics(t *testing.T) { @@ -353,15 +353,15 @@ func Test_BenthosBuilder_GenerateBenthosConfigs_Metrics(t *testing.T) { &GenerateBenthosConfigsRequest{JobId: "123", WorkflowId: "123"}, slog.Default(), ) - assert.Nil(t, err) - assert.NotEmpty(t, resp.BenthosConfigs) - assert.Len(t, resp.BenthosConfigs, 1) + require.Nil(t, err) + require.NotEmpty(t, resp.BenthosConfigs) + require.Len(t, resp.BenthosConfigs, 1) bc := resp.BenthosConfigs[0] - assert.Equal(t, bc.Name, "public.users") - assert.Empty(t, bc.DependsOn) + require.Equal(t, bc.Name, "public.users") + require.Empty(t, bc.DependsOn) out, err := yaml.Marshal(bc.Config) - assert.NoError(t, err) - assert.Equal( + require.NoError(t, err) + require.Equal( t, strings.TrimSpace(` input: @@ -433,20 +433,20 @@ metrics: // create a new streambuilder instance so we can access the SetYaml method benthosenv := service.NewEnvironment() err = neosync_benthos_sql.RegisterPooledSqlInsertOutput(benthosenv, nil, false) - assert.NoError(t, err) + require.NoError(t, err) err = neosync_benthos_sql.RegisterPooledSqlRawInput(benthosenv, nil, nil) - assert.NoError(t, err) + require.NoError(t, err) err = neosync_benthos_error.RegisterErrorProcessor(benthosenv, nil) - assert.NoError(t, err) + require.NoError(t, err) err = neosync_benthos_error.RegisterErrorOutput(benthosenv, nil) - assert.NoError(t, err) + require.NoError(t, err) err = benthos_metrics.RegisterOtelMetricsExporter(benthosenv, nil) - assert.NoError(t, err) + require.NoError(t, err) newSB := benthosenv.NewStreamBuilder() // SetYAML parses a full Benthos config and uses it to configure the builder. err = newSB.SetYAML(string(out)) - assert.NoError(t, err) + require.NoError(t, err) } func Test_BenthosBuilder_GenerateBenthosConfigs_Generate_Pg_Pg(t *testing.T) { @@ -549,15 +549,15 @@ func Test_BenthosBuilder_GenerateBenthosConfigs_Generate_Pg_Pg(t *testing.T) { &GenerateBenthosConfigsRequest{JobId: "123", WorkflowId: "123"}, slog.Default(), ) - assert.Nil(t, err) - assert.NotEmpty(t, resp.BenthosConfigs) - assert.Len(t, resp.BenthosConfigs, 1) + require.Nil(t, err) + require.NotEmpty(t, resp.BenthosConfigs) + require.Len(t, resp.BenthosConfigs, 1) bc := resp.BenthosConfigs[0] - assert.Equal(t, bc.Name, "public.users") - assert.Empty(t, bc.DependsOn) + require.Equal(t, bc.Name, "public.users") + require.Empty(t, bc.DependsOn) out, err := yaml.Marshal(bc.Config) - assert.NoError(t, err) - assert.Equal( + require.NoError(t, err) + require.Equal( t, strings.TrimSpace(string(out)), strings.TrimSpace(` @@ -619,18 +619,18 @@ output: // create a new streambuilder instance so we can access the SetYaml method benthosenv := service.NewEnvironment() err = neosync_benthos_sql.RegisterPooledSqlInsertOutput(benthosenv, nil, false) - assert.NoError(t, err) + require.NoError(t, err) err = neosync_benthos_sql.RegisterPooledSqlRawInput(benthosenv, nil, nil) - assert.NoError(t, err) + require.NoError(t, err) err = neosync_benthos_error.RegisterErrorProcessor(benthosenv, nil) - assert.NoError(t, err) + require.NoError(t, err) err = neosync_benthos_error.RegisterErrorOutput(benthosenv, nil) - assert.NoError(t, err) + require.NoError(t, err) newSB := benthosenv.NewStreamBuilder() // SetYAML parses a full Benthos config and uses it to configure the builder. err = newSB.SetYAML(string(out)) - assert.NoError(t, err) + require.NoError(t, err) } func Test_BenthosBuilder_GenerateBenthosConfigs_PrimaryKey_Transformer_Pg_Pg(t *testing.T) { @@ -778,30 +778,38 @@ func Test_BenthosBuilder_GenerateBenthosConfigs_PrimaryKey_Transformer_Pg_Pg(t * ColumnName: "buyer_id", }, }, nil) - pgquerier.On("GetForeignKeyConstraints", mock.Anything, mock.Anything, mock.Anything). - Return([]*pg_queries.GetForeignKeyConstraintsRow{ + pgquerier.On("GetTableConstraintsBySchema", mock.Anything, mock.Anything, mock.Anything). + Once(). + Return([]*pg_queries.GetTableConstraintsBySchemaRow{ { - ConstraintName: "fk_user_account_associations_user_id_users_id", - SchemaName: "public", - TableName: "orders", - ColumnName: "buyer_id", - ForeignSchemaName: "public", - ForeignTableName: "users", - ForeignColumnName: "id", + ConstraintName: "fk_user_account_associations_user_id_users_id", + SchemaName: "public", + TableName: "orders", + ConstraintColumns: []string{"buyer_id"}, + ForeignSchemaName: "public", + ForeignTableName: "users", + ForeignColumnNames: []string{"id"}, + ConstraintType: "f", + Notnullable: []bool{true}, }, }, nil) - pgquerier.On("GetPrimaryKeyConstraints", mock.Anything, mock.Anything, mock.Anything). - Return([]*pg_queries.GetPrimaryKeyConstraintsRow{{ - SchemaName: "public", - TableName: "users", - ConstraintName: "users", - ColumnName: "id", - }, { - SchemaName: "public", - TableName: "orders", - ConstraintName: "orders", - ColumnName: "id", - }}, nil) + pgquerier.On("GetTableConstraintsBySchema", mock.Anything, mock.Anything, mock.Anything). + Once(). + Return( + []*pg_queries.GetTableConstraintsBySchemaRow{{ + SchemaName: "public", + TableName: "users", + ConstraintName: "users", + ConstraintColumns: []string{"id"}, + ConstraintType: "p", + }, { + SchemaName: "public", + TableName: "orders", + ConstraintName: "orders", + ConstraintColumns: []string{"id"}, + ConstraintType: "p", + }}, nil, + ) bbuilder := newBenthosBuilder(*mockSqlAdapter, mockJobClient, mockConnectionClient, mockTransformerClient, mockJobId, mockRunId, redisConfig, false) resp, err := bbuilder.GenerateBenthosConfigs( @@ -810,17 +818,17 @@ func Test_BenthosBuilder_GenerateBenthosConfigs_PrimaryKey_Transformer_Pg_Pg(t * slog.Default(), ) - assert.Nil(t, err) - assert.NotEmpty(t, resp.BenthosConfigs) - assert.Len(t, resp.BenthosConfigs, 2) + require.NoError(t, err) + require.NotEmpty(t, resp.BenthosConfigs) + require.Len(t, resp.BenthosConfigs, 2) bc := getBenthosConfigByName(resp.BenthosConfigs, "public.users") - assert.Equal(t, bc.Name, "public.users") - assert.Len(t, bc.RedisConfig, 1) - assert.Equal(t, bc.RedisConfig[0].Table, "public.users") - assert.Equal(t, bc.RedisConfig[0].Column, "id") + require.Equal(t, bc.Name, "public.users") + require.Len(t, bc.RedisConfig, 1) + require.Equal(t, bc.RedisConfig[0].Table, "public.users") + require.Equal(t, bc.RedisConfig[0].Column, "id") out, err := yaml.Marshal(bc.Config) - assert.NoError(t, err) - assert.Equal( + require.NoError(t, err) + require.Equal( t, strings.TrimSpace(string(out)), strings.TrimSpace(` @@ -880,11 +888,11 @@ output: ) bc = getBenthosConfigByName(resp.BenthosConfigs, "public.orders") - assert.Equal(t, bc.Name, "public.orders") - assert.Empty(t, bc.RedisConfig) + require.Equal(t, bc.Name, "public.orders") + require.Empty(t, bc.RedisConfig) out, err = yaml.Marshal(bc.Config) - assert.NoError(t, err) - assert.Equal( + require.NoError(t, err) + require.Equal( t, strings.TrimSpace(string(out)), strings.TrimSpace(` @@ -945,19 +953,19 @@ output: benthosenv := service.NewEnvironment() err = neosync_benthos_sql.RegisterPooledSqlInsertOutput(benthosenv, nil, false) - assert.NoError(t, err) + require.NoError(t, err) err = neosync_benthos_sql.RegisterPooledSqlRawInput(benthosenv, nil, nil) - assert.NoError(t, err) + require.NoError(t, err) err = neosync_benthos_error.RegisterErrorProcessor(benthosenv, nil) - assert.NoError(t, err) + require.NoError(t, err) err = neosync_benthos_error.RegisterErrorOutput(benthosenv, nil) - assert.NoError(t, err) + require.NoError(t, err) newSB := benthosenv.NewStreamBuilder() // SetYAML parses a full Benthos config and uses it to configure the builder. err = newSB.SetYAML(string(out)) - assert.NoError(t, err) + require.NoError(t, err) } func Test_BenthosBuilder_GenerateBenthosConfigs_PrimaryKey_Passthrough_Pg_Pg(t *testing.T) { @@ -1093,30 +1101,38 @@ func Test_BenthosBuilder_GenerateBenthosConfigs_PrimaryKey_Passthrough_Pg_Pg(t * ColumnName: "buyer_id", }, }, nil) - pgquerier.On("GetForeignKeyConstraints", mock.Anything, mock.Anything, mock.Anything). - Return([]*pg_queries.GetForeignKeyConstraintsRow{ + pgquerier.On("GetTableConstraintsBySchema", mock.Anything, mock.Anything, mock.Anything). + Once(). + Return([]*pg_queries.GetTableConstraintsBySchemaRow{ { - ConstraintName: "fk_user_account_associations_user_id_users_id", - SchemaName: "public", - TableName: "orders", - ColumnName: "buyer_id", - ForeignSchemaName: "public", - ForeignTableName: "users", - ForeignColumnName: "id", + ConstraintName: "fk_user_account_associations_user_id_users_id", + SchemaName: "public", + TableName: "orders", + ConstraintColumns: []string{"buyer_id"}, + ForeignSchemaName: "public", + ForeignTableName: "users", + ForeignColumnNames: []string{"id"}, + ConstraintType: "f", + Notnullable: []bool{true}, }, }, nil) - pgquerier.On("GetPrimaryKeyConstraints", mock.Anything, mock.Anything, mock.Anything). - Return([]*pg_queries.GetPrimaryKeyConstraintsRow{{ - SchemaName: "public", - TableName: "users", - ConstraintName: "users", - ColumnName: "id", - }, { - SchemaName: "public", - TableName: "orders", - ConstraintName: "orders", - ColumnName: "id", - }}, nil) + pgquerier.On("GetTableConstraintsBySchema", mock.Anything, mock.Anything, mock.Anything). + Once(). + Return( + []*pg_queries.GetTableConstraintsBySchemaRow{{ + SchemaName: "public", + TableName: "users", + ConstraintName: "users", + ConstraintColumns: []string{"id"}, + ConstraintType: "p", + }, { + SchemaName: "public", + TableName: "orders", + ConstraintName: "orders", + ConstraintColumns: []string{"id"}, + ConstraintType: "p", + }}, nil, + ) bbuilder := newBenthosBuilder(*mockSqlAdapter, mockJobClient, mockConnectionClient, mockTransformerClient, mockJobId, mockRunId, nil, false) resp, err := bbuilder.GenerateBenthosConfigs( @@ -1125,15 +1141,15 @@ func Test_BenthosBuilder_GenerateBenthosConfigs_PrimaryKey_Passthrough_Pg_Pg(t * slog.Default(), ) - assert.Nil(t, err) - assert.NotEmpty(t, resp.BenthosConfigs) - assert.Len(t, resp.BenthosConfigs, 2) + require.Nil(t, err) + require.NotEmpty(t, resp.BenthosConfigs) + require.Len(t, resp.BenthosConfigs, 2) bc := getBenthosConfigByName(resp.BenthosConfigs, "public.users") - assert.Equal(t, bc.Name, "public.users") - assert.Empty(t, bc.RedisConfig) + require.Equal(t, bc.Name, "public.users") + require.Empty(t, bc.RedisConfig) out, err := yaml.Marshal(bc.Config) - assert.NoError(t, err) - assert.Equal( + require.NoError(t, err) + require.Equal( t, strings.TrimSpace(string(out)), strings.TrimSpace(` @@ -1181,11 +1197,11 @@ output: ) bc = getBenthosConfigByName(resp.BenthosConfigs, "public.orders") - assert.Equal(t, bc.Name, "public.orders") - assert.Empty(t, bc.RedisConfig) + require.Equal(t, bc.Name, "public.orders") + require.Empty(t, bc.RedisConfig) out, err = yaml.Marshal(bc.Config) - assert.NoError(t, err) - assert.Equal( + require.NoError(t, err) + require.Equal( t, strings.TrimSpace(string(out)), strings.TrimSpace(` @@ -1234,16 +1250,16 @@ output: benthosenv := service.NewEnvironment() err = neosync_benthos_sql.RegisterPooledSqlInsertOutput(benthosenv, nil, false) - assert.NoError(t, err) + require.NoError(t, err) err = neosync_benthos_sql.RegisterPooledSqlRawInput(benthosenv, nil, nil) - assert.NoError(t, err) + require.NoError(t, err) err = neosync_benthos_error.RegisterErrorOutput(benthosenv, nil) - assert.NoError(t, err) + require.NoError(t, err) newSB := benthosenv.NewStreamBuilder() // SetYAML parses a full Benthos config and uses it to configure the builder. err = newSB.SetYAML(string(out)) - assert.NoError(t, err) + require.NoError(t, err) } func Test_BenthosBuilder_GenerateBenthosConfigs_CircularDependency_PrimaryKey_Transformer_Pg_Pg(t *testing.T) { @@ -1365,25 +1381,33 @@ func Test_BenthosBuilder_GenerateBenthosConfigs_CircularDependency_PrimaryKey_Tr ColumnName: "parent_id", }, }, nil) - pgquerier.On("GetForeignKeyConstraints", mock.Anything, mock.Anything, mock.Anything). - Return([]*pg_queries.GetForeignKeyConstraintsRow{ + pgquerier.On("GetTableConstraintsBySchema", mock.Anything, mock.Anything, mock.Anything). + Once(). + Return([]*pg_queries.GetTableConstraintsBySchemaRow{ { - ConstraintName: "fk_user_account_associations_user_id_users_id", - SchemaName: "public", - TableName: "jobs", - ColumnName: "parent_id", - ForeignSchemaName: "public", - ForeignTableName: "jobs", - ForeignColumnName: "id", + ConstraintName: "fk_user_account_associations_user_id_users_id", + SchemaName: "public", + TableName: "jobs", + ConstraintColumns: []string{"parent_id"}, + ForeignSchemaName: "public", + ForeignTableName: "jobs", + ForeignColumnNames: []string{"id"}, + ConstraintType: "f", + Notnullable: []bool{false}, }, }, nil) - pgquerier.On("GetPrimaryKeyConstraints", mock.Anything, mock.Anything, mock.Anything). - Return([]*pg_queries.GetPrimaryKeyConstraintsRow{{ - SchemaName: "public", - TableName: "jobs", - ConstraintName: "job", - ColumnName: "id", - }}, nil) + pgquerier.On("GetTableConstraintsBySchema", mock.Anything, mock.Anything, mock.Anything). + Once(). + Return( + []*pg_queries.GetTableConstraintsBySchemaRow{{ + SchemaName: "public", + TableName: "jobs", + ConstraintName: "job", + ConstraintColumns: []string{"id"}, + ConstraintType: "p", + }}, nil, + ) + bbuilder := newBenthosBuilder(*mockSqlAdapter, mockJobClient, mockConnectionClient, mockTransformerClient, mockJobId, mockRunId, redisConfig, false) resp, err := bbuilder.GenerateBenthosConfigs( @@ -1392,17 +1416,17 @@ func Test_BenthosBuilder_GenerateBenthosConfigs_CircularDependency_PrimaryKey_Tr slog.Default(), ) - assert.Nil(t, err) - assert.NotEmpty(t, resp.BenthosConfigs) - assert.Len(t, resp.BenthosConfigs, 2) + require.NoError(t, err) + require.NotEmpty(t, resp.BenthosConfigs) + require.Len(t, resp.BenthosConfigs, 2) bc := getBenthosConfigByName(resp.BenthosConfigs, "public.jobs") - assert.Equal(t, bc.Name, "public.jobs") - assert.Len(t, bc.RedisConfig, 1) - assert.Equal(t, bc.RedisConfig[0].Table, "public.jobs") - assert.Equal(t, bc.RedisConfig[0].Column, "id") + require.Equal(t, bc.Name, "public.jobs") + require.Len(t, bc.RedisConfig, 1) + require.Equal(t, bc.RedisConfig[0].Table, "public.jobs") + require.Equal(t, bc.RedisConfig[0].Column, "id") out, err := yaml.Marshal(bc.Config) - assert.NoError(t, err) - assert.Equal( + require.NoError(t, err) + require.Equal( t, strings.TrimSpace(string(out)), strings.TrimSpace(` @@ -1461,11 +1485,11 @@ output: ) bc = getBenthosConfigByName(resp.BenthosConfigs, "public.jobs.update") - assert.Equal(t, bc.Name, "public.jobs.update") - assert.Empty(t, bc.RedisConfig) + require.Equal(t, bc.Name, "public.jobs.update") + require.Empty(t, bc.RedisConfig) out, err = yaml.Marshal(bc.Config) - assert.NoError(t, err) - assert.Equal( + require.NoError(t, err) + require.Equal( t, strings.TrimSpace(string(out)), strings.TrimSpace(` @@ -1535,20 +1559,20 @@ output: benthosenv := service.NewEnvironment() err = neosync_benthos_sql.RegisterPooledSqlInsertOutput(benthosenv, nil, false) - assert.NoError(t, err) + require.NoError(t, err) err = neosync_benthos_sql.RegisterPooledSqlUpdateOutput(benthosenv, nil) - assert.NoError(t, err) + require.NoError(t, err) err = neosync_benthos_sql.RegisterPooledSqlRawInput(benthosenv, nil, nil) - assert.NoError(t, err) + require.NoError(t, err) err = neosync_benthos_error.RegisterErrorProcessor(benthosenv, nil) - assert.NoError(t, err) + require.NoError(t, err) err = neosync_benthos_error.RegisterErrorOutput(benthosenv, nil) - assert.NoError(t, err) + require.NoError(t, err) newSB := benthosenv.NewStreamBuilder() // SetYAML parses a full Benthos config and uses it to configure the builder. err = newSB.SetYAML(string(out)) - assert.NoError(t, err) + require.NoError(t, err) } func Test_BenthosBuilder_GenerateBenthosConfigs_Basic_Pg_Pg_With_Constraints(t *testing.T) { @@ -1684,30 +1708,39 @@ func Test_BenthosBuilder_GenerateBenthosConfigs_Basic_Pg_Pg_With_Constraints(t * ColumnName: "user_id", }, }, nil) - pgquerier.On("GetForeignKeyConstraints", mock.Anything, mock.Anything, mock.Anything). - Return([]*pg_queries.GetForeignKeyConstraintsRow{ + pgquerier.On("GetTableConstraintsBySchema", mock.Anything, mock.Anything, mock.Anything). + Once(). + Return([]*pg_queries.GetTableConstraintsBySchemaRow{ { - ConstraintName: "fk_user_account_associations_user_id_users_id", - SchemaName: "public", - TableName: "user_account_associations", - ColumnName: "user_id", - ForeignSchemaName: "public", - ForeignTableName: "users", - ForeignColumnName: "id", + ConstraintName: "fk_user_account_associations_user_id_users_id", + SchemaName: "public", + TableName: "user_account_associations", + ConstraintColumns: []string{"user_id"}, + ForeignSchemaName: "public", + ForeignTableName: "users", + ForeignColumnNames: []string{"id"}, + ConstraintType: "f", + Notnullable: []bool{true}, }, }, nil) - pgquerier.On("GetPrimaryKeyConstraints", mock.Anything, mock.Anything, mock.Anything). - Return([]*pg_queries.GetPrimaryKeyConstraintsRow{{ - SchemaName: "public", - TableName: "users", - ConstraintName: "name", - ColumnName: "id", - }, { - SchemaName: "public", - TableName: "user_account_associations", - ConstraintName: "acc_assoc_constraint", - ColumnName: "id", - }}, nil) + pgquerier.On("GetTableConstraintsBySchema", mock.Anything, mock.Anything, mock.Anything). + Once(). + Return( + []*pg_queries.GetTableConstraintsBySchemaRow{{ + SchemaName: "public", + TableName: "users", + ConstraintName: "name", + ConstraintColumns: []string{"id"}, + ConstraintType: "p", + }, { + SchemaName: "public", + TableName: "user_account_associations", + ConstraintName: "acc_assoc_constraint", + ConstraintColumns: []string{"id"}, + ConstraintType: "p", + }}, nil, + ) + bbuilder := newBenthosBuilder(*mockSqlAdapter, mockJobClient, mockConnectionClient, mockTransformerClient, mockJobId, mockRunId, nil, false) resp, err := bbuilder.GenerateBenthosConfigs( @@ -1715,38 +1748,38 @@ func Test_BenthosBuilder_GenerateBenthosConfigs_Basic_Pg_Pg_With_Constraints(t * &GenerateBenthosConfigsRequest{JobId: "123", WorkflowId: "123"}, slog.Default(), ) - assert.Nil(t, err) - assert.NotEmpty(t, resp.BenthosConfigs) - assert.Len(t, resp.BenthosConfigs, 2) + require.Nil(t, err) + require.NotEmpty(t, resp.BenthosConfigs) + require.Len(t, resp.BenthosConfigs, 2) bc := getBenthosConfigByName(resp.BenthosConfigs, "public.users") - assert.NotNil(t, bc) - assert.Equal(t, bc.Name, "public.users") - assert.Empty(t, bc.DependsOn) + require.NotNil(t, bc) + require.Equal(t, bc.Name, "public.users") + require.Empty(t, bc.DependsOn) out, err := yaml.Marshal(bc.Config) - assert.NoError(t, err) + require.NoError(t, err) bc2 := getBenthosConfigByName(resp.BenthosConfigs, "public.user_account_associations") - assert.Equal(t, bc2.Name, "public.user_account_associations") - assert.Equal(t, bc2.DependsOn, []*tabledependency.DependsOn{{Table: "public.users", Columns: []string{"id"}}}) + require.Equal(t, bc2.Name, "public.user_account_associations") + require.Equal(t, bc2.DependsOn, []*tabledependency.DependsOn{{Table: "public.users", Columns: []string{"id"}}}) out2, err := yaml.Marshal(bc2.Config) - assert.NoError(t, err) + require.NoError(t, err) benthosenv := service.NewEnvironment() err = neosync_benthos_sql.RegisterPooledSqlInsertOutput(benthosenv, nil, false) - assert.NoError(t, err) + require.NoError(t, err) err = neosync_benthos_sql.RegisterPooledSqlRawInput(benthosenv, nil, nil) - assert.NoError(t, err) + require.NoError(t, err) err = neosync_benthos_error.RegisterErrorOutput(benthosenv, nil) - assert.NoError(t, err) + require.NoError(t, err) newSB := benthosenv.NewStreamBuilder() // SetYAML parses a full Benthos config and uses it to configure the builder. err = newSB.SetYAML(string(out)) - assert.NoError(t, err) + require.NoError(t, err) err = newSB.SetYAML(string(out2)) - assert.NoError(t, err) + require.NoError(t, err) } func Test_BenthosBuilder_GenerateBenthosConfigs_Basic_Pg_Pg_With_Circular_Dependency(t *testing.T) { @@ -1895,43 +1928,49 @@ func Test_BenthosBuilder_GenerateBenthosConfigs_Basic_Pg_Pg_With_Circular_Depend ColumnName: "user_id", }, }, nil) - pgquerier.On("GetForeignKeyConstraints", mock.Anything, mock.Anything, mock.Anything). - Return([]*pg_queries.GetForeignKeyConstraintsRow{ + pgquerier.On("GetTableConstraintsBySchema", mock.Anything, mock.Anything, mock.Anything). + Once(). + Return([]*pg_queries.GetTableConstraintsBySchemaRow{ { - ConstraintName: "fk_user_account_associations_user_id_users_id", - SchemaName: "public", - TableName: "user_account_associations", - ColumnName: "user_id", - ForeignSchemaName: "public", - ForeignTableName: "users", - ForeignColumnName: "id", - IsNullable: "NO", + ConstraintName: "fk_user_account_associations_user_id_users_id", + SchemaName: "public", + TableName: "user_account_associations", + ConstraintColumns: []string{"user_id"}, + ForeignSchemaName: "public", + ForeignTableName: "users", + ForeignColumnNames: []string{"id"}, + ConstraintType: "f", + Notnullable: []bool{true}, }, { - ConstraintName: "fk_users_user_assoc_id_user_account_associations_id", - SchemaName: "public", - TableName: "users", - ColumnName: "user_assoc_id", - ForeignSchemaName: "public", - ForeignTableName: "user_account_associations", - ForeignColumnName: "id", - IsNullable: "YES", + ConstraintName: "fk_users_user_assoc_id_user_account_associations_id", + SchemaName: "public", + TableName: "users", + ConstraintColumns: []string{"user_assoc_id"}, + ForeignSchemaName: "public", + ForeignTableName: "user_account_associations", + ForeignColumnNames: []string{"id"}, + ConstraintType: "f", + Notnullable: []bool{false}, }, }, nil) - pgquerier.On("GetPrimaryKeyConstraints", mock.Anything, mock.Anything, mock.Anything).Return([]*pg_queries.GetPrimaryKeyConstraintsRow{ - { - ConstraintName: "pkey-user-id", - SchemaName: "public", - TableName: "users", - ColumnName: "id", - }, - { - ConstraintName: "pkey-user-assoc-id", - SchemaName: "public", - TableName: "users_account_associations", - ColumnName: "id", - }, - }, nil) + pgquerier.On("GetTableConstraintsBySchema", mock.Anything, mock.Anything, mock.Anything). + Once(). + Return( + []*pg_queries.GetTableConstraintsBySchemaRow{{ + SchemaName: "public", + TableName: "users", + ConstraintName: "pk-user-id", + ConstraintColumns: []string{"id"}, + ConstraintType: "p", + }, { + SchemaName: "public", + TableName: "users_account_associations", + ConstraintName: "pk-user-assoc-id", + ConstraintColumns: []string{"id"}, + ConstraintType: "p", + }}, nil, + ) bbuilder := newBenthosBuilder(*mockSqlAdapter, mockJobClient, mockConnectionClient, mockTransformerClient, mockJobId, mockRunId, nil, false) resp, err := bbuilder.GenerateBenthosConfigs( @@ -1939,17 +1978,17 @@ func Test_BenthosBuilder_GenerateBenthosConfigs_Basic_Pg_Pg_With_Circular_Depend &GenerateBenthosConfigsRequest{JobId: "123", WorkflowId: "123"}, slog.Default(), ) - assert.Nil(t, err) - assert.NotEmpty(t, resp.BenthosConfigs) - assert.Len(t, resp.BenthosConfigs, 3) + require.Nil(t, err) + require.NotEmpty(t, resp.BenthosConfigs) + require.Len(t, resp.BenthosConfigs, 3) insertConfig := getBenthosConfigByName(resp.BenthosConfigs, "public.users") - assert.NotNil(t, insertConfig) - assert.Equal(t, insertConfig.Name, "public.users") - assert.Empty(t, insertConfig.DependsOn) + require.NotNil(t, insertConfig) + require.Equal(t, insertConfig.Name, "public.users") + require.Empty(t, insertConfig.DependsOn) out, err := yaml.Marshal(insertConfig.Config) - assert.NoError(t, err) - assert.Equal( + require.NoError(t, err) + require.Equal( t, strings.TrimSpace(string(out)), strings.TrimSpace(` @@ -1997,12 +2036,12 @@ output: ) updateConfig := getBenthosConfigByName(resp.BenthosConfigs, "public.users.update") - assert.NotNil(t, updateConfig) - assert.Equal(t, updateConfig.Name, "public.users.update") - assert.Equal(t, updateConfig.DependsOn, []*tabledependency.DependsOn{{Table: "public.user_account_associations", Columns: []string{"id"}}, {Table: "public.users", Columns: []string{"id"}}}) + require.NotNil(t, updateConfig) + require.Equal(t, updateConfig.Name, "public.users.update") + require.Equal(t, updateConfig.DependsOn, []*tabledependency.DependsOn{{Table: "public.user_account_associations", Columns: []string{"id"}}, {Table: "public.users", Columns: []string{"id"}}}) out1, err := yaml.Marshal(updateConfig.Config) - assert.NoError(t, err) - assert.Equal( + require.NoError(t, err) + require.Equal( t, strings.TrimSpace(string(out1)), strings.TrimSpace(` @@ -2049,11 +2088,11 @@ output: ) bc2 := getBenthosConfigByName(resp.BenthosConfigs, "public.user_account_associations") - assert.Equal(t, bc2.Name, "public.user_account_associations") - assert.Equal(t, bc2.DependsOn, []*tabledependency.DependsOn{{Table: "public.users", Columns: []string{"id"}}}) + require.Equal(t, bc2.Name, "public.user_account_associations") + require.Equal(t, bc2.DependsOn, []*tabledependency.DependsOn{{Table: "public.users", Columns: []string{"id"}}}) out2, err := yaml.Marshal(bc2.Config) - assert.NoError(t, err) - assert.Equal( + require.NoError(t, err) + require.Equal( t, strings.TrimSpace(string(out2)), strings.TrimSpace(` @@ -2102,21 +2141,21 @@ output: benthosenv := service.NewEnvironment() err = neosync_benthos_sql.RegisterPooledSqlInsertOutput(benthosenv, nil, false) - assert.NoError(t, err) + require.NoError(t, err) err = neosync_benthos_sql.RegisterPooledSqlUpdateOutput(benthosenv, nil) - assert.NoError(t, err) + require.NoError(t, err) err = neosync_benthos_sql.RegisterPooledSqlRawInput(benthosenv, nil, nil) - assert.NoError(t, err) + require.NoError(t, err) err = neosync_benthos_error.RegisterErrorOutput(benthosenv, nil) - assert.NoError(t, err) + require.NoError(t, err) newSB := benthosenv.NewStreamBuilder() // SetYAML parses a full Benthos config and uses it to configure the builder. err = newSB.SetYAML(string(out)) - assert.NoError(t, err) + require.NoError(t, err) err = newSB.SetYAML(string(out2)) - assert.NoError(t, err) + require.NoError(t, err) } func Test_BenthosBuilder_GenerateBenthosConfigs_Basic_Pg_Pg_With_Circular_Dependency_S3(t *testing.T) { @@ -2295,43 +2334,50 @@ func Test_BenthosBuilder_GenerateBenthosConfigs_Basic_Pg_Pg_With_Circular_Depend ColumnName: "user_id", }, }, nil) - pgquerier.On("GetForeignKeyConstraints", mock.Anything, mock.Anything, mock.Anything). - Return([]*pg_queries.GetForeignKeyConstraintsRow{ + pgquerier.On("GetTableConstraintsBySchema", mock.Anything, mock.Anything, mock.Anything). + Once(). + Return([]*pg_queries.GetTableConstraintsBySchemaRow{ { - ConstraintName: "fk_user_account_associations_user_id_users_id", - SchemaName: "public", - TableName: "user_account_associations", - ColumnName: "user_id", - ForeignSchemaName: "public", - ForeignTableName: "users", - ForeignColumnName: "id", - IsNullable: "NO", + ConstraintName: "fk_user_account_associations_user_id_users_id", + SchemaName: "public", + TableName: "user_account_associations", + ConstraintColumns: []string{"user_id"}, + ForeignSchemaName: "public", + ForeignTableName: "users", + ForeignColumnNames: []string{"id"}, + ConstraintType: "f", + Notnullable: []bool{true}, }, { - ConstraintName: "fk_users_user_assoc_id_user_account_associations_id", - SchemaName: "public", - TableName: "users", - ColumnName: "user_assoc_id", - ForeignSchemaName: "public", - ForeignTableName: "user_account_associations", - ForeignColumnName: "id", - IsNullable: "YES", + ConstraintName: "fk_users_user_assoc_id_user_account_associations_id", + SchemaName: "public", + TableName: "users", + ConstraintColumns: []string{"user_assoc_id"}, + ForeignSchemaName: "public", + ForeignTableName: "user_account_associations", + ForeignColumnNames: []string{"id"}, + ConstraintType: "f", + Notnullable: []bool{false}, }, }, nil) - pgquerier.On("GetPrimaryKeyConstraints", mock.Anything, mock.Anything, mock.Anything).Return([]*pg_queries.GetPrimaryKeyConstraintsRow{ - { - ConstraintName: "pkey-user-id", - SchemaName: "public", - TableName: "users", - ColumnName: "id", - }, - { - ConstraintName: "pkey-user-assoc-id", - SchemaName: "public", - TableName: "users_account_associations", - ColumnName: "id", - }, - }, nil) + pgquerier.On("GetTableConstraintsBySchema", mock.Anything, mock.Anything, mock.Anything). + Once(). + Return( + []*pg_queries.GetTableConstraintsBySchemaRow{{ + SchemaName: "public", + TableName: "users", + ConstraintName: "pk-user-id", + ConstraintColumns: []string{"id"}, + ConstraintType: "p", + }, { + SchemaName: "public", + TableName: "users_account_associations", + ConstraintName: "pk-user-assoc-id", + ConstraintColumns: []string{"id"}, + ConstraintType: "p", + }}, nil, + ) + bbuilder := newBenthosBuilder(*mockSqlAdapter, mockJobClient, mockConnectionClient, mockTransformerClient, mockJobId, mockRunId, nil, false) resp, err := bbuilder.GenerateBenthosConfigs( @@ -2339,18 +2385,18 @@ func Test_BenthosBuilder_GenerateBenthosConfigs_Basic_Pg_Pg_With_Circular_Depend &GenerateBenthosConfigsRequest{JobId: "123", WorkflowId: "123"}, slog.Default(), ) - assert.Nil(t, err) - assert.NotEmpty(t, resp.BenthosConfigs) - assert.Len(t, resp.BenthosConfigs, 3) + require.Nil(t, err) + require.NotEmpty(t, resp.BenthosConfigs) + require.Len(t, resp.BenthosConfigs, 3) insertConfig := getBenthosConfigByName(resp.BenthosConfigs, "public.users") - assert.NotNil(t, insertConfig) - assert.Equal(t, insertConfig.Name, "public.users") - assert.Empty(t, insertConfig.DependsOn) + require.NotNil(t, insertConfig) + require.Equal(t, insertConfig.Name, "public.users") + require.Empty(t, insertConfig.DependsOn) out, err := yaml.Marshal(insertConfig.Config) - assert.NoError(t, err) - assert.Equal( + require.NoError(t, err) + require.Equal( t, strings.TrimSpace(string(out)), strings.TrimSpace(` @@ -2425,13 +2471,13 @@ output: ) updateConfig := getBenthosConfigByName(resp.BenthosConfigs, "public.users.update") - assert.NotNil(t, updateConfig) - assert.Equal(t, updateConfig.Name, "public.users.update") - assert.Equal(t, updateConfig.DependsOn, []*tabledependency.DependsOn{{Table: "public.user_account_associations", Columns: []string{"id"}}, {Table: "public.users", Columns: []string{"id"}}}) + require.NotNil(t, updateConfig) + require.Equal(t, updateConfig.Name, "public.users.update") + require.Equal(t, updateConfig.DependsOn, []*tabledependency.DependsOn{{Table: "public.user_account_associations", Columns: []string{"id"}}, {Table: "public.users", Columns: []string{"id"}}}) out1, err := yaml.Marshal(updateConfig.Config) - assert.NoError(t, err) - assert.Equal( + require.NoError(t, err) + require.Equal( t, strings.TrimSpace(string(out1)), strings.TrimSpace(` @@ -2478,12 +2524,12 @@ output: ) bc2 := getBenthosConfigByName(resp.BenthosConfigs, "public.user_account_associations") - assert.Equal(t, bc2.Name, "public.user_account_associations") - assert.Equal(t, bc2.DependsOn, []*tabledependency.DependsOn{{Table: "public.users", Columns: []string{"id"}}}) + require.Equal(t, bc2.Name, "public.user_account_associations") + require.Equal(t, bc2.DependsOn, []*tabledependency.DependsOn{{Table: "public.users", Columns: []string{"id"}}}) out2, err := yaml.Marshal(bc2.Config) - assert.NoError(t, err) - assert.Equal( + require.NoError(t, err) + require.Equal( t, strings.TrimSpace(string(out2)), strings.TrimSpace(` @@ -2559,19 +2605,19 @@ output: benthosenv := service.NewEnvironment() err = neosync_benthos_sql.RegisterPooledSqlInsertOutput(benthosenv, nil, false) - assert.NoError(t, err) + require.NoError(t, err) err = neosync_benthos_sql.RegisterPooledSqlRawInput(benthosenv, nil, nil) - assert.NoError(t, err) + require.NoError(t, err) err = neosync_benthos_error.RegisterErrorOutput(benthosenv, nil) - assert.NoError(t, err) + require.NoError(t, err) newSB := benthosenv.NewStreamBuilder() // SetYAML parses a full Benthos config and uses it to configure the builder. err = newSB.SetYAML(string(out)) - assert.NoError(t, err) + require.NoError(t, err) err = newSB.SetYAML(string(out2)) - assert.NoError(t, err) + require.NoError(t, err) } func Test_BenthosBuilder_GenerateBenthosConfigs_Basic_Mysql_Mysql(t *testing.T) { @@ -2738,16 +2784,16 @@ func Test_BenthosBuilder_GenerateBenthosConfigs_Basic_Mysql_Mysql(t *testing.T) &GenerateBenthosConfigsRequest{JobId: "123", WorkflowId: "123"}, slog.Default(), ) - assert.Nil(t, err) - assert.NotEmpty(t, resp.BenthosConfigs) - assert.Len(t, resp.BenthosConfigs, 2) + require.Nil(t, err) + require.NotEmpty(t, resp.BenthosConfigs) + require.Len(t, resp.BenthosConfigs, 2) bc := getBenthosConfigByName(resp.BenthosConfigs, "public.users") - assert.Equal(t, bc.Name, "public.users") - assert.Empty(t, bc.DependsOn) + require.Equal(t, bc.Name, "public.users") + require.Empty(t, bc.DependsOn) out, err := yaml.Marshal(bc.Config) - assert.NoError(t, err) - assert.Equal( + require.NoError(t, err) + require.Equal( t, strings.TrimSpace(string(out)), strings.TrimSpace(` @@ -2795,11 +2841,11 @@ output: ) bc2 := getBenthosConfigByName(resp.BenthosConfigs, "public.user_account_associations") - assert.Equal(t, bc2.Name, "public.user_account_associations") - assert.Equal(t, bc2.DependsOn, []*tabledependency.DependsOn{{Table: "public.users", Columns: []string{"id"}}}) + require.Equal(t, bc2.Name, "public.user_account_associations") + require.Equal(t, bc2.DependsOn, []*tabledependency.DependsOn{{Table: "public.users", Columns: []string{"id"}}}) out2, err := yaml.Marshal(bc2.Config) - assert.NoError(t, err) - assert.Equal( + require.NoError(t, err) + require.Equal( t, strings.TrimSpace(string(out2)), strings.TrimSpace(` @@ -2848,20 +2894,20 @@ output: benthosenv := service.NewEnvironment() err = neosync_benthos_sql.RegisterPooledSqlInsertOutput(benthosenv, nil, false) - assert.NoError(t, err) + require.NoError(t, err) err = neosync_benthos_sql.RegisterPooledSqlRawInput(benthosenv, nil, nil) - assert.NoError(t, err) + require.NoError(t, err) err = neosync_benthos_error.RegisterErrorOutput(benthosenv, nil) - assert.NoError(t, err) + require.NoError(t, err) newSB := benthosenv.NewStreamBuilder() // SetYAML parses a full Benthos config and uses it to configure the builder. err = newSB.SetYAML(string(out)) - assert.NoError(t, err) + require.NoError(t, err) // SetYAML parses a full Benthos config and uses it to configure the builder. err = newSB.SetYAML(string(out2)) - assert.NoError(t, err) + require.NoError(t, err) } func Test_BenthosBuilder_GenerateBenthosConfigs_Basic_Mysql_Mysql_With_Circular_Dependency(t *testing.T) { @@ -3055,50 +3101,50 @@ func Test_BenthosBuilder_GenerateBenthosConfigs_Basic_Mysql_Mysql_With_Circular_ &GenerateBenthosConfigsRequest{JobId: "123", WorkflowId: "123"}, slog.Default(), ) - assert.Nil(t, err) - assert.NotEmpty(t, resp.BenthosConfigs) - assert.Len(t, resp.BenthosConfigs, 3) + require.Nil(t, err) + require.NotEmpty(t, resp.BenthosConfigs) + require.Len(t, resp.BenthosConfigs, 3) insertConfig := getBenthosConfigByName(resp.BenthosConfigs, "public.users") - assert.NotNil(t, insertConfig) - assert.Equal(t, insertConfig.Name, "public.users") - assert.Empty(t, insertConfig.DependsOn) + require.NotNil(t, insertConfig) + require.Equal(t, insertConfig.Name, "public.users") + require.Empty(t, insertConfig.DependsOn) out, err := yaml.Marshal(insertConfig.Config) - assert.NoError(t, err) + require.NoError(t, err) updateConfig := getBenthosConfigByName(resp.BenthosConfigs, "public.users.update") - assert.NotNil(t, updateConfig) - assert.Equal(t, updateConfig.Name, "public.users.update") - assert.Equal(t, updateConfig.DependsOn, []*tabledependency.DependsOn{{Table: "public.user_account_associations", Columns: []string{"id"}}, {Table: "public.users", Columns: []string{"id"}}}) + require.NotNil(t, updateConfig) + require.Equal(t, updateConfig.Name, "public.users.update") + require.Equal(t, updateConfig.DependsOn, []*tabledependency.DependsOn{{Table: "public.user_account_associations", Columns: []string{"id"}}, {Table: "public.users", Columns: []string{"id"}}}) out1, err := yaml.Marshal(updateConfig.Config) - assert.NoError(t, err) + require.NoError(t, err) bc2 := getBenthosConfigByName(resp.BenthosConfigs, "public.user_account_associations") - assert.Equal(t, bc2.Name, "public.user_account_associations") - assert.Equal(t, bc2.DependsOn, []*tabledependency.DependsOn{{Table: "public.users", Columns: []string{"id"}}}) + require.Equal(t, bc2.Name, "public.user_account_associations") + require.Equal(t, bc2.DependsOn, []*tabledependency.DependsOn{{Table: "public.users", Columns: []string{"id"}}}) out2, err := yaml.Marshal(bc2.Config) - assert.NoError(t, err) + require.NoError(t, err) benthosenv := service.NewEnvironment() err = neosync_benthos_sql.RegisterPooledSqlInsertOutput(benthosenv, nil, false) - assert.NoError(t, err) + require.NoError(t, err) err = neosync_benthos_sql.RegisterPooledSqlUpdateOutput(benthosenv, nil) - assert.NoError(t, err) + require.NoError(t, err) err = neosync_benthos_sql.RegisterPooledSqlRawInput(benthosenv, nil, nil) - assert.NoError(t, err) + require.NoError(t, err) err = neosync_benthos_error.RegisterErrorOutput(benthosenv, nil) - assert.NoError(t, err) + require.NoError(t, err) newSB := benthosenv.NewStreamBuilder() // SetYAML parses a full Benthos config and uses it to configure the builder. err = newSB.SetYAML(string(out)) - assert.NoError(t, err) + require.NoError(t, err) err = newSB.SetYAML(string(out1)) - assert.NoError(t, err) + require.NoError(t, err) err = newSB.SetYAML(string(out2)) - assert.NoError(t, err) + require.NoError(t, err) } func getBenthosConfigByName(resps []*BenthosConfigResponse, name string) *BenthosConfigResponse { @@ -3169,8 +3215,8 @@ func Test_ProcessorConfigEmpty(t *testing.T) { mockRunId, nil, ) - assert.Nil(t, err) - assert.Empty(t, res[0].Config.StreamConfig.Pipeline.Processors) + require.Nil(t, err) + require.Empty(t, res[0].Config.StreamConfig.Pipeline.Processors) } func Test_buildBenthosSqlSourceConfigResponses_skipTable(t *testing.T) { @@ -3229,8 +3275,8 @@ func Test_buildBenthosSqlSourceConfigResponses_skipTable(t *testing.T) { mockRunId, nil, ) - assert.Nil(t, err) - assert.Len(t, res, 0) + require.Nil(t, err) + require.Len(t, res, 0) } func Test_ProcessorConfigEmptyJavascript(t *testing.T) { mockTransformerClient := mgmtv1alpha1connect.NewMockTransformersServiceClient(t) @@ -3295,8 +3341,8 @@ func Test_ProcessorConfigEmptyJavascript(t *testing.T) { mockRunId, nil, ) - assert.Nil(t, err) - assert.Empty(t, res[0].Config.StreamConfig.Pipeline.Processors) + require.Nil(t, err) + require.Empty(t, res[0].Config.StreamConfig.Pipeline.Processors) } func Test_ProcessorConfigMultiJavascript(t *testing.T) { @@ -3365,11 +3411,11 @@ func Test_ProcessorConfigMultiJavascript(t *testing.T) { mockRunId, nil, ) - assert.Nil(t, err) + require.Nil(t, err) out, err := yaml.Marshal(res[0].Config.Pipeline.Processors) - assert.NoError(t, err) - assert.Equal( + require.NoError(t, err) + require.Equal( t, strings.TrimSpace(` - javascript: @@ -3467,13 +3513,13 @@ func Test_ProcessorConfigMutationAndJavascript(t *testing.T) { nil, ) - assert.Nil(t, err) + require.Nil(t, err) - assert.Len(t, res[0].Config.Pipeline.Processors, 3) + require.Len(t, res[0].Config.Pipeline.Processors, 3) out, err := yaml.Marshal(res[0].Config.Pipeline.Processors) - assert.NoError(t, err) - assert.Equal( + require.NoError(t, err) + require.Equal( t, strings.TrimSpace(` - mutation: root."email" = generate_email(max_length:40) @@ -3513,7 +3559,7 @@ func TestAreMappingsSubsetOfSchemas(t *testing.T) { {Schema: "public", Table: "users", Column: "created_by"}, }, ) - assert.True(t, ok, "job mappings are a subset of the present database schemas") + require.True(t, ok, "job mappings are a subset of the present database schemas") ok = areMappingsSubsetOfSchemas( map[string]map[string]*dbschemas_utils.ColumnInfo{ @@ -3525,7 +3571,7 @@ func TestAreMappingsSubsetOfSchemas(t *testing.T) { {Schema: "public", Table: "users", Column: "id2"}, }, ) - assert.False(t, ok, "job mappings contain mapping that is not in the source schema") + require.False(t, ok, "job mappings contain mapping that is not in the source schema") ok = areMappingsSubsetOfSchemas( map[string]map[string]*dbschemas_utils.ColumnInfo{ @@ -3538,7 +3584,7 @@ func TestAreMappingsSubsetOfSchemas(t *testing.T) { {Schema: "public", Table: "users", Column: "created_by"}, }, ) - assert.False(t, ok, "job mappings contain more mappings than are present in the source schema") + require.False(t, ok, "job mappings contain more mappings than are present in the source schema") } func TestShouldHaltOnSchemaAddition(t *testing.T) { @@ -3554,7 +3600,7 @@ func TestShouldHaltOnSchemaAddition(t *testing.T) { {Schema: "public", Table: "users", Column: "created_by"}, }, ) - assert.False(t, ok, "job mappings are valid set of database schemas") + require.False(t, ok, "job mappings are valid set of database schemas") ok = shouldHaltOnSchemaAddition( map[string]map[string]*dbschemas_utils.ColumnInfo{ @@ -3571,7 +3617,7 @@ func TestShouldHaltOnSchemaAddition(t *testing.T) { {Schema: "public", Table: "users", Column: "created_by"}, }, ) - assert.True(t, ok, "job mappings are missing database schema mappings") + require.True(t, ok, "job mappings are missing database schema mappings") ok = shouldHaltOnSchemaAddition( map[string]map[string]*dbschemas_utils.ColumnInfo{ @@ -3584,7 +3630,7 @@ func TestShouldHaltOnSchemaAddition(t *testing.T) { {Schema: "public", Table: "users", Column: "id"}, }, ) - assert.True(t, ok, "job mappings are missing table column") + require.True(t, ok, "job mappings are missing table column") ok = shouldHaltOnSchemaAddition( map[string]map[string]*dbschemas_utils.ColumnInfo{ @@ -3598,7 +3644,7 @@ func TestShouldHaltOnSchemaAddition(t *testing.T) { {Schema: "public", Table: "users", Column: "updated_by"}, }, ) - assert.True(t, ok, "job mappings have same column count, but missing specific column") + require.True(t, ok, "job mappings have same column count, but missing specific column") } func Test_buildProcessorConfigsMutation(t *testing.T) { @@ -3607,30 +3653,30 @@ func Test_buildProcessorConfigsMutation(t *testing.T) { ctx := context.Background() output, err := buildProcessorConfigs(ctx, mockTransformerClient, []*mgmtv1alpha1.JobMapping{}, map[string]*dbschemas_utils.ColumnInfo{}, map[string]*dbschemas_utils.ForeignKey{}, []string{}, mockJobId, mockRunId, nil) - assert.Nil(t, err) - assert.Empty(t, output) + require.Nil(t, err) + require.Empty(t, output) output, err = buildProcessorConfigs(ctx, mockTransformerClient, []*mgmtv1alpha1.JobMapping{}, map[string]*dbschemas_utils.ColumnInfo{}, map[string]*dbschemas_utils.ForeignKey{}, []string{}, mockJobId, mockRunId, nil) - assert.Nil(t, err) - assert.Empty(t, output) + require.Nil(t, err) + require.Empty(t, output) output, err = buildProcessorConfigs(ctx, mockTransformerClient, []*mgmtv1alpha1.JobMapping{ {Schema: "public", Table: "users", Column: "id"}, }, map[string]*dbschemas_utils.ColumnInfo{}, map[string]*dbschemas_utils.ForeignKey{}, []string{}, mockJobId, mockRunId, nil) - assert.Nil(t, err) - assert.Empty(t, output) + require.Nil(t, err) + require.Empty(t, output) output, err = buildProcessorConfigs(ctx, mockTransformerClient, []*mgmtv1alpha1.JobMapping{ {Schema: "public", Table: "users", Column: "id", Transformer: &mgmtv1alpha1.JobMappingTransformer{}}, }, map[string]*dbschemas_utils.ColumnInfo{}, map[string]*dbschemas_utils.ForeignKey{}, []string{}, mockJobId, mockRunId, nil) - assert.Nil(t, err) - assert.Empty(t, output) + require.Nil(t, err) + require.Empty(t, output) output, err = buildProcessorConfigs(ctx, mockTransformerClient, []*mgmtv1alpha1.JobMapping{ {Schema: "public", Table: "users", Column: "id", Transformer: &mgmtv1alpha1.JobMappingTransformer{Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_PASSTHROUGH}}, }, map[string]*dbschemas_utils.ColumnInfo{}, map[string]*dbschemas_utils.ForeignKey{}, []string{}, mockJobId, mockRunId, nil) - assert.Nil(t, err) - assert.Empty(t, output) + require.Nil(t, err) + require.Empty(t, output) output, err = buildProcessorConfigs(ctx, mockTransformerClient, []*mgmtv1alpha1.JobMapping{ {Schema: "public", Table: "users", Column: "id", Transformer: &mgmtv1alpha1.JobMappingTransformer{Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_NULL, Config: &mgmtv1alpha1.TransformerConfig{ @@ -3645,9 +3691,9 @@ func Test_buildProcessorConfigsMutation(t *testing.T) { }}}, }, map[string]*dbschemas_utils.ColumnInfo{}, map[string]*dbschemas_utils.ForeignKey{}, []string{}, mockJobId, mockRunId, nil) - assert.Nil(t, err) + require.Nil(t, err) - assert.Equal(t, *output[0].Mutation, "root.\"id\" = null\nroot.\"name\" = null") + require.Equal(t, *output[0].Mutation, "root.\"id\" = null\nroot.\"name\" = null") jsT := mgmtv1alpha1.SystemTransformer{ Name: "stage", @@ -3683,8 +3729,8 @@ func Test_buildProcessorConfigsMutation(t *testing.T) { output, err = buildProcessorConfigs(ctx, mockTransformerClient, []*mgmtv1alpha1.JobMapping{ {Schema: "public", Table: "users", Column: "email", Transformer: &mgmtv1alpha1.JobMappingTransformer{Source: jsT.Source, Config: jsT.Config}}}, groupedSchemas, map[string]*dbschemas_utils.ForeignKey{}, []string{}, mockJobId, mockRunId, nil) - assert.Nil(t, err) - assert.Equal(t, *output[0].Mutation, `root."email" = transform_email(email:this."email",preserve_domain:true,preserve_length:false,excluded_domains:[],max_length:40)`) + require.Nil(t, err) + require.Equal(t, *output[0].Mutation, `root."email" = transform_email(email:this."email",preserve_domain:true,preserve_length:false,excluded_domains:[],max_length:40)`) } const transformJsCodeFnStr = `var payload = value+=" hello";return payload;` @@ -3712,8 +3758,8 @@ func Test_buildProcessorConfigsJavascript(t *testing.T) { res, err := buildProcessorConfigs(ctx, mockTransformerClient, []*mgmtv1alpha1.JobMapping{ {Schema: "public", Table: "users", Column: "address", Transformer: &mgmtv1alpha1.JobMappingTransformer{Source: jsT.Source, Config: jsT.Config}}}, map[string]*dbschemas_utils.ColumnInfo{}, map[string]*dbschemas_utils.ForeignKey{}, []string{}, mockJobId, mockRunId, nil) - assert.NoError(t, err) - assert.Equal(t, ` + require.NoError(t, err) + require.Equal(t, ` (() => { function fn_address(value, input){ @@ -3752,8 +3798,8 @@ func Test_buildProcessorConfigsGenerateJavascript(t *testing.T) { res, err := buildProcessorConfigs(ctx, mockTransformerClient, []*mgmtv1alpha1.JobMapping{ {Schema: "public", Table: "users", Column: "test", Transformer: &mgmtv1alpha1.JobMappingTransformer{Source: jsT.Source, Config: jsT.Config}}}, map[string]*dbschemas_utils.ColumnInfo{}, map[string]*dbschemas_utils.ForeignKey{}, []string{}, mockJobId, mockRunId, nil) - assert.NoError(t, err) - assert.Equal(t, ` + require.NoError(t, err) + require.Equal(t, ` (() => { function fn_test(){ @@ -3798,8 +3844,8 @@ func Test_buildProcessorConfigsJavascriptMultiLineScript(t *testing.T) { res, err := buildProcessorConfigs(ctx, mockTransformerClient, []*mgmtv1alpha1.JobMapping{ {Schema: "public", Table: "users", Column: nameCol, Transformer: &mgmtv1alpha1.JobMappingTransformer{Source: jsT.Source, Config: jsT.Config}}}, map[string]*dbschemas_utils.ColumnInfo{}, map[string]*dbschemas_utils.ForeignKey{}, []string{}, mockJobId, mockRunId, nil) - assert.NoError(t, err) - assert.Equal(t, ` + require.NoError(t, err) + require.Equal(t, ` (() => { function fn_name(value, input){ @@ -3856,8 +3902,8 @@ func Test_buildProcessorConfigsJavascriptMultiple(t *testing.T) { {Schema: "public", Table: "users", Column: nameCol, Transformer: &mgmtv1alpha1.JobMappingTransformer{Source: jsT.Source, Config: jsT.Config}}, {Schema: "public", Table: "users", Column: col2, Transformer: &mgmtv1alpha1.JobMappingTransformer{Source: jsT2.Source, Config: jsT2.Config}}}, map[string]*dbschemas_utils.ColumnInfo{}, map[string]*dbschemas_utils.ForeignKey{}, []string{}, mockJobId, mockRunId, nil) - assert.NoError(t, err) - assert.Equal(t, ` + require.NoError(t, err) + require.Equal(t, ` (() => { function fn_name(value, input){ @@ -3918,8 +3964,8 @@ func Test_buildProcessorConfigsTransformAndGenerateJavascript(t *testing.T) { {Schema: "public", Table: "users", Column: nameCol, Transformer: &mgmtv1alpha1.JobMappingTransformer{Source: jsT.Source, Config: jsT.Config}}, {Schema: "public", Table: "users", Column: col2, Transformer: &mgmtv1alpha1.JobMappingTransformer{Source: jsT2.Source, Config: jsT2.Config}}}, map[string]*dbschemas_utils.ColumnInfo{}, map[string]*dbschemas_utils.ForeignKey{}, []string{}, mockJobId, mockRunId, nil) - assert.NoError(t, err) - assert.Equal(t, ` + require.NoError(t, err) + require.Equal(t, ` (() => { function fn_name(value, input){ @@ -3952,7 +3998,7 @@ func Test_ShouldProcessColumnTrue(t *testing.T) { } res := shouldProcessColumn(val) - assert.Equal(t, true, res) + require.Equal(t, true, res) } func Test_ShouldProcessColumnFalse(t *testing.T) { @@ -3966,14 +4012,14 @@ func Test_ShouldProcessColumnFalse(t *testing.T) { } res := shouldProcessColumn(val) - assert.Equal(t, false, res) + require.Equal(t, false, res) } func Test_ConstructJsFunctionTransformJs(t *testing.T) { s := mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_TRANSFORM_JAVASCRIPT res := constructJsFunction(transformJsCodeFnStr, "col", s) - assert.Equal(t, ` + require.Equal(t, ` function fn_col(value, input){ var payload = value+=" hello";return payload; }; @@ -3984,7 +4030,7 @@ func Test_ConstructJsFunctionGenerateJS(t *testing.T) { s := mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_JAVASCRIPT res := constructJsFunction(generateJSCodeFnStr, "col", s) - assert.Equal(t, ` + require.Equal(t, ` function fn_col(){ var payload = "hello";return payload; }; @@ -4004,7 +4050,7 @@ func Test_ConstructBenthosJsProcessorTransformJS(t *testing.T) { res := constructBenthosJsProcessor(jsFunctions, benthosOutputs) - assert.Equal(t, ` + require.Equal(t, ` (() => { function fn_name(value, input){ @@ -4031,7 +4077,7 @@ func Test_ConstructBenthosJsProcessorGenerateJS(t *testing.T) { res := constructBenthosJsProcessor(jsFunctions, benthosOutputs) - assert.Equal(t, ` + require.Equal(t, ` (() => { function fn_name(){ @@ -4048,13 +4094,13 @@ benthos.v0_msg_set_structured(output); func Test_ConstructBenthosOutputTranformJs(t *testing.T) { s := mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_TRANSFORM_JAVASCRIPT res := constructBenthosJavascriptObject("col", s) - assert.Equal(t, `output["col"] = fn_col(input["col"], input);`, res) + require.Equal(t, `output["col"] = fn_col(input["col"], input);`, res) } func Test_ConstructBenthosOutputGenerateJs(t *testing.T) { s := mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_JAVASCRIPT res := constructBenthosJavascriptObject("col", s) - assert.Equal(t, `output["col"] = fn_col();`, res) + require.Equal(t, `output["col"] = fn_col();`, res) } func Test_buildProcessorConfigsJavascriptEmpty(t *testing.T) { @@ -4078,8 +4124,8 @@ func Test_buildProcessorConfigsJavascriptEmpty(t *testing.T) { resp, err := buildProcessorConfigs(ctx, mockTransformerClient, []*mgmtv1alpha1.JobMapping{ {Schema: "public", Table: "users", Column: "id", Transformer: &mgmtv1alpha1.JobMappingTransformer{Source: jsT.Source, Config: jsT.Config}}}, map[string]*dbschemas_utils.ColumnInfo{}, map[string]*dbschemas_utils.ForeignKey{}, []string{}, mockJobId, mockRunId, nil) - assert.NoError(t, err) - assert.Empty(t, resp) + require.NoError(t, err) + require.Empty(t, resp) } func Test_convertUserDefinedFunctionConfig(t *testing.T) { @@ -4137,8 +4183,8 @@ func Test_convertUserDefinedFunctionConfig(t *testing.T) { } resp, err := convertUserDefinedFunctionConfig(ctx, mockTransformerClient, jmt) - assert.NoError(t, err) - assert.Equal(t, resp, expected) + require.NoError(t, err) + require.Equal(t, resp, expected) } func MockJobMappingTransformer(source int32, transformerId string) db_queries.NeosyncApiTransformer { @@ -4149,15 +4195,15 @@ func MockJobMappingTransformer(source int32, transformerId string) db_queries.Ne } func Test_buildPlainInsertArgs(t *testing.T) { - assert.Empty(t, buildPlainInsertArgs(nil)) - assert.Empty(t, buildPlainInsertArgs([]string{})) - assert.Equal(t, buildPlainInsertArgs([]string{"foo", "bar", "baz"}), `root = [this."foo", this."bar", this."baz"]`) + require.Empty(t, buildPlainInsertArgs(nil)) + require.Empty(t, buildPlainInsertArgs([]string{})) + require.Equal(t, buildPlainInsertArgs([]string{"foo", "bar", "baz"}), `root = [this."foo", this."bar", this."baz"]`) } func Test_buildPlainColumns(t *testing.T) { - assert.Empty(t, buildPlainColumns(nil)) - assert.Empty(t, buildPlainColumns([]*mgmtv1alpha1.JobMapping{})) - assert.Equal( + require.Empty(t, buildPlainColumns(nil)) + require.Empty(t, buildPlainColumns([]*mgmtv1alpha1.JobMapping{})) + require.Equal( t, buildPlainColumns([]*mgmtv1alpha1.JobMapping{ {Column: "foo"}, @@ -4170,58 +4216,58 @@ func Test_buildPlainColumns(t *testing.T) { func Test_splitTableKey(t *testing.T) { schema, table := splitTableKey("foo") - assert.Equal(t, schema, "public") - assert.Equal(t, table, "foo") + require.Equal(t, schema, "public") + require.Equal(t, table, "foo") schema, table = splitTableKey("neosync.foo") - assert.Equal(t, schema, "neosync") - assert.Equal(t, table, "foo") + require.Equal(t, schema, "neosync") + require.Equal(t, table, "foo") } func Test_buildBenthosS3Credentials(t *testing.T) { - assert.Nil(t, buildBenthosS3Credentials(nil)) + require.Nil(t, buildBenthosS3Credentials(nil)) - assert.Equal( + require.Equal( t, buildBenthosS3Credentials(&mgmtv1alpha1.AwsS3Credentials{}), &neosync_benthos.AwsCredentials{}, ) - assert.Equal( + require.Equal( t, buildBenthosS3Credentials(&mgmtv1alpha1.AwsS3Credentials{Profile: shared.Ptr("foo")}), &neosync_benthos.AwsCredentials{Profile: "foo"}, ) - assert.Equal( + require.Equal( t, buildBenthosS3Credentials(&mgmtv1alpha1.AwsS3Credentials{AccessKeyId: shared.Ptr("foo")}), &neosync_benthos.AwsCredentials{Id: "foo"}, ) - assert.Equal( + require.Equal( t, buildBenthosS3Credentials(&mgmtv1alpha1.AwsS3Credentials{SecretAccessKey: shared.Ptr("foo")}), &neosync_benthos.AwsCredentials{Secret: "foo"}, ) - assert.Equal( + require.Equal( t, buildBenthosS3Credentials(&mgmtv1alpha1.AwsS3Credentials{SessionToken: shared.Ptr("foo")}), &neosync_benthos.AwsCredentials{Token: "foo"}, ) - assert.Equal( + require.Equal( t, buildBenthosS3Credentials(&mgmtv1alpha1.AwsS3Credentials{FromEc2Role: shared.Ptr(true)}), &neosync_benthos.AwsCredentials{FromEc2Role: true}, ) - assert.Equal( + require.Equal( t, buildBenthosS3Credentials(&mgmtv1alpha1.AwsS3Credentials{RoleArn: shared.Ptr("foo")}), &neosync_benthos.AwsCredentials{Role: "foo"}, ) - assert.Equal( + require.Equal( t, buildBenthosS3Credentials(&mgmtv1alpha1.AwsS3Credentials{RoleExternalId: shared.Ptr("foo")}), &neosync_benthos.AwsCredentials{RoleExternalId: "foo"}, ) - assert.Equal( + require.Equal( t, buildBenthosS3Credentials(&mgmtv1alpha1.AwsS3Credentials{ Profile: shared.Ptr("profile"), @@ -4251,8 +4297,8 @@ func Test_computeMutationFunction_null(t *testing.T) { Source: mgmtv1alpha1.TransformerSource_TRANSFORMER_SOURCE_GENERATE_NULL, }, }, &dbschemas_utils.ColumnInfo{}) - assert.NoError(t, err) - assert.Equal(t, val, "null") + require.NoError(t, err) + require.Equal(t, val, "null") } func Test_computeMutationFunction_Validate_Bloblang_Output(t *testing.T) { @@ -4636,9 +4682,9 @@ func Test_computeMutationFunction_Validate_Bloblang_Output(t *testing.T) { }, }, emailColInfo) - assert.NoError(t, err) + require.NoError(t, err) _, err = bloblang.Parse(val) - assert.NoError(t, err, fmt.Sprintf("transformer lint failed, check that the transformer string is being constructed correctly. Failing source: %s", transformer.Source)) + require.NoError(t, err, fmt.Sprintf("transformer lint failed, check that the transformer string is being constructed correctly. Failing source: %s", transformer.Source)) }) } } @@ -4715,11 +4761,11 @@ func Test_computeMutationFunction_handles_Db_Maxlen(t *testing.T) { for _, tc := range testcases { t.Run(t.Name(), func(t *testing.T) { out, err := computeMutationFunction(tc.jm, tc.ci) - assert.NoError(t, err) - assert.NotNil(t, out) - assert.Equal(t, tc.expected, out, "computed bloblang string was not expected") + require.NoError(t, err) + require.NotNil(t, out) + require.Equal(t, tc.expected, out, "computed bloblang string was not expected") _, err = bloblang.Parse(out) - assert.NoError(t, err) + require.NoError(t, err) }) } } @@ -4741,8 +4787,8 @@ func Test_buildBranchCacheConfigs_null(t *testing.T) { } resp, err := buildBranchCacheConfigs(cols, constraints, mockJobId, mockRunId, nil) - assert.NoError(t, err) - assert.Len(t, resp, 0) + require.NoError(t, err) + require.Len(t, resp, 0) } func Test_buildBranchCacheConfigs_missing_redis(t *testing.T) { @@ -4762,7 +4808,7 @@ func Test_buildBranchCacheConfigs_missing_redis(t *testing.T) { } _, err := buildBranchCacheConfigs(cols, constraints, mockJobId, mockRunId, nil) - assert.Error(t, err) + require.Error(t, err) } func Test_buildBranchCacheConfigs_success(t *testing.T) { @@ -4792,10 +4838,10 @@ func Test_buildBranchCacheConfigs_success(t *testing.T) { resp, err := buildBranchCacheConfigs(cols, constraints, mockJobId, mockRunId, redisConfig) - assert.NoError(t, err) - assert.Len(t, resp, 1) - assert.Equal(t, *resp[0].RequestMap, `root = if this."user_id" == null { deleted() } else { this }`) - assert.Equal(t, *resp[0].ResultMap, `root."user_id" = this`) + require.NoError(t, err) + require.Len(t, resp, 1) + require.Equal(t, *resp[0].RequestMap, `root = if this."user_id" == null { deleted() } else { this }`) + require.Equal(t, *resp[0].ResultMap, `root."user_id" = this`) } func Test_buildBranchCacheConfigs_self_referencing(t *testing.T) { @@ -4819,22 +4865,22 @@ func Test_buildBranchCacheConfigs_self_referencing(t *testing.T) { } resp, err := buildBranchCacheConfigs(cols, constraints, mockJobId, mockRunId, redisConfig) - assert.NoError(t, err) - assert.Len(t, resp, 0) + require.NoError(t, err) + require.Len(t, resp, 0) } func Test_ConverStringSliceToStringEmptySlice(t *testing.T) { slc := []string{} res, err := convertStringSliceToString(slc) - assert.NoError(t, err) - assert.Equal(t, "[]", res) + require.NoError(t, err) + require.Equal(t, "[]", res) } func Test_ConverStringSliceToStringNotEmptySlice(t *testing.T) { slc := []string{"gmail.com", "yahoo.com"} res, err := convertStringSliceToString(slc) - assert.NoError(t, err) - assert.Equal(t, `["gmail.com","yahoo.com"]`, res) + require.NoError(t, err) + require.Equal(t, `["gmail.com","yahoo.com"]`, res) } diff --git a/worker/pkg/workflows/datasync/activities/run-sql-init-table-stmts/init-statement-builder.go b/worker/pkg/workflows/datasync/activities/run-sql-init-table-stmts/init-statement-builder.go index 3a1444009..add7e7205 100644 --- a/worker/pkg/workflows/datasync/activities/run-sql-init-table-stmts/init-statement-builder.go +++ b/worker/pkg/workflows/datasync/activities/run-sql-init-table-stmts/init-statement-builder.go @@ -137,11 +137,15 @@ func (b *initStatementBuilder) RunSqlInitTableStatements( pool := b.pgpool[sourceConnection.Id] sourceConnectionId = sourceConnection.Id - allConstraints, err := dbschemas_postgres.GetAllPostgresFkConstraints(b.pgquerier, ctx, pool, uniqueSchemas) + allConstraints, err := dbschemas_postgres.GetAllPostgresForeignKeyConstraints(ctx, pool, b.pgquerier, uniqueSchemas) if err != nil { return nil, fmt.Errorf("unable to retrieve postgres foreign key constraints: %w", err) } - tableDependencies = dbschemas_postgres.GetPostgresTableDependencies(allConstraints) + tableDeps, err := dbschemas_postgres.GetPostgresTableDependencies(allConstraints) + if err != nil { + return nil, fmt.Errorf("uanble to build postgres table deps from fk constraints: %w", err) + } + tableDependencies = tableDeps case *mgmtv1alpha1.JobSourceOptions_Mysql: sourceConnection, err := b.getConnectionById(ctx, jobSourceConfig.Mysql.ConnectionId) diff --git a/worker/pkg/workflows/datasync/activities/run-sql-init-table-stmts/init-statement-builder_test.go b/worker/pkg/workflows/datasync/activities/run-sql-init-table-stmts/init-statement-builder_test.go index 8262f1b5c..863b1d120 100644 --- a/worker/pkg/workflows/datasync/activities/run-sql-init-table-stmts/init-statement-builder_test.go +++ b/worker/pkg/workflows/datasync/activities/run-sql-init-table-stmts/init-statement-builder_test.go @@ -375,8 +375,8 @@ func Test_InitStatementBuilder_Pg_TruncateCascade(t *testing.T) { }, }, }), nil) - pgquerier.On("GetForeignKeyConstraints", mock.Anything, mock.Anything, mock.Anything). - Return([]*pg_queries.GetForeignKeyConstraintsRow{}, nil) + pgquerier.On("GetTableConstraintsBySchema", mock.Anything, mock.Anything, mock.Anything). + Return([]*pg_queries.GetTableConstraintsBySchemaRow{}, nil) var cmdtag pgconn.CommandTag allowedQueries := []string{ @@ -524,16 +524,18 @@ func Test_InitStatementBuilder_Pg_Truncate(t *testing.T) { }, }, }), nil) - pgquerier.On("GetForeignKeyConstraints", mock.Anything, mock.Anything, mock.Anything). - Return([]*pg_queries.GetForeignKeyConstraintsRow{ + pgquerier.On("GetTableConstraintsBySchema", mock.Anything, mock.Anything, mock.Anything). + Return([]*pg_queries.GetTableConstraintsBySchemaRow{ { - ConstraintName: "fk_user_account_id", - SchemaName: "public", - TableName: "users", - ColumnName: "account_id", - ForeignSchemaName: "public", - ForeignTableName: "accounts", - ForeignColumnName: "id", + ConstraintName: "fk_user_account_id", + SchemaName: "public", + TableName: "users", + ConstraintColumns: []string{"account_id"}, + ForeignSchemaName: "public", + ForeignTableName: "accounts", + ForeignColumnNames: []string{"id"}, + Notnullable: []bool{true}, + ConstraintType: "f", }, }, nil) @@ -678,16 +680,18 @@ func Test_InitStatementBuilder_Pg_InitSchema(t *testing.T) { }, }, }), nil) - pgquerier.On("GetForeignKeyConstraints", mock.Anything, mock.Anything, mock.Anything). - Return([]*pg_queries.GetForeignKeyConstraintsRow{ + pgquerier.On("GetTableConstraintsBySchema", mock.Anything, mock.Anything, mock.Anything). + Return([]*pg_queries.GetTableConstraintsBySchemaRow{ { - ConstraintName: "fk_user_account_id", - SchemaName: "public", - TableName: "users", - ColumnName: "account_id", - ForeignSchemaName: "public", - ForeignTableName: "accounts", - ForeignColumnName: "id", + ConstraintName: "fk_user_account_id", + SchemaName: "public", + TableName: "users", + ConstraintColumns: []string{"account_id"}, + ForeignSchemaName: "public", + ForeignTableName: "accounts", + ForeignColumnNames: []string{"id"}, + Notnullable: []bool{true}, + ConstraintType: "f", }, }, nil) pgquerier.On("GetTableConstraints", mock.Anything, mock.Anything, &pg_queries.GetTableConstraintsParams{