In [25]:
from dataloaders import get_dataloader
import clip
import torch
import numpy as np
from tqdm import tqdm

In [8]:
model, preprocess = clip.load('ViT-B/16')

100%|███████████████████████████████████████| 335M/335M [01:14<00:00, 4.68MiB/s]


In [9]:
test = get_dataloader("geolocation_kaggle", preprocess, loader_type="test")

 Dataset: GEOLOCATION_KAGGLE.
 Transformation test: Compose(
    Resize(size=224, interpolation=bicubic, max_size=None, antialias=warn)
    CenterCrop(size=(224, 224))
    <function _convert_image_to_rgb at 0x7f92de6b5240>
    ToTensor()
    Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
)
 Transformation train: None
 Dataloader type: test.
 Test images 49997




In [17]:
loader = torch.utils.data.DataLoader(test, batch_size=64)

In [18]:
batch1 = next(iter(loader))

In [23]:
print(f" Model parameters: {np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")

 Model parameters: 149,620,737


In [28]:
def read_txt(file_location):
    with open(file_location, 'r') as file:
        content = file.read(); content = str(content); content = content.split('\n', -1)
    try: content.remove("")
    except: pass
    return content

In [31]:
def get_classes_prompts():
    classes = read_txt("/data/azfarm/siddhant/Geolocalization_UCF/ VLM-GeoBench/CLIP/dataloaders/classes/GeoLocation_kaggle.txt")
    templates = read_txt("/data/azfarm/siddhant/Geolocalization_UCF/ VLM-GeoBench/CLIP/dataloaders/templates/GeoLocation_Kaggle.txt")
    return classes, templates

In [32]:
classes, templates = get_classes_prompts()

In [33]:
def zeroshot_classifier(classnames, templates, model):
    """ 
    Creating zero-shot classifier weights. This is taken form CLIP official codebase.
    Please refer to .
    """
    with torch.no_grad():
        zeroshot_weights = []
        for classname in tqdm(classnames):
            texts = [template.format(classname) for template in templates] # format with class
            texts = clip.tokenize(texts).cuda() # tokenize
            class_embeddings = model.encode_text(texts) # embed with text encoder
            class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
            class_embedding = class_embeddings.mean(dim=0)
            class_embedding /= class_embedding.norm()
            zeroshot_weights.append(class_embedding)
        zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
    return zeroshot_weights

In [34]:
zeroshot_weights = zeroshot_classifier(classes, templates, model)

100%|██████████| 124/124 [00:03<00:00, 36.00it/s]


In [36]:

def accuracy(output, target, topk=(1,)):
    """
    Zero-shot prediction. This is taken form CLIP official codebase.
    Please refer to .
    """
    pred = output.topk(max(topk), 1, True, True)[1].t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))
    return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk]

In [40]:
with torch.no_grad():
    top1, top5, n = 0., 0., 0.
    for i, (images, target) in enumerate(tqdm(loader)):
        images = images.cuda()
        target = target.cuda()
        
        # predict
        image_features = model.encode_image(images)
        image_features /= image_features.norm(dim=-1, keepdim=True)
        logits = 100. * image_features @ zeroshot_weights

        # measure accuracy
        acc1, _ = accuracy(logits, target, topk=(1, 5))
        top1 += acc1
        n += images.size(0)
        if(i == 20):
            break

  return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk]
  3%|▎         | 20/782 [00:38<24:38,  1.94s/it]
