In [1]:
%env CUDA_VISIBLE_DEVICES=0

env: CUDA_VISIBLE_DEVICES=0


In [2]:
import os

import json
import torch
import wandb
import sys
from torch.utils.data import DataLoader, Subset, RandomSampler
from tqdm import tqdm
import pandas as pd
import numpy as np
import torch.nn as nn
from torchmetrics.classification import BinaryF1Score
from torchmetrics import Dice

# Add the parent directory to sys.path
sys.path.append(os.path.dirname(os.path.realpath(os.path.abspath(""))))

from unet.train_dataset import DeadwoodDataset
from unet.unet_model import UNet
import matplotlib.pyplot as plt
import seaborn as sns

In [3]:
api = wandb.Api()
run_id = "8zcpefnk"
experiment = api.run(f"jmoehring/standing-deadwood-unet-pro/{run_id}")

In [4]:
val_samples = 100
epoch = 59
fold = 0

In [5]:
experiment.config

{'amp': True,
 'bins': [0, 0.02, 0.04, 0.06, 0.08, 0.1, 0.12, 0.14, 0.16, 0.18, 0.2],
 'epochs': 60,
 'momentum': 0.999,
 'no_folds': 3,
 'test_size': 0,
 'use_wandb': True,
 'batch_size': 110,
 'bce_weight': 0.5,
 'images_dir': '/net/scratch/jmoehring',
 'pos_weight': 40,
 'lr_patience': 5,
 'num_workers': 32,
 'random_seed': 10,
 'weight_decay': 1e-08,
 'learning_rate': 1e-05,
 'register_file': '/net/scratch/jmoehring/tiles_register_biome_bin.csv',
 'experiments_dir': '/net/home/jmoehring/experiments',
 'save_checkpoint': True,
 'balancing_factor': 1,
 'epoch_val_samples': 0,
 'gradient_clipping': 1,
 'epoch_train_samples': 50000}

In [6]:
register_df = pd.read_csv("/net/scratch/jmoehring/tiles_register_512.csv")
dataset = DeadwoodDataset(
    register_df=register_df,
    images_dir=experiment.config["images_dir"],
    no_folds=experiment.config["no_folds"],
    random_seed=experiment.config["random_seed"],
    test_size=0,
)

In [7]:
loader_args = {
    "batch_size": 16,
    "num_workers": 12,
    "pin_memory": True,
    "shuffle": True,
}
_, val_set = dataset.get_train_val_fold(0)
val_loader = DataLoader(val_set, **loader_args)

# only sample a subset of the validation set
if val_samples > 0:
    loader_args["shuffle"] = False
    sampler = RandomSampler(val_set, num_samples=val_samples)
    val_loader = DataLoader(val_set, sampler=sampler, **loader_args)

In [8]:
model_path = f"{experiment.config['experiments_dir']}/{experiment.name}/fold_{fold}_epoch_{epoch}.pt"
# preferably use GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model with three input channels (RGB)
model = UNet(n_channels=3, n_classes=1, bilinear=True)
model = nn.DataParallel(model)
model.load_state_dict(torch.load(model_path)["model_state_dict"])
model = model.to(memory_format=torch.channels_last, device=device)

model.eval()

DataParallel(
  (module): UNet(
    (inc): DoubleConv(
      (double_conv): Sequential(
        (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace=True)
      )
    )
    (down1): Down(
      (maxpool_conv): Sequential(
        (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (1): DoubleConv(
          (double_conv): Sequential(
            (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
            (3): Conv2d(128, 128, 

In [9]:
run_name = f"{experiment.name}_fold_{fold}_epoch_{epoch}_eval"
wandb.init(project="standing-deadwood-unet-pro", name=run_name, resume=False)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjmoehring[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [11]:
image_count = 0

for batch, (images, true_masks, images_metas) in tqdm(
    enumerate(val_loader), total=len(val_loader)
):
    images = images.to(memory_format=torch.channels_last, device=device)
    true_masks = true_masks.to(device=device)

    with torch.no_grad():
        pred_masks = model(images)
        pred_masks = torch.sigmoid(pred_masks)
        # extend the dataframe by all the results of the batch
        for i in range(len(images)):
            if image_count < val_samples:
                wandb.log(
                    {
                        "image": wandb.Image(images[i].float().cpu()),
                        "true_mask": wandb.Image(true_masks[i].float().cpu()),
                        "pred_mask": wandb.Image(pred_masks[i].float().cpu()),
                    },
                )
                image_count += 1

100%|██████████| 7/7 [00:13<00:00,  1.92s/it]
