In [None]:
from huggingface_hub import notebook_login
from datasets import load_dataset
from transformers import (
    AutoTokenizer, 
    AutoModelForSequenceClassification, 
    TrainingArguments, 
    Trainer,
    DataCollatorWithPadding,
    TrainerCallback
)
import evaluate
import numpy as np
import mlflow
from mlflow.models import infer_signature

# Initialize MLflow
mlflow.set_tracking_uri("http://localhost:5000")
mlflow.set_experiment("imdb_sentiment_analysis")

# Load dataset and tokenizer
imdb = load_dataset("imdb")
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

def preprocess_function(examples):
    return tokenizer(examples["text"], truncation=True)

tokenized_imdb = imdb.map(preprocess_function, batched=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

# Evaluation metrics
accuracy = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    metrics = accuracy.compute(predictions=predictions, references=labels)
    
    # Log evaluation metrics to MLflow
    print("Logging evaluation metrics to MLflow:", metrics)
    mlflow.log_metrics(metrics)
    return metrics

# Custom callback to log metrics per batch
class MLflowBatchLoggingCallback(TrainerCallback):
    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs and "loss" in logs:
            print(f"Logging batch loss: {logs['loss']} at step {state.global_step}")
            mlflow.log_metric("batch_loss", logs["loss"], step=state.global_step)

# Model configuration
id2label = {0: "NEGATIVE", 1: "POSITIVE"}
label2id = {"NEGATIVE": 0, "POSITIVE": 1}

with mlflow.start_run() as run:
    # Log parameters to MLflow
    mlflow.log_params({
        "model_name": "distilbert-base-uncased",
        "learning_rate": 2e-5,
        "batch_size": 16,
        "num_epochs": 2,
        "weight_decay": 0.01
    })

    # Load model
    model = AutoModelForSequenceClassification.from_pretrained(
        "distilbert-base-uncased", 
        num_labels=2, 
        id2label=id2label, 
        label2id=label2id
    )

    # Training arguments
    training_args = TrainingArguments(
        output_dir="my_awesome_model",
        learning_rate=2e-5,
        per_device_train_batch_size=16,
        per_device_eval_batch_size=16,
        num_train_epochs=2,
        weight_decay=0.01,
        evaluation_strategy="steps",   # Evaluate after every step
        eval_steps=100,                # Evaluation frequency
        save_strategy="steps",         # Save checkpoints every 100 steps
        save_steps=100,
        logging_steps=1,               # Log every single batch
        load_best_model_at_end=True,
        push_to_hub=False,
    )

    # Custom trainer to log evaluation metrics to MLflow
    class MLflowTrainer(Trainer):
        def log_metrics(self, split, metrics, **kwargs):
            super().log_metrics(split, metrics, **kwargs)
            if split == "eval":
                mlflow.log_metrics({
                    f"eval_{k}": v for k, v in metrics.items()
                })

    # Initialize the trainer with batch logging callback
    trainer = MLflowTrainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_imdb["train"],
        eval_dataset=tokenized_imdb["test"],
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
        callbacks=[MLflowBatchLoggingCallback()]  # Add batch logging callback
    )

    # Train the model and log metrics
    trainer.train()
    
    # Log model to MLflow
    mlflow.pytorch.log_model(
        model,
        artifact_path="model",
        signature=infer_signature(
            tokenized_imdb["train"][:2]["input_ids"],
            np.array([[0.1, 0.9], [0.9, 0.1]])
        )
    )


: 