In [None]:
# need this library to handle pickling bc Colab doesn't support protocol 5 which we used when pickling
# !pip install pickle5

In [1]:
from tqdm import tqdm
import os
import torch
import torchvision
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader, Sampler
from torchvision import transforms, models
import numpy as np
from random import shuffle
import pickle
from itertools import chain
from math import ceil
import matplotlib.pyplot as plt

### Add to console to prevent runtime from disconnecting
```
function ClickConnect(){
    console.log("Clicked to stay connected"); 
    document.getElementById("header-background").click()
}
setInterval(ClickConnect,60000)
```

### Load Data from Google Cloud

In [None]:
!gcloud auth login --no-launch-browser

In [None]:
!mkdir train_arrays
!mkdir test_arrays
!mkdir train_images
!mkdir test_images

In [None]:
!gsutil -m rsync gs://sai_training/train_images train_images
!gsutil -m rsync gs://sai_training/test_images test_images
!gsutil -m rsync gs://sai_training/train_arrays train_arrays
!gsutil -m rsync gs://sai_training/test_arrays test_arrays

In [2]:
torch.cuda.is_available()

True

In [3]:
# load model and move it to GPU
gpu = torch.device('cuda')
cpu = torch.device('cpu')
efficientnet_b6 = models.efficientnet_b6(pretrained = True)
efficientnet_b6.eval()
efficientnet_b6.to(gpu)

EfficientNet(
  (features): Sequential(
    (0): ConvNormActivation(
      (0): Conv2d(3, 56, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(56, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      (2): SiLU(inplace=True)
    )
    (1): Sequential(
      (0): MBConv(
        (block): Sequential(
          (0): ConvNormActivation(
            (0): Conv2d(56, 56, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=56, bias=False)
            (1): BatchNorm2d(56, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
            (2): SiLU(inplace=True)
          )
          (1): SqueezeExcitation(
            (avgpool): AdaptiveAvgPool2d(output_size=1)
            (fc1): Conv2d(56, 14, kernel_size=(1, 1), stride=(1, 1))
            (fc2): Conv2d(14, 56, kernel_size=(1, 1), stride=(1, 1))
            (activation): SiLU(inplace=True)
            (scale_activation): Sigmoid()
          )
          (2): ConvNormActivatio

In [4]:
# use optimized image loader
torchvision.set_image_backend('accimage')

In [5]:
# normalize using the convention for all pretrained torchvision classifications models
normalize = transforms.Compose([
    transforms.Lambda(lambda x: x.float()),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# apply some data augmenting/model resiliency techniques and then normalize
augment_and_normalize = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    # transforms.ColorJitter(brightness = (0.5,1.2), saturation = 0.5, contrast = (0.2, 2), hue = 0.08),
    normalize
])

In [6]:
num_captions_per_image = 5

class ImageCaptionDataset(Dataset):
    def __init__(self, img_dir, caption_array_dir, id_list, transform = None):
        # assumes that captions are downloaded as jpgs (with no extra processing) and saved in the folder img_dir
        self.img_dir = img_dir
        self.num_images = sum(filename[-4:] == '.jpg' for filename in os.listdir(self.img_dir))
        self.num_captions = num_captions_per_image * self.num_images
        # assumes that captions are already preprocessed and represented as numpy arrays in the folder caption_array_dir
        self.caption_array_dir = caption_array_dir
        # list of image ids used for both images and caption arrays
        self.id_list = id_list
        self.transform = transform if transform else normalize

    def __len__(self):
        return self.num_captions

    def __getitem__(self, index):
        # find corresponding image and caption indices
        i = index % num_captions_per_image
        index = index // num_captions_per_image
        # find image id
        index = self.id_list[index]
        # filenames are of the form id.jpg where the id is padded with zeroes to the left until it has length 12
        filename = str(index).zfill(12) + '.jpg'
        # caption arrays have format id_i.jpg where id is not padded with zeroes
        with open(f"{self.caption_array_dir}/{index}_{i}.npy", mode = "rb") as f:
            arr = np.load(f)
        img = torchvision.io.read_image(f"{self.img_dir}/{filename}")
        # convert to RGB if grayscale
        if img.shape[0] == 1:
            img = img.repeat(3, 1, 1)
        elif img.shape[0] != 3:
            print("improper shape: ", tuple(img.shape))
            return
        # apply transform for images and just create an equivalent tensor for caption array
        return torch.from_numpy(arr).float(), self.transform(img)

In [7]:
class SimilarSizeBatchSampler(Sampler):
    def __init__(self, caption_lengths, batch_size):
        self.length = ceil(len(caption_lengths) / batch_size)
        pairs = list(enumerate(caption_lengths))
        # shuffle first so that the sorting doesn't produce the same ordering every single time
        shuffle(pairs)
        reordered_pairs = sorted(pairs, key = lambda x : x[1])
        # grab index i.e. the first value from every pair
        reordered_indices = list(zip(*reordered_pairs))[0]
        reordered = list(reordered_indices)
        batches = []
        for i in range(self.length):
            batches.append(reordered[i * batch_size : (i + 1) * batch_size])
        # each batch contains samples that are similar in length, but we want to change up the order
        # for the different sized batches in order to make the mini batch gradient "more stochastic"
        # while still having each mini batch be made up of samples with similar length 
        shuffle(batches)
        self.batches = batches

    def __iter__(self):
        return iter(self.batches)

    def __len__(self):
        return self.length

In [8]:
train_image_dir = 'train_images'
train_caption_dir = 'train_arrays'
# retrieve saved id list
with open("train_ids.pkl", mode = "rb") as f:
    train_ids = pickle.load(f)
train_dataset = ImageCaptionDataset(train_image_dir, train_caption_dir, train_ids, transform = augment_and_normalize)

In [9]:
batch_size = 50
num_workers = 0
# retrieve saved caption length list
with open("train_caption_lengths.pkl", mode = "rb") as f:
    train_caption_lengths = pickle.load(f)

# pads tensor by adding zeroes at top to extend tensor length to length
# this is necessary since the sequences have different lengths
def pad_top(tensor, length):
    m, n = tensor.shape
    assert length >= m, f"tensor is already too long: {m} > {length}"
    return torch.cat((torch.zeros(length - m, n), tensor))

# stacks images and caption arrays after padding arrays to make all of them the same length
def custom_collate(batch):
    captions, imgs = list(zip(*batch))
    imgs = torch.stack(imgs)
    length = max(caption.shape[0] for caption in captions)
    captions = torch.stack([pad_top(caption, length) for caption in captions])
    return captions, imgs

train_dataloader = DataLoader(train_dataset, collate_fn = custom_collate, num_workers = num_workers,
                              batch_sampler = SimilarSizeBatchSampler(train_caption_lengths, batch_size))
train_iter = iter(train_dataloader)

In [10]:
test_image_dir = 'test_images'
test_caption_dir = 'test_arrays'
# retrieve saved id list
with open("test_ids.pkl", mode = "rb") as f:
    test_ids = pickle.load(f)
test_dataset = ImageCaptionDataset(test_image_dir, test_caption_dir, test_ids)

In [11]:
test_batch_size = 200
test_num_workers = 10
# retrieve saved caption length list
with open("test_caption_lengths.pkl", mode = "rb") as f:
    test_caption_lengths = pickle.load(f)

test_dataloader = DataLoader(test_dataset, collate_fn = custom_collate, num_workers = test_num_workers,
                             batch_sampler = SimilarSizeBatchSampler(test_caption_lengths, test_batch_size))
test_iter = iter(test_dataloader)

In [12]:
word_dim = 301
hidden_dim = 1200
num_layers = 1
output_dim = 1000
lstm = nn.LSTM(input_size = word_dim, hidden_size = hidden_dim, num_layers = num_layers, batch_first = True, proj_size = output_dim)
seq = nn.Sequential(
    nn.GELU(),
    nn.Linear(1000, 1000),
    nn.Softmax(dim = 1),
)

lstm.to(gpu)
seq.to(gpu)

def model(batch):
    # first element of returned pair is tensor of output states
    # and we want to grab the final output state
    return seq(lstm(batch)[0][:, -1, :])

cos_sim = nn.CosineSimilarity(dim = 1)
optimizer = optim.Adam(chain(lstm.parameters(), seq.parameters()), lr=1e-3, weight_decay=1e-5)

In [None]:
train_losses = []
test_losses= []
train_length = len(train_ids) * num_captions_per_image
test_length = len(test_ids) * num_captions_per_image

for epoch in tqdm(range(1), desc="Epoch"):
    total_loss = 0
    counter = 0
    for train_captions, train_images in train_iter:
        print("start", torch.cuda.memory_allocated())
        n = train_images.shape[0]
        train_captions, train_images = train_captions.to(gpu), train_images.to(gpu)
        with torch.no_grad():
            train_classes = efficientnet_b6.forward(train_images)
            del train_images
        train_outputs = model(train_captions)
        loss = torch.mean(cos_sim(train_outputs, train_classes))
        total_loss += loss.item() * n
        loss.backward()
        optimizer.step()
        print("post step", torch.cuda.memory_allocated())
        optimizer.zero_grad()
        counter += 1
        if counter == 10: break
    train_losses.append(total_loss / train_length)

    # with torch.no_grad():
    #     test_loss = 0
    #     for test_captions, test_images in test_iter:
    #         n = test_images.shape[0]
    #         test_captions, test_images = test_captions.to(gpu), test_images.to(gpu)
    #         test_classes = efficientnet_b6.forward(test_images)
    #         test_outputs = model(test_captions)
    #         test_loss += torch.sum(cos_sim(test_outputs, test_classes)).item()
    #     test_losses.append(test_loss / test_length)
    #     print("test: ", torch.cuda.memory_allocated())

In [None]:
plt.plot(train_losses, label="Train Loss")
plt.plot(test_losses, label="Test Loss")
plt.set_title("Loss Over Epochs")
plt.legend()
plt.show()

In [None]:
torch.rand(3, 4, 5)[:,-1,:].shape