In [39]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset, Dataset
from torch.optim import SGD, Adam
from torchvision.datasets import CIFAR100
from torchvision import transforms


import matplotlib.pyplot as plt

In [5]:
device = torch.device('cuda')

## Augmentation pipeline

### For Input

In [6]:
transform_train = transforms.Compose([
    transforms.ToTensor(),
    transforms.RandomAutocontrast(),
    transforms.RandomHorizontalFlip(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), # for standardisation with mean and variance
])

In [7]:
transform_other = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

## Download Data

In [8]:
image_path = './'

In [9]:
# Applies augmentation while loading the dataset
cifar_train = CIFAR100( 
     root=image_path, 
     train=True ,
     transform = transform_train,
     download = True
)
cifar_test = CIFAR100( 
     root=image_path, 
     train=False,
     transform = transform_other,
     download = True
)


Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./cifar-100-python.tar.gz


  0%|          | 0/169001437 [00:00<?, ?it/s]

Extracting ./cifar-100-python.tar.gz to ./
Files already downloaded and verified


## Dataset preparation
- **Note**: I avoid using the entire training data as it is computationally expensive. So I stick to using subset of training data. To cope up with performance I use augmentations.

In [10]:
print(f'Total train data: {len(cifar_train)}')
print(f'Total test data: {len(cifar_test)}')


Total train data: 50000
Total test data: 10000


In [11]:
torch.manual_seed(2)
# index_train = torch.randperm(50000)
index_test_val = torch.randperm(10000)

In [12]:
# data_train = Subset(cifar_train, index_train[:20000])
data_val = Subset(cifar_test, index_test_val[:5000])
data_test = Subset(cifar_test, index_test_val[5000:])

In [32]:
data = DataLoader(cifar_train, batch_size=256, shuffle=True)
data_vl = DataLoader(data_val, batch_size=64)

## Model, Loss, Optimizer

In [28]:
model = nn.Sequential(
    nn.Conv2d(3, 96, 3),
    nn.ReLU(),
    nn.Conv2d(in_channels=96,out_channels=96, stride=2, kernel_size=3),
    nn.ReLU(),
    nn.Dropout(0.2),
    nn.Conv2d(in_channels=96,out_channels=192, kernel_size=3),
    nn.ReLU(),
    nn.Conv2d(in_channels=192,out_channels=192, stride=2, kernel_size=3),
    nn.ReLU(),
    nn.Dropout(0.5),
    nn.BatchNorm2d(192),
    nn.Flatten(1),
    nn.Linear(192* 5 * 5, 256),
    nn.ReLU(),
    nn.Linear(256, 100)
)

In [33]:
next(iter(data))[0].shape

torch.Size([256, 3, 32, 32])

In [36]:
model(next(iter(data))[0]).shape

torch.Size([256, 100])

In [37]:
model = model.to(device)
model

Sequential(
  (0): Conv2d(3, 96, kernel_size=(3, 3), stride=(1, 1))
  (1): ReLU()
  (2): Conv2d(96, 96, kernel_size=(3, 3), stride=(2, 2))
  (3): ReLU()
  (4): Dropout(p=0.2, inplace=False)
  (5): Conv2d(96, 192, kernel_size=(3, 3), stride=(1, 1))
  (6): ReLU()
  (7): Conv2d(192, 192, kernel_size=(3, 3), stride=(2, 2))
  (8): ReLU()
  (9): Dropout(p=0.5, inplace=False)
  (10): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (11): Flatten(start_dim=1, end_dim=-1)
  (12): Linear(in_features=4800, out_features=256, bias=True)
  (13): ReLU()
  (14): Linear(in_features=256, out_features=100, bias=True)
)

In [41]:
loss_fn = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=0.01)

## Train loop

In [42]:
def train(model ,epoch, data, data_vl):
    train_loss = [0] * epoch
    train_accuracy = [0] * epoch
    
    valid_loss = [0] * epoch
    valid_accuracy = [0] * epoch
    
    for i in (range(epoch)):
        model.train()
        for x_batch, y_batch in data:

            x_batch = x_batch.to(device)
            y_batch = torch.tensor(y_batch)
            y_batch = y_batch.to(device)

            pred = model(x_batch)
            loss = loss_fn(pred, y_batch)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            
            train_loss[i] += loss.item() * x_batch.size(0)
            correct = torch.sum(torch.argmax(torch.softmax(pred,axis=1), axis=1) == y_batch)
#             print(correct, x_batch.size(0))
            train_accuracy[i] += correct.to(torch.float32).item()
        print(train_accuracy[i], len(data.dataset))
        train_loss[i] /= len(data.dataset)
        train_accuracy[i] /= len(data.dataset)
        
        model.eval()
        with torch.no_grad():
            for x_batch, y_batch in data_vl:
                x_batch = x_batch.to(device)
                y_batch = torch.tensor(y_batch)
                y_batch = y_batch.to(device)
                pred = model(x_batch)
                loss = loss_fn(pred, y_batch)
                valid_loss[i] += loss.item() * x_batch.size(0)
                correct = torch.sum(torch.argmax(torch.softmax(pred,axis=1), axis=1) == y_batch)
                valid_accuracy[i] += correct.item()
            valid_loss[i] /= len(data_vl.dataset)
            valid_accuracy[i] /= len(data_vl.dataset)
        print(f'Epoch {i+1} accuracy: {train_accuracy[i]:.4f} val_accuracy:{valid_accuracy[i]:.4f}')
        print(f'Epoch {i+1} loss: {train_loss[i]:.4f} val_loss:{valid_loss[i]:.4f}')
        print()
    return train_loss, train_accuracy, valid_loss, valid_accuracy

In [43]:
epoch = 50
train_loss, train_accuracy, valid_loss, valid_accuracy = train(model, epoch, data, data_vl)

  del sys.path[0]


3412.0 50000




Epoch 1 accuracy: 0.0682 val_accuracy:0.0832
Epoch 1 loss: 4.1190 val_loss:3.9113

5648.0 50000
Epoch 2 accuracy: 0.1130 val_accuracy:0.1304
Epoch 2 loss: 3.7855 val_loss:4.1675

7029.0 50000
Epoch 3 accuracy: 0.1406 val_accuracy:0.1590
Epoch 3 loss: 3.6275 val_loss:3.4847

8278.0 50000
Epoch 4 accuracy: 0.1656 val_accuracy:0.1938
Epoch 4 loss: 3.4718 val_loss:3.9393

9549.0 50000
Epoch 5 accuracy: 0.1910 val_accuracy:0.2106
Epoch 5 loss: 3.3525 val_loss:21.9745

10328.0 50000
Epoch 6 accuracy: 0.2066 val_accuracy:0.2226
Epoch 6 loss: 3.2483 val_loss:3.2063

11341.0 50000
Epoch 7 accuracy: 0.2268 val_accuracy:0.2574
Epoch 7 loss: 3.1552 val_loss:3.0153

12438.0 50000
Epoch 8 accuracy: 0.2488 val_accuracy:0.2742
Epoch 8 loss: 3.0367 val_loss:2.8978

13157.0 50000
Epoch 9 accuracy: 0.2631 val_accuracy:0.3038
Epoch 9 loss: 2.9469 val_loss:2.8000

13738.0 50000
Epoch 10 accuracy: 0.2748 val_accuracy:0.3090
Epoch 10 loss: 2.8724 val_loss:2.7411

14391.0 50000
Epoch 11 accuracy: 0.2878 val_a

## Remarks
- Unless or untill I use more sophistications to avoid overfitting and learn reasonable patterns from such a small image the performance wont be that high