In [2]:
import torch
import torchvision.transforms as transforms
import torch.utils.data as data
import os
import pickle
import numpy as np
from PIL import Image
import json
import h5py

In [5]:
class CocoDataset(data.Dataset):
    """COCO Custom Dataset compatible with torch.utils.data.DataLoader."""

    def __init__(self, root, origin_file, split, img_tags, vocab):
        """Set the path for images, captions and vocabulary wrapper.

        Args:
            root: image directory.
            json: coco annotation file path.
            vocab: vocabulary wrapper.
            transform: image transformer.
        """
        self.root = root
        if split in {'train', 'restval'}:
            self.split = ['train', 'restval']
        if split in {'val'}:
            self.split = ['val']
        if split in {'test'}:
            self.split = ['test']
        
        with open(origin_file, 'r') as j:
            self.origin_file = json.load(j)
        
        self.images_id = [self.origin_file['images'][index]['imgid'] \
                     for index in range(0,len(self.origin_file['images'])) \
                     if self.origin_file['images'][index]['split'] in self.split]
        
        with open(img_tags, 'r') as j:
            self.img_tags = json.load(j)
      
        with open(vocab, 'r') as j:
            self.vocab = json.load(j)
        self.transform = transforms.Compose([
            transforms.RandomCrop(224),\
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406),
                                 (0.229, 0.224, 0.225))])

    def __getitem__(self, index):
        """Returns one data pair (image and caption)."""
        
        word2id = self.vocab['word_map']
        ID = self.images_id[index]
        
        img_id = self.origin_file['images'][ID]['imgid']
        path = self.origin_file['images'][ID]['filepath'] + \
            '/'+self.origin_file['images'][ID]['filename']

        image = Image.open(os.path.join(self.root, path)).convert('RGB')
        if self.transform is not None:
            image = self.transform(image)

        # Convert caption (string) to word ids.
        tags = []
        t = list(map(str.lower, self.img_tags[str(ID)]))
        tags = [word2id[token] for token in t]
        target = torch.Tensor(tags)
        return image, target

    def __len__(self):
        return len(self.images_id)


def collate_fn(data):
    """Creates mini-batch tensors from the list of tuples (image, caption).

    We should build custom collate_fn rather than using default collate_fn, 
    because merging caption (including padding) is not supported in default.
    Args:
        data: list of tuple (image, caption). 
            - image: torch tensor of shape (3, 256, 256).
            - caption: torch tensor of shape (?); variable length.
    Returns:
        images: torch tensor of shape (batch_size, 3, 256, 256).
        targets: torch tensor of shape (batch_size, padded_length).
        lengths: list; valid length for each padded caption.
    """
    # Sort a data list by caption length (descending order).
    data.sort(key=lambda x: len(x[1]), reverse=True)
    images, captions = zip(*data)

    # Merge images (from tuple of 3D tensor to 4D tensor).
    images = torch.stack(images, 0)

    # Merge captions (from tuple of 1D tensor to 2D tensor).
    lengths = [len(cap) for cap in captions]
    targets = torch.zeros(len(captions), max(lengths)).long()
    for i, cap in enumerate(captions):
        end = lengths[i]
        targets[i, :end] = cap[:end]
    return images, targets, lengths


def get_loader(root, origin_file, split,img_tags, vocab, batch_size, shuffle, num_workers):
    """Returns torch.utils.data.DataLoader for custom coco dataset."""
    # COCO caption dataset
    coco = CocoDataset(root=root,
                       origin_file=origin_file,
                       split=split,
                       img_tags=img_tags,
                       vocab=vocab)

    # Data loader for COCO dataset
    # This will return (images, captions, lengths) for each iteration.
    # images: a tensor of shape (batch_size, 3, 224, 224).
    # captions: a tensor of shape (batch_size, padded_length).
    # lengths: a list indicating valid length for each caption. length is (batch_size).
    data_loader = torch.utils.data.DataLoader(dataset=coco,
                                              batch_size=batch_size,
                                              shuffle=shuffle,
                                              num_workers=num_workers,
                                              collate_fn=collate_fn)
    return data_loader

In [6]:
root = '/home/lkk/datasets/coco2014'
origin_file = root+'/'+'dataset_coco.json'
img_tags='./img_tags.json'
voc = './vocab.json'

In [7]:
d=get_loader(root,origin_file,'train',img_tags,voc,8,True,1)

In [9]:
dd=None

In [19]:
for i,d in enumerate(d):
#     print(a.shape)
#     print(b.shape)
    dd=d
    break

3

In [15]:
a.shape

torch.Size([8, 3, 224, 224])

In [17]:
b

tensor([[147, 766,  62, 508,  32, 126, 235,  30, 104,  58, 129,  24,  10,   5,
          23, 299,  42,  35, 467, 360, 514, 516,   6,  75,   1, 285],
        [791,  19, 263,  38,  95, 239, 120,  32, 248,  87,  53,   3, 129,  17,
         111,   4,  18, 393, 193, 165, 653, 119, 557, 135,  70,  55],
        [391, 370,   9,  15,   3, 127, 312,  96,  23,   2, 222,  36, 611, 246,
          13, 182, 407, 187, 245, 476,  94,   0,   0,   0,   0,   0],
        [  2, 437,  53,   9,  19,  10,   6, 312,  44, 155, 822, 786,  15, 553,
          26, 169,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
        [ 11,  89, 180, 721,   3,   4,  18, 195, 133, 134,  39, 139,  47,  45,
         179,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
        [ 11, 273, 957,   1,  24, 356,   4, 367, 277, 659, 101, 272,  22,   6,
          47,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
        [712,  87, 625,  43, 476, 100,   3, 127,  34, 658, 673, 173, 197, 201,
         480,   0,   0,   0

In [4]:
images_id = [file['images'][index]['imgid'] \
            for index in range(len(file['images'])) \
            if file['images'][index]['split'] in ['train', 'restval']]

In [5]:
len(images_id)

113287

In [6]:
with open(img_tags, 'r') as j:
    img_tagss = json.load(j)

In [7]:
len(img_tagss)

123287

In [11]:
keys=list(img_tagss.keys())

In [62]:
len(img_tagss)

10000

In [None]:
for ID in keys:
    if str(ID) not in images_id:
        img_tagss.pop(str(ID))