In [None]:
# 라이브러리 및 모듈 import
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
import numpy as np
import pandas as pd
import cv2
import os
import torch
from torch.utils.data import DataLoader, Dataset
import albumentations as A
from albumentations.pytorch import ToTensorV2
import pandas as pd
from tqdm.notebook import tqdm
import json

import matplotlib.pyplot as plt
import seaborn as sns

import matplotlib.patches as patches
from collections import Counter

sns.set_style("darkgrid")

## Data path Setting

In [None]:
annotation = '../../dataset/train.json'
data_dir = '../../dataset'

## DataFrame 생성

In [None]:
with open(annotation) as json_file:
    anns = json.load(json_file)

print(anns.keys())
# print(json.dumps(anns['info'], indent=4))
# print(json.dumps(anns['licenses'], indent=4))
# print(json.dumps(anns['images'], indent=4))
# print(json.dumps(anns['categories'], indent=4))
# print(json.dumps(anns['annotations'][0], indent=4))
print()

label_name = [ann_dict['name'] for ann_dict in anns['categories']]
print(f"labels : {label_name}")

df = pd.json_normalize(anns['annotations'])
df[["X","Y","W","H"]] = list(df.bbox)
df.drop(columns='bbox', inplace=True)
# df['WH_ratio'] = df['W']/df['H']
# df['HW_ratio'] = df['H']/df['W']
df['sqrt_area'] = np.sqrt(df['area'])
df['category_name'] = df['category_id'].apply(lambda x: label_name[x])
df = df[['id', 'image_id', 'category_id', 'category_name', 'area', 'sqrt_area', 'X', 'Y', 'W', 'H', 'iscrowd']]
if len(df['iscrowd'].unique()) == 1:
    df.drop(columns='iscrowd', inplace=True)

print(f"num of image : {len(df['image_id'].unique())}")
df.sample(10)

## DF Describe

In [None]:
df.describe().round(2)

## Class 분포
- 전체 오브젝트들의 class 분포를 살펴봅니다.

In [None]:
fig, ax = plt.subplots(figsize=(16, 9))
sns.countplot(x='category_id', data=df, ax=ax)
ax.set(xticks=range(len(label_name)), xticklabels=label_name)

for idx, val in df['category_id'].value_counts().sort_index().iteritems():
    ax.text(x=idx, y=val, s=val,
            va='bottom', ha='center',
            fontsize=10, fontweight='semibold'
           )

plt.show()

## Image당 지표 확인
- 전체 데이터에서 이미지당 포함된 class의 갯수와 object의 갯수를 추출합니다.

In [None]:
img_by = df[['image_id', 'category_id']].groupby(['image_id']).agg([pd.Series.count, pd.Series.nunique])
img_by.columns = list(map(lambda x: 'cat_' + x[1],img_by.columns))
img_by.sample(10)

In [None]:
img_by.describe()

In [None]:
fig, ax = plt.subplots(figsize=(16, 9))
# sns.kdeplot(x='W', y='H', hue='category_name', data=df, ax=ax)
sns.boxplot(data=img_by, ax=ax)
plt.show()

## Image당 지표 시각화

In [None]:
fig = plt.figure(figsize=(12, 12))
gs = fig.add_gridspec(7, 7) # make 3 by 3 grid (row, col)
axes = [None for _ in range(3)]

axes[0] = fig.add_subplot(gs[0, :6])
axes[1] = fig.add_subplot(gs[1:, :6])
axes[2] = fig.add_subplot(gs[1:, 6])

sns.boxplot(x='cat_nunique', y='cat_count', data=img_by, width=0.3, linewidth=2, fliersize=4, ax=axes[1])
# sns.boxenplot(x='cat_nunique', y='cat_count', data=img_by, width=0.3, linewidth=2, ax=axes[1])
ylim = axes[1].get_ylim()

sns.countplot(x='cat_nunique', data=img_by, ax=axes[0])
# sns.histplot(x='cat_nunique', data=img_by, bins=6, ax=axes[0])
# axes[0].spines[['top', 'bottom', 'right']].set_visible(False)
axes[0].xaxis.set_visible(False)
# axes[0].grid()
for idx, val in img_by['cat_nunique'].value_counts().iteritems():
    if val > 2000:
        axes[0].text(x=idx-1, y=val-150, s=val,
                va='top', ha='center',
                fontsize=10, fontweight='semibold', color='w'
            )
    else:
        axes[0].text(x=idx-1, y=val, s=val,
                va='bottom', ha='center',
                fontsize=10, fontweight='semibold'
            )

sns.countplot(y='cat_count', data=img_by, order=list(range(img_by['cat_count'].max()+1)), ax=axes[2])
axes[2].set_ylim(ylim)
# sns.histplot(y='cat_count', data=img_by, bins=71, ax=axes[2])
# axes[2].spines[['top', 'right', 'left']].set_visible(False)
axes[2].yaxis.set_visible(False)
# axes[2].grid()
axes[2].axvline(x=100, color='royalblue', linestyle='--', linewidth=1, alpha=0.5)
axes[2].text(x=200, y=-1, s='100', color='royalblue',
            va='top', ha='left',
            fontsize=10, fontweight='semibold', alpha=0.7
           )

plt.tight_layout()
plt.show()

## Image 단위 지표 요약
- 대부분의 이미지당 오브젝트 5개 이하, 클래스종류 2개이하를 포함한다.


In [None]:
fig, axes = plt.subplots(1,2, figsize=(20, 10))
sns.kdeplot(x='sqrt_area', hue='category_name', data=df, bw_method=0.2, ax=axes[0])
axes[0].axvline(x=df['sqrt_area'].mean(), color='tomato', linestyle='--', linewidth=1)

# sns.histplot(x='sqrt_area', hue='category_id' ,data=df, ax=axes[0])
sns.boxplot(x='category_id', y='sqrt_area', data=df, width=0.3, linewidth=2, fliersize=4, ax=axes[1])
axes[1].set(xticks=range(len(label_name)), xticklabels=label_name)
axes[1].axhline(y=df['sqrt_area'].mean(), color='tomato', linestyle='--', linewidth=1)
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(16, 9))
sns.kdeplot(x='W', y='H', hue='category_name', data=df, ax=ax)
plt.show()

In [None]:
# fig, ax = plt.subplots(figsize=(16, 9))
sns.lmplot(x='W', y='H', hue='category_name', data=df, scatter_kws={'alpha':0.1}, height=10)
plt.show()

In [None]:
# fig, ax = plt.subplots(figsize=(16, 9))
sns.lmplot(x='W', y='H', hue='category_name', data=df, row='category_name', scatter_kws={'alpha':0.1})
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(12, 12))
sns.kdeplot(x='X', y='Y', hue='category_name', data=df, ax=ax)
plt.show()

In [None]:
# fig, ax = plt.subplots(figsize=(16, 9))
sns.lmplot(x='X', y='Y', hue='category_name', data=df, scatter_kws={'alpha':0.1}, height=10)
plt.show()

In [None]:
df[df['X']<10]

In [None]:
df[df['Y']<10]

In [None]:
df[df['W']<10]

In [None]:
df[df['H']<10]

In [None]:
df.sort_values(['H','W']).tail(25)

In [None]:
len(df[(df['H']>1000) & (df['W']>1000)])

In [None]:
a = df[['image_id', 'category_id']].groupby(['image_id']).agg([pd.Series.count, pd.Series.nunique])

In [None]:
a.columns

In [None]:
a.sort_values(('category_id',   'count')).tail(16).index.tolist()

In [None]:
df[(df['X'] > 1000) | (df['Y'] > 1000)]['image_id'].values.tolist()