In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import os
from utils.common import (
    m2f_dataset_collate,
    m2f_extract_pred_maps_and_masks,
    BG_VALUE_255,
    set_seed,
    pixel_mean_std,
    CADIS_PIXEL_MEAN,
    CADIS_PIXEL_STD,
    CAT1K_PIXEL_MEAN,
    CAT1K_PIXEL_STD,
)
from utils.dataset_utils import (
    get_cadisv2_dataset,
    get_cataract1k_dataset,
    ZEISS_CATEGORIES,
)
from utils.medical_datasets import Mask2FormerDataset
from transformers import (
    Mask2FormerForUniversalSegmentation,
    SwinModel,
    SwinConfig,
    Mask2FormerConfig,
    AutoImageProcessor,
    Mask2FormerImageProcessor
)
from torch.utils.data import DataLoader
import evaluate
import torch.optim as optim
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from dotenv import load_dotenv
import wandb

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
set_seed(42) # seed everything

In [4]:
NUM_CLASSES = len(ZEISS_CATEGORIES) - 3  # Remove class incremental
SWIN_BACKBONE = "microsoft/swin-tiny-patch4-window7-224"#"microsoft/swin-large-patch4-window12-384"

# Download pretrained swin model
swin_model = SwinModel.from_pretrained(
    SWIN_BACKBONE, out_features=["stage1", "stage2", "stage3", "stage4"]
)
swin_config = SwinConfig.from_pretrained(
    SWIN_BACKBONE, out_features=["stage1", "stage2", "stage3", "stage4"]
)

# Create Mask2Former configuration based on Swin's configuration
mask2former_config = Mask2FormerConfig(
    backbone_config=swin_config, num_labels=NUM_CLASSES #, ignore_value=BG_VALUE
)

# Create the Mask2Former model with this configuration
model = Mask2FormerForUniversalSegmentation(mask2former_config)

# Reuse pretrained parameters
for swin_param, m2f_param in zip(
    swin_model.named_parameters(),
    model.model.pixel_level_module.encoder.named_parameters(),
):
    m2f_param_name = f"model.pixel_level_module.encoder.{m2f_param[0]}"

    if swin_param[0] == m2f_param[0]:
        model.state_dict()[m2f_param_name].copy_(swin_param[1])
        continue

    print(f"Not Matched: {m2f_param[0]} != {swin_param[0]}")



Not Matched: hidden_states_norms.stage1.weight != layernorm.weight
Not Matched: hidden_states_norms.stage1.bias != layernorm.bias


In [5]:
# Helper function to load datasets
def load_dataset(dataset_getter, data_path, domain_incremental):
    return dataset_getter(data_path, domain_incremental=domain_incremental)


# Helper function to create dataloaders for a dataset
def create_dataloaders(
    dataset, batch_size, shuffle, num_workers, drop_last, pin_memory, collate_fn
):
    return {
        "train": DataLoader(
            dataset["train"],
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=num_workers,
            drop_last=drop_last,
            pin_memory=pin_memory,
            collate_fn=collate_fn,
        ),
        "val": DataLoader(
            dataset["val"],
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=num_workers,
            drop_last=drop_last,
            pin_memory=pin_memory,
            collate_fn=collate_fn,
        ),
        "test": DataLoader(
            dataset["test"],
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=num_workers,
            drop_last=drop_last,
            pin_memory=pin_memory,
            collate_fn=collate_fn,
        ),
    }


# Load datasets
datasets = {
    "A": load_dataset(get_cadisv2_dataset, "../../storage/data/CaDISv2", True),
    "B": load_dataset(get_cataract1k_dataset, "../../storage/data/cataract-1k", True),
}

# pixel_mean_A,pixel_std_A=pixel_mean_std(datasets["A"][0])
pixel_mean_A = CADIS_PIXEL_MEAN
pixel_std_A = CADIS_PIXEL_STD

# pixel_mean_B,pixel_std_B=pixel_mean_std(datasets["B"][0])
pixel_mean_B = CAT1K_PIXEL_MEAN
pixel_std_B = CAT1K_PIXEL_STD

# Define preprocessor
swin_processor = AutoImageProcessor.from_pretrained(SWIN_BACKBONE)
m2f_preprocessor_A = Mask2FormerImageProcessor(
    reduce_labels=True,
    ignore_index=255,
    do_resize=False,
    do_rescale=True,
    do_normalize=True,
    image_std=pixel_std_A,
    image_mean=pixel_mean_A,
)

m2f_preprocessor_B = Mask2FormerImageProcessor(
    reduce_labels=True,
    ignore_index=255,
    do_resize=False,
    do_rescale=True,
    do_normalize=True,
    image_std=pixel_std_B,
    image_mean=pixel_mean_B,
)

# Create Mask2Former Datasets
m2f_datasets = {
    "A": {
        "train": Mask2FormerDataset(datasets["A"][0], m2f_preprocessor_A),
        "val": Mask2FormerDataset(datasets["A"][1], m2f_preprocessor_A),
        "test": Mask2FormerDataset(datasets["A"][2], m2f_preprocessor_A),
    },
    "B": {
        "train": Mask2FormerDataset(datasets["B"][0], m2f_preprocessor_B),
        "val": Mask2FormerDataset(datasets["B"][1], m2f_preprocessor_B),
        "test": Mask2FormerDataset(datasets["B"][2], m2f_preprocessor_B),
    },
}

# DataLoader parameters
N_WORKERS = 4
BATCH_SIZE = 16
SHUFFLE = True
DROP_LAST = True

dataloader_params = {
    "batch_size": BATCH_SIZE,
    "shuffle": SHUFFLE,
    "num_workers": N_WORKERS,
    "drop_last": DROP_LAST,
    "pin_memory": True,
    "collate_fn": m2f_dataset_collate,
}

# Create DataLoaders
dataloaders = {
    key: create_dataloaders(m2f_datasets[key], **dataloader_params)
    for key in m2f_datasets
}

print(dataloaders)



{'A': {'train': <torch.utils.data.dataloader.DataLoader object at 0x7fb28515d0a0>, 'val': <torch.utils.data.dataloader.DataLoader object at 0x7fb285aeaf30>, 'test': <torch.utils.data.dataloader.DataLoader object at 0x7fb28515e390>}, 'B': {'train': <torch.utils.data.dataloader.DataLoader object at 0x7fb28515d1f0>, 'val': <torch.utils.data.dataloader.DataLoader object at 0x7fb28515d8e0>, 'test': <torch.utils.data.dataloader.DataLoader object at 0x7fb28515c3e0>}}


In [6]:
# Check if CUDA is available, otherwise use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Using device: {device}")

Using device: cuda


In [7]:
# Tensorboard setup
out_dir="outputs/"
if not os.path.exists(out_dir):
    os.makedirs(out_dir)
if not os.path.exists(out_dir+"runs"):
    os.makedirs(out_dir+"runs")
%load_ext tensorboard
%tensorboard --logdir outputs/runs

Reusing TensorBoard on port 6006 (pid 1160), started 0:03:23 ago. (Use '!kill 1160' to kill it.)

In [13]:
#!CUDA_LAUNCH_BLOCKING=1

# First train on dataset A

In [8]:
# Training
NUM_EPOCHS = 110 #200
LEARNING_RATE = 1e-4
LR_MULTIPLIER = 0.1
BACKBONE_LR = LEARNING_RATE * LR_MULTIPLIER
WEIGHT_DECAY = 0.05
metric = evaluate.load("mean_iou")
encoder_params = [
    param
    for name, param in model.named_parameters()
    if name.startswith("model.pixel_level_module.encoder")
]
decoder_params = [
    param
    for name, param in model.named_parameters()
    if name.startswith("model.pixel_level_module.decoder")
]
transformer_params = [
    param
    for name, param in model.named_parameters()
    if name.startswith("model.transformer_module")
]
optimizer = optim.AdamW(
    [
        {"params": encoder_params, "lr": BACKBONE_LR},
        {"params": decoder_params},
        {"params": transformer_params},
    ],
    lr=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
)

scheduler = optim.lr_scheduler.PolynomialLR(
    optimizer, total_iters=NUM_EPOCHS, power=0.9
)

In [9]:
# WandB for team usage !!!!

wandb.login() # use this one if a different person is going to run the notebook
#wandb.login(relogin=False) # if the same person in the last run is going to run the notebook again


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mge85ket[0m ([33mcontinual-learning-tum[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [10]:
wandb.init(
    project="M2F_original",
    config={
        "learning_rate": LEARNING_RATE,
        "learning_rate_multiplier": LR_MULTIPLIER,
        "backbone_learning_rate": BACKBONE_LR,
        "learning_rate_scheduler": scheduler.__class__.__name__,
        "optimizer": optimizer.__class__.__name__,
        "backbone": SWIN_BACKBONE,
        "m2f_preprocessor": m2f_preprocessor_B.__dict__,
        "m2f_model_config": model.config
    },
    name="M2F-Swin-Tiny-Train_Cataract1K",
    notes="M2F with tiny Swin backbone pretrained on ImageNet-1K. \
        Scenario: Train on B, Test on B"
)

In [11]:
# Tensorboard logging
writer = SummaryWriter(log_dir=out_dir + "runs")

# Model checkpointing
base_model_name="m2f_swin_backbone_train_cataract1k"
model_dir = out_dir + "models/"
if not os.path.exists(model_dir):
    print("Store weights in: ", model_dir)
    os.makedirs(model_dir)

best_model_dir = model_dir + f"{base_model_name}/best_model/"
if not os.path.exists(best_model_dir):
    print("Store best model weights in: ", best_model_dir)
    os.makedirs(best_model_dir)
final_model_dir = model_dir + f"{base_model_name}/final_model/"
if not os.path.exists(final_model_dir):
    print("Store final model weights in: ", final_model_dir)
    os.makedirs(final_model_dir)
    

In [12]:
# Save the preprocessor
m2f_preprocessor_B.save_pretrained(model_dir + base_model_name)

['outputs/models/m2f_swin_backbone_train_cataract1k/preprocessor_config.json']

In [13]:
# To avoid making stupid errors
CURR_TASK = "B"

# For storing the model
best_val_metric = -np.inf

# Move model to device
model.to(device)

for epoch in range(NUM_EPOCHS):
    model.train()
    train_running_loss = 0.0
    val_running_loss = 0.0

    # Set up tqdm for the training loop
    train_loader = tqdm(
        dataloaders[CURR_TASK]["train"], desc=f"Epoch {epoch + 1}/{NUM_EPOCHS} Training"
    )

    for batch in train_loader:
        # Move everything to the device
        batch["pixel_values"] = batch["pixel_values"].to(device)
        batch["pixel_mask"] = batch["pixel_mask"].to(device)
        batch["mask_labels"] = [entry.to(device) for entry in batch["mask_labels"]]
        batch["class_labels"] = [entry.to(device) for entry in batch["class_labels"]]

        # Compute output and loss
        outputs = model(**batch)

        loss = outputs.loss

        # Compute gradient and perform step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Record losses
        current_loss = loss.item() * batch["pixel_values"].size(0)
        train_running_loss += current_loss
        train_loader.set_postfix(loss=f"{current_loss:.4f}")

        # Extract and compute metrics
        pred_maps, masks = m2f_extract_pred_maps_and_masks(
            batch, outputs, m2f_preprocessor_B
        )
        metric.add_batch(references=masks, predictions=pred_maps)

    # After compute the batches that were added are deleted
    mean_train_iou = metric.compute(
        num_labels=NUM_CLASSES, ignore_index=BG_VALUE_255, reduce_labels=False
    )["mean_iou"]

    # Validation phase
    model.eval()
    val_loader = tqdm(
        dataloaders[CURR_TASK]["val"], desc=f"Epoch {epoch + 1}/{NUM_EPOCHS} Validation"
    )
    with torch.no_grad():
        for batch in val_loader:
            # Move everything to the device
            batch["pixel_values"] = batch["pixel_values"].to(device)
            batch["pixel_mask"] = batch["pixel_mask"].to(device)
            batch["mask_labels"] = [entry.to(device) for entry in batch["mask_labels"]]
            batch["class_labels"] = [
                entry.to(device) for entry in batch["class_labels"]
            ]
            # Compute output and loss
            outputs = model(**batch)

            loss = outputs.loss
            # Record losses
            current_loss = loss.item() * batch["pixel_values"].size(0)
            val_running_loss += current_loss
            val_loader.set_postfix(loss=f"{current_loss:.4f}")

            # Extract and compute metrics
            pred_maps, masks = m2f_extract_pred_maps_and_masks(
                batch, outputs, m2f_preprocessor_B
            )
            metric.add_batch(references=masks, predictions=pred_maps)

    # After compute the batches that were added are deleted
    mean_val_iou = metric.compute(
        num_labels=NUM_CLASSES, ignore_index=BG_VALUE_255, reduce_labels=False
    )["mean_iou"]

    epoch_train_loss = train_running_loss / len(dataloaders[CURR_TASK]["train"].dataset)
    epoch_val_loss = val_running_loss / len(dataloaders[CURR_TASK]["val"].dataset)

    writer.add_scalar(f"Loss/train_{base_model_name}_{CURR_TASK}", epoch_train_loss, epoch + 1)
    writer.add_scalar(f"Loss/val_{base_model_name}_{CURR_TASK}", epoch_val_loss, epoch + 1)
    writer.add_scalar(f"mIoU/train_{base_model_name}_{CURR_TASK}", mean_train_iou, epoch + 1)
    writer.add_scalar(f"mIoU/val_{base_model_name}_{CURR_TASK}", mean_val_iou, epoch + 1)

    wandb.log({
        f"Loss/train_{CURR_TASK}": epoch_train_loss,
        f"Loss/val_{CURR_TASK}": epoch_val_loss,
        f"mIoU/train_{CURR_TASK}": mean_train_iou,
        f"mIoU/val_{CURR_TASK}": mean_val_iou
    })

    if mean_val_iou > best_val_metric:
        best_val_metric = mean_val_iou
        model.save_pretrained(f"{best_model_dir}{CURR_TASK}/")

    tqdm.write(
        f"Epoch {epoch + 1}/{NUM_EPOCHS}, Train Loss: {epoch_train_loss:.4f}, Train mIoU: {mean_train_iou:.4f}, Validation Loss: {epoch_val_loss:.4f}, Validation mIoU: {mean_val_iou:.4f}"
    )

Epoch 1/110 Training: 100%|██████████| 112/112 [06:32<00:00,  3.51s/it, loss=975.0126] 
  iou = total_area_intersect / total_area_union
  acc = total_area_intersect / total_area_label
Epoch 1/110 Validation: 100%|██████████| 14/14 [00:42<00:00,  3.03s/it, loss=1006.0431]


Epoch 1/110, Train Loss: 75.3267, Train mIoU: 0.0618, Validation Loss: 62.1338, Validation mIoU: 0.0665


Epoch 2/110 Training: 100%|██████████| 112/112 [06:06<00:00,  3.28s/it, loss=461.8261]
Epoch 2/110 Validation: 100%|██████████| 14/14 [00:40<00:00,  2.88s/it, loss=442.4552]


Epoch 2/110, Train Loss: 45.4344, Train mIoU: 0.1487, Validation Loss: 29.5133, Validation mIoU: 0.2229


Epoch 3/110 Training: 100%|██████████| 112/112 [06:15<00:00,  3.36s/it, loss=266.4327]
Epoch 3/110 Validation: 100%|██████████| 14/14 [00:42<00:00,  3.07s/it, loss=306.1015]


Epoch 3/110, Train Loss: 24.0933, Train mIoU: 0.2890, Validation Loss: 17.5510, Validation mIoU: 0.3621


Epoch 4/110 Training: 100%|██████████| 112/112 [06:14<00:00,  3.34s/it, loss=219.0459]
Epoch 4/110 Validation: 100%|██████████| 14/14 [00:40<00:00,  2.90s/it, loss=210.8945]


Epoch 4/110, Train Loss: 16.4646, Train mIoU: 0.4183, Validation Loss: 14.8467, Validation mIoU: 0.4480


Epoch 5/110 Training: 100%|██████████| 112/112 [06:13<00:00,  3.34s/it, loss=262.1952]
Epoch 5/110 Validation: 100%|██████████| 14/14 [00:40<00:00,  2.92s/it, loss=190.0068]


Epoch 5/110, Train Loss: 14.1583, Train mIoU: 0.5064, Validation Loss: 13.2606, Validation mIoU: 0.4753


Epoch 6/110 Training: 100%|██████████| 112/112 [06:15<00:00,  3.35s/it, loss=169.2522]
Epoch 6/110 Validation: 100%|██████████| 14/14 [00:36<00:00,  2.63s/it, loss=158.6317]


Epoch 6/110, Train Loss: 12.5061, Train mIoU: 0.5331, Validation Loss: 11.7312, Validation mIoU: 0.5403


Epoch 7/110 Training: 100%|██████████| 112/112 [06:25<00:00,  3.44s/it, loss=189.6375]
Epoch 7/110 Validation: 100%|██████████| 14/14 [00:43<00:00,  3.08s/it, loss=185.6576]


Epoch 7/110, Train Loss: 11.9068, Train mIoU: 0.5939, Validation Loss: 12.0291, Validation mIoU: 0.5276


Epoch 8/110 Training: 100%|██████████| 112/112 [06:13<00:00,  3.33s/it, loss=178.7290]
Epoch 8/110 Validation: 100%|██████████| 14/14 [00:39<00:00,  2.79s/it, loss=181.7968]


Epoch 8/110, Train Loss: 10.8959, Train mIoU: 0.6339, Validation Loss: 11.3294, Validation mIoU: 0.5657


Epoch 9/110 Training: 100%|██████████| 112/112 [06:09<00:00,  3.30s/it, loss=199.5288]
Epoch 9/110 Validation: 100%|██████████| 14/14 [00:38<00:00,  2.76s/it, loss=160.8610]


Epoch 9/110, Train Loss: 10.1826, Train mIoU: 0.7084, Validation Loss: 11.5274, Validation mIoU: 0.6590


Epoch 10/110 Training: 100%|██████████| 112/112 [06:15<00:00,  3.35s/it, loss=143.8205]
Epoch 10/110 Validation: 100%|██████████| 14/14 [00:39<00:00,  2.82s/it, loss=188.2857]


Epoch 10/110, Train Loss: 9.3202, Train mIoU: 0.7442, Validation Loss: 9.7845, Validation mIoU: 0.7257


Epoch 11/110 Training: 100%|██████████| 112/112 [06:08<00:00,  3.29s/it, loss=138.0228]
Epoch 11/110 Validation: 100%|██████████| 14/14 [00:40<00:00,  2.92s/it, loss=144.9584]


Epoch 11/110, Train Loss: 8.6737, Train mIoU: 0.7733, Validation Loss: 9.2874, Validation mIoU: 0.7604


Epoch 12/110 Training: 100%|██████████| 112/112 [06:03<00:00,  3.25s/it, loss=118.3179]
Epoch 12/110 Validation: 100%|██████████| 14/14 [00:37<00:00,  2.69s/it, loss=178.5293]


Epoch 12/110, Train Loss: 8.2745, Train mIoU: 0.7976, Validation Loss: 9.2222, Validation mIoU: 0.8045


Epoch 13/110 Training: 100%|██████████| 112/112 [06:02<00:00,  3.23s/it, loss=134.6131]
Epoch 13/110 Validation: 100%|██████████| 14/14 [00:40<00:00,  2.91s/it, loss=135.1642]


Epoch 13/110, Train Loss: 7.9321, Train mIoU: 0.7999, Validation Loss: 8.9584, Validation mIoU: 0.8022


Epoch 14/110 Training: 100%|██████████| 112/112 [06:07<00:00,  3.28s/it, loss=135.4963]
Epoch 14/110 Validation: 100%|██████████| 14/14 [00:39<00:00,  2.85s/it, loss=130.4563]


Epoch 14/110, Train Loss: 7.4877, Train mIoU: 0.8339, Validation Loss: 9.2101, Validation mIoU: 0.7164


Epoch 15/110 Training: 100%|██████████| 112/112 [06:14<00:00,  3.34s/it, loss=123.4660]
Epoch 15/110 Validation: 100%|██████████| 14/14 [00:39<00:00,  2.84s/it, loss=135.2343]


Epoch 15/110, Train Loss: 7.3631, Train mIoU: 0.8319, Validation Loss: 8.6883, Validation mIoU: 0.7617


Epoch 16/110 Training: 100%|██████████| 112/112 [06:02<00:00,  3.24s/it, loss=146.7481]
Epoch 16/110 Validation: 100%|██████████| 14/14 [00:38<00:00,  2.73s/it, loss=130.6294]


Epoch 16/110, Train Loss: 8.7848, Train mIoU: 0.7975, Validation Loss: 9.5053, Validation mIoU: 0.7549


Epoch 17/110 Training: 100%|██████████| 112/112 [06:03<00:00,  3.25s/it, loss=131.3831]
Epoch 17/110 Validation: 100%|██████████| 14/14 [00:43<00:00,  3.13s/it, loss=143.1883]


Epoch 17/110, Train Loss: 8.1424, Train mIoU: 0.8340, Validation Loss: 10.1808, Validation mIoU: 0.7296


Epoch 18/110 Training: 100%|██████████| 112/112 [06:05<00:00,  3.27s/it, loss=105.6142]
Epoch 18/110 Validation: 100%|██████████| 14/14 [00:42<00:00,  3.04s/it, loss=155.7063]


Epoch 18/110, Train Loss: 7.9686, Train mIoU: 0.8031, Validation Loss: 8.7772, Validation mIoU: 0.7213


Epoch 19/110 Training: 100%|██████████| 112/112 [06:03<00:00,  3.25s/it, loss=107.9434]
Epoch 19/110 Validation: 100%|██████████| 14/14 [00:42<00:00,  3.01s/it, loss=179.7084]


Epoch 19/110, Train Loss: 7.6105, Train mIoU: 0.8204, Validation Loss: 8.6651, Validation mIoU: 0.7385


Epoch 20/110 Training: 100%|██████████| 112/112 [06:01<00:00,  3.23s/it, loss=109.3766]
Epoch 20/110 Validation: 100%|██████████| 14/14 [00:40<00:00,  2.89s/it, loss=131.6930]


Epoch 20/110, Train Loss: 6.6917, Train mIoU: 0.8366, Validation Loss: 8.8586, Validation mIoU: 0.7657


Epoch 21/110 Training: 100%|██████████| 112/112 [06:13<00:00,  3.34s/it, loss=92.8285] 
Epoch 21/110 Validation: 100%|██████████| 14/14 [00:40<00:00,  2.86s/it, loss=131.5689]


Epoch 21/110, Train Loss: 6.6017, Train mIoU: 0.8489, Validation Loss: 8.4139, Validation mIoU: 0.8195


Epoch 22/110 Training: 100%|██████████| 112/112 [06:20<00:00,  3.40s/it, loss=93.8675] 
Epoch 22/110 Validation: 100%|██████████| 14/14 [00:42<00:00,  3.01s/it, loss=114.3569]


Epoch 22/110, Train Loss: 6.4940, Train mIoU: 0.8516, Validation Loss: 8.1034, Validation mIoU: 0.8060


Epoch 23/110 Training: 100%|██████████| 112/112 [06:03<00:00,  3.24s/it, loss=86.4454] 
Epoch 23/110 Validation: 100%|██████████| 14/14 [00:41<00:00,  2.99s/it, loss=111.6324]


Epoch 23/110, Train Loss: 6.3260, Train mIoU: 0.8548, Validation Loss: 8.4268, Validation mIoU: 0.8017


Epoch 24/110 Training: 100%|██████████| 112/112 [06:04<00:00,  3.25s/it, loss=96.1467] 
Epoch 24/110 Validation: 100%|██████████| 14/14 [00:42<00:00,  3.01s/it, loss=172.8296]


Epoch 24/110, Train Loss: 6.2367, Train mIoU: 0.8508, Validation Loss: 7.9212, Validation mIoU: 0.7320


Epoch 25/110 Training: 100%|██████████| 112/112 [06:03<00:00,  3.25s/it, loss=86.6870] 
Epoch 25/110 Validation: 100%|██████████| 14/14 [00:41<00:00,  2.97s/it, loss=106.4672]


Epoch 25/110, Train Loss: 6.1768, Train mIoU: 0.8565, Validation Loss: 8.3396, Validation mIoU: 0.7315


Epoch 26/110 Training: 100%|██████████| 112/112 [06:11<00:00,  3.32s/it, loss=86.2516] 
Epoch 26/110 Validation: 100%|██████████| 14/14 [00:41<00:00,  2.95s/it, loss=149.0842]


Epoch 26/110, Train Loss: 5.8611, Train mIoU: 0.8657, Validation Loss: 7.8289, Validation mIoU: 0.8096


Epoch 27/110 Training: 100%|██████████| 112/112 [06:20<00:00,  3.39s/it, loss=88.5464] 
Epoch 27/110 Validation: 100%|██████████| 14/14 [00:45<00:00,  3.23s/it, loss=113.5185]


Epoch 27/110, Train Loss: 5.7306, Train mIoU: 0.8602, Validation Loss: 8.0084, Validation mIoU: 0.7992


Epoch 28/110 Training: 100%|██████████| 112/112 [06:09<00:00,  3.30s/it, loss=85.1345] 
Epoch 28/110 Validation: 100%|██████████| 14/14 [00:44<00:00,  3.17s/it, loss=152.1606]


Epoch 28/110, Train Loss: 5.7190, Train mIoU: 0.8643, Validation Loss: 8.5803, Validation mIoU: 0.7827


Epoch 29/110 Training: 100%|██████████| 112/112 [06:19<00:00,  3.39s/it, loss=89.2868] 
Epoch 29/110 Validation: 100%|██████████| 14/14 [00:43<00:00,  3.08s/it, loss=127.2457]


Epoch 29/110, Train Loss: 5.6806, Train mIoU: 0.8722, Validation Loss: 8.0646, Validation mIoU: 0.7596


Epoch 30/110 Training: 100%|██████████| 112/112 [06:08<00:00,  3.29s/it, loss=85.4323]
Epoch 30/110 Validation: 100%|██████████| 14/14 [00:41<00:00,  2.97s/it, loss=104.7291]


Epoch 30/110, Train Loss: 5.3630, Train mIoU: 0.8683, Validation Loss: 7.8403, Validation mIoU: 0.7345


Epoch 31/110 Training: 100%|██████████| 112/112 [06:16<00:00,  3.37s/it, loss=81.7401]
Epoch 31/110 Validation: 100%|██████████| 14/14 [00:39<00:00,  2.80s/it, loss=132.5495]


Epoch 31/110, Train Loss: 5.2835, Train mIoU: 0.8709, Validation Loss: 7.8547, Validation mIoU: 0.7773


Epoch 32/110 Training: 100%|██████████| 112/112 [06:16<00:00,  3.36s/it, loss=81.6292] 
Epoch 32/110 Validation: 100%|██████████| 14/14 [00:45<00:00,  3.29s/it, loss=123.8131]


Epoch 32/110, Train Loss: 5.5809, Train mIoU: 0.8699, Validation Loss: 8.3859, Validation mIoU: 0.7514


Epoch 33/110 Training: 100%|██████████| 112/112 [06:18<00:00,  3.38s/it, loss=96.2843] 
Epoch 33/110 Validation: 100%|██████████| 14/14 [00:42<00:00,  3.02s/it, loss=123.3301]


Epoch 33/110, Train Loss: 5.3337, Train mIoU: 0.8777, Validation Loss: 8.0973, Validation mIoU: 0.7641


Epoch 34/110 Training: 100%|██████████| 112/112 [06:08<00:00,  3.29s/it, loss=80.1332] 
Epoch 34/110 Validation: 100%|██████████| 14/14 [00:38<00:00,  2.77s/it, loss=145.1810]


Epoch 34/110, Train Loss: 5.3706, Train mIoU: 0.8795, Validation Loss: 8.2540, Validation mIoU: 0.7316


Epoch 35/110 Training: 100%|██████████| 112/112 [06:12<00:00,  3.32s/it, loss=79.2056]
Epoch 35/110 Validation: 100%|██████████| 14/14 [00:41<00:00,  2.93s/it, loss=90.5401] 


Epoch 35/110, Train Loss: 5.3760, Train mIoU: 0.8645, Validation Loss: 7.8298, Validation mIoU: 0.7824


Epoch 36/110 Training: 100%|██████████| 112/112 [06:12<00:00,  3.32s/it, loss=83.2289]
Epoch 36/110 Validation: 100%|██████████| 14/14 [00:40<00:00,  2.89s/it, loss=113.3196]


Epoch 36/110, Train Loss: 4.9911, Train mIoU: 0.8756, Validation Loss: 7.9152, Validation mIoU: 0.7302


Epoch 37/110 Training: 100%|██████████| 112/112 [06:19<00:00,  3.39s/it, loss=94.3741] 
Epoch 37/110 Validation: 100%|██████████| 14/14 [00:44<00:00,  3.16s/it, loss=124.9485]


Epoch 37/110, Train Loss: 4.9814, Train mIoU: 0.8717, Validation Loss: 8.1908, Validation mIoU: 0.7248


Epoch 38/110 Training: 100%|██████████| 112/112 [06:15<00:00,  3.35s/it, loss=72.6202] 
Epoch 38/110 Validation: 100%|██████████| 14/14 [00:40<00:00,  2.92s/it, loss=117.0646]


Epoch 38/110, Train Loss: 5.1210, Train mIoU: 0.8597, Validation Loss: 7.8168, Validation mIoU: 0.7367


Epoch 39/110 Training: 100%|██████████| 112/112 [06:13<00:00,  3.33s/it, loss=87.7326] 
Epoch 39/110 Validation: 100%|██████████| 14/14 [00:41<00:00,  2.98s/it, loss=143.1897]


Epoch 39/110, Train Loss: 5.0666, Train mIoU: 0.8498, Validation Loss: 8.4450, Validation mIoU: 0.7123


Epoch 40/110 Training: 100%|██████████| 112/112 [06:12<00:00,  3.32s/it, loss=80.0509]
Epoch 40/110 Validation: 100%|██████████| 14/14 [00:43<00:00,  3.10s/it, loss=145.0665]


Epoch 40/110, Train Loss: 4.9892, Train mIoU: 0.8669, Validation Loss: 8.3220, Validation mIoU: 0.7888


Epoch 41/110 Training: 100%|██████████| 112/112 [06:17<00:00,  3.37s/it, loss=78.3698]
Epoch 41/110 Validation: 100%|██████████| 14/14 [00:42<00:00,  3.07s/it, loss=208.1985]


Epoch 41/110, Train Loss: 4.9615, Train mIoU: 0.8798, Validation Loss: 8.2251, Validation mIoU: 0.6950


Epoch 42/110 Training: 100%|██████████| 112/112 [06:21<00:00,  3.41s/it, loss=73.0255]
Epoch 42/110 Validation: 100%|██████████| 14/14 [00:42<00:00,  3.04s/it, loss=112.1368]


Epoch 42/110, Train Loss: 5.1140, Train mIoU: 0.8599, Validation Loss: 7.9561, Validation mIoU: 0.7556


Epoch 43/110 Training: 100%|██████████| 112/112 [06:12<00:00,  3.33s/it, loss=91.2770]
Epoch 43/110 Validation: 100%|██████████| 14/14 [00:42<00:00,  3.05s/it, loss=130.8797]


Epoch 43/110, Train Loss: 4.7186, Train mIoU: 0.8826, Validation Loss: 8.2719, Validation mIoU: 0.7659


Epoch 44/110 Training: 100%|██████████| 112/112 [06:10<00:00,  3.31s/it, loss=69.2326]
Epoch 44/110 Validation: 100%|██████████| 14/14 [00:41<00:00,  2.99s/it, loss=189.1159]


Epoch 44/110, Train Loss: 4.5593, Train mIoU: 0.8772, Validation Loss: 8.0593, Validation mIoU: 0.7429


Epoch 45/110 Training: 100%|██████████| 112/112 [06:17<00:00,  3.37s/it, loss=81.6108]
Epoch 45/110 Validation: 100%|██████████| 14/14 [00:41<00:00,  2.95s/it, loss=110.7702]


Epoch 45/110, Train Loss: 4.4449, Train mIoU: 0.8863, Validation Loss: 8.1324, Validation mIoU: 0.7145


Epoch 46/110 Training: 100%|██████████| 112/112 [06:07<00:00,  3.28s/it, loss=74.3946]
Epoch 46/110 Validation: 100%|██████████| 14/14 [00:35<00:00,  2.56s/it, loss=109.6624]


Epoch 46/110, Train Loss: 4.6429, Train mIoU: 0.8741, Validation Loss: 7.9114, Validation mIoU: 0.7933


Epoch 47/110 Training: 100%|██████████| 112/112 [06:05<00:00,  3.26s/it, loss=67.3759]
Epoch 47/110 Validation: 100%|██████████| 14/14 [00:42<00:00,  3.07s/it, loss=96.8803] 


Epoch 47/110, Train Loss: 4.5335, Train mIoU: 0.8829, Validation Loss: 7.7874, Validation mIoU: 0.7503


Epoch 48/110 Training: 100%|██████████| 112/112 [06:18<00:00,  3.38s/it, loss=71.0857]
Epoch 48/110 Validation: 100%|██████████| 14/14 [00:40<00:00,  2.88s/it, loss=108.4149]


Epoch 48/110, Train Loss: 4.7876, Train mIoU: 0.8798, Validation Loss: 8.2846, Validation mIoU: 0.7375


Epoch 49/110 Training: 100%|██████████| 112/112 [06:13<00:00,  3.33s/it, loss=76.9127]
Epoch 49/110 Validation: 100%|██████████| 14/14 [00:42<00:00,  3.01s/it, loss=187.3259]


Epoch 49/110, Train Loss: 4.4574, Train mIoU: 0.8821, Validation Loss: 8.0079, Validation mIoU: 0.7443


Epoch 50/110 Training: 100%|██████████| 112/112 [06:10<00:00,  3.31s/it, loss=67.6232]
Epoch 50/110 Validation: 100%|██████████| 14/14 [00:39<00:00,  2.80s/it, loss=176.9978]


Epoch 50/110, Train Loss: 4.2267, Train mIoU: 0.8862, Validation Loss: 8.3417, Validation mIoU: 0.7605


Epoch 51/110 Training: 100%|██████████| 112/112 [06:19<00:00,  3.39s/it, loss=83.3491]
Epoch 51/110 Validation: 100%|██████████| 14/14 [00:44<00:00,  3.16s/it, loss=99.2848] 


Epoch 51/110, Train Loss: 4.1729, Train mIoU: 0.8866, Validation Loss: 8.3512, Validation mIoU: 0.7797


Epoch 52/110 Training: 100%|██████████| 112/112 [06:17<00:00,  3.37s/it, loss=71.0529]
Epoch 52/110 Validation: 100%|██████████| 14/14 [00:40<00:00,  2.93s/it, loss=119.8707]


Epoch 52/110, Train Loss: 4.4247, Train mIoU: 0.8523, Validation Loss: 7.8449, Validation mIoU: 0.7845


Epoch 53/110 Training: 100%|██████████| 112/112 [06:21<00:00,  3.40s/it, loss=65.2648]
Epoch 53/110 Validation: 100%|██████████| 14/14 [00:42<00:00,  3.01s/it, loss=145.8271]


Epoch 53/110, Train Loss: 4.2011, Train mIoU: 0.8731, Validation Loss: 7.8329, Validation mIoU: 0.8059


Epoch 54/110 Training: 100%|██████████| 112/112 [06:18<00:00,  3.38s/it, loss=80.8462]
Epoch 54/110 Validation: 100%|██████████| 14/14 [00:39<00:00,  2.81s/it, loss=90.8396] 


Epoch 54/110, Train Loss: 4.1142, Train mIoU: 0.8858, Validation Loss: 7.7983, Validation mIoU: 0.7539


Epoch 55/110 Training: 100%|██████████| 112/112 [06:23<00:00,  3.43s/it, loss=76.8546]
Epoch 55/110 Validation: 100%|██████████| 14/14 [00:41<00:00,  2.97s/it, loss=145.0527]


Epoch 55/110, Train Loss: 4.2444, Train mIoU: 0.8883, Validation Loss: 8.2718, Validation mIoU: 0.6532


Epoch 56/110 Training: 100%|██████████| 112/112 [06:21<00:00,  3.41s/it, loss=74.6158] 
Epoch 56/110 Validation: 100%|██████████| 14/14 [00:39<00:00,  2.80s/it, loss=124.1848]


Epoch 56/110, Train Loss: 4.9492, Train mIoU: 0.8817, Validation Loss: 8.4248, Validation mIoU: 0.7481


Epoch 57/110 Training: 100%|██████████| 112/112 [06:07<00:00,  3.28s/it, loss=75.7085]
Epoch 57/110 Validation: 100%|██████████| 14/14 [00:43<00:00,  3.11s/it, loss=118.8639]


Epoch 57/110, Train Loss: 4.5461, Train mIoU: 0.8847, Validation Loss: 8.3558, Validation mIoU: 0.7695


Epoch 58/110 Training: 100%|██████████| 112/112 [06:11<00:00,  3.31s/it, loss=73.0761]
Epoch 58/110 Validation: 100%|██████████| 14/14 [00:37<00:00,  2.67s/it, loss=174.2192]


Epoch 58/110, Train Loss: 4.2883, Train mIoU: 0.8871, Validation Loss: 8.6223, Validation mIoU: 0.7196


Epoch 59/110 Training: 100%|██████████| 112/112 [06:02<00:00,  3.23s/it, loss=77.3474]
Epoch 59/110 Validation: 100%|██████████| 14/14 [00:37<00:00,  2.71s/it, loss=134.5484]


Epoch 59/110, Train Loss: 4.1710, Train mIoU: 0.8861, Validation Loss: 7.9430, Validation mIoU: 0.7000


Epoch 60/110 Training: 100%|██████████| 112/112 [06:08<00:00,  3.29s/it, loss=63.5938]
Epoch 60/110 Validation: 100%|██████████| 14/14 [00:41<00:00,  2.98s/it, loss=120.1027]


Epoch 60/110, Train Loss: 4.0198, Train mIoU: 0.8828, Validation Loss: 7.8530, Validation mIoU: 0.7261


Epoch 61/110 Training: 100%|██████████| 112/112 [06:07<00:00,  3.28s/it, loss=58.3019]
Epoch 61/110 Validation: 100%|██████████| 14/14 [00:42<00:00,  3.04s/it, loss=118.4647]


Epoch 61/110, Train Loss: 4.0210, Train mIoU: 0.8632, Validation Loss: 8.2386, Validation mIoU: 0.6919


Epoch 62/110 Training: 100%|██████████| 112/112 [06:06<00:00,  3.27s/it, loss=70.6514]
Epoch 62/110 Validation: 100%|██████████| 14/14 [00:35<00:00,  2.54s/it, loss=142.8454]


Epoch 62/110, Train Loss: 3.8779, Train mIoU: 0.8777, Validation Loss: 7.9819, Validation mIoU: 0.7454


Epoch 63/110 Training: 100%|██████████| 112/112 [06:13<00:00,  3.34s/it, loss=63.6232]
Epoch 63/110 Validation: 100%|██████████| 14/14 [00:40<00:00,  2.87s/it, loss=148.8664]


Epoch 63/110, Train Loss: 3.7871, Train mIoU: 0.8908, Validation Loss: 8.0283, Validation mIoU: 0.7261


Epoch 64/110 Training: 100%|██████████| 112/112 [06:10<00:00,  3.31s/it, loss=78.5261]
Epoch 64/110 Validation: 100%|██████████| 14/14 [00:43<00:00,  3.08s/it, loss=211.7089]


Epoch 64/110, Train Loss: 3.9554, Train mIoU: 0.8901, Validation Loss: 10.7008, Validation mIoU: 0.7769


Epoch 65/110 Training: 100%|██████████| 112/112 [06:14<00:00,  3.35s/it, loss=66.3837]
Epoch 65/110 Validation: 100%|██████████| 14/14 [00:40<00:00,  2.93s/it, loss=125.1017]


Epoch 65/110, Train Loss: 4.6485, Train mIoU: 0.8486, Validation Loss: 9.2537, Validation mIoU: 0.7399


Epoch 66/110 Training: 100%|██████████| 112/112 [06:16<00:00,  3.36s/it, loss=68.1784]
Epoch 66/110 Validation: 100%|██████████| 14/14 [00:41<00:00,  2.96s/it, loss=121.4528]


Epoch 66/110, Train Loss: 4.0988, Train mIoU: 0.8852, Validation Loss: 7.7685, Validation mIoU: 0.7677


Epoch 67/110 Training: 100%|██████████| 112/112 [06:10<00:00,  3.31s/it, loss=67.0510]
Epoch 67/110 Validation: 100%|██████████| 14/14 [00:38<00:00,  2.78s/it, loss=129.2693]


Epoch 67/110, Train Loss: 3.7981, Train mIoU: 0.8821, Validation Loss: 8.0081, Validation mIoU: 0.7714


Epoch 68/110 Training: 100%|██████████| 112/112 [06:12<00:00,  3.32s/it, loss=65.1566]
Epoch 68/110 Validation: 100%|██████████| 14/14 [00:43<00:00,  3.08s/it, loss=99.8453] 


Epoch 68/110, Train Loss: 3.8212, Train mIoU: 0.8908, Validation Loss: 7.9622, Validation mIoU: 0.7271


Epoch 69/110 Training: 100%|██████████| 112/112 [06:14<00:00,  3.34s/it, loss=54.4329]
Epoch 69/110 Validation: 100%|██████████| 14/14 [00:40<00:00,  2.92s/it, loss=133.5040]


Epoch 69/110, Train Loss: 3.8773, Train mIoU: 0.8895, Validation Loss: 8.2349, Validation mIoU: 0.7798


Epoch 70/110 Training: 100%|██████████| 112/112 [06:14<00:00,  3.34s/it, loss=69.4531]
Epoch 70/110 Validation: 100%|██████████| 14/14 [00:41<00:00,  2.96s/it, loss=118.3763]


Epoch 70/110, Train Loss: 3.8165, Train mIoU: 0.8715, Validation Loss: 8.0230, Validation mIoU: 0.6956


Epoch 71/110 Training: 100%|██████████| 112/112 [06:23<00:00,  3.42s/it, loss=63.8733]
Epoch 71/110 Validation: 100%|██████████| 14/14 [00:48<00:00,  3.49s/it, loss=108.9589]


Epoch 71/110, Train Loss: 3.7039, Train mIoU: 0.8919, Validation Loss: 8.0240, Validation mIoU: 0.7391


Epoch 72/110 Training: 100%|██████████| 112/112 [06:14<00:00,  3.35s/it, loss=56.2297]
Epoch 72/110 Validation: 100%|██████████| 14/14 [00:35<00:00,  2.53s/it, loss=97.7751] 


Epoch 72/110, Train Loss: 3.6425, Train mIoU: 0.8924, Validation Loss: 8.1175, Validation mIoU: 0.7382


Epoch 73/110 Training: 100%|██████████| 112/112 [06:16<00:00,  3.36s/it, loss=56.9931]
Epoch 73/110 Validation: 100%|██████████| 14/14 [00:42<00:00,  3.04s/it, loss=165.0127]


Epoch 73/110, Train Loss: 3.5973, Train mIoU: 0.8924, Validation Loss: 8.4900, Validation mIoU: 0.7424


Epoch 74/110 Training: 100%|██████████| 112/112 [06:24<00:00,  3.43s/it, loss=61.7518]
Epoch 74/110 Validation: 100%|██████████| 14/14 [00:39<00:00,  2.80s/it, loss=105.7981]


Epoch 74/110, Train Loss: 3.5302, Train mIoU: 0.8919, Validation Loss: 8.5149, Validation mIoU: 0.7573


Epoch 75/110 Training: 100%|██████████| 112/112 [05:59<00:00,  3.21s/it, loss=80.6331] 
Epoch 75/110 Validation: 100%|██████████| 14/14 [00:41<00:00,  2.98s/it, loss=133.6363]


Epoch 75/110, Train Loss: 3.5555, Train mIoU: 0.8908, Validation Loss: 8.3698, Validation mIoU: 0.8183


Epoch 76/110 Training: 100%|██████████| 112/112 [06:13<00:00,  3.34s/it, loss=106.2054]
Epoch 76/110 Validation: 100%|██████████| 14/14 [00:38<00:00,  2.78s/it, loss=160.1930]


Epoch 76/110, Train Loss: 5.2057, Train mIoU: 0.8300, Validation Loss: 10.6319, Validation mIoU: 0.6630


Epoch 77/110 Training: 100%|██████████| 112/112 [06:07<00:00,  3.28s/it, loss=62.4755]
Epoch 77/110 Validation: 100%|██████████| 14/14 [00:39<00:00,  2.79s/it, loss=124.6865]


Epoch 77/110, Train Loss: 4.9726, Train mIoU: 0.8506, Validation Loss: 8.3745, Validation mIoU: 0.7799


Epoch 78/110 Training: 100%|██████████| 112/112 [06:13<00:00,  3.34s/it, loss=94.2802]
Epoch 78/110 Validation: 100%|██████████| 14/14 [00:43<00:00,  3.11s/it, loss=108.3571]


Epoch 78/110, Train Loss: 3.9254, Train mIoU: 0.8825, Validation Loss: 8.1611, Validation mIoU: 0.8055


Epoch 79/110 Training: 100%|██████████| 112/112 [06:14<00:00,  3.34s/it, loss=54.4780]
Epoch 79/110 Validation: 100%|██████████| 14/14 [00:40<00:00,  2.86s/it, loss=142.8563]


Epoch 79/110, Train Loss: 3.9820, Train mIoU: 0.8802, Validation Loss: 8.3835, Validation mIoU: 0.7805


Epoch 80/110 Training: 100%|██████████| 112/112 [06:18<00:00,  3.38s/it, loss=70.4817]
Epoch 80/110 Validation: 100%|██████████| 14/14 [00:43<00:00,  3.10s/it, loss=126.0212]


Epoch 80/110, Train Loss: 4.0326, Train mIoU: 0.8849, Validation Loss: 8.1446, Validation mIoU: 0.7753


Epoch 81/110 Training: 100%|██████████| 112/112 [06:15<00:00,  3.35s/it, loss=57.3315]
Epoch 81/110 Validation: 100%|██████████| 14/14 [00:41<00:00,  2.95s/it, loss=117.1318]


Epoch 81/110, Train Loss: 3.6586, Train mIoU: 0.8890, Validation Loss: 8.0457, Validation mIoU: 0.7760


Epoch 82/110 Training: 100%|██████████| 112/112 [06:21<00:00,  3.40s/it, loss=53.9091]
Epoch 82/110 Validation: 100%|██████████| 14/14 [00:43<00:00,  3.12s/it, loss=118.0915]


Epoch 82/110, Train Loss: 3.7974, Train mIoU: 0.8876, Validation Loss: 8.6246, Validation mIoU: 0.7782


Epoch 83/110 Training: 100%|██████████| 112/112 [06:10<00:00,  3.31s/it, loss=54.1264]
Epoch 83/110 Validation: 100%|██████████| 14/14 [00:37<00:00,  2.69s/it, loss=171.2393]


Epoch 83/110, Train Loss: 3.5809, Train mIoU: 0.8903, Validation Loss: 8.2591, Validation mIoU: 0.7418


Epoch 84/110 Training: 100%|██████████| 112/112 [06:10<00:00,  3.31s/it, loss=53.5306]
Epoch 84/110 Validation: 100%|██████████| 14/14 [00:40<00:00,  2.88s/it, loss=118.7275]


Epoch 84/110, Train Loss: 3.4214, Train mIoU: 0.8936, Validation Loss: 8.0625, Validation mIoU: 0.7751


Epoch 85/110 Training: 100%|██████████| 112/112 [06:15<00:00,  3.35s/it, loss=59.1661]
Epoch 85/110 Validation: 100%|██████████| 14/14 [00:42<00:00,  3.06s/it, loss=88.3193] 


Epoch 85/110, Train Loss: 3.3760, Train mIoU: 0.8939, Validation Loss: 8.4784, Validation mIoU: 0.7499


Epoch 86/110 Training: 100%|██████████| 112/112 [06:11<00:00,  3.32s/it, loss=52.4387]
Epoch 86/110 Validation: 100%|██████████| 14/14 [00:41<00:00,  2.95s/it, loss=124.3296]


Epoch 86/110, Train Loss: 3.4562, Train mIoU: 0.8906, Validation Loss: 8.2824, Validation mIoU: 0.7147


Epoch 87/110 Training: 100%|██████████| 112/112 [06:01<00:00,  3.23s/it, loss=51.7282]
Epoch 87/110 Validation: 100%|██████████| 14/14 [00:36<00:00,  2.63s/it, loss=114.1225]


Epoch 87/110, Train Loss: 4.1122, Train mIoU: 0.8570, Validation Loss: 8.3840, Validation mIoU: 0.7817


Epoch 88/110 Training: 100%|██████████| 112/112 [06:16<00:00,  3.36s/it, loss=67.0899]
Epoch 88/110 Validation: 100%|██████████| 14/14 [00:40<00:00,  2.86s/it, loss=137.6998]


Epoch 88/110, Train Loss: 3.5305, Train mIoU: 0.8920, Validation Loss: 8.7052, Validation mIoU: 0.7417


Epoch 89/110 Training: 100%|██████████| 112/112 [06:05<00:00,  3.26s/it, loss=48.3001]
Epoch 89/110 Validation: 100%|██████████| 14/14 [00:39<00:00,  2.83s/it, loss=136.2785]


Epoch 89/110, Train Loss: 3.5370, Train mIoU: 0.8883, Validation Loss: 7.9023, Validation mIoU: 0.7197


Epoch 90/110 Training: 100%|██████████| 112/112 [06:01<00:00,  3.23s/it, loss=52.5167]
Epoch 90/110 Validation: 100%|██████████| 14/14 [00:41<00:00,  2.98s/it, loss=123.4686]


Epoch 90/110, Train Loss: 3.3700, Train mIoU: 0.8929, Validation Loss: 8.6390, Validation mIoU: 0.7443


Epoch 91/110 Training: 100%|██████████| 112/112 [06:07<00:00,  3.29s/it, loss=53.8160]
Epoch 91/110 Validation: 100%|██████████| 14/14 [00:41<00:00,  2.98s/it, loss=148.2362]


Epoch 91/110, Train Loss: 3.5438, Train mIoU: 0.8623, Validation Loss: 7.9256, Validation mIoU: 0.7569


Epoch 92/110 Training: 100%|██████████| 112/112 [05:58<00:00,  3.20s/it, loss=55.1922]
Epoch 92/110 Validation: 100%|██████████| 14/14 [00:39<00:00,  2.80s/it, loss=125.8430]


Epoch 92/110, Train Loss: 3.4041, Train mIoU: 0.8908, Validation Loss: 8.3716, Validation mIoU: 0.7486


Epoch 93/110 Training: 100%|██████████| 112/112 [05:59<00:00,  3.21s/it, loss=53.0960]
Epoch 93/110 Validation: 100%|██████████| 14/14 [00:41<00:00,  2.96s/it, loss=174.0335]


Epoch 93/110, Train Loss: 3.6003, Train mIoU: 0.8845, Validation Loss: 8.0612, Validation mIoU: 0.7260


Epoch 94/110 Training: 100%|██████████| 112/112 [06:02<00:00,  3.24s/it, loss=53.1766]
Epoch 94/110 Validation: 100%|██████████| 14/14 [00:42<00:00,  3.03s/it, loss=97.6434] 


Epoch 94/110, Train Loss: 3.2547, Train mIoU: 0.8924, Validation Loss: 7.9374, Validation mIoU: 0.7427


Epoch 95/110 Training: 100%|██████████| 112/112 [06:09<00:00,  3.30s/it, loss=76.6273]
Epoch 95/110 Validation: 100%|██████████| 14/14 [00:41<00:00,  2.98s/it, loss=177.8506]


Epoch 95/110, Train Loss: 3.3622, Train mIoU: 0.8833, Validation Loss: 8.4240, Validation mIoU: 0.7834


Epoch 96/110 Training: 100%|██████████| 112/112 [06:06<00:00,  3.27s/it, loss=46.6619]
Epoch 96/110 Validation: 100%|██████████| 14/14 [00:44<00:00,  3.15s/it, loss=147.7455]


Epoch 96/110, Train Loss: 3.3115, Train mIoU: 0.8823, Validation Loss: 8.0758, Validation mIoU: 0.7915


Epoch 97/110 Training: 100%|██████████| 112/112 [06:02<00:00,  3.24s/it, loss=50.5556]
Epoch 97/110 Validation: 100%|██████████| 14/14 [00:36<00:00,  2.62s/it, loss=113.9885]


Epoch 97/110, Train Loss: 3.3966, Train mIoU: 0.8872, Validation Loss: 7.8672, Validation mIoU: 0.8101


Epoch 98/110 Training: 100%|██████████| 112/112 [06:07<00:00,  3.28s/it, loss=46.5159]
Epoch 98/110 Validation: 100%|██████████| 14/14 [00:42<00:00,  3.06s/it, loss=110.4448]


Epoch 98/110, Train Loss: 3.4330, Train mIoU: 0.8896, Validation Loss: 7.9986, Validation mIoU: 0.7775


Epoch 99/110 Training: 100%|██████████| 112/112 [06:02<00:00,  3.24s/it, loss=51.7440]
Epoch 99/110 Validation: 100%|██████████| 14/14 [00:31<00:00,  2.25s/it, loss=140.0768]


Epoch 99/110, Train Loss: 3.6676, Train mIoU: 0.8883, Validation Loss: 7.8084, Validation mIoU: 0.7952


Epoch 100/110 Training: 100%|██████████| 112/112 [06:03<00:00,  3.25s/it, loss=61.2571]
Epoch 100/110 Validation: 100%|██████████| 14/14 [00:41<00:00,  2.98s/it, loss=162.0311]


Epoch 100/110, Train Loss: 3.2296, Train mIoU: 0.8943, Validation Loss: 8.3263, Validation mIoU: 0.7775


Epoch 101/110 Training: 100%|██████████| 112/112 [06:10<00:00,  3.31s/it, loss=63.0106]
Epoch 101/110 Validation: 100%|██████████| 14/14 [00:39<00:00,  2.85s/it, loss=121.1318]


Epoch 101/110, Train Loss: 3.1669, Train mIoU: 0.8949, Validation Loss: 8.4986, Validation mIoU: 0.7572


Epoch 102/110 Training: 100%|██████████| 112/112 [05:53<00:00,  3.16s/it, loss=51.8454]
Epoch 102/110 Validation: 100%|██████████| 14/14 [00:41<00:00,  2.98s/it, loss=105.6758]


Epoch 102/110, Train Loss: 3.1536, Train mIoU: 0.8951, Validation Loss: 7.9503, Validation mIoU: 0.8129


Epoch 103/110 Training: 100%|██████████| 112/112 [06:02<00:00,  3.24s/it, loss=43.6428]
Epoch 103/110 Validation: 100%|██████████| 14/14 [00:39<00:00,  2.85s/it, loss=117.4209]


Epoch 103/110, Train Loss: 3.0725, Train mIoU: 0.8946, Validation Loss: 8.2998, Validation mIoU: 0.7804


Epoch 104/110 Training: 100%|██████████| 112/112 [06:02<00:00,  3.23s/it, loss=44.7161]
Epoch 104/110 Validation: 100%|██████████| 14/14 [00:38<00:00,  2.75s/it, loss=148.5388]


Epoch 104/110, Train Loss: 3.0775, Train mIoU: 0.8955, Validation Loss: 8.7601, Validation mIoU: 0.7790


Epoch 105/110 Training: 100%|██████████| 112/112 [06:08<00:00,  3.29s/it, loss=45.9061]
Epoch 105/110 Validation: 100%|██████████| 14/14 [00:43<00:00,  3.12s/it, loss=95.9401] 


Epoch 105/110, Train Loss: 3.0534, Train mIoU: 0.8956, Validation Loss: 8.2323, Validation mIoU: 0.7695


Epoch 106/110 Training: 100%|██████████| 112/112 [06:09<00:00,  3.30s/it, loss=46.2696]
Epoch 106/110 Validation: 100%|██████████| 14/14 [00:40<00:00,  2.87s/it, loss=124.8912]


Epoch 106/110, Train Loss: 3.0542, Train mIoU: 0.8960, Validation Loss: 8.2351, Validation mIoU: 0.7475


Epoch 107/110 Training: 100%|██████████| 112/112 [06:09<00:00,  3.30s/it, loss=49.5898]
Epoch 107/110 Validation: 100%|██████████| 14/14 [00:41<00:00,  2.96s/it, loss=123.9954]


Epoch 107/110, Train Loss: 3.1767, Train mIoU: 0.8908, Validation Loss: 8.6936, Validation mIoU: 0.7602


Epoch 108/110 Training: 100%|██████████| 112/112 [05:53<00:00,  3.16s/it, loss=55.0713]
Epoch 108/110 Validation: 100%|██████████| 14/14 [00:36<00:00,  2.62s/it, loss=124.2938]


Epoch 108/110, Train Loss: 3.5431, Train mIoU: 0.8806, Validation Loss: 8.6622, Validation mIoU: 0.7661


Epoch 109/110 Training: 100%|██████████| 112/112 [05:53<00:00,  3.16s/it, loss=44.6753]
Epoch 109/110 Validation: 100%|██████████| 14/14 [00:37<00:00,  2.68s/it, loss=97.6525] 


Epoch 109/110, Train Loss: 3.1937, Train mIoU: 0.8936, Validation Loss: 8.1484, Validation mIoU: 0.7677


Epoch 110/110 Training: 100%|██████████| 112/112 [05:48<00:00,  3.11s/it, loss=52.4550]
Epoch 110/110 Validation: 100%|██████████| 14/14 [00:41<00:00,  3.00s/it, loss=124.5382]


Epoch 110/110, Train Loss: 3.1319, Train mIoU: 0.8934, Validation Loss: 8.2447, Validation mIoU: 0.7220


## Test results on A

In [19]:
# Load best model and evaluate on test
model = Mask2FormerForUniversalSegmentation.from_pretrained(f"{best_model_dir}{CURR_TASK}/").to(device)

In [20]:
model.eval()
test_running_loss = 0
test_loader = tqdm(dataloaders[CURR_TASK]["test"], desc="Test loop")
with torch.no_grad():
    for batch in test_loader:
        # Move everything to the device
        batch["pixel_values"] = batch["pixel_values"].to(device)
        batch["pixel_mask"] = batch["pixel_mask"].to(device)
        batch["mask_labels"] = [entry.to(device) for entry in batch["mask_labels"]]
        batch["class_labels"] = [entry.to(device) for entry in batch["class_labels"]]
        # Compute output and loss
        outputs = model(**batch)

        loss = outputs.loss
        # Record losses
        current_loss = loss.item() * batch["pixel_values"].size(0)
        test_running_loss += current_loss
        test_loader.set_postfix(loss=f"{current_loss:.4f}")

        # Extract and compute metrics
        pred_maps, masks = m2f_extract_pred_maps_and_masks(
            batch, outputs, m2f_preprocessor_A
        )
        metric.add_batch(references=masks, predictions=pred_maps)
        
# After compute the batches that were added are deleted
test_metrics_A = metric.compute(
    num_labels=NUM_CLASSES, ignore_index=BG_VALUE_255, reduce_labels=False
)
mean_test_iou = test_metrics_A["mean_iou"]
final_test_loss = test_running_loss / len(dataloaders[CURR_TASK]["test"].dataset)
wandb.log({
    f"Loss/test_{CURR_TASK}": final_test_loss,
    f"mIoU/test_{CURR_TASK}": mean_test_iou
})
print(f"Test Loss: {final_test_loss:.4f}, Test mIoU: {mean_test_iou:.4f}")
wandb.finish()

Test loop: 100%|██████████| 73/73 [02:23<00:00,  1.97s/it, loss=158.3683]


Test Loss: 15.9287, Test mIoU: 0.7058


# Now train on B and forget A

In [8]:
# Training
NUM_EPOCHS = 200
LEARNING_RATE = 1e-4
LR_MULTIPLIER = 0.1
BACKBONE_LR = LEARNING_RATE * LR_MULTIPLIER
WEIGHT_DECAY = 0.05
encoder_params = [
    param
    for name, param in model.named_parameters()
    if name.startswith("model.pixel_level_module.encoder")
]
decoder_params = [
    param
    for name, param in model.named_parameters()
    if name.startswith("model.pixel_level_module.decoder")
]
transformer_params = [
    param
    for name, param in model.named_parameters()
    if name.startswith("model.transformer_module")
]
optimizer = optim.AdamW(
    [
        {"params": encoder_params, "lr": BACKBONE_LR},
        {"params": decoder_params},
        {"params": transformer_params},
    ],
    lr=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
)

scheduler = optim.lr_scheduler.PolynomialLR(
    optimizer, total_iters=NUM_EPOCHS, power=0.9
)

In [9]:
# WandB for team usage !!!!

wandb.login() # use this one if a different person is going to run the notebook
#wandb.login(relogin=False) # if the same person in the last run is going to run the notebook again

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mge85ket[0m ([33mcontinual-learning-tum[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [10]:
wandb.init(
    project="M2F_original",
    config={
        "learning_rate": LEARNING_RATE,
        "learning_rate_multiplier": LR_MULTIPLIER,
        "backbone_learning_rate": BACKBONE_LR,
        "learning_rate_scheduler": scheduler.__class__.__name__,
        "optimizer": optimizer.__class__.__name__,
        "backbone": SWIN_BACKBONE,
        "m2f_preprocessor": m2f_preprocessor_B.__dict__,
        "m2f_model_config": model.config
    },
    name="M2F-Swin-Tiny-Naive-Forgetting",
    notes="M2F with tiny Swin backbone pretrained on ImageNet-1K. \
        Scenario: Pretrained on A, Train on B, Test forgetting on A"
)
# Tensorboard logging
writer = SummaryWriter(log_dir=out_dir + "runs")
# Model checkpointing
model_name = "m2f_swin_backbone_naive_forgetting"
model_dir = out_dir + "models/"
if not os.path.exists(model_dir):
    print("Store weights in: ", model_dir)
    os.makedirs(model_dir)

best_model_dir = model_dir + f"{model_name}/best_model/"
if not os.path.exists(best_model_dir):
    print("Store best model weights in: ", best_model_dir)
    os.makedirs(best_model_dir)
final_model_dir = model_dir + f"{model_name}/final_model/"
if not os.path.exists(final_model_dir):
    print("Store final model weights in: ", final_model_dir)
    os.makedirs(final_model_dir)

In [None]:
# Save the preprocessor
m2f_preprocessor_B.save_pretrained(model_dir + model_name)

In [22]:
# To avoid making stupid errors
CURR_TASK = "B"

# For storing the model
best_val_metric = -np.inf

# Move model to device
model.to(device)

for epoch in range(NUM_EPOCHS):
    model.train()
    train_running_loss = 0.0
    val_running_loss = 0.0

    # Set up tqdm for the training loop
    train_loader = tqdm(
        dataloaders[CURR_TASK]["train"], desc=f"Epoch {epoch + 1}/{NUM_EPOCHS} Training"
    )

    for batch in train_loader:
        # Move everything to the device
        batch["pixel_values"] = batch["pixel_values"].to(device)
        batch["pixel_mask"] = batch["pixel_mask"].to(device)
        batch["mask_labels"] = [entry.to(device) for entry in batch["mask_labels"]]
        batch["class_labels"] = [entry.to(device) for entry in batch["class_labels"]]

        # Compute output and loss
        outputs = model(**batch)

        loss = outputs.loss

        # Compute gradient and perform step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Record losses
        current_loss = loss.item() * batch["pixel_values"].size(0)
        train_running_loss += current_loss
        train_loader.set_postfix(loss=f"{current_loss:.4f}")

        # Extract and compute metrics
        pred_maps, masks = m2f_extract_pred_maps_and_masks(
            batch, outputs, m2f_preprocessor_B
        )
        metric.add_batch(references=masks, predictions=pred_maps)

    # After compute the batches that were added are deleted
    mean_train_iou = metric.compute(
        num_labels=NUM_CLASSES, ignore_index=BG_VALUE_255, reduce_labels=False
    )["mean_iou"]

    # Validation phase
    model.eval()
    val_loader = tqdm(
        dataloaders[CURR_TASK]["val"], desc=f"Epoch {epoch + 1}/{NUM_EPOCHS} Validation"
    )
    with torch.no_grad():
        for batch in val_loader:
            # Move everything to the device
            batch["pixel_values"] = batch["pixel_values"].to(device)
            batch["pixel_mask"] = batch["pixel_mask"].to(device)
            batch["mask_labels"] = [entry.to(device) for entry in batch["mask_labels"]]
            batch["class_labels"] = [
                entry.to(device) for entry in batch["class_labels"]
            ]
            # Compute output and loss
            outputs = model(**batch)

            loss = outputs.loss
            # Record losses
            current_loss = loss.item() * batch["pixel_values"].size(0)
            val_running_loss += current_loss
            val_loader.set_postfix(loss=f"{current_loss:.4f}")

            # Extract and compute metrics
            pred_maps, masks = m2f_extract_pred_maps_and_masks(
                batch, outputs, m2f_preprocessor_B
            )
            metric.add_batch(references=masks, predictions=pred_maps)

    # After compute the batches that were added are deleted
    mean_val_iou = metric.compute(
        num_labels=NUM_CLASSES, ignore_index=BG_VALUE_255, reduce_labels=False
    )["mean_iou"]

    epoch_train_loss = train_running_loss / len(dataloaders[CURR_TASK]["train"].dataset)
    epoch_val_loss = val_running_loss / len(dataloaders[CURR_TASK]["val"].dataset)

    writer.add_scalar(f"Loss/train_{model_name}_{CURR_TASK}", epoch_train_loss, epoch + 1)
    writer.add_scalar(f"Loss/val_{model_name}_{CURR_TASK}", epoch_val_loss, epoch + 1)
    writer.add_scalar(f"mIoU/train_{model_name}_{CURR_TASK}", mean_train_iou, epoch + 1)
    writer.add_scalar(f"mIoU/val_{model_name}_{CURR_TASK}", mean_val_iou, epoch + 1)

    wandb.log({
        f"Loss/train_{CURR_TASK}": epoch_train_loss,
        f"Loss/val_{CURR_TASK}": epoch_val_loss,
        f"mIoU/train_{CURR_TASK}": mean_train_iou,
        f"mIoU/val_{CURR_TASK}": mean_val_iou
    })

    if mean_val_iou > best_val_metric:
        best_val_metric = mean_val_iou
        model.save_pretrained(f"{best_model_dir}{CURR_TASK}/")

    tqdm.write(
        f"Epoch {epoch + 1}/{NUM_EPOCHS}, Train Loss: {epoch_train_loss:.4f}, Train mIoU: {mean_train_iou:.4f}, Validation Loss: {epoch_val_loss:.4f}, Validation mIoU: {mean_val_iou:.4f}"
    )

Epoch 1/5 Training: 100%|██████████| 225/225 [10:28<00:00,  2.79s/it, loss=140.3550]
Epoch 1/5 Validation: 100%|██████████| 28/28 [00:53<00:00,  1.92s/it, loss=147.4061]


Epoch 1/5, Train Loss: 21.2275, Train mIoU: 0.4917, Validation Loss: 17.3837, Validation mIoU: 0.6450


Epoch 2/5 Training: 100%|██████████| 225/225 [09:51<00:00,  2.63s/it, loss=88.4616] 
Epoch 2/5 Validation: 100%|██████████| 28/28 [00:51<00:00,  1.84s/it, loss=90.7679] 


Epoch 2/5, Train Loss: 16.5085, Train mIoU: 0.6278, Validation Loss: 11.5713, Validation mIoU: 0.7547


Epoch 3/5 Training: 100%|██████████| 225/225 [11:22<00:00,  3.03s/it, loss=71.2189] 
Epoch 3/5 Validation: 100%|██████████| 28/28 [00:49<00:00,  1.78s/it, loss=75.3303] 


Epoch 3/5, Train Loss: 11.1197, Train mIoU: 0.7375, Validation Loss: 9.1890, Validation mIoU: 0.7599


Epoch 4/5 Training: 100%|██████████| 225/225 [12:14<00:00,  3.26s/it, loss=74.6497] 
Epoch 4/5 Validation: 100%|██████████| 28/28 [00:49<00:00,  1.77s/it, loss=58.7313] 


Epoch 4/5, Train Loss: 9.1472, Train mIoU: 0.7742, Validation Loss: 8.5610, Validation mIoU: 0.7252


Epoch 5/5 Training: 100%|██████████| 225/225 [11:06<00:00,  2.96s/it, loss=80.4140] 
Epoch 5/5 Validation: 100%|██████████| 28/28 [00:51<00:00,  1.85s/it, loss=83.1679] 


Epoch 5/5, Train Loss: 8.9787, Train mIoU: 0.7828, Validation Loss: 10.4601, Validation mIoU: 0.7588


## Test results on B first

In [14]:
# Load best model and evaluate on test
model = Mask2FormerForUniversalSegmentation.from_pretrained(f"{best_model_dir}{CURR_TASK}/").to(device)

In [15]:
model.eval()
test_running_loss = 0
test_loader = tqdm(dataloaders[CURR_TASK]["test"], desc="Test loop")
with torch.no_grad():
    for batch in test_loader:
        # Move everything to the device
        batch["pixel_values"] = batch["pixel_values"].to(device)
        batch["pixel_mask"] = batch["pixel_mask"].to(device)
        batch["mask_labels"] = [entry.to(device) for entry in batch["mask_labels"]]
        batch["class_labels"] = [entry.to(device) for entry in batch["class_labels"]]
        # Compute output and loss
        outputs = model(**batch)

        loss = outputs.loss
        # Record losses
        current_loss = loss.item() * batch["pixel_values"].size(0)
        test_running_loss += current_loss
        test_loader.set_postfix(loss=f"{current_loss:.4f}")

        # Extract and compute metrics
        pred_maps, masks = m2f_extract_pred_maps_and_masks(
            batch, outputs, m2f_preprocessor_B
        )
        metric.add_batch(references=masks, predictions=pred_maps)

# After compute the batches that were added are deleted
test_metrics_B = metric.compute(
    num_labels=NUM_CLASSES, ignore_index=BG_VALUE_255, reduce_labels=False
)
mean_test_iou = test_metrics_B["mean_iou"]
final_test_loss = test_running_loss / len(dataloaders[CURR_TASK]["test"].dataset)
wandb.log({
    f"Loss/test_{CURR_TASK}": final_test_loss,
    f"mIoU/test_{CURR_TASK}": mean_test_iou
})
print(f"Test Loss: {final_test_loss:.4f}, Test mIoU: {mean_test_iou:.4f}")

Test loop: 100%|██████████| 14/14 [00:43<00:00,  3.12s/it, loss=129.5613]


Test Loss: 8.7328, Test mIoU: 0.7895


In [16]:
wandb.finish()

0,1
Loss/test_B,▁
Loss/train_B,█▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Loss/val_B,█▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
mIoU/test_B,▁
mIoU/train_B,▁▃▅▆▇▇▇████████████████████▇████████████
mIoU/val_B,▁▄▅▇█▇▇█████▇▇█▇▇▇█▇▇▇▇▇▇▇▇▇████▇▇█████▇

0,1
Loss/test_B,8.73283
Loss/train_B,3.13194
Loss/val_B,8.24469
mIoU/test_B,0.78951
mIoU/train_B,0.89339
mIoU/val_B,0.72203


## Test results on A after training on B

In [25]:
# To avoid making stupid errors
CURR_TASK = "A"

model.eval()
test_running_loss = 0
test_loader = tqdm(dataloaders[CURR_TASK]["test"], desc="Test loop")
with torch.no_grad():
    for batch in test_loader:
        # Move everything to the device
        batch["pixel_values"] = batch["pixel_values"].to(device)
        batch["pixel_mask"] = batch["pixel_mask"].to(device)
        batch["mask_labels"] = [entry.to(device) for entry in batch["mask_labels"]]
        batch["class_labels"] = [entry.to(device) for entry in batch["class_labels"]]
        # Compute output and loss
        outputs = model(**batch)

        loss = outputs.loss
        # Record losses
        current_loss = loss.item() * batch["pixel_values"].size(0)
        test_running_loss += current_loss
        test_loader.set_postfix(loss=f"{current_loss:.4f}")

        # Extract and compute metrics
        pred_maps, masks = m2f_extract_pred_maps_and_masks(
            batch, outputs, m2f_preprocessor_A
        )
        metric.add_batch(references=masks, predictions=pred_maps)
        
# After compute the batches that were added are deleted
test_metrics_forgetting_A = metric.compute(
    num_labels=NUM_CLASSES, ignore_index=BG_VALUE_255, reduce_labels=False
)
mean_test_iou = test_metrics_forgetting_A["mean_iou"]
final_test_loss = test_running_loss / len(dataloaders[CURR_TASK]["test"].dataset)
wandb.log({
    f"Loss/test_naive_forgetting_{CURR_TASK}": final_test_loss,
    f"mIoU/test_naive_forgetting_{CURR_TASK}": mean_test_iou
})
print(f"Test Loss: {final_test_loss:.4f}, Test mIoU: {mean_test_iou:.4f}")


Test loop: 100%|██████████| 73/73 [02:05<00:00,  1.73s/it, loss=279.5755]


Test Loss: 38.9396, Test mIoU: 0.4669


In [None]:
# Collect overall mIoU
mIoU_A = test_metrics_A["mean_iou"]
mIoU_forgetting_A = test_metrics_forgetting_A["mean_iou"]
mIoU_B = test_metrics_B["mean_iou"]

# Collect per category mIoU
per_category_mIoU_A = np.array(test_metrics_A["per_category_iou"])
per_category_mIoU_forgetting_A = np.array(test_metrics_forgetting_A["per_category_iou"])
per_category_mIoU_B = np.array(test_metrics_forgetting_A["per_category_iou"])

# Average learning accuracies (mIoUs)
avg_learning_acc = (mIoU_A + mIoU_B) / 2
per_category_avg_learning_acc = (per_category_mIoU_A + per_category_mIoU_B) / 2

# Forgetting
total_forgetting = mIoU_A - mIoU_forgetting_A
per_category_forgetting = (per_category_mIoU_A - per_category_mIoU_forgetting_A)

# Export evaluation metrics to WandB
wandb.log({
    "eval/avg_learning_acc": avg_learning_acc,
    "eval/per_category_avg_learning_acc": per_category_avg_learning_acc,
    "eval/total_forgetting": total_forgetting,
    "eval/per_category_forgetting": per_category_forgetting
})
print("**** Overall mIoU ****")
print(f"mIoU on task A: {mIoU_A}")
print(f"mIoU on task B: {mIoU_B}")
print(f"mIoU on task A after training on B: {mIoU_forgetting_A}")

print("\n**** Per category mIoU ****")
print(f"Per category mIoU on task A: {per_category_mIoU_A}")
print(f"Per category mIoU on task B: {per_category_mIoU_B}")
print(f"Per category mIoU on task A after training on B: {per_category_mIoU_forgetting_A}")

print("\n**** Average learning accuracies ****")
print(f"Average learning acc.: {avg_learning_acc}")
print(f"Per category Average learning acc.: {per_category_avg_learning_acc}")

print("\n**** Forgetting ****")
print(f"Total forgetting: {total_forgetting}")
print(f"Per category forgetting: {per_category_forgetting}")
wandb.finish()