## Experimental Notebook following 

In [2]:
import sys

sys.path.append('..')

import warnings

import torch
from lightning.pytorch import Trainer

warnings.filterwarnings('ignore')
from pathlib import Path

import wandb
from lightning.pytorch.loggers import WandbLogger
from terratorch.tasks import SemanticSegmentationTask

from src.data.datamodules.unet_change import SegmentationDataModule


In [14]:

data_folder = Path("/home/hkristen/Nextcloud/HabitAlp2.0/Originaldaten")
input_data = {
    "rgb": data_folder / "processed/orthophoto_gis_stmk/flug_2013_2015_rgb.tif",
    "cir": data_folder / "processed/orthophoto_gis_stmk/falschfarben_2013_2015.tif",
    "dtm": data_folder / "processed/elevation_waldstmk/dtm.tif",
    # "dsm": data_folder / "processed/elevation_waldstmk/dsm.tif",
    "ndsm": data_folder / "processed/elevation_waldstmk/ndsm.tif",
    # "slope": data_folder / "processed/elevation_waldstmk/slope.tif",
    # "aspect": data_folder / "processed/elevation_waldstmk/aspect.tif",
    # "tri": data_folder / "processed/elevation_waldstmk/tri.tif",
    # "tpi": data_folder / "processed/elevation_waldstmk/tpi.tif",
    # "roughness": data_folder / "processed/elevation_waldstmk/roughness.tif",
    # "curvature": data_folder / "processed/elevation_waldstmk/curvature.tif",
    # "planform_curvature": data_folder / "processed/elevation_waldstmk/planform_curvature.tif",
    # "profile_curvature": data_folder / "processed/elevation_waldstmk/profile_curvature.tif",
}
mask_path = data_folder / "processed/mask/classes_v3.tif"
roi_shape_path = data_folder / "roi/habitalp_2013_boundary.gpkg"
target_class_definiton_path = data_folder / "habitalp_target_classes/Zielklassen 2024 v3.csv"

experiments_pth = '/home/hkristen/habitalp2/src/models/experiments' 
experiment_name = 'TERRATORCH_TEST_6_CLAY_TEST_15'

In [15]:
lr = 1e-4
accelerator = "auto"
max_epochs = 2
batch_size = 6

In [16]:
# First create the datamodule
datamodule = SegmentationDataModule(
    input_data,
    mask_path,
    roi_shape_path,
    n_classes=23,
    batch_size=batch_size,
    patch_size=(256,256),
    num_workers=4,
    train_batches_per_epoch=1024,
    val_batches_per_epoch=64,
    target_class_definition_path=target_class_definiton_path,
)

datamodule.setup(stage='fit')


# Now you can access the datasets
print(f"Train dataset size: {len(datamodule.train_dataset)}")
print(f"Val dataset size: {len(datamodule.val_dataset)}")
print(f"Test dataset size: {len(datamodule.test_dataset)}")

Converting RasterDataset res from (0.5, 0.5) to (0.19989701994387044, 0.19989701994387044)
Converting RasterDataset res from (0.5, 0.5) to (0.19989701994387044, 0.19989701994387044)
Converting RasterDataset res from (0.19989701994387069, 0.19989701994386977) to (0.19989701994387044, 0.19989701994387044)
Assigned 47/10/9 cells to train/val/test datasets.
Train dataset size: 47
Val dataset size: 10
Test dataset size: 9


In [17]:
datamodule.class_names

['Gewässer',
 'Kiesbank, Sandbank, fluviatil',
 'Erosionsfläche, Rinne',
 'Lockermaterial',
 'Fels',
 'Nadel-Jungwuchs, Dickung, Feldgehölz',
 'Laub-Jungwuchs, Dickung, Feldgehölz',
 'Nadel-Stangenholz DG0-70',
 'Nadel-Stangenholz DG80-100',
 'Laub-Stangenholz',
 'Nadel-Baumholz DG0-70',
 'Nadel-Baumholz DG80-100',
 'Laub-Baumholz DG0-70',
 'Laub-Baumholz DG80-100',
 'Nadel-Altbestand, mehrsch. DG0-70',
 'Nadel-Altbestand, mehrsch. DG80-100',
 'Laub-Altbestand, mehrsch. DG0-70',
 'Laub-Altbestand, mehrsch. DG80-100',
 'Schlagflächen',
 'Krummholzgürtel',
 'Grünland, Saumvegetation',
 'Berggrünland, Heide',
 'Priorität 1']

In [18]:
datamodule.n_classes

23

### CLAY

In [19]:
model_args = dict(
    backbone="clay_v1_base",
    decoder="FCNDecoder",
    in_channels=6,
    bands=["red", "green", "blue", "nir", "dtm", "dsm"],
    num_classes=datamodule.n_classes + 1,
    pretrained=True,
    num_frames=1,
)

In [20]:
len(model_args['bands'])

6

In [21]:
task = SemanticSegmentationTask(
    model_args,
    "ClayModelFactory",
    loss="ce",
    lr=lr,
    ignore_index=None,
    optimizer="AdamW",
    optimizer_hparams={"weight_decay": 0.05},
    freeze_backbone = False,
    plot_on_val = False,
    class_names = datamodule.class_names,
    class_weights=[
      3.05754852e-09, 5.61588104e-08, 2.55059813e-07, 1.31145305e-08,
      1.18684274e-08, 2.41300286e-09, 1.39145003e-08, 2.71112867e-08,
      1.55613875e-08, 8.56254672e-09, 3.42943184e-08, 4.17894324e-09,
      4.42492833e-09, 2.31283622e-08, 1.43593138e-08, 5.54229727e-09,
      1.03878198e-08, 4.66937618e-08, 1.60617068e-08, 1.63495321e-08,
      2.86372574e-09, 1.65979327e-08, 6.29546752e-09, 4.21235049e-08
      ]
)

INFO:timm.models._builder:Loading pretrained weights from Hugging Face hub (('made-with-clay/Clay', 'v1/clay-v1-base.ckpt'))


## Train

In [22]:
logging = 'remote'
wandb_project = 'semantic-segmentation-terratorch'
wandb_key = 'YOUR_WANDB_API_KEY_HERE'

wandb.login(key=wandb_key)
wb_logger = WandbLogger(
    name=experiment_name,
    project=wandb_project,
    log_model="all",
)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/hkristen/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mhkristen[0m ([33muniversity-of-graz-geo[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [23]:
trainer = Trainer(
    accelerator=accelerator,
    max_epochs=max_epochs,
    logger=[wb_logger],
)

INFO: You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
INFO:lightning.pytorch.utilities.rank_zero:You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [24]:
#Check sizes and shapes of the batches
sample_batch = next(iter(datamodule.train_dataloader()))
print("Input shape:", sample_batch["image"].shape)
print("Input dtype:", sample_batch["image"].dtype)
print("Mask shape:", sample_batch["mask"].shape)
print("Mask dtype:", sample_batch["mask"].dtype)
print("Unique values in mask:", torch.unique(sample_batch["mask"]))

Input shape: torch.Size([6, 6, 256, 256])
Input dtype: torch.float32
Mask shape: torch.Size([6, 256, 256])
Mask dtype: torch.int64
Unique values in mask: tensor([ 1,  6,  8,  9, 11, 15, 19, 21, 23])


In [None]:
trainer.fit(model=task, datamodule=datamodule)

INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: 
  | Name          | Type             | Params | Mode 
-----------------------------------------------------------
0 | model         | PixelWiseModel   | 96.0 M | train
1 | criterion     | CrossEntropyLoss | 0      | train
2 | train_metrics | MetricCollection | 0      | train
3 | val_metrics   | MetricCollection | 0      | train
4 | test_metrics  | ModuleList       | 0      | train
-----------------------------------------------------------
96.0 M    Trainable params
0         Non-trainable params
96.0 M    Total params
384.178   Total estimated model params size (MB)
234       Modules in train mode
0         Modules in eval mode
INFO:lightning.pytorch.callbacks.model_summary:
  | Name          | Type             | Params | Mode 
-----------------------------------------------------------
0 | model         | PixelWiseModel   | 96.0 M | train
1 | crite

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

tensor([0.6650, 0.5600, 0.4930, 0.8420, 0.0000, 0.0000])
tensor([0.6650, 0.5600, 0.4930, 0.8420, 0.0000, 0.0000])


Training: |          | 0/? [00:00<?, ?it/s]

tensor([0.6650, 0.5600, 0.4930, 0.8420, 0.0000, 0.0000])
tensor([0.6650, 0.5600, 0.4930, 0.8420, 0.0000, 0.0000])
tensor([0.6650, 0.5600, 0.4930, 0.8420, 0.0000, 0.0000])
tensor([0.6650, 0.5600, 0.4930, 0.8420, 0.0000, 0.0000])
tensor([0.6650, 0.5600, 0.4930, 0.8420, 0.0000, 0.0000])
tensor([0.6650, 0.5600, 0.4930, 0.8420, 0.0000, 0.0000])
tensor([0.6650, 0.5600, 0.4930, 0.8420, 0.0000, 0.0000])
tensor([0.6650, 0.5600, 0.4930, 0.8420, 0.0000, 0.0000])
tensor([0.6650, 0.5600, 0.4930, 0.8420, 0.0000, 0.0000])
tensor([0.6650, 0.5600, 0.4930, 0.8420, 0.0000, 0.0000])
tensor([0.6650, 0.5600, 0.4930, 0.8420, 0.0000, 0.0000])
tensor([0.6650, 0.5600, 0.4930, 0.8420, 0.0000, 0.0000])
tensor([0.6650, 0.5600, 0.4930, 0.8420, 0.0000, 0.0000])
tensor([0.6650, 0.5600, 0.4930, 0.8420, 0.0000, 0.0000])
tensor([0.6650, 0.5600, 0.4930, 0.8420, 0.0000, 0.0000])
tensor([0.6650, 0.5600, 0.4930, 0.8420, 0.0000, 0.0000])
tensor([0.6650, 0.5600, 0.4930, 0.8420, 0.0000, 0.0000])
tensor([0.6650, 0.5600, 0.4930,

Validation: |          | 0/? [00:00<?, ?it/s]

tensor([0.6650, 0.5600, 0.4930, 0.8420, 0.0000, 0.0000])
tensor([0.6650, 0.5600, 0.4930, 0.8420, 0.0000, 0.0000])
tensor([0.6650, 0.5600, 0.4930, 0.8420, 0.0000, 0.0000])
tensor([0.6650, 0.5600, 0.4930, 0.8420, 0.0000, 0.0000])
tensor([0.6650, 0.5600, 0.4930, 0.8420, 0.0000, 0.0000])
tensor([0.6650, 0.5600, 0.4930, 0.8420, 0.0000, 0.0000])
tensor([0.6650, 0.5600, 0.4930, 0.8420, 0.0000, 0.0000])
tensor([0.6650, 0.5600, 0.4930, 0.8420, 0.0000, 0.0000])
tensor([0.6650, 0.5600, 0.4930, 0.8420, 0.0000, 0.0000])
tensor([0.6650, 0.5600, 0.4930, 0.8420, 0.0000, 0.0000])
tensor([0.6650, 0.5600, 0.4930, 0.8420, 0.0000, 0.0000])
tensor([0.6650, 0.5600, 0.4930, 0.8420, 0.0000, 0.0000])
tensor([0.6650, 0.5600, 0.4930, 0.8420, 0.0000, 0.0000])
tensor([0.6650, 0.5600, 0.4930, 0.8420, 0.0000, 0.0000])
tensor([0.6650, 0.5600, 0.4930, 0.8420, 0.0000, 0.0000])
tensor([0.6650, 0.5600, 0.4930, 0.8420, 0.0000, 0.0000])
tensor([0.6650, 0.5600, 0.4930, 0.8420, 0.0000, 0.0000])
tensor([0.6650, 0.5600, 0.4930,