Skip to content

Commit

Permalink
Allow specification of a server column default in migrations.
Browse files Browse the repository at this point in the history
Replaces #2803, thanks @b40yd
  • Loading branch information
coleifer committed Nov 2, 2023
1 parent 3547d5c commit 7e2c227
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 0 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ https://github.com/coleifer/peewee/releases
## master

* Add bitwise and other helper methods to `BigBitField`, #2802.
* Add `add_column_default` and `drop_column_default` migrator methods for
specifying a server-side default value, #2803.
* The new `star` attribute was causing issues for users who had a field named
star on their models. This attribute is now renamed to `__star__`. #2796.

Expand Down
25 changes: 25 additions & 0 deletions docs/peewee/playhouse.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3116,6 +3116,31 @@ Adding or dropping table constraints:
# Add a UNIQUE constraint on the first and last names.
migrate(migrator.add_unique('person', 'first_name', 'last_name'))
Adding or dropping a database-level default value for a column:

.. code-block:: python
# Add a default value for a status column.
migrate(migrator.add_column_default(
'entries',
'status',
'draft'))
# Remove the default.
migrate(migrator.drop_column_default('entries', 'status'))
# Use a function for the default value (does not work with Sqlite):
migrate(migrator.add_column_default(
'entries',
'timestamp',
fn.now()))
# Or alternatively (works with Sqlite):
migrate(migrator.add_column_default(
'entries',
'timestamp',
'now()'))
.. note::
Postgres users may need to set the search-path when using a non-standard
schema. This can be done as follows:
Expand Down
47 changes: 47 additions & 0 deletions playhouse/migrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,32 @@ def drop_not_null(self, table, column):
._alter_column(self.make_context(), table, column)
.literal(' DROP NOT NULL'))

@operation
def add_column_default(self, table, column, default):
if default is None:
raise ValueError('`default` must be not None/NULL.')
if callable_(default):
default = default()
# Try to handle SQL functions and string literals, otherwise pass as a
# bound value.
if isinstance(default, str) and default.endswith((')', "'")):
default = SQL(default)

return (self
._alter_table(self.make_context(), table)
.literal(' ALTER COLUMN ')
.sql(Entity(column))
.literal(' SET DEFAULT ')
.sql(default))

@operation
def drop_column_default(self, table, column):
return (self
._alter_table(self.make_context(), table)
.literal(' ALTER COLUMN ')
.sql(Entity(column))
.literal(' DROP DEFAULT'))

@operation
def alter_column_type(self, table, column, field, cast=None):
# ALTER TABLE <table> ALTER COLUMN <column>
Expand Down Expand Up @@ -866,6 +892,27 @@ def _drop_not_null(column_name, column_def):
return column_def.replace('NOT NULL', '')
return self._update_column(table, column, _drop_not_null)

@operation
def add_column_default(self, table, column, default):
if default is None:
raise ValueError('`default` must be not None/NULL.')
if callable_(default):
default = default()
if (isinstance(default, str) and not default.endswith((')', "'"))
and not default.isdigit()):
default = "'%s'" % default
def _add_default(column_name, column_def):
# Try to handle SQL functions and string literals, otherwise quote.
return column_def + ' DEFAULT %s' % default
return self._update_column(table, column, _add_default)

@operation
def drop_column_default(self, table, column):
def _drop_default(column_name, column_def):
col = re.sub(r'DEFAULT\s+[\w"\'\(\)]+(\s|$)', '', column_def, re.I)
return col.strip()
return self._update_column(table, column, _drop_default)

@operation
def alter_column_type(self, table, column, field, cast=None):
if cast is not None:
Expand Down
23 changes: 23 additions & 0 deletions tests/migrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,29 @@ class Meta:
def test_rename_gh380_sqlite_legacy(self):
self.test_rename_gh380(legacy=True)

def test_add_default_drop_default(self):
with self.database.transaction():
migrate(self.migrator.add_column_default('person', 'first_name',
default='x'))

p = Person.create(last_name='Last')
p_db = Person.get(Person.last_name == 'Last')
self.assertEqual(p_db.first_name, 'x')

with self.database.transaction():
migrate(self.migrator.drop_column_default('person', 'first_name'))

if IS_MYSQL:
# MySQL, even though the column is NOT NULL, does not seem to be
# enforcing the constraint(?).
Person.create(last_name='Last2')
p_db = Person.get(Person.last_name == 'Last2')
self.assertEqual(p_db.first_name, '')
else:
with self.assertRaises(IntegrityError):
with self.database.transaction():
Person.create(last_name='Last2')

def test_add_not_null(self):
self._create_people()

Expand Down

0 comments on commit 7e2c227

Please sign in to comment.