Copyright (c) Microsoft Corporation. All rights reserved.

Licensed under the MIT License.

# Custom Trainers

In this tutorial, we demonstrate how to extend a TorchGeo ["trainer class"](https://torchgeo.readthedocs.io/en/latest/api/trainers.html). In TorchGeo there exist several trainer classes that are pre-made PyTorch Lightning Modules designed to allow for the easy training of models on semantic segmentation, classification, change detection, etc. tasks using TorchGeo's [prebuild DataModules](https://torchgeo.readthedocs.io/en/latest/api/datamodules.html). While the trainers aim to provide sensible defaults and customization options for common tasks, they will not be able to cover all situations (e.g. researchers will likely want to implement and use their own architectures, loss functions, optimizers, etc. in the training routine). If you run into such a situation, then you can simply extend the trainer class you are interested in, and write custom logic to override the default functionality.

This tutorial shows how to do exactly this to customize a learning rate schedule, logging, and model checkpointing for a semantic segmentation task using the [LandCoverAI](https://landcover.ai.linuxpolska.com/) dataset.

It's recommended to run this notebook on Google Colab if you don't have your own GPU. Click the "Open in Colab" button above to get started.

## Setup

As always, we install TorchGeo.

In [None]:
%pip install torchgeo

## Imports

Next, we import TorchGeo and any other libraries we need.

In [2]:
from torchgeo.trainers import SemanticSegmentationTask
from torchgeo.datamodules import LandCoverAIDataModule
from torchmetrics import MetricCollection
from torchmetrics.classification import (
    Accuracy,
    FBetaScore,
    Precision,
    Recall,
    JaccardIndex,
)

import lightning.pytorch as pl
from lightning.pytorch.callbacks import ModelCheckpoint
import torch

from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.optim import AdamW

# Get rid of the pesky warnings raised by kornia
# UserWarning: Default grid_sample and affine_grid behavior has changed to align_corners=False since 1.3.0. Please specify align_corners=True if the old behavior is desired. See the documentation of grid_sample for details.
import warnings

warnings.filterwarnings("ignore", category=UserWarning, module="torch.nn.functional")

## Custom SemanticSegmentationTask

Now, we create a `CustomSemanticSegmentationTask` class that inhierits from `SemanticSegmentationTask` and that overrides a few methods:
- `__init__`: We add two new parameters `tmax` and `eta_min` to control the learning rate scheduler
- `configure_optimizers`: We use the `CosineAnnealingLR` learning rate scheduler instead of the default `ReduceLROnPlateau`
- `configure_metrics`: We add a "MeanIou" metric (what we will use to evaluate the model's performance) and a variety of other classification metrics
- `configure_callbacks`: We demonstrate how to stack `ModelCheckpoint` callbacks to save the best checkpoint as well as periodic checkpoints
- `on_train_epoch_start`: We log the learning rate at the start of each epoch so we can easily see how it decays over a training run

Overall these demonstrate how to customize the training routine to investigate specific research questions (e.g. of the scheduler on test performance).

In [3]:
class CustomSemanticSegmentationTask(SemanticSegmentationTask):

    # any keywords we add here between *args and **kwargs will be found in self.hparams
    def __init__(self, *args, tmax=50, eta_min=1e-6, **kwargs) -> None:
        super().__init__(*args, **kwargs)  # pass args and kwargs to the parent class

    def configure_optimizers(
        self,
    ) -> "lightning.pytorch.utilities.types.OptimizerLRSchedulerConfig":
        """Initialize the optimizer and learning rate scheduler.

        Returns:
            Optimizer and learning rate scheduler.
        """
        tmax: int = self.hparams["tmax"]
        eta_min: float = self.hparams["eta_min"]

        optimizer = AdamW(self.parameters(), lr=self.hparams["lr"])
        scheduler = CosineAnnealingLR(optimizer, T_max=tmax, eta_min=eta_min)
        return {
            "optimizer": optimizer,
            "lr_scheduler": {"scheduler": scheduler, "monitor": self.monitor},
        }

    def configure_metrics(self) -> None:
        """Initialize the performance metrics."""
        num_classes: int = self.hparams["num_classes"]

        self.train_metrics = MetricCollection(
            {
                "OverallAccuracy": Accuracy(
                    task="multiclass", num_classes=num_classes, average="micro"
                ),
                "OverallPrecision": Precision(
                    task="multiclass", num_classes=num_classes, average="micro"
                ),
                "OverallRecall": Recall(
                    task="multiclass", num_classes=num_classes, average="micro"
                ),
                "OverallF1Score": FBetaScore(
                    task="multiclass",
                    num_classes=num_classes,
                    beta=1.0,
                    average="micro",
                ),
                "MeanIoU": JaccardIndex(
                    num_classes=num_classes, task="multiclass", average="macro"
                ),
            },
            prefix="train_",
        )
        self.val_metrics = self.train_metrics.clone(prefix="val_")
        self.test_metrics = self.train_metrics.clone(prefix="test_")

    def configure_callbacks(self):
        """Initialize callbacks for saving the best and latest models.

        Returns:
            List of callbacks to apply.
        """
        return [
            ModelCheckpoint(every_n_epochs=50, save_top_k=-1),
            ModelCheckpoint(monitor=self.monitor, mode=self.mode, save_top_k=5),
        ]

    def on_train_epoch_start(self) -> None:
        """Log the learning rate at the start of each training epoch."""
        lr = self.optimizers().param_groups[0]["lr"]
        self.logger.experiment.add_scalar("lr", lr, self.current_epoch)

## Train model

The remainder of the turial is straightforward and follows the typical [PyTorch Lightning](https://lightning.ai/) training routine. We instantiate a `DataModule` for the LandCover.AI dataset, instantiate a `CustomSemanticSegmentationTask` with a U-Net and ResNet-50 backbone, then train the model using a Lightning trainer.

In [4]:
dm = LandCoverAIDataModule(
    root="data/",
    batch_size=64,
    num_workers=8,
    download=True,
)

In [5]:
task = CustomSemanticSegmentationTask(
    model="unet",
    backbone="resnet50",
    weights=True,
    in_channels=3,
    num_classes=6,
    loss="ce",
    lr=1e-3,
    tmax=50,
)

In [6]:
# validate that the task's hyperparameters are as expected
task.hparams

"backbone":        resnet50
"class_weights":   None
"eta_min":         1e-06
"freeze_backbone": False
"freeze_decoder":  False
"ignore":          weights
"ignore_index":    None
"in_channels":     3
"loss":            ce
"lr":              0.001
"model":           unet
"num_classes":     6
"num_filters":     3
"patience":        10
"tmax":            50

In [7]:
accelerator = "gpu" if torch.cuda.is_available() else "cpu"

trainer = pl.Trainer(
    accelerator=accelerator, devices=[0], min_epochs=150, max_epochs=300, log_every_n_steps=50
)

GPU available: True (cuda), used: True


TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [8]:
trainer.fit(task, dm)

The following callbacks returned in `LightningModule.configure_callbacks` will override existing callbacks passed to Trainer: ModelCheckpoint


Downloading https://landcover.ai.linuxpolska.com/download/landcover.ai.v1.zip to data/landcover.ai.v1.zip


100%|██████████| 1538212277/1538212277 [01:25<00:00, 17913845.14it/s]


Processed M-33-20-D-c-4-2 1/41
Processed M-33-20-D-d-3-3 2/41
Processed M-33-32-B-b-4-4 3/41
Processed M-33-48-A-c-4-4 4/41
Processed M-33-7-A-d-2-3 5/41
Processed M-33-7-A-d-3-2 6/41
Processed M-34-32-B-a-4-3 7/41
Processed M-34-32-B-b-1-3 8/41
Processed M-34-5-D-d-4-2 9/41
Processed M-34-51-C-b-2-1 10/41
Processed M-34-51-C-d-4-1 11/41
Processed M-34-55-B-b-4-1 12/41
Processed M-34-56-A-b-1-4 13/41
Processed M-34-6-A-d-2-2 14/41
Processed M-34-65-D-a-4-4 15/41
Processed M-34-65-D-c-4-2 16/41
Processed M-34-65-D-d-4-1 17/41
Processed M-34-68-B-a-1-3 18/41
Processed M-34-77-B-c-2-3 19/41
Processed N-33-104-A-c-1-1 20/41
Processed N-33-119-C-c-3-3 21/41
Processed N-33-130-A-d-3-3 22/41
Processed N-33-130-A-d-4-4 23/41
Processed N-33-139-C-d-2-2 24/41
Processed N-33-139-C-d-2-4 25/41
Processed N-33-139-D-c-1-3 26/41
Processed N-33-60-D-c-4-2 27/41
Processed N-33-60-D-d-1-2 28/41
Processed N-33-96-D-d-1-1 29/41
Processed N-34-106-A-b-3-4 30/41
Processed N-34-106-A-c-1-3 31/41
Processed N-

You are using a CUDA device ('NVIDIA A100 80GB PCIe') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision


Processed N-34-97-D-c-2-4 41/41


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name          | Type             | Params
---------------------------------------------------
0 | criterion     | CrossEntropyLoss | 0     
1 | train_metrics | MetricCollection | 0     
2 | val_metrics   | MetricCollection | 0     
3 | test_metrics  | MetricCollection | 0     
4 | model         | Unet             | 32.5 M
---------------------------------------------------
32.5 M    Trainable params
0         Non-trainable params
32.5 M    Total params
130.087   Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

/home/calebrobinson/.conda/envs/geo/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


## Test model

Finally, we test the model on the test set and visualize the results.

In [9]:
# If you are starting from a checkpoint, run this cell
task = CustomSemanticSegmentationTask.load_from_checkpoint("lightning_logs/version_3/checkpoints/epoch=0-step=117.ckpt")

TypeError: SemanticSegmentationTask.__init__() got an unexpected keyword argument 'ignore'

In [None]:
trainer.test(task, dm)