In [20]:
"""Evaluate model on test set and report results"""

import os
import sys
from tqdm import tqdm
import torch
from monai.inferers import sliding_window_inference
from monai.data import decollate_batch
from monai.transforms import (
    AsDiscrete,
    Compose,
    EnsureType,
)
from monai.utils import set_determinism

from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import random_split
import torch
import os
import sys
import yaml
import random
import glob
import pandas as pd
from pathlib import Path
sys.path.append("/home/local/VANDERBILT/litz/github/MASILab/lobe_seg")
from dataloader import train_dataloader, val_dataloader
from models import unet256, unet512

In [27]:
def load_config(config_name, config_dir):
    with open(os.path.join(config_dir, config_name)) as file:
        config = yaml.load(file, Loader=yaml.FullLoader)
    return config

def test(config,
         config_id,
         device,
         model,
         model_path,
         test_metric,
         test_loader):

    device = torch.device(config["device"])
    model.load_state_dict(torch.load(model_path))
    model.eval()
    post_pred = Compose([EnsureType(), AsDiscrete(argmax=True, to_onehot=6)])
    post_label = Compose([EnsureType(), AsDiscrete(to_onehot=6)])
    with torch.no_grad():
        for test_data in tqdm(test_loader):
            test_inputs, test_labels = (
                test_data["image"].to(device),
                test_data["label"].to(device),
            )
            roi_size = config["crop_shape"]
            sw_batch_size=4
            test_outputs = sliding_window_inference(test_inputs, roi_size, sw_batch_size, model)
            print(test_outputs.shape)
            test_outputs = [post_pred(i) for i in decollate_batch(test_outputs)]
            test_labels = [post_label(i) for i in decollate_batch(test_labels)]
            # Append metric of each class to buffer
            test_metric(y_pred=test_outputs, y=test_labels)
        # Record metrics and compute mean over test set
        test_dices = test_metric.aggregate()
        class_means = torch.mean(test_dices, dim=0)
        mean = torch.mean(test_dices)
        test_dices_df = pd.DataFrame(test_dices.detach().cpu().numpy())

    # Log best dice
    # print(f"All scores: {test_dices_df}")
    print(f"Average class scores: {class_means}")
    print(f"Average score overall: {mean}")


In [None]:
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="1"
device = torch.device("cuda:0")
print(device)

In [14]:
config_id = "0327unet512"
MODEL_DIR = os.path.join(config["model_dir"], config_id)
model_path = os.path.join(MODEL_DIR, f"{config_id}_best_model.pth")

In [28]:
CONFIG_DIR = "/home/local/VANDERBILT/litz/github/MASILab/lobe_seg/configs"
config = load_config(f"Config_{config_id}.YAML", CONFIG_DIR)
DATA_DIR = config["test_dir"]

# Set randomness
set_determinism(seed=config["random_seed"])
random.seed(config["random_seed"])

# Load data
images = sorted(glob.glob(os.path.join(DATA_DIR, "*.nii.gz")))
images = images[:2]
test_loader = val_dataloader(config, images)

# Initialize Model and test metric
if config["model"] == 'unet512':
    model = unet512(6).to(device)
else:
    model = unet256(6).to(device)
# Set metric to compute average over each class
test_metric = DiceMetric(include_background=False, reduction="none")

test(config,
     config_id,
     device,
     model,
     model_path,
     test_metric,
     test_loader)

Validation sample size: 2


  0%|                                                                                      | 0/2 [00:00<?, ?it/s]

torch.Size([1, 6, 313, 290, 283])


 50%|███████████████████████████████████████                                       | 1/2 [00:11<00:11, 11.62s/it]

torch.Size([1, 6, 351, 272, 349])


100%|██████████████████████████████████████████████████████████████████████████████| 2/2 [00:16<00:00,  8.37s/it]

Average class scores: tensor([0.9720, 0.9675, 0.9657, 0.9154, 0.9781], device='cuda:0')
Average score overall: 0.9597461819648743



