In [1]:
import json
import numpy as np
import os
import torch
from torch.utils.data import Dataset, DataLoader

In [13]:
DATASET_VAL = 'CLEVR_val'
QUESTIONS_VAL_PATH = '/home/chrisams/Datasets/CLEVR_sample/questions/CLEVR_val_questions_sample.json'
FEATS_VAL_DIR = '/home/chrisams/Datasets/CLEVR_sample/regions-MiniCLEVR/regions-miniCLEVR-val'
QUESTIONS_EMBEDDING_VAL = '/home/chrisams/Datasets/CLEVR_sample/questions_val_glove_embeddings.npy'
MAPPER_VAL_PATH = 'q2img_val_mapper.json'

DATASET_TRAIN = 'CLEVR_train'
QUESTIONS_TRAIN_PATH = '/home/chrisams/Datasets/CLEVR_sample/questions/CLEVR_train_questions_sample.json'
FEATS_TRAIN_DIR = '/home/chrisams/Datasets/CLEVR_sample/regions-MiniCLEVR/regions-miniCLEVR-train'
QUESTIONS_EMBEDDING_TRAIN = '/home/chrisams/Datasets/CLEVR_sample/questions_train_glove_embeddings.npy'
MAPPER_TRAIN_PATH = 'q2img_train_mapper.json'

In [14]:
def get_q2img_mapper(questions_path, output_filename):
    '''Saves a json file with the mapper of question to img.
    This will be used by the dataloader.
    '''
    with open(questions_path, 'r') as f:
        data = json.load(f)
    with open(output_filename, 'w') as f:
        json.dump(data['image_index'], f)

In [15]:
get_q2img_mapper(QUESTIONS_VAL_PATH, MAPPER_VAL_PATH)
get_q2img_mapper(QUESTIONS_TRAIN_PATH, MAPPER_TRAIN_PATH)

In [16]:
# # TEST
# with open('q2img_val_mapper.json', 'r') as f:
#     data1 = json.load(f)
# with open(QUESTIONS_VAL_PATH, 'r') as f:
#     data2 = json.load(f)['image_index']

# assert data1 == data2

# with open('q2img_train_mapper.json', 'r') as f:
#     data1 = json.load(f)
# with open(QUESTIONS_TRAIN_PATH, 'r') as f:
#     data2 = json.load(f)['image_index']

# assert data1 == data2

In [17]:
class BottomFeaturesDataset(Dataset):

    def __init__(self, feats_dir, questions_path, mapper_path, dataset):
        self.q2img_mapper = self.load_mapper(mapper_path)
        self.q_emb = np.load(questions_path)
        self.feats_dir = feats_dir
        if dataset in ['CLEVR_val', 'CLEVR_train']:
            self.get_img_name = lambda x: f'{dataset}_{str(x).zfill(6)}.npy'
        
    def load_mapper(self, mapper_path):
        with open(mapper_path, 'r') as f:
            mapper = json.load(f)
        # Sort keys by number.
        q_idxs = sorted(map(int, mapper.keys()))
        new_mapper = [None] * len(q_idxs)
        for i, q_idx in enumerate(q_idxs):
            new_mapper[i] = mapper[str(q_idx)]
        return new_mapper
    
    def __len__(self):
        return len(self.q2img_mapper)

    def __getitem__(self, idx):
        question = self.q_emb[idx]
        img_filename = self.get_img_name(self.q2img_mapper[idx])
        img = np.load(os.path.join(self.feats_dir, img_filename), allow_pickle=True).item()['features']
        return {'image': torch.from_numpy(img),
                'question': torch.from_numpy(question)}

In [18]:
bottom_feats_dataset = BottomFeaturesDataset(FEATS_VAL_DIR, QUESTIONS_EMBEDDING_VAL, MAPPER_VAL_PATH, DATASET_VAL)

In [19]:
dataloader = DataLoader(bottom_feats_dataset, batch_size=4,
                        shuffle=True, num_workers=4)

In [20]:
for i_batch, sample_batched in enumerate(dataloader):
    print(i_batch, sample_batched['image'].size(),
          sample_batched['question'].size())

    # observe 4th batch and stop.
    if i_batch == 3:
        break

0 torch.Size([4, 36, 2048]) torch.Size([4, 25, 300])
1 torch.Size([4, 36, 2048]) torch.Size([4, 25, 300])
2 torch.Size([4, 36, 2048]) torch.Size([4, 25, 300])
3 torch.Size([4, 36, 2048]) torch.Size([4, 25, 300])
