In [1]:
import numpy as np
import matplotlib
from matplotlib import pyplot as plt
import nibabel as nib
import torch
import torch.nn as nn
from model.unet import UNet
from skimage.morphology import remove_small_holes, remove_small_objects
from lib.viewer3D import ImageSliceViewer3D
%matplotlib inline
from skimage.transform import resize

In [2]:
morph = lambda x, axis: remove_small_holes(remove_small_objects(x))
# Loop through the .nii files in data
datadir = "./data/"

files = ["rNa_lr.nii", "PD.nii", "T1.nii", "T2.nii", "rNa_hr.nii"]

device = ("cuda" if torch.cuda.is_available() else "cpu")
net = UNet(n_channels=4)
net.load_state_dict(torch.load('./checkpoints/ResUNET_SR_L1-lr0.01Ba8.pth'))
net.to(device)
net.eval()

i = 7 # 1,2,3,4

print(f'Loading masks for file: {i}')
mask = nib.load(f'./data/Masks/Data{i}/c1PD.nii').get_fdata()
mask += nib.load(f'./data/Masks/Data{i}/c2PD.nii').get_fdata()
mask += nib.load(f'./data/Masks/Data{i}/c3PD.nii').get_fdata()
mask = mask > 0.0
mask = np.apply_over_axes(morph, mask, [1, 2])
print(f'Working on file: {i}')
img = nib.load(f'{datadir}{i}{files[0]}').get_fdata()
img[np.isnan(img)] = 0.0
# img = img*mask
img /= img.max()
mmag = np.mean(img, (1, 2))
idx = mmag > (np.max(mmag) * 0.4)

bimg = np.zeros((np.sum(idx), 4, 160, 160))
bimg[:,0] = img[idx]
for jj in range(1, 4):
    img = nib.load(f'{datadir}{i}{files[jj]}').get_fdata()
    img[np.isnan(img)] = 0.0
    # img = img * mask
    img /= img.max()
    bimg[:,jj] = img[idx]

# imglr = resize(img[idx], [sum(idx), 80, 80], anti_aliasing=True)

img2 = nib.load(f'{datadir}{i}{files[-1]}').get_fdata()
img2[np.isnan(img2)] = 0.0
# img2 = img2*mask
img2 /= img2.max()
imghr = img2[idx, :, :]

timglr = torch.as_tensor(bimg.astype(np.float32)).to(device)
imgsr = net(timglr).cpu().detach().numpy().squeeze()*mask[idx]

RuntimeError: Error(s) in loading state_dict for UNet:
	Missing key(s) in state_dict: "inc.conv.conv.1.weight", "inc.conv.conv.1.bias", "inc.conv.conv.1.running_mean", "inc.conv.conv.1.running_var", "inc.conv.conv.3.weight", "inc.conv.conv.3.bias", "inc.conv.conv.4.weight", "inc.conv.conv.4.bias", "inc.conv.conv.4.running_mean", "inc.conv.conv.4.running_var", "down1.mpconv.1.conv.1.weight", "down1.mpconv.1.conv.1.bias", "down1.mpconv.1.conv.1.running_mean", "down1.mpconv.1.conv.1.running_var", "down1.mpconv.1.conv.3.weight", "down1.mpconv.1.conv.3.bias", "down1.mpconv.1.conv.4.weight", "down1.mpconv.1.conv.4.bias", "down1.mpconv.1.conv.4.running_mean", "down1.mpconv.1.conv.4.running_var", "down2.mpconv.1.conv.1.weight", "down2.mpconv.1.conv.1.bias", "down2.mpconv.1.conv.1.running_mean", "down2.mpconv.1.conv.1.running_var", "down2.mpconv.1.conv.3.weight", "down2.mpconv.1.conv.3.bias", "down2.mpconv.1.conv.4.weight", "down2.mpconv.1.conv.4.bias", "down2.mpconv.1.conv.4.running_mean", "down2.mpconv.1.conv.4.running_var", "down3.mpconv.1.conv.1.weight", "down3.mpconv.1.conv.1.bias", "down3.mpconv.1.conv.1.running_mean", "down3.mpconv.1.conv.1.running_var", "down3.mpconv.1.conv.3.weight", "down3.mpconv.1.conv.3.bias", "down3.mpconv.1.conv.4.weight", "down3.mpconv.1.conv.4.bias", "down3.mpconv.1.conv.4.running_mean", "down3.mpconv.1.conv.4.running_var", "up1.conv.conv.1.weight", "up1.conv.conv.1.bias", "up1.conv.conv.1.running_mean", "up1.conv.conv.1.running_var", "up1.conv.conv.3.weight", "up1.conv.conv.3.bias", "up1.conv.conv.4.weight", "up1.conv.conv.4.bias", "up1.conv.conv.4.running_mean", "up1.conv.conv.4.running_var", "up2.conv.conv.1.weight", "up2.conv.conv.1.bias", "up2.conv.conv.1.running_mean", "up2.conv.conv.1.running_var", "up2.conv.conv.3.weight", "up2.conv.conv.3.bias", "up2.conv.conv.4.weight", "up2.conv.conv.4.bias", "up2.conv.conv.4.running_mean", "up2.conv.conv.4.running_var", "up3.conv.conv.1.weight", "up3.conv.conv.1.bias", "up3.conv.conv.1.running_mean", "up3.conv.conv.1.running_var", "up3.conv.conv.3.weight", "up3.conv.conv.3.bias", "up3.conv.conv.4.weight", "up3.conv.conv.4.bias", "up3.conv.conv.4.running_mean", "up3.conv.conv.4.running_var". 
	Unexpected key(s) in state_dict: "inc.conv.conv.2.weight", "inc.conv.conv.2.bias", "down1.mpconv.1.conv.2.weight", "down1.mpconv.1.conv.2.bias", "down2.mpconv.1.conv.2.weight", "down2.mpconv.1.conv.2.bias", "down3.mpconv.1.conv.2.weight", "down3.mpconv.1.conv.2.bias", "up1.conv.conv.2.weight", "up1.conv.conv.2.bias", "up2.conv.conv.2.weight", "up2.conv.conv.2.bias", "up3.conv.conv.2.weight", "up3.conv.conv.2.bias". 

In [None]:
ImageSliceViewer3D(imghr, imgsr, figsize=(12,12))

In [None]:
np.mean(np.abs(imgsr-imghr))

In [None]:
from piqa import SSIM
ssim = SSIM(n_channels=1)
x = torch.as_tensor(np.expand_dims(imgsr.astype(np.float32),1))
y = torch.as_tensor(np.expand_dims(imghr.astype(np.float32),1))
N = x.shape[0]
ssim(x[N//2-2:N//2+2], y[N//2-2:N//2+2])

In [None]:
from skimage.transform import rescale, resize
imgscaled = resize(img, (y.shape[0],160,160), anti_aliasing=False).astype(np.float32)
z = torch.as_tensor(np.expand_dims(imgscaled,1))
z.shape
ssim(z[N//2-2:N//2+2], y[N//2-2:N//2+2])

In [None]:

ImageSliceViewer3D(img, imghr, imgscaled, figsize=(12,12))