# COCO dataset
> Dataclasses defining the COCO dataset and how to convert it to/from a dict.

In [None]:
# default_exp coco

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# hide
from IPython.display import display

In [None]:
# export

import logging
import random
from abc import abstractmethod
from datetime import datetime
from dataclasses_json import dataclass_json
from dataclasses import fields, asdict, field, replace
from pydantic.dataclasses import dataclass
from typing import *
from pathlib import Path
from cocorepr.utils import sanitize_filename

Executing <Handle IOLoop.add_future.<locals>.<lambda>(<Future finis...queues.py:248>) at /Users/ysem/miniconda3/envs/cocorepr36/lib/python3.6/site-packages/tornado/ioloop.py:688 created at /Users/ysem/miniconda3/envs/cocorepr36/lib/python3.6/site-packages/tornado/concurrent.py:184> took 0.102 seconds


In [None]:
# export
logger = logging.getLogger()

In [None]:
# export
@dataclass_json
@dataclass
class CocoElement:
    def to_dict_skip_nulls(self):
        """ Same as `self.to_dict()` but does not add those fields
            whose values are `None`.
            WARNING! If you explicitly set a `None` value to a field
                     that has a non-`None` default value, it still
                     won't be dumped and will be deserialized wrongly.
        """
        return asdict(
            self,
            dict_factory=(
                lambda kv: {
                    k: v
                    for k, v in kv
                    if v is not None
                }
            )
        )

    @property
    def collection_name(self) -> str:
        raise NotImplementedError

    def is_valid(self) -> bool:
        raise NotImplementedError

In [None]:
# hide
@dataclass
class SampleCocoElement(CocoElement):
    a: Optional[int] = None
    b: int = 3

# 'b' has default 3, so we have to preserve it
d = SampleCocoElement(a=2).to_dict_skip_nulls()
assert d == {'a': 2, 'b': 3}, d

d = SampleCocoElement().to_dict_skip_nulls()
assert d == {'b': 3}, d

# 'b' gets non-default value 10, so we have to preserve it
d = SampleCocoElement(b=10).to_dict_skip_nulls()
assert d == {'b': 10}, d

# 'b' gets non-int value '10', but pydantic parses it to int
d = SampleCocoElement(b='10').to_dict_skip_nulls()
assert d == {'b': 10}, d

# 'b' gets non-int value 'abrakadabra', so pydantic cannot parse it and fails
try:
    d = SampleCocoElement(b='abrakadabra').to_dict_skip_nulls()
except:
    pass
else:
    raise RuntimeError("Pudantic did not raise error!")

In [None]:
# hide
@dataclass
class SampleCocoElementList(CocoElement):
    c: List[SampleCocoElement]

# 'b' has default 3, so we have to preserve it
d = SampleCocoElementList(c=[SampleCocoElement()]).to_dict_skip_nulls()
assert d == {'c': [{'b': 3}]}, d

In [None]:
# export

@dataclass
class CocoInfo(CocoElement):
    year: Optional[int] = None
    version: Optional[str] = None
    description: Optional[str] = None
    contributor: Optional[str] = None
    url: Optional[str] = None
    date_created: Optional[str] = None

    @property
    def collection_name(self):
        return "info"

    def is_valid(self) -> bool:
        return True  # no restrictions on the format

In [None]:
coco_info_dict = {
 'year': 2020,
 'version': 'v1',
 'description': 'desc',
 'contributor': 'me',
 'url': 'http://url',
}

coco_info = CocoInfo.from_dict(coco_info_dict)
coco_info_dict2 = coco_info.to_dict_skip_nulls()

display(coco_info)
display(coco_info_dict2)
assert coco_info_dict2 == coco_info_dict, coco_info_dict2

assert coco_info.collection_name == 'info'
assert coco_info.is_valid()

CocoInfo(year=2020, version='v1', description='desc', contributor='me', url='http://url', date_created=None)

{'year': 2020,
 'version': 'v1',
 'description': 'desc',
 'contributor': 'me',
 'url': 'http://url'}

In [None]:
# export

@dataclass
class CocoLicense(CocoElement):
    id: str
    name: str
    url: Optional[str] = None

    @property
    def collection_name(self):
        return "licenses"

    def is_valid(self) -> bool:
        return True  # no restrictions on the format

In [None]:
coco_license_dict = {
 'id': '1',
 'name': 'Attribution-NonCommercial-NoDerivatives 4.0 International',
 'url': 'https://creativecommons.org/licenses/by-nc-nd/4.0/',
}

coco_license = CocoLicense.from_dict(coco_license_dict)
coco_license_dict2 = coco_license.to_dict_skip_nulls()
display(coco_license)
display(coco_license_dict2)

assert coco_license.is_valid()

assert coco_license_dict2 == coco_license_dict, coco_license_dict2
assert coco_license.collection_name == 'licenses'

# --
# test minimal required fields
assert CocoLicense.from_dict({'id': 2, 'name': 'Apache 2.0'}) == CocoLicense(id=2, name='Apache 2.0')

CocoLicense(id='1', name='Attribution-NonCommercial-NoDerivatives 4.0 International', url='https://creativecommons.org/licenses/by-nc-nd/4.0/')

{'id': '1',
 'name': 'Attribution-NonCommercial-NoDerivatives 4.0 International',
 'url': 'https://creativecommons.org/licenses/by-nc-nd/4.0/'}

In [None]:
# export

@dataclass
class CocoImage(CocoElement):
    id: str
    coco_url: str
    width: Optional[int] = None
    height: Optional[int] = None
    license: Optional[int] = None
    file_name: Optional[str] = None
    flickr_url: Optional[str] = None
    date_captured: Optional[str] = None

    def is_valid(self) -> bool:
        try:
            assert self.id
            assert self.coco_url
            return True
        except:
            return False

    @property
    def collection_name(self):
        return "images"

    def get_file_name(self) -> str:
        return self.file_name or Path(self.coco_url).name

In [None]:
coco_image_dict = {
 'id': '204800',
 'license': 1,
 'coco_url': 'https://outforz.s3.amazonaws.com/media/public/content/2021/01/10/e2e76667-f7e.jpg',
 'width': 1920,
 'height': 2560,
 'file_name': 'e2e76667-f7e.jpg',
 'date_captured': '2021-01-05 13:18:13',
}


coco_image = CocoImage.from_dict(coco_image_dict)
display(coco_image)

assert coco_image.is_valid()

assert coco_image.to_dict_skip_nulls() == coco_image_dict, coco_image.to_dict_skip_nulls()
assert coco_image.collection_name == 'images'

# --
# test minimal required fields
assert CocoImage.from_dict({'id': 2, 'coco_url': 'http://image'}) == CocoImage(id=2, coco_url='http://image')

assert CocoImage(id=1, coco_url='http://abc.com/keks.jpg').get_file_name() == 'keks.jpg'

CocoImage(id='204800', coco_url='https://outforz.s3.amazonaws.com/media/public/content/2021/01/10/e2e76667-f7e.jpg', width=1920, height=2560, license=1, file_name='e2e76667-f7e.jpg', flickr_url=None, date_captured='2021-01-05 13:18:13')

In [None]:
# export

@dataclass
class CocoAnnotation(CocoElement):
    id: str
    image_id: str

    def is_valid(self) -> bool:
        try:
            assert self.id
            assert self.image_id
            return True
        except:
            return False

    def get_file_name(self) -> str:
        return f'{self.id}.png'

@dataclass
class CocoObjectDetectionAnnotation(CocoAnnotation):
    category_id: str
    bbox: Optional[Tuple[int, ...]]
    supercategory: Optional[str] = None
    area: Optional[int] = None
    iscrowd: Optional[int] = None

    def is_valid(self) -> bool:
        if not super().is_valid():
            return False
        try:
            assert self.category_id
            x, y, w, h = map(int, self.bbox)
            assert x >= 0, x
            assert y >= 0, y
            assert w >= 0, w
            assert h >= 0, h
            return True
        except:
            return False


In [None]:
c = CocoObjectDetectionAnnotation(id=1, image_id=2, category_id=3, bbox=[])
assert not c.is_valid()

c = replace(c, bbox=[1, -2, 3, 4])
assert not c.is_valid()

c = replace(c, bbox=[1, 2, 3, 4])
assert c.is_valid()

In [None]:
# export

@dataclass
class CocoCategory(CocoElement):
    id: str

    @abstractmethod
    def get_alias(self):
        raise NotImplementedError

    def is_valid(self) -> bool:
        try:
            assert self.id
            return True
        except:
            return False


@dataclass
class CocoObjectDetectionCategory(CocoCategory):
    name: str
    supercategory: Optional[str] = None

    def get_dir_name(self):
        name = sanitize_filename(self.name)
        return f'{name}--{self.id}'

    def is_valid(self) -> bool:
        if not super().is_valid():
            return False
        if not self.name:
            return False
        return True

In [None]:
c = CocoObjectDetectionCategory(id='12345678', name='Бреф Кольорова вода Евкаліпт НОВИНКА!!! 50 г')
assert c.is_valid()
a = c.get_dir_name()
assert a == 'Бреф_Кольорова_вода_Евкаліпт_НОВИНКА_50_г--12345678', a

In [None]:
# export

@dataclass
class CocoDataset(CocoElement):
    annotations: List[CocoAnnotation] = field(default_factory=list)
    images: List[CocoImage] = field(default_factory=list)
    info: CocoInfo = CocoInfo()
    licenses: List[CocoLicense] = field(default_factory=list)

    @classmethod
    def get_non_collective_elements(cls):
        # TODO: rename to get_individual_elements()
        return ['info']

    @classmethod
    def get_collective_elements(cls):
        default_self = cls()
        non_collective = set(cls.get_non_collective_elements())
        return sorted([f.name for f in fields(default_self) if f.name not in non_collective])

    def to_full_str(self):
        return (
            f'{self.__class__.__name__}(' + \
            ', '.join(f'{k}={len(getattr(self, k))}' for k in self.get_collective_elements()) + \
            ')'
        )

    def is_valid(self) -> bool:
        return all(
            el.is_valid()
            for el in self.annotations + self.images + [self.info] + self.licenses
        )

In [None]:
non_col = CocoDataset.get_non_collective_elements()
col = CocoDataset.get_collective_elements()

assert non_col == ['info'], non_col
assert col == ['annotations', 'images', 'licenses'], col

In [None]:
# export

@dataclass
class CocoObjectDetectionDataset(CocoDataset):
    annotations: List[CocoObjectDetectionAnnotation] = field(default_factory=list)
    categories: List[CocoObjectDetectionCategory] = field(default_factory=list)

In [None]:
non_col = CocoObjectDetectionDataset.get_non_collective_elements()
col = CocoObjectDetectionDataset.get_collective_elements()

assert non_col == ['info'], non_col
assert col == ['annotations', 'categories', 'images', 'licenses'], col

In [None]:
dataset = CocoObjectDetectionDataset(
    info=CocoInfo(year=2017, version='1.0', description='COCO 2017 Dataset', contributor='COCO Consortium', url='http://cocodataset.org', date_created='2017/09/01'),
    images=[CocoImage(id='362343', coco_url='http://image')], 
    annotations=[
      CocoObjectDetectionAnnotation(id='402717', image_id='362343', category_id='10', bbox=(196.7, 254.52, 9.89, 23.19), iscrowd=0),
    ],
    categories=[CocoObjectDetectionCategory(id='10', name="person")],
)

assert dataset.is_valid()
dataset.to_dict_skip_nulls()

{'annotations': [{'id': '402717',
   'image_id': '362343',
   'category_id': '10',
   'bbox': (196, 254, 9, 23),
   'iscrowd': 0}],
 'images': [{'id': '362343', 'coco_url': 'http://image'}],
 'info': {'year': 2017,
  'version': '1.0',
  'description': 'COCO 2017 Dataset',
  'contributor': 'COCO Consortium',
  'url': 'http://cocodataset.org',
  'date_created': '2017/09/01'},
 'licenses': [],
 'categories': [{'id': '10', 'name': 'person'}]}

In [None]:
dataset2 = CocoObjectDetectionDataset.from_dict(dataset.to_dict())
assert dataset2 == dataset, display(dataset2, dataset)

In [None]:
dataset.to_full_str()

'CocoObjectDetectionDataset(annotations=1, categories=1, images=1, licenses=0)'

In [None]:
d = {'info': {'year': 2017,
  'version': '1.0',
  'description': 'COCO 2017 Dataset',
  'contributor': 'COCO Consortium',
  'url': 'http://cocodataset.org',
  'date_created': '2017/09/01'},
 'images': [
     {'id': '204800',
      'width': 1920,
      'height': 2560,
      'file_name': 'e2e76667-f7e.jpg',
      'license': 1,
      'coco_url': 'https://e2e76667-f7e.jpg',
      'date_captured': '2021-01-05 13:18:13'},
 ],
 'annotations': [{'id': '402717',
   'image_id': '362343',
   'category_id': '10',
   'bbox': (196.7, 254.52, 9.89, 23.19),
   'iscrowd': 0}],
    'categories':[]}

CocoObjectDetectionDataset.from_dict(d)

CocoObjectDetectionDataset(annotations=[CocoObjectDetectionAnnotation(id='402717', image_id='362343', category_id='10', bbox=(196, 254, 9, 23), supercategory=None, area=None, iscrowd=0)], images=[CocoImage(id='204800', coco_url='https://e2e76667-f7e.jpg', width=1920, height=2560, license=1, file_name='e2e76667-f7e.jpg', flickr_url=None, date_captured='2021-01-05 13:18:13')], info=CocoInfo(year=2017, version='1.0', description='COCO 2017 Dataset', contributor='COCO Consortium', url='http://cocodataset.org', date_created='2017/09/01'), licenses=[], categories=[])

In [None]:
raw = {
    "id": '204800',
    "license": 1,
    "coco_url": "https://e2e76667-f7e.jpg",
    "width": 1920,
    "height": 2560,
    "file_name": "e2e76667-f7e.jpg",
    "date_captured": "2021-01-05 13:18:13"
}
display(CocoImage.from_dict(raw))
del raw["date_captured"]
display(CocoImage.from_dict(raw))
del raw["id"]
try:
    display(CocoImage.from_dict(raw))
except KeyError:
    pass
else:
    assert False, "no exception"

CocoImage(id='204800', coco_url='https://e2e76667-f7e.jpg', width=1920, height=2560, license=1, file_name='e2e76667-f7e.jpg', flickr_url=None, date_captured='2021-01-05 13:18:13')

CocoImage(id='204800', coco_url='https://e2e76667-f7e.jpg', width=1920, height=2560, license=1, file_name='e2e76667-f7e.jpg', flickr_url=None, date_captured=None)

In [None]:
# export

MAP_COCO_TYPE_TO_DATASET_CLASS = {
    "object_detection": CocoObjectDetectionDataset,
}

def get_dataset_class(coco_kind: str):
    try:
        return MAP_COCO_TYPE_TO_DATASET_CLASS[coco_kind]
    except KeyError:
        raise ValueError(f"Not supported dataset kind: {kind}")

In [None]:
get_dataset_class("object_detection")

__main__.CocoObjectDetectionDataset

In [None]:
# export

def merge_datasets(d1: CocoDataset, d2: CocoDataset) -> CocoDataset:
    if d1 is None:
        return d2
    if d2 is None:
        return d1
    assert isinstance(d1, CocoDataset), (type(d1), d1)
    assert isinstance(d2, CocoDataset), (type(d1), d1)
    t1 = type(d1)
    t2 = type(d2)
    assert t1 == t2, f'Cannot merge datasets: {t1} != {t2}'

    D1 = d1.to_dict()
    D2 = d2.to_dict()
    K1 = set(D1.keys())
    K2 = set(D2.keys())
    assert K1 == K2, f'Cannot merge datasets: {K1} != {K2}'

    res = {}
    for k in K1:
        if isinstance(D1[k], list):
            v1 = {x['id']: x for x in (D1[k] or [])}
            v2 = {x['id']: x for x in (D2[k] or [])}
            v_res = {}
            for i in v1:
                if i in v2 and v1[i] != v2[i]:
                    raise ValueError(f'Invalid "{k}" of id={i}: {v1[i]} != {v2[i]}')
                v_res[i] = v1[i]
            for i in v2:
                if i in v1 and v2[i] != v1[i]:
                    raise ValueError(f'Invalid "{k}" of id={i}: {v2[i]} != {v1[i]}')
                v_res[i] = v2[i]
            res[k] = sorted(v_res.values(), key=lambda x: str(x['id']))
            # we are converting ID to str since sometimes its integer
        else:
            v1 = D1[k] or {}
            v2 = D2[k] or {}
            if not v1:
                res[k] = v2
            elif not v2:
                res[k] = v1
            else:
                assert v1 == v2, f'key={k}: unexpectedly: {v1} != {v2}'
                res[k] = v1

    return t1.from_dict(res)

In [None]:
d1 = {'info': {},
 'images': [{'id': '1', 'coco_url': 'https://image1.jpg'}],
 'annotations': [{'id': '10',
   'image_id': '1',
   'category_id': '1',
   'bbox': (4,3,2,1)}],
 'categories': [{'id': '1', 'name': 'human'}]}

d2 = {'info': {},
 'images': [{'id': '1', 'coco_url': 'https://image1.jpg'},
            {'id': '2', 'coco_url': 'https://image2.jpg'}],
 'annotations': [{'id': '10',
   'image_id': '1',
   'category_id': '1',
   'bbox': (4,3,2,1)},
 {'id': '11',
   'image_id': '2',
   'category_id': '2',
   'bbox': (1,2,3,4)}],
 'categories': [{'id': '2', 'name': 'animal'}]}

In [None]:
d_res = merge_datasets(CocoObjectDetectionDataset.from_dict(d1),
                       CocoObjectDetectionDataset.from_dict(d2))
res = d_res.to_dict_skip_nulls()
display(res)
assert res == {'images': [{'id': '1', 'coco_url': 'https://image1.jpg'},
  {'id': '2', 'coco_url': 'https://image2.jpg'}],
 'info': {},
 'licenses': [],
 'annotations': [{'id': '10',
   'image_id': '1',
   'category_id': '1',
   'bbox': (4, 3, 2, 1)},
  {'id': '11', 'image_id': '2', 'category_id': '2', 'bbox': (1, 2, 3, 4)}],
 'categories': [{'id': '1', 'name': 'human'}, {'id': '2', 'name': 'animal'}]}, res

{'annotations': [{'id': '10',
   'image_id': '1',
   'category_id': '1',
   'bbox': (4, 3, 2, 1)},
  {'id': '11', 'image_id': '2', 'category_id': '2', 'bbox': (1, 2, 3, 4)}],
 'images': [{'id': '1', 'coco_url': 'https://image1.jpg'},
  {'id': '2', 'coco_url': 'https://image2.jpg'}],
 'info': {},
 'licenses': [],
 'categories': [{'id': '1', 'name': 'human'}, {'id': '2', 'name': 'animal'}]}

In [None]:
d2['annotations'][0]['bbox'] = (100, 101, 102, 103)
assert d1['annotations'][0]['id'] == d2['annotations'][0]['id'] \
    and d2['annotations'][0]['bbox'] != d1['annotations'][0]['bbox']

try:
    d_res = merge_datasets(CocoObjectDetectionDataset.from_dict(d1),
                           CocoObjectDetectionDataset.from_dict(d2))
except ValueError:
    pass
else:
    assert False, 'test failed'

In [None]:
assert merge_datasets(None, CocoObjectDetectionDataset.from_dict(d2)) == CocoObjectDetectionDataset.from_dict(d2)

In [None]:
# export
def shuffle(arr):
    return random.sample(arr, k=len(arr))

In [None]:
a = [1,2,3,4,5]
b = shuffle(a)
a, b

([1, 2, 3, 4, 5], [5, 1, 3, 4, 2])

In [None]:
# export 
def cut_annotations_per_category(coco: CocoDataset, max_annotations_per_category: int) -> CocoDataset:
    """ Returns a copy of the input dataset where each class (category)
        contains up to `max_crops_per_class` crops (annotations)
    """
    imgid2img = {img.id: img for img in coco.images}
    catid2anns = {cat.id: [] for cat in coco.categories}
    for ann in coco.annotations:
        catid2anns[ann.category_id].append(ann)

    images = {}
    annotations = {}
    for _, anns in catid2anns.items():
        if len(anns) > max_annotations_per_category:
            anns = shuffle(anns)[:max_annotations_per_category]
        for ann in anns:
            annotations[ann.id] = ann
            images[ann.image_id] = imgid2img[ann.image_id]
    coco = replace(coco, annotations=sorted(annotations.values(), key=lambda x: x.id))
    coco = replace(coco, images=sorted(images.values(), key=lambda x: x.id))

    return coco

In [None]:
res = cut_annotations_per_category(
    CocoObjectDetectionDataset.from_dict(
        {'info': {},
         'images': [{'id': '1', 'coco_url': 'https://image1.jpg'}],
         'annotations': [
             {'id': '10', 'image_id': '1', 'category_id': '1', 'bbox': (4,3,2,1)},
             {'id': '11', 'image_id': '1', 'category_id': '1', 'bbox': (4,3,2,1)},
             {'id': '12', 'image_id': '1', 'category_id': '2', 'bbox': (4,3,2,1)},
             {'id': '13', 'image_id': '1', 'category_id': '2', 'bbox': (4,3,2,1)},
         ],
         'categories': [
             {'id': '1', 'name': 'animal'},
             {'id': '2', 'name': 'animal'},
         ]}
    ),
    max_annotations_per_category=1
)
assert len(res.annotations) == 2, res.to_dict_skip_nulls()

In [None]:
# export
from collections import defaultdict

def remove_invalid_elements(coco: CocoDataset) -> CocoDataset:
    annid2ann = {ann.id: ann for ann in coco.annotations if ann.is_valid()}
    imgid2img = {img.id: img for img in coco.images if img.is_valid()}
    catid2cat = {cat.id: cat for cat in coco.categories if cat.is_valid()}

    imgid2img_used = {}
    catid2cat_used = {}
    annid2ann_used = {}
    for ann in annid2ann.values():
        img = imgid2img.get(ann.image_id)
        cat = catid2cat.get(ann.category_id)
        if img is not None and cat is not None:
            annid2ann_used[ann.id] = ann
            imgid2img_used[img.id] = img
            catid2cat_used[cat.id] = cat

    coco = replace(
        coco,
        annotations=sorted(annid2ann_used.values(), key=lambda x: str(x.id)),
        images=sorted(imgid2img_used.values(), key=lambda x: str(x.id)),
        categories=sorted(catid2cat_used.values(), key=lambda x: str(x.id)),
    )

    # TODO: filter also licenses
    # TODO: filter also get_non_collective_elements (info)
    return coco

In [None]:
d = {'info': {},
 'images': [{'id': '1', 'coco_url': 'https://image1.jpg'},
            {'id': '', 'coco_url': 'https://image2.jpg'},
            {'id': '2', 'coco_url': ''},
            {'id': '3','coco_url': 'https://valid-but-unused'},
           ],
 'annotations': [
     {'id': '10',
      'image_id': '1',
      'category_id': '1',
      'bbox': (4,3,2,1)},
     {'id': '',
       'image_id': '2',
       'category_id': '1',
       'bbox': (1,2,3,4)},
     {'id': 'ANN-01',
       'image_id': '1',
       'category_id': '1',
       'bbox': (1,2,3,4)},
     {'id': '3',
       'image_id': '2',
       'category_id': '',
       'bbox': (1,2,3,4)},
     {'id': '4',
       'image_id': '2',
       'category_id': '2',
       'bbox': (1,2,3,0)},
     {'id': '5',
       'image_id': '2',
       'category_id': '2',
       'bbox': (1,2,-2,4)},
 ],
 'categories': [
     {'id': '1', 'name': 'animal'},
     {'id': '', 'name': 'nobody'},
 ]}

c = CocoObjectDetectionDataset.from_dict(d)
c2 = remove_invalid_elements(c)
d2 = c2.to_dict_skip_nulls()
assert d2 == {
    'annotations': [
      {
        'id': '10',
        'image_id': '1',
        'category_id': '1',
        'bbox': (4, 3, 2, 1)
      },
      {
        'id': 'ANN-01',
        'image_id': '1',
        'category_id': '1',
        'bbox': (1,2,3,4)
      },
    ],
    'images': [
         {'id': '1', 'coco_url': 'https://image1.jpg'}],
     'info': {},
     'licenses': [],
     'categories': [{'id': '1', 'name': 'animal'}],
}, d2