In [None]:
import itertools as it
import json
import pickle
import re
import matplotlib.pyplot as plt
import torch
import torchvision

from typing import Literal, Callable, Tuple, TypedDict
from jaxtyping import Array, Float, UInt
from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_image

In [None]:
!mkdir dataset

In [None]:
!wget "https://docs.google.com/uc?export=download&confirm=t&id=1P8a1g76lDJ8cMIXjNDdboaRR5-HsVmUb" -O "dataset/refcocog.tar.gz"

In [None]:
!tar -xf dataset/refcocog.tar.gz -C dataset

In [None]:
!rm dataset/refcocog.tar.gz

In [None]:
torch.manual_seed(17)

In [None]:
Split = Literal['train', 'test', 'val']


class Info(TypedDict, total=True):
    description: str  # This is stable 1.0 version of the 2014 MS COCO dataset.
    url: str  # http://mscoco.org/
    version: str  # 1.0
    year: int  # 2014
    contributor: str  # Microsoft COCO group
    data_created: str  # 2015-01-27 09:11:52.357475


class Image(TypedDict, total=True):
    license: int  # each image has an associated licence id
    file_name: str  # file name of the image
    coco_url: str  # example http://mscoco.org/images/131074
    height: int
    width: int
    data_captured: str  # example '2013-11-21 01:03:06'
    flickr_url: str  # example http://farm9.staticflickr.com/8308/7908210548_33e
    id: int  # id of the image


class License(TypedDict, total=True):
    url: str  # example http://creativecommons.org/licenses/by-nc-sa/2.0/
    id: int  # id of the licence
    name: str  # example 'Attribution-NonCommercial-ShareAlike License'


class Annotation(TypedDict, total=True):
    segmentation: str  # description of the mask; example [[44.17, 217.83, 36.21, 219.37, 33.64, 214.49, 31.08, 204.74, 36.47, 202.68, 44.17, 203.2]]
    area: int  # number of pixel of the described object
    iscrowd: Literal[1, 0]  # Crowd annotations (iscrowd=1) are used to label large groups of objects (e.g. a crowd of people)
    image_id: int  # id of the target image
    bbox: UInt[Array, '4']  # bounding box coordinates [xmin, ymin, width, height]
    category_id: int
    id: int  # annotation id


class Category(TypedDict, total=True):
    supercategory: str  # example 'vehicle'
    id: int  # category id
    name: str  # example 'airplane'


class Instances(TypedDict, total=True):
    info: Info
    images: list[Image]
    licenses: list[License]
    annotations: list[Annotation]
    categories: list[Category]


class Sentence(TypedDict, total=True):
    tokens: list[str]  # tokenized version of referring expression
    raw: str  # unprocessed referring expression
    sent: str  # referring expression with mild processing, lower case, spell correction, etc.
    sent_it: int  # unique referring expression id


class Ref(TypedDict, total=True):
    image_id: int  # unique image id
    split: Split
    sentences: list[Sentence]
    file_name: str  # file name of image relative to img_root
    category_id: int  # object category label
    ann_id: int  # id of object annotation in instance.json
    sent_ids: list[int]  # same ids as nested sentences[...][sent_id]
    ref_id: int  # unique id for refering expression

In [None]:
def fix_ref(x: Ref) -> Ref:
    x['file_name'] = fix_filename(x['file_name'])
    return x


def fix_filename(x: str) -> str:
    """
    :param x: COCO_..._[image_id]_[annotation_id].jpg
    :return:  COCO_..._[image_id].jpg
    """
    return re.sub('_\d+\.jpg$', '.jpg', x)

In [None]:
f = open('dataset/refcocog/annotations/refs(umd).p', 'rb')
refs: list[Ref] = [
    fix_ref(ref)
    for ref in pickle.load(f)
]
f.close()

In [None]:
f = open('dataset/refcocog/annotations/instances.json', 'r')
instances: Instances = json.load(f)
f.close()

In [None]:
id2annotation = {
    x['id']: x
    for x in instances['annotations']
}

In [None]:
I = UInt[torch.Tensor, 'R G B']
P = Sentence
B = Float[torch.Tensor, '4']


class CocoDataset(Dataset[Tuple[Tuple[I, str], B]]):

    def __init__(
            self,
            split: Split,
            img_transform: Callable[[I], I] = lambda x: x,
            prompt_transform: Callable[[list[P]], str] = lambda x: x[0]['sent'],
            bb_transform: Callable[[B], B] = lambda x: x,
    ):
        self.img_transform = img_transform
        self.prompt_transform = prompt_transform
        self.bb_transform = bb_transform

        self.items: list[tuple[tuple[str, list[P]], B]] = [
            ((i, ps), o)
            for ref in refs
            if ref['split'] == split
            for i in [ref['file_name']]
            for ps in [ref['sentences']]
            for o in [torch.tensor(id2annotation[ref['ann_id']]['bbox'])]
        ]

    def __len__(self) -> int:
        return len(self.items)

    def __getitem__(self, item: int) -> Tuple[Tuple[I, str], B]:
        ((i, ps), b) = self.items[item]
        img = read_image('dataset/refcocog/images/' + i)
        return (
            (
                self.img_transform(img),
                self.prompt_transform(ps)
            ),
            self.bb_transform(b)
        )

In [None]:
max_h = max([ img['height'] for img in instances['images'] ])
max_w = max([ img['width'] for img in instances['images'] ])
    
def padding(x: I) -> I:
    z = torch.zeros((3, max_w, max_h), dtype=torch.uint8)
    c, w, h = x.size()
    z[0:c, 0:w, 0:h] = x
    return z

In [None]:
def shortest(ps: list[P]) -> str:
    return min(map(lambda p: p['sent'], ps), key=len)

In [None]:
train_dataloader_custom = DataLoader(
    dataset=CocoDataset(
        split='train',
        img_transform=padding,
        prompt_transform=shortest
    ),  # use custom created train Dataset
    batch_size=10,  # how many samples per batch?
    shuffle=True,  # shuffle the data?
)

test_dataloader_custom = DataLoader(
    dataset=CocoDataset(split='test'),  # use custom created test Dataset
    batch_size=1,
    shuffle=False,  # usually there is no need to shuffle testing data
)

In [None]:
fig, axs = plt.subplots(2, 5, figsize=(25, 10))

(imgs, prompts), outputs = next(iter(train_dataloader_custom))

bboxes = torchvision.ops.box_convert(outputs, in_fmt='xywh', out_fmt='xyxy')

for ax, img, prompt, bbox in zip(it.chain.from_iterable(axs), imgs, prompts, bboxes):
    r, g, b = torch.randint(0, 256, [3]).tolist()
    img_bbox = torchvision.utils.draw_bounding_boxes(
        image=img,
        boxes=bbox.unsqueeze(0),
        colors=(r, g, b),
        width=2,
    )
    ax.imshow(img_bbox.permute(1, 2, 0))
    ax.set_title(prompt)
    ax.set_axis_off()