In [5]:
# default_exp datasets

# datasets

This contains the definitions for all the datasets used for the experiments.

> API details.

In [1]:
#export
from torch.utils.data import Dataset
from PIL import Image

import json
from collections import namedtuple
from pathlib import Path

annotations = namedtuple('Annotations',['image_id','sentences'])

class Flickr8k(Dataset):
    """ for flickr 8k dataset."""
    
    def __init__(self, img_dir, ann_file, split='train', transform=None, target_transform=None):
        """
        Args:
            root (str): The root dir that points to the Flickr images.
            ann_file (str): The file that contains the annotations for the images.
            split ['train', 'val', 'test']: This decides which partition to load.
            transform: Transforms for image.
            target_transforms: transforms for sentences.
        """
        self.img_dir = Path(img_dir)
        assert split in ['train', 'test', 'val']
        self.split = split
        self.transform = transform
        self.target_transform = target_transform
        self.annotations = list()
        
        # indices when spliting the json file
        if self.split == 'train':
            m, n = 0, 6000
        elif self.split == 'val':
            m, n = 6000, 7000
        elif self.split == 'test':
            m, n = 7000, 8000
            
        with open(ann_file, 'r') as ann_file:
            ann_json = json.load(ann_file)
            for image in ann_json['images'][m : n]:
                image_id = image['filename']
                sentences_list = list()
                for sentence in image['sentences']:
                    sentences_list.append(sentence['tokens'])
                self.annotations.append(annotations(image_id, sentences_list))
                
                assert image['split'] == self.split
                
            print('loading %s complete'%(self.split))
        
    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, index):
        img_id = self.annotations[index].image_id
        
        img = Image.open(self.img_dir/img_id).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
            
        # Captions
        target = self.annotations[index].sentences
        if self.target_transform is not None:
            target = self.target_transform(target)
        
        return img_id, img, target

In [2]:
flickr8k_dir = '/home/jithin/datasets/imageCaptioning/flicker8k/Flicker8k_Dataset'
captions_file = '/home/jithin/datasets/imageCaptioning/captions/dataset_flickr8k.json'

In [3]:
dataset = Flickr8k(flickr8k_dir, captions_file, split='val')
len(dataset)

loading val complete


1000

In [4]:
dataset[0]

(<PIL.Image.Image image mode=RGB size=500x333 at 0x7F84BC346BE0>,
 [['the',
   'boy',
   'laying',
   'face',
   'down',
   'on',
   'a',
   'skateboard',
   'is',
   'being',
   'pushed',
   'along',
   'the',
   'ground',
   'by',
   'another',
   'boy'],
  ['two', 'girls', 'play', 'on', 'a', 'skateboard', 'in', 'a', 'courtyard'],
  ['two', 'people', 'play', 'on', 'a', 'long', 'skateboard'],
  ['two',
   'small',
   'children',
   'in',
   'red',
   'shirts',
   'playing',
   'on',
   'a',
   'skateboard'],
  ['two',
   'young',
   'children',
   'on',
   'a',
   'skateboard',
   'going',
   'across',
   'a',
   'sidewalk']])