In [3]:
import os
from typing import Union, Tuple, Literal

import torch

import micro_sam.training as sam_training
from micro_sam.training.util import ConvertToSemanticSamInputs

from medico_sam.util import LinearWarmUpScheduler

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image


In [2]:
class GrayscaleSegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, patch_shape):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.image_files = sorted(os.listdir(image_dir))
        self.mask_files = sorted(os.listdir(mask_dir))
        self.patch_shape = patch_shape

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir, self.image_files[idx])
        mask_path = os.path.join(self.mask_dir, self.mask_files[idx])

        # Load image and mask
        image = Image.open(image_path).convert("L")  # Load as grayscale
        mask = Image.open(mask_path).convert("L")    # Also grayscale but should have class labels as int

        # Resize to patch shape
        image = image.resize(self.patch_shape[::-1])  # PIL expects (W, H)
        mask = mask.resize(self.patch_shape[::-1])

        # Convert to tensors
        image_tensor = transforms.ToTensor()(image)         # (1, H, W), float in [0, 1]
        mask_tensor = transforms.PILToTensor()(mask).long() # (1, H, W), integer labels

        return image_tensor, mask_tensor

In [None]:
DATA_ROOT = "data"


def get_data_loaders(data_path: Union[os.PathLike, str], split: Literal["train", "val"], patch_shape: Tuple[int, int]):
    image_dir = os.path.join(data_path, split, "images")
    mask_dir = os.path.join(data_path, split, "masks")

    dataset = GrayscaleSegmentationDataset(image_dir, mask_dir, patch_shape)
    dataloader = DataLoader(dataset, batch_size=1, shuffle=True, pin_memory=True)

    # Manually add .shuffle attribute to make it compatible
    dataloader.shuffle = True

    return dataloader

def finetune_semantic_sam_2d():
    """Scripts for training a 2d semantic segmentation model on medical datasets."""
    # override this (below) if you have some more complex set-up and need to specify the exact gpu
    device = "cuda" if torch.cuda.is_available() else "cpu"  # device to train the model on.
    print("Using device:", device)

    # training settings:
    model_type = "vit_b_lm"  # override this to your desired choice of Segment Anything model.
    checkpoint_path = None  # override this to start training from a custom checkpoint
    num_classes = 3  # 1 background class and 'n' semantic foreground classes
    checkpoint_name = "oimhs_semantic_sam"  # the name for storing the checkpoints.
    patch_shape = (662, 662)  # the patch shape for 2d semantic segmentation training

    # get the trainable segment anything model
    model = sam_training.get_trainable_sam_model(
        model_type=model_type,
        device=device,
        checkpoint_path=checkpoint_path,
        flexible_load_checkpoint=True,
        num_multimask_outputs=num_classes,
    )
    model.to(device)

    # all the stuff we need for training
    n_epochs = 100
    learning_rate = 1e-4
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.1)
    mscheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.9, patience=5)
    scheduler = LinearWarmUpScheduler(optimizer, warmup_epochs=4, main_scheduler=mscheduler)

    # Get the dataloaders
    train_loader = get_data_loaders(os.path.join(DATA_ROOT), "train", patch_shape)
    val_loader = get_data_loaders(os.path.join(DATA_ROOT), "val", patch_shape)

    # this class creates all the training data for a batch (inputs and labels)
    convert_inputs = ConvertToSemanticSamInputs()

    # the trainer which performs the semantic segmentation training and validation (implemented using "torch_em")
    trainer = sam_training.SemanticSamTrainer(
        name=checkpoint_name,
        train_loader=train_loader,
        val_loader=val_loader,
        model=model,
        optimizer=optimizer,
        device=device,
        lr_scheduler=scheduler,
        log_image_interval=100,
        mixed_precision=True,
        compile_model=False,
        convert_inputs=convert_inputs,
        num_classes=num_classes,
        dice_weight=0.5,
    )
    trainer.fit(epochs=n_epochs)


def main():
    finetune_semantic_sam_2d()


if __name__ == "__main__":
    main()

Using device: cpu


For this object, empty constructor arguments will be used.
The trainer can probably not be correctly deserialized via 'DefaultTrainer.from_checkpoint'.
For this object, empty constructor arguments will be used.
The trainer can probably not be correctly deserialized via 'DefaultTrainer.from_checkpoint'.
For this object, empty constructor arguments will be used.
The trainer can probably not be correctly deserialized via 'DefaultTrainer.from_checkpoint'.


Start fitting for 17900 iterations /  100 epochs
with 179 iterations per epoch
Training with mixed precision


