In [5]:
from operator import is_
import os
import sys
sys.path.append('../.')
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

import matplotlib.pyplot as plt
import numpy as np

import clip

from Metrics import base_kmeans_model_evaluation, kmeans_with_init, cosine_kmeans_with_init
from networks import CustomCLIP, load_clip_to_cpu
from lr_scheduler import ConstantWarmupScheduler

import argparse

In [6]:
clip_backbone = "ViT-L/14"
print(clip.available_models())
device = torch.device('cuda') if torch.cuda.is_available() else "cpu"
print('Device:', device)
print('Current cuda device:', torch.cuda.current_device())
print('Count of using GPUs:', torch.cuda.device_count())

backbone_name = clip_backbone
clip_model, preprocess = load_clip_to_cpu(backbone_name)


['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14']
Device: cuda
Current cuda device: 0
Count of using GPUs: 4


In [7]:
batch_size =50
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=preprocess)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False)

Files already downloaded and verified


In [8]:
model = CustomCLIP(clip_model, len(testset.classes))

Initializing class-specific contexts
Initial context: "X X X X X X X X X X X X X X X X"
Number of context words (tokens): 16


In [17]:
classes_list = testset.classes
classes_list = ['a photo of a ' + txt for txt in classes_list]

In [12]:
classes_list

['a photo of a airplane',
 'a photo of a automobile',
 'a photo of a bird',
 'a photo of a cat',
 'a photo of a deer',
 'a photo of a dog',
 'a photo of a frog',
 'a photo of a horse',
 'a photo of a ship',
 'a photo of a truck']

In [18]:
with torch.no_grad():
    classes_list = clip.tokenize(classes_list).to(device)
    centroid_candidate_text = clip_model.to(device).encode_text(classes_list)
    centroid_candidate_text = centroid_candidate_text / centroid_candidate_text.norm(dim=-1, keepdim=True)

In [25]:
import time
start = time.time()
new_label, acc, nmi = cosine_kmeans_with_init(
    model, testloader, 10, centroid_candidate_text)
print(time.time()-start)

image_ACC 0.9535
image_NMI 0.8978035140252928
267.17397451400757


In [26]:
start = time.time()
new_label, acc, nmi = kmeans_with_init(
    model, testloader, 10, centroid_candidate_text)
print(time.time()-start)

image_ACC 0.9606
image_NMI 0.9090755996653886
301.06729388237


In [27]:
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                       download=True, transform=preprocess)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                         shuffle=False)

Files already downloaded and verified


In [28]:
import time
start = time.time()
new_label, acc, nmi = cosine_kmeans_with_init(
    model, trainloader, 10, centroid_candidate_text)
print(time.time()-start)

image_ACC 0.9551
image_NMI 0.9010647389422483
1452.4428737163544


In [29]:
start = time.time()
new_label, acc, nmi = kmeans_with_init(
    model, trainloader, 10, centroid_candidate_text)
print(time.time()-start)

image_ACC 0.96212
image_NMI 0.9123226975349202
1398.7406959533691
