In [1]:
from pathlib import Path
import math

from matplotlib import pyplot as plt
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.loggers import CSVLogger
from shapely.geometry import Polygon
from rastervision.core.data import RasterioSource, MinMaxTransformer, TemporalMultiRasterSource, Scene, SemanticSegmentationLabelSource, ClassConfig, NanTransformer, ReclassTransformer
from rastervision.pytorch_learner import SemanticSegmentationSlidingWindowGeoDataset
from terratorch.models import PrithviModelFactory
from terratorch.datasets import HLSBands

from cropland_data_layer_class_table import class_info

In [2]:
class PrithviSemanticSegmentation(pl.LightningModule):
    def __init__(self):
        super().__init__()
        model_factory = PrithviModelFactory()
        self.model = model_factory.build_model(task="segmentation",
                backbone="prithvi_vit_100",
                decoder="FCNDecoder",
                decoder_num_convs = 1,
                in_channels=6,
                bands=[
                    HLSBands.BLUE,
                    HLSBands.GREEN,
                    HLSBands.RED,
                    HLSBands.NIR_NARROW,
                    HLSBands.SWIR_1,
                    HLSBands.SWIR_2,
                ],
                num_classes=16,
                pretrained=True,
                num_frames=7,
                head_dropout=0.0,
                img_size=224,
            )

        for param in self.model.encoder.parameters():
            param.requires_grad = False

    def training_step(self, batch, batch_idx):
        x, y = batch
        model_output = self.model(x)
        mask = model_output.output
        loss = F.cross_entropy(mask, y)
        self.log("train_loss", loss, prog_bar=True, on_step=True, on_epoch=True)

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        model_output = self.model(x)
        mask = model_output.output
        loss = F.cross_entropy(mask, y)
        self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True)

        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        return optimizer

### Load the data

In [3]:
data_dir = Path('/workspace/data')
l9_images = sorted(data_dir.glob('Landsat9_Composite_2022_0[2-9].tiff'))

In [None]:
l9_images

In [5]:
colors = [item["Color"] for item in class_info]
names = [item["Description"] for item in class_info]

In [6]:
# Map class IDs to use classes that contain more than 1% of pixels. All other classes are "Other" (0).
# All classes for developed areas are combined
most_frequent_crops = {1:1, 6:2, 24:3, 28:4, 36:5, 37:6, 54:7, 61:8, 75:9, 76:10, 152:11, 176:12, 211:13, 220:14}
developed_classes = [82, 121, 122, 123, 124]
mapping = {}
for item in class_info:
    value = int(item["Value"])
    if value in most_frequent_crops:
        mapping[value] = most_frequent_crops[value]
    elif value in developed_classes:
        mapping[value] = 15
    else:
        mapping[value] = 0

In [7]:
class_config = ClassConfig(names=names, colors=colors, null_class="Other")
label_source = SemanticSegmentationLabelSource(
                raster_source=RasterioSource(uris='/workspace/data/Cropland_Data_Layer_2022.tiff',
                                             raster_transformers=[ReclassTransformer(mapping)]
                                             ), 
                class_config=class_config)

In [8]:
raster_sources = []
for image_uri in l9_images:
    raster_sources.append(RasterioSource(str(image_uri), 
                                         channel_order=[0, 1, 2, 3, 4, 5], 
                                         raster_transformers=[NanTransformer(to_value=0),
                                                              MinMaxTransformer()]))

In [9]:
time_series = TemporalMultiRasterSource(raster_sources)

In [10]:
extent = raster_sources[0].bbox.extent
extent = extent.to_dict()

In [11]:
train_percent = 0.7
train_aoi = Polygon.from_bounds(ymin=0, ymax=int(extent["ymax"]*train_percent), xmin=0, xmax=extent["xmax"])
val_aoi = Polygon.from_bounds(ymin=math.ceil(extent["ymax"]*train_percent), ymax=extent["ymax"], xmin=0, xmax=extent["xmax"])

In [12]:
train_scene = Scene(id="train", raster_source=time_series, label_source=label_source, aoi_polygons=[train_aoi])
val_scene = Scene(id="val", raster_source=time_series, label_source=label_source, aoi_polygons=[val_aoi])

In [13]:
train_dataset = SemanticSegmentationSlidingWindowGeoDataset(train_scene, size=224, stride=224, padding=0)
val_dataset = SemanticSegmentationSlidingWindowGeoDataset(val_scene, size=224, stride=224, padding=0)

In [None]:
print(f"len train dataset: {len(train_dataset)}")
print(f"len val dataset: {len(val_dataset)}")

In [15]:
def custom_collate_fn(batch):
    """Changes the order of the axes from what Raster Vision outputs (B,T,C,H,W) to what
    the Prithvi model expects (B,C,T,H,W).
    """
    data, targets = zip(*batch)
    data = torch.stack(data)
    data = data.permute(0, 2, 1, 3, 4)
    if isinstance(targets[0], torch.Tensor):
        targets = torch.stack(targets)
    else:
        targets = torch.tensor(targets)
    return data, targets

In [16]:
train_dl = DataLoader(train_dataset, batch_size=5, shuffle=True, num_workers=0, collate_fn=custom_collate_fn)
val_dl = DataLoader(val_dataset, batch_size=5, shuffle=False, num_workers=0, collate_fn=custom_collate_fn)

### Visualize a batch

In [None]:
x, y = next(iter(train_dl))
print(f"x shape: {x.shape}")
print(f"y shape: {y.shape}")

In [None]:
images = x[:, [2,1,0], 0, :, :]

batch_size = images.shape[0]

fig, axes = plt.subplots(2, batch_size, figsize=(3 * batch_size, 6))

for i in range(batch_size):
    img = torch.squeeze(images[i])
    img = images[i].permute(1, 2, 0).numpy()
    axes[0, i].imshow(img)
    axes[0, i].axis('off')
    axes[0, i].set_title(f'Image {i + 1}')

    mask = y[i].numpy()
    axes[1, i].imshow(mask, cmap='tab20', vmin=0, vmax=15)
    axes[1, i].axis('off')
    axes[1, i].set_title(f'Mask {i + 1}')

plt.tight_layout()
plt.show()

### Load and train the model

In [19]:
model = PrithviSemanticSegmentation()

In [20]:
csv_logger = CSVLogger(save_dir='workspace/csv-logs')

In [None]:
trainer = pl.Trainer(
    logger=csv_logger,
    max_epochs=1,
    # fast_dev_run=1,
)

In [None]:
trainer.fit(model, train_dataloaders=train_dl, val_dataloaders=val_dl)