# DL Project Report
Authors: **Giovanni Valer, Emanuele Poiana**

## Table of Contents
- Introduction
- Utils
- Data
- Entropy Loss
    - Weighted Entropy Loss
    - Cut Entropy Loss
- MEMO
    - Standard MEMO
    - MEMO with final prediction on marginal distribution
    - Batch Normalization
- Test Time Adaptation
- Experiments
- Results and Conclusion
- Future Work

## Introduction
This project focuses on implementing a Test-Time Adaptation technique for image classification. The goal is to improve the performance of a pre-trained model on out-of-distribution data, without any knowledge of the test-time data distribution. The method we implement is based on **MEMO** (Marginal Entropy Minimization with One test point). The model we use is ResNet-50, pre-trained on **ImageNet**; the dataset we use for testing is **ImageNet-A** (consisting of images that are misclassified by ResNet).

The baseline (i.e., ResNet-50 without any adaptation) achieves an accuracy of 0.03%(v1 weights) and 15.3%(v2 weights) on ImageNet-A, so we choose to use v2 weights. We aim to improve this performance by applying MEMO and possibly other techniques.

### MEMO, [Zhang et al. (2021)](https://arxiv.org/abs/2110.09506)
MEMO consists in applying a set of augmentations to the test image, collecting the output probability of the pre-trained model for each augmented image, and then undertaking a gradient-based optimization to minimize the entropy of the output distributions. The idea is to fine-tune all the model's parameters (on a single test image) to produce consistent predictions across different augmentations.

### Our Contributions
We further explore other techniques to improve the performance, either modifying MEMO or undertaking different approaches:
- Entropy Loss variants (**Weighted Entropy** Loss, **Cut Entropy** Loss)
- Final **Prediction on Marginal Distribution** (either standard or weighted average)
- **Batch Normalization**
- Different **composition of augmentations**

## Utils

In [None]:
import torch
import torchvision
import pandas as pd
import torch.nn.functional as F
import torchvision.transforms as T
import os
import boto3
import numpy as np
import pandas as pd
import os
from tqdm import tqdm
from PIL import Image
from io import BytesIO
from pathlib import Path
from torch.utils.data import Dataset
from torchvision.models import resnet50, ResNet50_Weights
from torchvision.transforms import v2

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

### ResNet Model

In [None]:
def initialize_resnet(weights='v2'):
    if weights == 'v1':
        resnet = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1).to(device)
    elif weights == 'v2':
        resnet = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2).to(device)
    return resnet

## Data
We need to remap ImageNet-A classes (based on WordNet IDs) to ImageNet classes, in order to use the pre-trained ResNet model.

**Note:** if wget fails, download the file from [here](https://github.com/jo-valer/tta-memo/blob/main/label_mappings.tsv).

In [None]:
# Comment out if you already downloaded the label mappings
!wget https://raw.githubusercontent.com/jo-valer/tta-memo/main/label_mappings.tsv #--no-check-certificate

In [None]:
label_mapping = pd.read_csv("label_mappings.tsv", sep="\t")
remap_dict = label_mapping.set_index("imagenet_a_label").to_dict()["imagenet_label"]

def remap(imagenet_a_label):
    """
    Maps ImageNet-A labels (derived from WordNet IDs) to ImageNet labels (used by ResNet).
    """
    # Map the label
    return remap_dict[imagenet_a_label]

# We will remove all other labels from the model output
labels_in_imagenet_a = [6, 11, 13, 15, 17, 22, 23, 27, 30, 37, 39, 42, 47, 50, 57, 70, 71, 76, 79, 89, 90, 94, 96, 97, 99, 105, 107, 108, 110, 113, 124, 125, 130, 132, 143, 144, 150, 151, 207, 234, 235, 254, 277, 283, 287, 291, 295, 298, 301, 306, 307, 308, 309, 310, 311, 313, 314, 315, 317, 319, 323, 324, 326, 327, 330, 334, 335, 336, 347, 361, 363, 372, 378, 386, 397, 400, 401, 402, 404, 407, 411, 416, 417, 420, 425, 428, 430, 437, 438, 445, 456, 457, 461, 462, 470, 472, 483, 486, 488, 492, 496, 514, 516, 528, 530, 539, 542, 543, 549, 552, 557, 561, 562, 569, 572, 573, 575, 579, 589, 606, 607, 609, 614, 626, 627, 640, 641, 642, 643, 658, 668, 677, 682, 684, 687, 701, 704, 719, 736, 746, 749, 752, 758, 763, 765, 768, 773, 774, 776, 779, 780, 786, 792, 797, 802, 803, 804, 813, 815, 820, 823, 831, 833, 835, 839, 845, 847, 850, 859, 862, 870, 879, 880, 888, 890, 897, 900, 907, 913, 924, 932, 933, 934, 937, 943, 945, 947, 951, 954, 956, 957, 959, 971, 972, 980, 981, 984, 986, 987, 988]

In [None]:
img_root="imagenet-a"

class ImageNetA(Dataset):
    def __init__(self, root, transform=None):
        self.s3_bucket = "deeplearning2024-datasets"
        self.s3_region = "eu-west-1"
        self.s3_client = boto3.client("s3", region_name=self.s3_region, verify=True)
        self.transform = transform

        # Get list of objects in the bucket
        response = self.s3_client.list_objects_v2(Bucket=self.s3_bucket, Prefix=root)
        objects = response.get("Contents", [])
        while response.get("NextContinuationToken"):
            response = self.s3_client.list_objects_v2(
                Bucket=self.s3_bucket,
                Prefix=root,
                ContinuationToken=response["NextContinuationToken"]
            )
            objects.extend(response.get("Contents", []))

        # Iterate and keep valid files only
        self.instances = []
        for ds_idx, item in enumerate(objects):
            key = item["Key"]
            path = Path(key)
            
            # Check if file is valid
            if path.suffix.lower() not in (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp"):
                continue

            # Get label
            label = path.parent.name

            # Map the label
            label = remap(label)

            # Keep track of valid instances
            self.instances.append((label, key))

    def __len__(self):
        return len(self.instances)

    def __getitem__(self, idx):
        try:
            label, key = self.instances[idx]

            label = labels_in_imagenet_a.index(label)

            img_bytes = BytesIO()
            response = self.s3_client.download_fileobj(Bucket=self.s3_bucket, Key=key, Fileobj=img_bytes)
            
            # Open image with PIL
            img = Image.open(img_bytes).convert("RGB")

            # Apply transformations if any
            if self.transform is not None:
                img = self.transform(img)
        except Exception as e:
            raise RuntimeError(f"Error loading image at index {idx}: {str(e)}")

        return img, label

In [None]:
def get_data(batch_size, img_root, preprocess = False):
    # Prepare data transformations for the train loader
    transform = None
    if preprocess:
        transform = T.Compose([
            T.Resize((256,256)),
            T.CenterCrop((224,224)),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])      # Normalize with ImageNet mean
        ])
    else:
        transform = T.Compose([
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    # Load data
    imageneta_dataset = ImageNetA(root=img_root, transform=transform)

    # Initialize dataloader
    loader = torch.utils.data.DataLoader(imageneta_dataset, batch_size, shuffle=False, num_workers=4)

    return loader

## Entropy Loss
We first implement the Entropy Loss function proposed by [Zhang et al. (2021)](https://arxiv.org/abs/2110.09506). The loss is computed on the outputs of the pre-trained model, after applying the augmentations and is defined as:
$$ L_{\text{E}} = - \sum_{i=1}^{N} p_i \log p_i $$
where $p_i$ is the probability of class $i$ (averaged over the augmentations) and $N$ is the number of classes.

In [None]:
class EntropyLoss(torch.nn.Module):
    def __init__(self):
        super(EntropyLoss, self).__init__()

    def forward(self, x):
        # Softmax to get probabilities
        x = F.softmax(x, dim=1)
        # Marginal distribution averaged over augmentations
        avg = torch.mean(x, dim=0)
        # Entropy loss
        return -torch.sum(avg * torch.log(avg))

### Weighted Entropy Loss
We also implement a variant: the Weighted Entropy Loss. It consists in computing the marginal distribution as a weighted average of the output probabilities over the augmentations.
The idea is to give more importance to the augmentations that produce less uncertain predictions, thus the weight associated to each augmentation $a \in A$ is defined as the inverse of the entropy:
$$ w'_a = \frac{1}{H(p_a)} $$
We further normalize the weights over the $A$ augmentations, so that they sum to 1:
$$ w_a = \frac{w'_a}{\sum_{j=1}^{A} w'_j} $$

The marginal distribution is computed as:
$$ p_{\text{m}i} = \sum_{j=1}^{A} w_j p_{ji} $$
where $p_{ji}$ is the probability of class $i$ for the $j$-th augmentation.

Finally, the Weighted Entropy Loss is then defined as:
$$ L_{\text{WE}} = - \sum_{i=1}^{N} p_{\text{m}i} \log p_{\text{m}i} $$

In [None]:
class WeightedEntropyLoss(torch.nn.Module):
    def __init__(self):
        super(WeightedEntropyLoss, self).__init__()
        
    def forward(self, x): # A augentations, N classes
        # Softmax to get probabilities
        probs = F.softmax(x, dim=1) # [A x N]
        with torch.no_grad():
            # Entropies of the output probabilities
            entropies = -torch.sum(probs * torch.log(probs), dim=1) # [A]
            # Weights are the inverse of the entropy, normalized
            weights = 1 / entropies # [A]
            weights /= torch.sum(weights) # [A]
        # Weighted probabilities
        weighted_probs = weights[:, None] * probs # [A] x [A x N] = [A x N]
        # Marginal distribution averaged over augmentations
        avg = torch.mean(weighted_probs, dim=0) # [N]
        # Entropy loss
        return -torch.sum(avg * torch.log(avg))

### Cut Entropy Loss
Another variant we implement is the Cut Entropy Loss. It consists in computing the marginal distribution as the average of the output probabilities over the augmentations, but only considering the top-$k$ augmentations with the lowest entropy. The idea is to avoid handcrafting the set of augmentations to consider, but rather automatically select the most informative ones. The Cut Entropy Loss is defined as:
$$ L_{\text{CE}} = - \sum_{i=1}^{N} p_{\text{m}i} \log p_{\text{m}i} $$
where $p_{\text{m}i}$ is computed as for the standard Entropy Loss, but only considering the top $k$ augmentations with the lowest entropy:
$$ p_{\text{m}i} = \sum_{j=1}^{k} w_j p_{ji} $$
where:
$$ w_j = \begin{cases} \frac{1}{k} & \text{if } j \leq k \\ 0 & \text{otherwise} \end{cases} $$

In [None]:
class CutEntropyLoss(torch.nn.Module):
    def __init__(self, cut=0.9):
        super(CutEntropyLoss, self).__init__()
        self.cut = cut

    def forward(self, x):
        # Softmax to get probabilities
        probs = F.softmax(x, dim=1)
        with torch.no_grad():
            # Entropy of the output probabilities
            entropy = -torch.sum(probs * torch.log(probs), dim=1)
            # Sort the entropies
            sorted_entropy, _ = torch.sort(entropy, descending=True)
            # Assign the weights: 1 for the top cut, 0 for the rest
            weights = torch.zeros_like(entropy)
            weights[entropy <= sorted_entropy[int(self.cut * len(entropy))]] = 1 # Note: "top" here means "low entropy"
            # Normalize the weights
            weights /= torch.sum(weights)
        # Weighted marginal distribution
        avg = torch.sum(weights[:, None] * probs, dim=0) # [A] x [A x N] = [N]
        # Entropy loss
        return -torch.sum(avg * torch.log(avg))
        

## MEMO
We implement MEMO as a class wrapping the pre-trained model, the optimizer, and defining the set of augmentations to apply to the test image. We use PyTorch's SGD and perform a single optimization step on the test image, as from preliminary experiments more steps do not seem to improve the performance.

We conjointly implement also some variants/alternatives to MEMO, controlled by three parameters:
- `memo_train`: whether to train the model to minimize the Entropy Loss on the test image across the augmentations (default: **True**);
- `use_augmentations`: instead of making the final prediction on the original data point, make it on the average of the predictions over the augmentations (default: **False**);
- `use_weighted_aug`: whether to use the weighted average of the predictions over the augmentations, same as for the Weighted Entropy Loss (default: **False**).

The default behavior corresponds to the standard MEMO, i.e., training the model to minimize the Entropy Loss across the augmentations and making the final prediction on the original data point.

### Augmentations
We experiment with different sets of augmentations, including:
- **RandomHorizontalFlip**
- **RandomResizedCrop**
- **AugMix**

### Input Preprocessing
The `preprocess` flag specifies the transformation applied in the dataloader for the dataset.
- if **True**: Resize and CenterCrop to every image in the set;
- if **False**: no transformations are applied (except for the normalization). This requires that the batch size be 1.

Such flag is useful to test the model on the original images of the dataset, instead of the resized and cropped ones, so that the augmentations are applied on the original images.

### Batch Normalization
[Zhang et al. (2021)](https://arxiv.org/abs/2110.09506) suggest to use Batch Normalization on the test image's augmentations, and, to prevent overfitting, to only slighlty update the running statistics of the BatchNorm layers. We implement this technique by setting the `use_adaptive_bn` flag to **True**. This also requires the batch size to be 1, as the running statistics are updated for each image separately. The running statistics are updated as follows:
$$ \mu = \alpha \mu_{\text{train}} + (1 - \alpha) \mu_{\text{test}} $$
$$ \sigma^2 = \alpha \sigma^2_{\text{train}} + (1 - \alpha) \sigma^2_{\text{test}} $$
where $\mu_{\text{train}}$ and $\sigma^2_{\text{train}}$ are the running statistics of the BatchNorm layer on the training set, and $\mu_{\text{test}}$ and $\sigma^2_{\text{test}}$ are the statistics computed on the test image's augmentations. The parameter $\alpha$ is set to the best value found in their experiments.

In [None]:
def adaptive_bn(self, input):
    if self.adapt:
        input_mean = torch.zeros(self.running_mean.shape, device=self.running_mean.device)
        input_var = torch.ones(self.running_var.shape, device=self.running_var.device)
        
        # Compute the current mean and variance and put it in 
        torch.nn.functional.batch_norm(input, input_mean, input_var, None, None, True, 1.0, 0.0)
        
        # The value 0.9411 is taken directly from the MEMO paper, 16/17 = 0.9411 defined as optimal
        adapted_mean = (1- 0.9411) * input_mean + 0.9411 * self.running_mean 
        adapted_var =  (1- 0.9411) * input_var + 0.9411 * self.running_var
        
        # Compute the batch_norm with the adapted mean and variance
        return torch.nn.functional.batch_norm(input, adapted_mean, adapted_var, self.weight, self.bias, False, 0.0, self.eps)

    else:
        return torch.nn.functional.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, False, 0.1, self.eps)

### Implementation
All the aforementioned alternatives can be combined with the previously defined Entropy Loss variants, by just passing it as an argument to the MEMO class.

In [None]:
class Memo(torch.nn.Module):
    """
    Implements the Test-Time Adaptation method of MEMO (Marginal Entropy Minimization with One test point) for image classification.
    """
    def __init__(self, net=resnet50(weights=ResNet50_Weights.IMAGENET1K_V2), optimizer=torch.optim.SGD, lr=0.005, n=64, augmentations='augmix_flip_crop', preprocess=False, device='cuda:0', memo_train:bool=True, use_augmentations:bool=False, use_weighted_aug:bool=False, use_adaptive_bn = False, criterion=EntropyLoss()):
        super(Memo, self).__init__()
        self.device = device
        self.augmentations_name = augmentations # name of augmentations to apply
        self.n = n # number of augmentations
        self.memo_train = memo_train # whether to train the model at test time or not
        self.use_augmentations = use_augmentations # if True, the output is the marginal distribution of the augmented images
        self.use_weighted_aug = use_weighted_aug if use_augmentations else False # if True, the output is the weighted marginal distribution of the augmented images (weights for each augmentation are 1/entropy)
        self.adapt_bn = use_adaptive_bn
        torch.nn.BatchNorm2d.adapt = self.adapt_bn # if True, the batch normalization adapt over input test samples for every BN layer
        torch.nn.BatchNorm2d.forward = adaptive_bn 
        self.net = net.to(device)

        self.preprocess = preprocess # if True, the dataloader return the images already preprocessed (cropped, resized, etc.)
        self.criterion = criterion

        self.optimizer = optimizer(self.net.parameters(), lr=lr)
        self.net.eval()

        # Save the default parameters of the model
        self.default_params = {}
        for param in self.net.parameters():
            self.default_params[param] = param.clone()

        self.cropping = T.Compose([
            T.Resize((256,256)),
            T.CenterCrop((224,224))
        ])

        if augmentations == 'augmix':
            self.augmentations = v2.AugMix()
        elif augmentations == 'augmix_flip':
            self.augmentations = T.Compose([v2.AugMix(),
                                            T.RandomHorizontalFlip()])
        elif augmentations == 'augmix_flip_crop':
            self.augmentations = T.Compose([v2.AugMix(),
                                            T.RandomHorizontalFlip(),
                                            T.RandomResizedCrop(224)])
        elif augmentations == 'flip_crop_augmix':
            self.augmentations = T.Compose([T.RandomHorizontalFlip(),
                                            T.RandomResizedCrop(224),
                                            v2.AugMix()])
        elif augmentations == 'crop':
            self.augmentations = T.Compose([T.RandomResizedCrop(224)])
        elif augmentations == 'flip_crop':
            self.augmentations = T.Compose([T.RandomHorizontalFlip(),
                                            T.RandomResizedCrop(224)])
        else: # Throw error
            raise ValueError("Invalid augmentations. Choose from 'augmix', 'augmix_flip', 'augmix_flip_crop', 'flip_crop_augmix', 'crop', 'flip_crop'")
    
    def filter_classes(self, outputs):
        return outputs[:, labels_in_imagenet_a]

    def apply_augmentations(self, img_original, augmentations):
        """
        Applies the augmentations to the input image.
        """
        with torch.no_grad():
            return augmentations(img_original)
    
    def get_augment_loader(self, x):
        """
        Returns the dataloader with the input image x and its augmentations.
        If the input image is already cropped, or the augmentations do not apply cropping, the image does not require cropping.
        """
        with torch.no_grad():
            if self.preprocess or (self.augmentations_name == 'augmix_flip' or self.augmentations_name == 'augmix'):
                aug_images = [x]
            else:
                aug_images = [self.cropping(x)] # This ensures that the original image and the augmented images are of the same size
        
            for _ in range(self.n):
                aug_img = self.apply_augmentations(x, augmentations=self.augmentations)#.to(device)
                aug_images.append(aug_img)
        
            return torch.utils.data.DataLoader(aug_images, batch_size = self.n + 1, shuffle=False, num_workers=0)

    def augment(self, x):
        """
        Returns the image x and its augmentations.
        """
        aug_loader = self.get_augment_loader(x)
        
        # Get a single batch of augmented images
        aug_inputs = next(iter(aug_loader)).to(self.device)

        return aug_inputs
    
    def get_averaged_distribution(self, x):
        """
        Returns the averaged marginal distribution of the network over the input image and its augmentations.
        If use_weighted_aug is True, the output is the weighted marginal distribution of the augmented images (weights for each augmentation are 1/entropy).
        """
        predictions = None
        for image in x:
            inputs = self.augment(image)
            outputs = self.filter_classes(self.net(inputs))

            if self.use_weighted_aug:
                # Entropy of the output probabilities (softmax of the network output)
                entropies = -torch.sum(F.softmax(outputs, dim=1) * torch.log(F.softmax(outputs, dim=1)), dim=1)
                weights = 1 / entropies
                weights /= torch.sum(weights)
                weighted_outputs = weights[:, None] * outputs
                mean_distribution = torch.mean(weighted_outputs, dim=0).unsqueeze(dim=0)
            else:
                mean_distribution = torch.mean(outputs, dim=0).unsqueeze(dim=0)
            if predictions is None:
                predictions = mean_distribution
            else:
                predictions = torch.cat((predictions, mean_distribution), 0)

        return predictions

    def test_time_training(self, x):
        """
        Trains the model at test time.
        """
        predictions = None
        for image in x:
            torch.cuda.empty_cache()
            self.optimizer.zero_grad()

            inputs = self.augment(image)
            outputs = self.filter_classes(self.net(inputs))

            del inputs
            
            # Compute the entropy loss
            entropy_loss = self.criterion(outputs)
            
            # Backpropagate the loss
            entropy_loss.backward()
            
            # Update the weights
            self.optimizer.step()
            
            if not self.preprocess and not (self.augmentations_name == 'augmix_flip' or self.augmentations_name == 'augmix'):
                image = self.cropping(image)
            
            if not self.use_augmentations:
                prediction = self.filter_classes(self.net(image.unsqueeze(dim=0)))
            else:
                prediction = self.get_averaged_distribution(image.unsqueeze(dim=0))

            if predictions is None:
                predictions = prediction
            else:
                predictions = torch.cat((predictions, prediction), 0)

            # Restore the default parameters
            for param in self.net.parameters():
                param.data = self.default_params[param].clone()
        
        return predictions

    def forward(self, x):
        """
        Forward pass of the model, eventually training the model at test time.
        """
        if self.memo_train:
            output = self.test_time_training(x)
        else:
            if self.use_augmentations:
                output = self.get_averaged_distribution(x)
            else:
                output = self.filter_classes(self.net(x))

        return output


## Test Time Adaptation

In [None]:
def test(net, data_loader, device="cuda"):
    samples = 0.0
    cumulative_accuracy = 0.0

    # Set the network to evaluation mode
    net.eval()

    # Iterate over the test set
    for _, (inputs, targets) in enumerate(tqdm(data_loader)):

        # Load data into GPU
        inputs = inputs.to(device)
        targets = targets.to(device)

        # Forward pass
        outputs = net(inputs)

        # Disable gradient computation (we are only testing, we do not want our model to be modified in this step!)
        with torch.no_grad():

            # Fetch prediction
            samples += inputs.shape[0]
            _, predicted = outputs.max(1)

            # Compute accuracy
            cumulative_accuracy += predicted.eq(targets).sum().item()
            del outputs, inputs, targets, predicted

    return cumulative_accuracy / samples * 100

## Experiments
We hereby present a selection of significant experiments.

### MEMO

In [None]:
net = Memo(
    net=initialize_resnet(),
    device=device
    )

loader = get_data(1, img_root)
test_accuracy = test(net, loader, device=device)
print(f"\tTop-1 accuracy {test_accuracy:.2f}")

### MEMO + Cut Entropy Loss + Batch Normalization + Final Prediction on Marginal Distribution

In [None]:
net = Memo(
    net=initialize_resnet(),
    augmentations="crop",
    use_augmentations=True,
    criterion=CutEntropyLoss(cut=0.9),
    use_adaptive_bn=True,
    device=device
    )

loader = get_data(1, img_root)
test_accuracy = test(net, loader, device=device)
print(f"\tTop-1 accuracy {test_accuracy:.2f}")

### MEMO + CEL + BN + Final Prediction on Weighted Marginal Distribution

In [None]:
net = Memo(
    net=initialize_resnet(),
    augmentations="crop",
    use_augmentations=True,
    use_weighted_aug=True,
    criterion=CutEntropyLoss(cut=0.9),
    use_adaptive_bn=True,
    device=device
    )

loader = get_data(1, img_root)
test_accuracy = test(net, loader, device=device)
print(f"\tTop-1 accuracy {test_accuracy:.2f}")

### Average Output Probabilities of Augmentations
We also experiment by avoiding any training steps: we use the averaged output probabilities (over the augmentations) as the final prediction.

In [None]:
net = Memo(
    net=initialize_resnet(),
    augmentations="flip_crop",
    memo_train=False,
    use_augmentations=True,
    use_adaptive_bn=True,
    device=device
    )

loader = get_data(1, img_root)
test_accuracy = test(net, loader, device=device)
print(f"\tTop-1 accuracy {test_accuracy:.2f}")

### Weighted Average Output Probabilities of Augmentations
Similarly, we experiment with the weighted average of the output probabilities, using the weights computed based on the entropy of the output probabilities.

In [None]:
net = Memo(
    net=initialize_resnet(),
    augmentations="crop",
    memo_train=False,
    use_augmentations=True,
    use_weighted_aug=True,
    use_adaptive_bn=True,
    device=device
    )

loader = get_data(1, img_root)
test_accuracy = test(net, loader, device=device)
print(f"\tTop-1 accuracy {test_accuracy:.2f}")

## Results

The baseline model achieves a top-1 accuracy of **15.07%**, when restricting the output labels to those in Imagenet-A.
We find that MEMO improves the performance of the pre-trained model on ImageNet-A, and that combining it other techniques can further improve the performance.

The best results of MEMO are obtained with with 64 augmentations and a learning rate of 0.005. When using the augmentations proposed by [Zhang et al. (2021)](https://arxiv.org/abs/2110.09506) we find a top-1 accuracy of **16.17%**; while with RandomResizedCrop we achieve **19.87%**.
In the following, we present the results in an incremental fashion, showing the impact of each technique on the performance.

### Entropy Loss
From preliminary experiments we notice that the Weighted Entropy Loss does not improve the performance, compared to the standard Entropy Loss. We therefore focus on the **Cut Entropy Loss**, which instead has a **beneficial impact on the accuracy** of the model: the top-$k$ augmentations (those with lower entropy) are much more informative than the handcrafted set of augmentations. We find the best results with the top-10% and top-20% (i.e. cut=0.9 and cut=0.8, respectively), out of 64 augmentations. This further improves the accuracy to **23.20%**.

### Augmentations
We tried different sets of augmentations, and we observed that random cropping part of the original (full size) image produces the best results, together with random horizontal flipping.
This is probably due to the characteristics of the ImageNet-A dataset, where the true labels of the images are usually located in small, sparse areas of the test sample.
Random cropping prompts the model to avoid the “bigger” objects in the image and improves the probability of detecting smaller but more relevant features.
Speaking of top-1 accuracy, the best results to this point are indeed obtained with RandomResizedCrop+RandomHorizontalFlip: **23.20%**, while adding AugMix leads to **19.86%**.

### Batch Normalization Adaptation
Batch Normalization is implemented in ResNet50; but, as suggested by [Schneider et al. (2020)](https://arxiv.org/abs/2006.16971), just a single test sample adaptation is enough to generally show an improvement in the model performance.
We achieve better results with BN, in particular it further improves the accuracy to **27.68%**, using a learning rate of 0.01.

### Final prediction on marginal distribution
This approach can be used either alone or in combination with MEMO. Both when added on top of previous techniques and when used alone, it achieves similar results: **27.68%**. It is worth noting, however, that is requires less computational resources, and thus proves to be the most promising technique. We also expected to achieve even better results, but due to time and computational constraints we have not been able to properly search the best hyperparameters' values (as there is a considerable number of possible combinations).\
The alternative of using a weighted average, instead of the standard average of the output probabilities over the augmentations, does not seem to be particularly effective, as it does not improve the accuracy.


## Conclusion
We showed how MEMO can be a viable TTA technique, but other approaches are more effective and also faster, at least in our experimental setup. In particular, the best results are obtained by using the MEMO with RandomResizedCrop, Cut Entropy Loss, and Batch Normalization. The final prediction on the marginal distribution also seems to be a promising technique, reaching the same best results in our experiments.

### Future Work
Due to time and computational constraints, we could not undertake more experiments; however, we foresee another method for test-time adaptation, which can be either combined with MEMO or used as an alternative: **APoZ-based Adaptation**. The idea is to prune the model based on the Average Percentage of Zeros (APoZ) of the activations (before the final fully connected layer) over the augmentations of the test image. The intuition is that the nodes that are more active across the augmentations are more informative for the target class. By pruning the model, we are retaining only the most informative features, and therefore adapting the model to the test-time data point.
