In [1]:
#Load BERT base
from transformers import BertTokenizer, BertModel
import torch
import os
import pickle
import numpy as np

In [2]:
from tqdm.auto import tqdm, trange

In [3]:
UD_ENG_DIR="./UD2.6/ud-treebanks-v2.6/UD_English-EWT"
UD_ENG_Training_File = "en_ewt-ud-train.conllu"
OUT_DIR="./data/ud"

### LOAD BERT Model

In [4]:
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
model = BertModel.from_pretrained('bert-base-cased')

In [5]:
inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
outputs = model(**inputs, output_attentions=True, output_hidden_states=True)

last_hidden_states = outputs[:1]

In [6]:
#last_hidden_state (torch.FloatTensor of shape (batch_size, sequence_length, hidden_size)):
print("Last Hidden State:", len(last_hidden_states), ", Shape:", last_hidden_states[0].shape)

Last Hidden State: 1 , Shape: torch.Size([1, 8, 768])


In [7]:
#pooler_output (batch_size, hidden_size) for NSP objective.
print("CLS Token output:", outputs[1].shape)

CLS Token output: torch.Size([1, 768])


In [8]:
#hidden_states (batch_size, sequence_length, hidden_size).
print("Hidden states at each layer:", len(outputs[2]), "x", outputs[2][0].shape)

Hidden states at each layer: 13 x torch.Size([1, 8, 768])


In [9]:
#attentions (batch_size, num_heads, sequence_length, sequence_length)
print("Attentions:", len(outputs[3]), "x", outputs[3][0].shape)

Attentions: 12 x torch.Size([1, 12, 8, 8])


### Feed UD Tree Bank Example

In [None]:
model = model.cuda()
model.eval()
words, heads, rels = [], [], []
dev_data = []
with open(os.path.join(UD_ENG_DIR, UD_ENG_Training_File), 'r') as fp:
    for cnt, line in enumerate(tqdm(fp)):
        #if cnt > 100000:
        #    break
        if line.startswith("# text"):
            if len(words) > 0:
                record = {}
                record["words"]=words
                try:
                    record["heads"]=np.array(list(map(int, heads)))
                except Exception as e:
                    words, heads, rels, attns = [], [], [], []
                    record = {}
                    continue
                record["relns"]=rels
                inputs = tokenizer(" ".join(record["words"]), return_tensors="pt")
                for k, v in inputs.items():
                    if isinstance(v, torch.Tensor):
                        inputs[k] = v.cuda()
                outputs = model(**inputs, output_attentions=True, output_hidden_states=True)
                attns = torch.stack(outputs[3]).squeeze(1)
                record["attns"]=attns.detach().cpu()
                if record["attns"].shape[-1] == (len(words) + 2):
                    dev_data.append(record)
                    #if(len(dev_data) >= 1000):
                    #    break
            words, heads, rels, attns = [], [], [], []
        record = line.split("\t")
        if len(record) > 8:
            words.append(record[1])
            heads.append(record[6])
            rels.append(record[7])


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

In [11]:
torch.cuda.empty_cache()

In [12]:
pickle.dump(dev_data, open(os.path.join(OUT_DIR, "ud_attention_data.pkl"), 'wb'))