In [2]:
from transformers import RobertaTokenizer, RobertaForMaskedLM

import os
import torch
import tqdm
import numpy as np

tokenizer = RobertaTokenizer.from_pretrained('./pretrained/roberta')
model = RobertaForMaskedLM.from_pretrained('./pretrained/roberta', output_hidden_states=True)
model.to("cuda:0")

txt_root = "./data/text_c10"
img_root = "./data/CUB_200_2011/images"
sentences = []
sentence_lengths = []
saved_paths = []
for bird in tqdm.tqdm(os.listdir(txt_root)):
    if ".txt" in bird:
        continue
    bird_root = os.path.join(txt_root, bird)
    for filename in os.listdir(bird_root):
        if ".txt" not in filename:
            continue
        with open(os.path.join(bird_root, filename)) as f:
            lines = []
            lengths = []
            for line in f.readlines():
                lines.append(line)
                line = list(filter(lambda x: len(x)!=1, line.split(" ")))
                lengths.append(len(line))
            input_txt = lines[lengths.index(sorted(lengths)[-3])]
            # input_txt = f.readlines()[0]
            sentences.append(input_txt)
            sentence_lengths.append(len(input_txt))
            saved_path = os.path.join(img_root, bird, filename) + ".npy"
            saved_paths.append(saved_path)
tokens = tokenizer(sentences, return_tensors="pt", padding=True).to("cuda:0")

tokens_lengths = []
for saved_path, input_id, mask in tqdm.tqdm(zip(saved_paths, tokens.input_ids, tokens.attention_mask), total=len(saved_paths)):
    input_id = input_id.unsqueeze(0)
    mask = mask.unsqueeze(0)
    outputs = model.roberta(input_id, attention_mask=mask)
    tokens_lengths.append(outputs.hidden_states[-1].shape[1])
    try:
        os.remove(saved_path.replace(".npy", ".pt"))
    except FileNotFoundError:
        pass
    hidden_state = (outputs.hidden_states[-1] + outputs.hidden_states[-2] + outputs.hidden_states[-3] + outputs.hidden_states[-4]) / 4

    np.save(saved_path, hidden_state.cpu().detach().numpy())

100%|██████████| 202/202 [00:00<00:00, 401.44it/s]
100%|██████████| 11788/11788 [01:35<00:00, 123.76it/s]


In [3]:
len(outputs.hidden_states)

13

In [4]:
outputs.hidden_states[-1].shape

torch.Size([1, 49, 768])