In [1]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "2"

import random

import json
import torch
import wandb
import sys
from torch.utils.data import DataLoader, Subset, RandomSampler
from tqdm import tqdm
import pandas as pd
import numpy as np
import torch.nn as nn
from torchmetrics.classification import BinaryF1Score
from torchmetrics import Dice
from safetensors.torch import load_model
import segmentation_models_pytorch as smp


# Add the parent directory to sys.path
sys.path.append(os.path.dirname(os.path.realpath(os.path.abspath(""))))

from unet.train_dataset import DeadwoodDataset
from unet.unet_model import UNet
import matplotlib.pyplot as plt
import seaborn as sns

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
api = wandb.Api()
run_id = "a7jnzozl"
experiment = api.run(f"jmoehring/standing-deadwood-unet-pro/{run_id}")

In [3]:
val_samples = 100
epoch = 74
fold = 2

In [4]:
experiment.config

{'amp': True,
 'beta': 0.7,
 'bins': [0, 0.02, 0.04, 0.06, 0.08, 0.1, 0.12, 0.14, 0.16, 0.18, 0.2],
 'loss': 'tverskyfocal',
 'alpha': 0.3,
 'gamma': 2,
 'epochs': 100,
 'momentum': 0.999,
 'no_folds': 3,
 'run_fold': 2,
 'test_size': 0.2,
 'use_wandb': 'true',
 'val_every': 5,
 'batch_size': 3,
 'bce_weight': 0.9,
 'images_dir': '/lscratch/standing-deadwood/tiles',
 'pos_weight': 12,
 'lr_patience': 5,
 'num_workers': 8,
 'random_seed': 11,
 'encoder_name': 'unet',
 'weight_decay': 0.0001,
 'learning_rate': 1e-05,
 'register_file': '/lscratch/standing-deadwood/tiles/register.csv',
 'experiment_name': '50k_100e_vanilla_tversky_a03b07g2',
 'experiments_dir': '/lscratch/standing-deadwood',
 'save_checkpoint': 'true',
 'balancing_factor': 0.9,
 'epoch_val_samples': 0,
 'gradient_clipping': 1,
 'epoch_train_samples': 50000,
 'gradient_accumulation': 1}

In [5]:
random.seed(experiment.config["random_seed"])
np.random.seed(experiment.config["random_seed"])
torch.manual_seed(experiment.config["random_seed"])
torch.cuda.manual_seed_all(experiment.config["random_seed"])

In [6]:
register_df = pd.read_csv("/net/home/jmoehring/scratch/tiles/register.csv")
indexes = register_df[register_df["biome"] == 12].index

In [7]:
dataset = DeadwoodDataset(
    register_df=register_df,
    images_dir="/net/home/jmoehring/scratch/tiles",
    no_folds=experiment.config["no_folds"],
    random_seed=experiment.config["random_seed"],
    test_size=experiment.config["test_size"],
    register_indexes=None,
    verbose=True,
)

In [8]:
loader_args = {
    "batch_size": 1,
    "num_workers": 1,
    "pin_memory": True,
    "shuffle": False,
}
g = torch.Generator()
g.manual_seed(experiment.config["random_seed"])
val_set = dataset.get_test_set()
val_loader = DataLoader(val_set, generator=g, **loader_args)

# only sample a subset of the validation set
if val_samples > 0:
    loader_args["shuffle"] = False
    sampler = RandomSampler(val_set, num_samples=val_samples, generator=g)
    val_loader = DataLoader(val_set, sampler=sampler, generator=g, **loader_args)

In [9]:
experiment_name = experiment.name.replace(f"_fold{fold}", "")
model_path = f"/net/home/jmoehring/experiments/{experiment_name}/fold_{fold}_epoch_{epoch}/model.safetensors"
# preferably use GPU
device = torch.device("cuda")
# model with three input channels (RGB)
if experiment.config["encoder_name"] != "unet":
    model = smp.Unet(
        encoder_name=experiment.config["encoder_name"],
        encoder_weights=experiment.config["encoder_weights"],
        in_channels=3,
        classes=1,
    ).to(memory_format=torch.channels_last)
else:
    model = UNet(
        n_channels=3,
        n_classes=1,
    ).to(memory_format=torch.channels_last)
load_model(model, model_path)
model = model.to(memory_format=torch.channels_last, device=device)

model.eval()

UNet(
  (inc): DoubleConv(
    (double_conv): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (down1): Down(
    (maxpool_conv): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): DoubleConv(
        (double_conv): Sequential(
          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
 

In [10]:
run_name = f"{experiment.name}_fold_{fold}_epoch_{epoch}_eval"
wandb.init(project="standing-deadwood-unet-pro", name=run_name, resume=False)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjmoehring[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [11]:
text_table = wandb.Table(columns=["step", "base_file_name", "register_index"])

In [12]:
def get_log_image(input):
    # Ensure the input tensor is of the correct dtype (if it's not already)
    if input.dtype != torch.uint8:
        input = input.byte()

    # Create an empty tensor to hold the RGB output (3 channels for RGB)
    # The shape will be [height, width, 3] where 3 is for RGB channels
    rgb_image = torch.zeros((input.shape[0], input.shape[1], 3), dtype=torch.uint8)

    # Set the RGB values for pixels with value 1 (Yellow: [255, 255, 0])
    rgb_image[input == 1] = torch.tensor([255, 255, 0], dtype=torch.uint8)

    # Set the RGB values for pixels with value 255 (Red: [255, 0, 0])
    rgb_image[input == 255] = torch.tensor([255, 0, 0], dtype=torch.uint8)

    # transform to [3, height, width]
    rgb_image = rgb_image.permute(2, 0, 1)

    return rgb_image

In [13]:
step = 0

for batch, (images, true_masks, _, indexes) in tqdm(
    enumerate(val_loader), total=len(val_loader)
):
    images = images.to(memory_format=torch.channels_last, device=device)
    true_masks = true_masks.to(device=device)

    with torch.no_grad():
        pred_masks = model(images)
        pred_masks = torch.sigmoid(pred_masks)
        # extend the dataframe by all the results of the batch
        for i in range(len(images)):
            if step < val_samples:
                register_index = indexes[i].item()
                image = images[i].float().cpu()
                ground_truth = get_log_image(true_masks[i].float().cpu().squeeze())
                prediction = get_log_image(
                    (pred_masks[i] > 0.5).float().cpu().squeeze()
                )
                wandb.log(
                    {
                        "image": wandb.Image(image),
                        "true_mask": wandb.Image(ground_truth),
                        "pred_mask": wandb.Image(prediction),
                        "step": step,
                    },
                )
                text_table.add_data(
                    step,
                    register_df.iloc[register_index]["base_file_name"],
                    register_index,
                )
                step += 1
            else:
                break

100%|██████████| 100/100 [00:50<00:00,  1.96it/s]


In [14]:
wandb.log({"index_table": text_table})