# Experiments: Diabetes Retinopathy

<img src = 'https://repository-images.githubusercontent.com/195603342/63983100-b4a6-11e9-846c-99b9465f7b3b'>

In [None]:
from collections import Counter

from docarray.typing import TorchTensor, TorchEmbedding, ImageUrl
from typing import Optional
from docarray.documents import ImageDoc
from docarray import BaseDoc, DocVec, DocList
from docarray.typing import ID

class MultiModal(BaseDoc):
    embedding: TorchTensor
    path: ImageUrl
    label: int
    label_description: str
    zeroshot_label: int
    zeroshot_description: str
    task: int
    task_description: str
    id_code: str
    metadata: dict

In [None]:
from sklearn.neighbors import NearestNeighbors
from sklearn.neighbors import NearestCentroid
from sklearn.cluster import DBSCAN
from sklearn.decomposition import PCA
from sklearn import metrics
import numpy as np
from torchvision import transforms
import torch
import clip
import matplotlib.pyplot as plt
import pickle
from PIL import Image
from rich import print
from scipy.stats import trim_mean

from sklearn.metrics.pairwise import cosine_similarity


def stratified_sampling(docs: DocList[MultiModal], zeroshot_labels, n_neighbors=10):
    np.random.seed(43)
# Stratified
    # Create a list to store nearest docs
    sampled_docs = []

    # For each unique class in labels
    for i in np.unique(zeroshot_labels):
        # Filter instances belonging to current class
        class_docs = [doc for doc, label in zip(docs, zeroshot_labels) if label == i]

        # Sample n_samples_per_class from current class_docs
        if len(class_docs) > n_neighbors:
            sampled_docs += np.random.choice(class_docs, n_neighbors, replace=False).tolist()
        else:
            # Maybe the under-represented classes are the wrong one
            sampled_docs += class_docs

    # print(f"length of sampled_docs = {len(sampled_docs)}")

    return DocList[MultiModal](sampled_docs)

def zeroshot_clustering(docs: DocList[MultiModal], text_embs):
    # Convert embeddings to a stack of tensors
    image_embs = torch.stack([doc.embedding for doc in docs])
    image_embs /= image_embs.norm(dim=-1, keepdim=True)
    
    text_emb = text_embs.clone()
    text_emb /= text_emb.norm(dim=-1, keepdim=True)
    
    print(f"image_embs: {image_embs.shape}")
    print(f"text_emb: {text_emb.shape}")
    
    # Compute the similarity matrix between all image embeddings and text embeddings
    similarity = (100.0 * image_embs @ text_emb.T).softmax(dim=-1)
    
    # Find the index of the label with maximum similarity for each image embedding
    zeroshot_labels = torch.argmax(similarity, dim=1).numpy()

    # Number of unique labels
    n_clusters_ = len(set(zeroshot_labels))

    print("Estimated number of clusters: %d" % n_clusters_)
    print(f"Counter zero-shot labels = {Counter(zeroshot_labels)}")
    # print(f"Unique labels = {np.unique(labels)}")
    
    new_docs = DocList[MultiModal](docs)
    new_docs.zeroshot_label = zeroshot_labels
    new_docs.zeroshot_description = [labels_prompts[i] for i in zeroshot_labels]
    
    return new_docs, zeroshot_labels, n_clusters_

def get_zeroshot_sampling(docs: DocList[MultiModal], text_embs, n_neighbors=10):
    
    #1. Get the CLIP embeddings
    x = torch.stack([doc.embedding for doc in docs])
    
    #2. zero-shot clustering the embeddings
    clustered_docs, zeroshot_labels, n_clusters = zeroshot_clustering(docs, text_embs)

    #3. Get stratified samples
    sampled_docs = stratified_sampling(clustered_docs, zeroshot_labels, n_neighbors) 


    print(f"# Sampled docs ({round((len(sampled_docs) / len(clustered_docs)) * 100, 2)}%):", len(sampled_docs))
    print()

    return clustered_docs, sampled_docs 


## Pulling docs

In [None]:
%%time

url = 'file://./brazilian_indian'
dl = DocList[MultiModal].pull(url)
dl.summary()

In [None]:
print(f"Task: {Counter(dl.task)}")
print(f"Task description: {Counter(dl.task_description)}")

print(f"Label: {Counter(dl.label)}")
print(f"Label description: {Counter(dl.label_description)}")

dl[0]

In [None]:
# from docarray.data import MultiModalDataset

list_doclists = list()
unique_tasks = np.unique(dl.task)

for task in unique_tasks:
    doc_list = DocList[MultiModal]([doc for doc in dl if doc.task == task])
    # multi_dataset = MultiModalDataset[MultiModal](doc_list, preprocessing={})
    # processed_datasets.append(multi_dataset)
    list_doclists.append(doc_list)
    
list_doclists

## Visualize clusters

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
embeder, preprocess = clip.load("ViT-L/14@336px", device=device)

In [None]:
class_labels = {
    0: 'No retinopathy', 
    1: 'Retinopathy'
}

In [None]:
def get_label_tokens(embeder, labels_prompts):

    print(labels_prompts)
    label_tokens = clip.tokenize(labels_prompts)
    with torch.no_grad():
        text_embs = embeder.encode_text(label_tokens)
    return text_embs

# text_embs = get_label_tokens(embeder, labels_prompts)

# print(text_embs.shape)

In [None]:
description_sets_long = [
    # Original Descriptions
    {
        'No retinopathy': 'No retinopathy',
        'Retinopathy': 'Mild to Severe non-proliferative diabetic retinopathy and Proliferative diabetic retinopathy with post-laser status'
    },
    # Alternative 1
    {
        'No retinopathy': 'Normal eye fundus without signs of retinal damage',
        'Retinopathy': 'Presence of any degree of diabetic retinopathy, including post-laser treatment'
    },
    # Alternative 2
    {
        'No retinopathy': 'Absence of retinal abnormalities associated with diabetic retinopathy',
        'Retinopathy': 'Eye fundus showing markers of non-proliferative or proliferative diabetic retinopathy, treated or untreated'
    },
    # Alternative 3
    {
        'No retinopathy': 'Eye fundus with no diabetic retinal pathology',
        'Retinopathy': 'Fundus image displaying features of mild to severe diabetic retinopathy, possibly with laser treatment marks'
    },
    # Alternative 4
    {
        'No retinopathy': 'Healthy retina with no evidence of retinopathy',
        'Retinopathy': 'Retina with signs of diabetic retinopathy ranging from mild to severe, including laser-treated cases'
    },
    # Alternative 5
    {
        'No retinopathy': 'Fundus image free of diabetic retinal disease indicators',
        'Retinopathy': 'Fundus showing varying stages of diabetic retinopathy or post-laser treatment indicators'
    }
]

description_sets_short = [

    {
        'No retinopathy': 'Healthy retina',
        'Retinopathy': 'Diabetic damage'
    },
    # Short Alternative 2
    {
        'No retinopathy': 'No damage',
        'Retinopathy': 'Diabetic signs'
    },
    # Short Alternative 3
    {
        'No retinopathy': 'Normal retina',
        'Retinopathy': 'Retinal disease'
    },
    # Short Alternative 4
    {
        'No retinopathy': 'No issues',
        'Retinopathy': 'Retinopathy present'
    },
    # Short Alternative 5
    {
        'No retinopathy': 'Clear fundus',
        'Retinopathy': 'Fundus changes'
    }
]


In [None]:
from matplotlib.patches import Patch


# Templates
templates = [
    "an iris with", 
    "a human eye with", 
    "", 
    "an ocular image with", 
    "a retinal photo with",
    "A fundus image displaying",
    "Visible symptoms in the retina suggest",
    "Retinal scan reveals",
    "Optical image shows",
    "The condition of the retina is"
]

task_all_accuracies = {}


# Iterating over each template
for task_index, task in enumerate(list_doclists):
    all_accuracies = {}
    all_f1_scores = {}
    for template in templates:
        # Initialize lists to store metrics for the current template
        accuracies = []
        f1_scores = []
        # Initialize list to store description labels
        # description_labels = []

        # Iterating over each set of descriptions
        for i, descriptions in enumerate(description_sets_short + description_sets_long):
            description_label = f"Set {i+1}"
            # description_labels.append(description_label)

            labels_prompts = [f"{template} {descriptions[c]}" for c in class_labels.values()]
            text_embs = get_label_tokens(embeder, labels_prompts)

            clustered_docs, sampled_docs = get_zeroshot_sampling(task, text_embs, n_neighbors=10)
            ground_truth_labels = clustered_docs.label
            zeroshot_labels = clustered_docs.zeroshot_label

            accuracy = metrics.accuracy_score(ground_truth_labels, zeroshot_labels)
            f1 = metrics.f1_score(ground_truth_labels, zeroshot_labels, average='weighted')

            accuracies.append(accuracy)
            f1_scores.append(f1)

        all_accuracies[template] = accuracies
        all_f1_scores[template] = f1_scores

    task_all_accuracies[task_index] = {"accuracy": all_accuracies, "f1-score": all_f1_scores}
    


In [None]:
# Combined list of all description sets (both short and long)
all_description_sets = description_sets_short + description_sets_long

# Iterate through each template
score = "f1-score"
best_values = dict()

for task, value in task_all_accuracies.items():
    best_template = None
    best_description_set = None
    best_score = 0

    print("Task: ", task)
    for template, accuracies in value[score].items():
        # Iterate through each description set
        for i, accuracy in enumerate(accuracies):
            if accuracy > best_score:
                best_score = accuracy
                best_template = template
                best_description_set = all_description_sets[i]

    best_values[task] = {"best_score": best_score, "best_template": best_template, "best_description_set": best_description_set}
    print("Best Template:", best_template)
    print("Best Description Set:", best_description_set)
    print("Best F1-Score:", best_score)
    print()

In [None]:
labels_prompts_best = [f"{best_template} {best_description_set[c]}" for c in class_labels.values()]
text_embs_best = get_label_tokens(embeder, labels_prompts_best)

In [None]:
best_values

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Initialize f1_matrix

for task_index, task in enumerate(list_doclists):
    all_f1_scores_ = task_all_accuracies[task_index]['f1-score']
    num_templates = len(templates)
    num_description_sets = len(description_sets_short) + len(description_sets_long)
    f1_matrix = np.zeros((num_templates, num_description_sets))

    # Fill in f1_matrix with the values from all_f1_scores
    for i, template in enumerate(templates):
        for j in range(num_description_sets):
            f1_matrix[i, j] = all_f1_scores_[template][j]

    # Plotting the F1-Score matrix
    fig, ax = plt.subplots(figsize=(12, 12))

    cax = ax.matshow(f1_matrix, cmap='coolwarm')

    fig.colorbar(cax)

    # Set up axes explicitly
    ax.set_xticks(np.arange(num_description_sets))
    ax.set_xticklabels([f"Set {i+1}" for i in range(num_description_sets)], rotation=45)
    ax.set_yticks(np.arange(num_templates))
    ax.set_yticklabels([t if t != '' else '[No template]' for t in templates])

    # Loop over data dimensions and create text annotations.
    for i in range(f1_matrix.shape[0]):
        for j in range(f1_matrix.shape[1]):
            ax.text(j, i, f"{f1_matrix[i, j]:.3f}", ha="center", va="center", color="w")

    plt.xlabel('Description Sets')
    plt.ylabel('Templates')
    plt.title(f'F1-Scores for Task {task_index}')
    plt.savefig(f'ConfusionMatrix_{task_index}.pdf', bbox_inches='tight', pad_inches=0)

    plt.show()


In [None]:
import umap
import numpy as np
import matplotlib.pyplot as plt
from sklearn import metrics

# Initialize UMAP
umap_model = umap.UMAP(n_neighbors=20, n_components=2, min_dist=0.1, metric='cosine')

# Assuming you already have your multi-modal data and text embeddings

for i, docs in enumerate(list_doclists):
    print(f">>>> TASK = {i}")
    clustered_docs, sampled_docs = get_zeroshot_sampling(docs, text_embs_best, n_neighbors=10)
    ground_truth_labels = clustered_docs.label
    zeroshot_labels = clustered_docs.zeroshot_label

    print(f"Accuracy: {metrics.accuracy_score(ground_truth_labels, zeroshot_labels):.3f}")
    print(f"f1_score: {metrics.f1_score(ground_truth_labels, zeroshot_labels, average='weighted'):.3f}")  

    # Assume the embeddings are stored in a 'features' attribute
    docs_embeddings = torch.stack([data.embedding for data in clustered_docs]).detach().cpu().numpy()
    sampled_embeddings = torch.stack([doc.embedding for doc in sampled_docs]).detach().cpu().numpy()

    # Pre-compute 2D UMAP for visualization
    docs_embeddings_2d = umap_model.fit_transform(docs_embeddings)
    sampled_embeddings_2d = umap_model.transform(sampled_embeddings)  # You could save this and only load when needed

    # Define Colors
    cluster_colors = [plt.cm.Spectral(each) for each in np.linspace(0, 1, 2+1)]
    cluster_colors_ = [cluster_colors[i] for i in ground_truth_labels]
    sampled_colors_ = [cluster_colors[i] for i in sampled_docs.zeroshot_label]

    # Fast plotting
    plt.scatter(docs_embeddings_2d[:, 0], docs_embeddings_2d[:, 1], c=cluster_colors_, marker='o', s=300, alpha=0.2, label=f'Clusters')
    plt.scatter(sampled_embeddings_2d[:, 0], sampled_embeddings_2d[:, 1], c=sampled_colors_, marker='*', s=150, label=f'Zero-shot sampling', edgecolors="black")

    plt.legend()
    plt.title(f"UMAP Projection for Domain = {i}")
    # plt.savefig(f'projection_{i}.pdf', bbox_inches='tight', pad_inches=0)
    plt.show()


## TADILER: TADIL + Experience Replay 

In [None]:
from avalanche.training.storage_policy import ParametricBuffer, _ParametricSingleBuffer
from avalanche.benchmarks.utils.data_loader import ReplayDataLoader
from avalanche.training.plugins import SupervisedPlugin

import numpy as np

class RandomExemplarsBuffer(ParametricBuffer):
    def __init__(self, max_size, n_neighbors=10, seed=42, groupby=None, 
                 selection_strategy=None):
        super().__init__(max_size, groupby, selection_strategy)
        self.n_neighbors = n_neighbors
        print(">>>>RandomExemplarsBuffer")

    def update(self, strategy, **kwargs):
        dataset_len = len(strategy.experience.dataset)
        if self.n_neighbors > dataset_len:
            raise ValueError("n_neighbors cannot be greater than the size of the dataset")
        
        np.random.seed(seed)
        random_indices = np.random.choice(dataset_len, self.n_neighbors, replace=False)
        
        x_random = [strategy.experience.dataset[i][0] for i in random_indices]
        # y_random = [1 for i in random_indices]
        y_random = [strategy.experience.dataset[i][1] for i in random_indices]
        # t_random = [0 for i in random_indices]
        t_random = [strategy.experience.dataset[i][2] for i in random_indices]

        x_random_pt = torch.stack(x_random)
        y_random_pt = torch.from_numpy(np.asarray(y_random))
        t_random_pt = torch.from_numpy(np.asarray(t_random))

        # include task labels when creating new_data


        # new_data = TensorDataset(x_random_pt, y_random_pt)
        new_data = TensorDataset(x_random_pt, y_random_pt, t_random_pt)
        new_data = AvalancheDataset(new_data)
        new_data.targets = y_random_pt.tolist()
        # new_data = make_classification_dataset(new_data, task_labels=t_random_pt)
        
        print("Update x_random: ", len(x_random))
        print("Update y_random: ",  len(y_random))
        print("Update new_data: ", len(new_data))
        print("Update new_data[0]: ", len(new_data[0]))
        
        new_groups = self._make_groups(strategy, new_data)
        self.seen_groups.update(new_groups.keys())

        lens = self.get_group_lengths(len(self.seen_groups))
        group_to_len = {group_id: ll for group_id, ll in zip(self.seen_groups, lens)}

        for group_id, new_data_g in new_groups.items():
            ll = group_to_len[group_id]
            if group_id in self.buffer_groups:
                old_buffer_g = self.buffer_groups[group_id]
                old_buffer_g.update_from_dataset(strategy, new_data_g)
                old_buffer_g.resize(strategy, ll)
            else:
                new_buffer = _ParametricSingleBuffer(ll, self.selection_strategy)
                new_buffer.update_from_dataset(strategy, new_data_g)
                self.buffer_groups[group_id] = new_buffer

        for group_id, class_buf in self.buffer_groups.items():
            self.buffer_groups[group_id].resize(strategy, group_to_len[group_id])
                      

class ZeroshotExemplarsBuffer(ParametricBuffer):
    def __init__(self, max_size, n_neighbors=10, groupby=None, 
                 selection_strategy=None):
        super().__init__(max_size, groupby, selection_strategy)
        self.n_neighbors = n_neighbors
        print(">>>>ZeroshotExemplarsBuffer")
        
    def update(self, strategy, **kwargs):
        data_list = list()
        for data in strategy.experience.dataset:
            data_list.append(MultiModal(embedding=data[0], path="", label=data[1], label_description="",
                                        zeroshot_label=-1, zeroshot_description="",
                                            task=data[2], task_description="",
                                            id_code="", metadata={}))
        multimodal_data = DocList[MultiModal](data_list)
        print(type(multimodal_data))
        clustered_docs, sampled_docs = get_zeroshot_sampling(multimodal_data, text_embs_best, n_neighbors=self.n_neighbors)
        
        x_knn = sampled_docs.embedding
        y_knn = sampled_docs.label
        t_knn = sampled_docs.task
        # x_knn, knn_indices = get_zeroshot_sampling(strategy.experience.dataset, self.n_neighbors)
        # y_knn = [strategy.experience.dataset[i][1] for i in knn_indices]
        # t_knn = [strategy.experience.dataset[i][2] for i in knn_indices]
        
        x_knn_pt = torch.stack(x_knn)
        # x_knn_pt = torch.from_numpy(x_knn)
        y_knn_pt = torch.from_numpy(np.asarray(y_knn))
        t_knn_pt = torch.from_numpy(np.asarray(t_knn))
        
        # include task labels when creating new_data
        # new_data = TensorDataset(x_knn_pt, y_knn_pt)
        new_data = TensorDataset(x_knn_pt, y_knn_pt, t_knn_pt)
        new_data = AvalancheDataset(new_data)
        new_data.targets = y_knn_pt.tolist()
        # new_data = make_classification_dataset(new_data, task_labels=y_knn_pt)

        # print("@"*50)
        print("Update x_knn: ", len(x_knn))
        print("Update y_knn: ",  len(y_knn))
        print("Update new_data: ", len(new_data))
        print("Update new_data[0] before: ", len(new_data[0]))
        
        new_groups = self._make_groups(strategy, new_data)
        self.seen_groups.update(new_groups.keys())

        lens = self.get_group_lengths(len(self.seen_groups))
        group_to_len = {group_id: ll for group_id, ll in zip(self.seen_groups, lens)}

        for group_id, new_data_g in new_groups.items():
            ll = group_to_len[group_id]
            if group_id in self.buffer_groups:
                old_buffer_g = self.buffer_groups[group_id]
                old_buffer_g.update_from_dataset(strategy, new_data_g)
                old_buffer_g.resize(strategy, ll)
            else:
                new_buffer = _ParametricSingleBuffer(ll, self.selection_strategy)
                new_buffer.update_from_dataset(strategy, new_data_g)
                self.buffer_groups[group_id] = new_buffer

        for group_id, class_buf in self.buffer_groups.items():
            self.buffer_groups[group_id].resize(strategy, group_to_len[group_id])
        print("Update new_data[0] after: ", len(new_data[0]))


class TADILER(SupervisedPlugin):
    def __init__(self, storage_policy):
        super().__init__()
        self.storage_policy = storage_policy
        # print("TADILER Wrapper")

    def before_training_exp(self, strategy,
                            num_workers: int = 0, shuffle: bool = False,
                            **kwargs):
        
        # print(dir(strategy))
        if len(self.storage_policy.buffer) == 0:
            return

        # print("@"*50)
        print(f"strategy.adapted_dataset: {len(strategy.adapted_dataset[0])}, Length: {len(strategy.adapted_dataset)}")
        print(f"self.storage_policy.buffer: {len(self.storage_policy.buffer[0])}, Length: {len(self.storage_policy.buffer)}")
        print(f"strategy.train_mb_size: {strategy.train_mb_size}")
        
        strategy.dataloader = ReplayDataLoader(
            strategy.adapted_dataset,
            self.storage_policy.buffer,
            # oversample_small_tasks=True,
            # task_balanced_dataloader=True,
            num_workers=num_workers,
            batch_size=strategy.train_mb_size,
            shuffle=shuffle)
        
    def after_training_exp(self, strategy, **kwargs):
        self.storage_policy.update(strategy)


### Train/Test split

In [None]:
from torch.utils.data import Dataset, DataLoader
from random import shuffle, seed

class CustomDataset(Dataset):
    """Dataset generator for the Custom dataset"""

    def __init__(self, docs: DocList[MultiModal], transform=None):
        self.docs = docs
        self.transform = transform
        self.targets = [doc.label for doc in self.docs]

    def __len__(self):
        """Returns the size of the dataset"""
        return len(self.docs)

    def __getitem__(self, idx):
        """Returns a batch of image, labels as Torch tensors"""
        image = self.docs[idx].embedding
        label = self.docs[idx].label

        if self.transform is not None:
            image = self.transform(image)

        return image, label


In [None]:
from docarray import DocVec
DocVec[MultiModal]

In [None]:
from random import shuffle, seed

def split_datasets(docs: DocList[MultiModal], seed_val=123, train_percentage=0.8):

    shuffled_docs = list(docs)
    seed(seed_val)
    shuffle(shuffled_docs)
    shuffled_docs = DocList[MultiModal](shuffled_docs)

    split_idx = int(train_percentage * len(shuffled_docs))
    train_docs = shuffled_docs[:split_idx]
    test_docs = shuffled_docs[split_idx:]


    train_dataset = CustomDataset(train_docs)
    test_dataset = CustomDataset(test_docs)

    return train_dataset, test_dataset


In [None]:
import torch
from torch.utils.data import random_split

# Define the split sizes for each dataset
train_size = 0.8
SEED = 42


train_sets_method = []
test_sets_method = []
for doclist in list_doclists:

    train_dataset, test_dataset = split_datasets(doclist, seed_val=SEED, train_percentage=train_size)


    train_sets_method.append(train_dataset)
    test_sets_method.append(test_dataset)

print("#Method data")
print('Training lenghts: ', [len(ts) for ts in train_sets_method])
print('Testing lenghts: ', [len(ts) for ts in test_sets_method])


In [None]:
Counter(train_sets_method[1].docs.label_description)

In [None]:
# %%time
from torch.utils.data.dataset import TensorDataset
from avalanche.benchmarks.utils import AvalancheDataset
from avalanche.benchmarks.utils import make_classification_dataset
from torch.utils.data import DataLoader
from avalanche.benchmarks.generators import filelist_benchmark, dataset_benchmark
from torch.utils.data import TensorDataset
    
# Method benchmark
training_datasets = list()
testing_datasets = list()

for task, (train_s, test_s) in enumerate(zip(train_sets_method, test_sets_method), start=0):

    training_datasets.append(make_classification_dataset(train_s, task_labels=train_s.docs.task))
    testing_datasets.append(make_classification_dataset(test_s, task_labels=test_s.docs.task))

    
benchmark_method = dataset_benchmark(
    train_datasets=training_datasets,
    test_datasets=testing_datasets,
    # other_streams_datasets={'metadata': make_classification_dataset(dataset=[e.id_code for e in train_s], targets=[e.label for e in train_s])}
    # other_streams_datasets={'multimodal_train': make_classification_dataset(dataset=[e.id_code for e in train_s], targets=[e.label for e in train_s])}
)

In [None]:
# np.unique(downsampled_data['DR_ICDR'].values)

In [None]:
benchmark_method.task_labels

In [None]:
def get_sequence(sequence, stream):
    return [stream[i] for i in sequence]
sequence = [0, 1]
stream_seq = get_sequence(sequence, benchmark_method.train_stream)
stream_seq

In [None]:
# torch.all(stream_seq[0].dataset[0][0] == stream_seq[1].dataset[0][0])

## Prepare the strategies to run

In [None]:
def get_strategy(strategy_name, model, optimizer, criterion, eval_plugin, n_epochs, custom_replay):
    
    strategies = {
        'Naive': Naive( #Regularization-based method
            model=model,
            optimizer=optimizer,
            criterion=criterion,
            train_mb_size=200,
            train_epochs=n_epochs,
            eval_mb_size=200,
            device=device,
            evaluator=eval_plugin,
            plugins=[custom_replay],
        ),
        'EWC': EWC( #Regularization-based method
            model=model,
            optimizer=optimizer,
            criterion=criterion,
            train_mb_size=200,
            train_epochs=n_epochs,
            eval_mb_size=200,
            device=device,
            evaluator=eval_plugin,
            ewc_lambda=0.2,
            plugins=[custom_replay],
        ),
        'Replay': Replay( #Rehearsal-based method
            model=model,
            optimizer=optimizer,
            criterion=criterion,
            train_mb_size=200,
            train_epochs=n_epochs,
            eval_mb_size=200,
            device=device,
            evaluator=eval_plugin,
            plugins=[custom_replay],
            mem_size=20,
        ),
        'LwF': LwF( #Architecture-based method
            model=model,
            optimizer=optimizer,
            criterion=criterion,
            alpha=0.5,
            temperature=0.2,
            train_epochs=n_epochs,
            device=device,
            train_mb_size=200,
            eval_mb_size=200,
            evaluator=eval_plugin,
            plugins=[custom_replay],
        ),
        'GEM': GEM( #Architecture-based method
            model=model,
            optimizer=optimizer,
            criterion=criterion,
            # alpha=0.5,
            # temperature=0.2,
            patterns_per_exp=10,
            train_epochs=n_epochs,
            device=device,
            train_mb_size=200,
            eval_mb_size=200,
            evaluator=eval_plugin,
            plugins=[custom_replay],
        ),

    }

    return strategies[strategy_name]

## Prepare and run the experiments

In [None]:
# root_exp_name = 'Agnostic_no_repetition_improver'
# root_exp_name = 'Agnostic_with_repetition'
# sequence = [0, 1, 2, 1, 3, 3, 4, 5]
sequence = list(range(len(benchmark_method.train_stream)))
print(f"Sequence = {sequence}")
n_epochs = 6

# n_neighbors = 50
# list_n_neighbors = [1, 5, 10, 15]
# list_n_neighbors = [15, 20]
list_n_neighbors = [15, 20, 25, 30, 50]
# list_n_neighbors = [25, 30, 50]

# list_n_neighbors = [1]
# random_seeds = [23, 24, 25]
# random_seeds = [34, 88, 100]
# random_seeds = [3]
random_seeds = [3, 11, 51]

num_experiments = len(random_seeds)
print(f"n_exps = {num_experiments}")
# print(f"Setup = {root_exp_name}")
print(f"Epochs = {n_epochs}")
print(f"Seeds = {random_seeds}")
print(f"list n_neighbors = {list_n_neighbors}")

In [None]:
# %%time
import torch
import copy
import random
# import wandb
import gc
from torch.nn import CrossEntropyLoss
import copy
import statistics as st
from torch.optim import Adam
from rich import print
from avalanche.models import as_multitask
from avalanche.benchmarks.classic import SplitMNIST, PermutedMNIST
from avalanche.models import MTSimpleMLP, MTSimpleCNN, SimpleCNN, SimpleMLP
# from avalanche.training.supervised import EWC, Naive, LFL
from avalanche.evaluation.metrics import (
    forgetting_metrics,
    accuracy_metrics,
    confusion_matrix_metrics
)
from avalanche.training.plugins import EvaluationPlugin
from avalanche.logging import InteractiveLogger, TextLogger, TensorboardLogger, WandBLogger
# from avalanche.training import Naive
from torch.nn import CrossEntropyLoss
from torch.optim import SGD
from avalanche.benchmarks.classic import SplitMNIST, PermutedMNIST, RotatedMNIST, SplitCIFAR10, SplitCIFAR100
from avalanche.benchmarks.generators import filelist_benchmark, dataset_benchmark
from avalanche.training.utils import adapt_classification_layer

from datetime import datetime
from avalanche.evaluation.metrics.images_samples import images_samples_metrics
from avalanche.evaluation.metrics.labels_repartition import (
    labels_repartition_metrics,
)
from avalanche.evaluation.metrics.mean_scores import mean_scores_metrics
from torch.nn import Linear
from avalanche.training.supervised import Replay, GEM, LwF, EWC, Naive

from torch.nn import Sequential, Linear, ReLU

for exp in range(num_experiments):

    for n_neighbors in list_n_neighbors:
        print(f"Length training: {len(benchmark_method.train_stream[0].dataset)}")
        print(f"Length testing: {len(benchmark_method.test_stream[0].dataset)}")

        print("-"*50)

        n_workers = 8
        device = 'cpu'
        seed = random_seeds[exp]
        random.seed(seed)
        torch.manual_seed(seed)
        # print(f"Scenario = {root_exp_name}")
        print(f"Experiment: {exp+1} / {num_experiments}")
        print(f"Seed = {seed}")
        print(f"n_epochs = {n_epochs}")
        print(f"n_neighbors = {n_neighbors}")
        # torch.cuda.manual_seed(seed)
        # all_strategies = ['TADILER']
        # all_strategies = ['Naive']
        n_classes = len(np.unique(dl.label))
        embedding_shape = dl[0].embedding.shape[0]
        # all_strategies = ['Replay_Random', 'TADILER']
        # all_strategies = ['Naive']
        all_strategies = ['Naive', 'GEM', 'LwF', 'EWC']
        # all_strategies = ['EWC', 'LwF']
        # all_strategies = ['EWC','Replay', 'LwF', 'TADILER', 'Replay_Random', 'Naive']

        for index_strat, strategy_name in enumerate(all_strategies):

            selection_strategy_random = TADILER(RandomExemplarsBuffer(max_size=2000, groupby=None, 
                                                               n_neighbors=n_classes*n_neighbors))
            selection_strategy_tadiler = TADILER(ZeroshotExemplarsBuffer(max_size=2000, groupby=None, 
                                                            n_neighbors=n_neighbors))
            
            dict_selection_strategy = {'Random': selection_strategy_random, 
                                       'TADILER': selection_strategy_tadiler}
            
            for index_select, (name_selection, selection_strategy) in enumerate(dict_selection_strategy.items()):
                benchmark_method_train_stream = get_sequence(sequence, benchmark_method.train_stream)
                benchmark_method_test_stream = get_sequence(sequence, benchmark_method.test_stream)

                loggers = []
                # wandb_logger = WandBLogger(
                #     project_name="TADILER", run_name=f"strategy-{strategy_name}-knn-{n_neighbors}-seed-{seed}"
                # )
                loggers.append(InteractiveLogger())
                # loggers.append(wandb_logger)

                eval_plugin = EvaluationPlugin(
                    accuracy_metrics(
                        minibatch=False, epoch=False, experience=True
                    ),

                    forgetting_metrics(experience=True),
                    confusion_matrix_metrics(save_image=False, normalize=None, stream=True),
                    loggers=loggers,
                )
                
                n_inputs = embedding_shape  # input size after CLIP embedding
                n_outputs = n_classes 
                hidden_layer_size = 256 

                # Define the model
                model = Sequential(
                    Linear(n_inputs, hidden_layer_size),
                    ReLU(),
                    Linear(hidden_layer_size, n_outputs)
                )

                mt_model = as_multitask(model, '2')  # modify 'fc' with the index or name of your last layer

                optimizer = Adam(mt_model.parameters(), lr=0.001)
                criterion = CrossEntropyLoss()
                # criterion = CrossEntropyLoss(weight=class_weights)
        
                strategy = get_strategy(strategy_name, mt_model, optimizer, criterion, 
                                        eval_plugin, n_epochs, selection_strategy)

                print(f"Running strategy: {strategy_name}-{name_selection}, experiment: {exp}, n_neighbors={n_neighbors}")
                results = {key: [] for key in ['values']}


                print(f"Sequence: {sequence}")
                for index, (task, experience) in enumerate(zip(sequence, benchmark_method_train_stream)):

                    curr_experience = experience.current_experience
                    print("Current experience: ", curr_experience)    
                    print("Experience task ID ", experience.task_label)
                    print('Experience shape:', len(experience.dataset))            

                    print("Training multi-head model")
                    # print(strategy)
                    strategy.train(experience, num_workers=n_workers)
                    print('Training multi-head model completed')            

                    print(f"Evaluation benchmark: {strategy_name}-{name_selection}")

                    values = strategy.eval(benchmark_method_test_stream[:index+1], num_workers=n_workers)

                    if index == 0 and index_select == 0:
                        values_initial = values
                    elif index == 0 and index_select == 1:
                        values = values_initial

                    results['values'].append(values)

                    print("@"*60)
                    print()

                with open(f'./results/{strategy_name}-{name_selection}_knn_{n_neighbors}_seed_{seed}.pickle', 'wb') as handle:
                    pickle.dump(results, handle, protocol=pickle.HIGHEST_PROTOCOL)
                    print(f"Saved results for {root_exp_name}_{strategy_name}_{name_selection}, seed: {seed}")

                print("&"*100)
                print()
        print("#"*100)


In [None]:
print(f"List of n_neighbors: {list_n_neighbors}")
print(f"Seeds: {random_seeds}")
print(f"Strategies: {all_strategies}")

In [None]:
# It takes 40min to run all the experiments

In [None]:
import glob
import pickle
import matplotlib.pyplot as plt
import numpy as np

from rich import print

global_avg_acc = dict() # Final Table
s_name = "Naive"
# n_neighbors = "*"
n_neighbors = 25
seed = "3"
# seed = 3 #*
# 51, 42, 7
files = glob.glob(f"./results/{s_name}*_*knn_{n_neighbors}_seed_{seed}.pickle")
# files = glob.glob(f"./results/{root_exp_name}*_*knn_{n_neighbors}_seed_{seed}.pickle")
print(files)

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import pickle


def amca_from_confusion_matrix(confusion_matrix):
    # Convert tensor to numpy array
    confusion_matrix = confusion_matrix.numpy()
    # Calculate per-class accuracy (diagonal elements / sum of row elements)
    class_accuracies = np.diagonal(confusion_matrix) / np.sum(confusion_matrix, axis=1)
    # Return the average of class accuracies
    return np.nanmean(class_accuracies)

def weighted_f1_from_confusion_matrix(confusion_matrix):
    # Convert tensor to numpy array if it's a tensor
    if hasattr(confusion_matrix, 'numpy'):
        confusion_matrix = confusion_matrix.numpy()
    
    # Calculate precision for each class: TP / (TP + FP)
    # TP is the diagonal element and FP is the sum of column excluding the diagonal
    precision = np.diagonal(confusion_matrix) / np.sum(confusion_matrix, axis=0)
    
    # Calculate recall for each class: TP / (TP + FN)
    # TP is the diagonal element and FN is the sum of row excluding the diagonal
    recall = np.diagonal(confusion_matrix) / np.sum(confusion_matrix, axis=1)
    
    # Calculate F1-score for each class: 2 * (precision * recall) / (precision + recall)
    f1_scores = 2 * (precision * recall) / (precision + recall)
    
    # Calculate the weighted F1-score
    weights = np.sum(confusion_matrix, axis=1) / np.sum(confusion_matrix)
    weighted_f1 = np.nansum(f1_scores * weights)
    
    return weighted_f1

data_dict = {}

for f in files:
    # Get the method and seed from the filename
    method_seed = f.split('_')[-5:]
    method = method_seed[0]
    seed = method_seed[4].split('.')[0]
    knn = method_seed[2]

    # Load the results
    with open(f, 'rb') as handle:
        r = pickle.load(handle)

    data = r['values']
    performances = []

    for i in range(len(data)):
        # Calculate the AMCA of the current task
        current_task_performance = amca_from_confusion_matrix(data[i]['ConfusionMatrix_Stream/eval_phase/test_stream'])
        performances.append(current_task_performance)

    # If the method is not in the dictionary, add it
    if method not in data_dict:
        data_dict[method] = {}

    # Add the performances to the data dictionary
    data_dict[method][seed] = performances

plt.figure(figsize=(10, 6))

for method, seeds in data_dict.items():
    # Get all the performances for this method
    all_performances = np.array(list(seeds.values()))

    # Calculate the mean, min, and max performances for each task
    mean_performances = np.mean(all_performances, axis=0)
    print(f'Method: {method} | Values: {mean_performances}')
    min_performances = np.min(all_performances, axis=0)
    max_performances = np.max(all_performances, axis=0)

    # Calculate overall average performance
    overall_avg_performance = np.mean(mean_performances)

    # Generate task numbers
    tasks = range(1, len(mean_performances) + 1)

    # Plot the mean performances and fill between the min and max performances
    if 'Tadiler' in method:
        plt.plot(tasks, mean_performances, marker='*', linestyle='-', color='#8B0000',
                 label=f'{method} (avg: {overall_avg_performance:.3f})', linewidth=2)
        plt.fill_between(tasks, min_performances, max_performances, color='#8B0000', alpha=0.2)
    else:
        plt.plot(tasks, mean_performances, marker='o', linestyle='-',
                 label=f'{method} (avg: {overall_avg_performance:.3f})')
        plt.fill_between(tasks, min_performances, max_performances, alpha=0.2)
    # plt.fill_between(tasks, min_performances, max_performances, alpha=0.2)

plt.xlabel('Task Number')
plt.ylabel('AMCA')
plt.title(f'AMCA over Tasks, n_samples={knn}')
plt.grid(True)
plt.xticks(tasks)
plt.legend()

plt.show()


In [None]:
global_avg_acc = dict() # Final Table
# all_strategies = ['LwF', 'EWC']
all_strategies = ['Naive','GEM', 'LwF', 'EWC']

for strategy_name in all_strategies:
    # root_exp_name = f'Agnostic_no_repetition_improver_{strategy_name}'
    n_neighbors = "*"
    # n_neighbors = 1
    seed = "*"
    # seed = "11"
    # 51, 42, 7
    files = glob.glob(f"./results/{strategy_name}*_*knn_{n_neighbors}_seed_{seed}.pickle")
    # files = [f for f in files if "seed_3" in f]
    files = [f for f in files if "seed_3" in f or "seed_11" in f]
    # files = [f for f in files if "seed_11" in f or "seed_7." in f]
    # print(files)

    data_dict = {}

    for f in files:

        method_seed = f.split('_')[-5:]
        method = method_seed[0]
        seed = method_seed[4].split('.')[0]
        knn = int(method_seed[2])

        # Load the results
        with open(f, 'rb') as handle:
            r = pickle.load(handle)

        data = r['values']
        # task_amcas = [weighted_f1_from_confusion_matrix(d['ConfusionMatrix_Stream/eval_phase/test_stream']) for d in data]
        task_amcas = [amca_from_confusion_matrix(d['ConfusionMatrix_Stream/eval_phase/test_stream']) for d in data]
        overall_amca = np.mean(task_amcas)  # Average of AMCAs for all tasks

        # If the method is not in the dictionary, add it
        if method not in data_dict:
            data_dict[method] = {}

        # Add the AMCA to the data dictionary
        if knn not in data_dict[method]:
            data_dict[method][knn] = []

        data_dict[method][knn].append(overall_amca)

    plt.figure(figsize=(10, 6))

    for method, knn_values in data_dict.items():
        # Prepare data
        method = method.replace("Random", "Default")
        knns = sorted(list(knn_values.keys()))
        means = [np.mean(knn_values[knn]) for knn in knns]
        mins = [np.min(knn_values[knn]) for knn in knns]
        maxs = [np.max(knn_values[knn]) for knn in knns]

        overall_avg_amca = np.mean(means)  # Average AMCA over all n_neighbors

        print(method)
        if 'TADILER' in method:
            plt.plot(knns, means, marker='*', linestyle='-', color='#8B0000', 
                     label=f'{method} (avg: {overall_avg_amca:.3f})', linewidth=2)
            plt.fill_between(knns, mins, maxs, color='#8B0000', alpha=0.2)
        else:
            plt.plot(knns, means, marker='o', linestyle='-', 
                     label=f'{method} (avg: {overall_avg_amca:.3f})')
            plt.fill_between(knns, mins, maxs, alpha=0.2)

    plt.xlabel('n_samples')
    plt.ylabel('Average AMCA')
    plt.title(f'Dataset: Diabetes Retinopathy | Strategy: {strategy_name} ')
    plt.grid(True)
    plt.xticks(knns)
    plt.legend()
    plt.savefig(f'tadiler_diabetes_{strategy_name}.pdf', bbox_inches='tight', pad_inches=0)
    plt.show()
