In [19]:
import os
import random
import time
import json
import warnings 
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import cv2

import numpy as np
import pandas as pd
from tqdm import tqdm

# 전처리를 위한 라이브러리
from pycocotools.coco import COCO
import torchvision
import torchvision.transforms as transforms

import albumentations as A
from albumentations.pytorch import ToTensorV2

# 시각화를 위한 라이브러리
import matplotlib.pyplot as plt
import seaborn as sns; sns.set()
from matplotlib.patches import Patch
import webcolors
from collections import Counter, defaultdict

plt.rcParams['axes.grid'] = False

print('pytorch version: {}'.format(torch.__version__))
print('GPU 사용 가능 여부: {}'.format(torch.cuda.is_available()))

print(torch.cuda.get_device_name(0))
print(torch.cuda.device_count())

# GPU 사용 가능 여부에 따라 device 정보 저장
device = "cuda" if torch.cuda.is_available() else "cpu"

pytorch version: 1.7.1
GPU 사용 가능 여부: True
Tesla V100-SXM2-32GB
1


In [20]:
batch_size = 1  
dataset_path  = '../input/data'
anns_file_path = dataset_path + '/' + 'train_all.json'
train_path = dataset_path + '/train.json'
val_path = dataset_path + '/val.json'

In [21]:
transform1 = A.Compose([
    A.Resize(width=1024, height=1024),
]
)


category_names = ['Backgroud',
 'General trash',
 'Paper',
 'Paper pack',
 'Metal',
 'Glass',
 'Plastic',
 'Styrofoam',
 'Plastic bag',
 'Battery',
 'Clothing']


def get_classname(classID, cats):
    for i in range(len(cats)):
        if cats[i]['id']==classID:
            return cats[i]['name']
    return "None"



In [22]:
transform = transform1
coco = COCO(anns_file_path)

loading annotations into memory...
Done (t=5.26s)
creating index...
index created!


In [23]:
with open(anns_file_path) as json_file:
    origin_data_json = json.load(json_file)
indexes = []
for i in origin_data_json['images']:
    indexes.append(i['id'])

In [25]:
save_obj = {}

for index in indexes:
    image_id = coco.getImgIds(imgIds=index)
    image_infos = coco.loadImgs(image_id)[0]

    images = cv2.imread(os.path.join(dataset_path, image_infos['file_name']))
    images = cv2.cvtColor(images, cv2.COLOR_BGR2RGB).astype(np.float32)
    images /= 255.0

    ann_ids = coco.getAnnIds(imgIds=image_infos['id'])
    anns = coco.loadAnns(ann_ids)

    cat_ids = coco.getCatIds()
    cats = coco.loadCats(cat_ids)

    masks = np.zeros((image_infos["height"], image_infos["width"]))

    anns = sorted(anns, key=lambda idx : len(idx['segmentation'][0]), reverse=False)
    for i in range(len(anns)):
        className = get_classname(anns[i]['category_id'], cats)
        pixel_value = category_names.index(className)
        masks[coco.annToMask(anns[i]) == 1] = pixel_value
    masks = masks.astype(np.int8)

    # transform -> albumentations 라이브러리 활용
    transformed = transform(image=images, mask=masks)
    images = transformed["image"]
    masks = transformed["mask"]

    plt.imsave(f'./resized_data/image/{index}.jpg', images)
    np.save(f'./resized_data/mask/{index}', masks)    