Skip to content

Commit

Permalink
mutable json field tracking (#24)
Browse files Browse the repository at this point in the history
* mutable json field tracking

* mutable json tests
  • Loading branch information
eyadgaran committed Oct 29, 2020
1 parent 592ccf9 commit 50f63ce
Show file tree
Hide file tree
Showing 9 changed files with 78 additions and 28 deletions.
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
elif sys.version_info < (3, 6): # Python 3.5
version_based_dependencies = [
'scikit-learn<0.23.0',
'scipy<1.5.0', # Scikit-learn dependency
'scipy<1.5.0', # Scikit-learn dependency
'pandas<1.0.0',
'markupsafe<2.0.0',
]
Expand Down Expand Up @@ -54,6 +54,7 @@
install_requires=[
'sqlalchemy>=1.3.7', # Unified json_serializer/deserializer for sqlite
'sqlalchemy-mixins',
'sqlalchemy-json',
'alembic',
'numpy',
'cloudpickle',
Expand All @@ -70,7 +71,7 @@
zip_safe=False,
test_suite='simpleml.tests.load_tests',
tests_require=['nose'],
entry_points = {
entry_points={
'console_scripts': [
'simpleml-test=simpleml.tests:run_tests',
'simpleml-unit-test=simpleml.tests.unit:run_tests',
Expand Down
6 changes: 3 additions & 3 deletions simpleml/metrics/base_metric.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from simpleml.persistables.base_persistable import Persistable, GUID, JSON
from simpleml.persistables.base_persistable import Persistable, GUID, MutableJSON
from simpleml.registries import MetricRegistry
from simpleml.utils.errors import MetricError
from sqlalchemy import Column, ForeignKey, UniqueConstraint, Index, func
Expand All @@ -22,7 +22,7 @@ class AbstractMetric(with_metaclass(MetricRegistry, Persistable)):
'''
__abstract__ = True

values = Column(JSON, nullable=False)
values = Column(MutableJSON, nullable=False)

object_type = 'METRIC'

Expand Down Expand Up @@ -123,4 +123,4 @@ class Metric(AbstractMetric):
UniqueConstraint('name', 'model_id', 'version', name='metric_name_model_version_unique'),
# Index for searching through friendly names
Index('metric_name_index', 'name'),
)
)
24 changes: 12 additions & 12 deletions simpleml/migrations/versions/9df691c76c63_base_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def upgrade():
sa.Column('version', sa.Integer(), nullable=False),
sa.Column('version_description', sa.String(), nullable=True),
sa.Column('has_external_files', sa.Boolean(), nullable=True),
sa.Column('filepaths', simpleml.persistables.sqlalchemy_types.JSON(), nullable=True),
sa.Column('metadata', simpleml.persistables.sqlalchemy_types.JSON(), nullable=True),
sa.Column('filepaths', simpleml.persistables.sqlalchemy_types.MutableJSON(), nullable=True),
sa.Column('metadata', simpleml.persistables.sqlalchemy_types.MutableJSON(), nullable=True),
sa.Column('pipeline_id', simpleml.persistables.sqlalchemy_types.GUID(), nullable=True),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('name', 'version', name='dataset_name_version_unique')
Expand All @@ -51,9 +51,9 @@ def upgrade():
sa.Column('version', sa.Integer(), nullable=False),
sa.Column('version_description', sa.String(), nullable=True),
sa.Column('has_external_files', sa.Boolean(), nullable=True),
sa.Column('filepaths', simpleml.persistables.sqlalchemy_types.JSON(), nullable=True),
sa.Column('metadata', simpleml.persistables.sqlalchemy_types.JSON(), nullable=True),
sa.Column('values', simpleml.persistables.sqlalchemy_types.JSON(), nullable=False),
sa.Column('filepaths', simpleml.persistables.sqlalchemy_types.MutableJSON(), nullable=True),
sa.Column('metadata', simpleml.persistables.sqlalchemy_types.MutableJSON(), nullable=True),
sa.Column('values', simpleml.persistables.sqlalchemy_types.MutableJSON(), nullable=False),
sa.Column('model_id', simpleml.persistables.sqlalchemy_types.GUID(), nullable=True),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('name', 'model_id', 'version', name='metric_name_model_version_unique')
Expand All @@ -72,10 +72,10 @@ def upgrade():
sa.Column('version', sa.Integer(), nullable=False),
sa.Column('version_description', sa.String(), nullable=True),
sa.Column('has_external_files', sa.Boolean(), nullable=True),
sa.Column('filepaths', simpleml.persistables.sqlalchemy_types.JSON(), nullable=True),
sa.Column('metadata', simpleml.persistables.sqlalchemy_types.JSON(), nullable=True),
sa.Column('params', simpleml.persistables.sqlalchemy_types.JSON(), nullable=True),
sa.Column('feature_metadata', simpleml.persistables.sqlalchemy_types.JSON(), nullable=True),
sa.Column('filepaths', simpleml.persistables.sqlalchemy_types.MutableJSON(), nullable=True),
sa.Column('metadata', simpleml.persistables.sqlalchemy_types.MutableJSON(), nullable=True),
sa.Column('params', simpleml.persistables.sqlalchemy_types.MutableJSON(), nullable=True),
sa.Column('feature_metadata', simpleml.persistables.sqlalchemy_types.MutableJSON(), nullable=True),
sa.Column('pipeline_id', simpleml.persistables.sqlalchemy_types.GUID(), nullable=True),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('name', 'version', name='model_name_version_unique')
Expand All @@ -94,9 +94,9 @@ def upgrade():
sa.Column('version', sa.Integer(), nullable=False),
sa.Column('version_description', sa.String(), nullable=True),
sa.Column('has_external_files', sa.Boolean(), nullable=True),
sa.Column('filepaths', simpleml.persistables.sqlalchemy_types.JSON(), nullable=True),
sa.Column('metadata', simpleml.persistables.sqlalchemy_types.JSON(), nullable=True),
sa.Column('params', simpleml.persistables.sqlalchemy_types.JSON(), nullable=True),
sa.Column('filepaths', simpleml.persistables.sqlalchemy_types.MutableJSON(), nullable=True),
sa.Column('metadata', simpleml.persistables.sqlalchemy_types.MutableJSON(), nullable=True),
sa.Column('params', simpleml.persistables.sqlalchemy_types.MutableJSON(), nullable=True),
sa.Column('dataset_id', simpleml.persistables.sqlalchemy_types.GUID(), nullable=True),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('name', 'version', name='pipeline_name_version_unique')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from sqlalchemy import MetaData, Column
from sqlalchemy.orm import scoped_session, sessionmaker

from simpleml.persistables.sqlalchemy_types import GUID, JSON
from simpleml.persistables.sqlalchemy_types import GUID, MutableJSON
from simpleml.persistables.base_sqlalchemy import BaseSQLAlchemy

LOGGER = logging.getLogger(__name__)
Expand All @@ -33,8 +33,8 @@ class UpgradeTableModel(BaseSQLAlchemy):
__abstract__ = True
metadata = MetaData()
id = Column(GUID, primary_key=True)
metadata_ = Column('metadata', JSON, default={})
filepaths = Column(JSON)
metadata_ = Column('metadata', MutableJSON, default={})
filepaths = Column(MutableJSON)


class DatasetModel(UpgradeTableModel):
Expand Down
6 changes: 3 additions & 3 deletions simpleml/models/base_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from simpleml.persistables.base_persistable import Persistable, GUID, JSON
from simpleml.persistables.base_persistable import Persistable, GUID, MutableJSON
from simpleml.registries import ModelRegistry
from simpleml.persistables.saving import ExternalArtifactsMixin
from simpleml.utils.errors import ModelError
Expand Down Expand Up @@ -38,8 +38,8 @@ class AbstractModel(with_metaclass(ModelRegistry, Persistable)):
__abstract__ = True

# Additional model specific metadata
params = Column(JSON, default={})
feature_metadata = Column(JSON, default={})
params = Column(MutableJSON, default={})
feature_metadata = Column(MutableJSON, default={})

object_type = 'MODEL'

Expand Down
6 changes: 3 additions & 3 deletions simpleml/persistables/base_persistable.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from typing import Dict, Union, Optional, Any
from sqlalchemy import Column, func, String, Boolean, Integer

from simpleml.persistables.sqlalchemy_types import GUID, JSON
from simpleml.persistables.sqlalchemy_types import GUID, MutableJSON
from simpleml.persistables.base_sqlalchemy import SimplemlCoreSqlalchemy
from simpleml.persistables.saving import AllSaveMixin
from simpleml.persistables.hashing import CustomHasherMixin
Expand Down Expand Up @@ -95,10 +95,10 @@ class Persistable(with_metaclass(MetaRegistry, SimplemlCoreSqlalchemy, AllSaveMi

# Persistence of fitted states
has_external_files = Column(Boolean, default=False)
filepaths = Column(JSON, default={})
filepaths = Column(MutableJSON, default={})

# Generic store and metadata for all child objects
metadata_ = Column('metadata', JSON, default={})
metadata_ = Column('metadata', MutableJSON, default={})

# Internal Registry for all allowed external files
# Does not need to be persisted because it gets populated on import
Expand Down
6 changes: 6 additions & 0 deletions simpleml/persistables/sqlalchemy_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from sqlalchemy.types import TypeDecorator, CHAR, JSON as SQLJSON, Text
from sqlalchemy.dialects.postgresql import UUID, JSONB
from sqlalchemy_json import mutable_json_type
import uuid


Expand Down Expand Up @@ -57,3 +58,8 @@ def load_dialect_impl(self, dialect):
return dialect.type_descriptor(JSONB(astext_type=Text()))
else:
return dialect.type_descriptor(SQLJSON())


# Mutable version of JSON field
# https://docs.sqlalchemy.org/en/13/core/type_basics.html?highlight=json#sqlalchemy.types.JSON
MutableJSON = mutable_json_type(dbtype=JSON, nested=True)
4 changes: 2 additions & 2 deletions simpleml/pipelines/base_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from simpleml.persistables.base_persistable import Persistable
from simpleml.persistables.saving import ExternalArtifactsMixin
from simpleml.registries import PipelineRegistry
from simpleml.persistables.sqlalchemy_types import GUID, JSON
from simpleml.persistables.sqlalchemy_types import GUID, MutableJSON

from simpleml.pipelines.external_pipelines import DefaultPipeline, SklearnPipeline
from simpleml.pipelines.validation_split_mixins import Split
Expand Down Expand Up @@ -44,7 +44,7 @@ class AbstractPipeline(with_metaclass(PipelineRegistry, Persistable)):
__abstract__ = True

# Additional pipeline specific metadata
params = Column(JSON, default={})
params = Column(MutableJSON, default={})

object_type = 'PIPELINE'

Expand Down
43 changes: 43 additions & 0 deletions simpleml/tests/integration/test_sqlalchemy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
'''
SqlAlchemy specific tests
'''

__author__ = 'Elisha Yadgaran'


import unittest

from simpleml.datasets import Dataset


class MutableJSONTests(unittest.TestCase):
'''Default sqlalchemy behavior treats JSON data as immutable'''

def test_modifying_json_field(self):
'''
Top level JSON change
'''
persistable = Dataset(name='top_level_json_modification_test')
persistable._external_file = 'datadata'
persistable.save()

persistable.metadata_['new_key'] = 'blah'
self.assertIn(persistable, persistable._session.dirty)
persistable._session.refresh(persistable)

def test_modifying_nested_json_field(self):
'''
Nested JSON change
'''
persistable = Dataset(name='nested_json_modification_test')
persistable.metadata_['new_key'] = {}
persistable._external_file = 'datadata'
persistable.save()

persistable.metadata_['new_key']['sub_key'] = 'blah'
self.assertIn(persistable, persistable._session.dirty)
persistable._session.refresh(persistable)


if __name__ == '__main__':
unittest.main(verbosity=2)

0 comments on commit 50f63ce

Please sign in to comment.