In [1]:
import datasets, torch, transformers
import numpy as np
from icecream import ic
import logging, colorama
from tqdm import tqdm
import math
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
log_format = (
    colorama.Fore.MAGENTA
    + "[%(asctime)s %(name)s %(levelname)s] "
    + colorama.Fore.WHITE
    + "%(message)s"
)

In [2]:
vinvl_path = '/home/yangliu/data/vqav2/processed_tokenized_separate_top3k'
vinvl_datasets = datasets.load_from_disk(vinvl_path)


In [3]:
vinvl_datasets

DatasetDict({
    train: Dataset({
        features: ['attention_mask', 'image_id', 'input_ids', 'labels', 'multiple_labels', 'question_type', 'tag_attention_mask', 'tag_ids', 'tags', 'token_type_ids'],
        num_rows: 443757
    })
    val: Dataset({
        features: ['attention_mask', 'image_id', 'input_ids', 'labels', 'multiple_labels', 'question_type', 'tag_attention_mask', 'tag_ids', 'tags', 'token_type_ids'],
        num_rows: 214354
    })
})

In [4]:
vinvl_datasets["train"]["tags"][0]

'broccoli container almond meat container fruit fruit slice container almond potato muffin fruit bowl almond pineapple box tomato plastic fruit'

In [5]:
# Find all tags words
tags_words = []
for tag in vinvl_datasets["train"]["tags"]:
    tags_words.extend(tag.split(" "))
tags_words = list(set(tags_words))
# Show some of the tags
ic(tags_words[:10])
ic(len(tags_words))

ic| tags_words[:10]: ['mat',
                      'bridle',
                      'watermelon',
                      'grapefruit',
                      'park',
                      'fire',
                      'column',
                      'spoke',
                      'hamburger',
                      'milk']
ic| len(tags_words): 1293


1293

In [6]:
# Find 'urn' in all tags
for tag in vinvl_datasets["train"]["tags"]:
    if "urn" in tag.split(" "):
        print(tag)
        break

pot basket bagel bowl ceiling table lid pot container countertop container person spoon pot spoon spoon bottle vase pot paper spoon kitchen pot donut basket person napkin paper lid pot urn bread basket bread cookie spoon food countertop plate lid pot ceiling light lamp lid bread pot person


In [7]:
class ObjectDict(dict):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def __setattr__(self, key, value):
        self.__setitem__(key, value)
    
    def __getattr__(self, key):
        return self.__getitem__(key)

    def purge(self):
        r"""converts all sub-elements into ObjectDict
        """
        for k, v in self.items():
            if type(v) == dict:
                res = ObjectDict()
                res.update(v)
                res.purge()
                self[k] = res


In [8]:

opt = ObjectDict()
opt.model_type = "clip"
opt.answer_bs = 16
opt.save_path = '../bert-vqa/t2v_clip.pt'

answers = tags_words
device = "cpu"

import clip
with torch.no_grad():
    logging.info("Generating Answer Embeddings...")
    
    model, preprocess = clip.load("RN50", device=device)
    embedded = None
    for idx in tqdm(range(math.ceil(len(answers) / opt.answer_bs))):
        answers_chunk = answers[idx * opt.answer_bs : (idx + 1) * opt.answer_bs]
        answers_tok = clip.tokenize(
            answers_chunk,
        ).to(device)
        text_feature = model.encode_text(answers_tok)
        if embedded is None:
            embedded = text_feature.detach().cpu()  # take first token outputs
        else:
            embedded = torch.cat((embedded, text_feature.detach().cpu()), dim=0)
        # ic(embedded.shape)

embedded = {tag: embedding.numpy() for tag, embedding in zip(answers, embedded)}
# Save a tag-to-embedding mapping
logging.info("Saving Word Embeddings...")
with open(opt.save_path, "wb") as f:
    torch.save(embedded, f)


100%|██████████| 81/81 [00:49<00:00,  1.65it/s]


In [9]:
# Load saved embeddings
with open(opt.save_path, "rb") as f:
    embedded = torch.load(f)

ic(embedded["apple"].shape)


ic| embedded["apple"].shape: (1024,)


(1024,)