In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import os
from utils.common import (
    m2f_dataset_collate,
    m2f_extract_pred_maps_and_masks,
    BG_VALUE_255,
    set_seed,
    pixel_mean_std,
    CADIS_PIXEL_MEAN,
    CADIS_PIXEL_STD,
)
from utils.dataset_utils import (
    get_cadisv2_dataset,
    get_cataract1k_dataset,
    ZEISS_CATEGORIES,
)
from utils.medical_datasets import Mask2FormerDataset
from transformers import (
    Mask2FormerForUniversalSegmentation,
    SwinModel,
    SwinConfig,
    Mask2FormerConfig,
    AutoImageProcessor,
    Mask2FormerImageProcessor,
)
from torch.utils.data import DataLoader
import evaluate
import torch.optim as optim
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from dotenv import load_dotenv
import wandb

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
set_seed(42) # seed everything

In [4]:
NUM_CLASSES = len(ZEISS_CATEGORIES) - 3  # Remove class incremental
SWIN_BACKBONE = "microsoft/swin-tiny-patch4-window7-224"#"microsoft/swin-large-patch4-window12-384"

# Download pretrained swin model
swin_model = SwinModel.from_pretrained(
    SWIN_BACKBONE, out_features=["stage1", "stage2", "stage3", "stage4"]
)
swin_config = SwinConfig.from_pretrained(
    SWIN_BACKBONE, out_features=["stage1", "stage2", "stage3", "stage4"]
)

# Create Mask2Former configuration based on Swin's configuration
mask2former_config = Mask2FormerConfig(
    backbone_config=swin_config, num_labels=NUM_CLASSES #, ignore_value=BG_VALUE
)

# Create the Mask2Former model with this configuration
model = Mask2FormerForUniversalSegmentation(mask2former_config)

# Reuse pretrained parameters
for swin_param, m2f_param in zip(
    swin_model.named_parameters(),
    model.model.pixel_level_module.encoder.named_parameters(),
):
    m2f_param_name = f"model.pixel_level_module.encoder.{m2f_param[0]}"

    if swin_param[0] == m2f_param[0]:
        model.state_dict()[m2f_param_name].copy_(swin_param[1])
        continue

    print(f"Not Matched: {m2f_param[0]} != {swin_param[0]}")



Not Matched: hidden_states_norms.stage1.weight != layernorm.weight
Not Matched: hidden_states_norms.stage1.bias != layernorm.bias


In [5]:
# Helper function to load datasets
def load_dataset(dataset_getter, data_path, domain_incremental):
    return dataset_getter(data_path, domain_incremental=domain_incremental)


# Helper function to create dataloaders for a dataset
def create_dataloaders(
    dataset, batch_size, shuffle, num_workers, drop_last, pin_memory, collate_fn
):
    return {
        "train": DataLoader(
            dataset["train"],
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=num_workers,
            drop_last=drop_last,
            pin_memory=pin_memory,
            collate_fn=collate_fn,
        ),
        "val": DataLoader(
            dataset["val"],
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=num_workers,
            drop_last=drop_last,
            pin_memory=pin_memory,
            collate_fn=collate_fn,
        ),
        "test": DataLoader(
            dataset["test"],
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=num_workers,
            drop_last=drop_last,
            pin_memory=pin_memory,
            collate_fn=collate_fn,
        ),
    }


# Load datasets
datasets = {
    "A": load_dataset(get_cadisv2_dataset, "../../storage/data/CaDISv2", True),
    "B": load_dataset(get_cataract1k_dataset, "../../storage/data/cataract-1k", True),
}

# pixel_mean_A,pixel_std_A=pixel_mean_std(datasets["A"][0])
pixel_mean_A = CADIS_PIXEL_MEAN
pixel_std_A = CADIS_PIXEL_STD

# This time define the B train dataset such that it replays approximately
# 32 MBs of images from the previous dataset. Each image is approximately 700 kBs
# Generate N random indices from dataset A
N = int(32 / 0.7)
random_indices = torch.randperm(len(datasets["A"][0]))[:N]

# Create a subset of B using the randomly sampled indices
subset_A = torch.utils.data.Subset(datasets["A"][0], random_indices)
new_train = torch.utils.data.ConcatDataset([subset_A, datasets["B"][0]])

pixel_mean_B,pixel_std_B=pixel_mean_std(new_train)
print("pixel mean of B",pixel_mean_B,"pixel std:",pixel_std_B)
#pixel_mean_B = REPLAY32_PIXEL_MEAN
#pixel_std_B = REPLAY32_PIXEL_STD

datasets["B"] = (new_train, datasets["B"][1], datasets["B"][2])

# Define preprocessor
swin_processor = AutoImageProcessor.from_pretrained(SWIN_BACKBONE)
m2f_preprocessor_A = Mask2FormerImageProcessor(
    reduce_labels=True,
    ignore_index=255,
    do_resize=False,
    do_rescale=True,
    do_normalize=True,
    image_std=pixel_std_A,
    image_mean=pixel_mean_A,
)

m2f_preprocessor_B = Mask2FormerImageProcessor(
    reduce_labels=True,
    ignore_index=255,
    do_resize=False,
    do_rescale=True,
    do_normalize=True,
    image_std=pixel_std_B,
    image_mean=pixel_mean_B,
)

# Create Mask2Former Datasets

m2f_datasets = {
    "A": {
        "train": Mask2FormerDataset(datasets["A"][0], m2f_preprocessor_A),
        "val": Mask2FormerDataset(datasets["A"][1], m2f_preprocessor_A),
        "test": Mask2FormerDataset(datasets["A"][2], m2f_preprocessor_A),
    },
    "B": {
        "train": Mask2FormerDataset(datasets["B"][0], m2f_preprocessor_B),
        "val": Mask2FormerDataset(datasets["B"][1], m2f_preprocessor_B),
        "test": Mask2FormerDataset(datasets["B"][2], m2f_preprocessor_B),
    },
}

# DataLoader parameters
N_WORKERS = 4
BATCH_SIZE = 16
SHUFFLE = True
DROP_LAST = True

dataloader_params = {
    "batch_size": BATCH_SIZE,
    "shuffle": SHUFFLE,
    "num_workers": N_WORKERS,
    "drop_last": DROP_LAST,
    "pin_memory": True,
    "collate_fn": m2f_dataset_collate,
}

# Create DataLoaders
dataloaders = {
    key: create_dataloaders(m2f_datasets[key], **dataloader_params)
    for key in m2f_datasets
}

print(dataloaders)



pixel mean of B [0.32117264 0.28945913 0.21137104] pixel std: [0.30119563 0.25390536 0.22080096]
{'A': {'train': <torch.utils.data.dataloader.DataLoader object at 0x7f546657a5a0>, 'val': <torch.utils.data.dataloader.DataLoader object at 0x7f546657a630>, 'test': <torch.utils.data.dataloader.DataLoader object at 0x7f546657a450>}, 'B': {'train': <torch.utils.data.dataloader.DataLoader object at 0x7f54666b02c0>, 'val': <torch.utils.data.dataloader.DataLoader object at 0x7f546657a390>, 'test': <torch.utils.data.dataloader.DataLoader object at 0x7f546657a060>}}


In [6]:
# Check if CUDA is available, otherwise use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Using device: {device}")

Using device: cuda


In [7]:
# Tensorboard setup
out_dir="outputs/"
if not os.path.exists(out_dir):
    os.makedirs(out_dir)
if not os.path.exists(out_dir+"runs"):
    os.makedirs(out_dir+"runs")
%load_ext tensorboard
%tensorboard --logdir outputs/runs

In [None]:
#!rm -r outputs

In [None]:
#!CUDA_LAUNCH_BLOCKING=1

# First train on dataset A

In [8]:
# Training
NUM_EPOCHS = 200
LEARNING_RATE = 1e-4
LR_MULTIPLIER = 0.1
BACKBONE_LR = LEARNING_RATE * LR_MULTIPLIER
WEIGHT_DECAY = 0.05
PATIENCE=15
metric = evaluate.load("mean_iou")
encoder_params = [
    param
    for name, param in model.named_parameters()
    if name.startswith("model.pixel_level_module.encoder")
]
decoder_params = [
    param
    for name, param in model.named_parameters()
    if name.startswith("model.pixel_level_module.decoder")
]
transformer_params = [
    param
    for name, param in model.named_parameters()
    if name.startswith("model.transformer_module")
]
optimizer = optim.AdamW(
    [
        {"params": encoder_params, "lr": BACKBONE_LR},
        {"params": decoder_params},
        {"params": transformer_params},
    ],
    lr=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
)

scheduler = optim.lr_scheduler.PolynomialLR(
    optimizer, total_iters=NUM_EPOCHS, power=0.9
)

In [9]:
# WandB for team usage !!!!

wandb.login() # use this one if a different person is going to run the notebook
#wandb.login(relogin=False) # if the same person in the last run is going to run the notebook again


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mkristiyan-sakalyan[0m ([33mcontinual-learning-tum[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [12]:
wandb.init(
    project="M2F_original",
    config={
        "learning_rate": LEARNING_RATE,
        "learning_rate_multiplier": LR_MULTIPLIER,
        "backbone_learning_rate": BACKBONE_LR,
        "learning_rate_scheduler": scheduler.__class__.__name__,
        "optimizer": optimizer.__class__.__name__,
        "backbone": SWIN_BACKBONE,
        "m2f_preprocessor": m2f_preprocessor_A.__dict__,
        "m2f_model_config": model.config
    },
    name="M2F-Swin-Tiny-Train_Cadis",
    notes="M2F with tiny Swin backbone pretrained on ImageNet-1K. \
        Scenario: Train on A, Test on A"
)

[34m[1mwandb[0m: Currently logged in as: [33mge85ket[0m ([33mcontinual-learning-tum[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [10]:
# Tensorboard logging
writer = SummaryWriter(log_dir=out_dir + "runs")

# Model checkpointing
base_model_name="m2f_swin_backbone_train_cadis"
model_dir = out_dir + "models/"
if not os.path.exists(model_dir):
    print("Store weights in: ", model_dir)
    os.makedirs(model_dir)

best_model_dir = model_dir + f"{base_model_name}/best_model/"
if not os.path.exists(best_model_dir):
    print("Store best model weights in: ", best_model_dir)
    os.makedirs(best_model_dir)
final_model_dir = model_dir + f"{base_model_name}/final_model/"
if not os.path.exists(final_model_dir):
    print("Store final model weights in: ", final_model_dir)
    os.makedirs(final_model_dir)
    

In [17]:
# Save the preprocessor
m2f_preprocessor_A.save_pretrained(model_dir + base_model_name)

['outputs/models/m2f_swin_backbone_train_cadis/preprocessor_config.json']

In [18]:
# To avoid making stupid errors
CURR_TASK = "A"

# For storing the model
best_val_metric = -np.inf

# Move model to device
model.to(device)
counter=0
for epoch in range(NUM_EPOCHS):
    model.train()
    train_running_loss = 0.0
    val_running_loss = 0.0

    # Set up tqdm for the training loop
    train_loader = tqdm(
        dataloaders[CURR_TASK]["train"], desc=f"Epoch {epoch + 1}/{NUM_EPOCHS} Training"
    )

    for batch in train_loader:
        # Move everything to the device
        batch["pixel_values"] = batch["pixel_values"].to(device)
        batch["pixel_mask"] = batch["pixel_mask"].to(device)
        batch["mask_labels"] = [entry.to(device) for entry in batch["mask_labels"]]
        batch["class_labels"] = [entry.to(device) for entry in batch["class_labels"]]

        # Compute output and loss
        outputs = model(**batch)

        loss = outputs.loss

        # Compute gradient and perform step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Record losses
        current_loss = loss.item() * batch["pixel_values"].size(0)
        train_running_loss += current_loss
        train_loader.set_postfix(loss=f"{current_loss:.4f}")

        # Extract and compute metrics
        pred_maps, masks = m2f_extract_pred_maps_and_masks(
            batch, outputs, m2f_preprocessor_A
        )
        metric.add_batch(references=masks, predictions=pred_maps)
        
    

    # After compute the batches that were added are deleted
    mean_train_iou = metric.compute(
        num_labels=NUM_CLASSES, ignore_index=BG_VALUE_255, reduce_labels=False
    )["mean_iou"]

    # Validation phase
    model.eval()
    val_loader = tqdm(
        dataloaders[CURR_TASK]["val"], desc=f"Epoch {epoch + 1}/{NUM_EPOCHS} Validation"
    )
    with torch.no_grad():
        for batch in val_loader:
            # Move everything to the device
            batch["pixel_values"] = batch["pixel_values"].to(device)
            batch["pixel_mask"] = batch["pixel_mask"].to(device)
            batch["mask_labels"] = [entry.to(device) for entry in batch["mask_labels"]]
            batch["class_labels"] = [
                entry.to(device) for entry in batch["class_labels"]
            ]
            # Compute output and loss
            outputs = model(**batch)

            loss = outputs.loss
            # Record losses
            current_loss = loss.item() * batch["pixel_values"].size(0)
            val_running_loss += current_loss
            val_loader.set_postfix(loss=f"{current_loss:.4f}")

            # Extract and compute metrics
            pred_maps, masks = m2f_extract_pred_maps_and_masks(
                batch, outputs, m2f_preprocessor_A
            )
            metric.add_batch(references=masks, predictions=pred_maps)
            

    # After compute the batches that were added are deleted
    mean_val_iou = metric.compute(
        num_labels=NUM_CLASSES, ignore_index=BG_VALUE_255, reduce_labels=False
    )["mean_iou"]

    epoch_train_loss = train_running_loss / len(dataloaders[CURR_TASK]["train"].dataset)
    epoch_val_loss = val_running_loss / len(dataloaders[CURR_TASK]["val"].dataset)

    writer.add_scalar(f"Loss/train_{base_model_name}_{CURR_TASK}", epoch_train_loss, epoch + 1)
    writer.add_scalar(f"Loss/val_{base_model_name}_{CURR_TASK}", epoch_val_loss, epoch + 1)
    writer.add_scalar(f"mIoU/train_{base_model_name}_{CURR_TASK}", mean_train_iou, epoch + 1)
    writer.add_scalar(f"mIoU/val_{base_model_name}_{CURR_TASK}", mean_val_iou, epoch + 1)

    wandb.log({
        f"Loss/train_{CURR_TASK}": epoch_train_loss,
        f"Loss/val_{CURR_TASK}": epoch_val_loss,
        f"mIoU/train_{CURR_TASK}": mean_train_iou,
        f"mIoU/val_{CURR_TASK}": mean_val_iou
    })


    tqdm.write(
        f"Epoch {epoch + 1}/{NUM_EPOCHS}, Train Loss: {epoch_train_loss:.4f}, Train mIoU: {mean_train_iou:.4f}, Validation Loss: {epoch_val_loss:.4f}, Validation mIoU: {mean_val_iou:.4f}"
    )
    
    if mean_val_iou > best_val_metric:
        best_val_metric = mean_val_iou
        model.save_pretrained(f"{best_model_dir}{CURR_TASK}/")
        counter=0
    else:
        counter+=1
        if counter == PATIENCE:
            print("Early stopping at epoch",epoch)
            break

Epoch 1/200 Training:   0%|          | 0/1775 [00:01<?, ?it/s, loss=222.9132]
  iou = total_area_intersect / total_area_union
  acc = total_area_intersect / total_area_label
Epoch 1/200 Validation:   0%|          | 0/267 [00:01<?, ?it/s, loss=194.8053]


Epoch 1/200, Train Loss: 0.0628, Train mIoU: 0.0000, Validation Loss: 0.3648, Validation mIoU: 0.2076


## Test results on A

In [11]:
# Load best model and evaluate on test
CURR_TASK = "A"
model = Mask2FormerForUniversalSegmentation.from_pretrained(f"{best_model_dir}{CURR_TASK}/").to(device)

In [12]:
model.eval()
test_running_loss = 0
test_loader = tqdm(dataloaders[CURR_TASK]["test"], desc="Test loop")
with torch.no_grad():
    for batch in test_loader:
        # Move everything to the device
        batch["pixel_values"] = batch["pixel_values"].to(device)
        batch["pixel_mask"] = batch["pixel_mask"].to(device)
        batch["mask_labels"] = [entry.to(device) for entry in batch["mask_labels"]]
        batch["class_labels"] = [entry.to(device) for entry in batch["class_labels"]]
        # Compute output and loss
        outputs = model(**batch)

        loss = outputs.loss
        # Record losses
        current_loss = loss.item() * batch["pixel_values"].size(0)
        test_running_loss += current_loss
        test_loader.set_postfix(loss=f"{current_loss:.4f}")

        # Extract and compute metrics
        pred_maps, masks = m2f_extract_pred_maps_and_masks(
            batch, outputs, m2f_preprocessor_A
        )
        metric.add_batch(references=masks, predictions=pred_maps)
        
# After compute the batches that were added are deleted
test_metrics_A = metric.compute(
    num_labels=NUM_CLASSES, ignore_index=BG_VALUE_255, reduce_labels=False
)
mean_test_iou = test_metrics_A["mean_iou"]

final_test_loss = test_running_loss / len(dataloaders[CURR_TASK]["test"].dataset)
#wandb.log({
#    f"Loss/test_{CURR_TASK}": final_test_loss,
#    f"mIoU/test_{CURR_TASK}": mean_test_iou
#})
print(f"Test Loss: {final_test_loss:.4f}, Test mIoU: {mean_test_iou:.4f}")
#wandb.finish()

Test loop: 100%|██████████| 36/36 [00:31<00:00,  1.15it/s, loss=218.7560]


Test Loss: 14.8976, Test mIoU: 0.7877


In [13]:
test_metrics_A

{'mean_iou': 0.7877320741744921,
 'mean_accuracy': 0.8500004734317124,
 'overall_accuracy': 0.9620660501296412,
 'per_category_iou': array([0.95360882, 0.91871854, 0.687211  , 0.63869946, 0.28185184,
        0.81372453, 0.79456505, 0.78816018, 0.91600346, 0.94137199,
        0.93113795]),
 'per_category_accuracy': array([0.98604046, 0.98150333, 0.79010483, 0.72754962, 0.2902951 ,
        0.9256408 , 0.87657205, 0.89267951, 0.94431993, 0.97042209,
        0.96487748])}

# Sample images from the training dataset of A with max loss

In [14]:
losses = []

# Collect losses
model.eval()
with torch.no_grad():
    for sample in tqdm(m2f_datasets["A"]["train"]):
        sample["pixel_values"] = sample["pixel_values"].to(device)
        sample["pixel_mask"] = sample["pixel_mask"].to(device)
        sample["mask_labels"] = [entry.to(device) for entry in sample["mask_labels"]]
        sample["class_labels"] = [entry.to(device) for entry in sample["class_labels"]]
        losses.append(model(**sample).loss.item())

losses_np = np.array(losses)

# Sample images with max loss
closest_indices = np.argsort(losses_np)[-N:] # N was calculated above

# Create a subset of B using the mean loss sampled indices
subset_A = torch.utils.data.Subset(datasets["A"][0], closest_indices)
new_train = torch.utils.data.ConcatDataset([subset_A, datasets["B"][0]])

# Calculate new mean and std
pixel_mean_B,pixel_std_B=pixel_mean_std(new_train)
print("pixel mean of B",pixel_mean_B,"pixel std:",pixel_std_B)

datasets["B"] = (new_train, datasets["B"][1], datasets["B"][2])

m2f_preprocessor_B = Mask2FormerImageProcessor(
    reduce_labels=True,
    ignore_index=255,
    do_resize=False,
    do_rescale=True,
    do_normalize=True,
    image_std=pixel_std_B,
    image_mean=pixel_mean_B,
)

# Create Mask2Former Datasets

m2f_datasets = {
    "A": {
        "train": Mask2FormerDataset(datasets["A"][0], m2f_preprocessor_A),
        "val": Mask2FormerDataset(datasets["A"][1], m2f_preprocessor_A),
        "test": Mask2FormerDataset(datasets["A"][2], m2f_preprocessor_A),
    },
    "B": {
        "train": Mask2FormerDataset(datasets["B"][0], m2f_preprocessor_B),
        "val": Mask2FormerDataset(datasets["B"][1], m2f_preprocessor_B),
        "test": Mask2FormerDataset(datasets["B"][2], m2f_preprocessor_B),
    },
}

# DataLoader parameters
N_WORKERS = 4
BATCH_SIZE = 16
SHUFFLE = True
DROP_LAST = True

dataloader_params = {
    "batch_size": BATCH_SIZE,
    "shuffle": SHUFFLE,
    "num_workers": N_WORKERS,
    "drop_last": DROP_LAST,
    "pin_memory": True,
    "collate_fn": m2f_dataset_collate,
}

# Create DataLoaders
dataloaders = {
    key: create_dataloaders(m2f_datasets[key], **dataloader_params)
    for key in m2f_datasets
}

print(dataloaders)

100%|██████████| 3550/3550 [10:11<00:00,  5.80it/s]


pixel mean of B [0.32712337 0.29087587 0.21119382] pixel std: [0.30112183 0.25227443 0.21882615]
{'A': {'train': <torch.utils.data.dataloader.DataLoader object at 0x7f546b397a10>, 'val': <torch.utils.data.dataloader.DataLoader object at 0x7f546b3e75c0>, 'test': <torch.utils.data.dataloader.DataLoader object at 0x7f546b3c6c60>}, 'B': {'train': <torch.utils.data.dataloader.DataLoader object at 0x7f546b3c67e0>, 'val': <torch.utils.data.dataloader.DataLoader object at 0x7f546b3f7ce0>, 'test': <torch.utils.data.dataloader.DataLoader object at 0x7f546b3f7b60>}}


# Now train on B and forget A

In [15]:
# Training
NUM_EPOCHS = 200
LEARNING_RATE = 1e-4
LR_MULTIPLIER = 0.1
BACKBONE_LR = LEARNING_RATE * LR_MULTIPLIER
WEIGHT_DECAY = 0.05
PATIENCE=15
encoder_params = [
    param
    for name, param in model.named_parameters()
    if name.startswith("model.pixel_level_module.encoder")
]
decoder_params = [
    param
    for name, param in model.named_parameters()
    if name.startswith("model.pixel_level_module.decoder")
]
transformer_params = [
    param
    for name, param in model.named_parameters()
    if name.startswith("model.transformer_module")
]
optimizer = optim.AdamW(
    [
        {"params": encoder_params, "lr": BACKBONE_LR},
        {"params": decoder_params},
        {"params": transformer_params},
    ],
    lr=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
)

scheduler = optim.lr_scheduler.PolynomialLR(
    optimizer, total_iters=NUM_EPOCHS, power=0.9
)

In [16]:
# WandB for team usage !!!!

wandb.login() # use this one if a different person is going to run the notebook
#wandb.login(relogin=False) # if the same person in the last run is going to run the notebook again


True

In [17]:
wandb.init(
    project="M2F_original",
    config={
        "learning_rate": LEARNING_RATE,
        "learning_rate_multiplier": LR_MULTIPLIER,
        "backbone_learning_rate": BACKBONE_LR,
        "learning_rate_scheduler": scheduler.__class__.__name__,
        "optimizer": optimizer.__class__.__name__,
        "backbone": SWIN_BACKBONE,
        "m2f_preprocessor": m2f_preprocessor_B.__dict__,
        "m2f_model_config": model.config
    },
    name="M2F-Swin-Tiny-Max-Loss-32MBReplay",
    notes="M2F with tiny Swin backbone pretrained on ImageNet-1K. \
        Scenario: Pretrained on A, Train on B, Test 32MB max loss replay on A"
)

# Tensorboard logging
writer = SummaryWriter(log_dir=out_dir + "runs")

# Model checkpointing
model_name = "m2f_swin_backbone_replay_max_loss_32MBs"
model_dir = out_dir + "models/"
if not os.path.exists(model_dir):
    print("Store weights in: ", model_dir)
    os.makedirs(model_dir)

best_model_dir = model_dir + f"{model_name}/best_model/"
if not os.path.exists(best_model_dir):
    print("Store best model weights in: ", best_model_dir)
    os.makedirs(best_model_dir)
final_model_dir = model_dir + f"{model_name}/final_model/"
if not os.path.exists(final_model_dir):
    print("Store final model weights in: ", final_model_dir)
    os.makedirs(final_model_dir)

Store best model weights in:  outputs/models/m2f_swin_backbone_replay_max_loss_32MBs/best_model/
Store final model weights in:  outputs/models/m2f_swin_backbone_replay_max_loss_32MBs/final_model/


In [18]:
# Save the preprocessor
m2f_preprocessor_B.save_pretrained(model_dir + model_name)

['outputs/models/m2f_swin_backbone_replay_max_loss_32MBs/preprocessor_config.json']

In [19]:
# To avoid making stupid errors
CURR_TASK = "B"

# For storing the model
best_val_metric = -np.inf

# Move model to device
model.to(device)
counter=0
for epoch in range(NUM_EPOCHS):
    model.train()
    train_running_loss = 0.0
    val_running_loss = 0.0

    # Set up tqdm for the training loop
    train_loader = tqdm(
        dataloaders[CURR_TASK]["train"], desc=f"Epoch {epoch + 1}/{NUM_EPOCHS} Training"
    )

    for batch in train_loader:
        # Move everything to the device
        batch["pixel_values"] = batch["pixel_values"].to(device)
        batch["pixel_mask"] = batch["pixel_mask"].to(device)
        batch["mask_labels"] = [entry.to(device) for entry in batch["mask_labels"]]
        batch["class_labels"] = [entry.to(device) for entry in batch["class_labels"]]

        # Compute output and loss
        outputs = model(**batch)

        loss = outputs.loss

        # Compute gradient and perform step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Record losses
        current_loss = loss.item() * batch["pixel_values"].size(0)
        train_running_loss += current_loss
        train_loader.set_postfix(loss=f"{current_loss:.4f}")

        # Extract and compute metrics
        pred_maps, masks = m2f_extract_pred_maps_and_masks(
            batch, outputs, m2f_preprocessor_B
        )
        metric.add_batch(references=masks, predictions=pred_maps)
        

    # After compute the batches that were added are deleted
    mean_train_iou = metric.compute(
        num_labels=NUM_CLASSES, ignore_index=BG_VALUE_255, reduce_labels=False
    )["mean_iou"]

    # Validation phase
    model.eval()
    val_loader = tqdm(
        dataloaders[CURR_TASK]["val"], desc=f"Epoch {epoch + 1}/{NUM_EPOCHS} Validation"
    )
    with torch.no_grad():
        for batch in val_loader:
            # Move everything to the device
            batch["pixel_values"] = batch["pixel_values"].to(device)
            batch["pixel_mask"] = batch["pixel_mask"].to(device)
            batch["mask_labels"] = [entry.to(device) for entry in batch["mask_labels"]]
            batch["class_labels"] = [
                entry.to(device) for entry in batch["class_labels"]
            ]
            # Compute output and loss
            outputs = model(**batch)

            loss = outputs.loss
            # Record losses
            current_loss = loss.item() * batch["pixel_values"].size(0)
            val_running_loss += current_loss
            val_loader.set_postfix(loss=f"{current_loss:.4f}")

            # Extract and compute metrics
            pred_maps, masks = m2f_extract_pred_maps_and_masks(
                batch, outputs, m2f_preprocessor_B
            )
            metric.add_batch(references=masks, predictions=pred_maps)
            

    # After compute the batches that were added are deleted
    mean_val_iou = metric.compute(
        num_labels=NUM_CLASSES, ignore_index=BG_VALUE_255, reduce_labels=False
    )["mean_iou"]

    epoch_train_loss = train_running_loss / len(dataloaders[CURR_TASK]["train"].dataset)
    epoch_val_loss = val_running_loss / len(dataloaders[CURR_TASK]["val"].dataset)

    writer.add_scalar(f"Loss/train_{model_name}_{CURR_TASK}", epoch_train_loss, epoch + 1)
    writer.add_scalar(f"Loss/val_{model_name}_{CURR_TASK}", epoch_val_loss, epoch + 1)
    writer.add_scalar(f"mIoU/train_{model_name}_{CURR_TASK}", mean_train_iou, epoch + 1)
    writer.add_scalar(f"mIoU/val_{model_name}_{CURR_TASK}", mean_val_iou, epoch + 1)

    wandb.log({
        f"Loss/train_replay_A_{CURR_TASK}": epoch_train_loss,
        f"Loss/val_replay_A_{CURR_TASK}": epoch_val_loss,
        f"mIoU/train_replay_A_{CURR_TASK}": mean_train_iou,
        f"mIoU/val_replay_A_{CURR_TASK}": mean_val_iou
    })


    tqdm.write(
        f"Epoch {epoch + 1}/{NUM_EPOCHS}, Train Loss: {epoch_train_loss:.4f}, Train mIoU: {mean_train_iou:.4f}, Validation Loss: {epoch_val_loss:.4f}, Validation mIoU: {mean_val_iou:.4f}"
    )
    
    if mean_val_iou > best_val_metric:
        best_val_metric = mean_val_iou
        model.save_pretrained(f"{best_model_dir}{CURR_TASK}/")
        counter=0
    else:
        counter+=1
        if counter == PATIENCE:
            print("Early stopping at epoch",epoch)
            break

Epoch 1/200 Training: 100%|██████████| 118/118 [02:49<00:00,  1.44s/it, loss=323.3484]
Epoch 1/200 Validation: 100%|██████████| 14/14 [00:23<00:00,  1.67s/it, loss=242.3506]


Epoch 1/200, Train Loss: 22.3490, Train mIoU: 0.5186, Validation Loss: 17.1329, Validation mIoU: 0.5998


Epoch 2/200 Training: 100%|██████████| 118/118 [03:03<00:00,  1.56s/it, loss=213.6685]
Epoch 2/200 Validation: 100%|██████████| 14/14 [00:10<00:00,  1.30it/s, loss=252.0575]


Epoch 2/200, Train Loss: 16.8354, Train mIoU: 0.6506, Validation Loss: 13.8539, Validation mIoU: 0.7329


Epoch 3/200 Training: 100%|██████████| 118/118 [02:34<00:00,  1.31s/it, loss=210.6209]
Epoch 3/200 Validation: 100%|██████████| 14/14 [00:12<00:00,  1.09it/s, loss=201.2150]


Epoch 3/200, Train Loss: 13.2635, Train mIoU: 0.7419, Validation Loss: 10.8973, Validation mIoU: 0.7129


Epoch 4/200 Training: 100%|██████████| 118/118 [02:59<00:00,  1.52s/it, loss=180.0348]
Epoch 4/200 Validation: 100%|██████████| 14/14 [00:12<00:00,  1.12it/s, loss=154.2014]


Epoch 4/200, Train Loss: 11.0680, Train mIoU: 0.8059, Validation Loss: 9.8355, Validation mIoU: 0.7635


Epoch 5/200 Training: 100%|██████████| 118/118 [02:28<00:00,  1.26s/it, loss=156.2476]
Epoch 5/200 Validation: 100%|██████████| 14/14 [00:10<00:00,  1.32it/s, loss=180.2553]


Epoch 5/200, Train Loss: 10.4177, Train mIoU: 0.8128, Validation Loss: 10.4831, Validation mIoU: 0.7919


Epoch 6/200 Training: 100%|██████████| 118/118 [03:06<00:00,  1.58s/it, loss=139.0027]
Epoch 6/200 Validation: 100%|██████████| 14/14 [00:10<00:00,  1.30it/s, loss=203.2075]


Epoch 6/200, Train Loss: 9.6969, Train mIoU: 0.8589, Validation Loss: 9.0283, Validation mIoU: 0.7833


Epoch 7/200 Training: 100%|██████████| 118/118 [02:32<00:00,  1.29s/it, loss=151.6768]
Epoch 7/200 Validation: 100%|██████████| 14/14 [00:12<00:00,  1.13it/s, loss=120.2030]


Epoch 7/200, Train Loss: 9.1153, Train mIoU: 0.8602, Validation Loss: 8.8996, Validation mIoU: 0.7753


Epoch 8/200 Training: 100%|██████████| 118/118 [02:26<00:00,  1.24s/it, loss=130.8290]
Epoch 8/200 Validation: 100%|██████████| 14/14 [00:12<00:00,  1.14it/s, loss=132.0407]


Epoch 8/200, Train Loss: 8.3360, Train mIoU: 0.8971, Validation Loss: 9.1568, Validation mIoU: 0.7768


Epoch 9/200 Training: 100%|██████████| 118/118 [03:09<00:00,  1.61s/it, loss=124.7300]
Epoch 9/200 Validation: 100%|██████████| 14/14 [00:12<00:00,  1.13it/s, loss=176.2644]


Epoch 9/200, Train Loss: 7.9949, Train mIoU: 0.9046, Validation Loss: 8.1433, Validation mIoU: 0.8324


Epoch 10/200 Training: 100%|██████████| 118/118 [02:42<00:00,  1.38s/it, loss=105.8107]
Epoch 10/200 Validation: 100%|██████████| 14/14 [00:13<00:00,  1.05it/s, loss=107.8936]


Epoch 10/200, Train Loss: 7.5219, Train mIoU: 0.9004, Validation Loss: 8.3072, Validation mIoU: 0.7658


Epoch 11/200 Training: 100%|██████████| 118/118 [02:27<00:00,  1.25s/it, loss=104.3380]
Epoch 11/200 Validation: 100%|██████████| 14/14 [00:21<00:00,  1.57s/it, loss=111.6945]


Epoch 11/200, Train Loss: 7.3008, Train mIoU: 0.8897, Validation Loss: 8.1821, Validation mIoU: 0.8124


Epoch 12/200 Training: 100%|██████████| 118/118 [02:22<00:00,  1.21s/it, loss=100.1026]
Epoch 12/200 Validation: 100%|██████████| 14/14 [00:14<00:00,  1.07s/it, loss=106.3854]


Epoch 12/200, Train Loss: 7.4358, Train mIoU: 0.9234, Validation Loss: 7.9073, Validation mIoU: 0.7561


Epoch 13/200 Training: 100%|██████████| 118/118 [02:24<00:00,  1.22s/it, loss=125.2439]
Epoch 13/200 Validation: 100%|██████████| 14/14 [00:14<00:00,  1.02s/it, loss=111.6215]


Epoch 13/200, Train Loss: 6.8301, Train mIoU: 0.9243, Validation Loss: 7.9004, Validation mIoU: 0.7981


Epoch 14/200 Training: 100%|██████████| 118/118 [02:29<00:00,  1.27s/it, loss=106.8927]
Epoch 14/200 Validation: 100%|██████████| 14/14 [00:13<00:00,  1.02it/s, loss=98.4945] 


Epoch 14/200, Train Loss: 6.6442, Train mIoU: 0.9274, Validation Loss: 7.9772, Validation mIoU: 0.8954


Epoch 15/200 Training: 100%|██████████| 118/118 [02:36<00:00,  1.33s/it, loss=105.9120]
Epoch 15/200 Validation: 100%|██████████| 14/14 [00:13<00:00,  1.04it/s, loss=107.7376]


Epoch 15/200, Train Loss: 6.5497, Train mIoU: 0.9413, Validation Loss: 7.7496, Validation mIoU: 0.8149


Epoch 16/200 Training: 100%|██████████| 118/118 [02:29<00:00,  1.27s/it, loss=97.7808] 
Epoch 16/200 Validation: 100%|██████████| 14/14 [00:13<00:00,  1.01it/s, loss=166.9202]


Epoch 16/200, Train Loss: 6.3080, Train mIoU: 0.9491, Validation Loss: 7.5995, Validation mIoU: 0.8751


Epoch 17/200 Training: 100%|██████████| 118/118 [02:53<00:00,  1.47s/it, loss=95.3673] 
Epoch 17/200 Validation: 100%|██████████| 14/14 [00:11<00:00,  1.20it/s, loss=98.7694] 


Epoch 17/200, Train Loss: 6.1701, Train mIoU: 0.9505, Validation Loss: 7.7295, Validation mIoU: 0.8299


Epoch 18/200 Training: 100%|██████████| 118/118 [02:35<00:00,  1.32s/it, loss=90.8235] 
Epoch 18/200 Validation: 100%|██████████| 14/14 [00:10<00:00,  1.36it/s, loss=103.9548]


Epoch 18/200, Train Loss: 6.1571, Train mIoU: 0.9290, Validation Loss: 7.9546, Validation mIoU: 0.8501


Epoch 19/200 Training: 100%|██████████| 118/118 [02:28<00:00,  1.26s/it, loss=88.9108] 
Epoch 19/200 Validation: 100%|██████████| 14/14 [00:11<00:00,  1.21it/s, loss=168.0760]


Epoch 19/200, Train Loss: 5.9546, Train mIoU: 0.9532, Validation Loss: 7.5331, Validation mIoU: 0.8169


Epoch 20/200 Training: 100%|██████████| 118/118 [02:33<00:00,  1.30s/it, loss=89.1680] 
Epoch 20/200 Validation: 100%|██████████| 14/14 [00:11<00:00,  1.21it/s, loss=132.7462]


Epoch 20/200, Train Loss: 5.7608, Train mIoU: 0.9156, Validation Loss: 7.7315, Validation mIoU: 0.8088


Epoch 21/200 Training: 100%|██████████| 118/118 [02:33<00:00,  1.30s/it, loss=94.1937] 
Epoch 21/200 Validation: 100%|██████████| 14/14 [00:13<00:00,  1.04it/s, loss=127.9835]


Epoch 21/200, Train Loss: 5.6025, Train mIoU: 0.9538, Validation Loss: 7.6789, Validation mIoU: 0.9120


Epoch 22/200 Training: 100%|██████████| 118/118 [02:31<00:00,  1.29s/it, loss=80.1660] 
Epoch 22/200 Validation: 100%|██████████| 14/14 [00:13<00:00,  1.06it/s, loss=125.0520]


Epoch 22/200, Train Loss: 5.8821, Train mIoU: 0.9506, Validation Loss: 7.5573, Validation mIoU: 0.8288


Epoch 23/200 Training: 100%|██████████| 118/118 [02:31<00:00,  1.28s/it, loss=101.1543]
Epoch 23/200 Validation: 100%|██████████| 14/14 [00:11<00:00,  1.25it/s, loss=115.0213]


Epoch 23/200, Train Loss: 5.8357, Train mIoU: 0.9337, Validation Loss: 8.2358, Validation mIoU: 0.8764


Epoch 24/200 Training: 100%|██████████| 118/118 [02:23<00:00,  1.22s/it, loss=93.8433] 
Epoch 24/200 Validation: 100%|██████████| 14/14 [00:14<00:00,  1.00s/it, loss=107.6767]


Epoch 24/200, Train Loss: 5.6023, Train mIoU: 0.9436, Validation Loss: 7.6835, Validation mIoU: 0.8210


Epoch 25/200 Training: 100%|██████████| 118/118 [02:25<00:00,  1.23s/it, loss=88.2525] 
Epoch 25/200 Validation: 100%|██████████| 14/14 [00:15<00:00,  1.09s/it, loss=118.2592]


Epoch 25/200, Train Loss: 5.4820, Train mIoU: 0.9509, Validation Loss: 7.7134, Validation mIoU: 0.8574


Epoch 26/200 Training: 100%|██████████| 118/118 [02:25<00:00,  1.23s/it, loss=86.0836] 
Epoch 26/200 Validation: 100%|██████████| 14/14 [00:15<00:00,  1.12s/it, loss=167.3118]


Epoch 26/200, Train Loss: 5.4454, Train mIoU: 0.9648, Validation Loss: 7.7783, Validation mIoU: 0.8397


Epoch 27/200 Training: 100%|██████████| 118/118 [02:31<00:00,  1.28s/it, loss=99.2491] 
Epoch 27/200 Validation: 100%|██████████| 14/14 [00:15<00:00,  1.14s/it, loss=165.0406]


Epoch 27/200, Train Loss: 5.2592, Train mIoU: 0.9624, Validation Loss: 7.8690, Validation mIoU: 0.8226


Epoch 31/200 Validation: 100%|██████████| 14/14 [00:12<00:00,  1.16it/s, loss=149.5909]


Epoch 31/200, Train Loss: 4.9299, Train mIoU: 0.9551, Validation Loss: 7.7778, Validation mIoU: 0.8104


Epoch 32/200 Training: 100%|██████████| 118/118 [02:33<00:00,  1.30s/it, loss=79.9247]
Epoch 32/200 Validation: 100%|██████████| 14/14 [00:10<00:00,  1.29it/s, loss=98.1811] 


Epoch 32/200, Train Loss: 4.8338, Train mIoU: 0.9668, Validation Loss: 7.7229, Validation mIoU: 0.8163


Epoch 33/200 Training: 100%|██████████| 118/118 [02:11<00:00,  1.11s/it, loss=68.5262]
Epoch 33/200 Validation: 100%|██████████| 14/14 [00:11<00:00,  1.22it/s, loss=123.2453]


Epoch 33/200, Train Loss: 4.7558, Train mIoU: 0.9722, Validation Loss: 7.5337, Validation mIoU: 0.8149


Epoch 34/200 Training: 100%|██████████| 118/118 [02:18<00:00,  1.18s/it, loss=77.7048] 
Epoch 34/200 Validation: 100%|██████████| 14/14 [00:10<00:00,  1.35it/s, loss=102.4156]


Epoch 34/200, Train Loss: 4.7940, Train mIoU: 0.9592, Validation Loss: 7.9700, Validation mIoU: 0.8079


Epoch 35/200 Training: 100%|██████████| 118/118 [02:17<00:00,  1.17s/it, loss=84.6143] 
Epoch 35/200 Validation: 100%|██████████| 14/14 [00:11<00:00,  1.27it/s, loss=104.6482]


Epoch 35/200, Train Loss: 4.9532, Train mIoU: 0.9568, Validation Loss: 7.9266, Validation mIoU: 0.8783


Epoch 36/200 Training: 100%|██████████| 118/118 [02:20<00:00,  1.19s/it, loss=104.4526]
Epoch 36/200 Validation: 100%|██████████| 14/14 [00:13<00:00,  1.04it/s, loss=192.2474]


Epoch 36/200, Train Loss: 4.7924, Train mIoU: 0.9652, Validation Loss: 8.0746, Validation mIoU: 0.8429
Early stopping at epoch 35


## Test results on B first

In [None]:
# Load best model and evaluate on test
model = Mask2FormerForUniversalSegmentation.from_pretrained(f"{best_model_dir}{CURR_TASK}/").to(device)

In [None]:
model.eval()
test_running_loss = 0
test_loader = tqdm(dataloaders[CURR_TASK]["test"], desc="Test loop")
with torch.no_grad():
    for batch in test_loader:
        # Move everything to the device
        batch["pixel_values"] = batch["pixel_values"].to(device)
        batch["pixel_mask"] = batch["pixel_mask"].to(device)
        batch["mask_labels"] = [entry.to(device) for entry in batch["mask_labels"]]
        batch["class_labels"] = [entry.to(device) for entry in batch["class_labels"]]
        # Compute output and loss
        outputs = model(**batch)

        loss = outputs.loss
        # Record losses
        current_loss = loss.item() * batch["pixel_values"].size(0)
        test_running_loss += current_loss
        test_loader.set_postfix(loss=f"{current_loss:.4f}")

        # Extract and compute metrics
        pred_maps, masks = m2f_extract_pred_maps_and_masks(
            batch, outputs, m2f_preprocessor_B
        )
        metric.add_batch(references=masks, predictions=pred_maps)
        
    
# After compute the batches that were added are deleted
test_metrics_B = metric.compute(
    num_labels=NUM_CLASSES, ignore_index=BG_VALUE_255, reduce_labels=False
)
mean_test_iou = test_metrics_B["mean_iou"]
final_test_loss = test_running_loss / len(dataloaders[CURR_TASK]["test"].dataset)
wandb.log({
    f"Loss/test_{CURR_TASK}": final_test_loss,
    f"mIoU/test_{CURR_TASK}": mean_test_iou
})
print(f"Test Loss: {final_test_loss:.4f}, Test mIoU: {mean_test_iou:.4f}")

Test loop: 100%|██████████| 14/14 [00:12<00:00,  1.14it/s, loss=98.7925] 


Test Loss: 8.3312, Test mIoU: 0.8711


## Test results on A after training on B

In [None]:
# To avoid making stupid errors
CURR_TASK = "A"

model.eval()
test_running_loss = 0
test_loader = tqdm(dataloaders[CURR_TASK]["test"], desc="Test loop")
with torch.no_grad():
    for batch in test_loader:
        # Move everything to the device
        batch["pixel_values"] = batch["pixel_values"].to(device)
        batch["pixel_mask"] = batch["pixel_mask"].to(device)
        batch["mask_labels"] = [entry.to(device) for entry in batch["mask_labels"]]
        batch["class_labels"] = [entry.to(device) for entry in batch["class_labels"]]
        # Compute output and loss
        outputs = model(**batch)

        loss = outputs.loss
        # Record losses
        current_loss = loss.item() * batch["pixel_values"].size(0)
        test_running_loss += current_loss
        test_loader.set_postfix(loss=f"{current_loss:.4f}")

        # Extract and compute metrics
        pred_maps, masks = m2f_extract_pred_maps_and_masks(
            batch, outputs, m2f_preprocessor_A
        )
        metric.add_batch(references=masks, predictions=pred_maps)
        

# After compute the batches that were added are deleted
test_metrics_forgetting_A = metric.compute(
    num_labels=NUM_CLASSES, ignore_index=BG_VALUE_255, reduce_labels=False
)
mean_test_iou = test_metrics_forgetting_A["mean_iou"]

final_test_loss = test_running_loss / len(dataloaders[CURR_TASK]["test"].dataset)
wandb.log({
    f"Loss/test_random_replay_32mb_{CURR_TASK}": final_test_loss,
    f"mIoU/test_random_replay_32mb_{CURR_TASK}": mean_test_iou
})
print(f"Test Loss: {final_test_loss:.4f}, Test mIoU: {mean_test_iou:.4f}")


Test loop: 100%|██████████| 36/36 [00:31<00:00,  1.13it/s, loss=262.4240]


Test Loss: 17.5828, Test mIoU: 0.7207


In [None]:
# Collect overall mIoU
mIoU_A = test_metrics_A["mean_iou"]
mIoU_forgetting_A = test_metrics_forgetting_A["mean_iou"]
mIoU_B = test_metrics_B["mean_iou"]

# Collect per category mIoU
per_category_mIoU_A = np.array(test_metrics_A["per_category_iou"])
per_category_mIoU_forgetting_A = np.array(test_metrics_forgetting_A["per_category_iou"])
per_category_mIoU_B = np.array(test_metrics_B["per_category_iou"])

# Average learning accuracies (mIoUs)
avg_learning_acc = (mIoU_A + mIoU_B) / 2
per_category_avg_learning_acc = (per_category_mIoU_A + per_category_mIoU_B) / 2

# Forgetting
total_forgetting = mIoU_A - mIoU_forgetting_A
per_category_forgetting = (per_category_mIoU_A - per_category_mIoU_forgetting_A)

# Export evaluation metrics to WandB
wandb.log({
    "eval/avg_learning_acc": avg_learning_acc,
    "eval/per_category_avg_learning_acc": per_category_avg_learning_acc,
    "eval/total_forgetting": total_forgetting,
    "eval/per_category_forgetting": per_category_forgetting
})
print("**** Overall mIoU ****")
print(f"mIoU on task A: {mIoU_A}")
print(f"mIoU on task B: {mIoU_B}")
print(f"mIoU on task A after training on B: {mIoU_forgetting_A}")

print("\n**** Per category mIoU ****")
print(f"Per category mIoU on task A: {per_category_mIoU_A}")
print(f"Per category mIoU on task B: {per_category_mIoU_B}")
print(f"Per category mIoU on task A after training on B: {per_category_mIoU_forgetting_A}")

print("\n**** Average learning accuracies ****")
print(f"Average learning acc.: {avg_learning_acc}")
print(f"Per category Average learning acc.: {per_category_avg_learning_acc}")

print("\n**** Forgetting ****")
print(f"Total forgetting: {total_forgetting}")
print(f"Per category forgetting: {per_category_forgetting}")
wandb.finish()

**** Overall mIoU ****
mIoU on task A: 0.7877320741744921
mIoU on task B: 0.8711453891218022
mIoU on task A after training on B: 0.7207366617977596

**** Per category mIoU ****
Per category mIoU on task A: [0.95360882 0.91871854 0.687211   0.63869946 0.28185184 0.81372453
 0.79456505 0.78816018 0.91600346 0.94137199 0.93113795]
Per category mIoU on task B: [0.95921619 0.71816284 0.74091253 0.67861785 0.90445688 0.93730329
 0.91828812 0.92046192 0.90004802 0.93650766 0.96862397]
Per category mIoU on task A after training on B: [0.95080557 0.91341507 0.54542797 0.43005187 0.20140019 0.78958616
 0.67288361 0.72030864 0.87443118 0.92154552 0.9082475 ]

**** Average learning accuracies ****
Average learning acc.: 0.8294387316481471
Per category Average learning acc.: [0.9564125  0.81844069 0.71406176 0.65865865 0.59315436 0.87551391
 0.85642659 0.85431105 0.90802574 0.93893983 0.94988096]

**** Forgetting ****
Total forgetting: 0.06699541237673245
Per category forgetting: [0.00280326 0.0053

0,1
Loss/test_B,▁
Loss/test_random_replay_32mb_A,▁
Loss/train_replay_A_B,█▆▄▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/val_replay_A_B,█▆▃▃▃▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁
eval/avg_learning_acc,▁
eval/total_forgetting,▁
mIoU/test_B,▁
mIoU/test_random_replay_32mb_A,▁
mIoU/train_replay_A_B,▁▃▄▅▆▆▆▇▇▇▇▇▇▇███▇█▇██▇█████████████
mIoU/val_replay_A_B,▁▄▄▅▅▅▅▅▆▅▆▅▅█▆▇▆▇▆▆█▆▇▆▇▆▆▇▆▆▆▆▆▆▇▆

0,1
Loss/test_B,8.33123
Loss/test_random_replay_32mb_A,17.58275
Loss/train_replay_A_B,4.79243
Loss/val_replay_A_B,8.07463
eval/avg_learning_acc,0.82944
eval/total_forgetting,0.067
mIoU/test_B,0.87115
mIoU/test_random_replay_32mb_A,0.72074
mIoU/train_replay_A_B,0.96523
mIoU/val_replay_A_B,0.84291


In [None]:
per_category_mIoU_A,per_category_mIoU_B, per_category_mIoU_forgetting_A,per_category_avg_learning_acc,per_category_forgetting,

(array([0.95360882, 0.91871854, 0.687211  , 0.63869946, 0.28185184,
        0.81372453, 0.79456505, 0.78816018, 0.91600346, 0.94137199,
        0.93113795]),
 array([0.95921619, 0.71816284, 0.74091253, 0.67861785, 0.90445688,
        0.93730329, 0.91828812, 0.92046192, 0.90004802, 0.93650766,
        0.96862397]),
 array([0.95080557, 0.91341507, 0.54542797, 0.43005187, 0.20140019,
        0.78958616, 0.67288361, 0.72030864, 0.87443118, 0.92154552,
        0.9082475 ]),
 array([0.9564125 , 0.81844069, 0.71406176, 0.65865865, 0.59315436,
        0.87551391, 0.85642659, 0.85431105, 0.90802574, 0.93893983,
        0.94988096]),
 array([0.00280326, 0.00530346, 0.14178303, 0.20864758, 0.08045165,
        0.02413836, 0.12168145, 0.06785154, 0.04157228, 0.01982647,
        0.02289045]))