In [None]:
%load_ext tensorboard
%load_ext autoreload
%autoreload 2
%cd ..

In [None]:
import time

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Subset, random_split
import pytorch_lightning as pl
from torchsummary import summary
from tqdm import tqdm
import matplotlib.pyplot as plt

from torchvision import transforms

from src.data.utils import simple_collate_fn
from src.data import NSCLCDataset
from src.models.unet import UNet
from src.visualization import plot_batch, plot_true_vs_pred
from src.preprocess import DEFAULT_TRANSFORM

In [None]:
device = "cuda:1"
from src.models.unet import UNet
net = UNet.load_from_checkpoint("models")
net.to(device).eval();

In [None]:
# evaluate on test set
ct_ids = get_common_ids("data/raw/NSCLC-Radiomics/",
                        "data/processed/NSCLC_ground_truths/")
train_ratio = 0.7
val_ratio = 0.2
num_train_scans = int(len(ct_ids) * train_ratio)
num_val_scans = int(len(ct_ids) * val_ratio)
test_scans = ct_ids[num_train_scans+num_val_scans:]
test_ds = NSCLCDataset(metadata_path="data/processed/NSCLC-Radiomics_metadata_v2.csv",
                       ct_ids=test_scans)
test_loader = DataLoader(test_ds, batch_size=8, collate_fn=simple_collate_fn, num_workers=4, shuffle=False,                                      pin_memory=True)
    
trainer = pl.Trainer()
trainer.test(net, test_dataloaders=test_loader)

In [None]:
X_test, y_test = next(iter(test_loader))
with torch.no_grad():
    preds = net(X_test.to(device))
    pred_masks = torch.argmax(preds, dim=1)
    pred_masks = pred_masks.cpu()
plot_true_vs_pred(X_test, y_test, pred_masks,
                  mask_alpha=0.2)