In [None]:
from my_model import MyModel

In [None]:
import pandas as pd
import numpy as np
import mlflow
from mlflow.models.signature import ModelSignature
from mlflow.types import DataType, Schema, ColSpec, ParamSchema, ParamSpec

# Define input and output schema
input_schema = Schema(
    [
        ColSpec(DataType.string, "prompt"),
    ]
)
output_schema = Schema([ColSpec(DataType.string, "text_from_llm")])

parameters = ParamSchema(
    [       
    ]
)

signature = ModelSignature(inputs=input_schema, outputs=output_schema, params=parameters)


# Define input example
input_example = pd.DataFrame({"prompt": ["What is machine learning?"]})

In [None]:
client = mlflow.MlflowClient()
model_name="llama2-guanaco-sft"
registered_model = None
try:
    registered_model = client.create_registered_model(model_name)
except:
     registered_model = client.get_registered_model(model_name)

In [None]:
import peft
import trl
import torch
import transformers


from mlflow.store.artifact.runs_artifact_repo import RunsArtifactRepository

# Get the current base version of torch that is installed, without specific version modifiers
torch_version = torch.__version__.split("+")[0]

# Start an MLflow run context and log the llama2-7B model wrapper along with the param-included signature to
# allow for overriding parameters at inference time
with mlflow.start_run() as run:
    model_info = mlflow.pyfunc.log_model(
        model_name,
        python_model=MyModel(),
        # NOTE: the artifacts dictionary mapping is critical! This dict is used by the load_context() method in our MyModel() class.
        artifacts={"snapshot": '/mnt/'},
        pip_requirements=[
            f"torch=={torch_version}",
            f"transformers=={transformers.__version__}",                        
            f"peft=={peft.__version__}",
            f"trl=={trl.__version__}",            
            "einops",
            "sentencepiece",
        ],
        input_example=input_example,
        signature=signature,
    )
    runs_uri = model_info.model_uri
    print(runs_uri)
    # Create a new model version of the RandomForestRegression model from this run
    
    model_src = RunsArtifactRepository.get_underlying_uri(runs_uri)
    mv = client.create_model_version(model_name, model_src, run.info.run_id)
    print("Name: {}".format(mv.name))
    print("Version: {}".format(mv.version))
    print("Description: {}".format(mv.description))
    print("Status: {}".format(mv.status))
    print("Stage: {}".format(mv.current_stage))