# Refactoring

# 0. Imports

In [1]:
! pip install ftfy regex tqdm
! pip install git+https://github.com/openai/CLIP.git
! pip install datasets transformers
! pip install seaborn
! pip install nltk

# Used for CLIP:
import clip
import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import copy
import numpy as np


#Used for testing:
#from pathlib import Path
# from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from tqdm import tqdm
from PIL import Image
# from dataclasses import dataclass, asdict
# from typing import List, Dict
# import os
# import json
# from torch.utils.data import Subset
import torch.nn.functional as F
from torchvision.transforms.functional import to_tensor, to_pil_image
import torch.nn as nn

# for Dictionary
import nltk
nltk.download('wordnet')
from nltk.corpus import wordnet

#Used for visualizing results
import matplotlib.pyplot as plt
# import pandas as pd
# import seaborn as sns
# from ipywidgets import widgets, interactive_output, Dropdown, Output, VBox, Button
# from IPython.display import display


# for adaptive temperature
from scipy.optimize import fsolve
from scipy.special import softmax

Collecting ftfy
  Downloading ftfy-6.2.0-py3-none-any.whl (54 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.4/54.4 kB[0m [31m1.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: ftfy
Successfully installed ftfy-6.2.0
Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-kvlk6aom
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-kvlk6aom
  Resolved https://github.com/openai/CLIP.git to commit dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch->clip==1.0)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch->clip==1.0)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==

[nltk_data] Downloading package wordnet to /root/nltk_data...


# 0.5. Inizialize dataset and CLIP

In [2]:
cifar100 = torchvision.datasets.CIFAR100(root= './data', download = True, train = False)

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)

# in order to avoid problems with MEMO backpropagation step, all parameters of the model are set to torch.float32
for param in model.parameters(): param.data = param.data.to(torch.float32)

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./data/cifar-100-python.tar.gz


100%|██████████| 169001437/169001437 [00:04<00:00, 34285558.83it/s]


Extracting ./data/cifar-100-python.tar.gz to ./data


100%|███████████████████████████████████████| 338M/338M [00:04<00:00, 79.4MiB/s]


# 1. Useful functions

Function not specific to a single class, but useful for every environment

In [3]:
class model_methods:

  def get_entropy(self, logits):

    z = logits - logits.logsumexp(dim = -1, keepdim=True)     # compute z_ij
    marginal_logp = z.logsumexp(dim=0) - np.log(z.shape[0])   # compute marginal log probabilities

    min_real = torch.finfo(marginal_logp.dtype).min           # for numerical stability, the smallest representable number given the dtype of logits.
    avg_logits = torch.clamp(marginal_logp, min = min_real)   # put a threshold to avoid underflow

    return -(avg_logits * torch.exp(avg_logits)).sum(dim=-1)


  def get_augmentations(self, image, num_augmentations = 50, transformations = None, manual_seed = None):

    if transformations == None:
        #Set a seed for reproducibility of the random augmentations
        if manual_seed != None: torch.manual_seed(manual_seed)

        transformations = transforms.Compose([
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.5),
            transforms.RandomRotation(degrees=30),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
            transforms.RandomResizedCrop(size=224, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
        ])

    augmented_images = [self.preprocess(image).unsqueeze(0).to(self.device)] #Add the original image to the batch of augmentations
    for _ in range(num_augmentations):
        augmented_images.append(self.preprocess(transformations(image)).unsqueeze(0).to(self.device))

    batch = torch.vstack(augmented_images)
    return batch #(num_augumentations + 1, 3, 224, 224)


  def get_synonyms(self, word):

    synonyms = set()

    for syn in wordnet.synsets(word):
        for lemma in syn.lemmas():
            synonyms.add(lemma.name().lower())

    return synonyms

  # Get definitions
  def get_definitions(self, word):
    definitions = []

    for syn in wordnet.synsets(word):
      definitions.append(syn.definition())

    return definitions


  def get_temperature(x, beta = 0.5 + np.log(2 * np.pi * np.exp(1))):

    n = x.shape[0]
    x_np = x.cpu().numpy()

    x_max = np.max(x_np)


    def entropy(tau):

      y = x_np * tau
      p = softmax(y)

      den = np.ones(n) + y + y**2 / 2 + y**3 / 6 + y**4 / 24 + y**5/120
      num = y * den

      H = - np.sum(num) / np.sum(den) + tau * x_max + np.log(n) / 2
      return H - beta

    initial_guess = 100.0  # Initial guess for tau
    root = fsolve(entropy, initial_guess)

    return torch.tensor(root).to(device)


# 2. Clip class

In [4]:
class clip_model(model_methods):

  def __init__(self, model, preprocess, device = "cuda" if torch.cuda.is_available() else "cpu", dataset = None):

    self.device = device
    self.model = model
    self.preprocess = preprocess
    self.text_features = None

    if dataset != None: self.get_text_features(dataset)

  def get_text_features(self, dataset):

    text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in dataset.classes]).to(self.device)
    with torch.no_grad(): self.text_features = self.model.encode_text(text_inputs)

    self.text_features /= self.text_features.norm(dim=-1, p = 2, keepdim=True)


  def get_image_features(self, image):

    image_prep = self.preprocess(image).unsqueeze(0).to(self.device)

    with torch.no_grad():

      image_features = self.model.encode_image(image_prep) # image features contains the embeddings of all the elements of the class.
      image_features = image_features / image_features.norm(dim = -1, keepdim = True)

    return image_features

  def get_candidates(self, reference, candidates, num_candidates):

    logits = reference @ candidates.T
    _, label_indeces = torch.topk(logits, num_candidates)
    candidates = candidates[label_indeces].squeeze()

    return candidates, label_indeces.squeeze()


  def get_closest_features(self, target, candidates, num_candidates):

    with torch.no_grad():

        tokens = torch.cat([clip.tokenize(c) for c in candidates]).to(device)

        features = model.encode_text(tokens)
        features /= features.norm(dim = -1, p = 2, keepdim=True)

    picked_features, _ = self.get_candidates(target, features, num_candidates = num_candidates)
    return picked_features


  def get_prob(self, tensor1, tensor2, custom_temp = None, adaptive_temp = False):

    logits = tensor1 @ tensor2.T

    if adaptive_temp: temp = self.get_temperature(logits)
    else:
      temp = self.model.logit_scale.exp() if custom_temp == None else custom_temp

    logits = temp * logits

    return logits.softmax(dim = -1).squeeze()



  def __call__(self, image, custom_temp = None):

    self.model.eval()

    image_features = self.get_image_features(image)
    prob = self.get_prob(image_features, self.text_features)


    predicted = torch.argmax(prob)
    entropy = self.get_entropy(prob).item()


    return predicted.item(), prob, entropy

## CLIP - sanity check

In [None]:
base_clip = clip_model(model, preprocess, device, dataset = cifar100)

In [None]:
n = 1
correct_clip = 0


for i in tqdm(range(n)):

  image, label = cifar100[i]
  predicted_label, _, entropy = base_clip(image)

  if predicted_label == label: correct_clip += 1

print(f'Correct: {correct_clip}')


100%|██████████| 1/1 [00:00<00:00,  2.51it/s]

Correct: 0





# 3. MEMO class

In [None]:
class MEMO_model(clip_model):

  def __init__(self, model, preprocess, device = "cuda" if torch.cuda.is_available() else "cpu", dataset = None,
               optimizer = 'SGD', num_augmentations = 50, lr = 1e-5, momentum = 0.9):

    super().__init__(model, preprocess, device, dataset)

    self.optimizer = None
    self.optimizer_type = optimizer

    self.num_augmentations = num_augmentations

    self.lr = lr
    self.momentum = momentum


  def use_ADAM(self): self.optimizer = optim.Adam(self.model.parameters(), lr = self.lr)
  def use_SGD(self): self.optimizer = optim.SGD(self.model.parameters(), lr = self.lr, momentum = self.momentum)


  def set_optimizer(self, s):

    if   s == 'ADAM': self.use_ADAM()
    elif s == 'SGD' : self.use_SGD()
    else: print('wrong input')


  def require_model_gradients(self, state = True):
    for param in self.model.parameters(): param.requires_grad = state


  def entropy_loss_MEMO(self, logits):

      z = logits - logits.logsumexp(dim = -1, keepdim=True)     # compute z_ij
      marginal_logp = z.logsumexp(dim=0) - np.log(z.shape[0])   # compute marginal log probabilities

      min_real = torch.finfo(marginal_logp.dtype).min          # for numerical stability, The smallest representable number given the dtype of logits.
      avg_logits = torch.clamp(marginal_logp, min = min_real)  # put a threshold to avoid underflow

      return -(avg_logits * torch.exp(avg_logits)).sum(dim=-1)


  def backprop_sweep(self, image):

    self.set_optimizer(self.optimizer_type)

    try:

      self.model.train(True)
      self.optimizer.zero_grad()

      batch = self.get_augmentations(image, num_augmentations = self.num_augmentations, manual_seed = 33)

      batch_features = self.model.encode_image(batch)
      batch_features = batch_features / batch_features.norm(dim = -1, p = 2, keepdim = True)
      logits = 100.0 * batch_features @ self.text_features.T

      loss = self.entropy_loss_MEMO(logits)

      loss.backward()
      self.optimizer.step()


    except Exception as e:

      print(f"Exception{type(e).__name__} occurred. Details: {e.args}")



  def __call__(self, image):

      original_params = {name: param.clone().detach() for name, param in self.model.named_parameters()}

      self.require_model_gradients(state = True) # Require gradients to update the CLIP parameters
      self.backprop_sweep(image)

      self.require_model_gradients(state = False)

      predicted, distribution, entropy = super().__call__(image, custom_temp = 100.0)

      # Restore original parameters
      with torch.no_grad():

        for name, param in self.model.named_parameters():
          param.copy_(original_params[name])


      return predicted, distribution, entropy

## MEMO - sanity check

In [None]:
memo = MEMO_model(model, preprocess, device, num_augmentations = 100, dataset = cifar100)

In [None]:
n = 0
correct = 0

for i in tqdm(range(n)):

  image, label = cifar100[i]
  predicted_label, _, entropy = memo(image)

  if predicted_label == label: correct += 1

print(f'Correct: {correct}')

0it [00:00, ?it/s]

Correct: 0





# 4. EB class

In [None]:
class eb_model(clip_model):

  # __init__ same as clip
  def __init__(self, model, preprocess, device = "cuda" if torch.cuda.is_available() else "cpu", dataset = None,
               num_augmentations = 50, num_pick_augmenations = 15, num_candidates = 5, max_iter = 15):

    super().__init__(model, preprocess, device, dataset)

    self.num_augmentations = num_augmentations
    self.num_pick_augmenations = num_pick_augmenations
    self.num_candidates = num_candidates
    self.max_iter = max_iter




  def get_and_pick_augmentations(self, image, image_features, text_candidates):

    transformations = transforms.Compose([transforms.RandomResizedCrop(size=224, scale=(0.08, 1.0), ratio=(0.75, 1.333))])
    batch = self.get_augmentations(image, num_augmentations = self.num_augmentations, transformations = transformations)

    aug_img_features = model.encode_image(batch)
    aug_img_features = aug_img_features / aug_img_features.norm(dim = -1, p = 2, keepdim = True)

    # pick the ones closest to the original image
    similarities_aug = image_features @ aug_img_features.T
    _, pick_indeces = torch.topk(similarities_aug, self.num_pick_augmenations)

    picked_aug_features = torch.squeeze(aug_img_features[pick_indeces])
    aug_probs = self.get_prob(picked_aug_features, text_candidates)

    aug_prob = aug_probs.mean(dim = 0).squeeze()

    return aug_prob




  def __call__(self, image):

    image_features = self.get_image_features(image)
    text_candidates, label_indeces = self.get_candidates(image_features, self.text_features, self.num_candidates)

    img_prob = self.get_prob(image_features, text_candidates)
    aug_prob = None
    output_prob = img_prob.clone()

    running_predict = torch.argmax(output_prob)
    iter = 0
    while iter < self.max_iter:

      new_aug_prob = self.get_and_pick_augmentations(image, image_features, text_candidates)

      if aug_prob == None: aug_prob = new_aug_prob
      else:                aug_prob = 0.5 * aug_prob + 0.5 * new_aug_prob

      output_prob = 0.6 * img_prob + 0.4 * aug_prob # give a little more weight to the probability of the image.

      predict_step = torch.argmax(output_prob)

      # if running_predict != predict_step then it changed its mind: the class with highest probability has changed - better do further checks.
      if running_predict == predict_step: break
      running_predict = predict_step

      iter += 1


    predicted = label_indeces[predict_step].item()
    entropy = self.get_entropy(output_prob).item()

    return predicted, output_prob, entropy

## EB - sanity check

In [None]:
eb = eb_model(model, preprocess, device, dataset = cifar100)

In [None]:
n = 100
correct_eb = 0

for i in tqdm(range(n)):

  image, label = cifar100[i]
  predicted_eb, _, entropy_eb = eb(image)

  if predicted_eb == label:   correct_eb += 1


print(f'Correct: {correct_eb}')

100%|██████████| 100/100 [00:30<00:00,  3.27it/s]

Correct: 71





# tmp.

In [8]:
class features_tmp(clip_model):

  def __init__(self, model, preprocess, device = "cuda" if torch.cuda.is_available() else "cpu", dataset = None,
              num_candidates = 10, num_synonyms = 5):

    super().__init__(model, preprocess, device, dataset)
    self.classes = dataset.classes
    self.num_candidates = num_candidates

    self.num_synonyms = np.min([num_synonyms, 5])
    if num_synonyms > 6: print("num_synonims at most 5. Set to 5.")

    self.descript_features = torch.empty(len(self.classes), self.num_synonyms + 1, 512).to(device)
    if self.text_features != None: self.get_descript_features()


  def get_descript_features(self):

    # hard coded data augmentation over the synonyms to have them above a minimum threshold, as some classes don't have synonims.
    phrase_patterns = [ 'a photo of: ',
      'a picture of a ', 'a picture of: ',
      'an image of a ', 'an image of: '
    ]

    for i in tqdm(range(len(self.classes))):

      c = self.classes[i]

      words = c.split('_')

      synonyms = []
      definitions = []

      for word in words:
        definitions.extend(self.get_definitions(word))
        synonyms.extend(self.get_synonyms(word))

      strings = [f'{pattern}{s}' for s in synonyms for pattern in phrase_patterns]

      synonyms_features = self.get_closest_features(target = self.text_features[i], candidates = strings, num_candidates = self.num_synonyms)
      context_feature   = self.get_closest_features(target = self.text_features[i], candidates = definitions, num_candidates = 1)

      self.descript_features[i, :, :] = torch.cat([synonyms_features, context_feature.unsqueeze(0)])

# 5. Dictionary, Adaptive temp

In [5]:
class DAT_model(clip_model):

  def __init__(self, features, model, preprocess, device = "cuda" if torch.cuda.is_available() else "cpu", dataset = None,
               num_candidates = 10, num_synonyms = 5):

    super().__init__(model, preprocess, device, dataset)
    self.classes = dataset.classes
    self.num_candidates = num_candidates

    self.num_synonyms = np.min([num_synonyms, 5])
    if num_synonyms > 5: print("num_synonims at most 5. Set to 5.")

    self.descript_features = torch.empty(len(self.classes), self.num_synonyms + 1, 512).to(device)
    self.proj_matrices = torch.empty(len(self.classes), 512, 512).to(device)

    if self.text_features != None:
      self.get_descript_features()
      # self.descript_features = features
      self.get_proj_matrix()


  def get_descript_features(self):

    # hard coded data augmentation over the synonyms to have them above a minimum threshold, as some classes don't have synonims.
    phrase_patterns = [ 'a photo of: ',
      'a picture of a ', 'a picture of: ',
      'an image of a ', 'an image of: '
    ]

    for i in tqdm(range(len(self.classes))):

      c = self.classes[i]

      words = c.split('_')

      synonyms = []
      definitions = []

      for word in words:
        definitions.extend(self.get_definitions(word))
        synonyms.extend(self.get_synonyms(word))

      strings = [f'{pattern}{s}' for s in synonyms for pattern in phrase_patterns]

      synonyms_features = self.get_closest_features(target = self.text_features[i], candidates = strings, num_candidates = self.num_synonyms)
      context_feature   = self.get_closest_features(target = self.text_features[i], candidates = definitions, num_candidates = 1)

      self.descript_features[i, :, :] = torch.cat([synonyms_features, context_feature.unsqueeze(0)])


  def get_proj_matrix(self):

    for i in range(len(self.classes)):

      Y = self.descript_features[i, :, :]
      X = Y.to(torch.float32).T

      # Projection matrix: X (X^T X)^-1 X^T

      gram_inv = torch.linalg.inv( torch.matmul(X.T, X))
      P = torch.matmul( torch.matmul(X, gram_inv), X.T)

      self.proj_matrices[i, :, :] = P


  def __call__(self, image):


    image_features = self.get_image_features(image)
    text_candidates, label_indeces = self.get_candidates(image_features, self.text_features, self.num_candidates)

    projections = []

    for i in range(self.num_candidates):

      P = self.proj_matrices[label_indeces[i], :, :]

      project_img = torch.matmul(P, image_features.T.to(torch.float32))
      projections.append( project_img.norm(dim = 0, p = 2, keepdim = True))


    img_prob = self.get_prob(image_features, self.text_features[label_indeces])

    descript_logits = 100.0 * torch.vstack(projections).squeeze().to(device)
    descript_prob = torch.nn.functional.softmax(descript_logits, dim = 0)

    output_prob = img_prob * 0.6 + descript_prob * 0.4

    predicted_candidate = torch.argmax(output_prob)
    predicted = label_indeces[predicted_candidate]

    return predicted.item()

In [9]:
features = features_tmp(model, preprocess, num_synonyms = 5, dataset = cifar100)

100%|██████████| 100/100 [11:33<00:00,  6.94s/it]


In [6]:
dat = DAT_model(features.descript_features, model, preprocess, num_synonyms = 3, dataset = cifar100)

NameError: name 'features' is not defined

In [None]:
n = 100
correct_dat = 0

for i in tqdm(range(n)):

  image, label = cifar100[i]
  if label == dat(image): correct_dat += 1


print(f'Correct: {correct_dat}')

100%|██████████| 100/100 [00:01<00:00, 79.91it/s]

Correct: 65





In [None]:
torch.cuda.empty_cache()

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [None]:
torch.cuda.reset_peak_memory_stats()

# Original Clip class

In [None]:
class CLIPModel:

    def __init__(self, model_name='ViT-B/32', device=None):
        self.device = device if device else "cuda" if torch.cuda.is_available() else "cpu"
        self.model, self.preprocess = clip.load(model_name, self.device)
        self.model = self.convert_model_parameters_to_float32(self.model)
        self.optimizer = optim.SGD(self.model.parameters(), lr=1e-5, momentum=0.9)
        self.text_features = None
        self.requiring_grads = None
        self.logit_scale = self.model.logit_scale #temperature parameter learned by CLIP
        self.changeseed = 0 #This is to be able to do diverse random transforms but replicable

    def use_ADAM(self):
        self.optimizer = optim.Adam(self.model.parameters(), lr=1e-5)

    def use_SGD(self):
        self.optimizer = optim.SGD(self.model.parameters(), lr=1e-5, momentum=0.9)

    def require_CLIP_gradients(self, state = True):
        if self.requiring_grads is None or state != self.requiring_grads: #don't change if the state is already OK
            for param in self.model.parameters():
                param.requires_grad = state
            self.requiring_grads = state

    def convert_model_parameters_to_float32(self, model):
        for param in model.parameters():
            param.data = param.data.to(torch.float32)
        return model

    def load_data(self):
        cifar100 = torchvision.datasets.CIFAR100(root='./data', download=True, train=False)
        return cifar100

    #This are heuristic labels
    def tokenize_labels(self, classes):
        text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in classes]).to(self.device)
        with torch.no_grad():
            self.text_features = self.model.encode_text(text_inputs)
            self.text_features /= self.text_features.norm(dim=-1, p=2, keepdim=True)

    def augment_image(self, image, num_augmentations=100, transformations=None):
        if transformations==None:
            torch.manual_seed(33)#Set a seed for reproducibility of the random augmentations
            augmentations = transforms.Compose([
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomVerticalFlip(p=0.5),
                transforms.RandomRotation(degrees=30),
                transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
                transforms.RandomResizedCrop(size=224, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
            ])
        augmented_images = [self.preprocess(image).unsqueeze(0).to(self.device)] #Add the original image to the batch of augmentations
        for _ in range(num_augmentations):
            augmented_images.append(self.preprocess(augmentations(image)).unsqueeze(0).to(self.device))
        batch = torch.vstack(augmented_images)
        return batch #(num_augumentations + 1, 3, 224, 224)

    def cos_sim(self, image_features, text_features):
        return  image_features @ text_features.T

    def logits(self, image_features, text_features):
        logit_scale = self.logit_scale.exp()
        return logit_scale * self.cos_sim(image_features, text_features)

    def class_probabilities(self, image_features, text_features):
        #Compute cosine similarities
        return  self.logits(image_features, text_features).softmax(dim=-1)

    def marginal_entropy(self, logits):
        z = logits - logits.logsumexp(dim = -1, keepdim=True)     # compute z_ij
        marginal_logp = z.logsumexp(dim=0) - np.log(z.shape[0])   # compute marginal log probabilities

        min_real = torch.finfo(marginal_logp.dtype).min           # for numerical stability, the smallest representable number given the dtype of logits.
        avg_logits = torch.clamp(marginal_logp, min = min_real)   # put a threshold to avoid underflow

        return -(avg_logits * torch.exp(avg_logits)).sum(dim=-1)

    def compute_entropy(self, x): #Shanon entropy in bits
        #This computes the Shanon entropy
        log_x = torch.log2(x.clamp_min(1e-20))
        entropy = -torch.sum(x * log_x)
        return entropy

    def confidence_selection(self, probs_matrix, percentile=0.8):
        # Compute entropies for each row in the probability matrix
        entropies = torch.tensor([self.compute_entropy(row) for row in probs_matrix])

        # Find the threshold for the desired percentile
        threshold = torch.quantile(entropies, percentile, interpolation = 'linear')

        # Create a boolean mask where entropies below the threshold are selected
        boolean_mask = entropies < threshold

        # Assuming similarities is intended to be probs_matrix, return filtered matrix
        return probs_matrix[boolean_mask]

    def entropy_loss_MEMO(self, batch_features, text_features = None):
        if text_features is None:
            text_features = self.text_features
        #Logits (unnormalized probabilities)
        logits = self.logits(batch_features, text_features)
        # Compute the entropy of every text caption accross all augmentations
        marginal_entropy = self.marginal_entropy(logits)
        return marginal_entropy

    def entropy_loss_TPT(self, batch_features, text_features = None):
        if text_features is None:
            text_features = self.text_features
        probs_matrix = self.class_probabilities(batch_features, text_features)
        # Confidence selection for the augmented views:
        probs_matrix = self.confidence_selection(probs_matrix)
        # Average the caption probabilities across all augmentations
        avg_probs = torch.tensor([row.mean() for row in probs_matrix.T])
        # Compute the entropy of the averaged probability distribution
        return self.compute_entropy(avg_probs), avg_probs

    def forward(self, image):
        image_features = self.model.encode_image(image)
        norms = image_features.norm(dim=-1, p=2,  keepdim=True)
        if (norms == 0).any():
            print("Zero norm found in image features")
        image_features = image_features / norms.clamp_min(1e-10)
        return self.class_probabilities(image_features, self.text_features)

    def predict(self, image):
        self.model.eval()
        image = self.preprocess(image).unsqueeze(0).to(self.device)
        with torch.no_grad():
            probs = self.forward(image)

        prediction = torch.argmax(probs).item()
        entropy = float(self.compute_entropy(probs))
        return prediction, probs.squeeze(), entropy

    def grad_descent_step(self, loss):
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def MEMO(self, image, num_augmentations=100):
        # Save original parameters
        original_params = {name: param.clone() for name, param in self.model.named_parameters()}

        # Require gradients to update the CLIP parameters
        self.require_CLIP_gradients(state = True)

        try:
            self.model.train()
            batch = self.augment_image(image, num_augmentations)
            batch_features = self.model.encode_image(batch)
            norms = batch_features.norm(dim=-1, p=2, keepdim=True)
            if (norms == 0).any():
                print("Zero norm found in image features")
            batch_features = batch_features / norms.clamp_min(1e-10)
            loss = self.entropy_loss_MEMO(batch_features)
            self.grad_descent_step(loss)

            if any(torch.isnan(param).any() for param in self.model.parameters()):
                print("nan values detected in model parameters after updating")
            # Predict using the updated model
            prediction, probs, entropy = self.predict(image)
        finally:
            # Restore original parameters
            with torch.no_grad():
                for name, param in self.model.named_parameters():
                    param.copy_(original_params[name])
        return prediction, probs.squeeze(), entropy

    def TPT(self, image, num_augmentations=100):
        self.model.eval()
        self.require_CLIP_gradients(False)
        batch = self.augment_image(image, num_augmentations)
        batch_features = self.model.encode_image(batch)
        norms = batch_features.norm(dim=-1, p=2, keepdim=True)
        if (norms == 0).any():
            print("Zero norm found in image features")
        batch_features = batch_features / norms.clamp_min(1e-10)

        entropy, avg_probs = self.entropy_loss_TPT(batch_features)
        prediction = torch.argmax(avg_probs).item()
        return prediction, avg_probs.squeeze(), float(entropy)

    #Implementing the Entropy Boost stuff:
    def pick_candidates(self, tensor, classifier, top_num):
        # to select a subset of "candidates" from a given tensor based on scores provided by a classifier
        _, top_indices = torch.topk(classifier, top_num)
        candidates = torch.squeeze(tensor[top_indices])

        return candidates, top_indices

    def expand_tensor(self, tensor, top_indices, n):

        exp_tensor = torch.zeros(n).to(self.device)
        for i in range(top_indices.shape[0]): exp_tensor[top_indices[i]] = tensor[i]

        return exp_tensor

    def generate_augmentations_similarities(self, image, image_features, txt_candidates, num_augmentations, top_aug_num):
        # augment
        torch.manual_seed(33+self.changeseed)
        augmentations = transforms.Compose([
                                transforms.RandomResizedCrop(size=224, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
                                ])
        batch = self.augment_image(image, num_augmentations)
        batch_features = self.model.encode_image(batch)
        norms = batch_features.norm(dim=-1, p=2, keepdim=True)
        if (norms == 0).any():
            print("Zero norm found in image features")
        batch_features = batch_features / norms.clamp_min(1e-10)


        # pick the ones closest to the original image
        probs_matrix = self.class_probabilities(image_features, batch_features).squeeze()
        candidates_features, top_indices_aug = self.pick_candidates(batch_features, probs_matrix, top_num = top_aug_num)

        candidates_probs_matrix = self.class_probabilities(candidates_features, txt_candidates).squeeze()

        del batch # avoid Cuda to run out of memory?
        return torch.mean(candidates_probs_matrix, dim=0)

    def boost_augmentations(self, image, image_features, prob, num_augmentations, num_candidates, top_aug_num):

        _, top_indices = torch.topk(prob, num_candidates)
        text_candidates = self.text_features[top_indices]

        n = prob.shape[0]

        #Candidates probabilities in the original image
        candidates_prob_og = self.class_probabilities(image_features, text_candidates).squeeze()
        candidates_prob_og = self.expand_tensor(candidates_prob_og, top_indices, n)

        #Candidates averaged probability in the augmentations
        candidates_avg_prob = self.generate_augmentations_similarities(image, image_features, text_candidates, num_augmentations, top_aug_num)
        candidates_avg_prob = self.expand_tensor(candidates_avg_prob, top_indices, n)

        return candidates_prob_og, candidates_avg_prob

    def entropyboosting(self, image, num_augmentations = 100, num_candidates = 10, top_aug_num = 5):
        self.model.eval()
        self.require_CLIP_gradients(False)
        if num_candidates <= 1:
            print("If num_candidates=0 this is just CLIP")
            return
        image_prepro = self.preprocess(image).unsqueeze(0).to(self.device)
        image_features = self.model.encode_image(image_prepro)
        norms = image_features.norm(dim=-1, p=2,  keepdim=True)
        image_features = image_features / norms.clamp_min(1e-10)

        #Compute out of the box CLIP probability distribution and prediction
        clip_probs = self.class_probabilities(image_features, self.text_features).squeeze()
        clip_prediction = torch.argmax(clip_probs).item()

        phi = 2 / (1 + np.sqrt(5)) # Aura section - math fetish but it seems to work so...

        # Defining loop variables
        aug_prob = None #Should this be an input of the method??
        output_prob = clip_probs.clone() #Is this necessary??
        max_iter = 4
        iter = 0
        while iter < max_iter:
            self.changeseed = self.changeseed + 1 #To make different random transformations at every iter but replicable
            candidates_prob_og, candidates_avg_prob = self.boost_augmentations(image, image_features, output_prob, num_augmentations, num_candidates, top_aug_num)

            clip_probs = phi * clip_probs + (1 - phi) * candidates_prob_og #Why are we doing this to the clip probs?

            if aug_prob == None: aug_prob = candidates_avg_prob
            else:                aug_prob = 0.5 * aug_prob + 0.5 * candidates_avg_prob

            output_prob = 0.6 * clip_probs + 0.4 * aug_prob # give a little more weight to the probability of the image.
            EB_prediction = torch.argmax(output_prob).item()

            # if clip_prediction != EB_prediction then it changed its mind: the class with highest probability has changed - better do further checks.
            if clip_prediction == EB_prediction: break

            iter += 1
        self.changeseed = 0
        entropy = self.compute_entropy(output_prob)
        return EB_prediction, output_prob.squeeze(), float(entropy)


# Preparing the class for usage
clip_model = CLIPModel()

## Good afternoon :)

In [None]:
def load_cifar100():
    cifar100 = torchvision.datasets.CIFAR100(root='./data', download=True, train=False)
    return cifar100

def cos_sim(image_features, text_features):
    return  image_features @ text_features.T

def logits(image_features, text_features, logit_scale):
    logit_scale = logit_scale.exp()
    return logit_scale * cos_sim(image_features, text_features)

def class_probabilities(image_features, text_features, logit_scale):
    return  logits(image_features, text_features, logit_scale).softmax(dim=-1)

def marginal_entropy(logits):
    z = logits - logits.logsumexp(dim = -1, keepdim=True)     # compute z_ij
    marginal_logp = z.logsumexp(dim=0) - np.log(z.shape[0])   # compute marginal log probabilities

    min_real = torch.finfo(marginal_logp.dtype).min           # for numerical stability, the smallest representable number given the dtype of logits.
    avg_logits = torch.clamp(marginal_logp, min = min_real)   # put a threshold to avoid underflow

    return -(avg_logits * torch.exp(avg_logits)).sum(dim=-1)

def compute_entropy(x): #Shanon entropy in bits
    log_x = torch.log2(x.clamp_min(1e-20))
    entropy = -torch.sum(x * log_x)
    return entropy

class CLIPModel:

    def __init__(self, model_name='ViT-B/32', device=None):
        self.device = device if device else "cuda" if torch.cuda.is_available() else "cpu"
        self.model, self.preprocess = clip.load(model_name, self.device)
        self.model = self.convert_model_parameters_to_float32(self.model)
        self.optimizer = optim.SGD(self.model.parameters(), lr=1e-5, momentum=0.9)
        self.text_features = None
        self.requiring_grads = None
        self.scale = self.model.logit_scale #temperature parameter learned by CLIP
        self.changeseed = 0 #This is to be able to do diverse random transforms but replicable

    def use_ADAM(self):
        self.optimizer = optim.Adam(self.model.parameters(), lr=1e-5)

    def use_SGD(self):
        self.optimizer = optim.SGD(self.model.parameters(), lr=1e-5, momentum=0.9)

    def require_CLIP_gradients(self, state = True):
        if self.requiring_grads is None or state != self.requiring_grads: #don't change if the state is already OK
            for param in self.model.parameters():
                param.requires_grad = state
            self.requiring_grads = state

    def convert_model_parameters_to_float32(self, model):
        for param in model.parameters():
            param.data = param.data.to(torch.float32)
        return model

    def tokenize_labels(self, classes): #We use heuristic labels
        text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in classes]).to(self.device)
        with torch.no_grad():
            self.text_features = self.model.encode_text(text_inputs)
            self.text_features /= self.text_features.norm(dim=-1, p=2, keepdim=True)

    def grad_descent_step(self, loss):
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def forward(self, image):
        image_features = self.model.encode_image(image)
        norms = image_features.norm(dim=-1, p=2,  keepdim=True)
        if (norms == 0).any():
            print("Zero norm found in image features")
        image_features = image_features / norms.clamp_min(1e-10)
        return class_probabilities(image_features, self.text_features, self.scale)

    def predict(self, image):
        self.model.eval()
        image = self.preprocess(image).unsqueeze(0).to(self.device)
        with torch.no_grad():
            probs = self.forward(image)

        prediction = torch.argmax(probs).item()
        entropy = float(compute_entropy(probs))
        return prediction, probs.squeeze(), entropy



In [None]:
class MEMO:

    def __init__(self, CLIP):
        self.clip = CLIP

    def augment_image(self, image, num_augmentations=100, transformations=None):
        if transformations==None:
            torch.manual_seed(33)#Set a seed for reproducibility of the random augmentations
            augmentations = transforms.Compose([
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomVerticalFlip(p=0.5),
                transforms.RandomRotation(degrees=30),
                transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
                transforms.RandomResizedCrop(size=224, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
            ])
        augmented_images = [self.clip.preprocess(image).unsqueeze(0).to(self.clip.device)] #Add the original image to the batch of augmentations
        for _ in range(num_augmentations):
            augmented_images.append(self.clip.preprocess(augmentations(image)).unsqueeze(0).to(self.clip.device))
        batch = torch.vstack(augmented_images)
        return batch #(num_augumentations + 1, 3, 224, 224)

    def entropy_loss_MEMO(self, batch_features, text_features = None):
        if text_features is None:
            text_features = self.clip.text_features
        #Logits (unnormalized probabilities)
        logits = logits(batch_features, text_features, self.clip.scale)
        # Compute the entropy of every text caption accross all augmentations
        marginal_entropy = marginal_entropy(logits)
        return marginal_entropy

    def MEMO(self, image, num_augmentations=100):
        # Save original parameters
        original_params = {name: param.clone() for name, param in self.clip.model.named_parameters()}

        # Require gradients to update the CLIP parameters
        self.clip.require_CLIP_gradients(state = True)

        try:
            self.clip..model.train()
            batch = self.augment_image(image, num_augmentations)
            batch_features = self.clip.model.encode_image(batch)
            norms = batch_features.norm(dim=-1, p=2, keepdim=True)
            if (norms == 0).any():
                print("Zero norm found in image features")
            batch_features = batch_features / norms.clamp_min(1e-10)
            loss = self.entropy_loss_MEMO(batch_features)
            self.clip.grad_descent_step(loss)

            if any(torch.isnan(param).any() for param in self.clip.model.parameters()):
                print("nan values detected in model parameters after updating")
            # Predict using the updated model
            prediction, probs, entropy = self.clip.predict(image)
        finally:
            # Restore original parameters
            with torch.no_grad():
                for name, param in self.clip.model.named_parameters():
                    param.copy_(original_params[name])
        return prediction, probs.squeeze(), entropy


In [None]:
class TPT:

    def __init__(self, CLIP):
        self.clip = CLIP

    def augment_image(self, image, num_augmentations=100, transformations=None):
        if transformations==None:
            torch.manual_seed(33)#Set a seed for reproducibility of the random augmentations
            augmentations = transforms.Compose([
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomVerticalFlip(p=0.5),
                transforms.RandomRotation(degrees=30),
                transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
                transforms.RandomResizedCrop(size=224, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
            ])
        augmented_images = [self.clip.preprocess(image).unsqueeze(0).to(self.clip.device)] #Add the original image to the batch of augmentations
        for _ in range(num_augmentations):
            augmented_images.append(self.clip.preprocess(augmentations(image)).unsqueeze(0).to(self.clip.device))
        batch = torch.vstack(augmented_images)
        return batch #(num_augumentations + 1, 3, 224, 224)

    def confidence_selection(self, probs_matrix, percentile=0.8):
        # Compute entropies for each row in the probability matrix
        entropies = torch.tensor([compute_entropy(row) for row in probs_matrix])

        # Find the threshold for the desired percentile
        threshold = torch.quantile(entropies, percentile, interpolation = 'linear')

        # Create a boolean mask where entropies below the threshold are selected
        boolean_mask = entropies < threshold

        # Assuming similarities is intended to be probs_matrix, return filtered matrix
        return probs_matrix[boolean_mask]

    def entropy_loss_TPT(self, batch_features, text_features = None):
        if text_features is None:
            text_features = self.clip.text_features
        probs_matrix = class_probabilities(batch_features, text_features, self.clip.scale)
        # Confidence selection for the augmented views:
        probs_matrix = self.confidence_selection(probs_matrix)
        # Average the caption probabilities across all augmentations
        avg_probs = torch.tensor([row.mean() for row in probs_matrix.T])
        # Compute the entropy of the averaged probability distribution
        return compute_entropy(avg_probs), avg_probs

    def TPT(self, image, num_augmentations=100):
        self.clip.model.eval()
        self.clip.require_CLIP_gradients(False)
        batch = self.augment_image(image, num_augmentations)
        batch_features = self.clip.model.encode_image(batch)
        norms = batch_features.norm(dim=-1, p=2, keepdim=True)
        if (norms == 0).any():
            print("Zero norm found in image features")
        batch_features = batch_features / norms.clamp_min(1e-10)

        entropy, avg_probs = self.entropy_loss_TPT(batch_features)
        prediction = torch.argmax(avg_probs).item()
        return prediction, avg_probs.squeeze(), float(entropy)

In [None]:
class EB:

    def __init__(self, CLIP):
        self.clip = CLIP

    def augment_image(self, image, num_augmentations=100, transformations=None):
        if transformations==None:
            torch.manual_seed(33)#Set a seed for reproducibility of the random augmentations
            augmentations = transforms.Compose([
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomVerticalFlip(p=0.5),
                transforms.RandomRotation(degrees=30),
                transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
                transforms.RandomResizedCrop(size=224, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
            ])
        augmented_images = [self.clip.preprocess(image).unsqueeze(0).to(self.clip.device)] #Add the original image to the batch of augmentations
        for _ in range(num_augmentations):
            augmented_images.append(self.clip.preprocess(augmentations(image)).unsqueeze(0).to(self.clip.device))
        batch = torch.vstack(augmented_images)
        return batch #(num_augumentations + 1, 3, 224, 224)

    #Implementing the Entropy Boost stuff:
    def pick_candidates(self, tensor, classifier, top_num):
        # to select a subset of "candidates" from a given tensor based on scores provided by a classifier
        _, top_indices = torch.topk(classifier, top_num)
        candidates = torch.squeeze(tensor[top_indices])

        return candidates, top_indices

    def expand_tensor(self, tensor, top_indices, n):

        exp_tensor = torch.zeros(n).to(self.clip.device)
        for i in range(top_indices.shape[0]): exp_tensor[top_indices[i]] = tensor[i]

        return exp_tensor

    def generate_augmentations_similarities(self, image, image_features, txt_candidates, num_augmentations, top_aug_num):
        # augment
        torch.manual_seed(33+self.clip.changeseed)
        augmentations = transforms.Compose([
                                transforms.RandomResizedCrop(size=224, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
                                ])
        batch = self.augment_image(image, num_augmentations)
        batch_features = self.clip.model.encode_image(batch)
        norms = batch_features.norm(dim=-1, p=2, keepdim=True)
        if (norms == 0).any():
            print("Zero norm found in image features")
        batch_features = batch_features / norms.clamp_min(1e-10)


        # pick the ones closest to the original image
        probs_matrix = class_probabilities(image_features, batch_features, self.clip.scale).squeeze()
        candidates_features, top_indices_aug = self.pick_candidates(batch_features, probs_matrix, top_num = top_aug_num)

        candidates_probs_matrix = class_probabilities(candidates_features, txt_candidates, self.clip.scale).squeeze()

        del batch # avoid Cuda to run out of memory?
        return torch.mean(candidates_probs_matrix, dim=0)

    def boost_augmentations(self, image, image_features, prob, num_augmentations, num_candidates, top_aug_num):

        _, top_indices = torch.topk(prob, num_candidates)
        text_candidates = self.clip.text_features[top_indices]

        n = prob.shape[0]

        #Candidates probabilities in the original image
        candidates_prob_og = class_probabilities(image_features, text_candidates, self.clip.scale).squeeze()
        candidates_prob_og = self.expand_tensor(candidates_prob_og, top_indices, n)

        #Candidates averaged probability in the augmentations
        candidates_avg_prob = self.generate_augmentations_similarities(image, image_features, text_candidates, num_augmentations, top_aug_num)
        candidates_avg_prob = self.expand_tensor(candidates_avg_prob, top_indices, n)

        return candidates_prob_og, candidates_avg_prob

    def entropyboosting(self, image, num_augmentations = 100, num_candidates = 10, top_aug_num = 5):
        self.clip.model.eval()
        self.clip.require_CLIP_gradients(False)
        if num_candidates <= 1:
            print("If num_candidates=0 this is just CLIP")
            return
        image_prepro = self.clip.preprocess(image).unsqueeze(0).to(self.clip.device)
        image_features = self.clip.model.encode_image(image_prepro)
        norms = image_features.norm(dim=-1, p=2,  keepdim=True)
        image_features = image_features / norms.clamp_min(1e-10)

        #Compute out of the box CLIP probability distribution and prediction
        clip_probs = class_probabilities(image_features, self.clip.text_features, self.clip.scale).squeeze()
        clip_prediction = torch.argmax(clip_probs).item()

        phi = 2 / (1 + np.sqrt(5)) # Aura section - math fetish but it seems to work so...

        # Defining loop variables
        aug_prob = None #Should this be an input of the method??
        output_prob = clip_probs.clone() #Is this necessary??
        max_iter = 4
        iter = 0
        while iter < max_iter:
            self.clip.changeseed = self.clip.changeseed + 1 #To make different random transformations at every iter but replicable
            candidates_prob_og, candidates_avg_prob = self.boost_augmentations(image, image_features, output_prob, num_augmentations, num_candidates, top_aug_num)

            clip_probs = phi * clip_probs + (1 - phi) * candidates_prob_og #Why are we doing this to the clip probs?

            if aug_prob == None: aug_prob = candidates_avg_prob
            else:                aug_prob = 0.5 * aug_prob + 0.5 * candidates_avg_prob

            output_prob = 0.6 * clip_probs + 0.4 * aug_prob # give a little more weight to the probability of the image.
            EB_prediction = torch.argmax(output_prob).item()

            # if clip_prediction != EB_prediction then it changed its mind: the class with highest probability has changed - better do further checks.
            if clip_prediction == EB_prediction: break

            iter += 1
        self.clip.changeseed = 0
        entropy = compute_entropy(output_prob)
        return EB_prediction, output_prob.squeeze(), float(entropy)