In [1]:
import torch
import json
import os
from utils import get_dataloader
from dataset import BratsDataset
import numpy as np
import matplotlib.pyplot as plt

classes = ['WT', 'TC', 'ET']
OUT_PATH = "/home/jiwoo/바탕화면/final/brats"

def process_all(id, pred, model_sign, slide2show=85):
    plt_save_path = f"{OUT_PATH}/{id}/"
    os.makedirs(plt_save_path, exist_ok=True)
    res = re_process(pred, slide2show)
    plt.imsave(f"{plt_save_path}/{model_sign}_{slide2show}_Pred_TOTAL.png", res, cmap='copper')
    for i, clas in enumerate(classes):
        plt.imsave(f"{plt_save_path}/{model_sign}_{slide2show}_Pred_{clas}.png", pred[i][slide2show], cmap='gray')
        # plt.imsave(f"{plt_save_path}/{model_sign}_{slide2show}_GT_{clas}.png", target[i][slide2show], cmap='gray')

def re_process(pred, slide2show):
    wt = pred[0][slide2show]
    tc = pred[1][slide2show]
    et = pred[2][slide2show]
    ed = np.where((wt == 1) & (tc == 0), 2, 0)
    ncr = np.where((tc == 1) & (et == 0), 1, 0)
    et = np.where(et==1, 3, 0)
    res = (ncr+ed+et)
    res = np.where(res>4, 3, res)
    return res

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
with open('models/model_subscriptions.json', 'r') as f:
    models_info = json.load(f)
model_name = "ch3_32_interval_3_240"
model_info = next((model for model in models_info if model['model_name'] == model_name), None)
model_info

{'model_name': 'ch3_32_interval_3_240',
 'depth / in_channel / n_channel': [52, 3, 32],
 'img_size': 240,
 'used_channel': ['-t1n.nii.gz', '-t1c.nii.gz', '-t2f.nii.gz'],
 'val score(dice/jaccard)': [88, 81],
 'batch/total epoch/best epoch': [1, 50, 33],
 'resize_info': [0, 155, 3],
 'run time(m)': 1360.0}

In [3]:
model = torch.load(os.path.join('models', model_name,
                                    f"{model_name}.pth")).to(device)
model.eval()
test_dataloader =get_dataloader(dataset=BratsDataset, phase="test", img_depth=64, img_width=240, 
                                data_type=model_info['used_channel'], batch_size=1)
data_batch = next(iter(test_dataloader))
batch_id, images, targets = data_batch['Id'], data_batch['image'], data_batch['mask']
batch_id

['BraTS-GLI-00715-001']

In [4]:
images = images.to(device)
targets = targets.detach().numpy()
logits = model(images)
pred = torch.sigmoid(logits).detach().cpu().numpy()
threshold = 0.33
pred = (pred >= threshold).astype(int)

In [5]:
pred.shape

(1, 3, 155, 240, 240)

In [6]:
slides = [113]
for i in range(1):
    process_all(batch_id[i], pred[i], 'u', slide2show=slides[i])