In [4]:
import clip
from tqdm.notebook import tqdm
from sklearn.linear_model import LogisticRegression
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)
clip_model, clip_preprocess = clip.load("ViT-B/32", device)

phrase = "This is a {} with petals"

cuda


In [14]:
def get_clip_image_features(data_loader):
    """Given a dataloader object, generate two torch arrays of encoded images and corresponding labels"""
    all_features = []
    all_labels = []

    global clip_model

    with torch.no_grad():
        for images, labels in tqdm(data_loader):
            print(images.shape)
            features = clip_model.encode_image(images.to(device))
            all_features.append(features)
            all_labels.append(labels)

    return torch.cat(all_features), torch.cat(all_labels)

In [6]:
# Get embeddings from the linear probe
def get_clip_linear_probe_classifier(
    train_features, train_labels, C=1
):

    classifier = LogisticRegression(C=C, max_iter=1000, n_jobs=6)
    classifier.fit(train_features.cpu().numpy(), train_labels.cpu().numpy())

    return classifier

def get_clip_linear_probe_embedding(classifier, imgs):
    return torch.from_numpy(classifier.predict_proba(imgs.cpu().detach().numpy())).to(torch.float16).cuda()

In [7]:
# Get embeddings for 

def get_clip_text_features(classes):
    """Given a dataloader object, generate two torch arrays of encoded images and corresponding labels"""
    # Assumes the positions are in accordance to the label numbers
    embedding_per_class = {}
    
    global clip_model

    global phrase

    with torch.no_grad():
        for i,_class in enumerate(classes):
            _class = _class.replace("_", " ")
            text = clip.tokenize(phrase.format(_class)).cuda() 
            class_embeddings = clip_model.encode_text(
                    text
                )
            class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
            embedding_per_class[i] = class_embeddings
    return embedding_per_class

def get_text_embeds(classes):
    text_clip_features = get_clip_text_features(classes)
    text_embeds = []
    for c in range(len(classes)):
        text_embs = text_clip_features[c].squeeze()
        text_embeds.append(text_embs)

    text_embeds = torch.stack(text_embeds).squeeze(1)
    return text_embeds

In [8]:
import torch.nn as nn
import torch.optim as optim

class FinetuneLambda(nn.Module):
    def __init__(self):
        super().__init__()
        self.a = torch.nn.Parameter(torch.tensor([1.0]).cuda())
        self.b = torch.nn.Parameter(torch.tensor([1.0]).cuda())
        self.sftmx = torch.nn.Softmax(dim=1)
        
    def forward(self, image_image, image_text):
        out = (self.a*(image_image) + (self.b)*self.sftmx((image_text)*100))/(self.a + self.b)
        return out[0]
    
    def string(self):
        return f'A: {self.a.item()}, B: {self.b.item(),}, Lambda: {self.a.item()/(self.a.item()+self.b.item()) }'
    
model = FinetuneLambda()
model.logit_scale = nn.Parameter(torch.ones([], device=device))
criterion = nn.CrossEntropyLoss()
learning_rate = 0.005
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=1e-5)
num_epochs = 500

def num_correct_preds(outputs, labels):
    predicted = outputs.argmax().item()
    labels = labels.item()
    return predicted == labels

In [15]:
from datasets import *
dataset_obj = Flowers102(4, 1)
num_elements_per_class = 4

train_loader, _, _, _ = dataset_obj.get_train_loaders(transform_fn=clip_preprocess,num_elements_per_class=num_elements_per_class)
test_loader = dataset_obj.get_test_loader(transform_fn=clip_preprocess)
train_features, train_labels = get_clip_image_features(train_loader)
classes = dataset_obj.classes

  0%|          | 0/408 [00:00<?, ?it/s]

torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1,

torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])
torch.Size([1,

In [19]:
lpc = get_clip_linear_probe_classifier(train_features, train_labels)
text_embeddings_per_class = get_text_embeds(classes)

In [20]:
import random

for epoch in tqdm(range(num_epochs+1)):
    
    model.train()
    running_loss = 0.0
    correct = 0.0
    total = 0
    

    for inputs, labels in zip(train_features, train_labels):
        
        inputs = inputs.unsqueeze(0)
        
        inputs, labels = inputs.to(device), labels.to(device)
        
        image_only = get_clip_linear_probe_embedding(lpc, inputs.repeat(len(classes), 1))
        
        inputs_norm = inputs/inputs.norm(dim=-1, keepdim=True)
        image_text =inputs_norm.repeat(len(classes), 1) @ text_embeddings_per_class.T
        
        
        optimizer.zero_grad()

        outputs = model(image_only, image_text)
        
        loss = criterion(outputs.unsqueeze(0), labels.unsqueeze(0))

        loss.backward()

        optimizer.step()
        running_loss+=loss
        total += 1
        correct += num_correct_preds(outputs, labels)
        
    if epoch%50 == 0:

        model.eval()
        eval_total = 0
        eval_correct = 0
        for images, target in tqdm(test_loader):
            images = images.cuda()
            target = target.cuda()
            image_features = clip_model.encode_image(images)
            image_features_norm = image_features/image_features.norm(dim=-1, keepdim=True)

            image_only = get_clip_linear_probe_embedding(lpc, image_features.repeat(len(classes), 1))
            image_text = inputs_norm.repeat(len(classes), 1) @ text_embeddings_per_class.T

            outputs = model(image_only, image_text)

            eval_correct += num_correct_preds(outputs, target)
            eval_total+=1
        print("----------------")
        print("Accuracy: ", (eval_correct)*100/eval_total, "%" )
        print("----------------")
        
        
    epoch_loss = running_loss/len(train_loader)
    epoch_accuracy = correct*100/total
    print(
        f"Training: Epoch {epoch} || Loss: {epoch_loss:7.3f} || {model.string()}"
    )



  0%|          | 0/501 [00:00<?, ?it/s]

  0%|          | 0/819 [00:00<?, ?it/s]

----------------
Accuracy:  3.4188034188034186 %
----------------
Training: Epoch 0 || Loss:   3.997 || A: 1.0234801769256592, B: (0.9759701490402222,) Lambda: 0.5118807722473641
Training: Epoch 1 || Loss:   3.995 || A: 1.0464024543762207, B: (0.9513680934906006,) Lambda: 0.5237851040969385
Training: Epoch 2 || Loss:   3.992 || A: 1.0687774419784546, B: (0.926176905632019,) Lambda: 0.5357403006533058
Training: Epoch 3 || Loss:   3.990 || A: 1.0906144380569458, B: (0.9003782272338867,) Lambda: 0.5477742118641282
Training: Epoch 4 || Loss:   3.988 || A: 1.1119189262390137, B: (0.8739504218101501,) Lambda: 0.5599154482802794
Training: Epoch 5 || Loss:   3.986 || A: 1.1326944828033447, B: (0.8468690514564514,) Lambda: 0.5721940534870911
Training: Epoch 6 || Loss:   3.983 || A: 1.1529446840286255, B: (0.8191070556640625,) Lambda: 0.5846422083267923
Training: Epoch 7 || Loss:   3.981 || A: 1.1726678609848022, B: (0.790632963180542,) Lambda: 0.5972940298047994
Training: Epoch 8 || Loss:   3.9

  0%|          | 0/819 [00:00<?, ?it/s]

----------------
Accuracy:  32.84493284493284 %
----------------
Training: Epoch 50 || Loss:   3.951 || A: 4.383269309997559, B: (1.4483345746994019,) Lambda: 0.7516404400339917
Training: Epoch 51 || Loss:   3.951 || A: 4.387263298034668, B: (1.4361188411712646,) Lambda: 0.7533874977047116
Training: Epoch 52 || Loss:   3.951 || A: 4.391234397888184, B: (1.4238578081130981,) Lambda: 0.7551444142805421
Training: Epoch 53 || Loss:   3.950 || A: 4.395180702209473, B: (1.4115514755249023,) Lambda: 0.7569112140323216
Training: Epoch 54 || Loss:   3.950 || A: 4.39910364151001, B: (1.3991988897323608,) Lambda: 0.7586881881044999
Training: Epoch 55 || Loss:   3.950 || A: 4.403006553649902, B: (1.3867998123168945,) Lambda: 0.7604756144404627
Training: Epoch 56 || Loss:   3.949 || A: 4.406883239746094, B: (1.3743536472320557,) Lambda: 0.7622734245109908
Training: Epoch 57 || Loss:   3.949 || A: 4.410735607147217, B: (1.3618602752685547,) Lambda: 0.7640818267883616
Training: Epoch 58 || Loss:   3.

  0%|          | 0/819 [00:00<?, ?it/s]

----------------
Accuracy:  57.63125763125763 %
----------------
Training: Epoch 100 || Loss:   3.932 || A: 4.55027961730957, B: (0.7715627551078796,) Lambda: 0.8550196151793581
Training: Epoch 101 || Loss:   3.932 || A: 4.552806854248047, B: (0.7563631534576416,) Lambda: 0.857536460056871
Training: Epoch 102 || Loss:   3.931 || A: 4.555294990539551, B: (0.7410824298858643,) Lambda: 0.8600774886193177
Training: Epoch 103 || Loss:   3.931 || A: 4.557746887207031, B: (0.7257205247879028,) Lambda: 0.8626431341016099
Training: Epoch 104 || Loss:   3.930 || A: 4.560157775878906, B: (0.7102754712104797,) Lambda: 0.8652339498649885
Training: Epoch 105 || Loss:   3.930 || A: 4.562526702880859, B: (0.694746196269989,) Lambda: 0.8678504598111686
Training: Epoch 106 || Loss:   3.929 || A: 4.564853191375732, B: (0.6791316866874695,) Lambda: 0.8704932026923964
Training: Epoch 107 || Loss:   3.929 || A: 4.567140102386475, B: (0.6634306907653809,) Lambda: 0.8731628502889245
Training: Epoch 108 || Los

  0%|          | 0/819 [00:00<?, ?it/s]

----------------
Accuracy:  64.59096459096459 %
----------------
Training: Epoch 150 || Loss:   3.901 || A: 4.612587928771973, B: (-0.11938578635454178,) Lambda: 1.0265703127904033
Training: Epoch 151 || Loss:   3.900 || A: 4.611958980560303, B: (-0.1409592181444168,) Lambda: 1.031527449258519
Training: Epoch 152 || Loss:   3.899 || A: 4.611219882965088, B: (-0.16274376213550568,) Lambda: 1.0365841599943568
Training: Epoch 153 || Loss:   3.898 || A: 4.610368728637695, B: (-0.18474513292312622,) Lambda: 1.0417444296668201
Training: Epoch 154 || Loss:   3.897 || A: 4.609404563903809, B: (-0.20696939527988434,) Lambda: 1.0470124799917444
Training: Epoch 155 || Loss:   3.896 || A: 4.608321666717529, B: (-0.22942295670509338,) Lambda: 1.0523928439314008
Training: Epoch 156 || Loss:   3.895 || A: 4.607112407684326, B: (-0.2521124482154846,) Lambda: 1.0578903445606997
Training: Epoch 157 || Loss:   3.894 || A: 4.605778694152832, B: (-0.2750447988510132,) Lambda: 1.0635099744062766
Training: E

  0%|          | 0/819 [00:00<?, ?it/s]

----------------
Accuracy:  64.34676434676435 %
----------------
Training: Epoch 200 || Loss:   3.783 || A: 4.270881652832031, B: (-1.7440961599349976,) Lambda: 1.6902430637019925
Training: Epoch 201 || Loss:   3.775 || A: 4.244297504425049, B: (-1.8078278303146362,) Lambda: 1.7419865921272746
Training: Epoch 202 || Loss:   3.765 || A: 4.214570045471191, B: (-1.8761000633239746,) Lambda: 1.802276735492372
Training: Epoch 203 || Loss:   3.753 || A: 4.1809401512146, B: (-1.94991934299469,) Lambda: 1.874003207773976
Training: Epoch 204 || Loss:   3.739 || A: 4.142318248748779, B: (-2.030712842941284,) Lambda: 1.9616914397719696
Training: Epoch 205 || Loss:   3.721 || A: 4.097052097320557, B: (-2.1206154823303223,) Lambda: 2.0729488951209296
Training: Epoch 206 || Loss:   3.698 || A: 4.042426109313965, B: (-2.223069906234741,) Lambda: 2.221899209441362
Training: Epoch 207 || Loss:   3.666 || A: 3.9734480381011963, B: (-2.3442792892456055,) Lambda: 2.4389419701871544
Training: Epoch 208 || 

  0%|          | 0/819 [00:00<?, ?it/s]

----------------
Accuracy:  31.257631257631257 %
----------------
Training: Epoch 250 || Loss:   3.953 || A: 10.283077239990234, B: (3.528285503387451,) Lambda: 0.7445374820034175
Training: Epoch 251 || Loss:   3.952 || A: 10.284773826599121, B: (3.5231692790985107,) Lambda: 0.7448447424696637
Training: Epoch 252 || Loss:   3.952 || A: 10.286467552185059, B: (3.5180487632751465,) Lambda: 0.7451523339984647
Training: Epoch 253 || Loss:   3.952 || A: 10.288159370422363, B: (3.5129239559173584,) Lambda: 0.7454602749037206
Training: Epoch 254 || Loss:   3.952 || A: 10.289850234985352, B: (3.5077967643737793,) Lambda: 0.7457684803404009
Training: Epoch 255 || Loss:   3.952 || A: 10.291542053222656, B: (3.5026652812957764,) Lambda: 0.7460770890016453
Training: Epoch 256 || Loss:   3.952 || A: 10.293233871459961, B: (3.4975311756134033,) Lambda: 0.7463859935489482
Training: Epoch 257 || Loss:   3.952 || A: 10.294923782348633, B: (3.4923927783966064,) Lambda: 0.7466952497239365
Training: Epoch

  0%|          | 0/819 [00:00<?, ?it/s]

----------------
Accuracy:  34.065934065934066 %
----------------
Training: Epoch 300 || Loss:   3.950 || A: 10.365784645080566, B: (3.2682511806488037,) Lambda: 0.7602873263336196
Training: Epoch 301 || Loss:   3.950 || A: 10.367389678955078, B: (3.2629618644714355,) Lambda: 0.7606105863025184
Training: Epoch 302 || Loss:   3.949 || A: 10.368993759155273, B: (3.257668972015381,) Lambda: 0.760934204046634
Training: Epoch 303 || Loss:   3.949 || A: 10.370596885681152, B: (3.2523715496063232,) Lambda: 0.7612582334712213
Training: Epoch 304 || Loss:   3.949 || A: 10.372199058532715, B: (3.2470712661743164,) Lambda: 0.7615825819769706
Training: Epoch 305 || Loss:   3.949 || A: 10.373796463012695, B: (3.241767168045044,) Lambda: 0.7619072367558534
Training: Epoch 306 || Loss:   3.949 || A: 10.375391006469727, B: (3.236459255218506,) Lambda: 0.7622322319892242
Training: Epoch 307 || Loss:   3.949 || A: 10.376985549926758, B: (3.231147527694702,) Lambda: 0.7625576183548414
Training: Epoch 308

  0%|          | 0/819 [00:00<?, ?it/s]

----------------
Accuracy:  37.60683760683761 %
----------------
Training: Epoch 350 || Loss:   3.947 || A: 10.443876266479492, B: (2.9992144107818604,) Lambda: 0.7768954712285803
Training: Epoch 351 || Loss:   3.946 || A: 10.445387840270996, B: (2.9937362670898438,) Lambda: 0.7772372482630677
Training: Epoch 352 || Loss:   3.946 || A: 10.44689655303955, B: (2.9882540702819824,) Lambda: 0.7775794143241838
Training: Epoch 353 || Loss:   3.946 || A: 10.448404312133789, B: (2.9827675819396973,) Lambda: 0.7779220156317227
Training: Epoch 354 || Loss:   3.946 || A: 10.449911117553711, B: (2.9772775173187256,) Lambda: 0.7782650115165369
Training: Epoch 355 || Loss:   3.946 || A: 10.45141887664795, B: (2.97178316116333,) Lambda: 0.778608475623608
Training: Epoch 356 || Loss:   3.946 || A: 10.452921867370605, B: (2.966284990310669,) Lambda: 0.778952286691017
Training: Epoch 357 || Loss:   3.946 || A: 10.454425811767578, B: (2.9607818126678467,) Lambda: 0.7792966090755937
Training: Epoch 358 ||

  0%|          | 0/819 [00:00<?, ?it/s]

----------------
Accuracy:  42.857142857142854 %
----------------
Training: Epoch 400 || Loss:   3.943 || A: 10.517131805419922, B: (2.7202954292297363,) Lambda: 0.7944996878162834
Training: Epoch 401 || Loss:   3.943 || A: 10.518545150756836, B: (2.7146103382110596,) Lambda: 0.7948629606540819
Training: Epoch 402 || Loss:   3.943 || A: 10.5199556350708, B: (2.7089195251464844,) Lambda: 0.7952267677834832
Training: Epoch 403 || Loss:   3.943 || A: 10.521363258361816, B: (2.703223943710327,) Lambda: 0.7955910530585966
Training: Epoch 404 || Loss:   3.943 || A: 10.522767066955566, B: (2.6975247859954834,) Lambda: 0.7959557310837023
Training: Epoch 405 || Loss:   3.943 || A: 10.524166107177734, B: (2.6918208599090576,) Lambda: 0.7963208599847449
Training: Epoch 406 || Loss:   3.943 || A: 10.525562286376953, B: (2.686115026473999,) Lambda: 0.7966862978206996
Training: Epoch 407 || Loss:   3.943 || A: 10.526955604553223, B: (2.6804039478302,) Lambda: 0.7970522467265995
Training: Epoch 408 |

  0%|          | 0/819 [00:00<?, ?it/s]

----------------
Accuracy:  48.35164835164835 %
----------------
Training: Epoch 450 || Loss:   3.940 || A: 10.5850248336792, B: (2.4304964542388916,) Lambda: 0.8132616896032395
Training: Epoch 451 || Loss:   3.940 || A: 10.586326599121094, B: (2.4245810508728027,) Lambda: 0.8136501221823721
Training: Epoch 452 || Loss:   3.940 || A: 10.587626457214355, B: (2.418661594390869,) Lambda: 0.8140390567397621
Training: Epoch 453 || Loss:   3.940 || A: 10.588926315307617, B: (2.412738084793091,) Lambda: 0.8144285215687922
Training: Epoch 454 || Loss:   3.940 || A: 10.590221405029297, B: (2.4068076610565186,) Lambda: 0.8148186290252444
Training: Epoch 455 || Loss:   3.939 || A: 10.591513633728027, B: (2.400871992111206,) Lambda: 0.8152093032601837
Training: Epoch 456 || Loss:   3.939 || A: 10.592804908752441, B: (2.394932508468628,) Lambda: 0.8156004828606197
Training: Epoch 457 || Loss:   3.939 || A: 10.594094276428223, B: (2.3889880180358887,) Lambda: 0.8159922302075728
Training: Epoch 458 |

  0%|          | 0/819 [00:00<?, ?it/s]

----------------
Accuracy:  54.09035409035409 %
----------------
Training: Epoch 500 || Loss:   3.936 || A: 10.647069931030273, B: (2.1285979747772217,) Lambda: 0.8333865602588484
