In [13]:
from sqlalchemy import create_engine
from sqlalchemy.engine.base import Engine
from sqlalchemy_utils import database_exists, create_database

from sqlalchemy.orm import declarative_base, Session, sessionmaker
from sqlalchemy import Column, Integer, String, TypeDecorator
# from sqlalchemy.types import Column, String, TypeDecorator
# from sqlalchemy.ext.declarative import declarative_base
import pandas as pd

class HexByteString(TypeDecorator):
    """  
        Class to store model weights in postgress
    """

    impl = String

    def process_bind_param(self, value, dialect):
        if not isinstance(value, bytes):
            raise TypeError("HexByteString columns support only bytes values.")
        return value.hex()

    def process_result_value(self, value, dialect):
        return bytes.fromhex(value) if value else None


# Base = declarative_base()
# class MyModel(Base):
#     data = Column(HexByteString)

import json

from typing import Dict, Any, List
from functools import lru_cache

In [14]:
## Connect to db
def connec_to_db() -> Engine:
    """
        Connect to postgress database
    """
    #     # ToDo: Add reading variables from linux 
    #     import os
    #     print(os.environ["test1"])
    
    postgress_url = "172.19.0.2"
    postgress_password = "password"
    postgress_user = "postgres"
    postgress_db = "test"
    
    engine = create_engine(f'postgresql+psycopg2://{postgress_user}:'+\
                           f'{postgress_password}@{postgress_url}/'+
                           f'{postgress_db}'
                          )
    if not database_exists(engine.url):
        create_database(engine.url)

    #print(engine.url)
    return engine

In [15]:
import pickle
from io import BytesIO

Base = declarative_base()
class ModelInstance(Base):
    """
        ORM for model's instances
    """
    __tablename__ = "model_instance"

    model_name = Column(String, primary_key=True)
    model_type = Column(String)
    fit_params_json = Column(String)
    python_library_path = Column(String)
    model_bin = Column(HexByteString)
    features  = Column(String)
    target_column = Column(String)
    
    def __repr__(self):
        return f"model_name={self.model_name}\n" + \
               f"model_type={self.model_type}" + \
               f"fit_params_json={self.fit_params_json}" + \
               f"python_library_path={self.python_library_path}"
    
    def _get_features(self):
        open("xyu.txt", 'w+').write(self.features)
        return json.loads(self.features)
    
    def fit(self, 
            data: pd.DataFrame, 
            target_column: str = 'y'
           ) -> None:
        model_class = self._import_sklearn_model_class()
        fit_params = json.loads(self.fit_params_json)
        model = model_class(**fit_params)
        
        self.target_column = target_column

        features = list(data.keys())
        features.remove(target_column)
        self.features = json.dumps(features)
        
        model.fit(X = data[self._get_features()], 
                  y = data[self.target_column]
                 )
        
        self.model_bin = ModelInstance._model_to_buff(model)
        
    def predict(self, data: pd.DataFrame) -> Dict[Any, Any]:
        model = self._get_model()
        data['predict'] = model.predict(X = data[self._get_features()])
        return data['predict'].to_dict()
    
    def _import_sklearn_model_class(self):
        """
            Интроспекция для загрузки модуля sklearn
        """
        from_str = '.'.join(self.python_library_path.split('.')[:2])
        res = __import__(from_str)
        res = getattr(res, self.python_library_path.split('.')[1])
        res = getattr(res, self.python_library_path.split('.')[2])
        return res
    
    @classmethod
    def _model_to_buff(cls, model_python) -> bytes:
        buffer = BytesIO()
        pickle.dump(model_python, buffer)
        buffer.seek(0)
        return buffer.read()
    
    @classmethod
    def _buff_to_model(csl, model_bin) -> Any:
        model_python = pickle.loads(model_bin)
        return model_python

    def _get_model(self) -> Any:
        return ModelInstance._buff_to_model(self.model_bin)
    

In [4]:
## Удалить таблицу
engine = connec_to_db()
ModelInstance.__table__.drop(engine)
Base.metadata.create_all(engine)

Session = sessionmaker(bind=engine)
session = Session()

In [5]:
# ToDo: add checking that model is already exists
model_instance = ModelInstance(model_name  = "test_real90",
                               model_type  = "RandomForestClassifier",
                               fit_params_json = "{}",
                               python_library_path = "sklearn.ensemble.RandomForestClassifier",
#                                model_bin = model_to_buff(model)
                              )

data = pd.read_csv('fastapi_microservice/datasets/iris/data.csv')
model_instance.fit(data = data, target_column = "y")
session.add(model_instance)
session.commit()

In [14]:
!rm -rf xyu.txt

In [6]:
!cat xyu.txt

["0", "1", "2", "3"]

In [10]:
# o = model_instance.get_model()
# .predict(data_input=data)

In [12]:
data = pd.read_csv('fastapi_microservice/datasets/iris/data.csv')
session.query(ModelInstance).get("test_real90").predict(data)

{0: 0,
 1: 0,
 2: 0,
 3: 0,
 4: 0,
 5: 0,
 6: 0,
 7: 0,
 8: 0,
 9: 0,
 10: 0,
 11: 0,
 12: 0,
 13: 0,
 14: 0,
 15: 0,
 16: 0,
 17: 0,
 18: 0,
 19: 0,
 20: 0,
 21: 0,
 22: 0,
 23: 0,
 24: 0,
 25: 0,
 26: 0,
 27: 0,
 28: 0,
 29: 0,
 30: 0,
 31: 0,
 32: 0,
 33: 0,
 34: 0,
 35: 0,
 36: 0,
 37: 0,
 38: 0,
 39: 0,
 40: 0,
 41: 0,
 42: 0,
 43: 0,
 44: 0,
 45: 0,
 46: 0,
 47: 0,
 48: 0,
 49: 0,
 50: 1,
 51: 1,
 52: 1,
 53: 1,
 54: 1,
 55: 1,
 56: 1,
 57: 1,
 58: 1,
 59: 1,
 60: 1,
 61: 1,
 62: 1,
 63: 1,
 64: 1,
 65: 1,
 66: 1,
 67: 1,
 68: 1,
 69: 1,
 70: 1,
 71: 1,
 72: 1,
 73: 1,
 74: 1,
 75: 1,
 76: 1,
 77: 1,
 78: 1,
 79: 1,
 80: 1,
 81: 1,
 82: 1,
 83: 1,
 84: 1,
 85: 1,
 86: 1,
 87: 1,
 88: 1,
 89: 1,
 90: 1,
 91: 1,
 92: 1,
 93: 1,
 94: 1,
 95: 1,
 96: 1,
 97: 1,
 98: 1,
 99: 1,
 100: 2,
 101: 2,
 102: 2,
 103: 2,
 104: 2,
 105: 2,
 106: 2,
 107: 2,
 108: 2,
 109: 2,
 110: 2,
 111: 2,
 112: 2,
 113: 2,
 114: 2,
 115: 2,
 116: 2,
 117: 2,
 118: 2,
 119: 2,
 120: 2,
 121: 2,
 122: 2,
 12

In [16]:
data = pd.read_csv('fastapi_microservice/datasets/iris/data.csv')
session.query(ModelInstance).get("test_real0").get_model().predict(data[session.query(ModelInstance).get("test_real0")._get_features()])


array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])

In [20]:
session.query(ModelInstance).all()[0]

model_name=test_real90
model_type=RandomForestClassifierfit_params_json={}python_library_path=sklearn.ensemble.RandomForestClassifier