In [None]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torchsummary import summary
import os
import matplotlib.pyplot as plt
import numpy as np

import sys
sys.path.append('../')
from dataset import CIFAR10_captioning
from vgg import vgg13_bn
from lstm import lstm
from encdec_model import EncoderDecoder

%matplotlib inline

## 1. Load CIFAR10-captioning Dataset

In [None]:
batch_size = 256

In [None]:
data_dir_path = os.path.join(os.getcwd(), '../', 'data')

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2471, 0.2435, 0.2616))])

# make download true to download data!
trainset = CIFAR10_captioning(root=data_dir_path, train=True, download=False, transform = transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)
                                          
testset = CIFAR10_captioning(root=data_dir_path, train=False, download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

classes = trainset.classes
print(classes)

## 2. Define model

In [None]:
encoder = vgg13_bn(pretrained=True, device=device).to(device)
decoder = lstm(10, 16, 1, len(trainset.vocab)).to(device)
model = EncoderDecoder(encoder, decoder).to(device)
summary(model, (3, 32, 32))

## 3. Training

In [None]:
epochs = 100
lr = 0.001
weight_decay = 0.0005

In [None]:
critertion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

In [20]:
class RunningAverage():
    def __init__(self):
        self.reset()
        
    def reset(self):
        self.sum = 0
        self.count = 0
        
    def update(self, value, n=1):
        self.sum += value * n
        self.count += n
        
    def __call__(self):
        return self.sum / self.count

In [None]:
running_training_loss = RunningAverage()
running_training_acc = RunningAverage()

for _ in range(epochs):
    for i, data in enumerate(trainloader, 0):
        inputs, captions, labels = data
        inputs = inputs.to(device)
        captions = captions.to(device)
        target_length = captions.size(1) - 1

        optimizer.zero_grad()
        
        outputs = model(inputs, target_length) # (batch_size, target_length, vocab_size)
        loss = critertion(outputs, captions[:, 1:])
        loss.backward()
        optimizer.step()

        running_training_loss.update(loss.item(), inputs.size(0))
        
        _, predicted = torch.max(outputs.data, 2)
        correct = (predicted == captions[:, 1:]).sum().item()
        running_training_acc.update(correct, inputs.size(0))

        if i % 100 == 0:
            print(f"Epoch: {_}, Batch: {i}, Loss: {running_training_loss()}, Accuracy: {running_training_acc()}")