In [None]:
import os
import warnings 
warnings.filterwarnings('ignore')

import json
import shutil

import torch
from torch.utils.data import Dataset
import cv2

import numpy as np
from tqdm import tqdm

# 전처리를 위한 라이브러리
from pycocotools.coco import COCO

#!pip install albumentations==0.4.6
import albumentations as A
from albumentations.pytorch import ToTensorV2

print('pytorch version: {}'.format(torch.__version__))
print('GPU 사용 가능 여부: {}'.format(torch.cuda.is_available()))

print(torch.cuda.get_device_name(0))
print(torch.cuda.device_count())

# GPU 사용 가능 여부에 따라 device 정보 저장
device = "cuda" if torch.cuda.is_available() else "cpu"

dataset_path  = '../../../../data'

In [None]:
def get_classname(classID, cats):
    for i in range(len(cats)):
        if cats[i]['id']==classID:
            return cats[i]['name']
    return "None"

In [None]:
category_names = [
    "Background",
    "General trash",
    "Paper",
    "Paper pack",
    "Metal",
    "Glass",
    "Plastic",
    "Styrofoam",
    "Plastic bag",
    "Battery",
    "Clothing",
]

In [None]:
class CustomDataLoader(Dataset):
    """COCO format"""
    def __init__(self, data_dir, mode = 'train', transform = None):
        super().__init__()
        self.mode = mode
        self.transform = transform
        self.coco = COCO(data_dir)
        
    def __getitem__(self, index: int):
        # dataset이 index되어 list처럼 동작
        image_id = self.coco.getImgIds()[index]
        image_infos = self.coco.loadImgs(image_id)[0]
        
        # cv2 를 활용하여 image 불러오기
        images = cv2.imread(os.path.join(dataset_path, image_infos['file_name']))
        images = cv2.cvtColor(images, cv2.COLOR_BGR2RGB).astype(np.float32)
        images /= 255.0
        
        if (self.mode in ('train', 'val')):
            ann_ids = self.coco.getAnnIds(imgIds=image_infos['id'])
            anns = self.coco.loadAnns(ann_ids)

            # Load the categories in a variable
            cat_ids = self.coco.getCatIds()
            cats = self.coco.loadCats(cat_ids)

            # masks : size가 (height x width)인 2D
            # 각각의 pixel 값에는 "category id" 할당
            # Background = 0
            masks = np.zeros((image_infos["height"], image_infos["width"]))

            # General trash = 1, ... , Cigarette = 10
            anns = sorted(anns, key=lambda idx : idx['area'], reverse=True)
            for i in range(len(anns)):
                className = get_classname(anns[i]['category_id'], cats)
                pixel_value = category_names.index(className)
                masks[self.coco.annToMask(anns[i]) == 1] = pixel_value
            masks = masks.astype(np.int8)
                        
            # transform -> albumentations 라이브러리 활용
            if self.transform is not None:
                transformed = self.transform(image=images, mask=masks)
                images = transformed["image"]
                masks = transformed["mask"]
            return images, masks, image_infos
        
        if self.mode == 'test':
            # transform -> albumentations 라이브러리 활용
            if self.transform is not None:
                transformed = self.transform(image=images)
                images = transformed["image"]
            return images, image_infos
    
    def __len__(self) -> int:
        # 전체 dataset의 size를 return
        return len(self.coco.getImgIds())

In [None]:
# collate_fn needs for batch
def collate_fn(batch):
    return tuple(zip(*batch))

import albumentations as A
from albumentations.pytorch import ToTensorV2

train_transform = A.Compose([
                            ToTensorV2()
                            ])

In [None]:
# seg_mask_suffix png 생성
def convert_ann(json_file):
    
    json_path = f'../../../../data/stratified_kfold/{json_file}'

    # train dataset
    train_dataset = CustomDataLoader(data_dir=json_path, mode='train', transform=train_transform)

    # DataLoader
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           batch_size=1,
                                           shuffle=False,
                                           num_workers=0,
                                           collate_fn=collate_fn)

    fold = json_file.split('.')[0]
    save_dir = f'../../../../data/mmseg/ann_dir/{fold}'

    if not os.path.isdir(save_dir):
        os.mkdir(save_dir)

    print()
    print(f'*****{fold}*****')
    print('saving mask.png')
    for _, masks, image_infos in tqdm(train_loader):
        image_infos = image_infos[0]
        masks = masks[0].numpy()

        cv2.imwrite(os.path.join(save_dir, f"{image_infos['id']:04}.png"), masks)

In [None]:
# mmseg 경로로 이미지 copy
def copy_img(json_file):
    fold = json_file.split('.')[0]

    save_dir = os.path.join('../../../../data/mmseg/img_dir', fold)
    if not os.path.isdir(save_dir):
        os.mkdir(save_dir)
    
    json_path = os.path.join('../../../../data/stratified_kfold', json_file)
    with open(json_path, 'r', encoding='UTF-8') as fold_json:
        fold_data = json.load(fold_json)
        fold_images = fold_data['images']

    
    for img in fold_images:
        shutil.copyfile(
            os.path.join('../../../../data', img['file_name']),
            os.path.join('../../../../data', f"mmseg/img_dir/{fold}/{img['id']:04}.jpg") 
        )
    print(f'{fold} done')

In [None]:
# mmseg 경로로 이미지 copy
def copy_img_test(json_file):
    fold = json_file.split('.')[0]

    save_dir = os.path.join('../../../../data/mmseg/img_dir', fold)
    if not os.path.isdir(save_dir):
        os.mkdir(save_dir)
    
    json_path = os.path.join('../../../../data', json_file)
    with open(json_path, 'r', encoding='UTF-8') as fold_json:
        fold_data = json.load(fold_json)
        fold_images = fold_data['images']

    
    for img in fold_images:
        shutil.copyfile(
            os.path.join('../../../../data', img['file_name']),
            os.path.join('../../../../data', f"mmseg/img_dir/{fold}/{img['id']:04}.jpg") 
        )
    print(f'{fold} done')

In [None]:
# '/opt/ml/input/data/mmseg'
save_dir = os.path.join(dataset_path ,f'mmseg')

if not os.path.isdir(save_dir):
    os.mkdir(save_dir)

# '/opt/ml/input/data/mmseg/ann_dir'
save_dir = os.path.join(save_dir, 'ann_dir')
    
if not os.path.isdir(save_dir):
    os.mkdir(save_dir)

In [None]:
convert_ann('train0.json')
convert_ann('train1.json')
convert_ann('train2.json')
convert_ann('train3.json')
convert_ann('train4.json')
convert_ann('val0.json')
convert_ann('val1.json')
convert_ann('val2.json')
convert_ann('val3.json')
convert_ann('val4.json')

In [None]:
if not os.path.isdir('../../../../data/mmseg/img_dir'):
    os.mkdir('../../../../data/mmseg/img_dir')

In [None]:
copy_img('train0.json')
copy_img('train1.json')
copy_img('train2.json')
copy_img('train3.json')
copy_img('train4.json')
copy_img('val0.json')
copy_img('val1.json')
copy_img('val2.json')
copy_img('val3.json')
copy_img('val4.json')
copy_img_test('test.json')