In [None]:
import json
import pathlib
import os
import sys  
sys.path.insert(0, '/nethome/bdevnani3/flash1/long_tail_lang/')
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from tqdm.notebook import trange, tqdm
from os import listdir
from os.path import isfile, join

In [None]:
from data_loader import dataloaders, classes
from clip import clip

In [None]:
dataset_path = '/nethome/bdevnani3/flash1/long_tail_lang/datasets/ImageNet/'
dataset = 'ImageNet'
split = "val"

In [None]:
dl = dataloaders.load_data(data_root= dataset_path, dataset=dataset, phase=split, batch_size=1)

In [None]:
# Initialize CLIP models 
class TextEncoder(nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        self.transformer = clip_model.transformer
        self.positional_embedding = clip_model.positional_embedding
        self.ln_final = clip_model.ln_final
        self.text_projection = clip_model.text_projection
        self.dtype = clip_model.dtype
        self.token_embedding = clip_model.token_embedding

    def forward(self, text):
        x = self.token_embedding(text).type(self.dtype)  # [batch_size, n_ctx, d_model]

        x = x + self.positional_embedding.type(self.dtype)
        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_final(x).type(self.dtype)

        # x.shape = [batch_size, n_ctx, transformer.width]
        # take features from the eot embedding (eot_token is the highest number in each sequence)
        x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection

        return x

def load_clip_to_cpu(visual_backbone):
    backbone_name = visual_backbone
    url = clip._MODELS[backbone_name]
    model_path = clip._download(url, os.path.expanduser("~/.cache/clip"))

    try:
        # loading JIT archive
        model = torch.jit.load(model_path, map_location="cpu").eval()
        state_dict = None

    except RuntimeError:
        state_dict = torch.load(model_path, map_location="cpu")

    model = clip.build_model(state_dict or model.state_dict())

    return model

clip_model = load_clip_to_cpu("RN50")

visual_model = torch.nn.DataParallel(clip_model.visual).cuda()

text_model = TextEncoder(clip_model)
text_model = torch.nn.DataParallel(text_model).cuda()

Image Encodings

In [104]:
# Encode images as CLIP embeddings: Unnormalized
output_path_images = f"/nethome/bdevnani3/flash1/long_tail_lang/datasets/ImageNet_balanced_emb/RN50/images/"
for inp, label, index, path in tqdm(dl):
    out = visual_model(inp.half()).float()
    new_path = path[0].split("ImageNet")[1]
    new_path =new_path.replace(".JPEG", ".pt")
#     new_path =new_path.replace("train", split)
    new_path = output_path_images+new_path
    new_path = pathlib.Path(new_path)
    new_path.parent.mkdir(parents=True, exist_ok=True)
    torch.save(out, new_path)

  0%|          | 0/50000 [00:00<?, ?it/s]

Text Encodings

In [105]:
# Encode labels as CLIP embeddings: Unnormalized
output_path = "/nethome/bdevnani3/flash1/long_tail_lang/datasets/ImageNet_balanced_emb/RN50/labels"


# Save prompt paths
prompt_indices = {}
for i, prompt in enumerate(classes.GENERIC_PROMPT_COLLECTIONS["ImageNet"]):
    prompt_indices[i] = prompt
fp = pathlib.Path(output_path)
fp.mkdir(exist_ok=True)
with open(output_path+'/prompt_indices.json', 'w') as f:
    json.dump(prompt_indices, f)


In [106]:
# Save embeddings
for c, actual_label in tqdm(enumerate(classes.CLASSES)):
    for i, prompt in enumerate(classes.GENERIC_PROMPT_COLLECTIONS["ImageNet"]):

        text = clip.tokenize(prompt.format(actual_label))
        texts = text.cuda()
        text_embedding = text_model(texts).float()

        new_path = output_path+f"/{c}/{i}.pt"
        new_path = pathlib.Path(new_path)
        new_path.parent.mkdir(parents=True, exist_ok=True)
        torch.save(text_embedding, new_path)
    

0it [00:00, ?it/s]

In [107]:
#save image path to label mapping
fp = pathlib.Path(output_path+f'/{split}/image_to_label.txt')
fp.parent.mkdir(parents=True, exist_ok=True)
with open(fp, 'w') as f:
    for inp, label, index, path in tqdm(dl):
        new_path = path[0].split("ImageNet")[1]
        new_path =new_path.replace(".JPEG", ".pt")
#         new_path =new_path.replace("train", split)
        new_path = output_path_images+new_path
        new_path = pathlib.Path(new_path)

        actual_label = classes.CLASSES[int(label)]
        label = label.item()
        onlyfiles = [str(f.absolute()) for f in pathlib.Path(f"{output_path}/{label}/").glob("*.pt")]
        to_write = [str(new_path), str(label), "_".join(str(actual_label).split(" "))]
        to_write.extend(onlyfiles)
        to_write.append("\n")
        
        f.write(" ".join(to_write))

  0%|          | 0/50000 [00:00<?, ?it/s]

Generating balanced imagenet dataset

In [86]:
# Get labels

import os

train_dir = "/nethome/bdevnani3/flash1/long_tail_lang/datasets/ImageNet/train"
val_dir = "/nethome/bdevnani3/flash1/long_tail_lang/datasets/ImageNet/val"

classes = next(os.walk(train_dir))[1]

c2i = {}
i = 0
for c in classes:
    c2i[c] = i
    i+=1
    
i2c = {}
for c,i in c2i.items():
    i2c[i] = c
    
i2c

{0: 'n07584110',
 1: 'n07695742',
 2: 'n07614500',
 3: 'n02097298',
 4: 'n04252225',
 5: 'n02093859',
 6: 'n03207941',
 7: 'n02484975',
 8: 'n04238763',
 9: 'n01978287',
 10: 'n04275548',
 11: 'n07742313',
 12: 'n02276258',
 13: 'n04090263',
 14: 'n02769748',
 15: 'n01770393',
 16: 'n01806567',
 17: 'n03476684',
 18: 'n02119022',
 19: 'n03729826',
 20: 'n02107908',
 21: 'n02280649',
 22: 'n02965783',
 23: 'n09229709',
 24: 'n03697007',
 25: 'n02137549',
 26: 'n01855032',
 27: 'n04310018',
 28: 'n03777754',
 29: 'n03983396',
 30: 'n02871525',
 31: 'n02980441',
 32: 'n03935335',
 33: 'n12267677',
 34: 'n01955084',
 35: 'n04131690',
 36: 'n04536866',
 37: 'n02398521',
 38: 'n03127925',
 39: 'n04009552',
 40: 'n03788195',
 41: 'n02823750',
 42: 'n03777568',
 43: 'n02095570',
 44: 'n01843383',
 45: 'n02504458',
 46: 'n03796401',
 47: 'n04417672',
 48: 'n03063599',
 49: 'n07802026',
 50: 'n13054560',
 51: 'n04465501',
 52: 'n02089867',
 53: 'n03595614',
 54: 'n03937543',
 55: 'n03787032',
 5

In [87]:
# All train image paths

from pathlib import Path
result = list(Path(val_dir).rglob("*.JPEG"))


In [88]:
with open("/nethome/bdevnani3/flash1/long_tail_lang/data/ImageNet/ImageNet_val.txt", 'w') as f:
    for path in result:
        p = str(result[0]).split("ImageNet/")[1]
        label = p.split("/")[1]
        i = c2i[label]
        f.write(f"{p} {i}\n")