Skip to content

Commit

Permalink
add time column support, allow any lowercase fixture filenames (#10)
Browse files Browse the repository at this point in the history
  • Loading branch information
briancappello committed Aug 16, 2023
1 parent a7d6eb4 commit 4f59db8
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 4 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
.idea/
*.egg-info
*.pyc
.coverage
Expand Down
6 changes: 4 additions & 2 deletions py_yaml_fixtures/factories/sqlalchemy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import defaultdict
from datetime import date, datetime, timedelta
from datetime import date, datetime, time, timedelta
from functools import lru_cache
from types import FunctionType
from typing import *
Expand Down Expand Up @@ -70,7 +70,7 @@ def _get_existing(self, identifier: Identifier, data: Dict[str, Any]):
filter_kwargs = {k: v for k, v in data.items()
if (k in relationships and hasattr(v, '__mapper__'))
or v is None
or isinstance(v, (bool, int, str, float, date, datetime))}
or isinstance(v, (bool, int, str, float))}
if not filter_kwargs:
return None

Expand Down Expand Up @@ -116,6 +116,8 @@ def maybe_convert_values(self,
continue
elif col.type.python_type == date:
rv[col_name] = self.date_factory(value)
elif col.type.python_type == time:
rv[col_name] = time(*[int(x) for x in value.split(':')])
elif col.type.python_type == datetime:
rv[col_name] = self.datetime_factory(value)
elif col.type.python_type == timedelta:
Expand Down
4 changes: 2 additions & 2 deletions py_yaml_fixtures/fixtures_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def _load_data(self):
rendered_yaml = env.get_template(filepath).render()
data = yaml.load(rendered_yaml, Loader=yaml.FullLoader)
if data:
if filename in MULTI_CLASS_FILENAMES:
if filename.islower():
for class_name in data:
model_identifiers[class_name] = list(
data[class_name].keys())
Expand All @@ -173,7 +173,7 @@ def _load_from_yaml(self, filepath: str, model_identifiers: Dict[str, List[str]]

identifier_data = {}
filename = os.path.basename(filepath)
if filename in MULTI_CLASS_FILENAMES:
if filename.islower():
for class_name in data:
d, self.relationships[class_name] = self._post_process_yaml_data(
data[class_name], self.factory.get_relationships(class_name))
Expand Down

0 comments on commit 4f59db8

Please sign in to comment.