In [1]:
%load_ext autoreload
%autoreload 2

In [1]:
import torch
from utils.common import preprocess_mask2former, m2f_extract_pred_maps_and_masks, mask2former_auto_image_processor
from utils.dataset_utils import get_cadisv2_dataset, get_cataract1k_dataset, ZEISS_CATEGORIES
from transformers import Mask2FormerForUniversalSegmentation
from torch.utils.data import DataLoader
import evaluate
import torch.optim as optim
from tqdm import tqdm

In [2]:
NUM_CLASSES = len(ZEISS_CATEGORIES) - 3 + 1 # Remove class incremental, add background

# Load model
model = Mask2FormerForUniversalSegmentation.from_pretrained(
    "facebook/mask2former-swin-large-ade-semantic",
    num_labels=NUM_CLASSES - 1,
    ignore_mismatched_sizes=True,
)

Some weights of Mask2FormerForUniversalSegmentation were not initialized from the model checkpoint at facebook/mask2former-swin-large-ade-semantic and are newly initialized because the shapes did not match:
- class_predictor.bias: found shape torch.Size([151]) in the checkpoint and torch.Size([12]) in the model instantiated
- class_predictor.weight: found shape torch.Size([151, 256]) in the checkpoint and torch.Size([12, 256]) in the model instantiated
- criterion.empty_weight: found shape torch.Size([151]) in the checkpoint and torch.Size([12]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
# Domain incremental datasets
train_cadis_ds, val_cadis_ds, test_cadis_ds = get_cadisv2_dataset(
    "data/CaDISv2", domain_incremental=True
)
train_cataract_ds, val_cataract_ds, test_cataract_ds = get_cataract1k_dataset(
    "data/catract-1k", domain_incremental=True
)

# First case: CADIS + CATARACT at the same time
merged_train = torch.utils.data.ConcatDataset([train_cataract_ds, train_cadis_ds])
merged_val = torch.utils.data.ConcatDataset([val_cataract_ds, val_cadis_ds])

# Define dataloader params
N_WORKERS = 1
BATCH_SIZE = 2
SHUFFLE = True
DROP_LAST = True

# Define dataloader
train_merged_loader = DataLoader(
    train_cadis_ds,
    batch_size=BATCH_SIZE,
    shuffle=SHUFFLE,
    num_workers=N_WORKERS,
    drop_last=DROP_LAST,
    pin_memory=True,
    collate_fn=preprocess_mask2former,
)
val_merged_loader = DataLoader(
    merged_val,
    batch_size=BATCH_SIZE,
    shuffle=SHUFFLE,
    num_workers=N_WORKERS,
    drop_last=DROP_LAST,
    pin_memory=True,
    collate_fn=preprocess_mask2former,
)

In [5]:
# Check if CUDA is available, otherwise use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Using device: {device}")

Using device: cpu


In [6]:
# Training
NUM_EPOCHS = 50
LEARNING_RATE = 1e-4
LR_MULTIPLIER = 0.1
BACKBONE_LR = LEARNING_RATE * LR_MULTIPLIER
WEIGHT_DECAY = 0.5
# dice = Dice(average='micro')

# lambda_CE=5.0
# lambda_dice=5.0
metric = evaluate.load("mean_iou")
encoder_params = [
    param
    for name, param in model.named_parameters()
    if name.startswith("model.pixel_level_module.encoder")
]
decoder_params = [
    param
    for name, param in model.named_parameters()
    if name.startswith("model.pixel_level_module.decoder")
]
transformer_params = [
    param
    for name, param in model.named_parameters()
    if name.startswith("model.transformer_module")
]
optimizer = optim.AdamW(
    [
        {"params": encoder_params, "lr": BACKBONE_LR},
        {"params": decoder_params},
        {"params": transformer_params},
    ],
    lr=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
)

scheduler = optim.lr_scheduler.PolynomialLR(
    optimizer, total_iters=NUM_EPOCHS, power=0.9
)

In [7]:
# Move model to device
model.to(device)

for epoch in range(NUM_EPOCHS):
    model.train()
    train_running_loss = 0.0
    val_running_loss = 0.0

    # Set up tqdm for the training loop
    train_loader = tqdm(
        train_merged_loader, desc=f"Epoch {epoch + 1}/{NUM_EPOCHS} Training"
    )
    for batch in train_loader:
        # Move everything to the device
        batch["pixel_values"] = batch["pixel_values"].to(device)
        batch["pixel_mask"] = batch["pixel_mask"].to(device)
        batch["mask_labels"] = [entry.to(device) for entry in batch["mask_labels"]]
        batch["class_labels"] = [entry.to(device) for entry in batch["class_labels"]]

        # Compute output and loss
        outputs = model(**batch)
        loss = outputs.loss

        # Compute gradient and perform step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Record losses
        current_loss = loss.item() * batch["pixel_values"].size(0)
        train_running_loss += current_loss
        train_loader.set_postfix(loss=f"{current_loss:.4f}")

        # Extract and compute metrics
        pred_maps, masks = m2f_extract_pred_maps_and_masks(
            batch, outputs, mask2former_auto_image_processor
        )
        metric.add_batch(references=masks, predictions=pred_maps)

    # After compute the batches that were added are deleted
    mean_train_iou = metric.compute(
        num_labels=NUM_CLASSES, ignore_index=255, reduce_labels=True
    )["mean_iou"]

    # Validation phase
    model.eval()
    val_loader = tqdm(
        val_merged_loader, desc=f"Epoch {epoch + 1}/{NUM_EPOCHS} Validation"
    )
    with torch.no_grad():
        for batch in val_loader:
            # Move everything to the device
            batch["pixel_values"] = batch["pixel_values"].to(device)
            batch["pixel_mask"] = batch["pixel_mask"].to(device)
            batch["mask_labels"] = [entry.to(device) for entry in batch["mask_labels"]]
            batch["class_labels"] = [
                entry.to(device) for entry in batch["class_labels"]
            ]

            # Compute output and loss
            outputs = model(**batch)
            loss = outputs.loss

            # Record losses
            current_loss = loss.item() * batch["pixel_values"].size(0)
            val_running_loss += current_loss
            val_loader.set_postfix(loss=f"{current_loss:.4f}")

            # Extract and compute metrics
            pred_maps, masks = m2f_extract_pred_maps_and_masks(
                batch, outputs, mask2former_auto_image_processor
            )
            metric.add_batch(references=masks, predictions=pred_maps)

    # After compute the batches that were added are deleted
    mean_val_iou = metric.compute(
        num_labels=NUM_CLASSES, ignore_index=255, reduce_labels=True
    )["mean_iou"]

    epoch_train_loss = train_running_loss / len(train_merged_loader.dataset)
    epoch_val_loss = val_running_loss / len(val_merged_loader.dataset)
    tqdm.write(
        f"Epoch {epoch + 1}/{NUM_EPOCHS}, Train Loss: {epoch_train_loss:.4f}, Train mIoU: {mean_train_iou:.4f}, Validation Loss: {epoch_val_loss:.4f}, Validation mIoU: {mean_val_iou:.4f}"
    )

Epoch 1/50 Training:   3%|▎         | 47/1775 [05:15<3:13:32,  6.72s/it, loss=101.6175]


KeyboardInterrupt: 

In [None]:
"""
TODO:
  1. The preprocessor reduces the size of all images AND masks, so that we can train them.
    This raises the question if such a training procedure is good enough. We need to try and
    train a model, then postprocess the predicted maps the original image size and validate
    the data visually.

  2. We currently load a Mask2Former model that is pretrained using the swin backbone on
    ADE. Can we reinitialize the other weights s.t. we end up having a pretrained backbone
    and untrained pixel decoder and transformer decoder.

    => Mask2Former config with pretrained backbone and newly initialized model.

  3. Reuse Gözde code for Tensorboard and consider Weights and Biases. I've heard only
    good things about it and maybe we can give it a try.
  

"""