In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm

import cv2

import torch
import albumentations as A

import mmcv
from mmcv import Config
from mmcv.runner import load_checkpoint
from mmcv.parallel import MMDataParallel

from mmseg.apis import single_gpu_test
from mmseg.datasets import build_dataloader, build_dataset
from mmseg.models import build_segmentor

import pycocotools
from pycocotools.coco import COCO


In [None]:

# config file 들고오기
cfg = Config.fromfile('/opt/ml/Git/p_stage_img_seg/work_dirs/swin-b/swin-b.py')

# dataset config 수정
cfg.data.test.test_mode = True
cfg.data.samples_per_gpu = 2
cfg.data.workers_per_gpu = 2

cfg.seed = 123456
cfg.gpu_ids = [0]
cfg.work_dir = '/opt/ml/Git/p_stage_img_seg/work_dirs/swin-b/'

# cfg.optimizer_config.grad_clip = dict(max_norm=35, norm_type=2)
cfg.model.train_cfg = None

In [None]:
# build dataset & dataloader
dataset = build_dataset(cfg.data.test)
data_loader = build_dataloader(
        dataset,
        samples_per_gpu=cfg.data.samples_per_gpu,
        workers_per_gpu=cfg.data.workers_per_gpu,
        dist=False,
        shuffle=False)

In [None]:
# checkpoint path
checkpoint_path = os.path.join(cfg.work_dir, f'latest.pth')

model = build_segmentor(cfg.model, test_cfg=cfg.get('test_cfg')) # build detector
# checkpoint = load_checkpoint(model, checkpoint_path, map_location='cpu') # ckpt load
checkpoint = load_checkpoint(model, checkpoint_path) # ckpt load

model.CLASSES = dataset.CLASSES
model.PALETTE = dataset.PALETTE
model = MMDataParallel(model.cuda(), device_ids=[0])


In [None]:
torch.cuda.empty_cache()
output = single_gpu_test(model, data_loader) # output 계산

## Load COCO json

In [None]:
prediction_strings = []
file_names = []
coco = COCO('/opt/ml/segmentation/input/data/test.json')
image_path = '/opt/ml/segmentation/input/data/'
img_ids = coco.getImgIds()

## Extract semantic mask and image

In [None]:
from pathlib import Path

In [None]:
def merge_image(insert_image, insert_mask, class_id, base_image=None, base_mask=None):
    """
    Args:
        insert_image: 
        insert_mask:
        class_id:
        base_image:
        base_mask:
    """
    tmp_img = np.ones((512,512,3), dtype=np.uint8) * 255
    if type(base_image) is type(None):
        base_image = tmp_img.copy()
    tmp_img[:,:,0] = np.where(insert_mask == class_id, insert_image[:,:,0], base_image[:,:,0])  # R
    tmp_img[:,:,1] = np.where(insert_mask == class_id, insert_image[:,:,1], base_image[:,:,1])  # G
    tmp_img[:,:,2] = np.where(insert_mask == class_id, insert_image[:,:,2], base_image[:,:,2])  # B
    if type(base_mask) is type(None):
        base_mask = np.zeros((512,512), dtype=np.uint8)
    tmp_mask = np.where(insert_mask == class_id, insert_mask, base_mask)  # mask
    return tmp_img, tmp_mask

In [None]:
classes = ("Background", "General trash", "Paper", "Paper pack", "Metal", "Glass", "Plastic", "Styrofoam", "Plastic bag", "Battery", "Clothing")

# ------------------------------------------- #
target_category = 1
# ------------------------------------------- #

save_dir = f"/opt/ml/segmentation/extract_image/{target_category}"
Path(f"{save_dir}/image/").mkdir(exist_ok=True, parents=True)
Path(f"{save_dir}/mask/").mkdir(exist_ok=True, parents=True)

for i, seg_mask in tqdm(enumerate(output)):
    if not np.isin(target_category, seg_mask):
        continue
    image_info = coco.loadImgs(coco.getImgIds(imgIds=img_ids[i]))[0]
    image = cv2.imread(image_path + image_info['file_name'])
    catImage, catSegMask = merge_image(image, seg_mask, target_category)
    _path = Path(image_info['file_name'])
    image_name = f"{save_dir}/image/{_path.parent}_{_path.stem}.png"
    mask_name = f"{save_dir}/mask/{_path.parent}_{_path.stem}.mask.png"
    cv2.imwrite(image_name, catImage)
    cv2.imwrite(mask_name, catSegMask)

## Inference

In [None]:
def prepare_submission_mask(output, img_size: list=[512, 512]):
    dummy_image = np.zeros(img_size)
    resized_tmp_mask_arr = A.Resize(256, 256)(image=dummy_image, mask=output)['mask']
    return resized_tmp_mask_arr.flatten()

In [None]:

for i, out in tqdm(enumerate(output)):
    image_info = coco.loadImgs(coco.getImgIds(imgIds=img_ids[i]))[0]
    prediction_strings.append(' '.join(str(e) for e in prepare_submission_mask(out).tolist()))
    file_names.append(image_info['file_name'])

submission = pd.DataFrame()
submission['image_id'] = file_names
submission['PredictionString'] = prediction_strings
submission.to_csv('/opt/ml/Git/p_stage_img_seg/submission/mmseg_pointrend_submission.yb007.csv', index=None)
submission.head()

In [None]:
for i in range(10):
    fig, ax = plt.subplots(1, 2)
    # a1 = instanceMask2semtsegMask(output[i], score_thrs=0.2)
    a1 = prepare_submission_mask(output[i])
    # a2 = instanceMask2semtsegMaskV2(output[i], score_thrs=0.3)
    ax[0].imshow(a1.reshape(256, 256))
    # ax[1].imshow(a2.reshape(256, 256))
    plt.show()