In [None]:
%load_ext lab_black
%config IPCompleter.greedy=True

In [None]:
import multiprocessing
import time
from pathlib import Path
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as T

from torchsummary import summary

torch.backends.cudnn.benchmark = False

import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.pyplot import cm
from matplotlib.colorbar import ColorbarBase

import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix
import seaborn as sns
import pprint

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
n_workers = 0

# OneCycleLR

In [None]:
tmp_model = torch.nn.Linear(2, 1)
lr = 0.1
batch_steps = 128
optimizer = torch.optim.Adam(tmp_model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, max_lr=lr, steps_per_epoch=batch_steps, epochs=1000
)
lrs = []


for i in range(1000):
    optimizer.step()
    lrs.append(optimizer.param_groups[0]["lr"])
    for j in range(batch_steps):
        scheduler.step()
plt.title("OneCycleLR Illustration")
plt.plot(lrs)
plt.xlabel("Iterations")
plt.ylabel("Learning Rate")

# Dataloader Example

the following class reads the data for the third assignment and creates a torch dataset object for it. With this, you can easily use a dataloader to train your model. 

Due to size limit on moodle, the data for this assignment should be obtained from 

https://drive.google.com/file/d/1khzPamThzWScipEfMmOPevtfWV7Tx6UL/view?usp=sharing


Make sure that the file "hw3.npz" is located properly (in this example, it should be in the same folder as this notebook).

 



In [None]:
class STLData(Dataset):
    def __init__(self, mode="", transform=None):
        data = np.load("hw3.npz")
        if "train" in mode:
            # trainloader
            self.images = data["arr_0"]
            self.labels = data["arr_1"]
        elif "val" in mode:
            # valloader
            self.images = data["arr_2"]
            self.labels = data["arr_3"]
        elif "test" in mode:
            # testloader
            self.images = data["arr_4"]
            self.labels = data["arr_5"]
        else:
            raise ValueError("mode should be 'train', 'val' or 'test'")

        # self.images = np.float32(self.images) / 1.0
        # BUG FIXED: arr_N are all np.uint8,
        # T.ToTensor() WILL NOT convert np.float32
        self.images = np.uint8(self.images)
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        sample = self.images[idx, :]
        label = self.labels[idx]
        if self.transform:
            sample = self.transform(sample)
        return sample, label

Here is an example of how you can create a dataloader. 
First read the data. Note that the STL10 class can work with torchvision.transforms that are required in HW3

In [None]:
# modified STLData class
train_set = STLData("train", T.ToTensor())
val_set = STLData("val", T.ToTensor())
test_set = STLData("test", T.ToTensor())

batch_size = 100
n_workers = 0 * multiprocessing.cpu_count()
trainloader = torch.utils.data.DataLoader(
    train_set, batch_size=batch_size, shuffle=True, num_workers=n_workers
)
image_batch, labels = next(iter(trainloader))

# order tensor properly since if no T.Tensor() called in STLData
if trainloader.dataset.transform is None:
    image_batch = image_batch.permute(0, 3, 1, 2)

fig, ax_arr = plt.subplots(2, 4)
for i in range(8):
    img = (image_batch[i]).permute(1, 2, 0)
    row = i // 4
    col = i % 4
    ax_arr[row, col].imshow(img)
    # ax_arr[i // 4, i % 4].axis("off")
    ax_arr[row, col].axes.get_yaxis().set_visible(False)
    ax_arr[row, col].set_xlabel(labels[i].item())
    ax_arr[row, col].set_xticklabels([])
fig.set_figheight(5)
fig.set_figwidth(10)
plt.subplots_adjust(wspace=0.3, hspace=0.01)
plt.show()

Now for a batchsize of 100, you can have a dataloader as follows for your training data. 

# Defining class labels

In [None]:
# class labels for STL dataset
class_labels = [
    "airplane",
    "bird",
    "car",
    "cat",
    "deer",
    "dog",
    "horse",
    "monkey",
    "ship",
    "truck",
]

# Define our main functions

In [None]:
def un_normalize(img, mean, std):
    """Un-normalize a NORMALIZED IMAGE given mean and std, as lists of 3 elements"""
    mean, std = torch.Tensor(mean), torch.Tensor(std)

    # change from (3,) to (1,3,1,1) for broadcasting
    mean, std = mean.unsqueeze(1).T, std.unsqueeze(1).T
    mean = mean.unsqueeze(2).unsqueeze(3)
    std = std.unsqueeze(2).unsqueeze(3)

    return img.mul_(std).add_(mean)

In [None]:
def order_tensor(t_: torch.Tensor):
    """Ensure proper order of tensor dims since T.ToTensor() is inconsistent"""

    # i.e. (B,H,W,C) -> (B,C,H,W)
    if t_.ndim == 4:
        h, w, c = t_.shape[1:4]
        # ensure channels dim is at idx 1
        if c < h or c < w:
            return t_.permute(0, 3, 1, 2).contiguous()
        else:
            return t_
    else:
        raise NotImplementedError

In [None]:
def plot_log(log, model_config, save=False, select=True):
    fig, ax1 = plt.subplots()
    fig.set_figheight(7.5)
    fig.set_figwidth(12)
    # use ax1 for loss, ax2 for accuracy
    ax2 = ax1.twinx()

    epochs = model_config.get("num_epochs")
    x_axis = np.linspace(1, epochs, epochs)
    color = iter(cm.rainbow(np.linspace(0, 1, len(log))))

    # storage for all max/min values based on keys
    selected = dict.fromkeys(log)

    for key, values in log.items():
        c = next(color)
        key_str = key.replace("_", " ").title()
        # plot data
        if "loss" in key:
            ax1.plot(x_axis, values, color=c, label=key_str)
        elif "acc" in key:
            ax2.plot(x_axis, values, color=c, label=key_str)
        if select:
            if "loss" in key:
                # search for min
                x = np.argmin(values) + 1
                y = np.amin(values)
                ax1.plot(
                    x,
                    y,
                    color=c,
                    label=f"Min. {key_str}",
                    markersize=16,
                    marker="d",
                    alpha=0.5,
                )
            elif "acc" in key:
                # search for max
                x = np.argmax(log[key]) + 1
                y = np.amax(log[key])
                ax2.plot(
                    x,
                    y,
                    color=c,
                    label=f"Max. {key_str}",
                    markersize=16,
                    marker="d",
                    alpha=0.5,
                )
            # save values in dict
            # format: (epoch id, data value)
            selected[key] = (x, y)

    ax1.set_ylabel("Loss")
    ax1.set_xlabel("Number of Epochs")
    ax2.set_ylabel("Accuracy (%)")

    # 0 = 'best', 7 = 'center right'
    fig.legend(loc=7, bbox_to_anchor=(1.1, 0.5))

    if save:
        plt.savefig(f"./LR_{model_config['lr']}_{model_config['num_epochs']}.jpg")

    plt.title(
        f"{model_cfg['model']._name} Learning Rate={str(model_cfg['lr'])} Batch Size={(model_cfg['batch_size'])} Max Val Acc={selected['val_acc'][1]} @ Epoch {selected['val_acc'][0]}"
    )
    plt.show()

    if select:
        return selected

In [None]:
@torch.no_grad()
def test_model(net, data_generator, loss_fn, transform=None):
    """Function to easily test model on specified dataset"""
    batch_loss, batch_steps = 0.0, 0
    correct_pred, total_pred = 0, 0

    net = net.to(device)
    net.eval()
    for batch_id, (data, label) in enumerate(data_generator):
        data, label = data.to(device), label.long().to(device)

        if data_generator.dataset.transform is None:
            data = order_tensor(data)
        if transform is not None:
            data = transform(data.cuda())

        output = net(data)
        batch_loss += loss_fn(output, label).item()
        batch_steps += 1

        # indices where probability is maximum
        _, pred_label = torch.max(output, 1)
        correct_pred += (pred_label == label).sum().item()
        total_pred += label.shape[0]

    # average loss/acc across ALL batches
    # i.e. ACROSS specified dataset
    avg_loss = batch_loss / batch_steps
    avg_acc = correct_pred / total_pred

    return avg_loss, avg_acc

In [None]:
def train_model(config):
    logger = {
        "train_loss": np.zeros(config["num_epochs"]),
        "val_loss": np.zeros(config["num_epochs"]),
        "train_acc": np.zeros(config["num_epochs"]),
        "val_acc": np.zeros(config["num_epochs"]),
        "test_acc": np.zeros(config["num_epochs"]),
    }

    #### LOAD DATA ####
    b_size = config["batch_size"]

    # set to None if not specified
    train_transform = config.get("train_transform")
    val_transform = config.get("val_transform")
    test_transform = config.get("test_transform")

    train_data = STLData("train", T.ToTensor())
    train_dataloader = DataLoader(
        train_data,
        batch_size=b_size,
        num_workers=n_workers,
        shuffle=True,
        pin_memory=False,
    )

    val_data = STLData("val", T.ToTensor())
    val_dataloader = DataLoader(
        val_data,
        batch_size=b_size,
        num_workers=n_workers,
        shuffle=False,
        pin_memory=False,
    )

    test_data = STLData("test", T.ToTensor())
    test_dataloader = DataLoader(
        test_data,
        batch_size=b_size,
        num_workers=n_workers,
        shuffle=False,
        pin_memory=False,
    )

    #### INSTANTIATE MODEL ####
    net = config["model"].to(device)
    loss_function = nn.CrossEntropyLoss()
    scaler = torch.cuda.amp.GradScaler()

    if "Adam" in config["optimizer"]:
        optimizer = optim.Adam(
            net.parameters(), lr=config["lr"], weight_decay=config["weight_decay"]
        )
    elif "SGD" in config["optimizer"]:
        optimizer = optim.SGD(
            net.parameters(),
            lr=config["lr"],
            momentum=config["momentum"],
            weight_decay=config["weight_decay"],
        )

    # TODO: https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.OneCycleLR.html
    # https://spell.ml/blog/lr-schedulers-and-adaptive-optimizers-YHmwMhAAACYADm6F
    if config.get("lr_scheduler") is not None:
        div_factor = config["lr_scheduler"].get("div_factor")
        div_factor = 1e4 if div_factor is None else div_factor

        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=config["lr"],
            steps_per_epoch=len(train_dataloader),
            epochs=config["num_epochs"],
            final_div_factor=div_factor,
        )

    #### BEGIN TRAINING ####
    start_time = time.time()
    best_val_acc = 0

    pp = pprint.PrettyPrinter(indent=4)
    pp.pprint(f"Train Transforms: {train_transform}")
    pp.pprint(f"Test Transforms: {test_transform}")
    pp.pprint(f"Val Transforms: {val_transform}")

    for j in range(config["num_epochs"]):
        ## START OF EPOCH ##
        train_loss, train_steps = 0.0, 0
        net.train()

        for batch_id, (data, label) in enumerate(train_dataloader):
            data, label = data.to(device), label.long().to(device)

            if train_dataloader.dataset.transform is None:
                data = order_tensor(data)
            if train_transform is not None:
                data = train_transform(data.cuda())

            # forwardfacecolor=fig.get_facecolor()
            with torch.cuda.amp.autocast():
                output = net(data)
                loss = loss_function(output, label)

            # backward
            optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            train_loss += loss.item()
            train_steps += 1

            if config.get("lr_scheduler") is not None:
                # OneCycleLR steps inside dataloader loop
                scheduler.step()

        ## END OF EPOCH ##

        # average training loss for 1 epoch
        train_loss /= train_steps

        # test model on validation dataset
        _, train_acc = test_model(net, train_dataloader, loss_function, train_transform)
        val_loss, val_acc = test_model(
            net, val_dataloader, loss_function, val_transform
        )
        _, test_acc = test_model(net, test_dataloader, loss_function, test_transform)

        logger["train_loss"][j] = train_loss
        logger["val_loss"][j] = val_loss
        logger["train_acc"][j] = train_acc
        logger["val_acc"][j] = val_acc
        logger["test_acc"][j] = test_acc

        if config["log_training"] and (j + 1) % config["log_interval"] == 0:
            print(
                f"Epoch:{j+1}/{config['num_epochs']}",
                f"Train Loss: {logger['train_loss'][j]:.4f}",
                f"Train Acc: {logger['train_acc'][j]:.4f}",
                f"Val Loss: {logger['val_loss'][j]:.4f}",
                f"Val Acc: {logger['val_acc'][j]:.4f}",
                f"Test Acc: {logger['test_acc'][j]:.4f}",
            )

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            if config["save_model"]:
                # make sure folder is created to place saved checkpoints
                path = Path.cwd() / "models" / net._name
                if not path.exists():
                    path.mkdir(parents=True, exist_ok=False)

                # pad with appropriate number of zeros i.e. epoch 10 named as 010
                checkpoint_num = str(j + 1).zfill(len(str(config["num_epochs"])))
                model_path = f"./models/{net._name}/{net._name}_{checkpoint_num}.pt"
                torch.save(net.state_dict(), model_path)

    print(f"{config['num_epochs']} epochs took {time.time() - start_time:.2f}s")

    if config["log_training"]:
        return logger

In [None]:
@torch.no_grad()
def store_model_outputs(net, data_generator, transform=None):
    images = []  # (N,)
    labels = []  # (N,)
    outputs = []  # (N,10)

    net.eval()
    # loop through specified dataset, collect all scores per class
    for batch_id, (data, label) in enumerate(data_generator):
        data, label = data.to(device), label.to(device)

        if data_generator.dataset.transform is None:
            data = order_tensor(data)
        # don't transform since we're just storing model outputs for Q1.2,Q1.3
        if transform is not None:
            data = transform(data)

        output = net(data)
        images += data
        labels += label
        outputs += output

    # convert lists of tensors to single tensor and overwrite variables
    images = torch.stack(images).cpu()
    labels = torch.stack(labels).cpu()  # torch.Tensor(labels) also works
    outputs = torch.stack(outputs).cpu()

    return labels, outputs, images

In [None]:
def get_topk_by_class(net, dataloader, n_img, correct=True):
    """Return Top K images by confidence of prediction by class"""

    labels, outputs, images = store_model_outputs(net, dataloader)
    # apply softmax
    confidence = F.softmax(outputs, dim=1)
    # find correct labels indices
    idx = confidence.argmax(dim=1) == labels

    if correct != True:
        idx = ~idx  # jank bitwise complement

    images = images[idx]
    labels = labels[idx]
    confidence = confidence[idx]

    display_img = []
    for j in range(confidence.shape[1]):
        top_n_idx = torch.argsort(confidence[:, j], descending=True)[:n_img]
        display_img += images[top_n_idx]

    display_img = torch.stack(display_img).cpu()

    return display_img

In [None]:
def visualize_model_outputs(net, dataloader, correct=True):
    """Visualize model's top 5 images for each class on val dataset, defaults to correct predictions"""
    num_img = 5
    display_img = get_topk_by_class(net, dataloader, num_img, correct)

    out = torchvision.utils.make_grid(display_img, nrow=num_img, padding=0)

    fig, ax = plt.subplots(figsize=(10, 20))
    ax.imshow(out.permute(1, 2, 0), interpolation="nearest", aspect="auto")
    ax.get_xaxis().set_visible(False)

    # adding class labels to y-axis
    # appropriately space labels
    offset = int(display_img.shape[2] / 2)
    max_dim = int(display_img.shape[0] / num_img * display_img.shape[2])
    spacing = int(display_img.shape[2])
    yticks = [i for i in range(0 + offset, max_dim, spacing)]

    ax.set_yticks(yticks)
    ax.set_yticklabels(class_labels)
    mode_str = "Correct" if correct else "Wrong"
    ax.set_title(
        f"Visualizing Top {num_img} {mode_str} image predictions for each class"
    )

In [None]:
def make_confusion_matrix(
    net,
    data_generator,
    labels,
):
    """Create confusion matrix based for a given dataset and respective labels"""

    true_labels, outputs, _ = store_model_outputs(net, data_generator)

    true_labels = true_labels.cpu()
    pred_labels = outputs.argmax(1).cpu()
    n_classes = outputs.shape[1]

    # confusion matrix
    cm = np.zeros((n_classes, n_classes))
    for i, val in enumerate(outputs):
        cm[true_labels[i], pred_labels[i]] += 1

    cm_df = pd.DataFrame(cm, index=labels, columns=labels)
    plt.figure(figsize=(11, 8))
    sns.heatmap(cm_df, annot=True, fmt=".0f")
    plt.yticks(rotation=0)
    plt.ylabel("True label", rotation=0)
    plt.xlabel("Predicted label")

    if type(data_generator) == DataLoader:
        if len(data_generator.dataset) == len(STLData("test")):
            dataset_str = "Test"
        elif len(data_generator.dataset) == len(STLData("train")):
            dataset_str = "Train"
        elif len(data_generator.dataset) == len(STLData("val")):
            dataset_str = "Validation"

    plt.title(f"Confusion Matrix for {dataset_str} Dataset")
    plt.show()

In [None]:
def get_mean_std(mode=""):
    """Calculate mean and std across image channels from specified dataset"""
    assert mode == "train" or mode == "val" or mode == "test"
    dataset = STLData(mode, T.ToTensor())
    dataloader = DataLoader(
        dataset,
        batch_size=len(dataset),  # random batch size
        num_workers=0,
        shuffle=False,
        pin_memory=False,
    )

    mean, std = 0, 0
    n_samples = 0
    for _, (data, _) in enumerate(dataloader):
        if dataloader.dataset.transform is None:
            data = order_tensor(data)

        # DO NOT TRANSFORM i.e. ColorJitter etc.!!!!!
        data = data.flatten(2).to("cpu")
        # take across each channel
        # dims=(B,C,H*W)
        # calculate mean/std of each image
        # then sum across batch
        mean += data.mean(dim=2).sum(dim=0)
        std += data.std(dim=2).sum(dim=0)
        n_samples += data.shape[0]
    mean /= n_samples
    std /= n_samples

    return mean, std

In [None]:
@torch.no_grad()
def occlusion_single_img(net, img, true_label, kernel: tuple, stride: int):
    """Perform Occlusion Sensitivity on a single img"""
    if img.ndim == 4:
        H, W = img.shape[2:4]
    elif img.ndim == 3:
        H, W = img.shape[1:3]

    # determine output dimensions
    H_out = int(np.floor((W - kernel[0]) / stride) + 1)
    W_out = int(np.floor((H - kernel[1]) / stride) + 1)
    heatmap = torch.zeros((H_out, W_out))

    for i in range(H_out):
        for j in range(W_out):
            img_mod = T.functional.erase(
                img, i * stride, j * stride, kernel[0], kernel[1], v=0
            )

            # run inference
            net.eval()
            output = net(img_mod)
            # apply softmax
            probs = F.softmax(output, dim=1)

            # setting the heatmap location to probability value
            if torch.is_tensor(true_label):
                heatmap[i, j] = probs.squeeze()[true_label.item()].item()
            else:
                # when true_label is a str
                heatmap[i, j] = probs.squeeze()[true_label].item()

    return heatmap

# ShallowCNN

In [None]:
class ShallowCNN(nn.Module):
    def __init__(self):
        super(ShallowCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 96, (7, 7), stride=2, padding=0)
        self.conv2 = nn.Conv2d(96, 64, (5, 5), stride=2, padding=0)
        self.conv3 = nn.Conv2d(64, 128, (3, 3), stride=2, padding=0)

        self.fc1 = nn.Linear(1152, 128)
        self.fc2 = nn.Linear(128, 10)

        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=3)
        self.relu = nn.ReLU()

        self._name = self.__class__.__name__

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))

        x = self.maxpool(x)

        # flatten all dimensions except batch
        x = torch.flatten(x, 1)

        x = self.relu(self.fc1(x))
        x = self.fc2(x)

        return x

## Train

In [None]:
shallow_net = ShallowCNN().to(device)
# gamma is how much to reduce LR every time scheduler is triggered
# LR between 2.21e-4 and 8.05e-4
model_cfg = {
    "model": shallow_net,
    "optimizer": "Adam",
    "momentum": 0.9,  # only for SGD
    "weight_decay": 0,
    "lr": 2.21e-4,
    "lr_scheduler": {"div_factor": 1e2},
    "batch_size": 128,
    "log_training": True,
    "log_interval": 10,
    "save_model": True,
    "num_epochs": 60,
}

In [None]:
# log_shallow = train_model(model_cfg)

## Plot

In [None]:
plot_log(log_shallow, model_cfg)

## Load model & evaluate

In [None]:
shallow_net = ShallowCNN().to(device)
shallow_net.eval()

# select 33.pt
model_path = f"models/{shallow_net._name}/select/{shallow_net._name}_33.pt"
shallow_net.load_state_dict(torch.load(model_path))

test_dataloader = DataLoader(
    STLData("test", T.ToTensor()),
    batch_size=256,
    num_workers=n_workers,
    shuffle=False,
    pin_memory=False,
)
_, test_acc = test_model(shallow_net, test_dataloader, nn.CrossEntropyLoss())
print("Test Accuracy:", test_acc)

## Visualization

In [None]:
val_dataloader = DataLoader(
    STLData("val", T.ToTensor()),
    batch_size=128,
    num_workers=n_workers,
    shuffle=False,
    pin_memory=False,
)

### correct images

In [None]:
visualize_model_outputs(shallow_net, val_dataloader, correct=True)

### wrong images

In [None]:
visualize_model_outputs(shallow_net, val_dataloader, correct=False)

## Confusion matrix

### CM Train

In [None]:
train_dataloader = DataLoader(
    STLData("train", T.ToTensor()),
    batch_size=128,
    num_workers=n_workers,
    shuffle=True,
    pin_memory=False,
)
make_confusion_matrix(shallow_net, train_dataloader, class_labels)

### CM val

In [None]:
val_dataloader = DataLoader(
    STLData("val", T.ToTensor()),
    batch_size=128,
    num_workers=n_workers,
    shuffle=False,
    pin_memory=False,
)
make_confusion_matrix(shallow_net, val_dataloader, class_labels)

### CM Test

In [None]:
test_dataloader = DataLoader(
    STLData("test", T.ToTensor()),
    batch_size=128,
    num_workers=n_workers,
    shuffle=False,
    pin_memory=False,
)
make_confusion_matrix(shallow_net, test_dataloader, class_labels)

# DeepCNN

In [None]:
class DeepCNN(nn.Module):
    def __init__(self):
        super(DeepCNN, self).__init__()

        self.blocks = self._build_blocks()
        # global average pooling
        # since the output of our conv blocks is (6,6)
        self.gap = nn.AvgPool2d(kernel_size=6, stride=1)
        self.fc1 = nn.Linear(192, 10)

        self._name = self.__class__.__name__

    def _build_blocks(self):
        conv_blk_dims = [3, 32, 64, 128, 192]
        blocks_list = []
        for i in range(len(conv_blk_dims) - 1):
            conv_block = self._create_conv_block(conv_blk_dims[i], conv_blk_dims[i + 1])
            named_block = (f"Conv-Blk-{i+1}", conv_block)
            # blocks_list.append(conv_block)
            blocks_list.append(named_block)

        # return nn.Sequential(*blocks_list)
        return nn.Sequential(OrderedDict(blocks_list))

    def _create_conv_block(self, in_channels, out_channels):
        """Create conv_block based on in/out channels"""
        conv_block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, (3, 3), stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, (1, 1), stride=1, padding=0),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, (3, 3), stride=1, padding=1),
            nn.ReLU(),
        )
        return conv_block

    def forward(self, x):
        x = self.blocks(x)
        x = self.gap(x)
        # flatten all dimensions except batch
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        return x

## No transforms

### Train

In [None]:
deep_net = DeepCNN().to(device)
# gamma is how much to reduce LR every time scheduler is triggered
model_cfg = {
    "model": deep_net,
    "optimizer": "Adam",
    "momentum": 0.9,  # only for SGD
    "weight_decay": 0,
    "lr": 2.5e-4,
    "lr_scheduler": {"div_factor": 2},
    "batch_size": 128,
    "log_training": True,
    "log_interval": 10,
    "save_model": True,
    "num_epochs": 150,
}

In [None]:
# log_deep = train_model(model_cfg)

In [None]:
plot_log(log_deep, model_cfg)

### Eval

In [None]:
deep_net = DeepCNN().to(device)
deep_net.eval()
# select 128.pt
model_path = f"./models/{deep_net._name}/select/{deep_net._name}_128.pt"
deep_net.load_state_dict(torch.load(model_path))

test_dataloader = DataLoader(
    STLData("test", T.ToTensor()),
    batch_size=model_cfg["batch_size"],
    num_workers=n_workers,
    shuffle=False,
    pin_memory=True,
)
_, test_acc = test_model(deep_net, test_dataloader, nn.CrossEntropyLoss())
print(test_acc)

## With transforms

### calculate mean and std and define transforms

In [None]:
t_mean, t_std = get_mean_std("train")
print("mean values:", t_mean)
print("std values:", t_std)

t_mean = (t_mean).tolist()
t_std = (t_std).tolist()
print("modified mean values:", t_mean)
print("modified std values:", t_std)

val_transform = nn.Sequential(
    T.Normalize(mean=t_mean, std=t_std),
)
test_transform = nn.Sequential(
    T.Normalize(mean=t_mean, std=t_std),
)

train_transform = nn.Sequential(
    T.Normalize(mean=t_mean, std=t_std),
    T.RandomRotation(degrees=45),
    T.RandomHorizontalFlip(p=0.5),
    T.RandomVerticalFlip(p=0.05),
    T.RandomGrayscale(p=0.2),
    T.RandomErasing(p=0.5),
)
#     T.ColorJitter(brightness=(0, 0.5), hue=(-0.3, 0.3)),
#     T.RandomAffine(degrees=(-45, 45), translate=(0, 0.2), scale=(0.5, 1.0)),

### Visualize transformed images

In [None]:
dataset = STLData("train", T.ToTensor())
dataloader = DataLoader(
    dataset,
    batch_size=128,
    num_workers=0,
    shuffle=True,
    pin_memory=False,
)

img = []
for idx, (data, _) in enumerate(dataloader):
    if dataloader.dataset.transform is None:
        data = order_tensor(data)
    tmp = train_transform(data)
    img += tmp
img = torch.stack(img)
print(img.shape)
print("Channel std after normalization: ", img.std(dim=[0, 2, 3]))
print("Channel means after normalization: ", img.mean(dim=[0, 2, 3]))
print("Max/Min values after normalization: ", img.max(), img.min())
n_img = 256
display = img[:n_img]
out = torchvision.utils.make_grid(display, nrow=16 // 2, padding=0)
fig, ax = plt.subplots(figsize=(15, 45))
ax.imshow(out.permute(1, 2, 0), interpolation="nearest", aspect="auto")
ax.axis("off")
plt.show()

### Train

In [None]:
deep_net = DeepCNN().to(device)
model_cfg = {
    "model": deep_net,
    "optimizer": "Adam",
    "momentum": 0.9,  # only for SGD
    "weight_decay": 0,
    "lr": 1e-3,
    "lr_scheduler": {"div_factor": 1e3},
    "batch_size": 64,
    "log_training": True,
    "log_interval": 10,
    "save_model": True,
    "num_epochs": 300,
    "train_transform": train_transform,
    "val_transform": val_transform,
    "test_transform": test_transform,
}

In [None]:
# tmp = train_model(model_cfg)

In [None]:
plot_log(tmp, model_cfg)

In [None]:
# old plot
plot_log(tmp, model_cfg)

### Eval

In [None]:
net = DeepCNN().to(device)
net.eval()
# select 282.pt or 269.pt
model_path = f"./models/{net._name}/select/{net._name}_282.pt"
net.load_state_dict(torch.load(model_path))

test_dataloader = DataLoader(
    STLData("test", T.ToTensor()),
    batch_size=model_cfg["batch_size"],
    num_workers=n_workers,
    shuffle=False,
    pin_memory=False,
)
_, test_acc = test_model(net, test_dataloader, nn.CrossEntropyLoss(), test_transform)
print(test_acc)

# CustomCNN

In [None]:
class CustomCNN(nn.Module):
    def __init__(self):
        super(CustomCNN, self).__init__()

        self.blocks = self._build_blocks()
        self.fc1 = nn.Linear(384, 192)
        self.fc2 = nn.Linear(192, 10)
        self.ap = nn.AvgPool2d(kernel_size=6, stride=1, padding=0)
        self.mp = nn.MaxPool2d(kernel_size=6, stride=1, padding=0)
        # # (6,6) -> (3,3)
        # self.ap = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
        # # (3,3) -> (1,1)
        # self.mp = nn.MaxPool2d(kernel_size=3, stride=1, padding=0)
        self.lrelu = nn.LeakyReLU()
        self.dropout = nn.Dropout(0.25)

        self._name = self.__class__.__name__

    def _build_blocks(self):
        conv_blk_dims = [3, 32, 64, 128, 192]
        blocks_list = []
        for i in range(len(conv_blk_dims) - 1):
            conv_block = self._create_conv_block(conv_blk_dims[i], conv_blk_dims[i + 1])
            blocks_list.append((f"Conv-Blk-{i+1}", conv_block))

        return nn.Sequential(OrderedDict(blocks_list))

    def _create_conv_block(self, in_channels, out_channels):
        """Create conv_block based on in/out channels"""
        conv_block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, (5, 5), stride=2, padding=2),
            nn.LeakyReLU(),
            nn.BatchNorm2d(out_channels),
            nn.Dropout(0.25),
            nn.Conv2d(out_channels, out_channels, (3, 3), stride=1, padding=1),
            nn.LeakyReLU(),
            nn.BatchNorm2d(out_channels),
            nn.Dropout(0.25),
        )
        return conv_block

    def forward(self, x):
        x = self.blocks(x)
        # x = self.ap(x)
        # x = self.mp(x)

        # concatenate pool2d to preserve more information
        x = torch.cat([self.mp(x), self.ap(x)], dim=1)
        # flatten all dimensions except batch
        x = torch.flatten(x, 1)

        x = self.lrelu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

In [None]:
print(CustomCNN())

## calculate mean, std

In [None]:
t_mean, t_std = get_mean_std("train")
t_mean = t_mean.tolist()
t_std = t_std.tolist()
print("mean values:", t_mean)
print("std values:", t_std)

val_transform = nn.Sequential(
    T.Normalize(mean=t_mean, std=t_std),
)
test_transform = nn.Sequential(
    T.Normalize(mean=t_mean, std=t_std),
)

train_transform = nn.Sequential(
    T.Normalize(mean=t_mean, std=t_std),
    T.RandomRotation(degrees=45),
    T.RandomHorizontalFlip(p=0.5),
    T.RandomVerticalFlip(p=0.05),
    T.RandomGrayscale(p=0.2),
    T.RandomErasing(p=0.5),
)

## train

In [None]:
custom_net = CustomCNN().to(device)
model_cfg = {
    "model": custom_net,
    "optimizer": "Adam",
    "momentum": 0.9,  # only for SGD
    "weight_decay": 0,
    "lr": 1e-3,
    "lr_scheduler": {"div_factor": 1e3},
    "batch_size": 64,
    "log_training": True,
    "log_interval": 10,
    "save_model": True,
    "num_epochs": 300,
    "train_transform": train_transform,
    "val_transform": val_transform,
    "test_transform": test_transform,
}

In [None]:
# log_custom = train_model(model_cfg)

In [None]:
plot_log(log_custom, model_cfg)

## eval

In [None]:
net = CustomCNN().to(device)
net.eval()
# select 294.pt
model_path = f"./models/{net._name}/select/{net._name}_294.pt"
net.load_state_dict(torch.load(model_path))

test_dataloader = DataLoader(
    STLData("test", T.ToTensor()),
    batch_size=model_cfg["batch_size"],
    num_workers=n_workers,
    shuffle=False,
    pin_memory=False,
)
_, test_acc = test_model(net, test_dataloader, nn.CrossEntropyLoss(), test_transform)
print(test_acc)

# Model comparison

In [None]:
deep_cnn = DeepCNN().to(device)
summary(deep_cnn, (3, 96, 96))

In [None]:
custom_cnn = CustomCNN().to(device)
summary(custom_cnn, (3, 96, 96))

# Occlusion Sensitivity for ShallowCNN

## Load trained model

In [None]:
##### LOAD MODEL
device = "cuda"
net = ShallowCNN().to(device)
net.eval()
# select 33.pt
model_path = f"./models/{net._name}/select/{net._name}_33.pt"
net.load_state_dict(torch.load(model_path))

In [None]:
dataloader = DataLoader(
    STLData("val", T.ToTensor()),
    batch_size=512,
    num_workers=0,
    shuffle=True,
    pin_memory=False,
)
# K number of top images
K = 5
images = get_topk_by_class(net, dataloader, K, correct=False)

In [None]:
n_cols = 4
n_rows = int(images.shape[0] * 2 / n_cols)

In [None]:
display = []
for i in range(images.shape[0]):
    curr_img = images[i]
    display.append(curr_img)
    dummy = torch.zeros(3, 96, 96)
    display.append(dummy)
# create (K * 2 * num_classes, 3, 96, 96) Tensor
display = torch.stack(display)

In [None]:
mpl.rcParams["figure.dpi"] = 600
fig, ax = plt.subplots(n_rows, n_cols, figsize=(5, 25), facecolor="white")
tmp = display.unsqueeze(0).reshape(n_rows, n_cols, *display.shape[1::])
import math

for i in range(n_rows):
    for j in range(n_cols):
        out = tmp[i, j].permute(1, 2, 0)
        if (j + 1) % 2 == 0:
            prev_img = tmp[i, j - 1].unsqueeze(0).cuda()
            label = math.floor(i / K)
            heatmap = occlusion_single_img(
                net, prev_img, label, kernel=(8, 8), stride=8
            ).cpu()
            shown_img = ax[i, j].imshow(
                heatmap, interpolation="nearest", aspect="equal"
            )
            cbar = fig.colorbar(shown_img, ax=ax[i, j], format="%.4f")
            cbar.ax.tick_params(labelsize=5)
        else:
            ax[i, j].imshow(out, interpolation="nearest", aspect="equal")
        ax[i, j].axis("off")
# lmao
# n_heatmaps = 0
# for path in Path.cwd().iterdir():
#     if "heatmap" in str(path):
#         n_heatmaps += 1

plt.subplots_adjust(wspace=0.1, hspace=0.15)
plt.savefig(f"heatmap.svg", format="svg", dpi=600)
plt.show()