# U-Net Segmantic Segmentation

In [1]:
import sys
from pathlib import Path

BASE_FOLDER = Path.cwd().parent
sys.path.append(str(BASE_FOLDER))

wandb_project = "unet_change"
logging = "remote"
wandb_key = "YOUR_WANDB_API_KEY_HERE"

# Training

In [None]:
from src.trainers.unet_segmentation.train import train_model

datamodule, task, trainer = train_model(
    logging=logging,
    wandb_project=wandb_project,
    wandb_key=wandb_key
)

# Sweeps

### 1 - Sweep Resnet18

```config-sweep-resnet18.yaml```

Testing different batch size, patch size, number of training batches per epoch, use of pretrained model weights with Resnet18 backbone.

Batch size, patch size, training batches: No significant trend observable. Pretrained model weights = True seems to work better.

### 2 - Sweep Backbones

Testing different backbones:
- resnet50: No outstanding performance
- resnext50_32x4d: Better with 512 training batches. ValJaccard ~0.25
- efficientnet-b4: Requires many GPU ressources, but better perfomance than other backbones. 512 training batches leads to good results faster than 1024. ValJaccard ~0.27
- mit_b2: Performance similar to efficientnet-b4. ValJaccard ~0.27
- HRNet https://doi.org/10.3390/rs13163087: Not working with torchgeo.
- MobileNet v4: https://doi.org/10.48550/arXiv.2404.10518

### 3 - Sweep Input Bands

Adding more bands with every training (mit_b2, batch size 32, patch size 256)
- Generally: The more the better
- IoU:
    - RGB: 0.07
    - RGB + CIR: 0.1
    - RGB + CIR + Elevation: 0.2
    - RGB + CIR + Elevation + Derived Layers: 0.25
- Adding Spot-4 data (```sweep-spot-4```) does *not* lead to significantly higher scores!

In [None]:
from src.trainers.unet_segmentation_sweep.sweep import start_sweep

start_sweep()

# Load saved Datamodule and Task

In [15]:
from src.trainers.utils import load_config_from_yaml

experiment_dir = Path("../models/best-parameters-4.3")
config = load_config_from_yaml(experiment_dir / "config.yaml")

Get datamodule and task if not already loaded

In [None]:
import torch
import yaml

from src.trainers.unet_segmentation.unet_segmentation import MultiClassSemanticSegmentationTask
from src.trainers.unet_segmentation.train import get_datamodule

# Load task
with open(experiment_dir / "model_paths.yaml", "r") as file:
    model_path = yaml.safe_load(file)["best_model_path"]
task = MultiClassSemanticSegmentationTask.load_from_checkpoint(experiment_dir / model_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
task.to(device).eval()

# Get datamodule
datamodule = get_datamodule(
    config=config,
    prediction_year=2024,
)
datamodule.setup()

# Plot Samples from Test Dataset

In [None]:
fig = task.plot_prediction_samples(
    datamodule=datamodule,
    experiment_dir=experiment_dir,
    n_samples=6,
)

fig.show()

# Inference

In [4]:
import torch
from torchgeo.datasets import BoundingBox
import wandb

from src.inference.eval import infer_on_whole_image

## Inference on Cross-Temporal Prediction Dataset

In [None]:
datamodule.num_workers = 0

roi = None #BoundingBox(minx=470800, maxx=473000, miny=5270500, maxy=5272000, mint=0.0, maxt=9.223372036854776e+18) # Define a small test ROI
infer_on_whole_image(
    datamodule=datamodule,
    task=task,
    experiment_dir=experiment_dir,
    overlap=64,
    delta=8,
    predict_on_test_ds=False,
    roi=roi,
    output_filename="prediction_cross-temporal_2024",
    inference_batch_size=4,
)

Inference:  28%|██▊       | 5285/18634 [17:27<54:35,  4.08it/s]  

In [None]:
datamodule.num_workers = 0

roi = None #BoundingBox(minx=470800, maxx=473000, miny=5270500, maxy=5272000, mint=0.0, maxt=9.223372036854776e+18) # Define a small test ROI
infer_on_whole_image(
    datamodule=datamodule,
    task=task,
    experiment_dir=experiment_dir,
    overlap=64,
    delta=8,
    predict_on_test_ds=False,
    roi=roi,
    output_filename="prediction_cross-temporal_2020",
    inference_batch_size=4,
)

## Inference on Test Dataset

In [None]:
datamodule.num_workers = 0

roi = None #BoundingBox(minx=470800, maxx=473000, miny=5270500, maxy=5272000, mint=0.0, maxt=9.223372036854776e+18) # Define a small test ROI
infer_on_whole_image(
    datamodule=datamodule,
    task=task,
    experiment_dir=experiment_dir,
    overlap=64,
    delta=8,
    predict_on_test_ds=True,
    roi=roi,
    output_filename="prediction_in-domain_2013",
    inference_batch_size=4,
)

# Change Map

In [9]:
from src.inference.eval import generate_change_map
from src.trainers.utils import compute_final_metrics

data_folder = Path("D:/Nextcloud/HabitAlp2.0/Originaldaten")
experiment_dir = data_folder / "model_output/clay_v1_base_rgb_cir_ndsm/"

## Change Map with Cross-Temporal Prediction Dataset

In [13]:
generate_change_map(
    mask_path="D:/Nextcloud/HabitAlp2.0/Originaldaten/processed/mask/classes_v3_2013.tif",#datamodule.mask_path,
    prediction_path=experiment_dir / "training_rois_2020_model-f64vg6b6_v40_mIoU=0.3306.tif",
    output_path=experiment_dir / "change_map_2013-2020-f64vg6b6_v40.tif",
)

In [None]:
class_names = [
    "No change",
    "Mature Tree Density Loss",
    "Old Growth Density Loss",
    "Forest Setback YoungLoss",
    "Forest Stage Progression",
    "Forest Density Gain",
    "Early Forest Establishment",
    "Clearcut Loss",
    "Other Transition"]

#class_names = [
#    "No change",
#    "Change"
#]

num_classes = len(class_names)

multiclass_metrics, metrics_each_label, figure_collection = compute_final_metrics(
    reference_change_map=data_folder / "habitalp_change/habitalp_change_2013_2020.tif",
    prediction_change_map=experiment_dir / "change_map_2013-2020_roi_example-ocg72h08_v97.tif.tif",
    num_classes=num_classes, # 9 classes to take no change and index into account
    class_names=class_names,
)

print(multiclass_metrics)

In [None]:
with wandb.init(project=wandb_project, id="nt7lhzek", resume="must") as run:
    if num_classes == 2:
        metric_name = "change_metrics_binary"
        # log Multiclass metrics
        run.summary[metric_name] = multiclass_metrics

        # Log metrics for each class
        metrics_each_label.pop("ConfusionMatrix")
        tensor_stack = torch.stack(list(metrics_each_label.values())).T
        column_lists = [row.tolist() for row in tensor_stack]
        change_metrics_each_class_table = wandb.Table(
            columns=list(metrics_each_label.keys()),
            data=column_lists,
        )
        run.log({f"{metric_name}/table_each_class": change_metrics_each_class_table})

        # Log figures
        run.log({f"{metric_name}/confusion_matrix": wandb.Image(figure_collection[0])})
        run.log({f"{metric_name}/accuracy_each_class": wandb.Image(figure_collection[1])})
        run.log({f"{metric_name}/jaccard_each_class": wandb.Image(figure_collection[2])})
        run.log({f"{metric_name}/precision_each_class": wandb.Image(figure_collection[3])})
        run.log({f"{metric_name}/recall_each_class": wandb.Image(figure_collection[4])})
        run.log({f"{metric_name}/f1score_each_class": wandb.Image(figure_collection[5])})

    elif num_classes > 2:
        metric_name = "change_metrics_2013_2020"
        # log Multiclass metrics
        run.summary[metric_name] = multiclass_metrics

        # Log metrics for each class
        metrics_each_label.pop("ConfusionMatrix")
        tensor_stack = torch.stack(list(metrics_each_label.values())).T
        column_lists = [row.tolist() for row in tensor_stack]
        change_metrics_each_class_table = wandb.Table(
            columns=list(metrics_each_label.keys()),
            data=column_lists,
        )
        run.log({f"{metric_name}/table_each_class": change_metrics_each_class_table})

        # Log figures
        run.log({f"{metric_name}/confusion_matrix": wandb.Image(figure_collection[0])})
        run.log({f"{metric_name}/accuracy_each_class": wandb.Image(figure_collection[1])})
        run.log({f"{metric_name}/jaccard_each_class": wandb.Image(figure_collection[2])})
        run.log({f"{metric_name}/precision_each_class": wandb.Image(figure_collection[3])})
        run.log({f"{metric_name}/recall_each_class": wandb.Image(figure_collection[4])})
        run.log({f"{metric_name}/f1score_each_class": wandb.Image(figure_collection[5])})

## Change Map with Test Dataset

In [12]:
generate_change_map(
    mask_path=data_folder / "processed/mask/classes_v3_2003.tif",
    prediction_path=experiment_dir / "roi_example_model-ocg72h08_v97_mIoU=0.3695.tif.tif",
    output_path=experiment_dir / "change_map_2003-2013_roi_example-ocg72h08_v97.tif",
)

In [None]:
class_names = [
    "No change",
    "Mature Tree Density Loss",
    "Old Growth Density Loss",
    "Forest Setback YoungLoss",
    "Forest Stage Progression",
    "Forest Density Gain",
    "Early Forest Establishment",
    "Clearcut Loss",
    "Other Transition"]

#class_names = [
#    "No change",
#    "Change"
#]

num_classes = len(class_names)

multiclass_metrics, metrics_each_label, figure_collection = compute_final_metrics(
    reference_change_map=data_folder / "habitalp_change/habitalp_change_2013_2020.tif",
    prediction_change_map=experiment_dir / "change_map_2003-2013-ocg72h08_v97.tif",
    num_classes=num_classes, # 9 classes to take no change and index into account
    class_names=class_names,
)

print(multiclass_metrics)

In [None]:
with wandb.init(project=wandb_project, id="nt7lhzek", resume="must") as run:
    if num_classes == 2:
        metric_name = "change_metrics_binary_2003-2013"
        # log Multiclass metrics
        run.summary[metric_name] = multiclass_metrics

        # Log metrics for each class
        metrics_each_label.pop("ConfusionMatrix")
        tensor_stack = torch.stack(list(metrics_each_label.values())).T
        column_lists = [row.tolist() for row in tensor_stack]
        change_metrics_each_class_table = wandb.Table(
            columns=list(metrics_each_label.keys()),
            data=column_lists,
        )
        run.log({f"{metric_name}/table_each_class": change_metrics_each_class_table})

        # Log figures
        run.log({f"{metric_name}/confusion_matrix": wandb.Image(figure_collection[0])})
        run.log({f"{metric_name}/accuracy_each_class": wandb.Image(figure_collection[1])})
        run.log({f"{metric_name}/jaccard_each_class": wandb.Image(figure_collection[2])})
        run.log({f"{metric_name}/precision_each_class": wandb.Image(figure_collection[3])})
        run.log({f"{metric_name}/recall_each_class": wandb.Image(figure_collection[4])})
        run.log({f"{metric_name}/f1score_each_class": wandb.Image(figure_collection[5])})

    elif num_classes > 2:
        metric_name = "change_metrics_2003-2013"
        # log Multiclass metrics
        run.summary[metric_name] = multiclass_metrics

        # Log metrics for each class
        metrics_each_label.pop("ConfusionMatrix")
        tensor_stack = torch.stack(list(metrics_each_label.values())).T
        column_lists = [row.tolist() for row in tensor_stack]
        change_metrics_each_class_table = wandb.Table(
            columns=list(metrics_each_label.keys()),
            data=column_lists,
        )
        run.log({f"{metric_name}/table_each_class": change_metrics_each_class_table})

        # Log figures
        run.log({f"{metric_name}/confusion_matrix": wandb.Image(figure_collection[0])})
        run.log({f"{metric_name}/accuracy_each_class": wandb.Image(figure_collection[1])})
        run.log({f"{metric_name}/jaccard_each_class": wandb.Image(figure_collection[2])})
        run.log({f"{metric_name}/precision_each_class": wandb.Image(figure_collection[3])})
        run.log({f"{metric_name}/recall_each_class": wandb.Image(figure_collection[4])})
        run.log({f"{metric_name}/f1score_each_class": wandb.Image(figure_collection[5])})

# Physical Constraints Module

In [25]:
from src.inference.utils import physical_constraints_module

data_folder = Path("D:/Nextcloud/HabitAlp2.0/Originaldaten")
experiment_dir = data_folder / "model_output/clay_v1_base_rgb_cir_ndsm/proability_weighted_bleding_for_inference"

mask_path = data_folder / "processed/mask/classes_v3_2013.tif"
prediction_path = experiment_dir / "prediction_training_roi_2_2020_Clayv1_64.tif"
dtm_path = data_folder / "processed/elevation_gis_stmk_photogrammetry_2022_2023/dtm_2010_2012_aligned_cross-temporal_0.2m.tif"
slope_path = data_folder / "processed/elevation_gis_stmk_photogrammetry_2022_2023/slope_2010_2012_aligned_cross-temporal_0.2m.tif"

In [None]:
physical_constraints_module(
    mask_path,
    prediction_path,
    dtm_path,
    slope_path,
    experiment_dir,
    output_name="prediction_2020_roi_2_cross-temporal_with_constraint_violation_reset_to_mask",
    export_mask_for_each_constraint=True,
)

In [31]:
from src.inference.eval import generate_change_map

generate_change_map(
    mask_path=mask_path,
    prediction_path=experiment_dir / "prediction_2020_roi_1_cross-temporal_with_constraint_violation_reset_to_mask.tif",
    output_path=experiment_dir / "change_map_cross-temporal_2020_roi_1_with_constraint_violation_reset_to_mask.tif",
)

In [None]:
from src.trainers.utils import compute_final_metrics

class_names = [
    "No change",
    "Mature Tree Density Loss",
    "Old Growth Density Loss",
    "Forest Setback YoungLoss",
    "Forest Stage Progression",
    "Forest Density Gain",
    "Early Forest Establishment",
    "Clearcut Loss",
    "Other Transition"]

#class_names = [
#    "No change",
#    "Change"
#]

num_classes = len(class_names)

multiclass_metrics, metrics_each_label, figure_collection = compute_final_metrics(
    reference_change_map=data_folder / "habitalp_change/habitalp_change_2013_2020.tif",
    prediction_change_map=experiment_dir / "change_map_cross-temporal_2020_roi_1_with_constraint_violation_reset_to_mask.tif",
    num_classes=num_classes, # 9 classes to take no change and index into account
    class_names=class_names,
)

print(multiclass_metrics)

# Shutdown OS

In [7]:
import os

os.system("shutdown /s /t 1")

0