In [1]:
%load_ext autoreload
%autoreload 2

In [6]:
import torch
import os
from utils.common import (
    m2f_dataset_collate,
    m2f_extract_pred_maps_and_masks,
    BG_VALUE_255,
    set_seed,
    pixel_mean_std,
    CADIS_PIXEL_MEAN,
    CADIS_PIXEL_STD,
    CAT1K_PIXEL_MEAN,
    CAT1K_PIXEL_STD,
)
from utils.dataset_utils import (
    get_cadisv2_dataset,
    get_cataract1k_dataset,
    ZEISS_CATEGORIES,
)
from utils.medical_datasets import Mask2FormerDataset
from transformers import (
    Mask2FormerForUniversalSegmentation,
    SwinModel,
    SwinConfig,
    Mask2FormerConfig,
    AutoImageProcessor,
    Mask2FormerImageProcessor
)
from torch.utils.data import DataLoader
import evaluate
import torch.optim as optim
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from dotenv import load_dotenv
import wandb

In [7]:
set_seed(42) # seed everything

In [8]:
NUM_CLASSES = len(ZEISS_CATEGORIES) - 3  # Remove class incremental
SWIN_BACKBONE = "microsoft/swin-tiny-patch4-window7-224"#"microsoft/swin-large-patch4-window12-384"

# Download pretrained swin model
swin_model = SwinModel.from_pretrained(
    SWIN_BACKBONE, out_features=["stage1", "stage2", "stage3", "stage4"]
)
swin_config = SwinConfig.from_pretrained(
    SWIN_BACKBONE, out_features=["stage1", "stage2", "stage3", "stage4"]
)

# Create Mask2Former configuration based on Swin's configuration
mask2former_config = Mask2FormerConfig(
    backbone_config=swin_config, num_labels=NUM_CLASSES #, ignore_value=BG_VALUE
)

# Create the Mask2Former model with this configuration
model = Mask2FormerForUniversalSegmentation(mask2former_config)

# Reuse pretrained parameters
for swin_param, m2f_param in zip(
    swin_model.named_parameters(),
    model.model.pixel_level_module.encoder.named_parameters(),
):
    m2f_param_name = f"model.pixel_level_module.encoder.{m2f_param[0]}"

    if swin_param[0] == m2f_param[0]:
        model.state_dict()[m2f_param_name].copy_(swin_param[1])
        continue

    print(f"Not Matched: {m2f_param[0]} != {swin_param[0]}")



Not Matched: hidden_states_norms.stage1.weight != layernorm.weight
Not Matched: hidden_states_norms.stage1.bias != layernorm.bias


In [9]:
# Helper function to load datasets
def load_dataset(dataset_getter, data_path, domain_incremental):
    return dataset_getter(data_path, domain_incremental=domain_incremental)


# Helper function to create dataloaders for a dataset
def create_dataloaders(
    dataset, batch_size, shuffle, num_workers, drop_last, pin_memory, collate_fn
):
    return {
        "train": DataLoader(
            dataset["train"],
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=num_workers,
            drop_last=drop_last,
            pin_memory=pin_memory,
            collate_fn=collate_fn,
        ),
        "val": DataLoader(
            dataset["val"],
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=num_workers,
            drop_last=drop_last,
            pin_memory=pin_memory,
            collate_fn=collate_fn,
        ),
        "test": DataLoader(
            dataset["test"],
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=num_workers,
            drop_last=drop_last,
            pin_memory=pin_memory,
            collate_fn=collate_fn,
        ),
    }


# Load datasets
datasets = {
    "A": load_dataset(get_cadisv2_dataset, "../../storage/data/CaDISv2", True),
    "B": load_dataset(get_cataract1k_dataset, "../../storage/data/cataract-1k", True),
}

# pixel_mean_A,pixel_std_A=pixel_mean_std(datasets["A"][0])
pixel_mean_A = CADIS_PIXEL_MEAN
pixel_std_A = CADIS_PIXEL_STD

# pixel_mean_B,pixel_std_B=pixel_mean_std(datasets["B"][0])
pixel_mean_B = CAT1K_PIXEL_MEAN
pixel_std_B = CAT1K_PIXEL_STD

# Define preprocessor
swin_processor = AutoImageProcessor.from_pretrained(SWIN_BACKBONE)
m2f_preprocessor_A = Mask2FormerImageProcessor(
    reduce_labels=True,
    ignore_index=255,
    do_resize=False,
    do_rescale=True,
    do_normalize=True,
    image_std=pixel_std_A,
    image_mean=pixel_mean_A,
)

m2f_preprocessor_B = Mask2FormerImageProcessor(
    reduce_labels=True,
    ignore_index=255,
    do_resize=False,
    do_rescale=True,
    do_normalize=True,
    image_std=pixel_std_B,
    image_mean=pixel_mean_B,
)

# Create Mask2Former Datasets
m2f_datasets = {
    "A": {
        "train": Mask2FormerDataset(datasets["A"][0], m2f_preprocessor_A),
        "val": Mask2FormerDataset(datasets["A"][1], m2f_preprocessor_A),
        "test": Mask2FormerDataset(datasets["A"][2], m2f_preprocessor_A),
    },
    "B": {
        "train": Mask2FormerDataset(datasets["B"][0], m2f_preprocessor_B),
        "val": Mask2FormerDataset(datasets["B"][1], m2f_preprocessor_B),
        "test": Mask2FormerDataset(datasets["B"][2], m2f_preprocessor_B),
    },
}

# DataLoader parameters
N_WORKERS = 4
BATCH_SIZE = 16
SHUFFLE = True
DROP_LAST = True

dataloader_params = {
    "batch_size": BATCH_SIZE,
    "shuffle": SHUFFLE,
    "num_workers": N_WORKERS,
    "drop_last": DROP_LAST,
    "pin_memory": True,
    "collate_fn": m2f_dataset_collate,
}

# Create DataLoaders
dataloaders = {
    key: create_dataloaders(m2f_datasets[key], **dataloader_params)
    for key in m2f_datasets
}

print(dataloaders)



{'A': {'train': <torch.utils.data.dataloader.DataLoader object at 0x7f1b0ea9b0e0>, 'val': <torch.utils.data.dataloader.DataLoader object at 0x7f1b0e8ee600>, 'test': <torch.utils.data.dataloader.DataLoader object at 0x7f1b0ea99850>}, 'B': {'train': <torch.utils.data.dataloader.DataLoader object at 0x7f1b0ea99cd0>, 'val': <torch.utils.data.dataloader.DataLoader object at 0x7f1b0ea98cb0>, 'test': <torch.utils.data.dataloader.DataLoader object at 0x7f1b0ea98bc0>}}


In [10]:
# Check if CUDA is available, otherwise use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Using device: {device}")

Using device: cuda


In [11]:
# Tensorboard setup
out_dir="outputs/"
if not os.path.exists(out_dir):
    os.makedirs(out_dir)
if not os.path.exists(out_dir+"runs"):
    os.makedirs(out_dir+"runs")
%load_ext tensorboard
%tensorboard --logdir outputs/runs

Reusing TensorBoard on port 6006 (pid 1178), started 0:17:01 ago. (Use '!kill 1178' to kill it.)

In [12]:
!CUDA_LAUNCH_BLOCKING=1

In [13]:
# Training
NUM_EPOCHS = 200
LEARNING_RATE = 1e-4
LR_MULTIPLIER = 0.1
BACKBONE_LR = LEARNING_RATE * LR_MULTIPLIER
WEIGHT_DECAY = 0.05
PATIENCE=15
metric = evaluate.load("mean_iou") # mIoU will be used to pick the best performing model using val set
encoder_params = [
    param
    for name, param in model.named_parameters()
    if name.startswith("model.pixel_level_module.encoder")
]
decoder_params = [
    param
    for name, param in model.named_parameters()
    if name.startswith("model.pixel_level_module.decoder")
]
transformer_params = [
    param
    for name, param in model.named_parameters()
    if name.startswith("model.transformer_module")
]
optimizer = optim.AdamW(
    [
        {"params": encoder_params, "lr": BACKBONE_LR},
        {"params": decoder_params},
        {"params": transformer_params},
    ],
    lr=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
)

scheduler = optim.lr_scheduler.PolynomialLR(
    optimizer, total_iters=NUM_EPOCHS, power=0.9
)

In [14]:
# WandB for team usage !!!!

wandb.login() # use this one if a different person is going to run the notebook
#wandb.login(relogin=False) # if the same person in the last run is going to run the notebook again

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mge85ket[0m ([33mcontinual-learning-tum[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [15]:
wandb.init(
    project="M2F_original",
    config={
        "learning_rate": LEARNING_RATE,
        "learning_rate_multiplier": LR_MULTIPLIER,
        "backbone_learning_rate": BACKBONE_LR,
        "learning_rate_scheduler": scheduler.__class__.__name__,
        "optimizer": optimizer.__class__.__name__,
        "backbone": SWIN_BACKBONE,
        "m2f_preprocessor": m2f_preprocessor_B.__dict__,
        "m2f_model_config": model.config
    },
    name="M2F-Swin-Tiny-Train_Cataract1k",
    notes="M2F with tiny Swin backbone pretrained on ImageNet-1K. \
        Scenario: Train on B, Test on B"
)

In [16]:
# Tensorboard logging
writer = SummaryWriter(log_dir=out_dir + "runs")

# Model checkpointing
base_model_name="m2f_swin_backbone_train_cataract1k"
model_dir = out_dir + "models/"
if not os.path.exists(model_dir):
    print("Store weights in: ", model_dir)
    os.makedirs(model_dir)

best_model_dir = model_dir + f"{base_model_name}/best_model/"
if not os.path.exists(best_model_dir):
    print("Store best model weights in: ", best_model_dir)
    os.makedirs(best_model_dir)
final_model_dir = model_dir + f"{base_model_name}/final_model/"
if not os.path.exists(final_model_dir):
    print("Store final model weights in: ", final_model_dir)
    os.makedirs(final_model_dir)

Store best model weights in:  outputs/models/m2f_swin_backbone_train_cataract1k/best_model/
Store final model weights in:  outputs/models/m2f_swin_backbone_train_cataract1k/final_model/


In [17]:
# Save the preprocessor
m2f_preprocessor_B.save_pretrained(model_dir + base_model_name)

['outputs/models/m2f_swin_backbone_train_cataract1k/preprocessor_config.json']

In [18]:
# To avoid making stupid errors
CURR_TASK = "B"

# For storing the model
best_val_metric = -np.inf

# Move model to device
model.to(device)
counter=0
for epoch in range(NUM_EPOCHS):
    model.train()
    train_running_loss = 0.0
    val_running_loss = 0.0

    # Set up tqdm for the training loop
    train_loader = tqdm(
        dataloaders[CURR_TASK]["train"], desc=f"Epoch {epoch + 1}/{NUM_EPOCHS} Training"
    )

    for batch in train_loader:
        # Move everything to the device
        batch["pixel_values"] = batch["pixel_values"].to(device)
        batch["pixel_mask"] = batch["pixel_mask"].to(device)
        batch["mask_labels"] = [entry.to(device) for entry in batch["mask_labels"]]
        batch["class_labels"] = [entry.to(device) for entry in batch["class_labels"]]

        # Compute output and loss
        outputs = model(**batch)

        loss = outputs.loss

        # Compute gradient and perform step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Record losses
        current_loss = loss.item() * batch["pixel_values"].size(0)
        train_running_loss += current_loss
        train_loader.set_postfix(loss=f"{current_loss:.4f}")

        # Extract and compute metrics
        pred_maps, masks = m2f_extract_pred_maps_and_masks(
            batch, outputs, m2f_preprocessor_B
        )
        metric.add_batch(references=masks, predictions=pred_maps)

    # After compute the batches that were added are deleted
    mean_train_iou = metric.compute(
        num_labels=NUM_CLASSES, ignore_index=BG_VALUE_255, reduce_labels=False
    )["mean_iou"]

    # Validation phase
    model.eval()
    val_loader = tqdm(
        dataloaders[CURR_TASK]["val"], desc=f"Epoch {epoch + 1}/{NUM_EPOCHS} Validation"
    )
    with torch.no_grad():
        for batch in val_loader:
            # Move everything to the device
            batch["pixel_values"] = batch["pixel_values"].to(device)
            batch["pixel_mask"] = batch["pixel_mask"].to(device)
            batch["mask_labels"] = [entry.to(device) for entry in batch["mask_labels"]]
            batch["class_labels"] = [
                entry.to(device) for entry in batch["class_labels"]
            ]
            # Compute output and loss
            outputs = model(**batch)

            loss = outputs.loss
            # Record losses
            current_loss = loss.item() * batch["pixel_values"].size(0)
            val_running_loss += current_loss
            val_loader.set_postfix(loss=f"{current_loss:.4f}")

            # Extract and compute metrics
            pred_maps, masks = m2f_extract_pred_maps_and_masks(
                batch, outputs, m2f_preprocessor_B
            )
            metric.add_batch(references=masks, predictions=pred_maps)

    # After compute the batches that were added are deleted
    mean_val_iou = metric.compute(
        num_labels=NUM_CLASSES, ignore_index=BG_VALUE_255, reduce_labels=False
    )["mean_iou"]

    epoch_train_loss = train_running_loss / len(dataloaders[CURR_TASK]["train"].dataset)
    epoch_val_loss = val_running_loss / len(dataloaders[CURR_TASK]["val"].dataset)

    writer.add_scalar(f"Loss/train_{base_model_name}_{CURR_TASK}", epoch_train_loss, epoch + 1)
    writer.add_scalar(f"Loss/val_{base_model_name}_{CURR_TASK}", epoch_val_loss, epoch + 1)
    writer.add_scalar(f"mIoU/train_{base_model_name}_{CURR_TASK}", mean_train_iou, epoch + 1)
    writer.add_scalar(f"mIoU/val_{base_model_name}_{CURR_TASK}", mean_val_iou, epoch + 1)

    wandb.log({
        f"Loss/train_{CURR_TASK}": epoch_train_loss,
        f"Loss/val_{CURR_TASK}": epoch_val_loss,
        f"mIoU/train_{CURR_TASK}": mean_train_iou,
        f"mIoU/val_{CURR_TASK}": mean_val_iou
    })


    tqdm.write(
        f"Epoch {epoch + 1}/{NUM_EPOCHS}, Train Loss: {epoch_train_loss:.4f}, Train mIoU: {mean_train_iou:.4f}, Validation Loss: {epoch_val_loss:.4f}, Validation mIoU: {mean_val_iou:.4f}"
    )
    if mean_val_iou > best_val_metric:
        best_val_metric = mean_val_iou
        model.save_pretrained(f"{best_model_dir}{CURR_TASK}/")
        counter=0
    else:
        counter+=1
        if counter == PATIENCE:
            print("Early stopping at epoch",epoch)
            break

Epoch 1/200 Training: 100%|██████████| 112/112 [04:02<00:00,  2.16s/it, loss=1005.3492]
Epoch 1/200 Validation: 100%|██████████| 14/14 [00:28<00:00,  2.06s/it, loss=1034.2212]


Epoch 1/200, Train Loss: 75.5773, Train mIoU: 0.0561, Validation Loss: 63.7610, Validation mIoU: 0.0527


Epoch 2/200 Training: 100%|██████████| 112/112 [04:25<00:00,  2.37s/it, loss=456.7245]
Epoch 2/200 Validation: 100%|██████████| 14/14 [00:29<00:00,  2.08s/it, loss=443.1598]


Epoch 2/200, Train Loss: 46.6233, Train mIoU: 0.1180, Validation Loss: 29.6270, Validation mIoU: 0.2017


Epoch 3/200 Training: 100%|██████████| 112/112 [04:39<00:00,  2.49s/it, loss=277.5189]
Epoch 3/200 Validation: 100%|██████████| 14/14 [00:26<00:00,  1.86s/it, loss=312.1281]


Epoch 3/200, Train Loss: 24.7570, Train mIoU: 0.2643, Validation Loss: 17.9317, Validation mIoU: 0.3077


Epoch 4/200 Training: 100%|██████████| 112/112 [04:43<00:00,  2.53s/it, loss=218.5298]
Epoch 4/200 Validation: 100%|██████████| 14/14 [00:24<00:00,  1.73s/it, loss=213.4838]


Epoch 4/200, Train Loss: 16.5261, Train mIoU: 0.3603, Validation Loss: 14.5305, Validation mIoU: 0.4005


Epoch 5/200 Training: 100%|██████████| 112/112 [04:23<00:00,  2.35s/it, loss=283.2661]
Epoch 5/200 Validation: 100%|██████████| 14/14 [00:26<00:00,  1.90s/it, loss=207.1202]


Epoch 5/200, Train Loss: 13.9841, Train mIoU: 0.4384, Validation Loss: 14.2592, Validation mIoU: 0.4985


Epoch 6/200 Training: 100%|██████████| 112/112 [04:32<00:00,  2.43s/it, loss=170.4574]
Epoch 6/200 Validation: 100%|██████████| 14/14 [00:23<00:00,  1.71s/it, loss=160.7390]


Epoch 6/200, Train Loss: 12.7808, Train mIoU: 0.5229, Validation Loss: 11.9468, Validation mIoU: 0.5536


Epoch 7/200 Training: 100%|██████████| 112/112 [04:22<00:00,  2.34s/it, loss=178.7817]
Epoch 7/200 Validation: 100%|██████████| 14/14 [00:21<00:00,  1.52s/it, loss=195.4693]


Epoch 7/200, Train Loss: 11.5162, Train mIoU: 0.6232, Validation Loss: 11.8902, Validation mIoU: 0.6239


Epoch 8/200 Training: 100%|██████████| 112/112 [04:51<00:00,  2.60s/it, loss=151.7322]
Epoch 8/200 Validation: 100%|██████████| 14/14 [00:25<00:00,  1.82s/it, loss=169.8750]


Epoch 8/200, Train Loss: 10.9962, Train mIoU: 0.6969, Validation Loss: 10.9025, Validation mIoU: 0.6095


Epoch 9/200 Training: 100%|██████████| 112/112 [04:55<00:00,  2.64s/it, loss=200.2686]
Epoch 9/200 Validation: 100%|██████████| 14/14 [00:25<00:00,  1.81s/it, loss=144.3062]


Epoch 9/200, Train Loss: 10.0828, Train mIoU: 0.7432, Validation Loss: 10.1736, Validation mIoU: 0.8034


Epoch 10/200 Training: 100%|██████████| 112/112 [03:54<00:00,  2.09s/it, loss=141.2014]
Epoch 10/200 Validation: 100%|██████████| 14/14 [00:22<00:00,  1.61s/it, loss=186.5227]


Epoch 10/200, Train Loss: 9.3311, Train mIoU: 0.8310, Validation Loss: 9.4750, Validation mIoU: 0.8526


Epoch 11/200 Training: 100%|██████████| 112/112 [04:34<00:00,  2.45s/it, loss=136.5857]
Epoch 11/200 Validation: 100%|██████████| 14/14 [00:21<00:00,  1.55s/it, loss=159.2937]


Epoch 11/200, Train Loss: 8.7547, Train mIoU: 0.8564, Validation Loss: 9.4845, Validation mIoU: 0.7024


Epoch 12/200 Training: 100%|██████████| 112/112 [03:55<00:00,  2.11s/it, loss=119.9671]
Epoch 12/200 Validation: 100%|██████████| 14/14 [00:29<00:00,  2.13s/it, loss=170.6537]


Epoch 12/200, Train Loss: 8.1992, Train mIoU: 0.8753, Validation Loss: 9.3101, Validation mIoU: 0.8035


Epoch 13/200 Training: 100%|██████████| 112/112 [04:34<00:00,  2.45s/it, loss=145.1824]
Epoch 13/200 Validation: 100%|██████████| 14/14 [00:29<00:00,  2.12s/it, loss=130.2578]


Epoch 13/200, Train Loss: 7.9346, Train mIoU: 0.8984, Validation Loss: 9.0380, Validation mIoU: 0.8238


Epoch 14/200 Training: 100%|██████████| 112/112 [04:19<00:00,  2.32s/it, loss=129.4955]
Epoch 14/200 Validation: 100%|██████████| 14/14 [00:26<00:00,  1.87s/it, loss=129.5777]


Epoch 14/200, Train Loss: 7.9298, Train mIoU: 0.8975, Validation Loss: 8.9417, Validation mIoU: 0.7770


Epoch 15/200 Training: 100%|██████████| 112/112 [04:27<00:00,  2.39s/it, loss=116.6620]
Epoch 15/200 Validation: 100%|██████████| 14/14 [00:28<00:00,  2.07s/it, loss=141.9222]


Epoch 15/200, Train Loss: 7.5682, Train mIoU: 0.9167, Validation Loss: 8.7759, Validation mIoU: 0.8661


Epoch 16/200 Training: 100%|██████████| 112/112 [04:26<00:00,  2.38s/it, loss=107.9743]
Epoch 16/200 Validation: 100%|██████████| 14/14 [00:26<00:00,  1.90s/it, loss=116.5517]


Epoch 16/200, Train Loss: 7.3831, Train mIoU: 0.9232, Validation Loss: 8.4447, Validation mIoU: 0.8371


Epoch 17/200 Training: 100%|██████████| 112/112 [04:22<00:00,  2.34s/it, loss=119.3054]
Epoch 17/200 Validation: 100%|██████████| 14/14 [00:25<00:00,  1.80s/it, loss=124.5542]


Epoch 17/200, Train Loss: 7.3554, Train mIoU: 0.9141, Validation Loss: 8.1975, Validation mIoU: 0.8541


Epoch 18/200 Training: 100%|██████████| 112/112 [04:09<00:00,  2.23s/it, loss=116.9357]
Epoch 18/200 Validation: 100%|██████████| 14/14 [00:25<00:00,  1.84s/it, loss=150.1043]


Epoch 18/200, Train Loss: 7.0820, Train mIoU: 0.9208, Validation Loss: 8.5689, Validation mIoU: 0.7813


Epoch 19/200 Training: 100%|██████████| 112/112 [04:07<00:00,  2.21s/it, loss=98.2649] 
Epoch 19/200 Validation: 100%|██████████| 14/14 [00:23<00:00,  1.68s/it, loss=187.2308]


Epoch 19/200, Train Loss: 6.7037, Train mIoU: 0.9393, Validation Loss: 8.1149, Validation mIoU: 0.8525


Epoch 20/200 Training: 100%|██████████| 112/112 [04:31<00:00,  2.42s/it, loss=123.9750]
Epoch 20/200 Validation: 100%|██████████| 14/14 [00:28<00:00,  2.07s/it, loss=142.1499]


Epoch 20/200, Train Loss: 8.1609, Train mIoU: 0.8820, Validation Loss: 9.6510, Validation mIoU: 0.6926


Epoch 21/200 Training: 100%|██████████| 112/112 [04:12<00:00,  2.25s/it, loss=109.7539]
Epoch 21/200 Validation: 100%|██████████| 14/14 [00:22<00:00,  1.62s/it, loss=136.8968]


Epoch 21/200, Train Loss: 7.2986, Train mIoU: 0.8841, Validation Loss: 8.9689, Validation mIoU: 0.8472


Epoch 22/200 Training: 100%|██████████| 112/112 [04:05<00:00,  2.20s/it, loss=99.6331] 
Epoch 22/200 Validation: 100%|██████████| 14/14 [00:26<00:00,  1.90s/it, loss=116.2064]


Epoch 22/200, Train Loss: 7.1549, Train mIoU: 0.9162, Validation Loss: 8.5157, Validation mIoU: 0.8304


Epoch 23/200 Training: 100%|██████████| 112/112 [04:08<00:00,  2.21s/it, loss=87.9806] 
Epoch 23/200 Validation: 100%|██████████| 14/14 [00:19<00:00,  1.39s/it, loss=112.4945]


Epoch 23/200, Train Loss: 6.6452, Train mIoU: 0.9170, Validation Loss: 8.2561, Validation mIoU: 0.8683


Epoch 24/200 Training: 100%|██████████| 112/112 [04:00<00:00,  2.15s/it, loss=95.3512] 
Epoch 24/200 Validation: 100%|██████████| 14/14 [00:26<00:00,  1.88s/it, loss=169.1621]


Epoch 24/200, Train Loss: 6.4900, Train mIoU: 0.9452, Validation Loss: 7.9266, Validation mIoU: 0.7951


Epoch 25/200 Training: 100%|██████████| 112/112 [04:11<00:00,  2.24s/it, loss=85.4631] 
Epoch 25/200 Validation: 100%|██████████| 14/14 [00:26<00:00,  1.88s/it, loss=110.5565]


Epoch 25/200, Train Loss: 6.0841, Train mIoU: 0.9402, Validation Loss: 8.7427, Validation mIoU: 0.8370


Epoch 26/200 Training: 100%|██████████| 112/112 [04:09<00:00,  2.23s/it, loss=87.2481] 
Epoch 26/200 Validation: 100%|██████████| 14/14 [00:27<00:00,  1.94s/it, loss=135.7147]


Epoch 26/200, Train Loss: 6.0309, Train mIoU: 0.9389, Validation Loss: 7.8734, Validation mIoU: 0.7716


Epoch 27/200 Training: 100%|██████████| 112/112 [04:06<00:00,  2.20s/it, loss=88.2843] 
Epoch 27/200 Validation: 100%|██████████| 14/14 [00:23<00:00,  1.67s/it, loss=110.0062]


Epoch 27/200, Train Loss: 5.8112, Train mIoU: 0.9520, Validation Loss: 8.2624, Validation mIoU: 0.8875


Epoch 28/200 Training: 100%|██████████| 112/112 [03:54<00:00,  2.10s/it, loss=91.4766] 
Epoch 28/200 Validation: 100%|██████████| 14/14 [00:11<00:00,  1.24it/s, loss=117.9502]


Epoch 28/200, Train Loss: 6.6771, Train mIoU: 0.9198, Validation Loss: 8.4984, Validation mIoU: 0.7635


Epoch 29/200 Training: 100%|██████████| 112/112 [02:14<00:00,  1.20s/it, loss=93.8823] 
Epoch 29/200 Validation: 100%|██████████| 14/14 [00:10<00:00,  1.29it/s, loss=130.0483]


Epoch 29/200, Train Loss: 6.0221, Train mIoU: 0.9423, Validation Loss: 7.9589, Validation mIoU: 0.8411


Epoch 30/200 Training: 100%|██████████| 112/112 [02:17<00:00,  1.23s/it, loss=98.7947] 
Epoch 30/200 Validation: 100%|██████████| 14/14 [00:10<00:00,  1.28it/s, loss=114.3894]


Epoch 30/200, Train Loss: 6.2123, Train mIoU: 0.9379, Validation Loss: 8.1811, Validation mIoU: 0.8171


Epoch 31/200 Training: 100%|██████████| 112/112 [02:11<00:00,  1.18s/it, loss=88.5705] 
Epoch 31/200 Validation: 100%|██████████| 14/14 [00:15<00:00,  1.10s/it, loss=138.6827]


Epoch 31/200, Train Loss: 5.8942, Train mIoU: 0.9560, Validation Loss: 8.0357, Validation mIoU: 0.8632


Epoch 32/200 Training: 100%|██████████| 112/112 [02:23<00:00,  1.28s/it, loss=83.5863] 
Epoch 32/200 Validation: 100%|██████████| 14/14 [00:11<00:00,  1.20it/s, loss=120.5308]


Epoch 32/200, Train Loss: 5.6216, Train mIoU: 0.9588, Validation Loss: 8.3198, Validation mIoU: 0.8461


Epoch 33/200 Training: 100%|██████████| 112/112 [02:19<00:00,  1.25s/it, loss=96.4502] 
Epoch 33/200 Validation: 100%|██████████| 14/14 [00:09<00:00,  1.41it/s, loss=121.3018]


Epoch 33/200, Train Loss: 5.5729, Train mIoU: 0.9611, Validation Loss: 7.5850, Validation mIoU: 0.7510


Epoch 34/200 Training: 100%|██████████| 112/112 [02:35<00:00,  1.39s/it, loss=81.9006] 
Epoch 34/200 Validation: 100%|██████████| 14/14 [00:12<00:00,  1.11it/s, loss=119.7700]


Epoch 34/200, Train Loss: 5.6075, Train mIoU: 0.9600, Validation Loss: 7.7490, Validation mIoU: 0.7654


Epoch 35/200 Training: 100%|██████████| 112/112 [02:50<00:00,  1.52s/it, loss=78.6634]
Epoch 35/200 Validation: 100%|██████████| 14/14 [00:12<00:00,  1.14it/s, loss=91.7693] 


Epoch 35/200, Train Loss: 5.4461, Train mIoU: 0.9625, Validation Loss: 7.8459, Validation mIoU: 0.8375


Epoch 36/200 Training: 100%|██████████| 112/112 [02:18<00:00,  1.24s/it, loss=90.4501] 
Epoch 36/200 Validation: 100%|██████████| 14/14 [00:13<00:00,  1.02it/s, loss=108.9191]


Epoch 36/200, Train Loss: 5.4865, Train mIoU: 0.9573, Validation Loss: 8.5012, Validation mIoU: 0.7426


Epoch 37/200 Training: 100%|██████████| 112/112 [02:14<00:00,  1.20s/it, loss=79.4322]
Epoch 37/200 Validation: 100%|██████████| 14/14 [00:10<00:00,  1.32it/s, loss=121.0195]


Epoch 37/200, Train Loss: 5.5642, Train mIoU: 0.9447, Validation Loss: 8.6511, Validation mIoU: 0.7469


Epoch 38/200 Training: 100%|██████████| 112/112 [02:52<00:00,  1.54s/it, loss=89.8768]
Epoch 38/200 Validation: 100%|██████████| 14/14 [00:11<00:00,  1.17it/s, loss=104.0121]


Epoch 38/200, Train Loss: 5.3639, Train mIoU: 0.9612, Validation Loss: 8.1783, Validation mIoU: 0.7978


Epoch 39/200 Training: 100%|██████████| 112/112 [02:13<00:00,  1.20s/it, loss=87.4207] 
Epoch 39/200 Validation: 100%|██████████| 14/14 [00:15<00:00,  1.12s/it, loss=124.7681]


Epoch 39/200, Train Loss: 5.2418, Train mIoU: 0.9542, Validation Loss: 7.9088, Validation mIoU: 0.8776


Epoch 40/200 Training: 100%|██████████| 112/112 [02:19<00:00,  1.25s/it, loss=84.8645] 
Epoch 40/200 Validation: 100%|██████████| 14/14 [00:10<00:00,  1.35it/s, loss=133.7952]


Epoch 40/200, Train Loss: 5.6825, Train mIoU: 0.9587, Validation Loss: 8.5264, Validation mIoU: 0.7715


Epoch 41/200 Training: 100%|██████████| 112/112 [02:23<00:00,  1.28s/it, loss=79.3472]
Epoch 41/200 Validation: 100%|██████████| 14/14 [00:12<00:00,  1.15it/s, loss=137.8056]


Epoch 41/200, Train Loss: 5.3715, Train mIoU: 0.9617, Validation Loss: 7.8416, Validation mIoU: 0.7692


Epoch 42/200 Training: 100%|██████████| 112/112 [02:21<00:00,  1.27s/it, loss=70.9277]
Epoch 42/200 Validation: 100%|██████████| 14/14 [00:12<00:00,  1.16it/s, loss=108.0276]


Epoch 42/200, Train Loss: 5.2162, Train mIoU: 0.9647, Validation Loss: 7.8950, Validation mIoU: 0.8257
Early stopping at epoch 41


## Test on A

In [19]:
# Load best model and evaluate on test
model = Mask2FormerForUniversalSegmentation.from_pretrained(f"{best_model_dir}{CURR_TASK}/").to(device)

In [20]:
model.eval()
test_running_loss = 0
test_loader = tqdm(dataloaders[CURR_TASK]["test"], desc="Test loop")
with torch.no_grad():
    for batch in test_loader:
        # Move everything to the device
        batch["pixel_values"] = batch["pixel_values"].to(device)
        batch["pixel_mask"] = batch["pixel_mask"].to(device)
        batch["mask_labels"] = [entry.to(device) for entry in batch["mask_labels"]]
        batch["class_labels"] = [entry.to(device) for entry in batch["class_labels"]]
        # Compute output and loss
        outputs = model(**batch)

        loss = outputs.loss
        # Record losses
        current_loss = loss.item() * batch["pixel_values"].size(0)
        test_running_loss += current_loss
        test_loader.set_postfix(loss=f"{current_loss:.4f}")

        # Extract and compute metrics
        pred_maps, masks = m2f_extract_pred_maps_and_masks(
            batch, outputs, m2f_preprocessor_B
        )
        metric.add_batch(references=masks, predictions=pred_maps)
        
# After compute the batches that were added are deleted
test_metrics_B = metric.compute(
    num_labels=NUM_CLASSES, ignore_index=BG_VALUE_255, reduce_labels=False
)
mean_test_iou = test_metrics_B["mean_iou"]
final_test_loss = test_running_loss / len(dataloaders[CURR_TASK]["test"].dataset)
wandb.log({
    f"Loss/test_{CURR_TASK}": final_test_loss,
    f"mIoU/test_{CURR_TASK}": mean_test_iou
})
print(f"Test Loss: {final_test_loss:.4f}, Test mIoU: {mean_test_iou:.4f}")
wandb.finish()

Test loop: 100%|██████████| 14/14 [00:10<00:00,  1.36it/s, loss=136.1198]


Test Loss: 8.2671, Test mIoU: 0.8541


0,1
Loss/test_B,▁
Loss/train_B,█▅▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/val_B,█▄▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
mIoU/test_B,▁
mIoU/train_B,▁▁▃▃▄▅▅▆▆▇▇▇▇▇█████▇████████████████████
mIoU/val_B,▁▂▃▄▅▅▆▆▇█▆▇▇▇███▇█▆██▇█▇█▇█▇██▇▇█▇▇▇█▇▇

0,1
Loss/test_B,8.26706
Loss/train_B,5.21622
Loss/val_B,7.89498
mIoU/test_B,0.85406
mIoU/train_B,0.96474
mIoU/val_B,0.82568
