# Text embeddings pre-compute

In [1]:
import torch
from transformers import CLIPTextModel, CLIPTokenizer, CLIPModel, CLIPVisionModelWithProjection

import json
import os
from IN_id_to_classname import IMAGENET2012_CLASSES

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
sd_path = "/scratch/choi/model/stable-diffusion-v1-5"
text_encoder = CLIPTextModel.from_pretrained(sd_path, subfolder="text_encoder")
tokenizer = CLIPTokenizer.from_pretrained(sd_path, subfolder="tokenizer")
# tokenizer2 = CLIPTokenizer.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
# text_encoder = CLIPTextModel.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
clip_model = CLIPModel.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
# clip_vision_model = CLIPVisionModelWithProjection.from_pretrained(
#     "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
# )

In [3]:
# data = json.load(
#     open("/scratch/choi/dataset/ImageNet100/_img_text_pair_train.json")
# )  # list of dict: [{"image_file": "1.png", "text": "A dog"}]

In [16]:
text_list = []
label_list = os.listdir("/scratch/choi/dataset/ImageNet100/train")
for label in label_list:
    text_list.append("a photo of " + IMAGENET2012_CLASSES[label])
label_list.append("")
text_list.append("")

In [17]:
len(label_list)

101

In [6]:
text_tokens = tokenizer(text_list, padding="max_length", truncation=True, return_tensors="pt").input_ids

In [7]:
text_tokens.shape

torch.Size([101, 77])

## CLIP-L (SD)

In [8]:
with torch.no_grad():
    out = text_encoder(text_tokens)[0]

In [9]:
out.shape

torch.Size([101, 77, 768])

In [10]:
a = dict(zip(label_list, out))

In [11]:
a["n02114855"].shape

torch.Size([77, 768])

In [16]:
torch.save(a, "IN100_text_embedding_dict_L.pt")

## CLIP-H (with projection)

In [12]:
with torch.no_grad():
    out2 = clip_model.text_projection(clip_model.text_model(text_tokens)[1])

In [13]:
out2.shape

torch.Size([101, 1024])

In [14]:
b = dict(zip(label_list, out2))

In [15]:
b["n02114855"].shape

torch.Size([1024])

In [21]:
torch.save(b, "IN100_text_embedding_dict_with_projection_H.pt")

## Zero-shot classification

In [18]:
import torchvision.datasets as datasets
from torchvision import transforms
from IN_id_to_classname import IMAGENET2012_CLASSES

transform_train = transforms.Compose(
    [
        # RandomResizedCrop(224, interpolation=3),
        # transforms.RandomHorizontalFlip(),
        transforms.Resize(256, interpolation=3),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

dataset_val = datasets.ImageFolder("/scratch/choi/dataset/ImageNet100/val", transform=transform_train)
classnames = [IMAGENET2012_CLASSES[id] for id in dataset_val.classes]
prefix = "a photo of a "
text_inputs = [prefix + name for name in classnames]

In [19]:
from transformers import AutoTokenizer, CLIPTextModelWithProjection

model = CLIPTextModelWithProjection.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")

inputs = tokenizer(text_inputs, padding=True, return_tensors="pt")

with torch.no_grad():
    outputs = model(**inputs)
text_embeds = outputs.text_embeds

In [20]:
text_embeds.shape

torch.Size([100, 1024])

In [None]:
torch.save(text_embeds, "IN100_classnames_text_features.pt")

In [None]:
a = torch.load("IN100_classnames_text_features.pt")

In [None]:
text_embeds = a
text_embeds.shape

# Test

In [66]:
import torch

In [67]:
a = torch.load("IN100_text_embedding_dict_L.pt")

In [68]:
a["n02109047"].shape

torch.Size([77, 768])

In [70]:
a[""]

tensor([[-0.3884,  0.0229, -0.0522,  ..., -0.4899, -0.3066,  0.0675],
        [-0.3711, -1.4497, -0.3401,  ...,  0.9489,  0.1867, -1.1034],
        [-0.5107, -1.4629, -0.2926,  ...,  1.0419,  0.0701, -1.0284],
        ...,
        [ 0.5006, -0.9552, -0.6610,  ...,  1.6013, -1.0622, -0.2191],
        [ 0.4988, -0.9451, -0.6656,  ...,  1.6467, -1.0858, -0.2088],
        [ 0.4923, -0.8124, -0.4912,  ...,  1.6108, -1.0174, -0.2484]],
       device='cuda:1')