# Assignment 2 - train.ipynb
Model version control and experiment tracking with MLflow.

In [1]:
import pandas as pd
import mlflow
import mlflow.sklearn
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.naive_bayes import MultinomialNB
from sklearn.linear_model import LogisticRegression
from sklearn.svm import LinearSVC
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, average_precision_score



## Setup

In [2]:
mlflow.set_tracking_uri('sqlite:///mlflow.db')
experiment_name = 'sms_spam_benchmarks'
mlflow.set_experiment(experiment_name)

train_df = pd.read_csv('train.csv')
validation_df = pd.read_csv('validation.csv')
test_df = pd.read_csv('test.csv')

print(train_df.shape, validation_df.shape, test_df.shape)

2026/02/17 00:07:24 INFO mlflow.store.db.utils: Creating initial MLflow database tables...


2026/02/17 00:07:24 INFO mlflow.store.db.utils: Updating database tables


INFO  [alembic.runtime.migration] Context impl SQLiteImpl.


INFO  [alembic.runtime.migration] Will assume non-transactional DDL.


INFO  [alembic.runtime.migration] Running upgrade  -> 451aebb31d03, add metric step


INFO  [alembic.runtime.migration] Running upgrade 451aebb31d03 -> 90e64c465722, migrate user column to tags


INFO  [alembic.runtime.migration] Running upgrade 90e64c465722 -> 181f10493468, allow nulls for metric values


INFO  [alembic.runtime.migration] Running upgrade 181f10493468 -> df50e92ffc5e, Add Experiment Tags Table


INFO  [alembic.runtime.migration] Running upgrade df50e92ffc5e -> 7ac759974ad8, Update run tags with larger limit


INFO  [alembic.runtime.migration] Running upgrade 7ac759974ad8 -> 89d4b8295536, create latest metrics table


INFO  [89d4b8295536_create_latest_metrics_table_py] Migration complete!


INFO  [alembic.runtime.migration] Running upgrade 89d4b8295536 -> 2b4d017a5e9b, add model registry tables to db


INFO  [2b4d017a5e9b_add_model_registry_tables_to_db_py] Adding registered_models and model_versions tables to database.


INFO  [2b4d017a5e9b_add_model_registry_tables_to_db_py] Migration complete!


INFO  [alembic.runtime.migration] Running upgrade 2b4d017a5e9b -> cfd24bdc0731, Update run status constraint with killed


INFO  [alembic.runtime.migration] Running upgrade cfd24bdc0731 -> 0a8213491aaa, drop_duplicate_killed_constraint


INFO  [alembic.runtime.migration] Running upgrade 0a8213491aaa -> 728d730b5ebd, add registered model tags table


INFO  [alembic.runtime.migration] Running upgrade 728d730b5ebd -> 27a6a02d2cf1, add model version tags table


INFO  [alembic.runtime.migration] Running upgrade 27a6a02d2cf1 -> 84291f40a231, add run_link to model_version


INFO  [alembic.runtime.migration] Running upgrade 84291f40a231 -> a8c4a736bde6, allow nulls for run_id


INFO  [alembic.runtime.migration] Running upgrade a8c4a736bde6 -> 39d1c3be5f05, add_is_nan_constraint_for_metrics_tables_if_necessary


INFO  [alembic.runtime.migration] Running upgrade 39d1c3be5f05 -> c48cb773bb87, reset_default_value_for_is_nan_in_metrics_table_for_mysql


INFO  [alembic.runtime.migration] Running upgrade c48cb773bb87 -> bd07f7e963c5, create index on run_uuid


INFO  [alembic.runtime.migration] Running upgrade bd07f7e963c5 -> 0c779009ac13, add deleted_time field to runs table


INFO  [alembic.runtime.migration] Running upgrade 0c779009ac13 -> cc1f77228345, change param value length to 500


INFO  [alembic.runtime.migration] Running upgrade cc1f77228345 -> 97727af70f4d, Add creation_time and last_update_time to experiments table


INFO  [alembic.runtime.migration] Running upgrade 97727af70f4d -> 3500859a5d39, Add Model Aliases table


INFO  [alembic.runtime.migration] Running upgrade 3500859a5d39 -> 7f2a7d5fae7d, add datasets inputs input_tags tables


INFO  [alembic.runtime.migration] Running upgrade 7f2a7d5fae7d -> 2d6e25af4d3e, increase max param val length from 500 to 8000


INFO  [alembic.runtime.migration] Running upgrade 2d6e25af4d3e -> acf3f17fdcc7, add storage location field to model versions


INFO  [alembic.runtime.migration] Running upgrade acf3f17fdcc7 -> 867495a8f9d4, add trace tables


INFO  [alembic.runtime.migration] Running upgrade 867495a8f9d4 -> 5b0e9adcef9c, add cascade deletion to trace tables foreign keys


INFO  [alembic.runtime.migration] Running upgrade 5b0e9adcef9c -> 4465047574b1, increase max dataset schema size


INFO  [alembic.runtime.migration] Running upgrade 4465047574b1 -> f5a4f2784254, increase run tag value limit to 8000


INFO  [alembic.runtime.migration] Running upgrade f5a4f2784254 -> 0584bdc529eb, add cascading deletion to datasets from experiments


INFO  [alembic.runtime.migration] Running upgrade 0584bdc529eb -> 400f98739977, add logged model tables


INFO  [alembic.runtime.migration] Running upgrade 400f98739977 -> 6953534de441, add step to inputs table


INFO  [alembic.runtime.migration] Running upgrade 6953534de441 -> bda7b8c39065, increase_model_version_tag_value_limit


INFO  [alembic.runtime.migration] Context impl SQLiteImpl.


INFO  [alembic.runtime.migration] Will assume non-transactional DDL.


2026/02/17 00:07:25 INFO mlflow.tracking.fluent: Experiment with name 'sms_spam_benchmarks' does not exist. Creating a new experiment.


(3900, 2) (836, 2) (836, 2)


## Helper Functions

In [3]:
def build_vectorizer():
    return TfidfVectorizer(stop_words='english', max_features=5000)


def build_model(model_name: str):
    registry = {
        'nb': MultinomialNB(),
        'lr': LogisticRegression(max_iter=1000, random_state=42),
        'svm': LinearSVC(max_iter=2000, random_state=42),
    }
    return registry[model_name]


def get_scores(model, X_vec, y_true):
    y_pred = model.predict(X_vec)

    if hasattr(model, 'predict_proba'):
        y_score = model.predict_proba(X_vec)[:, 1]
    elif hasattr(model, 'decision_function'):
        y_score = model.decision_function(X_vec)
    else:
        y_score = y_pred

    return {
        'accuracy': accuracy_score(y_true, y_pred),
        'precision': precision_score(y_true, y_pred, zero_division=0),
        'recall': recall_score(y_true, y_pred, zero_division=0),
        'f1': f1_score(y_true, y_pred, zero_division=0),
        'aucpr': average_precision_score(y_true, y_score),
    }

## Build, Track, and Register 3 Benchmark Models

In [4]:
benchmark_models = ['nb', 'lr', 'svm']
results = []

for model_name in benchmark_models:
    with mlflow.start_run(run_name=model_name):
        vectorizer = build_vectorizer()

        X_train = vectorizer.fit_transform(train_df['message'])
        X_val = vectorizer.transform(validation_df['message'])
        X_test = vectorizer.transform(test_df['message'])

        y_train = train_df['target']
        y_val = validation_df['target']
        y_test = test_df['target']

        model = build_model(model_name)
        model.fit(X_train, y_train)

        val_metrics = get_scores(model, X_val, y_val)
        test_metrics = get_scores(model, X_test, y_test)

        mlflow.log_param('model_name', model_name)
        mlflow.log_param('vectorizer', 'tfidf_stopwords_english_max5000')
        mlflow.log_param('random_state', 42)

        mlflow.log_metrics({
            'validation_accuracy': val_metrics['accuracy'],
            'validation_precision': val_metrics['precision'],
            'validation_recall': val_metrics['recall'],
            'validation_f1': val_metrics['f1'],
            'validation_aucpr': val_metrics['aucpr'],
            'test_accuracy': test_metrics['accuracy'],
            'test_precision': test_metrics['precision'],
            'test_recall': test_metrics['recall'],
            'test_f1': test_metrics['f1'],
            'test_aucpr': test_metrics['aucpr'],
        })

        registered_model_name = f'sms_spam_{model_name}'
        mlflow.sklearn.log_model(
            sk_model=model,
            artifact_path='model',
            registered_model_name=registered_model_name,
        )

        results.append({
            'model': model_name,
            'validation_aucpr': val_metrics['aucpr'],
            'test_aucpr': test_metrics['aucpr'],
            'run_id': mlflow.active_run().info.run_id,
            'registered_model': registered_model_name,
        })

results_df = pd.DataFrame(results).sort_values('validation_aucpr', ascending=False).reset_index(drop=True)
print('Benchmark model selection metric (AUCPR):')
display(results_df[['model', 'validation_aucpr', 'test_aucpr']])

best_model = results_df.iloc[0]['model']
print(f'Best model selected using validation AUCPR: {best_model}')





2026/02/17 00:07:26 INFO mlflow.store.db.utils: Creating initial MLflow database tables...


2026/02/17 00:07:26 INFO mlflow.store.db.utils: Updating database tables


INFO  [alembic.runtime.migration] Context impl SQLiteImpl.


INFO  [alembic.runtime.migration] Will assume non-transactional DDL.


Successfully registered model 'sms_spam_nb'.
Created version '1' of model 'sms_spam_nb'.
  norm2_w = weights @ weights if weights.ndim == 1 else squared_norm(weights)
  norm2_w = weights @ weights if weights.ndim == 1 else squared_norm(weights)
  norm2_w = weights @ weights if weights.ndim == 1 else squared_norm(weights)




Successfully registered model 'sms_spam_lr'.
Created version '1' of model 'sms_spam_lr'.




Benchmark model selection metric (AUCPR):


Successfully registered model 'sms_spam_svm'.
Created version '1' of model 'sms_spam_svm'.


Unnamed: 0,model,validation_aucpr,test_aucpr
0,svm,0.977774,0.990482
1,lr,0.972808,0.974854
2,nb,0.955947,0.96858


Best model selected using validation AUCPR: svm


## Checkout AUCPR from MLflow Runs

In [5]:
exp = mlflow.get_experiment_by_name(experiment_name)
runs = mlflow.search_runs(
    experiment_ids=[exp.experiment_id],
    order_by=['metrics.validation_aucpr DESC'],
)

view_cols = [
    'tags.mlflow.runName',
    'metrics.validation_aucpr',
    'metrics.test_aucpr',
    'run_id',
]
print('AUCPR per benchmark model from MLflow tracking:')
display(runs[view_cols])

AUCPR per benchmark model from MLflow tracking:


Unnamed: 0,tags.mlflow.runName,metrics.validation_aucpr,metrics.test_aucpr,run_id
0,svm,0.977774,0.990482,fdf323c36a5e4953b0901e8a753443d9
1,lr,0.972808,0.974854,74ed0562e83c410fb3c1b76f7ae8c2b8
2,nb,0.955947,0.96858,9532e19f05c248f7b7ac5e2e0847ca51
