In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [2]:
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision.datasets import CIFAR100
from torchvision.transforms import ToTensor
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import random_split
from sklearn.metrics import confusion_matrix, classification_report, ConfusionMatrixDisplay

In [3]:
torch.manual_seed(42)

<torch._C.Generator at 0x7f1ef3cc1850>

In [4]:
class Baseline(nn.Module):
    def __init__(self):
        super(Baseline, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding="valid")
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding="valid")
        self.conv3 = nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, padding="valid")
        
        self.flatten = nn.Flatten()
        
        self.fc1 = nn.Linear(in_features=43264, out_features=256)
        self.fc2 = nn.Linear(in_features=256, out_features=64)
        self.fc3 = nn.Linear(in_features=64, out_features=100)
        
    def forward(self, x):
        exit_outputs = []
        x = self.conv1(x)
        x = F.relu(x)
        exit_outputs.append(x)
        x = self.conv2(x)
        x = F.relu(x)
        exit_outputs.append(x)
        x = self.conv3(x)
        x = F.relu(x)
        exit_outputs.append(x)
        
        x = self.flatten(x)
        
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        
        return x, exit_outputs

In [5]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [6]:
dataset = CIFAR100(root='./data', download=True, transform=ToTensor())
test_dataset = CIFAR100(root='./data', train=False, transform=ToTensor())

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./data/cifar-100-python.tar.gz


100%|███████████████████████| 169001437/169001437 [00:15<00:00, 11200242.55it/s]


Extracting ./data/cifar-100-python.tar.gz to ./data


In [7]:
batch_size=128
val_size = 5000
train_size = len(dataset) - val_size
train_ds, val_ds = random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_ds, batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_ds, batch_size*2, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size*2, num_workers=4)

In [8]:
model = Baseline().to(device)

In [9]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=3e-3, momentum=0.9)
epochs = 50

In [10]:
best_val_epoch, best_val_loss = 0, 1e6
break_flag = 0
for epoch in range(epochs):  # loop over the dataset multiple times
    model.train()
    t_loss = 0
    correct = 0
    total = 0
    for i, data in enumerate(train_loader):
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs, _ = model(images)
        loss = criterion(outputs, labels)
        t_loss += loss.item()
        loss.backward()
        optimizer.step()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    t_loss = t_loss / (i+1)
    t_loss = round(t_loss, 5)
    t_acc = round(100*(correct / total), 5)
    model.eval()
    v_loss = 0
    correct = 0
    total = 0
    for i, data in enumerate(val_loader):
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs, _ = model(images)
        loss = criterion(outputs, labels)
        v_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    v_loss = v_loss/(i+1)
    v_loss = round(v_loss, 5)
    v_acc = round(100*(correct / total), 5)
    if v_loss <= best_val_loss:
        torch.save(model.state_dict(), "cifar100_baseline_s.h5")
        best_val_epoch = epoch + 1
        best_val_loss = v_loss
        break_flag = 0
    else:
        break_flag += 1
    print(f'Epoch[{epoch+1}]: t_loss: {t_loss} t_acc: {t_acc} v_loss: {v_loss} v_acc: {v_acc}')
    if break_flag >9 :
        break
print('Finished Training')
print('Best model saved at epoch: ', best_val_epoch)

Epoch[1]: t_loss: 4.5125 t_acc: 2.58444 v_loss: 4.18967 v_acc: 6.48
Epoch[2]: t_loss: 3.95849 t_acc: 10.05778 v_loss: 3.81763 v_acc: 12.34
Epoch[3]: t_loss: 3.73266 t_acc: 13.84222 v_loss: 3.67824 v_acc: 15.48
Epoch[4]: t_loss: 3.56998 t_acc: 17.08889 v_loss: 3.57207 v_acc: 16.06
Epoch[5]: t_loss: 3.42936 t_acc: 19.45333 v_loss: 3.40389 v_acc: 20.28
Epoch[6]: t_loss: 3.3127 t_acc: 21.43556 v_loss: 3.34443 v_acc: 22.04
Epoch[7]: t_loss: 3.22754 t_acc: 23.14889 v_loss: 3.29716 v_acc: 22.42
Epoch[8]: t_loss: 3.12793 t_acc: 24.99556 v_loss: 3.21594 v_acc: 23.26
Epoch[9]: t_loss: 3.01663 t_acc: 27.25556 v_loss: 3.14099 v_acc: 25.3
Epoch[10]: t_loss: 2.9035 t_acc: 29.31333 v_loss: 3.06493 v_acc: 25.84
Epoch[11]: t_loss: 2.7831 t_acc: 31.51111 v_loss: 3.00909 v_acc: 27.64
Epoch[12]: t_loss: 2.65864 t_acc: 33.96889 v_loss: 2.93523 v_acc: 28.68
Epoch[13]: t_loss: 2.52694 t_acc: 36.77111 v_loss: 2.91266 v_acc: 29.62
Epoch[14]: t_loss: 2.40229 t_acc: 39.51111 v_loss: 2.88949 v_acc: 30.74
Epoch[15

In [11]:
model.load_state_dict(torch.load("cifar100_baseline_s.h5", map_location='cpu'))
correct = 0
total = 0
pred, actual = [], []
with torch.no_grad():
    for i, data in enumerate(test_loader):
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs, _ = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        pred = pred + list(predicted.detach().cpu().numpy())
        actual = actual + list(labels.detach().cpu().numpy())
print(f'Test accuracy: {100 * correct /total}')       

Test accuracy: 31.26
