In [None]:
import itertools
import os
import random
from abc import ABCMeta, abstractmethod
from contextlib import redirect_stdout
from sys import stderr
import matplotlib.pyplot as plt

from pycocotools.coco import COCO
from skimage.io import imread

DATA_DIR = "/home/docker/src/abyss/deep-learning/data"

# Coco data format

In [None]:
from abyss_deep_learning.datasets.coco import ImageClassificationDataset
from abyss_deep_learning.datasets.translators import AnnotationTranslator

# Test COCO Realisations

## Data type: Classification, Task: Classification

In [None]:
class BasicCsvCaptions(AnnotationTranslator):
    '''base class to transform annotations'''
    def __init__(self):
        pass
    def filter(self, annotation):
        '''Whether or not to use a annotation'''
        return 'caption' in annotation
    def translate(self, annotation):
        '''Transform the annotation in to a list of captions'''
        return annotation['caption'].split(',')

ds = ImageClassificationDataset(
    os.path.join(DATA_DIR, "coco-caption.json"),
    image_dir=DATA_DIR,
    cached=False, translator=BasicCsvCaptions())
# ds.caption_map

In [None]:
image, caption = ds.sample()
print("sample:", image.shape, caption)
print("generated:")
for image, label in ds.generator(endless=False):
    print(image.shape, label)
    plt.figure()
    plt.imshow(image)
    break

## Data type: Object detection, Task: Classification

In [None]:
from skimage.transform import resize

def image_transformer(image):
    return resize(image, (299, 299), mode='constant', cval=0)

class CaptionsFromCatId(AnnotationTranslator):
    '''base class to transform annotations'''
    def __init__(self):
        pass
    def filter(self, annotation):
        '''Whether or not to use a annotation'''
        return 'segmentation' in annotation
    def translate(self, annotation):
        '''Transform the annotation in to a list of captions'''
        return [annotation['category_id']]

ds = ImageClassificationDataset(
    os.path.join(DATA_DIR, "coco-segmentation.json"),
    image_dir=DATA_DIR,
    cached=True,
    preprocess_data=image_transformer,
    translator=CaptionsFromCatId()
)


In [None]:
# %%timeit -n1 -r3
'''Check difference in timing when setting cached True/False'''

image, caption = ds.sample()
print("sample:", image.shape, caption)
print("generated:")
for image, label in ds.generator(endless=False):
    print(image.shape, label)
    plt.figure()
    plt.imshow(image)
    break

In [None]:
ds.print_class_stats()