In [None]:
# 라이브러리 및 모듈 import
from pycocotools.coco import COCO
import numpy as np
import cv2
import os
import torch
from torch.utils.data import Dataset
import albumentations as A
import pandas as pd

import matplotlib.pyplot as plt

import matplotlib.patches as patches
from collections import Counter
import json

In [None]:
class CustomDataset(Dataset):
    '''
      data_dir: data가 존재하는 폴더 경로
      transforms: data transform (resize, crop, Totensor, etc,,,)
    '''

    def __init__(self, annotation, data_dir, transforms=None):
        super().__init__()
        self.data_dir = data_dir
        
        # coco annotation 불러오기 (by. coco API)
        self.coco = COCO(annotation)
        self.predictions = {
            "images": self.coco.dataset["images"].copy(),
            "categories": self.coco.dataset["categories"].copy(),
            "annotations": None
        }
        self.transforms = transforms
    
    def set_transform(transforms):
        self.transforms = transforms

    def __getitem__(self, index: int):
        image_id = self.coco.getImgIds(imgIds=index)

        image_info = self.coco.loadImgs(image_id)[0]
        
        image = cv2.imread(os.path.join(self.data_dir, image_info['file_name']))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)#.astype(np.float32)
        # image /= 255.0

        ann_ids = self.coco.getAnnIds(imgIds=image_info['id'])
        anns = self.coco.loadAnns(ann_ids)

        # boxes (x, y, w, h)
        boxes = np.array([x['bbox'] for x in anns])

        # boxex (x_min, y_min, x_max, y_max)
        boxes[:, 2] = boxes[:, 0] + boxes[:, 2]
        boxes[:, 3] = boxes[:, 1] + boxes[:, 3]
        
        # box별 label
        labels = np.array([x['category_id'] for x in anns])
        labels = torch.as_tensor(labels, dtype=torch.int64)
        
        areas = np.array([x['area'] for x in anns])
        areas = torch.as_tensor(areas, dtype=torch.float32)
        
        is_crowds = np.array([x['iscrowd'] for x in anns])
        is_crowds = torch.as_tensor(is_crowds, dtype=torch.int64)

        target = {'boxes': boxes, 'labels': labels, 'image_id': torch.tensor([index]), 'area': areas,
                  'iscrowd': is_crowds}

        # transform
        if self.transforms:
            while True:
                sample = self.transforms(**{
                    'image': image,
                    'bboxes': target['boxes'],
                    'labels': labels
                })
                if len(sample['bboxes']) > 0:
                    image = sample['image']
                    target['boxes'] = torch.stack(tuple(map(torch.tensor, zip(*sample['bboxes'])))).permute(1, 0)
                    target['boxes'][:,[0,1,2,3]] = target['boxes'][:,[1,0,3,2]]  #yxyx: be warning
                    target['labels'] = torch.tensor(sample['labels'])
                    break
            
        return image, target, image_id
    
    def __len__(self) -> int:
        return len(self.coco.getImgIds())

In [None]:
# Albumentation을 이용, augmentation 선언
def multi_view_transform():
    return A.Compose([
        A.Resize(256, 256),
        # A.Flip(p=0.5),
        # ToTensorV2(p=1.0)
    ], bbox_params={'format': 'pascal_voc', 'label_fields': ['labels']})

def single_view_transform():
    return A.Compose([
        A.Resize(1024, 1024),
    ], bbox_params={'format': 'pascal_voc', 'label_fields': ['labels']})

def collate_fn(batch):
    return tuple(zip(*batch))

In [None]:
class_full_name_list = ['General trash', 'Paper', 'Paper pack', 'Metal', 'Glass', 'Plastic', 'Styrofoam', 'Plastic bag', 'Battery', 'Clothing']
class_short_name_list = ['General', 'Paper', 'Pack', 'Metal', 'Glass', 'Plastic', 'Foam', 'Bag', 'Battery', 'Cloth']
class_color_list = [
    [250,0,50], 
    [0, 255, 0], 
    [0, 180, 80], 
    [185, 185, 185], 
    [100, 100, 100], 
    [200, 50, 150], 
    [50, 150, 200], 
    [50, 200, 150], 
    [200, 200, 200], 
    [255,255,255], 
]
class_color_list = [(np.array(RGB)/255).tolist() for RGB in class_color_list]
boxform = [{'boxstyle': 'round', 'ec': RGB, 'fc': RGB, 'alpha': 0.4} for RGB in class_color_list] # 'fc': np.clip(np.array(RGB)*10, 0., 1.)

annotation = '../../dataset/train.json'
data_dir = '../../dataset'
multi_view_dataset = CustomDataset(annotation, data_dir, multi_view_transform())
single_view_dataset = CustomDataset(annotation, data_dir, single_view_transform())

In [None]:
pick = 3289 #-1 랜덤
text_vis = True

if pick == -1:
    pick = np.random.randint(0, len(single_view_dataset), 1)
image, target, image_id = single_view_dataset[pick]

fig, ax = plt.subplots(figsize=(12, 13))

ax.imshow(image)
ax.set_title(f"label_{Counter(list(target['labels'].numpy()))}, total : {len(target['labels'])}", fontsize=24)
ax.axis('off')

for box, label in zip(target['boxes'], target['labels']):
    y_min, x_min, y_max, x_max = box # 좌표에 문제가 있는듯 하니 조심
    # x_min, y_min, x_max, y_max 
    rect = patches.Rectangle(
            (x_min, y_min),
            x_max-x_min,
            y_max-y_min,
            edgecolor = class_color_list[label],
            fill=False,
    ) 
    ax.add_patch(rect)
    if text_vis:
        ax.text(x_min, y_min, class_full_name_list[label], fontsize=16, horizontalalignment='left', verticalalignment='bottom', bbox=boxform[label])

plt.show()

In [None]:
n_rows, n_cols = 4, 4
vis_cut = 6
fig, axes = plt.subplots(n_rows, n_cols, sharex=True, sharey=True, figsize=(30, 32))

#pick을 변경해서 원하는대로 출력도 가능
pick = np.random.randint(0, len(multi_view_dataset), n_rows*n_cols)
pick = [1671,
 1020,
 3825,
 2521,
 4778,
 4841,
 4492,
 1797,
 1678,
 1955,
 2416,
 2981,
 394,
 4047,
 946,
 4197]
       
for row in range(n_rows):
    for col in range(n_cols):
        image, target, image_id = multi_view_dataset[pick[row*n_cols+col]]
        text_vis = len(target['boxes']) < vis_cut

        axes[row, col].imshow(image)
        if text_vis:
            axes[row, col].set_title(f"labels = {list(target['labels'].numpy())}")
        else:
            axes[row, col].set_title(f"num_label = {len(target['labels'])}")
        axes[row, col].axis('off')

        for box, label in zip(target['boxes'], target['labels']):
            y_min, x_min, y_max, x_max = box # 좌표에 문제가 있는듯 하니 조심
            # x_min, y_min, x_max, y_max 
            rect = patches.Rectangle(
                    (x_min, y_min),
                    x_max-x_min,
                    y_max-y_min,
                    edgecolor = class_color_list[label],
                    fill=False,
                    alpha=0.5
            ) 
            axes[row, col].add_patch(rect)
            if text_vis:
                axes[row, col].text(x_min, y_min, class_short_name_list[label], fontsize=8, horizontalalignment='left', verticalalignment='bottom', bbox=boxform[label])

plt.tight_layout()
plt.show()

In [None]:
with open(annotation) as json_file:
    anns = json.load(json_file)

print(anns.keys())
# print(json.dumps(anns['info'], indent=4))
# print(json.dumps(anns['licenses'], indent=4))
# print(json.dumps(anns['images'], indent=4))
# print(json.dumps(anns['categories'], indent=4))
# print(json.dumps(anns['annotations'][0], indent=4))
print()

label_name = [ann_dict['name'] for ann_dict in anns['categories']]
print(f"labels : {label_name}")

df = pd.json_normalize(anns['annotations'])
df[["X","Y","W","H"]] = list(df.bbox)
df.drop(columns='bbox', inplace=True)
# df['WH_ratio'] = df['W']/df['H']
# df['HW_ratio'] = df['H']/df['W']
df['sqrt_area'] = np.sqrt(df['area'])
df['category_name'] = df['category_id'].apply(lambda x: label_name[x])
df = df[['id', 'image_id', 'category_id', 'category_name', 'area', 'sqrt_area', 'X', 'Y', 'W', 'H', 'iscrowd']]
if len(df['iscrowd'].unique()) == 1:
    df.drop(columns='iscrowd', inplace=True)

df.sample(10)

In [None]:
object_pick = -1 #-1 랜덤

if object_pick == -1:
    object_pick = df.sample(1)['id'].values[0]

pick = df[df['id'] == object_pick]['image_id'].values[0]

label, name, x, y, w, h = df[df['id'] == object_pick][['category_id', 'category_name', 'X', 'Y', 'W', 'H']].values.tolist()[0]
y, x, h, w = int(x), int(y), int(w), int(h)
image, target, image_id = single_view_dataset[pick]

fig, ax = plt.subplots(figsize=(7, 7))

ax.imshow(image[x:x+w+1,y:y+h+1,:])
ax.set_title(f"[{pick:0>4d}.jpg/{object_pick}] [{label}]{name} ({h},{w})", fontsize=24)
ax.axis('off')

plt.show()

In [None]:
n_rows = 10
query =  df[df['X'] < 5] #None이나 df로 설정하면 기본값
# df
# df[df['sqrt_area'] < 25]
# df[df['sqrt_area'] < 30]
# df[df['X'] < 5]

num_class = 10
items = [set() for _ in range(num_class)]
pick = query.groupby('category_id').sample(n_rows*2, replace=True)[['id', 'image_id', 'category_id', 'category_name', 'X', 'Y', 'W', 'H']].values.tolist()

for obj in pick:
    if len(items[obj[2]]) >= n_rows: continue
    items[obj[2]].add(tuple(obj))

max_len = 0
for item in items:
    max_len = max(max_len,len(item))

fig, axes = plt.subplots(max_len, num_class, sharex=False, sharey=False, figsize=(30, 32))
fig.suptitle(f"{label_name}", fontsize=28)

for col in range(num_class):
    for row, (obj_id, img_id, label, name, x, y, w, h) in enumerate(items[col]):
        y, x, h, w = int(x), int(y), int(w), int(h)
        image, target, image_id = single_view_dataset[img_id]

        axes[row, col].imshow(image[x:x+w+1,y:y+h+1,:])
        axes[row, col].set_title(f"[{img_id:0>4d}.jpg/{obj_id}], [{h},{w}]", fontsize=10)
        # axes[row, col].set_title(f"[{pick:0>4d}.jpg/{object_pick}] [{label}]{name} ({w},{h})", fontsize=24)
        axes[row, col].axis('off')

plt.tight_layout(rect=[0, 0, 1, 0.97])
plt.show()