In [None]:
!pip install nibabel elasticdeform

In [40]:
import os
import torch
import torch.nn.functional as F
import nibabel as nib
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from tqdm.notebook import trange

from thesisproject.utils import create_animation, create_overlay_figure
from thesisproject.predict import Predict 
from thesisproject.models import UNet

plt.rcParams['figure.figsize'] = [20, 10]

In [2]:
net = UNet(1, 9)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
net.to(device)

checkpoint_path = os.path.join("model_saves", "model_checkpoint.pt")
checkpoint = torch.load(checkpoint_path)
net.load_state_dict(checkpoint['model_state_dict'])

predict = Predict(net, batch_size=8, show_progress=True)

In [3]:
scan = nib.load("/notebooks/knee_data/train/images/9003406_20041118_SAG_3D_DESS_LEFT_016610296205.nii.gz").get_fdata()
scan_tensor = torch.from_numpy(scan).float().to(device)
scan_tensor -= torch.min(scan_tensor)
scan_tensor /= torch.max(scan_tensor)

label = nib.load("/notebooks/knee_data/train/labels/9003406_20041118_SAG_3D_DESS_LEFT_016610296205.nii.gz").get_fdata()
label_tensor = torch.from_numpy(label).long()

In [4]:
prediction = predict(scan_tensor)

Third view: 100%|██████████| 928/928 [00:31<00:00, 29.61slice/s] 


In [34]:
print(np.nonzero(prediction)[2])

tensor([ 57, 300,  57])


In [None]:
frames = []
h, w, d = scan_tensor.shape

for i in trange(d):
    scan_slice = scan_tensor[:, :, i].unsqueeze(0).unsqueeze(0)
    label_slice = label_tensor[:, :, i].unsqueeze(0).unsqueeze(0)
    pred_slice = prediction[:, :, i].unsqueeze(0).unsqueeze(0)
    
    tmp_fig, tmp_axs = create_overlay_figure(scan_slice, label_slice, pred_slice, images_per_batch=1)
    tmp_fig.show()

    if i==0:
        fig, ax = plt.gcf(), plt.gca()
    else:
        dummy_fig, ax = plt.gcf(), plt.gca()
        print(ax)
        ax.set(animated=True)
        ax.remove()
        ax.figure = fig
        fig.add_axes(ax)
        plt.close(dummy_fig)
        
    fig.tight_layout()
        
    frames.append([ax])

ani = animation.ArtistAnimation(fig, frames, interval=50, blit=True, repeat_delay=1000)

In [17]:
ani.save("pred.mp4")

In [42]:
i = 300
scan_slice = scan_tensor[:, i, :].unsqueeze(0).unsqueeze(0)
label_slice = label_tensor[:, i, :].unsqueeze(0).unsqueeze(0)
pred_slice = F.one_hot(prediction[:, i, :].unsqueeze(0).unsqueeze(0), dims=1)


tmp_fig, tmp_axs = create_overlay_figure(scan_slice, label_slice, pred_slice, images_per_batch=1)
tmp_fig.show()

TypeError: one_hot() got an unexpected keyword argument 'dims'