# Playground

In [27]:
import wandb
import torch
import seaborn_image as isns
from pathlib import Path

from uda import UNetConfig, UNet, CC359, CC359Config, HParams
from uda.utils import reshape_to_volume
from uda.metrics import surface_dice
from tqdm.notebook import tqdm

device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

files_dir = Path("/tmp/files")
data_dir = Path("/tmp/data/CC359")
run_id = "3c4h822p"

Using device: cuda:0


In [28]:
wandb.restore("config/cc359.yml", f"tiser/UDA/{run_id}", root=files_dir, replace=True)
wandb.restore("config/hparams.yml", f"tiser/UDA/{run_id}", root=files_dir, replace=True)
wandb.restore("config/unet.yml", f"tiser/UDA/{run_id}", root=files_dir, replace=True)

wandb.restore("best_model", f"tiser/UDA/{run_id}", root=files_dir, replace=True);

In [29]:
dataset_conf = CC359Config.from_file(files_dir / "config/cc359.yml")
hparams = HParams.from_file(files_dir / "config/hparams.yml")

val_dataset = CC359(data_dir, dataset_conf, train=False)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=hparams.val_batch_size, shuffle=False)

Loading files:   0%|          | 0/20 [00:00<?, ?it/s]

In [30]:
unet_conf = UNetConfig.from_file(files_dir / "config/unet.yml")
model = UNet.from_pretrained(files_dir / "best_model", unet_conf)

In [31]:
model.eval().to(device)

with torch.no_grad():
    preds, targets = [*zip(*[(model(x.to(device)).cpu(), y_true) for x, y_true in tqdm(val_loader)])]

preds = reshape_to_volume(torch.cat(preds).round(), val_dataset.PADDING_SHAPE, val_dataset.patch_dims)
targets = reshape_to_volume(torch.cat(targets), val_dataset.PADDING_SHAPE, val_dataset.patch_dims)

model.cpu()

preds.shape

  0%|          | 0/160 [00:00<?, ?it/s]

torch.Size([20, 192, 256, 256])

3D-UNet:

In [26]:
sf_dice = surface_dice(preds, targets, val_dataset.spacing_mm, tolerance_mm=1, prog_bar=True)

sf_dice.mean()

Computing surface dice:   0%|          | 0/20 [00:00<?, ?it/s]

tensor(0.9043)

2D-UNet:

In [32]:
sf_dice = surface_dice(preds, targets, val_dataset.spacing_mm, tolerance_mm=1, prog_bar=True)

sf_dice.mean()

Computing surface dice:   0%|          | 0/20 [00:00<?, ?it/s]

tensor(0.9370)