In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torchvision
from torch import device
from tqdm.notebook import tqdm
torch.manual_seed(47)
np.random.seed(47)

from ResNet import ResNet18

In [None]:
class ResUnit(nn.Module):
    def __init__(self, p=64, stride=2,exp=1):
        super(ResUnit, self).__init__()
        self.c1 = nn.Conv2d(p, p*exp, kernel_size=3, stride=stride, padding=1, bias=False)
        self.b1 = nn.BatchNorm2d(p*exp)
        self.relu = nn.ReLU()
        self.c2 = nn.Conv2d( p*exp,  p*exp , kernel_size=3, stride=1, padding=1, bias=False)
        self.b2 = nn.BatchNorm2d( p*exp)
        self.drp = nn.Dropout(0.2)
        self.relu2 = nn.ReLU()

        ## to ensure same dimension
        self.residual = nn.Sequential()
        if stride != 1 or exp!=1:
            self.residual = nn.Sequential(
                nn.Conv2d( p, p*exp, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d( p*exp )
            )

    def forward(self, x):
        out = self.relu(self.b1(self.c1(x)))
        out = self.b2(self.c2(out))
        out = self.drp(out)
        res = self.residual(x)
        out = self.relu2(out+res)
        return out


class ResNet18(nn.Module):
    def __init__(self,p,expansion,num_classes=100):
        super(ResNet18, self).__init__()
        #input size is 32x32x3
        #first layer
        # to go from 32*32 to
        p=64
        l=[
            nn.Conv2d(3, p, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
            nn.BatchNorm2d(p),
            nn.ReLU(inplace=True),
          ]
        l.append(ResUnit(p,  1))
        l.append(ResUnit(p,  1))

        l.append(ResUnit(p,  2,expansion))
        p*=expansion
        l.append(ResUnit(p,  1))

        l.append(ResUnit(p,  2,expansion))
        p*=expansion
        l.append(ResUnit(p,  1))

        l.append(ResUnit(p,  2,expansion))
        p*=expansion
        l.append(ResUnit(p,  1))

        l.extend([
            nn.AdaptiveAvgPool2d((1,1)),
            nn.Flatten(),
        ])

        self.model = nn.Sequential(*l)
        self.cl=nn.Linear(p,num_classes)
    def forward(self, x):
        x = self.model(x)
        x = self.cl(x)
        return x

# Training the model

In [None]:
#hyperparameters
p=64
lr =0.01
batch_size = 32
epochs = 500
exp = 2
weight_decay = 5e-4
crit = nn.CrossEntropyLoss()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
resnet = ResNet18(p,2,100).to(device)
opti = torch.optim.SGD(resnet.parameters(), lr=lr, weight_decay=1e-6,momentum = 0.9)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opti, T_max=epochs)

In [None]:
transf = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.RandomVerticalFlip(),
])
tr = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
])

In [None]:
train_data = torchvision.datasets.CIFAR100(
    root="./data", train=True, download=True, transform=transf
)
val_data = torchvision.datasets.CIFAR100(
    root="./data", train=False, download=True, transform=tr
)

In [None]:
train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=2)
val_dataloader = torch.utils.data.DataLoader(val_data, batch_size=batch_size, shuffle=False, num_workers=2)

In [None]:
print(resnet)

In [None]:
def train(resnet,crit,opti,scheduler,train_dataloader,val_dataloader,epochs,save=True):
    resnet.train()
    losses = []
    accuracies = []
    val_losses = []
    val_accuracies = []
    for epoch in range(epochs):
        i=0
        correct, total = 0, 0
        for inputs, labels in tqdm(train_dataloader):

            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = resnet(inputs)

            correct += (outputs.argmax(dim=1) == labels).sum().item()
            total += labels.size(0)

            loss = crit(outputs, labels)
            losses.append(loss.item())

            opti.zero_grad()
            loss.backward()
            opti.step()
            scheduler.step()


            i+=1

        accuracies.append(float(correct)/float(total))
        print(f"Epoch {epoch+1}/{epochs}, Loss: {np.mean(losses[-i:])}, Accuracy: {accuracies[-1]}")

        resnet.eval()
        val_correct, val_total = 0, 0
        with torch.no_grad():
            for val_in,val_l in val_dataloader:
                val_in = val_in.to(device)
                val_l = val_l.to(device)

                val_out = resnet(val_in)

                val_loss = crit(val_out, val_l)

                val_correct += (val_out.argmax(dim=1) == val_l).sum().item()
                val_total += val_l.size(0)
        val_losses .append(val_loss.item())
        val_accuracies.append(float(val_correct)/float(val_total))
        print(f"Validation Loss: {val_loss.item()}, Validation Accuracy: {val_accuracies[-1]}")

        if epoch % 10 == 0 and save:
            torch.save(resnet.state_dict(), f"/content/drive/MyDrive/resnet_{epoch}_epoch.pth")

    return losses, accuracies, val_losses, val_accuracies


In [None]:
losses, accuracies, val_losses, val_accuracies = train(resnet,crit,opti,scheduler,train_dataloader,val_dataloader,epochs)