In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
# from torchsummary import summary
from torchinfo import summary
import os
import matplotlib.pyplot as plt
import numpy as np

import sys
sys.path.append('../')
from dataset import CIFAR10_captioning
from vgg import vgg13_bn
from lstm import lstm
from encdec_model import EncoderDecoder

%matplotlib inline

  warn(f"Failed to load image Python extension: {e}")


## 1. Load CIFAR10-captioning Dataset

In [2]:
batch_size = 256

In [3]:
data_dir_path = os.path.join(os.getcwd(), '../', 'data')

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2471, 0.2435, 0.2616))])

# make download true to download data!
trainset = CIFAR10_captioning(root=data_dir_path, train=True, download=False, transform = transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)
                                          
testset = CIFAR10_captioning(root=data_dir_path, train=False, download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

classes = trainset.classes
print(classes)

('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


## 2. Define model

In [48]:
encoder_output_size = len(trainset.classes)
decoder_output_size = len(trainset.vocab)

encoder = vgg13_bn(pretrained=True, device=device).to(device)
decoder = lstm(len(trainset.vocab), 32, 1, len(trainset.vocab)).to(device)
model = EncoderDecoder(encoder, decoder, encoder_output_size=encoder_output_size, decoder_output_size=decoder_output_size).to(device)
summary(model, input_size=(1, 3, 32, 32))

Layer (type:depth-idx)                   Output Shape              Param #
EncoderDecoder                           [1, 6, 16]                --
├─VGG: 1-1                               [1, 10]                   --
│    └─Sequential: 2-1                   [1, 512, 1, 1]            --
│    │    └─Conv2d: 3-1                  [1, 64, 32, 32]           1,792
│    │    └─BatchNorm2d: 3-2             [1, 64, 32, 32]           128
│    │    └─ReLU: 3-3                    [1, 64, 32, 32]           --
│    │    └─Conv2d: 3-4                  [1, 64, 32, 32]           36,928
│    │    └─BatchNorm2d: 3-5             [1, 64, 32, 32]           128
│    │    └─ReLU: 3-6                    [1, 64, 32, 32]           --
│    │    └─MaxPool2d: 3-7               [1, 64, 16, 16]           --
│    │    └─Conv2d: 3-8                  [1, 128, 16, 16]          73,856
│    │    └─BatchNorm2d: 3-9             [1, 128, 16, 16]          256
│    │    └─ReLU: 3-10                   [1, 128, 16, 16]          --
│

## 3. Training

In [49]:
epochs = 20
lr = 0.1
weight_decay = 0.0005

In [50]:
criteration = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=weight_decay)

In [51]:
class RunningAverage():
    def __init__(self):
        self.reset()
        
    def reset(self):
        self.sum = 0
        self.count = 0
        
    def update(self, value, n=1):
        self.sum += value * n
        self.count += n
        
    def __call__(self):
        return self.sum / self.count

In [52]:
def train(model, criteration, optimizer, trainloader):
    model.train()

    running_training_loss = RunningAverage()
    running_training_acc = RunningAverage()
    
    for inputs, (captions, labels) in trainloader:
        inputs = inputs.to(device)
        captions = captions.to(device)
        
        optimizer.zero_grad()

        outputs = model(inputs) # (batch_size, target_length, vocab_size)
        loss = criteration(outputs.permute(0,2,1), captions[:,1:])
        loss.backward()
        optimizer.step()

        running_training_loss.update(loss.item(), inputs.size(0))

        _, predicted = torch.max(outputs.data, 2)
        correct = (predicted == captions[:, 1:]).sum().item()
        running_training_acc.update(correct/(inputs.size(0)*6), inputs.size(0))

    return running_training_loss(), running_training_acc()

In [53]:
def test(model, criteration, testloader, n_examples=-1):
    model.eval()

    running_test_loss = RunningAverage()
    running_test_acc = RunningAverage()

    for inputs, (captions, labels) in testloader:
        inputs = inputs.to(device)
        captions = captions.to(device)
        target_length = captions.size(1) - 1

        with torch.no_grad():
            outputs = model(inputs)
            loss = criteration(outputs.permute(0,2,1), captions[:,1:])
        
        running_test_loss.update(loss.item(), inputs.size(0))
        
        _, predicted = torch.max(outputs.data, 2)
        correct = (predicted == captions[:, 1:]).sum().item()
        running_test_acc.update(correct/(inputs.size(0)*target_length), inputs.size(0))

    if n_examples != -1:
        return running_test_loss(), running_test_acc(), inputs[:n_examples], predicted[:n_examples]
    else:
        return running_test_loss(), running_test_acc()

In [56]:
for e in range(epochs):
    train_loss, train_acc = train(model, criteration, optimizer, trainloader)
    
    if e == epochs - 1:
        test_loss, test_acc, example_images, example_captions = test(model, criteration, testloader, n_examples=4)
    else:
        test_loss, test_acc = test(model, criteration, testloader)

    print(f"Epoch: {e}, Training Loss: {train_loss}, Training Accuracy: {train_acc}, Test Loss: {test_loss}, Test Accuracy: {test_acc}")

Epoch: 0, Training Loss: 2.0714442247009277, Training Accuracy: 0.8333333333333338, Test Loss: 2.066313655471802, Test Accuracy: 0.833333333333333
Epoch: 1, Training Loss: 2.0628743252563475, Training Accuracy: 0.8333333333333338, Test Loss: 2.059933357620239, Test Accuracy: 0.833333333333333
Epoch: 2, Training Loss: 2.0578512464904786, Training Accuracy: 0.8333333333333338, Test Loss: 2.0560204425811768, Test Accuracy: 0.833333333333333
Epoch: 3, Training Loss: 2.0546613621520997, Training Accuracy: 0.8333333333333338, Test Loss: 2.0534371028900145, Test Accuracy: 0.833333333333333
Epoch: 4, Training Loss: 2.052493119735718, Training Accuracy: 0.8333333333333338, Test Loss: 2.051624613571167, Test Accuracy: 0.833333333333333
Epoch: 5, Training Loss: 2.050934086532593, Training Accuracy: 0.8333333333333338, Test Loss: 2.050286424255371, Test Accuracy: 0.833333333333333
Epoch: 6, Training Loss: 2.049757568588257, Training Accuracy: 0.8333333333333338, Test Loss: 2.049253356552124, Test 

In [58]:
example_captions

tensor([[ 0,  1,  2,  3, 15, 15],
        [ 0,  1,  2,  3, 15, 15],
        [ 0,  1,  2,  3, 15, 15],
        [ 0,  1,  2,  3, 15, 15]])