Skip to content

Commit

Permalink
Merge pull request #530 from xliiv/default_repr
Browse files Browse the repository at this point in the history
Add default __repr__ for Model class
  • Loading branch information
davidism committed Sep 26, 2017
2 parents 80a7cce + bceb09b commit 5135ba2
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 3 deletions.
6 changes: 5 additions & 1 deletion flask_sqlalchemy/__init__.py
Expand Up @@ -30,7 +30,7 @@
from sqlalchemy.orm.exc import UnmappedClassError
from sqlalchemy.orm.session import Session as SessionBase

from ._compat import itervalues, string_types, xrange
from ._compat import itervalues, string_types, xrange, to_str

__version__ = '2.2.1'

Expand Down Expand Up @@ -662,6 +662,10 @@ class Model(object):
#: Equivalent to ``db.session.query(Model)`` unless :attr:`query_class` has been changed.
query = None

def __repr__(self):
pk = ', '.join(to_str(value) for value in inspect(self).identity)
return '<{0} {1}>'.format(type(self).__name__, pk)


class SQLAlchemy(object):
"""This class is used to control the SQLAlchemy integration to one
Expand Down
21 changes: 19 additions & 2 deletions flask_sqlalchemy/_compat.py
Expand Up @@ -10,7 +10,6 @@
"""
import sys


PY2 = sys.version_info[0] == 2


Expand All @@ -25,6 +24,15 @@ def itervalues(d):

string_types = (unicode, bytes)

def to_str(x, charset='utf8', errors='strict'):
if x is None or isinstance(x, str):
return x

if isinstance(x, unicode):
return x.encode(charset, errors)

return str(x)

else:
def iteritems(d):
return iter(d.items())
Expand All @@ -34,4 +42,13 @@ def itervalues(d):

xrange = range

string_types = (str, )
string_types = (str,)

def to_str(x, charset='utf8', errors='strict'):
if x is None or isinstance(x, str):
return x

if isinstance(x, bytes):
return x.decode(charset, errors)

return str(x)
31 changes: 31 additions & 0 deletions tests/test_model_class.py
@@ -1,4 +1,6 @@
# coding=utf8
import flask_sqlalchemy as fsa
from flask_sqlalchemy._compat import to_str


def test_custom_query_class(app):
Expand All @@ -11,3 +13,32 @@ class SomeModel(db.Model):
id = db.Column(db.Integer, primary_key=True)

assert isinstance(SomeModel(), CustomModelClass)


def test_repr(db):
class User(db.Model):
name = db.Column(db.String, primary_key=True)

class Report(db.Model):
id = db.Column(db.Integer, primary_key=True, autoincrement=False)
user_name = db.Column(db.ForeignKey(User.name), primary_key=True)

db.create_all()

u = User(name='test')
db.session.add(u)
db.session.flush()
assert repr(u) == '<User test>'
assert repr(u) == str(u)

u2 = User(name=u'🐍')
db.session.add(u2)
db.session.flush()
assert repr(u2) == to_str(u'<User 🐍>')
assert repr(u2) == str(u2)

r = Report(id=2, user_name=u.name)
db.session.add(r)
db.session.flush()
assert repr(r) == '<Report 2, test>'
assert repr(u) == str(u)

0 comments on commit 5135ba2

Please sign in to comment.