# **Deep Learning - Test-Time Adaptation**
---
**University of Trento, Academic Year 2023/2024**

---
#### **Group 26**
> <a href="https://github.com/giuseppecurci">Giuseppe Curci</a> \
> 243049

> <a href="https://github.com/andy295">Andrea Cristiano</a> \
> 229370
---
---

In [None]:
# to replace with a requirements file

!pip install ollama # if ollama is not available, install by executing
                    # the install_and_run_ollama.sh script
!pip install diffusers
!pip install bing_image_downloader
!pip install git+https://github.com/openai/CLIP.git

In [None]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
import torchvision.transforms as T
import torchvision.models as models
from scipy import stats
import math

from typing import List, Union, Dict
from PIL import Image
from tqdm import tqdm
from io import BytesIO
from pathlib import Path
import boto3 # read and write for AWS buckets
import json
import os
import random
import time

from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, DPMSolverMultistepScheduler
import ollama
import clip
from bing_image_downloader import downloader

import warnings
warnings.filterwarnings('ignore')

## **Introduction**

### **Domain Shift**
Domain shift refers to the change in the data distribution between the training phase and the testing phase of a machine learning model. In other words, the data that the model encounters during deployment (inference) differs from the data it was trained on.

Machine learning models are typically trained on a specific dataset. The fundamental assumption is that the training data distribution is as representative as possible of what the models will encounter during deployment. In the presence of domain shift, the model may not generalize well, leading to a degradation in performance.

In real-world applications, it is easy to identify cases where data distributions may vary, and the causes can be numerous, such as changes in the environment, noise in the sensors used to acquire information, etc.

A model trained on images depicting a scene during the spring season may perform poorly when tested on images of the same scene but taken during the winter season, even if the scene's content is the same.

### **Test-Time Adaptation (TTA)**
Test-Time Adaptation refers to techniques that allow a previously trained model to adapt to a new data distribution during inference (testing), without the need for a new training phase.

The goal is to improve the model's performance on the shifted domain by adapting its parameters or predictions based on the data encountered during testing.

## **Pipeline**

Before explaining the methods implemented, the following picture provides an overview of how predictions are made by our models.

<p align="center">
  <img src="images/prediction_pipeline_colored.png" width="800" height="400">  
</p>

1. **Image Classification and Generation**: Given a sample image for classification, the `top_j` images, previously generated using a diffusion model and a LLM, are retrieved based on cosine similarity between `CLIP` image embeddings. Details of this generation process are elaborated in subsequent sections.

2. **Augmentation and Confidence Filtering**: `k` augmentations of the original image are generated, and the corresponding probability tensors are computed. An entropy-based confidence selection filter is then applied to identify the `top_aug` augmentations.

3. **Marginal Output Distribution**: The probabilities of the generated images and the augmentations are combined, and the marginal output distribution is computed (refer to MEMO for more details).

4. **Model Update**: The model is updated by minimizing the entropy of the marginal output distribution tensor. Multiple updates may occur, but the augmentations remain unchanged.

5. **Final Prediction**: The updated model computes new probabilities on the previous augmentations (TTA). These are filtered again using the confidence selection mechanism, and the final prediction is obtained by applying `softmax` and `argmax` on the marginal output distribution. Note: Generated images are used only until the update step, with TTA performed exclusively on the augmentations.


## **Utils**



### **ImageNet-A masking**

The **imagenetA_masking.json** file provides a masking for ImageNet-A dataset indices to the standard 1000-class ImageNet output indices used by pre-trained models in PyTorch's torchvision library.

Each key in the file corresponds to an index in the standard 1000-class ImageNet output vector. The value associated with each key indicates whether that index should be considered when mapping the 1000-class output to a the smaller set of classes ImageNet-A.

A value of -1 indicates that the corresponding index in the 1000-class output should be ignored in the subset of outputs for ImageNet-A.
A non-negative integer value indicates that the corresponding index in the 1000-class output should be included in the subset of outputs for ImageNet-A.

In [None]:
imagenetA_masking = {
    "0": -1,
    "1": -1,
    "2": -1,
    "3": -1,
    "4": -1,
    "5": -1,
    "6": 0,
    "7": -1,
    "8": -1,
    "9": -1,
    "10": -1,
    "11": 1,
    "12": -1,
    "13": 2,
    "14": -1,
    "15": 3,
    "16": -1,
    "17": 4,
    "18": -1,
    "19": -1,
    "20": -1,
    "21": -1,
    "22": 5,
    "23": 6,
    "24": -1,
    "25": -1,
    "26": -1,
    "27": 7,
    "28": -1,
    "29": -1,
    "30": 8,
    "31": -1,
    "32": -1,
    "33": -1,
    "34": -1,
    "35": -1,
    "36": -1,
    "37": 9,
    "38": -1,
    "39": 10,
    "40": -1,
    "41": -1,
    "42": 11,
    "43": -1,
    "44": -1,
    "45": -1,
    "46": -1,
    "47": 12,
    "48": -1,
    "49": -1,
    "50": 13,
    "51": -1,
    "52": -1,
    "53": -1,
    "54": -1,
    "55": -1,
    "56": -1,
    "57": 14,
    "58": -1,
    "59": -1,
    "60": -1,
    "61": -1,
    "62": -1,
    "63": -1,
    "64": -1,
    "65": -1,
    "66": -1,
    "67": -1,
    "68": -1,
    "69": -1,
    "70": 15,
    "71": 16,
    "72": -1,
    "73": -1,
    "872": -1,
    "873": -1,
    "874": -1,
    "875": -1,
    "876": -1,
    "877": -1,
    "878": -1,
    "879": 171,
    "880": 172,
    "881": -1,
    "882": -1,
    "883": -1,
    "884": -1,
    "885": -1,
    "886": -1,
    "887": -1,
    "888": 173,
    "889": -1,
    "890": 174,
    "891": -1,
    "892": -1,
    "893": -1,
    "894": -1,
    "895": -1,
    "896": -1,
    "897": 175,
    "898": -1,
    "899": -1,
    "900": 176,
    "901": -1,
    "902": -1,
    "903": -1,
    "904": -1,
    "905": -1,
    "906": -1,
    "907": 177,
    "908": -1,
    "909": -1,
    "910": -1,
    "911": -1,
    "912": -1,
    "913": 178,
    "914": -1,
    "915": -1,
    "916": -1,
    "917": -1,
    "918": -1,
    "919": -1,
    "920": -1,
    "921": -1,
    "922": -1,
    "923": -1,
    "924": 179,
    "925": -1,
    "926": -1,
    "927": -1,
    "928": -1,
    "929": -1,
    "930": -1,
    "931": -1,
    "932": 180,
    "933": 181,
    "934": 182,
    "935": -1,
    "936": -1,
    "937": 183,
    "938": -1,
    "939": -1,
    "940": -1,
    "941": -1,
    "942": -1,
    "943": 184,
    "944": -1,
    "945": 185,
    "946": -1,
    "947": 186,
    "948": -1,
    "949": -1,
    "950": -1,
    "951": 187,
    "952": -1,
    "953": -1,
    "954": 188,
    "955": -1,
    "956": 189,
    "957": 190,
    "958": -1,
    "959": 191,
    "960": -1,
    "961": -1,
    "962": -1,
    "963": -1,
    "964": -1,
    "965": -1,
    "966": -1,
    "967": -1,
    "968": -1,
    "969": -1,
    "970": -1,
    "971": 192,
    "972": 193,
    "973": -1,
    "974": -1,
    "975": -1,
    "976": -1,
    "977": -1,
    "978": -1,
    "979": -1,
    "980": 194,
    "981": 195,
    "982": -1,
    "983": -1,
    "984": 196,
    "985": -1,
    "986": 197,
    "987": 198,
    "988": 199,
    "989": -1,
    "990": -1,
    "991": -1,
    "992": -1,
    "993": -1,
    "994": -1,
    "995": -1,
    "996": -1,
    "997": -1,
    "998": -1,
    "999": -1
}

### **Recover dataset**

In [None]:
# Class that interacts with images stored in an Amazon S3 bucket.
# It allows to load and preprocess images on-the-fly during training or inference.
class S3ImageFolder(Dataset):
    def __init__(self, root, transform=None):
        self.s3_bucket = "deeplearning2024-datasets" # name of the bucket
        self.s3_region = "eu-west-1" # Ireland
        self.s3_client = boto3.client("s3", region_name=self.s3_region, verify=True)
        self.transform = transform

        # Get list of objects in the bucket
        response = self.s3_client.list_objects_v2(Bucket=self.s3_bucket, Prefix=root)
        objects = response.get("Contents", [])
        while response.get("NextContinuationToken"):
            response = self.s3_client.list_objects_v2(
                Bucket=self.s3_bucket,
                Prefix=root,
                ContinuationToken=response["NextContinuationToken"]
            )
            objects.extend(response.get("Contents", []))

        # Iterate and keep valid files only
        self.instances = []
        for ds_idx, item in enumerate(objects):
            key = item["Key"]
            path = Path(key)

            # Check if file is valid
            if path.suffix.lower() not in (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp"):
                continue

            # Get label
            label = path.parent.name

            # Keep track of valid instances
            self.instances.append((label, key))

        # Sort classes in alphabetical order (as in ImageFolder)
        self.classes = sorted(set(label for label, _ in self.instances))
        self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}

    def __len__(self):
        return len(self.instances)

    def __getitem__(self, idx):
        try:
            label, key = self.instances[idx]

            # Download image from S3
            # response = self.s3_client.get_object(Bucket=self.s3_bucket, Key=key)
            # img_bytes = response["Body"]._raw_stream.data

            img_bytes = BytesIO()
            response = self.s3_client.download_fileobj(Bucket=self.s3_bucket, Key=key, Fileobj=img_bytes) # download each image
            # img_bytes = response["Body"]._raw_stream.data

            # Open image with PIL
            img = Image.open(img_bytes).convert("RGB")

            # Apply transformations if any
            if self.transform is not None:
                img = self.transform(img)
        except Exception as e:
            raise RuntimeError(f"Error loading image at index {idx}: {str(e)}")

        return img, self.class_to_idx[label]

# Function to create DataLoaders for training and evaluating models.
# Loads the dataset from the S3 bucket and optionally splits it into training,
# validation, and test sets. It then returns PyTorch DataLoader objects for these datasets.
def get_data(batch_size, img_root, seed = None, split_data = False, transform = None):

    # Load data
    data = S3ImageFolder(root=img_root, transform=transform)

    if split_data:
        # Create train and test splits (80/20)
        num_samples = len(data)
        training_samples = int(num_samples * 0.8 + 1)
        val_samples = int(num_samples * 0.1)
        test_samples = num_samples - training_samples - val_samples

        torch.manual_seed(seed)
        training_data, val_data, test_data = torch.utils.data.random_split(data, [training_samples, val_samples, test_samples])

        # Initialize dataloaders
        train_loader = torch.utils.data.DataLoader(training_data, batch_size, shuffle=True, num_workers=4)
        val_loader = torch.utils.data.DataLoader(val_data, batch_size, shuffle=False, num_workers=4)
        test_loader = torch.utils.data.DataLoader(test_data, batch_size, shuffle=False, num_workers=4)

        return train_loader, val_loader, test_loader

    data_loader = torch.utils.data.DataLoader(data, batch_size, shuffle=False, num_workers=4)
    return data_loader

## **Test-time adaptation methods**

During the testing phase, Test-Time Adaptation (TTA) methods allow for modifications to be made to the model. This enables the model to adapt to new data distributions, even if it has not encountered them before, thus maintaining a certain level of reliability.

Below, we outline the techniques used to enhance the model's performance, providing a brief introduction to each method followed by its implementation.

### **MEMO: Test Time Robustness via Adaptation and Augmentation<sup>[1]</sup>**

In this paper, the authors propose a method called MEMO (Marginal Entropy Minimization with One Test Point) designed to address the problem of robustness in deep neural networks when confronted with distribution shifts or unexpected perturbations. To achieve this, the method employs both adaptation and augmentation strategies.

<p align="center">
  <img src="images/MEMO.png" width="600" height="300">  
</p>

MEMO applies data augmentations to a single test input to generate various versions of the input, and from these, it calculates the marginal output distribution. The model's parameters are then updated to minimize the entropy of this marginal distribution. Finally, the model uses these updated parameters to make a prediction on the original test input.

<br>

---
**Algorithm 1** Test time robustness via MEMO

---

**Require:** trained model f<sub>θ</sub>, test point $x$, number of augmentations $B$, learning rate $η$, update rule $G$

1. Sample a<sub>1</sub>...a<sub>B</sub> $\overset{\text{i.i.d.}}{\sim}$ $\mathcal{U}$ ($\mathcal{A}$) and produce augmented points $\tilde{\mathbf{x}}_i = a_i(\mathbf{x})$ for $i \in \{1, \ldots, B\}$
2. Compute estimate $\tilde{p} = \frac{1}{B} \sum_{i=1}^B p_0(y|\tilde{\mathbf{x}}_i) \approx p_0(y|\mathbf{x})$ and $\tilde{\ell} = H(\tilde{p}) \approx \ell(\theta; \mathbf{x})$
3. Adapt parameters via update rule $\theta' \leftarrow G(\theta, \eta, \tilde{\ell})$
4. Predict $\hat{y} \triangleq \arg \max_y p_{\theta'}(y|\mathbf{x})$
---

In [None]:
def compute_entropy(probabilities):
    """
    Takes a tensor of probabilities [1,Classes] and computes the entropy returned as one-dimensional tensor.
    """
    # Ensure probabilities are normalized (sum to 1)
    if not torch.isclose(probabilities.sum(), torch.tensor(1.0)):
        raise ValueError("The probabilities should sum to 1.")

    # Compute entropy
    # Adding a small value to avoid log(0) issues
    epsilon = 1e-10
    probabilities = torch.clamp(probabilities, min=epsilon)
    entropy = -torch.sum(probabilities * torch.log(probabilities))

    return entropy

def get_best_augmentations(probabilities, top_k):
    """
    Takes a tensor of probabilities with dimension [num_augmentations,classes] or [mc_models,num_augmentations,200]
    and outputs a tensor containing the probabilities corresponding to the augmentations
    with the lowest entropy of dimension [top_k, classes] or [mc_models, top_k, classes].
    ----------
    top_k: number of augmentations to select
    probabilities: a tensor of dimension [num_augmentations,200]
    """
    if probabilities.dim() == 2:
        probabilities = probabilities.unsqueeze(0)

    # nested list comprehension needed if probabilities is a 3D tensor (MC dropout)
    entropies = torch.tensor([[compute_entropy(prob) for prob in prob_set] for prob_set in probabilities])
    _, top_k_indices = torch.topk(entropies, top_k, largest=False, sorted=False)
    sorted_top_k_indices = torch.stack([indices[torch.argsort(entropies[i, indices])]
                                            for i, indices in enumerate(top_k_indices)])
    top_k_probabilities = torch.stack([probabilities[i][sorted_top_k_indices[i]]
                                        for i in range(probabilities.shape[0])])
    if top_k_probabilities.shape[0] == 1:
        top_k_probabilities = top_k_probabilities.squeeze(0)

    return top_k_probabilities

def get_test_augmentations(input, augmentations, num_augmentations, seed_augmentations):
    """
    Takes a tensor image of dimension [C,H,W] and returns a tensor of augmentations of dimension [num_augmentations, C,H,W].
    The augmentations are produced by sampling different torchvision.transforms from "augmentations".
    ----------
    input: an image tensor of dimension [C,H,W]
    augmentations: a list of torchvision.transforms augmentations
    num_augmentations: the number of augmentations to produce
    seed_augmentations: seed to reproduce the sampling of augmentations
    """
    torch.manual_seed(seed_augmentations)
    random.seed(seed_augmentations)
    sampled_augmentations = random.sample(augmentations, num_augmentations)
    test_augmentations = torch.zeros((num_augmentations, 3, 224, 224))
    for i, augmentation in enumerate(sampled_augmentations):
        transform_MEMO = T.Compose([
            T.ToPILImage(),
            augmentation,
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        augmented_input = transform_MEMO(input.cpu())
        test_augmentations[i] = augmented_input
    return test_augmentations

### TTA: Test Time Augmentations

Test Time Augmentation (TTA) is a technique used to improve the robustness and accuracy of a model's predictions during inference by applying multiple augmentations to the input data. The main idea is to simulate different variations of the test input, make predictions on each of these variations, and then combine these predictions to make a final decision. The predictions are usually combined by averaging and then the `argmax` function is used as usual.

### **Adaptive Batch Normalization - Improving robustness against common corruptions by covariate shift adaptation<sup>[3]<sup>**

In this paper, the authors investigate methods to improve the robustness of machine learning models trained for computer vision against common image corruptions, such as blurring or compression artifacts. These corruptions often degrade model performance, reducing their effectiveness in real-world applications.

Traditional approaches tend to underestimate model robustness in scenarios where models can adapt to corruptions found in multiple unlabeled examples. The authors argue that models should utilize these examples for unsupervised online adaptation, a strategy not commonly employed in current evaluations. Instead of relying on static batch normalization (BN) statistics computed during training, the authors propose that these statistics should be dynamically updated by the models using data from corrupted images encountered during testing. This adaptive approach can significantly enhance model performance under real-world conditions. The formula to compute the update is the following:

<div align="center">

$\mathcal{v} \in \{\mu, \sigma^2 \}$ \
$\mathcal{v} = \frac{N}{N+1}\mathcal{v}_{\text{train}} + \frac{1}{N+1}\mathcal{v}_{\text{test}} $

</div>

Following the paper's suggestion we set `N` to 16.

In [None]:
def adaptive_bn_forward(self, input: torch.Tensor):
    """
    Applies an adaptive batch normalization to the input tensor using precomputed running
    statistics that are updated in an adaptive manner using Schneider et al. [40] formula:
                        mean = N/(N+1)*mean_train + 1/(N+1)*mean_test
                        var = N/(N+1)*var_train + 1/(N+1)*var_test
    N corresponds to the weight that is given to the statistics of the pre-trained model.
    In the implementation, N/(N+1) corresponds to self.prior_strength and can be modified assigning
    a float/int to nn.BatchNorm2d.prior_strength.
    -----------
    input : input tensor of shape [N, C, H, W]
    """
    # compute channel-wise statistics for the input
    point_mean = input.mean([0,2,3]).to(device = self.running_mean.device)
    point_var = input.var([0,2,3], unbiased=True).to(device = self.running_mean.device)
    # BN adaptation
    adapted_running_mean = self.prior_strength * self.running_mean + (1 - self.prior_strength) * point_mean
    adapted_running_var = self.prior_strength * self.running_var + (1 - self.prior_strength) * point_var
    # detach to avoid non-differentiable torch error
    adapted_running_mean = adapted_running_mean.detach()
    adapted_running_var = adapted_running_var.detach()

    return torch.nn.functional.batch_norm(input, adapted_running_mean, adapted_running_var, self.weight, self.bias, False, 0, self.eps)

### **Monte Carlo Dropout - Dropout as a Bayesian Approximation: Representing Model Uncertainty in Deep Learning<sup>[4]</sup>**

In this paper, the authors propose Monte Carlo Dropout, a technique that leverages dropout, a regularization method commonly used during training, to also perform approximate Bayesian inference during the testing phase.

During the model training phase, the dropout technique randomly "drops" or deactivates a fraction of neurons in the network during each forward pass. The probability of dropping neurons can be controlled using a specific parameter. The aim is to prevent neurons from memorizing specific inputs, thus reducing overfitting and encouraging the network to learn more general representations.

However, during the model test phase, dropout is usually turned off, allowing the full network to make predictions. The key idea behind Monte Carlo Dropout is to keep dropout active during the test phase and perform multiple forward passes through the network. The result is that for each forward pass, a different dropout mask is used, which means a random subset of neurons is activated, allowing for different predictions. By averaging these predictions, it is possible to obtain both the final prediction and a measure of the model's uncertainty.

In [None]:
class ResNet50Dropout(nn.Module):
    """
    It creates a version of the ResNet-50 model that integrates dropout layers at
    various points in the architecture. By using dropout, the model can be trained
    with a regularization technique that allows for the implementation of
    Monte Carlo Dropout.
    ----------
    weights: Optional pre-trained weights for the ResNet-50 model.
    dropout_rate: The probability of dropping out neurons during training.
    A value of 0 means no dropout is applied and the architecture is identical to the
    original ResNet-50.
    """
    def __init__(self, weights=None, dropout_rate=0., dropout_positions=None):
        super(ResNet50Dropout, self).__init__()

        self.weights = weights
        self.model = models.resnet50(weights=self.weights)
        self.dropout_rate = dropout_rate

        self.dropout_positions = []
        if self.dropout_rate > 0:
            self.dropout_positions = self.get_dropout_positions() if dropout_positions == None else dropout_positions["dropout_positions"]

        self._add_dropout()

    # This method reads a JSON file that contains a list of layer names where
    # dropout should be applied.
    def get_dropout_positions(self):
        dropout_positions_path = "INSERT PATH TO DROPOUT POSITIONS JSON FILE"
        with open(dropout_positions_path, 'r') as json_file:
            dropout_positions = json.load(json_file)
        dropout_positions = dropout_positions["dropout_positions"]

        return dropout_positions

    # This method adds dropout layers to the ResNet-50 model at the specified
    # positions, by looking at the dropout_positions list.
    # For each specified layer, the method wraps the original layer in
    # a nn.Sequential block, which includes the original layer followed by a
    # nn.Dropout layer with the specified dropout rate.
    def _add_dropout(self):
        if 'conv1' in self.dropout_positions:
            self.model.conv1 = nn.Sequential(
                self.model.conv1,
                nn.Dropout(p=self.dropout_rate)
            )

        if 'layer1' in self.dropout_positions:
            self.model.layer1 = nn.Sequential(
                self.model.layer1,
                nn.Dropout(p=self.dropout_rate)
            )

        if 'layer2' in self.dropout_positions:
            self.model.layer2 = nn.Sequential(
                self.model.layer2,
                nn.Dropout(p=self.dropout_rate)
            )

        if 'layer3' in self.dropout_positions:
            self.model.layer3 = nn.Sequential(
                self.model.layer3,
                nn.Dropout(p=self.dropout_rate)
            )

        if 'layer4' in self.dropout_positions:
            self.model.layer4 = nn.Sequential(
                self.model.layer4,
                nn.Dropout(p=self.dropout_rate)
            )

        if 'avgpool' in self.dropout_positions:
            self.model.avgpool = nn.Sequential(
                self.model.avgpool,
                nn.Dropout(p=self.dropout_rate)
            )

        if 'fc' in self.dropout_positions:
            self.model.fc = nn.Sequential(
                nn.Dropout(p=self.dropout_rate),
                self.model.fc
            )

    def forward(self, x):
        return self.model(x)

#### **Dropout positions**

The **dropout_positions.json** file is to define the locations within a custom ResNet50 model where dropout layers should be inserted.

The dropout layers are incorporated to enhance the proposed method using Monte Carlo Dropout, a technique that improves model robustness and uncertainty estimation. For more details, refer to the **ResNet50Dropout** section.

In [None]:
dropout_positions = {
    "dropout_positions": [
        "conv1",
        "layer1",
        "layer2",
        "layer3",
        "layer4",
        "avgpool",
        "fc"
    ]
}

### **Efficient DiffTPT - Offline Data Augmentations with Diffusion and LLM**


Traditional data augmentation methods are limited by insufficient data diversity. We re-adapt the DiffTPT method proposed by the paper 
"DiffTPT - Diverse Data Augmentation with Diffusions for Effective Test-time Prompt Tuning<sup>[5]</sup>" to use it with MEMO. 

<p align="center">
  <img src="images/DiffTPT.png" width="800" height="300">  
</p>

DiffTPT relies on the CLIP image encoder to generate new images, which may limit the variability of the generated outputs. Additionally, generating images at test time slows down online inference and necessitates the creation of new images for each instance. To address these limitations, we propose a novel approach as follows:

1. Query Definition: Formulate a query that includes the class name, image style, and other relevant details.

2. Image Scraping: Using the query, scrape a small set of images from the internet for each class (e.g., 10 images per class). For domains with significant shifts or that are very abstract, incorporating these images aids in generating samples more aligned with the new data distribution, though it may reduce variability and may erronously bias the generation if the scraping is not performed appropriately.

3. Prompt Generation: Using a similar query and a large language model (LLM), specifically "llama3.1" in our case, to generate a set of prompts for each class. The CLIP text embeddings for these prompts are then stored.

4. Image Generation and Embedding Storage: Generate new images using stable diffusion based on the prompts and/or scraped images, and store their corresponding CLIP image embeddings.

5. Cosine Similarity for Retrieval: At test time, use cosine similarity between the image embeddings of the previously generated images and/or the text embeddings to retrieve the most similar images. This method is computationally less expensive than generating new images for each test sample it should still enhance accuracy. For esample, while our number of generated images if fixed and does not depend on the number of samples to classify, the original method scales linearly with it. Which means that for the `Imagenet-A`, assuming 64 augmentations per sample, one needs to produce a total of 480,000 images, nearly 50 times the ones we used. Thus, not only our method is much more efficient and expensive, but it's also significantly faster assuming the same computational power. 

6. Data Augmentation: Augmented data is incorporated using both conventional methods and pre-trained stable diffusion models, albeit with varying percentages.

We generated 30 images using the `t2i` pipeline and 25 images using the `iti` pipeline (due to lack of time), resulting in a total of 55 images per class. The entire dataset comprises approximately 11,000 images, generated over approximately 15 hours (17 hours including the generation of 20 prompts per class), excluding the time taken for scraping, which was minimal. So generating each image took nearly (4 seconds for Stable Diffusion) 5.5 seconds. This time investment is a fixed cost, and once a sufficiently diverse set of images is generated, further use of generative models is unnecessary. Unlike DiffTPT, which requires generating new images for each sample and discards a substantial portion of the outputs, our approach generates images in a more controlled manner, reducing the need for such discarding. 

Finally, for each class, we selected the `top-k` images using the cosine similarity filter, followed by an additional filter that imposes a minimum similarity threshold to mitigate the risk of including images from other classes. This means that the final number of generated augmentations used is not `top-k`, but a number between 0 and `top-k`.

<p align="center">
  <img src="images/image_generation_pipeline_colored.png" width="600" height="300">  
</p>

#### **Image generation**

To facilitate reproducibility we made the generated dataset available at the following Google Drive link: https://drive.google.com/file/d/1DIbKDdsGZD4pUizaL_rI_XTAS5ZmgU-v/view?usp=sharing

#### **ImageNet-A classes**

The **imagenetA_classes.json** file provides a mapping between the synset IDs used in the ImageNet dataset and their corresponding class names. This mapping is essential for converting model outputs from synset IDs to human-readable labels.

The class IDs are mainly used to create the directory structure where newly generated images will be stored, forming a secondary dataset. This dataset can then be utilized for training and inference activities. For more details, refer to the **ToDo** section.

In [None]:
imagenetA_classes = {
	"n01498041": "stingray",
	"n01531178": "goldfinch",
	"n01534433": "junco",
	"n01558993": "American robin",
	"n01580077": "jay",
	"n01614925": "bald eagle",
	"n01616318": "vulture",
	"n01631663": "newt",
	"n01641577": "American bullfrog",
	"n01669191": "box turtle",
	"n01677366": "green iguana",
	"n01687978": "agama",
	"n01694178": "chameleon",
	"n01698640": "American alligator",
	"n01735189": "garter snake",
	"n01770081": "harvestman",
	"n01770393": "scorpion",
	"n01774750": "tarantula",
	"n01784675": "centipede",
	"n01819313": "sulphur-crested cockatoo",
	"n01820546": "lorikeet",
	"n01833805": "hummingbird",
	"n01843383": "toucan",
	"n01847000": "duck",
	"n01855672": "goose",
	"n01882714": "koala",
	"n01910747": "jellyfish",
	"n01914609": "sea anemone",
	"n01924916": "flatworm",
	"n01944390": "snail",
	"n01985128": "crayfish",
	"n01986214": "hermit crab",
	"n02007558": "flamingo",
	"n02009912": "great egret",
	"n02037110": "oystercatcher",
	"n02051845": "pelican",
	"n02077923": "sea lion",
	"n02085620": "Chihuahua",
	"n02099601": "Golden Retriever",
	"n02106550": "Rottweiler",
	"n02106662": "German Shepherd Dog",
	"n02110958": "pug",
	"n02119022": "red fox",
	"n02123394": "Persian cat",
	"n02127052": "lynx",
	"n02129165": "lion",
	"n02133161": "American black bear",
	"n02137549": "mongoose",
	"n02165456": "ladybug",
	"n02174001": "rhinoceros beetle",
	"n02177972": "weevil",
	"n02190166": "fly",
	"n02206856": "bee",
	"n02219486": "ant",
	"n02226429": "grasshopper",
	"n02231487": "stick insect",
	"n02233338": "cockroach",
	"n02236044": "mantis",
	"n02259212": "leafhopper",
	"n02268443": "dragonfly",
	"n02279972": "monarch butterfly",
	"n02280649": "small white",
	"n02281787": "gossamer-winged butterfly",
	"n02317335": "starfish",
	"n02325366": "cottontail rabbit",
	"n02346627": "porcupine",
	"n02356798": "fox squirrel",
	"n02361337": "marmot",
	"n02410509": "bison",
	"n02445715": "skunk",
	"n02454379": "armadillo",
	"n02486410": "baboon",
	"n02492035": "white-headed capuchin",
	"n02504458": "African bush elephant",
	"n02655020": "pufferfish",
	"n02669723": "academic gown",
	"n02672831": "accordion",
	"n02676566": "acoustic guitar",
	"n02690373": "airliner",
	"n02701002": "ambulance",
	"n02730930": "apron",
	"n02777292": "balance beam",
	"n02782093": "balloon",
	"n02787622": "banjo",
	"n02793495": "barn",
	"n02797295": "wheelbarrow",
	"n02802426": "basketball",
	"n02814860": "lighthouse",
	"n02815834": "beaker",
	"n02837789": "bikini",
	"n02879718": "bow",
	"n02883205": "bow tie",
	"n02895154": "breastplate",
	"n02906734": "broom",
	"n02948072": "candle",
	"n02951358": "canoe",
	"n02980441": "castle",
	"n02992211": "cello",
	"n02999410": "chain",
	"n03014705": "chest",
	"n03026506": "Christmas stocking",
	"n03124043": "cowboy boot",
	"n03125729": "cradle",
	"n03187595": "rotary dial telephone",
	"n03196217": "digital clock",
	"n03223299": "doormat",
	"n03250847": "drumstick",
	"n03255030": "dumbbell",
	"n03291819": "envelope",
	"n03325584": "feather boa",
	"n03355925": "flagpole",
	"n03384352": "forklift",
	"n03388043": "fountain",
	"n03417042": "garbage truck",
	"n03443371": "goblet",
	"n03444034": "go-kart",
	"n03445924": "golf cart",
	"n03452741": "grand piano",
	"n03483316": "hair dryer",
	"n03584829": "clothes iron",
	"n03590841": "jack-o'-lantern",
	"n03594945": "jeep",
	"n03617480": "kimono",
	"n03666591": "lighter",
	"n03670208": "limousine",
	"n03717622": "manhole cover",
	"n03720891": "maraca",
	"n03721384": "marimba",
	"n03724870": "mask",
	"n03775071": "mitten",
	"n03788195": "mosque",
	"n03804744": "nail",
	"n03837869": "obelisk",
	"n03840681": "ocarina",
	"n03854065": "organ",
	"n03888257": "parachute",
	"n03891332": "parking meter",
	"n03935335": "piggy bank",
	"n03982430": "billiard table",
	"n04019541": "hockey puck",
	"n04033901": "quill",
	"n04039381": "racket",
	"n04067472": "reel",
	"n04086273": "revolver",
	"n04099969": "rocking chair",
	"n04118538": "rugby ball",
	"n04131690": "salt shaker",
	"n04133789": "sandal",
	"n04141076": "saxophone",
	"n04146614": "school bus",
	"n04147183": "schooner",
	"n04179913": "sewing machine",
	"n04208210": "shovel",
	"n04235860": "sleeping bag",
	"n04252077": "snowmobile",
	"n04252225": "snowplow",
	"n04254120": "soap dispenser",
	"n04270147": "spatula",
	"n04275548": "spider web",
	"n04310018": "steam locomotive",
	"n04317175": "stethoscope",
	"n04344873": "couch",
	"n04347754": "submarine",
	"n04355338": "sundial",
	"n04366367": "suspension bridge",
	"n04376876": "syringe",
	"n04389033": "tank",
	"n04399382": "teddy bear",
	"n04442312": "toaster",
	"n04456115": "torch",
	"n04482393": "tricycle",
	"n04507155": "umbrella",
	"n04509417": "unicycle",
	"n04532670": "viaduct",
	"n04540053": "volleyball",
	"n04554684": "washing machine",
	"n04562935": "water tower",
	"n04591713": "wine bottle",
	"n04606251": "shipwreck",
	"n07583066": "guacamole",
	"n07695742": "pretzel",
	"n07697313": "cheeseburger",
	"n07697537": "hot dog",
	"n07714990": "broccoli",
	"n07718472": "cucumber",
	"n07720875": "bell pepper",
	"n07734744": "mushroom",
	"n07749582": "lemon",
	"n07753592": "banana",
	"n07760859": "custard apple",
	"n07768694": "pomegranate",
	"n07831146": "carbonara",
	"n09229709": "bubble",
	"n09246464": "cliff",
	"n09472597": "volcano",
	"n09835506": "baseball player",
	"n11879895": "rapeseed",
	"n12057211": "yellow lady's slipper",
	"n12144580": "corn",
	"n12267677": "acorn"
}

In [None]:
def get_imagenetA_classes(imagenetA_classes : Union[str, Dict[str, str]] = None):
    """
    ImageNet-A uses the same label structure as the original ImageNet (ImageNet-1K).
    Each class in ImageNet is represented by a synset ID (e.g., n01440764 for "tench, Tinca tinca").
    This function returns a dictionary that maps the synset IDs of ImageNet-A to the corresponding class names.
    ----------
    indices_in_1k: list of indices to map [B,1000] -> [B,200]
    """
    if isinstance(imagenetA_classes,str):
        with open(imagenetA_classes, 'r') as json_file:
            imagenetA_classes_dict = json.load(json_file)
    else:
        imagenetA_classes_dict = imagenetA_classes

    # ensure `class_dict` is a dictionary with keys as class IDs and values as class names
    class_dict = {k: v for k, v in imagenetA_classes_dict.items()}
    return class_dict

def create_dir_generated_images(path:str, imageneA_classes: Union[str, Dict[str, str]] = None):
    """
    Create directory where to store generated images.
    ---
    path (str): path where to create the directories
    """
    classes = list(get_imagenetA_classes(imageneA_classes).values())
    for class_name in classes:
        class_path = os.path.join(path, class_name)
        os.makedirs(class_path, exist_ok=True)

In [None]:
class ImageGenerator:
    """
    A class containing all the functions to generate and save prompts, images and their respective
    CLIP embeddings.
    """
    def __init__(self):
        pass 

    def get_text_embedding(self,clip_model, text: str):
        """
        
        """
        text_token = clip.tokenize(text).cuda()
        text_embedding = clip_model.encode_text(text_token).float()
        text_embedding /= text_embedding.norm()
        return text_embedding 
        
    def generate_prompts(self, 
                         num_prompts_per_class: int, 
                         style_of_picture: str, 
                         path: str, 
                         context_llm: Union[str, Dict], 
                         llm_model: str = "llama3.1", 
                         clip_text_encoder: str = "ViT-L/14"
                        ) -> None:
        """
        Generate image prompts for each class in the given directory path using a language model available in ollama library.
        The prompts will then be used to generate images using a diffusion model.
        ---
        num_prompts_per_class (int): Number of prompts to generate per class. This number doesn't correspond to the actual number 
                                     of prompts that will be generated, but rather to the desired total number of prompts for each 
                                     class e.g. if the class "ant" has already 10 prompts and num_prompts_per_class = 12, then only
                                     two prompts will be generated.
        style_of_picture (str): Style to be used in image prompts.
        path (str): Path to the generated dataset.
        context_llm (str, dict): A dict containing the context for the language model or a path to a json file containing it.
        llm_model (str): The language model to use for generating prompts.
        clip_text_encoder (str): The CLIP text encoder model to use. 
        ---
        Returns:
            list: A list of class names for which prompts could not be generated due to bad prompts formatting.
        """
        assert isinstance(llm_model,str), "Model must be a str"
        assert isinstance(context_llm, (str,list)), "context_llm must be a str path or a list of dict"
        assert isinstance(clip_text_encoder, str), "clip_text_encoder must be a str. Use clip.available_models() to get valid strings."
        assert isinstance(style_of_picture, str), "style_of_picture must be a str representing the style of the image that will be generated"
        assert isinstance(num_prompts_per_class, int), "num_prompts_per_class must be an int"

        # check that llm_model is available
        try:
          ollama.chat(llm_model)
        except ollama.ResponseError as e:
          print('Error:', e.error)
          if e.status_code == 404:
            # try to pull the model if it exists
            print("Pulling the model...")
            ollama.pull(llm_model)

        # load context_llm dict if needed
        if isinstance(context_llm,str):
            with open(context_llm, 'r') as file:
                context_llm = json.load(file) 

        skipped_classes = []

        clip_model, _ = clip.load(clip_text_encoder)
        clip_model.cuda().eval()

        class_list = os.listdir(path)
        with torch.no_grad():
            with tqdm(total=len(class_list), desc="Processing classes") as pbar:
                for class_name in class_list:
                    pbar.set_description(f"Processing class: {class_name}")
                    sub_dir_class = os.path.join(path, class_name)
                    prompts_to_generate = num_prompts_per_class - len(os.listdir(sub_dir_class)) - 1 # -1 to account for scraped_img folder
                    if prompts_to_generate <= 0: 
                        pbar.update(1)
                        continue
                    # sometimes the language model doesn't return an appropriate output, tolerance = number of possible attempts
                    tolerance = 6
                    gen_prompts = []
                    original_prompts_to_gen = prompts_to_generate
                    while tolerance>0:
                        prompts_generation_instruction = {
                            "role": "user",
                            "content": f"class:{class_name}, number of prompts:{prompts_to_generate}, style of picture: {style_of_picture}"
                        }
                        if len(context_llm) == 3:
                            context_llm.append(prompts_generation_instruction)
                        else:
                            # needed from the second iteration
                            context_llm[3] = prompts_generation_instruction
                        try:
                            response = ollama.chat(model=llm_model, messages=context_llm)
                            content = json.loads(response['message']['content'])  # json.loads to convert str to list
                            if len(content) > prompts_to_generate:
                                # enough or more than enough prompts generated
                                prompts_to_generate -= len(content)
                                gen_prompts.extend(content)
                                gen_prompts = gen_prompts[:original_prompts_to_gen]
                                tolerance = -1
                            else:
                                # more prompts neeeded
                                prompts_to_generate -= len(content)
                                gen_prompts.extend(content)
                        except Exception as e:
                            tolerance -= 1
    
                    if len(gen_prompts) != 1:
                        counter_flag = -1 
    
                    if tolerance == -1:
                        num_prompts_already_gen = len(os.listdir(sub_dir_class))
                        for i in range(num_prompts_already_gen, num_prompts_already_gen + len(gen_prompts)):
                            new_sub_dir = os.path.join(path, class_name, str(i))
                            os.makedirs(new_sub_dir, exist_ok=True)
                            prompt = gen_prompts[i - num_prompts_already_gen]
                            prompt_embedding = self.get_text_embedding(clip_model, prompt) # compute text CLIP embedding
                            with open(os.path.join(new_sub_dir, "prompt.txt"), 'w') as file:
                                file.write(prompt)
                            torch.save(prompt_embedding, os.path.join(new_sub_dir,"prompt_clip_embedding.pt"))
                    else:
                        # not even one prompt was generated, class entirely skipped
                        skipped_classes.append(class_name)
                        print(f"Skipping class {class_name}.")

        return skipped_classes
    
    def get_image_embedding(self,clip_model, preprocess, image):
        image_preprocessed = preprocess(image).unsqueeze(0).cuda()
        image_embedding = clip_model.encode_image(image_preprocessed)
        image_embedding /= image_embedding.norm()
        return image_embedding
        
    def generate_images(self, 
                        path: str, 
                        num_images_per_class: int,  
                        image_generation_pipeline: Union[StableDiffusionPipeline, StableDiffusionImg2ImgPipeline], 
                        num_inference_steps: int, 
                        class_to_skip: List[str] = [],
                        guidance_scale: int = 9, 
                        strength: float = 0.8, 
                        clip_image_encoder: str = "ViT-L/14"
                       ) -> None:
        """
        Generate images for each class in the specified directory using the given image generation pipeline.
        ---
        path (str): Path to the directory containing class subdirectories.
        num_images_per_class (int): Number of images to generate per class. 
        class_to_skip (List[str]): List of classes to skip.
        image_generation_pipeline (Union[StableDiffusionPipeline, StableDiffusionImg2ImgPipeline]): The image generation pipeline to use.
        num_inference_steps (int): Number of inference steps for image generation.
        guidance_scale (int): Guidance scale for image generation. A higher guidance scale value encourages the model to generate 
                                        images closely linked to the text prompt at the expense of lower image quality
        strength (float): Indicates extent to transform the reference scraped image. A value of 1 essentially ignores image. 
        clip_image_encoder (str, optional): The CLIP image encoder model to use.
        ---
        Returns:
            None: This function does not return any value.
        """
        assert isinstance(image_generation_pipeline, (StableDiffusionPipeline, StableDiffusionImg2ImgPipeline)), \
            "image_generation_pipeline must be one of StableDiffusionPipeline or StableDiffusionImg2ImgPipeline"
        assert isinstance(clip_image_encoder, str), "clip_image_encoder must be a str. Use clip.available_models() to get valid strings."
        assert isinstance(num_images_per_class, int), "num_images_per_class must be an int"
        
        random.seed(42)

        print("Loading CLIP model...")
        clip_model, preprocess = clip.load(clip_image_encoder)
        clip_model.cuda().eval()

        os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'
        class_list = os.listdir(path)
        with torch.no_grad():
            with tqdm(total=len(class_list), desc="Processing classes") as pbar:
                for class_name in class_list:
                    pbar.set_description(f"Processing class: {class_name}")
                    if class_name in class_to_skip: 
                        pbar.update(1)
                        continue
                    class_path = os.path.join(path, class_name)
                    num_prompts = len(os.listdir(class_path))
                    
                    if image_generation_pipeline.__class__.__name__ == "StableDiffusionImg2ImgPipeline":
                        print("Loading scraped images...")
                        scraped_image_paths = os.path.join(class_path, "scraped_images") # go in scraped_images folder
                        scraped_images = []
                        for scraped_image_path in os.listdir(scraped_image_paths): # open and append scraped images
                            if scraped_image_path in (".ipynb_checkpoints"): continue
                            img_path = os.path.join(scraped_image_paths,scraped_image_path)
                            scraped_image = Image.open(img_path)
                            scraped_image = scraped_image.resize((512,512))
                            scraped_images.append(scraped_image)
                    
                    num_gen_images = 0 # needed if num_images < num_prompts
    
                    while num_gen_images < num_images_per_class:
                        for gen_images_class in os.listdir(class_path):
                            if gen_images_class in (".ipynb_checkpoints","scraped_images"): continue
                            gen_image_class = os.path.join(class_path,gen_images_class)
                            # needed bc some folders don't have a prompt.txt due to some error during generation
                            try:
                                with open(os.path.join(gen_image_class, "prompt.txt"), 'r') as file:
                                    text_prompt = file.read()
                            except:
                                continue
                                
                            # get scraped images using the image-to-image generation pipeline
                            if isinstance(image_generation_pipeline, StableDiffusionImg2ImgPipeline):
                                try: # sometimes get weird OOM error
                                    i2i_image_path = os.path.join(gen_image_class,"i2i_gen_images")
                                    os.makedirs(i2i_image_path,exist_ok=True)
                                    scraped_image = random.sample(scraped_images,1)[0]
                                    with torch.no_grad():
                                        gen_image = image_generation_pipeline(prompt=text_prompt,
                                                                                image=scraped_image,
                                                                                strength=strength,
                                                                                guidance_scale=guidance_scale,
                                                                                num_inference_steps=num_inference_steps).images[0]
                                        del scraped_image
                                        gen_image_embedding = self.get_image_embedding(clip_model, preprocess, gen_image)
                                        save_gen_image_path = os.path.join(i2i_image_path,str(len(os.listdir(i2i_image_path))))
                                        os.makedirs(save_gen_image_path)
                                        torch.save(gen_image_embedding, os.path.join(save_gen_image_path, "image_embedding.pt"))
                                        gen_image.save(os.path.join(save_gen_image_path, "image.png"))
                                        num_gen_images += 1
                                except:
                                    print("Error occurred")
                            else:
                                t2i_image_path = os.path.join(gen_image_class,"t2i_gen_images")
                                os.makedirs(t2i_image_path,exist_ok=True)
                                with torch.no_grad():
                                    gen_image = image_generation_pipeline(prompt=text_prompt,
                                                                            strength=strength,
                                                                            guidance_scale=guidance_scale,
                                                                            num_inference_steps=num_inference_steps).images[0]
                                    gen_image_embedding = self.get_image_embedding(clip_model, preprocess, gen_image)
                                    save_gen_image_path = os.path.join(t2i_image_path,str(len(os.listdir(t2i_image_path))))
                                    os.makedirs(save_gen_image_path)
                                    torch.save(gen_image_embedding, os.path.join(save_gen_image_path, "image_embedding.pt"))
                                    gen_image.save(os.path.join(save_gen_image_path, "image.png"))                        
                                    num_gen_images += 1
                            # break loop over the class prompts if generated enough images
                            if num_gen_images == num_images_per_class: break 
                    pbar.update(1)

In [None]:
def retrieve_gen_images(img: Union[torch.Tensor, Image.Image],  
                        num_images: int,
                        clip_model: clip.model.CLIP, 
                        clip_preprocess: torchvision.transforms.transforms.Compose,
                        img_to_tensor_pipe: torchvision.transforms.transforms.Compose = None,
                        data_path: str = "/home/sagemaker-user/Domain-Shift-Computer-Vision/imagenetA_generated",
                        use_t2i_similarity: bool = False, 
                        t2i_images: bool = True,
                        i2i_images: bool = False, 
                        threshold: float = 0.0) -> Union[torch.Tensor, None]:
        """
        Retrieve the most similar generated images based on CLIP embeddings.
        ---
        img (Union[torch.Tensor, Image.Image]): The input image to compare against. Can be a torch tensor or PIL image.
        num_images (int): The number of similar images to retrieve.
        clip_model: The preloaded CLIP model for generating embeddings.
        clip_preprocess: The preprocessing function for the CLIP model.
        img_to_tensor_pipe: A pipeline function that converts images to tensors.
        data_path (str): Path to the directory containing generated images.
        use_t2i_similarity (bool): Whether to average text-to-image similarity with image-to-image similarity.
        t2i_images (bool): Whether to include text-to-image generated images in the search.
        i2i_images (bool): Whether to include image-to-image generated images in the search.
        threshold (float): The minimum cosine similarity threshold for an image to be considered. 
        ---
        Returns:
            Union[torch.Tensor, None]: A tensor containing the retrieved images. Returns None if no images are retrieved.
        """
        assert i2i_images or t2i_images, "One of t2i_images and i2i_images must be true"
        assert isinstance(use_t2i_similarity, bool), "use_t2i_similarity must be a bool"
        assert isinstance(t2i_images, bool), "t2i_images must be a bool"
        assert isinstance(i2i_images, bool), "i2i_images must be a bool"
        assert isinstance(num_images, int), "num_images must be an int"
        assert isinstance(threshold, float) and 0 < threshold < 1, "threshold must be a float and between 0 and 1"
        
        if isinstance(img, torch.Tensor):
            img = T.ToPILImage()(img)

        retrieved_images_paths = []
        retrieved_images_similarity = torch.zeros(num_images)
        with torch.no_grad():
            image_embedding = clip_model.encode_image(clip_preprocess(img).unsqueeze(0).cuda())
            image_embedding /= image_embedding.norm()
        
        for class_name in os.listdir(data_path):
            class_path = os.path.join(data_path, class_name)
            for gen_images_class in os.listdir(class_path):
                if gen_images_class in ["scraped_images", ".ipynb_checkpoints"]: continue
                gen_images_class_path = os.path.join(class_path,gen_images_class)
                gen_prompt_embedding = torch.load(os.path.join(gen_images_class_path, "prompt_clip_embedding.pt"))
                t2i_similarity = F.cosine_similarity(image_embedding, gen_prompt_embedding)
                # Search in text-to-image generated images
                if t2i_images:
                    t2i_gen_images_main_path = os.path.join(gen_images_class_path,"t2i_gen_images")
                    for t2i_images_paths in os.listdir(t2i_gen_images_main_path):
                        t2i_image_path = os.path.join(t2i_gen_images_main_path,t2i_images_paths)
                        gen_image_embedding = torch.load(os.path.join(t2i_image_path, "image_embedding.pt"))
                        i2i_similarity = F.cosine_similarity(image_embedding, gen_image_embedding)
                        if use_t2i_similarity:
                            similarity = (i2i_similarity + t2i_similarity)/2 # avg similarity
                        else:
                            similarity = i2i_similarity
                        if similarity < threshold: continue
                        if len(retrieved_images_paths) < num_images:
                            retrieved_images_similarity[len(retrieved_images_paths)] = similarity
                            retrieved_images_paths.append(os.path.join(t2i_image_path, "image.png"))
                        else:
                            min_similarity, id_similarity = retrieved_images_similarity.min(dim=0)
                            if similarity > min_similarity:
                                retrieved_images_similarity[id_similarity] = similarity
                                retrieved_images_paths[id_similarity] = os.path.join(t2i_image_path, "image.png")
                # Search in image-to-image generated images
                if i2i_images:
                    i2i_gen_images_main_path = os.path.join(gen_images_class_path,"i2i_gen_images")
                    for i2i_images_paths in os.listdir(i2i_gen_images_main_path):
                        i2i_image_path = os.path.join(i2i_gen_images_main_path,i2i_images_paths)
                        gen_image_embedding = torch.load(os.path.join(i2i_image_path, "image_embedding.pt"))
        
                        i2i_similarity = F.cosine_similarity(image_embedding, gen_image_embedding)
                        if use_t2i_similarity:
                            similarity = (i2i_similarity + t2i_similarity)/2 # avg similarity
                        else:
                            similarity = i2i_similarity
                        if similarity < threshold: continue
                        if len(retrieved_images_paths) < num_images:
                            retrieved_images_similarity[len(retrieved_images_paths)] = similarity
                            retrieved_images_paths.append(os.path.join(t2i_image_path, "image.png"))
                        else:
                            min_similarity, id_similarity = retrieved_images_similarity.min(dim=0)
                            if similarity > min_similarity:
                                retrieved_images_similarity[id_similarity] = similarity
                                retrieved_images_paths[id_similarity] = os.path.join(i2i_image_path, "image.png")
        
        # Load and return the retrieved images as a tensor
        retrieved_images = []
        for image_path in retrieved_images_paths:
            # Apply image-to-tensor pipeline if provided
            if img_to_tensor_pipe:
                retrieved_images.append(img_to_tensor_pipe(Image.open(image_path)))
            else:
                retrieved_images.append(Image.open(image_path))

        # Return tensor if image-to-tensor pipeline is provided, otherwise return list of PIL images
        if img_to_tensor_pipe:
            retrieved_images = torch.stack(retrieved_images) if len(retrieved_images) >= 1 else torch.tensor([])
        else:
            return retrieved_images
        return retrieved_images

In [None]:
def scrape_images_imagenetA(img_style: str, imgenetA_gen_path: str, limit = 5):
    """
    Scrape images for each imagenet-A class in the given directory path using the specified image style.
    Images are intially stored in a folder named as "img_style + \" \" class_name". Then the folder is renamed
    scraped_images.
    ---
    img_style (str): The style or keywords to use for image queries.
    imgenetA_gen_path (str): The path to the directory containing class subdirectories.
    limit (int, optional): The minimum number of images to scrape per class. Default is 5.
    """
    class_list = os.listdir(imgenetA_gen_path) # get classes' name
    with tqdm(total=len(class_list), desc="Processing classes") as pbar:
        for class_name in class_list:
            pbar.set_description(f"Processing class: {class_name}")
            class_path = os.path.join(imgenetA_gen_path, class_name)
            new_scraped_img_path = os.path.join(class_path, "scraped_images")
            # if for a class enough images have already been retrieved then skip it
            if os.path.exists(new_scraped_img_path):
                if len(os.listdir(new_scraped_img_path)) >= limit: 
                    pbar.update(1)
                    continue 
            query = img_style + " " + class_name
            downloader.download(query = query, 
                                limit=limit, 
                                output_dir=class_path, 
                                adult_filter_off=True, 
                                force_replace=False, 
                                timeout=60,
                                verbose=False)
            current_scraped_img_path = os.path.join(class_path, query)
            os.rename(current_scraped_img_path, new_scraped_img_path)
            pbar.update(1)

#### **Install and run Ollama**

To create the new images, we decided to use a StableDiffusion model to which we provided a list of prompts as input, one for each image to be generated. To generate the prompts, we used an LLM, specifically Llama 3.1. We were able to obtain this model through the Ollama library<sup>[6]</sup>, which provides the model and all the necessary configurations for its use. To get everything needed to download and run Ollama, and therefore Llama 3.1, simply execute the file **install_and_run_ollama.sh**. The script will download and install Ollama, start it, and instantiate Llama 3.1.

```#!/bin/bash

# Download and install Ollama
curl -fsSL https://ollama.com/install.sh | sh

# Start the Ollama server in the background
ollama serve &

# Pull the specified model
# To add more models: ollama pull [model_name]
# List of available models can be found at https://ollama.com/library 
ollama pull llama3.1

# Display the list of availavle model
# Used as a success message when the script completes
ollama list

# Run the specified model
# Execute the following command ONLY if you want to run ollama from terminal
# ollama run llama3.1
```

#### **LLM context**

The **llm_context.json** file is used to provide context for generating prompts for a text-to-image generator model.

Each entry in the file specifies a Role and Content:

* **Role**: Specifies who is speaking or interacting in the conversation. It can either be "system" (representing the LLM) or "user" (representing the person providing input to the LLM).
* **Content**: This field contains the actual message or instructions being communicated. It includes system instructions or user-provided input regarding the prompts to be generated.

The messages can be of the following types:
1.    A message that sets the system's role and context, instructing the LLM on how to generate prompts for a text-to-image generation task.
2.    A message that serves as an example input a user might provide. It helps demonstrate how the user specifies the class name, number of prompts, and style of the picture.
3.    A message that provides an example output from the LLM, demonstrating the kind of response it should generate based on the provided input.

In conclusion, the content of the file provides a framework that guides the LLM in understanding the context of the conversation, adhering to rules, and producing the desired output format.

In [None]:
imagenetA_generator = ImageGenerator()

In [None]:
imgenetA_gen_path = "INSERT THE PATH OF THE GENERATED DATASET"
context_llm = [
    {"role": "system", "content": "You are a system that generates reasonable and accurate prompts to be fed to generative text2image models. The prompt contains: [number of prompts to generate], [class name], [style of picture]. The prompts should slightly differ from one another in the background (e.g. on a beach, in a park etc.) or in the perspective (e.g. viewed from above, viewed from right etc.). This variation must be reasonable, so for example the prompt \"A chicken in the water viewed from above\" is not a good prompt because chickens don't swim. The prompt specifies that the [class name] must be centered in the generated image. Return only the prompts in a python list and nothing else (no comments, no explanations)."}, 
    {"role": "user", "content": "class:warrior chief with tribal panther make up blue on red and serious eyes, number of prompts:2, style of picture: photo"}, 
    {"role": "system", "content": "[\"portrait photo of a old warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3 --beta --upbeta\", portrait photo of a young warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes 50mm portrait photography, hard rim lighting photography--beta --ar 2:3  --beta --upbeta\"]"}
]

In [None]:
create_dir_generated_images(imgenetA_gen_path, imagenetA_classes)

In [None]:
# generate prompts
skipped_classes = imagenetA_generator.generate_prompts(
    num_prompts_per_class=20,
    style_of_picture="photograph",
    path=imgenetA_gen_path,
    context_llm = context_llm,
    llm_model = "llama3.1",
    clip_text_encoder = "ViT-L/14"
)

In [None]:
scrape_images_imagenetA(img_style = "a photo of",
                        imgenetA_gen_path = imgenetA_gen_path,
                        limit = 5)

In [None]:
# generate images
model_id = "runwayml/stable-diffusion-v1-5"

By default the stable diffusion pipelines provided by `diffusers` use the `PNDMScheduler`. Given the large number of images to produce and the fact that we are not working on fine-grained classification we decided to use the `DPMSolverMultistepScheduler`. This scheduler produces results of a slightly lower quality, but still very good with a much lower `number_of_inference_steps` required. In fact, while the former requires `50` inference steps, the latter only `25` which means that generating each image takes only `4 seconds` regardless of the pipeline used, instead of `30` seconds on a `T4 GPU` and using `float16 ` tensors.

In [None]:
pipet2i = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipet2i.scheduler = DPMSolverMultistepScheduler.from_config(pipet2i.scheduler.config)
pipet2i = pipet2i.to("cuda")
num_inference_steps = 25

In [None]:
pipei2i = StableDiffusionImg2ImgPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipei2i.scheduler = DPMSolverMultistepScheduler.from_config(pipei2i.scheduler.config)
pipei2i = pipei2i.to("cuda")
strength = 0.89 
num_inference_steps = int(strength**(-1)*25) # to make sure that the number of steps is 25 no matter the selected strength 

In [None]:
class_to_skip = []

In [None]:
imagenetA_generator.generate_images(path = imgenetA_gen_path,
                                    num_images_per_class = 30,
                                    class_to_skip = class_to_skip,
                                    image_generation_pipeline = pipet2i, # select one of pipei2i and pipet2i
                                    num_inference_steps = num_inference_steps,
                                    guidance_scale = 12,
                                    strength=strength)

## **Testing**

The Tester class is designed to facilitate the running of experiments involving a deep neural network model. It provides methods to manage various aspects of the experimental setup, including configuring models and optimizers, handling augmentations, computing statistics, and saving results.

In [None]:
class Tester:
    """
    A class to run all the experiments. It stores all the informations to reproduce the experiments in a json file
    at exp_path.
    """
    def __init__(self, model, optimizer, exp_path, device):
        self.__model = model
        self.__optimizer = optimizer
        self.__device = device
        self.__exp_path = exp_path

    def save_result(self, accuracy, path_result, num_augmentations, augmentations, seed_augmentations, top_augmentations, MEMO, num_adaptation_steps, lr_setting, weights, prior_strength, time_test, use_MC, gen_aug_settings):
        """
        Takes all information of the experiment saves it in a json file stored at exp_path
        """
        data = {
            "accuracy": accuracy,
            "top_augmentations" : top_augmentations,
            "use_MEMO" : MEMO,
            "num_adaptation_steps" : num_adaptation_steps,
            "lr_setting" : lr_setting,
            "weights" : weights,
            "num_augmentations" : num_augmentations,
            "seed_augmentations": seed_augmentations,
            "augmentations" : [str(augmentation) for augmentation in augmentations],
            "prior_strength" : prior_strength,
            "MC" : use_MC,
            "time_test" : time_test,
            "gen_aug_settings" : gen_aug_settings
        }
        try:
            with open(path_result, 'w') as json_file:
                json.dump(data, json_file)
        except:
            print("Result were not saved")

    def get_model(self, weights_imagenet, MC):
        """
        Utility function to instantiate a torch model. The argument weights_imagenet should have
        a value in accordance with the parameter weights of torchvision.models.
        """
        if MC:
            self.__model=ResNet50Dropout(weights=weights_imagenet, dropout_rate=MC['dropout_rate'], dropout_positions = MC["dropout_positions"])
            model = self.__model
        else:
            model = self.__model(weights=weights_imagenet)

        model.to(self.__device)
        model.eval()
        return model

    def get_optimizer(self, model, lr_setting:list):
        """
        Utility function to instantiate a torch optimizer.
        ----------
        lr_setting: must be a list containing either one global lr for the whole model or a dictionary
        where each value is a list with a list of parameters' names and a lr for those parameters.
        e.g. 
        lr_setting = [{
            "classifier" : [["fc.weight", "fc.bias"], 0.00025]
            }, 0]
        lr_setting = [0.00025]
        """
        if len(lr_setting) == 2:
            layers_groups = []
            lr_optimizer = []
            for layers, lr_param_name in lr_setting[0].items():
                layers_groups.extend(lr_param_name[0])
                params = [param for name, param in model.named_parameters() if name in lr_param_name[0]]
                lr_optimizer.append({"params":params, "lr": lr_param_name[1]})
            other_params = [param for name, param in model.named_parameters() if name not in layers_groups]
            lr_optimizer.append({"params":other_params})
            optimizer = self.__optimizer(lr_optimizer, lr = lr_setting[1], weight_decay = 0)
        else:
            optimizer = self.__optimizer(model.parameters(), lr = lr_setting[0], weight_decay = 0)
        return optimizer

    def get_imagenetA_masking(self, imagenetA_masking_dict = None):
        """
        All torchvision models output a tensor [B,1000] with "B" being the batch dimension. This function 
        returns a list of indices to apply to the model's output to use the model on imagenet-A dataset.
        ----------
        indices_in_1k: list of indices to map [B,1000] -> [B,200]
        """
        if imagenetA_masking_dict == None:
            imagenetA_masking_path = "INSERT PATH TO IMAGENET A MASKING"
            with open(imagenetA_masking_path, 'r') as json_file:
                imagenetA_masking = json.load(json_file)
            indices_in_1k = [int(k) for k in imagenetA_masking if imagenetA_masking[k] != -1]
        else:
            indices_in_1k = [int(k) for k in imagenetA_masking_dict if imagenetA_masking_dict[k] != -1]
        return indices_in_1k

    def get_monte_carlo_statistics(self, mc_logits):
        """
        Compute mean, median, mode and standard deviation of the Monte Carlo samples.
        """
        statistics = {}
        mean_logits = mc_logits.mean(dim=0)
        statistics['mean'] = mean_logits

        median_logits = mc_logits.median(dim=0).values
        statistics['median'] = median_logits

        pred_classes = mc_logits.argmax(dim=1)
        pred_classes_cpu = pred_classes.cpu().numpy()
        mode_predictions, _ = stats.mode(pred_classes_cpu, axis=0)
        mode_predictions = torch.tensor(mode_predictions.squeeze(), dtype=torch.long)
        statistics['mode'] = mode_predictions

        uncertainty = mc_logits.var(dim=0)
        statistics['std'] = uncertainty
        return statistics

    def get_prediction(self, image_tensors, model, masking, TTA = False, top_augmentations = 0, MC = None):
        """
        Takes a tensor of images and outputs a prediction for each image.
        ----------
        image_tensors: is a tensor of [B,C,H,W] if TTA is used or if both MEMO and TTA are not used, or of dimension [C,H,W]
                       if only MEMO is used
        masking: a list of indices to map the imagenet1k logits to the one of imagenet-A
        top_augmentations: a non-negative integer, if greater than 0 then the "top_augmentations" with the lowest entropy are
                           selected to make the final prediction
        MC: a dictionary containing the number of evaluations using Monte Carlo Dropout and the dropout rate
        """
        if MC:
            model.train()  # enable dropout by setting the model to training mode
            mc_logits = []
            for _ in range(MC['num_samples']):
                logits = model(image_tensors)[:,masking] if image_tensors.dim() == 4 else model(image_tensors.unsqueeze(0))[:,masking]
                mc_logits.append(logits)
            mc_logits = torch.stack(mc_logits, dim=0)
            if TTA:
                # first mean is over MC samples, second mean is over TTA augmentations
                probab_augmentations = F.softmax(mc_logits - mc_logits.max(dim=2, keepdim=True)[0], dim=2)
                if top_augmentations:
                    probab_augmentations = self.get_best_augmentations(probab_augmentations, top_augmentations)
                y_pred = probab_augmentations.mean(dim=0).mean(dim=0).argmax().item()
                statistics = self.get_monte_carlo_statistics(probab_augmentations.mean(dim=1))
                return y_pred, statistics
            statistics = self.get_monte_carlo_statistics(mc_logits)
            return statistics['median'].argmax(dim=1), statistics
        else:
            logits = model(image_tensors)[:,masking] if image_tensors.dim() == 4 else model(image_tensors.unsqueeze(0))[:,masking]
            if TTA:
                probab_augmentations = F.softmax(logits - logits.max(dim=1)[0][:, None], dim=1)
                if top_augmentations:
                    probab_augmentations = self.get_best_augmentations(probab_augmentations, top_augmentations)
                y_pred = probab_augmentations.mean(dim=0).argmax().item()
                return y_pred, None
            return logits.argmax(dim=1), None

    def compute_entropy(self, probabilities: torch.tensor):
        """
        See MEMO.py
        """
        return compute_entropy(probabilities)

    def get_best_augmentations(self, probabilities: torch.tensor, top_k: int):
        """
        See MEMO.py
        """
        return get_best_augmentations(probabilities, top_k)

    def get_test_augmentations(self, input:torch.tensor, augmentations:list, num_augmentations:int, seed_augmentations:int):
        """
        See MEMO.py
        """
        return get_test_augmentations(input, augmentations, num_augmentations, seed_augmentations)

    def retrieve_generated_images(self, img, num_images, clip_model, clip_preprocess, img_to_tensor_pipe, data_path, use_t2i_similarity, t2i_images, i2i_images, threshold):
        """
        See image_generator.py.
        """
        return retrieve_gen_images(img = img,  
                                   num_images = num_images, 
                                   clip_model = clip_model, 
                                   clip_preprocess = clip_preprocess,
                                   img_to_tensor_pipe = img_to_tensor_pipe,
                                   data_path = data_path,
                                   use_t2i_similarity = use_t2i_similarity, 
                                   t2i_images = t2i_images, 
                                   i2i_images = i2i_images, 
                                   threshold = threshold)
    
    def test(self,
             augmentations:list, 
             num_augmentations:int, 
             seed_augmentations:int,
             img_root:str,
             lr_setting:list,
             weights_imagenet = None,
             dataset = "imagenetA",
             imagenetA_masking = None,
             batch_size = 64,
             MEMO = False,
             num_adaptation_steps = 0,
             top_augmentations = 0,
             TTA = False,
             prior_strength = -1,
             verbose = True,
             log_interval = 1,
             MC = None,
             gen_aug_settings = None):
        """
        Main function to test a torchvision model with different test-time adaptation techniques 
        and keep track of the results and the experiment setting. 
        ---
        augmentations: list of torchvision.transforms functions.
        num_augmentations: the number of augmentations to use for each sample to perform test-time adaptation.
        seed_augmentations: seed to reproduce the sampling of augmentations.
        img_root: str path to get a dataset in a torch format.
        lr_setting: list with lr instructions to adapt the model. See "get_optimizer" for more details.
        weights_imagenet: weights_imagenet should have a value in accordance with the parameter 
                          weights of torchvision.models.
        dataset: the name of the dataset to use. Note: this parameter doesn't directly control the data 
                 used, it's only used to use the right masking to map the models' outputs to the right dimensions. 
                 At the moment only Imagenet-A masking is supported.
        MEMO: a boolean to use marginal entropy minimization with one test point
        TTA: a boolean to use test time augmentation
        top_augmentations: if MEMO or TTA are set to True, then values higher than zero select the top_augmentations 
                           with the lowest entropy (highest confidence).
        prior_strength: defines the weight given to pre-trained statistics in BN adaptation. If negative, then no BN
                        adaptation is applied.
        verbose: use loading bar to visualize accuracy and number of batch during testing.
        log_interval: defines after how many batches a new accuracy should be displayed. Default is 1, thus 
                      after each batch a new value is displayed. 
        """
        # check some basic conditions
        assert bool(num_adaptation_steps) == MEMO, "When using MEMO adaptation steps should be > 1, otherwise equal to 0."  
        if not (MEMO or TTA):
            assert not (num_augmentations or top_augmentations), "If both MEMO and TTA are set to False, then top_augmentations and num_augmentations must be 0"
        assert not lr_setting if not MEMO else True, "If MEMO is false, then lr_setting must be None" 
        assert isinstance(prior_strength, (float,int)) , "Prior adaptation must be either a float or an int"
        assert isinstance(gen_aug_settings, dict), "gen_aug_settings must be a dict containing settings to retrieve the generated images"
        
        # get the name of the weigths used and define the name of the experiment 
        weights_name = str(weights_imagenet).split(".")[-1] if weights_imagenet else "MEMO_repo"
        use_MC = True if MC else False
        name_result = f"MEMO_{MEMO}_AdaptSteps_{num_adaptation_steps}_adaptBN_{prior_strength}_TTA_{TTA}_aug_{num_augmentations}_topaug_{top_augmentations}_seed_aug_{seed_augmentations}_weights_{weights_name}_MC_{use_MC}_genAug_{bool(gen_aug_settings)}"
        path_result = os.path.join(self.__exp_path,name_result)
        assert not os.path.exists(path_result),f"MEMO test already exists: {path_result}"

        # in case of using dropout, check if the model is a ResNet50Dropout and the parameters are correct
        if MC:
            assert isinstance(self.__model, ResNet50Dropout), f"To use dropout the model must be a ResNet50Dropout"
            assert MC['num_samples'] > 1, f"To use dropout the number of samples must be greater than 1" 

        # transformation pipeline used in ResNet-50 original training
        transform_loader = T.Compose([
            T.Resize(256),
            T.CenterCrop(224),
            T.ToTensor()
        ])

        # to use after model's update
        normalize_input = T.Compose([
                        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                    ])

        test_loader = get_data(batch_size, img_root, transform = transform_loader, split_data=False)
        model = self.get_model(weights_imagenet, MC)

        # if MEMO is used, create a checkpoint to reload after each model and optimizer update
        if MEMO:
            optimizer = self.get_optimizer(model = model, lr_setting = lr_setting)
            MEMO_checkpoint_path = os.path.join(self.__exp_path,"checkpoint.pth")
            torch.save({
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, MEMO_checkpoint_path)
            MEMO_checkpoint = torch.load(MEMO_checkpoint_path)

        if dataset == "imagenetA":
            imagenetA_masking = self.get_imagenetA_masking(imagenetA_masking)
        
        if gen_aug_settings:
            clip_model, clip_preprocess = clip.load(gen_aug_settings["clip_img_encoder"])
        
        if prior_strength < 0:
            torch.nn.BatchNorm2d.prior_strength = 1
        else:
            torch.nn.BatchNorm2d.prior_strength = prior_strength / (prior_strength + 1)
            torch.nn.BatchNorm2d.forward = adaptive_bn_forward

        # Initialize a dictionary to store accumulated time for each step
        time_dict = {
            "MEMO_update": 0.0,
            "get_augmentations": 0.0,
            "confidence_selection": 0.0,
            "get_prediction": 0.0,
            "get_gen_images" : 0.0,
            "total_time": 0.0
        }

        samples = 0.0
        cumulative_accuracy = 0.0

        for batch_idx, (inputs, targets) in enumerate(test_loader):
            inputs, targets = inputs.to(self.__device), targets.to(self.__device)
            if MEMO or TTA:
                for input, target in zip(inputs, targets):
                    if MEMO:
                        model.load_state_dict(MEMO_checkpoint['model'])
                        model.eval()
                        optimizer.load_state_dict(MEMO_checkpoint['optimizer'])

                    # get normalized augmentations
                    start_time_augmentations = time.time()
                    test_augmentations = self.get_test_augmentations(input, augmentations, num_augmentations, seed_augmentations)
                    end_time_augmentations = time.time()
                    time_dict["get_augmentations"] += (end_time_augmentations - start_time_augmentations)

                    # retrieve generated images
                    if gen_aug_settings:
                        start_time_gen_augmentations = time.time()
                        retrieved_gen_images = self.retrieve_generated_images(img = input, 
                                                                              num_images = gen_aug_settings["num_img"], 
                                                                              clip_model = clip_model, 
                                                                              clip_preprocess = clip_preprocess,
                                                                              img_to_tensor_pipe = transform_loader, 
                                                                              data_path = gen_aug_settings["gen_data_path"], 
                                                                              use_t2i_similarity = gen_aug_settings["use_t2i_similarity"], 
                                                                              t2i_images = gen_aug_settings["t2i_img"], 
                                                                              i2i_images = gen_aug_settings["i2i_img"],
                                                                              threshold = gen_aug_settings["threshold"])
                        if len(retrieved_gen_images):
                            retrieved_gen_images = retrieved_gen_images.to(self.__device)
                        end_time_gen_augmentations = time.time()
                        time_dict["get_gen_images"] += (end_time_gen_augmentations - start_time_gen_augmentations)

                    test_augmentations = test_augmentations.to(self.__device)

                    for _ in range(num_adaptation_steps):
                        logits = model(test_augmentations)

                        # apply imagenetA masking
                        if dataset == "imagenetA":
                            logits = logits[:, imagenetA_masking]
                        # compute stable softmax
                        probab_augmentations = F.softmax(logits - logits.max(dim=1)[0][:, None], dim=1)

                        # confidence selection for augmentations
                        if top_augmentations:
                            start_time_confidence_selection = time.time()
                            probab_augmentations = self.get_best_augmentations(probab_augmentations, top_augmentations)
                            end_time_confidence_selection = time.time()
                            time_dict["confidence_selection"] += (end_time_confidence_selection - start_time_confidence_selection)

                        if len(gen_aug_settings):
                            if len(retrieve_gen_images):
                                gen_images_logits = model(retrieved_gen_images)
                                if dataset == "imagenetA":
                                    gen_images_logits = gen_images_logits[:, imagenetA_masking]
                                probab_gen_augmentations = F.softmax(gen_images_logits - gen_images_logits.max(dim=1)[0][:, None], dim=1)
                                probab_augmentations = torch.cat([probab_augmentations,probab_gen_augmentations],dim=0)

                        if MEMO:
                            start_time_memo_update = time.time()
                            marginal_output_distribution = torch.mean(probab_augmentations, dim=0)
                            marginal_loss = self.compute_entropy(marginal_output_distribution)
                            marginal_loss.backward()
                            optimizer.step()
                            optimizer.zero_grad()
                            end_time_memo_update = time.time()
                            time_dict["MEMO_update"] += (end_time_memo_update - start_time_memo_update)

                    start_time_prediction = time.time()
                    with torch.no_grad():
                        if TTA:
                            # statistics:
                            # dictionary containing statistics resulting from the application of monte carlo dropout
                            # look at get_monte_carlo_statistics() for more details
                            y_pred, statistics = self.get_prediction(test_augmentations, model, imagenetA_masking, TTA, top_augmentations, MC=MC)
                        else:
                            input = normalize_input(input)
                            y_pred, statistics = self.get_prediction(input, model, imagenetA_masking, MC=MC)
                        cumulative_accuracy += int(target == y_pred)
                    end_time_prediction = time.time()
                    time_dict["get_prediction"] += (end_time_prediction - start_time_prediction)
            else:
                start_time_prediction = time.time()
                with torch.no_grad():
                    inputs = normalize_input(inputs)
                    y_pred, _ = self.get_prediction(inputs, model, imagenetA_masking, MC=MC)
                    correct_predictions = (targets == y_pred).sum().item()
                    cumulative_accuracy += correct_predictions
                end_time_prediction = time.time()
                time_dict["get_prediction"] += (end_time_prediction - start_time_prediction)

            samples += inputs.shape[0]

            if verbose and batch_idx % log_interval == 0:
                current_accuracy = cumulative_accuracy / samples * 100
                print(f'Batch {batch_idx}/{len(test_loader)}, Accuracy: {current_accuracy:.2f}%', end='\r')

        accuracy = cumulative_accuracy / samples * 100
        time_dict["total_time"] += sum(time_dict.values())

        # save information to reproduce the experiment
        self.save_result(accuracy = accuracy,
                         path_result = path_result,
                         seed_augmentations = seed_augmentations,
                         num_augmentations = num_augmentations,
                         augmentations = augmentations,
                         top_augmentations = top_augmentations,
                         MEMO = MEMO,
                         num_adaptation_steps = num_adaptation_steps,
                         lr_setting = lr_setting,
                         weights = weights_name,
                         prior_strength = prior_strength,
                         use_MC = use_MC,
                         time_test = time_dict,
                         gen_aug_settings = gen_aug_settings)

        return accuracy

To run the solution, simply execute the following code sections.

In [None]:
imagenet_a_path = "imagenet-a"

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [None]:
augmix_augmentations = [
    T.AugMix(severity=3, mixture_width=3, chain_depth=3, alpha=1.0),
    T.AugMix(severity=2, mixture_width=2, chain_depth=3, alpha=1.0),
    T.AugMix(severity=4, mixture_width=4, chain_depth=3, alpha=1.0),
    T.AugMix(severity=3, mixture_width=3, chain_depth=3, alpha=1.0),
    T.AugMix(severity=2, mixture_width=2, chain_depth=3, alpha=1.0),
    T.AugMix(severity=4, mixture_width=4, chain_depth=3, alpha=1.0),
    T.AugMix(severity=3, mixture_width=3, chain_depth=3, alpha=1.0),
    T.AugMix(severity=2, mixture_width=2, chain_depth=3, alpha=1.0),
    T.AugMix(severity=4, mixture_width=4, chain_depth=3, alpha=1.0),
    T.AugMix(severity=3, mixture_width=3, chain_depth=3, alpha=1.0),
    T.AugMix(severity=2, mixture_width=2, chain_depth=3, alpha=1.0),
    T.AugMix(severity=4, mixture_width=4, chain_depth=3, alpha=1.0),
    T.AugMix(severity=3, mixture_width=3, chain_depth=3, alpha=1.0),
    T.AugMix(severity=2, mixture_width=2, chain_depth=3, alpha=1.0),
    T.AugMix(severity=4, mixture_width=4, chain_depth=3, alpha=1.0),
    T.AugMix(severity=3, mixture_width=3, chain_depth=3, alpha=1.0),
    T.AugMix(severity=2, mixture_width=2, chain_depth=3, alpha=1.0),
    T.AugMix(severity=4, mixture_width=4, chain_depth=3, alpha=1.0)
]

### Resnet50

In [None]:
exp_path_a = "INSERT THE PATH TO SAVE THE EXPERIMENTS DETAILS"

In [None]:
MC = {
	"dropout_rate": 0.2,
	"num_samples": 10,
	"use_dropout": False,
    "dropout_positions": dropout_positions
}

In [None]:
tester_resnet50 = Tester(
    model = ResNet50Dropout() if MC['use_dropout'] else models.resnet50,
    optimizer = torch.optim.SGD,
    exp_path = exp_path_a,
    device = device
)

In [None]:
# you can assign different learning rates to different layers of the model
#lr_setting = [{
#    "classifier" : [["fc.weight", "fc.bias"], 0.00025]
#}, 0]

# gloabal learning rate
lr_setting_sgd = [0.00025] # setting used in MEMO paper for SGD
lr_setting_adam = [0.0001] # setting used in MEMO paper for ADAM

In [None]:
imagenetV1_weights = models.ResNet50_Weights.IMAGENET1K_V1 # MEMO paper used these weights
imagenetV2_weights = models.ResNet50_Weights.IMAGENET1K_V2

In [None]:
gen_aug_settings = {
    "clip_img_encoder" : "ViT-L/14",
    "num_img" : 40,
    "gen_data_path" : imgenetA_gen_path,
    "use_t2i_similarity" : True,
    "t2i_img" : True,
    "i2i_img" : True,
    "threshold" : 0.45
}

In [None]:
tester_resnet50.test(
     augmentations = augmix_augmentations,
     num_augmentations = 16,
     seed_augmentations = 42,
     batch_size = 64,
     img_root = imagenet_a_path,
     imagenetA_masking = imagenetA_masking,
     dataset = "imagenetA",
     num_adaptation_steps = 4,
     MEMO = True,
     lr_setting = lr_setting_sgd,
     top_augmentations = 8, 
     weights_imagenet = imagenetV1_weights,
     prior_strength = 16,
     TTA = True,
     MC = None,
     gen_aug_settings = gen_aug_settings
)

## **Results**

<table>
  <thead>
    <tr>
      <th rowspan="1">Experiment</th>
      <th rowspan="1">Dataset</th>
      <th rowspan="1">Base Model</th>
      <th rowspan="1">Weights</th>
      <th rowspan="1", colspan="2">Optimizer</th>
      <th rowspan="1">Optimization Steps</th>
      <th rowspan="1", colspan="3">Augmentations</th>
      <th rowspan="1">Batch Size</th>
      <th rowspan="1">MEMO</th>
      <th rowspan="1">Confidence Selection</th>
      <th rowspan="1", colspan="2">BN</th>
      <th rowspan="1">TTA</th>
      <th rowspan="1", colspan="3">MC</th>
      <th rowspan="1">Efficient DiffTPT</th>
      <th rowspan="1">Accuracy</th>
      <th rowspan="1">Inference Time</th>
    </tr>
    <tr>
      <th>Nr.</th>
      <th></th>
      <th></th>
      <th></th>
      <th>Type</th>
      <th>LR</th>
      <th>Nr.</th>
      <th>Type</th>
      <th>Number</th>
      <th>Seed</th>
      <th></th>
      <th></th>
      <th></th>
      <th></th>
      <th>Prior Strength</th>
      <th></th>
      <th></th>
      <th>Dropout rate</th>
      <th>Nr. Samples</th>
      <th></th>
      <th>%</th>
      <th></th>
    </tr>
  </thead>
  <tbody align="center">
    <tr>
      <th>1</th>
      <td>ImageNet-A</td>
      <td>ResNet50</td>
      <td>Imagenet_1K_V1</td>
      <td>-</td>
      <td>-</td>
      <td>-</td>
      <td>-</td>
      <td>-</td>
      <td>-</td>
      <td>64</td>
      <td>False</td>
      <td>0</td>
      <td>False</td>
      <td>-</td>
      <td>False</td>
      <td>False</td>
      <td>-</td>
      <td>-</td>
      <td>False</td>
      <td>0.026</td>
      <td>00:00:20</td>
    </tr>
    <tr>
      <th>2</th>
      <td>ImageNet-A</td>
      <td>ResNet50</td>
      <td>Imagenet_1K_V1</td>
      <td>-</td>
      <td>-</td>
      <td>-</td>
      <td>AugMix</td>
      <td>16</td>
      <td>42</td>
      <td>64</td>
      <td>False</td>
      <td>8</td>
      <td>False</td>
      <td>-</td>
      <td>True</td>
      <td>False</td>
      <td>-</td>
      <td>-</td>
      <td>False</td>
      <td>0.253</td>
      <td>00:26:20</td>
    </tr>
    <tr>
      <th>3</th>
      <td>ImageNet-A</td>
      <td>ResNet50</td>
      <td>Imagenet_1K_V1</td>
      <td>-</td>
      <td>-</td>
      <td>-</td>
      <td>-</td>
      <td>-</td>
      <td>-</td>
      <td>64</td>
      <td>False</td>
      <td>0</td>
      <td>False</td>
      <td>-</td>
      <td>False</td>
      <td>True</td>
      <td>0.20</td>
      <td>10</td>
      <td>False</td>
      <td>1.613</td>
      <td>00:03:48</td>
    </tr>
    <tr>
      <th>4</th>
      <td>ImageNet-A</td>
      <td>ResNet50</td>
      <td>Imagenet_1K_V1</td>
      <td>-</td>
      <td>-</td>
      <td>-</td>
      <td>-</td>
      <td>-</td>
      <td>-</td>
      <td>64</td>
      <td>False</td>
      <td>0</td>
      <td>True</td>
      <td>16</td>
      <td>False</td>
      <td>False</td>
      <td>-</td>
      <td>-</td>
      <td>False</td>
      <td>0.4</td>
      <td>00:00:23</td>
    </tr>
    <tr>
      <th>5</th>
      <td>ImageNet-A</td>
      <td>ResNet50</td>
      <td>Imagenet_1K_V1</td>
      <td>SGD</td>
      <td>0.00025</td>
      <td>1</td>
      <td>AugMix</td>
      <td>16</td>
      <td>42</td>
      <td>64</td>
      <td>True</td>
      <td>8</td>
      <td>False</td>
      <td>-</td>
      <td>False</td>
      <td>False</td>
      <td>-</td>
      <td>-</td>
      <td>False</td>
      <td>0.16</td>
      <td>00:33:20</td>
    </tr>
    <tr>
      <th>6</th>
      <td>ImageNet-A</td>
      <td>ResNet50</td>
      <td>Imagenet_1K_V1</td>
      <td>SGD</td>
      <td>0.00025</td>
      <td>1</td>
      <td>AugMix</td>
      <td>16</td>
      <td>42</td>
      <td>64</td>
      <td>True</td>
      <td>8</td>
      <td>True</td>
      <td>16</td>
      <td>True</td>
      <td>False</td>
      <td>-</td>
      <td>-</td>
      <td>False</td>
      <td>0.853</td>
      <td>00:38:44</td>
    </tr>
    <tr>
      <th>7</th>
      <td>ImageNet-A</td>
      <td>ResNet50</td>
      <td>Imagenet_1K_V1</td>
      <td>SGD</td>
      <td>0.00025</td>
      <td>1</td>
      <td>AugMix</td>
      <td>16</td>
      <td>42</td>
      <td>64</td>
      <td>True</td>
      <td>8</td>
      <td>True</td>
      <td>16</td>
      <td>True</td>
      <td>True</td>
      <td>0.20</td>
      <td>10</td>
      <td>False</td>
      <td>1.053</td>
      <td>01:43:29</td>
    </tr>
    <tr>
      <th>8</th>
      <td>ImageNet-A</td>
      <td>ResNet50</td>
      <td>Imagenet_1K_V1</td>
      <td>SGD</td>
      <td>0.00025</td>
      <td>4</td>
      <td>AugMix</td>
      <td>16</td>
      <td>42</td>
      <td>64</td>
      <td>True</td>
      <td>8</td>
      <td>False</td>
      <td>-</td>
      <td>False</td>
      <td>False</td>
      <td>-</td>
      <td>-</td>
      <td>False</td>
      <td>0.213</td>
      <td>00:42:03</td>
    </tr>
    <tr>
      <th>9</th>
      <td>ImageNet-A</td>
      <td>ResNet50</td>
      <td>Imagenet_1K_V1</td>
      <td>SGD</td>
      <td>0.00025</td>
      <td>4</td>
      <td>AugMix</td>
      <td>16</td>
      <td>42</td>
      <td>64</td>
      <td>True</td>
      <td>8</td>
      <td>True</td>
      <td>16</td>
      <td>True</td>
      <td>False</td>
      <td>-</td>
      <td>-</td>
      <td>False</td>
      <td>0.853</td>
      <td>00:47:37</td>
    </tr>
    <tr>
      <th>10</th>
      <td>ImageNet-A</td>
      <td>ResNet50</td>
      <td>Imagenet_1K_V1</td>
      <td>SGD</td>
      <td>0.00025</td>
      <td>4</td>
      <td>AugMix</td>
      <td>16</td>
      <td>42</td>
      <td>64</td>
      <td>True</td>
      <td>8</td>
      <td>True</td>
      <td>16</td>
      <td>True</td>
      <td>True</td>
      <td>0.20</td>
      <td>10</td>
      <td>False</td>
      <td>0.95</td>
      <td>01:49:42</td>
    </tr>
    <tr>
      <th>11</th>
      <td>ImageNet-A</td>
      <td>ResNet50</td>
      <td>Imagenet_1K_V1</td>
      <td>ADAM</td>
      <td>0.0001</td>
      <td>1</td>
      <td>AugMix</td>
      <td>16</td>
      <td>42</td>
      <td>64</td>
      <td>True</td>
      <td>8</td>
      <td>False</td>
      <td>-</td>
      <td>False</td>
      <td>False</td>
      <td>-</td>
      <td>-</td>
      <td>False</td>
      <td>0.12</td>
      <td>00:31:45</td>
    </tr>
    <tr>
      <th>12</th>
      <td>ImageNet-A</td>
      <td>ResNet50</td>
      <td>Imagenet_1K_V1</td>
      <td>ADAM</td>
      <td>0.0001</td>
      <td>1</td>
      <td>AugMix</td>
      <td>16</td>
      <td>42</td>
      <td>64</td>
      <td>True</td>
      <td>8</td>
      <td>True</td>
      <td>16</td>
      <td>True</td>
      <td>False</td>
      <td>-</td>
      <td>-</td>
      <td>False</td>
      <td>0.826</td>
      <td>00:37:37</td>
    </tr>
    <tr>
      <th>13</th>
      <td>ImageNet-A</td>
      <td>ResNet50</td>
      <td>Imagenet_1K_V1</td>
      <td>ADAM</td>
      <td>0.0001</td>
      <td>1</td>
      <td>AugMix</td>
      <td>16</td>
      <td>42</td>
      <td>64</td>
      <td>True</td>
      <td>8</td>
      <td>True</td>
      <td>16</td>
      <td>True</td>
      <td>True</td>
      <td>0.20</td>
      <td>10</td>
      <td>False</td>
      <td>1.506</td>
      <td>01:39:47</td>
    </tr>
    <tr>
      <th>14</th>
      <td>ImageNet-A</td>
      <td>ResNet50</td>
      <td>Imagenet_1K_V1</td>
      <td>ADAM</td>
      <td>0.0001</td>
      <td>4</td>
      <td>AugMix</td>
      <td>16</td>
      <td>42</td>
      <td>64</td>
      <td>True</td>
      <td>8</td>
      <td>False</td>
      <td>-</td>
      <td>False</td>
      <td>False</td>
      <td>-</td>
      <td>-</td>
      <td>False</td>
      <td>0.213</td>
      <td>00:41:19</td>
    </tr>
    <tr>
      <th>15</th>
      <td>ImageNet-A</td>
      <td>ResNet50</td>
      <td>Imagenet_1K_V1</td>
      <td>ADAM</td>
      <td>0.0001</td>
      <td>4</td>
      <td>AugMix</td>
      <td>16</td>
      <td>42</td>
      <td>64</td>
      <td>True</td>
      <td>8</td>
      <td>True</td>
      <td>16</td>
      <td>True</td>
      <td>False</td>
      <td>-</td>
      <td>-</td>
      <td>False</td>
      <td>0.826</td>
      <td>00:46:55</td>
    </tr>
    <tr>
      <th>16</th>
      <td>ImageNet-A</td>
      <td>ResNet50</td>
      <td>Imagenet_1K_V1</td>
      <td>ADAM</td>
      <td>0.0001</td>
      <td>4</td>
      <td>AugMix</td>
      <td>16</td>
      <td>42</td>
      <td>64</td>
      <td>True</td>
      <td>8</td>
      <td>True</td>
      <td>16</td>
      <td>True</td>
      <td>True</td>
      <td>0.20</td>
      <td>10</td>
      <td>False</td>
      <td>1.04</td>
      <td>01:44:57</td>
    </tr>  
    <tr>
      <th>17</th>
      <td>ImageNet-A</td>
      <td>ResNet50</td>
      <td>Imagenet_1K_V1</td>
      <td>ADAM</td>
      <td>0.0001</td>
      <td>4</td>
      <td>AugMix</td>
      <td>16</td>
      <td>42</td>
      <td>64</td>
      <td>True</td>
      <td>8</td>
      <td>True</td>
      <td>16</td>
      <td>True</td>
      <td>False</td>
      <td>0</td>
      <td>0</td>
      <td><strong>True<strong></td>
      <td><strong>9.5<strong></td>
      <td>03:00:00</td>
    </tr>
    <tr>
      <th>18</th>
      <td>ImageNet-A</td>
      <td>ViT-B16</td>
      <td>Imagenet_1K_V1</td>
      <td>-</td>
      <td>-</td>
      <td>-</td>
      <td>-</td>
      <td>-</td>
      <td>42</td>
      <td>64</td>
      <td>False</td>
      <td>-</td>
      <td>False</td>
      <td>-</td>
      <td>False</td>
      <td>False</td>
      <td>-</td>
      <td>-</td>
      <td>False</td>
      <td>20.75</td>
      <td>00:01:17</td>
    </tr>
    <tr>
      <th>19</th>
      <td>ImageNet-A</td>
      <td>ViT-B16</td>
      <td>Imagenet_1K_V1</td>
      <td>ADAM</td>
      <td>0.0001</td>
      <td>4</td>
      <td>AugMix</td>
      <td>16</td>
      <td>42</td>
      <td>64</td>
      <td>True</td>
      <td>8</td>
      <td>True</td>
      <td>16</td>
      <td>True</td>
      <td>False</td>
      <td>0</td>
      <td>0</td>
      <td><strong>True<strong></td>
      <td><strong>36.5<strong></td>
      <td>04:20:00</td>
    </tr>
  </tbody>
</table>


Note:
- The hyperparameters used for Efficient DiffTPT are the same as the ones used in the code cell above i.e. 
    - `clip_img_encoder` : "ViT-L/14"
    - `num_img` : 40
    - `use_t2i_similarity` : True
    - `t2i_img` : True
    - `i2i_img` : True
    - `threshold` : 0.45
- Due to the lack of time, the experiments with EfficientDiffTPT were performed on a subset of the dataset of 1280 samples. We assume that using the whole dataset should not change the final result much. 

## **Discussion**

### Previously Existing Methods

All methods implemented improve the accuracy score. Notably we were able to reproduce the result of the original MEMO paper by obtaining roughly 0.9% of accuracy. Additionally, as it was expected, using MC Dropout further increases the performance regardless of the optimizer and the other methods used. The only oddity is that the effect of MC seems not to be additive, but interactive with the other methods. Indeed, when using MC without any other method, aside with Efficient DiffTPT, we reach the highest performance with 1.6%. On the other hand, the second best result is obtained when we use all our methods with `ADAM` optimizer instead of `SGD`. Thus the improvement can not directly be attributed to the use of MC as in the experiment 4 `optim_steps` and a different optimizer. Regardless, in general `MC` seems to increase the performance despite the much higher `inference time`.

### Effiecient DiffTPT

In [None]:
clip_image_encoder = "ViT-L/14"
clip_model, clip_preprocess = clip.load(clip_image_encoder)

retrieved_images = retrieve_gen_images(img = "put one of the dataset images here",
                                       num_images = 40,
                                       data_path = imgenetA_gen_path,
                                       clip_model = clip_model,
                                       clip_preprocess = clip_preprocess,
                                       t2i_images = True,
                                       i2i_images = False,
                                       use_t2i_similarity = True,
                                       threshold = 0.45)

In [None]:
def create_image_grid(images, grid_width, save_path, cell_size=(100, 100)):
    """
    Create a grid of images from a list of PIL images.

    Args:
        images (list of PIL.Image): List of PIL images to arrange in a grid.
        grid_width (int): Number of columns in the grid.
        cell_size (tuple): Size of each cell in the grid (width, height).

    Returns:
        PIL.Image: An image containing the grid of images.
    """
    # Resize images to the specified cell size
    if len(images) == 0:
        print("No images")
        return
        
    resized_images = [img.resize(cell_size) for img in images]
    
    # Calculate grid dimensions
    grid_height = math.ceil(len(images) / grid_width)  # Number of rows needed
    grid_img_width = cell_size[0] * grid_width
    grid_img_height = cell_size[1] * grid_height

    # Create a blank canvas for the grid
    grid_img = Image.new('RGB', (grid_img_width, grid_img_height), (255, 255, 255))  # White background

    # Paste images into the grid
    for i, img in enumerate(resized_images):
        row = i // grid_width
        col = i % grid_width
        x = col * cell_size[0]
        y = row * cell_size[1]
        grid_img.paste(img, (x, y))

    if save_path:
        try:
            grid_img.save(save_path)
            print(f"Grid image saved to {save_path}")
        except Exception as e:
            print(f"Error saving the image: {e}")

    return grid_img

In [None]:
create_image_grid(retrieved_images, 
                  save_path = "PUT YOUR PATH TO SAVE THE GRID IMAGE",
                  grid_width = 4)

<p align="center">
  <img src="images/generated_images/dragonfly_i2i.png" width="500" height="300">  
</p>

<p align="center">
  <img src="images/generated_images/jelly_fish_i2i.png" width="500" height="300">  
</p>

<p align="center">
  <img src="images/generated_images/stingray_i2i.png" width="500" height="300">  
</p>

<p align="center">
  <img src="images/generated_images/goldfinch_i2i.png" width="500" height="300">  
</p>

- **Trade-off Between Data Consistency and Data Scarcity**: Lowering the `threshold` allows for retrieving more generated augmentations, but this comes at the cost of introducing augmentations that may belong to other classes, potentially hindering the adaptation phase. This is especially problematic if the original image has particularly low quality or if the class object within it is difficult to distinguish. In such cases, either a few augmentations might be retrieved, or a high number of augmentations may be obtained, many of which could be unrelated.  
  - **Example**: The first and second images show pictures of `dragonflies` and `jellyfishes` that were retrieved in high quantity without any spurious samples. In contrast, the third image shows 40 pictures of `stingrays` (all with cosine similarity > `threshold`), which unfortunately contain two extraneous augmentations highlighted by the red circle. Finally, the last image illustrates the few images (6 out of 40) that were retrieved from an image of a `goldfinch`.

- **Hyper-parameters**:  
  - **`threshold`**: Due to time constraints, it was not possible to experiment with different values. This parameter is critical, as it determines the aforementioned trade-off between data consistency and scarcity.  
  - **`i2i_images` and `use_t2i_similarity`**: Since no ablation study was performed, it remains unclear to what extent using images generated with the `i2i` Stable Diffusion pipeline and incorporating `t2i_similarity` may improve or possibly reduce performance.

- **Efficient DiffTPT with MC**: In this study we didn't test the use of Efficient DiffTPT combined with MC Dropout due to both the lack of time and incredibly long time it would take to test this setting.

## **Conclusion**

Possible improvements on Efficient DiffTPT:
- Conisder fine-tuned models for image generation
- LLM can guide the creation of prompts making them more realistic and tailored to some style to avoid wasting resources without sacrificing variability
- One can guide the generation using Images retrieved on Internet using the label along with a random prompts. However, it's not clear if it would be useful to use such images. Perhaps, sometimes it is (especially if very big domain shift), sometimes it doesn't make much difference or it might hurt performance. Thus further research is needed.
- Need to explore more powerful prompt engineering and models when using it for fine-grained domain-shift as images needs to be as accurate as possible. Perhaps, in such a use case the original DiffTPT would be preferred.
- Is it better to retrieve images using image embeddings, text embeddings that generated images or a mix of both?

## **Bibliography**

1. **Marvin Zhang and Sergey Levine and Chelsea Finn.** "MEMO: Test Time Robustness via Adaptation and Augmentation." Advances in neural information processing systems, Vol. 35, 2021, pp. 38629-38642. [https://arxiv.org/abs/2110.09506](https://arxiv.org/abs/2110.09506).

2. **Lyzhov, Alexander and Molchanova, Yuliya and Ashukha, Arsenii and Molchanov, Dmitry and Vetrov, Dmitry.** "Greedy Policy Search: A Simple Baseline for Learnable Test-Time Augmentation." Proceedings of Machine Learning Research, Vol. 124, 2020, pp. 1308-1317. [https://proceedings.mlr.press/v124/lyzhov20a.html](https://proceedings.mlr.press/v124/lyzhov20a.html).

3. **Schneider, Steffen and Rusak, Evgenia and Eck, Luisa and Bringmann, Oliver and Brendel, Wieland and Bethge, Matthias.** "Improving robustness against common corruptions by covariate shift adaptation." Advances in Neural Information Processing Systems, Vol. 33, 2020, pp. 11539-11551. [https://proceedings.neurips.cc/paper_files/paper/2020/file/85690f81aadc1749175c187784afc9ee-Paper.pdf](https://proceedings.neurips.cc/paper_files/paper/2020/file/85690f81aadc1749175c187784afc9ee-Paper.pdf).

4. **Gal, Yarin and Ghahramani, Zoubin** "Dropout as a Bayesian Approximation: Representing Model Uncertainty in Deep Learning." Proceedings of The 33rd International Conference on Machine Learning, Vol. 48, 2016, pp. 1050-1059. [https://proceedings.mlr.press/v48/gal16.html](https://proceedings.mlr.press/v48/gal16.html).

5. **Feng, Chun-Mei and Yu, Kai and Liu, Yong and Khan, Salman and Zuo, Wangmeng.** "Diverse Data Augmentation with Diffusions for Effective Test-time Prompt Tuning." 2023 IEEE/CVF International Conference on Computer Vision (ICCV), 2023, pp. 2704-2714. [https://arxiv.org/abs/2308.06038](https://arxiv.org/abs/2308.06038).

6. [Ollama library](https://ollama.com/)



