# API

In [4]:
from pathlib import Path
import json
import datetime
import copy
import os

# Third Parties
import numpy as np
from PIL import Image
import imgaug.augmentables as ia
import albumentations
import albumentations as A

# Torch
from torch.utils.data import Dataset
import torchvision as tv

# Types
from typing import List, Tuple, Optional
import dataclasses
from pydantic.dataclasses import dataclass
from pydantic import validate_arguments, validator, BaseModel

from effdet.east import get_input_image_and_bboxes, scale_bboxes, create_ground_truth

In [2]:
@dataclass
class COCO_Text:
    """
    Constructor of COCO-Text helper class for reading and visualizing annotations.
    :param annotation_file (str): location of annotation file
    :return:
    """
    annotation_file: Path = dataclasses.field(metadata="location of annotation file")
    
    def __post_init_post_parse__(self,):
        
        # load dataset
        self.dataset = {}
        self.anns = {}
        self.imgToAnns = {}
        self.catToImgs = {}
        self.imgs = {}
        self.cats = {}
        self.val = []
        self.test = []
        self.train = []
        if not self.annotation_file == None:
            assert os.path.isfile(self.annotation_file), "file does not exist"
            print('loading annotations into memory...')
            time_t = datetime.datetime.utcnow()
            dataset = json.load(open(self.annotation_file, 'r'))
            print(datetime.datetime.utcnow() - time_t)
            self.dataset = dataset
            self.createIndex()

    def createIndex(self):
        # create index
        print('creating index...')
        self.imgToAnns = {int(cocoid): self.dataset['imgToAnns'][cocoid] for cocoid in self.dataset['imgToAnns']}
        self.imgs      = {int(cocoid): self.dataset['imgs'][cocoid] for cocoid in self.dataset['imgs']}
        self.anns      = {int(annid): self.dataset['anns'][annid] for annid in self.dataset['anns']}
        self.cats      = self.dataset['cats']
        self.val       = [int(cocoid) for cocoid in self.dataset['imgs'] if self.dataset['imgs'][cocoid]['set'] == 'val']
        self.test      = [int(cocoid) for cocoid in self.dataset['imgs'] if self.dataset['imgs'][cocoid]['set'] == 'test']
        self.train     = [int(cocoid) for cocoid in self.dataset['imgs'] if self.dataset['imgs'][cocoid]['set'] == 'train']
        print('index created!')

    def info(self):
        """
        Print information about the annotation file.
        :return:
        """
        for key, value in self.dataset['info'].items():
            print('%s: %s'%(key, value))

    def filtering(self, filterDict, criteria):
        return [key for key in filterDict if all(criterion(filterDict[key]) for criterion in criteria)]

    def getAnnByCat(self, properties):
        """
        Get ann ids that satisfy given properties
        :param properties (list of tuples of the form [(category type, category)] e.g., [('readability','readable')] 
            : get anns for given categories - anns have to satisfy all given property tuples
        :return: ids (int array)       : integer array of ann ids
        """
        return self.filtering(self.anns, [lambda d, x=a, y=b:d[x] == y for (a,b) in properties])

    def getAnnIds(self, imgIds=[], catIds=[], areaRng=[]):
        """
        Get ann ids that satisfy given filter conditions. default skips that filter
        :param imgIds  (int array)     : get anns for given imgs
               catIds  (list of tuples of the form [(category type, category)] e.g., [('readability','readable')] 
                : get anns for given cats
               areaRng (float array)   : get anns for given area range (e.g. [0 inf])
        :return: ids (int array)       : integer array of ann ids
        """
        imgIds = imgIds if type(imgIds) == list else [imgIds]
        catIds = catIds if type(catIds) == list else [catIds]

        if len(imgIds) == len(catIds) == len(areaRng) == 0:
            anns = list(self.anns.keys())
        else:
            if not len(imgIds) == 0:
                anns = sum([self.imgToAnns[imgId] for imgId in imgIds if imgId in self.imgToAnns],[])
            else:
                anns = list(self.anns.keys())
            anns = anns if len(catIds)  == 0 else list(set(anns).intersection(set(self.getAnnByCat(catIds)))) 
            anns = anns if len(areaRng) == 0 else [ann for ann in anns if self.anns[ann]['area'] > areaRng[0] and self.anns[ann]['area'] < areaRng[1]]
        return anns

    def getImgIds(self, imgIds=[], catIds=[]):
        '''
        Get img ids that satisfy given filter conditions.
        :param imgIds (int array) : get imgs for given ids
        :param catIds (int array) : get imgs with all given cats
        :return: ids (int array)  : integer array of img ids
        '''
        imgIds = imgIds if type(imgIds) == list else [imgIds]
        catIds = catIds if type(catIds) == list else [catIds]

        if len(imgIds) == len(catIds) == 0:
            ids = list(self.imgs.keys())
        else:
            ids = set(imgIds)
            if not len(catIds) == 0:
                ids  = ids.intersection(set([self.anns[annid]['image_id'] for annid in self.getAnnByCat(catIds)]))
        return list(ids)

    def loadAnn(self, img_id: int):
        """
        Load anns with the specified ids.
        :param id (int)       : integer id specifying ann
        :return: anns (object) : loaded ann object
        """
        ids = self.getAnnIds(img_id)
        return [self.anns[id] for id in ids]

    def loadImg(self, id: int):
        """
        Load anns with the specified ids.
        :param ids (int array)       : integer ids specifying img
        :return: imgs (object array) : loaded img objects
        """
        return self.imgs[id]


In [19]:
@dataclasses.dataclass
class COCOTextDataset(Dataset):

    img_ids: List[int]
    img_dir: str
    ct: COCO_Text
    image_size: Tuple[int, int] = dataclasses.field(default=(768, 768))
    scale: int = dataclasses.field(default=4)
    transforms: Optional[albumentations.Compose] = dataclasses.field(
        default=None)
        
    to_torch = tv.transforms.Compose([
        tv.transforms.ToTensor(),
        tv.transforms.Lambda(lambda x: 2 * x - 1)
    ])

    def __len__(self,):
        return len(self.img_ids)

    def _img_and_annotations(self, idx):
        img = ct.loadImg(self.img_ids[idx])
        ann = ct.loadAnn(img['id'])

        path = os.path.join(self.img_dir, img['file_name'])
        image = np.array(Image.open(path).convert('RGB'))
        return image, ann

    def _make_bboxes_from_annotations(self, anns, shape):
        bboxes = []
        for a in anns:
            x0, y0, w, h = a['bbox']
            y1, x1 = y0 + h, x0 + w
            bboxes.append(ia.BoundingBox(x0, y0, x1, y1))
        bboxes = ia.BoundingBoxesOnImage(bboxes, shape=shape)
        return bboxes

    def __getitem__(self, idx):
        image, anns = self._img_and_annotations(idx)
        bboxes = self._make_bboxes_from_annotations(anns, image.shape)

        # HERE COMES THE AUGMENTATIONS
        if self.transforms is not None:
            transformed = transforms(image=image, bboxes=bboxes.to_xyxy_array())
            img_aug = transformed['image']
            bboxes_aug = transformed['bboxes']
            bboxes = ia.BoundingBoxesOnImage.from_xyxy_array(
                bboxes_aug, img_aug.shape)

        # SCALES BOXES AND GENERATE GT
        image, bboxes = get_input_image_and_bboxes(
            image, bboxes, self.image_size, self.scale)
        bboxes, mask_boxes = scale_bboxes(bboxes)
        gt_image, _ = create_ground_truth(
            bboxes, mask_boxes, self.image_size, self.scale)

        image = self.to_torch(image)
        
        return image, gt_image

In [20]:
ct = COCO_Text('COCO/COCO_Text.json')

loading annotations into memory...
0:00:01.378568
creating index...
index created!


In [21]:
ct.info()

url: http://vision.cornell.edu/se3/coco-text/
date_created: 2017-03-28
version: 1.4
description: This is 1.4 version of the 2017 COCO-Text dataset.
author: COCO-Text group


In [22]:
transforms = A.Compose([
    A.RandomCrop(250, 250)
], bbox_params=A.BboxParams(format='pascal_voc',  label_fields=[]))

In [23]:
imgIds = ct.getImgIds()
dataDir = 'COCO/train2014/'

In [24]:
ds_debug = COCOTextDataset(imgIds[:2], dataDir, ct, transforms=transforms)

In [25]:
x, y = ds_debug[0]

In [26]:
x.shape, y.shape

(torch.Size([3, 768, 768]), (5, 192, 192))