<a href="https://colab.research.google.com/github/javiimo/ImageClassificationAssignment/blob/main/Memo_book.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Imaplementing MEMO

In this Colab we implement MEMO combined with CLIP, to then see if there is any significant performance gain in testing time.
The dataset of our choice is cifar100.

# 0. Import libraries, load CLIP and dataset

In [None]:
! pip install ftfy regex tqdm
! pip install git+https://github.com/openai/CLIP.git

import clip
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt # who knows maybe we'll need it later on
import torch.optim as optim
import copy
import torch.amp
from torch.cuda.amp import GradScaler
import torch.nn as nn

Collecting ftfy
  Downloading ftfy-6.2.0-py3-none-any.whl (54 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/54.4 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━[0m [32m41.0/54.4 kB[0m [31m1.1 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.4/54.4 kB[0m [31m1.1 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: ftfy
Successfully installed ftfy-6.2.0
Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-870w9_ye
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-870w9_ye
  Resolved https://github.com/openai/CLIP.git to commit a1d071733d7111c9c014f024669f959182114e33
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch->clip==1.0)
  Using cached nvidia_cuda_nvr

CLIP model uses quantization to reduce the memory footprint and computational complexity. This is achieved alternating ` dtype = torch.float16` layers with ` dtype = torch.float32` .

This, though, makes the backward propagation step more difficult to make and hence, to avoid issues, once the model is initialized every dtype is converted directly to ` dtype = torch.float32`.

In [None]:
def convert_model_parameters_to_float32(model):
    for param in model.parameters():
        param.data = param.data.to(torch.float32)
    return model

In [None]:
def load():

    # Load the model
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model, preprocess = clip.load('ViT-B/32', device)

    model = convert_model_parameters_to_float32(model).to(device)

    # We just need the testing dataset. "preprocess" function of clip deals with the transformation of an image into a tensor, so
    # we don't have to worry about it applying some transformation to the dataset.

    cifar100 = torchvision.datasets.CIFAR100(root= './data', download = True, train = False)

    return cifar100, model, device, preprocess


In [None]:
cifar100, model, device, preprocess = load()

#Tokenize the text labels
text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device)

#Encoded text
with torch.no_grad():
  text_features = model.encode_text(text_inputs)

text_features /= text_features.norm(dim=-1, keepdim=True)


100%|████████████████████████████████████████| 338M/338M [00:01<00:00, 227MiB/s]


Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./data/cifar-100-python.tar.gz


100%|██████████| 169001437/169001437 [00:03<00:00, 55609382.69it/s]


Extracting ./data/cifar-100-python.tar.gz to ./data


# 1. Bluid MEMO

details of course in the paper: https://arxiv.org/pdf/2110.09506.pdf

A first implementation, following what is written in section 3 of the paper , is the following. Let $f_\theta$ be the model, and $x_0$ the test input:

1. Let $x_1, \ldots x_n$ be the $n$-uple, given by $n$ augmentations of the original sample $x_0$. We shall collect all those elements in a single tensor, which will act as batch for future gradient descent step.

2. Compute the probabilities distributions for every element of the batch, say $p_\theta(\cdot | x_i)$ .  The model’s average, or $marginal$, output distribution with respect to the augmented points is given by $$ \bar{p}_\theta(y | x) = \frac{1}{n}\sum_i^n \bar{p}_\theta(y | x_i) $$
and compute its $marginal$ entropy $H(\bar{p}_\theta(\cdot | x))$.

4. Using the entropy as a loss function, perform a single update step (so backprop) to give us the parameters $\theta'$.

5. With those modified parameters we now have the updated $p_{\theta'}(\cdot | x_0)$, that we use to predict $y_0 = argmax_y p(y | x_0) $.


With this in mind, let's start building what we need, starting with augmentations.

In [None]:
# returns a tensor of size (num_augmentations, image size )
def augment_image(image, preprocess, num_augmentations = 100):

  # apply some random transformations
    augmentations = transforms.Compose([
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.5),
        transforms.RandomRotation(degrees=30),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
        transforms.RandomResizedCrop(size=224, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
    ])

    augmented_images = []

    #Add n augmentations to the augmented_images
    for _ in range(num_augmentations):
        augmented_images.append( preprocess(augmentations(image)).unsqueeze(0).to(device) )

    #Save it as a tensor
    batch = torch.vstack(augmented_images)

    return batch

The entropy of a random variable $X$ with discrete distribution $p$ is defined as:
$$H(X) = \sum_i^n p_i \log (p_i) \quad \text{where } p_i = \mathbb{P}(X = x_i) \quad \text{and we set } 0 \log(0) := 0$$
To compute it, we need to be careful to underflow/overflows that may occour while computing it.

To this extent, we perform some mathematical operations exploiting logarithms properties to make computations more stable.

Let $k$ be the number of classes. Let $n$ be the number of samples.
Consider $$X_{i = 1, \ldots n,j = 1, \ldots k}$$
the matrix of logits, where $x_{ij}$ is the $j-th$ logit of the $i-th$ sample, and $p_{ij}$ be the probability output given by the softmax.

Then the log of the marginal probability of belonging to class
$$
\begin{align*}
\log ( \bar{p}_j )
&= \log \biggl( \frac{1}{n}\sum_i^n p_{ij} \biggr) \\
&= \log \biggl( \frac{1}{n}\sum_i^n p_{ij} \biggr) \\
&= \log \biggl( \frac{1}{n}\sum_i^n \frac{e^{x_{ij}}}{ \sum_h e^{x_{ih}}} \biggr) \\
&= \log \biggl( \sum_i^n e^{z_{ij}} \biggr) - \log(n) \\
\end{align*}
$$
where $z_{ij} = x_{ij} - \log \sum_h^k e^{x_{ik}}$.

Computing $z_{ij}$ first and then using the upper formula to get $\log{\bar{p}_j}$ is more stable, as $\log$ oprations are done over a sum of elements rather than just a single one.

In [None]:
def marginal_entropy(logits):

    z = logits - logits.logsumexp(dim = -1, keepdim=True)                            # compute z_ij
    marginal_logp = z.logsumexp(dim=0) - np.log(z.shape[0])   # compute marginal log probabilities

    min_real = torch.finfo(marginal_logp.dtype).min          # for numerical stability, The smallest representable number given the dtype of logits.
    avg_logits = torch.clamp(marginal_logp, min = min_real)  # put a threshold to avoid underflow

    return -(avg_logits * torch.exp(avg_logits)).sum(dim=-1)

CLIP's output of a given image is the embedding on a $512$-dimensional space.
CLIP's output of a given text is the embedding on the same $512$-dimensional space.

From there it is possible to get the logits of the probabilities distributions by computing the cosine similarities of the embedding of an image with the embeddings of all the labels.

We have all the ingredients now to perform the gradient descent step:


In [None]:
def backprop_sweep(model, batch, text_features, optimizer):

    model.train(True)

    optimizer.zero_grad()

    # forward pass
    image_features = model.encode_image(batch) # image features contains the embeddings of all the elements of the class.
    image_features = image_features / image_features.norm(dim = -1, keepdim = True)
    logits = 100.0 * image_features @ text_features.T # get logits

    # compute loss
    loss = marginal_entropy(logits)

    # backward pass
    loss.backward()
    optimizer.step()



Finally, let's build a function that now predicts the output of the original sample on the updated model:

In [None]:
def predict(model, image, text_features, label):

    model.eval()
    image_prep = preprocess(image).unsqueeze(0).to(device)

    with torch.no_grad():

        image_features = model.encode_image(image_prep) # image features contains the embeddings of all the elements of the class.
        image_features = image_features / image_features.norm(dim = -1, keepdim = True)
        outputs = 100.0 * image_features @ text_features.T

        _, predicted = outputs.max(1)
        confidence = nn.functional.softmax(outputs,dim=1).squeeze()[predicted].item()

    correctness = 1 if predicted.item() == label else 0

    return correctness, confidence

Finally it is possible to write MEMO function:

In [None]:
def MEMO_test(model, preprocess, text_features, image, class_id, device, num_augmentations = 50, lr = 1e-5):

  batch = augment_image(image, preprocess, num_augmentations = num_augmentations)
  model_copy = copy.deepcopy(model).to(device)

  optimizer = optim.SGD(model_copy.parameters(), lr = lr)

  backprop_sweep(model_copy, batch, text_features, optimizer)

  return predict(model_copy, image, text_features, class_id)

# 2. Sanity check

Let's see if it everything works.

In [None]:
image, class_id = cifar100[0]

print('(correctness, confidence)')
print(f'MEMO: {MEMO_test(model, preprocess, text_features, image, class_id, device, lr = 1e-5)}')

(correctness, confidence)
MEMO: (1, 0.19025388360023499)


# 3. Full test

In [None]:
from tqdm import tqdm

print('Running...')

correct_memo = []
correct_clip = []

confidence_memo = []
confidence_clip = []

for i in tqdm(range(5000)):


    image, class_id = cifar100[i]
    memo_eval, memo_belief = MEMO_test(model, preprocess, text_features, image, class_id, device)

    correct_memo.append(memo_eval)
    confidence_memo.append(memo_belief)

    clip_eval, clip_belief = predict(model, image, text_features, class_id)

    correct_clip.append(clip_eval)
    confidence_clip.append(clip_belief)

print('')
print(f'MEMO adapt test accuracy {(np.mean(correct_memo))*100:.2f}. Average confidence: {(np.mean(confidence_memo))*100:.2f}')
print(f'CLIP adapt test accuracy {(np.mean(correct_clip))*100:.2f}. Average confidence: {(np.mean(confidence_clip))*100:.2f}')

Running...


100%|██████████| 500/500 [04:46<00:00,  1.74it/s]


MEMO adapt test accuracy 64.60. Average confidence: 51.37
CLIP adapt test accuracy 64.40. Average confidence: 50.60





# 4. A few remarks

*   It's a non trivial task to set the learning rate of the optimizer properly. In fact since loss function, *the entropy*, is minimized for high confidence predictions regardless of the correctness of the prediction itself: setting it too high will destroy model's parameters and therefore giving lots of erroneus predictions. On the other hand though if the learning rate is way too little then the backpropagation step won't be effectful enough. Setting ```lr = 1e-5``` seems an ok in-between.


*   In information theory literature to compute the entropy of a random variable is it usually employed the base 2 logarithm, $\log_2$. In this implementation we use the classical base $e$ log but this is not only formally correct but also does not influence the backpropagation step as it just differs from $log_2$ by a multiplicative constant ($\frac{1}{\log_e(2)}$).

* The number of augmentations does not have a huge impact on the outcome. However, the more augmentations the more accuracy at the price of a reduced confidence - averaging to a higher number of probability distributions will land a less peaked distribution, hence more entropy, but gains more robustness to domain shifts.


