In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import os
from utils.common import (
    m2f_dataset_collate,
    m2f_extract_pred_maps_and_masks,
    BG_VALUE,
    set_seed,
    pixel_mean_std,
    CADIS_PIXEL_MEAN,
    CADIS_PIXEL_STD,
    CAT1K_PIXEL_MEAN,
    CAT1K_PIXEL_STD,
)
from utils.dataset_utils import (
    get_cadisv2_dataset,
    get_cataract1k_dataset,
    ZEISS_CATEGORIES,
)
from utils.medical_datasets import Mask2FormerDataset
from transformers import (
    Mask2FormerForUniversalSegmentation,
    SwinModel,
    SwinConfig,
    Mask2FormerConfig,
    AutoImageProcessor,
    Mask2FormerImageProcessor
)
from torch.utils.data import DataLoader
import evaluate
import torch.optim as optim
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from dotenv import load_dotenv
import wandb

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
set_seed(42) # seed everything

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Using device: {device}")

Using device: cuda


In [5]:
NUM_CLASSES = len(ZEISS_CATEGORIES) - 3  # Remove class incremental
model = Mask2FormerForUniversalSegmentation.from_pretrained(f"outputs/models/m2f_swin_backbone_train_cadis/best_model/A").to(device)

In [6]:
# Helper function to load datasets
def load_dataset(dataset_getter, data_path, domain_incremental):
    return dataset_getter(data_path, domain_incremental=domain_incremental)


# Helper function to create dataloaders for a dataset
def create_dataloaders(
    dataset, batch_size, shuffle, num_workers, drop_last, pin_memory, collate_fn
):
    return {
        "train": DataLoader(
            dataset["train"],
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=num_workers,
            drop_last=drop_last,
            pin_memory=pin_memory,
            collate_fn=collate_fn,
        ),
        "val": DataLoader(
            dataset["val"],
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=num_workers,
            drop_last=drop_last,
            pin_memory=pin_memory,
            collate_fn=collate_fn,
        ),
        "test": DataLoader(
            dataset["test"],
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=num_workers,
            drop_last=drop_last,
            pin_memory=pin_memory,
            collate_fn=collate_fn,
        ),
    }


# Load datasets
datasets = {
    "A": load_dataset(get_cadisv2_dataset, "../../storage/data/CaDISv2", True),
    "B": load_dataset(get_cataract1k_dataset, "../../storage/data/cataract-1k", True),
}

# pixel_mean_A,pixel_std_A=pixel_mean_std(datasets["A"][0])
pixel_mean_A = CADIS_PIXEL_MEAN
pixel_std_A = CADIS_PIXEL_STD

# pixel_mean_B,pixel_std_B=pixel_mean_std(datasets["B"][0])
pixel_mean_B = CAT1K_PIXEL_MEAN
pixel_std_B = CAT1K_PIXEL_STD

# Define preprocessor
m2f_preprocessor_A = Mask2FormerImageProcessor(
    reduce_labels=True,
    ignore_index=255,
    do_resize=False,
    do_rescale=True,
    do_normalize=True,
    image_std=pixel_std_A,
    image_mean=pixel_mean_A,
)

m2f_preprocessor_B = Mask2FormerImageProcessor(
    reduce_labels=True,
    ignore_index=255,
    do_resize=False,
    do_rescale=True,
    do_normalize=True,
    image_std=pixel_std_B,
    image_mean=pixel_mean_B,
)

# Create Mask2Former Datasets
m2f_datasets = {
    "A": {
        "train": Mask2FormerDataset(datasets["A"][0], m2f_preprocessor_A),
        "val": Mask2FormerDataset(datasets["A"][1], m2f_preprocessor_A),
        "test": Mask2FormerDataset(datasets["A"][2], m2f_preprocessor_A),
    },
    "B": {
        "train": Mask2FormerDataset(datasets["B"][0], m2f_preprocessor_B),
        "val": Mask2FormerDataset(datasets["B"][1], m2f_preprocessor_B),
        "test": Mask2FormerDataset(datasets["B"][2], m2f_preprocessor_B),
    },
}

# DataLoader parameters
N_WORKERS = 4
BATCH_SIZE = 16
SHUFFLE = True
DROP_LAST = True

dataloader_params = {
    "batch_size": BATCH_SIZE,
    "shuffle": SHUFFLE,
    "num_workers": N_WORKERS,
    "drop_last": DROP_LAST,
    "pin_memory": True,
    "collate_fn": m2f_dataset_collate,
}

# Create DataLoaders
dataloaders = {
    key: create_dataloaders(m2f_datasets[key], **dataloader_params)
    for key in m2f_datasets
}

print(dataloaders)

{'A': {'train': <torch.utils.data.dataloader.DataLoader object at 0x7f1132ee4390>, 'val': <torch.utils.data.dataloader.DataLoader object at 0x7f125d34b110>, 'test': <torch.utils.data.dataloader.DataLoader object at 0x7f113c359b50>}, 'B': {'train': <torch.utils.data.dataloader.DataLoader object at 0x7f11313b0bd0>, 'val': <torch.utils.data.dataloader.DataLoader object at 0x7f11313b0d90>, 'test': <torch.utils.data.dataloader.DataLoader object at 0x7f11313b0ed0>}}


In [7]:
# Tensorboard setup
out_dir="outputs/"
if not os.path.exists(out_dir):
    os.makedirs(out_dir)
if not os.path.exists(out_dir+"runs"):
    os.makedirs(out_dir+"runs")
%load_ext tensorboard
%tensorboard --logdir outputs/runs

In [8]:
wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33manarlee[0m ([33mcontinual-learning-tum[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [9]:
# Tensorboard logging
writer = SummaryWriter(log_dir=out_dir + "runs")

# Model checkpointing
base_model_name="naive_forgetting"
model_dir = out_dir + "models/"
if not os.path.exists(model_dir):
    print("Store weights in: ", model_dir)
    os.makedirs(model_dir)

best_model_dir = model_dir + f"{base_model_name}/best_model/"
if not os.path.exists(best_model_dir):
    print("Store best model weights in: ", best_model_dir)
    os.makedirs(best_model_dir)
final_model_dir = model_dir + f"{base_model_name}/final_model/"
if not os.path.exists(final_model_dir):
    print("Store final model weights in: ", final_model_dir)
    os.makedirs(final_model_dir)

In [10]:
m2f_preprocessor_A.save_pretrained(model_dir + base_model_name)

['outputs/models/naive_forgetting/preprocessor_config.json']

In [12]:
metric = evaluate.load("mean_iou")

In [13]:
CURR_TASK = "A"
model.eval()
test_running_loss = 0
test_loader = tqdm(dataloaders[CURR_TASK]["test"], desc="Test loop")
with torch.no_grad():
    for batch in test_loader:
        # Move everything to the device
        batch["pixel_values"] = batch["pixel_values"].to(device)
        batch["pixel_mask"] = batch["pixel_mask"].to(device)
        batch["mask_labels"] = [entry.to(device) for entry in batch["mask_labels"]]
        batch["class_labels"] = [entry.to(device) for entry in batch["class_labels"]]
        # Compute output and loss
        outputs = model(**batch)

        loss = outputs.loss
        # Record losses
        current_loss = loss.item() * batch["pixel_values"].size(0)
        test_running_loss += current_loss
        test_loader.set_postfix(loss=f"{current_loss:.4f}")

        # Extract and compute metrics
        pred_maps, masks = m2f_extract_pred_maps_and_masks(
            batch, outputs, m2f_preprocessor_A
        )
        metric.add_batch(references=masks, predictions=pred_maps)
        
# After compute the batches that were added are deleted
test_metrics_A = metric.compute(
    num_labels=NUM_CLASSES, ignore_index=BG_VALUE, reduce_labels=False
)
mean_test_iou = test_metrics_A["mean_iou"]
final_test_loss = test_running_loss / len(dataloaders[CURR_TASK]["test"].dataset)
# wandb.log({
#     f"Loss/test_{CURR_TASK}": final_test_loss,
#     f"mIoU/test_{CURR_TASK}": mean_test_iou
# })
print(f"Test Loss: {final_test_loss:.4f}, Test mIoU: {mean_test_iou:.4f}")

Test loop: 100%|██████████| 36/36 [00:33<00:00,  1.09it/s, loss=329.8553]


Test Loss: 15.0059, Test mIoU: 0.7352


  acc = total_area_intersect / total_area_label


In [14]:
# Training
NUM_EPOCHS = 200
LEARNING_RATE = 1e-4
LR_MULTIPLIER = 0.1
BACKBONE_LR = LEARNING_RATE * LR_MULTIPLIER
WEIGHT_DECAY = 0.05
encoder_params = [
    param
    for name, param in model.named_parameters()
    if name.startswith("model.pixel_level_module.encoder")
]
decoder_params = [
    param
    for name, param in model.named_parameters()
    if name.startswith("model.pixel_level_module.decoder")
]
transformer_params = [
    param
    for name, param in model.named_parameters()
    if name.startswith("model.transformer_module")
]
optimizer = optim.AdamW(
    [
        {"params": encoder_params, "lr": BACKBONE_LR},
        {"params": decoder_params},
        {"params": transformer_params},
    ],
    lr=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
)

scheduler = optim.lr_scheduler.PolynomialLR(
    optimizer, total_iters=NUM_EPOCHS, power=0.9)

In [15]:
wandb.init(
    project="M2F_original",
    config={
        "learning_rate": LEARNING_RATE,
        "learning_rate_multiplier": LR_MULTIPLIER,
        "backbone_learning_rate": BACKBONE_LR,
        "learning_rate_scheduler": scheduler.__class__.__name__,
        "optimizer": optimizer.__class__.__name__,
        "backbone": "m2f_swin_backbone_train_cadis",
        "m2f_preprocessor": m2f_preprocessor_B.__dict__,
        "m2f_model_config": model.config
    },
    name="M2F-Swin-Tiny-Naive-Forgetting-200",
    notes="M2F with tiny Swin backbone pretrained on ImageNet-1K. \
        Scenario: Pretrained on A, Train on B, Test forgetting on A"
)
# Tensorboard logging
writer = SummaryWriter(log_dir=out_dir + "runs")
# Model checkpointing
model_name = "m2f_swin_backbone_naive_forgetting-b"
model_dir = out_dir + "models/"
if not os.path.exists(model_dir):
    print("Store weights in: ", model_dir)
    os.makedirs(model_dir)

best_model_dir = model_dir + f"{model_name}/best_model/"
if not os.path.exists(best_model_dir):
    print("Store best model weights in: ", best_model_dir)
    os.makedirs(best_model_dir)
final_model_dir = model_dir + f"{model_name}/final_model/"
if not os.path.exists(final_model_dir):
    print("Store final model weights in: ", final_model_dir)
    os.makedirs(final_model_dir)

In [16]:
m2f_preprocessor_B.save_pretrained(model_dir + model_name)

['outputs/models/m2f_swin_backbone_naive_forgetting-b/preprocessor_config.json']

In [17]:
# To avoid making stupid errors
CURR_TASK = "B"

# For storing the model
best_val_metric = -np.inf

# Move model to device
model.to(device)

for epoch in range(NUM_EPOCHS):
    model.train()
    train_running_loss = 0.0
    val_running_loss = 0.0

    # Set up tqdm for the training loop
    train_loader = tqdm(
        dataloaders[CURR_TASK]["train"], desc=f"Epoch {epoch + 1}/{NUM_EPOCHS} Training"
    )

    for batch in train_loader:
        # Move everything to the device
        batch["pixel_values"] = batch["pixel_values"].to(device)
        batch["pixel_mask"] = batch["pixel_mask"].to(device)
        batch["mask_labels"] = [entry.to(device) for entry in batch["mask_labels"]]
        batch["class_labels"] = [entry.to(device) for entry in batch["class_labels"]]

        # Compute output and loss
        outputs = model(**batch)

        loss = outputs.loss

        # Compute gradient and perform step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Record losses
        current_loss = loss.item() * batch["pixel_values"].size(0)
        train_running_loss += current_loss
        train_loader.set_postfix(loss=f"{current_loss:.4f}")

        # Extract and compute metrics
        pred_maps, masks = m2f_extract_pred_maps_and_masks(
            batch, outputs, m2f_preprocessor_B
        )
        metric.add_batch(references=masks, predictions=pred_maps)

    # After compute the batches that were added are deleted
    mean_train_iou = metric.compute(
        num_labels=NUM_CLASSES, ignore_index=BG_VALUE, reduce_labels=False
    )["mean_iou"]

    # Validation phase
    model.eval()
    val_loader = tqdm(
        dataloaders[CURR_TASK]["val"], desc=f"Epoch {epoch + 1}/{NUM_EPOCHS} Validation"
    )
    with torch.no_grad():
        for batch in val_loader:
            # Move everything to the device
            batch["pixel_values"] = batch["pixel_values"].to(device)
            batch["pixel_mask"] = batch["pixel_mask"].to(device)
            batch["mask_labels"] = [entry.to(device) for entry in batch["mask_labels"]]
            batch["class_labels"] = [
                entry.to(device) for entry in batch["class_labels"]
            ]
            # Compute output and loss
            outputs = model(**batch)

            loss = outputs.loss
            # Record losses
            current_loss = loss.item() * batch["pixel_values"].size(0)
            val_running_loss += current_loss
            val_loader.set_postfix(loss=f"{current_loss:.4f}")

            # Extract and compute metrics
            pred_maps, masks = m2f_extract_pred_maps_and_masks(
                batch, outputs, m2f_preprocessor_B
            )
            metric.add_batch(references=masks, predictions=pred_maps)

    # After compute the batches that were added are deleted
    mean_val_iou = metric.compute(
        num_labels=NUM_CLASSES, ignore_index=BG_VALUE, reduce_labels=False
    )["mean_iou"]

    epoch_train_loss = train_running_loss / len(dataloaders[CURR_TASK]["train"].dataset)
    epoch_val_loss = val_running_loss / len(dataloaders[CURR_TASK]["val"].dataset)

    writer.add_scalar(f"Loss/train_{model_name}_{CURR_TASK}", epoch_train_loss, epoch + 1)
    writer.add_scalar(f"Loss/val_{model_name}_{CURR_TASK}", epoch_val_loss, epoch + 1)
    writer.add_scalar(f"mIoU/train_{model_name}_{CURR_TASK}", mean_train_iou, epoch + 1)
    writer.add_scalar(f"mIoU/val_{model_name}_{CURR_TASK}", mean_val_iou, epoch + 1)

    wandb.log({
        f"Loss/train_{CURR_TASK}": epoch_train_loss,
        f"Loss/val_{CURR_TASK}": epoch_val_loss,
        f"mIoU/train_{CURR_TASK}": mean_train_iou,
        f"mIoU/val_{CURR_TASK}": mean_val_iou
    })

    if mean_val_iou > best_val_metric:
        best_val_metric = mean_val_iou
        model.save_pretrained(f"{best_model_dir}{CURR_TASK}/")

    tqdm.write(
        f"Epoch {epoch + 1}/{NUM_EPOCHS}, Train Loss: {epoch_train_loss:.4f}, Train mIoU: {mean_train_iou:.4f}, Validation Loss: {epoch_val_loss:.4f}, Validation mIoU: {mean_val_iou:.4f}"
    )

Epoch 1/200 Training: 100%|██████████| 112/112 [01:57<00:00,  1.05s/it, loss=275.9501]
  acc = total_area_intersect / total_area_label
Epoch 1/200 Validation: 100%|██████████| 14/14 [00:09<00:00,  1.46it/s, loss=253.6964]


Epoch 1/200, Train Loss: 22.9025, Train mIoU: 0.4203, Validation Loss: 16.6096, Validation mIoU: 0.5314


Epoch 2/200 Training: 100%|██████████| 112/112 [02:01<00:00,  1.08s/it, loss=240.8674]
Epoch 2/200 Validation: 100%|██████████| 14/14 [00:14<00:00,  1.01s/it, loss=190.2122]


Epoch 2/200, Train Loss: 15.5788, Train mIoU: 0.5907, Validation Loss: 11.7769, Validation mIoU: 0.7086


Epoch 3/200 Training: 100%|██████████| 112/112 [02:00<00:00,  1.08s/it, loss=127.9669]
Epoch 3/200 Validation: 100%|██████████| 14/14 [00:09<00:00,  1.48it/s, loss=147.2661]


Epoch 3/200, Train Loss: 11.8668, Train mIoU: 0.7144, Validation Loss: 10.3128, Validation mIoU: 0.7349


Epoch 4/200 Training: 100%|██████████| 112/112 [02:03<00:00,  1.11s/it, loss=137.3650]
Epoch 4/200 Validation: 100%|██████████| 14/14 [00:10<00:00,  1.36it/s, loss=181.9393]


Epoch 4/200, Train Loss: 10.0055, Train mIoU: 0.7534, Validation Loss: 9.3656, Validation mIoU: 0.7713


Epoch 5/200 Training: 100%|██████████| 112/112 [01:57<00:00,  1.05s/it, loss=141.9254]
Epoch 5/200 Validation: 100%|██████████| 14/14 [00:08<00:00,  1.58it/s, loss=162.7552]


Epoch 5/200, Train Loss: 8.9175, Train mIoU: 0.7676, Validation Loss: 9.6687, Validation mIoU: 0.7280


Epoch 6/200 Training: 100%|██████████| 112/112 [02:06<00:00,  1.13s/it, loss=137.1364]
Epoch 6/200 Validation: 100%|██████████| 14/14 [00:12<00:00,  1.12it/s, loss=153.3968]


Epoch 6/200, Train Loss: 8.4230, Train mIoU: 0.7877, Validation Loss: 8.8760, Validation mIoU: 0.7461


Epoch 7/200 Training: 100%|██████████| 112/112 [02:01<00:00,  1.08s/it, loss=119.9177]
Epoch 7/200 Validation: 100%|██████████| 14/14 [00:10<00:00,  1.29it/s, loss=141.6956]


Epoch 7/200, Train Loss: 7.8464, Train mIoU: 0.8103, Validation Loss: 8.6316, Validation mIoU: 0.7664


Epoch 8/200 Training: 100%|██████████| 112/112 [02:05<00:00,  1.12s/it, loss=149.1364]
Epoch 8/200 Validation: 100%|██████████| 14/14 [00:11<00:00,  1.21it/s, loss=115.0689]


Epoch 8/200, Train Loss: 7.5442, Train mIoU: 0.8184, Validation Loss: 8.0913, Validation mIoU: 0.7644


Epoch 9/200 Training: 100%|██████████| 112/112 [02:05<00:00,  1.12s/it, loss=115.7365]
Epoch 9/200 Validation: 100%|██████████| 14/14 [00:09<00:00,  1.40it/s, loss=111.9807]


Epoch 9/200, Train Loss: 7.1263, Train mIoU: 0.8359, Validation Loss: 8.3469, Validation mIoU: 0.7508


Epoch 10/200 Training: 100%|██████████| 112/112 [02:03<00:00,  1.11s/it, loss=149.6504]
Epoch 10/200 Validation: 100%|██████████| 14/14 [00:09<00:00,  1.45it/s, loss=167.7110]


Epoch 10/200, Train Loss: 7.6502, Train mIoU: 0.8257, Validation Loss: 11.6121, Validation mIoU: 0.6923


Epoch 11/200 Training: 100%|██████████| 112/112 [02:01<00:00,  1.08s/it, loss=114.1203]
Epoch 11/200 Validation: 100%|██████████| 14/14 [00:09<00:00,  1.51it/s, loss=105.6527]


Epoch 11/200, Train Loss: 7.0374, Train mIoU: 0.8246, Validation Loss: 8.1429, Validation mIoU: 0.7429


Epoch 12/200 Training: 100%|██████████| 112/112 [02:01<00:00,  1.08s/it, loss=107.9282]
Epoch 12/200 Validation: 100%|██████████| 14/14 [00:09<00:00,  1.52it/s, loss=128.7861]


Epoch 12/200, Train Loss: 6.5493, Train mIoU: 0.8526, Validation Loss: 8.1093, Validation mIoU: 0.7662


Epoch 13/200 Training: 100%|██████████| 112/112 [02:05<00:00,  1.12s/it, loss=140.8829]
Epoch 13/200 Validation: 100%|██████████| 14/14 [00:09<00:00,  1.47it/s, loss=103.5610]


Epoch 13/200, Train Loss: 6.6065, Train mIoU: 0.8491, Validation Loss: 8.6595, Validation mIoU: 0.7736


Epoch 14/200 Training: 100%|██████████| 112/112 [02:05<00:00,  1.12s/it, loss=103.5890]
Epoch 14/200 Validation: 100%|██████████| 14/14 [00:10<00:00,  1.38it/s, loss=132.5417]


Epoch 14/200, Train Loss: 6.2999, Train mIoU: 0.8419, Validation Loss: 8.1991, Validation mIoU: 0.7450


Epoch 15/200 Training: 100%|██████████| 112/112 [02:00<00:00,  1.08s/it, loss=100.5470]
Epoch 15/200 Validation: 100%|██████████| 14/14 [00:09<00:00,  1.45it/s, loss=134.3832]


Epoch 15/200, Train Loss: 6.3554, Train mIoU: 0.8532, Validation Loss: 7.7134, Validation mIoU: 0.7579


Epoch 16/200 Training: 100%|██████████| 112/112 [01:59<00:00,  1.07s/it, loss=91.1026]
Epoch 16/200 Validation: 100%|██████████| 14/14 [00:13<00:00,  1.04it/s, loss=112.0637]


Epoch 16/200, Train Loss: 5.8959, Train mIoU: 0.8605, Validation Loss: 7.7093, Validation mIoU: 0.7703


Epoch 17/200 Training: 100%|██████████| 112/112 [01:59<00:00,  1.07s/it, loss=95.7740] 
Epoch 17/200 Validation: 100%|██████████| 14/14 [00:09<00:00,  1.47it/s, loss=120.1680]


Epoch 17/200, Train Loss: 5.8645, Train mIoU: 0.8636, Validation Loss: 8.0901, Validation mIoU: 0.7598


Epoch 18/200 Training: 100%|██████████| 112/112 [01:59<00:00,  1.07s/it, loss=102.4421]
Epoch 18/200 Validation: 100%|██████████| 14/14 [00:12<00:00,  1.11it/s, loss=130.3600]


Epoch 18/200, Train Loss: 5.6744, Train mIoU: 0.8674, Validation Loss: 7.6826, Validation mIoU: 0.7549


Epoch 19/200 Training: 100%|██████████| 112/112 [02:02<00:00,  1.09s/it, loss=79.3289] 
Epoch 19/200 Validation: 100%|██████████| 14/14 [00:10<00:00,  1.33it/s, loss=105.0657]


Epoch 19/200, Train Loss: 5.6974, Train mIoU: 0.8616, Validation Loss: 8.4423, Validation mIoU: 0.7469


Epoch 20/200 Training: 100%|██████████| 112/112 [02:07<00:00,  1.14s/it, loss=84.5511] 
Epoch 20/200 Validation: 100%|██████████| 14/14 [00:12<00:00,  1.13it/s, loss=118.5611]


Epoch 20/200, Train Loss: 5.4746, Train mIoU: 0.8757, Validation Loss: 7.8613, Validation mIoU: 0.7554


Epoch 21/200 Training: 100%|██████████| 112/112 [02:02<00:00,  1.10s/it, loss=85.7838] 
Epoch 21/200 Validation: 100%|██████████| 14/14 [00:11<00:00,  1.17it/s, loss=165.8735]


Epoch 21/200, Train Loss: 5.4987, Train mIoU: 0.8660, Validation Loss: 7.8093, Validation mIoU: 0.7435


Epoch 22/200 Training: 100%|██████████| 112/112 [02:00<00:00,  1.08s/it, loss=97.2870] 
Epoch 22/200 Validation: 100%|██████████| 14/14 [00:09<00:00,  1.46it/s, loss=127.4018]


Epoch 22/200, Train Loss: 5.4203, Train mIoU: 0.8763, Validation Loss: 8.2343, Validation mIoU: 0.6659


Epoch 23/200 Training: 100%|██████████| 112/112 [02:04<00:00,  1.11s/it, loss=106.7787]
Epoch 23/200 Validation: 100%|██████████| 14/14 [00:09<00:00,  1.47it/s, loss=126.8006]


Epoch 23/200, Train Loss: 5.4855, Train mIoU: 0.8719, Validation Loss: 7.9564, Validation mIoU: 0.7717


Epoch 24/200 Training: 100%|██████████| 112/112 [01:58<00:00,  1.06s/it, loss=95.3583] 
Epoch 24/200 Validation: 100%|██████████| 14/14 [00:09<00:00,  1.43it/s, loss=139.3829]


Epoch 24/200, Train Loss: 5.6196, Train mIoU: 0.8440, Validation Loss: 7.9240, Validation mIoU: 0.7825


Epoch 25/200 Training: 100%|██████████| 112/112 [01:59<00:00,  1.07s/it, loss=78.6507]
Epoch 25/200 Validation: 100%|██████████| 14/14 [00:09<00:00,  1.48it/s, loss=123.3385]


Epoch 25/200, Train Loss: 5.1429, Train mIoU: 0.8688, Validation Loss: 7.5422, Validation mIoU: 0.7627


Epoch 26/200 Training: 100%|██████████| 112/112 [01:58<00:00,  1.06s/it, loss=77.3643]
Epoch 26/200 Validation: 100%|██████████| 14/14 [00:14<00:00,  1.01s/it, loss=91.9193] 


Epoch 26/200, Train Loss: 4.8606, Train mIoU: 0.8824, Validation Loss: 7.5539, Validation mIoU: 0.7718


Epoch 27/200 Training: 100%|██████████| 112/112 [02:05<00:00,  1.12s/it, loss=69.6758]
Epoch 27/200 Validation: 100%|██████████| 14/14 [00:10<00:00,  1.29it/s, loss=110.6696]


Epoch 27/200, Train Loss: 4.8508, Train mIoU: 0.8692, Validation Loss: 7.5291, Validation mIoU: 0.7429


Epoch 28/200 Training: 100%|██████████| 112/112 [01:59<00:00,  1.06s/it, loss=69.4213]
Epoch 28/200 Validation: 100%|██████████| 14/14 [00:11<00:00,  1.27it/s, loss=94.3941] 


Epoch 28/200, Train Loss: 4.8242, Train mIoU: 0.8767, Validation Loss: 7.6581, Validation mIoU: 0.7725


Epoch 29/200 Training: 100%|██████████| 112/112 [02:10<00:00,  1.17s/it, loss=69.8156]
Epoch 29/200 Validation: 100%|██████████| 14/14 [00:09<00:00,  1.48it/s, loss=150.8713]


Epoch 29/200, Train Loss: 4.7298, Train mIoU: 0.8809, Validation Loss: 7.7486, Validation mIoU: 0.7211


Epoch 30/200 Training: 100%|██████████| 112/112 [02:00<00:00,  1.08s/it, loss=70.1466]
Epoch 30/200 Validation: 100%|██████████| 14/14 [00:12<00:00,  1.13it/s, loss=105.8971]


Epoch 30/200, Train Loss: 4.7730, Train mIoU: 0.8734, Validation Loss: 8.1137, Validation mIoU: 0.7762


Epoch 31/200 Training: 100%|██████████| 112/112 [02:04<00:00,  1.11s/it, loss=64.4183]
Epoch 31/200 Validation: 100%|██████████| 14/14 [00:12<00:00,  1.09it/s, loss=116.5353]


Epoch 31/200, Train Loss: 4.7995, Train mIoU: 0.8809, Validation Loss: 7.5503, Validation mIoU: 0.7862


Epoch 33/200 Training: 100%|██████████| 112/112 [02:05<00:00,  1.12s/it, loss=69.5825]
Epoch 33/200 Validation: 100%|██████████| 14/14 [00:13<00:00,  1.04it/s, loss=107.2553]


Epoch 33/200, Train Loss: 4.5929, Train mIoU: 0.8785, Validation Loss: 7.6055, Validation mIoU: 0.7619


Epoch 34/200 Training: 100%|██████████| 112/112 [02:01<00:00,  1.09s/it, loss=74.7053]
Epoch 34/200 Validation: 100%|██████████| 14/14 [00:11<00:00,  1.23it/s, loss=149.1676]


Epoch 34/200, Train Loss: 4.5238, Train mIoU: 0.8840, Validation Loss: 7.4858, Validation mIoU: 0.7518


Epoch 35/200 Training: 100%|██████████| 112/112 [02:07<00:00,  1.14s/it, loss=89.8614]
Epoch 35/200 Validation: 100%|██████████| 14/14 [00:09<00:00,  1.40it/s, loss=105.7646]


Epoch 35/200, Train Loss: 4.7800, Train mIoU: 0.8839, Validation Loss: 8.4405, Validation mIoU: 0.7332


Epoch 36/200 Training: 100%|██████████| 112/112 [01:58<00:00,  1.06s/it, loss=74.8241]
Epoch 36/200 Validation: 100%|██████████| 14/14 [00:10<00:00,  1.34it/s, loss=129.0229]


Epoch 36/200, Train Loss: 5.1755, Train mIoU: 0.8692, Validation Loss: 7.8997, Validation mIoU: 0.6674


Epoch 37/200 Training: 100%|██████████| 112/112 [02:07<00:00,  1.14s/it, loss=66.2714]
Epoch 37/200 Validation: 100%|██████████| 14/14 [00:10<00:00,  1.35it/s, loss=131.9821]


Epoch 37/200, Train Loss: 4.9682, Train mIoU: 0.8507, Validation Loss: 7.5454, Validation mIoU: 0.7733


Epoch 38/200 Training: 100%|██████████| 112/112 [02:01<00:00,  1.09s/it, loss=80.8728]
Epoch 38/200 Validation: 100%|██████████| 14/14 [00:09<00:00,  1.51it/s, loss=134.4797]


Epoch 38/200, Train Loss: 4.7400, Train mIoU: 0.8682, Validation Loss: 7.5057, Validation mIoU: 0.7645


Epoch 39/200 Training: 100%|██████████| 112/112 [02:09<00:00,  1.15s/it, loss=71.0283]
Epoch 39/200 Validation: 100%|██████████| 14/14 [00:10<00:00,  1.29it/s, loss=107.7461]


Epoch 39/200, Train Loss: 4.2832, Train mIoU: 0.8865, Validation Loss: 7.5352, Validation mIoU: 0.7557


Epoch 40/200 Training: 100%|██████████| 112/112 [02:08<00:00,  1.15s/it, loss=71.0680]
Epoch 40/200 Validation: 100%|██████████| 14/14 [00:10<00:00,  1.30it/s, loss=134.2769]


Epoch 40/200, Train Loss: 4.2970, Train mIoU: 0.8853, Validation Loss: 7.6954, Validation mIoU: 0.8001


Epoch 41/200 Training: 100%|██████████| 112/112 [02:04<00:00,  1.11s/it, loss=69.4115]
Epoch 41/200 Validation: 100%|██████████| 14/14 [00:11<00:00,  1.27it/s, loss=96.4566] 


Epoch 41/200, Train Loss: 4.2057, Train mIoU: 0.8809, Validation Loss: 7.5585, Validation mIoU: 0.7909


Epoch 42/200 Training: 100%|██████████| 112/112 [01:57<00:00,  1.05s/it, loss=65.4313]
Epoch 42/200 Validation: 100%|██████████| 14/14 [00:09<00:00,  1.40it/s, loss=127.2476]


Epoch 42/200, Train Loss: 4.0823, Train mIoU: 0.8818, Validation Loss: 7.4293, Validation mIoU: 0.7745


Epoch 43/200 Training: 100%|██████████| 112/112 [01:57<00:00,  1.05s/it, loss=67.3288]
Epoch 43/200 Validation: 100%|██████████| 14/14 [00:10<00:00,  1.34it/s, loss=99.6713] 


Epoch 43/200, Train Loss: 4.0455, Train mIoU: 0.8891, Validation Loss: 7.5998, Validation mIoU: 0.7858


Epoch 44/200 Training: 100%|██████████| 112/112 [02:05<00:00,  1.12s/it, loss=75.1746]
Epoch 44/200 Validation: 100%|██████████| 14/14 [00:14<00:00,  1.05s/it, loss=121.9317]


Epoch 44/200, Train Loss: 4.0979, Train mIoU: 0.8800, Validation Loss: 7.7144, Validation mIoU: 0.8081


Epoch 45/200 Training: 100%|██████████| 112/112 [02:00<00:00,  1.08s/it, loss=73.5646] 
Epoch 45/200 Validation: 100%|██████████| 14/14 [00:10<00:00,  1.38it/s, loss=183.4653]


Epoch 45/200, Train Loss: 4.1068, Train mIoU: 0.8818, Validation Loss: 7.8322, Validation mIoU: 0.7753


Epoch 46/200 Training: 100%|██████████| 112/112 [02:03<00:00,  1.10s/it, loss=63.7483]
Epoch 46/200 Validation: 100%|██████████| 14/14 [00:10<00:00,  1.33it/s, loss=99.4504] 


Epoch 46/200, Train Loss: 4.2388, Train mIoU: 0.8816, Validation Loss: 7.7437, Validation mIoU: 0.7808


Epoch 47/200 Training: 100%|██████████| 112/112 [02:00<00:00,  1.08s/it, loss=67.5670] 
Epoch 47/200 Validation: 100%|██████████| 14/14 [00:10<00:00,  1.32it/s, loss=154.7922]


Epoch 47/200, Train Loss: 4.7092, Train mIoU: 0.8360, Validation Loss: 7.8529, Validation mIoU: 0.6435


Epoch 48/200 Training: 100%|██████████| 112/112 [02:03<00:00,  1.10s/it, loss=67.8108]
Epoch 48/200 Validation: 100%|██████████| 14/14 [00:13<00:00,  1.05it/s, loss=184.1185]


Epoch 48/200, Train Loss: 4.1567, Train mIoU: 0.8536, Validation Loss: 7.6363, Validation mIoU: 0.7509


Epoch 49/200 Training: 100%|██████████| 112/112 [02:07<00:00,  1.14s/it, loss=68.5038]
Epoch 49/200 Validation: 100%|██████████| 14/14 [00:10<00:00,  1.36it/s, loss=144.8727]


Epoch 49/200, Train Loss: 3.9774, Train mIoU: 0.8854, Validation Loss: 7.3830, Validation mIoU: 0.7947


Epoch 50/200 Training: 100%|██████████| 112/112 [02:02<00:00,  1.09s/it, loss=80.0469]
Epoch 50/200 Validation: 100%|██████████| 14/14 [00:13<00:00,  1.04it/s, loss=93.6114] 
  iou = total_area_intersect / total_area_union


Epoch 50/200, Train Loss: 3.9850, Train mIoU: 0.8873, Validation Loss: 8.0967, Validation mIoU: 0.8273


Epoch 51/200 Training: 100%|██████████| 112/112 [02:00<00:00,  1.08s/it, loss=62.3284]
Epoch 51/200 Validation: 100%|██████████| 14/14 [00:09<00:00,  1.40it/s, loss=139.4620]


Epoch 51/200, Train Loss: 4.1917, Train mIoU: 0.8847, Validation Loss: 7.5094, Validation mIoU: 0.7720


Epoch 52/200 Training: 100%|██████████| 112/112 [02:05<00:00,  1.12s/it, loss=62.2683]
Epoch 52/200 Validation: 100%|██████████| 14/14 [00:13<00:00,  1.03it/s, loss=110.8706]


Epoch 52/200, Train Loss: 3.9174, Train mIoU: 0.8818, Validation Loss: 7.5836, Validation mIoU: 0.7583


Epoch 53/200 Training: 100%|██████████| 112/112 [01:57<00:00,  1.05s/it, loss=62.0481]
Epoch 53/200 Validation: 100%|██████████| 14/14 [00:09<00:00,  1.47it/s, loss=108.3047]


Epoch 53/200, Train Loss: 3.8526, Train mIoU: 0.8901, Validation Loss: 7.4731, Validation mIoU: 0.7648


Epoch 54/200 Training: 100%|██████████| 112/112 [02:05<00:00,  1.12s/it, loss=61.7987]
Epoch 54/200 Validation: 100%|██████████| 14/14 [00:09<00:00,  1.44it/s, loss=103.0588]


Epoch 54/200, Train Loss: 3.8799, Train mIoU: 0.8894, Validation Loss: 7.6549, Validation mIoU: 0.7419


Epoch 55/200 Training: 100%|██████████| 112/112 [02:05<00:00,  1.12s/it, loss=58.2066]
Epoch 55/200 Validation: 100%|██████████| 14/14 [00:09<00:00,  1.41it/s, loss=110.3214]


Epoch 55/200, Train Loss: 3.8539, Train mIoU: 0.8894, Validation Loss: 7.5258, Validation mIoU: 0.7717


Epoch 56/200 Training: 100%|██████████| 112/112 [02:03<00:00,  1.11s/it, loss=69.9409]
Epoch 56/200 Validation: 100%|██████████| 14/14 [00:09<00:00,  1.41it/s, loss=103.1507]


Epoch 56/200, Train Loss: 3.7151, Train mIoU: 0.8910, Validation Loss: 7.1401, Validation mIoU: 0.7982


Epoch 57/200 Training: 100%|██████████| 112/112 [01:59<00:00,  1.07s/it, loss=47.3391]
Epoch 57/200 Validation: 100%|██████████| 14/14 [00:11<00:00,  1.25it/s, loss=94.8212] 


Epoch 57/200, Train Loss: 3.6864, Train mIoU: 0.8870, Validation Loss: 7.6190, Validation mIoU: 0.7796


Epoch 58/200 Training: 100%|██████████| 112/112 [01:59<00:00,  1.07s/it, loss=65.5836]
Epoch 58/200 Validation: 100%|██████████| 14/14 [00:14<00:00,  1.05s/it, loss=88.5024] 


Epoch 58/200, Train Loss: 3.6370, Train mIoU: 0.8886, Validation Loss: 7.6808, Validation mIoU: 0.7551


Epoch 59/200 Training: 100%|██████████| 112/112 [02:01<00:00,  1.08s/it, loss=67.7087]
Epoch 59/200 Validation: 100%|██████████| 14/14 [00:13<00:00,  1.05it/s, loss=107.8220]


Epoch 59/200, Train Loss: 3.6872, Train mIoU: 0.8887, Validation Loss: 7.8004, Validation mIoU: 0.7720


Epoch 60/200 Training: 100%|██████████| 112/112 [01:58<00:00,  1.05s/it, loss=66.8611]


UnidentifiedImageError: cannot identify image file <_io.BytesIO object at 0x7f117ee3f3d0>

In [18]:
masks[0].shape

torch.Size([270, 480])

In [36]:
len(pred_maps)

16

In [20]:
masks[0]

tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]], dtype=torch.int32)

In [23]:
model = Mask2FormerForUniversalSegmentation.from_pretrained(f"{best_model_dir}{CURR_TASK}/").to(device)

In [24]:
model.eval()
test_running_loss = 0
test_loader = tqdm(dataloaders[CURR_TASK]["test"], desc="Test loop")
with torch.no_grad():
    for batch in test_loader:
        # Move everything to the device
        batch["pixel_values"] = batch["pixel_values"].to(device)
        batch["pixel_mask"] = batch["pixel_mask"].to(device)
        batch["mask_labels"] = [entry.to(device) for entry in batch["mask_labels"]]
        batch["class_labels"] = [entry.to(device) for entry in batch["class_labels"]]
        # Compute output and loss
        outputs = model(**batch)

        loss = outputs.loss
        # Record losses
        current_loss = loss.item() * batch["pixel_values"].size(0)
        test_running_loss += current_loss
        test_loader.set_postfix(loss=f"{current_loss:.4f}")

        # Extract and compute metrics
        pred_maps, masks = m2f_extract_pred_maps_and_masks(
            batch, outputs, m2f_preprocessor_B
        )
        metric.add_batch(references=masks, predictions=pred_maps)

# After compute the batches that were added are deleted
test_metrics_B = metric.compute(
    num_labels=NUM_CLASSES, ignore_index=BG_VALUE, reduce_labels=False
)
mean_test_iou = test_metrics_B["mean_iou"]
final_test_loss = test_running_loss / len(dataloaders[CURR_TASK]["test"].dataset)
wandb.log({
    f"Loss/test_{CURR_TASK}": final_test_loss,
    f"mIoU/test_{CURR_TASK}": mean_test_iou
})
print(f"Test Loss: {final_test_loss:.4f}, Test mIoU: {mean_test_iou:.4f}")

Test loop: 100%|██████████| 14/14 [00:11<00:00,  1.23it/s, loss=166.3146]


Test Loss: 8.7681, Test mIoU: 0.8503


In [25]:
# To avoid making stupid errors
CURR_TASK = "A"

model.eval()
test_running_loss = 0
test_loader = tqdm(dataloaders[CURR_TASK]["test"], desc="Test loop")
with torch.no_grad():
    for batch in test_loader:
        # Move everything to the device
        batch["pixel_values"] = batch["pixel_values"].to(device)
        batch["pixel_mask"] = batch["pixel_mask"].to(device)
        batch["mask_labels"] = [entry.to(device) for entry in batch["mask_labels"]]
        batch["class_labels"] = [entry.to(device) for entry in batch["class_labels"]]
        # Compute output and loss
        outputs = model(**batch)

        loss = outputs.loss
        # Record losses
        current_loss = loss.item() * batch["pixel_values"].size(0)
        test_running_loss += current_loss
        test_loader.set_postfix(loss=f"{current_loss:.4f}")

        # Extract and compute metrics
        pred_maps, masks = m2f_extract_pred_maps_and_masks(
            batch, outputs, m2f_preprocessor_A
        )
        metric.add_batch(references=masks, predictions=pred_maps)
        
# After compute the batches that were added are deleted
test_metrics_forgetting_A = metric.compute(
    num_labels=NUM_CLASSES, ignore_index=BG_VALUE, reduce_labels=False
)
mean_test_iou = test_metrics_forgetting_A["mean_iou"]
final_test_loss = test_running_loss / len(dataloaders[CURR_TASK]["test"].dataset)
wandb.log({
    f"Loss/test_naive_forgetting_{CURR_TASK}": final_test_loss,
    f"mIoU/test_naive_forgetting_{CURR_TASK}": mean_test_iou
})
print(f"Test Loss: {final_test_loss:.4f}, Test mIoU: {mean_test_iou:.4f}")
wandb.finish()

Test loop: 100%|██████████| 36/36 [00:31<00:00,  1.15it/s, loss=1304.7112]


Test Loss: 64.4713, Test mIoU: 0.4018


0,1
Loss/test_B,▁
Loss/test_naive_forgetting_A,▁
Loss/train_B,█▅▄▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/val_B,█▄▃▃▂▂▂▂▂▂▁▂▁▂▁▂▂▁▁▁▂▁▁▂▂▁▁▁▁▁▂▂▁▂▁▁▁▁▁▁
mIoU/test_B,▁
mIoU/test_naive_forgetting_A,▁
mIoU/train_B,▁▄▅▆▆▇▇▇▇▇▇█████▇██████████████▇▇███████
mIoU/val_B,▁▅▆▆▆▇▆▆▇▆▆▆▆▆▆▇▇▇▆▅▇▅▆▆▄▇▆▇▇█▇▄▆█▇▇▆▇▇▇

0,1
Loss/test_B,8.76806
Loss/test_naive_forgetting_A,64.47129
Loss/train_B,3.68717
Loss/val_B,7.80043
mIoU/test_B,0.8503
mIoU/test_naive_forgetting_A,0.4018
mIoU/train_B,0.88872
mIoU/val_B,0.77205


In [26]:
wandb.init(
    project="M2F_original",
    config={
        "learning_rate": LEARNING_RATE,
        "learning_rate_multiplier": LR_MULTIPLIER,
        "backbone_learning_rate": BACKBONE_LR,
        "learning_rate_scheduler": scheduler.__class__.__name__,
        "optimizer": optimizer.__class__.__name__,
        "backbone": "m2f_swin_backbone_train_cadis",
        "m2f_preprocessor": m2f_preprocessor_B.__dict__,
        "m2f_model_config": model.config
    },
    name="M2F-Swin-Tiny-Naive-Forgetting-all-results",
    notes="M2F with tiny Swin backbone pretrained on ImageNet-1K. \
        Scenario: Pretrained on A, Train on B, Test forgetting on A"
)

In [27]:
# Collect overall mIoU
mIoU_A = test_metrics_A["mean_iou"]
mIoU_forgetting_A = test_metrics_forgetting_A["mean_iou"]
mIoU_B = test_metrics_B["mean_iou"]

# Collect per category mIoU
per_category_mIoU_A = np.array(test_metrics_A["per_category_iou"])
per_category_mIoU_forgetting_A = np.array(test_metrics_forgetting_A["per_category_iou"])
per_category_mIoU_B = np.array(test_metrics_forgetting_A["per_category_iou"])

# Average learning accuracies (mIoUs)
avg_learning_acc = (mIoU_A + mIoU_B) / 2
per_category_avg_learning_acc = (per_category_mIoU_A + per_category_mIoU_B) / 2

# Forgetting
total_forgetting = mIoU_A - mIoU_forgetting_A
per_category_forgetting = (per_category_mIoU_A - per_category_mIoU_forgetting_A)

# Export evaluation metrics to WandB
wandb.log({
    "eval/avg_learning_acc": avg_learning_acc,
    "eval/per_category_avg_learning_acc": per_category_avg_learning_acc,
    "eval/total_forgetting": total_forgetting,
    "eval/per_category_forgetting": per_category_forgetting
})

print("**** Overall mIoU ****")
print(f"mIoU on task A: {mIoU_A}")
print(f"mIoU on task B: {mIoU_B}")
print(f"mIoU on task A after training on B: {mIoU_forgetting_A}")

print("\n**** Per category mIoU ****")
print(f"Per category mIoU on task A: {per_category_mIoU_A}")
print(f"Per category mIoU on task B: {per_category_mIoU_B}")
print(f"Per category mIoU on task A after training on B: {per_category_mIoU_forgetting_A}")

print("\n**** Average learning accuracies ****")
print(f"Average learning acc.: {avg_learning_acc}")
print(f"Per category Average learning acc.: {per_category_avg_learning_acc}")

print("\n**** Forgetting ****")
print(f"Total forgetting: {total_forgetting}")
print(f"Per category forgetting: {per_category_forgetting}")

**** Overall mIoU ****
mIoU on task A: 0.7352255713541009
mIoU on task B: 0.8503015788034922
mIoU on task A after training on B: 0.4017984503336052

**** Per category mIoU ****
Per category mIoU on task A: [0.         0.94635122 0.71503252 0.70157295 0.52939948 0.84239861
 0.79553573 0.78785687 0.89455109 0.94140952 0.9333733 ]
Per category mIoU on task B: [0.         0.70445866 0.3408371  0.21008035 0.0307262  0.41182026
 0.45926575 0.53111568 0.80600993 0.38888299 0.53658605]
Per category mIoU on task A after training on B: [0.         0.70445866 0.3408371  0.21008035 0.0307262  0.41182026
 0.45926575 0.53111568 0.80600993 0.38888299 0.53658605]

**** Average learning accuracies ****
Average learning acc.: 0.7927635750787966
Per category Average learning acc.: [0.         0.82540494 0.52793481 0.45582665 0.28006284 0.62710943
 0.62740074 0.65948627 0.85028051 0.66514626 0.73497967]

**** Forgetting ****
Total forgetting: 0.33342712102049576
Per category forgetting: [0.         0.2418

In [29]:
# Assuming masks and pred_maps are lists of PyTorch tensors

# Save masks
torch.save(masks, 'masks.pth')

# Save pred_maps
torch.save(pred_maps, 'pred_maps.pth')
