In [23]:
import json
import os

import cv2
import numpy as np
import pandas as pd
from pycocotools.coco import COCO
from torchvision.utils import save_image
from torch.utils.data import Dataset
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm

import albumentations as A
from albumentations.pytorch import ToTensorV2

In [24]:
src_data_dir = '../data/'

# mode='train'
# src_data_json_path = src_data_dir + '/train.json'
# dst_data_dir = '../data/mmseg/'

# mode='val'
# src_data_json_path = src_data_dir + '/val.json'
# dst_data_dir = '../data/mmseg/'

mode='test'
src_data_json_path = src_data_dir + '/test.json'
dst_data_dir = '../data/mmseg/'

In [25]:
classes = ['Background', 'General trash', 'Paper', 'Paper pack', 'Metal', 'Glass', 'Plastic','Styrofoam', 'Plastic bag', 'Battery', 'Clothing']

def get_classname(classID, cats):
    for i in range(len(cats)):
        if cats[i]['id']==classID:
            return cats[i]['name']
    return "None"

class CustomDataLoader(Dataset):
    """COCO format"""
    def __init__(self, data_dir, data_json_path, mode = 'train', transform = None):
        super().__init__()
        self.data_dir = data_dir
        self.mode = mode
        self.transform = transform
        
        self.coco = COCO(data_json_path)
        
    def __getitem__(self, index: int):
        # dataset이 index되어 list처럼 동작
        image_id = self.coco.getImgIds(imgIds=index)
        image_infos = self.coco.loadImgs(image_id)[0]
        
        # cv2 를 활용하여 image 불러오기
        images = cv2.imread(os.path.join(self.data_dir, image_infos['file_name']))
        # images = cv2.cvtColor(images, cv2.COLOR_BGR2RGB).astype(np.float32)
        images = images.astype(np.float32)
        # images /= 255.0
        
        if (self.mode in ('train', 'val')):
            ann_ids = self.coco.getAnnIds(imgIds=image_infos['id'])
            anns = self.coco.loadAnns(ann_ids)

            # Load the categories in a variable
            cat_ids = self.coco.getCatIds()
            cats = self.coco.loadCats(cat_ids)

            # masks : size가 (height x width)인 2D
            # 각각의 pixel 값에는 "category id" 할당
            # Background = 0
            masks = np.zeros((image_infos["height"], image_infos["width"]))
            # General trash = 1, ... , Cigarette = 10
            anns = sorted(anns, key=lambda idx : len(idx['segmentation'][0]), reverse=False)
            for i in range(len(anns)):
                className = get_classname(anns[i]['category_id'], cats)
                pixel_value = classes.index(className)
                masks[self.coco.annToMask(anns[i]) == 1] = pixel_value
            masks = masks.astype(np.int8)
                        
            # transform -> albumentations 라이브러리 활용
            if self.transform is not None:
                transformed = self.transform(image=images, mask=masks)
                images = transformed["image"]
                masks = transformed["mask"]
            return images, masks, image_infos
        
        if self.mode == 'test':
            # transform -> albumentations 라이브러리 활용
            if self.transform is not None:
                transformed = self.transform(image=images)
                images = transformed["image"]
            return images, image_infos
    
    def __len__(self) -> int:
        # 전체 dataset의 size를 return
        return len(self.coco.getImgIds())
    

In [26]:
dataset = CustomDataLoader(data_dir=src_data_dir, data_json_path=src_data_json_path, mode=mode, transform=None)
dataset[0]

loading annotations into memory...
Done (t=0.00s)
creating index...
index created!


(array([[[ 77.,  86.,  89.],
         [ 70.,  79.,  82.],
         [ 63.,  72.,  75.],
         ...,
         [ 18.,  20.,  28.],
         [ 18.,  20.,  28.],
         [ 18.,  20.,  28.]],
 
        [[ 68.,  77.,  80.],
         [ 66.,  75.,  78.],
         [ 61.,  70.,  73.],
         ...,
         [ 18.,  20.,  28.],
         [ 17.,  19.,  27.],
         [ 17.,  19.,  27.]],
 
        [[ 46.,  55.,  58.],
         [ 47.,  56.,  59.],
         [ 45.,  54.,  57.],
         ...,
         [ 17.,  19.,  27.],
         [ 17.,  19.,  27.],
         [ 16.,  18.,  26.]],
 
        ...,
 
        [[ 46.,  58.,  76.],
         [ 33.,  45.,  63.],
         [ 31.,  43.,  61.],
         ...,
         [ 82., 100., 111.],
         [ 43.,  55.,  67.],
         [ 44.,  54.,  64.]],
 
        [[ 31.,  45.,  63.],
         [ 23.,  37.,  55.],
         [ 25.,  39.,  57.],
         ...,
         [ 72.,  94., 105.],
         [ 38.,  54.,  66.],
         [ 44.,  56.,  68.]],
 
        [[ 31.,  45.,  63.],
 

In [27]:
if mode == 'train':
    images_save_dir = os.path.join(dst_data_dir, 'images/train')
    annotations_save_dir = os.path.join(dst_data_dir, 'annotations/train')
elif mode == 'val':
    images_save_dir = os.path.join(dst_data_dir, 'images/val')
    annotations_save_dir = os.path.join(dst_data_dir, 'annotations/val')
else:  # mode == 'test'
    images_save_dir = os.path.join(dst_data_dir, 'test')
    annotations_save_dir = None
    
if not os.path.exists(images_save_dir):
    os.makedirs(images_save_dir)
    print('A directory - ' + images_save_dir + ' is created.')
          
if annotations_save_dir and not os.path.exists(annotations_save_dir):
    os.makedirs(annotations_save_dir)
    print('A directory - ' + annotations_save_dir + ' is created.')
    

if mode in ('train', 'val'):
    for idx in tqdm(range(len(dataset))):
        img, mask, image_infos = dataset[idx]
        image_save_path = os.path.join(images_save_dir, f'{image_infos["id"]:04}.jpg')
        annotation_save_path = os.path.join(annotations_save_dir, f'{image_infos["id"]:04}.png')
        
        cv2.imwrite(image_save_path, img)
        cv2.imwrite(annotation_save_path, mask)

elif mode == 'test':
    for idx in tqdm(range(len(dataset))):
        img, image_infos = dataset[idx]
        image_save_path = os.path.join(images_save_dir, f'{image_infos["id"]:04}.jpg')
        
        cv2.imwrite(image_save_path, img)

A directory - ../data/mmseg/test is created.


100%|██████████| 819/819 [00:08<00:00, 93.59it/s]
