In [45]:
#!pip install einops



In [46]:
import os
import pickle
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from tokenizers import Tokenizer

In [47]:
from src.models.natural_language_processing.nlp_backbones import GPTSmall, GPTBase
from src.models.computer_vision.backbones.vit import ViTBaseOver16at112, ViTBaseOver32at224, ViTSmallOver16at112, ViTMicroOver14at112
from src.models.CLIP_model import CLIPModule
from src.utils import load_from_checkpoint

In [48]:
dataset_file = "data/imagenet/imagenet.csv"
image_path = "data/imagenet/images"

In [49]:
def unpickle(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo)
    return dict

def load_dataset(data_folder, filename, img_size=64):
    data_file = os.path.join(data_folder, filename)

    data_col = 'data'
    label_col = 'labels'

    d = unpickle(data_file)
    x = d[data_col]
    y = d[label_col]

    img_size2 = img_size * img_size

    x = np.dstack((x[:, :img_size2], x[:, img_size2:2*img_size2], x[:,2*img_size2:]))
    x = x.reshape((len(x), img_size, img_size, 3)).transpose(0, 3, 1, 2)

    return x, y

def load_class_mapping(filename):
    class_map = { num: None for num in range(1, 1001) }
    
    with open(filename, 'r') as f:
        for line in f:
            _, idx, label = line.split()
            class_map[int(idx)] = label
            
    return class_map

In [85]:
def load_clip_backbone(image_encoder, text_encoder, device):
    image_model = None
    image_resolution = None
    
    if image_encoder == "B/32@224":
        image_model = ViTBaseOver32at224(dim_out=512).to(device)
        image_resolution = 224
    if image_encoder == "B/16@112":
        image_model = ViTBaseOver16at112(dim_out=512).to(device)
        image_resolution = 112
    if image_encoder == "S/16@112":
        image_model = ViTSmallOver16at112(dim_out=512).to(device)
        image_resolution = 112
    if image_encoder == "M/14@112":
        image_model = ViTMicroOver14at112(dim_out=512).to(device)
        image_resolution = 112

    text_model = None
    if text_encoder == "S":
        text_model = GPTSmall(dim_out=768, vocab_size=43001, max_length=34, batch_size=18).to(device)
    if text_encoder == "B":
        text_model = GPTBase(dim_out=768, vocab_size=43001, max_length=34, batch_size=18).to(device)
        
    clip_model = CLIPModule(image_encoder=image_model, text_encoder=text_model, dim_img=512, dim_text=768, embedding_dim=512, temperature=0.07).to(device)
    
    return clip_model


def tokenize(tokenizer, query, max_length):
    # Encode sequence
    encoded_query = tokenizer.encode(query).ids

    # Truncate query if necessary
    encoded_query = encoded_query[:max_length-2]

    # Add end_of_sentence token [EOS]
    encoded_query += [tokenizer.token_to_id('[EOS]')]

    # Add padding to encoded sentence
    encoded_query += [0] * (max_length - len(encoded_query) - 1)

    # Add [SOS] and [EOS] tokens
    encoded_query = [tokenizer.token_to_id('[SOS]')] + encoded_query
    
    return encoded_query

In [80]:
def load_clip(clip_model):
    checkpointsdir = "src/models/checkpoints"
    
    if clip_model == "ViT-Base/32 @ 224px":
        clip = load_clip_backbone(image_encoder="B/32@224", text_encoder="B", device=torch.device('cpu'))
        _, loss_hist = clip.load_from_checkpoint(os.path.join(checkpointdir, "CLIP_epoch_4_"))
        return clip
        
    if clip_model == "ViT-Base/16 @ 112px":
        clip = load_clip_backbone(image_encoder="B/16@112", text_encoder="B", device=torch.device('cpu'))
        _, loss_hist = clip.load_from_checkpoint(os.path.join(checkpointdir, "CLIP_epoch_2_"))
        return clip
    
    if clip_model == "ViT-Small/16 @ 112px":
        clip = load_clip_backbone(image_encoder="S/16@112", text_encoder="B", device=torch.device('cpu'))
        _, loss_hist = clip.load_from_checkpoint(os.path.join(checkpointdir, "CLIP_epoch_2_"))
        return clip
        
    if clip_model == "ViT-Micro/14 @ 112px":
        clip = load_clip_backbone(image_encoder="M/14@112", text_encoder="S", device=torch.device('cpu'))
        _, loss_hist = clip.load_from_checkpoint(os.path.join(checkpointdir, "CLIP_epoch_2_"))
        return clip
    

# Load data

In [52]:
datadir = f"data/imagenet/"
tokenizer_file = "src/data/nlp/tokenizers/CLIP-bpe.tokenizer.json"

imdir = os.path.join(datadir, "images")
clabelsdir = os.path.join(datadir, "map_clsloc.txt")

In [53]:
x, y = load_dataset(datadir, 'val_data', img_size=64)
class_map = load_class_mapping(clabelsdir)

In [54]:
X_val = torch.from_numpy(x)
print("X_val shape", X_val.shape)

X_val shape torch.Size([50000, 3, 64, 64])


## Load Model

In [88]:
clip = load_clip_backbone(image_encoder="B/32@224", text_encoder="B", device=torch.device('cpu'))
np.sum([ np.prod(x.shape) for x in clip.parameters() ])

211438848

In [92]:
clip = load_clip_backbone(image_encoder="S/16@112", text_encoder="B", device=torch.device('cpu'))
np.sum([ np.prod(x.shape) for x in clip.parameters() ])

127588864

In [93]:
127588864 / 211438848

0.6034315132099093

In [55]:
clip = load_clip("ViT-Base/32 @ 224px")
len(clip.parameters())

AttributeError: 'CLIPModule' object has no attribute 'load_from_checkpoint'

## Load Tokenizer

In [None]:
tokenizer = Tokenizer.from_file(tokenizer_file)

## Zero-shot Classification

Templates

In [35]:
templates = [
    'a photo of a {}.',
    'a blurry photo of a {}.',
    'a black and white photo of a {}.',
    'a low contrast photo of a {}.',
    'a high contrast photo of a {}.',
    'a bad photo of a {}.',
    'a good photo of a {}.',
    'a photo of a small {}.',
    'a photo of a big {}.',
    'a photo of the {}.',
    'a blurry photo of the {}.',
    'a black and white photo of the {}.',
    'a low contrast photo of the {}.',
    'a high contrast photo of the {}.',
    'a bad photo of the {}.',
    'a good photo of the {}.',
    'a photo of the small {}.',
    'a photo of the big {}.',
]

### Compute zero-shot weights for classification

In [42]:
encode_text = lambda x : tokenize(tokenizer, x.format(key), 34)

In [43]:
zero_shot_weights = torch.zeros(1000, 768)
for i, key in enumerate(class_map.keys()):
    class_tokens = torch.from_numpy( np.array( [ encode_text(x) for x in templates ] ) )
    zero_shot_weights[i, :] = clip.text_encoder(class_tokens).mean(dim=-1)
print(zero_shot_weights.shape)

NameError: name 'tokenizer' is not defined