In [1]:
import torch
import torchvision

print(torch.__version__)
print(torchvision.__version__)

1.7.1
0.8.2


In [2]:
import clip
from tqdm.notebook import tqdm
from sklearn.linear_model import LogisticRegression
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)
clip_model, clip_preprocess = clip.load("ViT-B/32", device)

phrase = "This is a photo of a {}."

cuda


In [3]:
def get_clip_image_features(data_loader):
    """Given a dataloader object, generate two torch arrays of encoded images and corresponding labels"""
    all_features = []
    all_labels = []

    global clip_model

    with torch.no_grad():
        for images, labels in tqdm(data_loader):
            features = clip_model.encode_image(images.to(device))
            all_features.append(features)
            all_labels.append(labels)

    return torch.cat(all_features), torch.cat(all_labels)

In [4]:
# Get embeddings from the linear probe
def get_clip_linear_probe_classifier(
    train_features, train_labels, C=1
):

    classifier = LogisticRegression(C=C, max_iter=1000, n_jobs=6)
    classifier.fit(train_features.cpu().numpy(), train_labels.cpu().numpy())

    return classifier

def get_clip_linear_probe_embedding(classifier, imgs):
    return torch.from_numpy(classifier.predict_proba(imgs.cpu().detach().numpy())).to(torch.float16).cuda()

In [5]:
# Get embeddings for 

def get_clip_text_features(classes):
    """Given a dataloader object, generate two torch arrays of encoded images and corresponding labels"""
    # Assumes the positions are in accordance to the label numbers
    embedding_per_class = {}
    
    global clip_model

    global phrase

    with torch.no_grad():
        for i,_class in enumerate(classes):
            _class = _class.replace("_", " ")
            text = clip.tokenize(phrase.format(_class)).cuda() 
            class_embeddings = clip_model.encode_text(
                    text
                )
            class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
            embedding_per_class[i] = class_embeddings
    return embedding_per_class

def get_text_embeds(classes):
    text_clip_features = get_clip_text_features(classes)
    text_embeds = []
    for c in range(len(classes)):
        text_embs = text_clip_features[c].squeeze()
        text_embeds.append(text_embs)

    text_embeds = torch.stack(text_embeds).squeeze(1)
    return text_embeds

In [6]:
import torch.nn as nn
import torch.optim as optim

class FinetuneLambda(nn.Module):
    def __init__(self):
        super().__init__()
        self.a = torch.nn.Parameter(torch.tensor([1.0]).cuda())
        self.b = torch.nn.Parameter(torch.tensor([0.0]).cuda())
        self.sftmx = torch.nn.Softmax(dim=1)
        
    def forward(self, image_image, image_text):
        out = (self.a*(image_image) + (self.b)*self.sftmx((image_text)*100))/(self.a+ self.b)
        return out[0]
    
    def string(self):
        return f'A: {self.a.item()}, B: {self.b.item()}'
    
model = FinetuneLambda()
model.logit_scale = nn.Parameter(torch.ones([], device=device))
criterion = nn.CrossEntropyLoss()
learning_rate = 0.005
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=1e-5)
num_epochs = 500

def num_correct_preds(outputs, labels):
    predicted = outputs.argmax().item()
    labels = labels.item()
    return predicted == labels

In [17]:
from .. import datasets
from datasets import *
dataset_obj = Cifar100(4, 1)
n_classes = 32

train_loader, _ = dataset_obj.get_train_loaders(transform_fn=clip_preprocess,num_elements_per_class=n_classes)
test_loader = dataset_obj.get_test_loader(transform_fn=clip_preprocess)
train_features, train_labels = get_clip_image_features(train_loader)
classes = dataset_obj.classes

Files already downloaded and verified
50000
HERE
3200
Files already downloaded and verified
Files already downloaded and verified


  0%|          | 0/3200 [00:00<?, ?it/s]

In [18]:
lpc = get_clip_linear_probe_classifier(train_features, train_labels)
text_embeddings_per_class = get_text_embeds(classes)

In [None]:
import random

for epoch in tqdm(range(num_epochs+1)):
    
    model.train()
    running_loss = 0.0
    correct = 0.0
    total = 0
    

    for inputs, labels in zip(train_features, train_labels):
        
        inputs = inputs.unsqueeze(0)
        
        inputs, labels = inputs.to(device), labels.to(device)
        
        image_only = get_clip_linear_probe_embedding(lpc, inputs.repeat(len(classes), 1))
        
        inputs_norm = inputs/inputs.norm(dim=-1, keepdim=True)
        image_text =inputs_norm.repeat(len(classes), 1) @ text_embeddings_per_class.T
        
        
        optimizer.zero_grad()

        outputs = model(image_only, image_text)
        
        loss = criterion(outputs.unsqueeze(0), labels.unsqueeze(0))

        loss.backward()

        optimizer.step()
        running_loss+=loss
        total += 1
        correct += num_correct_preds(outputs, labels)
        
    if epoch%250 == 0:

        model.eval()
        eval_total = 0
        eval_correct = 0
        for images, target in tqdm(test_loader):
            images = images.cuda()
            target = target.cuda()
            image_features = clip_model.encode_image(images)
            image_features_norm = image_features/image_features.norm(dim=-1, keepdim=True)

            image_only = get_clip_linear_probe_embedding(lpc, image_features.repeat(len(classes), 1))
            image_text = inputs_norm.repeat(len(classes), 1) @ text_embeddings_per_class.T

            outputs = model(image_only, image_text)

            eval_correct += num_correct_preds(outputs, target)
            eval_total+=1
        print("----------------")
        print("Accuracy: ", (eval_correct)*100/eval_total, "%" )
        print("----------------")
        
        
    epoch_loss = running_loss/len(train_loader)
    epoch_accuracy = correct*100/total
    print(
        f"Training: Epoch {epoch} || Loss: {epoch_loss:7.3f} || {model.string()}"
    )



  0%|          | 0/501 [00:00<?, ?it/s]

  0%|          | 0/10000 [00:00<?, ?it/s]

----------------
Accuracy:  71.1 %
----------------
Training: Epoch 0 || Loss:   0.010 || A: 14.014772415161133, B: 4.008363246917725
Training: Epoch 1 || Loss:   0.009 || A: 14.064529418945312, B: 3.993220329284668
Training: Epoch 2 || Loss:   0.009 || A: 14.11314868927002, B: 3.9781816005706787
Training: Epoch 3 || Loss:   0.009 || A: 14.160688400268555, B: 3.9633941650390625
Training: Epoch 4 || Loss:   0.009 || A: 14.207194328308105, B: 3.948850631713867
Training: Epoch 5 || Loss:   0.009 || A: 14.252708435058594, B: 3.934544324874878
Training: Epoch 6 || Loss:   0.009 || A: 14.297274589538574, B: 3.920464277267456
Training: Epoch 7 || Loss:   0.008 || A: 14.340943336486816, B: 3.9066009521484375
Training: Epoch 8 || Loss:   0.008 || A: 14.383710861206055, B: 3.892948627471924
Training: Epoch 9 || Loss:   0.008 || A: 14.42564868927002, B: 3.8795013427734375
Training: Epoch 10 || Loss:   0.008 || A: 14.466766357421875, B: 3.8662497997283936
