Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/model grouping #2

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ _pip-pypi-elemeno: clean
test:
pytest

dev:
pip install -r requirements.txt
pip install -r requirements-dev.txt

pip-testpypi: clean _pip-testpypi

pip-pypi: clean _pip-pypi
Expand Down Expand Up @@ -51,4 +55,4 @@ bump-major:
bump-dev:
bumpversion build --tag --verbose
@echo "New version: v$$(python setup.py --version)"
@echo "Make sure to push the new tag to GitHub"
@echo "Make sure to push the new tag to GitHub"
2 changes: 2 additions & 0 deletions elemeno.yaml.TEMPLATE
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ cos:
use_ssl: False
bucket: elemeno-cos
registry:
endpoint_url: http://localhost:9005
tracking_url: http://mlflow.tracking.url:80
ignore_tls: True
feature_store:
feast_config_path: .
registry: gs://elemeno-feature-store/generic_registry
Expand Down
8 changes: 4 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@
feast[redis,gcp,aws]==0.31.0
GitPython==3.1.26
Jinja2==3.0.3
mlflow==1.26.1
mlflow==2.4.0
numpy>=1.22
omegaconf>=2.1.0
onnx==1.10.2
# onnx==1.10.2
pandas>=1.2.4
pandas_gbq==0.15.0
setuptools>=59.5.0
sphinx_rtd_theme==1.0.0
tf2onnx==1.9.3
#onnxruntime>=1.8.1
onnxruntime>=1.8.1
elasticsearch[async]==8.1.3
redshift_connector==2.0.904
sqlalchemy-redshift==0.8.9
Expand All @@ -20,7 +20,7 @@ psycopg2-binary==2.9.4
pyspark==3.3.0
asyncio==3.4.3
minio>=7.1.3
#pip install -i https://test.pypi.org/simple/ onnxruntime==1.8.2.dev20210816004 # do when using mac arm64
# pip install -i https://test.pypi.org/simple/ onnxruntime==1.8.2.dev20210816004 # do when using mac arm64
#torch==1.10.2
# for 3.10 install 1.11

25 changes: 20 additions & 5 deletions src/elemeno_ai_sdk/ml/registry.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import os
import torch
import mlflow
import mlflow # type: ignore
import mlflow.pyfunc
import onnx
from elemeno_ai_sdk.config import Configs
from typing import Dict, List, Optional


class ModelRegistry:

Expand All @@ -12,12 +14,25 @@ def __init__(self):
self._config = cfg
os.environ["AWS_ACCESS_KEY_ID"] = cfg.cos.key_id
os.environ["AWS_SECRET_ACCESS_KEY"] = cfg.cos.secret
#TODO fix this to load from config file
os.environ["MLFLOW_S3_IGNORE_TLS"] = "true"
minio_host = cfg.cos.host
os.environ["MLFLOW_S3_ENDPOINT_URL"] = f"http://{minio_host}"
os.environ["MLFLOW_S3_IGNORE_TLS"] = cfg.registry.ignore_tsl
os.environ["MLFLOW_S3_ENDPOINT_URL"] = cfg.registry.endpoint_url
mlflow.set_tracking_uri(cfg.registry.tracking_url)
self.client = mlflow.MlflowClient()

def tag_model(self, model_name: str, tags: Dict[str, str]) -> None:
for tag_name, tag_value in tags.items():
self.client.set_registered_model_tag(name=model_name, key=tag_name, value=tag_value)

def save_model(self, model_file: str, model_name: str, tags: Optional[Dict[str, str]] = None) -> None:
with mlflow.start_run(tags=tags):
mlflow.pyfunc.log_model(artifact_path=model_file, registered_model_name=model_name)

if tags is not None:
self.tag_model(model_name=model_name, tags=tags)

def get_models_by_tag(self, tag: str) -> List[str]:
models = self.client.search_registered_models(filter_string=f"tag.value = {tag}")
return [model.name for model in models]

def get_latest_model_torch(self, model_name: str, device: str):
"""Loads the most recent model registered and in stage Production
Expand Down