In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torchvision import transforms
from torchvision.utils import make_grid
from torchvision.datasets import CIFAR10

from torch.utils.data import random_split
from torch.utils.data.dataloader import DataLoader

import matplotlib.pyplot as plt
%matplotlib inline

## Prepare data

### Download the CIFAR10 dataset

In [None]:
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.7),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 40K images
train_ds = CIFAR10(
    "data/", train=True, download=True, transform=train_transform
)

# 10K images
test_ds = CIFAR10(
    "data/", train=False, download=True, transform=test_transform
)

In [None]:
# Split train dataset into train and validation datasets
validation_size = 10000
train_size= len(train_ds) - validation_size
train_ds, validation_ds = random_split(train_ds, [train_size, validation_size])

In [None]:
batch_size = 54
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
validation_dl = DataLoader(validation_ds, batch_size=batch_size, shuffle=False)
test_dl = DataLoader(test_ds, batch_size=batch_size, shuffle=False)

### Preview the train dataset

In [None]:
for images,_ in train_dl:
    plt.figure(figsize=(16,8))
    plt.axis('off')
    
    # make_grid adds a grid around and between images 2px of width
    imgs_grid = make_grid(images, nrow=16) # 16 rows, 4 columns (16*4 = 64 images)
    
    # use permute() to move channel from the first
    # to the last dimension (required by matplotlib)
    plt.imshow(imgs_grid.permute(1, 2, 0))

    break

### Preview the test dataset

In [None]:
for images,_ in test_dl:
    plt.figure(figsize=(16,8))
    plt.axis('off')
    
    # make_grid adds a grid around and between images 2px of width
    imgs_grid = make_grid(images, nrow=16) # 16 rows, 4 columns (16*4 = 64 images)
    
    # use permute() to move channel from the first
    # to the last dimension (required by matplotlib)
    plt.imshow(imgs_grid.permute(1, 2, 0))

    break

### Preview the validation dataset

In [None]:
for images,_ in validation_dl:
    plt.figure(figsize=(16,8))
    plt.axis('off')
    
    # make_grid adds a grid around and between images 2px of width
    imgs_grid = make_grid(images, nrow=16) # 16 rows, 4 columns (16*4 = 64 images)
    
    # use permute() to move channel from the first
    # to the last dimension (required by matplotlib)
    plt.imshow(imgs_grid.permute(1, 2, 0))

    break

## Build the VGG-16 model

In [None]:
class VGG16(nn.Module):
    def __init__(self):
        super(VGG16, self).__init__()

        self.conv1_1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1)
        self.conv1_2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)
        
        self.conv2_1 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
        self.conv2_2 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1)

        self.conv3_1 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1)
        self.conv3_2 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1)
        self.conv3_3 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1)

        self.conv4_1 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1)
        self.conv4_2 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1)
        self.conv4_3 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1)

        self.conv5_1 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1)
        self.conv5_2 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1)
        self.conv5_3 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1)

        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.fc1 = nn.Linear(25088, 4096)
        self.fc2 = nn.Linear(4096, 4096)
        self.fc3 = nn.Linear(4096, 10)

    def forward(self, x):
        y = F.relu(self.conv1_1(x))
        y = F.relu(self.conv1_2(y))
        y = self.maxpool(y)

        y = F.relu(self.conv2_1(y))
        y = F.relu(self.conv2_2(y))
        y = self.maxpool(y)

        y = F.relu(self.conv3_1(y))
        y = F.relu(self.conv3_2(y))
        y = F.relu(self.conv3_3(y))
        y = self.maxpool(y)
        
        y = F.relu(self.conv4_1(y))
        y = F.relu(self.conv4_2(y))
        y = F.relu(self.conv4_3(y))
        y = self.maxpool(y)

        y = F.relu(self.conv5_1(y))
        y = F.relu(self.conv5_2(y))
        y = F.relu(self.conv5_3(y))
        y = self.maxpool(y)

        # TODO check the shape of y before and after reshape
        y = y.reshape(y.shape[0], -1)

        y = F.relu(self.fc1(y))
        y = F.dropout(y, 0.5) # TODO check the impacts of this dropout

        y = F.relu(self.fc2(y))
        y = F.dropout(y, 0.5) # TODO check the impacts of this dropout

        y = self.fc3(y)
        
        return y




### Prepare for training

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
model = VGG16()
model = model.to(device=device)

load_model = True
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [None]:
for epoch in range(50):
    print(f"Epoch: {epoch}")
    
    loss_ep = 0

    for batch_idx, (data, targets) in enumerate(train_dl):
        data = data.to(device=device)
        targets = targets.to(device=device)

        optimizer.zero_grad()
        output = model.forward(data)
        loss = criterion(output, targets)

        loss.backward()
        optimizer.step()
        
        loss_ep += loss.item()
    print(f"Loss in epoch {epoch} :::: {loss_ep/len(train_dl)}")

    with torch.no_grad():
        num_correct = 0
        num_samples = 0

        for batch_idx, (data, targets) in enumerate(validation_dl):
            data = data.to(device=device)
            targets = targets.to(device=device)

            scores = model.forward(data)
            _, predictions = scores.max(1)
            num_correct += (predictions == targets).sum()
            num_samples += predictions.size(0)

        print(
            f"Got {num_correct} / {num_samples} with accuracy {float(num_correct) / float(num_samples) * 100:.2f}"
        )


### Save the trained model

In [None]:
# save the trained model
torch.save(model.state_dict(), "vgg16_cifar.pt")

### Load the trained model

In [None]:
model = VGG16()
model.load_state_dict(torch.load("vgg16_cifar.pt"))

In [None]:
torch.cuda.empty_cache()

### Test the model

In [None]:
num_correct = 0
num_samples = 0

torch.cuda.empty_cache()

model.eval()
model = model.to(device=device)

for batch_idx, (inputs, outputs) in enumerate(test_dl):
    inputs = inputs.to(device=device)
    outputs = outputs.to(device=device)

    scores = model(inputs)
    
    #print(scores)
    break
    _, predictions = scores.max(1)