In [None]:
from dataclasses import dataclass, field
from pathlib import Path
@dataclass
class Config:
    data_root: Path = Path("/home/krschap/data")
    output_dir: Path = Path("outputs")
    seed: int = 64

    train_regions: list[str] = field(
        default_factory=lambda: ['ramp_dhaka_bangladesh', 'ramp_barishal_bangladesh','ramp_sylhet_bangladesh']
    )
    val_regions: list[str] = field(
        default_factory=lambda: ['ramp_coxs_bazar_bangladesh']
    )
    test_regions: list[str] = field(
        default_factory=lambda: ['stage1_ramp_sample']
    )

    val_split: float = 0.2
    pretrained_model: str = "facebook/mask2former-swin-base-IN21k-coco-instance" # https://huggingface.co/facebook/mask2former-swin-base-IN21k-coco-instance

    stage1_epochs: int = 10
    stage1_batch_size: int = 8
    stage1_loader_num_samples: int = 500

    # hyper params
    stage1_dice_weight: float = 5.0
    stage1_mask_weight: float = 5.0
    stage1_class_weight: float = 5.0
    stage1_learning_rate: float = 0.00001
    stage1_weight_decay: float = 0.0001 # penalty on large weights to prevent overfitting
    stage1_early_stopping_patience: int = 5

    num_workers: int = 32
    use_wandb: bool = True
    wandb_project: str = "building-seg-mask2former"
    wandb_run_name: str = "default_run"

In [2]:
import random 
import numpy as np 
import torch 

def set_seed(seed:int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def get_all_ramp_regions(root: Path) -> list[str]:
    regions = [d.name for d in root.iterdir() if d.is_dir() and d.name.startswith("ramp_")]
    if not regions:
        raise ValueError(f"No RAMP regions found in {root}")
    return sorted(regions)


def split_regions(regions: list[str], val_ratio: float = 0.2, seed: int = 42):
    rng = random.Random(seed)
    shuffled = regions.copy()
    rng.shuffle(shuffled)
    split_idx = int(len(shuffled) * (1 - val_ratio))
    return shuffled[:split_idx], shuffled[split_idx:]

In [3]:
cfg = Config()
set_seed(cfg.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [4]:
cfg.train_regions = ['stage1_ramp_sample']
cfg.val_regions = ['stage2_sample']

In [5]:
all_regions = get_all_ramp_regions(cfg.data_root)
if cfg.train_regions and cfg.val_regions:
    train_regions, val_regions = cfg.train_regions, cfg.val_regions
else:
    train_regions, val_regions = split_regions(all_regions, cfg.val_split, cfg.seed)

In [6]:
train_regions, val_regions

(['stage1_ramp_sample'], ['stage2_sample'])

In [7]:
from torchgeo.datasets import RasterDataset, VectorDataset
# from rasterio.crs import CRS
from pyproj import CRS

class RAMPImageDataset(RasterDataset):
    filename_glob = "*.tif"
    is_image = True
    all_bands = ("R", "G", "B")
    rgb_bands = ("R", "G", "B")

class RAMPMaskDataset(VectorDataset):
    filename_glob = "*.geojson"

    def __init__(self, paths, **kwargs):
        super().__init__(paths=paths, task="instance_segmentation", **kwargs)


def get_ramp_dataset(root: Path, regions: list[str]):
    image_paths , label_paths = [], []

    for region in regions:
        region_path = root / region
        img_path, lbl_path = region_path / "source", region_path / "labels"
        if img_path.exists() and lbl_path.exists():
            image_paths.append(img_path)
            label_paths.append(lbl_path)

    if not image_paths or not label_paths:
        raise ValueError(f"No valid regions found in {root}")

    target_crs = CRS.from_epsg(3857)

    print("Loading images ...")
    images = RAMPImageDataset(paths=image_paths, crs=target_crs, cache=True, res=0.4)
    
    print(f"Loaded {len(images)} image tiles from regions: {regions}")
    print("Loading labels ...")
    masks = RAMPMaskDataset(paths=label_paths, crs=target_crs, res=0.4)
    print(f"Loaded {len(masks)} masks tiles from regions: {regions}")
    return images & masks, label_paths

In [8]:
train_dataset, train_label_paths = get_ramp_dataset(cfg.data_root, cfg.train_regions)
val_dataset, val_label_paths = get_ramp_dataset(cfg.data_root, cfg.val_regions)

Loading images ...
Loaded 21 image tiles from regions: ['stage1_ramp_sample']
Loading labels ...
Loaded 21 masks tiles from regions: ['stage1_ramp_sample']
Loading images ...
Loaded 8 image tiles from regions: ['stage2_sample']
Loading labels ...
Loaded 8 masks tiles from regions: ['stage2_sample']


In [9]:

from torch.utils.data import DataLoader
from torchgeo.samplers import RandomGeoSampler
import multiprocessing as mp
from transformers import Mask2FormerImageProcessor


image_processor = Mask2FormerImageProcessor.from_pretrained(
    cfg.pretrained_model,
    num_labels=2,
    do_reduce_labels=True,
    ignore_index=255, 
    size=256,
    do_normalize=True,
)

def collate_fn_mask2former(batch): # source : https://debuggercafe.com/fine-tuning-mask2former/ 
    
    images = [sample['image'].float() for sample in batch]  
    inputs = image_processor(images=images, return_tensors="pt")
    
    
    mask_labels = []
    class_labels = []

    for sample in batch:
        mask = sample["mask"]
        if mask.ndim == 2:
            mask = mask.unsqueeze(0)

        instance_masks = []
        instance_classes = []

        for i in range(mask.shape[0]):
            instance_masks.append(mask[i].float())
            instance_classes.append(1)

        if instance_masks:
            mask_labels.append(torch.stack(instance_masks))
            class_labels.append(torch.tensor(instance_classes, dtype=torch.long))
        else:
            H, W = mask.shape[-2:]
            mask_labels.append(torch.zeros((0, H, W), dtype=torch.float32))
            class_labels.append(torch.tensor([255], dtype=torch.long))


    inputs['mask_labels'] = mask_labels 
    inputs['class_labels'] = class_labels  
    
    return inputs
    
    
def create_dataloader(dataset, batch_size, num_samples, num_workers= mp.cpu_count(), is_train=True):
    sampler = RandomGeoSampler(dataset, size=256, length=num_samples)
    return DataLoader(dataset, batch_size=batch_size, sampler=sampler, num_workers=num_workers, pin_memory=True, drop_last=is_train, collate_fn=collate_fn_mask2former)


  from .autonotebook import tqdm as notebook_tqdm
  image_processor = cls(**image_processor_dict)


In [10]:
train_loader = create_dataloader(train_dataset, batch_size=cfg.stage1_batch_size,num_samples=cfg.stage1_loader_num_samples,num_workers=cfg.num_workers,is_train=True)
val_loader = create_dataloader(val_dataset, batch_size=cfg.stage1_batch_size,num_samples=cfg.stage1_loader_num_samples,num_workers=cfg.num_workers,is_train=False)

In [11]:
from transformers import Mask2FormerConfig

base_config = Mask2FormerConfig.from_pretrained(cfg.pretrained_model)
base_config.num_labels = 2
base_config.ignore_index = 255
base_config.id2label = {0: "background", 1: "building"}
base_config.label2id = {"background": 0, "building": 1}
base_config.class_weight = cfg.stage1_class_weight or 5.0 # default value of this pretrained model is 2.0
base_config.dice_weight = cfg.stage1_dice_weight or 5.0# mask quality , 5.0 default
base_config.mask_weight = cfg.stage1_mask_weight or 5.0 # mask prediction, 5.0 default
print(base_config)

Mask2FormerConfig {
  "activation_function": "relu",
  "architectures": [
    "Mask2FormerForUniversalSegmentation"
  ],
  "backbone": null,
  "backbone_config": {
    "attention_probs_dropout_prob": 0.0,
    "depths": [
      2,
      2,
      18,
      2
    ],
    "drop_path_rate": 0.3,
    "embed_dim": 128,
    "encoder_stride": 32,
    "hidden_act": "gelu",
    "hidden_dropout_prob": 0.0,
    "hidden_size": 1024,
    "image_size": 224,
    "initializer_range": 0.02,
    "layer_norm_eps": 1e-05,
    "mlp_ratio": 4.0,
    "model_type": "swin",
    "num_channels": 3,
    "num_heads": [
      4,
      8,
      16,
      32
    ],
    "num_layers": 4,
    "out_features": [
      "stage1",
      "stage2",
      "stage3",
      "stage4"
    ],
    "out_indices": [
      1,
      2,
      3,
      4
    ],
    "patch_size": 4,
    "path_norm": true,
    "qkv_bias": true,
    "stage_names": [
      "stem",
      "stage1",
      "stage2",
      "stage3",
      "stage4"
    ],
    "use_absol

In [12]:
from transformers import Mask2FormerForUniversalSegmentation

model = Mask2FormerForUniversalSegmentation.from_pretrained(
    cfg.pretrained_model,
    config=base_config,
    ignore_mismatched_sizes=True)

Some weights of Mask2FormerForUniversalSegmentation were not initialized from the model checkpoint at facebook/mask2former-swin-base-IN21k-coco-instance and are newly initialized because the shapes did not match:
- class_predictor.weight: found shape torch.Size([81, 256]) in the checkpoint and torch.Size([3, 256]) in the model instantiated
- class_predictor.bias: found shape torch.Size([81]) in the checkpoint and torch.Size([3]) in the model instantiated
- criterion.empty_weight: found shape torch.Size([81]) in the checkpoint and torch.Size([3]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
print(f"Model moved to {device}")

Model moved to cuda


In [14]:
from transformers import TrainingArguments, EarlyStoppingCallback
training_args = TrainingArguments(
    output_dir=cfg.output_dir, 
    learning_rate=cfg.stage1_learning_rate,
    per_device_train_batch_size=cfg.stage1_batch_size,
    per_device_eval_batch_size=cfg.stage1_batch_size,
    num_train_epochs=cfg.stage1_epochs,
    weight_decay=cfg.stage1_weight_decay, 
    logging_strategy="epoch",
    save_strategy="epoch",  
    eval_strategy="epoch",
    load_best_model_at_end=True, 
    metric_for_best_model="eval_loss",
    greater_is_better=False, 
    dataloader_num_workers=cfg.num_workers, 
    remove_unused_columns=False,
    fp16=True,) # find more here : https://huggingface.co/docs/transformers/v5.0.0rc2/en/main_classes/trainer#transformers.TrainingArguments

In [15]:
from transformers import Trainer

class TorchGeoTrainer(Trainer):
    def __init__(self, *args, train_loader=None, val_loader=None, **kwargs):
        super().__init__(*args, **kwargs)
        self._train_loader = train_loader
        self._val_loader = val_loader

    def get_train_dataloader(self):
        return self._train_loader

    def get_eval_dataloader(self, eval_dataset=None):
        return self._val_loader


In [16]:
trainer = TorchGeoTrainer(
    model=model,
    args=training_args,
    train_loader=train_loader,
    val_loader=val_loader,
    eval_dataset=val_dataset,
    callbacks=[
        EarlyStoppingCallback(
            early_stopping_patience=cfg.stage1_early_stopping_patience,
            early_stopping_threshold=0.01
        )
    ],
)


In [17]:
trainer.train()

[34m[1mwandb[0m: [wandb.login()] Loaded credentials for https://api.wandb.ai from /home/krschap/.netrc.
[34m[1mwandb[0m: Currently logged in as: [33mkrschap[0m ([33mkrschap-ubs[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch,Training Loss,Validation Loss


RuntimeError: cannot sample n_sample > prob_dist.size(-1) samples without replacement