In [None]:
#!pip install -q mlflow torchmetrics torchinfo

`$ mlflow server`

In [None]:
import matplotlib.pyplot as plt

import torch
from torch import nn
from torch.utils.data import DataLoader

from torchvision import datasets
import torchvision.transforms.v2 as T

from torchinfo import summary
from torchmetrics.classification import Accuracy

import mlflow

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
transform = T.Compose([
    T.ToImage(),
    T.ToDtype(torch.float32, scale=True),
    T.ToPureTensor(),
])

In [None]:
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=transform,
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=transform,
)

In [None]:
class ImageClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(num_features=32),
            nn.LeakyReLU(),

            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(num_features=32),
            nn.LeakyReLU(),

            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(num_features=64),
            nn.LeakyReLU(),

            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(num_features=64),
            nn.LeakyReLU(),

            nn.Flatten(),

            nn.LazyLinear(10),
        )

    def forward(self, x):
        return self.model(x)

In [None]:
mlflow.set_tracking_uri("http://localhost:5000")

In [None]:
num_epoch = 5
learning_rate = 1e-3
batch_size = 64

train_dataloader = DataLoader(training_data, batch_size=batch_size, drop_last=True)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

loss_fn = nn.CrossEntropyLoss()
metric_fn = Accuracy(task="multiclass", num_classes=10).to(device)
model = ImageClassifier().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

params = {
    "num_epoch": num_epoch,
    "learning_rate": learning_rate,
    "batch_size": batch_size,
    "loss_function": loss_fn.__class__.__name__,
    "metric_function": metric_fn.__class__.__name__,
    "optimizer": optimizer.__class__.__name__,
}

In [None]:
params

{'num_epoch': 5,
 'learning_rate': 0.001,
 'batch_size': 64,
 'loss_function': 'CrossEntropyLoss',
 'metric_function': 'MulticlassAccuracy',
 'optimizer': 'SGD'}

In [None]:
with mlflow.start_run() as run:
    mlflow.log_params(params)

    global_step = -1
    for epoch in range(num_epoch):
        # Train
        model.train()
        for batch_idx, batch in enumerate(train_dataloader):
            X, y = batch
            X, y = X.to(device), y.to(device)

            y_pred = model(X)
            loss = loss_fn(y_pred, y)
            accuracy = metric_fn(y_pred, y)

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            global_step += 1

            if batch_idx % 100 == 0:
                train_loss = loss.item()
                train_accuracy = accuracy.item()
                mlflow.log_metric("loss", train_loss, step=global_step)
                mlflow.log_metric("accuracy", train_accuracy, step=global_step)
                print(f"loss: {loss:.4f} accuracy: {accuracy:.4f} [{batch_idx} / {len(train_dataloader)}]")
        
        # Evaluate
        model.eval()
        eval_loss, eval_accuracy = 0, 0
        with torch.inference_mode():
            for X, y in test_dataloader:
                X, y = X.to(device), y.to(device)
                y_pred = model(X)
                eval_loss += loss_fn(y_pred, y).item()
                eval_accuracy += metric_fn(y_pred, y).item()

        eval_loss /= len(test_dataloader)
        eval_accuracy /= len(test_dataloader)
        mlflow.log_metric("eval_loss", eval_loss, step=epoch)
        mlflow.log_metric("eval_accuracy", eval_accuracy, step=epoch)

        print(f"Eval metrics: \nAccuracy: {eval_accuracy:.4f}, Avg loss: {eval_loss:.4f} \n")

loss: 2.3500 accuracy: 0.1562 [0 / 937]
loss: 1.3271 accuracy: 0.6250 [100 / 937]
loss: 0.7336 accuracy: 0.8438 [200 / 937]
loss: 0.9166 accuracy: 0.7031 [300 / 937]
loss: 0.6618 accuracy: 0.7656 [400 / 937]
loss: 0.6582 accuracy: 0.7812 [500 / 937]
loss: 0.6868 accuracy: 0.7500 [600 / 937]
loss: 0.5952 accuracy: 0.7812 [700 / 937]
loss: 0.6896 accuracy: 0.7344 [800 / 937]
loss: 0.5305 accuracy: 0.8125 [900 / 937]
Eval metrics: 
Accuracy: 0.8136, Avg loss: 0.5492 

loss: 0.4168 accuracy: 0.8906 [0 / 937]
loss: 0.5689 accuracy: 0.8281 [100 / 937]
loss: 0.3137 accuracy: 0.8906 [200 / 937]
loss: 0.6015 accuracy: 0.8281 [300 / 937]
loss: 0.4966 accuracy: 0.8125 [400 / 937]
loss: 0.4761 accuracy: 0.7969 [500 / 937]
loss: 0.5055 accuracy: 0.8438 [600 / 937]
loss: 0.5284 accuracy: 0.7969 [700 / 937]
loss: 0.6101 accuracy: 0.7969 [800 / 937]
loss: 0.4348 accuracy: 0.8438 [900 / 937]
Eval metrics: 
Accuracy: 0.8397, Avg loss: 0.4616 

loss: 0.3159 accuracy: 0.8750 [0 / 937]
loss: 0.4659 accurac