In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
!mkdir dataset
!unzip -q /content/drive/MyDrive/MedicalData/LizardDataset.zip -d dataset

In [3]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
import torch

CLASSES = ['Neutrophil', 'Epithelial', 'Lymphocyte', 'Plasma', 'Neutrophil', 'Connective tissue']

def visualize_lizard_sample(image, inst_map, labels, bboxes, centroids, class_names=None):
    # Convert tensors to NumPy arrays
    img = image.detach().cpu()
    H, W = img.shape[1], img.shape[2]
    image_np = img.permute(1, 2, 0).numpy().astype(np.float32)
    image_np = (image_np * 255).astype(np.uint8)     # De-normalize for display if needed

    inst_map_np = inst_map.detach().cpu().numpy()
    bboxes_np = bboxes.detach().cpu().numpy().astype(np.float32)
    cents_np = centroids.detach().cpu().numpy().astype(np.float32)
    labels_np = labels.detach().cpu().numpy().astype(np.int32)

    fig, axs = plt.subplots(1, 3, figsize=(18, 6))

    # 1. Original Image
    axs[0].imshow(image_np)
    axs[0].set_title("Original Image")
    axs[0].axis('off')

    # 2. Instance Segmentation Map
    axs[1].imshow(inst_map_np, cmap='nipy_spectral')
    axs[1].set_title("Instance Segmentation")
    axs[1].axis('off')

    # 3. Bounding Boxes, Centroids, Class Labels
    axs[2].imshow(image_np)
    axs[2].set_title("BBoxes, Centroids, Labels")
    axs[2].axis('off')

    for i in range(len(labels)):
        x1, y1, x2, y2 = bboxes[i]
        cx, cy = centroids[i]
        class_id = labels[i] - 1
        if class_names:
            class_label = str(class_id.item()) # str(class_names[class_id])

        # Draw bounding box
        rect = patches.Rectangle((x1, y1), x2 - x1, y2 - y1,
                                 linewidth=2, edgecolor='red', facecolor='none')
        axs[2].add_patch(rect)

        # Draw centroid
        axs[2].plot(cx, cy, 'bo', markersize=4)

        # Label text
        axs[2].text(x1, y1 - 5, class_label, color='yellow', fontsize=8, backgroundcolor='black')

    plt.tight_layout()
    plt.show()

In [4]:
import os
from pathlib import Path

import numpy as np
import scipy.io as sio
from PIL import Image

import torch
from torch.utils.data import Dataset

import albumentations as A
from albumentations.pytorch import ToTensorV2

import matplotlib.pyplot as plt
import matplotlib.patches as patches

class LizardDataset(Dataset):
    def __init__(self, image_root, label_root, transform=None, debug=False, padding_colour=(114, 114, 114)):
        self.image_paths, self.label_paths = [], []
        self.transform = transform
        self.debug = debug

        # Gather all image-label pairs robustly
        for subdir in sorted(os.listdir(image_root)):
            subdir_path = os.path.join(image_root, subdir)
            if not (os.path.isdir(subdir_path) and subdir in ['lizard_images1', 'lizard_images2']):
                continue

            # find first child directory deterministically
            child_dirs = sorted([d for d in os.listdir(subdir_path)
                                 if os.path.isdir(os.path.join(subdir_path, d))])
            if not child_dirs:
                continue
            image_folder = os.path.join(subdir_path, child_dirs[0])

            for fname in sorted(os.listdir(image_folder)):
                if fname.lower().endswith(('.jpg', '.png', '.jpeg', '.tif', '.tiff')):
                    img_p = os.path.join(image_folder, fname)
                    mat_name = os.path.splitext(fname)[0] + '.mat'
                    lbl_p = os.path.join(label_root, 'Lizard_Labels', 'Labels', mat_name)
                    if os.path.exists(lbl_p):
                        self.image_paths.append(img_p)
                        self.label_paths.append(lbl_p)

        self.aug = A.Compose(
            [
                A.LongestMaxSize(max_size=512),              # keep aspect ratio
                A.PadIfNeeded(min_height=512, min_width=512, position='center', border_mode=0, value=padding_colour, mask_value=0),
                # A.HorizontalFlip(p=0.5),
                # A.RandomBrightnessContrast(p=0.2),
                # A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
                ToTensorV2(),
            ],
            bbox_params=A.BboxParams(format='pascal_voc', label_fields=['bbox_labels']),
            # Keep keypoints if you still want them; you can also omit keypoints entirely and recompute from boxes
            keypoint_params=A.KeypointParams(format='xy', remove_invisible=False),
        )

        print(f"Found {len(self.image_paths)} images and {len(self.label_paths)} labels.")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label_path = self.label_paths[idx]

        image = np.array(Image.open(img_path).convert('RGB'))
        label = sio.loadmat(label_path)

        inst_map = np.asarray(label['inst_map']).astype(np.uint8)
        bboxs = np.asarray(label['bbox']).squeeze()       # (N, 4): (y1, y2, x1, x2)
        centroids = np.asarray(label['centroid']).squeeze()
        classes = np.asarray(label['class']).squeeze()

        # ensure consistent shapes for N==1
        if bboxs.ndim == 1 and bboxs.size == 4:
            bboxs = bboxs[None, :]
        if centroids.ndim == 1 and centroids.size == 2:
            centroids = centroids[None, :]
        if classes.ndim == 0:
            classes = classes[None]

        bbox_list, bbox_labels, kpts = [], [], []
        for i in range(len(bboxs)):
            y1, y2, x1, x2 = bboxs[i].astype(float)
            bbox_list.append([x1, y1, x2, y2])  # pascal_voc: x_min,y_min,x_max,y_max
            bbox_labels.append(int(classes[i]))
            kpts.append((float(centroids[i][0]), float(centroids[i][1])))

        if len(bbox_list) == 0:
            aug = self.aug(image=image, masks=[inst_map], bboxes=[], bbox_labels=[], keypoints=[])
            return (
                aug['image'],
                (
                    torch.tensor(aug['masks'][0], dtype=torch.long),
                    torch.empty(0, dtype=torch.int64),
                    torch.empty((0, 4), dtype=torch.float32),
                    torch.empty((0, 2), dtype=torch.float32),
                ),
            )

        if self.debug:
            img_before_t = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
            inst_before_t = torch.from_numpy(inst_map.astype(np.int64))
            boxes_before_t = torch.tensor(bbox_list, dtype=torch.float32) if len(bbox_list) else torch.empty((0,4),dtype=torch.float32)
            cents_before_t = torch.tensor(kpts, dtype=torch.float32) if len(kpts) else torch.empty((0,2),dtype=torch.float32)
            labels_before_t = torch.tensor(bbox_labels, dtype=torch.int64) if len(bbox_labels) else torch.empty((0,),dtype=torch.int64)
            print(img_before_t.shape)
            # labels_before_t = labels_before_t.numpy()

            visualize_lizard_sample(img_before_t, inst_before_t, labels_before_t, boxes_before_t, cents_before_t, class_names=CLASSES)

        aug = self.aug(
            image=image,
            masks=[inst_map],
            bboxes=bbox_list,
            bbox_labels=bbox_labels,
            keypoints=kpts,
        )

        image_t = aug['image'] / 255.0
        inst_t = torch.tensor(aug['masks'][0], dtype=torch.long)
        bboxes_t = torch.tensor(aug['bboxes'], dtype=torch.float32)

        # Recompute centroids from post-aug boxes (robust to filtering)
        if bboxes_t.numel() > 0:
            cx = (bboxes_t[:, 0] + bboxes_t[:, 2]) * 0.5
            cy = (bboxes_t[:, 1] + bboxes_t[:, 3]) * 0.5
            centroids_t = torch.stack([cx, cy], dim=1)
        else:
            centroids_t = torch.empty((0, 2), dtype=torch.float32)

        labels_t = torch.tensor(aug['bbox_labels'], dtype=torch.int64)

        # if self.debug:
        #    visualize_lizard_sample(image_t, inst_t, labels_t, bboxes_t, centroids_t, class_names=CLASSES)

        return image_t, (inst_t, labels_t, bboxes_t, centroids_t)

In [5]:
from torch.utils.data import DataLoader

def custom_collate_fn(batch):
    images = torch.stack([item[0] for item in batch])
    inst_maps = torch.stack([item[1][0] for item in batch])
    classes = [item[1][1] for item in batch]
    bboxes = [item[1][2] for item in batch]
    centroids = [item[1][3] for item in batch]

    return images, (inst_maps, classes, bboxes, centroids)

dataset = LizardDataset(
    image_root='/content/dataset/',
    label_root='/content/dataset/lizard_labels/',
    transform=None,
    debug = False,
    padding_colour=(114, 114, 114)
)

dataloader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=custom_collate_fn, num_workers=0) # workers has to be 0 while debugging

Found 238 images and 238 labels.


  A.PadIfNeeded(min_height=512, min_width=512, position='center', border_mode=0, value=padding_colour, mask_value=0),


# Visualisation

In [None]:
for batch_id, (images, (inst_maps, classes, bboxes, centroids)) in enumerate(dataloader):
    # print(f"Image batch shape:     {images.shape}")        # [B, 3, 256, 256]
    # print(f"Inst map batch shape:  {inst_maps.shape}")     # [B, 256, 256]")
    # print(f"Num images in batch:   {len(classes)}")         # B
    # print(f"Classes[0] shape:      {classes[0].shape}")     # [N_0]
    # print(classes[0])
    #print(f"BBoxes[0] shape:       {bboxes[0].shape}")      # [N_0, 4]
    # print(f"Centroids[0] shape:    {centroids[0].shape}")   # [N_0, 2]
    visualize_lizard_sample(images[0], inst_maps[0], classes[0], bboxes[0], centroids[0], class_names=CLASSES)
    if (batch_id + 1) % 5 == 0:
      break

Output hidden; open in https://colab.research.google.com to view.

In [None]:
def visualize_lizard_sample_with_semantic_segmentation(image, inst_map, semantic_map, labels, bboxes, centroids, class_names=None):
    # Convert tensors to NumPy arrays (same style as your first function)
    img = image.detach().cpu()
    H, W = img.shape[1], img.shape[2]
    image_np = img.permute(1, 2, 0).numpy().astype(np.float32)
    image_np = (image_np * 255).astype(np.uint8)  # de-normalize for display if needed

    inst_map_np = inst_map.detach().cpu().numpy()
    semantic_map_np = semantic_map.detach().cpu().numpy()
    bboxes_np = bboxes.detach().cpu().numpy().astype(np.float32)
    cents_np = centroids.detach().cpu().numpy().astype(np.float32)
    labels_np = labels.detach().cpu().numpy().astype(np.int32)

    fig, axs = plt.subplots(1, 4, figsize=(24, 6))

    # 1. Original Image
    axs[0].imshow(image_np)
    axs[0].set_title("Original Image")
    axs[0].axis('off')

    # 2. Instance Segmentation Map
    axs[1].imshow(inst_map_np, cmap='nipy_spectral')
    axs[1].set_title("Instance Segmentation")
    axs[1].axis('off')

    # 3. Semantic Segmentation Map
    axs[2].imshow(semantic_map_np, cmap='gray')
    axs[2].set_title("Semantic Segmentation")
    axs[2].axis('off')

    # 4. BBoxes, Centroids, Class Labels
    axs[3].imshow(image_np)
    axs[3].set_title("BBoxes, Centroids, Labels")
    axs[3].axis('off')

    for i in range(len(labels_np)):
        x1, y1, x2, y2 = bboxes_np[i]
        cx, cy = cents_np[i]
        class_id = labels_np[i] - 1

        # Keep behavior similar to your first function; ensure class_label is always defined
        if class_names:
            # If you later want names, swap to: str(class_names[class_id])
            class_label = str(int(class_id))
        else:
            class_label = str(int(class_id))

        # Draw bounding box
        rect = patches.Rectangle(
            (x1, y1), x2 - x1, y2 - y1,
            linewidth=2, edgecolor='red', facecolor='none'
        )
        axs[3].add_patch(rect)

        # Draw centroid
        axs[3].plot(cx, cy, 'bo', markersize=4)

        # Label text
        axs[3].text(x1, max(y1 - 5, 0), class_label,
                    color='yellow', fontsize=8, backgroundcolor='black')

    plt.tight_layout()
    plt.show()

In [None]:
# Fetch a batch
for batch_id, (images, (inst_maps, classes, bboxes, centroids)) in enumerate(dataloader):
    semantic_maps = (inst_maps > 0).long()

    visualize_lizard_sample_with_semantic_segmentation(images[0], inst_maps[0], semantic_maps[0], classes[0], bboxes[0], centroids[0], class_names=CLASSES)
    if (batch_id + 1) % 5 == 0:
        break

Output hidden; open in https://colab.research.google.com to view.

In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import matplotlib.colors as mcolors
import numpy as np
import torch

def visualize_lizard_sample_with_semantic_segmentation(image, inst_map, semantic_map, labels, bboxes, centroids, class_names=None):
    image_np = image.permute(1, 2, 0).cpu().numpy()
    image_np = (image_np * 255).astype(np.uint8)

    class_names = ['Background'] + class_names

    inst_map_np = inst_map.cpu().numpy()
    semantic_map_np = semantic_map.cpu().numpy()
    bboxes = bboxes.cpu().numpy()
    centroids = centroids.cpu().numpy()
    labels = labels.cpu().numpy()

    fig, axs = plt.subplots(1, 4, figsize=(28, 6))

    # 1. Original Image
    axs[0].imshow(image_np)
    axs[0].set_title("Original Image")
    axs[0].axis('off')

    # 2. Instance Segmentation
    axs[1].imshow(inst_map_np, cmap='nipy_spectral')
    axs[1].set_title("Instance Segmentation")
    axs[1].axis('off')

    # 3. Semantic Segmentation (Multi-class)
    # Use tab10 colormap (10 discrete colors), set bounds and ticks
    num_classes = 7
    cmap = plt.get_cmap('tab10', num_classes)
    bounds = np.arange(num_classes + 1) - 0.5
    norm = mcolors.BoundaryNorm(bounds, cmap.N)

    im = axs[2].imshow(semantic_map_np, cmap=cmap, norm=norm)
    axs[2].set_title("Semantic Segmentation (0–6)")
    axs[2].axis('off')
    cbar = plt.colorbar(im, ax=axs[2], ticks=range(num_classes), shrink=0.8)
    if class_names:
        cbar.ax.set_yticklabels([class_names[i] for i in range(num_classes)])

    # 4. Bounding Boxes, Centroids, Labels
    axs[3].imshow(image_np)
    axs[3].set_title("BBoxes, Centroids, Labels")
    axs[3].axis('off')

    for i in range(len(labels)):
        x1, y1, x2, y2 = bboxes[i]
        cx, cy = centroids[i]
        class_id = labels[i] - 1
        if class_names:
            class_label = str(class_id) # str(class_names[class_id])

        rect = patches.Rectangle((x1, y1), x2 - x1, y2 - y1,
                                 linewidth=2, edgecolor='red', facecolor='none')
        axs[3].add_patch(rect)
        axs[3].plot(cx, cy, 'bo', markersize=4)
        axs[3].text(x1, y1 - 5, class_label, color='yellow', fontsize=8, backgroundcolor='black')

    plt.tight_layout()
    plt.show()


In [None]:
for batch_id, (images, (inst_maps, classes, bboxes, centroids)) in enumerate(dataloader):
    for i in range(images.shape[0]):
        image = images[i]
        inst_map = inst_maps[i]
        class_ids = classes[i]

        inst_map_np = inst_map.cpu().numpy()
        class_np = class_ids.cpu().numpy()

        semantic_map = np.zeros_like(inst_map_np, dtype=np.uint8)

        instance_ids = np.unique(inst_map_np)
        instance_ids = instance_ids[instance_ids != 0]

        for j, inst_id in enumerate(instance_ids):
            if j < len(class_np):
                semantic_map[inst_map_np == inst_id] = class_np[j]

        semantic_map = torch.tensor(semantic_map, dtype=torch.int64)

        # Visualize
        visualize_lizard_sample_with_semantic_segmentation(
            image, inst_map, semantic_map, class_ids, bboxes[i], centroids[i], class_names=CLASSES
        )

    if (batch_id + 1) % 5 == 0:
        break

Output hidden; open in https://colab.research.google.com to view.

In [None]:
for batch_id, (images, (inst_maps, classes, bboxes, centroids)) in enumerate(dataloader):
    for i in range(images.shape[0]):
        image = images[i]
        inst_map = inst_maps[i]
        class_ids = classes[i]
        box_list = bboxes[i]

        inst_map_np = inst_map.cpu().numpy()
        class_np = class_ids.cpu().numpy()
        box_np = box_list.cpu().numpy()

        semantic_map = np.zeros_like(inst_map_np, dtype=np.uint8)

        # Go through all bounding boxes
        for j in range(len(box_np)):
            x1, y1, x2, y2 = box_np[j].astype(int)
            cls = class_np[j]

            # Clip to valid range
            x1, y1 = max(x1, 0), max(y1, 0)
            x2, y2 = min(x2, inst_map_np.shape[1]), min(y2, inst_map_np.shape[0])

            # Crop region from inst_map
            region = inst_map_np[y1:y2, x1:x2]
            instance_ids_in_box = np.unique(region)
            instance_ids_in_box = instance_ids_in_box[instance_ids_in_box != 0]

            for inst_id in instance_ids_in_box:
                semantic_map[inst_map_np == inst_id] = cls

        semantic_map = torch.tensor(semantic_map, dtype=torch.int64)

        # Visualize
        visualize_lizard_sample_with_semantic_segmentation(
            image, inst_map, semantic_map, class_ids, box_list, centroids[i], class_names=CLASSES
        )

    if (batch_id + 1) % 5 == 0:
        break

Output hidden; open in https://colab.research.google.com to view.

In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import matplotlib.colors as mcolors
import numpy as np
import torch

def compare_semantic_segmentations_with_detections(
    image,
    semantic_binary,
    semantic_from_instance,
    semantic_from_bbox,
    bboxes,
    centroids,
    class_labels,
    class_names=None
):
    class_names = ['Background'] + class_names

    # Convert all tensors to NumPy arrays
    image_np = image.permute(1, 2, 0).cpu().numpy()
    image_np = (image_np * 255).astype(np.uint8)

    sem_bin = semantic_binary.cpu().numpy()
    sem_inst = semantic_from_instance.cpu().numpy()
    sem_bbox = semantic_from_bbox.cpu().numpy()
    diff = (sem_inst != sem_bbox).astype(np.uint8)

    bboxes = bboxes.cpu().numpy()
    centroids = centroids.cpu().numpy()
    class_labels = class_labels.cpu().numpy()

    num_classes = 7
    cmap = plt.get_cmap('tab10', num_classes)
    bounds = np.arange(num_classes + 1) - 0.5
    norm = mcolors.BoundaryNorm(bounds, cmap.N)

    fig, axs = plt.subplots(1, 5, figsize=(32, 6))

    # 1. Binary segmentation
    axs[0].imshow(sem_bin, cmap='gray')
    axs[0].set_title("1. Binary Segmentation (0 = BG, 1 = FG)")
    axs[0].axis('off')

    # 2. Semantic from instance IDs
    im2 = axs[1].imshow(sem_inst, cmap=cmap, norm=norm)
    axs[1].set_title("2. Semantic from Instance IDs")
    axs[1].axis('off')
    cbar2 = plt.colorbar(im2, ax=axs[1], ticks=range(num_classes), shrink=0.8)
    if class_names:
        cbar2.ax.set_yticklabels([class_names[i] for i in range(num_classes)])

    # 3. Semantic from bounding boxes
    im3 = axs[2].imshow(sem_bbox, cmap=cmap, norm=norm)
    axs[2].set_title("3. Semantic from Bounding Boxes")
    axs[2].axis('off')
    cbar3 = plt.colorbar(im3, ax=axs[2], ticks=range(num_classes), shrink=0.8)
    if class_names:
        cbar3.ax.set_yticklabels([class_names[i] for i in range(num_classes)])

    # 4. Difference map
    axs[3].imshow(diff, cmap='gray', vmin=0, vmax=1)
    axs[3].set_title("4. Difference Map (White = mismatch)")
    axs[3].axis('off')

    # 5. Detections
    axs[4].imshow(image_np)
    axs[4].set_title("5. Detections (BBoxes + Labels + Centroids)")
    axs[4].axis('off')

    for i in range(len(bboxes)):
        x1, y1, x2, y2 = bboxes[i]
        cx, cy = centroids[i]
        cls_id = class_labels[i]
        label = str(cls_id)# class_names[cls_id] if class_names else str(cls_id)

        # Draw bounding box
        rect = patches.Rectangle((x1, y1), x2 - x1, y2 - y1,
                                 linewidth=2, edgecolor='red', facecolor='none')
        axs[4].add_patch(rect)

        # Draw centroid
        axs[4].plot(cx, cy, 'bo', markersize=4)

        # Draw label
        axs[4].text(x1, y1 - 5, label, color='yellow', fontsize=8, backgroundcolor='black')

    plt.tight_layout()
    plt.show()

In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import matplotlib.colors as mcolors
import numpy as np
import torch
from scipy.ndimage import label as cc_label, generate_binary_structure

def compare_semantic_segmentations_with_detections2(
    image,
    semantic_binary,
    semantic_from_instance,
    semantic_from_bbox,
    bboxes,
    centroids,
    class_labels,
    class_names=None,
    connectivity=8,      # 4 or 8 connectivity
    min_cc_area=1        # ignore components smaller than this (in pixels)
):
    if class_names is not None:
        class_names = ['Background'] + list(class_names)

    # --- Tensors -> NumPy ---
    image_np = image.detach().cpu().permute(1, 2, 0).numpy()
    image_np = (np.clip(image_np, 0, 1) * 255).astype(np.uint8)

    sem_bin = semantic_binary.detach().cpu().numpy().astype(bool)
    sem_inst = semantic_from_instance.detach().cpu().numpy()
    sem_bbox = semantic_from_bbox.detach().cpu().numpy()
    diff = (sem_inst != sem_bbox).astype(np.uint8)

    bboxes = bboxes.detach().cpu().numpy()
    centroids = centroids.detach().cpu().numpy()
    class_labels = class_labels.detach().cpu().numpy()

    # --- Connected components on the binary map ---
    structure = generate_binary_structure(2, 2 if connectivity == 8 else 1)
    labeled_cc, num_cc = cc_label(sem_bin, structure=structure)

    # Optional area filtering
    if min_cc_area > 1 and num_cc > 0:
        areas = np.bincount(labeled_cc.ravel())[1:]  # skip background (0)
        keep_labels = np.where(areas >= min_cc_area)[0] + 1
        num_cc_kept = int(keep_labels.size)
    else:
        num_cc_kept = int(num_cc)

    print("Length class labels", len(class_labels))
    print("Length bboxes", len(bboxes))
    print(f"Connected components (all): {num_cc}")
    if min_cc_area > 1:
        print(f"Connected components (area ≥ {min_cc_area}): {num_cc_kept}")

    # --- Colormap for semantic maps ---
    if class_names is not None:
        num_classes = len(class_names)
    else:
        num_classes = int(max(sem_inst.max(), sem_bbox.max())) + 1 if sem_inst.size else 1

    cmap_name = 'tab10' if num_classes <= 10 else 'tab20'
    cmap = plt.get_cmap(cmap_name, num_classes)
    bounds = np.arange(num_classes + 1) - 0.5
    norm = mcolors.BoundaryNorm(bounds, cmap.N)

    # --- Figure ---
    fig, axs = plt.subplots(1, 5, figsize=(32, 6))

    # 1. Binary segmentation (+ CC count)
    axs[0].imshow(sem_bin, cmap='gray')
    title_cc = f"1. Binary Segmentation (CCs={num_cc_kept}" + (f", min_area={min_cc_area}" if min_cc_area > 1 else "") + ")"
    axs[0].set_title(title_cc)
    axs[0].axis('off')

    # 2. Semantic from instance IDs
    im2 = axs[1].imshow(sem_inst, cmap=cmap, norm=norm)
    axs[1].set_title("2. Semantic from Instance IDs")
    axs[1].axis('off')
    cbar2 = plt.colorbar(im2, ax=axs[1], ticks=range(num_classes), shrink=0.8)
    if class_names is not None:
        cbar2.ax.set_yticklabels(class_names)

    # 3. Semantic from bounding boxes
    im3 = axs[2].imshow(sem_bbox, cmap=cmap, norm=norm)
    axs[2].set_title("3. Semantic from Bounding Boxes")
    axs[2].axis('off')
    cbar3 = plt.colorbar(im3, ax=axs[2], ticks=range(num_classes), shrink=0.8)
    if class_names is not None:
        cbar3.ax.set_yticklabels(class_names)

    # 4. Difference map
    axs[3].imshow(diff, cmap='gray', vmin=0, vmax=1)
    axs[3].set_title("4. Difference Map (White = mismatch)")
    axs[3].axis('off')

    # 5. Detections
    axs[4].imshow(image_np)
    axs[4].set_title("5. Detections (BBoxes + Labels + Centroids)")
    axs[4].axis('off')

    for i in range(len(bboxes)):
        x1, y1, x2, y2 = bboxes[i]
        cx, cy = centroids[i]
        cls_id = class_labels[i]
        label_txt = str(int(cls_id))  # or map via class_names if desired

        rect = patches.Rectangle((x1, y1), x2 - x1, y2 - y1,
                                 linewidth=2, edgecolor='red', facecolor='none')
        axs[4].add_patch(rect)
        axs[4].plot(cx, cy, 'bo', markersize=4)
        axs[4].text(x1, max(y1 - 5, 0), label_txt, color='yellow',
                    fontsize=8, backgroundcolor='black')

    plt.tight_layout()
    plt.show()

    # Optionally return counts if you want to use them programmatically
    return num_cc, num_cc_kept


In [None]:
for batch_id, (images, (inst_maps, classes, bboxes, centroids)) in enumerate(dataloader):
    for i in range(images.shape[0]):
        image = images[i]
        inst_map = inst_maps[i]
        class_ids = classes[i]
        box_list = bboxes[i]
        centroid_list = centroids[i]

        inst_map_np = inst_map.cpu().numpy()
        class_np = class_ids.cpu().numpy()
        box_np = box_list.cpu().numpy()

        # 1. OLD semantic (binary)
        semantic_old = (inst_map > 0).long()

        # 2. New semantic via bounding boxes
        semantic_from_bbox = np.zeros_like(inst_map_np, dtype=np.uint8)
        for j in range(len(box_np)):
            x1, y1, x2, y2 = box_np[j].astype(int)
            cls = class_np[j]
            x1, y1 = max(x1, 0), max(y1, 0)
            x2, y2 = min(x2, inst_map_np.shape[1]), min(y2, inst_map_np.shape[0])
            region = inst_map_np[y1:y2, x1:x2]
            instance_ids_in_box = np.unique(region)
            instance_ids_in_box = instance_ids_in_box[instance_ids_in_box != 0]
            for inst_id in instance_ids_in_box:
                semantic_from_bbox[inst_map_np == inst_id] = cls
        semantic_from_bbox = torch.tensor(semantic_from_bbox, dtype=torch.int64)

        # 3. New semantic via instance order
        semantic_from_instance_order = np.zeros_like(inst_map_np, dtype=np.uint8)
        instance_ids = np.unique(inst_map_np)
        instance_ids = instance_ids[instance_ids != 0]
        for j, inst_id in enumerate(instance_ids):
            if j < len(class_np):
                semantic_from_instance_order[inst_map_np == inst_id] = class_np[j]
        semantic_from_instance_order = torch.tensor(semantic_from_instance_order, dtype=torch.int64)

        # Convert inputs to tensors before passing
        box_tensor = torch.tensor(box_list, dtype=torch.float32)
        centroid_tensor = torch.tensor(centroid_list, dtype=torch.float32)
        class_tensor = torch.tensor(class_ids, dtype=torch.int64)

        # Visualize comparison
        compare_semantic_segmentations_with_detections2(
            image,
            semantic_old,
            semantic_from_instance_order,
            semantic_from_bbox,
            box_tensor,
            centroid_tensor,
            class_tensor,
            CLASSES
        )

    if (batch_id + 1) % 5 == 0:
        break

Output hidden; open in https://colab.research.google.com to view.

In [None]:
# Fetch a batch
for batch_id, (images, (inst_maps, classes, bboxes, centroids)) in enumerate(dataloader):
    semantic_maps = (inst_maps > 0).long()

    visualize_lizard_sample_with_semantic_segmentation(images[0], inst_maps[0], semantic_maps[0], classes[0], bboxes[0], centroids[0], class_names=CLASSES)
    if (batch_id + 1) % 5 == 0:
        break

# Networks

## Segmentation

In [None]:
import os
import json

# ==== EXPERIMENT CONFIGURATION ====
EXP_NO = 0
LEARNING_RATE = 1e-3
BATCH_SIZE = 8
NUM_EPOCHS = 20
NUM_CLASSES = 7
K_FOLDS = 5
RANDOM_SEED = 42
NOTEBOOK_VERSION = 1
OPTIMIZER = "adam"  # "adam" or "sgd"

# ==== DIRECTORY SETUP ====
ROOT_EXPERIMENT = "/content/drive/MyDrive/MedicalData/Experiments/"
EXPERIMENT_PATH = os.path.join(ROOT_EXPERIMENT, f"exp_{EXP_NO}")
os.makedirs(EXPERIMENT_PATH, exist_ok=True)

# ==== SAVE CONFIGURATION ====
CONFIG = {
    "EXP_NO": EXP_NO,
    "LEARNING_RATE": LEARNING_RATE,
    "BATCH_SIZE": BATCH_SIZE,
    "NUM_EPOCHS": NUM_EPOCHS,
    "NUM_CLASSES": NUM_CLASSES,
    "K_FOLDS": K_FOLDS,
    "RANDOM_SEED": RANDOM_SEED,
    "NOTEBOOK_VERSION": NOTEBOOK_VERSION,
    "OPTIMIZER": OPTIMIZER.upper()
}

with open(os.path.join(EXPERIMENT_PATH, "config.json"), "w") as f:
    json.dump(CONFIG, f, indent=4)

print(f"Experiment directory and config saved at: {EXPERIMENT_PATH}")

Experiment directory and config saved at: /content/drive/MyDrive/MedicalData/Experiments/exp_0


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class UNet(nn.Module):
    def __init__(self, in_channels=3, out_classes=7):
        super(UNet, self).__init__()

        def conv_block(in_c, out_c):
            return nn.Sequential(
                nn.Conv2d(in_c, out_c, 3, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_c, out_c, 3, padding=1),
                nn.ReLU(inplace=True)
            )

        self.enc1 = conv_block(in_channels, 64)
        self.enc2 = conv_block(64, 128)
        self.enc3 = conv_block(128, 256)

        self.pool = nn.MaxPool2d(2)

        # Decoder blocks expect concatenated channels (enc + dec)
        self.dec3 = conv_block(256 + 128, 128)
        self.dec2 = conv_block(128 + 64, 64)

        self.final = nn.Conv2d(64, out_classes, kernel_size=1)

    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)               # [B, 64, H, W]
        e2 = self.enc2(self.pool(e1))   # [B, 128, H/2, W/2]
        e3 = self.enc3(self.pool(e2))   # [B, 256, H/4, W/4]

        # Decoder
        d3 = F.interpolate(e3, scale_factor=2, mode='bilinear', align_corners=False)
        d3 = self.dec3(torch.cat([d3, e2], dim=1))  # [B, 256+128, H/2, W/2] → [B, 128, H/2, W/2]

        d2 = F.interpolate(d3, scale_factor=2, mode='bilinear', align_corners=False)
        d2 = self.dec2(torch.cat([d2, e1], dim=1))  # [B, 128+64, H, W] → [B, 64, H, W]

        out = self.final(d2)  # [B, num_classes, H, W]
        return out

In [None]:
import os
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Subset, DataLoader
from sklearn.model_selection import KFold

def compute_iou(preds, targets, num_classes):
    ious = []
    preds = preds.view(-1)
    targets = targets.view(-1)
    for cls in range(num_classes):
        pred_inds = preds == cls
        target_inds = targets == cls
        intersection = (pred_inds & target_inds).sum().item()
        union = (pred_inds | target_inds).sum().item()
        if union == 0:
            ious.append(float('nan'))
        else:
            ious.append(intersection / union)
    mean_iou = np.nanmean(ious)
    return mean_iou, ious

def run_epoch(model, dataloader, optimizer, criterion, device, num_classes, is_train=True):
    model.train() if is_train else model.eval()

    total_loss = 0
    total_iou = 0
    total_class_ious = np.zeros(num_classes)
    class_counts = np.zeros(num_classes)
    num_batches = 0

    for images, (inst_maps, _, _, _) in dataloader:
        sem_seg_targets = (inst_maps > 0).long()
        images = images.to(device).float()
        targets = sem_seg_targets.to(device)

        if is_train:
            optimizer.zero_grad()

        outputs = model(images)
        loss = criterion(outputs, targets)

        if is_train:
            loss.backward()
            optimizer.step()

        total_loss += loss.item()
        num_batches += 1

        with torch.no_grad():
            preds = torch.argmax(outputs, dim=1)
            mean_iou, class_ious = compute_iou(preds, targets, num_classes)
            total_iou += mean_iou
            for i, val in enumerate(class_ious):
                if not np.isnan(val):
                    total_class_ious[i] += val
                    class_counts[i] += 1

    avg_loss = total_loss / num_batches
    avg_iou = total_iou / num_batches
    avg_class_ious = total_class_ious / np.maximum(class_counts, 1)

    return avg_loss, avg_iou, avg_class_ious

def plot_metrics(metrics, folder_path):
    epochs = np.arange(1, len(metrics["train_loss"]) + 1)
    plt.figure()
    plt.plot(epochs, metrics["train_loss"], label="Train Loss")
    plt.plot(epochs, metrics["val_loss"], label="Val Loss")
    plt.plot(epochs, metrics["train_miou"], label="Train mIoU")
    plt.plot(epochs, metrics["val_miou"], label="Val mIoU")
    plt.xlabel("Epoch")
    plt.ylabel("Value")
    plt.legend()
    plt.title("Training Metrics")
    plt.savefig(os.path.join(folder_path, "metric_plot.png"))
    plt.close()

def plot_class_iou_hist(class_ious, folder_path):
    final_class_ious = np.array(class_ious[-1])
    plt.figure()
    plt.bar(range(len(final_class_ious)), final_class_ious)
    plt.xlabel("Class")
    plt.ylabel("IoU")
    plt.title("Final Epoch Class IoUs")
    plt.savefig(os.path.join(folder_path, "class_iou_hist.png"))
    plt.close()

def cross_validate(model_class, dataset, device, output_dir):
    kfold = KFold(n_splits=K_FOLDS, shuffle=True, random_state=RANDOM_SEED)
    os.makedirs(output_dir, exist_ok=True)

    for fold, (train_ids, val_ids) in enumerate(kfold.split(dataset)):
        print(f"\n===== Fold {fold+1}/{K_FOLDS} =====")

        fold_dir = os.path.join(output_dir, f"fold_{fold}")
        os.makedirs(fold_dir, exist_ok=True)

        train_loader = DataLoader(Subset(dataset, train_ids), batch_size=BATCH_SIZE, shuffle=True, collate_fn=custom_collate_fn)
        val_loader = DataLoader(Subset(dataset, val_ids), batch_size=BATCH_SIZE, shuffle=False, collate_fn=custom_collate_fn)

        model = model_class().to(device)

        if OPTIMIZER.lower() == "adam":
            optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
        elif OPTIMIZER.lower() == "sgd":
            optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=0.9)
        else:
            raise ValueError(f"Unsupported optimizer: {OPTIMIZER}")

        criterion = nn.CrossEntropyLoss()
        best_val_miou = -1
        best_model_path = os.path.join(fold_dir, "model_best.pth")

        metrics = {
            "train_loss": [],
            "val_loss": [],
            "train_miou": [],
            "val_miou": [],
            "val_class_ious": []
        }

        for epoch in range(NUM_EPOCHS):
            train_loss, train_miou, _ = run_epoch(model, train_loader, optimizer, criterion, device, NUM_CLASSES, is_train=True)
            val_loss, val_miou, val_class_ious = run_epoch(model, val_loader, optimizer, criterion, device, NUM_CLASSES, is_train=False)

            metrics["train_loss"].append(train_loss)
            metrics["val_loss"].append(val_loss)
            metrics["train_miou"].append(train_miou)
            metrics["val_miou"].append(val_miou)
            metrics["val_class_ious"].append(val_class_ious)

            print(f"[Epoch {epoch+1}] Train Loss: {train_loss:.4f} | Train mIoU: {train_miou:.4f}")
            print(f"           Val Loss: {val_loss:.4f} | Val mIoU: {val_miou:.4f}")

            if val_miou > best_val_miou:
                best_val_miou = val_miou
                torch.save(model.state_dict(), best_model_path)

        np.save(os.path.join(fold_dir, "metrics.npy"), metrics)
        plot_metrics(metrics, fold_dir)
        plot_class_iou_hist(metrics["val_class_ious"], fold_dir)

In [None]:
model = UNet(in_channels=3, out_classes=NUM_CLASSES)

if OPTIMIZER.lower() == "adam":
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
elif OPTIMIZER.lower() == "sgd":
    optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=0.9)
else:
    assert False

cross_validate(
    model_class=lambda: UNet(in_channels=3, out_classes=NUM_CLASSES),
    dataset=dataset,
    device='cuda',
    output_dir=EXPERIMENT_PATH
)


===== Fold 1/5 =====


  inst_map = torch.tensor(aug['masks'][0], dtype=torch.int32)


[Epoch 1] Train Loss: 0.6094 | Train mIoU: 0.3885
           Val Loss: 0.3907 | Val mIoU: 0.4210
[Epoch 2] Train Loss: 0.3881 | Train mIoU: 0.4436
           Val Loss: 0.3299 | Val mIoU: 0.4913
[Epoch 3] Train Loss: 0.3510 | Train mIoU: 0.5264
           Val Loss: 0.3065 | Val mIoU: 0.5520
[Epoch 4] Train Loss: 0.3285 | Train mIoU: 0.5526
           Val Loss: 0.2896 | Val mIoU: 0.5801
[Epoch 5] Train Loss: 0.3238 | Train mIoU: 0.5663
           Val Loss: 0.3010 | Val mIoU: 0.5946
[Epoch 6] Train Loss: 0.3179 | Train mIoU: 0.5722
           Val Loss: 0.2796 | Val mIoU: 0.5956
[Epoch 7] Train Loss: 0.3049 | Train mIoU: 0.5940
           Val Loss: 0.2858 | Val mIoU: 0.5831
[Epoch 8] Train Loss: 0.3149 | Train mIoU: 0.5856
           Val Loss: 0.2714 | Val mIoU: 0.6003
[Epoch 9] Train Loss: 0.3096 | Train mIoU: 0.5859
           Val Loss: 0.2836 | Val mIoU: 0.5684
[Epoch 10] Train Loss: 0.3169 | Train mIoU: 0.5799
           Val Loss: 0.2817 | Val mIoU: 0.6255
[Epoch 11] Train Loss: 0.2923

## Detection

In [None]:
import torchvision
from torchvision.models.detection import maskrcnn_resnet50_fpn

def get_maskrcnn_model(num_classes=6):
    model = maskrcnn_resnet50_fpn(pretrained=True)
    in_features = model.roi_heads.box_predictor.cls_score.in_features

    model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(
        in_features, num_classes
    )
    model.roi_heads.mask_predictor = torchvision.models.detection.mask_rcnn.MaskRCNNPredictor(
        model.roi_heads.mask_predictor.conv5_mask.in_channels,
        256,
        num_classes
    )
    return model

In [None]:
# --- Helper: build torchvision targets from your batch ---
def _boxes_to_int(b, H, W):
    # clamp and convert to int pixel indices
    x1 = max(0, min(W - 1, int(round(float(b[0])))))
    y1 = max(0, min(H - 1, int(round(float(b[1])))))
    x2 = max(0, min(W - 1, int(round(float(b[2])))))
    y2 = max(0, min(H - 1, int(round(float(b[3])))))
    # ensure proper ordering
    x1, x2 = (x1, x2) if x1 <= x2 else (x2, x1)
    y1, y2 = (y1, y2) if y1 <= y2 else (y2, y1)
    return x1, y1, x2, y2

def _derive_instance_masks_from_inst_map(inst_map_t, boxes_t, labels_t):
    """
    Convert a per-pixel inst_map (H,W) with unique instance IDs into a stack of instance masks (N,H,W),
    one per bbox/label. We map each bbox to the most frequent non-zero instance id inside the box.
    """
    H, W = inst_map_t.shape[-2], inst_map_t.shape[-1]
    inst_np = inst_map_t.cpu().numpy()  # int64/long
    masks = []
    kept_boxes = []
    kept_labels = []

    for i in range(len(boxes_t)):
        x1, y1, x2, y2 = _boxes_to_int(boxes_t[i], H, W)
        if x2 <= x1 or y2 <= y1:
            continue  # degenerate box after aug
        crop = inst_np[y1:y2+1, x1:x2+1]
        # pick the dominant nonzero instance id in this crop
        vals, counts = np.unique(crop, return_counts=True)
        nonzero = [(v, c) for v, c in zip(vals, counts) if v != 0]
        if not nonzero:
            # fall back: if no instance id in the crop, skip this target
            continue
        inst_id = max(nonzero, key=lambda t: t[1])[0]
        mask = (inst_np == inst_id)
        masks.append(torch.from_numpy(mask.astype(np.uint8)))
        kept_boxes.append(boxes_t[i])
        kept_labels.append(labels_t[i])

    if len(masks) == 0:
        # Return properly shaped empties
        return (
            torch.empty((0, H, W), dtype=torch.uint8),
            torch.empty((0, 4), dtype=boxes_t.dtype, device=boxes_t.device),
            torch.empty((0,), dtype=labels_t.dtype, device=labels_t.device),
        )
    masks_t = torch.stack(masks, dim=0).to(inst_map_t.device)
    boxes_t = torch.stack(kept_boxes, dim=0) if isinstance(kept_boxes[0], torch.Tensor) else torch.tensor(kept_boxes, device=inst_map_t.device)
    labels_t = torch.stack(kept_labels, dim=0) if isinstance(kept_labels[0], torch.Tensor) else torch.tensor(kept_labels, device=inst_map_t.device)
    return masks_t, boxes_t, labels_t


def build_targets_from_batch(inst_maps, classes, bboxes):
    """
    Convert your collated batch to torchvision-style targets: a list of dicts,
    each with keys: 'boxes' (FloatTensor[N,4]), 'labels' (Int64Tensor[N]),
    'masks' (UInt8Tensor[N,H,W]).
    """
    batch_targets = []
    B = inst_maps.shape[0]
    for i in range(B):
        inst_map_i = inst_maps[i]                           # (H,W) long
        labels_i   = classes[i]                             # (Ni,)
        boxes_i    = bboxes[i]                              # (Ni,4)

        # Ensure correct dtypes/devices
        labels_i = labels_i.to(dtype=torch.int64, device=inst_map_i.device)
        boxes_i  = boxes_i.to(dtype=torch.float32, device=inst_map_i.device)

        masks_i, boxes_i, labels_i = _derive_instance_masks_from_inst_map(inst_map_i, boxes_i, labels_i)

        tgt = {
            "boxes":  boxes_i,                              # [Ni,4] x1,y1,x2,y2
            "labels": labels_i,                             # [Ni]
            "masks":  masks_i,                              # [Ni,H,W] uint8 (0/1)
        }
        batch_targets.append(tgt)
    return batch_targets

In [None]:
# --- Add this helper ---
def _sanitize_targets(targets, num_classes):
    sane = []
    for t in targets:
        boxes  = t["boxes"]
        labels = torch.tensor(t["labels"], dtype=torch.int64) - 1
        masks  = t.get("masks", None)

        # valid labels: 1..num_classes-1
        valid = (labels >= 0) & (labels <= (num_classes - 1))
        if valid.numel() == 0:
            sane.append({"boxes": boxes.new_zeros((0,4)), "labels": labels.new_zeros((0,), dtype=torch.int64),
                         "masks": masks.new_zeros((0, *masks.shape[-2:])) if masks is not None else None})
            continue

        if valid.sum().item() != labels.numel():
            # drop invalid ones
            boxes  = boxes[valid]
            labels = labels[valid]
            masks  = masks[valid] if masks is not None and masks.numel() > 0 else masks

        # drop degenerate boxes (area <= 0) to be safe
        if boxes.numel() > 0:
            x1, y1, x2, y2 = boxes.unbind(1)
            good = (x2 > x1) & (y2 > y1) & torch.isfinite(boxes).all(dim=1)
            boxes  = boxes[good]
            labels = labels[good]
            masks  = masks[good] if masks is not None and masks.numel() > 0 else masks

        sane.append({"boxes": boxes, "labels": labels, "masks": masks})
    return sane

def _get_num_classes(model):
    bp = model.roi_heads.box_predictor
    lin = getattr(bp, "cls_score", None) or getattr(bp, "cls_logits", None)
    return lin.out_features

In [None]:
import math
from collections import defaultdict

import torch
import torch.nn as nn
import torch.optim as optim

# --- Training loop ---
def train_maskrcnn(
    model,
    dataloader,
    num_epochs=10,
    lr=1e-3,
    weight_decay=1e-4,
    clip_grad_norm=0.0,           # set >0 to enable grad clipping
    use_amp=True,
    lr_scheduler_fn=None,         # e.g., lambda opt: torch.optim.lr_scheduler.StepLR(opt, step_size=5, gamma=0.1)
    device=None,
    print_every=20,
):
    """
    Trains Mask R-CNN using your dataloader (images, (inst_maps, classes, bboxes, _)).
    Returns the trained model.
    """
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.train()

    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    scaler = torch.cuda.amp.GradScaler(enabled=use_amp and (device == "cuda"))
    scheduler = lr_scheduler_fn(optimizer) if lr_scheduler_fn else None

    for epoch in range(1, num_epochs + 1):
        running = defaultdict(float)
        count = 0

        for step, batch in enumerate(dataloader, start=1):
            images, (inst_maps, classes, bboxes, _centroids) = batch

            # Move images to device
            images = [img.to(device) for img in images]  # list[Tensor[C,H,W]]
            inst_maps = inst_maps.to(device)

            # Build torchvision targets
            targets = build_targets_from_batch(inst_maps, classes, bboxes)

            num_classes_model = _get_num_classes(model)
            targets = _sanitize_targets(targets, num_classes=num_classes_model)

            if any(t["labels"].numel() for t in targets):
              all_labels = torch.cat([t["labels"].cpu() for t in targets if t["labels"].numel()], 0)
              # pass
              print("Label min/max:", int(all_labels.min()), int(all_labels.max()),
                    "| num_classes:", num_classes_model)

            # If every image ended up with 0 instances (rare), skip the step to avoid NaNs
            if all(t["boxes"].numel() == 0 for t in targets):
                continue

            optimizer.zero_grad(set_to_none=True)

            with torch.cuda.amp.autocast(enabled=scaler.is_enabled()):
                loss_dict = model(images, targets)  # training mode => returns losses
                print(loss_dict)
                loss = sum(loss_dict.values())

            if torch.isfinite(loss):
                if scaler.is_enabled():
                    scaler.scale(loss).backward()
                    if clip_grad_norm and clip_grad_norm > 0:
                        scaler.unscale_(optimizer)
                        nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm)
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    loss.backward()
                    if clip_grad_norm and clip_grad_norm > 0:
                        nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm)
                    optimizer.step()
            else:
                # Skip bad batch
                continue

            # Bookkeeping
            count += 1
            running["loss"] += float(loss.detach().cpu())
            for k, v in loss_dict.items():
                running[k] += float(v.detach().cpu())

            if step % print_every == 0:
                denom = max(1, count)
                msg = f"[Epoch {epoch:02d} | Step {step:04d}] "
                msg += " ".join([f"{k}: {running[k]/denom:.4f}" for k in ["loss", *sorted([kk for kk in loss_dict.keys()])]])
                print(msg)

        # Epoch end
        if count > 0:
            avg_loss = running["loss"] / count
            print(f"Epoch {epoch:02d} done. Avg loss: {avg_loss:.4f}")
        else:
            print(f"Epoch {epoch:02d} done. (no usable batches)")

        if scheduler is not None:
            scheduler.step()

    return model

In [None]:
for step, batch in enumerate(dataloader, start=1):
    images, (inst_maps, classes, bboxes, _centroids) = batch
    targets = build_targets_from_batch(inst_maps, classes, bboxes)

  inst_t = torch.tensor(aug['masks'][0], dtype=torch.long)


KeyboardInterrupt: 

In [None]:
model = get_maskrcnn_model(num_classes=6)
trained_model = train_maskrcnn(model, dataloader, num_epochs=20, lr=5e-4, weight_decay=1e-4, clip_grad_norm=1.0)

AcceleratorError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


# Error

In [None]:
import os, torch
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"  # better stack traces next time
print("CUDA available:", torch.cuda.is_available())
x = torch.randn(1, device="cuda")  # should succeed

CUDA available: True


In [None]:
# Recreate model fresh
model = get_maskrcnn_model(num_classes=6)

# CPU smoke test (ensures weights and shapes are fine)
_ = model.eval()
dummy = [torch.rand(3, 256, 256)]  # float in [0,1]
with torch.no_grad():
    out = model(dummy)  # should run on CPU

# Now move to GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)



Downloading: "https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth" to /root/.cache/torch/hub/checkpoints/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth


100%|██████████| 170M/170M [00:00<00:00, 236MB/s]


MaskRCNN(
  (transform): GeneralizedRCNNTransform(
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
      Resize(min_size=(800,), max_size=1333, mode='bilinear')
  )
  (backbone): BackboneWithFPN(
    (body): IntermediateLayerGetter(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): FrozenBatchNorm2d(64, eps=0.0)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(64, eps=0.0)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(64, eps=0.0)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(256, eps=0.0)
          (relu): ReLU(in

In [None]:
import torch, torchvision
print("torch:", torch.__version__)
print("torchvision:", torchvision.__version__)
print("cuda available:", torch.cuda.is_available())

torch: 2.8.0+cu126
torchvision: 0.23.0+cu126
cuda available: True


In [None]:
for step, batch in enumerate(dataloader, start=1):
    images, (inst_maps, classes, bboxes, _centroids) = batch

    images = [img.to(device).float() for img in images]
    for i in range(len(images)):
        images[i].clamp_(0, 1)

  inst_t = torch.tensor(aug['masks'][0], dtype=torch.long)


In [None]:
model = get_maskrcnn_model(num_classes=6)
trained_model = train_maskrcnn(model, dataloader, num_epochs=20, lr=5e-4, weight_decay=1e-4, clip_grad_norm=1.0, device="cpu")

  scaler = torch.cuda.amp.GradScaler(enabled=use_amp and (device == "cuda"))
  inst_t = torch.tensor(aug['masks'][0], dtype=torch.long)
  labels = torch.tensor(t["labels"], dtype=torch.int64) - 1


Label min/max: 0 5 | num_classes: 6


  with torch.cuda.amp.autocast(enabled=scaler.is_enabled()):


{'loss_classifier': tensor(1.8837, grad_fn=<NllLossBackward0>), 'loss_box_reg': tensor(0.3697, grad_fn=<DivBackward0>), 'loss_mask': tensor(1.5287, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>), 'loss_objectness': tensor(8.4098, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>), 'loss_rpn_box_reg': tensor(0.6213, grad_fn=<DivBackward0>)}
Label min/max: 0 5 | num_classes: 6
{'loss_classifier': tensor(0.9100, grad_fn=<NllLossBackward0>), 'loss_box_reg': tensor(0.3237, grad_fn=<DivBackward0>), 'loss_mask': tensor(0.8419, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>), 'loss_objectness': tensor(0.6624, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>), 'loss_rpn_box_reg': tensor(0.4849, grad_fn=<DivBackward0>)}
Label min/max: 0 5 | num_classes: 6
{'loss_classifier': tensor(1.1601, grad_fn=<NllLossBackward0>), 'loss_box_reg': tensor(0.3644, grad_fn=<DivBackward0>), 'loss_mask': tensor(0.6513, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>), 'loss_objectness': tensor(0.3756, grad_fn=<Bin

# Transform dataset to YOLO

In [16]:
IMAGE_ROOT = '/content/dataset/'  # root that contains lizard_images1/ lizard_images2/
LABEL_ROOT = '/content/dataset/lizard_labels/'   # root that contains Lizard_Labels/Labels/*.mat

# Where to write the YOLO dataset in Drive
OUT_DIR = "/content/drive/MyDrive/yolo_lizard_512/"
DATASET_NAME = "yolo_lizard_512"

# Export image/mask format
IMG_EXT = ".jpeg"
MASK_EXT = ".jpeg"   # consider ".png" if you prefer lossless
JPEG_QUALITY = 95

# Export size must match your aug below
TARGET_SIZE = 512

import os
os.makedirs(OUT_DIR, exist_ok=True)
os.makedirs(f"{OUT_DIR}/images", exist_ok=True)
os.makedirs(f"{OUT_DIR}/labels", exist_ok=True)
os.makedirs(f"{OUT_DIR}/masks", exist_ok=True)
print("Output dirs ready:", OUT_DIR)

Output dirs ready: /content/drive/MyDrive/yolo_lizard_512/


In [14]:
import os
from pathlib import Path
import numpy as np
import scipy.io as sio
from PIL import Image

import torch
from torch.utils.data import Dataset

import albumentations as A
from albumentations.pytorch import ToTensorV2

# (Optional) quick visualizer for sanity checks later
import matplotlib.pyplot as plt
import matplotlib.patches as patches

def draw_boxes(ax, boxes, labels=None, color="lime", lw=2):
    for i, b in enumerate(boxes):
        x1, y1, x2, y2 = [float(v) for v in b]
        rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=lw, edgecolor=color, facecolor='none')
        ax.add_patch(rect)
        if labels is not None:
            ax.text(x1, y1-2, str(labels[i]), fontsize=8, color=color, bbox=dict(fc="black", alpha=0.3, pad=1))

class LizardDataset(Dataset):
    def __init__(self, image_root, label_root, transform=None, debug=False, padding_colour=(114, 114, 114)):
        self.image_paths, self.label_paths = [], []
        self.transform = transform
        self.debug = debug

        print(image_root)

        for subdir in sorted(os.listdir(image_root)):
            subdir_path = os.path.join(image_root, subdir)
            if not (os.path.isdir(subdir_path) and subdir in ['lizard_images1', 'lizard_images2']):
                continue

            child_dirs = sorted([d for d in os.listdir(subdir_path)
                                 if os.path.isdir(os.path.join(subdir_path, d))])
            if not child_dirs:
                continue
            image_folder = os.path.join(subdir_path, child_dirs[0])

            for fname in sorted(os.listdir(image_folder)):
                if fname.lower().endswith(('.jpg', '.png', '.jpeg', '.tif', '.tiff')):
                    img_p = os.path.join(image_folder, fname)
                    mat_name = os.path.splitext(fname)[0] + '.mat'
                    lbl_p = os.path.join(label_root, 'Lizard_Labels', 'Labels', mat_name)
                    if os.path.exists(lbl_p):
                        self.image_paths.append(img_p)
                        self.label_paths.append(lbl_p)

        self.aug = A.Compose(
            [
                A.LongestMaxSize(max_size=TARGET_SIZE),
                A.PadIfNeeded(min_height=TARGET_SIZE, min_width=TARGET_SIZE,
                              position='center', border_mode=0,
                              value=(114,114,114), mask_value=0),
                ToTensorV2(),
            ],
            bbox_params=A.BboxParams(format='pascal_voc', label_fields=['bbox_labels']),
            keypoint_params=A.KeypointParams(format='xy', remove_invisible=False),
        )

        print(f"Found {len(self.image_paths)} images and {len(self.label_paths)} labels.")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label_path = self.label_paths[idx]

        image = np.array(Image.open(img_path).convert('RGB'))
        label = sio.loadmat(label_path)

        inst_map = np.asarray(label['inst_map']).astype(np.uint8)
        bboxs = np.asarray(label['bbox']).squeeze()       # (N, 4): (y1, y2, x1, x2)
        centroids = np.asarray(label['centroid']).squeeze()
        classes = np.asarray(label['class']).squeeze()

        if bboxs.ndim == 1 and bboxs.size == 4:
            bboxs = bboxs[None, :]
        if centroids.ndim == 1 and centroids.size == 2:
            centroids = centroids[None, :]
        if classes.ndim == 0:
            classes = classes[None]

        bbox_list, bbox_labels, kpts = [], [], []
        for i in range(len(bboxs)):
            y1, y2, x1, x2 = bboxs[i].astype(float)
            bbox_list.append([x1, y1, x2, y2])  # x_min,y_min,x_max,y_max
            bbox_labels.append(int(classes[i]))
            kpts.append((float(centroids[i][0]), float(centroids[i][1])))

        aug = self.aug(
            image=image,
            masks=[inst_map],
            bboxes=bbox_list,
            bbox_labels=bbox_labels,
            keypoints=kpts,
        )

        image_t = aug['image'] / 255.0          # C,H,W in [0,1]
        inst_t = torch.tensor(aug['masks'][0], dtype=torch.long)
        bboxes_t = torch.tensor(aug['bboxes'], dtype=torch.float32)
        labels_t = torch.tensor(aug['bbox_labels'], dtype=torch.int64)

        return image_t, (inst_t, labels_t, bboxes_t), Path(img_path).stem

In [18]:
import shutil
from tqdm import tqdm
import math

def _clamp_boxes(boxes, w, h):
    if boxes.numel() == 0:
        return boxes
    boxes[:, 0] = boxes[:, 0].clamp(0, w-1)  # x1
    boxes[:, 2] = boxes[:, 2].clamp(0, w-1)  # x2
    boxes[:, 1] = boxes[:, 1].clamp(0, h-1)  # y1
    boxes[:, 3] = boxes[:, 3].clamp(0, h-1)  # y2
    # ensure x2>=x1, y2>=y1
    boxes[:, 2] = torch.maximum(boxes[:, 2], boxes[:, 0])
    boxes[:, 3] = torch.maximum(boxes[:, 3], boxes[:, 1])
    return boxes

def _to_yolo_lines(boxes, labels, w, h):
    """boxes: tensor [N,4] in xyxy pixel coords; returns list of 'c cx cy w h' normalized strings"""
    if boxes.numel() == 0:
        return []
    cx = (boxes[:, 0] + boxes[:, 2]) / 2.0 / w
    cy = (boxes[:, 1] + boxes[:, 3]) / 2.0 / h
    bw = (boxes[:, 2] - boxes[:, 0]) / w
    bh = (boxes[:, 3] - boxes[:, 1]) / h
    lines = []
    for i in range(len(labels)):
        c = int(labels[i])
        lines.append(f"{c} {cx[i].item():.6f} {cy[i].item():.6f} {bw[i].item():.6f} {bh[i].item():.6f}")
    return lines

def shift_classes_to_zero_index(labels):
    labels = labels - 1
    return labels

def save_image_and_mask(image_t, mask_t, out_image_path, out_mask_path, img_ext=IMG_EXT, mask_ext=MASK_EXT, jpeg_quality=95):
    # image_t: C,H,W in [0,1]; mask_t: H,W long
    C, H, W = image_t.shape
    img = (image_t.clamp(0,1).permute(1,2,0).cpu().numpy() * 255.0).round().astype(np.uint8)
    img_pil = Image.fromarray(img)
    if img_ext.lower() in [".jpg", ".jpeg"]:
        img_pil.save(out_image_path, quality=jpeg_quality, subsampling=0, format="JPEG")
    else:
        img_pil.save(out_image_path)

    mask_arr = mask_t.cpu().numpy().astype(np.uint8)
    mask_pil = Image.fromarray(mask_arr)
    if mask_ext.lower() in [".jpg", ".jpeg"]:
        mask_pil.save(out_mask_path, quality=jpeg_quality, subsampling=0, format="JPEG")
    else:
        mask_pil.save(out_mask_path)

def export_to_yolo(ds, out_dir):
    images_dir = Path(out_dir) / "images"
    labels_dir = Path(out_dir) / "labels"
    masks_dir  = Path(out_dir) / "masks"

    total = len(ds)
    shifted_any = False
    empty_labels = 0
    written = 0

    for i in tqdm(range(total), desc="Converting"):
        image_t, (inst_t, labels_t, boxes_t), stem = ds[i]
        # Clamp and filter degenerate boxes
        _, H, W = image_t.shape
        boxes_t = _clamp_boxes(boxes_t.clone(), W, H)

        # drop zero-size boxes
        if boxes_t.numel() > 0:
            wh = (boxes_t[:, 2] - boxes_t[:, 0]) * (boxes_t[:, 3] - boxes_t[:, 1])
            keep = wh > 0
            boxes_t = boxes_t[keep]
            labels_t = labels_t[keep]

        labels_t = shift_classes_to_zero_index(labels_t)

        # write image + mask
        out_image_path = images_dir / f"{stem}{IMG_EXT}"
        out_mask_path  = masks_dir  / f"{stem}{MASK_EXT}"
        save_image_and_mask(image_t, inst_t, out_image_path, out_mask_path, IMG_EXT, MASK_EXT, JPEG_QUALITY)

        # write YOLO label file
        out_label_path = labels_dir / f"{stem}.txt"
        lines = _to_yolo_lines(boxes_t, labels_t, W, H)
        if len(lines) == 0:
            empty_labels += 1
            open(out_label_path, "w").close()
        else:
            with open(out_label_path, "w") as f:
                f.write("\n".join(lines))
        written += 1

    print(f"\nDone. Wrote {written} samples.")
    print(f"Empty label files: {empty_labels}")
    if shifted_any:
        print("Note: Detected 1-indexed classes in .mat and shifted to 0-index for YOLO.")

def zip_dataset(out_dir, dataset_name):
    base_dir = Path(out_dir)
    zip_path = str(base_dir.parent / f"{dataset_name}.zip")
    # Remove old zip if exists
    if os.path.exists(zip_path):
        os.remove(zip_path)
    shutil.make_archive(str(base_dir.parent / dataset_name), 'zip', base_dir)
    print("Zipped dataset at:", zip_path)
    return zip_path


In [19]:
# Create dataset and export
ds = LizardDataset(IMAGE_ROOT, LABEL_ROOT, debug=False)

export_to_yolo(ds, OUT_DIR)
zip_path = zip_dataset(OUT_DIR, DATASET_NAME)

print("All set! You can find it here in Drive:", zip_path)

  A.PadIfNeeded(min_height=TARGET_SIZE, min_width=TARGET_SIZE,


/content/dataset/
Found 238 images and 238 labels.


  inst_t = torch.tensor(aug['masks'][0], dtype=torch.long)
Converting: 100%|██████████| 238/238 [00:25<00:00,  9.48it/s]



Done. Wrote 238 samples.
Empty label files: 0
Zipped dataset at: /content/drive/MyDrive/yolo_lizard_512.zip
All set! You can find it here in Drive: /content/drive/MyDrive/yolo_lizard_512.zip
