## Load model, test batch

In [1]:
import torch
from model import UNet3d
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet3d(in_channels=3, n_classes=3, n_channels=32)
model.load_state_dict(torch.load('saved_model/best-basic-e100.pth'))
model = model.to(device)
model.eval()

UNet3d(
  (conv): DoubleConv(
    (double_conv): Sequential(
      (0): Conv3d(3, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (1): GroupNorm(8, 32, eps=1e-05, affine=True)
      (2): ReLU(inplace=True)
      (3): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (4): GroupNorm(8, 32, eps=1e-05, affine=True)
      (5): ReLU(inplace=True)
    )
  )
  (enc1): Down(
    (encoder): Sequential(
      (0): MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2), padding=0, dilation=1, ceil_mode=False)
      (1): DoubleConv(
        (double_conv): Sequential(
          (0): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
          (1): GroupNorm(8, 64, eps=1e-05, affine=True)
          (2): ReLU(inplace=True)
          (3): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
          (4): GroupNorm(8, 64, eps=1e-05, affine=True)
          (5): ReLU(inplace=True)
        )
      )
    )
  )
  (

In [2]:
from utils import get_dataloader
from dataset import BratsDataset

val_dataloader = get_dataloader(dataset=BratsDataset, phase="val")
len(val_dataloader)

32

In [3]:
test_batch = next(iter(val_dataloader))

In [4]:
test_batch.keys()

dict_keys(['Id', 'image', 'mask'])

In [5]:
batch_id, images, targets = test_batch['Id'], test_batch['image'], test_batch['mask']
print(batch_id)
print(images.shape)
print(targets.shape)

['BraTS-GLI-01344-000', 'BraTS-GLI-00703-001', 'BraTS-GLI-00097-000', 'BraTS-GLI-01147-000']
torch.Size([4, 3, 78, 120, 120])
torch.Size([4, 3, 78, 120, 120])


In [6]:
images = images.to(device)
targets = targets.to(device)

In [7]:
targets = targets.detach().cpu().numpy()
targets = targets[0]
targets.shape

(3, 78, 120, 120)

### Load original 155 slices

In [8]:
import os
import nibabel as nib
import numpy as np

data_path = 'brats_data'
phase = 'val'
id_ = batch_id[0]
data_type = '-seg.nii.gz'
_whole_slices = nib.load(os.path.join(data_path, phase, id_, id_+data_type))
_whole_slices = np.asarray(_whole_slices.dataobj) #(240,240,155)
_whole_slices = _whole_slices.transpose(2,0,1)
_whole_slices.shape # (155,240,240)

(155, 240, 240)

In [9]:
def preprocess_mask_labels(mask):
    print("before preprocess", mask.shape)
    # whole tumour
    mask_WT = mask.copy()
    mask_WT[mask_WT == 1] = 1
    mask_WT[mask_WT == 2] = 1
    mask_WT[mask_WT == 3] = 1
    # include all tumours 

    # NCR / NET - LABEL 1
    mask_TC = mask.copy()
    mask_TC[mask_TC == 1] = 1
    mask_TC[mask_TC == 2] = 0
    mask_TC[mask_TC == 3] = 1
    # exclude 2 / 4 labelled tumour 

    # ET - LABEL 4 
    mask_ET = mask.copy()
    mask_ET[mask_ET == 1] = 0
    mask_ET[mask_ET == 2] = 0
    mask_ET[mask_ET == 3] = 1
    # exclude 2 / 1 labelled tumour 

    # mask = np.stack([mask_WT, mask_TC, mask_ET, mask_ED])
    mask = np.stack([mask_WT, mask_TC, mask_ET])
    return mask

In [10]:
whole_slices = preprocess_mask_labels(_whole_slices)
print("after preprocess", whole_slices.shape)

before preprocess (155, 240, 240)
after preprocess (3, 155, 240, 240)


## Run model and preprocess output data

In [11]:
with torch.no_grad():
    logits = model(images[0].unsqueeze(0))
    pred = torch.sigmoid(logits).detach().cpu().numpy()
    threshold = 0.33
    pred = (pred >= threshold).astype(int)
    print(pred.min(), pred.max(), pred.shape)

torch.Size([1, 32, 78, 120, 120])
torch.Size([1, 64, 78, 60, 60])
torch.Size([1, 128, 78, 30, 30])
torch.Size([1, 256, 78, 15, 15])
torch.Size([1, 256, 78, 7, 7])
0 1 (1, 3, 78, 120, 120)


In [12]:
from skimage.transform import resize
pred = pred[0] #(3,78,120,120)
resized_pred = resize(pred,(3,155,240,240), preserve_range=True) #(3,155,240,240)

In [13]:
pred.shape, resized_pred.shape

((3, 78, 120, 120), (3, 155, 240, 240))

In [14]:
mask_set = {
    "WT": targets[0],
    "TC": targets[1],
    "ET": targets[2],
}
pred_set = {
    "WT": pred[0],
    "TC": pred[1],
    "ET": pred[2],
}
ori_mask_set = {
    "WT": whole_slices[0],
    "TC": whole_slices[1],
    "ET": whole_slices[2],
}
ori_pred_set = {
    "WT": resized_pred[0],
    "TC": resized_pred[1],
    "ET": resized_pred[2],
}

In [15]:
import matplotlib.pyplot as plt
def comp_gt_pred(percent, slices, gt, pred):
    slide2show = int(percent/100*slices)
    print(slide2show)
    plt.suptitle(f"Compare at slice {slide2show}/{slices}({percent}%)")
    for i,clas in enumerate(gt):
        plt.subplot(2,3,i+1)
        plt.title(f'GT of {clas}')
        plt.imshow(gt[clas][slide2show], cmap='gray')
        plt.axis('off')
        plt.subplot(2,3,i+1+3)
        plt.title(f'prediction of {clas}')
        plt.axis('off')
        plt.imshow(pred[clas][slide2show], cmap='gray')
    plt.show()

In [None]:
# for percent in range(30,50,4):
#     comp_gt_pred(percent,78, mask_set, pred_set)

In [None]:
# for percent in range(30,50,4):
#     comp_gt_pred(percent,155, ori_mask_set, ori_pred_set)

## Slice mapping

In [18]:
with open('slice_to_78.txt','a') as f:
    for cut_idx, real_idx in enumerate(np.arange(0, 155, 2)):
        f.write(f"{cut_idx}, {real_idx}\n")

In [19]:
with open('slice_to_50.txt','a') as f:
    for cut_idx, real_idx in enumerate(np.arange(3, 153, 3)):
        f.write(f"{cut_idx}, {real_idx}\n")

In [21]:
start_num, end_num, interval = [0,155,2]
start_num, end_num, interval

(0, 155, 2)

In [46]:
import json

dt = [int(_) for _ in input("Enter resizing start, end, interval: ").split(',')]
data = [[cut_idx, real_idx] for cut_idx, real_idx in enumerate(range(*dt))]
with open('data.json','w') as f:
    json.dump(data, f)
