In [1]:
from operator import is_
import os
import sys
sys.path.append('../.')
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

import matplotlib.pyplot as plt
import numpy as np

import clip

from Metrics import base_kmeans_model_evaluation, kmeans_with_init, cosine_kmeans_with_init
from networks import CustomCLIP, load_clip_to_cpu
from lr_scheduler import ConstantWarmupScheduler

import argparse

In [2]:
clip_backbone = "ViT-L/14"
print(clip.available_models())
device = torch.device('cuda') if torch.cuda.is_available() else "cpu"
print('Device:', device)
print('Current cuda device:', torch.cuda.current_device())
print('Count of using GPUs:', torch.cuda.device_count())

backbone_name = clip_backbone
clip_model, preprocess = load_clip_to_cpu(backbone_name)


['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14']
Device: cuda
Current cuda device: 0
Count of using GPUs: 4


In [3]:
batch_size =10
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=preprocess)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False)

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                       download=True, transform=preprocess)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                         shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
model = CustomCLIP(clip_model, len(testset.classes))

Initializing class-specific contexts
Initial context: "X X X X X X X X X X X X X X X X"
Number of context words (tokens): 16


In [49]:
num_classes = 10 
with torch.no_grad():
    centroids, test_label, acc, nmi= base_kmeans_model_evaluation(
        clip_model.to(device), testloader, num_classes)

image_ACC 0.8359
image_NMI 0.8406474690785359


In [207]:
num_classes = 10 
with torch.no_grad():
    train_centroids, test_label, acc, nmi= base_kmeans_model_evaluation(
        clip_model.to(device), trainloader, num_classes)

image_ACC 0.95328
image_NMI 0.9012859197933579


In [164]:
X = centroids
X_square = X**2
X_square_sum = np.sum(X_square,axis=1)
X_train_square = X**2
X_square_train_sum = np.sum(X_train_square,axis=1)
XY = X@X.T
# dists_original = np.sqrt(X_square_sum.reshape(-1,1)+X_square_train_sum.reshape(1,-1)-2*XY)
dists = np.sqrt(X_square_sum.reshape(-1,1)+X_square_train_sum.reshape(1,-1)-2*XY+1e-12)

In [165]:
np.argmin(dists,axis=-1)

array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

In [166]:
dists.argsort(axis=-1)[:,1]

array([9, 0, 7, 5, 8, 9, 9, 2, 4, 5])

In [167]:
k = dists[np.arange(10),dists.argsort(axis=-1)[:,1]]

In [100]:
c

array([[-0.0110123 ,  0.03005093, -0.01091487, ...,  0.00862239,
         0.00132192, -0.0172122 ],
       [ 0.00724263,  0.0220054 ,  0.00479171, ..., -0.00250016,
        -0.00948331, -0.02184404],
       [ 0.00571699,  0.03417663, -0.00728394, ...,  0.00621605,
        -0.00533025, -0.02589159],
       ...,
       [ 0.00849806,  0.04163024,  0.00323967, ...,  0.01474007,
         0.01407203, -0.02373944],
       [-0.01054857,  0.02590865, -0.01196582, ...,  0.03023189,
        -0.00372186, -0.02541586],
       [-0.01246983,  0.03131478,  0.00270114, ...,  0.00727812,
        -0.00256862, -0.01381275]])

In [5]:
classes_list = ['a photo of a '+ c for c in testset.classes]
classes_list
with torch.no_grad():
    classes_list = clip.tokenize(classes_list).to(device)
    centroid_candidate_text = clip_model.to(device).encode_text(classes_list)
    centroid_candidate_text = centroid_candidate_text / centroid_candidate_text.norm(dim=-1, keepdim=True)
for i, (x, target) in enumerate(trainloader):
    x = x.to(device)
    if i ==0:
        break
image_features = model.image_encoder(x.type(clip_model.dtype))
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
sim = image_features@centroid_candidate_text.t()

In [29]:
x = (image_features**2).cpu().detach().numpy()
np.square(np.sum(x,axis=-1))

np.random.rand(1,10)

array([[0.6172282 , 0.06005859, 0.65896441, 0.8763368 , 0.93334016,
        0.8563349 , 0.68656232, 0.79860074, 0.17883979, 0.14537133]])

In [211]:
X = image_features.detach().cpu().numpy()
X_square = X**2
X_square_sum = np.sum(X_square,axis=1)
X_train = train_centroids
X_train_square = X_train**2
X_square_train_sum = np.sum(X_train_square,axis=1)
XY = X@X_train.T
# dists_original = np.sqrt(X_square_sum.reshape(-1,1)+X_square_train_sum.reshape(1,-1)-2*XY)
dists = np.sqrt(X_square_sum.reshape(-1,1)+X_square_train_sum.reshape(1,-1)-2*XY+1e-12)

In [30]:
centroid_candidate_text

tensor([[-0.0237, -0.0017,  0.0569,  ...,  0.0322,  0.0125, -0.0069],
        [ 0.0119, -0.0088,  0.0164,  ...,  0.0056,  0.0435, -0.0258],
        [-0.0066,  0.0147,  0.0375,  ...,  0.0239,  0.0115,  0.0017],
        ...,
        [ 0.0291,  0.0064,  0.0308,  ...,  0.0106, -0.0188, -0.0157],
        [-0.0145, -0.0089,  0.0215,  ...,  0.0299,  0.0108, -0.0028],
        [-0.0124, -0.0029,  0.0378,  ..., -0.0095,  0.0779,  0.0179]],
       device='cuda:0')

In [None]:
prompts = self.prompt_learner()
tokenized_prompts = self.tokenized_prompts
text_features = self.text_encoder(prompts, tokenized_prompts)