In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import os
from utils.common import (
    m2f_dataset_collate,
    m2f_extract_pred_maps_and_masks,
    set_seed,
    pixel_mean_std,
)
from utils.dataset_utils import (
    get_cadisv2_dataset,
    get_cataract1k_dataset,
    ZEISS_CATEGORIES,
)
from utils.medical_datasets import Mask2FormerDataset
from transformers import (
    Mask2FormerForUniversalSegmentation,
    SwinModel,
    SwinConfig,
    Mask2FormerConfig,
    AutoImageProcessor,
    Mask2FormerImageProcessor
)
from torch.utils.data import DataLoader
import evaluate
import torch.optim as optim
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from dotenv import load_dotenv
import wandb
from utils.augmentations import (train_transforms_noise,
                                 train_transforms_noise_no_distortion,
                                 train_transforms_blur)
from copy import deepcopy
import shutil
from utils.wandb_utils import log_table_of_images

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
set_seed(42) # seed everything

Random seed set as 42


In [4]:
NUM_CLASSES = len(ZEISS_CATEGORIES) - 3 + 1 # Remove class incremental add background !!!
SWIN_BACKBONE = "microsoft/swin-tiny-patch4-window7-224"#"microsoft/swin-large-patch4-window12-384"

# Download pretrained swin model
swin_model = SwinModel.from_pretrained(
    SWIN_BACKBONE, out_features=["stage1", "stage2", "stage3", "stage4"]
)
swin_config = SwinConfig.from_pretrained(
    SWIN_BACKBONE, out_features=["stage1", "stage2", "stage3", "stage4"]
)

# Create Mask2Former configuration based on Swin's configuration
mask2former_config = Mask2FormerConfig(
    backbone_config=swin_config, num_labels=NUM_CLASSES #, ignore_value=BG_VALUE
)

# Create the Mask2Former model with this configuration
model = Mask2FormerForUniversalSegmentation(mask2former_config)

# Reuse pretrained parameters
for swin_param, m2f_param in zip(
    swin_model.named_parameters(),
    model.model.pixel_level_module.encoder.named_parameters(),
):
    m2f_param_name = f"model.pixel_level_module.encoder.{m2f_param[0]}"

    if swin_param[0] == m2f_param[0]:
        model.state_dict()[m2f_param_name].copy_(swin_param[1])
        continue

    print(f"Not Matched: {m2f_param[0]} != {swin_param[0]}")



Not Matched: hidden_states_norms.stage1.weight != layernorm.weight
Not Matched: hidden_states_norms.stage1.bias != layernorm.bias


In [5]:
# Helper function to load datasets
def load_dataset(dataset_getter, data_path, domain_incremental):
    return dataset_getter(data_path, domain_incremental=domain_incremental)


# Helper function to create dataloaders for a dataset
def create_dataloaders(
    dataset, batch_size, shuffle, num_workers, drop_last, pin_memory, collate_fn
):
    return {
        "train": DataLoader(
            dataset["train"],
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=num_workers,
            drop_last=drop_last,
            pin_memory=pin_memory,
            collate_fn=collate_fn,
        ),
        "val": DataLoader(
            dataset["val"],
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=num_workers,
            drop_last=drop_last,
            pin_memory=pin_memory,
            collate_fn=collate_fn,
        ),
        "test": DataLoader(
            dataset["test"],
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
            drop_last=False,
            pin_memory=pin_memory,
            collate_fn=collate_fn,
        ),
    }


# Load datasets
datasets = {
    "A": load_dataset(get_cadisv2_dataset, "../../storage/data/CaDISv2", True),
    "B": load_dataset(get_cataract1k_dataset, "../../storage/data/cataract-1k", True),
}
pixel_mean_A,pixel_std_A=pixel_mean_std(datasets["A"][0])
print("dataset A pixel mean:",pixel_mean_A,"pixel_std:",pixel_std_A)


# This time define the B train dataset such that it replays all the training samples from A
new_train = torch.utils.data.ConcatDataset([datasets["A"][0], datasets["B"][0]])

pixel_mean_B,pixel_std_B=pixel_mean_std(new_train)
print("dataset B pixel mean:",pixel_mean_B,"pixel_std:",pixel_std_B)


datasets["B"] = (new_train, datasets["B"][1], datasets["B"][2])

# Define preprocessor
swin_processor = AutoImageProcessor.from_pretrained(SWIN_BACKBONE)
m2f_preprocessor_A = Mask2FormerImageProcessor(
    reduce_labels=False,
    ignore_index=255,
    do_resize=False,
    do_rescale=False,
    do_normalize=True,
    image_std=pixel_std_A,
    image_mean=pixel_mean_A,
)

m2f_preprocessor_B = Mask2FormerImageProcessor(
    reduce_labels=False,
    ignore_index=255,
    do_resize=False,
    do_rescale=False,
    do_normalize=True,
    image_std=pixel_std_B,
    image_mean=pixel_mean_B,
)


# Create Mask2Former Datasets

m2f_datasets = {
    "A": {
        "train": Mask2FormerDataset(datasets["A"][0], m2f_preprocessor_A, transform=train_transforms_blur),
        "val": Mask2FormerDataset(datasets["A"][1], m2f_preprocessor_A),
        "test": Mask2FormerDataset(datasets["A"][2], m2f_preprocessor_A),
    },
    "B": {
        "train": Mask2FormerDataset(datasets["B"][0], m2f_preprocessor_B, transform=train_transforms_blur),
        "val": Mask2FormerDataset(datasets["B"][1], m2f_preprocessor_B),
        "test": Mask2FormerDataset(datasets["B"][2], m2f_preprocessor_B),
    },
}


# DataLoader parameters
N_WORKERS = 4
BATCH_SIZE = 16
SHUFFLE = True
DROP_LAST = True

dataloader_params = {
    "batch_size": BATCH_SIZE,
    "shuffle": SHUFFLE,
    "num_workers": N_WORKERS,
    "drop_last": DROP_LAST,
    "pin_memory": True,
    "collate_fn": m2f_dataset_collate,
}

# Create DataLoaders
dataloaders = {
    key: create_dataloaders(m2f_datasets[key], **dataloader_params)
    for key in m2f_datasets
}

print(dataloaders)

dataset A pixel mean: [0.57365126 0.34606295 0.19539679] pixel_std: [0.15933991 0.15584118 0.10485045]




dataset B pixel mean: [0.48640466 0.32646684 0.20089334] pixel_std: [0.24972845 0.19719756 0.15521158]
{'A': {'train': <torch.utils.data.dataloader.DataLoader object at 0x7f4286157c50>, 'val': <torch.utils.data.dataloader.DataLoader object at 0x7f428b129130>, 'test': <torch.utils.data.dataloader.DataLoader object at 0x7f428b4a5f70>}, 'B': {'train': <torch.utils.data.dataloader.DataLoader object at 0x7f4286157800>, 'val': <torch.utils.data.dataloader.DataLoader object at 0x7f42861579b0>, 'test': <torch.utils.data.dataloader.DataLoader object at 0x7f4286157a10>}}


In [6]:
m2f_preprocessor_A.reduce_labels, m2f_preprocessor_A.ignore_index

(False, 255)

In [7]:
# Check if CUDA is available, otherwise use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Using device: {device}")

BG_VALUE_255=255
base_run_name="M2F-Swin-Tiny-Train_Cadis_AugBlur"
new_run_name="M2F-Swin-Tiny-Replay-All_AugBlur"
project_name = "M2F_latest_aug"
user_or_team = "continual-learning-tum"

Using device: cuda


In [8]:
# Tensorboard setup
out_dir="outputs_aug/"
if not os.path.exists(out_dir):
    os.makedirs(out_dir)
if not os.path.exists(out_dir+"runs"):
    os.makedirs(out_dir+"runs")
%load_ext tensorboard
%tensorboard --logdir outputs/runs

In [9]:
!CUDA_LAUNCH_BLOCKING=1

In [10]:
# Tensorboard logging
writer = SummaryWriter(log_dir=out_dir + "runs")

# Model checkpointing
model_dir = out_dir + "models/"
if not os.path.exists(model_dir):
    print("Store weights in: ", model_dir)
    os.makedirs(model_dir)

best_model_dir = model_dir + f"{base_run_name}/best_model/"
if not os.path.exists(best_model_dir):
    print("Store best model weights in: ", best_model_dir)
    os.makedirs(best_model_dir)
final_model_dir = model_dir + f"{base_run_name}/final_model/"
if not os.path.exists(final_model_dir):
    print("Store final model weights in: ", final_model_dir)
    os.makedirs(final_model_dir)

In [11]:
# WandB for team usage !!!!

wandb.login() # use this one if a different person is going to run the notebook
#wandb.login(relogin=False) # if the same person in the last run is going to run the notebook again


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mge85ket[0m ([33mcontinual-learning-tum[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

# First train on dataset A

In [12]:
# Training
NUM_EPOCHS = 200
LEARNING_RATE = 1e-4
LR_MULTIPLIER = 0.1
BACKBONE_LR = LEARNING_RATE * LR_MULTIPLIER
WEIGHT_DECAY = 0.05
PATIENCE=15
metric = evaluate.load("mean_iou") # mIoU will be used to pick the best performing model using val set
encoder_params = [
    param
    for name, param in model.named_parameters()
    if name.startswith("model.pixel_level_module.encoder")
]
decoder_params = [
    param
    for name, param in model.named_parameters()
    if name.startswith("model.pixel_level_module.decoder")
]
transformer_params = [
    param
    for name, param in model.named_parameters()
    if name.startswith("model.transformer_module")
]
class_prediction_params=[
    param
    for name, param in model.named_parameters() 
    if not name.startswith("model.pixel_level_module.encoder") and not name.startswith("model.transformer_module") and not name.startswith("model.pixel_level_module.decoder")
]
optimizer = optim.AdamW(
    [
        {"params": encoder_params, "lr": BACKBONE_LR},
        {"params": decoder_params},
        {"params": transformer_params},
        {"params": class_prediction_params}
    ],
    lr=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
)

scheduler = optim.lr_scheduler.PolynomialLR(
    optimizer, total_iters=NUM_EPOCHS, power=0.9
)

In [13]:
# WandB for team usage !!!!

wandb.login() # use this one if a different person is going to run the notebook
#wandb.login(relogin=False) # if the same person in the last run is going to run the notebook again

[34m[1mwandb[0m: Currently logged in as: [33mge85ket[0m ([33mcontinual-learning-tum[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [14]:
wandb.init(
    project=project_name,
    config={
        "learning_rate": LEARNING_RATE,
        "learning_rate_multiplier": LR_MULTIPLIER,
        "backbone_learning_rate": BACKBONE_LR,
        "learning_rate_scheduler": scheduler.__class__.__name__,
        "optimizer": optimizer.__class__.__name__,
        "backbone": SWIN_BACKBONE,
        "m2f_preprocessor": m2f_preprocessor_A.__dict__,
        "m2f_model_config": model.config
    },
    name=base_run_name,
    notes="M2F with tiny Swin backbone pretrained on ImageNet-1K. \
        Scenario: Train on A, Test on A"
)
print("wandb run id:",wandb.run.id)

wandb run id: oyvhaqf5


In [15]:
# Save the preprocessor
m2f_preprocessor_A.save_pretrained(model_dir + base_run_name)

['outputs_aug/models/M2F-Swin-Tiny-Train_Cadis_AugBlur/preprocessor_config.json']

In [16]:
# To avoid making stupid errors
CURR_TASK = "A"

# For storing the model
best_val_metric = -np.inf
best_model_weights=None # best model weights are stored here

# Move model to device
model.to(device)
counter=0
for epoch in range(NUM_EPOCHS):
    model.train()
    train_running_loss = 0.0
    val_running_loss = 0.0

    # Set up tqdm for the training loop
    train_loader = tqdm(
        dataloaders[CURR_TASK]["train"], desc=f"Epoch {epoch + 1}/{NUM_EPOCHS} Training"
    )

    for batch in train_loader:
        # Move everything to the device
        batch["pixel_values"] = batch["pixel_values"].to(device)
        batch["pixel_mask"] = batch["pixel_mask"].to(device)
        batch["mask_labels"] = [entry.to(device) for entry in batch["mask_labels"]]
        batch["class_labels"] = [entry.to(device) for entry in batch["class_labels"]]
       
        # Compute output and loss
        outputs = model(**batch)

        loss = outputs.loss
        # Compute gradient and perform step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Record losses
        current_loss = loss.item() * batch["pixel_values"].size(0)
        train_running_loss += current_loss
        train_loader.set_postfix(loss=f"{current_loss:.4f}")

        # Extract and compute metrics
        pred_maps, masks = m2f_extract_pred_maps_and_masks(
            batch, outputs, m2f_preprocessor_A
        )
        metric.add_batch(references=masks, predictions=pred_maps)

    # After compute the batches that were added are deleted
    temp_metrics = metric.compute(
        num_labels=NUM_CLASSES, ignore_index=BG_VALUE_255, reduce_labels=False
    )
    mean_train_iou=temp_metrics["mean_iou"]
        
    # Validation phase
    model.eval()
    val_loader = tqdm(
        dataloaders[CURR_TASK]["val"], desc=f"Epoch {epoch + 1}/{NUM_EPOCHS} Validation"
    )
    with torch.no_grad():
        for batch in val_loader:
            # Move everything to the device
            batch["pixel_values"] = batch["pixel_values"].to(device)
            batch["pixel_mask"] = batch["pixel_mask"].to(device)
            batch["mask_labels"] = [entry.to(device) for entry in batch["mask_labels"]]
            batch["class_labels"] = [
                entry.to(device) for entry in batch["class_labels"]
            ]
            # Compute output and loss
            outputs = model(**batch)

            loss = outputs.loss
            # Record losses
            current_loss = loss.item() * batch["pixel_values"].size(0)
            val_running_loss += current_loss
            val_loader.set_postfix(loss=f"{current_loss:.4f}")

            # Extract and compute metrics
            pred_maps, masks = m2f_extract_pred_maps_and_masks(
                batch, outputs, m2f_preprocessor_A
            )
            metric.add_batch(references=masks, predictions=pred_maps)
            

    # After compute the batches that were added are deleted
    mean_val_iou = metric.compute(
        num_labels=NUM_CLASSES, ignore_index=BG_VALUE_255, reduce_labels=False
    )["mean_iou"]

    epoch_train_loss = train_running_loss / len(dataloaders[CURR_TASK]["train"].dataset)
    epoch_val_loss = val_running_loss / len(dataloaders[CURR_TASK]["val"].dataset)

    writer.add_scalar(f"Loss/train_{base_run_name}_{CURR_TASK}", epoch_train_loss, epoch + 1)
    writer.add_scalar(f"Loss/val_{base_run_name}_{CURR_TASK}", epoch_val_loss, epoch + 1)
    writer.add_scalar(f"mIoU/train_{base_run_name}_{CURR_TASK}", mean_train_iou, epoch + 1)
    writer.add_scalar(f"mIoU/val_{base_run_name}_{CURR_TASK}", mean_val_iou, epoch + 1)

    wandb.log({
        f"Loss/train_{CURR_TASK}": epoch_train_loss,
        f"Loss/val_{CURR_TASK}": epoch_val_loss,
        f"mIoU/train_{CURR_TASK}": mean_train_iou,
        f"mIoU/val_{CURR_TASK}": mean_val_iou
    })


    tqdm.write(
        f"Epoch {epoch + 1}/{NUM_EPOCHS}, Train Loss: {epoch_train_loss:.4f}, Train mIoU: {mean_train_iou:.4f}, Validation Loss: {epoch_val_loss:.4f}, Validation mIoU: {mean_val_iou:.4f}"
    )
    
    if mean_val_iou > best_val_metric:
        best_val_metric = mean_val_iou
        #model.save_pretrained(f"{best_model_dir}{CURR_TASK}/")
        best_model_weights = deepcopy(model.state_dict())
        counter=0
    else:
        counter+=1
        if counter == PATIENCE:
            print("Early stopping at epoch",epoch)
            break
            
os.makedirs(f"{best_model_dir}{CURR_TASK}/",exist_ok=True)
artifact = wandb.Artifact(f"best_model_{base_run_name}", type="model")
artifact.add_file(f"{best_model_dir}{CURR_TASK}/best_model_{base_run_name}.pth", torch.save(best_model_weights, f"{best_model_dir}{CURR_TASK}/best_model_{base_run_name}.pth"))
wandb.run.log_artifact(artifact)

if os.path.exists(model_dir + f"{base_run_name}"):
    shutil.rmtree(model_dir + f"{base_run_name}")

        


Epoch 1/200 Training: 100%|██████████| 221/221 [05:32<00:00,  1.50s/it, loss=704.7999] 
Epoch 1/200 Validation: 100%|██████████| 33/33 [00:39<00:00,  1.19s/it, loss=765.5924]


Epoch 1/200, Train Loss: 67.3616, Train mIoU: 0.1051, Validation Loss: 50.5375, Validation mIoU: 0.1664


Epoch 2/200 Training: 100%|██████████| 221/221 [05:28<00:00,  1.49s/it, loss=462.6539]
Epoch 2/200 Validation: 100%|██████████| 33/33 [00:39<00:00,  1.19s/it, loss=548.7513]


Epoch 2/200, Train Loss: 34.7298, Train mIoU: 0.2194, Validation Loss: 32.8301, Validation mIoU: 0.2439


Epoch 3/200 Training: 100%|██████████| 221/221 [05:01<00:00,  1.37s/it, loss=404.5788]
Epoch 3/200 Validation: 100%|██████████| 33/33 [00:37<00:00,  1.13s/it, loss=434.6761]


Epoch 3/200, Train Loss: 28.7677, Train mIoU: 0.3186, Validation Loss: 27.6610, Validation mIoU: 0.3750


Epoch 4/200 Training: 100%|██████████| 221/221 [04:45<00:00,  1.29s/it, loss=343.0320]
Epoch 4/200 Validation: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it, loss=322.5956]


Epoch 4/200, Train Loss: 25.5613, Train mIoU: 0.4206, Validation Loss: 24.7237, Validation mIoU: 0.4505


Epoch 5/200 Training: 100%|██████████| 221/221 [04:58<00:00,  1.35s/it, loss=366.2369]
Epoch 5/200 Validation: 100%|██████████| 33/33 [00:34<00:00,  1.05s/it, loss=370.9257]


Epoch 5/200, Train Loss: 22.9067, Train mIoU: 0.4816, Validation Loss: 21.9428, Validation mIoU: 0.5173


Epoch 6/200 Training: 100%|██████████| 221/221 [05:15<00:00,  1.43s/it, loss=343.9352]
Epoch 6/200 Validation: 100%|██████████| 33/33 [00:33<00:00,  1.02s/it, loss=392.7356]


Epoch 6/200, Train Loss: 21.1706, Train mIoU: 0.5291, Validation Loss: 21.9580, Validation mIoU: 0.5172


Epoch 7/200 Training: 100%|██████████| 221/221 [05:10<00:00,  1.41s/it, loss=330.4022]
Epoch 7/200 Validation: 100%|██████████| 33/33 [00:30<00:00,  1.10it/s, loss=287.5240]


Epoch 7/200, Train Loss: 19.6055, Train mIoU: 0.5596, Validation Loss: 20.7170, Validation mIoU: 0.5328


Epoch 8/200 Training: 100%|██████████| 221/221 [04:57<00:00,  1.35s/it, loss=263.0647]
Epoch 8/200 Validation: 100%|██████████| 33/33 [00:29<00:00,  1.10it/s, loss=275.9343]


Epoch 8/200, Train Loss: 18.7459, Train mIoU: 0.5736, Validation Loss: 19.8039, Validation mIoU: 0.5534


Epoch 9/200 Training: 100%|██████████| 221/221 [04:47<00:00,  1.30s/it, loss=307.2384]
Epoch 9/200 Validation: 100%|██████████| 33/33 [00:29<00:00,  1.11it/s, loss=301.8675]


Epoch 9/200, Train Loss: 17.6494, Train mIoU: 0.5986, Validation Loss: 19.7995, Validation mIoU: 0.5849


Epoch 10/200 Training: 100%|██████████| 221/221 [04:52<00:00,  1.32s/it, loss=269.2785]
Epoch 10/200 Validation: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it, loss=263.1896]


Epoch 10/200, Train Loss: 16.8887, Train mIoU: 0.6150, Validation Loss: 20.2517, Validation mIoU: 0.5386


Epoch 11/200 Training: 100%|██████████| 221/221 [05:00<00:00,  1.36s/it, loss=257.3792]
Epoch 11/200 Validation: 100%|██████████| 33/33 [00:29<00:00,  1.12it/s, loss=299.0481]


Epoch 11/200, Train Loss: 16.0872, Train mIoU: 0.6367, Validation Loss: 18.9255, Validation mIoU: 0.5439


Epoch 12/200 Training: 100%|██████████| 221/221 [04:53<00:00,  1.33s/it, loss=219.8190]
Epoch 12/200 Validation: 100%|██████████| 33/33 [00:30<00:00,  1.08it/s, loss=355.9763]


Epoch 12/200, Train Loss: 15.6515, Train mIoU: 0.6476, Validation Loss: 18.0997, Validation mIoU: 0.6326


Epoch 13/200 Training: 100%|██████████| 221/221 [04:50<00:00,  1.32s/it, loss=243.6143]
Epoch 13/200 Validation: 100%|██████████| 33/33 [00:34<00:00,  1.04s/it, loss=332.8428]


Epoch 13/200, Train Loss: 15.1475, Train mIoU: 0.6593, Validation Loss: 18.7931, Validation mIoU: 0.6116


Epoch 14/200 Training: 100%|██████████| 221/221 [04:58<00:00,  1.35s/it, loss=226.5961]
Epoch 14/200 Validation: 100%|██████████| 33/33 [00:32<00:00,  1.03it/s, loss=341.9105]


Epoch 14/200, Train Loss: 14.6902, Train mIoU: 0.6649, Validation Loss: 19.9125, Validation mIoU: 0.6209


Epoch 15/200 Training: 100%|██████████| 221/221 [05:07<00:00,  1.39s/it, loss=273.7238]
Epoch 15/200 Validation: 100%|██████████| 33/33 [00:35<00:00,  1.06s/it, loss=276.5111]


Epoch 15/200, Train Loss: 14.0534, Train mIoU: 0.6831, Validation Loss: 19.2335, Validation mIoU: 0.5780


Epoch 16/200 Training: 100%|██████████| 221/221 [04:49<00:00,  1.31s/it, loss=229.4294]
Epoch 16/200 Validation: 100%|██████████| 33/33 [00:32<00:00,  1.01it/s, loss=334.8688]


Epoch 16/200, Train Loss: 13.7361, Train mIoU: 0.6979, Validation Loss: 18.4203, Validation mIoU: 0.5739


Epoch 17/200 Training: 100%|██████████| 221/221 [05:02<00:00,  1.37s/it, loss=224.8513]
Epoch 17/200 Validation: 100%|██████████| 33/33 [00:34<00:00,  1.04s/it, loss=452.4117]


Epoch 17/200, Train Loss: 13.2985, Train mIoU: 0.7035, Validation Loss: 20.0991, Validation mIoU: 0.5710


Epoch 18/200 Training: 100%|██████████| 221/221 [04:52<00:00,  1.33s/it, loss=216.9509]
Epoch 18/200 Validation: 100%|██████████| 33/33 [00:33<00:00,  1.01s/it, loss=261.4957]


Epoch 18/200, Train Loss: 13.1116, Train mIoU: 0.7123, Validation Loss: 19.2272, Validation mIoU: 0.6088


Epoch 19/200 Training: 100%|██████████| 221/221 [04:59<00:00,  1.36s/it, loss=195.7676]
Epoch 19/200 Validation: 100%|██████████| 33/33 [00:30<00:00,  1.08it/s, loss=304.0920]


Epoch 19/200, Train Loss: 12.7216, Train mIoU: 0.7166, Validation Loss: 19.8469, Validation mIoU: 0.5650


Epoch 20/200 Training: 100%|██████████| 221/221 [04:52<00:00,  1.32s/it, loss=206.4666]
Epoch 20/200 Validation: 100%|██████████| 33/33 [00:33<00:00,  1.03s/it, loss=313.8513]


Epoch 20/200, Train Loss: 12.4162, Train mIoU: 0.7261, Validation Loss: 18.7710, Validation mIoU: 0.6316


Epoch 21/200 Training: 100%|██████████| 221/221 [04:44<00:00,  1.29s/it, loss=170.1489]
Epoch 21/200 Validation: 100%|██████████| 33/33 [00:29<00:00,  1.11it/s, loss=308.4799]


Epoch 21/200, Train Loss: 12.1035, Train mIoU: 0.7384, Validation Loss: 18.8181, Validation mIoU: 0.5872


Epoch 22/200 Training: 100%|██████████| 221/221 [05:00<00:00,  1.36s/it, loss=161.1959]
Epoch 22/200 Validation: 100%|██████████| 33/33 [00:30<00:00,  1.07it/s, loss=242.7206]


Epoch 22/200, Train Loss: 11.8035, Train mIoU: 0.7503, Validation Loss: 18.0007, Validation mIoU: 0.5843


Epoch 23/200 Training: 100%|██████████| 221/221 [04:52<00:00,  1.32s/it, loss=177.0575]
Epoch 23/200 Validation: 100%|██████████| 33/33 [00:32<00:00,  1.01it/s, loss=360.6679]


Epoch 23/200, Train Loss: 11.5736, Train mIoU: 0.7525, Validation Loss: 19.4243, Validation mIoU: 0.6029


Epoch 24/200 Training: 100%|██████████| 221/221 [04:57<00:00,  1.35s/it, loss=190.5688]
Epoch 24/200 Validation: 100%|██████████| 33/33 [00:31<00:00,  1.04it/s, loss=362.2864]


Epoch 24/200, Train Loss: 11.4478, Train mIoU: 0.7571, Validation Loss: 19.7169, Validation mIoU: 0.5727


Epoch 25/200 Training: 100%|██████████| 221/221 [04:42<00:00,  1.28s/it, loss=153.5436]
Epoch 25/200 Validation: 100%|██████████| 33/33 [00:33<00:00,  1.02s/it, loss=270.3229]


Epoch 25/200, Train Loss: 11.0825, Train mIoU: 0.7657, Validation Loss: 19.0948, Validation mIoU: 0.5966


Epoch 26/200 Training: 100%|██████████| 221/221 [04:52<00:00,  1.32s/it, loss=162.4758]
Epoch 26/200 Validation: 100%|██████████| 33/33 [00:28<00:00,  1.14it/s, loss=327.7599]


Epoch 26/200, Train Loss: 11.0449, Train mIoU: 0.7540, Validation Loss: 19.6666, Validation mIoU: 0.6107


Epoch 27/200 Training: 100%|██████████| 221/221 [04:45<00:00,  1.29s/it, loss=184.5747]
Epoch 27/200 Validation: 100%|██████████| 33/33 [00:34<00:00,  1.04s/it, loss=323.0818]


Epoch 27/200, Train Loss: 10.6907, Train mIoU: 0.7743, Validation Loss: 19.3291, Validation mIoU: 0.6421


Epoch 28/200 Training: 100%|██████████| 221/221 [04:49<00:00,  1.31s/it, loss=149.5404]
Epoch 28/200 Validation: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it, loss=237.8805]


Epoch 28/200, Train Loss: 10.6781, Train mIoU: 0.7675, Validation Loss: 19.6313, Validation mIoU: 0.6002


Epoch 29/200 Training: 100%|██████████| 221/221 [04:50<00:00,  1.32s/it, loss=161.0405]
Epoch 29/200 Validation: 100%|██████████| 33/33 [00:29<00:00,  1.13it/s, loss=347.0563]


Epoch 29/200, Train Loss: 10.4069, Train mIoU: 0.7791, Validation Loss: 20.4786, Validation mIoU: 0.5710


Epoch 30/200 Training: 100%|██████████| 221/221 [04:59<00:00,  1.35s/it, loss=158.3876]
Epoch 30/200 Validation: 100%|██████████| 33/33 [00:30<00:00,  1.08it/s, loss=314.8263]


Epoch 30/200, Train Loss: 10.0468, Train mIoU: 0.7920, Validation Loss: 19.4172, Validation mIoU: 0.6075


Epoch 31/200 Training: 100%|██████████| 221/221 [04:45<00:00,  1.29s/it, loss=155.5565]
Epoch 31/200 Validation: 100%|██████████| 33/33 [00:29<00:00,  1.10it/s, loss=435.9793]


Epoch 31/200, Train Loss: 10.0474, Train mIoU: 0.7904, Validation Loss: 19.6269, Validation mIoU: 0.6249


Epoch 32/200 Training: 100%|██████████| 221/221 [04:47<00:00,  1.30s/it, loss=175.4153]
Epoch 32/200 Validation: 100%|██████████| 33/33 [00:32<00:00,  1.00it/s, loss=286.8088]


Epoch 32/200, Train Loss: 9.7849, Train mIoU: 0.7974, Validation Loss: 19.8700, Validation mIoU: 0.5857


Epoch 33/200 Training: 100%|██████████| 221/221 [04:57<00:00,  1.35s/it, loss=128.6396]
Epoch 33/200 Validation: 100%|██████████| 33/33 [00:40<00:00,  1.21s/it, loss=256.4563]


Epoch 33/200, Train Loss: 9.7581, Train mIoU: 0.7957, Validation Loss: 18.7916, Validation mIoU: 0.6293


Epoch 34/200 Training: 100%|██████████| 221/221 [04:44<00:00,  1.29s/it, loss=143.9128]
Epoch 34/200 Validation: 100%|██████████| 33/33 [00:28<00:00,  1.17it/s, loss=376.7933]


Epoch 34/200, Train Loss: 9.5543, Train mIoU: 0.7994, Validation Loss: 19.7944, Validation mIoU: 0.5894


Epoch 35/200 Training: 100%|██████████| 221/221 [04:44<00:00,  1.29s/it, loss=125.1175]
Epoch 35/200 Validation: 100%|██████████| 33/33 [00:31<00:00,  1.04it/s, loss=273.5570]


Epoch 35/200, Train Loss: 9.4178, Train mIoU: 0.8082, Validation Loss: 19.7769, Validation mIoU: 0.6081


Epoch 36/200 Training: 100%|██████████| 221/221 [04:49<00:00,  1.31s/it, loss=161.5038]
Epoch 36/200 Validation: 100%|██████████| 33/33 [00:30<00:00,  1.07it/s, loss=359.7872]


Epoch 36/200, Train Loss: 9.2922, Train mIoU: 0.8015, Validation Loss: 22.6949, Validation mIoU: 0.5615


Epoch 37/200 Training: 100%|██████████| 221/221 [04:54<00:00,  1.33s/it, loss=130.8067]
Epoch 37/200 Validation: 100%|██████████| 33/33 [00:31<00:00,  1.06it/s, loss=365.8166]


Epoch 37/200, Train Loss: 9.1335, Train mIoU: 0.8076, Validation Loss: 20.8718, Validation mIoU: 0.6488


Epoch 38/200 Training: 100%|██████████| 221/221 [04:50<00:00,  1.31s/it, loss=141.2131]
Epoch 38/200 Validation: 100%|██████████| 33/33 [00:32<00:00,  1.03it/s, loss=390.0930]


Epoch 38/200, Train Loss: 9.0162, Train mIoU: 0.8118, Validation Loss: 21.1939, Validation mIoU: 0.6000


Epoch 39/200 Training: 100%|██████████| 221/221 [04:45<00:00,  1.29s/it, loss=135.0854]
Epoch 39/200 Validation: 100%|██████████| 33/33 [00:31<00:00,  1.03it/s, loss=428.8681]


Epoch 39/200, Train Loss: 8.9392, Train mIoU: 0.8155, Validation Loss: 20.6204, Validation mIoU: 0.6078


Epoch 40/200 Training: 100%|██████████| 221/221 [04:48<00:00,  1.31s/it, loss=122.1686]
Epoch 40/200 Validation: 100%|██████████| 33/33 [00:26<00:00,  1.24it/s, loss=271.4616]


Epoch 40/200, Train Loss: 8.7189, Train mIoU: 0.8207, Validation Loss: 20.4847, Validation mIoU: 0.6100


Epoch 41/200 Training: 100%|██████████| 221/221 [04:51<00:00,  1.32s/it, loss=163.8227]
Epoch 41/200 Validation: 100%|██████████| 33/33 [00:27<00:00,  1.18it/s, loss=300.2075]


Epoch 41/200, Train Loss: 8.6612, Train mIoU: 0.8215, Validation Loss: 21.4255, Validation mIoU: 0.5986


Epoch 42/200 Training: 100%|██████████| 221/221 [05:12<00:00,  1.41s/it, loss=117.1915]
Epoch 42/200 Validation: 100%|██████████| 33/33 [00:35<00:00,  1.07s/it, loss=401.2606]


Epoch 42/200, Train Loss: 8.4651, Train mIoU: 0.8250, Validation Loss: 21.2698, Validation mIoU: 0.6020


Epoch 43/200 Training: 100%|██████████| 221/221 [05:07<00:00,  1.39s/it, loss=123.3194]
Epoch 43/200 Validation: 100%|██████████| 33/33 [00:31<00:00,  1.06it/s, loss=466.0887]


Epoch 43/200, Train Loss: 8.3502, Train mIoU: 0.8304, Validation Loss: 21.1995, Validation mIoU: 0.6113


Epoch 44/200 Training: 100%|██████████| 221/221 [04:49<00:00,  1.31s/it, loss=124.5013]
Epoch 44/200 Validation: 100%|██████████| 33/33 [00:34<00:00,  1.03s/it, loss=431.4523]


Epoch 44/200, Train Loss: 8.3658, Train mIoU: 0.8275, Validation Loss: 21.9675, Validation mIoU: 0.5996


Epoch 45/200 Training: 100%|██████████| 221/221 [04:42<00:00,  1.28s/it, loss=133.2351]
Epoch 45/200 Validation: 100%|██████████| 33/33 [00:30<00:00,  1.09it/s, loss=336.2133]


Epoch 45/200, Train Loss: 8.2847, Train mIoU: 0.8331, Validation Loss: 21.9501, Validation mIoU: 0.5924


Epoch 46/200 Training: 100%|██████████| 221/221 [04:45<00:00,  1.29s/it, loss=154.6065]
Epoch 46/200 Validation: 100%|██████████| 33/33 [00:33<00:00,  1.01s/it, loss=361.6513]


Epoch 46/200, Train Loss: 8.3729, Train mIoU: 0.8306, Validation Loss: 21.1495, Validation mIoU: 0.6274


Epoch 47/200 Training: 100%|██████████| 221/221 [04:51<00:00,  1.32s/it, loss=158.5703]
Epoch 47/200 Validation: 100%|██████████| 33/33 [00:28<00:00,  1.17it/s, loss=298.2366]


Epoch 47/200, Train Loss: 8.1476, Train mIoU: 0.8338, Validation Loss: 22.2162, Validation mIoU: 0.6056


Epoch 48/200 Training: 100%|██████████| 221/221 [04:46<00:00,  1.30s/it, loss=120.2198]
Epoch 48/200 Validation: 100%|██████████| 33/33 [00:33<00:00,  1.02s/it, loss=366.9691]


Epoch 48/200, Train Loss: 8.0291, Train mIoU: 0.8455, Validation Loss: 22.6583, Validation mIoU: 0.5821


Epoch 49/200 Training: 100%|██████████| 221/221 [04:41<00:00,  1.27s/it, loss=117.8744]
Epoch 49/200 Validation: 100%|██████████| 33/33 [00:32<00:00,  1.02it/s, loss=378.2524]


Epoch 49/200, Train Loss: 7.9203, Train mIoU: 0.8434, Validation Loss: 21.9283, Validation mIoU: 0.5882


Epoch 50/200 Training: 100%|██████████| 221/221 [04:47<00:00,  1.30s/it, loss=142.8060]
Epoch 50/200 Validation: 100%|██████████| 33/33 [00:27<00:00,  1.22it/s, loss=370.8741]


Epoch 50/200, Train Loss: 7.7965, Train mIoU: 0.8387, Validation Loss: 23.0639, Validation mIoU: 0.5808


Epoch 51/200 Training: 100%|██████████| 221/221 [04:47<00:00,  1.30s/it, loss=126.2721]
Epoch 51/200 Validation: 100%|██████████| 33/33 [00:27<00:00,  1.22it/s, loss=401.6805]


Epoch 51/200, Train Loss: 7.7423, Train mIoU: 0.8426, Validation Loss: 22.1923, Validation mIoU: 0.5605


Epoch 52/200 Training: 100%|██████████| 221/221 [04:48<00:00,  1.31s/it, loss=165.8905]
Epoch 52/200 Validation: 100%|██████████| 33/33 [00:29<00:00,  1.12it/s, loss=312.8188]


Epoch 52/200, Train Loss: 7.7437, Train mIoU: 0.8392, Validation Loss: 20.9566, Validation mIoU: 0.6090
Early stopping at epoch 51


## Test results on A

In [14]:
# Load best model and evaluate on test
#model = Mask2FormerForUniversalSegmentation.from_pretrained(f"{best_model_dir}{CURR_TASK}/").to(device)

# Load pretrained on Cadis from naive forgetting 
#model = Mask2FormerForUniversalSegmentation.from_pretrained(f"/notebooks/continual-learning/outputs/models/{base_model_name}/best_model/A").to(device)

# Construct the artifact path
artifact_path = f"{user_or_team}/{project_name}/best_model_{base_run_name}:latest"

# Load from W&B
api = wandb.Api()
artifact=api.artifact(artifact_path)
model_dir=artifact.download()
model_state_dict_path = os.path.join(model_dir, f"best_model_{base_run_name}.pth" )
model_state_dict = torch.load(model_state_dict_path)
model = Mask2FormerForUniversalSegmentation(mask2former_config)
model.load_state_dict(model_state_dict)
model.to(device)

[34m[1mwandb[0m: Downloading large artifact best_model_M2F-Swin-Tiny-Train_Cadis_AugBlur:latest, 181.31MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.7


Mask2FormerForUniversalSegmentation(
  (model): Mask2FormerModel(
    (pixel_level_module): Mask2FormerPixelLevelModule(
      (encoder): SwinBackbone(
        (embeddings): SwinEmbeddings(
          (patch_embeddings): SwinPatchEmbeddings(
            (projection): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
          )
          (norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
        (encoder): SwinEncoder(
          (layers): ModuleList(
            (0): SwinStage(
              (blocks): ModuleList(
                (0-1): 2 x SwinLayer(
                  (layernorm_before): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
                  (attention): SwinAttention(
                    (self): SwinSelfAttention(
                      (query): Linear(in_features=96, out_features=96, bias=True)
                      (key): Linear(in_features=96, out_features=96, bias=True)
                      (value

In [15]:
model.eval()
test_running_loss = 0
CURR_TASK="A"
test_loader = tqdm(dataloaders[CURR_TASK]["test"], desc="Test loop")

BATCH_INDEX = 0
table = wandb.Table(columns=["ID", "Image"])
with torch.no_grad():
    for batch in test_loader:
        # Move everything to the device
        batch["pixel_values"] = batch["pixel_values"].to(device)
        batch["pixel_mask"] = batch["pixel_mask"].to(device)
        batch["mask_labels"] = [entry.to(device) for entry in batch["mask_labels"]]
        batch["class_labels"] = [entry.to(device) for entry in batch["class_labels"]]
        # Compute output and loss
        outputs = model(**batch)

        loss = outputs.loss
        # Record losses
        current_loss = loss.item() * batch["pixel_values"].size(0)
        test_running_loss += current_loss
        test_loader.set_postfix(loss=f"{current_loss:.4f}")

        # Extract and compute metrics
        pred_maps, masks = m2f_extract_pred_maps_and_masks(
            batch, outputs, m2f_preprocessor_A
        )
        metric.add_batch(references=masks, predictions=pred_maps)
        if BATCH_INDEX <5:
            # Visualize
            log_table_of_images(
                table, # common table for all batches
                batch["pixel_values"],
                pixel_mean_A, # remove normalization
                pixel_std_A, # remove normalization
                pred_maps,
                masks,
                BATCH_INDEX, # correct indexing in table
            )
            BATCH_INDEX += 1
# Log table
wandb.log({f"{CURR_TASK}_TEST_AFTER_TRAINING_A": table})

# After compute the batches that were added are deleted
test_metrics_A = metric.compute(
    num_labels=NUM_CLASSES, ignore_index=BG_VALUE_255, reduce_labels=False
)
mean_test_iou = test_metrics_A["mean_iou"]
final_test_loss = test_running_loss / len(dataloaders[CURR_TASK]["test"].dataset)
"""wandb.log({
    f"Loss/test_{CURR_TASK}": final_test_loss,
    f"mIoU/test_{CURR_TASK}": mean_test_iou
})"""
print(f"Test Loss: {final_test_loss:.4f}, Test mIoU: {mean_test_iou:.4f}")

Test loop: 100%|██████████| 37/37 [00:40<00:00,  1.10s/it, loss=362.5905]


Test Loss: 28.3386, Test mIoU: 0.6400


## Test results on B

In [16]:
model.eval()
test_running_loss = 0
CURR_TASK="B"
test_loader = tqdm(dataloaders[CURR_TASK]["test"], desc="Test loop")

BATCH_INDEX = 0
table = wandb.Table(columns=["ID", "Image"])
with torch.no_grad():
    for batch in test_loader:
        # Move everything to the device
        batch["pixel_values"] = batch["pixel_values"].to(device)
        batch["pixel_mask"] = batch["pixel_mask"].to(device)
        batch["mask_labels"] = [entry.to(device) for entry in batch["mask_labels"]]
        batch["class_labels"] = [entry.to(device) for entry in batch["class_labels"]]
        # Compute output and loss
        outputs = model(**batch)

        loss = outputs.loss
        # Record losses
        current_loss = loss.item() * batch["pixel_values"].size(0)
        test_running_loss += current_loss
        test_loader.set_postfix(loss=f"{current_loss:.4f}")

        # Extract and compute metrics
        pred_maps, masks = m2f_extract_pred_maps_and_masks(
            batch, outputs, m2f_preprocessor_B
        )
        metric.add_batch(references=masks, predictions=pred_maps)
        if BATCH_INDEX <5:
            # Visualize
            log_table_of_images(
                table, # common table for all batches
                batch["pixel_values"],
                pixel_mean_B, # remove normalization
                pixel_std_B, # remove normalization
                pred_maps,
                masks,
                BATCH_INDEX, # correct indexing in table
            )
            BATCH_INDEX += 1

# Log table
wandb.log({f"{CURR_TASK}_TEST_AFTER_TRAINING_A": table})

# After compute the batches that were added are deleted
test_metrics_B_before = metric.compute(
    num_labels=NUM_CLASSES, ignore_index=BG_VALUE_255, reduce_labels=False
)
mean_test_iou = test_metrics_B_before["mean_iou"]
final_test_loss = test_running_loss / len(dataloaders[CURR_TASK]["test"].dataset)
"""wandb.log({
    f"Loss/test_{CURR_TASK}": final_test_loss,
    f"mIoU/test_{CURR_TASK}": mean_test_iou
})"""
print(f"Test Loss: {final_test_loss:.4f}, Test mIoU: {mean_test_iou:.4f}")
#wandb.finish()

Test loop: 100%|██████████| 15/15 [00:21<00:00,  1.43s/it, loss=152.9794]


Test Loss: 48.3620, Test mIoU: 0.2589


# Now train on B and forget A

In [17]:
# Training
NUM_EPOCHS = 200
LEARNING_RATE = 1e-4
LR_MULTIPLIER = 0.1
BACKBONE_LR = LEARNING_RATE * LR_MULTIPLIER
WEIGHT_DECAY = 0.05
PATIENCE=15
model.train()
metric = evaluate.load("mean_iou")
encoder_params = [
    param
    for name, param in model.named_parameters()
    if name.startswith("model.pixel_level_module.encoder")
]
decoder_params = [
    param
    for name, param in model.named_parameters()
    if name.startswith("model.pixel_level_module.decoder")
]
transformer_params = [
    param
    for name, param in model.named_parameters()
    if name.startswith("model.transformer_module")
]
class_prediction_params=[
    param
    for name, param in model.named_parameters() 
    if not name.startswith("model.pixel_level_module.encoder") and not name.startswith("model.transformer_module") and not name.startswith("model.pixel_level_module.decoder")
]
optimizer = optim.AdamW(
    [
        {"params": encoder_params, "lr": BACKBONE_LR},
        {"params": decoder_params},
        {"params": transformer_params},
        {"params": class_prediction_params}
    ],
    lr=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
)

scheduler = optim.lr_scheduler.PolynomialLR(
    optimizer, total_iters=NUM_EPOCHS, power=0.9
)

In [None]:
# WandB for team usage !!!!

wandb.login() # use this one if a different person is going to run the notebook
#wandb.login(relogin=False) # if the same person in the last run is going to run the notebook again

In [13]:
wandb.init(
    project=project_name,
    config={
        "learning_rate": LEARNING_RATE,
        "learning_rate_multiplier": LR_MULTIPLIER,
        "backbone_learning_rate": BACKBONE_LR,
        "learning_rate_scheduler": scheduler.__class__.__name__,
        "optimizer": optimizer.__class__.__name__,
        "backbone": SWIN_BACKBONE,
        "m2f_preprocessor": m2f_preprocessor_B.__dict__,
        "m2f_model_config": model.config
    },
    name=new_run_name,
    notes="M2F with tiny Swin backbone pretrained on ImageNet-1K. \
        Scenario: Pretrained on A, Train on A + B naive finetuning (replay all), Test forgetting on A"
)

print("wandb run id:",wandb.run.id)

# Tensorboard logging
writer = SummaryWriter(log_dir=out_dir + "runs")

# Model checkpointing
model_dir = out_dir + "models/"
if not os.path.exists(model_dir):
    print("Store weights in: ", model_dir)
    os.makedirs(model_dir)

best_model_dir = model_dir + f"{new_run_name}/best_model/"
if not os.path.exists(best_model_dir):
    print("Store best model weights in: ", best_model_dir)
    os.makedirs(best_model_dir)
final_model_dir = model_dir + f"{new_run_name}/final_model/"
if not os.path.exists(final_model_dir):
    print("Store final model weights in: ", final_model_dir)
    os.makedirs(final_model_dir)

wandb run id: 8ewtur1y


In [None]:
# Save the preprocessor
m2f_preprocessor_B.save_pretrained(model_dir + new_run_name)

In [18]:
# To avoid making stupid errors
CURR_TASK = "B"
model_path_second = f"{best_model_dir}A+{CURR_TASK}/"


# For storing the model
best_val_metric = -np.inf
best_model_weights=None # best model weights are stored here

# Move model to device
model.to(device)
counter=0
for epoch in range(NUM_EPOCHS):
    model.train()
    train_running_loss = 0.0
    val_running_loss = 0.0

    # Set up tqdm for the training loop
    train_loader = tqdm(
        dataloaders[CURR_TASK]["train"], desc=f"Epoch {epoch + 1}/{NUM_EPOCHS} Training"
    )

    for batch in train_loader:
        # Move everything to the device
        batch["pixel_values"] = batch["pixel_values"].to(device)
        batch["pixel_mask"] = batch["pixel_mask"].to(device)
        batch["mask_labels"] = [entry.to(device) for entry in batch["mask_labels"]]
        batch["class_labels"] = [entry.to(device) for entry in batch["class_labels"]]

        # Compute output and loss
        outputs = model(**batch)

        loss = outputs.loss

        # Compute gradient and perform step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Record losses
        current_loss = loss.item() * batch["pixel_values"].size(0)
        train_running_loss += current_loss
        train_loader.set_postfix(loss=f"{current_loss:.4f}")

        # Extract and compute metrics
        pred_maps, masks = m2f_extract_pred_maps_and_masks(
            batch, outputs, m2f_preprocessor_B
        )
        metric.add_batch(references=masks, predictions=pred_maps)
        

    # After compute the batches that were added are deleted
    mean_train_iou = metric.compute(
        num_labels=NUM_CLASSES, ignore_index=BG_VALUE_255, reduce_labels=False
    )["mean_iou"]

    # Validation phase
    model.eval()
    val_loader = tqdm(
        dataloaders[CURR_TASK]["val"], desc=f"Epoch {epoch + 1}/{NUM_EPOCHS} Validation"
    )
    with torch.no_grad():
        for batch in val_loader:
            # Move everything to the device
            batch["pixel_values"] = batch["pixel_values"].to(device)
            batch["pixel_mask"] = batch["pixel_mask"].to(device)
            batch["mask_labels"] = [entry.to(device) for entry in batch["mask_labels"]]
            batch["class_labels"] = [
                entry.to(device) for entry in batch["class_labels"]
            ]
            # Compute output and loss
            outputs = model(**batch)

            loss = outputs.loss
            # Record losses
            current_loss = loss.item() * batch["pixel_values"].size(0)
            val_running_loss += current_loss
            val_loader.set_postfix(loss=f"{current_loss:.4f}")

            # Extract and compute metrics
            pred_maps, masks = m2f_extract_pred_maps_and_masks(
                batch, outputs, m2f_preprocessor_B
            )
            metric.add_batch(references=masks, predictions=pred_maps)
            

    # After compute the batches that were added are deleted
    mean_val_iou = metric.compute(
        num_labels=NUM_CLASSES, ignore_index=BG_VALUE_255, reduce_labels=False
    )["mean_iou"]

    epoch_train_loss = train_running_loss / len(dataloaders[CURR_TASK]["train"].dataset)
    epoch_val_loss = val_running_loss / len(dataloaders[CURR_TASK]["val"].dataset)

    writer.add_scalar(f"Loss/train_{new_run_name}_{CURR_TASK}", epoch_train_loss, epoch + 1)
    writer.add_scalar(f"Loss/val_{new_run_name}_{CURR_TASK}", epoch_val_loss, epoch + 1)
    writer.add_scalar(f"mIoU/train_{new_run_name}_{CURR_TASK}", mean_train_iou, epoch + 1)
    writer.add_scalar(f"mIoU/val_{new_run_name}_{CURR_TASK}", mean_val_iou, epoch + 1)

    wandb.log({
        f"Loss/train_A+{CURR_TASK}": epoch_train_loss,
        f"Loss/val_A+{CURR_TASK}": epoch_val_loss,
        f"mIoU/train_A+{CURR_TASK}": mean_train_iou,
        f"mIoU/val_A+{CURR_TASK}": mean_val_iou
    })

    tqdm.write(
        f"Epoch {epoch + 1}/{NUM_EPOCHS}, Train Loss: {epoch_train_loss:.4f}, Train mIoU: {mean_train_iou:.4f}, Validation Loss: {epoch_val_loss:.4f}, Validation mIoU: {mean_val_iou:.4f}"
    )
    if mean_val_iou > best_val_metric:
        best_val_metric = mean_val_iou
        #model.save_pretrained(model_path_second)
        best_model_weights = deepcopy(model.state_dict())
        counter=0
    else:
        counter+=1
        if counter == PATIENCE:
            print("Early stopping at epoch",epoch)
            break
            

os.makedirs(f"{model_path_second}",exist_ok=True)
artifact = wandb.Artifact(f"best_model_{new_run_name}", type="model")
artifact.add_file(f"{model_path_second}/best_model_{new_run_name}.pth", torch.save(best_model_weights, f"{model_path_second}/best_model_{new_run_name}.pth"))
wandb.run.log_artifact(artifact)

if os.path.exists(model_dir + f"{new_run_name}"):
    shutil.rmtree(model_dir + f"{new_run_name}")


Epoch 1/200 Training: 100%|██████████| 334/334 [07:27<00:00,  1.34s/it, loss=190.2361]
Epoch 1/200 Validation: 100%|██████████| 14/14 [00:13<00:00,  1.01it/s, loss=272.9021]


Epoch 1/200, Train Loss: 13.6459, Train mIoU: 0.6826, Validation Loss: 14.3665, Validation mIoU: 0.5382


Epoch 2/200 Training:  80%|████████  | 268/334 [06:03<01:29,  1.36s/it, loss=199.1889]


ValueError: matrix contains invalid numeric entries

In [None]:
print("training done")

## Test results on B first

In [None]:
# Load best model and evaluate on test
#model = Mask2FormerForUniversalSegmentation.from_pretrained(model_path_second).to(device)

# Construct the artifact path
artifact_path = f"{user_or_team}/{project_name}/best_model_{new_run_name}:latest"

# Load from W&B
api = wandb.Api()
artifact=api.artifact(artifact_path)
model_dir=artifact.download()
model_state_dict_path = os.path.join(model_dir, f"best_model_{new_run_name}.pth" )
model_state_dict = torch.load(model_state_dict_path)
model = Mask2FormerForUniversalSegmentation(mask2former_config)
model.load_state_dict(model_state_dict)
model.to(device)

In [None]:
model.eval()
test_running_loss = 0
CURR_TASK="B"
test_loader = tqdm(dataloaders[CURR_TASK]["test"], desc="Test loop")
BATCH_INDEX = 0
table = wandb.Table(columns=["ID", "Image"])
with torch.no_grad():
    for batch in test_loader:
        # Move everything to the device
        batch["pixel_values"] = batch["pixel_values"].to(device)
        batch["pixel_mask"] = batch["pixel_mask"].to(device)
        batch["mask_labels"] = [entry.to(device) for entry in batch["mask_labels"]]
        batch["class_labels"] = [entry.to(device) for entry in batch["class_labels"]]
        # Compute output and loss
        outputs = model(**batch)

        loss = outputs.loss
        # Record losses
        current_loss = loss.item() * batch["pixel_values"].size(0)
        test_running_loss += current_loss
        test_loader.set_postfix(loss=f"{current_loss:.4f}")

        # Extract and compute metrics
        pred_maps, masks = m2f_extract_pred_maps_and_masks(
            batch, outputs, m2f_preprocessor_B
        )
        metric.add_batch(references=masks, predictions=pred_maps)
        if BATCH_INDEX <5:
            # Visualize
            log_table_of_images(
                table, # common table for all batches
                batch["pixel_values"],
                pixel_mean_B, # remove normalization
                pixel_std_B, # remove normalization
                pred_maps,
                masks,
                BATCH_INDEX, # correct indexing in table
            )
            BATCH_INDEX += 1
# Log table
wandb.log({f"{CURR_TASK}_TEST_AFTER_TRAINING_B": table})

# After compute the batches that were added are deleted
test_metrics_B = metric.compute(
    num_labels=NUM_CLASSES, ignore_index=BG_VALUE_255, reduce_labels=False
)
mean_test_iou = test_metrics_B["mean_iou"]
final_test_loss = test_running_loss / len(dataloaders[CURR_TASK]["test"].dataset)
wandb.log({
    f"Loss/test_{CURR_TASK}": final_test_loss,
    f"mIoU/test_{CURR_TASK}": mean_test_iou
})
print(f"Test Loss: {final_test_loss:.4f}, Test mIoU: {mean_test_iou:.4f}")

## Test results on A after training on B

In [None]:
# To avoid making stupid errors
CURR_TASK = "A"

model.eval()
test_running_loss = 0
test_loader = tqdm(dataloaders[CURR_TASK]["test"], desc="Test loop")
BATCH_INDEX = 0
table = wandb.Table(columns=["ID", "Image"])
with torch.no_grad():
    for batch in test_loader:
        # Move everything to the device
        batch["pixel_values"] = batch["pixel_values"].to(device)
        batch["pixel_mask"] = batch["pixel_mask"].to(device)
        batch["mask_labels"] = [entry.to(device) for entry in batch["mask_labels"]]
        batch["class_labels"] = [entry.to(device) for entry in batch["class_labels"]]
        # Compute output and loss
        outputs = model(**batch)

        loss = outputs.loss
        # Record losses
        current_loss = loss.item() * batch["pixel_values"].size(0)
        test_running_loss += current_loss
        test_loader.set_postfix(loss=f"{current_loss:.4f}")

        # Extract and compute metrics
        pred_maps, masks = m2f_extract_pred_maps_and_masks(
            batch, outputs, m2f_preprocessor_A
        )
        metric.add_batch(references=masks, predictions=pred_maps)
        if BATCH_INDEX <5:
            # Visualize
            log_table_of_images(
                table, # common table for all batches
                batch["pixel_values"],
                pixel_mean_A, # remove normalization
                pixel_std_A, # remove normalization
                pred_maps,
                masks,
                BATCH_INDEX, # correct indexing in table
            )
            BATCH_INDEX += 1
# Log table
wandb.log({f"{CURR_TASK}_TEST_AFTER_TRAINING_B": table})

# After compute the batches that were added are deleted
test_metrics_forgetting_A = metric.compute(
    num_labels=NUM_CLASSES, ignore_index=BG_VALUE_255, reduce_labels=False
)
mean_test_iou = test_metrics_forgetting_A["mean_iou"]
final_test_loss = test_running_loss / len(dataloaders[CURR_TASK]["test"].dataset)
wandb.log({
    f"Loss/test_replay_all_{CURR_TASK}": final_test_loss,
    f"mIoU/test_replay_all_{CURR_TASK}": mean_test_iou
})
print(f"Test Loss: {final_test_loss:.4f}, Test mIoU: {mean_test_iou:.4f}")


In [None]:
# Collect overall mIoU
mIoU_A_before = test_metrics_A["mean_iou"]
mIoU_B_before=test_metrics_B_before["mean_iou"]
mIoU_forgetting_A = test_metrics_forgetting_A["mean_iou"]
mIoU_B = test_metrics_B["mean_iou"]

# Collect per category mIoU
per_category_mIoU_A_before = np.array(test_metrics_A["per_category_iou"])
per_category_mIoU_A = np.array(test_metrics_forgetting_A["per_category_iou"])
per_category_mIoU_B = np.array(test_metrics_B["per_category_iou"])
per_category_mIoU_B_before=np.array(test_metrics_B_before["per_category_iou"])

# Average learning accuracies (mIoUs)
avg_learning_acc = (mIoU_A_before + mIoU_B) / 2
per_category_avg_learning_acc = (per_category_mIoU_A_before + per_category_mIoU_B) / 2

# Forgetting
total_forgetting = mIoU_A_before - mIoU_forgetting_A
per_category_forgetting = (per_category_mIoU_A_before - per_category_mIoU_A)

# Export evaluation metrics to WandB
wandb.log({
    "eval/avg_learning_acc": avg_learning_acc,
    "eval/total_forgetting": total_forgetting,
})

columns=["categories","per_category_mIoU_A_before","per_category_mIoU_B_before",
         "per_category_mIoU_B", "per_category_mIoU_A",
         "per_category_avg_learning_acc","per_category_forgetting"]
data=[]

data.append(["background",per_category_mIoU_A_before[0],
                 per_category_mIoU_B_before[0],
                 per_category_mIoU_B[0],
                per_category_mIoU_A[0],per_category_avg_learning_acc[0],
                per_category_forgetting[0]])

for cat_id in range(1,12):
    data.append([ZEISS_CATEGORIES[cat_id],per_category_mIoU_A_before[cat_id],
                 per_category_mIoU_B_before[cat_id],
                 per_category_mIoU_B[cat_id],
                per_category_mIoU_A[cat_id],per_category_avg_learning_acc[cat_id],
                per_category_forgetting[cat_id]])
    
    
table = wandb.Table(columns=columns, data=data)
wandb.log({"per_category_metrics_table": table})

print("**** Overall mIoU ****")
print(f"mIoU on task A before training on B: {mIoU_A_before}")
print(f"mIoU on task B before training on B: {mIoU_B_before}")
print("\n")
print(f"mIoU on task B after training on B: {mIoU_B}")
print(f"mIoU on task A after training on B: {mIoU_forgetting_A}")

print("\n**** Per category mIoU ****")
print(f"Per category mIoU on task A before training on B: {per_category_mIoU_A_before}")
print(f"Per category mIoU on task B before training on B: {per_category_mIoU_B_before}")
print("\n")
print(f"Per category mIoU on task B after training on B: {per_category_mIoU_B}")
print(f"Per category mIoU on task A after training on B: {per_category_mIoU_A}")

print("\n**** Average learning accuracies ****")
print(f"Average learning acc.: {avg_learning_acc}")
print(f"Per category Average learning acc.: {per_category_avg_learning_acc}")

print("\n**** Forgetting ****")
print(f"Total forgetting: {total_forgetting}")
print(f"Per category forgetting: {per_category_forgetting}")
wandb.finish()

if os.path.exists("artifacts/"):
    shutil.rmtree("artifacts/")