In [None]:
import os
from os.path import join, basename
import glob
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchio as tio
from torch.utils.data import DataLoader
import yaml
import json
from addict import Dict
from model import nnUnet
from dataset import TotalSegmentatorData
from metrics import DiceScore
from utils import OneHot, RunModelOnPatches
import matplotlib.pyplot as plt

In [None]:
cfile = "/home/isaiah/TotalSegmentator/results/20230330/config_01.yaml"
device = torch.device("cuda:0")

In [None]:
cfgs = Dict(yaml.load(open(cfile, "r"), Loader=yaml.Loader))
with open(cfgs.paths.labels_src, "r") as f:
    label_dict = Dict(json.load(f))
testfiles = sorted(glob.glob(join(cfgs.paths.data_dest, "test/*.npz")))
testset = [basename(file).split(".")[0] for file in testfiles]

In [None]:
print(testset)

In [None]:
ckpt = torch.load(cfgs.paths.model_ckpts_dest, map_location=device)
modelparams = ckpt["model"]

In [None]:
all_params = [val for val in modelparams.values()]

In [None]:
print("max min")
for i in range(len(all_params)):
    print(all_params[i].max().item(), all_params[i].min().item()) 

In [None]:
net = nnUnet(cfgs.model_params.channels, cfgs.model_params.num_classes).to(device)
net.load_state_dict(modelparams)

In [None]:
data = TotalSegmentatorData(device, join(cfgs.paths.data_dest, "test/"), cfgs.test_dataset_params)
testloader = DataLoader(data)

In [None]:
results = Dict()
pred_labels = []
gt_labels = []
dice_scores = []
patient_ids = []
net.eval()
with torch.no_grad():
    for pat, loc, im, gt in testloader:
        patient_ids.append(pat)
        lbl_indices = torch.unique(gt.to(torch.int64)).tolist()
        gt_labels.append([label_dict[str(i)] for i in lbl_indices])
        logits = net(im.to(device))
        logits = RunModelOnPatches(net, im, 105, 128, 128, device)
        mask =~ torch.eq(gt, 105)
        gt[gt == 105] = 0
        gt_oh = OneHot(gt, 105)
        preds = logits.argmax(1)
        dice_scores.append(DiceScore(preds, gt_oh, mask))
        lbl_indices = torch.unique(preds).tolist()
        pred_labels.append([label_dict[str(i)] for i in lbl_indices])