In [1]:
import torch
import torchvision.transforms as transforms
import torch.utils.data as data
import os
import pickle
import numpy as np
from PIL import Image
import json
import h5py

In [2]:
class CocoDataset(data.Dataset):
    """COCO Custom Dataset compatible with torch.utils.data.DataLoader."""

    def __init__(self, root, origin_file, split, img_tags, vocab):
        """Set the path for images, captions and vocabulary wrapper.

        Args:
            root: image directory.
            json: coco annotation file path.
            vocab: vocabulary wrapper.
            transform: image transformer.
        """
        self.root = root
        if split in {'train', 'restval'}:
            self.split = ['train', 'restval']
        if split in {'val'}:
            self.split = ['val']
        if split in {'test'}:
            self.split = ['test']
        
        with open(origin_file, 'r') as j:
            self.origin_file = json.load(j)
        
        self.images_id = [self.origin_file['images'][index]['imgid'] \
                     for index in range(0,len(self.origin_file['images'])) \
                     if self.origin_file['images'][index]['split'] in self.split]
        
        with open(img_tags, 'r') as j:
            self.img_tags = json.load(j)
      
        with open(vocab, 'r') as j:
            self.vocab = json.load(j)
        self.transform = transforms.Compose([
            transforms.RandomCrop(224),\
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406),
                                 (0.229, 0.224, 0.225))])

    def __getitem__(self, index):
        """Returns one data pair (image and caption)."""
        
        word2id = self.vocab['word_map']
        ID = self.images_id[index]
        
        img_id = self.origin_file['images'][ID]['imgid']
        path = self.origin_file['images'][ID]['filepath'] + \
            '/'+self.origin_file['images'][ID]['filename']

        image = Image.open(os.path.join(self.root, path)).convert('RGB')
        if self.transform is not None:
            image = self.transform(image)

        # Convert caption (string) to word ids.
        tags = []
        t = list(map(str.lower, self.img_tags[str(ID)]))
        tags = [word2id[token] for token in t]
        target = torch.Tensor(tags)
        return image, target

    def __len__(self):
        return len(self.images_id)


def collate_fn(data):
    """Creates mini-batch tensors from the list of tuples (image, caption).

    We should build custom collate_fn rather than using default collate_fn, 
    because merging caption (including padding) is not supported in default.
    Args:
        data: list of tuple (image, caption). 
            - image: torch tensor of shape (3, 256, 256).
            - caption: torch tensor of shape (?); variable length.
    Returns:
        images: torch tensor of shape (batch_size, 3, 256, 256).
        targets: torch tensor of shape (batch_size, padded_length).
        lengths: list; valid length for each padded caption.
    """
    # Sort a data list by caption length (descending order).
    data.sort(key=lambda x: len(x[1]), reverse=True)
    images, captions = zip(*data)

    # Merge images (from tuple of 3D tensor to 4D tensor).
    images = torch.stack(images, 0)

    # Merge captions (from tuple of 1D tensor to 2D tensor).
    lengths = [len(cap) for cap in captions]
    targets = torch.zeros(len(captions), max(lengths)).long()
    for i, cap in enumerate(captions):
        end = lengths[i]
        targets[i, :end] = cap[:end]
    return images, targets, lengths


def get_loader(root, origin_file, split,img_tags, vocab, batch_size, shuffle, num_workers):
    """Returns torch.utils.data.DataLoader for custom coco dataset."""
    # COCO caption dataset
    coco = CocoDataset(root=root,
                       origin_file=origin_file,
                       split=split,
                       img_tags=img_tags,
                       vocab=vocab)

    # Data loader for COCO dataset
    # This will return (images, captions, lengths) for each iteration.
    # images: a tensor of shape (batch_size, 3, 224, 224).
    # captions: a tensor of shape (batch_size, padded_length).
    # lengths: a list indicating valid length for each caption. length is (batch_size).
    data_loader = torch.utils.data.DataLoader(dataset=coco,
                                              batch_size=batch_size,
                                              shuffle=shuffle,
                                              num_workers=num_workers,
                                              collate_fn=collate_fn)
    return data_loader

In [3]:
root = '/home/lkk/datasets/coco2014'
origin_file = root+'/'+'dataset_coco.json'
img_tags='./img_tags.json'
voc = './vocab.json'

In [4]:
d=get_loader(root,origin_file,'train',img_tags,voc,8,True,1)

In [10]:
test=None

In [11]:
for i,(a,b,c) in enumerate(d):
    print(a.shape)
    test=b
    print(b.shape)
    print(c)
    if i == 2:break

torch.Size([8, 3, 224, 224])
torch.Size([8, 22])
[22, 20, 18, 17, 16, 16, 11, 9]
torch.Size([8, 3, 224, 224])
torch.Size([8, 24])
[24, 19, 17, 16, 16, 13, 12, 9]
torch.Size([8, 3, 224, 224])
torch.Size([8, 20])
[20, 17, 16, 16, 15, 13, 11, 11]


In [12]:
test

tensor([[106, 268, 431,   1,  28, 242, 943, 216, 133, 114, 334, 590, 101,  74,
          56, 198, 150, 505, 240,  12],
        [933, 105,  80, 493,  27, 556, 666,  36,  53,  84,  19,  90,  61,  76,
         201, 967,  17,   0,   0,   0],
        [435,   1,  58,  73,  24, 759,  35,  18,  77, 540, 585,  70, 144,   6,
         689,  75,   0,   0,   0,   0],
        [110, 406,   1,  28,   3,  24,  34, 295, 444, 213,   5, 523, 156,   7,
          47,  12,   0,   0,   0,   0],
        [126,  48, 482, 405,   9,  13,   4, 120,  60, 182,  33,   6,  81, 555,
           2,   0,   0,   0,   0,   0],
        [865, 292,  86,  97, 124,  14, 190,  54,  22,  95,  79, 935,  55,   0,
           0,   0,   0,   0,   0,   0],
        [ 97, 108,  10, 111, 259, 408, 139, 167,  21,  17,  12,   0,   0,   0,
           0,   0,   0,   0,   0,   0],
        [ 40, 841, 449,  13, 118, 170, 325, 461, 479,   6, 388,   0,   0,   0,
           0,   0,   0,   0,   0,   0]])

In [13]:
len(d)

14161