In [1]:
from tqdm import tqdm
import os
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import numpy as np
from random import randint
import pickle

In [None]:
# load model but don't move it to GPU to avoid multiprocessing issues in data loading
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
efficientnet_b6 = models.efficientnet_b6(pretrained = True)
efficientnet_b6.eval()
efficientnet_b6.to(device)

In [3]:
print(torch.cuda.is_available())

False


In [3]:
torchvision.set_image_backend('accimage')

In [11]:
# normalize using the convention for all pretrained torchvision classifications models
normalize = transforms.Compose([
    transforms.Lambda(lambda x: x.float()),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# apply some data augmenting/model resiliency techniques and then normalize
augment_and_normalize = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness = (0.5,1.2), saturation = 0.5, contrast = (0.2, 2), hue = 0.08),
    normalize
])

In [19]:
# note that this class automatically moves both inputs and outputs to GPU (if available)
class ImageCaptionDataset(Dataset):
    def __init__(self, img_dir, caption_array_dir, id_list, transform = None):
        # assumes that captions are downloaded as jpgs (with no extra processing)
        # and saved in the folder img_dir
        self.img_dir = img_dir
        # assumes that captions are already preprocessed and represented as numpy arrays in
        # the folder caption_array_dir
        self.caption_array_dir = caption_array_dir
        # list of image ids used for both images and caption arrays
        self.id_list = id_list
        self.transform = transform if transform else normalize

    def __len__(self):
        return sum(filename[-4:] == '.jpg' for filename in os.listdir(self.img_dir))

    def __getitem__(self, index):
        # print(index)
        index = self.id_list[index]
        # filenames are of the form id.jpg where the id is padded with zeroes to the left
        # until it has length 12
        filename = str(index).zfill(12) + '.jpg'
        # each image comes with at least 5 captions, so choose one at random
        # caption arrays have format id_n.jpg where id is not padded with zeroes
        # and n is an integer between 0 and 4 indicating which of the 5 captions is represented
        i = randint(0, 4)
        with open(f"{self.caption_array_dir}/{index}_{i}.npy", mode = "rb") as f:
            arr = np.load(f)
        img = torchvision.io.read_image(f"{self.img_dir}/{filename}")
        # convert to RGB if grayscale
        if img.shape[0] == 1:
            img = img.repeat(3, 1, 1)
        elif img.shape[0] != 3:
            print("improper shape: ", tuple(img.shape))
            return
        # apply transform and classify for img; just create an equivalent tensor for caption array
        # print("loaded files")
        return efficientnet_b6.forward(torch.unsqueeze(self.transform(img), 0)), torch.from_numpy(arr)

In [20]:
train_image_dir = 'train_images'
train_caption_dir = 'train_arrays'
# retrieve saved id list
with open("train_ids.pkl", mode = "rb") as f:
    train_ids = pickle.load(f)
train_dataset = ImageCaptionDataset(train_image_dir, train_caption_dir, train_ids, transform = augment_and_normalize)

In [21]:
batch_size = 50
num_workers = 0
# use this for collate fn to avoid stacking tensors from same batch since dimensions won't line up for caption arrays
def identity(x):
    return x
train_dataloader = DataLoader(train_dataset, collate_fn = identity, batch_size = batch_size, shuffle = True, num_workers = num_workers)

In [8]:
test_image_dir = 'test_images'
test_caption_dir = 'test_arrays'
# retrieve saved id list
with open("test_ids.pkl", mode = "rb") as f:
    test_ids = pickle.load(f)
test_dataset = ImageCaptionDataset(test_image_dir, test_caption_dir, test_ids)

In [9]:
test_dataloader = DataLoader(test_dataset, collate_fn = identity, batch_size = batch_size, shuffle = True, num_workers = num_workers)

In [23]:
train_iter = iter(train_dataloader)

In [24]:
n = next(train_iter)