In [None]:
from terratorch.datasets.od_augmentation import CopyPasteObjectDetectionDataset
from terratorch.datasets.generic_od_dataset import GenericObjectDetectionDataset
import os
image_dir = os.environ.get("image_dir", "/home/romeokienzler/Downloads/swisstopo/")
object_folder = os.environ.get("object_folder", "/home/romeokienzler/Downloads/objects/")
tile_cache_dir = os.environ.get("tile_cache_dir", "/home/romeokienzler/Downloads/swisstopo_tile_cache/")
checkpoint_path = os.environ.get('checkpoint_path', None)

# Object Detection Data Augmentation with Copy-Paste

This notebook demonstrates how to use the Copy-Paste augmentation technique for object detection datasets.

## Requirements

For this notebook to work with real data, you need:

1. **Images**: A folder of images (JPG, PNG, TIFF, etc.)
2. **Annotations**: YOLO format `.txt` files with the same name as images, containing:
   ```
   <class_id> <center_x_norm> <center_y_norm> <width_norm> <height_norm>
   ```
   where normalized coordinates are in range [0, 1], one line per object per image.

3. **Objects for pasting**: RGBA PNG images in the `object_folder`

**Note**: The environment variables (`image_dir`, `object_folder`, `tile_cache_dir`) need to point to directories with actual annotated data for training to work. Without annotation files, the dataset will be empty and training will fail.

In [None]:
dataset = GenericObjectDetectionDataset(
    image_dir=image_dir,
)

sample = dataset[0]
print(sample["image"].shape)
print(sample["boxes"].shape)
print(sample["labels"].shape)
print(sample["boxes"])



In [None]:
from terratorch.datasets.od_tiled_dataset_wrapper import TiledDataset
tiled_dataset = TiledDataset(
    base_dataset=dataset,
    tile_size=(512, 512),
    overlap=0,
    cache_dir=tile_cache_dir,
    skip_empty_boxes=False, # we dont have any boxes yet as we add them later via augmentation
    rebuild=False
)

In [None]:
dataset_aug = CopyPasteObjectDetectionDataset(
    base_dataset=tiled_dataset,
    object_folder=object_folder,
    scale_range=(0.01, 0.01),
    max_objects=3,
    paste_prob=1
)

sample = dataset_aug[0]
print(sample["image"].shape, sample["mask"].shape)

In [None]:
# plot original dataset image with boxes and labels
import terratorch.visualization as ttv
sample_original = dataset[0]
ttv.plot_boxes_labels(sample["image"], sample["boxes"])

In [None]:
from torch.utils.data import DataLoader, random_split
import lightning as L
import torch

class GenericDataModule(L.LightningDataModule):
    def __init__(
        self,
        dataset=None,
        train_dataset=None,
        val_dataset=None,
        test_dataset=None,
        predict_dataset=None,
        split_ratio=(0.6, 0.2, 0.2),
        seed=42,
        batch_size=1,
        num_workers=4,
        pin_memory=True,
        persistent_workers=True,
        collate_fn=None,
        shuffle_train=True,
    ):
        super().__init__()

        self.dataset = dataset
        self.train_dataset = train_dataset
        self.val_dataset = val_dataset
        self.test_dataset = test_dataset
        self.predict_dataset = predict_dataset

        self.split_ratio = split_ratio
        self.seed = seed

        self.batch_size = batch_size
        self.num_workers = num_workers
        self.pin_memory = pin_memory
        self.persistent_workers = persistent_workers
        self.collate_fn = collate_fn
        self.shuffle_train = shuffle_train

        if dataset is None and train_dataset is None:
            raise ValueError(
                "You must provide either `dataset` or `train_dataset`."
            )

    def setup(self, stage=None):
        # already explicitly provided → do nothing
        if self.train_dataset is not None:
            return

        if self.dataset is None:
            raise ValueError("Dataset is None but no explicit splits were provided.")

        n = len(self.dataset)
        r_train, r_val, r_test = self.split_ratio

        if not abs(r_train + r_val + r_test - 1.0) < 1e-6:
            raise ValueError(f"split_ratio must sum to 1, got {self.split_ratio}")

        n_train = int(n * r_train)
        n_val = int(n * r_val)
        n_test = n - n_train - n_val  # remainder → test

        g = torch.Generator().manual_seed(self.seed)

        self.train_dataset, self.val_dataset, self.test_dataset = random_split(
            self.dataset,
            [n_train, n_val, n_test],
            generator=g
        )

    def _loader(self, dataset, shuffle=False):
        if dataset is None:
            return None

        return DataLoader(
            dataset,
            batch_size=self.batch_size,
            shuffle=shuffle,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            persistent_workers=self.persistent_workers and self.num_workers > 0,
            collate_fn=self.collate_fn,
        )

    def train_dataloader(self):
        return self._loader(self.train_dataset, shuffle=self.shuffle_train)

    def val_dataloader(self):
        return self._loader(self.val_dataset, shuffle=False)

    def test_dataloader(self):
        return self._loader(self.test_dataset, shuffle=False)

    def predict_dataloader(self):
        return self._loader(self.predict_dataset, shuffle=False)


In [None]:
def detection_collate(batch):
    images = torch.stack([b["image"] for b in batch])  # [B, C, H, W]

    targets = {
        "boxes": [b["boxes"] for b in batch],
        "labels": [b["labels"] for b in batch],
    }

    return {
        "image": images,
        **targets,
    }



dm = GenericDataModule(
    dataset=dataset_aug,
    batch_size=4,
    num_workers=4,
    collate_fn=detection_collate
)

In [None]:
dm.setup("fit")

In [None]:
train_loader = dm.train_dataloader()
batch = next(iter(train_loader))

In [None]:
from terratorch.tasks import ObjectDetectionTask
from terratorch.models.object_detection_model_factory import ObjectDetectionModelFactory

model = ObjectDetectionTask(
    model_factory="ObjectDetectionModelFactory",
    model_args={
        "framework": "faster-rcnn",
        "backbone": "terramind_v1_tiny",
        "backbone_pretrained": True,
        "num_classes": 2,
        "framework_min_size": 512,
        "framework_max_size": 512,
        "backbone_modalities": ["RGB"],
        "in_channels": 3,
        "necks": [
            {
                "name": "SelectIndices",
                "indices": [2, 5, 8, 11],
            },
            {
                "name": "ReshapeTokensToImage",
                "remove_cls_token": False,
            },
            {
                "name": "LearnedInterpolateToPyramidal",
            },
            {
                "name": "FeaturePyramidNetworkNeck",
            },
        ],
    },
    freeze_backbone=False,
    freeze_decoder=False,
    class_names=[
        "Background",
        "SimpleObject",
    ],
)


In [None]:
import torch
import matplotlib.pyplot as plt
import pytorch_lightning as pl
from torch.utils.data import Subset
import terratorch.visualization as ttv


class LogDetectionGrid(pl.Callback):
    def __init__(self, grid_size: int = 4, every_n_epochs: int = 1, score_thr=0.3):
        self.grid_size = grid_size
        self.every_n_epochs = every_n_epochs
        self.score_thr = score_thr

    @torch.no_grad()
    def on_validation_epoch_end(self, trainer, pl_module):
        epoch = trainer.current_epoch
        if epoch % self.every_n_epochs != 0:
            return

        pl_module.eval()

        loader = trainer.datamodule.val_dataloader()
        batch = next(iter(loader))

        images = batch["image"].to(pl_module.device)

        outputs = pl_module(images)
        preds = outputs.predictions  # <-- KEY FIX

        fig, axes = plt.subplots(
            self.grid_size,
            self.grid_size,
            figsize=(self.grid_size * 4, self.grid_size * 4),
            tight_layout=True,
        )

        for i, ax in enumerate(axes.flat):
            if i >= len(preds):
                break

            pred = preds[i]
            keep = pred["scores"] > self.score_thr

            ttv.plot_boxes_labels(
                image=images[i].cpu(),
                boxes=pred["boxes"][keep].cpu(),
                labels=pred["labels"][keep].cpu(),
                scores=pred["scores"][keep].cpu(),
                ax=ax,
                show=False,
            )
            ax.axis("off")

        trainer.logger.experiment.add_figure(
            "val/predictions",
            fig,
            global_step=epoch,
        )
        plt.close(fig)


In [None]:
trainer = L.Trainer(
    max_epochs=100,
    accelerator="auto",
    callbacks=[LogDetectionGrid(grid_size=4)],)

if checkpoint_path:
    # Resume training from checkpoint
    trainer.fit(model, dm, ckpt_path=checkpoint_path)
else:
    # Train from scratch
    trainer.fit(model, dm)