In [7]:
import sys

sys.path.append('../')

# torch
import torch.nn as nn
from torch.utils.data import DataLoader

from xraydataset import XRayDataset, split_data
from utils import get_sorted_files_by_type, label2rgb, decode_rle_to_mask

from constants import TRAIN_DATA_DIR, CLASSES, PALETTE, TEST_DATA_DIR

from argparse import ArgumentParser

import albumentations as A

import os

import numpy as np

import shutil

from PIL import Image, ImageDraw
import cv2

import wandb

import pandas as pd

import datetime

In [8]:
def create_pred_mask_dict(csv_path, input_size):
    df = pd.read_csv(csv_path)

    mask_dict = dict()

    # 그룹화하여 처리
    grouped = df.groupby('image_name')
    for idx, (image_name, group) in enumerate(grouped):
        if idx > 10:
            break
        print(f'creating mask for {image_name}...')
        masks = dict()
        mask_test = []
        for _, row in group.iterrows():
            classname = row['class']
            rle = row['rle']
            if isinstance(rle, str):
                mask = decode_rle_to_mask(rle, 2048, 2048)
                mask_resized = np.array(Image.fromarray(mask).resize((input_size, input_size)))
                masks[classname]=mask_resized
                mask_test.append(mask_resized)

        mask_dict[image_name] = mask_test
    print('mask creation from csv is done')
    return mask_dict

In [9]:
def draw_outline(image, label, is_binary = False):

    draw = ImageDraw.Draw(image)

    for i, class_label in enumerate(label):
        if class_label.max() > 0:  # Only process if the class is present in the image
            color = PALETTE[i] if not is_binary else 1

            

    return image

In [10]:
def visual_dataset(visual_loader, mask_dict):
    save_dir = 'visualize/'

    if os.path.exists(save_dir):  
        shutil.rmtree(save_dir)

    os.makedirs(save_dir, exist_ok=True)    

    for image_names, images in visual_loader:
        for image_name, image in zip(image_names, images):
            img = image.permute(1, 2, 0).numpy()
            img = (img * 255).astype(np.uint8)
            
            mask = mask_dict.get(image_name, None)
            if mask is not None:  
                for i, class_label in enumerate(mask):
                    print(class_label.shape)
                    img[class_label == 1] = PALETTE[i]
                    
                img = Image.fromarray(img)  # [C, H, W] -> [H, W, C]로 변환   
                img.save(os.path.join(save_dir, image_name))

In [11]:
def visualize_test(csv_path, input_size=1024):
    image_root = os.path.join(TEST_DATA_DIR, 'DCM')

    pngs = get_sorted_files_by_type(image_root, 'png')

    visualize_dataset = XRayDataset(image_files=np.array(pngs), label_files=None, transforms=A.Resize(input_size, input_size))

    visual_loader = DataLoader(
        dataset=visualize_dataset, 
        batch_size=8,
        shuffle=False,
        num_workers=1,
        drop_last=False,
    )

    mask_dict = create_pred_mask_dict(csv_path, input_size)
    
    visual_dataset(visual_loader, mask_dict)

In [12]:
visualize_test('../output.csv')

creating mask for image1661319116107.png...
creating mask for image1661319145363.png...
creating mask for image1661319356239.png...
creating mask for image1661319390106.png...
creating mask for image1661320372752.png...
creating mask for image1661320397148.png...
creating mask for image1661320538919.png...
creating mask for image1661320557045.png...
creating mask for image1661320671343.png...
creating mask for image1661320722689.png...
creating mask for image1661320864475.png...
mask creation from csv is done
(1024, 1024)
(1024, 1024)
(1024, 1024)
(1024, 1024)
(1024, 1024)
(1024, 1024)
(1024, 1024)
(1024, 1024)
(1024, 1024)
(1024, 1024)
(1024, 1024)
(1024, 1024)
(1024, 1024)
(1024, 1024)
(1024, 1024)
(1024, 1024)
(1024, 1024)
(1024, 1024)
(1024, 1024)
(1024, 1024)
(1024, 1024)
(1024, 1024)
(1024, 1024)
(1024, 1024)
(1024, 1024)
(1024, 1024)
(1024, 1024)
(1024, 1024)
(1024, 1024)
(1024, 1024)
(1024, 1024)
(1024, 1024)
(1024, 1024)
(1024, 1024)
(1024, 1024)
(1024, 1024)
(1024, 1024)
(102