In [6]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
import warnings
warnings.filterwarnings("ignore")

# Set up
*   Install libraries
*   Load dataset

In [None]:
# ! rm -rf dataset
# ! rm -rf library

In [None]:
# ! pip install terratorch

In [None]:
# ! mkdir -p dataset
# ! mkdir -p library

In [None]:
# ! wget https://huggingface.co/datasets/hk-kaden-kim/Small_S2_CloudCover_Seg/resolve/main/train.zip -P dataset
# ! unzip -q dataset/train.zip -d dataset
# ! rm dataset/train.zip

In [None]:
# ! wget https://huggingface.co/datasets/hk-kaden-kim/Small_S2_CloudCover_Seg/resolve/main/test.zip -P dataset
# ! unzip -q dataset/test.zip -d dataset
# ! rm dataset/test.zip

In [None]:
# ! wget https://github.com/hk-kaden-kim/S2-CloudCover/raw/refs/heads/main/library/__init__.py -P library
# ! wget https://github.com/hk-kaden-kim/S2-CloudCover/raw/refs/heads/main/library/analysis.py -P library
# ! wget https://github.com/hk-kaden-kim/S2-CloudCover/raw/refs/heads/main/library/dataset.py -P library

In [None]:
# ! rm -rf dataset/__MACOSX

# Lightning Trainers

In [8]:
from terratorch.datamodules import TorchNonGeoDataModule

from library.dataset import CustomNonGeoDataModule, CustomCloudCoverDetection

import albumentations as A
from albumentations.pytorch import ToTensorV2

BATCH_SIZE = 16 # 32

datamodule = TorchNonGeoDataModule(

    # Dataset Module
    cls = CustomNonGeoDataModule,
    batch_size = BATCH_SIZE,
    num_workers = 2,
    train_aug=[
        A.D4(),   # D4 package : e | r90 | r180 | r270 | v | hvt | h | t
        A.ToFloat(max_value=65536.0),   
        ToTensorV2(),
        ],
    val_aug=[
        A.ToFloat(max_value=65536.0),   
        ToTensorV2()
        ],
    test_aug=[
        A.ToFloat(max_value=65536.0),   
        ToTensorV2()
        ],

    # Dataset
    dataset_class = CustomCloudCoverDetection,
    root = './dataset',
    bands = ['B02', 'B03', 'B04', 'B08'],
    download=False
)

INFO:numexpr.utils:NumExpr defaulting to 4 threads.
INFO:albumentations.check_version:A new version of Albumentations is available: 2.0.5 (you have 1.4.10). Upgrade using: pip install --upgrade albumentations


In [27]:
from terratorch.tasks import SemanticSegmentationTask
from library.analysis import CustomSemanticSegmentationTask

LOSS = 'ce'
LEARNING_RATE = 1e-3
OPTIMIZER = 'AdamW'
OPTIMIZER_HPARAMS = {"weight_decay": 0.05}

model_args = {
        "backbone":"resnet34", # see smp_encoders.keys()
        'model': 'DeepLabV3Plus', # 'DeepLabV3', 'DeepLabV3Plus', 'FPN', 'Linknet', 'MAnet', 'PAN', 'PSPNet', 'Unet', 'UnetPlusPlus'
        "bands": ['B02', 'B03', 'B04', 'B08'], # We use the 4 bands from the Sentinel-2 dataset.
        "in_channels": 4,
        "num_classes": 2,
        "pretrained": True,
}

task = SemanticSegmentationTask(
    model_args=model_args,
    model_factory="SMPModelFactory",
    loss=LOSS,
    lr=LEARNING_RATE,
    optimizer=OPTIMIZER,
    optimizer_hparams=OPTIMIZER_HPARAMS,
    freeze_backbone=True,
    freeze_decoder=False,
    class_names=['No', 'Cloud'],
    plot_on_val=0,
)

In [28]:
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint, RichProgressBar
from lightning.pytorch.loggers import TensorBoardLogger

EPOCH = 30
MAIN_METRIC = 'Multiclass_Jaccard_Index'

checkpoint_callback = ModelCheckpoint(mode="max",
                                      monitor=f"val/{MAIN_METRIC}", # Variable to monitor
                                      filename="best-{epoch:02d}",)
early_stopping_callback = EarlyStopping(mode="max",
                                        monitor=f"val/{MAIN_METRIC}",
                                        min_delta=0.0001,
                                        patience=5)

logger = TensorBoardLogger(save_dir='output',
                           version=f"E{EPOCH}_B{BATCH_SIZE}_{LOSS}_LR{LEARNING_RATE}",
                           name=f"{model_args['model']}_{model_args['backbone']}")

trainer = Trainer(
    devices=1, # Number of GPUs. Interactive mode recommended with 1 device
    precision="16-mixed",
    callbacks=[
        RichProgressBar(),
        checkpoint_callback,
        # early_stopping_callback,
        LearningRateMonitor(logging_interval="epoch"),
    ],
    logger=logger,
    max_epochs=EPOCH,
    default_root_dir='output',
    log_every_n_steps=1,
    check_val_every_n_epoch=1
)

INFO: Using 16bit Automatic Mixed Precision (AMP)
INFO:lightning.pytorch.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [29]:
_ = trainer.fit(model=task, datamodule=datamodule)

INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


INFO: `Trainer.fit` stopped: `max_epochs=30` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=30` reached.


In [30]:
res = trainer.test(model=task, datamodule=datamodule)

INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


In [None]:
# ! zip -r DeepLabV3_resnet34.zip output/Unet_resnet34

In [None]:
# %load_ext tensorboard
# %tensorboard --logdir output

In [None]:
"""
Model : UNet with ResNet34
Epoch : 30
Batch : 16
loss : ce
lr : 1e-3
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃             Test metric             ┃            DataLoader 0             ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│      test/Multiclass_Accuracy       │         0.8920501470565796          │
│      test/Multiclass_F1_Score       │         0.8920501470565796          │
│    test/Multiclass_Jaccard_Index    │         0.7992327809333801          │
│ test/Multiclass_Jaccard_Index_Micro │         0.8051358461380005          │
│              test/loss              │         0.29767701029777527         │
│    test/multiclassaccuracy_Cloud    │         0.8943116664886475          │
│     test/multiclassaccuracy_No      │         0.8885871171951294          │
│  test/multiclassjaccardindex_Cloud  │         0.8336586356163025          │
│   test/multiclassjaccardindex_No    │         0.7648069262504578          │
└─────────────────────────────────────┴─────────────────────────────────────┘

Model : UNet++ with ResNet34
Epoch : 30
Batch : 16
loss : ce
lr : 1e-3
Trainable params: 4.8 M
Non-trainable params: 21.3 M
Total params: 26.1 M
Total estimated model params size (MB): 104
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃             Test metric             ┃            DataLoader 0             ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│      test/Multiclass_Accuracy       │         0.8686501383781433          │
│      test/Multiclass_F1_Score       │         0.8686501383781433          │
│    test/Multiclass_Jaccard_Index    │         0.7505015134811401          │
│ test/Multiclass_Jaccard_Index_Micro │         0.7677997350692749          │
│              test/loss              │         0.3378654420375824          │
│    test/multiclassaccuracy_Cloud    │         0.9641627669334412          │
│     test/multiclassaccuracy_No      │         0.7223877906799316          │
│  test/multiclassjaccardindex_Cloud  │         0.8161967992782593          │
│   test/multiclassjaccardindex_No    │          0.684806227684021          │
└─────────────────────────────────────┴─────────────────────────────────────┘

Model : DeepLabV3 with ResNet34
Epoch : 30
Batch : 16
loss : ce
lr : 1e-3
Trainable params: 4.7 M                                                                                            
Non-trainable params: 21.3 M                                                                                       
Total params: 26.0 M                                                                                               
Total estimated model params size (MB): 104 
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃             Test metric             ┃            DataLoader 0             ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│      test/Multiclass_Accuracy       │         0.8701466917991638          │
│      test/Multiclass_F1_Score       │         0.8701466917991638          │
│    test/Multiclass_Jaccard_Index    │         0.7591824531555176          │
│ test/Multiclass_Jaccard_Index_Micro │         0.7701413631439209          │
│              test/loss              │         0.34615635871887207         │
│    test/multiclassaccuracy_Cloud    │         0.9183960556983948          │
│     test/multiclassaccuracy_No      │         0.7962605953216553          │
│  test/multiclassjaccardindex_Cloud  │         0.8105546832084656          │
│   test/multiclassjaccardindex_No    │         0.7078101634979248          │
└─────────────────────────────────────┴─────────────────────────────────────┘

Model : DeepLabV3+ with ResNet34
Epoch : 30
Batch : 16
loss : ce
lr : 1e-3
Trainable params: 1.2 M                                                                                            
Non-trainable params: 21.3 M                                                                                       
Total params: 22.4 M                                                                                               
Total estimated model params size (MB): 89  
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃             Test metric             ┃            DataLoader 0             ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│      test/Multiclass_Accuracy       │         0.8643276691436768          │
│      test/Multiclass_F1_Score       │         0.8643276691436768          │
│    test/Multiclass_Jaccard_Index    │         0.7454323768615723          │
│ test/Multiclass_Jaccard_Index_Micro │         0.7610713243484497          │
│              test/loss              │         0.3959507942199707          │
│    test/multiclassaccuracy_Cloud    │         0.9470270276069641          │
│     test/multiclassaccuracy_No      │         0.7376868724822998          │
│  test/multiclassjaccardindex_Cloud  │         0.8085288405418396          │
│   test/multiclassjaccardindex_No    │         0.6823359727859497          │
└─────────────────────────────────────┴─────────────────────────────────────┘
"""

# CLI tool

You find an example for SMP models in `configs/burnscars_smp.yaml` that you can run with `terratorch fit -c configs/burnscars_smp.yaml`.