In [1]:
import os
import sys

import torch
import wandb
from torch.utils.data import DataLoader
from tqdm import tqdm

# Add the parent directory to sys.path
sys.path.append(os.path.dirname(os.path.realpath(os.path.abspath(""))))

from unet.dataset import DeadwoodDataset
from unet.dice_score import dice_coeff
from unet.evaluate import evaluate
from unet.unet_model import UNet

In [2]:
# data paths
images_dir = "/net/scratch/jmoehring/tiles/images"
masks_dir = "/net/scratch/jmoehring/tiles/masks"
checkpoint_dir = "/net/scratch/jmoehring/checkpoints"

# data params
no_folds: int = 10
fold: int = 0
val_percent: float = 0.1
random_seed: int = 42
batch_size: int = 16

In [3]:
dataset = DeadwoodDataset(
    images_dir,
    masks_dir,
    n_folds=no_folds,
    random_seed=random_seed,
)

In [4]:
# get test fold
_, test_set = dataset.get_fold(fold)

In [5]:
loader_args = {
    "batch_size": batch_size,
    "num_workers": 12,
    "pin_memory": True,
    "shuffle": True,
}
test_loader = DataLoader(test_set, **loader_args)

In [6]:
# preferably use GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# model with three input channels (RGB)
model = UNet(n_channels=3, n_classes=1, bilinear=True)
model.load_state_dict(
    torch.load(
        os.path.join(checkpoint_dir, "deadwood_fold_0_1704409757.066416_epoch5.pth")
    )
)
model = model.to(memory_format=torch.channels_last, device=device)
model.eval()

UNet(
  (inc): DoubleConv(
    (double_conv): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (down1): Down(
    (maxpool_conv): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): DoubleConv(
        (double_conv): Sequential(
          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
 

In [7]:
run = "deadwood_fold_0_1704409757.066416_test3"
wandb.init(project="standing-deadwood-unet", name=run, resume=True)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjmoehring[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [8]:
image_count = 0
dice_score = 0
for batch, (images, true_masks) in tqdm(enumerate(test_loader), total=len(test_loader)):
    images = images.to(memory_format=torch.channels_last, device=device)
    true_masks = true_masks.to(memory_format=torch.channels_last, device=device)
    with torch.no_grad():
        pred_masks = model(images)
        pred_masks = torch.sigmoid(pred_masks)
        pred_masks_sig = (pred_masks > 0.5).float()
        dice_score += dice_coeff(pred_masks_sig, true_masks).item()
        for i in range(images.shape[0]):
            # only add if true masks has ones in it
            if true_masks[i].sum() > 0 and image_count < 100:
                wandb.log(
                    {
                        "image": wandb.Image(images[i].cpu()),
                        "true_mask": wandb.Image(true_masks[i].float().cpu()),
                        "pred_mask": wandb.Image(pred_masks[i].float().cpu()),
                        "pred_mask_sig": wandb.Image(pred_masks_sig[i].float().cpu()),
                    }
                )
                image_count += 1
wandb.log({"dice_score": dice_score / len(test_loader)})

3605it [12:34,  4.78it/s]
