In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path
import matplotlib.pyplot as plt
import torch
from torchvision.datasets import CIFAR100
from torchvision.transforms import Normalize, ToTensor, Lambda, Compose
import torch.nn.functional as F
import numpy as np
from copy import deepcopy
from tqdm import tqdm
from torch.utils.data import DataLoader

import torch.nn as nn
import torch.optim as optim
from ipywidgets import interact, fixed
import ipywidgets as widgets
import cifar_names as cn

from tqdm import tqdm

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

In [None]:
image_transforms = Compose([
    ToTensor(), 
    Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

label_transforms = Compose([
    Lambda(lambda label: (label, cn.SUPERCLASS_MAPPING[label])),
    Lambda(lambda lsl: (torch.tensor(lsl[0]), torch.tensor(lsl[1])))
])

In [None]:
cifar100_path = Path("data/")
cifar100_train, cifar100_test = [
    CIFAR100(
        cifar100_path, 
        train=is_train, 
        download=True,
        transform=image_transforms,
        target_transform=label_transforms
    )
    for is_train in [True, False]
]

In [None]:
def show_example(i: int):
    img, (c, sc)  = cifar100_train[i]
    plt.imshow(img.permute(1,2,0))
    plt.title(f"class: {cn.CIFAR100_LABELS_LIST[c]}, superclass: {cn.SUPERCLASS_LIST[sc]}")
    plt.show()
    
interact(show_example, i=widgets.IntSlider(max=len(cifar100_train)))

In [None]:
loader_train = DataLoader(cifar100_train, batch_size=32,shuffle=True)
loader_test = DataLoader(cifar100_test, batch_size=32)
iter_train = iter(loader_train)
iter_test = iter(loader_test)

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, channels_in: int, channels_out: int):
        super().__init__()
        self.conv_1 = nn.Conv2d(channels_in, channels_out, kernel_size=3, padding=1)
        self.conv_2 = nn.Conv2d(channels_out, channels_out, kernel_size=3, padding=1)
        self.conv_3 = nn.Conv2d(channels_out, channels_out, kernel_size=3, padding=1)
        self.relu = nn.ReLU()
        self.batch_norm = nn.BatchNorm2d(channels_out)
    def forward(self, X):
        c1 = self.conv_1(X)
        c2 = self.conv_2(c1)
        c3 = self.conv_3(c2)
        re = self.relu(c3)
        bn = self.batch_norm(re)
        return bn

In [None]:
model_base = nn.Sequential(
    ConvBlock(3, 32),
    nn.Conv2d(32, 64, kernel_size=3, padding=1, stride=2),
    ConvBlock(64, 64),
    nn.Conv2d(64, 128, kernel_size=3, padding=1, stride=2),
    ConvBlock(128, 128),
    nn.Conv2d(128, 256, kernel_size=3, padding=1, stride=2),
    ConvBlock(256, 256),
    nn.Conv2d(256, 512, kernel_size=3, padding=1, stride=2),
    ConvBlock(512, 512),    
    nn.Conv2d(512, 1024, kernel_size=3, padding=1, stride=2),
    ConvBlock(1024, 1024),    
)

In [None]:
from time import sleep

In [None]:
iters = 5
epochs = 10
loss_hist = []
acc_hist = []
loss_val_hist = []
acc_val_hist = []

for i in range(iters):
    epochbar = tqdm(range(epochs))
    ls = [0]
    acc = [0]
    for _ in epochbar:
        epochbar.set_description(
            f"iter: {i} |\t" 
            f"train_loss: {np.mean(ls)} |\t"
            f"train_acc: {np.mean(acc)}"
        )

In [None]:
x, (y_c, y_s) = next(iter_train)

In [None]:
x.shape

In [None]:
out = model_base(x)
out.shape