# Example usage of IDGDatasetBase

IDGDatasetBase is DataLoader for image captioning.

## Parameters

* dataset_path
        
        Path to datast created by shells/pre_process.sh
        dataset is saved to data/captions/converted/XXX/YYY.pkl as a default.
        
* vocab_path
        
        Path to vocabulary dictionary created by shells/pre_process.sh
        vocalulary dictionary is save to data/vocab/XXX.pkl as a default.

* img_root

        Path to root directory that contains MSCOCO images.
        This has to be specified when raw_caption is True.
        
* img_feature_root

        Path to root directory that contains image features.
        This has to be specified when raw_caption is False.

* raw_caption

        when raw_caption is True, it returns list of tokenized captions.
        if False, it returns numpy.nparray.

* raw_img

        When raw_img is True, it use raw images downloaded from MSCOCO dataset.
        if False, it uses image features processed beforehand.
        but this repository doesn't contain preprocessed image features.
        So it can't be False.

* img_mean

        This parameter is used for preprocess images.
        The default value is "imagenet", mean value of imagenet is substracted from each images.
        if None, original RGB values are used.
        You can specify mean values like (123.44, 355.22, 235.2)

* img_size
        
        Output size of image. Default size is (244, 244).
        if image size is more/less than img_size, the image is automatically resized.

* preload_features

        When preload_features is True, all features preprocessed beforehand are preloaded onto RAM.
        it consumes much RAM.
        if dataset is MSCOCO train 2014 and each feature size is 2048, then it would takes about &GB.
        But this parameter is also can't be used because This repository doesn't contain preprocessed
        image features. 
        If you want to use preprocessed image features. you use ResNet or something to get them.

In [3]:
import chainer
from IDGDataset import IDGDatasetBase

In [4]:
# set paths
# if you use shells/pre_process_sh, dataset contains both image paths and captions, 
# and vocabulary dictionary is saved to designated directory like below.

dataset_path = "data/captions/converted/MSCOCO_captions/train2014.pkl"
vocab_path = "data/vocab/mscoco_train2014_vocab.pkl"
img_root = "data/images/original"
#img_feature_root = "data/images/ResNet50"

In [5]:
# set other configurations
# see help(IDGDatasetBase) for detail.
raw_caption = False
raw_img = True
img_mean = "imagenet"
img_size = (224,224)
preload_features = True

In [6]:
# load dataset using chainer.dataset.Mixin wrapper.
train_data = IDGDatasetBase(
                dataset_path,
                vocab_path,
                img_root=img_root,
                raw_caption=raw_caption,
                raw_img=raw_img,
                img_mean=img_mean,
                img_size=img_size,
                preload_features=preload_features
)

In [11]:
# preprocessed image and encoded caption can be loaded.
img, caption = train_data[1]

print(img)
print(caption)

[[[ -85.255325    -81.398186    -73.19411    ...  -14.928864
     -3.4185715    35.591484  ]
  [ -82.90839     -93.30635     -76.30635    ...   -9.265175
      7.305847     47.3059    ]
  [ -83.29614     -70.13288     -89.00023    ...   -6.489624
     -3.969635     -7.7549973 ]
  ...
  [ -89.29612     -89.04106     -92.7349     ...  146.63242
    146.6528      145.20386   ]
  [ -84.09198     -84.88788     -84.17357    ...  146.81613
    145.63248     146.13248   ]
  [ -89.22446     -77.29615     -88.765656   ...  148.16312
    142.32632     142.47937   ]]

 [[ -78.58512     -80.09532     -63.83002    ...  -30.493385
    -19.513618     36.353554  ]
  [ -83.33002     -83.360634    -79.08512    ...  -24.074593
     -7.656563     46.28225   ]
  [ -77.09532     -76.47288     -77.860634   ...  -11.809265
    -15.411659    -13.941849  ]
  ...
  [-100.717766    -98.31984     -97.72798    ...   51.363853
     48.75158      46.221     ]
  [ -98.197395   -100.156525    -92.411606   ...   48.11904

In [12]:
# index2token decode captions encoded.
# token2index encode tokenized captions.

img, caption = train_data[1]


dec_caption = train_data.index2token(caption)

print('decoded caption')
print(dec_caption)

re_encoded_caption = train_data.token2index(dec_caption)

print('encoded caption')
print(re_encoded_caption)

decoded caption
['<SOS>', 'a', 'long', 'restaurant', 'table', 'with', '<UNK>', 'rounded', 'back', 'chairs', '<EOS>']
encoded caption
[1, 3, 248, 340, 23, 8, 0, 4069, 164, 285, 2]


In [13]:
# You can get list of words by get_word_ids.
word_ids = train_data.get_word_ids

print(word_ids)



In [14]:
# number of words
num_words = len(word_ids)

print('The number of words: %d' % num_words)

The number of words: 8823


In [15]:
# you can get <UNK> ratio by get_unk_rato
unk_ratio = train_data.get_unk_ratio

print('unk ratio: %.3f' %unk_ratio)

unk ratio: 0.061


In [17]:
# you just need to send train_data to chainer.iterators.SerialIterators()
batchsize = 128
train_iter = chainer.iterators.SerialIterator(train_data, batchsize)

[(array([[[125.44494   , 126.18566   , 127.19463   , ..., 122.03393   ,
           122.11216   , 121.677055  ],
          [125.44492   , 125.061     , 126.061     , ..., 122.061     ,
           123.04983   , 121.68133   ],
          [125.44494   , 125.061     , 124.00518   , ..., 121.375725  ,
           122.061     , 122.061     ],
          ...,
          [-47.308098  , -52.520134  , -35.594345  , ...,  -1.64357   ,
           -20.79663   ,  -9.23629   ],
          [-28.163506  , -21.274345  , -38.323265  , ...,  -0.43891144,
            -4.829262  , -14.042778  ],
          [-30.009758  , -24.83329   , -11.090401  , ...,  -3.5963516 ,
             4.012726  , -18.525627  ]],
  
         [[ 91.279045  ,  90.89512   ,  91.19483   , ...,  86.19392   ,
            87.27216   ,  86.83706   ],
          [ 92.58262   ,  91.35992   ,  92.130394  , ...,  86.221     ,
            88.20985   ,  86.84135   ],
          [ 90.60493   ,  90.90627   ,  91.96133   , ...,  86.221     ,
            8