Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove class registry #32

Merged
merged 3 commits into from
May 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 0 additions & 68 deletions src/sqlalchemyseed/class_registry.py

This file was deleted.

10 changes: 10 additions & 0 deletions src/sqlalchemyseed/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,13 @@ class UnsupportedClassError(Exception):
class NotInModuleError(Exception):
"""Raised when a value is not found in module"""
pass


class InvalidModelPath(Exception):
"""Raised when an invalid model path is invoked"""
pass


class UnsupportedClassError(Exception):
"""Raised when an unsupported class is invoked"""
pass
10 changes: 3 additions & 7 deletions src/sqlalchemyseed/seeder.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from sqlalchemy.orm.relationships import RelationshipProperty
from sqlalchemy.sql import schema

from . import class_registry, validator, errors, util
from . import validator, errors, util


class AbstractSeeder(abc.ABC):
Expand Down Expand Up @@ -146,7 +146,6 @@ class Seeder(AbstractSeeder):

def __init__(self, session: sqlalchemy.orm.Session = None, ref_prefix="!"):
self.session = session
self._class_registry = class_registry.ClassRegistry()
self._instances = []
self.ref_prefix = ref_prefix

Expand All @@ -156,15 +155,14 @@ def instances(self):

def get_model_class(self, entity, parent: Entity):
if self.__model_key in entity:
return self._class_registry.register_class(entity[self.__model_key])
return util.get_model_class(entity[self.__model_key])
# parent is not None
return parent.referenced_class

def seed(self, entities, add_to_session=True):
validator.validate(entities=entities, ref_prefix=self.ref_prefix)

self._instances.clear()
self._class_registry.clear()

self._pre_seed(entities)

Expand Down Expand Up @@ -231,7 +229,6 @@ class HybridSeeder(AbstractSeeder):

def __init__(self, session: sqlalchemy.orm.Session, ref_prefix: str = '!'):
self.session = session
self._class_registry = class_registry.ClassRegistry()
self._instances = []
self.ref_prefix = ref_prefix

Expand All @@ -245,7 +242,7 @@ def get_model_class(self, entity, parent: Entity):

if self.__model_key in entity:
class_path = entity[self.__model_key]
return self._class_registry.register_class(class_path)
return util.get_model_class(class_path)

# parent is not None
return parent.referenced_class
Expand All @@ -255,7 +252,6 @@ def seed(self, entities):
entities=entities, ref_prefix=self.ref_prefix)

self._instances.clear()
self._class_registry.clear()

self._pre_seed(entities)

Expand Down
88 changes: 74 additions & 14 deletions src/sqlalchemyseed/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,55 @@
"""


from functools import lru_cache
import importlib
from typing import Iterable

from sqlalchemy import inspect
from sqlalchemyseed import errors


def iter_ref_kwargs(kwargs: dict, ref_prefix: str):
"""Iterate kwargs with name prefix or references"""
"""
Iterate kwargs with name prefix or references
"""
for attr_name, value in kwargs.items():
if attr_name.startswith(ref_prefix):
# removed prefix
yield attr_name[len(ref_prefix):], value


def iter_kwargs_with_prefix(kwargs: dict, prefix: str):
"""
Iterate kwargs(dict) that has the specified prefix.
"""
for key, value in kwargs.items():
if str(key).startswith(prefix):
yield key, value


def iterate_json(json: dict, key_prefix: str):
"""
Iterate through json that has matching key prefix
"""
for key, value in json.items():
has_prefix = str(key).startswith(key_prefix)

if has_prefix:
# removed prefix
yield key[len(key_prefix):], value


def iterate_json_no_prefix(json: dict, key_prefix: str):
"""
Iterate through json that has no matching key prefix
"""
for key, value in json.items():
has_prefix = str(key).startswith(key_prefix)
if not has_prefix:
yield key, value


def iter_non_ref_kwargs(kwargs: dict, ref_prefix: str):
"""Iterate kwargs, skipping item with name prefix or references"""
for attr_name, value in kwargs.items():
Expand All @@ -33,22 +71,44 @@ def is_supported_class(class_):
def generate_repr(instance: object) -> str:
"""
Generate repr of object instance

Example:
```
class Person(Base):
...
def __repr__(self):
return generate_repr(self)
```

Output format:
```
"<Person(id='1',name='John Doe')>"
```
"""
class_name = instance.__class__.__name__
insp = inspect(instance)
attributes = {column.key: column.value for column in insp.attrs}
str_attributes = ",".join(f"{k}='{v}'" for k, v in attributes.items())
return f"<{class_name}({str_attributes})>"


def find_item(json: Iterable, keys: list):
"""
Finds item of json from keys
"""
return find_item(json[keys[0]], keys[1:]) if keys else json


# check if class is a sqlalchemy model
def is_model(class_):
"""
Check if class is a sqlalchemy model
"""
insp = inspect(class_, raiseerr=False)
return insp is not None and insp.is_mapper


# get sqlalchemy model class from path
@lru_cache(maxsize=None)
def get_model_class(path: str):
"""
Get sqlalchemy model class from path
"""
try:
module_name, class_name = path.rsplit(".", 1)
module = importlib.import_module(module_name)
except (ImportError, AttributeError) as e:
raise errors.InvalidModelPath(path=path, error=e)

class_ = getattr(module, class_name)
if not is_model(class_):
raise errors.UnsupportedClassError(path=path)

return class_
62 changes: 0 additions & 62 deletions tests/test_class_registry.py

This file was deleted.