### verify that GPU is available

In [3]:
import torch

if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"Using GPU: {torch.cuda.get_device_name(0)}")
else:
    device = torch.device("cpu")
    print("GPU not available, using CPU.")

Using GPU: NVIDIA A100-SXM4-40GB


### upload my files from the histo-segmentation folder to iDrive

In [1]:
from google.colab import files
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
# install all other necessary dependencies
!pip install torch torchvision monai kornia matplotlib scikit-image

Collecting monai
  Downloading monai-1.5.0-py3-none-any.whl.metadata (13 kB)
Collecting kornia
  Downloading kornia-0.8.1-py2.py3-none-any.whl.metadata (17 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_

### download the dataset

In [None]:
# download and extract the NuInsSeg Dataset to the drive
%%bash

# Set up directory in Google Drive
DRIVE_DIR="/content/drive/MyDrive/histo-segmentation/NuInsSeg"
ZIP_FILE="NuInsSeg.zip"
ZIP_PATH="$DRIVE_DIR/$ZIP_FILE"
DEST_DIR="$DRIVE_DIR"

mkdir -p "$DEST_DIR"

echo "Initiating NuInsSeg dataset download from Zenodo..."

# Download the dataset if not already present
if [ ! -f "$ZIP_PATH" ]; then
  echo "Downloading NuInsSeg.zip..."
  wget -O "$ZIP_PATH" "https://zenodo.org/record/10518968/files/NuInsSeg.zip?download=1"
else
  echo "Zip already exists. Skipping download."
fi

# Extract the dataset if not already extracted
if [ -d "$DEST_DIR/human bladder" ] || [ -d "$DEST_DIR/mouse brain" ]; then
  echo "Dataset already extracted. Skipping."
else
  echo "Extracting NuInsSeg.zip..."
  unzip -o -q "$ZIP_PATH" -d "$DEST_DIR"
fi

echo "NuInsSeg dataset is ready at $DEST_DIR"

Initiating NuInsSeg dataset download from Zenodo...
Zip already exists. Skipping download.
Dataset already extracted. Skipping.
NuInsSeg dataset is ready at /content/drive/MyDrive/histo-segmentation/NuInsSeg


In [None]:
# download and extract the NuInsSeg Dataset to the drive
%%bash
cd /content/drive/MyDrive/histo-segmentation
git clone https://github.com/yformer/EfficientSAM.git
# Set up directory in Google Drive
# DRIVE_DIR="/content/drive/MyDrive/histo-segmentation"
# ZIP_FILE="NuInsSeg.zip"
# ZIP_PATH="$DRIVE_DIR/$ZIP_FILE"
# DEST_DIR="$DRIVE_DIR"

Cloning into 'EfficientSAM'...
Updating files:  92% (36/39)Updating files:  94% (37/39)Updating files:  97% (38/39)Updating files: 100% (39/39)Updating files: 100% (39/39), done.


In [13]:
import sys
sys.modules.pop('utils.visualize', None)
sys.modules.pop('utils.util', None)
sys.modules.pop('utils.loss_functions', None)
sys.modules.pop('dataset', None)

sys.modules.pop('dataset', None)
sys.modules.pop('model', None)
sys.modules.pop('train', None)
sys.modules.pop('EfficientSAM', None)
sys.modules.pop('EfficientSAM.efficient_sam', None)


In [15]:
import sys
sys.path.append('/content/drive/MyDrive/medical_image_computing/histo-segmentation')

from utils.loss_functions import DiceLoss, clDiceLoss, FocalLoss, MCELoss
from utils.visualize import visualize_predictions, display_visualizations_inline, save_visualizations_to_drive
from utils.metrics import plot_metrics

from prepare_data import NuInsSegDatasetV2
from lora_sam import AdaptiveLoRA_EfficientSAM
from EfficientSAM.efficient_sam.efficient_sam import build_efficient_sam
from train import train_one_epoch, evaluate_model, get_loss

import torch
import os
import sys
import json
import random
from datetime import datetime
from torch.utils.data import Subset
from tqdm import tqdm
from torchvision.utils import save_image
import os

import numpy as np
import torch
import torch.nn.functional as F
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader, SubsetRandomSampler
from sklearn.model_selection import KFold

import wandb
from tqdm import tqdm
from skimage.measure import label
from monai.metrics import DiceMetric, MeanIoU, PanopticQualityMetric
import matplotlib.pyplot as plt

class Config:
    def __init__(self):
        self.seed = 42
        self.root_dir = "/content/drive/MyDrive/medical_image_computing/histo-segmentation/NuInsSeg"
        self.num_epochs = 12
        self.batch_size = 4
        self.num_folds = 5
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.patience = 10
        self.early_stopping_min_delta = 0.0005
        self.apply_augmentation = False
        self.visualize_results = True
        self.num_vis_samples = 3
        self.rank = 8
        self.alpha = 32
        self.learning_rate = 1e-4
        self.max_scheduler_iter = 10
        self.min_learning_rate = 1e-6
        self.lambda_focal = 1.0
        self.lambda_dice = 1.0
        self.lambda_boundary = 1.0
        self.lambda_contrastive = 1.0
        self.loss_alpha = 0.75
        self.loss_gamma = 2.0
        self.loss_type = 'dice'
        self.lora_dropout = 0.05
        self.use_gradient_checkpointing = False

class EarlyStopping:
    def __init__(self, patience=5, min_delta=0.0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = float('inf')
        self.early_stop = False

    def __call__(self, val_loss, model, path):
        if val_loss < self.best_loss - self.min_delta:
            tqdm.write(f"Validation loss decreased ({self.best_loss:.4f} --> {val_loss:.4f}). Saving model...")
            self.best_loss = val_loss
            self.counter = 0
            model.save_lora_state(path)
            return True
        else:
            self.counter += 1
            tqdm.write(f"EarlyStopping counter: {self.counter}/{self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True
                tqdm.write("Early stopping triggered")
            return False

def print_model_parameter_stats(model, name="Model"):
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"{name} parameter count: {total:,} total | {trainable:,} trainable ({trainable / total:.2%})")


def build_sam():
    return build_efficient_sam(
        encoder_patch_embed_dim=192, encoder_num_heads=3,
        checkpoint="/content/drive/MyDrive/medical_image_computing/histo-segmentation/EfficientSAM/weights/efficient_sam_vitt.pt"
    ).eval()

def main():
    cfg = Config()
    torch.manual_seed(cfg.seed)
    np.random.seed(cfg.seed)

    dataset = NuInsSegDatasetV2(
        root_dir=cfg.root_dir,
        apply_augmentation=cfg.apply_augmentation,
        stain_normalize=True,
        use_albumentations=False,
        subset_fraction=1,
        subset_seed=cfg.seed
    )

    n = len(dataset)
    kf = KFold(n_splits=cfg.num_folds, shuffle=True, random_state=cfg.seed)
    results = {'folds': {}, 'avg_train_losses': [], 'avg_val_losses': []}

    for fold, (tr, val) in enumerate(kf.split(dataset)):
        tqdm.write(f"==== Fold {fold+1}/{cfg.num_folds} ====")
        results['folds'][str(fold)] = {'train_losses': [], 'val_losses': []}

        train_loader = DataLoader(dataset, batch_size=cfg.batch_size, sampler=SubsetRandomSampler(tr))
        val_loader = DataLoader(dataset, batch_size=cfg.batch_size, sampler=SubsetRandomSampler(val))

        sam = build_sam().to(cfg.device)
        for p in sam.parameters():
            p.requires_grad = False

        model = AdaptiveLoRA_EfficientSAM(
            config=cfg,
            sam_model=sam,
            rank=cfg.rank,
            alpha=cfg.alpha,
            dropout=cfg.lora_dropout,
            use_checkpoint=cfg.use_gradient_checkpointing
        ).to(cfg.device)

        print_model_parameter_stats(model, name="EfficientSAM + LoRA")
        opt = Adam(model.get_lora_params(), lr=cfg.learning_rate)
        sched = CosineAnnealingLR(opt, T_max=cfg.max_scheduler_iter, eta_min=cfg.min_learning_rate)
        crit = get_loss(cfg)

        stopper = EarlyStopping(patience=cfg.patience, min_delta=cfg.early_stopping_min_delta)
        fold_ckpt = os.path.join(cfg.root_dir, "checkpoints", f"fold_{fold}_best.pth")
        os.makedirs(os.path.dirname(fold_ckpt), exist_ok=True)

        best_visuals = None

        for ep in range(cfg.num_epochs):
            tqdm.write(f"===========Epoch {ep}===========")
            train_loss = train_one_epoch(model, train_loader, crit, opt, sched, cfg.device, cfg)
            val_loss, dice, iou, pq = evaluate_model(model, val_loader, crit, cfg.device, cfg)

            tqdm.write(f"  Validation Loss: {val_loss:.4f}, Dice: {dice:.4f}, Jaccard: {iou:.4f}, PQ: {pq:.4f}")
            results['folds'][str(fold)]['train_losses'].append(train_loss)
            results['folds'][str(fold)]['val_losses'].append(val_loss)

            tqdm.write(f"Fold {fold+1}, Epoch {ep+1}: train {train_loss:.4f}, val {val_loss:.4f}, dice {dice:.4f}, iou {iou:.4f}, pq {pq:.4f}")

            saved = stopper(val_loss, model, fold_ckpt)
            if saved and cfg.visualize_results:
                best_visuals = visualize_predictions(
                    model, dataset, cfg.device,
                    output_dir=None, num_samples=cfg.num_vis_samples
                )
                tqdm.write("New best model saved. Visuals updated.")

            if stopper.early_stop:
                break

        if cfg.visualize_results and best_visuals:
            display_visualizations_inline(best_visuals, fold)

            pics = []
            from torchvision.transforms import ToPILImage
            tp = ToPILImage()
            for o in best_visuals:
                pics.append(wandb.Image(tp(o['image'].cpu()), masks={
                    "pred": {"mask_data": o['prediction'].cpu(), "class_labels": {1: "Prediction"}},
                    "gt": {"mask_data": o['ground_truth'].cpu(), "class_labels": {1: "Ground Truth"}}
                }))

    results['avg_train_losses'] = np.mean([results['folds'][f]['train_losses'] for f in results['folds']], axis=0).tolist()
    results['avg_val_losses'] = np.mean([results['folds'][f]['val_losses'] for f in results['folds']], axis=0).tolist()

    plot_metrics(results, os.path.join(cfg.root_dir, "metrics"), show_inline=True)
    tqdm.write("Training complete.")

if __name__ == "__main__":
    main()

Output hidden; open in https://colab.research.google.com to view.