In [1]:
import torch
import torchvision
from torchvision.models.detection.mask_rcnn import MaskRCNN_ResNet50_FPN_Weights
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
import cv2
import numpy as np
from util.general_utils import rle_to_mask


In [2]:
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

In [3]:
def get_model(num_classes=2):
    """
    Creates a Mask R-CNN model for instance segmentation.
    num_classes = 2 → background + ship
    """
    # Load the pretrained COCO model
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(
        weights=MaskRCNN_ResNet50_FPN_Weights.COCO_V1
    )

    # Replace classification head (COCO has 91 classes)
    in_features = model.roi_heads.box_predictor.cls_score.in_features  # type: ignore
    model.roi_heads.box_predictor = \
        torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)

    # Replace mask head (COCO has 91 classes)
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels # type: ignore
    hidden = 256
    model.roi_heads.mask_predictor = \
        torchvision.models.detection.mask_rcnn.MaskRCNNPredictor(
            in_features_mask, hidden, num_classes
        )

    return model

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"

print(device)

model = get_model()
model.to(device)

# Some people use Adam, but SGD works best for detection models
optimizer = optim.SGD(
    model.parameters(),
    lr=0.005,
    momentum=0.9,
    weight_decay=0.0005
)

# StepLR reduces the learning rate every 3 epochs
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

# TensorBoard logger
writer = SummaryWriter(log_dir="runs/maskrcnn_train")


cuda


In [5]:
from torch.utils.data import Dataset


class ShipDataset(Dataset):
    """
    PyTorch Dataset for instance segmentation of ships.
    Loads images lazily (one at a time) and decodes RLE masks on-demand.
    """

    def __init__(self,
                 rle_dict: dict[str, list[str]],
                 image_root: str,
                 transforms=None):
        """
        Args:
            rle_dict: {"img1.jpg": ["rle1", "rle2", ...], ...}
            image_root: folder path: "./data/images"
            transforms: torchvision transforms to apply to the image
        """
        self.rle_dict = rle_dict
        self.image_root = image_root
        self.transforms = transforms
        self.img_ids = list(rle_dict.keys())

    def __len__(self) -> int:
        """Number of images in the dataset"""
        return len(self.img_ids)

    def __getitem__(self, idx: int):
        """
        Load ONE image + all its instance masks.
        Returns the exact format Mask R-CNN expects.
        """

        img_id = self.img_ids[idx]
        rles = self.rle_dict[img_id]

        # -------------------
        # 1. Load image lazily
        # -------------------
        img_path = f"{self.image_root}/{img_id}"
        img = cv2.imread(img_path)
        if img is None:
            raise FileNotFoundError(img_path)

        img = img[:, :, ::-1]          # BGR → RGB
        img = img.astype(np.float32) / 255.0
        img_tensor = torch.from_numpy(img).permute(2, 0, 1)  # HWC → CHW

        H, W = img_tensor.shape[1:]

        # -------------------
        # 2. Decode all RLE masks
        # -------------------
        masks = []
        boxes = []

        for rle in rles:
            mask = rle_to_mask(rle, H, W).astype(np.uint8)
            
             # skip empty masks
            if mask.sum() == 0:
                continue
            
            # check if bounding box CAN be formed
            ys, xs = np.where(mask == 1)
            if len(xs) == 0:
                continue  # skip mask entirely
            
            masks.append(mask)

            x1, y1 = xs.min(), ys.min()
            x2, y2 = xs.max(), ys.max()
            boxes.append([x1, y1, x2, y2])
        
        # rcnn will go crazy if there are no valid objects   
        if len(boxes) == 0:
            # No valid objects in this image → skip and collate fn will handle it
            return None

        # Convert to torch tensors
        masks = torch.as_tensor(np.stack(masks), dtype=torch.uint8)  # [N,H,W]
        boxes = torch.as_tensor(boxes, dtype=torch.float32)          # [N,4]
        labels = torch.ones((len(boxes),), dtype=torch.int64)        # all 1 = ship class

        # -------------------
        # 3. Create COCO-style target dict
        # -------------------
        target = {
            "boxes": boxes,
            "labels": labels,
            "masks": masks,
            "image_id": torch.tensor([idx]),
            "area": (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]),
            "iscrowd": torch.zeros((len(boxes),), dtype=torch.int64),
        }

        # -------------------
        # 4. Apply transforms (we dont need them now)
        # -------------------
        # if self.transforms:
        #     img_tensor = self.transforms(img_tensor)

        return img_tensor, target


In [6]:
from tqdm import tqdm
from torch.utils.data import DataLoader

In [7]:
def train_one_epoch(
    model: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    data_loader: DataLoader,
    device: str,
    epoch: int
) -> float:
    """
    Trains Mask R-CNN for one epoch.

    Args:
        model: Mask R-CNN model.
        optimizer: SGD/Adam optimizer.
        data_loader: PyTorch DataLoader that yields (images, targets).
        device: 'cuda' or 'cpu'.
        epoch: Current epoch number.

    Returns:
        Average loss over the epoch (float).
    """

    model.train()
    total_loss: float = 0.0

    pbar = tqdm(data_loader, desc=f"Epoch {epoch}")

    for batch in pbar:
        
        if batch is None:
            continue  # skip this batch entirely

        images, targets = batch
        
        images: list[torch.Tensor] = [img.to(device) for img in images]

        targets: list[dict[str, torch.Tensor]] = [
            {key: val.to(device) for key, val in t.items()}
            for t in targets
        ]

        loss_dict: dict[str, torch.Tensor] = model(images, targets)

        losses: torch.Tensor = sum(loss_dict.values(), torch.tensor(0.0, device=device))
        
        total_loss += losses.item()

        optimizer.zero_grad()   
        losses.backward()
        optimizer.step()

  
        pbar.set_postfix(loss=float(losses.item()))

    return total_loss / len(data_loader)

In [8]:
import os 

In [9]:
def train(model, train_loader, val_loader, epochs):
    os.makedirs("checkpoints", exist_ok=True)
    for epoch in range(1, epochs + 1):
        train_loss = train_one_epoch(model, optimizer, train_loader, device, epoch)

        # Step learning rate down
        lr_scheduler.step()

        # Log to TensorBoard
        writer.add_scalar("Loss/train", train_loss, epoch)

        print(f"Epoch {epoch}/{epochs} - Loss: {train_loss:.4f}")

        # Save checkpoint
        torch.save(model.state_dict(), f"checkpoints/maskrcnn_epoch_{epoch}.pth")

    writer.close()


In [10]:
import pandas as pd
from sklearn.model_selection import train_test_split

In [11]:
df = pd.read_csv("./data/segmentations.csv")

rle_dict = (
    df.groupby("ImageId")["EncodedPixels"]
      .apply(list)
      .to_dict()
)

im_ids = list(rle_dict.keys())

im_ids[:3]

test_im_ids, temp_im_ids = train_test_split(im_ids, test_size=0.4, random_state=42)
val_im_ids, train_im_ids = train_test_split(temp_im_ids, test_size=0.5, random_state=42)

train_rle_dict = {im_id: rle_dict[im_id] for im_id in train_im_ids}


shipDataset = ShipDataset(rle_dict=train_rle_dict,
                          image_root="./data/images",
                          transforms=None)

In [12]:
from numpy.typing import NDArray

In [13]:
def collate_fn(batch : list[tuple[NDArray, dict[str, NDArray]]]):
    batch = [b for b in batch if b is not None]
    if len(batch) == 0:
        return None  # tell the train loop to skip this batch
    return tuple(zip(*batch))

In [14]:
train(
    model=model,
    train_loader=DataLoader(shipDataset, batch_size=4, shuffle=True, collate_fn=collate_fn),
    val_loader=None,
    epochs=5
)

Epoch 1:   0%|          | 0/9628 [00:00<?, ?it/s]

Epoch 1: 100%|██████████| 9628/9628 [51:30<00:00,  3.12it/s, loss=0.176]  


Epoch 1/5 - Loss: 0.3682


Epoch 2: 100%|██████████| 9628/9628 [40:10<00:00,  3.99it/s, loss=0.699]  


Epoch 2/5 - Loss: 0.3162


Epoch 3: 100%|██████████| 9628/9628 [33:10<00:00,  4.84it/s, loss=0.585] 


Epoch 3/5 - Loss: 0.3002


Epoch 4: 100%|██████████| 9628/9628 [32:51<00:00,  4.88it/s, loss=0.442] 


Epoch 4/5 - Loss: 0.2575


Epoch 5: 100%|██████████| 9628/9628 [33:20<00:00,  4.81it/s, loss=0.476] 


Epoch 5/5 - Loss: 0.2470


In [15]:
from typing import Dict, List
from util.general_utils import compute_iou_matrix, average_f_score_of_image, rles_to_masks

@torch.inference_mode()
def evaluate_model(model, rle_dict: Dict[str, List[str]], image_root: str, device="cuda"):
    """
    Evaluates Faster-RCNN segmentation performance using F2 score.
    rle_dict maps ImageId -> list of GT RLE strings.
    """
    model.eval()
    f2_scores = []

    for img_id, gt_rles in tqdm(rle_dict.items(), desc="Evaluating"):
        # --- Load image ---
        img_path = f"{image_root}/{img_id}"
        img_d = cv2.imread(img_path)
        
        if img_d is None:
            raise FileNotFoundError(img_path)
        
        img = img_d[:, :, ::-1]  # BGR→RGB
        img_tensor = torch.from_numpy(img.astype(np.float32) / 255.).permute(2,0,1).to(device)
        
        # --- Run model ---
        pred = model([img_tensor])[0]  
        pred_masks = pred["masks"].squeeze(1).cpu().numpy() > 0.5  # [N,H,W]

        # --- Prepare GT masks ---
        H, W = img_tensor.shape[1:]
        gt_masks = rles_to_masks(gt_rles, H, W)

        # --- Compute IoU matrix ---
        iou_mat = compute_iou_matrix(gt_masks, pred_masks)

        # --- Compute F2 score for this image ---
        f2 = average_f_score_of_image(iou_mat)
        f2_scores.append(f2)

    return np.mean(f2_scores), f2_scores


In [16]:
_mean_f2, f2_scores = evaluate_model(
    model=model,
    rle_dict={im_id: rle_dict[im_id] for im_id in val_im_ids},
    image_root="./data/images",
    device="cuda"
)

print(f"Mean F2 Score on Validation Set: {_mean_f2:.4f}")

Evaluating: 100%|██████████| 38511/38511 [1:02:18<00:00, 10.30it/s]

Mean F2 Score on Validation Set: 0.0811



