In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import os
from utils.common import (
    m2f_dataset_collate,
    m2f_extract_pred_maps_and_masks,
    BG_VALUE_255,
    set_seed,
    pixel_mean_std,
    CADIS_PIXEL_MEAN,
    CADIS_PIXEL_STD,
    FULL_MERGE_PIXEL_MEAN,
    FULL_MERGE_PIXEL_STD,
)
from utils.dataset_utils import (
    get_cadisv2_dataset,
    get_cataract1k_dataset,
    ZEISS_CATEGORIES,
)
from utils.medical_datasets import Mask2FormerDataset
from transformers import (
    Mask2FormerForUniversalSegmentation,
    SwinModel,
    SwinConfig,
    Mask2FormerConfig,
    AutoImageProcessor,
    Mask2FormerImageProcessor
)
from torch.utils.data import DataLoader
import evaluate
import torch.optim as optim
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from dotenv import load_dotenv
import wandb

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
set_seed(42) # seed everything

In [4]:
NUM_CLASSES = len(ZEISS_CATEGORIES) - 3  # Remove class incremental
SWIN_BACKBONE = "microsoft/swin-tiny-patch4-window7-224"#"microsoft/swin-large-patch4-window12-384"

# Download pretrained swin model
swin_model = SwinModel.from_pretrained(
    SWIN_BACKBONE, out_features=["stage1", "stage2", "stage3", "stage4"]
)
swin_config = SwinConfig.from_pretrained(
    SWIN_BACKBONE, out_features=["stage1", "stage2", "stage3", "stage4"]
)

# Create Mask2Former configuration based on Swin's configuration
mask2former_config = Mask2FormerConfig(
    backbone_config=swin_config, num_labels=NUM_CLASSES #, ignore_value=BG_VALUE
)

# Create the Mask2Former model with this configuration
model = Mask2FormerForUniversalSegmentation(mask2former_config)

# Reuse pretrained parameters
for swin_param, m2f_param in zip(
    swin_model.named_parameters(),
    model.model.pixel_level_module.encoder.named_parameters(),
):
    m2f_param_name = f"model.pixel_level_module.encoder.{m2f_param[0]}"

    if swin_param[0] == m2f_param[0]:
        model.state_dict()[m2f_param_name].copy_(swin_param[1])
        continue

    print(f"Not Matched: {m2f_param[0]} != {swin_param[0]}")



Not Matched: hidden_states_norms.stage1.weight != layernorm.weight
Not Matched: hidden_states_norms.stage1.bias != layernorm.bias


In [5]:
# Helper function to load datasets
def load_dataset(dataset_getter, data_path, domain_incremental):
    return dataset_getter(data_path, domain_incremental=domain_incremental)


# Helper function to create dataloaders for a dataset
def create_dataloaders(
    dataset, batch_size, shuffle, num_workers, drop_last, pin_memory, collate_fn
):
    return {
        "train": DataLoader(
            dataset["train"],
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=num_workers,
            drop_last=drop_last,
            pin_memory=pin_memory,
            collate_fn=collate_fn,
        ),
        "val": DataLoader(
            dataset["val"],
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=num_workers,
            drop_last=drop_last,
            pin_memory=pin_memory,
            collate_fn=collate_fn,
        ),
        "test": DataLoader(
            dataset["test"],
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=num_workers,
            drop_last=drop_last,
            pin_memory=pin_memory,
            collate_fn=collate_fn,
        ),
    }


# Load datasets
datasets = {
    "A": load_dataset(get_cadisv2_dataset, "../../storage/data/CaDISv2", True),
    "B": load_dataset(get_cataract1k_dataset, "../../storage/data/cataract-1k", True),
}
# pixel_mean_A,pixel_std_A=pixel_mean_std(datasets["A"][0])
pixel_mean_A = CADIS_PIXEL_MEAN
pixel_std_A = CADIS_PIXEL_STD

# This time define the B train dataset such that it replays all the training samples from A
new_train = torch.utils.data.ConcatDataset([datasets["A"][0], datasets["B"][0]])

# pixel_mean_B,pixel_std_B=pixel_mean_std(new_train)
pixel_mean_B = FULL_MERGE_PIXEL_MEAN
pixel_std_B = FULL_MERGE_PIXEL_STD

datasets["B"] = (new_train, datasets["B"][1], datasets["B"][2])

# Define preprocessor
swin_processor = AutoImageProcessor.from_pretrained(SWIN_BACKBONE)
m2f_preprocessor_A = Mask2FormerImageProcessor(
    reduce_labels=True,
    ignore_index=255,
    do_resize=False,
    do_rescale=True,
    do_normalize=True,
    image_std=pixel_std_A,
    image_mean=pixel_mean_A,
)

m2f_preprocessor_B = Mask2FormerImageProcessor(
    reduce_labels=True,
    ignore_index=255,
    do_resize=False,
    do_rescale=True,
    do_normalize=True,
    image_std=pixel_std_B,
    image_mean=pixel_mean_B,
)

# Create Mask2Former Datasets

m2f_datasets = {
    "A": {
        "train": Mask2FormerDataset(datasets["A"][0], m2f_preprocessor_A),
        "val": Mask2FormerDataset(datasets["A"][1], m2f_preprocessor_A),
        "test": Mask2FormerDataset(datasets["A"][2], m2f_preprocessor_A),
    },
    "B": {
        "train": Mask2FormerDataset(datasets["B"][0], m2f_preprocessor_B),
        "val": Mask2FormerDataset(datasets["B"][1], m2f_preprocessor_B),
        "test": Mask2FormerDataset(datasets["B"][2], m2f_preprocessor_B),
    },
}


# DataLoader parameters
N_WORKERS = 4
BATCH_SIZE = 16
SHUFFLE = True
DROP_LAST = True

dataloader_params = {
    "batch_size": BATCH_SIZE,
    "shuffle": SHUFFLE,
    "num_workers": N_WORKERS,
    "drop_last": DROP_LAST,
    "pin_memory": True,
    "collate_fn": m2f_dataset_collate,
}

# Create DataLoaders
dataloaders = {
    key: create_dataloaders(m2f_datasets[key], **dataloader_params)
    for key in m2f_datasets
}

print(dataloaders)



{'A': {'train': <torch.utils.data.dataloader.DataLoader object at 0x7f33710151f0>, 'val': <torch.utils.data.dataloader.DataLoader object at 0x7f3371014470>, 'test': <torch.utils.data.dataloader.DataLoader object at 0x7f33710140b0>}, 'B': {'train': <torch.utils.data.dataloader.DataLoader object at 0x7f3371015c40>, 'val': <torch.utils.data.dataloader.DataLoader object at 0x7f33710155b0>, 'test': <torch.utils.data.dataloader.DataLoader object at 0x7f3371016540>}}


In [6]:
torch.unique(datasets["A"][0][0][1])

tensor([ 0,  4, 10, 11], dtype=torch.int32)

In [7]:
m2f_datasets["A"]["train"][0]["class_labels"]

[tensor([ 3,  9, 10])]

In [23]:
# Check if CUDA is available, otherwise use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Using device: {device}")

Using device: cuda


In [14]:
# Tensorboard setup
out_dir="outputs/"
if not os.path.exists(out_dir):
    os.makedirs(out_dir)
if not os.path.exists(out_dir+"runs"):
    os.makedirs(out_dir+"runs")
%load_ext tensorboard
%tensorboard --logdir outputs/runs

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Reusing TensorBoard on port 6006 (pid 1235), started 2:51:18 ago. (Use '!kill 1235' to kill it.)

In [8]:
!CUDA_LAUNCH_BLOCKING=1

# First train on dataset A

In [11]:
# Training
NUM_EPOCHS = 200
LEARNING_RATE = 1e-4
LR_MULTIPLIER = 0.1
BACKBONE_LR = LEARNING_RATE * LR_MULTIPLIER
WEIGHT_DECAY = 0.05
PATIENCE=15
metric = evaluate.load("mean_iou") # mIoU will be used to pick the best performing model using val set
encoder_params = [
    param
    for name, param in model.named_parameters()
    if name.startswith("model.pixel_level_module.encoder")
]
decoder_params = [
    param
    for name, param in model.named_parameters()
    if name.startswith("model.pixel_level_module.decoder")
]
transformer_params = [
    param
    for name, param in model.named_parameters()
    if name.startswith("model.transformer_module")
]
optimizer = optim.AdamW(
    [
        {"params": encoder_params, "lr": BACKBONE_LR},
        {"params": decoder_params},
        {"params": transformer_params},
    ],
    lr=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
)

scheduler = optim.lr_scheduler.PolynomialLR(
    optimizer, total_iters=NUM_EPOCHS, power=0.9
)

Downloading builder script: 100%|██████████| 12.9k/12.9k [00:00<00:00, 17.2MB/s]


In [12]:
# WandB for team usage !!!!

wandb.login() # use this one if a different person is going to run the notebook
#wandb.login(relogin=False) # if the same person in the last run is going to run the notebook again

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

  ········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [13]:
wandb.init(
    project="M2F_original",
    config={
        "learning_rate": LEARNING_RATE,
        "learning_rate_multiplier": LR_MULTIPLIER,
        "backbone_learning_rate": BACKBONE_LR,
        "learning_rate_scheduler": scheduler.__class__.__name__,
        "optimizer": optimizer.__class__.__name__,
        "backbone": SWIN_BACKBONE,
        "m2f_preprocessor": m2f_preprocessor_A.__dict__,
        "m2f_model_config": model.config
    },
    name="M2F-Swin-Tiny-Train_Cadis",
    notes="M2F with tiny Swin backbone pretrained on ImageNet-1K. \
        Scenario: Train on A, Test on A"
)

[34m[1mwandb[0m: Currently logged in as: [33mge85ket[0m ([33mcontinual-learning-tum[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [9]:
# Tensorboard logging
writer = SummaryWriter(log_dir=out_dir + "runs")

# Model checkpointing
base_model_name="m2f_swin_backbone_train_cadis"
model_dir = out_dir + "models/"
if not os.path.exists(model_dir):
    print("Store weights in: ", model_dir)
    os.makedirs(model_dir)

best_model_dir = model_dir + f"{base_model_name}/best_model/"
if not os.path.exists(best_model_dir):
    print("Store best model weights in: ", best_model_dir)
    os.makedirs(best_model_dir)
final_model_dir = model_dir + f"{base_model_name}/final_model/"
if not os.path.exists(final_model_dir):
    print("Store final model weights in: ", final_model_dir)
    os.makedirs(final_model_dir)

In [15]:
# Save the preprocessor
m2f_preprocessor_A.save_pretrained(model_dir + base_model_name)

['outputs/models/m2f_swin_backbone_train_cadis/preprocessor_config.json']

In [16]:
# To avoid making stupid errors
CURR_TASK = "A"

# For storing the model
best_val_metric = -np.inf

# Move model to device
model.to(device)
counter=0
for epoch in range(NUM_EPOCHS):
    model.train()
    train_running_loss = 0.0
    val_running_loss = 0.0

    # Set up tqdm for the training loop
    train_loader = tqdm(
        dataloaders[CURR_TASK]["train"], desc=f"Epoch {epoch + 1}/{NUM_EPOCHS} Training"
    )

    for batch in train_loader:
        # Move everything to the device
        batch["pixel_values"] = batch["pixel_values"].to(device)
        batch["pixel_mask"] = batch["pixel_mask"].to(device)
        batch["mask_labels"] = [entry.to(device) for entry in batch["mask_labels"]]
        batch["class_labels"] = [entry.to(device) for entry in batch["class_labels"]]
       
        # Compute output and loss
        outputs = model(**batch)

        loss = outputs.loss
        # Compute gradient and perform step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Record losses
        current_loss = loss.item() * batch["pixel_values"].size(0)
        train_running_loss += current_loss
        train_loader.set_postfix(loss=f"{current_loss:.4f}")

        # Extract and compute metrics
        pred_maps, masks = m2f_extract_pred_maps_and_masks(
            batch, outputs, m2f_preprocessor_A
        )
        metric.add_batch(references=masks, predictions=pred_maps)

    # After compute the batches that were added are deleted
    temp_metrics = metric.compute(
        num_labels=NUM_CLASSES, ignore_index=BG_VALUE_255, reduce_labels=False
    )
    mean_train_iou=temp_metrics["mean_iou"]
        
    # Validation phase
    model.eval()
    val_loader = tqdm(
        dataloaders[CURR_TASK]["val"], desc=f"Epoch {epoch + 1}/{NUM_EPOCHS} Validation"
    )
    with torch.no_grad():
        for batch in val_loader:
            # Move everything to the device
            batch["pixel_values"] = batch["pixel_values"].to(device)
            batch["pixel_mask"] = batch["pixel_mask"].to(device)
            batch["mask_labels"] = [entry.to(device) for entry in batch["mask_labels"]]
            batch["class_labels"] = [
                entry.to(device) for entry in batch["class_labels"]
            ]
            # Compute output and loss
            outputs = model(**batch)

            loss = outputs.loss
            # Record losses
            current_loss = loss.item() * batch["pixel_values"].size(0)
            val_running_loss += current_loss
            val_loader.set_postfix(loss=f"{current_loss:.4f}")

            # Extract and compute metrics
            pred_maps, masks = m2f_extract_pred_maps_and_masks(
                batch, outputs, m2f_preprocessor_A
            )
            metric.add_batch(references=masks, predictions=pred_maps)
            

    # After compute the batches that were added are deleted
    mean_val_iou = metric.compute(
        num_labels=NUM_CLASSES, ignore_index=BG_VALUE_255, reduce_labels=False
    )["mean_iou"]

    epoch_train_loss = train_running_loss / len(dataloaders[CURR_TASK]["train"].dataset)
    epoch_val_loss = val_running_loss / len(dataloaders[CURR_TASK]["val"].dataset)

    writer.add_scalar(f"Loss/train_{base_model_name}_{CURR_TASK}", epoch_train_loss, epoch + 1)
    writer.add_scalar(f"Loss/val_{base_model_name}_{CURR_TASK}", epoch_val_loss, epoch + 1)
    writer.add_scalar(f"mIoU/train_{base_model_name}_{CURR_TASK}", mean_train_iou, epoch + 1)
    writer.add_scalar(f"mIoU/val_{base_model_name}_{CURR_TASK}", mean_val_iou, epoch + 1)

    wandb.log({
        f"Loss/train_{CURR_TASK}": epoch_train_loss,
        f"Loss/val_{CURR_TASK}": epoch_val_loss,
        f"mIoU/train_{CURR_TASK}": mean_train_iou,
        f"mIoU/val_{CURR_TASK}": mean_val_iou
    })


    tqdm.write(
        f"Epoch {epoch + 1}/{NUM_EPOCHS}, Train Loss: {epoch_train_loss:.4f}, Train mIoU: {mean_train_iou:.4f}, Validation Loss: {epoch_val_loss:.4f}, Validation mIoU: {mean_val_iou:.4f}"
    )
    
    if mean_val_iou > best_val_metric:
        best_val_metric = mean_val_iou
        model.save_pretrained(f"{best_model_dir}{CURR_TASK}/")
        counter=0
    else:
        counter+=1
        if counter == PATIENCE:
            print("Early stopping at epoch",epoch)
            break

Epoch 1/200 Training: 100%|██████████| 221/221 [07:11<00:00,  1.95s/it, loss=401.1506] 
Epoch 1/200 Validation: 100%|██████████| 33/33 [00:44<00:00,  1.34s/it, loss=483.8031]


Epoch 1/200, Train Loss: 59.9673, Train mIoU: 0.1127, Validation Loss: 28.1699, Validation mIoU: 0.1739


Epoch 2/200 Training: 100%|██████████| 221/221 [07:05<00:00,  1.92s/it, loss=299.5568]
Epoch 2/200 Validation: 100%|██████████| 33/33 [00:40<00:00,  1.21s/it, loss=255.4191]


Epoch 2/200, Train Loss: 21.9250, Train mIoU: 0.3603, Validation Loss: 18.8730, Validation mIoU: 0.5495


Epoch 3/200 Training: 100%|██████████| 221/221 [07:07<00:00,  1.93s/it, loss=223.5976]
Epoch 3/200 Validation: 100%|██████████| 33/33 [00:38<00:00,  1.17s/it, loss=507.0912]


Epoch 3/200, Train Loss: 16.0550, Train mIoU: 0.6059, Validation Loss: 16.0244, Validation mIoU: 0.6703


Epoch 4/200 Training: 100%|██████████| 221/221 [07:06<00:00,  1.93s/it, loss=225.2377]
Epoch 4/200 Validation: 100%|██████████| 33/33 [00:42<00:00,  1.28s/it, loss=211.8082]


Epoch 4/200, Train Loss: 13.6189, Train mIoU: 0.6810, Validation Loss: 14.8552, Validation mIoU: 0.6770


Epoch 5/200 Training: 100%|██████████| 221/221 [07:09<00:00,  1.94s/it, loss=185.5254]
Epoch 5/200 Validation: 100%|██████████| 33/33 [00:41<00:00,  1.26s/it, loss=200.0621]


Epoch 5/200, Train Loss: 12.2627, Train mIoU: 0.7518, Validation Loss: 14.0428, Validation mIoU: 0.7407


Epoch 6/200 Training: 100%|██████████| 221/221 [07:07<00:00,  1.93s/it, loss=174.2913]
Epoch 6/200 Validation: 100%|██████████| 33/33 [00:39<00:00,  1.20s/it, loss=241.7548]


Epoch 6/200, Train Loss: 11.0515, Train mIoU: 0.8090, Validation Loss: 13.5142, Validation mIoU: 0.7556


Epoch 7/200 Training: 100%|██████████| 221/221 [07:07<00:00,  1.93s/it, loss=172.0493]
Epoch 7/200 Validation: 100%|██████████| 33/33 [00:39<00:00,  1.19s/it, loss=180.1881]


Epoch 7/200, Train Loss: 10.2347, Train mIoU: 0.8442, Validation Loss: 13.0086, Validation mIoU: 0.7746


Epoch 8/200 Training: 100%|██████████| 221/221 [07:09<00:00,  1.94s/it, loss=178.9977]
Epoch 8/200 Validation: 100%|██████████| 33/33 [00:38<00:00,  1.16s/it, loss=202.4963]


Epoch 8/200, Train Loss: 9.7156, Train mIoU: 0.8640, Validation Loss: 12.8888, Validation mIoU: 0.8009


Epoch 9/200 Training: 100%|██████████| 221/221 [07:03<00:00,  1.92s/it, loss=150.3341]
Epoch 9/200 Validation: 100%|██████████| 33/33 [00:37<00:00,  1.15s/it, loss=182.8118]


Epoch 9/200, Train Loss: 9.3311, Train mIoU: 0.8710, Validation Loss: 13.0399, Validation mIoU: 0.7448


Epoch 10/200 Training: 100%|██████████| 221/221 [07:06<00:00,  1.93s/it, loss=144.6579]
Epoch 10/200 Validation: 100%|██████████| 33/33 [00:42<00:00,  1.28s/it, loss=157.7974]


Epoch 10/200, Train Loss: 8.9172, Train mIoU: 0.8876, Validation Loss: 12.7428, Validation mIoU: 0.8042


Epoch 11/200 Training: 100%|██████████| 221/221 [07:06<00:00,  1.93s/it, loss=126.5559]
Epoch 11/200 Validation: 100%|██████████| 33/33 [00:39<00:00,  1.20s/it, loss=141.9452]


Epoch 11/200, Train Loss: 8.5717, Train mIoU: 0.8973, Validation Loss: 12.8770, Validation mIoU: 0.8107


Epoch 12/200 Training: 100%|██████████| 221/221 [07:12<00:00,  1.96s/it, loss=125.8806]
Epoch 12/200 Validation: 100%|██████████| 33/33 [00:39<00:00,  1.21s/it, loss=343.2729]


Epoch 12/200, Train Loss: 8.3852, Train mIoU: 0.9048, Validation Loss: 12.3538, Validation mIoU: 0.7602


Epoch 13/200 Training: 100%|██████████| 221/221 [07:12<00:00,  1.96s/it, loss=113.7843]
Epoch 13/200 Validation: 100%|██████████| 33/33 [00:40<00:00,  1.21s/it, loss=140.8690]


Epoch 13/200, Train Loss: 8.0266, Train mIoU: 0.9076, Validation Loss: 12.3803, Validation mIoU: 0.8098


Epoch 14/200 Training: 100%|██████████| 221/221 [07:07<00:00,  1.93s/it, loss=106.1232]
Epoch 14/200 Validation: 100%|██████████| 33/33 [00:42<00:00,  1.29s/it, loss=141.3668]


Epoch 14/200, Train Loss: 7.8349, Train mIoU: 0.9120, Validation Loss: 12.5519, Validation mIoU: 0.7572


Epoch 15/200 Training: 100%|██████████| 221/221 [07:04<00:00,  1.92s/it, loss=127.3804]
Epoch 15/200 Validation: 100%|██████████| 33/33 [00:40<00:00,  1.23s/it, loss=162.7226]


Epoch 15/200, Train Loss: 7.5868, Train mIoU: 0.9169, Validation Loss: 12.4464, Validation mIoU: 0.7668


Epoch 16/200 Training: 100%|██████████| 221/221 [07:03<00:00,  1.92s/it, loss=109.6527]
Epoch 16/200 Validation: 100%|██████████| 33/33 [00:40<00:00,  1.23s/it, loss=168.0958]


Epoch 16/200, Train Loss: 7.4564, Train mIoU: 0.9121, Validation Loss: 12.5355, Validation mIoU: 0.8312


Epoch 17/200 Training: 100%|██████████| 221/221 [07:13<00:00,  1.96s/it, loss=122.7549]
Epoch 17/200 Validation: 100%|██████████| 33/33 [00:40<00:00,  1.23s/it, loss=217.2358]


Epoch 17/200, Train Loss: 7.3812, Train mIoU: 0.9274, Validation Loss: 12.6641, Validation mIoU: 0.7958


Epoch 18/200 Training: 100%|██████████| 221/221 [07:11<00:00,  1.95s/it, loss=113.9071]
Epoch 18/200 Validation: 100%|██████████| 33/33 [00:39<00:00,  1.19s/it, loss=212.6597]


Epoch 18/200, Train Loss: 7.2872, Train mIoU: 0.9248, Validation Loss: 12.2307, Validation mIoU: 0.7793


Epoch 19/200 Training: 100%|██████████| 221/221 [07:07<00:00,  1.93s/it, loss=113.8650]
Epoch 19/200 Validation: 100%|██████████| 33/33 [00:41<00:00,  1.26s/it, loss=213.1660]


Epoch 19/200, Train Loss: 6.9397, Train mIoU: 0.9283, Validation Loss: 13.0369, Validation mIoU: 0.7093


Epoch 20/200 Training: 100%|██████████| 221/221 [07:09<00:00,  1.94s/it, loss=107.2071]
Epoch 20/200 Validation: 100%|██████████| 33/33 [00:41<00:00,  1.25s/it, loss=157.6948]


Epoch 20/200, Train Loss: 6.8617, Train mIoU: 0.9328, Validation Loss: 12.7653, Validation mIoU: 0.7682


Epoch 21/200 Training: 100%|██████████| 221/221 [07:06<00:00,  1.93s/it, loss=112.2943]
Epoch 21/200 Validation: 100%|██████████| 33/33 [00:41<00:00,  1.25s/it, loss=177.6266]


Epoch 21/200, Train Loss: 6.6442, Train mIoU: 0.9404, Validation Loss: 12.9168, Validation mIoU: 0.7992


Epoch 22/200 Training: 100%|██████████| 221/221 [07:06<00:00,  1.93s/it, loss=114.5106]
Epoch 22/200 Validation: 100%|██████████| 33/33 [00:42<00:00,  1.29s/it, loss=170.0397]


Epoch 22/200, Train Loss: 6.4627, Train mIoU: 0.9407, Validation Loss: 12.7133, Validation mIoU: 0.7706


Epoch 23/200 Training: 100%|██████████| 221/221 [07:11<00:00,  1.95s/it, loss=104.8699]
Epoch 23/200 Validation: 100%|██████████| 33/33 [00:41<00:00,  1.26s/it, loss=187.9654]


Epoch 23/200, Train Loss: 6.6558, Train mIoU: 0.9238, Validation Loss: 13.4670, Validation mIoU: 0.7338


Epoch 24/200 Training: 100%|██████████| 221/221 [07:12<00:00,  1.96s/it, loss=112.5719]
Epoch 24/200 Validation: 100%|██████████| 33/33 [00:42<00:00,  1.30s/it, loss=200.1590]


Epoch 24/200, Train Loss: 6.3335, Train mIoU: 0.9302, Validation Loss: 12.6056, Validation mIoU: 0.8208


Epoch 25/200 Training: 100%|██████████| 221/221 [07:13<00:00,  1.96s/it, loss=95.5995] 
Epoch 25/200 Validation: 100%|██████████| 33/33 [00:42<00:00,  1.30s/it, loss=265.8563]


Epoch 25/200, Train Loss: 6.1377, Train mIoU: 0.9454, Validation Loss: 12.9725, Validation mIoU: 0.8183


Epoch 26/200 Training: 100%|██████████| 221/221 [07:10<00:00,  1.95s/it, loss=85.4450] 
Epoch 26/200 Validation: 100%|██████████| 33/33 [00:42<00:00,  1.28s/it, loss=274.1076]


Epoch 26/200, Train Loss: 6.0130, Train mIoU: 0.9459, Validation Loss: 13.2565, Validation mIoU: 0.7970


Epoch 27/200 Training: 100%|██████████| 221/221 [07:10<00:00,  1.95s/it, loss=92.0950] 
Epoch 27/200 Validation: 100%|██████████| 33/33 [00:43<00:00,  1.31s/it, loss=391.9355]


Epoch 27/200, Train Loss: 5.9057, Train mIoU: 0.9454, Validation Loss: 12.8361, Validation mIoU: 0.8126


Epoch 28/200 Training: 100%|██████████| 221/221 [07:14<00:00,  1.97s/it, loss=107.4168]
Epoch 28/200 Validation: 100%|██████████| 33/33 [00:42<00:00,  1.30s/it, loss=199.8778]


Epoch 28/200, Train Loss: 5.7176, Train mIoU: 0.9492, Validation Loss: 13.3272, Validation mIoU: 0.8082


Epoch 29/200 Training: 100%|██████████| 221/221 [07:14<00:00,  1.96s/it, loss=92.2878] 
Epoch 29/200 Validation: 100%|██████████| 33/33 [00:43<00:00,  1.32s/it, loss=159.5150]


Epoch 29/200, Train Loss: 5.7386, Train mIoU: 0.9483, Validation Loss: 12.7490, Validation mIoU: 0.7946


Epoch 30/200 Training: 100%|██████████| 221/221 [07:15<00:00,  1.97s/it, loss=93.8445] 
Epoch 30/200 Validation: 100%|██████████| 33/33 [00:42<00:00,  1.29s/it, loss=161.5087]


Epoch 30/200, Train Loss: 5.7040, Train mIoU: 0.9445, Validation Loss: 13.0323, Validation mIoU: 0.8008


Epoch 31/200 Training: 100%|██████████| 221/221 [07:09<00:00,  1.94s/it, loss=88.7041] 
Epoch 31/200 Validation: 100%|██████████| 33/33 [00:42<00:00,  1.28s/it, loss=168.6165]


Epoch 31/200, Train Loss: 5.5612, Train mIoU: 0.9444, Validation Loss: 13.1058, Validation mIoU: 0.8114
Early stopping at epoch 30


## Test results on A

In [24]:
# Load best model and evaluate on test
#model = Mask2FormerForUniversalSegmentation.from_pretrained(f"{best_model_dir}{CURR_TASK}/").to(device)

# Load pretrained on Cadis from naive forgetting 
model = Mask2FormerForUniversalSegmentation.from_pretrained(f"/notebooks/continual-learning/outputs/models/{base_model_name}/best_model/A").to(device)

In [26]:
model.eval()
test_running_loss = 0
test_loader = tqdm(dataloaders[CURR_TASK]["test"], desc="Test loop")
with torch.no_grad():
    for batch in test_loader:
        # Move everything to the device
        batch["pixel_values"] = batch["pixel_values"].to(device)
        batch["pixel_mask"] = batch["pixel_mask"].to(device)
        batch["mask_labels"] = [entry.to(device) for entry in batch["mask_labels"]]
        batch["class_labels"] = [entry.to(device) for entry in batch["class_labels"]]
        # Compute output and loss
        outputs = model(**batch)

        loss = outputs.loss
        # Record losses
        current_loss = loss.item() * batch["pixel_values"].size(0)
        test_running_loss += current_loss
        test_loader.set_postfix(loss=f"{current_loss:.4f}")

        # Extract and compute metrics
        pred_maps, masks = m2f_extract_pred_maps_and_masks(
            batch, outputs, m2f_preprocessor_A
        )
        metric.add_batch(references=masks, predictions=pred_maps)
        
# After compute the batches that were added are deleted
test_metrics_A = metric.compute(
    num_labels=NUM_CLASSES, ignore_index=BG_VALUE_255, reduce_labels=False
)
mean_test_iou = test_metrics_A["mean_iou"]
final_test_loss = test_running_loss / len(dataloaders[CURR_TASK]["test"].dataset)
wandb.log({
    f"Loss/test_{CURR_TASK}": final_test_loss,
    f"mIoU/test_{CURR_TASK}": mean_test_iou
})
print(f"Test Loss: {final_test_loss:.4f}, Test mIoU: {mean_test_iou:.4f}")
wandb.finish()

Test loop: 100%|██████████| 36/36 [00:58<00:00,  1.63s/it, loss=318.6559]


Test Loss: 14.8668, Test mIoU: 0.7716


In [19]:
#previous run
test_metrics_A

{'mean_iou': 0.7735610940366814,
 'mean_accuracy': 0.8364536793631839,
 'overall_accuracy': 0.9621300366527966,
 'per_category_iou': array([0.96072231, 0.93562797, 0.64369086, 0.39402199, 0.4091044 ,
        0.81320499, 0.81338487, 0.76786534, 0.89548438, 0.94261244,
        0.93345249]),
 'per_category_accuracy': array([0.97414412, 0.98914931, 0.81667993, 0.43158069, 0.42852069,
        0.91832696, 0.89176254, 0.86807917, 0.94488664, 0.96806242,
        0.969798  ])}

# Now train on B and forget A

In [20]:
# Training
NUM_EPOCHS = 200
LEARNING_RATE = 1e-4
LR_MULTIPLIER = 0.1
BACKBONE_LR = LEARNING_RATE * LR_MULTIPLIER
WEIGHT_DECAY = 0.05
PATIENCE=15
encoder_params = [
    param
    for name, param in model.named_parameters()
    if name.startswith("model.pixel_level_module.encoder")
]
decoder_params = [
    param
    for name, param in model.named_parameters()
    if name.startswith("model.pixel_level_module.decoder")
]
transformer_params = [
    param
    for name, param in model.named_parameters()
    if name.startswith("model.transformer_module")
]
optimizer = optim.AdamW(
    [
        {"params": encoder_params, "lr": BACKBONE_LR},
        {"params": decoder_params},
        {"params": transformer_params},
    ],
    lr=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
)

scheduler = optim.lr_scheduler.PolynomialLR(
    optimizer, total_iters=NUM_EPOCHS, power=0.9
)

In [21]:
# WandB for team usage !!!!

wandb.login() # use this one if a different person is going to run the notebook
#wandb.login(relogin=False) # if the same person in the last run is going to run the notebook again

True

In [28]:
wandb.init(
    project="M2F_original",
    config={
        "learning_rate": LEARNING_RATE,
        "learning_rate_multiplier": LR_MULTIPLIER,
        "backbone_learning_rate": BACKBONE_LR,
        "learning_rate_scheduler": scheduler.__class__.__name__,
        "optimizer": optimizer.__class__.__name__,
        "backbone": SWIN_BACKBONE,
        "m2f_preprocessor": m2f_preprocessor_B.__dict__,
        "m2f_model_config": model.config
    },
    name="M2F-Swin-Tiny-Replay-All",
    notes="M2F with tiny Swin backbone pretrained on ImageNet-1K. \
        Scenario: Pretrained on A, Train on A + B naive finetuning (replay all), Test forgetting on A"
)

# Tensorboard logging
writer = SummaryWriter(log_dir=out_dir + "runs")

# Model checkpointing
model_name = "m2f_swin_backbone_replay_all"
model_dir = out_dir + "models/"
if not os.path.exists(model_dir):
    print("Store weights in: ", model_dir)
    os.makedirs(model_dir)

best_model_dir = model_dir + f"{model_name}/best_model/"
if not os.path.exists(best_model_dir):
    print("Store best model weights in: ", best_model_dir)
    os.makedirs(best_model_dir)
final_model_dir = model_dir + f"{model_name}/final_model/"
if not os.path.exists(final_model_dir):
    print("Store final model weights in: ", final_model_dir)
    os.makedirs(final_model_dir)

In [23]:
# Save the preprocessor
m2f_preprocessor_B.save_pretrained(model_dir + model_name)

['outputs/models/m2f_swin_backbone_replay_all/preprocessor_config.json']

In [None]:
# To avoid making stupid errors
CURR_TASK = "B"
model_path_second = f"{best_model_dir}A+{CURR_TASK}/"


# For storing the model
best_val_metric = -np.inf

# Move model to device
model.to(device)
counter=0
for epoch in range(NUM_EPOCHS):
    model.train()
    train_running_loss = 0.0
    val_running_loss = 0.0

    # Set up tqdm for the training loop
    train_loader = tqdm(
        dataloaders[CURR_TASK]["train"], desc=f"Epoch {epoch + 1}/{NUM_EPOCHS} Training"
    )

    for batch in train_loader:
        # Move everything to the device
        batch["pixel_values"] = batch["pixel_values"].to(device)
        batch["pixel_mask"] = batch["pixel_mask"].to(device)
        batch["mask_labels"] = [entry.to(device) for entry in batch["mask_labels"]]
        batch["class_labels"] = [entry.to(device) for entry in batch["class_labels"]]

        # Compute output and loss
        outputs = model(**batch)

        loss = outputs.loss

        # Compute gradient and perform step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Record losses
        current_loss = loss.item() * batch["pixel_values"].size(0)
        train_running_loss += current_loss
        train_loader.set_postfix(loss=f"{current_loss:.4f}")

        # Extract and compute metrics
        pred_maps, masks = m2f_extract_pred_maps_and_masks(
            batch, outputs, m2f_preprocessor_B
        )
        metric.add_batch(references=masks, predictions=pred_maps)
        

    # After compute the batches that were added are deleted
    mean_train_iou = metric.compute(
        num_labels=NUM_CLASSES, ignore_index=BG_VALUE_255, reduce_labels=False
    )["mean_iou"]

    # Validation phase
    model.eval()
    val_loader = tqdm(
        dataloaders[CURR_TASK]["val"], desc=f"Epoch {epoch + 1}/{NUM_EPOCHS} Validation"
    )
    with torch.no_grad():
        for batch in val_loader:
            # Move everything to the device
            batch["pixel_values"] = batch["pixel_values"].to(device)
            batch["pixel_mask"] = batch["pixel_mask"].to(device)
            batch["mask_labels"] = [entry.to(device) for entry in batch["mask_labels"]]
            batch["class_labels"] = [
                entry.to(device) for entry in batch["class_labels"]
            ]
            # Compute output and loss
            outputs = model(**batch)

            loss = outputs.loss
            # Record losses
            current_loss = loss.item() * batch["pixel_values"].size(0)
            val_running_loss += current_loss
            val_loader.set_postfix(loss=f"{current_loss:.4f}")

            # Extract and compute metrics
            pred_maps, masks = m2f_extract_pred_maps_and_masks(
                batch, outputs, m2f_preprocessor_B
            )
            metric.add_batch(references=masks, predictions=pred_maps)
            

    # After compute the batches that were added are deleted
    mean_val_iou = metric.compute(
        num_labels=NUM_CLASSES, ignore_index=BG_VALUE_255, reduce_labels=False
    )["mean_iou"]

    epoch_train_loss = train_running_loss / len(dataloaders[CURR_TASK]["train"].dataset)
    epoch_val_loss = val_running_loss / len(dataloaders[CURR_TASK]["val"].dataset)

    writer.add_scalar(f"Loss/train_{model_name}_{CURR_TASK}", epoch_train_loss, epoch + 1)
    writer.add_scalar(f"Loss/val_{model_name}_{CURR_TASK}", epoch_val_loss, epoch + 1)
    writer.add_scalar(f"mIoU/train_{model_name}_{CURR_TASK}", mean_train_iou, epoch + 1)
    writer.add_scalar(f"mIoU/val_{model_name}_{CURR_TASK}", mean_val_iou, epoch + 1)

    wandb.log({
        f"Loss/train_A+{CURR_TASK}": epoch_train_loss,
        f"Loss/val_A+{CURR_TASK}": epoch_val_loss,
        f"mIoU/train_A+{CURR_TASK}": mean_train_iou,
        f"mIoU/val_A+{CURR_TASK}": mean_val_iou
    })

    tqdm.write(
        f"Epoch {epoch + 1}/{NUM_EPOCHS}, Train Loss: {epoch_train_loss:.4f}, Train mIoU: {mean_train_iou:.4f}, Validation Loss: {epoch_val_loss:.4f}, Validation mIoU: {mean_val_iou:.4f}"
    )
    if mean_val_iou > best_val_metric:
        best_val_metric = mean_val_iou
        model.save_pretrained(model_path_second)
        counter=0
    else:
        counter+=1
        if counter == PATIENCE:
            print("Early stopping at epoch",epoch)
            break

Epoch 1/200 Training: 100%|██████████| 334/334 [10:53<00:00,  1.96s/it, loss=159.9668]
Epoch 1/200 Validation: 100%|██████████| 14/14 [00:22<00:00,  1.62s/it, loss=174.6148]


Epoch 1/200, Train Loss: 12.2269, Train mIoU: 0.7891, Validation Loss: 12.5346, Validation mIoU: 0.5552


Epoch 2/200 Training: 100%|██████████| 334/334 [11:06<00:00,  1.99s/it, loss=148.3775]
Epoch 2/200 Validation: 100%|██████████| 14/14 [00:19<00:00,  1.43s/it, loss=151.7886]


Epoch 2/200, Train Loss: 9.4902, Train mIoU: 0.8563, Validation Loss: 10.8903, Validation mIoU: 0.6870


Epoch 3/200 Training: 100%|██████████| 334/334 [10:49<00:00,  1.94s/it, loss=136.3703]
Epoch 3/200 Validation: 100%|██████████| 14/14 [00:20<00:00,  1.45s/it, loss=148.0401]


Epoch 3/200, Train Loss: 8.7525, Train mIoU: 0.8820, Validation Loss: 9.7021, Validation mIoU: 0.7947


Epoch 4/200 Training: 100%|██████████| 334/334 [10:55<00:00,  1.96s/it, loss=180.6022]
Epoch 4/200 Validation: 100%|██████████| 14/14 [00:18<00:00,  1.35s/it, loss=156.6559]


Epoch 4/200, Train Loss: 8.3104, Train mIoU: 0.8840, Validation Loss: 9.5877, Validation mIoU: 0.7631


Epoch 5/200 Training: 100%|██████████| 334/334 [10:57<00:00,  1.97s/it, loss=123.5997]
Epoch 5/200 Validation: 100%|██████████| 14/14 [00:22<00:00,  1.58s/it, loss=169.5353]


Epoch 5/200, Train Loss: 7.8368, Train mIoU: 0.9126, Validation Loss: 9.2520, Validation mIoU: 0.7334


Epoch 6/200 Training: 100%|██████████| 334/334 [10:52<00:00,  1.96s/it, loss=136.4776]
Epoch 6/200 Validation: 100%|██████████| 14/14 [00:20<00:00,  1.49s/it, loss=112.9756]


Epoch 6/200, Train Loss: 7.5730, Train mIoU: 0.9168, Validation Loss: 8.5893, Validation mIoU: 0.7919


Epoch 7/200 Training: 100%|██████████| 334/334 [10:42<00:00,  1.92s/it, loss=125.7431]
Epoch 7/200 Validation: 100%|██████████| 14/14 [00:21<00:00,  1.54s/it, loss=120.9915]


Epoch 7/200, Train Loss: 7.2968, Train mIoU: 0.9233, Validation Loss: 8.3832, Validation mIoU: 0.7757


Epoch 8/200 Training: 100%|██████████| 334/334 [10:55<00:00,  1.96s/it, loss=102.1801]
Epoch 8/200 Validation: 100%|██████████| 14/14 [00:19<00:00,  1.41s/it, loss=107.4177]


Epoch 8/200, Train Loss: 6.9509, Train mIoU: 0.9302, Validation Loss: 8.3375, Validation mIoU: 0.8082


Epoch 9/200 Training: 100%|██████████| 334/334 [10:52<00:00,  1.96s/it, loss=85.7844] 
Epoch 9/200 Validation: 100%|██████████| 14/14 [00:19<00:00,  1.42s/it, loss=107.7070]


Epoch 9/200, Train Loss: 6.7584, Train mIoU: 0.9341, Validation Loss: 7.7436, Validation mIoU: 0.8446


Epoch 10/200 Training: 100%|██████████| 334/334 [10:55<00:00,  1.96s/it, loss=115.0919]
Epoch 10/200 Validation: 100%|██████████| 14/14 [00:18<00:00,  1.35s/it, loss=124.5227]


Epoch 10/200, Train Loss: 6.5894, Train mIoU: 0.9409, Validation Loss: 7.9311, Validation mIoU: 0.8064


Epoch 11/200 Training: 100%|██████████| 334/334 [10:45<00:00,  1.93s/it, loss=116.3852]
Epoch 11/200 Validation: 100%|██████████| 14/14 [00:19<00:00,  1.36s/it, loss=109.3057]


Epoch 11/200, Train Loss: 6.4495, Train mIoU: 0.9401, Validation Loss: 8.6182, Validation mIoU: 0.7538


Epoch 12/200 Training: 100%|██████████| 334/334 [10:56<00:00,  1.97s/it, loss=97.6583] 
Epoch 12/200 Validation: 100%|██████████| 14/14 [00:20<00:00,  1.44s/it, loss=189.7944]


Epoch 12/200, Train Loss: 6.3265, Train mIoU: 0.9402, Validation Loss: 8.0202, Validation mIoU: 0.8305


Epoch 13/200 Training: 100%|██████████| 334/334 [10:54<00:00,  1.96s/it, loss=83.9409] 
Epoch 13/200 Validation: 100%|██████████| 14/14 [00:19<00:00,  1.40s/it, loss=166.2131]


Epoch 13/200, Train Loss: 6.0628, Train mIoU: 0.9468, Validation Loss: 7.8697, Validation mIoU: 0.8194


Epoch 14/200 Training: 100%|██████████| 334/334 [10:48<00:00,  1.94s/it, loss=97.9259] 
Epoch 14/200 Validation: 100%|██████████| 14/14 [00:23<00:00,  1.71s/it, loss=105.7941]


Epoch 14/200, Train Loss: 6.0379, Train mIoU: 0.9405, Validation Loss: 7.9397, Validation mIoU: 0.7945


Epoch 15/200 Training: 100%|██████████| 334/334 [10:51<00:00,  1.95s/it, loss=88.8487] 
Epoch 15/200 Validation: 100%|██████████| 14/14 [00:20<00:00,  1.45s/it, loss=146.2966]


Epoch 15/200, Train Loss: 5.9496, Train mIoU: 0.9445, Validation Loss: 7.8810, Validation mIoU: 0.8511


Epoch 16/200 Training: 100%|██████████| 334/334 [10:51<00:00,  1.95s/it, loss=88.0879] 
Epoch 16/200 Validation: 100%|██████████| 14/14 [00:18<00:00,  1.35s/it, loss=116.7256]


Epoch 16/200, Train Loss: 6.0797, Train mIoU: 0.9330, Validation Loss: 8.2957, Validation mIoU: 0.7580


Epoch 17/200 Training: 100%|██████████| 334/334 [10:51<00:00,  1.95s/it, loss=80.7473] 
Epoch 17/200 Validation: 100%|██████████| 14/14 [00:19<00:00,  1.41s/it, loss=104.4139]


Epoch 17/200, Train Loss: 5.8240, Train mIoU: 0.9459, Validation Loss: 7.7586, Validation mIoU: 0.7633


Epoch 18/200 Training: 100%|██████████| 334/334 [10:49<00:00,  1.95s/it, loss=91.8050] 
Epoch 18/200 Validation: 100%|██████████| 14/14 [00:19<00:00,  1.37s/it, loss=132.3295]


Epoch 18/200, Train Loss: 5.6472, Train mIoU: 0.9504, Validation Loss: 8.3063, Validation mIoU: 0.8227


Epoch 19/200 Training: 100%|██████████| 334/334 [10:52<00:00,  1.95s/it, loss=83.2500] 
Epoch 19/200 Validation: 100%|██████████| 14/14 [00:19<00:00,  1.37s/it, loss=131.8750]


Epoch 19/200, Train Loss: 5.5010, Train mIoU: 0.9553, Validation Loss: 7.5506, Validation mIoU: 0.8362


Epoch 20/200 Training: 100%|██████████| 334/334 [10:55<00:00,  1.96s/it, loss=95.5137] 
Epoch 20/200 Validation: 100%|██████████| 14/14 [00:22<00:00,  1.58s/it, loss=99.1227] 


Epoch 20/200, Train Loss: 5.4049, Train mIoU: 0.9542, Validation Loss: 7.5598, Validation mIoU: 0.8622


Epoch 21/200 Training: 100%|██████████| 334/334 [10:46<00:00,  1.94s/it, loss=86.5527] 
Epoch 21/200 Validation: 100%|██████████| 14/14 [00:19<00:00,  1.38s/it, loss=114.3295]


Epoch 21/200, Train Loss: 5.3167, Train mIoU: 0.9564, Validation Loss: 7.7998, Validation mIoU: 0.8265


Epoch 22/200 Training: 100%|██████████| 334/334 [10:53<00:00,  1.96s/it, loss=93.7683] 
Epoch 22/200 Validation: 100%|██████████| 14/14 [00:21<00:00,  1.50s/it, loss=115.7745]


Epoch 22/200, Train Loss: 5.3428, Train mIoU: 0.9555, Validation Loss: 8.0463, Validation mIoU: 0.8111


Epoch 23/200 Training: 100%|██████████| 334/334 [10:56<00:00,  1.97s/it, loss=79.1328] 
Epoch 23/200 Validation: 100%|██████████| 14/14 [00:21<00:00,  1.52s/it, loss=109.4543]


Epoch 23/200, Train Loss: 5.3521, Train mIoU: 0.9469, Validation Loss: 8.3293, Validation mIoU: 0.7692


Epoch 24/200 Training: 100%|██████████| 334/334 [10:50<00:00,  1.95s/it, loss=71.5348] 
Epoch 24/200 Validation: 100%|██████████| 14/14 [00:21<00:00,  1.56s/it, loss=143.2441]


Epoch 24/200, Train Loss: 5.4763, Train mIoU: 0.9423, Validation Loss: 7.7064, Validation mIoU: 0.8193


Epoch 25/200 Training: 100%|██████████| 334/334 [10:51<00:00,  1.95s/it, loss=87.3346] 
Epoch 25/200 Validation: 100%|██████████| 14/14 [00:21<00:00,  1.51s/it, loss=126.5389]


Epoch 25/200, Train Loss: 5.2609, Train mIoU: 0.9572, Validation Loss: 8.1271, Validation mIoU: 0.8347


Epoch 26/200 Training: 100%|██████████| 334/334 [10:50<00:00,  1.95s/it, loss=81.2089] 
Epoch 26/200 Validation: 100%|██████████| 14/14 [00:20<00:00,  1.45s/it, loss=122.8888]


Epoch 26/200, Train Loss: 5.0475, Train mIoU: 0.9564, Validation Loss: 8.0226, Validation mIoU: 0.8325


Epoch 27/200 Training: 100%|██████████| 334/334 [10:53<00:00,  1.96s/it, loss=79.4973] 
Epoch 27/200 Validation: 100%|██████████| 14/14 [00:22<00:00,  1.58s/it, loss=97.8096] 


Epoch 27/200, Train Loss: 4.8845, Train mIoU: 0.9636, Validation Loss: 7.3859, Validation mIoU: 0.7768


Epoch 28/200 Training: 100%|██████████| 334/334 [10:54<00:00,  1.96s/it, loss=81.7715] 
Epoch 28/200 Validation: 100%|██████████| 14/14 [00:19<00:00,  1.38s/it, loss=111.1705]


Epoch 28/200, Train Loss: 4.8563, Train mIoU: 0.9610, Validation Loss: 7.8150, Validation mIoU: 0.8133


Epoch 29/200 Training: 100%|██████████| 334/334 [10:51<00:00,  1.95s/it, loss=84.7437] 
Epoch 29/200 Validation: 100%|██████████| 14/14 [00:20<00:00,  1.44s/it, loss=101.6447]


Epoch 29/200, Train Loss: 4.7794, Train mIoU: 0.9607, Validation Loss: 7.5795, Validation mIoU: 0.8144


Epoch 30/200 Training: 100%|██████████| 334/334 [10:54<00:00,  1.96s/it, loss=75.1428] 
Epoch 30/200 Validation: 100%|██████████| 14/14 [00:19<00:00,  1.41s/it, loss=124.0426]


Epoch 30/200, Train Loss: 4.7023, Train mIoU: 0.9586, Validation Loss: 8.0124, Validation mIoU: 0.8068


Epoch 31/200 Training: 100%|██████████| 334/334 [10:50<00:00,  1.95s/it, loss=69.8799] 
Epoch 31/200 Validation: 100%|██████████| 14/14 [00:22<00:00,  1.61s/it, loss=136.3914]


Epoch 31/200, Train Loss: 4.6417, Train mIoU: 0.9657, Validation Loss: 8.0394, Validation mIoU: 0.7760


Epoch 32/200 Training: 100%|██████████| 334/334 [10:52<00:00,  1.95s/it, loss=85.4359] 
Epoch 32/200 Validation: 100%|██████████| 14/14 [00:21<00:00,  1.55s/it, loss=146.0571]


Epoch 32/200, Train Loss: 4.6763, Train mIoU: 0.9625, Validation Loss: 7.9591, Validation mIoU: 0.7853


Epoch 33/200 Training:  87%|████████▋ | 290/334 [09:30<01:28,  2.00s/it, loss=75.1221] 

## Test results on B first

In [30]:
# Load best model and evaluate on test
model = Mask2FormerForUniversalSegmentation.from_pretrained(model_path_second).to(device)

In [31]:
model.eval()
test_running_loss = 0
test_loader = tqdm(dataloaders[CURR_TASK]["test"], desc="Test loop")
with torch.no_grad():
    for batch in test_loader:
        # Move everything to the device
        batch["pixel_values"] = batch["pixel_values"].to(device)
        batch["pixel_mask"] = batch["pixel_mask"].to(device)
        batch["mask_labels"] = [entry.to(device) for entry in batch["mask_labels"]]
        batch["class_labels"] = [entry.to(device) for entry in batch["class_labels"]]
        # Compute output and loss
        outputs = model(**batch)

        loss = outputs.loss
        # Record losses
        current_loss = loss.item() * batch["pixel_values"].size(0)
        test_running_loss += current_loss
        test_loader.set_postfix(loss=f"{current_loss:.4f}")

        # Extract and compute metrics
        pred_maps, masks = m2f_extract_pred_maps_and_masks(
            batch, outputs, m2f_preprocessor_B
        )
        metric.add_batch(references=masks, predictions=pred_maps)


# After compute the batches that were added are deleted
test_metrics_B = metric.compute(
    num_labels=NUM_CLASSES, ignore_index=BG_VALUE_255, reduce_labels=False
)
mean_test_iou = test_metrics_B["mean_iou"]
final_test_loss = test_running_loss / len(dataloaders[CURR_TASK]["test"].dataset)
wandb.log({
    f"Loss/test_{CURR_TASK}": final_test_loss,
    f"mIoU/test_{CURR_TASK}": mean_test_iou
})
print(f"Test Loss: {final_test_loss:.4f}, Test mIoU: {mean_test_iou:.4f}")

Test loop: 100%|██████████| 14/14 [00:22<00:00,  1.61s/it, loss=102.1515]


Test Loss: 7.7336, Test mIoU: 0.8278


## Test results on A after training on B

In [32]:
# To avoid making stupid errors
CURR_TASK = "A"

model.eval()
test_running_loss = 0
test_loader = tqdm(dataloaders[CURR_TASK]["test"], desc="Test loop")
with torch.no_grad():
    for batch in test_loader:
        # Move everything to the device
        batch["pixel_values"] = batch["pixel_values"].to(device)
        batch["pixel_mask"] = batch["pixel_mask"].to(device)
        batch["mask_labels"] = [entry.to(device) for entry in batch["mask_labels"]]
        batch["class_labels"] = [entry.to(device) for entry in batch["class_labels"]]
        # Compute output and loss
        outputs = model(**batch)

        loss = outputs.loss
        # Record losses
        current_loss = loss.item() * batch["pixel_values"].size(0)
        test_running_loss += current_loss
        test_loader.set_postfix(loss=f"{current_loss:.4f}")

        # Extract and compute metrics
        pred_maps, masks = m2f_extract_pred_maps_and_masks(
            batch, outputs, m2f_preprocessor_A
        )
        metric.add_batch(references=masks, predictions=pred_maps)
        
# After compute the batches that were added are deleted
test_metrics_forgetting_A = metric.compute(
    num_labels=NUM_CLASSES, ignore_index=BG_VALUE_255, reduce_labels=False
)
mean_test_iou = test_metrics_forgetting_A["mean_iou"]
final_test_loss = test_running_loss / len(dataloaders[CURR_TASK]["test"].dataset)
wandb.log({
    f"Loss/test_replay_all_{CURR_TASK}": final_test_loss,
    f"mIoU/test_replay_all_{CURR_TASK}": mean_test_iou
})
print(f"Test Loss: {final_test_loss:.4f}, Test mIoU: {mean_test_iou:.4f}")


Test loop: 100%|██████████| 36/36 [00:53<00:00,  1.48s/it, loss=266.1530]


Test Loss: 17.5114, Test mIoU: 0.7854


In [33]:
# Collect overall mIoU
mIoU_A = test_metrics_A["mean_iou"]
mIoU_forgetting_A = test_metrics_forgetting_A["mean_iou"]
mIoU_B = test_metrics_B["mean_iou"]

# Collect per category mIoU
per_category_mIoU_A = np.array(test_metrics_A["per_category_iou"])
per_category_mIoU_forgetting_A = np.array(test_metrics_forgetting_A["per_category_iou"])
per_category_mIoU_B = np.array(test_metrics_forgetting_A["per_category_iou"])

# Average learning accuracies (mIoUs)
avg_learning_acc = (mIoU_A + mIoU_B) / 2
per_category_avg_learning_acc = (per_category_mIoU_A + per_category_mIoU_B) / 2

# Forgetting
total_forgetting = mIoU_A - mIoU_forgetting_A
per_category_forgetting = (per_category_mIoU_A - per_category_mIoU_forgetting_A)

# Export evaluation metrics to WandB
wandb.log({
    "eval/avg_learning_acc": avg_learning_acc,
    "eval/per_category_avg_learning_acc": per_category_avg_learning_acc,
    "eval/total_forgetting": total_forgetting,
    "eval/per_category_forgetting": per_category_forgetting
})
print("**** Overall mIoU ****")
print(f"mIoU on task A: {mIoU_A}")
print(f"mIoU on task B: {mIoU_B}")
print(f"mIoU on task A after training on B: {mIoU_forgetting_A}")

print("\n**** Per category mIoU ****")
print(f"Per category mIoU on task A: {per_category_mIoU_A}")
print(f"Per category mIoU on task B: {per_category_mIoU_B}")
print(f"Per category mIoU on task A after training on B: {per_category_mIoU_forgetting_A}")

print("\n**** Average learning accuracies ****")
print(f"Average learning acc.: {avg_learning_acc}")
print(f"Per category Average learning acc.: {per_category_avg_learning_acc}")

print("\n**** Forgetting ****")
print(f"Total forgetting: {total_forgetting}")
print(f"Per category forgetting: {per_category_forgetting}")
wandb.finish()

**** Overall mIoU ****
mIoU on task A: 0.7715989272025321
mIoU on task B: 0.8277712327043516
mIoU on task A after training on B: 0.7854313993000791

**** Per category mIoU ****
Per category mIoU on task A: [0.96005866 0.93660036 0.63264134 0.41352694 0.39168077 0.81113885
 0.80177793 0.7707886  0.89548438 0.94156179 0.93232858]
Per category mIoU on task B: [0.96077659 0.93669937 0.64262966 0.59210555 0.44402038 0.81422925
 0.75690367 0.75457962 0.88294103 0.93135118 0.92350911]
Per category mIoU on task A after training on B: [0.96077659 0.93669937 0.64262966 0.59210555 0.44402038 0.81422925
 0.75690367 0.75457962 0.88294103 0.93135118 0.92350911]

**** Average learning accuracies ****
Average learning acc.: 0.7996850799534418
Per category Average learning acc.: [0.96041762 0.93664986 0.6376355  0.50281624 0.41785058 0.81268405
 0.7793408  0.76268411 0.8892127  0.93645649 0.92791884]

**** Forgetting ****
Total forgetting: -0.01383247209754701
Per category forgetting: [-7.17928432e-04 