In [None]:
from utils.dataset import CTMaskDataset
from utils.utils import  plotSomeImages, generateNpySlices, loadMatData, generateSplits
from predict import predict_vol_from_vol
from unet.unet_model import UNet
from random import randint
from eval import eval_volumes
import torch
import numpy as np
import scipy as sp
from matplotlib import pyplot as plt
from PIL import Image

## Create training data.

In [None]:
# Create list of patient [vol_idxs] from range.
all_idxs = [[a , b] for b in range(1,4) for a in range(1,23)]
# Create list of the class mask_data you want.
mask_names = ['spine_mask', 'stern_mask', 'pelvi_mask']
# Generate training/data splits from those that have .mat data matching the mask_names. 
val_idxs, trn_idxs = generateSplits(all_idxs, mask_names = mask_names)
# Generate the training .npy files.
generateNpySlices(trn_idxs[0], mask_names = mask_names)

## Verify training data.

In [None]:
# Create a CTMaskDataset from the training data.
dataset = CTMaskDataset(augment=False)

In [None]:
# Display a sample from the training data. 
ridx = randint(0, len(dataset))
verify = dataset[ridx]
ct = verify['ct'].squeeze().numpy()
target = verify['target'].squeeze().numpy()
imgs = {'ct': ct, 'target': target}
plotSomeImages(imgs, 1, 2)

## Evaluate model on vol_idxs

In [None]:
# Set up UNet. Must match model you are loading. 
device = torch.device('cuda')
# n_classes = 1 for binary class, n+1 for multi-class.
net = UNet(n_channels=1, n_classes=4, bilinear=False) 
# Load model from file.
net.load_state_dict(torch.load('C:/.py_workspace/mveUNet/unet2D/.runs/multiclass_testing/3class_results/model_state.pth'))
net.to(device)

In [None]:
# Evaluate from vol_idx list!
vol_idxs = [[2, 1], [2, 3]]
eval_volumes(net, device, vol_idxs)

## Generate prediction volume from model.

In [None]:
# Set up UNet. Must match model you are loading. 
device = torch.device('cuda')
# n_classes = 1 for binary class, n+1 for multi-class.
net = UNet(n_channels=1, n_classes=4, bilinear=False) 
# Load model from file.
net.load_state_dict(torch.load('C:/.py_workspace/mveUNet/unet2D/.runs/multiclass_testing/3class_results/model_state.pth'))
net.to(device)

In [None]:
# Predict from vol_idx!
vol_idx = [22, 3]
ct_data = loadMatData(vol_idx, data = 'ct')
pt_data = loadMatData(vol_idx, data = 'pt')
with torch.no_grad():
    pred_vol = predict_vol_from_vol(net, device, vol_idx)

In [None]:
# Save as .mat if you want...
# Something like this...
pred_dict = {'ct': ct_data,
             'pt': pt_data,
             'bg': pred_vol[0, :, :, :],
             'spine': pred_vol[1, :, :, :],
             'stern': pred_vol[2, :, :, :],
             'pelvi': pred_vol[3, :, :, :]}

savepath = "C:/.py_workspace/mveUNet/unet2D/.predictions/"
filename = "patient%d_day%d_pred.mat" % (vol_idx[0], vol_idx[1])
sp.io.savemat(savepath + filename, pred_dict, do_compression=True)