In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import os
from utils.common import (
    preprocess_mask2former,
    m2f_extract_pred_maps_and_masks,
    mask2former_auto_image_processor,
    MASK_IGNORE_VALUE,
)
from utils.dataset_utils import (
    get_cadisv2_dataset,
    get_cataract1k_dataset,
    ZEISS_CATEGORIES,
)
from transformers import (
    Mask2FormerForUniversalSegmentation,
    SwinModel,
    SwinConfig,
    Mask2FormerConfig,
)
from torch.utils.data import DataLoader
import evaluate
import torch.optim as optim
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import numpy as np



In [26]:
NUM_CLASSES = len(ZEISS_CATEGORIES) - 3 + 1  # Remove class incremental, add background
SWIN_BACKBONE = "microsoft/swin-large-patch4-window12-384"

# Download pretrained swin model
swin_model = SwinModel.from_pretrained(
    SWIN_BACKBONE, out_features=["stage1", "stage2", "stage3", "stage4"]
)
swin_config = SwinConfig.from_pretrained(
    SWIN_BACKBONE, out_features=["stage1", "stage2", "stage3", "stage4"]
)

# Create Mask2Former configuration based on Swin's configuration
mask2former_config = Mask2FormerConfig(
    backbone_config=swin_config, num_labels=NUM_CLASSES, ignore_value=MASK_IGNORE_VALUE
)

# Create the Mask2Former model with this configuration
model = Mask2FormerForUniversalSegmentation(mask2former_config)

# Reuse pretrained parameters
for swin_param, m2f_param in zip(
    swin_model.named_parameters(),
    model.model.pixel_level_module.encoder.named_parameters(),
):
    m2f_param_name = f"model.pixel_level_module.encoder.{m2f_param[0]}"

    if swin_param[0] == m2f_param[0]:
        model.state_dict()[m2f_param_name].copy_(swin_param[1])
        continue

    print(f"Not Matched: {m2f_param[0]} != {swin_param[0]}")



config.json:   0%|          | 0.00/71.8k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/791M [00:00<?, ?B/s]

In [4]:
# Domain incremental datasets
train_cadis_ds, val_cadis_ds, test_cadis_ds = get_cadisv2_dataset(
    "../data/CaDISv2", domain_incremental=True
)
train_cataract_ds, val_cataract_ds, test_cataract_ds = get_cataract1k_dataset(
    "../data/cataract-1k", domain_incremental=True
)

# First case: CADIS + CATARACT at the same time
merged_train = torch.utils.data.ConcatDataset([train_cataract_ds, train_cadis_ds])
merged_val = torch.utils.data.ConcatDataset([val_cataract_ds, val_cadis_ds])
merged_test = torch.utils.data.ConcatDataset([test_cataract_ds, test_cadis_ds])

# Define dataloader params
N_WORKERS = 1
BATCH_SIZE = 2
SHUFFLE = True
DROP_LAST = True

# Define dataloader
train_merged_loader = DataLoader(
    merged_train,
    batch_size=BATCH_SIZE,
    shuffle=SHUFFLE,
    num_workers=N_WORKERS,
    drop_last=DROP_LAST,
    pin_memory=True,
    collate_fn=preprocess_mask2former,
)
val_merged_loader = DataLoader(
    merged_val,
    batch_size=BATCH_SIZE,
    shuffle=SHUFFLE,
    num_workers=N_WORKERS,
    drop_last=DROP_LAST,
    pin_memory=True,
    collate_fn=preprocess_mask2former,
)
test_merged_loader = DataLoader(
    merged_test,
    batch_size=BATCH_SIZE,
    shuffle=SHUFFLE,
    num_workers=N_WORKERS,
    drop_last=DROP_LAST,
    pin_memory=True,
    collate_fn=preprocess_mask2former,
)

In [5]:
# Check if CUDA is available, otherwise use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Using device: {device}")

Using device: cpu


In [6]:
# Training
NUM_EPOCHS = 50
LEARNING_RATE = 1e-4
LR_MULTIPLIER = 0.1
BACKBONE_LR = LEARNING_RATE * LR_MULTIPLIER
WEIGHT_DECAY = 0.5
# dice = Dice(average='micro')

# lambda_CE=5.0
# lambda_dice=5.0
metric = evaluate.load("mean_iou")
encoder_params = [
    param
    for name, param in model.named_parameters()
    if name.startswith("model.pixel_level_module.encoder")
]
decoder_params = [
    param
    for name, param in model.named_parameters()
    if name.startswith("model.pixel_level_module.decoder")
]
transformer_params = [
    param
    for name, param in model.named_parameters()
    if name.startswith("model.transformer_module")
]
optimizer = optim.AdamW(
    [
        {"params": encoder_params, "lr": BACKBONE_LR},
        {"params": decoder_params},
        {"params": transformer_params},
    ],
    lr=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
)

scheduler = optim.lr_scheduler.PolynomialLR(
    optimizer, total_iters=NUM_EPOCHS, power=0.9
)

In [None]:
# Tensorboard setup
out_dir="outputs/"
if not os.path.exists(out_dir):
    os.makedirs(out_dir)
if not os.path.exists(out_dir+"runs"):
    os.makedirs(out_dir+"runs")
%load_ext tensorboard
%tensorboard --logdir outputs/runs

In [9]:
# Tensorboard logging
writer = SummaryWriter(log_dir=out_dir + "runs")

# Model checkpointing
model_name = "m2f_swin_backbone"
model_dir = out_dir + "models/"
if not os.path.exists(model_dir):
    print("Store weights in: ", model_dir)
    os.makedirs(model_dir)

best_model_dir = model_dir + f"{model_name}/best_models/"
if not os.path.exists(model_dir):
    print("Store best model weights in: ", best_model_dir)
    os.makedirs(best_model_dir)
final_model_dir = model_dir + f"{model_name}/final_model/"
if not os.path.exists(model_dir):
    print("Store final model weights in: ", final_model_dir)
    os.makedirs(final_model_dir)

In [11]:
!rm -r outputs

In [9]:
!CUDA_LAUNCH_BLOCKING=1

In [None]:
# For storing the model
best_val_metric = -np.inf

# Move model to device
model.to(device)

for epoch in range(NUM_EPOCHS):
    model.train()
    train_running_loss = 0.0
    val_running_loss = 0.0

    # Set up tqdm for the training loop
    train_loader = tqdm(
        train_merged_loader, desc=f"Epoch {epoch + 1}/{NUM_EPOCHS} Training"
    )

    for batch in train_loader:
        # Move everything to the device
        batch["pixel_values"] = batch["pixel_values"].to(device)
        batch["pixel_mask"] = batch["pixel_mask"].to(device)
        batch["mask_labels"] = [entry.to(device) for entry in batch["mask_labels"]]
        batch["class_labels"] = [entry.to(device) for entry in batch["class_labels"]]

        # Compute output and loss
        outputs = model(**batch)

        loss = outputs.loss

        # Compute gradient and perform step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Record losses
        current_loss = loss.item() * batch["pixel_values"].size(0)
        train_running_loss += current_loss
        train_loader.set_postfix(loss=f"{current_loss:.4f}")

        # Extract and compute metrics
        pred_maps, masks = m2f_extract_pred_maps_and_masks(
            batch, outputs, mask2former_auto_image_processor
        )
        metric.add_batch(references=masks, predictions=pred_maps)

    # After compute the batches that were added are deleted
    mean_train_iou = metric.compute(
        num_labels=NUM_CLASSES, ignore_index=255, reduce_labels=False
    )["mean_iou"]

    # Validation phase
    model.eval()
    val_loader = tqdm(
        val_merged_loader, desc=f"Epoch {epoch + 1}/{NUM_EPOCHS} Validation"
    )
    with torch.no_grad():
        for batch in val_loader:
            # Move everything to the device
            batch["pixel_values"] = batch["pixel_values"].to(device)
            batch["pixel_mask"] = batch["pixel_mask"].to(device)
            batch["mask_labels"] = [entry.to(device) for entry in batch["mask_labels"]]
            batch["class_labels"] = [
                entry.to(device) for entry in batch["class_labels"]
            ]
            # Compute output and loss
            outputs = model(**batch)

            loss = outputs.loss
            # Record losses
            current_loss = loss.item() * batch["pixel_values"].size(0)
            val_running_loss += current_loss
            val_loader.set_postfix(loss=f"{current_loss:.4f}")

            # Extract and compute metrics
            pred_maps, masks = m2f_extract_pred_maps_and_masks(
                batch, outputs, mask2former_auto_image_processor
            )
            metric.add_batch(references=masks, predictions=pred_maps)

    # After compute the batches that were added are deleted
    mean_val_iou = metric.compute(
        num_labels=NUM_CLASSES, ignore_index=255, reduce_labels=False
    )["mean_iou"]

    epoch_train_loss = train_running_loss / len(train_merged_loader.dataset)
    epoch_val_loss = val_running_loss / len(val_merged_loader.dataset)

    writer.add_scalar(f"Loss/train_{model_name}", epoch_train_loss, epoch + 1)
    writer.add_scalar(f"Loss/val_{model_name}", epoch_val_loss, epoch + 1)
    writer.add_scalar(f"mIoU/train_{model_name}", mean_train_iou, epoch + 1)
    writer.add_scalar(f"mIoU/val_{model_name}", mean_val_iou, epoch + 1)

    if mean_val_iou > best_val_metric:
        best_val_metric = mean_val_iou
        model.save_pretrained(best_model_dir)

    tqdm.write(
        f"Epoch {epoch + 1}/{NUM_EPOCHS}, Train Loss: {epoch_train_loss:.4f}, Train mIoU: {mean_train_iou:.4f}, Validation Loss: {epoch_val_loss:.4f}, Validation mIoU: {mean_val_iou:.4f}"
    )

In [None]:
model.eval()
test_running_loss = 0
test_loader = tqdm(test_merged_loader, desc="Test loop")
with torch.no_grad():
    for batch in test_loader:
        # Move everything to the device
        batch["pixel_values"] = batch["pixel_values"].to(device)
        batch["pixel_mask"] = batch["pixel_mask"].to(device)
        batch["mask_labels"] = [entry.to(device) for entry in batch["mask_labels"]]
        batch["class_labels"] = [entry.to(device) for entry in batch["class_labels"]]
        # Compute output and loss
        outputs = model(**batch)

        loss = outputs.loss
        # Record losses
        current_loss = loss.item() * batch["pixel_values"].size(0)
        test_running_loss += current_loss
        test_loader.set_postfix(loss=f"{current_loss:.4f}")

        # Extract and compute metrics
        pred_maps, masks = m2f_extract_pred_maps_and_masks(
            batch, outputs, mask2former_auto_image_processor
        )
        metric.add_batch(references=masks, predictions=pred_maps)

# After compute the batches that were added are deleted
mean_test_iou = metric.compute(
    num_labels=NUM_CLASSES, ignore_index=255, reduce_labels=False
)["mean_iou"]
final_test_loss = test_running_loss / len(test_merged_loader.dataset)
print(f"Test Loss: {final_test_loss:.4f}, Test mIoU: {mean_test_iou:.4f}")

In [None]:
"""
TODO:
  1. Consider Weights and Biases. I've heard only good things about it and maybe we can give it a try.
"""