In [23]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "3"

import random

import json
import torch
import wandb
import sys
from torch.utils.data import DataLoader, Subset, RandomSampler
from tqdm import tqdm
import pandas as pd
import numpy as np
import torch.nn as nn
from torchmetrics.classification import BinaryF1Score
from torchmetrics import Dice
from safetensors.torch import load_model
import segmentation_models_pytorch as smp


# Add the parent directory to sys.path
sys.path.append(os.path.dirname(os.path.realpath(os.path.abspath(""))))

from unet.train_dataset import DeadwoodDataset
from unet.unet_model import UNet
import matplotlib.pyplot as plt
import seaborn as sns

In [24]:
api = wandb.Api()
run_id = "vrluycuo"
experiment = api.run(f"deadtees/standing-deadwood-unet-pro/{run_id}")

In [25]:
val_samples = 100
epoch = 199
fold = 0

In [8]:
experiment.config

{'gamma': 2,
 'momentum': 0.999,
 'loss': 'tverskyfocal',
 'images_dir': '/net/scratch/jmoehring/tiles',
 'gradient_clipping': 1,
 'experiments_dir': '/net/scratch/cmosig/experiment_dir_deadwood_segmentation',
 'epochs': 200,
 'epoch_val_samples': 0,
 'lr_patience': 5,
 'register_file': '/net/scratch/jmoehring/tiles/register_new.csv',
 'gradient_accumulation': 1,
 'no_folds': 3,
 'batch_size': 12,
 'learning_rate': 1e-05,
 'experiment_name': 'segformer_b1',
 'epoch_train_samples': 11000,
 'test_size': 0.2,
 'balancing_factor': 1,
 'run_fold': -1,
 'val_every': 15,
 'num_workers': 8,
 'amp': True,
 'bce_weight': 0.9,
 'beta': 0.9,
 'weight_decay': 0.0001,
 'alpha': 0.1,
 'encoder_name': 'mit_b1',
 'save_checkpoint': 'true',
 'bins': [0, 0.02, 0.04, 0.06, 0.08, 0.1, 0.12, 0.14, 0.16, 0.18, 0.2],
 'encoder_weights': 'imagenet',
 'random_seed': 11,
 'use_wandb': 'true',
 'pos_weight': 12}

In [9]:
random.seed(experiment.config["random_seed"])
np.random.seed(experiment.config["random_seed"])
torch.manual_seed(experiment.config["random_seed"])
torch.cuda.manual_seed_all(experiment.config["random_seed"])

In [11]:
register_df = pd.read_csv("/net/scratch/jmoehring/tiles/register_new.csv")
indexes = register_df[register_df["biome"] == 12].index

In [12]:
dataset = DeadwoodDataset(
    register_df=register_df,
    images_dir="/net/scratch/jmoehring/tiles",
    no_folds=experiment.config["no_folds"],
    random_seed=experiment.config["random_seed"],
    test_size=experiment.config["test_size"],
    register_indexes=None,
    verbose=True,
)

In [13]:
loader_args = {
    "batch_size": 1,
    "num_workers": 1,
    "pin_memory": True,
    "shuffle": False,
}
g = torch.Generator()
g.manual_seed(experiment.config["random_seed"])
val_set = dataset.get_test_set()
val_loader = DataLoader(val_set, generator=g, **loader_args)

# only sample a subset of the validation set
if val_samples > 0:
    loader_args["shuffle"] = False
    sampler = RandomSampler(val_set, num_samples=val_samples, generator=g)
    val_loader = DataLoader(val_set, sampler=sampler, generator=g, **loader_args)

In [17]:
experiment_name = experiment.name.replace(f"_fold{fold}", "")
model_path = f"/net/scratch/cmosig/experiment_dir_deadwood_segmentation/{experiment_name}/fold_{fold}_epoch_{epoch}/model.safetensors"
print(model_path)
import torch
# preferably use GPU
device = torch.device("cuda")
# model with three input channels (RGB)
if experiment.config["encoder_name"] != "unet":
    model = smp.Unet(
        encoder_name=experiment.config["encoder_name"],
        encoder_weights=experiment.config["encoder_weights"],
        in_channels=3,
        classes=1,
    ).to(memory_format=torch.channels_last)
else:
    model = UNet(
        n_channels=3,
        n_classes=1,
    ).to(memory_format=torch.channels_last)
load_model(torch.compile(model), model_path)
model = model.to(memory_format=torch.channels_last, device=device)

model.eval()

/net/scratch/cmosig/experiment_dir_deadwood_segmentation/segformer_b1/fold_0_epoch_199/model.safetensors


Unet(
  (encoder): MixVisionTransformerEncoder(
    (patch_embed1): OverlapPatchEmbed(
      (proj): Conv2d(3, 64, kernel_size=(7, 7), stride=(4, 4), padding=(3, 3))
      (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    )
    (patch_embed2): OverlapPatchEmbed(
      (proj): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    )
    (patch_embed3): OverlapPatchEmbed(
      (proj): Conv2d(128, 320, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (norm): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
    )
    (patch_embed4): OverlapPatchEmbed(
      (proj): Conv2d(320, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    )
    (block1): ModuleList(
      (0): Block(
        (norm1): LayerNorm((64,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (q): Linear(in_features=64

In [18]:
run_name = f"{experiment.name}_fold_{fold}_epoch_{epoch}_eval"
wandb.init(project="standing-deadwood-unet-pro", name=run_name, resume=False)

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mcmosig[0m ([33mdeadtees[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [19]:
text_table = wandb.Table(columns=["step", "base_file_name", "register_index"])

In [20]:
def get_log_image(input):
    # Ensure the input tensor is of the correct dtype (if it's not already)
    if input.dtype != torch.uint8:
        input = input.byte()

    # Create an empty tensor to hold the RGB output (3 channels for RGB)
    # The shape will be [height, width, 3] where 3 is for RGB channels
    rgb_image = torch.zeros((input.shape[0], input.shape[1], 3), dtype=torch.uint8)

    # Set the RGB values for pixels with value 1 (Yellow: [255, 255, 0])
    rgb_image[input == 1] = torch.tensor([255, 255, 0], dtype=torch.uint8)

    # Set the RGB values for pixels with value 255 (Red: [255, 0, 0])
    rgb_image[input == 255] = torch.tensor([255, 0, 0], dtype=torch.uint8)

    # transform to [3, height, width]
    rgb_image = rgb_image.permute(2, 0, 1)

    return rgb_image

In [21]:
step = 0

for batch, (images, true_masks, _, indexes) in tqdm(
    enumerate(val_loader), total=len(val_loader)
):
    images = images.to(memory_format=torch.channels_last, device=device)
    true_masks = true_masks.to(device=device)

    with torch.no_grad():
        pred_masks = model(images)
        pred_masks = torch.sigmoid(pred_masks)
        # extend the dataframe by all the results of the batch
        for i in range(len(images)):
            if step < val_samples:
                register_index = indexes[i].item()
                image = images[i].float().cpu()
                ground_truth = get_log_image(true_masks[i].float().cpu().squeeze())
                prediction = get_log_image(
                    (pred_masks[i] > 0.5).float().cpu().squeeze()
                )
                wandb.log(
                    {
                        "image": wandb.Image(image),
                        "true_mask": wandb.Image(ground_truth),
                        "pred_mask": wandb.Image(prediction),
                        "step": step,
                    },
                )
                text_table.add_data(
                    step,
                    register_df.iloc[register_index]["base_file_name"],
                    register_index,
                )
                step += 1
            else:
                break

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:51<00:00,  1.94it/s]


In [22]:
wandb.log({"index_table": text_table})