# MNIST PyTorch `conv2d` model

In this document we

- read the training data and plot it, checking labels
- reproducibly split the training data
- make simplest model and train it for one epoch on CPU and GPU
- test logging with [MLflow](https://mlflow.org/docs/latest/index.html)

*Table of Contents*

1. [Packages](#packages)
2. [Config](#config)
3. [Data split](#data-split)
4. [`DataPipe`s and `DataLoader2`](#dataloader2)
5. [As-Simple-As-Possible model](#asap-model)
6. [Training](#training)

***
<a id="packages"></a>
## 1. Packages

In [None]:
import os
import pathlib
import subprocess
import time

from typing import Tuple
from functools import partial

import numpy as np
import torch
import torchdata.dataloader2 as dataloader2
from torchdata.datapipes.map import SequenceWrapper
import torch.nn as nn
import torch.optim as optim

from omegaconf import OmegaConf as oc
from sklearn.model_selection import StratifiedKFold
import matplotlib.pyplot as plt
%matplotlib inline
import mlflow

import kaggland.digrec.data.load as load
import kaggland.utils.preprocessing.aug.image2d as aug_image2d

***
<a id="config"></a>
## 2. Config

In [None]:
python_version = subprocess.run(
    ["python", "--version"],
    capture_output=True).stdout.decode("utf-8").replace("Python", "").strip()

base_config = {
    "info" : {
        "python_version" : python_version,
        "numpy_version" : str(np.__version__),
        "torch_version" : str(torch.__version__),
    },
    "use_cuda": True,
}

data_config = {
    "data": {
        "path_to_data" : pathlib.Path.home() / "data" / "kaggle" / "digit-recognizer" / "train.csv",
        "num_splits": 10,
        "num_val_splits": 3,
        "training": {
            "batch_size": 64,
            "num_workers": 2,
        },
        "validation": {
            "batch_size": 64,
            "num_workers": 2,
        },
    },
}

In [None]:
model_config = {
    "model": {
        "blocks": [
            {"name": "resnet1",
             "params": {
                 "in_channels": 1,
                 "out_channels": 64,
                 "kernel_size": (5, 5),
                 "padding": 2,
                 "conv_weight_init_fn": "kaiming_normal_"},
              },
            {"name": "resnet2",
             "params": {
                 "in_channels": 64,
                 "out_channels": 32,
                 "kernel_size": (3, 3),
                 "padding": 1,
                 "conv_weight_init_fn": "kaiming_normal_"},
            },
            {"name": "resnet3",
             "params": {
                 "in_channels": 32,
                 "out_channels": 16,
                 "kernel_size": (3, 3),
                 "padding": 1,
                 "conv_weight_init_fn": "kaiming_normal_"},
            },
            {"name": "resnet4",
             "params": {
                 "in_channels": 16,
                 "out_channels": "${model.num_out_channels}",
                 "kernel_size": (3, 3),
                 "padding": 1,
                 "conv_weight_init_fn": "kaiming_normal_"},
             },
        ],
        "num_out_channels": 10,
    }
}

config = oc.merge(
    oc.create(base_config),
    oc.create(data_config),
    oc.create(model_config),
)
oc.resolve(config) 

config.data, config.model.blocks

***
<a id="data-split"></a>
## 3. Data split

### 3.1 Make split and prepare data

In [None]:
def make_data(path_to_data: pathlib.PosixPath):
    """Read and split the MNIST dataset."""
    
    assert isinstance(path_to_data, pathlib.PosixPath), f"{path_to_data=} needs to be a pathlib.Path"
    
    images, labels = load.load(path_to_data)
    
    train_indices, val_splits = make_train_val_splits(
        labels=labels,
        num_splits=10,
        num_val_splits=3,
        random_state=0
    )
    
    train_data = {"images": images[train_indices], "labels": labels[train_indices], "indices": train_indices}
    val_data = {idx: {"images": images[val_indices], "labels": labels[val_indices], "indices": val_indices}
                for idx, val_indices in enumerate(val_splits)}
    
    return train_data, val_data    

def make_train_val_splits(labels: np.ndarray, num_splits: int, num_val_splits: int, random_state: int=0):
    """Return training/validation data splits using `StratifiedKFold`."""
    
    cross_val_splits = StratifiedKFold(n_splits=num_splits, shuffle=True, random_state=random_state)
    
    indices_splits = []
    for _, idx_ith_fold in cross_val_splits.split(labels, y=labels):
        indices_splits.append(idx_ith_fold)
    
    # get indices of train and validation splits
    num_train_splits = num_splits - num_val_splits
    train_indices = np.concatenate(indices_splits[:num_train_splits])
    val_splits = indices_splits[num_train_splits:]
    
    # shuffle indices in splits, setting random_seed
    np.random.seed(random_state)
    train_indices = train_indices[np.random.permutation(len(train_indices))]
    for idx, val_indices in enumerate(val_splits):
        val_splits[idx] = val_indices[np.random.permutation(len(val_indices))]
    
    return train_indices, val_splits

train_data, val_data = make_data(config.data.path_to_data)
train_images = train_data["images"]
train_labels = train_data["labels"]

print("unique train_labels:", np.unique(train_labels))

for _, val_img_lab in val_data.items():
    val_labels = val_img_lab["labels"]
    print("unique val_labels:", np.unique(val_labels))

Are the indices in the splits all the indices?

In [None]:
train_indices = train_data["indices"]
val_splits = tuple(value["indices"] for value in val_data.values())

indices_splits = np.concatenate((train_indices, *val_splits))
(np.unique(indices_splits) == np.arange(len(indices_splits))).all()

## 3.2 Util function: compute `train_images` mean and standard deviation

In [None]:
def compute_mean_std(train_images: np.ndarray, np_type=np.float32) -> Tuple[np.ndarray, np.ndarray]:
    mean_train_images = np.mean(train_images).astype(np_type)
    std_train_images = np.std(train_images).astype(np_type)
    return mean_train_images, std_train_images

### 3.3 Tests

#### Sanity check plot

In [None]:
# num = 12
# val_images = val_data[0]["images"]
# val_labels = val_data[0]["labels"]
# fig, ax = plt.subplots(1, num, figsize=(num*2, 2))
# for i in range(num):
#     img = val_images[i]
#     lab = val_labels[i]
#     ax[i].imshow(img, cmap="gray")
#     ax[i].set(title=f"{lab}")
# fig.tight_layout()

***
<a id="dataloader2"></a>
## 4. `DataPipe`s and `DataLoader2` 

### 4.1 Preprocessing transforms

In [None]:
def prepro_normalize_images(image: np.ndarray, mean_train_images: float, std_train_images: float) -> np.ndarray:
    normalized_image = (image.astype(np.float32) - mean_train_images) / std_train_images
    return normalized_image

def prepro_transforms_fn(image_label: Tuple[np.ndarray, np.ndarray], mean_train_images: float, std_train_images: float) -> Tuple[np.ndarray, np.ndarray]:
    image, label = image_label[0], image_label[1]
    image = prepro_normalize_images(image, mean_train_images, std_train_images)
    return image, label.astype(np.int64)

### 4.2 Collate function

In [None]:
def collate_MNIST_mini_batch(many_images_labels):
    batched_samples = torch.zeros(size=(len(many_images_labels), 1, *many_images_labels[0][0].shape),
                                  dtype=torch.float32)
    batched_labels = torch.zeros(size=(len(many_images_labels),),
                                 dtype=torch.int64)
    
    for idx, (image, label) in enumerate(many_images_labels):
        batched_samples[idx, 0] = torch.from_numpy(image)
        batched_labels[idx] = int(label)
    
    return batched_samples.requires_grad_(True), batched_labels

### 4.3 Training/validation pipelines

In [None]:
def make_MNIST_training_datapipe(train_data, data_config, random_state: int=100):
    """Make a datapipe for training data."""
    
    torch.manual_seed(random_state)
    
    train_images = train_data["images"]
    train_labels = train_data["labels"]
    aug_image2d.random_batch_translation(
        images=train_images,
        labels=train_labels,
        max_translation_offset=3,
        prob_thresh=0.3,
        random_seed=random_state,
    )

    mean_train_images, std_train_images = compute_mean_std(train_images)
    
    training_datapipe = (
        SequenceWrapper(train_images)
        .zip(SequenceWrapper(train_labels))
        .shuffle()
        .set_seed(0)
        .sharding_filter()
        .map(partial(prepro_transforms_fn,
                     mean_train_images=mean_train_images,
                     std_train_images=std_train_images))
        .batch(data_config.training.batch_size)
        .collate(collate_MNIST_mini_batch)
        )
    
    return training_datapipe


def make_MNIST_validation_datapipe(train_data, val_data, data_config, idx_val_split: int=0, random_state: int=100):
    """Make DataPipe for validation data, using the validation split at index `idx_val_split`."""
    
    torch.manual_seed(random_state)
    
    train_images = train_data["images"]
    train_labels = train_data["labels"]
    val_images = val_data[idx_val_split]["images"]
    val_labels = val_data[idx_val_split]["labels"]
    mean_train_images, std_train_images = compute_mean_std(train_images)
    
    validation_datapipe = (
        SequenceWrapper(val_images)
        .zip(SequenceWrapper(val_labels))
        .shuffle()
        .set_seed(0)
        .sharding_filter()
        .map(partial(prepro_transforms_fn,
                     mean_train_images=mean_train_images,
                     std_train_images=std_train_images))
        .batch(data_config.validation.batch_size)
        .collate(collate_MNIST_mini_batch)
    )

    return validation_datapipe

### 4.6 DataLoader2

In [None]:
def make_dataloader2(datapipe, num_workers: int=1):
    """Make a DataLoader2 using a DataPipe."""
    
    reading_service = dataloader2.MultiProcessingReadingService(
        num_workers=num_workers,
        worker_prefetch_cnt=0,
        main_prefetch_cnt=0
    )
    
    loader = dataloader2.DataLoader2(datapipe=datapipe, reading_service=reading_service)
    
    return loader

### 4.7 Tests

In [None]:
train_data, val_data = make_data(config.data.path_to_data)
train_dp = make_MNIST_training_datapipe(train_data, config.data, random_state=100)
val0_dp = make_MNIST_validation_datapipe(train_data, val_data=val_data, idx_val_split=0, data_config=config.data, random_state=100)
train_loader = make_dataloader2(train_dp, num_workers=config.data.training.num_workers)
val0_loader = make_dataloader2(val0_dp, num_workers=config.data.validation.num_workers)

#### DAG training DataPipe

In [None]:
# import torchdata.datapipes.utils as dputils
# dp_graph = dputils.to_graph(train_dp)
# dp_graph.view();

#### Sample from training DataLoader2

In [None]:
train_data, val_data = make_data(config.data.path_to_data)
train_dp = make_MNIST_training_datapipe(train_data, config.data, random_state=100)
val0_dp = make_MNIST_validation_datapipe(train_data, val_data=val_data, idx_val_split=0, data_config=config.data, random_state=100)
train_loader = make_dataloader2(train_dp, num_workers=config.data.training.num_workers)
val0_loader = make_dataloader2(val0_dp, num_workers=config.data.validation.num_workers)

dp_train_loader_iter = iter(train_loader)

for _ in range(2):
    batch = next(dp_train_loader_iter)
    images = batch[0]
    labels = batch[1]
    print(f"type(batch): {type(batch)}, l={len(batch)}")
    print(f"images: (dtype={images.dtype}, shape={images.shape}, mean={images.mean():.3f})")
    print(f"labels: (dtype={labels.dtype}, shape={labels.shape})")
    print("---")

#### Sanity check plots

In [None]:
num = 5

fig, ax = plt.subplots(1, num, figsize=(num*4, 4))
for i, batch in enumerate(train_loader):
    image = batch[0][0][0].detach().numpy()
    label = batch[1][0]
    if i < num:
        ax[i].imshow(image, cmap="gray")
        ax[i].set(title=f"{label}, {image.sum():.3f}, {image.dtype}, {image.mean():.3f}")
    else:
        break
fig.tight_layout()

#### Sanity check: `DataLoader2` returns same amounts of digits?

First, get all `DataPipe` labels

In [None]:
dp_labels = torch.zeros(size=(len(train_labels),), dtype=torch.uint8)
train_loader = make_dataloader2(train_dp, num_workers=config.data.training.num_workers)

idx = 0
for batch in train_loader:
    for j, el in enumerate(batch[1]):
        dp_labels[idx] = el
        idx += 1

In [None]:
u1, cnt1 = np.unique(dp_labels, return_counts=True)
# u2, cnt2 = np.unique(labels, return_counts=True)
u3, cnt3 = np.unique(train_labels, return_counts=True)

print(f"DataPipe  -> u1: {u1}, counts: {cnt1}")
print(f"Input lab -> u3: {u3}, counts: {cnt3}")

#### Bench: for one epoch `DataPipe`

In [None]:
train_loader = make_dataloader2(train_dp, num_workers=config.data.training.num_workers)

start_time = time.time()

total = 0
for batch in train_loader:
    images = batch[0][0]
    total += images[0][14, 10:12].sum()

end_time = time.time()

print(f"Elapsed time epoch: {end_time - start_time:.3f} seconds")
print(f"Total: {total}")

***
<a id="asap-model"></a>
## 5. As-Simple-As-Possible model

### 5.1 Configurable ResNet block

In [None]:
class ResNetBlock(nn.Module):
    def __init__(self, resnet_block_config):
        super(ResNetBlock, self).__init__()
        
        self.conv = nn.Conv2d(
            in_channels=resnet_block_config.in_channels,
            out_channels=resnet_block_config.out_channels,
            kernel_size=resnet_block_config.kernel_size,
            padding=resnet_block_config.padding,
            bias=False
        )
        self.conv_1x1 = nn.Conv2d(
            in_channels=resnet_block_config.in_channels,
            out_channels=resnet_block_config.out_channels,
            kernel_size=1,
            padding=0,
            bias=False
        )
        self.batch_norm = nn.BatchNorm2d(
            num_features=resnet_block_config.out_channels
        )
        self.relu = nn.ReLU()
    
        conv_weight_init_fn = getattr(torch.nn.init, resnet_block_config.conv_weight_init_fn)
        conv_weight_init_fn(self.conv.weight, nonlinearity="relu")
        torch.nn.init.constant_(self.batch_norm.weight, 0.5)
        torch.nn.init.zeros_(self.batch_norm.bias)
    
    def forward(self, x):
        out = self.conv(x)
        out_1x1 = self.conv_1x1(x)
        out = self.batch_norm(out)
        out = self.relu(out)
        return out + out_1x1

### 5.2 Model

In [None]:
class Conv2dModelV1(nn.Module):
    def __init__(self, model_config):
        super(Conv2dModelV1, self).__init__()
        self.blocks = nn.ModuleDict([
            (block_config["name"], ResNetBlock(block_config["params"]))
            for block_config in model_config.blocks
        ])
        self.num_out_channels = model_config.num_out_channels
        self.head = nn.ModuleDict([
            ("linear_head", nn.Linear(28 * 28 * self.num_out_channels, self.num_out_channels)),
            ("batch_norm_head", nn.BatchNorm1d(num_features=10)),
            ("relu_head", nn.ReLU()),
        ])
        
    def forward(self, x):
        for block in self.blocks.values():
            x = block(x)
        
        x = x.view(x.size(0), -1)
        for block in self.head.values():
            x = block(x)
        return x

model = Conv2dModelV1(config.model)
print("num trainable params:", sum(p.numel() for p in model.parameters() if p.requires_grad))

***
<a id="training"></a>
## 6. Training 

### 6.1 Loss function and optimizer

In [None]:
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

### 6.2 Training loop

In [None]:
def training_loop(config, n_epochs, optimizer, model, loss_fn):
    """Training loop for MNIST classification task."""
    train_data, val_data = make_data(config.data.path_to_data)
    
    model.cuda()
    use_cuda = config.use_cuda

    for epoch in range(1, n_epochs+1):
        print("Training: ", end="")
        model.train(True)

        train_dp = make_MNIST_training_datapipe(
            train_data, 
            config.data, 
            random_state=epoch
        )
        train_loader = make_dataloader2(train_dp, num_workers=config.data.training.num_workers)
        epoch_start_time = time.time()
        for idx_batch, batch in enumerate(train_loader):
            samples, targets = batch[0], batch[1]
            if use_cuda:
                samples = samples.to(device="cuda")
                targets = targets.to(device="cuda")
        
            optimizer.zero_grad()
            output_model = model(samples)
            loss = loss_fn(output_model, targets)
            loss.backward()
            optimizer.step()
        
            if idx_batch % 200 == 0:
              print(".", end="", flush=True)
        print()
        
        print("Validation: ", end="")
        model.train(False)
        val0_dp = make_MNIST_validation_datapipe(
            train_data, 
            val_data=val_data, 
            idx_val_split=0, 
            data_config=config.data, 
            random_state=epoch
        )
        val0_loader = make_dataloader2(val0_dp, num_workers=config.data.validation.num_workers)
        with torch.inference_mode():
            val_results_current_epoch = {"correct": 0, "total": 0}
            for idx_batch, batch in enumerate(val0_loader):
                samples, targets = batch[0], batch[1]
                if use_cuda:
                    samples = samples.cuda()
                
                output_model = model(samples)
                model_predictions = torch.argmax(output_model, dim=1)
                
                if use_cuda:
                    model_predictions = model_predictions.to(device="cpu").numpy()
                expected_predictions = targets.numpy()
                
                num_correct_predictions = (model_predictions == expected_predictions).sum()
                val_results_current_epoch["correct"] += num_correct_predictions
                val_results_current_epoch["total"] += len(expected_predictions)
            
                if idx_batch % 200 == 0:
                    print(".", end="", flush=True)
        print()
        # ---
        epoch_end_time = time.time()
        print(f" --> Epoch {epoch}")
        print(f"Final training loss: {loss:.3f}")
        print(f"Validation correct predictions: {100*(val_results_current_epoch['correct'] / val_results_current_epoch['total']):.2f}%")
        print(f"Time: {epoch_end_time - epoch_start_time:.2f} seconds")
        print("-"*60)

In [None]:
training_loop(
    config=config,
    n_epochs=5,
    optimizer=optimizer,
    model=model,
    loss_fn=loss_fn,
)