Skip to content

Commit

Permalink
refs #323, code style refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
wwwjfy committed Sep 6, 2018
1 parent c78a340 commit 7d7b9c4
Showing 1 changed file with 22 additions and 24 deletions.
46 changes: 22 additions & 24 deletions gino/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,26 @@ def __getattr__(self, item):
_none = object()


def _get_column(model, column_or_name) -> Column:
if isinstance(column_or_name, str):
return getattr(model, column_or_name)

if isinstance(column_or_name, Column):
if column_or_name in model:
return column_or_name
raise AttributeError('Column {} does not belong to model {}'.format(
column_or_name, model))

raise TypeError('Unknown column {} with type {}'.
format(column_or_name, type(column_or_name)))


class ModelLoader(Loader):
def __init__(self, model, *column_names, **extras):
def __init__(self, model, *columns, **extras):
self.model = model
self._distinct = None
if column_names:
self.columns = self._column_loader(model, column_names)
if columns:
self.columns = [_get_column(model, name) for name in columns]
else:
self.columns = model
self.extras = dict((key, self.get(value))
Expand Down Expand Up @@ -121,30 +135,14 @@ def get_from(self):
getattr(subloader, 'on_clause', None))
return rv

def load(self, *column_names, **extras):
if column_names:
self.columns = self._column_loader(self.model, column_names)
def load(self, *columns, **extras):
if columns:
self.columns = [_get_column(self.model, name) for name in columns]

self.extras.update((key, self.get(value))
for key, value in extras.items())
return self

@classmethod
def _column_loader(cls, model, column_names):
def column_formatter(column_name):
if isinstance(column_name, str):
return getattr(model, column_name)
elif isinstance(column_name, Column):
if column_name not in model:
raise AttributeError('Column {} does not belong '
'to this model'.format(column_name))
return column_name
else:
raise TypeError('Unknown column name {} type {}'.
format(column_name, type(column_name)))

return [column_formatter(column_name) for column_name in column_names]

def on(self, on_clause):
self.on_clause = on_clause
return self
Expand All @@ -163,8 +161,8 @@ def none_as_none(self, enabled=True):


class AliasLoader(ModelLoader):
def __init__(self, alias, *column_names, **extras):
super().__init__(alias, *column_names, **extras)
def __init__(self, alias, *columns, **extras):
super().__init__(alias, *columns, **extras)


class ColumnLoader(Loader):
Expand Down

0 comments on commit 7d7b9c4

Please sign in to comment.