In [1]:
import os

from pathlib import Path

import albumentations as albu  
import numpy as np  
import pytorch_lightning as pl  
import segmentation_models_pytorch as smp  
import sklearn
import tifffile
import torch  

from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from timm.optim import create_optimizer_v2
from timm.scheduler import create_scheduler_v2


  from .autonotebook import tqdm as notebook_tqdm
  check_for_updates()


In [None]:
seed_everything(42)
data_root = Path("./data")
class_names = ["grassland_shrubland", "logging", "mining", "plantation"]
epochs = 200

Seed set to 42


42

### Define dataset class to load images and masks for training and validation

In [None]:
def load_mask(mask_path):
    mask = np.load(mask_path)  # (4, H, W)
    assert mask.shape == (4, 1024, 1024)
    return (mask.transpose(1, 2, 0).astype(np.float32)) / 255.0  # (H, W, 4), normalized


def load_image(image_path):
    image = tifffile.imread(image_path)  # (H, W, 12)
    assert image.shape == (1024, 1024, 12)
    return np.nan_to_num(image).astype(np.float32)


def normalize_image(image):
    # Precomputed mean and std (12 bands)
    mean = np.array( [
            280.827,
            328.215,
            553.243,
            393.551,
            911.256,
            2394.626,
            2925.688,
            3160.688,
            3176.124,
            3275.213,
            1721.096,
            849.122,
        ], dtype=np.float32).reshape(12, 1, 1)
    std = np.array( [
            284.234,
            240.134,
            310.143,
            392.992,
            405.232,
            615.245,
            790.267,
            852.903,
            824.679,
            809.612,
            636.423,
            500.186,
        ], dtype=np.float32).reshape(12, 1, 1)
    return (image - mean) / std


class TrainValDataset(torch.utils.data.Dataset):
    def __init__(self, data_root, sample_indices, augmentations=None):
        self.image_paths = [data_root / "train_images" / f"train_{i}.tif" for i in sample_indices]
        self.mask_paths = [data_root / "train_masks" / f"train_{i}.npy" for i in sample_indices]
        self.augmentations = augmentations

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = load_image(self.image_paths[idx])
        mask = load_mask(self.mask_paths[idx])
        sample = {"image": image, "mask": mask}

        if self.augmentations:
            sample = self.augmentations(**sample)

        sample["image"] = normalize_image(sample["image"].transpose(2, 0, 1))  # (12, H, W)
        sample["mask"] = sample["mask"].transpose(2, 0, 1)  # (4, H, W)
        sample["image_path"] = str(self.image_paths[idx])
        sample["mask_path"] = str(self.mask_paths[idx])
        return sample


### Define U-Net model using pytorch-lightning

In [None]:
class Model(pl.LightningModule):
    def __init__(self):
        super().__init__()

        # Initialize UNet with EfficientNetV2-S encoder
        self.model = smp.create_model(
            arch="unet",
            encoder_name="tu-tf_efficientnetv2_s",
            encoder_weights="imagenet",
            in_channels=12,
            classes=4,
        )

        # Define losses
        self.dice_loss_fn = smp.losses.DiceLoss(mode=smp.losses.MULTILABEL_MODE, from_logits=True)
        self.bce_loss_fn = smp.losses.SoftBCEWithLogitsLoss(smooth_factor=0.0)

        self.training_step_outputs = []
        self.validation_step_outputs = []

    def forward(self, image):
        return self.model(image)

    def shared_step(self, batch, stage):
        image, mask = batch["image"], batch["mask"]
        logits = self(image)

        # Combine Dice + BCE losses
        loss = self.dice_loss_fn(logits, mask) + self.bce_loss_fn(logits, mask)

        # Compute classification stats
        prob_mask = logits.sigmoid()
        tp, fp, fn, tn = smp.metrics.get_stats(
            (prob_mask > 0.5).long(),
            mask.long(),
            mode=smp.losses.MULTILABEL_MODE,
        )

        # Detach results for aggregation
        output = {
            "loss": loss.detach().cpu(),
            "tp": tp.detach().cpu(),
            "fp": fp.detach().cpu(),
            "fn": fn.detach().cpu(),
            "tn": tn.detach().cpu(),
        }

        if stage == "train":
            self.training_step_outputs.append(output)
        else:
            self.validation_step_outputs.append(output)

        return loss

    def training_step(self, batch, batch_idx):
        return self.shared_step(batch, "train")

    def validation_step(self, batch, batch_idx):
        return self.shared_step(batch, "val")

    def shared_epoch_end(self, outputs, stage):
        def log(name, val, prog_bar=False):
            self.log(f"{stage}/{name}", val.to(self.device), sync_dist=True, prog_bar=prog_bar)

        # Average loss
        avg_loss = torch.stack([x["loss"] for x in outputs]).mean()
        log("loss", avg_loss, prog_bar=True)

        # Concatenate stats
        tp = torch.cat([x["tp"] for x in outputs])
        fp = torch.cat([x["fp"] for x in outputs])
        fn = torch.cat([x["fn"] for x in outputs])
        tn = torch.cat([x["tn"] for x in outputs])

        # Compute F1 per class
        f1_scores = {
            class_name: smp.metrics.f1_score(tp[:, i], fp[:, i], fn[:, i], tn[:, i], reduction="macro-imagewise")
            for i, class_name in enumerate(class_names)
        }
        for name, score in f1_scores.items():
            log(f"f1/{name}", score)

        # Mean F1 across classes
        avg_f1 = torch.stack(list(f1_scores.values())).mean()
        log("f1", avg_f1, prog_bar=True)

    def on_train_epoch_end(self):
        self.shared_epoch_end(self.training_step_outputs, "train")
        self.training_step_outputs.clear()

    def on_validation_epoch_end(self):
        self.shared_epoch_end(self.validation_step_outputs, "val")
        self.validation_step_outputs.clear()

    def configure_optimizers(self):
        # Set up AdamW optimizer
        optimizer = create_optimizer_v2(
            self.parameters(),
            opt="adamw",
            lr=1e-4,
            weight_decay=1e-2,
            filter_bias_and_bn=True,
        )

        # Set up cosine LR scheduler with warmup
        scheduler, _ = create_scheduler_v2(
            optimizer,
            sched="cosine",
            num_epochs=epochs,
            min_lr=0.0,
            warmup_lr=1e-5,
            warmup_epochs=0,
            warmup_prefix=False,
            step_on_epochs=True,
        )

        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "epoch",
            },
        }

    def lr_scheduler_step(self, scheduler, metric):
        # timm scheduler requires current epoch explicitly
        scheduler.step(epoch=self.current_epoch)


### Prepare trainer of pytorch-lightning

In [None]:
# Set output directory
train_output_dir = data_root / "training_result"

# Split dataset into training and validation sets
sample_indices = list(range(176))  # train_0.tif to train_175.tif
train_indices, val_indices = sklearn.model_selection.train_test_split(
    sample_indices, test_size=0.2, random_state=42
)

# Define training augmentations
augmentations = albu.Compose([
    albu.ShiftScaleRotate(
        p=0.5, shift_limit=0.0625, scale_limit=0.1, rotate_limit=15,
        border_mode=0, value=0, mask_value=0, interpolation=2
    ),
    albu.RandomCrop(p=1, width=512, height=512),
    albu.HorizontalFlip(p=0.5),
    albu.VerticalFlip(p=0.5),
    albu.Transpose(p=0.5),
    albu.RandomRotate90(p=0.5),
])

# Create data loaders
train_loader = torch.utils.data.DataLoader(
    TrainValDataset(data_root, train_indices, augmentations=augmentations),
    batch_size=16,
    num_workers=8,
    shuffle=True,
)

val_loader = torch.utils.data.DataLoader(
    TrainValDataset(data_root, val_indices, augmentations=None),
    batch_size=4,
    num_workers=8,
    shuffle=False,
)

# Configure PyTorch Lightning trainer
trainer = Trainer(
    max_epochs=epochs,
    callbacks=[
        ModelCheckpoint(
            dirpath=train_output_dir,
            filename="best_f1_05",
            monitor="val/f1",
            mode="max",
            save_weights_only=True,
            save_top_k=1,
            save_last=False,
        ),
        LearningRateMonitor(logging_interval="step"),
    ],
    logger=[TensorBoardLogger(train_output_dir, name=None)],
    precision="16-mixed",
    deterministic=True,
    benchmark=False,
    sync_batchnorm=False,
    check_val_every_n_epoch=5,
    default_root_dir=os.getcwd(),
    accelerator="gpu",
    devices=[0],
    strategy="ddp_notebook",
    log_every_n_steps=5,
)

# Initialize model
model = Model()


Using 16bit Automatic Mixed Precision (AMP)


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


### Start training

The trained model is saved as `data/training_result/best_f1_05.ckpt`.

In [7]:
# start training
trainer.fit(
    model,
    train_dataloaders=train_loader,
    val_dataloaders=val_loader,
)

You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/1
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 1 processes
----------------------------------------------------------------------------------------------------

/opt/conda/lib/python3.11/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:654: Checkpoint directory /workspace/data/training_result exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name         | Type                  | Params | Mode 
----------------

Epoch 199: 100%|██████████| 9/9 [00:14<00:00,  0.63it/s, v_num=0, train/loss=0.406, train/f1=0.800, val/loss=0.694, val/f1=0.552]

`Trainer.fit` stopped: `max_epochs=200` reached.


Epoch 199: 100%|██████████| 9/9 [00:14<00:00,  0.63it/s, v_num=0, train/loss=0.406, train/f1=0.800, val/loss=0.694, val/f1=0.552]



Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined

### Compute the evaluation metric for the validation set


In [None]:
def run_inference(model, loader, pred_output_dir):
    pred_output_dir = Path(pred_output_dir)
    pred_output_dir.mkdir(parents=True, exist_ok=True)

    for batch in tqdm(loader):
        images = batch["image"].cuda()

        with torch.no_grad():
            logits = model(images)
            probs = logits.sigmoid()

        # Save predicted probability masks
        for i in range(images.size(0)):
            fname = os.path.basename(batch["image_path"][i])
            mask = probs[i].cpu().numpy()  # (4, 1024, 1024)

            np.save(pred_output_dir / fname.replace(".tif", ".npy"), mask.astype(np.float16))


In [None]:
# Load best model checkpoint and run inference on validation set
del model  # free memory

model = Model()
state = torch.load(train_output_dir / "best_f1_05.ckpt")["state_dict"]
model.load_state_dict(state)
model = model.cuda().eval()

val_pred_dir = data_root / "val_preds"
run_inference(model, val_loader, val_pred_dir)


  model.load_state_dict(torch.load(train_output_dir / "best_f1_05.ckpt")["state_dict"])
100%|██████████| 9/9 [00:03<00:00,  2.31it/s]


In [None]:
def compute_f1_score(pred_mask, truth_mask):
    # Both inputs are binary masks of shape (1024, 1024)
    assert pred_mask.shape == (1024, 1024)
    assert truth_mask.shape == (1024, 1024)

    tp = ((pred_mask > 0) & (truth_mask > 0)).sum()
    fp = ((pred_mask > 0) & (truth_mask == 0)).sum()
    fn = ((pred_mask == 0) & (truth_mask > 0)).sum()

    precision = tp / (tp + fp) if tp + fp > 0 else 1.0
    recall = tp / (tp + fn) if tp + fn > 0 else 1.0
    f1 = 2 * precision * recall / (precision + recall) if precision + recall > 0 else 0.0

    return f1


score_thresh = 0.5  # threshold for binarizing prediction
min_area = 10000    # ignore small predicted areas

val_f1_scores = {}

for idx in sorted(val_indices):
    fn = f"train_{idx}"

    # Load predicted and ground truth masks
    pred = np.load(val_pred_dir / f"{fn}.npy") > score_thresh  # (4, 1024, 1024)
    truth = np.load(data_root / "train_masks" / f"{fn}.npy")   # (4, 1024, 1024)

    val_f1_scores[fn] = {}
    for i, class_name in enumerate(class_names):
        pred_mask = pred[i]
        if pred_mask.sum() < min_area:
            pred_mask = np.zeros_like(pred_mask)  # discard small predictions

        val_f1_scores[fn][class_name] = compute_f1_score(pred_mask, truth[i])

# Convert to DataFrame
val_f1_scores = pd.DataFrame(val_f1_scores).T
val_f1_scores["all_classes"] = val_f1_scores.mean(axis=1)     # mean F1 per image
val_f1_scores.loc["all_images"] = val_f1_scores.mean()        # mean F1 across all images

print(f"val F1 score: {val_f1_scores.loc['all_images', 'all_classes']:.4f}")
val_f1_scores


val f1 score: 0.6814


Unnamed: 0,grassland_shrubland,logging,mining,plantation,all_classes
train_9,0.925728,1.0,1.0,1.0,0.981432
train_12,0.0,1.0,1.0,0.0,0.5
train_15,0.044791,1.0,1.0,0.794052,0.709711
train_16,0.702484,1.0,1.0,0.906406,0.902222
train_18,0.532832,1.0,1.0,0.990736,0.880892
train_19,0.0,1.0,1.0,0.667145,0.666786
train_24,1.0,1.0,1.0,0.90391,0.975977
train_29,0.0,1.0,1.0,0.622041,0.65551
train_30,0.950854,1.0,1.0,1.0,0.987713
train_31,0.45162,0.0,0.0,0.0,0.112905


### Predict the evaluation images and generate a submission JSON file

Let's predict the evaluation images as already done with the validation set, and generate a submission JSON file.

The submission JSON file will be saved as `data/submission.json`.

In [11]:
class TestDataset(torch.utils.data.Dataset):
    def __init__(self, data_root):
        self.image_paths = []
        for i in range(118):  
            self.image_paths.append(data_root / "evaluation_images" / f"evaluation_{i}.tif")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        sample = {
            "image": load_image(self.image_paths[idx]),
        }

        sample["image"] = sample["image"].transpose(2, 0, 1)  # (12, H, W)
        sample["image"] = normalize_image(sample["image"])

        # add metadata
        sample["image_path"] = str(self.image_paths[idx])

        return sample

In [12]:
test_loader = torch.utils.data.DataLoader(
    TestDataset(data_root),
    batch_size=4,
    num_workers=8,
    shuffle=False,
)

test_pred_dir = data_root / "test_preds"
run_inference(model, test_loader, test_pred_dir)

100%|██████████| 30/30 [00:09<00:00,  3.28it/s]


`detect_polygons()` below extracts isolated areas as polygons from the predicted mask.

The point is `min_area` parameter to filter out small areas. Small predicted areas are often false positives which decrease the evaluation score.

In [13]:
def detect_polygons(pred_dir, score_thresh, min_area):
    pred_dir = Path(pred_dir)
    pred_paths = list(pred_dir.glob("*.npy"))
    pred_paths = sorted(pred_paths)

    polygons_all_imgs = {}
    for pred_path in tqdm(pred_paths):
        polygons_all_classes = {}

        mask = np.load(pred_path)  # (4, 1024, 1024)
        mask = mask > score_thresh  # binarize
        for i, class_name in enumerate(class_names):
            mask_for_a_class = mask[i]
            if mask_for_a_class.sum() < min_area:
                mask_for_a_class = np.zeros_like(mask_for_a_class)  # set all to zero if the predicted area is less than `min_area`

            # extract polygons from the binarized mask
            label = measure.label(mask_for_a_class, connectivity=2, background=0).astype(np.uint8)
            polygons = []
            for p, value in features.shapes(label, label):
                p = shape(p).buffer(0.5)
                p = p.simplify(tolerance=0.5)
                polygons.append(p)
            polygons_all_classes[class_name] = polygons
        polygons_all_imgs[pred_path.name.replace(".npy", ".tif")] = polygons_all_classes

    return polygons_all_imgs

In [14]:
test_pred_polygons = detect_polygons(test_pred_dir, score_thresh=score_thresh, min_area=min_area)

submission_save_path = data_root / f"submission.json"

images = []
for img_id in range(118):  # evaluation_0.tif to evaluation_117.tif
    annotations = []
    for class_name in class_names:
        for poly in test_pred_polygons[f"evaluation_{img_id}.tif"][class_name]:
            seg: list[float] = []  # [x0, y0, x1, y1, ..., xN, yN]
            for xy in poly.exterior.coords:
                seg.extend(xy)

            annotations.append({"class": class_name, "segmentation": seg})

    images.append({"file_name": f"evaluation_{img_id}.tif", "annotations": annotations})

with open(submission_save_path, "w", encoding="utf-8") as f:
    json.dump({"images": images}, f, indent=4)

100%|██████████| 118/118 [00:11<00:00,  9.91it/s]
