Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions pgsqlite/pgsqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,9 @@ def __init__(self, sqlite_filename: str, pg_conninfo: str, show_sample_data: boo
self.max_import_concurrency = max_import_concurrency
db = Database(self.sqlite_filename)
self._tables = {t.name: ParsedTable(t) for t in db.tables}

self._constr_names = []
self._constr_names_counter = 0

@property
def tables(self):
return self._tables.values()
Expand Down Expand Up @@ -237,6 +239,11 @@ def get_table_sql(self, table: ParsedTable) -> SQL:
transpiled_pks_to_add = [table.get_transpiled_colname(pk) for pk in pks_to_add]
all_column_sql = all_column_sql + SQL(",\n")
pk_name = f"PK_{table.source_name}_" + ''.join(pks_to_add)
if pk_name in self._constr_names:
self._constr_names_counter += 1
pk_name = f"{pk_name}_{self._constr_names_counter}"
else:
self._constr_names.append(pk_name)
pk_sql = SQL(" CONSTRAINT {pk_name} PRIMARY KEY ({pks})").format(
table_name=Identifier(table.transpiled_name),
pk_name=Identifier(pk_name), pks=SQL(", ").join(
Expand Down Expand Up @@ -269,6 +276,11 @@ def get_fk_sql(self, table: ParsedTable) -> SQL:
# create the foreign keys after the tables to avoid having to figure out the dep graph
for fk in table.src_table.foreign_keys:
fk_name = f"FK_{fk.other_table}_{fk.other_column}"
if fk_name in self._constr_names:
self._constr_names_counter += 1
fk_name = f"{fk_name}_{self._constr_names_counter}"
else:
self._constr_names.append(fk_name)
fk_sql = SQL("ALTER TABLE {table_name} ADD CONSTRAINT {key_name} FOREIGN KEY ({column}) REFERENCES {other_table} ({other_column})").format(
table_name=Identifier(table.transpiled_name),
column=Identifier(table.get_transpiled_colname(fk.column)),
Expand Down Expand Up @@ -586,4 +598,4 @@ async def create_all_indexes():
logger.debug(json.dumps(loader.get_summary(), indent=2))

if args.drop_tables_after_import:
loader._drop_tables()
loader._drop_tables()