In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import warnings
warnings.filterwarnings("ignore")

import lightning.pytorch as pl
pl.seed_everything(0)

Seed set to 0


0

# Set up
*   Install libraries
*   Load dataset

In [None]:
# ! rm -rf dataset
# ! rm -rf library

In [None]:
# ! pip install terratorch xarray-spatial

In [None]:
# ! mkdir -p dataset

In [None]:
# ! wget https://huggingface.co/datasets/hk-kaden-kim/Small_S2_CloudCover_Seg/resolve/main/train.zip -P dataset
# ! unzip -q dataset/train.zip -d dataset
# ! mv dataset/train dataset/public
# ! rm dataset/train.zip

In [None]:
# ! wget https://huggingface.co/datasets/hk-kaden-kim/Small_S2_CloudCover_Seg/resolve/main/test.zip -P dataset
# ! unzip -q dataset/test.zip -d dataset
# ! mv dataset/test dataset/private
# ! rm dataset/test.zip

In [None]:
# ! mkdir -p library
# ! wget https://github.com/hk-kaden-kim/S2-CloudCover/raw/refs/heads/main/library/__init__.py -P library
# ! wget https://github.com/hk-kaden-kim/S2-CloudCover/raw/refs/heads/main/library/analysis.py -P library

In [None]:
# ! mkdir -p library/datasets
# ! wget https://github.com/hk-kaden-kim/S2-CloudCover/raw/refs/heads/main/library/datasets/__init__.py -P library/datasets
# ! wget https://github.com/hk-kaden-kim/S2-CloudCover/raw/refs/heads/main/library/datasets/sen2cloud.py -P library/datasets

In [None]:
# ! mkdir -p library/datamodules
# ! wget https://github.com/hk-kaden-kim/S2-CloudCover/raw/refs/heads/main/library/datamodules/__init__.py -P library/datamodules
# ! wget https://github.com/hk-kaden-kim/S2-CloudCover/raw/refs/heads/main/library/datamodules/sen2cloud.py -P library/datamodules

In [None]:
# ! rm -rf dataset/__MACOSX

# Lightning Trainers

In [3]:
from library.datamodules.sen2cloud import Sen2CloudDataModule

BATCH_SIZE = 8 # 32

datamodule = Sen2CloudDataModule(
    data_root = './dataset',
    batch_size = BATCH_SIZE,
    means = [2631.64794921875, 2636.205078125, 2545.404052734375, 3444.451416015625],
    stds = [3057.195068359375, 2818.6640625, 2720.16796875, 2450.236328125],
)

datamodule.setup("fit")
datamodule.setup("test")

train_dataset = datamodule.train_dataset
val_dataset = datamodule.val_dataset

test_dataset = datamodule.test_dataset
len(train_dataset), len(val_dataset), len(test_dataset)

INFO:numexpr.utils:NumExpr defaulting to 8 threads.
INFO:albumentations.check_version:A new version of Albumentations is available: 2.0.5 (you have 1.4.10). Upgrade using: pip install --upgrade albumentations


wxc_downscaling not installed
wxc_downscaling not installed


(414, 171, 547)

In [21]:
from terratorch.tasks import SemanticSegmentationTask

LOSS = 'ce'
# class_weights = [0.6,0.4]
LEARNING_RATE = 1e-3
OPTIMIZER = 'AdamW'
OPTIMIZER_HPARAMS = {"weight_decay": 0.05}

model_args={
    # Backbone (Encoder)
    "backbone": "resnet34",
    "backbone_in_chans": 4,
    "backbone_channels": [64, 128, 256, 512],
    "backbone_pretrained": True,
    # "backbone_bands": ["BLUE", "GREEN", "RED", "NIR_NARROW"],

    # Decoder
    "decoder": "Unet",
    "decoder_channels": [512, 256, 128, 64],
    # "decoder_decoder_channels": [256, 128, 64, 32, 16],

    # Head
    "head_dropout": 0.1,
    "num_classes": 2,
}

# Model
task = SemanticSegmentationTask(
    model_args=model_args,
    model_factory="EncoderDecoderFactory",
    loss=LOSS,
    # class_weights=class_weights,
    lr=LEARNING_RATE,
    optimizer=OPTIMIZER,
    optimizer_hparams=OPTIMIZER_HPARAMS,
    freeze_backbone=True, # True. Only to speed up fine-tuning
    freeze_decoder=False,
    class_names=['No', 'Cloud'],  # optionally define class names
    plot_on_val=0,
)

INFO:timm.models._builder:Loading pretrained weights from Hugging Face hub (timm/resnet34.a1_in1k)
INFO:timm.models._hub:[timm/resnet34.a1_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.
INFO:timm.models._builder:Converted input conv conv1 pretrained weights from 3 to 4 channel(s)
INFO:timm.models._builder:Missing keys (fc.weight, fc.bias) discovered while loading pretrained weights. This is expected if model is being adapted.


4 (256, 128, 64, 32, 16)


ValueError: Model depth is 4, but you provide `decoder_channels` for 5 blocks.

In [None]:
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint, RichProgressBar
from lightning.pytorch.loggers import TensorBoardLogger

EPOCH = 20

checkpoint_callback = ModelCheckpoint(mode="max",
                                      monitor="val/Multiclass_Jaccard_Index",
                                      filename="best-{epoch:02d}",)
early_stopping_callback = EarlyStopping(mode="min",
                                        monitor=f"val/loss",
                                        patience=5)
logger = TensorBoardLogger(save_dir='output',
                           version=f"E{EPOCH}_B{BATCH_SIZE}_{LOSS}_LR{LEARNING_RATE}",
                           name=f"{model_args['backbone']}_{model_args['decoder']}")

trainer = Trainer(
    devices=1, # Number of GPUs. Interactive mode recommended with 1 device
    precision="16-mixed",
    callbacks=[
      RichProgressBar(),
      checkpoint_callback, # Set to save max val/IoU and last model
      early_stopping_callback,
      LearningRateMonitor(logging_interval="epoch"),
      ],
    logger=logger,
    max_epochs=EPOCH,
    default_root_dir='output',
    log_every_n_steps=1,
    check_val_every_n_epoch=1,
)

In [None]:
import torch

torch.cuda.empty_cache()

free_mem, total_mem = torch.cuda.mem_get_info()
print(f"Free Memory: {free_mem/1024**2} MB")
print(f"Total Memory: {total_mem/1024**2} MB")
print(f"Memory Allocated: {torch.cuda.memory_allocated(0)/1024**2} MB")
print(f"Memory Reserved: {torch.cuda.memory_reserved(0)/1024**2} MB")
print(f"Max Memory Allocated: {torch.cuda.max_memory_allocated(0)/1024**2} MB")
print(f"Max Memory Reserved: {torch.cuda.max_memory_reserved(0)/1024**2} MB")

In [None]:
_ = trainer.fit(model=task, datamodule=datamodule)

In [None]:
res = trainer.test(model=task, datamodule=datamodule) # Check default saved model here

In [None]:
! zip -r output.zip output

In [None]:
%load_ext tensorboard
%tensorboard --logdir output

In [None]:
"""
Backbone : prithvi_eo_v2_300 (Decoder Finetune)
Decoder : UNetDecoder
Epoch : 20
Batch : 8
loss : ce
lr : 1e-3
Trainable params: 20.3 M
Non-trainable params: 303 M
Total params: 323 M
Total estimated model params size (MB): 1.3 K
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃             Test metric             ┃            DataLoader 0             ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│      test/Multiclass_Accuracy       │          0.890679121017456          │
│      test/Multiclass_F1_Score       │         0.8906790018081665          │
│    test/Multiclass_Jaccard_Index    │          0.798092246055603          │
│ test/Multiclass_Jaccard_Index_Micro │         0.8029047846794128          │
│              test/loss              │         0.31213536858558655         │
│    test/multiclassaccuracy_Cloud    │         0.8777081370353699          │
│     test/multiclassaccuracy_No      │         0.9105421304702759          │
│  test/multiclassjaccardindex_Cloud  │         0.8292639851570129          │
│   test/multiclassjaccardindex_No    │         0.7669205665588379          │
└─────────────────────────────────────┴─────────────────────────────────────┘

Backbone : prithvi_eo_v1_100 (Decoder Finetune)
Decoder : UNetDecoder
Epoch : 20
Batch : 8
loss : ce
lr : 1e-3
Trainable params: 15.5 M
Non-trainable params: 85.8 M
Total params: 101 M
Total estimated model params size (MB): 405
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃             Test metric             ┃            DataLoader 0             ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│      test/Multiclass_Accuracy       │         0.8720716238021851          │
│      test/Multiclass_F1_Score       │         0.8720716238021851          │
│    test/Multiclass_Jaccard_Index    │         0.7667739987373352          │
│ test/Multiclass_Jaccard_Index_Micro │         0.7731622457504272          │
│              test/loss              │         0.3706841468811035          │
│    test/multiclassaccuracy_Cloud    │         0.8750652074813843          │
│     test/multiclassaccuracy_No      │         0.8674874305725098          │
│  test/multiclassjaccardindex_Cloud  │         0.8053733110427856          │
│   test/multiclassjaccardindex_No    │         0.7281746864318848          │
└─────────────────────────────────────┴─────────────────────────────────────┘

"""

# CLI tool

You find an example for SMP models in `configs/burnscars_smp.yaml` that you can run with `terratorch fit -c configs/burnscars_smp.yaml`.