# Baseten custom model example deployment

<a href="https://colab.research.google.com/github/basetenlabs/demos/blob/main/deployment/baseten_custom.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%pip install --upgrade sklearn baseten

In [None]:
# Model training

from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris

iris = load_iris()
data_x = iris['data']
data_y = iris['target']

model = RandomForestClassifier()
model.fit(data_x, data_y)

In [None]:
# Create a Truss

import truss

truss.init("rfc_truss")

In [None]:
# Generate model files

import joblib

# Pickled model
joblib.dump(model, 'rfc_truss/data/rfc_model.pkl', compress=True)

# Model class
MODEL_CODE_AS_STR = """
import joblib
from typing import Dict, List

class Model:
    def load(self):
        self.model = joblib.load(open('data/rfc_model.pkl', 'rb'))

    def predict(self, model_input):
        output = self.model.predict(model_input)
        model_output = {"predictions" : output }
        return model_output
"""

with open("rfc_truss/model/model.py", "w") as py_file:
    py_file.write(MODEL_CODE_AS_STR)

In [None]:
# Add dependencies

tr = truss.from_directory("rfc_truss")
tr.add_python_requirement("joblib==1.1.0")
tr.add_python_requirement("scikit-learn==1.0.2")

In [None]:
# Model deployment

import baseten

api_key = "PASTE API KEY HERE"
baseten.login(api_key)

baseten.deploy_truss(
    tr,
    model_name="Iris RFC Model (custom deployment)"
)

In [None]:
# After the deployment is finished, call your new model!

deployed_model_id = "PASTE VERSION ID HERE" # See deployed model page to find version ID
model_input = [[0, 0, 0, 0]]

deployed_model = baseten.deployed_model_version_id(deployed_model_id)
deployed_model.predict(model_input)

In [None]:
# Or call your model as an API

import requests
import json

endpoint = f"https://app.baseten.co/model_versions/{deployed_model_id}/predict"
auth = {"Authorization": f"Api-Key {api_key}"}
data = json.dumps(model_input)

resp = requests.post(
    endpoint,
    headers=auth,
    data=data
)

resp.json()