In [None]:
import os
from os.path import join, basename
import glob
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchio as tio
from torch.utils.data import DataLoader, DistributedSampler
import yaml
import json
from addict import Dict
from model import nnUnet
from dataset import TotalSegmentatorData
from metrics import DiceScore
from utils import OneHot

In [None]:
def SetupDDP(self, rank, world_size):
    """
    Args:
        rank: Unique identifier of each process
        world_size: Total number of processes
    """
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355"
    init_process_group(backend="nccl", rank=rank, world_size=world_size)
    return

def ShutdownDDP(self):
    destroy_process_group()
    return

In [None]:
cfile = "/home/isaiah/TotalSegmentator/results/20230330/config_01.yaml"
clsfile = "/home/isaiah/TotalSegmentatorProj/metadata/classes.json"
testsetdir = "/home/isaiah/TotalSegmentator/preprocessed2/test/"
ckptfile = "/home/isaiah/TotalSegmentator/results/20230329/nnunet_ckpt_01.pth"
use_ddp = False
device = torch.device("cpu")

In [None]:
cfgs = Dict(yaml.load(open(cfile, "r"), Loader=yaml.Loader))
with open(clsfile, "r") as f:
    label_dict = Dict(json.load(f))
testfiles = sorted(glob.glob(join(testsetdir, "*.npz")))
testset = [basename(file).split(".")[0] for file in testfiles]

In [None]:
print(testset)

In [None]:
ckpt = torch.load(ckptfile, map_location=device)
modelparams = ckpt["model"]

In [None]:
all_params = [val for val in modelparams.values()]

In [None]:
print("max min")
for i in range(len(all_params)):
    print(all_params[i].max().item(), all_params[i].min().item()) 

In [None]:
net = nnUnet(cfgs.model_params.channels, cfgs.model_params.num_classes)
net.load_state_dict(modelparams)

In [None]:
data = TotalSegmentatorData(device, testsetdir, cfgs.test_dataset_params)
testloader = DataLoader(data)

In [None]:
all_dice_scores = Dict()
net.eval()
with torch.no_grad():
    for pat, loc, im, gt in testloader:
        subj = tio.Subject({
            "image": tio.ScalarImage(tensor=im.squeeze(1)),
            "seg": tio.LabelMap(tensor=gt.squeeze(1))
        })
        patcher = tio.GridSampler(subj, patch_size=128, patch_overlap=64, padding_mode=None)
        patch_scores = {}
        for patch in patcher:
            inp = patch["image"].data.to(device).unsqueeze(0)
            gt = patch["seg"].data.to(device).unsqueeze(0)
            loc = patch["location"].data.
            out = net(inp)
            mask =~ torch.eq(gt, 105)
            gt[gt == 105] = 0
            gt_oh = OneHot(gt, 105)
            dice = DiceScore(F.softmax(out, 1), gt_oh, mask)
            patch_scores[patch[(dice.squeeze().tolist())
            break
        all_dice_scores[pat.item()] = scores
        break

In [None]:
print(all_dice_scores)