In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision.transforms as T
import torch.nn.functional as F
from torchvision.io import read_image, ImageReadMode
from PIL import Image
import os
import json
import pickle
import numpy as np

class RefCocoG_Dataset(Dataset):
    full_annotations = None

    def __init__(self, root_dir, annotations_f, instances_f, split='train', transform=None, target_transform=None) -> None:
        super().__init__()

        self.root_dir = root_dir
        self.annotations_f = annotations_f
        self.instances_f = instances_f

        self.split = split

        self.transform = transform
        self.target_transform = target_transform

        if self.transform is None:
            self.transform = T.Compose([
                T.Resize(size=224, interpolation=T.InterpolationMode.BICUBIC, max_size=None, antialias='warn'),
                T.CenterCrop(size=(224, 224))
            ])

        self.get_annotations()
        self.image_names = list([
            self.annotations[id]['image']['actual_file_name']
            for id in self.annotations
        ])

    def get_annotations(self):
        if RefCocoG_Dataset.full_annotations:
            self.annotations = dict(filter(lambda match: match[1]['image']['split'] == self.split, RefCocoG_Dataset.full_annotations.items()))
            return

        # Load pickle data
        with open(os.path.join(self.root_dir, 'annotations', self.annotations_f), 'rb') as file:
            self.data = pickle.load(file)

        # Load instances
        with open(os.path.join(self.root_dir, 'annotations', self.instances_f), 'rb') as file:
            self.instances = json.load(file)

        # Match data between the two files and build the actual dataset
        self.annotations = {}

        images_actual_file_names = {}
        for image in self.instances['images']:
            images_actual_file_names[image['id']] = image['file_name']

        for image in self.data:
            if image['ann_id'] not in self.annotations:
                self.annotations[image['ann_id']] = {}

            self.annotations[image['ann_id']]['image'] = image
            self.annotations[image['ann_id']]['image']['actual_file_name'] = images_actual_file_names[image['image_id']]

        for annotation in self.instances['annotations']:
            if annotation['id'] not in self.annotations:
                continue

            self.annotations[annotation['id']]['annotation'] = annotation

        # Keep only samples from the given split
        RefCocoG_Dataset.full_annotations = self.annotations
        self.annotations = dict(filter(lambda match: match[1]['image']['split'] == self.split, self.annotations.items()))

    def __len__(self):
        # Return the number of images
        return len(self.image_names)

    def corner_size_to_corners(self, bounding_box):
        """
        Transform (top_left_x, top_left_y, width, height) bounding box representation
        into (top_left_x, top_left_y, bottom_right_x, bottom_right_y)
        """

        return [
            bounding_box[0],
            bounding_box[1],
            bounding_box[0] + bounding_box[2],
            bounding_box[1] + bounding_box[3]
        ]

    def __getitem__(self, idx):
        # Get the image name at the given index
        image_name = self.image_names[idx]

        # Load the image file as a PIL image
        # image = Image.open(os.path.join(self.root_dir, 'images', image_name))
        image = read_image(os.path.join(self.root_dir, 'images', image_name), ImageReadMode.RGB)
        
        image_id = list(self.annotations)[idx]

        # print(image_id)

        # Get the caption for the image
        prompts = [
            prompt['sent'] for prompt in self.annotations[image_id]['image']['sentences']
        ]

        # Get the bounding box for the prompts for the image
        bounding_box = self.corner_size_to_corners(self.annotations[image_id]['annotation']['bbox'])

        # Apply the transform if given
        if self.transform:
            image = self.transform(image)

        sample = [
            image,
            bounding_box,
            prompts,
        ]

        # Return the sample as a list
        return sample

In [2]:
dataset_train = RefCocoG_Dataset('refcocog', 'refs(umd).p', 'instances.json', split='train')
dataset_val = RefCocoG_Dataset('refcocog', 'refs(umd).p', 'instances.json', split='val')
dataset_test = RefCocoG_Dataset('refcocog', 'refs(umd).p', 'instances.json', split='test')

dataset_splits = [
    dataset_train,
    dataset_val,
    dataset_test
]

In [3]:
len(RefCocoG_Dataset.full_annotations), len(dataset_train.annotations), len(dataset_val.annotations), len(dataset_test.annotations)

(49820, 42224, 2573, 5023)

In [11]:
def collate_differently_sized_prompts(batch):
    images = [item[0] for item in batch]
    bboxes = [item[1] for item in batch]
    prompts = [item[2] for item in batch]

    return torch.stack(images, dim=0), list(bboxes), list(prompts)

def get_data(dataset_splits, batch_size=64, test_batch_size=256):
    training_data = dataset_splits[0]
    validation_data = dataset_splits[1]
    test_data = dataset_splits[2]

    # Change shuffle to True for train
    train_loader = torch.utils.data.DataLoader(training_data, batch_size, shuffle=True, drop_last=True, collate_fn=collate_differently_sized_prompts, num_workers=0)
    val_loader = torch.utils.data.DataLoader(validation_data, test_batch_size, shuffle=False, collate_fn=collate_differently_sized_prompts, num_workers=0)
    test_loader = torch.utils.data.DataLoader(test_data, test_batch_size, shuffle=False, collate_fn=collate_differently_sized_prompts, num_workers=0)

    return train_loader, val_loader, test_loader

In [12]:
train_loader, val_loader, test_loader = get_data(dataset_splits, batch_size=128, test_batch_size=128)

In [13]:
for batch_idx, (image, bounding_box, prompts) in enumerate(test_loader):
    
    print(f'-- Batch index: {batch_idx} --')

    print(image.shape)
    print(len(bounding_box))
    print(len(prompts))

-- Batch index: 0 --
torch.Size([128, 3, 224, 224])
128
128
-- Batch index: 1 --
torch.Size([128, 3, 224, 224])
128
128
-- Batch index: 2 --
torch.Size([128, 3, 224, 224])
128
128
-- Batch index: 3 --
torch.Size([128, 3, 224, 224])
128
128
-- Batch index: 4 --
torch.Size([128, 3, 224, 224])
128
128
-- Batch index: 5 --
torch.Size([128, 3, 224, 224])
128
128
-- Batch index: 6 --
torch.Size([128, 3, 224, 224])
128
128
-- Batch index: 7 --
torch.Size([128, 3, 224, 224])
128
128
-- Batch index: 8 --
torch.Size([128, 3, 224, 224])
128
128
-- Batch index: 9 --
torch.Size([128, 3, 224, 224])
128
128
-- Batch index: 10 --
torch.Size([128, 3, 224, 224])
128
128
-- Batch index: 11 --
torch.Size([128, 3, 224, 224])
128
128
-- Batch index: 12 --
torch.Size([128, 3, 224, 224])
128
128
-- Batch index: 13 --
torch.Size([128, 3, 224, 224])
128
128
-- Batch index: 14 --
torch.Size([128, 3, 224, 224])
128
128
-- Batch index: 15 --
torch.Size([128, 3, 224, 224])
128
128
-- Batch index: 16 --
torch.Size([