Skip to content

Commit

Permalink
model: Consolidated self.location as a property of baseclass
Browse files Browse the repository at this point in the history
- model: pytorch: Updated to use self.location property
- model: spacy: Updated to use self.location property
- model: tensorflow_hub: Updated to use self.location property
  • Loading branch information
programmer290399 committed Nov 30, 2021
1 parent 66f965a commit 669ac3d
Show file tree
Hide file tree
Showing 9 changed files with 24 additions and 72 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Dev CMD to remove unused imports, `$ dffml service dev lint imports`
- Helper for creating a blank generic Python package
`$ dffml service dev create blank mypackage`
- Added `is_trained` flag to all models
- `is_trained` flag to all models
- Dynamic `location` property to `Model` baseclass.
### Changed
- Calls to hashlib now go through helper functions
- Build docs using `dffml service dev docs`
Expand Down
21 changes: 11 additions & 10 deletions dffml/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,7 @@ async def __aenter__(self):
in ["zip", "tar"],
]
):
temp_dir = self._get_directory()
self.location = temp_dir
self.create_temp_directory()
else:
self._make_config_location()

Expand Down Expand Up @@ -196,10 +195,9 @@ async def _run_operation(self, input_path, output_path, dataflow):
async for _, _ in run(dataflow):
pass

def _get_directory(self) -> pathlib.Path:
def create_temp_directory(self):
if not hasattr(self, "temp_dir"):
self.temp_dir = pathlib.Path(mkdtemp())
return self.temp_dir

def _make_config_location(self):
"""
Expand All @@ -212,6 +210,14 @@ def _make_config_location(self):
if not location.is_dir():
location.mkdir(mode=MODE_BITS_SECURE, parents=True)

@property
def location(self):
return (
self.config.location
if not hasattr(self, "temp_dir")
else self.temp_dir
)


class SimpleModelNoContext:
"""
Expand Down Expand Up @@ -298,12 +304,7 @@ def disk_path(self, extention: Optional[str] = None):
if "features" in exported:
exported["features"] = dict(sorted(exported["features"].items()))
# Hash the exported config
return pathlib.Path(
self.config.location
if not hasattr(self, "temp_dir")
else self.temp_dir,
"Model",
)
return pathlib.Path(self.location, "Model",)

def applicable_features(self, features):
usable = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,7 @@ def __init__(self, config):

async def __aenter__(self) -> "AutoSklearnModel":
await super().__aenter__()
self.path = self.filepath(
self.config.location
if not hasattr(self, "temp_dir")
else self.temp_dir,
"trained_model.sav",
)
self.path = self.filepath(self.location, "trained_model.sav",)
self.load_model()
return self

Expand Down
10 changes: 1 addition & 9 deletions model/pytorch/dffml_model_pytorch/pytorch_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,17 +362,9 @@ def createModel(self):
"Can't use createModel method from PyTorchModel"
)

@property
def base_path(self):
return (
self.config.location
if not hasattr(self, "temp_dir")
else self.temp_dir
)

@property
def model_path(self):
return self.base_path / "model.pt"
return self.location / "model.pt"

def _classifications(self, cids):
"""
Expand Down
16 changes: 4 additions & 12 deletions model/scikit/dffml_model_scikit/scikit_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,20 +58,12 @@ def __init__(self, config) -> None:
self.clf = None
self.joblib = importlib.import_module("joblib")

@property
def _filepath(self):
return (
self.config.location
if not hasattr(self, "temp_dir")
else self.temp_dir
)

async def __aenter__(self) -> "Scikit":
await super().__aenter__()
self.saved_filepath = self._filepath / "Scikit.json"
self.saved_filepath = self.location / "Scikit.json"
if self.saved_filepath.is_file():
self.saved = json.loads(self.saved_filepath.read_text())
self.clf_path = self._filepath / "ScikitFeatures.joblib"
self.clf_path = self.location / "ScikitFeatures.joblib"
if self.clf_path.is_file():
self.clf = self.joblib.load(str(self.clf_path))
self.is_trained = True
Expand All @@ -87,10 +79,10 @@ async def __aexit__(self, exc_type, exc_value, traceback):
class ScikitUnsprvised(Scikit):
async def __aenter__(self) -> "ScikitUnsprvised":
await super().__aenter__()
self.saved_filepath = self._filepath / "Scikit.json"
self.saved_filepath = self.location / "Scikit.json"
if self.saved_filepath.is_file():
self.saved = json.loads(self.saved_filepath.read_text())
self.clf_path = self._filepath / "ScikitUnsupervised.json"
self.clf_path = self.location / "ScikitUnsupervised.json"
if self.clf_path.is_file():
self.clf = self.joblib.load(str(self.clf_path))
self.is_trained = True
Expand Down
7 changes: 1 addition & 6 deletions model/spacy/dffml_model_spacy/ner/ner_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,12 +295,7 @@ def __init__(self, config):

async def __aenter__(self):
await super().__aenter__()
self.path = (
self.config.location
if not hasattr(self, "temp_dir")
else self.temp_dir
)
self.model_path = self.path / "ner"
self.model_path = self.location / "ner"
if self.model_path.exists():
self.nlp = spacy.load(self.model_path)
self.logger.debug("Loaded model from disk.")
Expand Down
10 changes: 1 addition & 9 deletions model/tensorflow/dffml_model_tensorflow/tf_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,17 +135,9 @@ def _applicable_features(self):
if name in self.feature_columns
]

@property
def base_path(self):
return (
self.config.location
if not hasattr(self, "temp_dir")
else self.temp_dir
)

@property
def model_path(self):
return self.base_path / "DNNModel"
return self.location / "DNNModel"

async def __aenter__(self):
await super().__aenter__()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -350,14 +350,6 @@ def __init__(self, config):
def model(self):
return self._model

@property
def base_path(self):
return (
self.config.location
if not hasattr(self, "temp_dir")
else self.temp_dir
)

@property
def model_folder_path(self):
_to_hash = self.features + [
Expand All @@ -366,7 +358,7 @@ def model_folder_path(self):
self.config.model_path,
]
model_name = secure_hash("".join(_to_hash), algorithm="sha384")
model_folder_path = self.base_path / model_name
model_folder_path = self.location / model_name
if not model_folder_path.exists():
model_folder_path.mkdir(parents=True, exist_ok=True)
return model_folder_path
Expand Down
12 changes: 2 additions & 10 deletions model/vowpalWabbit/dffml_model_vowpalWabbit/vw_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,14 +324,6 @@ def __init__(self, config) -> None:
self.clf = None
self.is_trained = self.model_filename.exists()

@property
def _base_path(self):
return (
self.config.location
if not hasattr(self, "temp_dir")
else self.temp_dir
)

def applicable_features(self, features):
usable = []
for feature in features:
Expand Down Expand Up @@ -364,7 +356,7 @@ def _feature_predict_hash(self):
@property
def model_filename(self):
model_name = self._features_hash + ".vw"
return self._base_path / model_name
return self.location / model_name

def _load_model(self):
formatted_args = ""
Expand Down Expand Up @@ -399,7 +391,7 @@ def _saved_file_path(self):
saved_file_name = (
secure_hash(self.config.predict.name, algorithm="sha384") + ".json"
)
return self._base_path / saved_file_name
return self.location / saved_file_name

async def __aenter__(self) -> "VWModel":
await super().__aenter__()
Expand Down

0 comments on commit 669ac3d

Please sign in to comment.