# Instance Segmentation for Cell Detection

This notebook implements an optimized Mask R-CNN for instance segmentation based on reference code.
Key features include:
* Mixed precision training with torch.amp for faster computation and reduced memory usage
* Memory optimization with explicit CUDA cache clearing
* AdamW optimizer and CosineAnnealingLR scheduler for better convergence
* Enhanced data augmentation with Albumentations
* Support for both resnet50 and resnet50_v2 backbones
* Optimized anchor sizes for cell detection

In [None]:
# Cell 1 - Imports
import os
import json
import numpy as np
import skimage.io as sio
from tqdm import tqdm
import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torchvision.models.detection import (
    maskrcnn_resnet50_fpn,
    maskrcnn_resnet50_fpn_v2,
    MaskRCNN_ResNet50_FPN_Weights,
    MaskRCNN_ResNet50_FPN_V2_Weights,
)
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from torchmetrics.detection.mean_ap import MeanAveragePrecision
from PIL import Image
import gc
from pycocotools import mask as mask_utils
import time
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.colors import ListedColormap
import random

In [None]:
# Cell 2 - Memory Optimization
def clear_memory():
    """Clear memory aggressively."""
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        print(
            f"CUDA Memory: {torch.cuda.memory_allocated()/1024**3:.2f}GB / {torch.cuda.memory_reserved()/1024**3:.2f}GB"
        )

In [None]:
# Cell 3 - Utility Functions
def decode_maskobj(mask_obj):
    """
    Decode a mask object into a binary mask.
    """
    return mask_utils.decode(mask_obj)


def encode_mask(binary_mask):
    """
    Encode a binary mask into a mask object.
    """
    arr = np.asfortranarray(binary_mask).astype(np.uint8)
    rle = mask_utils.encode(arr)
    rle["counts"] = rle["counts"].decode("utf-8")
    return rle


def read_maskfile(filepath):
    """
    Read a mask file into a numpy array.
    """
    return sio.imread(filepath)

In [None]:
# Cell 4 - Dataset
class CellInstanceDataset(Dataset):
    """
    A dataset for cell instance segmentation.
    """
    def __init__(self, root_dir, transform=None):
        """
        Initialize the dataset.
        """
        self.root_dir = root_dir
        self.transforms = transform
        self.image_dirs = sorted(os.listdir(root_dir))

    def __getitem__(self, idx):
        """
        Get an item from the dataset.
        """
        image_name = self.image_dirs[idx]
        image_path = os.path.join(self.root_dir, image_name, "image.tif")
        image = Image.open(image_path).convert("RGB")
        masks = []
        boxes = []
        labels = []
        for class_idx in range(1, 5):
            class_path = os.path.join(
                self.root_dir, image_name, f"class{class_idx}.tif"
            )
            if not os.path.exists(class_path):
                continue
            try:
                class_mask = read_maskfile(class_path)
            except Exception:
                print(f"WARN:{class_path}, continue")
                continue
            obj_ids = np.unique(class_mask)
            obj_ids = obj_ids[obj_ids != 0]
            for obj_id in obj_ids:
                binary_mask = class_mask == obj_id
                pos = np.where(binary_mask)
                if pos[0].size == 0 or pos[1].size == 0:
                    continue
                xmin, xmax = np.min(pos[1]), np.max(pos[1])
                ymin, ymax = np.min(pos[0]), np.max(pos[0])
                if xmax <= xmin or ymax <= ymin:
                    print(f"WARN illegal box in {image_name} id={obj_id}")
                    continue
                boxes.append([xmin, ymin, xmax, ymax])
                masks.append(binary_mask)
                labels.append(class_idx)
        if len(boxes) == 0:
            # Handle empty cases
            H, W = image.height, image.width
            masks = torch.zeros((0, H, W), dtype=torch.uint8)
            boxes = torch.zeros((0, 4), dtype=torch.float32)
            labels = torch.zeros((0,), dtype=torch.int64)
        else:
            boxes = torch.as_tensor(boxes, dtype=torch.float32)
            labels = torch.as_tensor(labels, dtype=torch.int64)
            masks = torch.as_tensor(np.stack(masks).astype(np.uint8), dtype=torch.uint8)
        target = {
            "boxes": boxes,
            "labels": labels,
            "masks": masks,
            "image_id": torch.tensor([idx]),
        }
        # Apply transforms
        if self.transforms is not None:
            image, target = self.transforms(image, target)
        return image, target

    def __len__(self):
        """
        Get the length of the dataset.
        """
        return len(self.image_dirs)

In [None]:
# Cell 5 - Data Transforms
from torchvision.transforms import functional as F
from torchvision.transforms import v2 as T


class MaskRCNNTransforms:
    """
    A class for data transforms for Mask R-CNN.
    """
    def __init__(self, is_train=True):
        """
        Initialize the transforms.
        """
        if is_train:
            self.transforms = T.Compose(
                [
                    T.RandomHorizontalFlip(p=0.5),
                    T.RandomVerticalFlip(p=0.5),
                    T.RandomApply([T.ElasticTransform(alpha=30, sigma=12)], p=0.3),
                    T.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2),
                    T.ToTensor(),
                    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                ]
            )
        else:
            self.transforms = T.Compose(
                [
                    T.ToTensor(),
                    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                ]
            )

    def __call__(self, image, target):
        """
        Apply the transforms to the image and target.
        """
        image, target = self.transforms(image, target)
        return image, target

In [None]:
# Cell 6 - Collate Function
def collate_fn(batch):
    """
    Collate function for the dataset.
    """
    return tuple(zip(*batch))

In [None]:
# Cell 7 - Model
def get_model(num_classes=5, model_type="resnet50_v2", min_size=512, max_size=512):
    """
    Create a Mask R-CNN model with performance improvements.
    """
    # Load specified backbone with pretrained weights
    try:
        if model_type == "resnet50":
            weights = MaskRCNN_ResNet50_FPN_Weights.DEFAULT
            model = maskrcnn_resnet50_fpn(
                weights=weights, progress=True, min_size=min_size, max_size=max_size
            )
        elif model_type == "resnet50_v2":
            weights = MaskRCNN_ResNet50_FPN_V2_Weights.DEFAULT
            model = maskrcnn_resnet50_fpn_v2(
                weights=weights, progress=True, min_size=min_size, max_size=max_size
            )
        else:
            raise ValueError(
                f"Unknown model_type '{model_type}', expected 'resnet50' or 'resnet50_v2'"
            )

        # Use smaller anchor sizes optimized for cell detection
        if hasattr(model, "rpn") and hasattr(model.rpn, "anchor_generator"):
            # Smaller anchor sizes for detecting tiny cell objects
            model.rpn.anchor_generator.sizes = tuple(
                [(8,), (16,), (32,), (64,), (128,)]
            )

            # More aspect ratios for better coverage
            model.rpn.anchor_generator.aspect_ratios = ((0.5, 1.0, 2.0),) * len(
                model.rpn.anchor_generator.sizes
            )

        # Replace the classification head
        in_features_box = model.roi_heads.box_predictor.cls_score.in_features
        model.roi_heads.box_predictor = FastRCNNPredictor(in_features_box, num_classes)

        # Replace the mask prediction head
        in_channels_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
        model.roi_heads.mask_predictor = MaskRCNNPredictor(
            in_channels_mask, 256, num_classes
        )

        # Adjust proposals per image and NMS settings
        if hasattr(model, "roi_heads"):
            model.roi_heads.detections_per_img = 200  # Increased from 100
            if hasattr(model.roi_heads, "nms_thresh"):
                model.roi_heads.nms_thresh = 0.3  # Lower NMS threshold
            model.roi_heads.score_thresh = 0.01  # Lower score threshold

        clear_memory()
        return model

    except Exception as e:
        print(f"Error creating model: {e}")
        # In case of error, try with more conservative memory settings
        clear_memory()
        if model_type == "resnet50":
            weights = MaskRCNN_ResNet50_FPN_Weights.DEFAULT
            model = maskrcnn_resnet50_fpn(
                weights=weights, min_size=min_size // 2, max_size=max_size // 2
            )
        else:
            weights = MaskRCNN_ResNet50_FPN_Weights.DEFAULT  # Fallback to resnet50
            model = maskrcnn_resnet50_fpn(
                weights=weights, min_size=min_size // 2, max_size=max_size // 2
            )

        # Basic required model updates
        in_features_box = model.roi_heads.box_predictor.cls_score.in_features
        model.roi_heads.box_predictor = FastRCNNPredictor(in_features_box, num_classes)

        # Reduce proposals to minimum
        model.roi_heads.detections_per_img = 50

        clear_memory()
        return model

In [None]:
# Cell 8 - Training Function with Mixed Precision
def train_one_epoch(model, optimizer, data_loader, device, mixed_precision=True):
    """
    Train one epoch of the model.
    """
    model.train()
    metric = MeanAveragePrecision(iou_type="bbox")
    total_loss = 0.0
    num_batches = 0
    # Mixed precision setup
    scaler = torch.amp.GradScaler("cuda") if mixed_precision else None
    for batch_idx, (images, targets) in enumerate(tqdm(data_loader)):
        try:
            # Move data to device
            images = [img.to(device) for img in images]
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            # Clear gradients
            optimizer.zero_grad()
            # Forward pass with or without mixed precision
            if mixed_precision:
                with torch.amp.autocast("cuda"):
                    loss_dict = model(images, targets)
                    losses = sum(loss for loss in loss_dict.values())
                # Skip batches with infinite or NaN losses
                if not torch.isfinite(losses):
                    print(f"[Error Batch {batch_idx}] infinite loss: {losses}")
                    continue
                # Mixed precision backward and optimizer step
                scaler.scale(losses).backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0)
                scaler.step(optimizer)
                scaler.update()
            else:
                # Standard training
                loss_dict = model(images, targets)
                losses = sum(loss for loss in loss_dict.values())
                # Skip batches with infinite or NaN losses
                if not torch.isfinite(losses):
                    print(f"[Error Batch {batch_idx}] infinite loss: {losses}")
                    continue
                losses.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0)
                optimizer.step()
            # Calculate mAP metrics
            model.eval()
            with torch.no_grad():
                preds = model(images)
            model.train()
            preds_cpu = [
                {
                    "boxes": p["boxes"].cpu(),
                    "scores": p["scores"].cpu(),
                    "labels": p["labels"].cpu(),
                }
                for p in preds
            ]
            gts_cpu = [
                {"boxes": t["boxes"].cpu(), "labels": t["labels"].cpu()}
                for t in targets
            ]
            metric.update(preds_cpu, gts_cpu)
            total_loss += losses.item()
            num_batches += 1
            # Periodically clear memory
            if batch_idx % 5 == 0:
                clear_memory()
        except Exception as e:
            print(f"[Error in Batch {batch_idx}]: {e}")
            clear_memory()
            continue
    avg_loss = total_loss / max(num_batches, 1)
    stats = metric.compute()
    map_score = stats["map"].item()
    map50_score = stats["map_50"].item()
    print(
        f"\n[Epoch Summary] Loss: {avg_loss:.4f} | mAP: {map_score:.4f} | mAP@50: {map50_score:.4f}\n"
    )
    return avg_loss, map_score, map50_score

In [None]:
# Cell 9 - Validation Function
def validate(model, loader, device):
    model.eval()
    metric = MeanAveragePrecision(iou_type="bbox")
    with torch.no_grad():
        for images, targets in tqdm(loader, desc="Validating"):
            images = [img.to(device) for img in images]
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            with torch.amp.autocast("cuda"):
                preds = model(images)
            preds_cpu = [
                {
                    "boxes": p["boxes"].cpu(),
                    "scores": p["scores"].cpu(),
                    "labels": p["labels"].cpu(),
                }
                for p in preds
            ]
            gts_cpu = [
                {"boxes": t["boxes"].cpu(), "labels": t["labels"].cpu()}
                for t in targets
            ]
            metric.update(preds_cpu, gts_cpu)
    stats = metric.compute()
    map_score = stats["map"].item()
    map50_score = stats["map_50"].item()
    print(f"Validation: mAP: {map_score:.4f} | mAP@50: {map50_score:.4f}")
    return map_score, map50_score

In [None]:
# Cell 10 - Inference and Submission
def inference(
    model,
    test_dir,
    output_json,
    image_id_map,
    confidence_threshold=0.1,
    mask_threshold=0.6,
):
    """
    Inference and submission.
    """
    clear_memory()
    model.eval()
    results = []
    with torch.no_grad():
        for image_file in tqdm(os.listdir(test_dir)):
            try:
                # Load and preprocess image
                image_path = os.path.join(test_dir, image_file)
                image = Image.open(image_path).convert("RGB")

                # Convert to tensor and normalize
                image_tensor = F.to_tensor(image).unsqueeze(0)
                image_tensor = F.normalize(
                    image_tensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                ).cuda()
                # Run inference with mixed precision
                with torch.amp.autocast("cuda"):
                    output = model(image_tensor)[0]
                # Filter results based on confidence threshold (match ref3 implementation)
                keep = output["scores"] >= confidence_threshold
                scores = output["scores"][keep].cpu()
                labels = output["labels"][keep].cpu()
                masks = output["masks"][keep, 0].cpu()  # (N,H,W) already squeezed
                boxes = output["boxes"][keep].cpu()  # (N,4) x1,y1,x2,y2
                # Process results
                for score, mask, label, box in zip(scores, masks, labels, boxes):
                    # Get bounding box coordinates and convert to xywh format
                    x1, y1, x2, y2 = box.tolist()
                    width = max(0.0, x2 - x1)
                    height = max(0.0, y2 - y1)
                    bbox = [x1, y1, width, height]
                    # Skip invalid boxes
                    if width <= 1 or height <= 1:
                        continue
                    # Binarize mask and encode as RLE
                    mask_bin = (mask > mask_threshold).numpy().astype(np.uint8)

                    # Skip empty masks
                    if mask_bin.sum() == 0:
                        continue
                    # Encode mask to RLE format
                    try:
                        rle = encode_mask(mask_bin)
                    except Exception as e:
                        print(f"Error encoding mask: {e}")
                        continue
                    # Get category ID and ensure it's valid (1-4)
                    category_id = int(label)
                    if category_id < 1 or category_id > 4:
                        print(f"Warning: Invalid category_id {category_id}, skipping")
                        continue
                    # Add result to list
                    results.append(
                        {
                            "image_id": image_id_map[image_file],
                            "bbox": [float(v) for v in bbox],
                            "score": float(score),
                            "category_id": category_id,
                            "segmentation": {
                                "size": list(mask_bin.shape),
                                "counts": rle["counts"],
                            },
                        }
                    )
                # Clear memory after processing each image
                if len(os.listdir(test_dir)) > 50 and len(results) % 10 == 0:
                    clear_memory()
            except Exception as e:
                print(f"Error processing {image_file}: {e}")
                continue
    # Print statistics about results
    print(
        f"Generated {len(results)} predictions across {len(os.listdir(test_dir))} images"
    )
    # Count predictions by category
    categories = {}
    for r in results:
        cat = r["category_id"]
        categories[cat] = categories.get(cat, 0) + 1
    print("Predictions by category:")
    for cat, count in categories.items():
        print(f"  Category {cat}: {count} predictions")
    # Calculate average score
    avg_score = sum(r["score"] for r in results) / max(len(results), 1)
    print(f"Average confidence score: {avg_score:.4f}")
    # Save results
    with open(output_json, "w") as f:
        json.dump(results, f)
    print(f"Saved {len(results)} predictions to {output_json}")

In [None]:
# Cell 11 - Main (Training)
def main_train():
    """
    Main training function.
    """
    train_root = "/kaggle/input/dataset/train"
    model_path = "/kaggle/working/maskrcnn_model.pth"
    # Set up training parameters
    model_type = "resnet50_v2"  # Use the more powerful backbone
    num_epochs = 50
    batch_size = 2  # Reduced batch size to prevent OOM
    learning_rate = 1e-4
    use_mixed_precision = False  # Enable mixed precision training
    # Set up device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    # Initialize transforms and datasets
    train_transform = MaskRCNNTransforms(is_train=True)
    val_transform = MaskRCNNTransforms(is_train=False)
    # Create model with performance improvements
    model = get_model(num_classes=5, model_type=model_type)
    model.to(device)
    # Check if model already exists
    if os.path.exists(model_path):
        model.load_state_dict(torch.load(model_path))
        print("Model already trained, loaded checkpoint.")
        return model
    # Set up training dataset and data loader
    train_dataset = CellInstanceDataset(train_root, transform=train_transform)
    # Split dataset into train and validation
    train_size = int(0.9 * len(train_dataset))
    val_size = len(train_dataset) - train_size
    train_ds, val_ds = torch.utils.data.random_split(
        train_dataset,
        [train_size, val_size],
        generator=torch.Generator().manual_seed(42),
    )
    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=collate_fn,
        num_workers=2,
        pin_memory=True,
    )
    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=collate_fn,
        num_workers=2,
        pin_memory=True,
    )
    # Set up optimizer with weight decay
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = AdamW(params, lr=learning_rate, weight_decay=1e-4)
    # Learning rate scheduler
    scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs)
    # Training loop
    best_map = 0.0
    start_time = time.time()
    for epoch in range(num_epochs):
        print(f"--- Epoch {epoch + 1}/{num_epochs} ---")
        epoch_start = time.time()
        # Train one epoch
        train_loss, train_map, train_map50 = train_one_epoch(
            model, optimizer, train_loader, device, mixed_precision=use_mixed_precision
        )
        # Validate
        val_map, val_map50 = validate(model, val_loader, device)
        # Update learning rate
        scheduler.step()
        current_lr = optimizer.param_groups[0]["lr"]
        # Save model if it's the best so far
        map_sum = val_map + val_map50
        if map_sum > best_map:
            best_map = map_sum
            torch.save(model.state_dict(), model_path)
            print(
                f"New best model saved with mAP {val_map:.4f} / mAP@50 {val_map50:.4f}"
            )
        # Report progress
        epoch_time = time.time() - epoch_start
        print(
            f"Epoch completed in {epoch_time:.2f}s. "
            f"Current learning rate: {current_lr:.2e}"
        )
        # Clear memory between epochs
        clear_memory()
    # Report training time
    total_time = time.time() - start_time
    print(f"\nTraining completed in {total_time/60:.2f} minutes.")
    print(f"Best mAP sum: {best_map:.4f}")
    print(f"Final model saved to {model_path}")
    return model

In [None]:
# Cell 12 - Visualization Functions
def ensure_dir(directory):
    """Create directory if it doesn't exist."""
    if not os.path.exists(directory):
        os.makedirs(directory)
        print(f"Created directory: {directory}")
def visualize_predictions(
    image_path,
    results,
    image_id_map,
    num_samples=5,
    figsize=(15, 10),
    save_dir="visualizations",
):
    """
    Visualize instance segmentation predictions on images.

    Args:
        image_path: Path to the test images directory
        results: List of prediction results from inference
        image_id_map: Dictionary mapping image filenames to IDs
        num_samples: Number of random samples to visualize
        figsize: Figure size for the plot
        save_dir: Directory to save the visualizations
    """
    # Create save directory
    ensure_dir(save_dir)
    # Group results by image_id
    results_by_image = {}
    for result in results:
        image_id = result["image_id"]
        if image_id not in results_by_image:
            results_by_image[image_id] = []
        results_by_image[image_id].append(result)
    # Select random samples
    sample_images = random.sample(
        list(
            results_by_image.keys()),
            min(num_samples, 
            len(results_by_image)))
    # Create a figure with subplots
    fig, axes = plt.subplots(num_samples, 2, figsize=figsize)
    if num_samples == 1:
        axes = axes.reshape(1, -1)
    # Define colors for different categories (RGB values)
    colors = [
        (1.0, 0.0, 0.0),  # Red
        (0.0, 1.0, 0.0),  # Green
        (0.0, 0.0, 1.0),  # Blue
        (1.0, 1.0, 0.0),  # Yellow
    ]
    for idx, image_id in enumerate(sample_images):
        # Find the corresponding image file
        image_file = None
        for file_name, id in image_id_map.items():
            if id == image_id:
                image_file = file_name
                break
        if image_file is None:
            continue
        # Load and display original image
        img_path = os.path.join(image_path, image_file)
        img = Image.open(img_path).convert("RGB")
        axes[idx, 0].imshow(img)
        axes[idx, 0].set_title(f"Original Image: {image_file}")
        axes[idx, 0].axis("off")
        # Display predictions
        axes[idx, 1].imshow(img)
        predictions = results_by_image[image_id]
        # Create a mask overlay
        mask_overlay = np.zeros(
            (img.height, img.width, 4), 
            dtype=np.float32
        )
        for pred in predictions:
            # Get prediction details
            bbox = pred["bbox"]
            category_id = pred["category_id"]
            score = pred["score"]
            # Get color for this category
            color = colors[category_id - 1]
            # Draw bounding box
            rect = patches.Rectangle(
                (bbox[0], bbox[1]),
                bbox[2],
                bbox[3],
                linewidth=2,
                edgecolor=color,
                facecolor="none",
            )
            axes[idx, 1].add_patch(rect)
            # Add label
            axes[idx, 1].text(
                bbox[0],
                bbox[1] - 5,
                f"Class {category_id} ({score:.2f})",
                color=color,
                fontsize=8,
                bbox=dict(facecolor="white", alpha=0.7),
            )
            # Decode and overlay mask
            if "segmentation" in pred:
                mask = decode_maskobj(pred["segmentation"])
                # Add alpha channel to color
                color_with_alpha = [*color, 0.3]
                mask_overlay[mask > 0] = color_with_alpha
        # Overlay masks
        axes[idx, 1].imshow(mask_overlay)
        axes[idx, 1].set_title(f"Predictions: {image_file}")
        axes[idx, 1].axis("off")
    plt.tight_layout()
    # Save the figure
    timestamp = time.strftime("%Y%m%d_%H%M%S")
    save_path = os.path.join(save_dir, f"predictions_{timestamp}.png")
    plt.savefig(save_path, dpi=300, bbox_inches="tight")
    print(f"Saved predictions visualization to: {save_path}")
    plt.show()
    plt.close()

def plot_training_metrics(
    train_losses, train_maps, val_maps, figsize=(15, 5), 
    save_dir="visualizations"
):
    """
    Plot training metrics over epochs.
    Args:
        train_losses: List of training losses
        train_maps: List of training mAP scores
        val_maps: List of validation mAP scores
        figsize: Figure size for the plot
        save_dir: Directory to save the plots
    """
    # Create save directory
    ensure_dir(save_dir)
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
    # Plot training loss
    ax1.plot(train_losses, label="Training Loss")
    ax1.set_xlabel("Epoch")
    ax1.set_ylabel("Loss")
    ax1.set_title("Training Loss over Epochs")
    ax1.legend()
    ax1.grid(True)
    # Plot mAP scores
    ax2.plot(train_maps, label="Training mAP")
    ax2.plot(val_maps, label="Validation mAP")
    ax2.set_xlabel("Epoch")
    ax2.set_ylabel("mAP")
    ax2.set_title("mAP Scores over Epochs")
    ax2.legend()
    ax2.grid(True)
    plt.tight_layout()
    # Save the figure
    timestamp = time.strftime("%Y%m%d_%H%M%S")
    save_path = os.path.join(save_dir, f"training_metrics_{timestamp}.png")
    plt.savefig(save_path, dpi=300, bbox_inches="tight")
    print(f"Saved training metrics plot to: {save_path}")
    plt.show()
    plt.close()

In [None]:
# Cell 13 - Main (Testing/Inference)
def main_test(model=None):
    # Load paths
    test_root = "/kaggle/input/dataset/test_release"
    image_id_map_path = "/kaggle/input/dataset/test_image_name_to_ids.json"
    model_path = "/kaggle/working/maskrcnn_model.pth"
    # Create visualization directory
    vis_dir = "/kaggle/working/visualizations"
    ensure_dir(vis_dir)
    # Load image ID mapping
    with open(image_id_map_path, "r") as f:
        image_id_map = json.load(f)
        if isinstance(image_id_map, list):
            image_id_map = {item["file_name"]: item["id"] for item in image_id_map}
        print(f"Loaded image ID mapping with {len(image_id_map)} entries")
    # Set up device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # Load model if not provided
    if model is None:
        model_type = "resnet50_v2"  # Use the more powerful backbone for inference
        model = get_model(num_classes=5, model_type=model_type)
        model.to(device)
        if os.path.exists(model_path):
            model.load_state_dict(torch.load(model_path))
            print("Loaded pretrained model for inference.")
        else:
            print("No trained model found. Run training first!")
            return
    # Set model to evaluation mode
    model.eval()
    # Lower the score threshold for the model's ROI heads
    if hasattr(model, "roi_heads"):
        original_threshold = model.roi_heads.score_thresh
        model.roi_heads.score_thresh = 0.1
        print(f"Set model ROI score threshold to {model.roi_heads.score_thresh}")
    # Run inference with the same thresholds as in ref3/infer.sh
    inference(
        model,
        test_root,
        "test-results.json",
        image_id_map,
        confidence_threshold=0.5,
        mask_threshold=0.6,
    )
    print("Inference complete. Results saved to test-results.json")
    # Load results for visualization
    with open("test-results.json", "r") as f:
        results = json.load(f)
    # Visualize predictions
    print("\nVisualizing predictions...")
    visualize_predictions(
        test_root, results, image_id_map, num_samples=5, save_dir=vis_dir
    )
    # Print some statistics about the predictions
    print("\nPrediction Statistics:")
    scores = [r["score"] for r in results]
    print(f"Average confidence score: {np.mean(scores):.4f}")
    print(f"Min confidence score: {np.min(scores):.4f}")
    print(f"Max confidence score: {np.max(scores):.4f}")
    # Count predictions by category
    categories = {}
    for r in results:
        cat = r["category_id"]
        categories[cat] = categories.get(cat, 0) + 1
    print("\nPredictions by category:")
    for cat, count in sorted(categories.items()):
        print(f"  Category {cat}: {count} predictions")

In [None]:
model = main_train()

In [None]:
main_test(model)