In [46]:
from PIL import Image
from tqdm import tqdm
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import numpy as np
from random import randint
import pickle

In [93]:
class ImageCaptionDataset(Dataset):
    def __init__(self, img_dir, caption_array_dir, id_list, transform = None):
        self.img_dir = img_dir
        self.caption_array_dir = caption_array_dir
        self.id_list = id_list
        self.transform = transform if transform else transforms.ToTensor()

    def __len__(self):
        return sum(filename[-4:] == '.jpg' for filename in os.listdir(self.img_dir))

    def __getitem__(self, index):
        index = self.id_list[index]
        filename = str(index).zfill(12) + '.jpg'
        i = randint(0, 4)
        with open(f"{self.caption_array_dir}/{index}_{i}.npy", mode = "rb") as f:
            arr = np.load(f)
        img = Image.open(f"{self.img_dir}/{filename}")
        if img.mode != 'RGB':
            img = img.convert('RGB')
        return self.transform(img), torch.from_numpy(arr)

In [94]:
train_image_dir = 'train_images'
train_caption_dir = 'caption_train_arrays'
with open("train_ids.pkl", mode = "rb") as f:
    train_ids = pickle.load(f)

augment_and_normalize = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=(0.5,1.2), saturation=0.5, contrast=(0.2, 2), hue=0.08),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
train_dataset = ImageCaptionDataset(train_image_dir, train_caption_dir, train_ids, transform = augment_and_normalize)

In [101]:
batch_size = 50
num_workers = 5
identity = lambda x : x
train_dataloader = DataLoader(train_dataset, collate_fn = identity, batch_size = batch_size, shuffle = True, num_workers = num_workers)

In [96]:
test_image_dir = 'test_images'
test_caption_dir = 'caption_test_arrays'
with open("test_ids.pkl", mode = "rb") as f:
    test_ids = pickle.load(f)

normalize = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
test_dataset = ImageCaptionDataset(test_image_dir, test_caption_dir, test_ids, transform = augment_and_normalize)

In [97]:
test_dataloader = DataLoader(test_dataset, collate_fn = identity, batch_size = batch_size, shuffle = True, num_workers = num_workers)

In [None]:
# load model
efficientnet_b6 = models.efficientnet_b6(pretrained=True)