In [None]:
%cd crop-type-segmentation/

In [None]:
from pathlib import Path
import math
import re

from matplotlib import pyplot as plt
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from shapely.geometry import Polygon
from rastervision.core.data import (
    RasterioSource,
    MinMaxTransformer,
    TemporalMultiRasterSource,
    Scene,
    SemanticSegmentationLabelSource,
    ClassConfig,
    NanTransformer,
    ReclassTransformer,
)
from rastervision.pytorch_learner import SemanticSegmentationSlidingWindowGeoDataset
from terratorch.models import PrithviModelFactory
from terratorch.datasets import HLSBands
from torchmetrics import JaccardIndex
from sklearn.metrics import precision_recall_fscore_support
import wandb

from cropland_data_layer_class_table import class_info

In [None]:
class PrithviSemanticSegmentation(pl.LightningModule):
    def __init__(
        self,
        num_classes,
        in_channels,
        num_frames,
        decoder_num_convs,
        img_size,
        learning_rate,
    ):
        super().__init__()
        model_factory = PrithviModelFactory()
        self.model = model_factory.build_model(
            task="segmentation",
            backbone="prithvi_vit_100",
            decoder="FCNDecoder",
            decoder_num_convs=decoder_num_convs,
            in_channels=in_channels,
            bands=[
                HLSBands.BLUE,
                HLSBands.GREEN,
                HLSBands.RED,
                HLSBands.NIR_NARROW,
                HLSBands.SWIR_1,
                HLSBands.SWIR_2,
            ],
            num_classes=num_classes,
            pretrained=True,
            num_frames=num_frames,
            head_dropout=0.0,
            img_size=img_size,
        )
        self.learning_rate = learning_rate

        for param in self.model.encoder.parameters():
            param.requires_grad = False

        self.jaccard_index = JaccardIndex(task="multiclass", num_classes=num_classes)

    def training_step(self, batch, batch_idx):
        x, y = batch
        model_output = self.model(x)
        mask = model_output.output
        loss = F.cross_entropy(mask, y)
        self.log("train/loss", loss, prog_bar=True, on_step=True, on_epoch=True)

        pred = torch.argmax(mask, dim=1)
        iou = self.jaccard_index(pred, y)
        self.log("train/iou", iou, on_step=False, on_epoch=True)
        y_flat = y.flatten().cpu()
        pred_flat = pred.flatten().cpu()
        precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(
            y_flat, pred_flat, average="macro"
        )
        self.log_dict(
            {
                "train/precision_macro": precision_macro,
                "train/recall_macro": recall_macro,
                "train/f1_macro": f1_macro,
            },
            on_step=False,
            on_epoch=True,
        )

        precision_weighted, recall_weighted, f1_weighted, _ = (
            precision_recall_fscore_support(y_flat, pred_flat, average="weighted")
        )
        self.log_dict(
            {
                "train/precision_weighted": precision_weighted,
                "train/recall_weighted": recall_weighted,
                "train/f1_weighted": f1_weighted,
            },
            on_step=False,
            on_epoch=True,
        )

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        model_output = self.model(x)
        mask = model_output.output
        loss = F.cross_entropy(mask, y)
        self.log("val/loss", loss, prog_bar=True, on_step=False, on_epoch=True)

        pred = torch.argmax(mask, dim=1)
        iou = self.jaccard_index(pred, y)
        self.log("val/iou", iou, on_step=False, on_epoch=True)
        y_flat = y.flatten().cpu()
        pred_flat = pred.flatten().cpu()
        precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(
            y_flat, pred_flat, average="macro"
        )
        self.log_dict(
            {
                "val/precision_macro": precision_macro,
                "val/recall_macro": recall_macro,
                "val/f1_macro": f1_macro,
            },
            on_step=False,
            on_epoch=True,
        )

        precision_weighted, recall_weighted, f1_weighted, _ = (
            precision_recall_fscore_support(y_flat, pred_flat, average="weighted")
        )
        self.log_dict(
            {
                "val/precision_weighted": precision_weighted,
                "val/recall_weighted": recall_weighted,
                "val/f1_weighted": f1_weighted,
            },
            on_step=False,
            on_epoch=True,
        )

        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

In [None]:
# Modify the following configuration as needed
config = {}
config["num_classes"] = 17
config["num_frames"] = 7
config["months"] = [2, 4, 5, 6, 7, 8, 9]
config["img_size"] = 224
config["learning_rate"] = 0.001
config["decoder_num_convs"] = 1
config["in_channels"] = 6
config["channels"] = [0, 1, 2, 3, 4, 5]
config["batch_size"] = 5
config["wandb_project"] = "test"
config["wandb_name"] = "test"
config["max_epochs"] = 1
config["num_workers"] = 4 if torch.cuda.is_available() else 0
config["ckpt_path"] = None

### Load the data

In [None]:
data_dir = Path("data")
months = config["months"]
months_regex = f"Landsat9_Composite_2022_0[{''.join(map(str, months))}].tiff"
l9_images = sorted(data_dir.glob("Landsat9_Composite_2022_0*.tiff"))
l9_images = [img for img in l9_images if re.match(months_regex, img.name)]

In [None]:
l9_images

In [None]:
colors = [item["Color"] for item in class_info]
names = [item["Description"] for item in class_info]

In [None]:
# Map class IDs to use classes that contain more than 1% of pixels. All other classes are "Other" (0).
# All classes for developed areas are combined
most_frequent_crops = {
    3: 1,
    6: 2,
    24: 3,
    36: 4,
    37: 5,
    54: 6,
    61: 7,
    75: 8,
    76: 9,
    111: 10,
    142: 11,
    152: 12,
    176: 13,
    195: 14,
    220: 15,
}
developed_classes = [82, 121, 122, 123, 124]
mapping = {}
for item in class_info:
    value = int(item["Value"])
    if value in most_frequent_crops:
        mapping[value] = most_frequent_crops[value]
    elif value in developed_classes:
        mapping[value] = 16
    else:
        mapping[value] = 0

In [None]:
class_config = ClassConfig(names=names, colors=colors, null_class="Other")
label_source = SemanticSegmentationLabelSource(
    raster_source=RasterioSource(
        uris="data/Cropland_Data_Layer_2022.tiff",
        raster_transformers=[ReclassTransformer(mapping)],
    ),
    class_config=class_config,
)

In [None]:
raster_sources = []
for image_uri in l9_images:
    raster_sources.append(
        RasterioSource(
            str(image_uri),
            channel_order=config["channels"],
            raster_transformers=[NanTransformer(to_value=0), MinMaxTransformer()],
        )
    )

In [None]:
time_series = TemporalMultiRasterSource(raster_sources)

In [None]:
extent = raster_sources[0].bbox.extent
extent = extent.to_dict()

In [None]:
train_percent = 0.7
train_aoi = Polygon.from_bounds(
    ymin=0, ymax=int(extent["ymax"] * train_percent), xmin=0, xmax=extent["xmax"]
)
val_aoi = Polygon.from_bounds(
    ymin=math.ceil(extent["ymax"] * train_percent),
    ymax=extent["ymax"],
    xmin=0,
    xmax=extent["xmax"],
)

In [None]:
train_scene = Scene(
    id="train",
    raster_source=time_series,
    label_source=label_source,
    aoi_polygons=[train_aoi],
)
val_scene = Scene(
    id="val",
    raster_source=time_series,
    label_source=label_source,
    aoi_polygons=[val_aoi],
)

In [None]:
train_dataset = SemanticSegmentationSlidingWindowGeoDataset(
    train_scene, size=config["img_size"], stride=config["img_size"], padding=0
)
val_dataset = SemanticSegmentationSlidingWindowGeoDataset(
    val_scene, size=config["img_size"], stride=config["img_size"], padding=0
)

In [None]:
print(f"len train dataset: {len(train_dataset)}")
print(f"len val dataset: {len(val_dataset)}")

In [None]:
def custom_collate_fn(batch):
    """Changes the order of the axes from what Raster Vision outputs (B,T,C,H,W) to what
    the Prithvi model expects (B,C,T,H,W).
    """
    data, targets = zip(*batch)
    data = torch.stack(data)
    data = data.permute(0, 2, 1, 3, 4)
    if isinstance(targets[0], torch.Tensor):
        targets = torch.stack(targets)
    else:
        targets = torch.tensor(targets)
    return data, targets

In [None]:
train_dl = DataLoader(
    train_dataset,
    batch_size=config["batch_size"],
    shuffle=True,
    num_workers=config["num_workers"],
    collate_fn=custom_collate_fn,
)
val_dl = DataLoader(
    val_dataset,
    batch_size=config["batch_size"],
    shuffle=False,
    num_workers=config["num_workers"],
    collate_fn=custom_collate_fn,
)

### Visualize a batch

In [None]:
x, y = next(iter(train_dl))
print(f"x shape: {x.shape}")
print(f"y shape: {y.shape}")

In [None]:
images = x[:, [2, 1, 0], 0, :, :]

batch_size = images.shape[0]

fig, axes = plt.subplots(2, batch_size, figsize=(3 * batch_size, 6))

for i in range(batch_size):
    img = torch.squeeze(images[i])
    img = images[i].permute(1, 2, 0).numpy()
    axes[0, i].imshow(img)
    axes[0, i].axis("off")
    axes[0, i].set_title(f"Image {i + 1}")

    mask = y[i].numpy()
    axes[1, i].imshow(mask, cmap="tab20", vmin=0, vmax=15)
    axes[1, i].axis("off")
    axes[1, i].set_title(f"Mask {i + 1}")

plt.tight_layout()
plt.show()

### Load and train the model

In [None]:
model = PrithviSemanticSegmentation(
    num_classes=config["num_classes"],
    in_channels=config["in_channels"],
    num_frames=config["num_frames"],
    decoder_num_convs=config["decoder_num_convs"],
    img_size=config["img_size"],
    learning_rate=config["learning_rate"],
)

In [None]:
wandb_logger = WandbLogger(
    name=config["wandb_name"], save_dir="wandb-logs", project=config["wandb_project"]
)

In [None]:
trainer = pl.Trainer(
    logger=wandb_logger,
    max_epochs=config["max_epochs"],
    log_every_n_steps=1,
    # limit_train_batches=1,
    # limit_test_batches=1,
    # fast_dev_run=1,
)

In [None]:
trainer.fit(
    model,
    train_dataloaders=train_dl,
    val_dataloaders=val_dl,
    ckpt_path=config["ckpt_path"],
)
wandb_logger.experiment.config.update(config)
wandb.finish()