# Fetch models from Hugging face hub, tag and store to mlflow UC registry

In [0]:
%pip install --upgrade accelerate mlflow-skinny optree>=0.13.0 torch torchvision transformers


%restart_python

In [0]:
import mlflow
from mlflow import MlflowClient


mlflow.set_registry_uri("databricks-uc")
client  = MlflowClient()

In [0]:
CATALOG = "amine_elhelou"
SCHEMA = "ray_gtm_examples"
VOLUME = "transcribe-data"

## ASR Model

In [0]:
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline


device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float32 #.float16 for model size reduction

model_id = "openai/whisper-large-v3-turbo"

model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)

processor = AutoProcessor.from_pretrained(model_id)
model_args = {
      "chunk_length_s" : 30,
      "language" : "en"
    }

asr_pipe = pipeline(
    "automatic-speech-recognition",
    model=model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    torch_dtype=torch_dtype,
    device=device,
    model_kwargs=model_args
)

In [0]:
import mlflow


ASR_MODEL_NAME = model_id.split("/")[-1]

# Log the pipeline
with mlflow.start_run(run_name="whisper-transcriber-log-pipeline"):
    model_info = mlflow.transformers.log_model(
        transformers_model=asr_pipe,
        artifact_path="whisper_transcriber",
        input_example="/path/to/audio.file",
        registered_model_name=f"{CATALOG}.{SCHEMA}.{ASR_MODEL_NAME}",
    )

In [0]:
client.set_registered_model_alias(
  name=f"{CATALOG}.{SCHEMA}.{ASR_MODEL_NAME}",
  version=model_info.registered_model_version,
  alias="production",
)

## PII Redaction Model

In [0]:
model_id = "iiiorg/piiranha-v1-detect-personal-information"
PII_MODEL_NAME = model_id.split("/")[-1] # for registered model name
pii_pipeline = pipeline("ner", model=model_id, device="cuda:0")

with mlflow.start_run(run_name="pii-redactor-log-pipeline"):
    pii_model_info = mlflow.transformers.log_model(
        transformers_model=pii_pipeline,
        artifact_path="pii_model",
        input_example="Sample text with PII",
        registered_model_name=f"{CATALOG}.{SCHEMA}.{PII_MODEL_NAME}",
    )

In [0]:
client.set_registered_model_alias(
  name=f"{CATALOG}.{SCHEMA}.{PII_MODEL_NAME}",
  version=pii_model_info.registered_model_version,
  alias="production",
)