In [1]:
import os
import cv2
import albumentations as A
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns; sns.set()

from skimage.color import label2rgb
from glob import glob
plt.rcParams['axes.grid'] = False

In [2]:
def label_mask(mask):
    mask = mask[:,:,-1]
    label = np.zeros((mask.shape[0], mask.shape[1], 3)).astype(np.uint8)
    for i in range(1, 11):
        label[mask == i] = palette[i]
    return label

def augment_and_show(transforms, image, mask=None):
    transformed = transforms(image=image, mask=mask)

    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image_aug = cv2.cvtColor(transformed['image'], cv2.COLOR_BGR2RGB)
    
    if mask is None:
        f, ax = plt.subplots(1, 2, figsize=(16, 8))
        
        ax[0].imshow(image)
        ax[0].set_title('Original image')
        
        ax[1].imshow(image_aug)
        ax[1].set_title('Augmented image')
    else:
        f, ax = plt.subplots(2, 2, figsize=(8, 8))
        
        if len(mask.shape) != 3:
            mask = label2rgb(mask, bg_label=0)            
            mask_aug = label2rgb(transformed['mask'], bg_label=0)
        else:
            mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)
            mask_aug = cv2.cvtColor(transformed['mask'], cv2.COLOR_BGR2RGB)
        
        mask = label_mask(mask)
        mask_aug = label_mask(mask_aug)

        ax[0, 0].imshow(image)
        ax[0, 0].set_title('Original image')

        ax[0, 1].imshow(image_aug)
        ax[0, 1].set_title('Augmented image')

        ax[1, 0].imshow(mask, interpolation='nearest')
        ax[1, 0].set_title('Original mask')

        ax[1, 1].imshow(mask_aug, interpolation='nearest')
        ax[1, 1].set_title('Augmented mask')

    return transformed['image'], transformed['mask']

In [3]:
classes = ['Background', 'General trash', 'Paper', 'Paper pack', 'Metal', 'Glass', 'Plastic','Styrofoam', 'Plastic bag', 'Battery', 'Clothing']
palette = [
    [0, 0, 0],
    [192, 0, 128], [0, 128, 192], [0, 128, 64],
    [128, 0, 0], [64, 0, 128], [64, 0, 192],
    [192, 128, 64], [192, 192, 128], [64, 64, 128], [128, 0, 192], [64,64,64]
    ]
img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)

## 적용할 augmentation 미리보기

1. file_path : 기존에 image와 annotation 이 저장된 경로
2. ignore_label : 제외하고 싶은 label
3. transforms : augmentation기법들
- 원하는 augmentation 추가하실 수 있습니다.

In [4]:
########### 이미지와 어노테이션 파일 경로  ###########
file_path = '/opt/ml/input/mmseg2'

########### 제외할 라벨  ###########
# 0 : 'Background'
# 1 : 'General trash'
# 2 : 'Paper' 
# 3 : 'Paper pack'
# 4 : 'Metal'
# 5 : 'Glass'
# 6 : 'Plastic'
# 7 : 'Styrofoam'
# 8 : 'Plastic bag'
# 9 : 'Battery'
# 10: 'Clothing'

# Paper, Plastic, Plastic bag 제외
ignore_label = [2,6,8]

########### Augmentaiton 기법들  ###########
transforms = A.Compose([
            A.HorizontalFlip(p=1),
            A.VerticalFlip(p=1)
            ])

In [None]:
########### 보고 싶은 이미지 id ###########
image_id = '0022'
mask = cv2.imread(os.path.join(file_path,'annotations/train/', image_id +'.png'))
image = cv2.imread(os.path.join(file_path,'images/train/', image_id +'.jpg'))

transformed=transforms(image=image,mask=mask)
show = augment_and_show(transforms,image,mask)

## 이미지 추가

In [None]:
for phase in ['train','valid']:
    image_paths = glob(os.path.join(file_path,'images',phase,'*.jpg'))
    annotation_paths = glob(os.path.join(file_path,'annotations',phase,'*.png'))
    image_paths.sort()
    annotation_paths.sort()
    print(f'{phase}_image_count : {len(image_paths)}')
    print(f'{phase}_annotation_count : {len(annotation_paths)}')
    for i in range(len(image_paths)):
        image_path = image_paths[i]
        mask_path = annotation_paths[i]
        
        image_name = image_path.split('/')[-1].split('.')[0]
        new_image_name = '1'+image_name
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(mask_path)
        
            
        if ignore_label in mask:
            continue
        
        transformed = transforms(image=image,mask=mask)

        image_aug = cv2.cvtColor(transformed['image'], cv2.COLOR_BGR2RGB)
        mask_aug = cv2.cvtColor(transformed['mask'], cv2.COLOR_BGR2RGB)
        
        
        cv2.imwrite(os.path.join(file_path,'images',phase,new_image_name+'.jpg'),image_aug)
        cv2.imwrite(os.path.join(file_path,'annotations',phase,new_image_name+'.png'),mask_aug[:512,:512])
        
    image_paths = glob(os.path.join(file_path,'images',phase,'*.jpg'))
    annotation_paths = glob(os.path.join(file_path,'annotations',phase,'*.png'))

    print('####################### 이미지 추가 이후 ##########################')
    print(f'{phase}_image_count : {len(image_paths)}')
    print(f'{phase}_annotation_count : {len(annotation_paths)}')
    print('##################################################################')
