In [2]:
import os
import numpy as np
import torch
import torch.nn.functional as F
import nibabel as nib
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm

from thesisproject.models import UNet
from thesisproject.predict import Predict
from thesisproject.data import ImagePairDataset

In [3]:
class Square_pad:
    def __call__(self, image: torch.Tensor):
        imsize = image.shape
        max_edge = np.argmax(imsize)
        pad_amounts = [imsize[max_edge] - imsize[0], imsize[max_edge] - imsize[1], imsize[max_edge] - imsize[2]]

        padding = [int(np.floor(pad_amounts[0] / 2)),
                   int(np.ceil(pad_amounts[0] / 2)),
                   int(np.floor(pad_amounts[1] / 2)),
                   int(np.ceil(pad_amounts[1] / 2)),
                   int(np.floor(pad_amounts[2] / 2)),
                   int(np.ceil(pad_amounts[2] / 2)),] #left, right, top, bottom, front, back
        padding = tuple(padding[::-1])
        
        padded_im = F.pad(image, padding, "constant", 0)
        return padded_im
    
def test_collate(image):
    return image

volume_transform = Square_pad()

In [4]:
label_keys = ["Lateral femoral cart.",
              "Lateral meniscus",
              "Lateral tibial cart.",
              "Medial femoral cartilage",
              "Medial meniscus",
              "Medial tibial cart.",
              "Patellar cart.",
              "Tibia"]
net = UNet(1, 9, class_names=label_keys)

In [7]:
checkpoint = torch.load("../model_checkpoint.pt", map_location=torch.device('cpu'))
net.load_state_dict(checkpoint["model_state_dict"])

predict = Predict(net, batch_size=8, show_progress=True)

In [19]:
image_files = np.loadtxt("../subject_images.txt", dtype=str)
print(f"Found {image_files.shape[0]} images")

Found 1298 images


In [9]:
def filename_to_subject_info(filename):
    subject_id = filename[:7]
    if filename[8] == "R":
        knee = filename[8:13]
        visit = int(filename[15:17])
    else:
        knee = filename[8:12]
        visit = int(filename[14:16]) 
    return subject_id, knee, visit

In [18]:
path = "/Volumes/Expansion/NIFTY/"
for file in sorted(image_files):
    subject_id, knee, visit = filename_to_subject_info(file)
    print(f"subject {subject_id}, {knee} knee, visit {visit}:")
    image_obj = nib.load(path + file)
    image_data = image_obj.get_fdata()
    image_tensor = torch.from_numpy(image_obj.get_fdata())
    image_tensor = volume_transform(image_tensor)
    
    prediction = predict(image_tensor.float())
    break

subject 9008561, Left knee, visit 0:


First view:   0%|                                   | 0/1152 [00:29<?, ?slice/s]
First view:   8%|█▊                      | 88/1152 [05:12<1:02:22,  3.52s/slice]

KeyboardInterrupt: 