Skip to content

Commit

Permalink
fix server_default
Browse files Browse the repository at this point in the history
  • Loading branch information
jadbin committed Jul 24, 2019
1 parent c94a3a4 commit bc01b3f
Showing 1 changed file with 28 additions and 4 deletions.
32 changes: 28 additions & 4 deletions guniflask/modelgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import sqlalchemy
from sqlalchemy import ForeignKeyConstraint, UniqueConstraint, PrimaryKeyConstraint, CheckConstraint, ForeignKey
import inflect
from sqlalchemy.util import OrderedDict

from guniflask.utils.template import string_camelcase, string_lowercase_underscore

Expand Down Expand Up @@ -56,7 +57,7 @@ def render(self, path):
module_name = convert_to_valid_identifier(model.table.name)
with open(join(path, module_name + '.py'), 'w', encoding='utf-8') as f:
f.write('# coding=utf-8\n\n')
f.write(self.render_imports())
f.write(self.render_imports(model))
tables_content = self.render_secondary_tables(model)
if tables_content:
f.write('\n')
Expand All @@ -69,8 +70,22 @@ def render(self, path):
for m in model_modules:
f.write('from .{} import {}\n'.format(m['module'], m['class']))

def render_imports(self):
return 'from {} import db\n'.format(self.name)
def render_imports(self, model):
d = OrderedDict()
for col in model.table.columns:
if col.server_default:
d.setdefault('sqlalchemy', ('text', '_text'))

imports = ''
for k, v in d.items():
if isinstance(v, tuple):
imports += 'from {} import {} as {}\n'.format(k, v[0], v[1])
else:
imports += 'from {} import {}\n'.format(k, v)
if len(d) > 0:
imports += '\n'
imports += 'from {} import db\n'.format(self.name)
return imports

def render_model(self, model):
header_str = "class {}(db.Model):\n".format(model.class_name)
Expand Down Expand Up @@ -122,7 +137,12 @@ def render_column(self, column, show_name=False):
if column.comment:
kwargs.append('comment')
if column.server_default:
server_default = 'server_default="{}"'.format(str(column.server_default.arg))
default_expr = self.get_compiled_expression(column.server_default.arg)
if '\n' in default_expr:
server_default = 'server_default=text("""\\\n{0}""")'.format(default_expr)
else:
default_expr = default_expr.replace('"', '\\"')
server_default = 'server_default=_text("{0}")'.format(default_expr)
extra_kwargs = self.get_extra_column_kwargs(column)
return "db.Column({})".format(', '.join(([repr(column.name)] if show_name else []) +
[self.render_column_type(column.type)] +
Expand Down Expand Up @@ -162,6 +182,10 @@ def render_relationship(self, relationship):
sorted(relationship.kwargs.keys())])
return '{} = db.relationship({})'.format(relationship.preferred_name, kwargs_str)

def get_compiled_expression(self, statement):
return str(statement.compile(
self.metadata.bind, compile_kwargs={"literal_binds": True}))


def convert_to_valid_identifier(name):
name = string_lowercase_underscore(name)
Expand Down

0 comments on commit bc01b3f

Please sign in to comment.