<a href="https://colab.research.google.com/github/javiimo/ImageClassificationAssignment/blob/main/CLIPClass.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
! pip install ftfy regex tqdm
! pip install git+https://github.com/openai/CLIP.git
! pip install datasets transformers


import clip
import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import copy
import numpy as np





Collecting ftfy
  Downloading ftfy-6.2.0-py3-none-any.whl (54 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/54.4 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━[0m [32m51.2/54.4 kB[0m [31m1.7 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.4/54.4 kB[0m [31m1.1 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: ftfy
Successfully installed ftfy-6.2.0
Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-q5o7ts7l
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-q5o7ts7l
  Resolved https://github.com/openai/CLIP.git to commit a1d071733d7111c9c014f024669f959182114e33
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch->clip==1.0)
  Using cached nvidia_cuda_nvr

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
from torch.cuda.amp import autocast, GradScaler

In [None]:
class CLIPModel:

    def __init__(self, model_name='ViT-B/32', device=None):
        self.device = device if device else "cuda" if torch.cuda.is_available() else "cpu"
        self.model, self.preprocess = clip.load(model_name, self.device)
        #self.model = self.convert_model_parameters_to_float32(self.model)
        self.optimizer = optim.SGD(self.model.parameters(), lr=1e-5, momentum=0.9)
        self.text_features = None
        self.requiring_grads = None
        self.logit_scale = self.model.logit_scale

    def use_ADAM(self):
        self.optimizer = optim.Adam(self.model.parameters(), lr=1e-5)

    def use_SGD(self):
        self.optimizer = optim.SGD(self.model.parameters(), lr=1e-5, momentum=0.9)

    def require_CLIP_gradients(self, state = True):
        if self.requiring_grads is None or state != self.requiring_grads: #don't change if the state is already OK
            for param in self.model.parameters():
                param.requires_grad = state
            self.requiring_grads = state

    def convert_model_parameters_to_float32(self, model):
        for param in model.parameters():
            param.data = param.data.to(torch.float32)
        return model

    def load_data(self):
        cifar100 = torchvision.datasets.CIFAR100(root='./data', download=True, train=False)
        return cifar100

    #This are heuristic labels
    def tokenize_labels(self, classes):
        text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in classes]).to(self.device)
        with torch.no_grad():
            self.text_features = self.model.encode_text(text_inputs)
            self.text_features /= self.text_features.norm(dim=-1, p=2, keepdim=True)

    def augment_image(self, image, num_augmentations=100, transformations=None):
        torch.manual_seed(33)#Set a seed for reproducibility of the random augmentations
        if transformations==None:
            augmentations = transforms.Compose([
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomVerticalFlip(p=0.5),
                transforms.RandomRotation(degrees=30),
                transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
                transforms.RandomResizedCrop(size=224, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
            ])
        augmented_images = [self.preprocess(image).unsqueeze(0).to(self.device)] #Add the original image to the batch of augmentations
        for _ in range(num_augmentations):
            augmented_images.append(self.preprocess(augmentations(image)).unsqueeze(0).to(self.device))
        batch = torch.vstack(augmented_images)
        return batch #(num_augumentations + 1, 3, 224, 224)

    def cos_sim(self, image_features, text_features):
        return  image_features @ text_features.T

    def logits(self, text_features, image_features):
        logit_scale = self.logit_scale.exp()
        return logit_scale * self.cos_sim(image_features, text_features)

    def class_probabilities(self, text_features, image_features):
        #Compute cosine similarities
        return  self.logits(text_features, image_features).softmax(dim=-1)

    def marginal_entropy(self, logits):
        z = logits - logits.logsumexp(dim = -1, keepdim=True) # compute z_ij
        marginal_logp = z.logsumexp(dim=0) - np.log(z.shape[0])   # compute marginal log probabilities

        min_real = torch.finfo(marginal_logp.dtype).min # for numerical stability,
        # the smallest representable number given the dtype of logits.
        avg_logits = torch.clamp(marginal_logp, min = min_real)  # put a threshold to avoid underflow

        return -(avg_logits * torch.exp(avg_logits)).sum(dim=-1)

    def compute_entropy(self, x): #Shanon entropy in bits
        #This computes the Shanon entropy
        log_x = torch.log2(x.clamp_min(1e-20))
        entropy = -torch.sum(x * log_x)
        return entropy

    def confidence_selection(self, probs_matrix, percentile=0.8):
        # Compute entropies for each row in the probability matrix
        entropies = torch.tensor([self.compute_entropy(row) for row in probs_matrix])

        # Find the threshold for the desired percentile
        threshold = torch.quantile(entropies, percentile, interpolation = 'linear')

        # Create a boolean mask where entropies below the threshold are selected
        boolean_mask = entropies < threshold

        # Assuming similarities is intended to be probs_matrix, return filtered matrix
        return probs_matrix[boolean_mask]

    def entropy_loss_MEMO(self, batch_features, text_features = None):
        if text_features is None:
            text_features = self.text_features
        #Logits (unnormalized probabilities)
        logits = self.logits(text_features, batch_features)
        # Compute the entropy of every text caption accross all augmentations
        marginal_entropy = self.marginal_entropy(logits)
        return marginal_entropy

    def entropy_loss_TPT(self, batch_features, text_features = None):
        if text_features is None:
            text_features = self.text_features
        probs_matrix = self.class_probabilities(text_features, batch_features)
        # Confidence selection for the augmented views:
        probs_matrix = self.confidence_selection(probs_matrix)
        # Average the caption probabilities across all augmentations
        avg_probs = torch.tensor([row.mean() for row in probs_matrix.T])
        # Compute the entropy of the averaged probability distribution
        return self.compute_entropy(avg_probs), avg_probs

    def forward(self, image):
        image_features = self.model.encode_image(image)
        norms = image_features.norm(dim=-1, p=2,  keepdim=True)
        if (norms == 0).any():
            print("Zero norm found in image features")
        image_features = image_features / norms.clamp_min(1e-10)
        return self.class_probabilities(self.text_features, image_features)

    def predict(self, image):
        self.model.eval()
        image = self.preprocess(image).unsqueeze(0).to(self.device)
        with torch.no_grad():
            probs = self.forward(image)

        prediction = torch.argmax(probs).item()
        entropy = float(self.compute_entropy(probs))
        return prediction, probs, entropy

    def grad_descent_step(self, loss):
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def MEMO(self, image, num_augmentations=100):
        # Save original parameters
        original_params = {name: param.clone() for name, param in self.model.named_parameters()}

        # Require gradients to update the CLIP parameters
        self.require_CLIP_gradients(state = True)

        scaler = GradScaler(growth_interval=1000)  # Initialize GradScaler for mixed precision
        try:
            self.model.train()
            batch = self.augment_image(image, num_augmentations)

            with autocast(device_type='cuda', dtype=torch.float16):  # Enable automatic mixed precision
                batch_features = self.model.encode_image(batch)
                norms = batch_features.norm(dim=-1, p=2, keepdim=True)
                if (norms == 0).any():
                    print("Zero norm found in image features")
                batch_features = batch_features / norms.clamp_min(1e-10)
                loss = self.entropy_loss_MEMO(batch_features)
            #self.grad_descent_step(loss)
            scaler.scale(loss).backward()  # Scale the loss and compute scaled gradients
            scaler.step(self.optimizer)    # Update weights
            scaler.update()                # Update the scale for next iteration
            self.optimizer.zero_grad()     # Clear gradients after update

            if any(torch.isnan(param).any() for param in self.model.parameters()):
                print("nan values detected in model parameters after updating")
            # Predict using the updated model
            prediction, probs, entropy = self.predict(image)
        finally:
            # Restore original parameters
            with torch.no_grad():
                for name, param in self.model.named_parameters():
                    param.copy_(original_params[name])
        return prediction, probs, entropy

    def TPT(self, image, num_augmentations=100):
        batch = self.augment_image(image, num_augmentations)
        batch_features = self.model.encode_image(batch)
        norms = batch_features.norm(dim=-1, p=2, keepdim=True)
        if (norms == 0).any():
            print("Zero norm found in image features")
        batch_features = batch_features / norms.clamp_min(1e-10)

        entropy, avg_probs = self.entropy_loss_TPT(batch_features)
        prediction = torch.argmax(avg_probs).item()
        return prediction, avg_probs, float(entropy)


# Preparing the class for usage
clip_model = CLIPModel()

100%|███████████████████████████████████████| 338M/338M [00:04<00:00, 74.7MiB/s]


1 IMAGE TRIES!!

In [None]:
#Loading CIFAR100 for one image tries_
cifar100 = clip_model.load_data()
clip_model.tokenize_labels(cifar100.classes)
image, class_id = cifar100[3637]
len(cifar100.classes)

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./data/cifar-100-python.tar.gz


100%|██████████| 169001437/169001437 [00:10<00:00, 15564992.03it/s]


Extracting ./data/cifar-100-python.tar.gz to ./data


100

In [None]:
# Prediction using CLIP out of the box
prediction1, probs1, entropy1 = clip_model.predict(image)
print(prediction1)
print(entropy1)
print(probs1)

78
2.3266122341156006
tensor([[1.1976e-03, 5.4628e-03, 4.8137e-03, 1.8082e-04, 1.8419e-04, 3.4514e-04,
         5.7728e-04, 8.2104e-03, 9.9321e-05, 3.7704e-04, 1.0402e-03, 8.2096e-04,
         3.4284e-05, 1.0731e-04, 1.9523e-03, 1.4169e-04, 1.7230e-03, 4.6303e-05,
         4.2151e-03, 9.5373e-04, 2.7544e-04, 1.2663e-03, 3.1628e-04, 5.2210e-05,
         6.0814e-04, 1.6915e-04, 7.3051e-03, 1.7485e-02, 5.9205e-04, 4.7085e-03,
         1.2483e-05, 4.3250e-04, 8.4594e-03, 1.0473e-04, 1.4937e-05, 5.7395e-04,
         9.1999e-04, 1.1684e-04, 1.0331e-04, 2.5124e-04, 1.1974e-03, 6.8168e-03,
         1.7164e-02, 8.7137e-04, 1.8750e-02, 1.4284e-04, 3.7086e-04, 1.1404e-04,
         1.0234e-04, 4.6625e-04, 6.6154e-03, 2.9241e-03, 6.3652e-04, 7.6153e-04,
         1.2287e-04, 1.4347e-04, 8.5426e-04, 1.4071e-04, 2.5623e-04, 1.5453e-03,
         8.2367e-04, 1.7338e-03, 5.6944e-04, 3.3850e-04, 4.9928e-04, 6.9673e-04,
         4.8177e-04, 6.3735e-04, 7.4545e-06, 1.2440e-05, 1.7305e-04, 7.2545e-05,
      

In [None]:
#Prediction using MEMO with SGD
clip_model.use_SGD()
prediction2, probs2, entropy2 = clip_model.MEMO(image, num_augmentations=10)
print(prediction2)
print(entropy2)
print(probs2)



TypeError: autocast.__init__() got an unexpected keyword argument 'device_type'

In [None]:
#Prediction using MEMO with ADAM
clip_model.use_ADAM()
clip_model.tokenize_labels(cifar100.classes) #You have to tokenize the labels again for the change in precision
prediction3, probs3, entropy3 = clip_model.MEMO(image, num_augmentations=10)
print(prediction3)
print(entropy3)
print(probs3)

78
0.3890048563480377
tensor([[7.0291e-05, 1.2202e-03, 2.0966e-03, 2.8387e-05, 4.4689e-06, 2.8498e-05,
         3.2856e-05, 6.2306e-05, 1.0148e-05, 4.9513e-05, 2.7819e-04, 5.3239e-04,
         6.4909e-06, 6.8182e-06, 9.8106e-05, 1.4464e-06, 4.3893e-04, 2.5695e-06,
         4.9707e-04, 5.9818e-05, 5.5518e-06, 6.2155e-05, 4.5030e-06, 1.0935e-05,
         5.5434e-06, 4.1206e-06, 2.6603e-05, 4.4462e-04, 3.8518e-04, 2.5397e-04,
         1.2974e-06, 1.2624e-05, 4.2114e-05, 4.5163e-05, 2.7994e-06, 4.2294e-04,
         7.2895e-05, 4.4254e-05, 1.0838e-05, 1.8192e-06, 3.2609e-05, 2.0624e-03,
         5.1894e-04, 3.0728e-05, 3.5041e-03, 1.8276e-06, 1.2709e-04, 2.3962e-05,
         5.4501e-06, 1.4678e-05, 1.1937e-03, 3.1152e-05, 1.7511e-04, 5.6660e-05,
         2.8462e-05, 1.3722e-05, 4.6307e-05, 2.8665e-06, 9.7663e-06, 8.9210e-04,
         2.8619e-04, 3.3893e-04, 2.8192e-06, 5.9116e-06, 1.3063e-04, 8.1568e-05,
         3.1866e-05, 1.4796e-05, 5.0939e-07, 1.6153e-06, 1.7927e-05, 1.1527e-05,
      

In [None]:
# Prediction using TPT
prediction4,prob_avg, entropy4 = clip_model.TPT(image, num_augmentations=10)
print(prediction4)
print(entropy4)
print(prob_avg)

59
5.61142635345459
tensor([0.0064, 0.0020, 0.0094, 0.0159, 0.0024, 0.0015, 0.0036, 0.0047, 0.0037,
        0.0079, 0.0029, 0.0115, 0.0021, 0.0032, 0.0041, 0.0024, 0.0094, 0.0036,
        0.0035, 0.0114, 0.0044, 0.0107, 0.0029, 0.0023, 0.0009, 0.0038, 0.0035,
        0.0242, 0.0036, 0.0294, 0.0014, 0.0065, 0.0031, 0.0114, 0.0015, 0.0182,
        0.0032, 0.0038, 0.0027, 0.0075, 0.0080, 0.0273, 0.0121, 0.0070, 0.0103,
        0.0014, 0.0164, 0.0213, 0.0034, 0.0102, 0.0039, 0.0064, 0.0574, 0.0015,
        0.0009, 0.0005, 0.0391, 0.0023, 0.0041, 0.1248, 0.0108, 0.0077, 0.0039,
        0.0030, 0.0048, 0.0042, 0.0057, 0.0013, 0.0009, 0.0124, 0.0019, 0.0014,
        0.0031, 0.0034, 0.0014, 0.0026, 0.0052, 0.0044, 0.1003, 0.0022, 0.0020,
        0.0020, 0.0149, 0.0114, 0.0022, 0.0200, 0.0051, 0.0162, 0.0090, 0.0066,
        0.0070, 0.0077, 0.0038, 0.0256, 0.0022, 0.0030, 0.0308, 0.0060, 0.0206,
        0.0081])


TESTSSSS


In [None]:
from torchvision import datasets
#imagenetv2 = datasets.ImageFolder(root='/content/drive/MyDrive/Petaloso Project/Code/Datasets/imagenetv2')
#imageneta = datasets.ImageFolder(root='/content/drive/MyDrive/Petaloso Project/Code/Datasets/imagenet-a')

In [None]:
# Set the class names for imagenet-A
def classnames_imagenetA():
    # Define the path to the words file
    file_path = '/content/drive/MyDrive/Petaloso Project/Code/Datasets/words_imageneta.txt'

    # Initialize an empty list to store the class names
    class_names = []

    # Open and read the file line by line
    with open(file_path, 'r') as file:
        for line in file:
            # Split each line into wnid and class name, and strip to remove any leading/trailing whitespace
            parts = line.strip().split(' ', 1)
            if len(parts) > 1:
                # Append only the class name (second part) to the list
                class_names.append(parts[1])
    return class_names

In [None]:
#Implementing subsets
import numpy as np
from torch.utils.data import Subset
def create_stratified_subset(dataset, num_samples_per_class=5):
    # Fix the random seed for reproducibility
    torch.manual_seed(0)
    np.random.seed(0)

    # Determine class indices
    targets = np.array([s[1] for s in dataset.samples])
    classes, class_indices = np.unique(targets, return_inverse=True)

    # Select samples from each class
    indices = []
    for c in classes:
        class_idx = np.where(class_indices == c)[0]
        if len(class_idx) >= num_samples_per_class:
            selected_indices = np.random.choice(class_idx, num_samples_per_class, replace=False)
            indices.extend(selected_indices)
        else:
            # If a class has fewer than the desired number, take all
            indices.extend(class_idx)

    # Create subset
    subset = Subset(dataset, indices)
    return subset

subset = create_stratified_subset(imageneta, num_samples_per_class=5)

In [None]:
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm

imageneta = datasets.ImageFolder(root='/content/drive/MyDrive/Petaloso Project/Code/Datasets/imagenet-a')

def testing(dataset, model,method='CLIP', batch_size=32, num_aug=100):
    model.tokenize_labels(classnames_imagenetA())

    def custom_collate_fn(batch):
        # Extract images and labels from the batch
        images = [item[0] for item in batch]  # PIL images
        labels = [item[1] for item in batch]  # Corresponding labels
        return images, labels
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=custom_collate_fn)

    # Initialize lists to store results
    correct_predictions = 0
    total_predictions = 0
    entropies = []
    confidences = []

    # Evaluation loop
    for images, labels in tqdm(dataloader):
        for image, label in zip(images, labels):
            try:
                # We choose how the test time predictions are made
                if method == 'CLIP':
                    prediction, probs, entropy = model.predict(image)
                elif method == 'MEMO':
                    prediction, probs, entropy = model.MEMO(image, num_augmentations=num_aug)
                elif method == 'MEMO_CONF':
                    prediction, probs, entropy = model.MEMO(image, num_augmentations=num_aug, conf_sel = False)
                elif method == 'TPT':
                    prediction, probs, entropy = model.TPT(image, num_augmentations=num_aug)
                else:
                    print('Enter a valid method for testing.')

                if int(prediction) == int(label):
                    correct_predictions += 1
                total_predictions += 1
                entropies.append(entropy)
                confidences.append(torch.max(probs).item())
            except Exception as e:
                print(f"An error occurred: {e}")

    # Post evaluation statistics or processing
    accuracy = (correct_predictions / total_predictions) * 100
    average_entropy = sum(entropies) / len(entropies)
    average_confidence = sum(confidences) / len(confidences)
    print(f'Accuracy: {accuracy:.2f}%')
    print(f'Average entropy across all predictions: {average_entropy:.2f}')


testing(subset,clip_model, method = 'CLIP', batch_size=35)


100%|██████████| 29/29 [00:18<00:00,  1.57it/s]

Accuracy: 31.26%
Average entropy across all predictions: 7.64





COOP!!!

In [None]:
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer

_tokenizer = _Tokenizer()

In [None]:
class TextEncoder(nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        self.transformer = clip_model.transformer
        self.positional_embedding = clip_model.positional_embedding
        self.ln_final = clip_model.ln_final
        self.text_projection = clip_model.text_projection

    def forward(self, prompts, tokenized_prompts):
        x = prompts + self.positional_embedding
        x = x.permute(1, 0, 2)  # [batch_size, n_ctx, transformer.width] -> [n_ctx, batch_size, transformer.width]
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # [n_ctx, batch_size, transformer.width] -> [batch_size, n_ctx, transformer.width]
        x = self.ln_final(x)

        # Take features from the eot embedding (eot_token is the highest number in each sequence)
        x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection

        return x

In [None]:
class PromptLearner(nn.Module):
    def __init__(self, clip_model, classnames, n_ctx, ctx_init, class_token_position, csc=False):
        super().__init__()
        n_cls = len(classnames)
        ctx_dim = clip_model.ln_final.weight.shape[0]
        clip_imsize = clip_model.visual.input_resolution

        # Use given words to initialize context vectors
        # De aquí sacamos el context vector (tokenizado si lo sacamos de un cierto
        # texto o directamente un vector random que no sale de tokenizar
        # si lo iniciamos en plan random) y el prompt prefix que es el texto con
        # el que comenzamos antes de entrenar, que puede ser o texto inicial
        # con sentido o una X que no representa nada si no hay texto que inicialize
        # solo representa la cantidad de palabras que hay
        if ctx_init:
            ctx_init = ctx_init.replace("_", " ")
            n_ctx = len(ctx_init.split(" "))
            prompt = clip.tokenize(ctx_init).to(clip_model.token_embedding.weight.device)
            with torch.no_grad():
                embedding = clip_model.token_embedding(prompt)
            ctx_vectors = embedding[0, 1 : 1 + n_ctx, :]
            prompt_prefix = ctx_init
        else:
            if csc:
                print("Initializing class-specific contexts")
                ctx_vectors = torch.empty(n_cls, n_ctx, ctx_dim)
            else:
                print("Initializing a generic context")
                ctx_vectors = torch.empty(n_ctx, ctx_dim)

            torch.nn.init.normal_(ctx_vectors, std=0.02)
            prompt_prefix = " ".join(["X"] * n_ctx)

        print(f"Initial context: '{prompt_prefix}'")
        print(f"Number of context words (tokens): {n_ctx}")

        # These are the `prompts` we want to optimize
        self.ctx = nn.Parameter(ctx_vectors)

        classnames = [name.replace("_", " ") for name in classnames]
        name_lens = [len(_tokenizer.encode(name)) for name in classnames]
        prompts = [prompt_prefix + " " + name + "." for name in classnames]

        # print("+++")
        # print("Prompts:")
        # for p in prompts:
        #     print(p)
        # print("+++")

        # Aqui está tokenizando a partir de el prompt (las X o lo que le hayamos
        # dado) pero no usa el context vector pa nada. Pero ese es el que nos importa
        # así que NO ENTIENDO ESTOS PA QUE SON. PARA SACAR EL SOS, CLS y EOS tokens
        tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]).to(clip_model.token_embedding.weight.device)

        with torch.no_grad():
            embedding = clip_model.token_embedding(tokenized_prompts)

        # These token vectors will be saved when in save_model(),
        # but they should be ignored in load_model() as we want to use
        # those computed using the current class names.
        # Buffer implica que no computas el gradiente para estos tokens. Asi que
        # son constantes!
        self.register_buffer("token_prefix", embedding[:, :1, :])  # SOS
        self.register_buffer("token_suffix", embedding[:, 1 + n_ctx :, :])  # CLS, EOS

        self.n_cls = n_cls
        self.n_ctx = n_ctx
        self.tokenized_prompts = tokenized_prompts
        self.name_lens = name_lens
        self.class_token_position = class_token_position

    def forward(self):
        prefix = self.token_prefix
        suffix = self.token_suffix
        ctx = self.ctx

        # If CoOp, expand the ctx for all classes (implying a shared context across all classes)
        if ctx.dim() == 2:
            ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)

        #Metemos el class token donde toque.
        #PERO AQUI NO ESTÁ EL CLASS TOKEN! SII VA EN EL SUFFIX JUNTO CON EOS
        if self.class_token_position == "end":
            prompts = torch.cat(
                [
                    prefix,  # (n_cls, 1, dim)
                    ctx,     # (n_cls, n_ctx, dim)
                    suffix,  # (n_cls, *, dim)
                ],
                dim=1,
            )

        elif self.class_token_position == "middle":
            half_n_ctx = self.n_ctx // 2
            prompts = []
            for i in range(self.n_cls):
                name_len = self.name_lens[i]
                prefix_i = prefix[i : i + 1, :, :]
                class_i = suffix[i : i + 1, :name_len, :]
                suffix_i = suffix[i : i + 1, name_len:, :]
                ctx_i_half1 = ctx[i : i + 1, :half_n_ctx, :]
                ctx_i_half2 = ctx[i : i + 1, half_n_ctx:, :]
                prompt = torch.cat(
                    [
                        prefix_i,     # (1, 1, dim)
                        ctx_i_half1,  # (1, n_ctx//2, dim)
                        class_i,      # (1, name_len, dim)
                        ctx_i_half2,  # (1, n_ctx//2, dim)
                        suffix_i,     # (1, *, dim)
                    ],
                    dim=1,
                )
                prompts.append(prompt)
            prompts = torch.cat(prompts, dim=0)

        elif self.class_token_position == "front":
            prompts = []
            for i in range(self.n_cls):
                name_len = self.name_lens[i]
                prefix_i = prefix[i : i + 1, :, :]
                class_i = suffix[i : i + 1, :name_len, :]
                suffix_i = suffix[i : i + 1, name_len:, :]
                ctx_i = ctx[i : i + 1, :, :]
                prompt = torch.cat(
                    [
                        prefix_i,  # (1, 1, dim)
                        class_i,   # (1, name_len, dim)
                        ctx_i,     # (1, n_ctx, dim)
                        suffix_i,  # (1, *, dim)
                    ],
                    dim=1,
                )
                prompts.append(prompt)
            prompts = torch.cat(prompts, dim=0)

        else:
            raise ValueError

        return prompts

In [None]:
main_coop()

In [None]:
class CoCoOp:
    def __init__(self, CLIP_model, device=None):
        self.device = device if device else "cuda" if torch.cuda.is_available() else "cpu"