# Multi-Organ Cell Instance Segmentation and Classification (Hover-Net)

## 1. Setup: Imports and Configuration

In [2]:
import sys
print(sys.executable)


/ext3/miniforge3/bin/python


In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import models
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import xml.etree.ElementTree as ET  

import tifffile
import os
import cv2
from tqdm.notebook import tqdm
from scipy.ndimage import distance_transform_edt, label, center_of_mass, find_objects
from scipy.stats import mode
from skimage.segmentation import watershed
from skimage.feature import peak_local_max
from skimage.color import rgb2hed, hed2rgb
import random
import warnings
import glob
from sklearn.model_selection import train_test_split
import collections
import time
import timm 

warnings.filterwarnings('ignore')


Converting xml to npz for faster processing. This ensures that GPU is used efficiently during training and resources are not wasted in storing the image

In [None]:
def preprocess_xml_to_npz(image_dir, mask_dir):
    """
    One-time script to convert XML polygon annotations into .npz masks.

    Reads from: image_dir (containing .tif and .xml files)
    Saves to:   mask_dir (as .npz files)
    """

    print(f"Starting preprocessing...")
    print(f"Image/XML Source: {image_dir}")
    print(f"NPZ Mask Destination: {mask_dir}")

    os.makedirs(mask_dir, exist_ok=True)

    # 0 is reserved for Background
    CLASS_NAME_TO_ID = {
        "Epithelial": 1,
        "Lymphocyte": 2,
        "Neutrophil": 3,
        "Macrophage": 4,
    }

    xml_files = sorted(glob.glob(os.path.join(image_dir, "*.xml")))
    if not xml_files:
        print(f"Error: No .xml files found in {image_dir}. Please check your path.")
        return

    print(f"Found {len(xml_files)} XML files to process.")

    for xml_path in tqdm(xml_files, desc="Processing XMLs"):
        image_id = os.path.basename(xml_path).replace(".xml", "")
        tif_path = os.path.join(image_dir, f"{image_id}.tif")
        npz_path = os.path.join(mask_dir, f"{image_id}.npz")

        if os.path.exists(npz_path):
            continue

        try:
            with tifffile.TiffFile(tif_path) as tif:
                shape = tif.pages[0].shape[:2]
        except Exception as e:
            print(f"\nWarning: Could not read {tif_path}. Skipping. Error: {e}")
            continue

        instance_map = np.zeros(shape, dtype=np.uint16)
        type_map = np.zeros(shape, dtype=np.uint8) # 0 is BG

        current_instance_id = 1

        # Parse the XML file
        try:
            tree = ET.parse(xml_path)
            root = tree.getroot()

            for annotation in root.findall("Annotation"):

                class_name_element = annotation.find("Attributes/Attribute")
                if class_name_element is None:
                    continue

                class_name = class_name_element.get("Name")

                if class_name not in CLASS_NAME_TO_ID:
                    continue 

                class_id = CLASS_NAME_TO_ID[class_name]

                for region in annotation.findall("Regions/Region"):
                    vertices = []
                    for vertex in region.findall("Vertices/Vertex"):
                        x = round(float(vertex.get("X")))
                        y = round(float(vertex.get("Y")))
                        vertices.append([x, y])

                    if not vertices:
                        continue

                    polygon = np.array(vertices, dtype=np.int32)

                    cv2.fillPoly(instance_map, [polygon], current_instance_id)
                    cv2.fillPoly(type_map, [polygon], class_id)

                    current_instance_id += 1

        except ET.ParseError as e:
            print(f"\nWarning: Failed to parse {xml_path}. Skipping. Error: {e}")
            continue
        except Exception as e:
            print(f"\nWarning: An error occurred processing {xml_path}. Skipping. Error: {e}")
            continue

        if current_instance_id > 1: 
            np.savez_compressed(npz_path, instance_map=instance_map, type_map=type_map)

    print(f"--- XML Preprocessing Complete ---")
    print(f"All .npz masks are saved in: {mask_dir}")


IMAGE_DIR = "train"

NEW_MASK_DIR = "mask_new"

preprocess_xml_to_npz(IMAGE_DIR, NEW_MASK_DIR)


Starting preprocessing...
Image/XML Source: train
NPZ Mask Destination: mask_new
Found 209 XML files to process.


--- XML Preprocessing Complete ---
All .npz masks are saved in: mask_new


In [None]:
DATA_DIR = "train"
MASK_DIR = "mask_new"                 # Directory with training .npz masks
TEST_DIR = "test_final"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
PATCH_SIZE = 256      
OVERLAP = 64          
BATCH_SIZE = 8
EPOCHS = 50           # This model was earlier trained for 50 epochs and then retrained for another 50, making it 100 epochs in total
LR = 1e-4
NUM_WORKERS = 2       
VALID_SPLIT = 0.2     
MODEL_SAVE_PATH = "best_model_new.pth"
SUBMISSION_PATH = "submission_new.csv"

NUM_CLASSES = 4 
TP_MAP_CHANNELS = NUM_CLASSES + 1 # (Background + 4 cell types)

CLASS_LOSS_WEIGHTS = torch.tensor([1.0, 1.0, 10.0, 10.0]).to(DEVICE)

SUBMISSION_CLASSES = ['Epithelial', 'Lymphocyte', 'Neutrophil', 'Macrophage']
CLASS_MAP = {
    1: 'Epithelial',
    2: 'Lymphocyte',
    3: 'Neutrophil',
    4: 'Macrophage'
}

print(f"Using device: {DEVICE}")
print(f"Data Dirs: {DATA_DIR}, {MASK_DIR}, {TEST_DIR}")
print(f"Model will be saved to: {MODEL_SAVE_PATH}")
print(f"Submission will be saved to: {SUBMISSION_PATH}")

Using device: cuda
Data Dirs: train, mask_new, test_final
Model will be saved to: best_model_new_3.pth
Submission will be saved to: submission_new_2.csv


## 2. RLE Helper Functions

In [None]:
def rle_encode_instance_mask(mask: np.ndarray) -> str:
    """
    Convert an instance segmentation mask (H,W) -> RLE triple string.
    0 = background, >0 = instance IDs.
    Note: The 'start' index is 1-based and in Fortran (column-major) order.
    """
    pixels = mask.flatten(order="F").astype(np.int32)

    pixels = np.concatenate([[0], pixels, [0]])

    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1

    rle = []
    for i in range(0, len(runs) - 1):
        start = runs[i]
        end = runs[i+1]
        length = end - start
        val = pixels[start] 

        
        if val > 0:
            rle.extend([val, start, length])

    if not rle:
        return "0" 

    return " ".join(map(str, rle))


def rle_decode_instance_mask(rle: str, shape: tuple[int, int]) -> np.ndarray:
    """
    Convert RLE triple string back into an instance mask of shape (H, W).
    """
    if not rle or str(rle).strip() in ("1", "0", "nan"):
        return np.zeros(shape, dtype=np.uint16)

    s = list(map(int, rle.split()))
    mask = np.zeros(shape[0] * shape[1], dtype=np.uint16)

    for i in range(0, len(s), 3):
        val, start, length = s[i], s[i+1], s[i+2]
        # RLE start is 1-based
        mask[start-1:start-1+length] = val

    # Reshape using Fortran (column-major) order
    return mask.reshape(shape, order="F")

test_mask = np.zeros((5, 5), dtype=np.uint16)
test_mask[0:2, 2:4] = 1
test_mask[3:5, 0:2] = 2

print("Original Mask:")
print(test_mask)

rle_string = rle_encode_instance_mask(test_mask)
print(f"\nRLE String: {rle_string}")

decoded_mask = rle_decode_instance_mask(rle_string, (5, 5))
print("\nDecoded Mask:")
print(decoded_mask)

assert np.array_equal(test_mask, decoded_mask), "RLE functions are not inverses!"

Original Mask:
[[0 0 1 1 0]
 [0 0 1 1 0]
 [0 0 0 0 0]
 [2 2 0 0 0]
 [2 2 0 0 0]]

RLE String: 2 4 2 2 9 2 1 11 2 1 16 2

Decoded Mask:
[[0 0 1 1 0]
 [0 0 1 1 0]
 [0 0 0 0 0]
 [2 2 0 0 0]
 [2 2 0 0 0]]


## 3. Stratified Train/Validation Split

This is a critical step. We can't use a simple random split. If we do, we might end up with no rare cells (Neutrophils, Macrophages) in our validation set, giving us a misleadingly high score.

We will:
1.  Load every type_map from the MASK_DIR.
2.  For each image, check if it contains at least one pixel of a rare class (Neutrophil or Macrophage).
3.  Create a "stratification key" for each image (e.g., has_rare_cell).
4.  Use sklearn.model_selection.train_test_split to create an 80/20 split that respects this key, ensuring both train and val sets get a similar percentage of rare-cell-containing images.

In [None]:
def create_stratified_split(mask_dir, valid_split=0.2, random_state=42):
    print("Creating stratified train/validation split...")
    image_ids = []
    stratify_keys = []

    mask_files = sorted(glob.glob(os.path.join(mask_dir, "*.npz")))
    if not mask_files:
        print(f"Error: No .npz mask files found in {mask_dir}. Make sure your MASK_DIR is correct.")
        return None, None

    # Our internal mapping: 3=Neutrophil, 4=Macrophage
    rare_class_indices = [3, 4]

    for mask_path in tqdm(mask_files, desc="Analyzing masks"):
        image_id = os.path.basename(mask_path).replace(".npz", "")
        try:
            with np.load(mask_path) as data:
                type_map = data['type_map']

            image_ids.append(image_id)

            # Find all unique classes present in the mask
            unique_classes = np.unique(type_map)

            # Stratification key: 0=common only, 1=has rare cell
            has_rare = 0
            for rare_idx in rare_class_indices:
                if rare_idx in unique_classes:
                    has_rare = 1
                    break
            stratify_keys.append(has_rare)

        except Exception as e:
            print(f"Warning: Could not load or process {mask_path}. Skipping. Error: {e}")

    if not image_ids:
        print("Error: No valid masks were processed.")
        return None, None

    # Perform the stratified split
    train_ids, val_ids, train_keys, val_keys = train_test_split(
        image_ids,
        stratify_keys,
        test_size=valid_split,
        random_state=random_state,
        stratify=stratify_keys
    )

    print(f"Total images: {len(image_ids)}")
    print(f"Training images: {len(train_ids)}")
    print(f"Validation images: {len(val_ids)}")
    print(f"Rare images in train: {np.sum(train_keys)} / {len(train_keys)} ({np.sum(train_keys)/len(train_keys)*100:.1f}%)")
    print(f"Rare images in val: {np.sum(val_keys)} / {len(val_ids)} ({np.sum(val_keys)/len(val_ids)*100:.1f}%)")
    print("Split created successfully.")

    return train_ids, val_ids

train_ids, val_ids = create_stratified_split(MASK_DIR, valid_split=VALID_SPLIT)

Creating stratified train/validation split...


Total images: 209
Training images: 167
Validation images: 42
Rare images in train: 106 / 167 (63.5%)
Rare images in val: 27 / 42 (64.3%)
Split created successfully.


## 4. Data Loading and Preprocessing (Final Version)

This section defines the core PyTorch Dataset. 
Its main jobs:
1.  Load an image and its corresponding mask from the provided ID list.
2.  Handle data augmentation (geometric and H&E color transforms).
3.  Use the robust **"Pad-then-Crop"** logic to ensure all output patches are `PATCH_SIZE x PATCH_SIZE`.
4.  On-the-fly, generate the three target maps:
    * `np_map` (Nucleus Pixel): Binary, 0 or 1.
    * `hv_map` (Horizontal/Vertical): 2-channel float vector map.
    * `tp_map` (Type Pixel): Long tensor with class indices `0` (BG), `1` (Epi), `2` (Lym), `3` (Neu), `4` (Mac).

In [None]:
def get_hv_map(instance_map):
    """Calculates the horizontal and vertical distance maps for each instance."""
    H, W = instance_map.shape
    hv_map = np.zeros((H, W, 2), dtype=np.float32)

    # Find all unique instance IDs (ignore 0, which is background)
    instance_ids = np.unique(instance_map)[1:]

    for inst_id in instance_ids:
        inst_coords = np.where(instance_map == inst_id)

        if inst_coords[0].shape[0] == 0:
            continue

        # Calculate the center of mass (centroid)
        y_center, x_center = np.mean(inst_coords[0]), np.mean(inst_coords[1])

        # Calculate vectors from each pixel to the centroid
        y_coords, x_coords = inst_coords

        # Horizontal component (x-axis)
        x_vec = x_center - x_coords
        # Vertical component (y-axis)
        y_vec = y_center - y_coords

        # Normalize vectors
        x_vec_max = np.max(np.abs(x_vec)) + 1e-8
        y_vec_max = np.max(np.abs(y_vec)) + 1e-8

        x_vec = x_vec / x_vec_max
        y_vec = y_vec / y_vec_max

        # Assign vectors to the map
        hv_map[y_coords, x_coords, 0] = y_vec
        hv_map[y_coords, x_coords, 1] = x_vec

    return hv_map

def augment_H_and_E(img, h_scale=0.1, e_scale=0.1, h_bias=0.1, e_bias=0.1):
    """Performs H&E-specific color augmentation."""
    # Ensure image is uint8 for rgb2hed
    if img.dtype != np.uint8:
        img = img.astype(np.uint8)

    try:
        img_hed = rgb2hed(img)
    except Exception as e:
        return img

    # Hematoxylin channel (nuclei)
    h = img_hed[:, :, 0]
    h_scale_factor = 1.0 + np.random.uniform(-h_scale, h_scale)
    h_bias_factor = np.random.uniform(-h_bias, h_bias)
    h = (h * h_scale_factor) + h_bias_factor

    # Eosin channel (cytoplasm/stroma)
    e = img_hed[:, :, 1]
    e_scale_factor = 1.0 + np.random.uniform(-e_scale, e_scale)
    e_bias_factor = np.random.uniform(-e_bias, e_bias)
    e = (e * e_scale_factor) + e_bias_factor

    # Dab channel (unused, but need to pass it back)
    d = img_hed[:, :, 2]

    img_augmented_hed = np.stack([h, e, d], axis=-1)
    img_augmented_rgb = hed2rgb(img_augmented_hed)

    # Clip and convert back to uint8
    img_augmented_rgb = np.clip(img_augmented_rgb, 0, 1)
    img_augmented_rgb = (img_augmented_rgb * 255).astype(np.uint8)

    return img_augmented_rgb


class CellDataset(Dataset):
    def __init__(self, image_dir, mask_dir, image_ids, patch_size=256, augment=True):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.image_ids = image_ids 
        self.patch_size = patch_size
        self.augment = augment

        self.image_info = {}
        print(f"Initializing dataset with {len(self.image_ids)} images...")
        for img_id in tqdm(image_ids, desc="Caching image info"):
            img_path = os.path.join(self.image_dir, f"{img_id}.tif")
            mask_path = os.path.join(self.mask_dir, f"{img_id}.npz")
            if os.path.exists(img_path) and os.path.exists(mask_path):
                self.image_info[img_id] = {
                    "img_path": img_path,
                    "mask_path": mask_path
                }
            else:
                print(f"Warning: Missing file for {img_id}. Img: {os.path.exists(img_path)}, Mask: {os.path.exists(mask_path)}")

        self.image_ids = list(self.image_info.keys()) 
        print(f"Dataset initialized. Found {len(self.image_ids)} valid image-mask pairs.")

    def __len__(self):
        return len(self.image_ids) * 10 # 10 patches per image

    def __getitem__(self, idx):
        image_id = self.image_ids[idx % len(self.image_ids)]
        info = self.image_info[image_id]

        try:
            img = tifffile.imread(info["img_path"])

            if img.ndim == 3 and img.shape[-1] == 4:
                img = img[:, :, :3]

            if img.ndim == 2:
                img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)

            if img.dtype != np.uint8:
                if img.max() > 255:
                    img = (img / img.max() * 255).astype(np.uint8)
                else:
                    img = img.astype(np.uint8)

            with np.load(info["mask_path"]) as data:
                instance_map = data['instance_map']
                type_map = data['type_map']
            H, W, _ = img.shape

            pad_h = max(0, self.patch_size - H)
            pad_w = max(0, self.patch_size - W)

            img_padded = np.pad(img, ((0, pad_h), (0, pad_w), (0, 0)), mode='constant', constant_values=0)
            inst_padded = np.pad(instance_map, ((0, pad_h), (0, pad_w)), mode='constant', constant_values=0)
            type_padded = np.pad(type_map, ((0, pad_h), (0, pad_w)), mode='constant', constant_values=0)

            H_pad, W_pad, _ = img_padded.shape

            x = np.random.randint(0, W_pad - self.patch_size + 1)
            y = np.random.randint(0, H_pad - self.patch_size + 1)

            img_patch = img_padded[y:y+self.patch_size, x:x+self.patch_size]
            inst_patch = inst_padded[y:y+self.patch_size, x:x+self.patch_size]
            type_patch = type_padded[y:y+self.patch_size, x:x+self.patch_size]

            # Data Augmentation
            if self.augment:
                # Geometric
                if random.random() > 0.5: # Horizontal Flip
                    img_patch = np.fliplr(img_patch)
                    inst_patch = np.fliplr(inst_patch)
                    type_patch = np.fliplr(type_patch)
                if random.random() > 0.5: # Vertical Flip
                    img_patch = np.flipud(img_patch)
                    inst_patch = np.flipud(inst_patch)
                    type_patch = np.flipud(type_patch)
                if random.random() > 0.5: # 90-degree Rotation
                    k = random.choice([1, 2, 3])
                    img_patch = np.rot90(img_patch, k)
                    inst_patch = np.rot90(inst_patch, k)
                    type_patch = np.rot90(type_patch, k)

                # Color (H&E)
                if random.random() > 0.5:
                    img_patch = augment_H_and_E(img_patch,
                                                h_scale=0.2, e_scale=0.2,
                                                h_bias=0.2, e_bias=0.2)

                # Contrast Jitter
                if random.random() > 0.5:
                    # 20% variation in brightness/contrast
                    factor = np.random.uniform(0.8, 1.2)
                    img_patch = np.clip(img_patch.astype(np.float32) * factor, 0, 255).astype(np.uint8)

                # Add Gaussian Blur
                if random.random() > 0.5:
                    # 5x5 kernel, standard deviation 0
                    img_patch = cv2.GaussianBlur(img_patch, (5, 5), 0)

                # CoarseDropout (Cutout) to fight overfitting
                if random.random() > 0.5:
                    max_holes = 8
                    max_h_size = int(self.patch_size * 0.1) # 10% of patch size
                    max_w_size = int(self.patch_size * 0.1)
                    for _ in range(np.random.randint(1, max_holes)):
                        y1 = np.random.randint(0, self.patch_size - max_h_size + 1)
                        x1 = np.random.randint(0, self.patch_size - max_w_size + 1)
                        h = np.random.randint(1, max_h_size)
                        w = np.random.randint(1, max_w_size)
                        y2 = y1 + h
                        x2 = x1 + w
                        img_patch[y1:y2, x1:x2] = 0

            # NP Map (Nucleus vs. Background)
            np_map = (inst_patch > 0).astype(np.float32)

            # HV Map (Horizontal/Vertical vectors)
            hv_map = get_hv_map(inst_patch.copy())

            # TP Map (Pixel-wise class labels)
            tp_map = (type_patch * np_map).astype(np.int64)

            img_tensor = torch.from_numpy(img_patch.copy().transpose(2, 0, 1)).float() / 255.0
            np_tensor = torch.from_numpy(np_map).float().unsqueeze(0) 
            hv_tensor = torch.from_numpy(hv_map.transpose(2, 0, 1)).float()
            tp_tensor = torch.from_numpy(tp_map).long() 

            return img_tensor, {"np": np_tensor, "hv": hv_tensor, "tp": tp_tensor}

        except Exception as e:
            print(f"Error loading item for {image_id}. Error: {e}. Returning None.")
            return None

def collate_fn_skip_errors(batch):
    """A collate_fn that filters out None values from a batch."""
    batch = [b for b in batch if b is not None]
    if not batch:
        return None, None

    return torch.utils.data.dataloader.default_collate(batch)

## 5. Model Definition (Hover-Net)

We will define the Hover-Net architecture. It consists of:
1.  A pre-trained **ConvNeXt-Tiny** backbone as an encoder.
2.  A **Feature Pyramid Network (FPN)** decoder to create a rich, multi-scale feature map.
3.  Three separate prediction heads (small conv blocks) branching from the decoder for the `NP`, `HV`, and `TP` maps.

In [None]:
class HoverNet(nn.Module):
    def __init__(self, num_classes):
        super().__init__()

        print("Loading ImageNet-pre-trained ConvNeXt-Tiny backbone from timm...")
        self.backbone = timm.create_model(
            'convnext_tiny',
            pretrained=True,
            features_only=True
        )
        print("ConvNeXt-Tiny backbone loaded successfully.")
        
        # Channel sizes for convnext_tiny 
        # c2: /4 scale,  96 channels
        # c3: /8 scale, 192 channels
        # c4: /16 scale, 384 channels
        # c5: /32 scale, 768 channels
        
        # Decoder (FPN-like)
        self.lateral_c5 = nn.Conv2d(768, 256, 1)
        self.lateral_c4 = nn.Conv2d(384, 256, 1)
        self.lateral_c3 = nn.Conv2d(192, 256, 1)
        self.lateral_c2 = nn.Conv2d(96,  256, 1)

        # Top-down upsampling layers
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)

        # Smoothing (3x3 convs)
        self.smooth_p4 = nn.Conv2d(256, 256, 3, padding=1)
        self.smooth_p3 = nn.Conv2d(256, 256, 3, padding=1)
        self.smooth_p2 = nn.Conv2d(256, 256, 3, padding=1)

        # Final Decoder Block 
        self.final_upsample1 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        self.final_conv1 = nn.Conv2d(256, 128, 3, padding=1)
        self.final_relu1 = nn.ReLU(inplace=True)

        self.final_upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        self.final_conv2 = nn.Conv2d(128, 64, 3, padding=1)
        self.final_relu2 = nn.ReLU(inplace=True)

        # 1. Nucleus Pixel (NP) Head
        self.head_np = nn.Sequential(
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Conv2d(64, 1, 1) 
        )

        # 2. Horizontal-Vertical (HV) Head
        self.head_hv = nn.Sequential(
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Conv2d(64, 2, 1),
            nn.Tanh()
        )

        # 3. Nucleus Type (TP) Head
        self.head_tp = nn.Sequential(
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Conv2d(64, num_classes, 1)
        )

    def forward(self, x):
        c2, c3, c4, c5 = self.backbone(x)

        # Decoder (FPN)
        p5 = self.lateral_c5(c5)

        p4 = self.upsample(p5) + self.lateral_c4(c4)
        p4 = self.smooth_p4(p4)

        p3 = self.upsample(p4) + self.lateral_c3(c3)
        p3 = self.smooth_p3(p3)

        p2 = self.upsample(p3) + self.lateral_c2(c2)
        p2 = self.smooth_p2(p2)

        # Final Decoder Block
        d1 = self.final_upsample1(p2)
        d1 = self.final_relu1(self.final_conv1(d1))

        d2 = self.final_upsample2(d1)
        d2 = self.final_relu2(self.final_conv2(d2))

        # Heads
        out_np = self.head_np(d2) 
        out_hv = self.head_hv(d2)
        out_tp = self.head_tp(d2)

        return {"np": out_np, "hv": out_hv, "tp": out_tp}


try:
    if 'TP_MAP_CHANNELS' not in locals():
         TP_MAP_CHANNELS = 5 # BG + 4 classes
    if 'DEVICE' not in locals():
        DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    if 'PATCH_SIZE' not in locals():
        PATCH_SIZE = 256

    model = HoverNet(num_classes=TP_MAP_CHANNELS).to(DEVICE)
    print("Model instantiated successfully.")
    print(f"TP Head output channels: {TP_MAP_CHANNELS}")

    dummy_input = torch.randn(2, 3, PATCH_SIZE, PATCH_SIZE).to(DEVICE)
    output = model(dummy_input)
    print(f"Output shapes:")
    print(f"  NP: {output['np'].shape}")
    print(f"  HV: {output['hv'].shape}")
    print(f"  TP: {output['tp'].shape}")
    assert output['np'].shape == (2, 1, PATCH_SIZE, PATCH_SIZE)
    assert output['hv'].shape == (2, 2, PATCH_SIZE, PATCH_SIZE)
    assert output['tp'].shape == (2, TP_MAP_CHANNELS, PATCH_SIZE, PATCH_SIZE)
    print("Model test forward pass successful.")
    del model, dummy_input, output # Free memory
except Exception as e:
    print(f"Error during model test: {e}")

Loading ImageNet-pre-trained ConvNeXt-Tiny backbone from timm...
ConvNeXt-Tiny backbone loaded successfully.
Model instantiated successfully.
TP Head output channels: 5
Output shapes:
  NP: torch.Size([2, 1, 256, 256])
  HV: torch.Size([2, 2, 256, 256])
  TP: torch.Size([2, 5, 256, 256])
Model test forward pass successful.


## 6. Loss Functions

We need a combined loss function. We will use:
1.  **BCEWithLogitsLoss** for the `NP` (binary segmentation) head.
2.  **Mean Squared Error (MSE)** for the `HV` (regression) head. We only calculate this loss on pixels that are *inside* a nucleus (i.e., where `np_true == 1`).
3.  **CrossEntropyLoss** for the `TP` (classification) head.

**CRITICAL:** To handle the class imbalance, we will use a **Weighted Cross Entropy Loss** (or Focal Loss). We will provide our `CLASS_LOSS_WEIGHTS` `[1, 1, 10, 10]` to the `TP` loss function and tell it to `ignore_index=0` (the background class).

Let's implement **Focal Loss**, as it's generally superior to simple weighted CE for extreme imbalance.

In [None]:
class FocalLoss(nn.Module):
    """
    Focal Loss for multi-class classification.
    Assumes logits as input and class indices as target.
    """
    def __init__(self, alpha=None, gamma=2.0, reduction='mean', ignore_index=-100):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha 
        self.reduction = reduction
        self.ignore_index = ignore_index

    def forward(self, inputs, targets):
        # inputs shape: (B, C, H, W)
        # targets shape: (B, H, W)

        # Calculate CrossEntropyLoss (but keep it per-pixel)
        # This is the negative log-likelihood
        ce_loss = F.cross_entropy(inputs, targets,
                                  weight=self.alpha,
                                  ignore_index=self.ignore_index,
                                  reduction='none')

        # Get the probabilities of the correct class
        pt = torch.exp(-ce_loss)

        # Calculate Focal Loss
        focal_loss = ((1 - pt) ** self.gamma) * ce_loss

        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss


class CombinedHoverLoss(nn.Module):
    def __init__(self, tp_class_weights=None, ignore_index=0):
        super().__init__()

        self.loss_np = nn.BCEWithLogitsLoss()
        self.loss_hv = nn.MSELoss(reduction='none') 

        if tp_class_weights is not None:
            full_weights = torch.cat([torch.tensor([0.0]).to(DEVICE), tp_class_weights])
        else:
            full_weights = None

        self.loss_tp = FocalLoss(alpha=full_weights, gamma=1.5, ignore_index=ignore_index)

        self.ignore_index = ignore_index

    def forward(self, preds, targets):
        pred_np = preds['np'] # (B, 1, H, W)
        pred_hv = preds['hv'] # (B, 2, H, W)
        pred_tp = preds['tp'] # (B, 5, H, W)

        true_np = targets['np'] # (B, 1, H, W)
        true_hv = targets['hv'] # (B, 2, H, W)
        true_tp = targets['tp'] # (B, H, W) - class indices [0-4]

        # NP Loss (Binary Segmentation)
        loss_np = self.loss_np(pred_np, true_np)

        # HV Loss (Regression)
        # Get per-pixel MSE loss
        loss_hv_per_pixel = self.loss_hv(pred_hv, true_hv) 

        # true_np is (B, 1, H, W), we need (B, 2, H, W) or (B, 1, H, W) for broadcasting
        masked_loss_hv = loss_hv_per_pixel * true_np # (B, 2, H, W)

        # Sum of loss / sum of non-zero pixels
        loss_hv = masked_loss_hv.sum() / (true_np.sum() + 1e-8)

        loss_tp = self.loss_tp(pred_tp, true_tp)

        total_loss = 2.0 * loss_np + 2.0 * loss_hv + 1.0 * loss_tp

        return total_loss, {"np": loss_np, "hv": loss_hv, "tp": loss_tp}

try:
    criterion = CombinedHoverLoss(tp_class_weights=CLASS_LOSS_WEIGHTS, ignore_index=0).to(DEVICE)
    print("Loss function instantiated successfully.")

    dummy_preds = {
        "np": torch.randn(BATCH_SIZE, 1, 64, 64).to(DEVICE),
        "hv": torch.randn(BATCH_SIZE, 2, 64, 64).to(DEVICE),
        "tp": torch.randn(BATCH_SIZE, TP_MAP_CHANNELS, 64, 64).to(DEVICE)
    }
    dummy_targets = {
        "np": torch.randint(0, 2, (BATCH_SIZE, 1, 64, 64)).float().to(DEVICE),
        "hv": torch.randn(BATCH_SIZE, 2, 64, 64).to(DEVICE),
        "tp": torch.randint(0, TP_MAP_CHANNELS, (BATCH_SIZE, 64, 64)).long().to(DEVICE)
    }

    total_loss, losses = criterion(dummy_preds, dummy_targets)
    print(f"Loss calculation successful:")
    print(f"  Total: {total_loss.item():.4f}")
    print(f"  NP Loss: {losses['np'].item():.4f}")
    print(f"  HV Loss: {losses['hv'].item():.4f}")
    print(f"  TP Loss: {losses['tp'].item():.4f}")
    del criterion, dummy_preds, dummy_targets, total_loss, losses
except Exception as e:
    print(f"Error during loss function test: {e}")

Loss function instantiated successfully.
Loss calculation successful:
  Total: 18.2983
  NP Loss: 0.8088
  HV Loss: 4.0461
  TP Loss: 8.5886


## 7. Training and Validation Loop

In [None]:
def train_epoch(model, loader, criterion, optimizer, device):
    """
    Runs a single training epoch.
    """
    model.train()
    epoch_loss = 0.0
    epoch_losses = collections.defaultdict(float)

    num_batches = len(loader)
    if num_batches == 0:
        print("Warning: train_loader is empty.")
        return 0.0, epoch_losses

    pbar = tqdm(loader, desc="Training", leave=False, total=num_batches)
    for images, targets in pbar:
        if images is None or targets is None:
            continue

        images = images.to(device)
        targets = {k: v.to(device) for k, v in targets.items()}

        # Zero gradients
        optimizer.zero_grad()

        # Forward pass
        preds = model(images)

        # Calculate loss
        total_loss, losses = criterion(preds, targets)

        if torch.isnan(total_loss):
            print("Warning: NaN loss detected. Skipping batch.")
            continue

        # Backward pass
        total_loss.backward()

        # Optimize
        optimizer.step()

        epoch_loss += total_loss.item()
        for k, v in losses.items():
            epoch_losses[k] += v.item()

        pbar.set_postfix(loss=f"{total_loss.item():.4f}",
                         np=f"{losses['np'].item():.4f}",
                         hv=f"{losses['hv'].item():.4f}",
                         tp=f"{losses['tp'].item():.4f}")

    avg_loss = epoch_loss / num_batches
    avg_losses = {k: v / num_batches for k, v in epoch_losses.items()}
    return avg_loss, avg_losses

def validate_epoch(model, loader, criterion, device):
    """
    Runs a single validation epoch.
    """
    model.eval()
    epoch_loss = 0.0
    epoch_losses = collections.defaultdict(float)

    num_batches = len(loader)
    if num_batches == 0:
        print("Warning: val_loader is empty.")
        return 0.0, epoch_losses

    with torch.no_grad():
        pbar = tqdm(loader, desc="Validating", leave=False, total=num_batches)
        for images, targets in pbar:
            if images is None or targets is None:
                continue

            images = images.to(device)
            targets = {k: v.to(device) for k, v in targets.items()}

            # Forward pass
            preds = model(images)

            # Calculate loss
            total_loss, losses = criterion(preds, targets)

            if torch.isnan(total_loss):
                continue

            epoch_loss += total_loss.item()
            for k, v in losses.items():
                epoch_losses[k] += v.item()

            pbar.set_postfix(loss=f"{total_loss.item():.4f}")

    avg_loss = epoch_loss / num_batches
    avg_losses = {k: v / num_batches for k, v in epoch_losses.items()}
    return avg_loss, avg_losses


if train_ids is None or not train_ids:
    print("Cannot start training: train_ids is empty. Check Stratified Split step.")
else:
    print("Initializing datasets and dataloaders...")

    train_dataset = CellDataset(DATA_DIR, MASK_DIR, train_ids, patch_size=PATCH_SIZE, augment=True)
    val_dataset = CellDataset(DATA_DIR, MASK_DIR, val_ids, patch_size=PATCH_SIZE, augment=False)

    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        collate_fn=collate_fn_skip_errors, 
        pin_memory=True
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=NUM_WORKERS,
        collate_fn=collate_fn_skip_errors,
        pin_memory=True
    )
    print("Dataloaders created.")

    model = HoverNet(num_classes=TP_MAP_CHANNELS).to(DEVICE)
    criterion = CombinedHoverLoss(tp_class_weights=CLASS_LOSS_WEIGHTS, ignore_index=0).to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=LR, weight_decay=1e-4) 
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)

    print("Starting training...")
    best_val_loss = float('inf')

    for epoch in range(1, EPOCHS + 1):
        print(f"--- Epoch {epoch}/{EPOCHS} ---")
        start_time = time.time()

        train_loss, train_losses_breakdown = train_epoch(model, train_loader, criterion, optimizer, DEVICE)
        val_loss, val_losses_breakdown = validate_epoch(model, val_loader, criterion, DEVICE)

        end_time = time.time()
        epoch_mins = (end_time - start_time) / 60

        print(f"Epoch {epoch} Summary | Time: {epoch_mins:.2f}m")
        print(f"  Train Loss: {train_loss:.4f}")
        print(f"    (NP: {train_losses_breakdown.get('np', 0):.4f}, HV: {train_losses_breakdown.get('hv', 0):.4f}, TP: {train_losses_breakdown.get('tp', 0):.4f}) ")
        print(f"  Valid Loss: {val_loss:.4f}")
        print(f"    (NP: {val_losses_breakdown.get('np', 0):.4f}, HV: {val_losses_breakdown.get('hv', 0):.4f}, TP: {val_losses_breakdown.get('tp', 0):.4f}) ")

        scheduler.step(val_loss)

        if val_loss < best_val_loss:
            print(f"==> Validation loss improved from {best_val_loss:.4f} to {val_loss:.4f}. Saving model to {MODEL_SAVE_PATH}")
            best_val_loss = val_loss
            torch.save(model.state_dict(), MODEL_SAVE_PATH)
        else:
            print(f"Validation loss did not improve from {best_val_loss:.4f}.")

    print("--- Training Finished ---")
    print(f"Best validation loss: {best_val_loss:.4f}")
    print(f"Model saved to {MODEL_SAVE_PATH}")


Initializing datasets and dataloaders...
Initializing dataset with 167 images...


Dataset initialized. Found 167 valid image-mask pairs.
Initializing dataset with 42 images...


Dataset initialized. Found 42 valid image-mask pairs.
Dataloaders created.
Loading ImageNet-pre-trained ConvNeXt-Tiny backbone from timm...
ConvNeXt-Tiny backbone loaded successfully.
Starting training...
--- Epoch 1/50 ---


KeyboardInterrupt: 

## 8. Inference and Post-processing

This is the final step. We will:
1.  Load our saved `best_model.pth`.
2.  Find all images in the `TEST_DIR`.
3.  For each test image:
    a.  Perform **sliding window inference** (with overlap) to get full-sized `NP`, `HV`, and `TP` prediction maps.
    b.  Run the **post-processing** pipeline on these maps to generate instance segmentations.
    c.  **Encode** the instances into the 4 required RLE strings.
4.  Save all RLE strings to `submission.csv`.

### 8a. Post-processing Pipeline

This is the most complex part of the inference. We must convert the model's three output maps into final instance masks.

1.  **NP Map** (`pred_np`): Threshold this to get a binary mask of all nuclei.
2.  **HV Map** (`pred_hv`): We don't use this directly for Watershed. Instead, we use it to find *markers* (seed points) for each nucleus. We'll use the NP map and `skimage.feature.peak_local_max` with a distance transform to get good, separated markers. A more advanced method would use the HV map to find energy basins, but this is a robust start.
3.  **TP Map** (`pred_tp`): Take an `argmax` to get a pixel-wise class prediction `(H, W)`.
4.  **Watershed:** Use `skimage.segmentation.watershed` with our `markers` and the *inverse distance transform* of the `NP` map as the landscape. This will "flood-fill" from each marker until it meets another, separating touching nuclei.
5.  **Assign Classes:** For each final instance ID from Watershed, find all the pixels belonging to it. Look up those pixels in the `TP Map (argmax)` and assign the **majority vote** (mode) class to that instance.
6.  **Format:** Separate the instances into 4 final masks (one per class) and encode.

In [None]:
def post_process(pred_np, pred_hv, pred_tp, np_thresh=0.6, marker_min_dist=9):
    """
    Post-processes the raw model outputs into class-specific instance masks.

    Args:
    - pred_np: (H, W) numpy array, logits or sigmoids for nucleus probability.
    - pred_hv: (2, H, W) numpy array, (y, x) vectors.
    - pred_tp: (5, H, W) numpy array, logits for [BG, C1, C2, C3, C4].
    - np_thresh: Threshold to binarize the NP map.
    - marker_min_dist: Minimum distance between seeds for watershed.

    Returns:
    - A dictionary {class_name: instance_mask} for the 4 cell types.
    """

    if pred_np.ndim == 3:
        pred_np = pred_np.squeeze() 

    if np.max(pred_np) > 1.0 or np.min(pred_np) < 0.0:
         pred_np = 1 / (1 + np.exp(-pred_np))

    binary_mask = (pred_np > np_thresh).astype(np.uint8)

    type_map = np.argmax(pred_tp, axis=0).astype(np.uint8)

    distance = distance_transform_edt(binary_mask)

    coords = peak_local_max(distance, min_distance=marker_min_dist, labels=binary_mask)
    markers = np.zeros_like(binary_mask, dtype=bool)
    markers[tuple(coords.T)] = True

    markers, num_features = label(markers)

    if num_features == 0:
        return {name: np.zeros(pred_np.shape, dtype=np.uint16) for name in SUBMISSION_CLASSES}

    instance_map = watershed(-distance, markers, mask=binary_mask)

    final_instance_masks = {}
    for class_name in SUBMISSION_CLASSES:
        final_instance_masks[class_name] = np.zeros(pred_np.shape, dtype=np.uint16)

    instance_ids = np.unique(instance_map)[1:] # Ignore 0 (background)
    next_instance_id_per_class = {name: 1 for name in SUBMISSION_CLASSES}

    for inst_id in instance_ids:
        inst_pixels = (instance_map == inst_id)

        inst_class_idx = mode(type_map[inst_pixels], keepdims=False)[0]

        if inst_class_idx == 0 or inst_class_idx not in CLASS_MAP:
            continue

        class_name = CLASS_MAP[inst_class_idx]

        new_inst_id = next_instance_id_per_class[class_name]

        final_instance_masks[class_name][inst_pixels] = new_inst_id

        next_instance_id_per_class[class_name] += 1

    return final_instance_masks

### 8b. Sliding Window Inference

We can't pass a 4000x4000 image to the model. We must:
1.  Create overlapping patches from the large image.
2.  Run the model on each patch.
3.  Stitch the predictions back together, averaging the results in the overlapping regions to get a smooth final map.

In [None]:
def sliding_window_inference(model, image, patch_size, overlap, device):
    """
    Performs sliding window inference on a large image.
    Returns full-sized prediction maps.
    """
    model.eval()

    H, W, C = image.shape
    stride = patch_size - overlap

    pad_h = (stride - (H - patch_size) % stride) % stride
    pad_w = (stride - (W - patch_size) % stride) % stride

    padded_image = np.pad(image, ((0, pad_h), (0, pad_w), (0, 0)), mode='constant')
    H_pad, W_pad, _ = padded_image.shape

    pred_map_np = np.zeros((H_pad, W_pad), dtype=np.float32)
    pred_map_hv = np.zeros((2, H_pad, W_pad), dtype=np.float32)
    pred_map_tp = np.zeros((TP_MAP_CHANNELS, H_pad, W_pad), dtype=np.float32)

    count_map = np.zeros((H_pad, W_pad), dtype=np.float32)

    with torch.no_grad():
        for y in range(0, H_pad - patch_size + 1, stride):
            for x in range(0, W_pad - patch_size + 1, stride):
                patch = padded_image[y:y+patch_size, x:x+patch_size]

                patch_tensor = torch.from_numpy(patch.transpose(2, 0, 1)).float() / 255.0
                patch_tensor = patch_tensor.unsqueeze(0).to(device)

                preds = model(patch_tensor)

                pred_np_patch = preds['np'].squeeze().cpu().numpy() # (H, W)
                pred_hv_patch = preds['hv'].squeeze().cpu().numpy() # (2, H, W)
                pred_tp_patch = preds['tp'].squeeze().cpu().numpy() # (5, H, W)

                pred_map_np[y:y+patch_size, x:x+patch_size] += pred_np_patch
                pred_map_hv[:, y:y+patch_size, x:x+patch_size] += pred_hv_patch
                pred_map_tp[:, y:y+patch_size, x:x+patch_size] += pred_tp_patch

                count_map[y:y+patch_size, x:x+patch_size] += 1.0

    final_pred_np = pred_map_np / count_map
    final_pred_hv = pred_map_hv / count_map
    final_pred_tp = pred_map_tp / count_map

    final_pred_np = final_pred_np[0:H, 0:W]
    final_pred_hv = final_pred_hv[:, 0:H, 0:W]
    final_pred_tp = final_pred_tp[:, 0:H, 0:W]

    return final_pred_np, final_pred_hv, final_pred_tp

WPQ


In [None]:
from scipy.optimize import linear_sum_assignment

def sliding_window_inference(model, image, patch_size, overlap, device):
    model.eval()
    H, W, C = image.shape
    stride = patch_size - overlap
    pad_h = (stride - (H - patch_size) % stride) % stride
    pad_w = (stride - (W - patch_size) % stride) % stride
    padded_image = np.pad(image, ((0, pad_h), (0, pad_w), (0, 0)), mode='constant')
    H_pad, W_pad, _ = padded_image.shape
    pred_map_np = np.zeros((H_pad, W_pad), dtype=np.float32)
    pred_map_hv = np.zeros((2, H_pad, W_pad), dtype=np.float32)
    pred_map_tp = np.zeros((TP_MAP_CHANNELS, H_pad, W_pad), dtype=np.float32)
    count_map = np.zeros((H_pad, W_pad), dtype=np.float32)
    with torch.no_grad():
        for y in range(0, H_pad - patch_size + 1, stride):
            for x in range(0, W_pad - patch_size + 1, stride):
                patch = padded_image[y:y+patch_size, x:x+patch_size]
                patch_tensor = torch.from_numpy(patch.transpose(2, 0, 1)).float() / 255.0
                patch_tensor = patch_tensor.unsqueeze(0).to(device)
                preds = model(patch_tensor)
                pred_np_patch = torch.sigmoid(preds['np']).squeeze().cpu().numpy()
                pred_hv_patch = preds['hv'].squeeze().cpu().numpy()
                pred_tp_patch = preds['tp'].squeeze().cpu().numpy()
                pred_map_np[y:y+patch_size, x:x+patch_size] += pred_np_patch
                pred_map_hv[:, y:y+patch_size, x:x+patch_size] += pred_hv_patch
                pred_map_tp[:, y:y+patch_size, x:x+patch_size] += pred_tp_patch
                count_map[y:y+patch_size, x:x+patch_size] += 1.0
    count_map[count_map == 0] = 1
    final_pred_np = pred_map_np / count_map
    final_pred_hv = pred_map_hv / count_map
    final_pred_tp = pred_map_tp / count_map
    final_pred_np = final_pred_np[0:H, 0:W]
    final_pred_hv = final_pred_hv[:, 0:H, 0:W]
    final_pred_tp = final_pred_tp[:, 0:H, 0:W]
    return final_pred_np, final_pred_hv, final_pred_tp

def get_pq(pred_mask, gt_mask, iou_thresh=0.5):
    pred_labels = np.unique(pred_mask)[1:]
    gt_labels = np.unique(gt_mask)[1:]
    if len(pred_labels) == 0 and len(gt_labels) == 0:
        return 1.0, 0, 0, 0, 0
    elif len(pred_labels) == 0 or len(gt_labels) == 0:
        return 0.0, 0, 0, len(pred_labels), len(gt_labels)
    iou_matrix = np.zeros((len(pred_labels), len(gt_labels)), dtype=np.float32)
    for i, pred_id in enumerate(pred_labels):
        pred_instance = (pred_mask == pred_id)
        for j, gt_id in enumerate(gt_labels):
            gt_instance = (gt_mask == gt_id)
            intersection = np.logical_and(pred_instance, gt_instance).sum()
            union = np.logical_or(pred_instance, gt_instance).sum()
            if union > 0:
                iou_matrix[i, j] = intersection / union
    row_ind, col_ind = linear_sum_assignment(-iou_matrix)
    tp = 0
    iou_sum = 0.0
    matched_pred_indices = set()
    matched_gt_indices = set()
    for pred_idx, gt_idx in zip(row_ind, col_ind):
        if iou_matrix[pred_idx, gt_idx] >= iou_thresh:
            tp += 1
            iou_sum += iou_matrix[pred_idx, gt_idx]
            matched_pred_indices.add(pred_idx)
            matched_gt_indices.add(gt_idx)
    fp = len(pred_labels) - len(matched_pred_indices)
    fn = len(gt_labels) - len(matched_gt_indices)
    pq = iou_sum / (tp + 0.5 * fp + 0.5 * fn + 1e-8)
    return pq, iou_sum, tp, fp, fn

def compute_pq_for_image(pred_masks_by_class, gt_masks_by_class):
    pq_scores = {}
    total_pq = 0.0
    total_weight = 0.0
    for class_name, class_weight in CLASS_WEIGHTS.items():
        pred_mask = pred_masks_by_class.get(class_name, np.zeros_like(list(gt_masks_by_class.values())[0]))
        gt_mask = gt_masks_by_class.get(class_name, np.zeros_like(list(gt_masks_by_class.values())[0]))

        pq, iou_sum, tp, fp, fn = get_pq(pred_mask, gt_mask)
        pq_scores[class_name] = pq
        total_pq += pq * class_weight
        total_weight += class_weight

    wpq = total_pq / total_weight
    return wpq, pq_scores

def evaluate_model_wpq(model, image_ids, device):
    model.eval()
    all_wpq_scores = []
    all_pq_by_class = collections.defaultdict(list)

    print(f"Starting wPQ evaluation on {len(image_ids)} validation images...")

    for i, image_id in enumerate(image_ids):

        print(f"\n[Image {i+1}/{len(image_ids)}] STARTING: {image_id}")
        start_time = time.time()

        try:
            img_path = os.path.join(DATA_DIR, f"{image_id}.tif")
            if not os.path.exists(img_path):
                print(f"  Error: Image file not found at {img_path}. Skipping.")
                continue
            image = tifffile.imread(img_path)

            if image.ndim == 3 and image.shape[-1] == 4: image = image[:, :, :3]
            if image.ndim == 2: image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
            if image.dtype != np.uint8:
                if image.max() > 255: image = (image / image.max() * 255).astype(np.uint8)
                else: image = image.astype(np.uint8)
            H, W, _ = image.shape
            print(f"  Loaded image {image_id} with shape ({H}, {W})")

            print(f"  Running sliding window inference...")
            pred_np, pred_hv, pred_tp = sliding_window_inference(
                model, image, PATCH_SIZE, OVERLAP, device
            )

            print(f"  Running post-processing...")
            pred_masks_by_class = post_process(pred_np, pred_hv, pred_tp, np_thresh=0.6)

            print(f"  Loading ground truth masks...")
            mask_path = os.path.join(MASK_DIR, f"{image_id}.npz")
            if not os.path.exists(mask_path):
                print(f"  Error: Mask file not found at {mask_path}. Skipping.")
                continue
            with np.load(mask_path) as data:
                gt_instance_map = data['instance_map']
                gt_type_map = data['type_map']

            gt_masks_by_class = {}
            for class_id, class_name in CLASS_MAP.items():
                class_instances = np.unique(gt_instance_map[gt_type_map == class_id])
                class_instances = class_instances[class_instances != 0]
                gt_mask = np.zeros((H, W), dtype=np.uint16)
                for new_id, inst_id in enumerate(class_instances, 1):
                    gt_mask[gt_instance_map == inst_id] = new_id
                gt_masks_by_class[class_name] = gt_mask

            print(f"  Calculating wPQ score...")
            wpq, pq_by_class = compute_pq_for_image(pred_masks_by_class, gt_masks_by_class)

            all_wpq_scores.append(wpq)
            for class_name, pq in pq_by_class.items():
                all_pq_by_class[class_name].append(pq)

            end_time = time.time()
            print(f"  FINISHED: {image_id} | Time: {end_time - start_time:.2f}s | wPQ: {wpq:.4f}")

        except Exception as e:
            print(f"  --- CRITICAL ERROR on {image_id}: {e}. Skipping. ---")
            continue

    avg_wpq = np.mean(all_wpq_scores) if all_wpq_scores else 0.0
    avg_pq_by_class = pd.Series(
        {class_name: np.mean(scores) for class_name, scores in all_pq_by_class.items()}
    )

    return avg_wpq, avg_pq_by_class

In [None]:
CLASS_MAP = {
    1: 'Epithelial',
    2: 'Lymphocyte',
    3: 'Neutrophil',
    4: 'Macrophage'
}
CLASS_WEIGHTS = {
    "Epithelial": 1,
    "Lymphocyte": 1,
    "Neutrophil": 10,
    "Macrophage": 10
}
DATA_DIR = "train"
MASK_DIR = "mask_new"
TEST_DIR = "test_final"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
PATCH_SIZE = 256
OVERLAP = 64
TP_MAP_CHANNELS = len(CLASS_MAP) + 1 # BG + 4 classes
MODEL_SAVE_PATH = "best_model_new.pth"

print(f"Loading best model from {MODEL_SAVE_PATH}...")
if not os.path.exists(MODEL_SAVE_PATH):
    print(f"Error: Model file not found at {MODEL_SAVE_PATH}. Please run training first.")
elif 'val_ids' not in locals() or not val_ids:
    print("Error: 'val_ids' not found. Please run the stratified split cell first.")
else:
    print("Finding 'fast' validation images...")

    PIXEL_THRESHOLD = 1000000

    val_ids_fast = []
    val_ids_slow = []

    for image_id in val_ids:
        img_path = os.path.join(DATA_DIR, f"{image_id}.tif")
        try:
            with tifffile.TiffFile(img_path) as tif:
                shape = tif.pages[0].shape

            num_pixels = shape[0] * shape[1]

            if num_pixels < PIXEL_THRESHOLD:
                val_ids_fast.append(image_id)
            else:
                val_ids_slow.append(image_id)
        except Exception as e:
            print(f"Warning: Could not read shape for {image_id}. Error: {e}. Adding to fast list.")
            val_ids_fast.append(image_id)

    print(f"Total validation images: {len(val_ids)}")
    print(f"Found {len(val_ids_fast)} images below {PIXEL_THRESHOLD} pixels (fast set).")
    print(f"Found {len(val_ids_slow)} images above threshold (slow set): {val_ids_slow}")
    print("---")

    model_for_eval = HoverNet(num_classes=TP_MAP_CHANNELS).to(DEVICE)
    model_for_eval.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=DEVICE))
    model_for_eval.eval()
    print("Model loaded successfully.")

    print(f"Starting FAST wPQ evaluation on {len(val_ids_fast)} validation images...")
    avg_wpq, avg_pq_by_class = evaluate_model_wpq(model_for_eval, val_ids_fast, DEVICE)

    print("\n--- FAST Evaluation Complete ---")
    print(f"  (Score is based on {len(val_ids_fast)} / {len(val_ids)} images)")
    print(f"  Average wPQ Score: {avg_wpq:.4f}")
    print("\n  Average PQ by Class:")
    print(avg_pq_by_class.to_string(float_format="%.4f"))

Loading best model from best_model_new_2.pth...
Finding 'fast' validation images...
Total validation images: 42
Found 38 images below 1000000 pixels (fast set).
Found 4 images above threshold (slow set): ['slide93', 'slide61', 'slide162', 'slide72']
---
Loading ImageNet-pre-trained ConvNeXt-Tiny backbone from timm...
ConvNeXt-Tiny backbone loaded successfully.
Model loaded successfully.
Starting FAST wPQ evaluation on 38 validation images...
Starting wPQ evaluation on 38 validation images...

[Image 1/38] STARTING: slide158
  Loaded image slide158 with shape (488, 586)
  Running sliding window inference...
  Running post-processing...
  Loading ground truth masks...
  Calculating wPQ score...
  FINISHED: slide158 | Time: 0.42s | wPQ: 0.7897

[Image 2/38] STARTING: slide32
  Loaded image slide32 with shape (270, 480)
  Running sliding window inference...
  Running post-processing...
  Loading ground truth masks...
  Calculating wPQ score...
  FINISHED: slide32 | Time: 0.35s | wPQ: 0.677

### 8c. Generate Submission File

In [None]:
def generate_submission(model_path, test_dir, submission_path):
    print("--- Starting Inference and Submission Generation ---")

    if not os.path.exists(model_path):
        print(f"Error: Model file not found at {model_path}. Cannot generate submission.")
        return

    print(f"Loading model from {model_path}...")
    model = HoverNet(num_classes=TP_MAP_CHANNELS).to(DEVICE)
    model.load_state_dict(torch.load(model_path, map_location=DEVICE))
    model.eval()
    print("Model loaded successfully.")

    test_image_paths = sorted(glob.glob(os.path.join(test_dir, "*.tif")))
    if not test_image_paths:
        print(f"Error: No .tif images found in {test_dir}. Cannot generate submission.")
        return

    print(f"Found {len(test_image_paths)} test images.")

    submission_data = []

    for img_path in tqdm(test_image_paths, desc="Processing Test Images"):
        image_id = os.path.basename(img_path).split('.')[0]

        try:
            image = tifffile.imread(img_path)

            if image.ndim == 3 and image.shape[-1] == 4:
                image = image[:, :, :3]
            if image.ndim == 2:
                image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
            if image.dtype != np.uint8:
                if image.max() > 255:
                    image = (image / image.max() * 255).astype(np.uint8)
                else:
                    image = image.astype(np.uint8)

            pred_np, pred_hv, pred_tp = sliding_window_inference(
                model, image,
                patch_size=PATCH_SIZE,
                overlap=OVERLAP,
                device=DEVICE
            )

            instance_masks_by_class = post_process(
                pred_np, pred_hv, pred_tp,
                np_thresh=0.5,
                marker_min_dist=9
            )

            rle_strings = {}
            for class_name in SUBMISSION_CLASSES:
                mask = instance_masks_by_class[class_name]
                rle_strings[class_name] = rle_encode_instance_mask(mask)

            submission_data.append({
                "image_id": image_id,
                "Epithelial": rle_strings['Epithelial'],
                "Lymphocyte": rle_strings['Lymphocyte'],
                "Neutrophil": rle_strings['Neutrophil'],
                "Macrophage": rle_strings['Macrophage']
            })

        except Exception as e:
            print(f"Error processing {image_id}. Appending '0' for all classes. Error: {e}")
            submission_data.append({
                "image_id": image_id,
                "Epithelial": "0",
                "Lymphocyte": "0",
                "Neutrophil": "0",
                "Macrophage": "0"
            })

    submission_df = pd.DataFrame(submission_data)
    submission_df = submission_df[["image_id"] + SUBMISSION_CLASSES]

    submission_df.to_csv(submission_path, index=False)
    print(f"\nSubmission file saved to: {submission_path}")
    print(submission_df.head())

generate_submission(
    model_path=MODEL_SAVE_PATH,
    test_dir=TEST_DIR,
    submission_path=SUBMISSION_PATH
)

--- Starting Inference and Submission Generation ---
Loading model from best_model_new_2.pth...
Loading ImageNet-pre-trained ConvNeXt-Tiny backbone from timm...
ConvNeXt-Tiny backbone loaded successfully.
Model loaded successfully.
Found 40 test images.



Submission file saved to: submission_new_2.csv
  image_id                                         Epithelial  \
0   slide1                                                  0   
1  slide10                                                  0   
2  slide11  64 308 4 114 572 22 64 938 11 114 1204 25 64 1...   
3  slide12                                                  0   
4  slide13                                                  0   

                                          Lymphocyte  \
0  1 31121 4 1 31276 8 1 31432 10 1 31588 12 1 31...   
1  527 500 17 527 1135 18 527 1769 19 440 2327 6 ...   
2  1 414286 7 1 414917 11 1 415549 13 1 416181 15...   
3                                                  0   
4                                                  0   

                                          Neutrophil  \
0                                                  0   
1  1 507800 7 1 508433 11 1 509067 14 1 509700 17...   
2                                                  0   


Performing Grid Search to calculate threshold and window size

In [None]:
from scipy.optimize import linear_sum_assignment

CLASS_MAP = {
    1: 'Epithelial',
    2: 'Lymphocyte',
    3: 'Neutrophil',
    4: 'Macrophage'
}
CLASS_WEIGHTS = {
    "Epithelial": 1,
    "Lymphocyte": 1,
    "Neutrophil": 10,
    "Macrophage": 10
}
DATA_DIR = "train"
MASK_DIR = "mask_new"
TEST_DIR = "test_final"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
device = "cuda" if torch.cuda.is_available() else "cpu"
PATCH_SIZE = 256
OVERLAP = 64
TP_MAP_CHANNELS = len(CLASS_MAP) + 1 # BG + 4 classes
MODEL_SAVE_PATH = "best_model_new_3.pth"

def post_process(pred_np, pred_hv, pred_tp, np_thresh=0.5, marker_min_dist=7):
    """
    Post-processes the raw model outputs into class-specific instance masks.
    pred_np is expected to be probabilities (0.0 to 1.0)
    """
    binary_mask = (pred_np > np_thresh).astype(np.uint8)

    type_map = np.argmax(pred_tp, axis=0).astype(np.uint8)

    distance = distance_transform_edt(binary_mask)
    coords = peak_local_max(distance, min_distance=marker_min_dist, labels=binary_mask)
    markers = np.zeros_like(binary_mask, dtype=bool)
    markers[tuple(coords.T)] = True
    markers, num_features = label(markers)

    if num_features == 0:
        return {name: np.zeros(pred_np.shape, dtype=np.uint16) for name in CLASS_MAP.values()}

    instance_map = watershed(-distance, markers, mask=binary_mask)

    final_instance_masks = {}
    for class_name in CLASS_MAP.values():
        final_instance_masks[class_name] = np.zeros(pred_np.shape, dtype=np.uint16)

    next_instance_id_per_class = {name: 1 for name in CLASS_MAP.values()}

    instance_ids = np.unique(instance_map)[1:] # Ignore 0
    for inst_id in instance_ids:
        inst_pixels = (instance_map == inst_id)

        inst_class_idx_mode = mode(type_map[inst_pixels], keepdims=False)

        if inst_class_idx_mode.count.size == 0:
            continue # Skip if instance has no typed pixels (shouldn't happen)

        inst_class_idx = inst_class_idx_mode.mode[0] if isinstance(inst_class_idx_mode.mode, np.ndarray) else inst_class_idx_mode.mode

        if inst_class_idx == 0 or inst_class_idx not in CLASS_MAP:
            continue

        class_name = CLASS_MAP[inst_class_idx]
        new_inst_id = next_instance_id_per_class[class_name]
        final_instance_masks[class_name][inst_pixels] = new_inst_id
        next_instance_id_per_class[class_name] += 1

    return final_instance_masks

def sliding_window_inference(model, image, patch_size, overlap, device):
    """
    Performs sliding window inference on a large image.
    Returns full-sized prediction maps.
    """
    model.eval()
    H, W, C = image.shape
    stride = patch_size - overlap

    pad_h = (stride - (H - patch_size) % stride) % stride
    pad_w = (stride - (W - patch_size) % stride) % stride

    padded_image = np.pad(image, ((0, pad_h), (0, pad_w), (0, 0)), mode='constant')
    H_pad, W_pad, _ = padded_image.shape

    pred_map_np = np.zeros((H_pad, W_pad), dtype=np.float32)
    pred_map_hv = np.zeros((2, H_pad, W_pad), dtype=np.float32)
    pred_map_tp = np.zeros((TP_MAP_CHANNELS, H_pad, W_pad), dtype=np.float32)
    count_map = np.zeros((H_pad, W_pad), dtype=np.float32)

    with torch.no_grad():
        for y in range(0, H_pad - patch_size + 1, stride):
            for x in range(0, W_pad - patch_size + 1, stride):
                patch = padded_image[y:y+patch_size, x:x+patch_size]
                patch_tensor = torch.from_numpy(patch.transpose(2, 0, 1)).float() / 255.0
                patch_tensor = patch_tensor.unsqueeze(0).to(device)
                preds = model(patch_tensor)

                pred_np_patch = torch.sigmoid(preds['np']).squeeze().cpu().numpy()
                pred_hv_patch = preds['hv'].squeeze().cpu().numpy()
                pred_tp_patch = preds['tp'].squeeze().cpu().numpy()

                pred_map_np[y:y+patch_size, x:x+patch_size] += pred_np_patch
                pred_map_hv[:, y:y+patch_size, x:x+patch_size] += pred_hv_patch
                pred_map_tp[:, y:y+patch_size, x:x+patch_size] += pred_tp_patch
                count_map[y:y+patch_size, x:x+patch_size] += 1.0

    count_map[count_map == 0] = 1 # Avoid division by zero
    final_pred_np = pred_map_np / count_map
    final_pred_hv = pred_map_hv / count_map
    final_pred_tp = pred_map_tp / count_map

    final_pred_np = final_pred_np[0:H, 0:W]
    final_pred_hv = final_pred_hv[:, 0:H, 0:W]
    final_pred_tp = final_pred_tp[:, 0:H, 0:W]

    return final_pred_np, final_pred_hv, final_pred_tp

def get_pq(pred_mask, gt_mask, iou_thresh=0.5):
    """Calculates Panoptic Quality for a single class."""
    pred_labels = np.unique(pred_mask)[1:] # Ignore 0
    gt_labels = np.unique(gt_mask)[1:] # Ignore 0

    if len(pred_labels) == 0 and len(gt_labels) == 0:
        return 1.0, 0, 0, 0, 0 # Both empty, perfect score
    elif len(pred_labels) == 0 or len(gt_labels) == 0:
        return 0.0, 0, 0, len(pred_labels), len(gt_labels) # One empty

    iou_matrix = np.zeros((len(pred_labels), len(gt_labels)), dtype=np.float32)
    for i, pred_id in enumerate(pred_labels):
        pred_instance = (pred_mask == pred_id)
        for j, gt_id in enumerate(gt_labels):
            gt_instance = (gt_mask == gt_id)

            intersection = np.logical_and(pred_instance, gt_instance).sum()
            union = np.logical_or(pred_instance, gt_instance).sum()

            if union > 0:
                iou_matrix[i, j] = intersection / union

    row_ind, col_ind = linear_sum_assignment(-iou_matrix)

    tp = 0
    iou_sum = 0.0
    matched_pred_indices = set()
    matched_gt_indices = set()
    for pred_idx, gt_idx in zip(row_ind, col_ind):
        if iou_matrix[pred_idx, gt_idx] >= iou_thresh:
            tp += 1
            iou_sum += iou_matrix[pred_idx, gt_idx]
            matched_pred_indices.add(pred_idx)
            matched_gt_indices.add(gt_idx)

    fp = len(pred_labels) - len(matched_pred_indices)
    fn = len(gt_labels) - len(matched_gt_indices)

    pq = iou_sum / (tp + 0.5 * fp + 0.5 * fn + 1e-8)

    return pq, iou_sum, tp, fp, fn

def compute_pq_for_image(pred_masks_by_class, gt_masks_by_class):
    """Computes the final wPQ score for a single image."""
    pq_scores = {}
    total_pq = 0.0
    total_weight = 0.0

    if gt_masks_by_class:
        ref_shape = list(gt_masks_by_class.values())[0].shape
    elif pred_masks_by_class:
        ref_shape = list(pred_masks_by_class.values())[0].shape
    else:
        ref_shape = (1,1) # Should not happen

    for class_name, class_weight in CLASS_WEIGHTS.items():
        pred_mask = pred_masks_by_class.get(class_name, np.zeros(ref_shape, dtype=np.uint16))
        gt_mask = gt_masks_by_class.get(class_name, np.zeros(ref_shape, dtype=np.uint16))

        pq, iou_sum, tp, fp, fn = get_pq(pred_mask, gt_mask)
        pq_scores[class_name] = pq
        total_pq += pq * class_weight
        total_weight += class_weight

    wpq = total_pq / (total_weight + 1e-8) # Add epsilon to denominator
    return wpq, pq_scores

print(f"Loading best model from {MODEL_SAVE_PATH}...")
if not os.path.exists(MODEL_SAVE_PATH):
    print(f"Error: Model file not found at {MODEL_SAVE_PATH}. Please run training first.")
elif 'val_ids' not in locals() or not val_ids:
    print("Error: 'val_ids' not found. Please run the stratified split cell first.")
elif 'HoverNet' not in locals():
    print("Error: 'HoverNet' class is not defined. Please run the model definition cell first.")
else:
    print("Finding 'fast' validation images...")
    PIXEL_THRESHOLD = 500000
    val_ids_fast = []
    for image_id in val_ids:
        img_path = os.path.join(DATA_DIR, f"{image_id}.tif")
        try:
            with tifffile.TiffFile(img_path) as tif:
                shape = tif.pages[0].shape
            num_pixels = shape[0] * shape[1]
            if num_pixels < PIXEL_THRESHOLD:
                val_ids_fast.append(image_id)
        except Exception as e:
            val_ids_fast.append(image_id) # Add it just in case

    print(f"Running 2D Grid Search on {len(val_ids_fast)} fast images...")

    model_for_eval = HoverNet(num_classes=TP_MAP_CHANNELS).to(DEVICE)
    model_for_eval.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=DEVICE))
    model_for_eval.eval()
    print("Model loaded successfully.")

    thresholds_to_test = [0.5, 0.6, 0.7, 0.8]
    dists_to_test = [5, 7, 9]

    best_wpq = -1.0
    best_thresh = -1.0
    best_dist = -1.0

    for dist in dists_to_test:
        for thresh in thresholds_to_test:
            print(f"\n--- Testing combo: (np_thresh = {thresh}, min_dist = {dist}) ---")

            all_wpq_scores = []
            pbar = tqdm(val_ids_fast, desc=f"Testing (t={thresh}, d={dist})")

            for image_id in pbar:
                try:
                    img_path = os.path.join(DATA_DIR, f"{image_id}.tif")
                    image = tifffile.imread(img_path)
                    if image.ndim == 3 and image.shape[-1] == 4: image = image[:, :, :3]
                    if image.ndim == 2: image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
                    if image.dtype != np.uint8:
                        if image.max() > 255: image = (image / image.max() * 255).astype(np.uint8)
                        else: image = image.astype(np.uint8)
                    H, W, _ = image.shape

                    pred_np, pred_hv, pred_tp = sliding_window_inference(
                        model_for_eval, image, PATCH_SIZE, OVERLAP, device
                    )

                    pred_masks_by_class = post_process(
                        pred_np, pred_hv, pred_tp,
                        np_thresh=thresh,
                        marker_min_dist=dist
                    )

                    mask_path = os.path.join(MASK_DIR, f"{image_id}.npz")
                    with np.load(mask_path) as data:
                        gt_instance_map = data['instance_map']
                        gt_type_map = data['type_map']
                    gt_masks_by_class = {}
                    for class_id, class_name in CLASS_MAP.items():
                        class_instances = np.unique(gt_instance_map[gt_type_map == class_id])
                        class_instances = class_instances[class_instances != 0]
                        gt_mask = np.zeros((H, W), dtype=np.uint16)
                        for new_id, inst_id in enumerate(class_instances, 1):
                            gt_mask[gt_instance_map == inst_id] = new_id
                        gt_masks_by_class[class_name] = gt_mask

                    wpq, _ = compute_pq_for_image(pred_masks_by_class, gt_masks_by_class)
                    all_wpq_scores.append(wpq)
                    pbar.set_postfix(avg_wPQ=f"{np.mean(all_wpq_scores):.4f}")

                except Exception as e:
                    print(f"ERROR on {image_id}: {e}")
                    continue

            if all_wpq_scores:
                avg_wpq = np.mean(all_wpq_scores)
            else:
                avg_wpq = 0.0 # Assign 0 if all images failed

            print(f"  Result for (t={thresh}, d={dist}): avg_wPQ = {avg_wpq:.4f}")

            if avg_wpq > best_wpq:
                best_wpq = avg_wpq
                best_thresh = thresh
                best_dist = dist

    print("\n--- 2D Grid Search Complete ---")
    print(f"Best wPQ Score: {best_wpq:.4f}")
    print(f"Found at: np_thresh = {best_thresh}, marker_min_dist = {best_dist}")



Loading best model from best_model_new_3.pth...
Finding 'fast' validation images...
Running 2D Grid Search on 31 fast images...
Loading ImageNet-pre-trained ConvNeXt-Tiny backbone from timm...
ConvNeXt-Tiny backbone loaded successfully.
Model loaded successfully.

--- Testing combo: (np_thresh = 0.5, min_dist = 5) ---


  Result for (t=0.5, d=5): avg_wPQ = 0.6566

--- Testing combo: (np_thresh = 0.6, min_dist = 5) ---


  Result for (t=0.6, d=5): avg_wPQ = 0.6314

--- Testing combo: (np_thresh = 0.7, min_dist = 5) ---


  Result for (t=0.7, d=5): avg_wPQ = 0.6282

--- Testing combo: (np_thresh = 0.8, min_dist = 5) ---


  Result for (t=0.8, d=5): avg_wPQ = 0.6023

--- Testing combo: (np_thresh = 0.5, min_dist = 7) ---


  Result for (t=0.5, d=7): avg_wPQ = 0.6714

--- Testing combo: (np_thresh = 0.6, min_dist = 7) ---


  Result for (t=0.6, d=7): avg_wPQ = 0.6420

--- Testing combo: (np_thresh = 0.7, min_dist = 7) ---


  Result for (t=0.7, d=7): avg_wPQ = 0.6325

--- Testing combo: (np_thresh = 0.8, min_dist = 7) ---


  Result for (t=0.8, d=7): avg_wPQ = 0.6083

--- Testing combo: (np_thresh = 0.5, min_dist = 9) ---


  Result for (t=0.5, d=9): avg_wPQ = 0.6763

--- Testing combo: (np_thresh = 0.6, min_dist = 9) ---


  Result for (t=0.6, d=9): avg_wPQ = 0.6456

--- Testing combo: (np_thresh = 0.7, min_dist = 9) ---


  Result for (t=0.7, d=9): avg_wPQ = 0.6381

--- Testing combo: (np_thresh = 0.8, min_dist = 9) ---


KeyboardInterrupt: 