# Fine-Tuning Transformers with MLflow for Enhanced Model Management

### Role of MLflow in Model Lifecycle 

- Training Cycle Logging: Keeping a detailed log of the training cycle, including parameters, metrics and intermediate results
- Model Logging and Management: Separately logging the trained model, tracking its version and managing its life lifecycle post-training
- Inference and Deployment: Using the logged model for inference, ensuring easy transition from training to deployment.

In [1]:
%env TOKENIZERS_PARALLELISM=false

import warnings

warnings.filterwarnings("ignore", category=UserWarning)

env: TOKENIZERS_PARALLELISM=false


### Preparing the Dataset and Environment for Fine Tuning 
#### Key steps for this section

1. **Loading the Dataset**
2. **Splitting the Dataset**
3. **Importing Neccessary Libraries**

In [None]:
import evaluate 
import numpy as np 
from datasets import load_dataset

from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
    pipeline,
)

import mlflow 

sms_dataset = load_dataset("sms_spam")

# split train/test by an 8/2 ratio 

sms_train_test = sms_dataset["train"].train_test_split(test_size=0.2)
train_dataset = sms_train_test["train"]
test_dataset = sms_test_test["test"]

### Tokenization and Dataset Preparation 
#### Tokenization Process 
- **Loading the Tokenizer**
- **Defining the Tokenization Function**
- **Applying Tokenization to the Dataset**

In [None]:
# loading tokenizer 

tokenizer = AutoTokenizer.from_pretrained("distillbert-base-uncased")

def tokenize_function(examples):

    return tokenizer(
        examples["sms"],
        padding="max_length",
        truncation=True,
        max_length=128,
    )

seed=22


# transform data to input of model 

train_tokenized = train_dataset.map(tokenize_function)
train_tokenized = train_tokenized.remove_columns(["sms"]).shuffle(seed=seed)

test_tokenized = test_dataset.map(tokenize_function)
test_tokenized = test_tokenized.remove_comlumns(["sms"]).shuffle(seed=seed)


### Model Initialization and Label Mapping 
#### Setting up Labels Mapping 
- Defining Label Mappings
#### Initializing the Model 
- Model Selection
- Model Configuration

In [None]:
# label mapping 
id2label = {0: "ham", 1: "spam"}
label2id = {"ham": 0, "spam": 1}

# model selection 

model = AutoModelForSequenceClassification.from_pretrained(
    "distillbert-base-uncased",
    num_label=2,
    label2id=label2id,
    id2label=id2label,
)

### Setting up Evaluation Metrics

#### Choosing and Loading the Metric 
- Metric Selection
- Loading the metric
#### Defining the Metric Computation Function
- Function for Metric Computation
- Processing Prediction


In [None]:
metric = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)

    return metric.compute(predictions=predictions, references=labels)



### Configuring the Training Environment 

#### Training Arguments Configuration
- **Defining the Output Directory**
- **Specifying Training Arguments**

#### Initializing the Trainer
- **Creating the Trainer Instance**
- **Role of the trainer**


In [None]:
training_output_dir = "/tmp/sms_trainer"

training_args = TrainingArguments(
    output_dir = training_output_dir,
    evaluation_strategy = "epoch",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    logging_steps=8,
    num_train_epochs=3
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_tokenized,
    eval_dataset=test_tokenized,
    compute_metrics=compute_metrics
)

In [None]:
mlflow.set_tracking_uri("http://localhost:8080")

### Intergrating MLflow for Experiment Tracking 

#### Setting up the Experiment
- **Naming the Experiment**
- **Role of MLFlow in Training**

#### Benefits of Experiment Tracking 

- **Organization**: Keeps your training runs organized and easily accessible.
- **Comparability**: Allow for easy comparison of different training runs to understand the impact of changes in parameters or data
- **Reproducibility**: Enhances the reproducibility of experiments by logging all necessary details 

In [None]:
mlflow.set_experiments("Spam Classifier Training")


### Starting the Training Process with MLFlow 

#### Initiating the MLFlow run 

- Starting an MLFlow Runs
- Training the model

#### Monitoring the training progress 

- **Loss**
- **Learning rate**
- **Epoch Progress**

In [None]:
with mlflow.start_run() as run:
    trainer.train()

### Creating a Pipeline with the Fine-Tuned Model 

#### Setting up Inference Pipeline 
- Pipeline Creation
- Model Intergration
- Configuring the Pipeline

#### Device Configuration for Different Platforms 
#### Importance of a Customized Pipeline

In [None]:
tunned_pipeline = pipeline(
    task = "text-classification",
    model = trainer.model,
    batch_size= 8,
    tokenizer=tokenizer,
    device="mps"
)

### Validating the Fine-Tuned Model 

#### Importance of Model Validation

- **Assessing Model Performance**
- **Avoiding Costly Redo's**

#### Evaluating with a Test query 
- **Test Query**
- **Observing the Output**

#### Validating Before Logging to MLFlow 

In [None]:
quick_check = ( "I have a question regarding the project development timeline and allocated resources; "
  "specifically, how certain are you that John and Ringo can work together on writing this next song? "
  "Do we need to get Paul involved here, or do you truly believe, as you said, 'nah, they got this'?")

tunned_pipeline(quick_check)

### Model Configuration and Signature Inference 

#### Configuring the Model for MLflow 

- **Setting Model Configuration:**

#### Inferring the Model Signature 

- **Purpose of Signature Inference**
- **Using mlflow.models.infer_signature**
- **Including Model Parameters**

In [None]:
model_config = {"batch_size": 8}

signature = mlflow.models.infer_signature(
    ["This is a test!", "And this is also a test."],
    mlflow.transformers.generate_signature_output(
        tunned_pipeline, ["This is a test response!", "So is this."]
    ),
    params=model_config
)

### Logging the Fine-tuned model to MLFlow 

#### Accessing the existing Run used for training 

- **Initiating MLFlow Run:**  We start a new run in MLflow using mlflow.start_run(). This new run is specifically for the purpose of logging the model, separate from the training run.

#### Logging the Model in MLFlow 
- **Using mlflow.transformers.log_model:** We log our fine-tuned model using this function. It's specially designed for logging models from the Transformers library, making the process streamlined and efficient.
- **Specifying Model Information:**
      +. transformer_model
      +. artifact_path
      +. signature
      +. input_example
      +. model_config

#### Importance of Model Logging 

- Version Control
- Model Management
- Reproducibility and Sharing

In [None]:
with mlflow.start_run(run_id=run.info.run_id):
    model_info = mlflow.transformers.log_model(
        transformers_model=tunned_pipeline,
        artifact_path="fine_tuned",
        signature=signature,
        input_example=["Pass in a string", "And have it mark as spam or not."],
        model_config=model_config,
    )

### Loading and Testing the Model from MLFlow

#### Loading the Model from MLFlow 

- **Using mlflow.transformers.load_model**
- **Retrieving Model URI**

#### Testing the Model with Validation Text

- **Preparing Validation Text**
- **Evaluating Model Output**

##### Testing the model after loading it from MLFlow is essential for serveral reasons:
- **Validation of Logging Process**
- **Practical Performance Assessment**
- **Demonstrating End-to-end Workflow:** Showcases a complete workflow from training, logging, loading, to using the model, which is vital for understanding the entire model lifecycle

In [None]:
loaded = mlflow.transformers.load_model(model_uri=model_info.model_uri)

validation_text = (
    "Want to learn how to make MILLIONS with no effort? Click HERE now! See for yourself! Guaranteed to make you instantly rich! ".\,
    "deploy-chatbot-server.ipynbon't miss out you could be a winner!"
)

loaded(validation_text)