Skip to content

Commit

Permalink
Allow for multiple foreign_key in CreateQueryBuilder
Browse files Browse the repository at this point in the history
  • Loading branch information
SebastianBoyd committed Oct 12, 2023
1 parent ded17eb commit 41b2a5e
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 35 deletions.
73 changes: 44 additions & 29 deletions pypika/queries.py
Expand Up @@ -1707,6 +1707,36 @@ def replace_table(self, current_table: Optional[Table], new_table: Optional[Tabl
self.fields = [field.replace_table(current_table, new_table) for field in self.fields]


class ForeignKey:
"""Represents a foreign key constraint."""

def __init__(
self,
columns: List[Column],
reference_table: Union[str, Table],
reference_columns: List[Column],
on_delete: ReferenceOption = None,
on_update: ReferenceOption = None,
) -> None:
self.columns = columns
self.reference_table = reference_table
self.reference_columns = reference_columns
self.on_delete = on_delete
self.on_update = on_update

def get_sql(self, **kwargs: Any) -> str:
foreign_key_sql = "FOREIGN KEY ({columns}) REFERENCES {table_name} ({reference_columns})".format(
columns=",".join(column.get_name_sql(**kwargs) for column in self.columns),
table_name=self.reference_table.get_sql(**kwargs),
reference_columns=",".join(column.get_name_sql(**kwargs) for column in self.reference_columns),
)
if self.on_delete:
foreign_key_sql += " ON DELETE " + self.on_delete.value
if self.on_update:
foreign_key_sql += " ON UPDATE " + self.on_update.value
return foreign_key_sql


class CreateQueryBuilder:
"""
Query builder used to build CREATE queries.
Expand All @@ -1729,11 +1759,7 @@ def __init__(self, dialect: Optional[Dialects] = None) -> None:
self._uniques = []
self._if_not_exists = False
self.dialect = dialect
self._foreign_key = None
self._foreign_key_reference_table = None
self._foreign_key_reference = None
self._foreign_key_on_update: ReferenceOption = None
self._foreign_key_on_delete: ReferenceOption = None
self._foreign_keys = []

def _set_kwargs_defaults(self, kwargs: dict) -> None:
kwargs.setdefault("quote_char", self.QUOTE_CHAR)
Expand Down Expand Up @@ -1908,19 +1934,19 @@ def foreign_key(
Update option.
:raises AttributeError:
If the foreign key is already defined.
:return:
CreateQueryBuilder.
"""
if self._foreign_key:
raise AttributeError("'Query' object already has attribute foreign_key")
self._foreign_key = self._prepare_columns_input(columns)
self._foreign_key_reference_table = reference_table
self._foreign_key_reference = self._prepare_columns_input(reference_columns)
self._foreign_key_on_delete = on_delete
self._foreign_key_on_update = on_update

self._foreign_keys.append(
ForeignKey(
columns=self._prepare_columns_input(columns),
reference_table=reference_table,
reference_columns=self._prepare_columns_input(reference_columns),
on_delete=on_delete,
on_update=on_update,
)
)

@builder
def as_select(self, query_builder: QueryBuilder) -> "CreateQueryBuilder":
Expand Down Expand Up @@ -2017,28 +2043,17 @@ def _primary_key_clause(self, **kwargs) -> str:
columns=",".join(column.get_name_sql(**kwargs) for column in self._primary_key)
)

def _foreign_key_clause(self, **kwargs) -> str:
clause = "FOREIGN KEY ({columns}) REFERENCES {table_name} ({reference_columns})".format(
columns=",".join(column.get_name_sql(**kwargs) for column in self._foreign_key),
table_name=self._foreign_key_reference_table.get_sql(**kwargs),
reference_columns=",".join(column.get_name_sql(**kwargs) for column in self._foreign_key_reference),
)
if self._foreign_key_on_delete:
clause += " ON DELETE " + self._foreign_key_on_delete.value
if self._foreign_key_on_update:
clause += " ON UPDATE " + self._foreign_key_on_update.value

return clause
def _foreign_key_clauses(self, **kwargs) -> str:
return [foreign_key.get_sql(**kwargs) for foreign_key in self._foreign_keys]

def _body_sql(self, **kwargs) -> str:
clauses = self._column_clauses(**kwargs)
clauses += self._period_for_clauses(**kwargs)
clauses += self._unique_key_clauses(**kwargs)
clauses += self._foreign_key_clauses(**kwargs)

if self._primary_key:
clauses.append(self._primary_key_clause(**kwargs))
if self._foreign_key:
clauses.append(self._foreign_key_clause(**kwargs))

return ",".join(clauses)

Expand Down
31 changes: 25 additions & 6 deletions pypika/tests/test_create.py
Expand Up @@ -99,6 +99,31 @@ def test_create_table_with_columns(self):
str(q),
)

with self.subTest("with multiple foreign key constrains"):
secondary_table = Table("secondary_table")
cref, dref = Columns(("c", "INT"), ("d", "VARCHAR(100)"))
q = (
Query.create_table(self.new_table)
.columns(self.foo, self.bar)
.foreign_key([self.foo], self.existing_table, [cref])
.foreign_key(
[self.bar],
secondary_table,
[dref],
on_delete=ReferenceOption.cascade,
on_update=ReferenceOption.restrict,
)
)

self.assertEqual(
'CREATE TABLE "abc" ('
'"a" INT,'
'"b" VARCHAR(100),'
'FOREIGN KEY ("a") REFERENCES "efg" ("c"),'
'FOREIGN KEY ("b") REFERENCES "secondary_table" ("d") ON DELETE CASCADE ON UPDATE RESTRICT)',
str(q),
)

with self.subTest("with unique keys"):
q = (
Query.create_table(self.new_table)
Expand Down Expand Up @@ -156,12 +181,6 @@ def test_create_table_with_select_and_columns_fails(self):
with self.assertRaises(AttributeError):
Query.create_table(self.new_table).as_select(select).columns(self.foo, self.bar)

with self.subTest("repeated foreign key"):
with self.assertRaises(AttributeError):
Query.create_table(self.new_table).foreign_key([self.foo], self.existing_table, [self.bar]).foreign_key(
[self.foo], self.existing_table, [self.bar]
)

def test_create_table_as_select_not_query_raises_error(self):
with self.assertRaises(TypeError):
Query.create_table(self.new_table).as_select("abc")
Expand Down

0 comments on commit 41b2a5e

Please sign in to comment.