In [None]:
# need this library to handle pickling bc Colab doesn't support protocol 5 
# which we used when pickling
!pip install pickle5

In [1]:
from tqdm import tqdm
import os
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import numpy as np
from random import randint
import pickle5 as pickle

### Add to console to prevent runtime from disconnecting
```
function ClickConnect(){
    console.log("Clicked on connect button"); 
    document.querySelector("#ok").click()
}
setInterval(ClickConnect,60000)
```

### Load Data from Google Cloud

In [None]:
!gcloud auth login --no-launch-browser

In [7]:
!mkdir train_arrays
!mkdir test_arrays
!mkdir train_images
!mkdir test_images

In [None]:
!gsutil -m rsync gs://sai_data/train_images train_images
!gsutil -m rsync gs://sai_data/test_images test_images
!gsutil -m rsync gs://sai_data/train_arrays train_arrays
!gsutil -m rsync gs://sai_data/test_arrays test_arrays

In [None]:
torch.cuda.is_available()

In [3]:
# load model and move it to GPU
gpu = torch.device('cuda')
cpu = torch.device('cpu')
efficientnet_b6 = models.efficientnet_b6(pretrained = True)
efficientnet_b6.eval()
efficientnet_b6.to(gpu)

In [13]:
# use optimized image loader
torchvision.set_image_backend('accimage')

In [5]:
# normalize using the convention for all pretrained torchvision classifications models
normalize = transforms.Compose([
    transforms.Lambda(lambda x: x.float()),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# apply some data augmenting/model resiliency techniques and then normalize
augment_and_normalize = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness = (0.5,1.2), saturation = 0.5, contrast = (0.2, 2), hue = 0.08),
    normalize
])

In [6]:
class ImageCaptionDataset(Dataset):
    def __init__(self, img_dir, caption_array_dir, id_list, transform = None):
        # assumes that captions are downloaded as jpgs (with no extra processing)
        # and saved in the folder img_dir
        self.img_dir = img_dir
        # assumes that captions are already preprocessed and represented as numpy arrays in
        # the folder caption_array_dir
        self.caption_array_dir = caption_array_dir
        # list of image ids used for both images and caption arrays
        self.id_list = id_list
        self.transform = transform if transform else normalize

    def __len__(self):
        return sum(filename[-4:] == '.jpg' for filename in os.listdir(self.img_dir))

    def __getitem__(self, index):
        index = self.id_list[index]
        # filenames are of the form id.jpg where the id is padded with zeroes to the left
        # until it has length 12
        filename = str(index).zfill(12) + '.jpg'
        # each image comes with at least 5 captions, so choose one at random
        # caption arrays have format id_n.jpg where id is not padded with zeroes
        # and n is an integer between 0 and 4 indicating which of the 5 captions is represented
        i = randint(0, 4)
        with open(f"{self.caption_array_dir}/{index}_{i}.npy", mode = "rb") as f:
            arr = np.load(f)
        img = torchvision.io.read_image(f"{self.img_dir}/{filename}")
        # convert to RGB if grayscale
        if img.shape[0] == 1:
            img = img.repeat(3, 1, 1)
        elif img.shape[0] != 3:
            print("improper shape: ", tuple(img.shape))
            return
        # apply transform for images and just create an equivalent tensor for caption array
        return torch.from_numpy(arr), self.transform(img)

In [7]:
train_image_dir = 'train_images'
train_caption_dir = 'train_arrays'
# retrieve saved id list
with open("train_ids.pkl", mode = "rb") as f:
    train_ids = pickle.load(f)
train_dataset = ImageCaptionDataset(train_image_dir, train_caption_dir, train_ids, transform = augment_and_normalize)

In [None]:
batch_size = 20
num_workers = 5
# pads tensor by adding zeroes at top to extend tensor length to length
# this is necessary since the sequences have different lengths
def pad_top(tensor, length):
    m, n = tensor.shape
    assert length >= m, f"tensor is already too long: {m} > {length}"
    return torch.cat((torch.zeros(length - m, n), tensor))

# stacks images and caption arrays after padding arrays to make all of them the same length
def custom_collate(batch):
    captions, imgs = list(zip(*batch))
    imgs = torch.stack(imgs)
    length = max(caption.shape[0] for caption in captions)
    captions = torch.stack([pad_top(caption, length) for caption in captions])
    return captions, imgs

train_dataloader = DataLoader(train_dataset, collate_fn = custom_collate, batch_size = batch_size,
                              shuffle = True, num_workers = num_workers)

In [22]:
test_image_dir = 'test_images'
test_caption_dir = 'test_arrays'
# retrieve saved id list
with open("test_ids.pkl", mode = "rb") as f:
    test_ids = pickle.load(f)
test_dataset = ImageCaptionDataset(test_image_dir, test_caption_dir, test_ids)

In [None]:
test_batch_size = batch_size
test_num_workers = num_workers
# no point in shuffling since we're not calculating batch gradients for test data
test_dataloader = DataLoader(test_dataset, collate_fn = custom_collate, batch_size = test_batch_size,
                             num_workers = test_num_workers)

In [None]:
train_iter = iter(train_dataloader)

In [20]:
class_batches = []
for _ in range(3):
    print("before: ", torch.cuda.memory_allocated())
    # use no_grad to let gpu free memory that's no longer being used
    # otherwise it holds onto the memory to compute gradients
    with torch.no_grad():
        class_batches.append(efficientnet_b6.forward(next(train_iter)[1].to(gpu)).to(cpu))
    print("after: ", torch.cuda.memory_allocated())

before:  173611008
after:  173611008
before:  173611008
after:  173611008
before:  173611008
after:  173611008
