In [None]:
import sys
import os 
project_root = os.path.abspath('/home/eiden/eiden/octc-cascade')
sys.path.append(project_root)

import torch 
from torchvision import transforms
from torch.utils.data import DataLoader
import os
import cv2
import numpy as np

from models.segment import load_segment_model 
from models.inpaint import load_inpaint_model

os.environ['KMP_DUPLICATE_LIB_OK']='True'
device = 'cuda' if torch.cuda.is_available() else 'cpu'


from utils.__init__ import *
from utils.dataset import Inference_Cascade_CustomDataset

import matplotlib.pyplot as plt

In [None]:
ts_img_dir = '/mnt/HDD/octc/seg_data/test_img'
ts_mask_dir = '/mnt/HDD/octc/seg_data/test_mask'

width, height = 512,512
test_transform = transforms.Compose([
    transforms.Resize((width, height)),
    transforms.ToTensor(),
    # transforms.Normalize(mean=[0.5], std=[0.5])
])

test_dataset = Inference_Cascade_CustomDataset(
    image_dir = ts_img_dir,
    transform= test_transform,
    seed = 627
)

ts_batch = 9
test_loader = DataLoader(dataset = test_dataset, batch_size = ts_batch, shuffle = False)

In [None]:
for images, paths in test_loader:
    # mask가 0이 아닌 부분에 대해 image를 mask로 대체
    plt.imshow(images[0].permute(1,2,0), cmap = 'gray')
    break


In [None]:
class cascade_models_load:
    def __init__(self, seg_model_path, inpaint_model_path, device):
        self.seg_model_name = seg_model_path.split('/')[-2]
        self.inpaint_model_name = inpaint_model_path.split('/')[-2]
        self.seg_model_path = seg_model_path
        self.inpaint_model_path = inpaint_model_path
        self.device = device
        
    def init_seg_model(self):
        model_save_path = os.path.dirname(self.seg_model_path)
        model_version = self.seg_model_path.split('/')[-1]
        if self.seg_model_path.split('/')[-2].split('_')[0] == 'monai':
            model_name = 'monai_swinunet'
        else:
            model_name = self.seg_model_path.split('/')[-2].split('_')[0]
        print(f" Model save path : {model_save_path}")
        print(f" Model version : {model_version}")
        print(f" Model name : {model_name}")
        
        self.load_seg_model(model_save_path, model_version, model_name)
        
    def load_seg_model(self, model_save_path, model_version, model_name):
        checkpoint = torch.load(os.path.join(model_save_path, model_version), map_location= self.device)['model_state_dict']
        
        seg_model_loader = load_segment_model.segmentation_models_loader(
            model_name = model_name, width = width, height = height
        )
        self.seg_model = seg_model_loader.load_model().to(self.device)
        self.seg_model.load_state_dict(checkpoint)
    
    def init_inpaint_model(self):
        model_save_path = os.path.dirname(self.inpaint_model_path)
        model_version = self.inpaint_model_path.split('/')[-1]
        model_name = self.inpaint_model_path.split('/')[-2].split('_')[0]
        print(f" Model save path : {model_save_path}")
        print(f" Model version : {model_version}")
        print(f" Model name : {model_name}")
        
        self.load_inpaint_model(model_save_path, model_version, model_name)

    def load_inpaint_model(self, model_save_path, model_version, model_name):
        checkpoint = torch.load(os.path.join(model_save_path, model_version), map_location= self.device)['netG_state_dict']
        inpaint_model_loader = load_inpaint_model.inpainting_models_loader(
            model_name = model_name, width = width, height = height
        )
        self.inpaint_model = inpaint_model_loader.load_model().to(self.device)
        self.inpaint_model.load_state_dict(checkpoint)
    def get_cascade_model_name(self):
        cascade_model_name = self.seg_model_name + '@' + self.inpaint_model_name
        return cascade_model_name 
        
        
    def load_models(self):
        self.init_seg_model()
        self.init_inpaint_model()
        
        return self.seg_model, self.inpaint_model
        
cascade_model_loader = cascade_models_load(
    seg_model_path = '/mnt/HDD/oci-seg_models/monai_swinunet_v4_240530/model_400.pt',
    inpaint_model_path = '/mnt/HDD/oci_models/aotgan/OCI-GAN_v3_240508/model_64.pt',
    # inpaint_model_path = '/mnt/HDD/oci_models/models/VAE_v1_240510/model_27.pt',
    
    device = device
)
seg_model, inpaint_model = cascade_model_loader.load_models()
cascade_model_name = cascade_model_loader.get_cascade_model_name()

In [None]:
def plot_cascade_result(image, segment_output, inpaint_mask, inpaint_input, pred_image, inpaint_output, save_path):
    plt.figure(dpi = 256, figsize= (12,8))
    plt.subplot(231)
    plt.imshow(image, cmap= 'gray')
    plt.title('Segment Input')
    plt.subplot(232)
    plt.imshow(segment_output, cmap= 'gray')
    plt.title('Segment Result')
    plt.subplot(233)
    plt.imshow(inpaint_mask, cmap= 'gray')
    plt.title('Inpainting Mask[PreProcess]')
    plt.subplot(234)
    plt.imshow(inpaint_input, cmap= 'gray')
    plt.title('Inpainting Input')
    plt.subplot(235)
    plt.imshow(pred_image, cmap= 'gray')
    plt.title('Inpainting Prediction')
    plt.subplot(236)
    plt.imshow(inpaint_output, cmap= 'gray')
    plt.title('Inpainting Results[PostProcess]')
    plt.tight_layout()
    # plt.savefig(save_path)
    plt.show()
    plt.close()
    


def mask_preprocessing(images, seg_masks):
    # seg_masks를 numpy로 변환후 opencv의 dilation 을 통해 확장한 후 다시 tensor로 변환 
    seg_masks = seg_masks.cpu().detach().numpy().squeeze(1)
    kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5,5))
    dilated_seg_masks = []
    for seg_mask in seg_masks:
        dilated_seg_mask = cv2.dilate(seg_mask, kernel, iterations=4)
        dilated_seg_mask = cv2.erode(dilated_seg_mask, kernel, iterations=2)
        dilated_seg_masks.append(dilated_seg_mask)
    
    # dilated_seg_masks를 다시 numpy 배열로 변환
    dilated_seg_masks = np.array(dilated_seg_masks)
    # 다시 torch 텐서로 변환하고, 원래 디바이스로 이동
    dilated_seg_masks = torch.tensor(dilated_seg_masks).unsqueeze(1).to(images.device)
    
    input_images = images.clone()
    inpaint_masks = images.clone()
    # -1~1로 범위로 정규화 되어있는 inpaint_mask를 0~1 범위로 다시 정규화
    # inpaint_masks = (inpaint_masks + 1) / 2
    # inpaint_masks = inpaint_masks / inpaint_masks.max()
    
    # 이미지와 마스크를 곱해서 배경을 제거
    inpaint_masks = inpaint_masks * dilated_seg_masks

    # input_images = (input_images + 1) / 2
    # input_images = input_images / input_images.max()
    input_images = input_images.repeat(1,3,1,1) # Inpaint Model 입력값에 맞춰주기 위함 
    

    return input_images, inpaint_masks

def compute_composite_images(input_images, pred_images, inpaint_masks):
    ## mask에서 0이 아닌 부분을 GT로 대체, 이때 마스크는 0~1사이의 값을 가짐 
    comp_images = input_images.clone()
    comp_images[inpaint_masks.repeat(1,3,1,1) != 0] = pred_images[inpaint_masks.repeat(1,3,1,1) != 0]
    return comp_images

In [None]:
threshold = 0.5
save_dir = os.path.join('/mnt/HDD/oci_cascade_models', cascade_model_name)
os.makedirs(save_dir, exist_ok= True)
with torch.no_grad():
    seg_model.eval(), inpaint_model.eval()
    for images, paths in test_loader:
        images = images.to(device)
        
        # medical mark segment mask 생성 
        seg_outputs = seg_model(images)
        seg_outputs = torch.sigmoid(seg_outputs)
        seg_outputs = (seg_outputs > threshold).float()
        # inpainting model에 넣기 위한 mask 생성
        inpaint_inputs, inpaint_masks = mask_preprocessing(images, seg_outputs)
        
        ################################################
        # pred_images = inpaint_model(inpaint_inputs, seg_outputs) # ocigan
        pred_images = inpaint_model(inpaint_inputs, inpaint_masks) # ocigan
        # pred_images, _, _, _ = inpaint_model(images, inpaint_masks) # vae  
        # pred_images= inpaint_model(images, inpaint_masks) #unet
        ################################################
        
        # Post processing
        comp_images = compute_composite_images(inpaint_inputs, pred_images, inpaint_masks)
        # 결과 shpae 
        print('\n', f"Image shape : {images.shape}, Segment shape : {seg_outputs.shape})")
        print(f"Inpaint Input shape : {inpaint_inputs.shape} Inpaint output shape : {comp_images.shape}")
        
        # save_dir 과 paths를 이용하여 결과 저장
        save_paths = [os.path.join(save_dir, os.path.basename(path)) for path in paths]
        for i in range(len(test_loader)):
            plot_cascade_result(
                images[i,0].cpu().numpy(), seg_outputs[i,0].cpu().numpy(), inpaint_masks[i,0].cpu().numpy(),
                inpaint_inputs[i,0].cpu().numpy(), pred_images[i,0].cpu().numpy(), comp_images[i,0].cpu().numpy(),
                save_path = save_paths[i] 
            )
        break 
            