In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torch.nn.functional as F
import os
from utils.common import (
    m2f_dataset_collate,
    m2f_extract_pred_maps_and_masks,
    set_seed,
    CADIS_PIXEL_MEAN,
    CADIS_PIXEL_STD,
    CAT1K_PIXEL_MEAN,
    CAT1K_PIXEL_STD,
)
from utils.kd import compute_kd_loss
from utils.dataset_utils import (
    get_cadisv2_dataset,
    get_cataract1k_dataset,
    ZEISS_CATEGORIES,
)
from utils.medical_datasets import Mask2FormerDataset
from transformers import (
    Mask2FormerForUniversalSegmentation,
    SwinModel,
    SwinConfig,
    Mask2FormerConfig,
    AutoImageProcessor,
    Mask2FormerImageProcessor,
)
from torch.utils.data import DataLoader
import evaluate
import torch.optim as optim
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import wandb
from copy import deepcopy
import shutil
from utils.wandb_utils import log_table_of_images

In [3]:
set_seed(42) # seed everything

Random seed set as 42


In [4]:
NUM_CLASSES = (
    len(ZEISS_CATEGORIES) - 3 + 1
)  # Remove class incremental and add background!!!
SWIN_BACKBONE = "microsoft/swin-tiny-patch4-window7-224"  # "microsoft/swin-large-patch4-window12-384"

# Download pretrained swin model
swin_model = SwinModel.from_pretrained(
    SWIN_BACKBONE, out_features=["stage1", "stage2", "stage3", "stage4"]
)
swin_config = SwinConfig.from_pretrained(
    SWIN_BACKBONE, out_features=["stage1", "stage2", "stage3", "stage4"]
)

# Create Mask2Former configuration based on Swin's configuration
mask2former_config = Mask2FormerConfig(
    backbone_config=swin_config, num_labels=NUM_CLASSES  # , ignore_value=BG_VALUE
)

# Create the Mask2Former model with this configuration
student = Mask2FormerForUniversalSegmentation(mask2former_config)

# Reuse pretrained parameters
for swin_param, m2f_param in zip(
    swin_model.named_parameters(),
    student.model.pixel_level_module.encoder.named_parameters(),
):
    m2f_param_name = f"model.pixel_level_module.encoder.{m2f_param[0]}"

    if swin_param[0] == m2f_param[0]:
        student.state_dict()[m2f_param_name].copy_(swin_param[1])
        continue

    print(f"Not Matched: {m2f_param[0]} != {swin_param[0]}")



Not Matched: hidden_states_norms.stage1.weight != layernorm.weight
Not Matched: hidden_states_norms.stage1.bias != layernorm.bias


In [5]:
# Helper function to load datasets
def load_dataset(dataset_getter, data_path, domain_incremental):
    return dataset_getter(data_path, domain_incremental=domain_incremental)


SWIN_BACKBONE = "microsoft/swin-tiny-patch4-window7-224"  # "microsoft/swin-large-patch4-window12-384"


# Helper function to create dataloaders for a dataset
def create_dataloaders(
    dataset, batch_size, shuffle, num_workers, drop_last, pin_memory, collate_fn
):
    return {
        "train": DataLoader(
            dataset["train"],
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=num_workers,
            drop_last=drop_last,
            pin_memory=pin_memory,
            collate_fn=collate_fn,
        ),
        "val": DataLoader(
            dataset["val"],
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=num_workers,
            drop_last=drop_last,
            pin_memory=pin_memory,
            collate_fn=collate_fn,
        ),
        "test": DataLoader(
            dataset["test"],
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
            drop_last=False,
            pin_memory=pin_memory,
            collate_fn=collate_fn,
        ),
    }


# Load datasets
datasets = {
    "A": load_dataset(get_cadisv2_dataset, "../../storage/data/CaDISv2", True),
    "B": load_dataset(get_cataract1k_dataset, "../../storage/data/cataract-1k", True),
}


pixel_mean_A = np.array(CADIS_PIXEL_MEAN)
pixel_std_A = np.array(CADIS_PIXEL_STD)
pixel_mean_B = np.array(CAT1K_PIXEL_MEAN)
pixel_std_B = np.array(CAT1K_PIXEL_STD)


# Define preprocessor
swin_processor = AutoImageProcessor.from_pretrained(SWIN_BACKBONE)
m2f_preprocessor_A = Mask2FormerImageProcessor(
    reduce_labels=False,
    ignore_index=255,
    do_resize=False,
    do_rescale=False,
    do_normalize=True,
    image_std=pixel_std_A,
    image_mean=pixel_mean_A,
)

m2f_preprocessor_B = Mask2FormerImageProcessor(
    reduce_labels=False,
    ignore_index=255,
    do_resize=False,
    do_rescale=False,
    do_normalize=True,
    image_std=pixel_std_B,
    image_mean=pixel_mean_B,
)

# Create Mask2Former Datasets
m2f_datasets = {
    "A": {
        "train": Mask2FormerDataset(datasets["A"][0], m2f_preprocessor_A),
        "val": Mask2FormerDataset(datasets["A"][1], m2f_preprocessor_A),
        "test": Mask2FormerDataset(datasets["A"][2], m2f_preprocessor_A),
    },
    "B": {
        "train": Mask2FormerDataset(datasets["B"][0], m2f_preprocessor_B),
        "val": Mask2FormerDataset(datasets["B"][1], m2f_preprocessor_B),
        "test": Mask2FormerDataset(datasets["B"][2], m2f_preprocessor_B),
    },
}

# DataLoader parameters
N_WORKERS = 4
BATCH_SIZE = 16
SHUFFLE = True
DROP_LAST = True

dataloader_params = {
    "batch_size": BATCH_SIZE,
    "shuffle": SHUFFLE,
    "num_workers": N_WORKERS,
    "drop_last": DROP_LAST,
    "pin_memory": True,
    "collate_fn": m2f_dataset_collate,
}

# Create DataLoaders
dataloaders = {
    key: create_dataloaders(m2f_datasets[key], **dataloader_params)
    for key in m2f_datasets
}

print(dataloaders)



{'A': {'train': <torch.utils.data.dataloader.DataLoader object at 0x3235676e0>, 'val': <torch.utils.data.dataloader.DataLoader object at 0x31fb38bf0>, 'test': <torch.utils.data.dataloader.DataLoader object at 0x323567890>}, 'B': {'train': <torch.utils.data.dataloader.DataLoader object at 0x323567aa0>, 'val': <torch.utils.data.dataloader.DataLoader object at 0x3235679e0>, 'test': <torch.utils.data.dataloader.DataLoader object at 0x323567c80>}}


In [6]:
# Check if CUDA is available, otherwise use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Using device: {device}")

BG_VALUE_255=255
base_run_name="M2F-Swin-Tiny-Train_Cadis"
new_run_name="M2F-Swin-Tiny-KD-EMA-HyperParameter-Search"
project_name = "M2F_latest"
user_or_team = "continual-learning-tum"

Using device: cpu


In [7]:
# Tensorboard setup
out_dir="outputs/"
if not os.path.exists(out_dir):
    os.makedirs(out_dir)
if not os.path.exists(out_dir+"runs"):
    os.makedirs(out_dir+"runs")
%load_ext tensorboard
%tensorboard --logdir outputs/runs

In [8]:
!CUDA_LAUNCH_BLOCKING=1

In [9]:
# Tensorboard logging
writer = SummaryWriter(log_dir=out_dir + "runs")

# Model checkpointing
model_dir = out_dir + "models/"
if not os.path.exists(model_dir):
    print("Store weights in: ", model_dir)
    os.makedirs(model_dir)

best_model_dir = model_dir + f"{base_run_name}/best_model/"
if not os.path.exists(best_model_dir):
    print("Store best model weights in: ", best_model_dir)
    os.makedirs(best_model_dir)
final_model_dir = model_dir + f"{base_run_name}/final_model/"
if not os.path.exists(final_model_dir):
    print("Store final model weights in: ", final_model_dir)
    os.makedirs(final_model_dir)

## Test results on A

In [10]:
# WandB for team usage !!!!

wandb.login() # use this one if a different person is going to run the notebook
#wandb.login(relogin=False) # if the same person in the last run is going to run the notebook again

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mkristiyan-sakalyan[0m ([33mcontinual-learning-tum[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [11]:
# Construct the artifact path
artifact_path = f"{user_or_team}/{project_name}/best_model_{base_run_name}:latest"
# Load from W&B
api = wandb.Api()
artifact = api.artifact(artifact_path)
model_dir = artifact.download()
model_state_dict_path = os.path.join(model_dir, f"best_model_{base_run_name}.pth")
model_state_dict = torch.load(model_state_dict_path, map_location=device)

# Student
student = Mask2FormerForUniversalSegmentation(mask2former_config)
student.load_state_dict(model_state_dict)
student.to(device)

# Teacher
teacher = Mask2FormerForUniversalSegmentation(mask2former_config)
teacher.load_state_dict(model_state_dict)
teacher.to(device)
# Eval mode for teacher
teacher.eval()

[34m[1mwandb[0m: Downloading large artifact best_model_M2F-Swin-Tiny-Train_Cadis:latest, 181.31MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.8


Mask2FormerForUniversalSegmentation(
  (model): Mask2FormerModel(
    (pixel_level_module): Mask2FormerPixelLevelModule(
      (encoder): SwinBackbone(
        (embeddings): SwinEmbeddings(
          (patch_embeddings): SwinPatchEmbeddings(
            (projection): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
          )
          (norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
        (encoder): SwinEncoder(
          (layers): ModuleList(
            (0): SwinStage(
              (blocks): ModuleList(
                (0-1): 2 x SwinLayer(
                  (layernorm_before): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
                  (attention): SwinAttention(
                    (self): SwinSelfAttention(
                      (query): Linear(in_features=96, out_features=96, bias=True)
                      (key): Linear(in_features=96, out_features=96, bias=True)
                      (value

In [12]:
# Training
NUM_EPOCHS = 200
LEARNING_RATE = 1e-4
LR_MULTIPLIER = 0.1
BACKBONE_LR = LEARNING_RATE * LR_MULTIPLIER
WEIGHT_DECAY = 0.05
PATIENCE = 15
metric = evaluate.load("mean_iou")
encoder_params = [
    param
    for name, param in student.named_parameters()
    if name.startswith("model.pixel_level_module.encoder")
]
decoder_params = [
    param
    for name, param in student.named_parameters()
    if name.startswith("model.pixel_level_module.decoder")
]
transformer_params = [
    param
    for name, param in student.named_parameters()
    if name.startswith("model.transformer_module")
]
class_prediction_params = [
    param
    for name, param in student.named_parameters()
    if not name.startswith("model.pixel_level_module.encoder")
    and not name.startswith("model.transformer_module")
    and not name.startswith("model.pixel_level_module.decoder")
]
optimizer = optim.AdamW(
    [
        {"params": encoder_params, "lr": BACKBONE_LR},
        {"params": decoder_params},
        {"params": transformer_params},
        {"params": class_prediction_params},
    ],
    lr=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
)

scheduler = optim.lr_scheduler.PolynomialLR(
    optimizer, total_iters=NUM_EPOCHS, power=0.9
)
# KD
# LAMBDAS = {"q": 5e-4, "c": 1e-3, "m": 0.3, "pod": 0.1}
LAMBDAS = {"q": 1e-4, "c": 5e-2, "m": 3e-2, "pod": 0.8}

# Initialize EMA parameters
EMA_DECAY = 0.999

In [13]:
# Define a function to update the teacher's parameters.
def update_teacher_ema(student_model, teacher_model, decay):
    student_params = dict(student_model.named_parameters())
    teacher_params = dict(teacher_model.named_parameters())

    for name in teacher_params.keys():
        teacher_params[name].data = (
            decay * teacher_params[name].data + (1 - decay) * student_params[name].data
        )


# 1: Define objective/training function
def objective(config) -> float:

    CURR_LAMBDAS = {
        "q": config["q"],
        "c": config["c"],
        "m": config["m"],
        "pod": config["pod"],
    }

    # ================== Model Initialization ==================#

    # Construct the artifact path
    artifact_path = f"{user_or_team}/{project_name}/best_model_{base_run_name}:latest"
    # Load from W&B
    api = wandb.Api()
    artifact = api.artifact(artifact_path)
    model_dir = artifact.download()
    model_state_dict_path = os.path.join(model_dir, f"best_model_{base_run_name}.pth")
    model_state_dict = torch.load(model_state_dict_path, map_location=device)

    # Student
    student = Mask2FormerForUniversalSegmentation(mask2former_config)
    student.load_state_dict(model_state_dict)
    student.to(device)

    # Teacher
    teacher = Mask2FormerForUniversalSegmentation(mask2former_config)
    teacher.load_state_dict(model_state_dict)
    teacher.to(device)

    # Teacher should be in eval mode
    teacher.eval()

    # Student should be trained
    student.train()
    # =====================================================#

    # ================== Train on B ==================#
    # Set up tqdm for the training loop
    train_loader = tqdm(dataloaders["B"]["train"], desc="Training on B")

    for batch in train_loader:
        # Move everything to the device
        batch["pixel_values"] = batch["pixel_values"].to(device)
        batch["pixel_mask"] = batch["pixel_mask"].to(device)
        batch["mask_labels"] = [entry.to(device) for entry in batch["mask_labels"]]
        batch["class_labels"] = [entry.to(device) for entry in batch["class_labels"]]

        # Compute output and loss
        student_outputs = student(**batch, output_hidden_states=True)

        # Compute output for teacher model
        with torch.no_grad():
            teacher_outputs = teacher(**batch, output_hidden_states=True)

        kd_loss = compute_kd_loss(
            student_outputs, teacher_outputs, lambdas=CURR_LAMBDAS
        )

        loss = student_outputs.loss + kd_loss

        # Compute gradient and perform step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Record losses
        current_loss = loss.item() * batch["pixel_values"].size(0)
        train_loader.set_postfix(loss=f"{current_loss:.4f}")

        # Update teacher model using EMA
        update_teacher_ema(student, teacher, EMA_DECAY)

    # =====================================================#

    # ================== Evaluate on A ==================#

    student.eval()
    # Evaluate on A first
    test_loader_A = tqdm(dataloaders["A"]["test"], desc="Test loop A")
    with torch.no_grad():
        for batch in test_loader_A:
            # Move everything to the device
            batch["pixel_values"] = batch["pixel_values"].to(device)
            batch["pixel_mask"] = batch["pixel_mask"].to(device)
            batch["mask_labels"] = [entry.to(device) for entry in batch["mask_labels"]]
            batch["class_labels"] = [
                entry.to(device) for entry in batch["class_labels"]
            ]
            # Compute output and loss
            student_outputs = student(**batch)

            loss = student_outputs.loss
            # Record losses
            current_loss = loss.item() * batch["pixel_values"].size(0)
            test_loader_A.set_postfix(loss=f"{current_loss:.4f}")

            # Extract and compute metrics
            pred_maps, masks = m2f_extract_pred_maps_and_masks(
                batch, student_outputs, m2f_preprocessor_A
            )
            metric.add_batch(references=masks, predictions=pred_maps)

    # After compute the batches that were added are deleted
    test_metrics_A = metric.compute(
        num_labels=NUM_CLASSES, ignore_index=BG_VALUE_255, reduce_labels=False
    )
    mean_test_iou_A = test_metrics_A["mean_iou"]

    # =====================================================#

    # ================== Evaluate on B ==================#

    test_loader_B = tqdm(dataloaders["B"]["test"], desc="Test loop B")
    # Evaluate on B now
    with torch.no_grad():
        for batch in test_loader_B:
            # Move everything to the device
            batch["pixel_values"] = batch["pixel_values"].to(device)
            batch["pixel_mask"] = batch["pixel_mask"].to(device)
            batch["mask_labels"] = [entry.to(device) for entry in batch["mask_labels"]]
            batch["class_labels"] = [
                entry.to(device) for entry in batch["class_labels"]
            ]
            # Compute output and loss
            student_outputs = student(**batch)

            loss = student_outputs.loss
            # Record losses
            current_loss = loss.item() * batch["pixel_values"].size(0)
            test_loader_B.set_postfix(loss=f"{current_loss:.4f}")

            # Extract and compute metrics
            pred_maps, masks = m2f_extract_pred_maps_and_masks(
                batch, student_outputs, m2f_preprocessor_B
            )
            metric.add_batch(references=masks, predictions=pred_maps)
            
    # After compute the batches that were added are deleted
    test_metrics_B_before = metric.compute(
        num_labels=NUM_CLASSES, ignore_index=BG_VALUE_255, reduce_labels=False
    )
    mean_test_iou_B = test_metrics_B_before["mean_iou"]

    # =====================================================#

    return (mean_test_iou_A + mean_test_iou_B) / 2


def main():
    wandb.init(
        project="kd-hyperparam-search",
    )
    score = objective(wandb.config)
    wandb.log({"score": score})


# 2: Define the search space
sweep_configuration = {
    "method": "random",
    "metric": {"goal": "maximize", "name": "score"},
    # We re looking for optimal lambads
    "parameters": {  # LAMBDAS = {"q": 1e-4, "c": 5e-2, "m": 3e-2, "pod": 0.8}
        "q": {"max": 1e-2, "min": 1e-5, "distribution": "uniform"},
        "c": {"max": 1, "min": 1e-3, "distribution": "uniform"},
        "m": {"max": 1, "min": 1e-3, "distribution": "uniform"},
        "pod": {"max": 10, "min": 1e-2, "distribution": "uniform"},
    },
}

# 3: Start the sweep
sweep_id = wandb.sweep(sweep=sweep_configuration, project="kd-hyperparam-search")

wandb.agent(sweep_id, function=main, count=10)

Create sweep with ID: cj6wqfp9
Sweep URL: https://wandb.ai/continual-learning-tum/kd-hyperparam-search/sweeps/cj6wqfp9


[34m[1mwandb[0m: Agent Starting Run: ucka9hgc with config:
[34m[1mwandb[0m: 	c: 0.9052473669369564
[34m[1mwandb[0m: 	m: 0.4414301195560252
[34m[1mwandb[0m: 	pod: 6.388138559352832
[34m[1mwandb[0m: 	q: 0.005509185668688547
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin


WANDB run id:  ucka9hgc


[34m[1mwandb[0m: Downloading large artifact best_model_M2F-Swin-Tiny-Train_Cadis:latest, 181.31MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.6
  layer_loss = torch.mean(torch.frobenius_norm(a - b, dim=-1))
Training on B:   0%|          | 0/112 [01:06<?, ?it/s, loss=1094.1665]
Test loop A:   0%|          | 0/37 [00:20<?, ?it/s, loss=479.1368]
  iou = total_area_intersect / total_area_union
  acc = total_area_intersect / total_area_label
Test loop B:   0%|          | 0/15 [00:22<?, ?it/s, loss=601.8488]


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
score,▁

0,1
score,0.49836


[34m[1mwandb[0m: Sweep Agent: Waiting for job.
[34m[1mwandb[0m: Job received.
[34m[1mwandb[0m: Agent Starting Run: jh6jpdjr with config:
[34m[1mwandb[0m: 	c: 0.5406306492112376
[34m[1mwandb[0m: 	m: 0.6230035553379925
[34m[1mwandb[0m: 	pod: 5.612962344021869
[34m[1mwandb[0m: 	q: 0.003265430412253134
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


WANDB run id:  jh6jpdjr


[34m[1mwandb[0m: Downloading large artifact best_model_M2F-Swin-Tiny-Train_Cadis:latest, 181.31MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.8
Training on B:   0%|          | 0/112 [01:15<?, ?it/s, loss=859.1275]
Test loop A:   0%|          | 0/37 [00:30<?, ?it/s, loss=466.6858]
  iou = total_area_intersect / total_area_union
  acc = total_area_intersect / total_area_label
Test loop B:   0%|          | 0/15 [00:23<?, ?it/s, loss=609.0536]


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
score,▁

0,1
score,0.5151


[34m[1mwandb[0m: Agent Starting Run: rgvo3nry with config:
[34m[1mwandb[0m: 	c: 0.5755991298633105
[34m[1mwandb[0m: 	m: 0.8532259329480064
[34m[1mwandb[0m: 	pod: 6.444897034245557
[34m[1mwandb[0m: 	q: 0.0021109662113670373
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


WANDB run id:  rgvo3nry


[34m[1mwandb[0m: Downloading large artifact best_model_M2F-Swin-Tiny-Train_Cadis:latest, 181.31MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:0.9
Training on B:   0%|          | 0/112 [00:00<?, ?it/s][34m[1mwandb[0m: Ctrl + C detected. Stopping sweep.
