In [None]:
import os

import json
import torch
import wandb
import sys
from torch.utils.data import DataLoader, Subset, RandomSampler
from tqdm import tqdm
import pandas as pd
import numpy as np
import torch.nn as nn
from torchmetrics.classification import BinaryF1Score
from torchmetrics import Dice

# Add the parent directory to sys.path
sys.path.append(os.path.dirname(os.path.realpath(os.path.abspath(""))))

from unet.dataset import DeadwoodDataset
from unet.dice_score import dice_coeff, confusion_values, confusion_tensor
from unet.evaluate import evaluate
from unet.unet_model import UNet
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
run_id = "garg20i1"

In [None]:
# import importlib
# importlib.reload(sys.modules['unet.dice_score'])

In [None]:
api = wandb.Api()
experiment = api.run(f"jmoehring/standing-deadwood-unet-pro/{run_id}")

In [None]:
experiment.config

In [None]:
no_val_samples: int = 0
epoch_models = [19, 12, 20]

fold: int = 2
epoch: int = epoch_models[fold]

# data paths
checkpoint_dir = f"/net/scratch/jmoehring/checkpoints/{experiment.name}"
model_checkpoint = f"fold_{fold}_epoch_{epoch}.pth"

# data params
no_folds: int = experiment.config["data"]["no_folds"]
random_seed: int = experiment.config["data"]["random_seed"]
batch_size: int = experiment.config["data"]["batch_size"]
test_size: float = experiment.config["data"]["test_size"]

In [None]:
register_df = pd.read_csv(experiment.config["data"]["register_file"])

In [None]:
dataset = DeadwoodDataset(
    register_df=register_df,
    no_folds=no_folds,
    random_seed=random_seed,
    test_size=test_size,
)

In [None]:
loader_args = {
    "batch_size": 64,
    "num_workers": 12,
    "pin_memory": True,
    "shuffle": True,
}
_, val_set = dataset.get_train_val_fold(fold)
val_loader = DataLoader(val_set, **loader_args)

# only sample a subset of the validation set
if no_val_samples > 0:
    loader_args["shuffle"] = False
    sampler = RandomSampler(val_set, num_samples=no_val_samples)
    val_loader.sampler = sampler

In [None]:
# preferably use GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device_ids = [0]
# model with three input channels (RGB)
model = UNet(n_channels=3, n_classes=1, bilinear=True)
if torch.cuda.device_count() > 1:
    device_ids = [0, 1]
model = nn.DataParallel(model, device_ids=device_ids)
model.load_state_dict(torch.load(os.path.join(checkpoint_dir, model_checkpoint)))
model = model.to(memory_format=torch.channels_last, device=device)

model.eval()

In [None]:
run_name = f"{experiment.name}_fold_{fold}_epoch_{epoch}_eval"
wandb.init(project="standing-deadwood-unet-pro", name=run_name, resume=False)

In [None]:
image_count = 0

eval_df = pd.DataFrame(columns=["biome", "resolution_bin", "precision", "recall", "f1"])
for batch, (images, true_masks, images_metas) in tqdm(
    enumerate(val_loader), total=len(val_loader)
):
    images = images.to(memory_format=torch.channels_last, device=device)
    true_masks = true_masks.to(device=device)

    with torch.no_grad():
        pred_masks = model(images)
        pred_masks = torch.sigmoid(pred_masks)
        pred_masks_sig = (pred_masks > 0.5).float()
        precision, recall, f1 = confusion_values(
            pred_masks_sig.squeeze(), true_masks.squeeze()
        )
        # extend the dataframe by all the results of the batch
        for i in range(len(images)):
            eval_row = pd.DataFrame(
                {
                    "biome": [images_metas["biome"][i].item()],
                    "resolution_bin": [images_metas["resolution_bin"][i].item()],
                    "precision": [precision[i].numpy()],
                    "recall": [recall[i].numpy()],
                    "f1": [f1[i].numpy()],
                },
            )
            eval_df = pd.concat([eval_df, eval_row])
            # if image_count < 50 and true_masks[i].sum() > 0:
            #     merged_confusion_tensor = confusion_tensor(
            #         pred_masks_sig[i].squeeze(), true_masks[i].squeeze()
            #     )
            #     wandb.log(
            #         {
            #             "segmentation": wandb.Image(
            #                 images[i].float().cpu(),
            #                 masks={
            #                     "confusion": {
            #                         "mask_data": merged_confusion_tensor.float()
            #                         .cpu()
            #                         .squeeze()
            #                         .numpy(),
            #                         "class_labels": {
            #                             1: "true_positive",
            #                             2: "false_positive",
            #                             3: "false_negative",
            #                         },
            #                     },
            #                 },
            #             )
            #         }
            #     )
            #     image_count += 1

In [None]:
eval_df["f1"] = eval_df["f1"].astype(float)
eval_df["precision"] = eval_df["precision"].astype(float)
eval_df["recall"] = eval_df["recall"].astype(float)

In [None]:
eval_df.info()

In [None]:
eval_df.to_csv(
    f"/net/scratch/jmoehring/eval_{experiment.name}_fold_{fold}_epoch_{epoch}.csv"
)

In [None]:
# eval_df = pd.read_csv(
#     f"/net/scratch/jmoehring/eval_{experiment.name}_fold_{fold}_epoch_{epoch}.csv"
# )

In [None]:
# add new column biome names
biome_names = {
    4: "Temperate Broadleaf and Mixed Forests",
    5: "Temperate Coniferous Forests",
    6: "Boreal Forests/Taiga",
    12: "Mediteranean Forests",
}
eval_df["biome_name"] = eval_df["biome"].map(biome_names)

In [None]:
# plot seaborn heatmap of dice scores with biome and resolution as x and y axis
sns.set_theme()
pivoted = eval_df.pivot_table(index="biome_name", columns="resolution_bin", values="f1")
sns.heatmap(pivoted, cmap="rocket", annot=True, fmt=".2f")
plt.show()