# MNIST PyTorch `conv2d` model

In this document we

- read the training data and plot it, checking labels
- reproducibly split the training data
- make simplest model and train it for one epoch on CPU and GPU
- test logging with [MLflow](https://mlflow.org/docs/latest/index.html)

*Table of Contents*

1. [Packages](#packages)
2. [Config](#config)
3. [Data split](#data-split)
4. [`DataPipe`s and `DataLoader2`](#dataloader2)
5. [As-Simple-As-Possible model](#asap-model)
6. [Training](#training)

***
<a id="packages"></a>
## 1. Packages

In [None]:
import mlflow
import mlflow.models
import mlflow.pytorch

import os
import pathlib
import subprocess
import time

from typing import Tuple
from functools import partial

import numpy as np
import torch
from torchdata.datapipes.map import SequenceWrapper
import torch.nn as nn
import torch.optim as optim

from omegaconf import OmegaConf as oc
import matplotlib.pyplot as plt
%matplotlib inline
import mlflow

import kaggland.digrec.data.data as digrec_data
import kaggland.digrec.data.transforms as digrec_transforms

import kaggland.utils.preprocessing.aug.image2d as aug_image2d
import kaggland.utils.data.dataloader2 as utils_dataloader2

***
<a id="config"></a>
## 2. Config

In [None]:
python_version = subprocess.run(
    ["python", "--version"],
    capture_output=True).stdout.decode("utf-8").replace("Python", "").strip()

base_config = {
    "info" : {
        "python_version" : python_version,
        "numpy_version" : str(np.__version__),
        "torch_version" : str(torch.__version__),
    },
    "use_cuda": True,
}

data_config = {
    "data": {
        "path_to_data" : pathlib.Path.home() / "data" / "kaggle" / "digit-recognizer" / "train.csv",
        "images_datatype": "f4",
        "num_splits": 10,
        "num_val_splits": 3,
        "training": {
            "batch_size": 64,
            "num_workers": 2,
        },
        "validation": {
            "batch_size": 64,
            "num_workers": 2,
        },
    },
}

In [None]:
model_config = {
    "model": {
        "blocks": [
            {"name": "resnet1",
             "params": {
                 "in_channels": 1,
                 "out_channels": 64,
                 "kernel_size": (5, 5),
                 "padding": 2,
                 "conv_weight_init_fn": "kaiming_normal_"},
              },
            {"name": "resnet2",
             "params": {
                 "in_channels": 64,
                 "out_channels": 32,
                 "kernel_size": (3, 3),
                 "padding": 1,
                 "conv_weight_init_fn": "kaiming_normal_"},
            },
            {"name": "resnet3",
             "params": {
                 "in_channels": 32,
                 "out_channels": 16,
                 "kernel_size": (3, 3),
                 "padding": 1,
                 "conv_weight_init_fn": "kaiming_normal_"},
            },
            {"name": "resnet4",
             "params": {
                 "in_channels": 16,
                 "out_channels": "${model.num_out_channels}",
                 "kernel_size": (3, 3),
                 "padding": 1,
                 "conv_weight_init_fn": "kaiming_normal_"},
             },
        ],
        "num_out_channels": 10,
    }
}

config = oc.merge(
    oc.create(base_config),
    oc.create(data_config),
    oc.create(model_config),
)
oc.resolve(config) 

config.data, config.model.blocks

***
<a id="data-split"></a>
## 3. Data split

### 3.1 Make split and prepare data

In [None]:
train_data, val_data = digrec_data.make(config.data.path_to_data, config.data.images_datatype)
train_images = train_data["images"]
train_labels = train_data["labels"]

print("unique train_labels:", np.unique(train_labels))

for _, val_img_lab in val_data.items():
    val_labels = val_img_lab["labels"]
    print("unique val_labels:", np.unique(val_labels))

Are the indices in the splits all the indices?

In [None]:
train_indices = train_data["indices"]
val_splits = tuple(value["indices"] for value in val_data.values())

indices_splits = np.concatenate((train_indices, *val_splits))
(np.unique(indices_splits) == np.arange(len(indices_splits))).all()

## 3.2 Util function: compute `train_images` mean and standard deviation

In [None]:
import kaggland.utils.image2d as utils_image2d

train_chan_mean, train_chan_std = utils_image2d.compute_channels_mean_std(train_images)
print(train_chan_mean, train_chan_std)

### 3.3 Tests

#### Sanity check plot

In [None]:
# num = 12
# val_images = val_data[0]["images"]
# val_labels = val_data[0]["labels"]
# fig, ax = plt.subplots(1, num, figsize=(num*2, 2))
# for i in range(num):
#     img = val_images[i]
#     lab = val_labels[i]
#     ax[i].imshow(img, cmap="gray")
#     ax[i].set(title=f"{lab}")
# fig.tight_layout()

***
<a id="dataloader2"></a>
## 4. `DataPipe`s and `DataLoader2` 

### 4.2 Collate function

In [None]:
def collate_MNIST_mini_batch(many_images_labels):
    batched_samples = torch.zeros(size=(len(many_images_labels), *many_images_labels[0][0].shape),
                                  dtype=torch.float32)
    batched_labels = torch.zeros(size=(len(many_images_labels),),
                                 dtype=torch.int64)
    
    for idx, (image, label) in enumerate(many_images_labels):
        batched_samples[idx, 0] = torch.from_numpy(image)
        batched_labels[idx] = int(label)
    
    return batched_samples.requires_grad_(True), batched_labels

### 4.3 Training/validation pipelines

In [None]:
def make_MNIST_training_datapipe(train_data, data_config, random_state: int=100):
    """Make a datapipe for training data."""
    
    torch.manual_seed(random_state)
    
    train_images = train_data["images"]
    train_labels = train_data["labels"]
    aug_image2d.random_batch_translation(
        images=train_images,
        labels=train_labels,
        max_translation_offset=3,
        prob_thresh=0.3,
        random_seed=random_state,
    )

    train_chan_mean, train_chan_std = utils_image2d.compute_channels_mean_std(train_images)
    
    training_datapipe = (
        SequenceWrapper(train_images)
        .zip(SequenceWrapper(train_labels))
        .shuffle()
        .set_seed(0)
        .sharding_filter()
        .map(partial(digrec_transforms.train_fn,
                     chan_mean=train_chan_mean,
                     chan_std=train_chan_std))
        .batch(data_config.training.batch_size)
        .collate(collate_MNIST_mini_batch)
        )
    
    return training_datapipe


def make_MNIST_validation_datapipe(train_data, val_data, data_config, idx_val_split: int=0, random_state: int=100):
    """Make DataPipe for validation data, using the validation split at index `idx_val_split`."""
    
    torch.manual_seed(random_state)
    
    train_images = train_data["images"]
    train_labels = train_data["labels"]
    val_images = val_data[idx_val_split]["images"]
    val_labels = val_data[idx_val_split]["labels"]
    
    train_chan_mean, train_chan_std = utils_image2d.compute_channels_mean_std(train_images)
    
    validation_datapipe = (
        SequenceWrapper(val_images)
        .zip(SequenceWrapper(val_labels))
        .shuffle()
        .set_seed(0)
        .sharding_filter()
        .map(partial(digrec_transforms.train_fn,
                     chan_mean=train_chan_mean,
                     chan_std=train_chan_std))
        .batch(data_config.validation.batch_size)
        .collate(collate_MNIST_mini_batch)
    )

    return validation_datapipe

### 4.6 DataLoader2

### 4.7 Tests

In [None]:
train_data, val_data = digrec_data.make(config.data.path_to_data)
train_dp = make_MNIST_training_datapipe(train_data, config.data, random_state=100)
val0_dp = make_MNIST_validation_datapipe(train_data, val_data=val_data, idx_val_split=0, data_config=config.data, random_state=100)
train_loader = utils_dataloader2.make(train_dp, num_workers=config.data.training.num_workers)
val0_loader = utils_dataloader2.make(val0_dp, num_workers=config.data.validation.num_workers)

#### DAG training DataPipe

In [None]:
# import torchdata.datapipes.utils as dputils
# dp_graph = dputils.to_graph(train_dp)
# dp_graph.view();

#### Sample from training DataLoader2

In [None]:
train_data, val_data = digrec_data.make(config.data.path_to_data, config.data.images_datatype)
train_dp = make_MNIST_training_datapipe(train_data, config.data, random_state=100)
val0_dp = make_MNIST_validation_datapipe(train_data, val_data=val_data, idx_val_split=0, data_config=config.data, random_state=100)
train_loader = utils_dataloader2.make(train_dp, num_workers=config.data.training.num_workers)
val0_loader = utils_dataloader2.make(val0_dp, num_workers=config.data.validation.num_workers)

dp_train_loader_iter = iter(train_loader)

for _ in range(2):
    batch = next(dp_train_loader_iter)
    images = batch[0]
    labels = batch[1]
    print(f"type(batch): {type(batch)}, l={len(batch)}")
    print(f"images: (dtype={images.dtype}, shape={images.shape}, mean={images.mean():.3f})")
    print(f"labels: (dtype={labels.dtype}, shape={labels.shape})")
    print("---")

#### Sanity check plots

In [None]:
num = 5

fig, ax = plt.subplots(1, num, figsize=(num*4, 4))
for i, batch in enumerate(train_loader):
    image = batch[0][0][0].detach().numpy()
    label = batch[1][0]
    if i < num:
        ax[i].imshow(image, cmap="gray")
        ax[i].set(title=f"{label}, {image.sum():.3f}, {image.dtype}, {image.mean():.3f}")
    else:
        break
fig.tight_layout()

#### Sanity check: `DataLoader2` returns same amounts of digits?

First, get all `DataPipe` labels

In [None]:
dp_labels = torch.zeros(size=(len(train_labels),), dtype=torch.uint8)
train_loader = utils_dataloader2.make(train_dp, num_workers=config.data.training.num_workers)

idx = 0
for batch in train_loader:
    for j, el in enumerate(batch[1]):
        dp_labels[idx] = el
        idx += 1

In [None]:
u1, cnt1 = np.unique(dp_labels, return_counts=True)
# u2, cnt2 = np.unique(labels, return_counts=True)
u3, cnt3 = np.unique(train_labels, return_counts=True)

print(f"DataPipe  -> u1: {u1}, counts: {cnt1}")
print(f"Input lab -> u3: {u3}, counts: {cnt3}")

#### Bench: for one epoch `DataPipe`

In [None]:
train_loader = utils_dataloader2.make(train_dp, num_workers=config.data.training.num_workers)

start_time = time.time()

total = 0
for batch in train_loader:
    images = batch[0][0]
    total += images[0][14, 10:12].sum()

end_time = time.time()

print(f"Elapsed time epoch: {end_time - start_time:.3f} seconds")
print(f"Total: {total}")

***
<a id="asap-model"></a>
## 5. As-Simple-As-Possible model

### 5.1 Configurable ResNet block

In [None]:
class ResNetBlock(nn.Module):
    def __init__(self, resnet_block_config):
        super(ResNetBlock, self).__init__()
        
        self.conv = nn.Conv2d(
            in_channels=resnet_block_config.in_channels,
            out_channels=resnet_block_config.out_channels,
            kernel_size=resnet_block_config.kernel_size,
            padding=resnet_block_config.padding,
            bias=False
        )
        self.conv_1x1 = nn.Conv2d(
            in_channels=resnet_block_config.in_channels,
            out_channels=resnet_block_config.out_channels,
            kernel_size=1,
            padding=0,
            bias=False
        )
        self.batch_norm = nn.BatchNorm2d(
            num_features=resnet_block_config.out_channels
        )
        self.relu = nn.ReLU()
    
        conv_weight_init_fn = getattr(torch.nn.init, resnet_block_config.conv_weight_init_fn)
        conv_weight_init_fn(self.conv.weight, nonlinearity="relu")
        torch.nn.init.constant_(self.batch_norm.weight, 0.5)
        torch.nn.init.zeros_(self.batch_norm.bias)
    
    def forward(self, x):
        out = self.conv(x)
        out_1x1 = self.conv_1x1(x)
        out = self.batch_norm(out)
        out = self.relu(out)
        return out + out_1x1

### 5.2 Model

In [None]:
class Conv2dModelV1(nn.Module):
    def __init__(self, model_config):
        super(Conv2dModelV1, self).__init__()
        self.blocks = nn.ModuleDict([
            (block_config["name"], ResNetBlock(block_config["params"]))
            for block_config in model_config.blocks
        ])
        self.num_out_channels = model_config.num_out_channels
        self.head = nn.ModuleDict([
            ("linear_head", nn.Linear(28 * 28 * self.num_out_channels, self.num_out_channels)),
            ("batch_norm_head", nn.BatchNorm1d(num_features=10)),
            ("relu_head", nn.ReLU()),
        ])
        
    def forward(self, x):
        for block in self.blocks.values():
            x = block(x)
        
        x = x.view(x.size(0), -1)
        for block in self.head.values():
            x = block(x)
        return x

model = Conv2dModelV1(config.model)
print("num trainable params:", sum(p.numel() for p in model.parameters() if p.requires_grad))

In [None]:
model

***
<a id="training"></a>
## 6. Training 

### 6.1 Training loop

In [None]:
def training_loop(config, n_epochs, model, learning_rate: float=5e-4):
    """Training loop for MNIST classification task."""

    loss_fn = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # set MLflow tracking server URI
    mlflow.set_tracking_uri(uri="http://127.0.0.1:5051")
    mlflow.autolog(disable=True)

    # start MLflow run
    with mlflow.start_run() as run:
        print(f"Run ID: {run.info.run_id}")
        mlflow.log_param(key="n_epochs", value=n_epochs)
        mlflow.log_param(key="learning_rate", value=learning_rate)

        # log config
        yaml_str = oc.to_yaml(config)
        with open("config.yaml", "w") as config_file:
            config_file.write(yaml_str)
        mlflow.log_artifact("config.yaml")
        
        if config.use_cuda:
            model.cuda()
        
        for epoch in range(1, n_epochs+1):
            print("Training: ", end="")
            model.train(True)
            train_data, val_data = digrec_data.make(config.data.path_to_data, config.data.images_datatype)

            train_dp = make_MNIST_training_datapipe(
                train_data.copy(), 
                config.data, 
                random_state=epoch
            )
            train_loader = utils_dataloader2.make(train_dp, num_workers=config.data.training.num_workers)
            epoch_start_time = time.time()
            for idx_batch, batch in enumerate(train_loader):
                samples, targets = batch[0], batch[1]
                if config.use_cuda:
                    samples = samples.to(device="cuda")
                    targets = targets.to(device="cuda")
            
                optimizer.zero_grad()
                output_model = model(samples)
                loss = loss_fn(output_model, targets)
                loss.backward()
                optimizer.step()
            
                if idx_batch % 200 == 0:
                  print(".", end="", flush=True)
            print()
            
            print("Validation: ", end="")
            model.train(False)
            val0_dp = make_MNIST_validation_datapipe(
                train_data.copy(), 
                val_data=val_data.copy(), 
                idx_val_split=0, 
                data_config=config.data, 
                random_state=epoch
            )
            val0_loader = utils_dataloader2.make(val0_dp, num_workers=config.data.validation.num_workers)
            with torch.inference_mode():
                val_results_current_epoch = {"correct": 0, "total": 0}
                for idx_batch, batch in enumerate(val0_loader):
                    samples, targets = batch[0], batch[1]
                    if config.use_cuda:
                        samples = samples.cuda()
                    
                    output_model = model(samples)
                    model_predictions = torch.argmax(output_model, dim=1)
                    
                    signature = mlflow.models.infer_signature(samples.to(device="cpu").numpy(), model_predictions.to(device="cpu").numpy())
                    
                    if config.use_cuda:
                        model_predictions = model_predictions.to(device="cpu").numpy()
                    expected_predictions = targets.numpy()
                    
                    num_correct_predictions = (model_predictions == expected_predictions).sum()
                    val_results_current_epoch["correct"] += num_correct_predictions
                    val_results_current_epoch["total"] += len(expected_predictions)
                
                    if idx_batch % 200 == 0:
                        print(".", end="", flush=True)
            print()
            # ---
            accuracy = val_results_current_epoch['correct'] / val_results_current_epoch['total']
            mlflow.log_metric("accuracy", accuracy)
            
            epoch_end_time = time.time()
            print(f" --> Epoch {epoch}")
            print(f"Final training loss: {loss:.3f}")
            print(f"Validation correct predictions: {100*(val_results_current_epoch['correct'] / val_results_current_epoch['total']):.2f}%")
            print(f"Time: {epoch_end_time - epoch_start_time:.2f} seconds")
            print("-"*60)

            if epoch == n_epochs - 1:
                mlflow.pytorch.log_model(pytorch_model=model, artifact_path="model", signature=signature)
            

In [None]:
training_loop(
    config=config,
    n_epochs=2,
    model=model,
    learning_rate=1e-4
)

## 7. Save model

In [None]:
import datetime
now = datetime.datetime.now()
model_filename = "/home/gabri/tmp/pytorch/models/digrec/" + "weights-" + f"{now.year:05}" + "-" + f"{now.month:02}" + "-" + f"{now.day:02}" + ".pth"
torch.save(model, model_filename)

In [None]:
model2 = torch.load(model_filename)
model2.cuda()

In [None]:
train_data, val_data = make_data(config.data.path_to_data)

In [None]:
val1_dp = make_MNIST_validation_datapipe(
    train_data.copy(), 
    val_data=val_data.copy(), 
    idx_val_split=1, 
    data_config=config.data, 
    random_state=0
)
val1_loader = make_dataloader2(val1_dp, num_workers=config.data.validation.num_workers)
with torch.inference_mode():
    val_results_current_epoch = {"correct": 0, "total": 0}
    for idx_batch, batch in enumerate(val1_loader):
        if idx_batch > 3:
            break
            
        samples, targets = batch[0], batch[1]
        samples = samples.cuda()
        
        output_model = model(samples)
        model_predictions = torch.argmax(output_model, dim=1)
        print(f"{model_predictions=}")