In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import os
from utils.common import (
    m2f_dataset_collate,
    m2f_extract_pred_maps_and_masks,
    set_seed,
    pixel_mean_std,
    CADIS_PIXEL_MEAN,
    CADIS_PIXEL_STD,
    CAT1K_PIXEL_MEAN,
    CAT1K_PIXEL_STD
)
from utils.dataset_utils import (
    get_cadisv2_dataset,
    get_cataract1k_dataset,
    ZEISS_CATEGORIES,
)
from utils.medical_datasets import Mask2FormerDataset
from transformers import (
    Mask2FormerForUniversalSegmentation,
    SwinModel,
    SwinConfig,
    Mask2FormerConfig,
    AutoImageProcessor,
    Mask2FormerImageProcessor,
)
from torch.utils.data import DataLoader
import evaluate
import torch.optim as optim
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from dotenv import load_dotenv
import wandb
from sklearn.cluster import KMeans
import umap
from scipy.spatial.distance import cdist
from copy import deepcopy
import shutil
import random
from utils.wandb_utils import log_table_of_images

In [None]:
NUM_CLASSES = len(ZEISS_CATEGORIES) - 3  + 1 # Remove class incremental and add background !!!
SWIN_BACKBONE = "microsoft/swin-tiny-patch4-window7-224"#"microsoft/swin-large-patch4-window12-384"

# Download pretrained swin model
swin_model = SwinModel.from_pretrained(
    SWIN_BACKBONE, out_features=["stage1", "stage2", "stage3", "stage4"]
)
swin_config = SwinConfig.from_pretrained(
    SWIN_BACKBONE, out_features=["stage1", "stage2", "stage3", "stage4"]
)

# Create Mask2Former configuration based on Swin's configuration
mask2former_config = Mask2FormerConfig(
    backbone_config=swin_config, num_labels=NUM_CLASSES #, ignore_value=BG_VALUE
)

# Create the Mask2Former model with this configuration
model = Mask2FormerForUniversalSegmentation(mask2former_config)

# Reuse pretrained parameters
for swin_param, m2f_param in zip(
    swin_model.named_parameters(),
    model.model.pixel_level_module.encoder.named_parameters(),
):
    m2f_param_name = f"model.pixel_level_module.encoder.{m2f_param[0]}"

    if swin_param[0] == m2f_param[0]:
        model.state_dict()[m2f_param_name].copy_(swin_param[1])
        continue

    print(f"Not Matched: {m2f_param[0]} != {swin_param[0]}")

In [None]:
# Helper function to load datasets
def load_dataset(dataset_getter, data_path, domain_incremental):
    return dataset_getter(data_path, domain_incremental=domain_incremental)


# Helper function to create dataloaders for a dataset
def create_dataloaders(
    dataset, batch_size, shuffle, num_workers, drop_last, pin_memory, collate_fn
):
    return {
        "train": DataLoader(
            dataset["train"],
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=num_workers,
            drop_last=drop_last,
            pin_memory=pin_memory,
            collate_fn=collate_fn,
        ),
        "val": DataLoader(
            dataset["val"],
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=num_workers,
            drop_last=drop_last,
            pin_memory=pin_memory,
            collate_fn=collate_fn,
        ),
        "test": DataLoader(
            dataset["test"],
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
            drop_last=False,
            pin_memory=pin_memory,
            collate_fn=collate_fn,
        ),
    }


# Load datasets
datasets = {
    "A": load_dataset(get_cadisv2_dataset, "../../storage/data/CaDISv2", True),
    "B": load_dataset(get_cataract1k_dataset, "../../storage/data/cataract-1k", True),
}

# pixel_mean_A,pixel_std_A=pixel_mean_std(datasets["A"][0])
# print("pixel mean of A",pixel_mean_A,"pixel std:",pixel_std_A)

pixel_mean_A=np.array(CADIS_PIXEL_MEAN)
pixel_std_A=np.array(CADIS_PIXEL_STD)
pixel_mean_B=np.array(CAT1K_PIXEL_MEAN)
pixel_std_B=np.array(CAT1K_PIXEL_STD)


# Calculate the byte size of one sample (image + mask)
def calculate_sample_size(image, mask):
    image_size = image.numel() * image.element_size()  # Number of elements * bytes per element (for RGB)
    
    # if mask sizes also need to be taken into account, uncomment the below 2 lines!!
    #mask_size = mask.numel() * mask.element_size()  
    #return image_size + mask_size
    
    return image_size


# Function to sample without replacement until target size is reached
def sample_until_target_size(dataset, target_size_bytes):
    sampled_indices = []
    cumulative_size = 0

    indices = list(range(len(dataset)))
    random.shuffle(indices)

    for idx in indices:
        image, mask = dataset[idx]
        sample_size = calculate_sample_size(image, mask)
        if cumulative_size + sample_size <= target_size_bytes:
            sampled_indices.append(idx)
            cumulative_size += sample_size
        else:
            break
    print("cumulative size:",cumulative_size)
    return sampled_indices


# Target size in bytes (32MB)
target_size_bytes = 32 * 1024 * 1024

# Get the sampled indices
sampled_indices = sample_until_target_size(datasets["A"][0], target_size_bytes)
N=len(sampled_indices) # will be used in the secod part of sampling
# subset_A = torch.utils.data.Subset(datasets["A"][0], sampled_indices)
# new_train = torch.utils.data.ConcatDataset([subset_A, datasets["B"][0]])

# pixel_mean_B,pixel_std_B=pixel_mean_std(new_train)
# print("pixel mean of B",pixel_mean_B,"pixel std:",pixel_std_B)


# datasets["B"] = (new_train, datasets["B"][1], datasets["B"][2])

# set_seed(42) # seed everything

# Define preprocessor
swin_processor = AutoImageProcessor.from_pretrained(SWIN_BACKBONE)
m2f_preprocessor_A = Mask2FormerImageProcessor(
    reduce_labels=False,
    ignore_index=255,
    do_resize=False,
    do_rescale=False,
    do_normalize=True,
    image_std=pixel_std_A,
    image_mean=pixel_mean_A,
)

m2f_preprocessor_B = Mask2FormerImageProcessor(
    reduce_labels=False,
    ignore_index=255,
    do_resize=False,
    do_rescale=False,
    do_normalize=True,
    image_std=pixel_std_B,
    image_mean=pixel_mean_B,
)
# Create Mask2Former Datasets

m2f_datasets = {
    "A": {
        "train": Mask2FormerDataset(datasets["A"][0], m2f_preprocessor_A),
        "val": Mask2FormerDataset(datasets["A"][1], m2f_preprocessor_A),
        "test": Mask2FormerDataset(datasets["A"][2], m2f_preprocessor_A),
    },
    "B": {
        "train": Mask2FormerDataset(datasets["B"][0], m2f_preprocessor_B),
        "val": Mask2FormerDataset(datasets["B"][1], m2f_preprocessor_B),
        "test": Mask2FormerDataset(datasets["B"][2], m2f_preprocessor_B),
    },
}

# DataLoader parameters
N_WORKERS = 4
BATCH_SIZE = 16
SHUFFLE = True
DROP_LAST = True

dataloader_params = {
    "batch_size": BATCH_SIZE,
    "shuffle": SHUFFLE,
    "num_workers": N_WORKERS,
    "drop_last": DROP_LAST,
    "pin_memory": True,
    "collate_fn": m2f_dataset_collate,
}

# Create DataLoaders
dataloaders = {
    key: create_dataloaders(m2f_datasets[key], **dataloader_params)
    for key in m2f_datasets
}

print(dataloaders)

In [None]:
m2f_preprocessor_A.reduce_labels, m2f_preprocessor_A.ignore_index

In [None]:
# Check if CUDA is available, otherwise use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Using device: {device}")

BG_VALUE_255=255
base_run_name="M2F-Swin-Tiny-Train_Cadis"
new_run_name="Replay-Samples-Visualization"
project_name = "M2F_latest"
user_or_team = "continual-learning-tum"

In [None]:
# Tensorboard setup
out_dir="outputs/"
if not os.path.exists(out_dir):
    os.makedirs(out_dir)
if not os.path.exists(out_dir+"runs"):
    os.makedirs(out_dir+"runs")
%load_ext tensorboard
%tensorboard --logdir outputs/runs

In [None]:
!CUDA_LAUNCH_BLOCKING=1

In [None]:
# Tensorboard logging
writer = SummaryWriter(log_dir=out_dir + "runs")

# Model checkpointing
model_dir = out_dir + "models/"
if not os.path.exists(model_dir):
    print("Store weights in: ", model_dir)
    os.makedirs(model_dir)

best_model_dir = model_dir + f"{base_run_name}/best_model/"
if not os.path.exists(best_model_dir):
    print("Store best model weights in: ", best_model_dir)
    os.makedirs(best_model_dir)
final_model_dir = model_dir + f"{base_run_name}/final_model/"
if not os.path.exists(final_model_dir):
    print("Store final model weights in: ", final_model_dir)
    os.makedirs(final_model_dir)

In [None]:
# WandB for team usage !!!!

wandb.login() # use this one if a different person is going to run the notebook
#wandb.login(relogin=False) # if the same person in the last run is going to run the notebook again


# First train on dataset A

In [None]:
wandb.init(
    project=project_name,
    name=new_run_name,
    notes="Visualizing samples from different replay methods"
)
print("wandb run id:",wandb.run.id)

In [None]:
# Load best model and evaluate on test

#model = Mask2FormerForUniversalSegmentation.from_pretrained(f"{best_model_dir}{CURR_TASK}/").to(device)

# Construct the artifact path
artifact_path = f"{user_or_team}/{project_name}/best_model_{base_run_name}:latest"

# Load from W&B
api = wandb.Api()
artifact=api.artifact(artifact_path)
model_dir=artifact.download()
model_state_dict_path = os.path.join(model_dir, f"best_model_{base_run_name}.pth" )
model_state_dict = torch.load(model_state_dict_path)
model = Mask2FormerForUniversalSegmentation(mask2former_config)
model.load_state_dict(model_state_dict)
model.to(device)

# Mean Loss

In [None]:
losses = []
encoder_samples = []

# Collect losses and samples
model.eval()
with torch.no_grad():
    for sample in tqdm(m2f_datasets["A"]["train"]):
        sample["pixel_values"] = sample["pixel_values"].to(device)
        sample["pixel_mask"] = sample["pixel_mask"].to(device)
        sample["mask_labels"] = [entry.to(device) for entry in sample["mask_labels"]]
        sample["class_labels"] = [entry.to(device) for entry in sample["class_labels"]]
        outputs = model(**sample)
        losses.append(outputs.loss.item())
        encoder_samples.append(outputs.encoder_last_hidden_state.cpu())

losses_np = np.array(losses)

# Sample images with mean loss
mean_loss = np.mean(losses_np)
differences = np.abs(losses_np - mean_loss)
closest_indices = np.argsort(differences)[:N] # N was calculated above

# Create a subset of B using the mean loss sampled indices
subset_A = [m2f_datasets["A"]["train"][i] for i in closest_indices]

In [None]:
table = wandb.Table(columns=["ID", "Image"])
model.eval()
for i, batch in tqdm(enumerate(subset_A)):
    batch["pixel_values"] = batch["pixel_values"].to(device)
    batch["pixel_mask"] = batch["pixel_mask"].to(device)
    batch["mask_labels"] = [entry.to(device) for entry in batch["mask_labels"]]
    batch["class_labels"] = [entry.to(device) for entry in batch["class_labels"]]
    
    with torch.no_grad():
        outputs = model(**batch)
    
    pred_maps, masks = m2f_extract_pred_maps_and_masks(
        batch, outputs, m2f_preprocessor_A
    )

    log_table_of_images(
        table,
        batch["pixel_values"],
        pixel_mean_A,
        pixel_std_A,
        pred_maps,
        masks,
        i
    )

wandb.log({"Mean_Loss_Samples": table})

# Min Loss

In [None]:
# Sample images with min loss
closest_indices = np.argsort(losses_np)[:N] # N was calculated above

# Create a subset of B using the mean loss sampled indices
subset_A = [m2f_datasets["A"]["train"][i] for i in closest_indices]

table = wandb.Table(columns=["ID", "Image"])
model.eval()
for i, batch in tqdm(enumerate(subset_A)):
    batch["pixel_values"] = batch["pixel_values"].to(device)
    batch["pixel_mask"] = batch["pixel_mask"].to(device)
    batch["mask_labels"] = [entry.to(device) for entry in batch["mask_labels"]]
    batch["class_labels"] = [entry.to(device) for entry in batch["class_labels"]]
    
    with torch.no_grad():
        outputs = model(**batch)
    
    pred_maps, masks = m2f_extract_pred_maps_and_masks(
        batch, outputs, m2f_preprocessor_A
    )

    log_table_of_images(
        table,
        batch["pixel_values"],
        pixel_mean_A,
        pixel_std_A,
        pred_maps,
        masks,
        i
    )

wandb.log({"Min_Loss_Samples": table})

# Max Loss

In [None]:
# Sample images with max loss
closest_indices = np.argsort(losses_np)[-N:] # N was calculated above

# Create a subset of B using the mean loss sampled indices
subset_A = [m2f_datasets["A"]["train"][i] for i in closest_indices]

table = wandb.Table(columns=["ID", "Image"])
model.eval()
for i, batch in tqdm(enumerate(subset_A)):
    batch["pixel_values"] = batch["pixel_values"].to(device)
    batch["pixel_mask"] = batch["pixel_mask"].to(device)
    batch["mask_labels"] = [entry.to(device) for entry in batch["mask_labels"]]
    batch["class_labels"] = [entry.to(device) for entry in batch["class_labels"]]
    
    with torch.no_grad():
        outputs = model(**batch)
    
    pred_maps, masks = m2f_extract_pred_maps_and_masks(
        batch, outputs, m2f_preprocessor_A
    )

    log_table_of_images(
        table,
        batch["pixel_values"],
        pixel_mean_A,
        pixel_std_A,
        pred_maps,
        masks,
        i
    )

wandb.log({"Max_Loss_Samples": table})

# RSS

In [None]:
encoder_samples_np = np.concatenate(encoder_samples)
print(f"Samples shape: {encoder_samples_np.shape}")

# ================== HYPERPARAMETERS ==================#
# Number of UMAP components (not specified in the paper)
N_COMPONENTS = 2

# Number of clusters (M in the paper but not specified explicitly)
N_CLUSTERS = 11

# Number of closest samples per cluster (derived from N)
N_PER_CLUSTER = int(N / N_CLUSTERS)
# =====================================================#

# Flatten the data for each sample
n_samples, dim1, dim2, dim3 = encoder_samples_np.shape
flattened_data = encoder_samples_np.reshape(n_samples, dim1 * dim2 * dim3)

# Apply UMAP to reduce dimensionality
umap_reducer = umap.UMAP(
    n_components=N_COMPONENTS, random_state=42
)  # You can adjust n_components as needed
reduced_data = umap_reducer.fit_transform(flattened_data)

# Perform clustering using KMeans
kmeans = KMeans(n_clusters=N_CLUSTERS, random_state=42)
cluster_labels = kmeans.fit_predict(reduced_data)

# Calculate the distance of each sample to its assigned cluster centroid
centroids = kmeans.cluster_centers_
distances = cdist(reduced_data, centroids, "euclidean")
sample_distances = distances[np.arange(n_samples), cluster_labels]

closest_indices_per_cluster = []

for cluster in range(N_CLUSTERS):
    # Get indices of samples in the current cluster
    cluster_indices = np.where(cluster_labels == cluster)[0]

    # Get distances of these samples to the cluster centroid
    cluster_distances = distances[cluster_indices, cluster]

    # Find the indices of the N_per_cluster closest samples
    closest_indices = cluster_indices[np.argsort(cluster_distances)[:N_PER_CLUSTER]]
    closest_indices_per_cluster.extend(closest_indices)

# Output the indices of the closest samples per cluster
closest_indices_per_cluster = np.array(closest_indices_per_cluster)
print(f"Indicies for the sampled images: {closest_indices_per_cluster}")

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 8))
scatter = plt.scatter(reduced_data[:, 0], reduced_data[:, 1], c=cluster_labels, cmap='viridis', s=50)
plt.scatter(centroids[:, 0], centroids[:, 1], c='red', marker='X', s=200, label='Centroids')
plt.colorbar(scatter, label='Cluster Label')
plt.xlabel('UMAP Dimension 1')
plt.ylabel('UMAP Dimension 2')
plt.title('UMAP Reduction with KMeans Clustering')
plt.legend()
plt.show()

In [None]:
subset_A = [m2f_datasets["A"]["train"][i] for i in closest_indices_per_cluster]

table = wandb.Table(columns=["ID", "Image"])
model.eval()
for i, batch in tqdm(enumerate(subset_A)):
    batch["pixel_values"] = batch["pixel_values"].to(device)
    batch["pixel_mask"] = batch["pixel_mask"].to(device)
    batch["mask_labels"] = [entry.to(device) for entry in batch["mask_labels"]]
    batch["class_labels"] = [entry.to(device) for entry in batch["class_labels"]]
    
    with torch.no_grad():
        outputs = model(**batch)
    
    pred_maps, masks = m2f_extract_pred_maps_and_masks(
        batch, outputs, m2f_preprocessor_A
    )

    log_table_of_images(
        table,
        batch["pixel_values"],
        pixel_mean_A,
        pixel_std_A,
        pred_maps,
        masks,
        i
    )

wandb.log({"RSS_Samples": table})

In [None]:
wandb.finish()