# **Deep Learning - Test-Time Adaptation**
---
**University of Trento, Academic Year 2023/2024**

---
#### **Group 26**
> <a href="https://github.com/giuseppecurci">Giuseppe Curci</a> \
> 243049

> <a href="https://github.com/andy295">Andrea Cristiano</a> \
> 229370
---
---

In [None]:
!pip install ollama # if ollama is not available, install by executing
                    # the intall_and_run_ollama.sh script
!pip install diffusers
!pip install bing_image_downloader
!pip install git+https://github.com/openai/CLIP.git

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset

import torchvision
import torchvision.transforms as T
import torchvision.models as models
from torchvision.transforms import Compose, Normalize, ToTensor

import matplotlib
from matplotlib.lines import Line2D
from matplotlib import pyplot as plt

from scipy import stats
from scipy.ndimage import zoom

from typing import Dict, List
from typing import Callable, List, Optional, Tuple

from PIL import Image
from tqdm import tqdm
from io import BytesIO
from pathlib import Path
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, DPMSolverMultistepScheduler

import boto3 # read and write for AWS buckets
import clip
import cv2
import gc
import json
import math
import numpy as np
import ollama
import os
import random
import time
import ttach as tta

import warnings
warnings.filterwarnings('ignore')

## **Introduction**

### **Domain Shift**
Domain shift refers to the change in the data distribution between the training phase and the testing phase of a machine learning model. In other words, the data that the model encounters during deployment (inference) differs from the data it was trained on.

Machine learning models are typically trained on a specific dataset. The fundamental assumption is that the training data distribution is as representative as possible of what the models will encounter during deployment. In the presence of domain shift, the model may not generalize well, leading to a degradation in performance.

In real-world applications, it is easy to identify cases where data distributions may vary, and the causes can be numerous, such as changes in the environment, noise in the sensors used to acquire information, etc.

A model trained on images depicting a scene during the spring season may perform poorly when tested on images of the same scene but taken during the winter season, even if the scene's content is the same.

### **Test-Time Adaptation (TTA)**
Test-Time Adaptation refers to techniques that allow a previously trained model to adapt to a new data distribution during inference (testing), without the need for a new training phase.

The goal is to improve the model's performance on the shifted domain by adapting its parameters or predictions based on the data encountered during testing.

## **Pipeline**

(idealmente andrebbe creato un disegno che spieghi cosa stiamo facendo, lo possiamo fare anche dopo. Per ora puoi anche limitarti a spiegare a parole oppure lascialo e lo facciamo quando abbiamo il disegno)

## **Utils**



### **Recover dataset**

In [None]:
# Class that interacts with images stored in an Amazon S3 bucket.
# It allows to load and preprocess images on-the-fly during training or inference.
class S3ImageFolder(Dataset):
    def __init__(self, root, transform=None):
        self.s3_bucket = "deeplearning2024-datasets" # name of the bucket
        self.s3_region = "eu-west-1" # Ireland
        self.s3_client = boto3.client("s3", region_name=self.s3_region, verify=True)
        self.transform = transform

        # Get list of objects in the bucket
        response = self.s3_client.list_objects_v2(Bucket=self.s3_bucket, Prefix=root)
        objects = response.get("Contents", [])
        while response.get("NextContinuationToken"):
            response = self.s3_client.list_objects_v2(
                Bucket=self.s3_bucket,
                Prefix=root,
                ContinuationToken=response["NextContinuationToken"]
            )
            objects.extend(response.get("Contents", []))

        # Iterate and keep valid files only
        self.instances = []
        for ds_idx, item in enumerate(objects):
            key = item["Key"]
            path = Path(key)

            # Check if file is valid
            if path.suffix.lower() not in (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp"):
                continue

            # Get label
            label = path.parent.name

            # Keep track of valid instances
            self.instances.append((label, key))

        # Sort classes in alphabetical order (as in ImageFolder)
        self.classes = sorted(set(label for label, _ in self.instances))
        self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}

    def __len__(self):
        return len(self.instances)

    def __getitem__(self, idx):
        try:
            label, key = self.instances[idx]

            # Download image from S3
            # response = self.s3_client.get_object(Bucket=self.s3_bucket, Key=key)
            # img_bytes = response["Body"]._raw_stream.data

            img_bytes = BytesIO()
            response = self.s3_client.download_fileobj(Bucket=self.s3_bucket, Key=key, Fileobj=img_bytes) # download each image
            # img_bytes = response["Body"]._raw_stream.data

            # Open image with PIL
            img = Image.open(img_bytes).convert("RGB")

            # Apply transformations if any
            if self.transform is not None:
                img = self.transform(img)
        except Exception as e:
            raise RuntimeError(f"Error loading image at index {idx}: {str(e)}")

        return img, self.class_to_idx[label]

# Function to create DataLoaders for training and evaluating models.
# Loads the dataset from the S3 bucket and optionally splits it into training,
# validation, and test sets. It then returns PyTorch DataLoader objects for these datasets.
def get_data(batch_size, img_root, seed = None, split_data = False, transform = None):

    # Load data
    data = S3ImageFolder(root=img_root, transform=transform)

    if split_data:
        # Create train and test splits (80/20)
        num_samples = len(data)
        training_samples = int(num_samples * 0.8 + 1)
        val_samples = int(num_samples * 0.1)
        test_samples = num_samples - training_samples - val_samples

        torch.manual_seed(seed)
        training_data, val_data, test_data = torch.utils.data.random_split(data, [training_samples, val_samples, test_samples])

        # Initialize dataloaders
        train_loader = torch.utils.data.DataLoader(training_data, batch_size, shuffle=True, num_workers=4)
        val_loader = torch.utils.data.DataLoader(val_data, batch_size, shuffle=False, num_workers=4)
        test_loader = torch.utils.data.DataLoader(test_data, batch_size, shuffle=False, num_workers=4)

        return train_loader, val_loader, test_loader

    data_loader = torch.utils.data.DataLoader(data, batch_size, shuffle=False, num_workers=4)
    return data_loader

### **ImageNet-A masking**

The **imagenetA_masking.json** file provides a masking for ImageNet-A dataset indices to the standard 1000-class ImageNet output indices used by pre-trained models in PyTorch's torchvision library.

Each key in the file corresponds to an index in the standard 1000-class ImageNet output vector. The value associated with each key indicates whether that index should be considered when mapping the 1000-class output to a the smaller set of classes ImageNet-A.

A value of -1 indicates that the corresponding index in the 1000-class output should be ignored in the subset of outputs for ImageNet-A.
A non-negative integer value indicates that the corresponding index in the 1000-class output should be included in the subset of outputs for ImageNet-A.

## **Test-time adaptation methods**

During the testing phase, Test-Time Adaptation (TTA) methods allow for modifications to be made to the model. This enables the model to adapt to new data distributions, even if it has not encountered them before, thus maintaining a certain level of reliability.

Below, we outline the techniques used to enhance the model's performance, providing a brief introduction to each method followed by its implementation.

### **MEMO: Test Time Robustness via Adaptation and Augmentation<sup>[1]</sup>**

In this paper, the authors propose a method called MEMO (Marginal Entropy Minimization with One Test Point) designed to address the problem of robustness in deep neural networks when confronted with distribution shifts or unexpected perturbations. To achieve this, the method employs both adaptation and augmentation strategies.

<p align="center">
  <img src="images/MEMO.png" width="600" height="300">  
</p>

Unlike traditional approaches that focus on modifying the training process, MEMO utilizes the information provided by test inputs. It applies data augmentations to a single test input to generate various versions of the input, and from these, it calculates the marginal output distribution. The model's parameters are then updated to minimize the entropy of this marginal distribution. Finally, the model uses these updated parameters to make a prediction on the original test input.

This approach allows MEMO to enhance the model's robustness against unseen distribution changes by effectively adapting to new conditions during testing.

<br>

---
**Algorithm 1** Test time robustness via MEMO

---

**Require:** trained model f<sub>θ</sub>, test point $x$, number of augmentations $B$, learning rate $η$, update rule $G$

1. Sample a<sub>1</sub>...a<sub>B</sub> $\overset{\text{i.i.d.}}{\sim}$ $\mathcal{U}$ ($\mathcal{A}$) and produce augmented points $\tilde{\mathbf{x}}_i = a_i(\mathbf{x})$ for $i \in \{1, \ldots, B\}$
2. Compute estimate $\tilde{p} = \frac{1}{B} \sum_{i=1}^B p_0(y|\tilde{\mathbf{x}}_i) \approx p_0(y|\mathbf{x})$ and $\tilde{\ell} = H(\tilde{p}) \approx \ell(\theta; \mathbf{x})$
3. Adapt parameters via update rule $\theta' \leftarrow G(\theta, \eta, \tilde{\ell})$
4. Predict $\hat{y} \triangleq \arg \max_y p_{\theta'}(y|\mathbf{x})$
---

In [None]:
def compute_entropy(probabilities):
    """
    Takes a tensor of probabilities [1,Classes] and computes the entropy returned as one-dimensional tensor.
    """
    # Ensure probabilities are normalized (sum to 1)
    if not torch.isclose(probabilities.sum(), torch.tensor(1.0)):
        raise ValueError("The probabilities should sum to 1.")

    # Compute entropy
    # Adding a small value to avoid log(0) issues
    epsilon = 1e-10
    probabilities = torch.clamp(probabilities, min=epsilon)
    entropy = -torch.sum(probabilities * torch.log(probabilities))

    return entropy

def get_best_augmentations(probabilities, top_k):
    """
    Takes a tensor of probabilities with dimension [num_augmentations,classes] or [mc_models,num_augmentations,200]
    and outputs a tensor containing the probabilities corresponding to the augmentations
    with the lowest entropy of dimension [top_k, classes] or [mc_models, top_k, classes].
    ----------
    top_k: number of augmentations to select
    probabilities: a tensor of dimension [num_augmentations,200]
    """
    if probabilities.dim() == 2:
        probabilities = probabilities.unsqueeze(0)

    # nested list comprehension needed if probabilities is a 3D tensor (MC dropout)
    entropies = torch.tensor([[compute_entropy(prob) for prob in prob_set] for prob_set in probabilities])
    _, top_k_indices = torch.topk(entropies, top_k, largest=False, sorted=False)
    sorted_top_k_indices = torch.stack([indices[torch.argsort(entropies[i, indices])]
                                            for i, indices in enumerate(top_k_indices)])
    top_k_probabilities = torch.stack([probabilities[i][sorted_top_k_indices[i]]
                                        for i in range(probabilities.shape[0])])
    if top_k_probabilities.shape[0] == 1:
        top_k_probabilities = top_k_probabilities.squeeze(0)

    return top_k_probabilities

def get_test_augmentations(input, augmentations, num_augmentations, seed_augmentations):
    """
    Takes a tensor image of dimension [C,H,W] and returns a tensor of augmentations of dimension [num_augmentations, C,H,W].
    The augmentations are produced by sampling different torchvision.transforms from "augmentations".
    ----------
    input: an image tensor of dimension [C,H,W]
    augmentations: a list of torchvision.transforms augmentations
    num_augmentations: the number of augmentations to produce
    seed_augmentations: seed to reproduce the sampling of augmentations
    """
    torch.manual_seed(seed_augmentations)
    random.seed(seed_augmentations)
    sampled_augmentations = random.sample(augmentations, num_augmentations)
    test_augmentations = torch.zeros((num_augmentations, 3, 224, 224))
    for i, augmentation in enumerate(sampled_augmentations):
        transform_MEMO = T.Compose([
            T.ToPILImage(),
            augmentation,
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        augmented_input = transform_MEMO(input.cpu())
        test_augmentations[i] = augmented_input
    return test_augmentations

### **TTA - Greedy Policy Search: A Simple Baseline for Learnable Test-Time Augmentation<sup>[2]</sup>**

In this paper, the authors propose a method called Greedy Policy Search (GPS) for learning test-time data augmentation policies that enhance the performance of machine learning models. The key idea is that data augmentation policies, typically designed for the training phase, can also be learned and optimized during the test phase to improve model performance.

<p align="center">
  <img src="images/GPS.png" width="990" height="300"/>  
</p>

The method involves iteratively selecting sub-policies that maximize a chosen performance criterion, such as the calibrated log-likelihood on a validation set. As the name suggests, GPS constructs the augmentation policy in a greedy, step-by-step manner, ensuring that each added sub-policy contributes to performance improvement.

This approach offers a simple yet powerful baseline for learning test-time augmentation policies, providing a promising alternative to more complex methods like reinforcement learning or Bayesian optimization, which are traditionally used for training-phase augmentation.

<br>

---
**Algorithm 1** Greedy Policy Search (GPS)

---

**Require:** Trained neural network $p(y \mid x, \theta)$  
**Require:** Validation data $X_{val}, y_{val}$  
**Require:** Pool size $B$, policy size $T$  
**Require:** Prior over sub-policies $p(s)$  

$S \gets \emptyset$ $\hspace{2em}$ $\triangleright$ Pool of candidate sub-policies

**for** $i \gets 1$ **to** $B$ **do**  
$\hspace{2em}$ $s_i \sim p(s)$  
$\hspace{2em}$ $S \gets S \cup \{s_i\}$ $\hspace{6.3em}$ $\triangleright$ Add $s_i$ to pool  
$\hspace{2em}$ $\pi^{s_i}_{val} \gets p(y \mid s_i(X_{val}), \theta)$ $\hspace{2em}$ $\triangleright$ Predict with $s_i$  
**end for**

$P \gets \emptyset$ $\hspace{2.8em}$ $\triangleright$ GPS policy  
$\pi^P_{val} \gets 0$ $\hspace{2em}$ $\triangleright$ Predictions made with GPS policy

**for** $t \gets 1$ **to** $T$ **do**  
$\hspace{0.5em}$ $\triangleright$ Choose the best sub-policy $s^*$ based on **calibrated log-likelihood** on validation:  
$\hspace{2em}$ $s^* \gets \arg\max_{s \in S}$ cLL $\left( \frac{t-1}{t} \pi^P_{val} + \frac{1}{t} \pi^{s}_{val}; y_{val} \right)$  

$\hspace{2em}$ $\pi^P_{val} \gets \frac{t-1}{t} \pi^P_{val} + \frac{1}{t} \pi^{s^*}_{val}$ $\hspace{2em}$ $\triangleright$ Update predictions  
$\hspace{2em}$ $P \gets P \cup \{s^*\}$ $\hspace{5.8em}$ $\triangleright$ Update policy  
**end for**

**return** policy $P$

---

### **Adaptive Batch Normalization - Improving robustness against common corruptions by covariate shift adaptation<sup>[3]<sup>**

In this paper, the authors investigate methods to improve the robustness of machine learning models trained for computer vision against common image corruptions, such as blurring or compression artifacts. These corruptions often degrade model performance, reducing their effectiveness in real-world applications.

Traditional approaches tend to underestimate model robustness in scenarios where models can adapt to corruptions found in multiple unlabeled examples. The authors argue that models should utilize these examples for unsupervised online adaptation, a strategy not commonly employed in current evaluations. Instead of relying on static batch normalization (BN) statistics computed during training, the authors propose that these statistics be dynamically updated by the models using data from corrupted images encountered during testing. This adaptive approach can significantly enhance model performance under real-world conditions.

In [None]:
def adaptive_bn_forward(self, input: torch.Tensor):
    """
    Applies an adaptive batch normalization to the input tensor using precomputed running
    statistics that are updated in an adaptive manner using Schneider et al. [40] formula:
                        mean = N/(N+1)*mean_train + 1/(N+1)*mean_test
                        var = N/(N+1)*var_train + 1/(N+1)*var_test
    N corresponds to the weight that is given to the statistics of the pre-trained model.
    In the implementation, N/(N+1) corresponds to self.prior_strength and can be modified assigning
    a float/int to nn.BatchNorm2d.prior_strength.
    -----------
    input : input tensor of shape [N, C, H, W]
    """
    # compute channel-wise statistics for the input
    point_mean = input.mean([0,2,3]).to(device = self.running_mean.device)
    point_var = input.var([0,2,3], unbiased=True).to(device = self.running_mean.device)
    # BN adaptation
    adapted_running_mean = self.prior_strength * self.running_mean + (1 - self.prior_strength) * point_mean
    adapted_running_var = self.prior_strength * self.running_var + (1 - self.prior_strength) * point_var
    # detach to avoid non-differentiable torch error
    adapted_running_mean = adapted_running_mean.detach()
    adapted_running_var = adapted_running_var.detach()

    return torch.nn.functional.batch_norm(input, adapted_running_mean, adapted_running_var, self.weight, self.bias, False, 0, self.eps)

### **Monte Carlo Dropout - Dropout as a Bayesian Approximation: Representing Model Uncertainty in Deep Learning<sup>[4]</sup>**

In this paper, the authors propose Monte Carlo Dropout, a technique that leverages dropout, a regularization method commonly used during training, to also perform approximate Bayesian inference during the testing phase.

During the model training phase, the dropout technique randomly "drops" or deactivates a fraction of neurons in the network during each forward pass. The probability of dropping neurons can be controlled using a specific parameter. The aim is to prevent neurons from memorizing specific inputs, thus reducing overfitting and encouraging the network to learn more general representations.

However, during the model test phase, dropout is usually turned off, allowing the full network to make predictions. The key idea behind Monte Carlo Dropout is to keep dropout active during the test phase and perform multiple forward passes through the network. The result is that for each forward pass, a different dropout mask is used, which means a random subset of neurons is activated, allowing for different predictions. By averaging these predictions, it is possible to obtain both the final prediction and a measure of the model's uncertainty.

#### **Dropout positions**

The **dropout_positions.json** file is to define the locations within a custom ResNet50 model where dropout layers should be inserted.

The dropout layers are incorporated to enhance the proposed method using Monte Carlo Dropout, a technique that improves model robustness and uncertainty estimation. For more details, refer to the **ResNet50Dropout** section.

In [None]:
class ResNet50Dropout(nn.Module):
"""
It creates a version of the ResNet-50 model that integrates dropout layers at
various points in the architecture. By using dropout, the model can be trained
with a regularization technique that allows for the implementation of
Monte Carlo Dropout.
----------
weights: Optional pre-trained weights for the ResNet-50 model.
dropout_rate: The probability of dropping out neurons during training.
A value of 0 means no dropout is applied and the architecture is identical to the
original ResNet-50.
"""
    def __init__(self, weights=None, dropout_rate=0.):
        super(ResNet50Dropout, self).__init__()

        self.weights = weights
        self.model = models.resnet50(weights=self.weights)
        self.dropout_rate = dropout_rate

        self.dropout_positions = []
        if self.dropout_rate > 0:
            self.dropout_positions = self.get_dropout_positions()

        self._add_dropout()

    # This method reads a JSON file that contains a list of layer names where
    # dropout should be applied.
    def get_dropout_positions(self):
        dropout_positions_path = "/home/sagemaker-user/Domain-Shift-Computer-Vision/utility/data/dropout_positions.json"
        with open(dropout_positions_path, 'r') as json_file:
            dropout_positions = json.load(json_file)
        dropout_positions = dropout_positions["dropout_positions"]

        return dropout_positions

    # This method adds dropout layers to the ResNet-50 model at the specified
    # positions, by looking at the dropout_positions list.
    # For each specified layer, the method wraps the original layer in
    # a nn.Sequential block, which includes the original layer followed by a
    # nn.Dropout layer with the specified dropout rate.
    def _add_dropout(self):
        if 'conv1' in self.dropout_positions:
            self.model.conv1 = nn.Sequential(
                self.model.conv1,
                nn.Dropout(p=self.dropout_rate)
            )

        if 'layer1' in self.dropout_positions:
            self.model.layer1 = nn.Sequential(
                self.model.layer1,
                nn.Dropout(p=self.dropout_rate)
            )

        if 'layer2' in self.dropout_positions:
            self.model.layer2 = nn.Sequential(
                self.model.layer2,
                nn.Dropout(p=self.dropout_rate)
            )

        if 'layer3' in self.dropout_positions:
            self.model.layer3 = nn.Sequential(
                self.model.layer3,
                nn.Dropout(p=self.dropout_rate)
            )

        if 'layer4' in self.dropout_positions:
            self.model.layer4 = nn.Sequential(
                self.model.layer4,
                nn.Dropout(p=self.dropout_rate)
            )

        if 'avgpool' in self.dropout_positions:
            self.model.avgpool = nn.Sequential(
                self.model.avgpool,
                nn.Dropout(p=self.dropout_rate)
            )

        if 'fc' in self.dropout_positions:
            self.model.fc = nn.Sequential(
                nn.Dropout(p=self.dropout_rate),
                self.model.fc
            )

    def forward(self, x):
        return self.model(x)

### **DiffTPT - Diverse Data Augmentation with Diffusions for Effective Test-time Prompt Tuning<sup>[5]</sup>**

Here's a revised and improved version of your text:

---

In this paper, the authors introduce a novel method called DiffTPT (Diverse Test-time Prompt Tuning) to enhance the performance of vision-language models in scenarios where test samples come from previously unseen domains. This approach is particularly relevant for models like CLIP, which rely on prompt tuning to adapt to new tasks without additional task-specific training data.

**Test-time Prompt Tuning (TPT):** TPT refers to the process of adapting the prompts of pre-trained models on the fly during testing, especially when encountering new data that differs from the training distribution. TPT aims to generate adaptive prompts based on each test sample. Typically, this is achieved using two main approaches, either individually or in combination:

- **Data augmentation techniques,** which unfortunately often fail to generate sufficiently diverse data.
- **Entropy-based confidence selection,** which does not always ensure accurate predictions.

<p align="center">
  <img src="images/DiffTPT.png" width="1137" height="300"/>
</p>

Unlike these traditional approaches, DiffTPT leverages pre-trained diffusion models to generate new images. These images typically exhibit a wide range of variations that data augmentation techniques alone cannot achieve. However, this broad variation does not always lead to model improvement due to the varying quality of the generated images.

To address this issue and maintain the quality and relevance of the augmented data, the authors introduce a **cosine similarity-based filtration** mechanism. This mechanism ensures that only the most semantically similar augmented samples are selected, effectively balancing diversity and fidelity.

Finally, the proposed approach combines both traditional augmentation methods and newly generated images to exploit their respective strengths, providing a richer set of training data during test time. This method enhances the model's ability to handle new, unseen test data effectively, making it a valuable approach for real-world applications where collecting labeled data for every new distribution is impractical.

#### **Image generation**

In [1]:
def get_imagenetA_classes():
    """
    ImageNet-A uses the same label structure as the original ImageNet (ImageNet-1K).
    Each class in ImageNet is represented by a synset ID (e.g., n01440764 for "tench, Tinca tinca").
    This function returns a dictionary that maps the synset IDs of ImageNet-A to the corresponding class names.
    ----------
    indices_in_1k: list of indices to map [B,1000] -> [B,200]
    """
    imagenetA_classes_path = "/home/sagemaker-user/Domain-Shift-Computer-Vision/utility/data/imagenetA_classes.json"
    imagenetA_classes_dict = None
    with open(imagenetA_classes_path, 'r') as json_file:
        imagenetA_classes_dict = json.load(json_file)

    # Ensure `class_dict` is a dictionary with keys as class IDs and values as class names
    class_dict = {k: v for k, v in imagenetA_classes_dict.items()}
    return class_dict

def create_dir_generated_images(path):
    classes = list(get_imagenetA_classes().values())
    for class_name in classes:
        class_path = os.path.join(path, class_name)
        os.makedirs(class_path, exist_ok=True)

In [1]:
class ImageGenerator:
    """
    A class designed to handle the generation of text prompts and images,
    specifically for applications involving vision-language models and diffusion
    models. It provides methods for generating prompts, creating images based
    on those prompts, and retrieving generated images that are similar to a
    given input image using a CLIP model.
    """
    def __init__(self):
        pass

    def get_model(self):
        print(self.__model)

    def get_text_embedding(self,clip_model, text):
        text_token = clip.tokenize(text).cuda()
        text_embedding = clip_model.encode_text(text_token).float()
        text_embedding /= text_embedding.norm()
        return text_embedding

    def generate_prompts(self, num_prompts_per_class, style_of_picture, path, context_llm, llm_model = "llama3.1", clip_text_encoder = "ViT-L/14"):
        """
        This method generates text prompts for image classes. It uses a large
        language model (LLM) to create prompts based on the class names and a
        specified picture style. The method then generates embeddings for these
        prompts using a CLIP model and saves them along with the prompt text to
        specific directories. If the LLM fails to generate sufficient prompts,
        the class is skipped, and an error message is logged.
        """
        assert isinstance(llm_model,str), "Model must be a str"
        try:
          ollama.chat(llm_model)
        except ollama.ResponseError as e:
          print('Error:', e.error)
          if e.status_code == 404:
            print("Pulling the model...")
            ollama.pull(llm_model)

        if isinstance(context_llm,str):
            with open(context_llm, 'r') as file:
                context_llm = json.load(file)

        skipped_classes = []

        clip_model, _ = clip.load(clip_text_encoder)
        clip_model.cuda().eval()

        class_list = os.listdir(path)
        with torch.no_grad():
            with tqdm(total=len(class_list), desc="Processing classes") as pbar:
                for class_name in class_list:
                    pbar.set_description(f"Processing class: {class_name}")
                    sub_dir_class = os.path.join(path, class_name)
                    prompts_to_generate = num_prompts_per_class - len(os.listdir(sub_dir_class) - 1) # -1 to account for scraped_img folder
                    if prompts_to_generate <= 0:
                        pbar.update(1)
                        continue
                    counter_flag = 6
                    gen_prompts = []
                    original_prompts_to_gen = prompts_to_generate
                    while counter_flag>0:
                        prompts_generation_instruction = {
                            "role": "user",
                            "content": f"class:{class_name}, number of prompts:{prompts_to_generate}, style of picture: {style_of_picture}"
                        }
                        if len(context_llm) == 3:
                            context_llm.append(prompts_generation_instruction)
                        else:
                            context_llm[3] = prompts_generation_instruction
                        try:
                            response = ollama.chat(model=llm_model, messages=context_llm)
                            content = json.loads(response['message']['content'])  # json.loads to convert str to list
                            if len(content) > prompts_to_generate:
                                prompts_to_generate -= len(content)
                                gen_prompts.extend(content)
                                gen_prompts = gen_prompts[:original_prompts_to_gen]
                                counter_flag = -1
                            else:
                                prompts_to_generate -= len(content)
                                gen_prompts.extend(content)
                        except Exception as e:
                            counter_flag -= 1

                    if len(gen_prompts) != 1:
                        counter_flag = -1

                    if counter_flag == -1:
                        num_prompts_already_gen = len(os.listdir(sub_dir_class))
                        for i in range(num_prompts_already_gen, num_prompts_already_gen + len(gen_prompts)):
                            new_sub_dir = os.path.join(path, class_name, str(i))
                            os.makedirs(new_sub_dir, exist_ok=True)
                            prompt = gen_prompts[i - num_prompts_already_gen]
                            prompt_embedding = self.get_text_embedding(clip_model, prompt)
                            with open(os.path.join(new_sub_dir, "prompt.txt"), 'w') as file:
                                file.write(prompt)
                            torch.save(prompt_embedding, os.path.join(new_sub_dir,"prompt_clip_embedding.pt"))
                    else:
                        skipped_classes.append(class_name)
                        print(f"Skipping class {class_name}.")

        return skipped_classes

    def get_image_embedding(self,clip_model, preprocess, image):
        image_preprocessed = preprocess(image).unsqueeze(0).cuda()
        image_embedding = clip_model.encode_image(image_preprocessed)
        image_embedding /= image_embedding.norm()
        return image_embedding

    def generate_images(self,
                        path,
                        num_images,
                        image_generation_pipeline,
                        num_inference_steps,
                        guidance_scale = 9,
                        strength=1,
                        clip_image_encoder = "ViT-L/14"):
        """
        This method generates a specified number of images for each class in the
        given path using a diffusion-based image generation pipeline.
        It either generates images from text prompts or uses image-to-image
        generation if the pipeline supports it. The generated images are
        embedded using a CLIP model, and the embeddings are saved along with
        the images.
        """
        assert image_generation_pipeline.__class__.__name__ in ("StableDiffusionPipeline", "StableDiffusionImg2ImgPipeline"), "image_generation_pipeline must be one of StableDiffusionPipeline or StableDiffusionImg2ImgPipeline"

        random.seed(42)

        print("Loading CLIP model...")
        clip_model, preprocess = clip.load(clip_image_encoder)
        clip_model.cuda().eval()

        os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'
        for class_name in tqdm(os.listdir(path), desc="Processing classes"):
            mem_allocated_before = torch.cuda.memory_allocated()
            mem_reserved_before = torch.cuda.memory_reserved()
            class_path = os.path.join(path, class_name)
            num_prompts = len(os.listdir(class_path))
            if image_generation_pipeline.__class__.__name__ == "StableDiffusionImg2ImgPipeline":
                print("Loading scraped images...")
                scraped_image_paths = os.path.join(class_path, "scraped_images")
                scraped_images = []
                for scraped_image_path in scraped_image_paths:
                    scraped_image = Image.open(scraped_image_path)
                    scraped_image = scraped_image.resize((512,512))
                    scraped_images.append(scraped_image)
            num_prompts = len(os.listdir(class_path))
            n_perm = math.ceil(num_images / num_prompts)
            num_gen_images = 0 # needed if num_images < num_prompts
            for gen_images_class in os.listdir(class_path):
                if num_gen_images == num_images: break
                gen_image_class = os.path.join(class_path,gen_images_class)
                with open(os.path.join(gen_image_class, "prompt.txt"), 'r') as file:
                    text_prompt = file.read()
                if image_generation_pipeline.__class__.__name__ == "StableDiffusionImg2ImgPipeline":
                    i2i_image_path = os.path.join(gen_image_class,"i2i_gen_images")
                    os.makedirs(i2i_image_path,exist_ok=True)
                    for _ in range(n_perm):
                        scraped_image = random.sample(scraped_images,1)[0]
                        with torch.no_grad():
                            gen_image = image_generation_pipeline(prompt=text_prompt,
                                                                  image=scraped_image,
                                                                  strength=strength,
                                                                  guidance_scale=guidance_scale,
                                                                  num_inference_steps=num_inference_steps).images[0]
                            del scraped_image
                            gen_image_embedding = self.get_image_embedding(clip_model, preprocess, gen_image)
                            save_gen_image_path = os.path.join(i2i_image_path,str(len(os.listdir(i2i_image_path))))
                            os.makedirs(save_gen_image_path)
                            torch.save(gen_image_embedding, os.path.join(save_gen_image_path, "image_embedding.pt"))
                            gen_image.save(os.path.join(save_gen_image_path, "image.png"))
                            num_gen_images += 1
                            if num_gen_images == num_images: break
                else:
                    t2i_image_path = os.path.join(gen_image_class,"t2i_gen_images")
                    os.makedirs(t2i_image_path,exist_ok=True)
                    for _ in range(n_perm):
                        with torch.no_grad():
                            gen_image = image_generation_pipeline(prompt=text_prompt,
                                                                  strength=strength,
                                                                  guidance_scale=guidance_scale,
                                                                  num_inference_steps=num_inference_steps).images[0]
                            gen_image_embedding = self.get_image_embedding(clip_model, preprocess, gen_image)
                            save_gen_image_path = os.path.join(t2i_image_path,str(len(os.listdir(t2i_image_path))))
                            os.makedirs(save_gen_image_path)
                            torch.save(gen_image_embedding, os.path.join(save_gen_image_path, "image_embedding.pt"))
                            gen_image.save(os.path.join(save_gen_image_path, "image.png"))
                            num_gen_images += 1
                            if num_gen_images == num_images: break

def retrieve_gen_images(img,
                        num_images,
                        clip_model,
                        preprocess,
                        img_to_tensor_pipe,
                        data_path = "/home/sagemaker-user/Domain-Shift-Computer-Vision/imagenetA_generated",
                        use_t2i_similarity = False,
                        t2i_images = True,
                        i2i_images = False,
                        threshold = 0.):
        """
        This function retrieves a specified number of images from the generated
        dataset that are most similar to a given input image. It uses cosine
        similarity between the CLIP embeddings of the input image and the
        generated images. The function can consider both text-to-image (t2i)
        and image-to-image (i2i) generated images based on the provided flags,
        and it returns the most similar images as tensors.
        """
        assert i2i_images or t2i_images, "One of t2i_images and i2i_images must be true"
        assert isinstance(use_t2i_similarity, bool), "use_t2i_similarity must be a bool"
        assert isinstance(t2i_images, bool), "t2i_images must be a bool"
        assert isinstance(i2i_images, bool), "i2i_images must be a bool"
        assert isinstance(num_images, int), "num_images must be an int"
        assert isinstance(threshold, float) and 0 < threshold < 1, "threshold must be a float and between 0 and 1"

        if isinstance(img, torch.Tensor):
            img = T.ToPILImage()(img)

        retrieved_images_paths = []
        retrieved_images_similarity = torch.zeros(num_images)
        with torch.no_grad():
            image_embedding = clip_model.encode_image(preprocess(img).unsqueeze(0).cuda())
            image_embedding /= image_embedding.norm()

        for class_name in os.listdir(data_path):
            class_path = os.path.join(data_path, class_name)
            for gen_images_class in os.listdir(class_path):
                gen_images_class_path = os.path.join(class_path,gen_images_class)
                gen_prompt_embedding = torch.load(os.path.join(gen_images_class_path, "prompt_clip_embedding.pt"))
                t2i_similarity = F.cosine_similarity(image_embedding, gen_prompt_embedding)
                if t2i_images:
                    t2i_gen_images_main_path = os.path.join(gen_images_class_path,"t2i_gen_images")
                    try: # needed bc some prompts don't have a corresponding image yet
                        for t2i_images_paths in os.listdir(t2i_gen_images_main_path):
                            t2i_image_path = os.path.join(t2i_gen_images_main_path,t2i_images_paths)
                            gen_image_embedding = torch.load(os.path.join(t2i_image_path, "image_embedding.pt"))
                            i2i_similarity = F.cosine_similarity(image_embedding, gen_image_embedding)
                            if use_t2i_similarity:
                                similarity = (i2i_similarity + t2i_similarity)/2 # avg similarity
                            else:
                                similarity = i2i_similarity
                            if similarity < threshold: continue
                            if len(retrieved_images_paths) < num_images:
                                retrieved_images_similarity[len(retrieved_images_paths)] = similarity
                                retrieved_images_paths.append(os.path.join(t2i_image_path, "image.png"))
                            else:
                                min_similarity, id_similarity = retrieved_images_similarity.min(dim=0)
                                if similarity > min_similarity:
                                    retrieved_images_similarity[id_similarity] = similarity
                                    retrieved_images_paths[id_similarity] = os.path.join(t2i_image_path, "image.png")
                    except:
                        pass
                if i2i_images:
                    i2i_gen_images_main_path = os.path.join(gen_images_class_path,"i2i_gen_images")
                    for i2i_images_paths in os.listdir(i2i_gen_images_main_path):
                        i2i_image_path = os.path.join(i2i_gen_images_main_path,i2i_images_paths)
                        gen_image_embedding = torch.load(os.path.join(i2i_image_path, "image_embedding.pt"))

                        i2i_similarity = F.cosine_similarity(image_embedding, gen_image_embedding)
                        if use_t2i_similarity:
                            similarity = (i2i_similarity + t2i_similarity)/2 # avg similarity
                        else:
                            similarity = i2i_similarity
                        if similarity < threshold: continue
                        if len(retrieved_images_paths) < num_images:
                            retrieved_images_similarity[len(retrieved_images_paths)] = similarity
                            retrieved_images_paths.append(os.path.join(t2i_image_path, "image.png"))
                        else:
                            min_similarity, id_similarity = retrieved_images_similarity.min(dim=0)
                            if similarity > min_similarity:
                                retrieved_images_similarity[id_similarity] = similarity
                                retrieved_images_paths[id_similarity] = os.path.join(i2i_image_path, "image.png")

        retrieved_images = []
        for image_path in retrieved_images_paths:
            retrieved_images.append(img_to_tensor_pipe(Image.open(image_path)))
        retrieved_images = torch.stack(retrieved_images) if len(retrieved_images) >= 1 else None

        return retrieved_images

In [None]:
imagenetA_generator = ImageGenerator()

In [None]:
# generate prompts
skipped_classes = imagenetA_generator.generate_prompts(
    num_prompts_per_class=20,
    style_of_picture="photograph",
    path="/home/sagemaker-user/Domain-Shift-Computer-Vision/imagenetA_generated",
    context_llm = "/home/sagemaker-user/Domain-Shift-Computer-Vision/test_time_adaptation/image_generation/llm_context.json",
    llm_model = "llama3.1",
    clip_text_encoder = "ViT-L/14"
)

In [None]:
# generate images
model_id = "runwayml/stable-diffusion-v1-5"
pipet2i = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipet2i.scheduler = DPMSolverMultistepScheduler.from_config(pipet2i.scheduler.config)
pipet2i = pipet2i.to("cuda")

In [None]:
imagenet_a_generated_path = "/home/sagemaker-user/Domain-Shift-Computer-Vision/imagenetA_generated"

In [None]:
imagenetA_generator.generate_images(path = imagenet_a_generated_path,
                                    num_images = 1,
                                    image_generation_pipeline = pipet2i,
                                    num_inference_steps = 25,
                                    guidance_scale = 9,
                                    strength=1)

#### **Retrieving Images**

In [None]:
dataloader = get_data(batch_size=32,
                      img_root = "imagenet-a",
                      split_data=False)

In [None]:
candle_img = dataloader.dataset[20][0]

In [None]:
clip_image_encoder = "ViT-L/14"
clip_model, preprocess = clip.load(clip_image_encoder)

In [None]:
imagenet_a_generated_path = "/home/sagemaker-user/Domain-Shift-Computer-Vision/imagenetA_generated

In [None]:
retrieved_images = retrieve_gen_images(img = candle_img,
                                       num_images = 30,
                                       data_path = imagenet_a_generated_path,
                                       clip_model = clip_model,
                                       preprocess = preprocess,
                                       t2i_images = True,
                                       use_t2i_similarity = False,
                                       threshold = 0.7)

In [None]:
scrape_images_imagenetA(img_style = "a photo of",
                        imgenetA_gen_path = "/home/sagemaker-user/Domain-Shift-Computer-Vision/imagenetA_generated",
                        limit = 5)

#### **ImageNet-A classes**

The **imagenetA_classes.json** file provides a mapping between the synset IDs used in the ImageNet dataset and their corresponding class names. This mapping is essential for converting model outputs from synset IDs to human-readable labels.

The class IDs are mainly used to create the directory structure where newly generated images will be stored, forming a secondary dataset. This dataset can then be utilized for training and inference activities. For more details, refer to the **ToDo** section.

#### **Install and run Ollama**

To create the new images, we decided to use a StableDiffusion model to which we provided a list of prompts as input, one for each image to be generated. To generate the prompts, we used an LLM, specifically Llama 3.1. We were able to obtain this model through the Ollama library<sup>[6]</sup>, which provides the model and all the necessary configurations for its use. To get everything needed to download and run Ollama, and therefore Llama 3.1, simply execute the file **install_and_run_ollama.sh**. The script will download and install Ollama, start it, and instantiate Llama 3.1.

#### **LLM context**

The **llm_context.json** file is used to provide context for generating prompts for a text-to-image generator model.

Each entry in the file specifies a Role and Content:

* **Role**: Specifies who is speaking or interacting in the conversation. It can either be "system" (representing the LLM) or "user" (representing the person providing input to the LLM).
* **Content**: This field contains the actual message or instructions being communicated. It includes system instructions or user-provided input regarding the prompts to be generated.

The messages can be of the following types:
1.    A message that sets the system's role and context, instructing the LLM on how to generate prompts for a text-to-image generation task.
2.    A message that serves as an example input a user might provide. It helps demonstrate how the user specifies the class name, number of prompts, and style of the picture.
3    A message that provides an example output from the LLM, demonstrating the kind of response it should generate based on the provided input.

In conclusion, the content of the file provides a framework that guides the LLM in understanding the context of the conversation, adhering to rules, and producing the desired output format.

## **Testing**

The Tester class is designed to facilitate the running of experiments involving a deep neural network model. It provides methods to manage various aspects of the experimental setup, including configuring models and optimizers, handling augmentations, computing statistics, and saving results.

In [None]:
class Tester:
    """
    A class to run all the experiments. It stores all the informations to reproduce the experiments in a json file
    at exp_path.
    """
    def __init__(self, model, optimizer, exp_path, device):
        self.__model = model
        self.__optimizer = optimizer
        self.__device = device
        self.__exp_path = exp_path

    def save_result(self, accuracy, path_result, num_augmentations, augmentations, seed_augmentations, top_augmentations, MEMO, num_adaptation_steps, lr_setting, weights, prior_strength, time_test, use_MC, gen_aug_settings):
        """
        Takes all information of the experiment saves it in a json file stored at exp_path
        """
        data = {
            "accuracy": accuracy,
            "top_augmentations" : top_augmentations,
            "use_MEMO" : MEMO,
            "num_adaptation_steps" : num_adaptation_steps,
            "lr_setting" : lr_setting,
            "weights" : weights,
            "num_augmentations" : num_augmentations,
            "seed_augmentations": seed_augmentations,
            "augmentations" : [str(augmentation) for augmentation in augmentations],
            "prior_strength" : prior_strength,
            "MC" : use_MC,
            "time_test" : time_test,
            "gen_aug_settings" : gen_aug_settings
        }
        try:
            with open(path_result, 'w') as json_file:
                json.dump(data, json_file)
        except:
            print("Result were not saved")

    def get_model(self, weights_imagenet, MC):
        """
        Utility function to instantiate a torch model. The argument weights_imagenet should have
        a value in accordance with the parameter weights of torchvision.models.
        """
        if MC:
            self.__model=ResNet50Dropout(weights=weights_imagenet, dropout_rate=MC['dropout_rate'])
            model = self.__model
        else:
            model = self.__model(weights=weights_imagenet)

        model.to(self.__device)
        model.eval()
        return model

    def get_optimizer(self, model, lr_setting:list):
        """
        Utility function to instantiate a torch optimizer.
        ----------
        lr_setting: must be a list containing either one global lr for the whole model or a dictionary
        where each value is a list with a list of parameters' names and a lr for those parameters.
        e.g.
        lr_setting = [{
            "classifier" : [["fc.weight", "fc.bias"], 0.00025]
            }, 0]
        lr_setting = [0.00025]
        """
        if len(lr_setting) == 2:
            layers_groups = []
            lr_optimizer = []
            for layers, lr_param_name in lr_setting[0].items():
                layers_groups.extend(lr_param_name[0])
                params = [param for name, param in model.named_parameters() if name in lr_param_name[0]]
                lr_optimizer.append({"params":params, "lr": lr_param_name[1]})
            other_params = [param for name, param in model.named_parameters() if name not in layers_groups]
            lr_optimizer.append({"params":other_params})
            optimizer = self.__optimizer(lr_optimizer, lr = lr_setting[1], weight_decay = 0)
        else:
            optimizer = self.__optimizer(model.parameters(), lr = lr_setting[0], weight_decay = 0)
        return optimizer

    def get_imagenetA_masking(self):
        """
        All torchvision models output a tensor [B,1000] with "B" being the batch dimension. This function
        returns a list of indices to apply to the model's output to use the model on imagenet-A dataset.
        ----------
        indices_in_1k: list of indices to map [B,1000] -> [B,200]
        """
        imagenetA_masking_path = "/home/sagemaker-user/Domain-Shift-Computer-Vision/utility/data/imagenetA_masking.json"
        with open(imagenetA_masking_path, 'r') as json_file:
            imagenetA_masking = json.load(json_file)
        indices_in_1k = [int(k) for k in imagenetA_masking if imagenetA_masking[k] != -1]
        return indices_in_1k

    def get_monte_carlo_statistics(self, mc_logits):
        """
        Compute mean, median, mode and standard deviation of the Monte Carlo samples.
        """
        statistics = {}
        mean_logits = mc_logits.mean(dim=0)
        statistics['mean'] = mean_logits

        median_logits = mc_logits.median(dim=0).values
        statistics['median'] = median_logits

        pred_classes = mc_logits.argmax(dim=1)
        pred_classes_cpu = pred_classes.cpu().numpy()
        mode_predictions, _ = stats.mode(pred_classes_cpu, axis=0)
        mode_predictions = torch.tensor(mode_predictions.squeeze(), dtype=torch.long)
        statistics['mode'] = mode_predictions

        uncertainty = mc_logits.var(dim=0)
        statistics['std'] = uncertainty
        return statistics

    def get_prediction(self, image_tensors, model, masking, TTA = False, top_augmentations = 0, MC = None):
        """
        Takes a tensor of images and outputs a prediction for each image.
        ----------
        image_tensors: is a tensor of [B,C,H,W] if TTA is used or if both MEMO and TTA are not used, or of dimension [C,H,W]
                       if only MEMO is used
        masking: a list of indices to map the imagenet1k logits to the one of imagenet-A
        top_augmentations: a non-negative integer, if greater than 0 then the "top_augmentations" with the lowest entropy are
                           selected to make the final prediction
        MC: a dictionary containing the number of evaluations using Monte Carlo Dropout and the dropout rate
        """
        if MC:
            model.train()  # enable dropout by setting the model to training mode
            mc_logits = []
            for _ in range(MC['num_samples']):
                logits = model(image_tensors)[:,masking] if image_tensors.dim() == 4 else model(image_tensors.unsqueeze(0))[:,masking]
                mc_logits.append(logits)
            mc_logits = torch.stack(mc_logits, dim=0)
            if TTA:
                # first mean is over MC samples, second mean is over TTA augmentations
                probab_augmentations = F.softmax(mc_logits - mc_logits.max(dim=2, keepdim=True)[0], dim=2)
                if top_augmentations:
                    probab_augmentations = self.get_best_augmentations(probab_augmentations, top_augmentations)
                y_pred = probab_augmentations.mean(dim=0).mean(dim=0).argmax().item()
                statistics = self.get_monte_carlo_statistics(probab_augmentations.mean(dim=1))
                return y_pred, statistics
            statistics = self.get_monte_carlo_statistics(mc_logits)
            return statistics['median'].argmax(dim=1), statistics
        else:
            logits = model(image_tensors)[:,masking] if image_tensors.dim() == 4 else model(image_tensors.unsqueeze(0))[:,masking]
            if TTA:
                probab_augmentations = F.softmax(logits - logits.max(dim=1)[0][:, None], dim=1)
                if top_augmentations:
                    probab_augmentations = self.get_best_augmentations(probab_augmentations, top_augmentations)
                y_pred = probab_augmentations.mean(dim=0).argmax().item()
                return y_pred, None
            return logits.argmax(dim=1), None

    def compute_entropy(self, probabilities: torch.tensor):
        """
        See MEMO section
        """
        return compute_entropy(probabilities)

    def get_best_augmentations(self, probabilities: torch.tensor, top_k: int):
        """
        See MEMO section
        """
        return get_best_augmentations(probabilities, top_k)

    def get_test_augmentations(self, input:torch.tensor, augmentations:list, num_augmentations:int, seed_augmentations:int):
        """
        See MEMO section
        """
        return get_test_augmentations(input, augmentations, num_augmentations, seed_augmentations)

    def retrieve_generated_images(self, img, num_images, clip_model, preprocess, img_to_tensor_pipe, data_path, use_t2i_similarity, t2i_images, i2i_images, threshold):
        """
        See Image generation
        """
        return retrieve_gen_images(img = img,
                                   num_images = num_images,
                                   clip_model = clip_model,
                                   preprocess = preprocess,
                                   img_to_tensor_pipe = img_to_tensor_pipe,
                                   data_path = data_path,
                                   use_t2i_similarity = use_t2i_similarity,
                                   t2i_images = t2i_images,
                                   i2i_images = i2i_images,
                                   threshold = threshold)

    def test(self,
             augmentations:list,
             num_augmentations:int,
             seed_augmentations:int,
             img_root:str,
             lr_setting:list,
             weights_imagenet = None,
             dataset = "imagenetA",
             batch_size = 64,
             MEMO = False,
             num_adaptation_steps = 0,
             top_augmentations = 0,
             TTA = False,
             prior_strength = -1,
             verbose = True,
             log_interval = 1,
             MC = None,
             gen_aug_settings = None):
        """
        Main function to test a torchvision model with different test-time adaptation techniques
        and keep track of the results and the experiment setting.
        ---
        augmentations: list of torchvision.transforms functions.
        num_augmentations: the number of augmentations to use for each sample to perform test-time adaptation.
        seed_augmentations: seed to reproduce the sampling of augmentations.
        img_root: str path to get a dataset in a torch format.
        lr_setting: list with lr instructions to adapt the model. See "get_optimizer" for more details.
        weights_imagenet: weights_imagenet should have a value in accordance with the parameter
                          weights of torchvision.models.
        dataset: the name of the dataset to use. Note: this parameter doesn't directly control the data
                 used, it's only used to use the right masking to map the models' outputs to the right dimensions.
                 At the moment only Imagenet-A masking is supported.
        MEMO: a boolean to use marginal entropy minimization with one test point
        TTA: a boolean to use test time augmentation
        top_augmentations: if MEMO or TTA are set to True, then values higher than zero select the top_augmentations
                           with the lowest entropy (highest confidence).
        prior_strength: defines the weight given to pre-trained statistics in BN adaptation. If negative, then no BN
                        adaptation is applied.
        verbose: use loading bar to visualize accuracy and number of batch during testing.
        log_interval: defines after how many batches a new accuracy should be displayed. Default is 1, thus
                      after each batch a new value is displayed.
        num_adaptation_steps: (TODO)
        MC: dictionary containing the number of evaluations using Monte Carlo Dropout and the dropout rate.
		    gen_aug_settings
        """
        # check some basic conditions
        assert bool(num_adaptation_steps) == MEMO, "When using MEMO adaptation steps should be >= 1, otherwise equal to 0."
        if not (MEMO or TTA):
            assert not (num_augmentations or top_augmentations), "If both MEMO and TTA are set to False, then top_augmentations and num_augmentations must be 0"
        assert not lr_setting if not MEMO else True, "If MEMO is false, then lr_setting must be None"
        assert isinstance(prior_strength, (float,int)) , "Prior adaptation must be either a float or an int"
        assert isinstance(gen_aug_settings, dict), "gen_aug_settings must be a dict containing settings to retrieve the generated images"

        # get the name of the weigths used and define the name of the experiment
        weights_name = str(weights_imagenet).split(".")[-1] if weights_imagenet else "MEMO_repo"
        use_MC = True if MC else False
        name_result = f"MEMO_{MEMO}_AdaptSteps_{num_adaptation_steps}_adaptBN_{prior_strength}_TTA_{TTA}_aug_{num_augmentations}_topaug_{top_augmentations}_seed_aug_{seed_augmentations}_weights_{weights_name}_MC_{use_MC}_genAug_{bool(gen_aug_settings)}"
        path_result = os.path.join(self.__exp_path,name_result)
        assert not os.path.exists(path_result),f"MEMO test already exists: {path_result}"

        # in case of using dropout, check if the model is a ResNet50Dropout and the parameters are correct
        if MC:
            assert isinstance(self.__model, ResNet50Dropout), f"To use dropout the model must be a ResNet50Dropout"
            assert MC['num_samples'] > 1, f"To use dropout the number of samples must be greater than 1"

        # transformation pipeline used in ResNet-50 original training
        transform_loader = T.Compose([
            T.Resize(256),
            T.CenterCrop(224),
            T.ToTensor()
        ])

        # to use after model's update
        normalize_input = T.Compose([
                        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                    ])

        test_loader = get_data(batch_size, img_root, transform = transform_loader, split_data=False)
        model = self.get_model(weights_imagenet, MC)

        # if MEMO is used, create a checkpoint to reload after each model and optimizer update
        if MEMO:
            optimizer = self.get_optimizer(model = model, lr_setting = lr_setting)
            MEMO_checkpoint_path = os.path.join(self.__exp_path,"checkpoint.pth")
            torch.save({
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, MEMO_checkpoint_path)
            MEMO_checkpoint = torch.load(MEMO_checkpoint_path)

        if dataset == "imagenetA":
            imagenetA_masking = self.get_imagenetA_masking()

        if gen_aug_settings:
            clip_model, preprocess = clip.load(gen_aug_settings["clip_img_encoder"])

        if prior_strength < 0:
            torch.nn.BatchNorm2d.prior_strength = 1
        else:
            torch.nn.BatchNorm2d.prior_strength = prior_strength / (prior_strength + 1)
            torch.nn.BatchNorm2d.forward = adaptive_bn_forward

        # Initialize a dictionary to store accumulated time for each step
        time_dict = {
            "MEMO_update": 0.0,
            "get_augmentations": 0.0,
            "confidence_selection": 0.0,
            "get_prediction": 0.0,
            "get_gen_images" : 0.0,
            "total_time": 0.0
        }

        samples = 0.0
        cumulative_accuracy = 0.0

        for batch_idx, (inputs, targets) in enumerate(test_loader):
            inputs, targets = inputs.to(self.__device), targets.to(self.__device)
            if MEMO or TTA:
                for input, target in zip(inputs, targets):
                    if MEMO:
                        model.load_state_dict(MEMO_checkpoint['model'])
                        model.eval()
                        optimizer.load_state_dict(MEMO_checkpoint['optimizer'])

                    # get normalized augmentations
                    start_time_augmentations = time.time()
                    test_augmentations = self.get_test_augmentations(input, augmentations, num_augmentations, seed_augmentations)
                    end_time_augmentations = time.time()
                    time_dict["get_augmentations"] += (end_time_augmentations - start_time_augmentations)

                    if gen_aug_settings:
                        start_time_gen_augmentations = time.time()
                        retrieved_gen_images = self.retrieve_generated_images(img = input,
                                                                              num_images = gen_aug_settings["num_img"],
                                                                              clip_model = clip_model,
                                                                              preprocess = preprocess,
                                                                              img_to_tensor_pipe = transform_loader,
                                                                              data_path = gen_aug_settings["gen_data_path"],
                                                                              use_t2i_similarity = gen_aug_settings["use_t2i_similarity"],
                                                                              t2i_images = gen_aug_settings["t2i_img"],
                                                                              i2i_images = gen_aug_settings["i2i_img"],
                                                                              threshold = gen_aug_settings["threshold"])
                        if retrieved_gen_images:
                            test_augmentations = torch.cat([test_augmentations,retrieved_gen_images],dim=0)
                        end_time_gen_augmentations = time.time()
                        time_dict["get_gen_images"] += (end_time_gen_augmentations - start_time_gen_augmentations)

                    test_augmentations = test_augmentations.to(self.__device)

                    for _ in range(num_adaptation_steps):
                        logits = model(test_augmentations)

                        # apply imagenetA masking
                        if dataset == "imagenetA":
                            logits = logits[:, imagenetA_masking]
                        # compute stable softmax
                        probab_augmentations = F.softmax(logits - logits.max(dim=1)[0][:, None], dim=1)

                        # confidence selection for augmentations
                        if top_augmentations:
                            start_time_confidence_selection = time.time()
                            probab_augmentations = self.get_best_augmentations(probab_augmentations, top_augmentations)
                            end_time_confidence_selection = time.time()
                            time_dict["confidence_selection"] += (end_time_confidence_selection - start_time_confidence_selection)

                        if MEMO:
                            start_time_memo_update = time.time()
                            marginal_output_distribution = torch.mean(probab_augmentations, dim=0)
                            marginal_loss = self.compute_entropy(marginal_output_distribution)
                            marginal_loss.backward()
                            optimizer.step()
                            optimizer.zero_grad()
                            end_time_memo_update = time.time()
                            time_dict["MEMO_update"] += (end_time_memo_update - start_time_memo_update)

                    start_time_prediction = time.time()
                    with torch.no_grad():
                        if TTA:
                            # statistics:
                            # dictionary containing statistics resulting from the application of monte carlo dropout
                            # look at get_monte_carlo_statistics() for more details
                            y_pred, statistics = self.get_prediction(test_augmentations, model, imagenetA_masking, TTA, top_augmentations, MC=MC)
                        else:
                            input = normalize_input(input)
                            y_pred, statistics = self.get_prediction(input, model, imagenetA_masking, MC=MC)
                        cumulative_accuracy += int(target == y_pred)
                    end_time_prediction = time.time()
                    time_dict["get_prediction"] += (end_time_prediction - start_time_prediction)
            else:
                start_time_prediction = time.time()
                with torch.no_grad():
                    inputs = normalize_input(inputs)
                    y_pred, _ = self.get_prediction(inputs, model, imagenetA_masking, MC=MC)
                    correct_predictions = (targets == y_pred).sum().item()
                    cumulative_accuracy += correct_predictions
                end_time_prediction = time.time()
                time_dict["get_prediction"] += (end_time_prediction - start_time_prediction)

            samples += inputs.shape[0]

            if verbose and batch_idx % log_interval == 0:
                current_accuracy = cumulative_accuracy / samples * 100
                print(f'Batch {batch_idx}/{len(test_loader)}, Accuracy: {current_accuracy:.2f}%', end='\r')

        accuracy = cumulative_accuracy / samples * 100
        time_dict["total_time"] += sum(time_dict.values())

        self.save_result(accuracy = accuracy,
                         path_result = path_result,
                         seed_augmentations = seed_augmentations,
                         num_augmentations = num_augmentations,
                         augmentations = augmentations,
                         top_augmentations = top_augmentations,
                         MEMO = MEMO,
                         num_adaptation_steps = num_adaptation_steps,
                         lr_setting = lr_setting,
                         weights = weights_name,
                         prior_strength = prior_strength,
                         use_MC = use_MC,
                         time_test = time_dict,
                         gen_aug_settings = gen_aug_settings)

        return accuracy

To run the solution, simply execute the following code sections.

In [None]:
imagenet_a_path = "imagenet-a"
imagenet_b_path = "imagenetv2-matched-frequency-format-val/"

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [None]:

augmentations = [
    T.RandomHorizontalFlip(p=1),
    T.RandomVerticalFlip(p=1),
    T.RandomRotation(degrees=30),
    T.RandomRotation(degrees=60),
    T.ColorJitter(brightness=0.2),
    T.ColorJitter(contrast=0.2),
    T.ColorJitter(saturation=0.2),
    T.ColorJitter(hue=0.2),
    T.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    T.RandomRotation(degrees=15),
    T.RandomAdjustSharpness(sharpness_factor=2, p=1),
    T.RandomGrayscale(p=1),
    T.RandomInvert(p=1),
    T.RandomAutocontrast(p=1),
    T.GaussianBlur(kernel_size=5),
]

augmix_augmentations = [
    T.AugMix(severity=3, mixture_width=3, chain_depth=3, alpha=1.0),
    T.AugMix(severity=2, mixture_width=2, chain_depth=3, alpha=1.0),
    T.AugMix(severity=4, mixture_width=4, chain_depth=3, alpha=1.0),
    T.AugMix(severity=3, mixture_width=3, chain_depth=3, alpha=1.0),
    T.AugMix(severity=2, mixture_width=2, chain_depth=3, alpha=1.0),
    T.AugMix(severity=4, mixture_width=4, chain_depth=3, alpha=1.0),
    T.AugMix(severity=3, mixture_width=3, chain_depth=3, alpha=1.0),
    T.AugMix(severity=2, mixture_width=2, chain_depth=3, alpha=1.0),
    T.AugMix(severity=4, mixture_width=4, chain_depth=3, alpha=1.0),
    T.AugMix(severity=3, mixture_width=3, chain_depth=3, alpha=1.0),
    T.AugMix(severity=2, mixture_width=2, chain_depth=3, alpha=1.0),
    T.AugMix(severity=4, mixture_width=4, chain_depth=3, alpha=1.0),
    T.AugMix(severity=3, mixture_width=3, chain_depth=3, alpha=1.0),
    T.AugMix(severity=2, mixture_width=2, chain_depth=3, alpha=1.0),
    T.AugMix(severity=4, mixture_width=4, chain_depth=3, alpha=1.0),
    T.AugMix(severity=3, mixture_width=3, chain_depth=3, alpha=1.0),
    T.AugMix(severity=2, mixture_width=2, chain_depth=3, alpha=1.0),
    T.AugMix(severity=4, mixture_width=4, chain_depth=3, alpha=1.0)
]

### Resnet50

In [None]:
exp_path_a = "/home/sagemaker-user/Domain-Shift-Computer-Vision/experiments/Resnet50_ImagenetA_SGD"

In [None]:
MC = {
	"dropout_rate": 0.3,
	"num_samples": 10,
	"use_dropout": False
}

In [None]:
tester_resnet50 = Tester(
    model = ResNet50Dropout() if MC['use_dropout'] else models.resnet50,
    optimizer = torch.optim.SGD,
    exp_path = exp_path_a,
    device = device
)

In [None]:
#lr_setting = [{
#    "classifier" : [["fc.weight", "fc.bias"], 0.00025]
#}, 0]
lr_setting_sgd = [0.00025] # setting used in MEMO paper for SGD
lr_setting_adam = [0.0001] # setting used in MEMO paper for ADAM

In [None]:
imagenetV1_weights = models.ResNet50_Weights.IMAGENET1K_V1 # MEMO paper used these weights
imagenetV2_weights = models.ResNet50_Weights.IMAGENET1K_V2

In [None]:
gen_aug_settings = {
    "clip_img_encoder" : "ViT-L/14",
    "num_img" : 30,
    "gen_data_path" : "/home/sagemaker-user/Domain-Shift-Computer-Vision/imagenetA_generated",
    "use_t2i_similarity" : True,
    "t2i_img" : True,
    "i2i_img" : False,
    "threshold" : 0.7
}

In [None]:
tester_resnet50.test(
     augmentations = augmix_augmentations,
     num_augmentations = 16,
     seed_augmentations = 42,
     batch_size = 64,
     img_root = imagenet_a_path,
     num_adaptation_steps = 2,
     MEMO = True,
     lr_setting = lr_setting_sgd,
     top_augmentations = 0, # if using gen_aug run with 0 bc otherwise might not be used at all
     weights_imagenet = imagenetV1_weights,
     prior_strength = 16,
     TTA = True,
     MC = None,
     gen_aug_settings = gen_aug_settings
)

## **Results**

<table>
  <thead>
    <tr>
      <th rowspan="1">Experiment</th>
      <th rowspan="1">Dataset</th>
      <th rowspan="1">Base Model</th>
      <th rowspan="1">Weights</th>
      <th rowspan="1", colspan="2">Optimizer</th>
      <th rowspan="1">Optimization Steps</th>
      <th rowspan="1", colspan="3">Augmentations</th>
      <th rowspan="1">Batch Size</th>
      <th rowspan="1">MEMO</th>
      <th rowspan="1">Confidence Selection</th>
      <th rowspan="1", colspan="2">BN</th>
      <th rowspan="1">TTA</th>
      <th rowspan="1", colspan="3">MC</th>
      <th rowspan="1">Accuracy</th>
      <th rowspan="1">Execution Time</th>
    </tr>
    <tr>
      <th>Nr.</th>
      <th></th>
      <th></th>
      <th></th>
      <th>Type</th>
      <th>LR</th>
      <th>Nr.</th>
      <th>Type</th>
      <th>Number</th>
      <th>Seed</th>
      <th></th>
      <th></th>
      <th></th>
      <th></th>
      <th>Prior Strength</th>
      <th></th>
      <th></th>
      <th>Dropout rate</th>
      <th>Nr. Samples</th>
      <th>%</th>
      <th></th>
    </tr>
  </thead>
  <tbody align="center">
    <tr>
      <th>1</th>
      <td>ImageNet-A</td>
      <td>ResNet50</td>
      <td>Imagenet_1K_V1</td>
      <td>SGD</td>
      <td>-</td>
      <td>1</td>
      <td>-</td>
      <td>-</td>
      <td>-</td>
      <td>64</td>
      <td>False</td>
      <td>0</td>
      <td>False</td>
      <td>-</td>
      <td>False</td>
      <td>False</td>
      <td>-</td>
      <td>-</td>
      <td>0.026</td>
      <td>00:00:20</td>
    </tr>
    <tr>
      <th>2</th>
      <td>ImageNet-A</td>
      <td>ResNet50</td>
      <td>Imagenet_1K_V1</td>
      <td>SGD</td>
      <td>-</td>
      <td>1</td>
      <td>AugMix</td>
      <td>16</td>
      <td>42</td>
      <td>64</td>
      <td>False</td>
      <td>8</td>
      <td>False</td>
      <td>-</td>
      <td>True</td>
      <td>False</td>
      <td>-</td>
      <td>-</td>
      <td>0.253</td>
      <td>00:26:20</td>
    </tr>
    <tr>
      <th>3</th>
      <td>ImageNet-A</td>
      <td>ResNet50</td>
      <td>Imagenet_1K_V1</td>
      <td>SGD</td>
      <td>-</td>
      <td>1</td>
      <td>-</td>
      <td>-</td>
      <td>-</td>
      <td>64</td>
      <td>False</td>
      <td>0</td>
      <td>False</td>
      <td>-</td>
      <td>False</td>
      <td>True</td>
      <td>0.20</td>
      <td>10</td>
      <td>1.613</td>
      <td>00:03:48</td>
    </tr>
    <tr>
      <th>4</th>
      <td>ImageNet-A</td>
      <td>ResNet50</td>
      <td>Imagenet_1K_V1</td>
      <td>SGD</td>
      <td>-</td>
      <td>1</td>
      <td>-</td>
      <td>-</td>
      <td>-</td>
      <td>64</td>
      <td>False</td>
      <td>0</td>
      <td>True</td>
      <td>16</td>
      <td>False</td>
      <td>False</td>
      <td>-</td>
      <td>-</td>
      <td>0.026</td>
      <td>00:00:23</td>
    </tr>
    <tr>
      <th>5</th>
      <td>ImageNet-A</td>
      <td>ResNet50</td>
      <td>Imagenet_1K_V1</td>
      <td>SGD</td>
      <td>0.00025</td>
      <td>1</td>
      <td>AugMix</td>
      <td>16</td>
      <td>42</td>
      <td>64</td>
      <td>True</td>
      <td>8</td>
      <td>False</td>
      <td>-</td>
      <td>False</td>
      <td>False</td>
      <td>-</td>
      <td>-</td>
      <td>0.16</td>
      <td>00:33:20</td>
    </tr>
    <tr>
      <th>6</th>
      <td>ImageNet-A</td>
      <td>ResNet50</td>
      <td>Imagenet_1K_V1</td>
      <td>SGD</td>
      <td>0.00025</td>
      <td>1</td>
      <td>AugMix</td>
      <td>16</td>
      <td>42</td>
      <td>64</td>
      <td>True</td>
      <td>8</td>
      <td>True</td>
      <td>16</td>
      <td>True</td>
      <td>False</td>
      <td>-</td>
      <td>-</td>
      <td>0.853</td>
      <td>00:38:44</td>
    </tr>
    <tr>
      <th>7</th>
      <td>ImageNet-A</td>
      <td>ResNet50</td>
      <td>Imagenet_1K_V1</td>
      <td>SGD</td>
      <td>0.00025</td>
      <td>1</td>
      <td>AugMix</td>
      <td>16</td>
      <td>42</td>
      <td>64</td>
      <td>True</td>
      <td>8</td>
      <td>True</td>
      <td>16</td>
      <td>True</td>
      <td>True</td>
      <td>0.20</td>
      <td>10</td>
      <td>1.053</td>
      <td>01:43:29</td>
    </tr>
    <tr>
      <th>8</th>
      <td>ImageNet-A</td>
      <td>ResNet50</td>
      <td>Imagenet_1K_V1</td>
      <td>SGD</td>
      <td>0.00025</td>
      <td>4</td>
      <td>AugMix</td>
      <td>16</td>
      <td>42</td>
      <td>64</td>
      <td>True</td>
      <td>8</td>
      <td>False</td>
      <td>-</td>
      <td>False</td>
      <td>False</td>
      <td>-</td>
      <td>-</td>
      <td>0.213</td>
      <td>00:42:03</td>
    </tr>
    <tr>
      <th>9</th>
      <td>ImageNet-A</td>
      <td>ResNet50</td>
      <td>Imagenet_1K_V1</td>
      <td>SGD</td>
      <td>0.00025</td>
      <td>4</td>
      <td>AugMix</td>
      <td>16</td>
      <td>42</td>
      <td>64</td>
      <td>True</td>
      <td>8</td>
      <td>True</td>
      <td>16</td>
      <td>True</td>
      <td>False</td>
      <td>-</td>
      <td>-</td>
      <td>0.853</td>
      <td>00:47:37</td>
    </tr>
    <tr>
      <th>10</th>
      <td>ImageNet-A</td>
      <td>ResNet50</td>
      <td>Imagenet_1K_V1</td>
      <td>SGD</td>
      <td>0.00025</td>
      <td>4</td>
      <td>AugMix</td>
      <td>16</td>
      <td>42</td>
      <td>64</td>
      <td>True</td>
      <td>8</td>
      <td>True</td>
      <td>16</td>
      <td>True</td>
      <td>True</td>
      <td>0.20</td>
      <td>10</td>
      <td>0.95</td>
      <td>01:49:42</td>
    </tr>
    <tr>
      <th>11</th>
      <td>ImageNet-A</td>
      <td>ResNet50</td>
      <td>Imagenet_1K_V1</td>
      <td>ADAM</td>
      <td>0.0001</td>
      <td>1</td>
      <td>AugMix</td>
      <td>16</td>
      <td>42</td>
      <td>64</td>
      <td>True</td>
      <td>8</td>
      <td>False</td>
      <td>-</td>
      <td>False</td>
      <td>False</td>
      <td>-</td>
      <td>-</td>
      <td>0.12</td>
      <td>00:31:45</td>
    </tr>
    <tr>
      <th>12</th>
      <td>ImageNet-A</td>
      <td>ResNet50</td>
      <td>Imagenet_1K_V1</td>
      <td>ADAM</td>
      <td>0.0001</td>
      <td>1</td>
      <td>AugMix</td>
      <td>16</td>
      <td>42</td>
      <td>64</td>
      <td>True</td>
      <td>8</td>
      <td>True</td>
      <td>16</td>
      <td>True</td>
      <td>False</td>
      <td>-</td>
      <td>-</td>
      <td>0.826</td>
      <td>00:37:37</td>
    </tr>
    <tr>
      <th>13</th>
      <td>ImageNet-A</td>
      <td>ResNet50</td>
      <td>Imagenet_1K_V1</td>
      <td>ADAM</td>
      <td>0.0001</td>
      <td>1</td>
      <td>AugMix</td>
      <td>16</td>
      <td>42</td>
      <td>64</td>
      <td>True</td>
      <td>8</td>
      <td>True</td>
      <td>16</td>
      <td>True</td>
      <td>True</td>
      <td>0.20</td>
      <td>10</td>
      <td>1.506</td>
      <td>01:39:47</td>
    </tr>
    <tr>
      <th>14</th>
      <td>ImageNet-A</td>
      <td>ResNet50</td>
      <td>Imagenet_1K_V1</td>
      <td>ADAM</td>
      <td>0.0001</td>
      <td>4</td>
      <td>AugMix</td>
      <td>16</td>
      <td>42</td>
      <td>64</td>
      <td>True</td>
      <td>8</td>
      <td>False</td>
      <td>-</td>
      <td>False</td>
      <td>False</td>
      <td>-</td>
      <td>-</td>
      <td>0.213</td>
      <td>00:41:19</td>
    </tr>
    <tr>
      <th>15</th>
      <td>ImageNet-A</td>
      <td>ResNet50</td>
      <td>Imagenet_1K_V1</td>
      <td>ADAM</td>
      <td>0.0001</td>
      <td>4</td>
      <td>AugMix</td>
      <td>16</td>
      <td>42</td>
      <td>64</td>
      <td>True</td>
      <td>8</td>
      <td>True</td>
      <td>16</td>
      <td>True</td>
      <td>False</td>
      <td>-</td>
      <td>-</td>
      <td>0.826</td>
      <td>00:46:55</td>
    </tr>
    <tr>
      <th>16</th>
      <td>ImageNet-A</td>
      <td>ResNet50</td>
      <td>Imagenet_1K_V1</td>
      <td>ADAM</td>
      <td>0.0001</td>
      <td>4</td>
      <td>AugMix</td>
      <td>16</td>
      <td>42</td>
      <td>64</td>
      <td>True</td>
      <td>8</td>
      <td>True</td>
      <td>16</td>
      <td>True</td>
      <td>True</td>
      <td>0.20</td>
      <td>10</td>
      <td>1.04</td>
      <td>01:44:57</td>
    </tr>  
  </tbody>
</table>


## **Discussion**

Discussione dei risultati

## **Conclusion**

wrap up di quello fatto e possibili alternative per future sperimentazioni

## **Bibliography**

1. **Marvin Zhang and Sergey Levine and Chelsea Finn.** "MEMO: Test Time Robustness via Adaptation and Augmentation." Advances in neural information processing systems, Vol. 35, 2021, pp. 38629-38642. [https://arxiv.org/abs/2110.09506](https://arxiv.org/abs/2110.09506).

2. **Lyzhov, Alexander and Molchanova, Yuliya and Ashukha, Arsenii and Molchanov, Dmitry and Vetrov, Dmitry.** "Greedy Policy Search: A Simple Baseline for Learnable Test-Time Augmentation." Proceedings of Machine Learning Research, Vol. 124, 2020, pp. 1308-1317. [https://proceedings.mlr.press/v124/lyzhov20a.html](https://proceedings.mlr.press/v124/lyzhov20a.html).

3. **Schneider, Steffen and Rusak, Evgenia and Eck, Luisa and Bringmann, Oliver and Brendel, Wieland and Bethge, Matthias.** "Improving robustness against common corruptions by covariate shift adaptation." Advances in Neural Information Processing Systems, Vol. 33, 2020, pp. 11539-11551. [https://proceedings.neurips.cc/paper_files/paper/2020/file/85690f81aadc1749175c187784afc9ee-Paper.pdf](https://proceedings.neurips.cc/paper_files/paper/2020/file/85690f81aadc1749175c187784afc9ee-Paper.pdf).

4. **Gal, Yarin and Ghahramani, Zoubin** "Dropout as a Bayesian Approximation: Representing Model Uncertainty in Deep Learning." Proceedings of The 33rd International Conference on Machine Learning, Vol. 48, 2016, pp. 1050-1059. [https://proceedings.mlr.press/v48/gal16.html](https://proceedings.mlr.press/v48/gal16.html).

5. **Feng, Chun-Mei and Yu, Kai and Liu, Yong and Khan, Salman and Zuo, Wangmeng.** "Diverse Data Augmentation with Diffusions for Effective Test-time Prompt Tuning." 2023 IEEE/CVF International Conference on Computer Vision (ICCV), 2023, pp. 2704-2714. [https://arxiv.org/abs/2308.06038](https://arxiv.org/abs/2308.06038).

6. [Ollama library](https://ollama.com/)



