In [1]:
import lightning
import torch
import torch.nn as nn
import segmentation_models_pytorch as smp
import matplotlib.pyplot as plt
from torchmetrics import MetricCollection
from torchmetrics.classification import Accuracy, FBetaScore, JaccardIndex
from torchmetrics.wrappers import ClasswiseWrapper

from torchgeo.trainers import SemanticSegmentationTask
from torchgeo.datasets.utils import RGBBandsMissingError, unbind_samples

from prithvi_pytorch import PrithviEncoderDecoder
from prithvi_pytorch.datasets import HLSBurnScarsDataModule

CKPT_PATH = "weights/Prithvi_100M.pt"
CFG_PATH = "weights/Prithvi_100M_config.yaml"

In [2]:
class FocalJaccardLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.ce_loss = smp.losses.FocalLoss(mode="multiclass", normalized=True)
        self.jaccard_loss = smp.losses.JaccardLoss(mode="multiclass")

    def forward(self, preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        return self.ce_loss(preds, targets) + self.jaccard_loss(preds, targets)


class PrithviSegmentationTask(SemanticSegmentationTask):
    def configure_losses(self):
        self.criterion = FocalJaccardLoss()

    def configure_metrics(self):
        num_classes = self.hparams["num_classes"]
        ignore_index = self.hparams["ignore_index"]

        self.train_metrics = MetricCollection(
            {
                "OverallAccuracy": Accuracy(
                    task="multiclass",
                    num_classes=num_classes,
                    average="micro",
                    multidim_average="global",
                ),
                "OverallF1Score": FBetaScore(
                    task="multiclass",
                    num_classes=num_classes,
                    beta=1.0,
                    average="micro",
                    multidim_average="global",
                ),
                "OverallIoU": JaccardIndex(
                    task="multiclass",
                    num_classes=num_classes,
                    ignore_index=ignore_index,
                    average="micro",
                ),
                "AverageAccuracy": Accuracy(
                    task="multiclass",
                    num_classes=num_classes,
                    average="macro",
                    multidim_average="global",
                ),
                "AverageF1Score": FBetaScore(
                    task="multiclass",
                    num_classes=num_classes,
                    beta=1.0,
                    average="macro",
                    multidim_average="global",
                ),
                "AverageIoU": JaccardIndex(
                    task="multiclass",
                    num_classes=num_classes,
                    ignore_index=ignore_index,
                    average="macro",
                ),
                "Accuracy": ClasswiseWrapper(
                    Accuracy(
                        task="multiclass",
                        num_classes=num_classes,
                        average="none",
                        multidim_average="global",
                    ),
                ),
                "F1Score": ClasswiseWrapper(
                    FBetaScore(
                        task="multiclass",
                        num_classes=num_classes,
                        beta=1.0,
                        average="none",
                        multidim_average="global",
                    ),
                ),
                "IoU": ClasswiseWrapper(
                    JaccardIndex(
                        task="multiclass", num_classes=num_classes, average="none"
                    ),
                ),
            },
            prefix="train_",
        )
        self.val_metrics = self.train_metrics.clone(prefix="val_")
        self.test_metrics = self.train_metrics.clone(prefix="test_")

    def configure_models(self):
        self.model = PrithviEncoderDecoder(
            num_classes=self.hparams["num_classes"],
            cfg_path=CFG_PATH,
            ckpt_path=CKPT_PATH,
            in_chans=self.hparams["in_channels"],
            img_size=224,
            freeze_encoder=False,
            num_neck_filters=32,
        )

    def training_step(self, batch, batch_idx, dataloader_idx=0):
        x = batch["image"]
        y = batch["mask"]
        y_hat = self(x)
        y_hat_hard = y_hat.argmax(dim=1)
        loss = self.criterion(y_hat, y)
        self.log("train_loss", loss)
        self.train_metrics(y_hat_hard, y)
        self.log_dict({f"{k}": v for k, v in self.train_metrics.compute().items()})
        return loss

    def validation_step(self, batch, batch_idx, dataloader_idx=0):
        x = batch["image"]
        y = batch["mask"]
        y_hat = self(x)
        y_hat_hard = y_hat.argmax(dim=1)
        loss = self.criterion(y_hat, y)
        self.log("val_loss", loss)
        self.val_metrics(y_hat_hard, y)
        self.log_dict({f"{k}": v for k, v in self.val_metrics.compute().items()})

        # Plot samples
        if (
            batch_idx < 10
            and hasattr(self.trainer, "datamodule")
            and hasattr(self.trainer.datamodule, "plot")
            and self.logger
            and hasattr(self.logger, "experiment")
            and hasattr(self.logger.experiment, "add_figure")
        ):
            datamodule = self.trainer.datamodule
            batch["prediction"] = y_hat_hard
            for key in ["image", "mask", "prediction"]:
                batch[key] = batch[key].cpu()
            sample = unbind_samples(batch)[0]

            fig = None
            try:
                fig = datamodule.plot(sample)
            except RGBBandsMissingError:
                pass

            if fig:
                summary_writer = self.logger.experiment
                summary_writer.add_figure(
                    f"image/{batch_idx}", fig, global_step=self.global_step
                )
                plt.close()

    def test_step(self, batch, batch_idx, dataloader_idx=0):
        x = batch["image"]
        y = batch["mask"]
        y_hat = self(x)
        y_hat_hard = y_hat.argmax(dim=1)
        loss = self.criterion(y_hat, y)
        self.log("test_loss", loss)
        self.test_metrics(y_hat_hard, y)
        self.log_dict({f"{k}": v for k, v in self.test_metrics.compute().items()})

In [3]:
module = PrithviSegmentationTask(in_channels=6, num_classes=2, lr=1e-4, patience=10)
datamodule = HLSBurnScarsDataModule(
    root="data/hls_burn_scars",
    batch_size=8,
    num_workers=8,
)

In [4]:
trainer = lightning.Trainer(
    accelerator="gpu", devices=[0], logger=True, max_epochs=20, precision="16-mixed"
)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [5]:
trainer.fit(model=module, datamodule=datamodule)

Missing logger folder: /workspace/storage/github/prithvi-pytorch/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name          | Type                  | Params
--------------------------------------------------------
0 | criterion     | FocalJaccardLoss      | 0     
1 | train_metrics | MetricCollection      | 0     
2 | val_metrics   | MetricCollection      | 0     
3 | test_metrics  | MetricCollection      | 0     
4 | model         | PrithviEncoderDecoder | 113 M 
--------------------------------------------------------
112 M     Trainable params
252 K     Non-trainable params
113 M     Total params
452.012   Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

In [None]:
trainer.test(datamodule=datamodule)