In [1]:
%load_ext autoreload
%autoreload 2

In [6]:
import torch
import os
from utils.common import (
    m2f_dataset_collate,
    m2f_extract_pred_maps_and_masks,
    BG_VALUE,
    set_seed,
    pixel_mean_std,
    CADIS_PIXEL_MEAN,
    CADIS_PIXEL_STD,
    CAT1K_PIXEL_MEAN,
    CAT1K_PIXEL_STD,
    FULL_MERGE_PIXEL_MEAN,
    FULL_MERGE_PIXEL_STD,
    REPLAY32_PIXEL_MEAN,
    REPLAY32_PIXEL_STD,
)
from utils.dataset_utils import (
    get_cadisv2_dataset,
    get_cataract1k_dataset,
    ZEISS_CATEGORIES,
)
from utils.medical_datasets import Mask2FormerDataset
from transformers import (
    Mask2FormerForUniversalSegmentation,
    SwinModel,
    SwinConfig,
    Mask2FormerConfig,
    AutoImageProcessor,
    Mask2FormerImageProcessor
)
from torch.utils.data import DataLoader
import evaluate
import torch.optim as optim
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from dotenv import load_dotenv
import wandb

In [7]:
set_seed(42) # seed everything

In [8]:
NUM_CLASSES = len(ZEISS_CATEGORIES) - 3  # Remove class incremental
SWIN_BACKBONE = "microsoft/swin-large-patch4-window12-384"

# Download pretrained swin model
swin_model = SwinModel.from_pretrained(
    SWIN_BACKBONE, out_features=["stage1", "stage2", "stage3", "stage4"]
)
swin_config = SwinConfig.from_pretrained(
    SWIN_BACKBONE, out_features=["stage1", "stage2", "stage3", "stage4"]
)

# Create Mask2Former configuration based on Swin's configuration
mask2former_config = Mask2FormerConfig(
    backbone_config=swin_config, num_labels=NUM_CLASSES #, ignore_value=BG_VALUE
)

# Create the Mask2Former model with this configuration
model = Mask2FormerForUniversalSegmentation(mask2former_config)

# Reuse pretrained parameters
for swin_param, m2f_param in zip(
    swin_model.named_parameters(),
    model.model.pixel_level_module.encoder.named_parameters(),
):
    m2f_param_name = f"model.pixel_level_module.encoder.{m2f_param[0]}"

    if swin_param[0] == m2f_param[0]:
        model.state_dict()[m2f_param_name].copy_(swin_param[1])
        continue

    print(f"Not Matched: {m2f_param[0]} != {swin_param[0]}")



Not Matched: hidden_states_norms.stage1.weight != layernorm.weight
Not Matched: hidden_states_norms.stage1.bias != layernorm.bias


In [6]:
# Helper function to load datasets
def load_dataset(dataset_getter, data_path, domain_incremental):
    return dataset_getter(data_path, domain_incremental=domain_incremental)

# Helper function to create dataloaders for a dataset
def create_dataloaders(dataset, batch_size, shuffle, num_workers, drop_last, pin_memory, collate_fn):
    return {
        "train": DataLoader(dataset["train"], batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, drop_last=drop_last, pin_memory=pin_memory, collate_fn=collate_fn),
        "val": DataLoader(dataset["val"], batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, drop_last=drop_last, pin_memory=pin_memory, collate_fn=collate_fn),
        "test": DataLoader(dataset["test"], batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, drop_last=drop_last, pin_memory=pin_memory, collate_fn=collate_fn),
    }

# Load datasets
datasets = {
    "A": load_dataset(get_cadisv2_dataset, "../../storage/data/CaDISv2", True),
    "B": load_dataset(get_cataract1k_dataset, "../../storage/data/cataract-1k", True)
}

#pixel_mean_A,pixel_std_A=pixel_mean_std(datasets["A"][0])
pixel_mean_A=CADIS_PIXEL_MEAN
pixel_std_A=CADIS_PIXEL_STD

#pixel_mean_B,pixel_std_B=pixel_mean_std(datasets["B"][0])
pixel_mean_B=CAT1K_PIXEL_MEAN
pixel_std_B=CAT1K_PIXEL_STD

# Define preprocessor
swin_processor = AutoImageProcessor.from_pretrained(SWIN_BACKBONE)
m2f_preprocessor_A = Mask2FormerImageProcessor(
    reduce_labels=True,
    ignore_index=255,
    do_resize=False,
    do_rescale=True,
    do_normalize=True,
    image_std=pixel_std_A, 
    image_mean=pixel_mean_A 
)

m2f_preprocessor_B = Mask2FormerImageProcessor(
    reduce_labels=True,
    ignore_index=255,
    do_resize=False,
    do_rescale=True,
    do_normalize=True,
    image_std=pixel_std_B,
    image_mean=pixel_mean_B
)

# Create Mask2Former Datasets
m2f_datasets={"A":{"train": Mask2FormerDataset(datasets["A"][0], m2f_preprocessor_A),
    "val": Mask2FormerDataset(datasets["A"][1], m2f_preprocessor_A),
    "test": Mask2FormerDataset(datasets["A"][2], m2f_preprocessor_A)},
            "B":{"train": Mask2FormerDataset(datasets["B"][0], m2f_preprocessor_B),
    "val": Mask2FormerDataset(datasets["B"][1], m2f_preprocessor_B),
    "test": Mask2FormerDataset(datasets["B"][2], m2f_preprocessor_B)}}

# DataLoader parameters
N_WORKERS = 4
BATCH_SIZE = 8
SHUFFLE = True
DROP_LAST = True

dataloader_params = {
    "batch_size": BATCH_SIZE,
    "shuffle": SHUFFLE,
    "num_workers": N_WORKERS,
    "drop_last": DROP_LAST,
    "pin_memory": True,
    "collate_fn": m2f_dataset_collate
}

# Create DataLoaders
dataloaders = {key: create_dataloaders(m2f_datasets[key], **dataloader_params) for key in m2f_datasets}

print(dataloaders)



{'A': {'train': <torch.utils.data.dataloader.DataLoader object at 0x7f3c39457950>, 'val': <torch.utils.data.dataloader.DataLoader object at 0x7f3c394575f0>, 'test': <torch.utils.data.dataloader.DataLoader object at 0x7f3c394579e0>}, 'B': {'train': <torch.utils.data.dataloader.DataLoader object at 0x7f3c39457cb0>, 'val': <torch.utils.data.dataloader.DataLoader object at 0x7f3c39457d40>, 'test': <torch.utils.data.dataloader.DataLoader object at 0x7f3c39454740>}}


In [7]:
# Check if CUDA is available, otherwise use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Using device: {device}")

Using device: cuda


In [8]:
# Training
NUM_EPOCHS = 5
LEARNING_RATE = 1e-4
LR_MULTIPLIER = 0.1
BACKBONE_LR = LEARNING_RATE * LR_MULTIPLIER
WEIGHT_DECAY = 0.5
# dice = Dice(average='micro')

# lambda_CE=5.0
# lambda_dice=5.0
metric = evaluate.load("mean_iou")
encoder_params = [
    param
    for name, param in model.named_parameters()
    if name.startswith("model.pixel_level_module.encoder")
]
decoder_params = [
    param
    for name, param in model.named_parameters()
    if name.startswith("model.pixel_level_module.decoder")
]
transformer_params = [
    param
    for name, param in model.named_parameters()
    if name.startswith("model.transformer_module")
]
optimizer = optim.AdamW(
    [
        {"params": encoder_params, "lr": BACKBONE_LR},
        {"params": decoder_params},
        {"params": transformer_params},
    ],
    lr=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
)

scheduler = optim.lr_scheduler.PolynomialLR(
    optimizer, total_iters=NUM_EPOCHS, power=0.9
)

In [2]:
# WandB for team usage !!!!

wandb.login() # use this one if a different person is going to run the notebook
#wandb.login(relogin=False) # if the same person in the last run is going to run the notebook again


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mge85ket[0m ([33mcontinual-learning-tum[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [11]:
wandb.init(
    project="M2F_original",
    config={
        "learning_rate": LEARNING_RATE,
        "learning_rate_multiplier": LR_MULTIPLIER,
        "backbone_learning_rate": BACKBONE_LR,
        "learning_rate_scheduler": scheduler.__class__.__name__,
        "optimizer": optimizer.__class__.__name__,
        "backbone": SWIN_BACKBONE,
        "m2f_preprocessor": m2f_preprocessor.__dict__,
        "m2f_model_config": model.config
    },
    name="M2F-Swin-Large-Naive-Forgetting",
    notes="M2F with large Swin backbone pretrained on ImageNet-1K. \
        Scenario: Train on A, Train on B, Test forgetting on A"
)

[34m[1mwandb[0m: Currently logged in as: [33mge85ket[0m ([33mcontinual-learning-tum[0m). Use [1m`wandb login --relogin`[0m to force relogin


0,1
Loss/train_A,▁

0,1
Loss/train_A,2


In [9]:
# Tensorboard setup
out_dir="outputs/"
if not os.path.exists(out_dir):
    os.makedirs(out_dir)
if not os.path.exists(out_dir+"runs"):
    os.makedirs(out_dir+"runs")
%load_ext tensorboard
%tensorboard --logdir outputs/runs

In [14]:
# Tensorboard logging
writer = SummaryWriter(log_dir=out_dir + "runs")

# Model checkpointing
model_name = "m2f_swin_backbone_naive_forgetting"
model_dir = out_dir + "models/"
if not os.path.exists(model_dir):
    print("Store weights in: ", model_dir)
    os.makedirs(model_dir)

best_model_dir = model_dir + f"{model_name}/best_models/"
if not os.path.exists(model_dir):
    print("Store best model weights in: ", best_model_dir)
    os.makedirs(best_model_dir)
final_model_dir = model_dir + f"{model_name}/final_model/"
if not os.path.exists(model_dir):
    print("Store final model weights in: ", final_model_dir)
    os.makedirs(final_model_dir)

In [15]:
# Save the preprocessor
m2f_preprocessor.save_pretrained(model_dir + model_name)

['outputs/models/m2f_swin_backbone_naive_forgetting/preprocessor_config.json']

In [12]:
#!rm -r outputs

In [13]:
#!CUDA_LAUNCH_BLOCKING=1

# First train on dataset A

In [16]:
# To avoid making stupid errors
CURR_TASK = "A"

# For storing the model
best_val_metric = -np.inf

# Move model to device
model.to(device)

for epoch in range(NUM_EPOCHS):
    model.train()
    train_running_loss = 0.0
    val_running_loss = 0.0

    # Set up tqdm for the training loop
    train_loader = tqdm(
        dataloaders[CURR_TASK]["train"], desc=f"Epoch {epoch + 1}/{NUM_EPOCHS} Training"
    )

    for batch in train_loader:
        # Move everything to the device
        batch["pixel_values"] = batch["pixel_values"].to(device)
        batch["pixel_mask"] = batch["pixel_mask"].to(device)
        batch["mask_labels"] = [entry.to(device) for entry in batch["mask_labels"]]
        batch["class_labels"] = [entry.to(device) for entry in batch["class_labels"]]

        # Compute output and loss
        outputs = model(**batch)

        loss = outputs.loss

        # Compute gradient and perform step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Record losses
        current_loss = loss.item() * batch["pixel_values"].size(0)
        train_running_loss += current_loss
        train_loader.set_postfix(loss=f"{current_loss:.4f}")

        # Extract and compute metrics
        pred_maps, masks = m2f_extract_pred_maps_and_masks(
            batch, outputs, m2f_preprocessor
        )
        metric.add_batch(references=masks, predictions=pred_maps)

    # After compute the batches that were added are deleted
    mean_train_iou = metric.compute(
        num_labels=NUM_CLASSES, ignore_index=BG_VALUE, reduce_labels=False
    )["mean_iou"]

    # Validation phase
    model.eval()
    val_loader = tqdm(
        dataloaders[CURR_TASK]["val"], desc=f"Epoch {epoch + 1}/{NUM_EPOCHS} Validation"
    )
    with torch.no_grad():
        for batch in val_loader:
            # Move everything to the device
            batch["pixel_values"] = batch["pixel_values"].to(device)
            batch["pixel_mask"] = batch["pixel_mask"].to(device)
            batch["mask_labels"] = [entry.to(device) for entry in batch["mask_labels"]]
            batch["class_labels"] = [
                entry.to(device) for entry in batch["class_labels"]
            ]
            # Compute output and loss
            outputs = model(**batch)

            loss = outputs.loss
            # Record losses
            current_loss = loss.item() * batch["pixel_values"].size(0)
            val_running_loss += current_loss
            val_loader.set_postfix(loss=f"{current_loss:.4f}")

            # Extract and compute metrics
            pred_maps, masks = m2f_extract_pred_maps_and_masks(
                batch, outputs, m2f_preprocessor
            )
            metric.add_batch(references=masks, predictions=pred_maps)

    # After compute the batches that were added are deleted
    mean_val_iou = metric.compute(
        num_labels=NUM_CLASSES, ignore_index=BG_VALUE, reduce_labels=False
    )["mean_iou"]

    epoch_train_loss = train_running_loss / len(dataloaders[CURR_TASK]["train"].dataset)
    epoch_val_loss = val_running_loss / len(dataloaders[CURR_TASK]["val"].dataset)

    writer.add_scalar(f"Loss/train_{model_name}_{CURR_TASK}", epoch_train_loss, epoch + 1)
    writer.add_scalar(f"Loss/val_{model_name}_{CURR_TASK}", epoch_val_loss, epoch + 1)
    writer.add_scalar(f"mIoU/train_{model_name}_{CURR_TASK}", mean_train_iou, epoch + 1)
    writer.add_scalar(f"mIoU/val_{model_name}_{CURR_TASK}", mean_val_iou, epoch + 1)

    wandb.log({
        f"Loss/train_{CURR_TASK}": epoch_train_loss,
        f"Loss/val_{CURR_TASK}": epoch_val_loss,
        f"mIoU/train_{CURR_TASK}": mean_train_iou,
        f"mIoU/val_{CURR_TASK}": mean_val_iou
    })

    if mean_val_iou > best_val_metric:
        best_val_metric = mean_val_iou
        model.save_pretrained(f"{best_model_dir}{CURR_TASK}/")

    tqdm.write(
        f"Epoch {epoch + 1}/{NUM_EPOCHS}, Train Loss: {epoch_train_loss:.4f}, Train mIoU: {mean_train_iou:.4f}, Validation Loss: {epoch_val_loss:.4f}, Validation mIoU: {mean_val_iou:.4f}"
    )

Epoch 1/5 Training: 100%|██████████| 443/443 [20:36<00:00,  2.79s/it, loss=157.1899] 
  acc = total_area_intersect / total_area_label
Epoch 1/5 Validation: 100%|██████████| 66/66 [02:16<00:00,  2.07s/it, loss=235.9049]


Epoch 1/5, Train Loss: 44.8594, Train mIoU: 0.2049, Validation Loss: 19.5540, Validation mIoU: 0.4475


Epoch 2/5 Training: 100%|██████████| 443/443 [25:48<00:00,  3.50s/it, loss=108.9452]
Epoch 2/5 Validation: 100%|██████████| 66/66 [01:57<00:00,  1.78s/it, loss=84.8620] 


Epoch 2/5, Train Loss: 15.8552, Train mIoU: 0.5744, Validation Loss: 15.8540, Validation mIoU: 0.6726


Epoch 3/5 Training: 100%|██████████| 443/443 [23:24<00:00,  3.17s/it, loss=108.8004]
Epoch 3/5 Validation: 100%|██████████| 66/66 [02:39<00:00,  2.42s/it, loss=215.1912]


Epoch 3/5, Train Loss: 12.8508, Train mIoU: 0.7110, Validation Loss: 13.5776, Validation mIoU: 0.6996


Epoch 4/5 Training: 100%|██████████| 443/443 [24:06<00:00,  3.27s/it, loss=91.0029] 
Epoch 4/5 Validation: 100%|██████████| 66/66 [02:42<00:00,  2.46s/it, loss=116.2157]


Epoch 4/5, Train Loss: 11.5175, Train mIoU: 0.7435, Validation Loss: 13.4987, Validation mIoU: 0.6997


Epoch 5/5 Training: 100%|██████████| 443/443 [20:13<00:00,  2.74s/it, loss=82.6482] 
Epoch 5/5 Validation: 100%|██████████| 66/66 [02:38<00:00,  2.40s/it, loss=94.3391] 


Epoch 5/5, Train Loss: 10.5654, Train mIoU: 0.7725, Validation Loss: 12.9103, Validation mIoU: 0.7094


# Test results on A

In [19]:
# Load best model and evaluate on test
model = Mask2FormerForUniversalSegmentation.from_pretrained(f"{best_model_dir}{CURR_TASK}/").to(device)

In [20]:
model.eval()
test_running_loss = 0
test_loader = tqdm(dataloaders[CURR_TASK]["test"], desc="Test loop")
with torch.no_grad():
    for batch in test_loader:
        # Move everything to the device
        batch["pixel_values"] = batch["pixel_values"].to(device)
        batch["pixel_mask"] = batch["pixel_mask"].to(device)
        batch["mask_labels"] = [entry.to(device) for entry in batch["mask_labels"]]
        batch["class_labels"] = [entry.to(device) for entry in batch["class_labels"]]
        # Compute output and loss
        outputs = model(**batch)

        loss = outputs.loss
        # Record losses
        current_loss = loss.item() * batch["pixel_values"].size(0)
        test_running_loss += current_loss
        test_loader.set_postfix(loss=f"{current_loss:.4f}")

        # Extract and compute metrics
        pred_maps, masks = m2f_extract_pred_maps_and_masks(
            batch, outputs, m2f_preprocessor
        )
        metric.add_batch(references=masks, predictions=pred_maps)

# After compute the batches that were added are deleted
mean_test_iou = metric.compute(
    num_labels=NUM_CLASSES, ignore_index=BG_VALUE, reduce_labels=False
)["mean_iou"]
final_test_loss = test_running_loss / len(dataloaders[CURR_TASK]["test"].dataset)
wandb.log({
    f"Loss/test_{CURR_TASK}": final_test_loss,
    f"mIoU/test_{CURR_TASK}": mean_test_iou
})
print(f"Test Loss: {final_test_loss:.4f}, Test mIoU: {mean_test_iou:.4f}")

Test loop: 100%|██████████| 73/73 [02:23<00:00,  1.97s/it, loss=158.3683]


Test Loss: 15.9287, Test mIoU: 0.7058


# Now train on B and forget A

In [21]:
# Training
NUM_EPOCHS = 5
LEARNING_RATE = 1e-4
LR_MULTIPLIER = 0.1
BACKBONE_LR = LEARNING_RATE * LR_MULTIPLIER
WEIGHT_DECAY = 0.5
# dice = Dice(average='micro')

# lambda_CE=5.0
# lambda_dice=5.0
encoder_params = [
    param
    for name, param in model.named_parameters()
    if name.startswith("model.pixel_level_module.encoder")
]
decoder_params = [
    param
    for name, param in model.named_parameters()
    if name.startswith("model.pixel_level_module.decoder")
]
transformer_params = [
    param
    for name, param in model.named_parameters()
    if name.startswith("model.transformer_module")
]
optimizer = optim.AdamW(
    [
        {"params": encoder_params, "lr": BACKBONE_LR},
        {"params": decoder_params},
        {"params": transformer_params},
    ],
    lr=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
)

scheduler = optim.lr_scheduler.PolynomialLR(
    optimizer, total_iters=NUM_EPOCHS, power=0.9
)

In [22]:
# To avoid making stupid errors
CURR_TASK = "B"

# For storing the model
best_val_metric = -np.inf

# Move model to device
model.to(device)

for epoch in range(NUM_EPOCHS):
    model.train()
    train_running_loss = 0.0
    val_running_loss = 0.0

    # Set up tqdm for the training loop
    train_loader = tqdm(
        dataloaders[CURR_TASK]["train"], desc=f"Epoch {epoch + 1}/{NUM_EPOCHS} Training"
    )

    for batch in train_loader:
        # Move everything to the device
        batch["pixel_values"] = batch["pixel_values"].to(device)
        batch["pixel_mask"] = batch["pixel_mask"].to(device)
        batch["mask_labels"] = [entry.to(device) for entry in batch["mask_labels"]]
        batch["class_labels"] = [entry.to(device) for entry in batch["class_labels"]]

        # Compute output and loss
        outputs = model(**batch)

        loss = outputs.loss

        # Compute gradient and perform step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Record losses
        current_loss = loss.item() * batch["pixel_values"].size(0)
        train_running_loss += current_loss
        train_loader.set_postfix(loss=f"{current_loss:.4f}")

        # Extract and compute metrics
        pred_maps, masks = m2f_extract_pred_maps_and_masks(
            batch, outputs, m2f_preprocessor
        )
        metric.add_batch(references=masks, predictions=pred_maps)

    # After compute the batches that were added are deleted
    mean_train_iou = metric.compute(
        num_labels=NUM_CLASSES, ignore_index=BG_VALUE, reduce_labels=False
    )["mean_iou"]

    # Validation phase
    model.eval()
    val_loader = tqdm(
        dataloaders[CURR_TASK]["val"], desc=f"Epoch {epoch + 1}/{NUM_EPOCHS} Validation"
    )
    with torch.no_grad():
        for batch in val_loader:
            # Move everything to the device
            batch["pixel_values"] = batch["pixel_values"].to(device)
            batch["pixel_mask"] = batch["pixel_mask"].to(device)
            batch["mask_labels"] = [entry.to(device) for entry in batch["mask_labels"]]
            batch["class_labels"] = [
                entry.to(device) for entry in batch["class_labels"]
            ]
            # Compute output and loss
            outputs = model(**batch)

            loss = outputs.loss
            # Record losses
            current_loss = loss.item() * batch["pixel_values"].size(0)
            val_running_loss += current_loss
            val_loader.set_postfix(loss=f"{current_loss:.4f}")

            # Extract and compute metrics
            pred_maps, masks = m2f_extract_pred_maps_and_masks(
                batch, outputs, m2f_preprocessor
            )
            metric.add_batch(references=masks, predictions=pred_maps)

    # After compute the batches that were added are deleted
    mean_val_iou = metric.compute(
        num_labels=NUM_CLASSES, ignore_index=BG_VALUE, reduce_labels=False
    )["mean_iou"]

    epoch_train_loss = train_running_loss / len(dataloaders[CURR_TASK]["train"].dataset)
    epoch_val_loss = val_running_loss / len(dataloaders[CURR_TASK]["val"].dataset)

    writer.add_scalar(f"Loss/train_{model_name}_{CURR_TASK}", epoch_train_loss, epoch + 1)
    writer.add_scalar(f"Loss/val_{model_name}_{CURR_TASK}", epoch_val_loss, epoch + 1)
    writer.add_scalar(f"mIoU/train_{model_name}_{CURR_TASK}", mean_train_iou, epoch + 1)
    writer.add_scalar(f"mIoU/val_{model_name}_{CURR_TASK}", mean_val_iou, epoch + 1)

    wandb.log({
        f"Loss/train_{CURR_TASK}": epoch_train_loss,
        f"Loss/val_{CURR_TASK}": epoch_val_loss,
        f"mIoU/train_{CURR_TASK}": mean_train_iou,
        f"mIoU/val_{CURR_TASK}": mean_val_iou
    })

    if mean_val_iou > best_val_metric:
        best_val_metric = mean_val_iou
        model.save_pretrained(f"{best_model_dir}{CURR_TASK}/")

    tqdm.write(
        f"Epoch {epoch + 1}/{NUM_EPOCHS}, Train Loss: {epoch_train_loss:.4f}, Train mIoU: {mean_train_iou:.4f}, Validation Loss: {epoch_val_loss:.4f}, Validation mIoU: {mean_val_iou:.4f}"
    )

Epoch 1/5 Training: 100%|██████████| 225/225 [10:28<00:00,  2.79s/it, loss=140.3550]
Epoch 1/5 Validation: 100%|██████████| 28/28 [00:53<00:00,  1.92s/it, loss=147.4061]


Epoch 1/5, Train Loss: 21.2275, Train mIoU: 0.4917, Validation Loss: 17.3837, Validation mIoU: 0.6450


Epoch 2/5 Training: 100%|██████████| 225/225 [09:51<00:00,  2.63s/it, loss=88.4616] 
Epoch 2/5 Validation: 100%|██████████| 28/28 [00:51<00:00,  1.84s/it, loss=90.7679] 


Epoch 2/5, Train Loss: 16.5085, Train mIoU: 0.6278, Validation Loss: 11.5713, Validation mIoU: 0.7547


Epoch 3/5 Training: 100%|██████████| 225/225 [11:22<00:00,  3.03s/it, loss=71.2189] 
Epoch 3/5 Validation: 100%|██████████| 28/28 [00:49<00:00,  1.78s/it, loss=75.3303] 


Epoch 3/5, Train Loss: 11.1197, Train mIoU: 0.7375, Validation Loss: 9.1890, Validation mIoU: 0.7599


Epoch 4/5 Training: 100%|██████████| 225/225 [12:14<00:00,  3.26s/it, loss=74.6497] 
Epoch 4/5 Validation: 100%|██████████| 28/28 [00:49<00:00,  1.77s/it, loss=58.7313] 


Epoch 4/5, Train Loss: 9.1472, Train mIoU: 0.7742, Validation Loss: 8.5610, Validation mIoU: 0.7252


Epoch 5/5 Training: 100%|██████████| 225/225 [11:06<00:00,  2.96s/it, loss=80.4140] 
Epoch 5/5 Validation: 100%|██████████| 28/28 [00:51<00:00,  1.85s/it, loss=83.1679] 


Epoch 5/5, Train Loss: 8.9787, Train mIoU: 0.7828, Validation Loss: 10.4601, Validation mIoU: 0.7588


# Test results on B first

In [23]:
# Load best model and evaluate on test
model = Mask2FormerForUniversalSegmentation.from_pretrained(f"{best_model_dir}{CURR_TASK}/").to(device)

In [24]:
model.eval()
test_running_loss = 0
test_loader = tqdm(dataloaders[CURR_TASK]["test"], desc="Test loop")
with torch.no_grad():
    for batch in test_loader:
        # Move everything to the device
        batch["pixel_values"] = batch["pixel_values"].to(device)
        batch["pixel_mask"] = batch["pixel_mask"].to(device)
        batch["mask_labels"] = [entry.to(device) for entry in batch["mask_labels"]]
        batch["class_labels"] = [entry.to(device) for entry in batch["class_labels"]]
        # Compute output and loss
        outputs = model(**batch)

        loss = outputs.loss
        # Record losses
        current_loss = loss.item() * batch["pixel_values"].size(0)
        test_running_loss += current_loss
        test_loader.set_postfix(loss=f"{current_loss:.4f}")

        # Extract and compute metrics
        pred_maps, masks = m2f_extract_pred_maps_and_masks(
            batch, outputs, m2f_preprocessor
        )
        metric.add_batch(references=masks, predictions=pred_maps)

# After compute the batches that were added are deleted
mean_test_iou = metric.compute(
    num_labels=NUM_CLASSES, ignore_index=BG_VALUE, reduce_labels=False
)["mean_iou"]
final_test_loss = test_running_loss / len(dataloaders[CURR_TASK]["test"].dataset)
wandb.log({
    f"Loss/test_{CURR_TASK}": final_test_loss,
    f"mIoU/test_{CURR_TASK}": mean_test_iou
})
print(f"Test Loss: {final_test_loss:.4f}, Test mIoU: {mean_test_iou:.4f}")

Test loop: 100%|██████████| 28/28 [00:48<00:00,  1.73s/it, loss=78.1765] 


Test Loss: 10.0053, Test mIoU: 0.7529


# Test results on A after training on B

In [25]:
# To avoid making stupid errors
CURR_TASK = "A"

model.eval()
test_running_loss = 0
test_loader = tqdm(dataloaders[CURR_TASK]["test"], desc="Test loop")
with torch.no_grad():
    for batch in test_loader:
        # Move everything to the device
        batch["pixel_values"] = batch["pixel_values"].to(device)
        batch["pixel_mask"] = batch["pixel_mask"].to(device)
        batch["mask_labels"] = [entry.to(device) for entry in batch["mask_labels"]]
        batch["class_labels"] = [entry.to(device) for entry in batch["class_labels"]]
        # Compute output and loss
        outputs = model(**batch)

        loss = outputs.loss
        # Record losses
        current_loss = loss.item() * batch["pixel_values"].size(0)
        test_running_loss += current_loss
        test_loader.set_postfix(loss=f"{current_loss:.4f}")

        # Extract and compute metrics
        pred_maps, masks = m2f_extract_pred_maps_and_masks(
            batch, outputs, m2f_preprocessor
        )
        metric.add_batch(references=masks, predictions=pred_maps)

# After compute the batches that were added are deleted
mean_test_iou = metric.compute(
    num_labels=NUM_CLASSES, ignore_index=BG_VALUE, reduce_labels=False
)["mean_iou"]
final_test_loss = test_running_loss / len(dataloaders[CURR_TASK]["test"].dataset)
wandb.log({
    f"Loss/test_naive_forgetting_{CURR_TASK}": final_test_loss,
    f"mIoU/test_naive_forgetting_{CURR_TASK}": mean_test_iou
})
print(f"Test Loss: {final_test_loss:.4f}, Test mIoU: {mean_test_iou:.4f}")
wandb.finish()

Test loop: 100%|██████████| 73/73 [02:05<00:00,  1.73s/it, loss=279.5755]


Test Loss: 38.9396, Test mIoU: 0.4669
