# Deep Learning - Test-Time Adaptation
---
###### University of Trento, Academic Year 2023/2024
---
##### Group 26
> <a href="https://github.com/giuseppecurci">Giuseppe Curci</a> \
> 243049

> <a href="https://github.com/andy295">Andrea Cristiano</a> \
> 229370
---
---

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset

import torchvision
import torchvision.transforms as T
import torchvision.models as models
from torchvision.transforms import Compose, Normalize, ToTensor

import matplotlib
from matplotlib import pyplot as plt
from matplotlib.lines import Line2D

from scipy import stats
from scipy.ndimage import zoom

from typing import Callable, List, Optional, Tuple
from typing import Dict, List

from io import BytesIO
from pathlib import Path
from PIL import Image
from tqdm import tqdm
from utility.data.get_data import get_data
import boto3 # read and write for AWS buckets
import clip
import cv2
import gc
import json
import math
import numpy as np
import ollama # if ollama is not available, install by executing the intall_and_run_ollama.sh script
import os
import random
import time
import ttach as tta

from test_methods.test import Tester

from test_time_adaptation.adaptive_bn import adaptive_bn_forward
from test_time_adaptation.MEMO import compute_entropy, get_best_augmentations, get_test_augmentations
from test_time_adaptation.resnet50_dropout import ResNet50Dropout

## Introduction

(spiegare Domain Shift Test-Time Adaptation cosa siano e perchè siano rilevanti in poche righe)

## Pipeline

(idealmente andrebbe creato un disegno che spieghi cosa stiamo facendo, lo possiamo fare anche dopo. Per ora puoi anche limitarti a spiegare a parole oppure lascialo e lo facciamo quando abbiamo il disegno)

## Utils



### File get_data.py

In [None]:
# Class that interacts with images stored in an Amazon S3 bucket.
# It allows to load and preprocess images on-the-fly during training or inference.
class S3ImageFolder(Dataset):
    def __init__(self, root, transform=None):
        self.s3_bucket = "deeplearning2024-datasets" # name of the bucket
        self.s3_region = "eu-west-1" # Ireland
        self.s3_client = boto3.client("s3", region_name=self.s3_region, verify=True)
        self.transform = transform

        # Get list of objects in the bucket
        response = self.s3_client.list_objects_v2(Bucket=self.s3_bucket, Prefix=root)
        objects = response.get("Contents", [])
        while response.get("NextContinuationToken"):
            response = self.s3_client.list_objects_v2(
                Bucket=self.s3_bucket,
                Prefix=root,
                ContinuationToken=response["NextContinuationToken"]
            )
            objects.extend(response.get("Contents", []))

        # Iterate and keep valid files only
        self.instances = []
        for ds_idx, item in enumerate(objects):
            key = item["Key"]
            path = Path(key)

            # Check if file is valid
            if path.suffix.lower() not in (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp"):
                continue

            # Get label
            label = path.parent.name

            # Keep track of valid instances
            self.instances.append((label, key))

        # Sort classes in alphabetical order (as in ImageFolder)
        self.classes = sorted(set(label for label, _ in self.instances))
        self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}

    def __len__(self):
        return len(self.instances)

    def __getitem__(self, idx):
        try:
            label, key = self.instances[idx]

            # Download image from S3
            # response = self.s3_client.get_object(Bucket=self.s3_bucket, Key=key)
            # img_bytes = response["Body"]._raw_stream.data

            img_bytes = BytesIO()
            response = self.s3_client.download_fileobj(Bucket=self.s3_bucket, Key=key, Fileobj=img_bytes) # download each image
            # img_bytes = response["Body"]._raw_stream.data

            # Open image with PIL
            img = Image.open(img_bytes).convert("RGB")

            # Apply transformations if any
            if self.transform is not None:
                img = self.transform(img)
        except Exception as e:
            raise RuntimeError(f"Error loading image at index {idx}: {str(e)}")

        return img, self.class_to_idx[label]

# Function to create DataLoaders for training and evaluating models.
# Loads the dataset from the S3 bucket and optionally splits it into training,
# validation, and test sets. It then returns PyTorch DataLoader objects for these datasets.
def get_data(batch_size, img_root, seed = None, split_data = False, transform = None):

    # Load data
    data = S3ImageFolder(root=img_root, transform=transform)

    if split_data:
        # Create train and test splits (80/20)
        num_samples = len(data)
        training_samples = int(num_samples * 0.8 + 1)
        val_samples = int(num_samples * 0.1)
        test_samples = num_samples - training_samples - val_samples

        torch.manual_seed(seed)
        training_data, val_data, test_data = torch.utils.data.random_split(data, [training_samples, val_samples, test_samples])

        # Initialize dataloaders
        train_loader = torch.utils.data.DataLoader(training_data, batch_size, shuffle=True, num_workers=4)
        val_loader = torch.utils.data.DataLoader(val_data, batch_size, shuffle=False, num_workers=4)
        test_loader = torch.utils.data.DataLoader(test_data, batch_size, shuffle=False, num_workers=4)

        return train_loader, val_loader, test_loader

    data_loader = torch.utils.data.DataLoader(data, batch_size, shuffle=False, num_workers=4)
    return data_loader

### File imagenetA_masking.json

The **imagenetA_masking.json** file provides a masking for ImageNet-A dataset indices to the standard 1000-class ImageNet output indices used by pre-trained models in PyTorch's torchvision library.

Each key in the file corresponds to an index in the standard 1000-class ImageNet output vector. The value associated with each key indicates whether that index should be considered when mapping the 1000-class output to a the smaller set of classes ImageNet-A.

A value of -1 indicates that the corresponding index in the 1000-class output should be ignored in the subset of outputs for ImageNet-A.
A non-negative integer value indicates that the corresponding index in the 1000-class output should be included in the subset of outputs for ImageNet-A.

### File imagenetA_classes.json

The **imagenetA_classes.json** file provides a mapping between the synset IDs used in the ImageNet dataset and their corresponding class names. This mapping is essential for converting model outputs from synset IDs to human-readable labels.

The class IDs are mainly used to create the directory structure where newly generated images will be stored, forming a secondary dataset. This dataset can then be utilized for training and inference activities. For more details, refer to the **ToDo** section.

### File dropout_positions.json

The **dropout_positions.json** file is to define the locations within a custom ResNet50 model where dropout layers should be inserted.

The dropout layers are incorporated to enhance the proposed method using Monte Carlo Dropout, a technique that improves model robustness and uncertainty estimation. For more details, refer to the **ResNet50Dropout** section.

## Test-time adaptation methods

### MEMO: Test Time Robustness via Adaptation and Augmentation<sup>[4]</sup>

In this paper, the authors propose a method called MEMO (Marginal Entropy Minimization with One Test Point) designed to address the problem of robustness in deep neural networks when confronted with distribution shifts or unexpected perturbations. To achieve this, the method employs both adaptation and augmentation strategies.

![Figure 1](\images\MEMO.png)

Unlike traditional approaches that focus on modifying the training process, MEMO utilizes the information provided by test inputs. It applies data augmentations to a single test input to generate various versions of the input, and from these, it calculates the marginal output distribution. The model's parameters are then updated to minimize the entropy of this marginal distribution. Finally, the model uses these updated parameters to make a prediction on the original test input.



This approach allows MEMO to enhance the model's robustness against unseen distribution changes by effectively adapting to new conditions during testing.

### TTA - Greedy Policy Search: A Simple Baseline for Learnable Test-Time Augmentation<sup>[5]</sup>

In this paper, the authors propose a method called Greedy Policy Search (GPS) for learning test-time data augmentation policies that enhance the performance of machine learning models. The key idea is that data augmentation policies, typically designed for the training phase, can also be learned and optimized during the test phase to improve model performance.

![Figure 2](\images\) TODO ADD image

The method involves iteratively selecting sub-policies that maximize a chosen performance criterion, such as the calibrated log-likelihood on a validation set. As the name suggests, GPS constructs the augmentation policy in a greedy, step-by-step manner, ensuring that each added sub-policy contributes to performance improvement.

This approach offers a simple yet powerful baseline for learning test-time augmentation policies, providing a promising alternative to more complex methods like reinforcement learning or Bayesian optimization, which are traditionally used for training-phase augmentation.

### Adaptive Batch Normalization - Improving robustness against common corruptions by covariate shift adaptation<sup>[6]<sup>

### Monte Carlo Dropout

Monte Carlo Dropout (MC Dropout) is a technique that leverages dropout, a regularization method commonly used during training, to also perform approximate Bayesian inference during the testing phase.

During the model training phase, the dropout technique randomly "drops" or deactivates a fraction of neurons in the network during each forward pass. The probability of dropping neurons can be controlled using a specific parameter. The aim is to prevent neurons from memorizing specific inputs, thus reducing overfitting and encouraging the network to learn more general representations.

However, during the model test phase, dropout is usually turned off, allowing the full network to make predictions. The key idea behind Monte Carlo Dropout is to keep dropout active during the test phase and perform multiple forward passes through the network. The result is that for each forward pass, a different dropout mask is used, which means a random subset of neurons is activated, allowing for different predictions. By averaging these predictions, it is possible to obtain both the final prediction and a measure of the model's uncertainty.

The technique was originally introduced by Yarin Gal and Zoubin Ghahramani in their seminal paper, "Dropout as a Bayesian Approximation: Representing Model Uncertainty in Deep Learning"[2].

In [None]:
class ResNet50Dropout(nn.Module):
"""
It creates a version of the ResNet-50 model that integrates dropout layers at
various points in the architecture. By using dropout, the model can be trained
with a regularization technique that allows for the implementation of
Monte Carlo Dropout.
----------
weights: Optional pre-trained weights for the ResNet-50 model.
dropout_rate: The probability of dropping out neurons during training.
A value of 0 means no dropout is applied and the architecture is identical to the
original ResNet-50.
"""
    def __init__(self, weights=None, dropout_rate=0.):
        super(ResNet50Dropout, self).__init__()

        self.weights = weights
        self.model = models.resnet50(weights=self.weights)
        self.dropout_rate = dropout_rate

        self.dropout_positions = []
        if self.dropout_rate > 0:
            self.dropout_positions = self.get_dropout_positions()

        self._add_dropout()

    # This method reads a JSON file that contains a list of layer names where
    # dropout should be applied.
    def get_dropout_positions(self):
        dropout_positions_path = "/home/sagemaker-user/Domain-Shift-Computer-Vision/utility/data/dropout_positions.json"
        with open(dropout_positions_path, 'r') as json_file:
            dropout_positions = json.load(json_file)
        dropout_positions = dropout_positions["dropout_positions"]

        return dropout_positions

    # This method adds dropout layers to the ResNet-50 model at the specified
    # positions, by looking at the dropout_positions list.
    # For each specified layer, the method wraps the original layer in
    # a nn.Sequential block, which includes the original layer followed by a
    # nn.Dropout layer with the specified dropout rate.
    def _add_dropout(self):
        if 'conv1' in self.dropout_positions:
            self.model.conv1 = nn.Sequential(
                self.model.conv1,
                nn.Dropout(p=self.dropout_rate)
            )

        if 'layer1' in self.dropout_positions:
            self.model.layer1 = nn.Sequential(
                self.model.layer1,
                nn.Dropout(p=self.dropout_rate)
            )

        if 'layer2' in self.dropout_positions:
            self.model.layer2 = nn.Sequential(
                self.model.layer2,
                nn.Dropout(p=self.dropout_rate)
            )

        if 'layer3' in self.dropout_positions:
            self.model.layer3 = nn.Sequential(
                self.model.layer3,
                nn.Dropout(p=self.dropout_rate)
            )

        if 'layer4' in self.dropout_positions:
            self.model.layer4 = nn.Sequential(
                self.model.layer4,
                nn.Dropout(p=self.dropout_rate)
            )

        if 'avgpool' in self.dropout_positions:
            self.model.avgpool = nn.Sequential(
                self.model.avgpool,
                nn.Dropout(p=self.dropout_rate)
            )

        if 'fc' in self.dropout_positions:
            self.model.fc = nn.Sequential(
                nn.Dropout(p=self.dropout_rate),
                self.model.fc
            )

    def forward(self, x):
        return self.model(x)

### DiffTPT - Diverse Data Augmentation with Diffusions for Effective Test-time Prompt Tuning<sup>[7]</sup>


In [None]:
# difftpt and image generation code

—-- our proposal MEMO + stable diffusion + LLM

In [None]:
def get_imagenetA_classes():
    """
    ImageNet-A uses the same label structure as the original ImageNet (ImageNet-1K).
    Each class in ImageNet is represented by a synset ID (e.g., n01440764 for "tench, Tinca tinca").
    This function returns a dictionary that maps the synset IDs of ImageNet-A to the corresponding class names.
    ----------
    indices_in_1k: list of indices to map [B,1000] -> [B,200]
    """
    imagenetA_classes_path = "/home/sagemaker-user/Domain-Shift-Computer-Vision/utility/data/imagenetA_classes.json"
    imagenetA_classes_dict = None
    with open(imagenetA_classes_path, 'r') as json_file:
        imagenetA_classes_dict = json.load(json_file)

    # Ensure `class_dict` is a dictionary with keys as class IDs and values as class names
    class_dict = {k: v for k, v in imagenetA_classes_dict.items()}
    return class_dict

def create_dir_generated_images(path):
    classes = list(get_imagenetA_classes().values())
    for class_name in classes:
        class_path = os.path.join(path, class_name)
        os.makedirs(class_path, exist_ok=True)

### File install_and_run_ollama.sh

To create the new images, we decided to use a StableDiffusion model to which we provided a list of prompts as input, one for each image to be generated. To generate the prompts, we used an LLM, specifically Llama 3.1. We were able to obtain this model through the Ollama library[3], which provides the model and all the necessary configurations for its use. To get everything needed to download and run Ollama, and therefore Llama 3.1, simply execute the file **install_and_run_ollama.sh**. The script will download and install Ollama, start it, and instantiate Llama 3.1.

### File llm_context.json

The **llm_context.json** file is used to provide context for generating prompts for a text-to-image generator model.

Each entry in the file specifies a Role and Content:

* **Role**: Specifies who is speaking or interacting in the conversation. It can either be "system" (representing the LLM) or "user" (representing the person providing input to the LLM).
* **Content**: This field contains the actual message or instructions being communicated. It includes system instructions or user-provided input regarding the prompts to be generated.

The messages can be of the following types:
1.    A message that sets the system's role and context, instructing the LLM on how to generate prompts for a text-to-image generation task.
2.    A message that serves as an example input a user might provide. It helps demonstrate how the user specifies the class name, number of prompts, and style of the picture.
3    A message that provides an example output from the LLM, demonstrating the kind of response it should generate based on the provided input.

In conclusion, the content of the file provides a framework that guides the LLM in understanding the context of the conversation, adhering to rules, and producing the desired output format.

## Image generation

In [None]:
—-- image_generator.py (TODO)

## Testing

The Tester class is designed to facilitate the running of experiments involving a deep neural network model. It provides methods to manage various aspects of the experimental setup, including configuring models and optimizers, handling augmentations, computing statistics, and saving results.

In [None]:
class Tester:
    """
    A class to run all the experiments. It stores all the informations to reproduce the experiments in a json file
    at exp_path.
    """
    def __init__(self, model, optimizer, exp_path, device):
        self.__model = model
        self.__optimizer = optimizer
        self.__device = device
        self.__exp_path = exp_path

    def save_result(self, accuracy, path_result, num_augmentations, augmentations, seed_augmentations, top_augmentations, MEMO, num_adaptation_steps, lr_setting, weights, prior_strength, time_test, use_MC):
        """
        Takes all information of the experiment saves it in a json file stored at exp_path
        """
        data = {
            "accuracy": accuracy,
            "top_augmentations" : top_augmentations,
            "use_MEMO" : MEMO,
            "num_adaptation_steps" : num_adaptation_steps,
            "lr_setting" : lr_setting,
            "weights" : weights,
            "num_augmentations" : num_augmentations,
            "seed_augmentations": seed_augmentations,
            "augmentations" : [str(augmentation) for augmentation in augmentations],
            "prior_strength" : prior_strength,
            "MC" : use_MC,
            "time_test" : time_test
        }
        try:
            with open(path_result, 'w') as json_file:
                json.dump(data, json_file)
        except:
            print("Result were not saved")

    def get_model(self, weights_imagenet, MC):
        """
        Utility function to instantiate a torch model. The argument weights_imagenet should have
        a value in accordance with the parameter weights of torchvision.models.
        """
        if MC:
            self.__model=ResNet50Dropout(weights=weights_imagenet, dropout_rate=MC['dropout_rate'])
            model = self.__model
        else:
            model = self.__model(weights=weights_imagenet)

        model.to(self.__device)
        model.eval()
        return model

    def get_optimizer(self, model, lr_setting:list):
        """
        Utility function to instantiate a torch optimizer.
        ----------
        lr_setting: must be a list containing either one global lr for the whole model or a dictionary
        where each value is a list with a list of parameters' names and a lr for those parameters.
        e.g.
        lr_setting = [{
            "classifier" : [["fc.weight", "fc.bias"], 0.00025]
            }, 0]
        lr_setting = [0.00025]
        """
        if len(lr_setting) == 2:
            layers_groups = []
            lr_optimizer = []
            for layers, lr_param_name in lr_setting[0].items():
                layers_groups.extend(lr_param_name[0])
                params = [param for name, param in model.named_parameters() if name in lr_param_name[0]]
                lr_optimizer.append({"params":params, "lr": lr_param_name[1]})
            other_params = [param for name, param in model.named_parameters() if name not in layers_groups]
            lr_optimizer.append({"params":other_params})
            optimizer = self.__optimizer(lr_optimizer, lr = lr_setting[1], weight_decay = 0)
        else:
            optimizer = self.__optimizer(model.parameters(), lr = lr_setting[0], weight_decay = 0)
        return optimizer

    def get_imagenetA_masking(self):
        """
        All torchvision models output a tensor [B,1000] with "B" being the batch dimension. This function
        returns a list of indices to apply to the model's output to use the model on imagenet-A dataset.
        ----------
        indices_in_1k: list of indices to map [B,1000] -> [B,200]
        """
        imagenetA_masking_path = "/home/sagemaker-user/Domain-Shift-Computer-Vision/utility/data/imagenetA_masking.json"
        with open(imagenetA_masking_path, 'r') as json_file:
            imagenetA_masking = json.load(json_file)
        indices_in_1k = [int(k) for k in imagenetA_masking if imagenetA_masking[k] != -1]
        return indices_in_1k

    def get_monte_carlo_statistics(self, mc_logits):
        """
        Compute mean, median, mode and standard deviation of the Monte Carlo samples.
        """
        statistics = {}
        mean_logits = mc_logits.mean(dim=0)
        statistics['mean'] = mean_logits

        median_logits = mc_logits.median(dim=0).values
        statistics['median'] = median_logits

        pred_classes = mc_logits.argmax(dim=1)
        pred_classes_cpu = pred_classes.cpu().numpy()
        mode_predictions, _ = stats.mode(pred_classes_cpu, axis=0)
        mode_predictions = torch.tensor(mode_predictions.squeeze(), dtype=torch.long)
        statistics['mode'] = mode_predictions

        uncertainty = mc_logits.var(dim=0)
        statistics['std'] = uncertainty
        return statistics

    def get_prediction(self, image_tensors, model, masking, TTA = False, top_augmentations = 0, MC = None):
        """
        Takes a tensor of images and outputs a prediction for each image.
        ----------
        image_tensors: is a tensor of [B,C,H,W] if TTA is used or if both MEMO and TTA are not used, or of dimension [C,H,W]
                       if only MEMO is used
        masking: a list of indices to map the imagenet1k logits to the one of imagenet-A
        top_augmentations: a non-negative integer, if greater than 0 then the "top_augmentations" with the lowest entropy are
                           selected to make the final prediction
        MC: a dictionary containing the number of evaluations using Monte Carlo Dropout and the dropout rate
        """
        if MC:
            model.train()  # enable dropout by setting the model to training mode
            mc_logits = []
            for _ in range(MC['num_samples']):
                logits = model(image_tensors)[:,masking] if image_tensors.dim() == 4 else model(image_tensors.unsqueeze(0))[:,masking]
                mc_logits.append(logits)
            mc_logits = torch.stack(mc_logits, dim=0)
            if TTA:
                # first mean is over MC samples, second mean is over TTA augmentations
                probab_augmentations = F.softmax(mc_logits - mc_logits.max(dim=2, keepdim=True)[0], dim=2)
                if top_augmentations:
                    probab_augmentations = self.get_best_augmentations(probab_augmentations, top_augmentations)
                y_pred = probab_augmentations.mean(dim=0).mean(dim=0).argmax().item()
                statistics = self.get_monte_carlo_statistics(probab_augmentations.mean(dim=1))
                return y_pred, statistics
            statistics = self.get_monte_carlo_statistics(mc_logits)
            return statistics['median'].argmax(dim=1), statistics
        else:
            logits = model(image_tensors)[:,masking] if image_tensors.dim() == 4 else model(image_tensors.unsqueeze(0))[:,masking]
            if TTA:
                probab_augmentations = F.softmax(logits - logits.max(dim=1)[0][:, None], dim=1)
                if top_augmentations:
                    probab_augmentations = self.get_best_augmentations(probab_augmentations, top_augmentations)
                y_pred = probab_augmentations.mean(dim=0).argmax().item()
                return y_pred, None
            return logits.argmax(dim=1), None

    def compute_entropy(self, probabilities: torch.tensor):
        """
        See MEMO.py
        """
        return compute_entropy(probabilities)

    def get_best_augmentations(self, probabilities: torch.tensor, top_k: int):
        """
        See MEMO.py
        """
        return get_best_augmentations(probabilities, top_k)

    def get_test_augmentations(self, input:torch.tensor, augmentations:list, num_augmentations:int, seed_augmentations:int):
        """
        See MEMO.py
        """
        return get_test_augmentations(input, augmentations, num_augmentations, seed_augmentations)

    def retrieve_synthetic_images(self):
        """
        Function to retrieve the synthetically generated images before test time using CLIP embeddings.
        """
        pass

    def test(self,
             augmentations:list,
             num_augmentations:int,
             seed_augmentations:int,
             img_root:str,
             lr_setting:list,
             weights_imagenet = None,
             dataset = "imagenetA",
             batch_size = 64,
             MEMO = False,
             num_adaptation_steps = 0,
             top_augmentations = 0,
             TTA = False,
             prior_strength = -1,
             verbose = True,
             log_interval = 1,
             MC = None):
        """
        Main function to test a torchvision model with different test-time adaptation techniques
        and keep track of the results and the experiment setting.
        ---
        augmentations: list of torchvision.transforms functions.
        num_augmentations: the number of augmentations to use for each sample to perform test-time adaptation.
        seed_augmentations: seed to reproduce the sampling of augmentations.
        img_root: str path to get a dataset in a torch format.
        lr_setting: list with lr instructions to adapt the model. See "get_optimizer" for more details.
        weights_imagenet: weights_imagenet should have a value in accordance with the parameter
                          weights of torchvision.models.
        dataset: the name of the dataset to use. Note: this parameter doesn't directly control the data
                 used, it's only used to use the right masking to map the models' outputs to the right dimensions.
                 At the moment only Imagenet-A masking is supported.
        MEMO: a boolean to use marginal entropy minimization with one test point
        TTA: a boolean to use test time augmentation
        top_augmentations: if MEMO or TTA are set to True, then values higher than zero select the top_augmentations
                           with the lowest entropy (highest confidence).
        prior_strength: defines the weight given to pre-trained statistics in BN adaptation. If negative, then no BN
                        adaptation is applied.
        verbose: use loading bar to visualize accuracy and number of batch during testing.
        log_interval: defines after how many batches a new accuracy should be displayed. Default is 1, thus
                      after each batch a new value is displayed.
        num_adaptation_steps: (TODO)
        MC: dictionary containing the number of evaluations using Monte Carlo Dropout and the dropout rate.
        """
        # check some basic conditions
        assert bool(num_adaptation_steps) == MEMO, "When using MEMO adaptation steps should be > 1, otherwise equal to 0."
        if not (MEMO or TTA):
            assert not (num_augmentations or top_augmentations), "If both MEMO and TTA are set to False, then top_augmentations and num_augmentations must be 0"
        assert not lr_setting if not MEMO else True, "If MEMO is false, then lr_setting must be None"
        assert isinstance(prior_strength, (float,int)) , "Prior adaptation must be either a float or an int"

        # get the name of the weigths used and define the name of the experiment
        weights_name = str(weights_imagenet).split(".")[-1] if weights_imagenet else "MEMO_repo"
        use_MC = True if MC else False
        name_result = f"MEMO_{MEMO}_AdaptSteps_{num_adaptation_steps}_adaptBN_{prior_strength}_TTA_{TTA}_aug_{num_augmentations}_topaug_{top_augmentations}_seed_aug_{seed_augmentations}_weights_{weights_name}_MC_{use_MC}"
        path_result = os.path.join(self.__exp_path,name_result)
        assert not os.path.exists(path_result),f"MEMO test already exists: {path_result}"

        # in case of using dropout, check if the model is a ResNet50Dropout and the parameters are correct
        if MC:
            assert isinstance(self.__model, ResNet50Dropout), f"To use dropout the model must be a ResNet50Dropout"
            assert MC['num_samples'] > 1, f"To use dropout the number of samples must be greater than 1"

        # transformation pipeline used in ResNet-50 original training
        transform_loader = T.Compose([
            T.Resize(256),
            T.CenterCrop(224),
            T.ToTensor()
        ])

        # to use after model's update
        normalize_input = T.Compose([
                        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                    ])

        test_loader = get_data(batch_size, img_root, transform = transform_loader, split_data=False)
        model = self.get_model(weights_imagenet, MC)

        # if MEMO is used, create a checkpoint to reload after each model and optimizer update
        if MEMO:
            optimizer = self.get_optimizer(model = model, lr_setting = lr_setting)
            MEMO_checkpoint_path = os.path.join(self.__exp_path,"checkpoint.pth")
            torch.save({
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, MEMO_checkpoint_path)
            MEMO_checkpoint = torch.load(MEMO_checkpoint_path)

        if dataset == "imagenetA":
            imagenetA_masking = self.get_imagenetA_masking()

        if prior_strength < 0:
            torch.nn.BatchNorm2d.prior_strength = 1
        else:
            torch.nn.BatchNorm2d.prior_strength = prior_strength / (prior_strength + 1)
            torch.nn.BatchNorm2d.forward = adaptive_bn_forward

        # Initialize a dictionary to store accumulated time for each step
        time_dict = {
            "MEMO_update": 0.0,
            "get_augmentations": 0.0,
            "confidence_selection": 0.0,
            "get_prediction": 0.0,
            "total_time": 0.0
        }

        samples = 0.0
        cumulative_accuracy = 0.0

        for batch_idx, (inputs, targets) in enumerate(test_loader):
            inputs, targets = inputs.to(self.__device), targets.to(self.__device)
            if MEMO or TTA:
                for input, target in zip(inputs, targets):
                    if MEMO:
                        model.load_state_dict(MEMO_checkpoint['model'])
                        model.eval()
                        optimizer.load_state_dict(MEMO_checkpoint['optimizer'])

                    # get normalized augmentations
                    start_time_augmentations = time.time()
                    test_augmentations = self.get_test_augmentations(input, augmentations, num_augmentations, seed_augmentations)
                    end_time_augmentations = time.time()
                    time_dict["get_augmentations"] += (end_time_augmentations - start_time_augmentations)

                    test_augmentations = test_augmentations.to(self.__device)
                    for _ in range(num_adaptation_steps):
                        logits = model(test_augmentations)

                        # apply imagenetA masking
                        if dataset == "imagenetA":
                            logits = logits[:, imagenetA_masking]
                        # compute stable softmax
                        probab_augmentations = F.softmax(logits - logits.max(dim=1)[0][:, None], dim=1)

                        # confidence selection for augmentations
                        if top_augmentations:
                            start_time_confidence_selection = time.time()
                            probab_augmentations = self.get_best_augmentations(probab_augmentations, top_augmentations)
                            end_time_confidence_selection = time.time()
                            time_dict["confidence_selection"] += (end_time_confidence_selection - start_time_confidence_selection)

                        if MEMO:
                            start_time_memo_update = time.time()
                            marginal_output_distribution = torch.mean(probab_augmentations, dim=0)
                            marginal_loss = self.compute_entropy(marginal_output_distribution)
                            marginal_loss.backward()
                            optimizer.step()
                            optimizer.zero_grad()
                            end_time_memo_update = time.time()
                            time_dict["MEMO_update"] += (end_time_memo_update - start_time_memo_update)

                    start_time_prediction = time.time()
                    with torch.no_grad():
                        if TTA:
                            # statistics:
                            # dictionary containing statistics resulting from the application of monte carlo dropout
                            # look at get_monte_carlo_statistics() for more details
                            y_pred, statistics = self.get_prediction(test_augmentations, model, imagenetA_masking, TTA, top_augmentations, MC=MC)
                        else:
                            input = normalize_input(input)
                            y_pred, statistics = self.get_prediction(input, model, imagenetA_masking, MC=MC)
                        cumulative_accuracy += int(target == y_pred)
                    end_time_prediction = time.time()
                    time_dict["get_prediction"] += (end_time_prediction - start_time_prediction)
            else:
                start_time_prediction = time.time()
                with torch.no_grad():
                    inputs = normalize_input(inputs)
                    y_pred = self.get_prediction(inputs, model, imagenetA_masking, MC=MC)

                    # Handle cases where targets or y_pred might be tuples
                    if isinstance(targets, tuple):
                        targets = targets[0]  # Extract the relevant tensor
                    if isinstance(y_pred, tuple):
                        y_pred = y_pred[0]  # Extract the relevant tensor

                    if targets.dim() == 0 or y_pred.dim() == 0:
                        # If both targets and y_pred are scalars
                        correct_predictions = int(targets == y_pred)
                    else:
                        # If targets and y_pred are tensors
                        correct_predictions = (targets == y_pred).sum().item()

                cumulative_accuracy += correct_predictions

                end_time_prediction = time.time()
                time_dict["get_prediction"] += (end_time_prediction - start_time_prediction)

            samples += inputs.shape[0]

            if verbose and batch_idx % log_interval == 0:
                current_accuracy = cumulative_accuracy / samples * 100
                print(f'Batch {batch_idx}/{len(test_loader)}, Accuracy: {current_accuracy:.2f}%', end='\r')

        accuracy = cumulative_accuracy / samples * 100
        time_dict["total_time"] += sum(time_dict.values())

        self.save_result(accuracy = accuracy,
                         path_result = path_result,
                         seed_augmentations = seed_augmentations,
                         num_augmentations = num_augmentations,
                         augmentations = augmentations,
                         top_augmentations = top_augmentations,
                         MEMO = MEMO,
                         num_adaptation_steps = num_adaptation_steps,
                         lr_setting = lr_setting,
                         weights = weights_name,
                         prior_strength = prior_strength,
                         use_MC = use_MC,
                         time_test = time_dict)

        return accuracy

In [None]:
- Experiments (qui metti solo le celle di codice da runnare per riprodurre i risultati):
—- resnet50 + SGD (lr, momentum, weight_decay)
—- resnet50 + ADAM (lr, momentum, weight_decay)

## Results

<table>
  <thead>
    <tr>
      <th rowspan="1">Experiment</th>
      <th rowspan="1">Dataset</th>
      <th rowspan="1">Base Model</th>
      <th rowspan="1">Weights</th>
      <th rowspan="1", colspan="2">Optimizer</th>
      <th rowspan="1">Optimization Steps</th>
      <th rowspan="1", colspan="3">Augmentations</th>
      <th rowspan="1">Batch Size</th>
      <th rowspan="1">MEMO</th>
      <th rowspan="1">Confidence Selection</th>
      <th rowspan="1", colspan="2">BN</th>
      <th rowspan="1">TTA</th>
      <th rowspan="1", colspan="3">MC</th>
      <th rowspan="1">Accuracy</th>
      <th rowspan="1">Execution Time</th>
    </tr>
    <tr>
      <th>Nr.</th>
      <th></th>
      <th></th>
      <th></th>
      <th>Type</th>
      <th>LR</th>
      <th>Nr.</th>
      <th>Type</th>
      <th>Number</th>
      <th>Seed</th>
      <th></th>
      <th></th>
      <th></th>
      <th></th>
      <th>Prior Strength</th>
      <th></th>
      <th></th>
      <th>Dropout rate</th>
      <th>Nr. Samples</th>
      <th>%</th>
      <th></th>
    </tr>
  </thead>
  <tbody align="center">
    <tr>
      <th>1</th>
      <td>ImageNet-A</td>
      <td>ResNet50</td>
      <td>Imagenet_1K_V1</td>
      <td>SGD</td>
      <td>-</td>
      <td>1</td>
      <td>-</td>
      <td>-</td>
      <td>-</td>
      <td>64</td>
      <td>False</td>
      <td>0</td>
      <td>False</td>
      <td>-</td>
      <td>False</td>
      <td>False</td>
      <td>-</td>
      <td>-</td>
      <td>0.026</td>
      <td>00:00:20</td>
    </tr>
    <tr>
      <th>2</th>
      <td>ImageNet-A</td>
      <td>ResNet50</td>
      <td>Imagenet_1K_V1</td>
      <td>SGD</td>
      <td>-</td>
      <td>1</td>
      <td>AugMix</td>
      <td>16</td>
      <td>42</td>
      <td>64</td>
      <td>False</td>
      <td>8</td>
      <td>False</td>
      <td>-</td>
      <td>True</td>
      <td>False</td>
      <td>-</td>
      <td>-</td>
      <td>0.253</td>
      <td>00:26:20</td>
    </tr>
    <tr>
      <th>3</th>
      <td>ImageNet-A</td>
      <td>ResNet50</td>
      <td>Imagenet_1K_V1</td>
      <td>SGD</td>
      <td>-</td>
      <td>1</td>
      <td>-</td>
      <td>-</td>
      <td>-</td>
      <td>64</td>
      <td>False</td>
      <td>0</td>
      <td>False</td>
      <td>-</td>
      <td>False</td>
      <td>True</td>
      <td>0.20</td>
      <td>10</td>
      <td></td>
      <td></td>
    </tr>
    <tr>
      <th>4</th>
      <td>ImageNet-A</td>
      <td>ResNet50</td>
      <td>Imagenet_1K_V1</td>
      <td>SGD</td>
      <td>-</td>
      <td>1</td>
      <td>-</td>
      <td>-</td>
      <td>-</td>
      <td>64</td>
      <td>False</td>
      <td>0</td>
      <td>True</td>
      <td>16</td>
      <td>False</td>
      <td>False</td>
      <td>-</td>
      <td>-</td>
      <td>0.026</td>
      <td>00:00:23</td>
    </tr>
    <tr>
      <th>5</th>
      <td>ImageNet-A</td>
      <td>ResNet50</td>
      <td>Imagenet_1K_V1</td>
      <td>SGD</td>
      <td>0.00025</td>
      <td>1</td>
      <td>AugMix</td>
      <td>16</td>
      <td>42</td>
      <td>64</td>
      <td>True</td>
      <td>8</td>
      <td>False</td>
      <td>-</td>
      <td>False</td>
      <td>False</td>
      <td>-</td>
      <td>-</td>
      <td>0.16</td>
      <td>00:33:20</td>
    </tr>
    <tr>
      <th>6</th>
      <td>ImageNet-A</td>
      <td>ResNet50</td>
      <td>Imagenet_1K_V1</td>
      <td>SGD</td>
      <td>0.00025</td>
      <td>1</td>
      <td>AugMix</td>
      <td>16</td>
      <td>42</td>
      <td>64</td>
      <td>True</td>
      <td>8</td>
      <td>True</td>
      <td>16</td>
      <td>True</td>
      <td>False</td>
      <td>-</td>
      <td>-</td>
      <td>0.853</td>
      <td>00:38:44</td>
    </tr>
    <tr>
      <th>7</th>
      <td>ImageNet-A</td>
      <td>ResNet50</td>
      <td>Imagenet_1K_V1</td>
      <td>SGD</td>
      <td>0.00025</td>
      <td>1</td>
      <td>AugMix</td>
      <td>16</td>
      <td>42</td>
      <td>64</td>
      <td>True</td>
      <td>8</td>
      <td>True</td>
      <td>16</td>
      <td>True</td>
      <td>True</td>
      <td>0.20</td>
      <td>10</td>
      <td></td>
      <td></td>
    </tr>
    <tr>
      <th>8</th>
      <td>ImageNet-A</td>
      <td>ResNet50</td>
      <td>Imagenet_1K_V1</td>
      <td>SGD</td>
      <td>0.00025</td>
      <td>4</td>
      <td>AugMix</td>
      <td>16</td>
      <td>42</td>
      <td>64</td>
      <td>True</td>
      <td>8</td>
      <td>False</td>
      <td>-</td>
      <td>False</td>
      <td>False</td>
      <td>-</td>
      <td>-</td>
      <td>0.213</td>
      <td>00:42:03</td>
    </tr>
    <tr>
      <th>9</th>
      <td>ImageNet-A</td>
      <td>ResNet50</td>
      <td>Imagenet_1K_V1</td>
      <td>SGD</td>
      <td>0.00025</td>
      <td>4</td>
      <td>AugMix</td>
      <td>16</td>
      <td>42</td>
      <td>64</td>
      <td>True</td>
      <td>8</td>
      <td>True</td>
      <td>16</td>
      <td>True</td>
      <td>False</td>
      <td>-</td>
      <td>-</td>
      <td>0.853</td>
      <td>00:47:37</td>
    </tr>
    <tr>
      <th>10</th>
      <td>ImageNet-A</td>
      <td>ResNet50</td>
      <td>Imagenet_1K_V1</td>
      <td>SGD</td>
      <td>0.00025</td>
      <td>4</td>
      <td>AugMix</td>
      <td>16</td>
      <td>42</td>
      <td>64</td>
      <td>True</td>
      <td>8</td>
      <td>True</td>
      <td>16</td>
      <td>True</td>
      <td>True</td>
      <td>0.20</td>
      <td>10</td>
      <td></td>
      <td></td>
    </tr>
    <tr>
      <th>11</th>
      <td>ImageNet-A</td>
      <td>ResNet50</td>
      <td>Imagenet_1K_V1</td>
      <td>ADAM</td>
      <td>0.0001</td>
      <td>1</td>
      <td>AugMix</td>
      <td>16</td>
      <td>42</td>
      <td>64</td>
      <td>True</td>
      <td>8</td>
      <td>False</td>
      <td>-</td>
      <td>False</td>
      <td>False</td>
      <td>-</td>
      <td>-</td>
      <td></td>
      <td></td>
    </tr>
    <tr>
      <th>12</th>
      <td>ImageNet-A</td>
      <td>ResNet50</td>
      <td>Imagenet_1K_V1</td>
      <td>ADAM</td>
      <td>0.0001</td>
      <td>1</td>
      <td>AugMix</td>
      <td>16</td>
      <td>42</td>
      <td>64</td>
      <td>True</td>
      <td>8</td>
      <td>True</td>
      <td>16</td>
      <td>True</td>
      <td>False</td>
      <td>-</td>
      <td>-</td>
      <td></td>
      <td></td>
    </tr>
    <tr>
      <th>13</th>
      <td>ImageNet-A</td>
      <td>ResNet50</td>
      <td>Imagenet_1K_V1</td>
      <td>ADAM</td>
      <td>0.0001</td>
      <td>1</td>
      <td>AugMix</td>
      <td>16</td>
      <td>42</td>
      <td>64</td>
      <td>True</td>
      <td>8</td>
      <td>True</td>
      <td>16</td>
      <td>True</td>
      <td>True</td>
      <td>0.20</td>
      <td>10</td>
      <td></td>
      <td></td>
    </tr>
    <tr>
      <th>14</th>
      <td>ImageNet-A</td>
      <td>ResNet50</td>
      <td>Imagenet_1K_V1</td>
      <td>ADAM</td>
      <td>0.0001</td>
      <td>4</td>
      <td>AugMix</td>
      <td>16</td>
      <td>42</td>
      <td>64</td>
      <td>True</td>
      <td>8</td>
      <td>False</td>
      <td>-</td>
      <td>False</td>
      <td>False</td>
      <td>-</td>
      <td>-</td>
      <td>0.213</td>
      <td>00:41:19</td>
    </tr>
    <tr>
      <th>15</th>
      <td>ImageNet-A</td>
      <td>ResNet50</td>
      <td>Imagenet_1K_V1</td>
      <td>ADAM</td>
      <td>0.0001</td>
      <td>4</td>
      <td>AugMix</td>
      <td>16</td>
      <td>42</td>
      <td>64</td>
      <td>True</td>
      <td>8</td>
      <td>True</td>
      <td>16</td>
      <td>True</td>
      <td>False</td>
      <td>-</td>
      <td>-</td>
      <td>0.826</td>
      <td>00:46:55</td>
    </tr>
    <tr>
      <th>16</th>
      <td>ImageNet-A</td>
      <td>ResNet50</td>
      <td>Imagenet_1K_V1</td>
      <td>ADAM</td>
      <td>0.0001</td>
      <td>4</td>
      <td>AugMix</td>
      <td>16</td>
      <td>42</td>
      <td>64</td>
      <td>True</td>
      <td>8</td>
      <td>True</td>
      <td>16</td>
      <td>True</td>
      <td>True</td>
      <td>0.20</td>
      <td>10</td>
      <td></td>
      <td></td>
    </tr>  
  </tbody>
</table>


## Discussion

Discussione dei risultati

## Conclusion

wrap up di quello fatto e possibili alternative per future sperimentazioni

## Bibliography

1. **Schneider, Steffen and Rusak, Evgenia and Eck, Luisa and Bringmann, Oliver and Brendel, Wieland and Bethge, Matthias.** "Improving robustness against common corruptions by covariate shift adaptation." Advances in Neural Information Processing Systems, Vol. 33, 2020, pp. 11539-11551. [https://proceedings.neurips.cc/paper_files/paper/2020/file/85690f81aadc1749175c187784afc9ee-Paper.pdf](https://proceedings.neurips.cc/paper_files/paper/2020/file/85690f81aadc1749175c187784afc9ee-Paper.pdf).

2. **Gal, Yarin and Ghahramani, Zoubin** "Dropout as a Bayesian Approximation: Representing Model Uncertainty in Deep Learning." Proceedings of The 33rd International Conference on Machine Learning, Vol. 48, 2016, pp. 1050-1059. [https://proceedings.mlr.press/v48/gal16.html](https://proceedings.mlr.press/v48/gal16.html).

3. [Ollama library](https://ollama.com/)

4. **Marvin Zhang and Sergey Levine and Chelsea Finn.** "MEMO: Test Time Robustness via Adaptation and Augmentation." Advances in neural information processing systems, Vol. 35, 2021, pp. 38629-38642. [https://arxiv.org/abs/2110.09506](https://arxiv.org/abs/2110.09506).

5. **Lyzhov, Alexander and Molchanova, Yuliya and Ashukha, Arsenii and Molchanov, Dmitry and Vetrov, Dmitry.** "Greedy Policy Search: A Simple Baseline for Learnable Test-Time Augmentation." Proceedings of Machine Learning Research, Vol. 124, 2020, pp. 1308-1317. [https://proceedings.mlr.press/v124/lyzhov20a.html](https://proceedings.mlr.press/v124/lyzhov20a.html).

6. **Schneider, Steffen and Rusak, Evgenia and Eck, Luisa and Bringmann, Oliver and Brendel, Wieland and Bethge, Matthias.** "Improving robustness against common corruptions by covariate shift adaptation." Advances in Neural Information Processing Systems, Vol. 33, 2020, pp. 11539-11551. [https://proceedings.neurips.cc/paper_files/paper/2020/file/85690f81aadc1749175c187784afc9ee-Paper.pdf](https://proceedings.neurips.cc/paper_files/paper/2020/file/85690f81aadc1749175c187784afc9ee-Paper.pdf).

7. **Feng, Chun-Mei and Yu, Kai and Liu, Yong and Khan, Salman and Zuo, Wangmeng.** "Diverse Data Augmentation with Diffusions for Effective Test-time Prompt Tuning." 2023 IEEE/CVF International Conference on Computer Vision (ICCV), 2023, pp. 2704-2714. [https://arxiv.org/abs/2308.06038](https://arxiv.org/abs/2308.06038).


