In [None]:
"""
Here, we create a custom dataset
"""
import torch
import pickle
import argparse
import os
import sys
import json
import numpy as np
import re
import pickle
import utils
from utils.types import PathT
from torch.utils.data as data
from typing import Any, Tuple, Dict, List
import torchvision.transforms as transforms
from PIL import Image
# from __future__ import print_function

In [None]:
class MyDataset(Dataset):
    """
    Custom dataset template. Implement the empty functions.
    """
#     def __init__(self, path: PathT) -> None:
#         # Set variables
#         self.path = path

#         # Load features
#         self.features = self._get_features()

#         # Create list of entries
#         self.entries = self._get_entries()

# #     def __getitem__(self, index: int) -> Tuple:
# #         return self.entries[index]['x'], self.entries[index]['y']

    def __init__(self, image_features_path, questions_path, answers_path, answerable_only=False):
        super(VQA, self).__init__()
        with open(questions_path, 'r') as fd:
            questions_json = json.load(fd)
        with open(answers_path, 'r') as fd:
            answers_json = json.load(fd)
        with open("vocab.json", 'r') as fd: # maybe we should remove it
            vocab_json = json.load(fd)
        self._check_integrity(questions_json, answers_json)

        # vocab
        self.vocab = vocab_json
        self.token_to_index = self.vocab['question']
        self.answer_to_index = self.vocab['answer']

        # q and a
        self.questions = list(prepare_questions(questions_json))
        self.answers = list(prepare_answers(answers_json))
        self.questions = [self._encode_question(q) for q in self.questions]
        self.answers = [self._encode_answers(a) for a in self.answers]

        # v
        self.image_features_path = image_features_path
        self.coco_id_to_index = self._create_coco_id_to_index()
        self.coco_ids = [q['image_id'] for q in questions_json['questions']]

        # only use questions that have at least one answer?
        self.answerable_only = answerable_only
        if self.answerable_only:
            self.answerable = self._find_answerable()

    def __getitem__(self, item):
        if self.answerable_only:
            # change of indices to only address answerable questions
            item = self.answerable[item]

        q, q_length = self.questions[item]
        a = self.answers[item]
        image_id = self.coco_ids[item]
        v = self._load_image(image_id)
        # since batches are re-ordered for PackedSequence's, the original question order is lost
        # we return `item` so that the order of (v, q, a) triples can be restored if desired
        # without shuffling in the dataloader, these will be in the order that they appear in the q and a json's.
        return v, q, a, item, q_length

    def __len__(self) -> int:
        """
        :return: the length of the dataset (number of sample).
        """
        return len(self.entries)

    def _get_features(self) -> Any:
        """
        Load all features into a structure (not necessarily dictionary). Think if you need/can load all the features
        into the memory.
        :return:
        :rtype:
        """
        with open(self.path, "rb") as features_file:
            features = pickle.load(features_file)

        return features

    def _get_entries(self) -> List:
        """
        This function create a list of all the entries. We will use it later in __getitem__
        :return: list of samples
        """
        entries = []

        for idx, item in self.features.items():
            entries.append(self._get_entry(item))

        return entries

    @staticmethod
    def _get_entry(item: Dict) -> Dict:
        """
        :item: item from the data. In this example, {'input': Tensor, 'y': int}
        """
        x = item['input']
        y = torch.Tensor([1, 0]) if item['label'] else torch.Tensor([0, 1])

        return {'x': x, 'y': y}

In [None]:
    class CocoImages(data.Dataset):
    """ Dataset for MSCOCO images located in a folder on the filesystem """
    def __init__(self, path, transform=None):
        super(CocoImages, self).__init__()
        self.path = path
        self.id_to_filename = self._find_images()
        self.sorted_ids = sorted(self.id_to_filename.keys())  # used for deterministic iteration order
        print('found {} images in {}'.format(len(self), self.path))
        self.transform = transform

    def _find_images(self):
        id_to_filename = {}
        for filename in os.listdir(self.path):
            if not filename.endswith('.jpg'):
                continue
            id_and_extension = filename.split('_')[-1]
            id = int(id_and_extension.split('.')[0])
            id_to_filename[id] = filename
        return id_to_filename

    def __getitem__(self, item):
        id = self.sorted_ids[item]
        path = os.path.join(self.path, self.id_to_filename[id])
        img = Image.open(path).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)
        return id, img

    def __len__(self):
        return len(self.sorted_ids)
    
    class Composite(data.Dataset):
    """ Dataset that is a composite of several Dataset objects. Useful for combining splits of a dataset. """
    def __init__(self, datasets):
        self.datasets = datasets

    def __getitem__(self, item):
        current = self.datasets
        for d in self.datasets:
            if item < len(d):
                return d[item]
            item -= len(d)
        else:
            raise IndexError('Index too large for composite dataset')

    def __len__(self):
        return sum(map(len, self.datasets))
    
    def get_transform(target_size, central_fraction=1.0):
    return transforms.Compose([
        transforms.Scale(int(target_size / central_fraction)),
        transforms.CenterCrop(target_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])

In [None]:
transform = get_transform(224, 0.875)
paths = ["data/train2014", "data/val2014"]
# temp = CocoImages("data/train2014", transform=transform)
train_dataset = [CocoImages("data/train2014", transform=transform)]
val_dataset = [CocoImages("data/val2014", transform=transform)]
train_dataset = Composite(train_dataset)
val_dataset = Composite(val_dataset)
train_data_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=16,
    num_workers=8,
    shuffle=False,
    pin_memory=True,
)
                          
val_data_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=16,
    num_workers=0,
    shuffle=False,
    pin_memory=True,
)

In [None]:
# train_dataset = MyDataset('data/train.pkl')
train_dataset = MyDataset("data/train2014")
val_dataset = MyDataset('data/validation.pkl')
train_loader = DataLoader(train_dataset, 16 shuffle=True,
                          num_workers=1)
eval_loader = DataLoader(val_dataset, 16 shuffle=True,
                         num_workers=1)

In [None]:
# for img, ques, ans, q_len in tq: