In [None]:
import os
from os.path import join, basename
import glob
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchio as tio
from torch.utils.data import DataLoader, DistributedSampler
import yaml
import json
from addict import Dict
from model import nnUnet
from dataset import TotalSegmentatorData
from metrics import DiceScore
from utils import OneHot, RunModelOnPatches
import matplotlib.pyplot as plt

In [None]:
def SetupDDP(self, rank, world_size):
    """
    Args:
        rank: Unique identifier of each process
        world_size: Total number of processes
    """
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355"
    init_process_group(backend="nccl", rank=rank, world_size=world_size)
    return

def ShutdownDDP(self):
    destroy_process_group()
    return

In [None]:
cfile = "/home/isaiah/TotalSegmentator/results/20230330/config_01.yaml"
clsfile = "/home/isaiah/TotalSegmentatorProj/metadata/classes.json"
testsetdir = "/home/isaiah/TotalSegmentator/preprocessed2/val/"
ckptfile = "/home/isaiah/TotalSegmentator/results/20230330/nnunet_ckpt_01.pth"
use_ddp = False
device = torch.device("cuda:0")

In [None]:
cfgs = Dict(yaml.load(open(cfile, "r"), Loader=yaml.Loader))
with open(clsfile, "r") as f:
    label_dict = Dict(json.load(f))
testfiles = sorted(glob.glob(join(testsetdir, "*.npz")))
testset = [basename(file).split(".")[0] for file in testfiles]

In [None]:
print(testset)

In [None]:
ckpt = torch.load(ckptfile, map_location=device)
modelparams = ckpt["model"]

In [None]:
all_params = [val for val in modelparams.values()]

In [None]:
print("max min")
for i in range(len(all_params)):
    print(all_params[i].max().item(), all_params[i].min().item()) 

In [None]:
net = nnUnet(cfgs.model_params.channels, cfgs.model_params.num_classes).to(device)
net.load_state_dict(modelparams)

In [None]:
data = TotalSegmentatorData(device, testsetdir, cfgs.test_dataset_params)
testloader = DataLoader(data)

In [None]:
all_dice_scores = Dict()
net.eval()
with torch.no_grad():
    for pat, loc, im, gt in testloader:
        print(pat)
        print(im.shape)
        if pat < 1:
            continue
        lbl_indices = torch.unique(gt.to(torch.int64)).tolist()
        labels = [label_dict[str(i)] for i in lbl_indices]
        logits = net(im.to(device))
        #logits = RunModelOnPatches(net, im, 105, 128, 128, device)
        #mask =~ torch.eq(gt, 105)
        #gt[gt == 105] = 0
        #gt_oh = OneHot(gt, 105)
        preds = logits.argmax(1)
        break

In [None]:
lbl_indices = torch.unique(preds).tolist()
pred_labels = [label_dict[str(i)] for i in lbl_indices]

In [None]:
print(logits.dtype)

In [None]:
print(pred_labels)

In [None]:
print(labels, end="\n\n")

In [None]:
fig, axs = plt.subplots(4, 4, figsize=(12, 12))
for i in range(4):
    for j in range(4):
        axs[i, j].imshow(im.squeeze()[:, :, i*j+j*8], cmap="bone")
plt.show()

In [None]:
fig, axs = plt.subplots(4, 4, figsize=(12, 12))
for i in range(4):
    for j in range(4):
        axs[i, j].imshow(preds.cpu()[0, :, :, i*j+j*8], cmap="jet")
plt.show()

In [None]:
fig, axs = plt.subplots(4, 4, figsize=(12, 12))
for i in range(4):
    for j in range(4):
        axs[i, j].imshow(gt.squeeze()[:, :, i*j+j*8], cmap="jet")
plt.show()

In [None]:
image_file = "/home/dataset/TotalSegmentor/Totalsegmentator_dataset/s0021/ct.nii.gz"

In [None]:
import nibabel as nib
im = nib.load(image_file)
arr = im.get_fdata()
print(type(arr))
print(arr.shape)

In [None]:
sh = np.asarray(arr.shape)
grid_spacing = (np.asarray((185, 185, 218))* np.array((1., 1., 1.))) / (np.array((9, 9, 9)) - 2)
print(grid_spacing / 2)

In [None]:
fig, axs = plt.subplots(7, 4, figsize=(12, 20))
for i in range(7):
    for j in range(4):
        axs[i, j].imshow(arr.squeeze()[:, :, i*j+j*8], cmap="bone")
plt.show()