In [1]:
from torch import nn
import torch
import numpy as np
import pickle
from torch import optim

In [2]:
def load_pickle(fname):
  with open(fname, "rb") as f:
    dev_data =  pickle.load(f, encoding="latin1")  # add, encoding="latin1") if using python3 and downloaded data
    return (dev_data)

In [3]:
class AttentionProbe(nn.Module):
    def __init__(self, glove_path="./data/glove/embeddings.pkl", num_attention_heads=144, non_trainable=False):
        super(AttentionProbe, self).__init__()
        
        weights_matrix = self.load_glove(glove_path)
        self.vocab_length, self.vocab_dim = weights_matrix.shape
        self.embeddings = nn.Embedding(self.vocab_length, self.vocab_dim)
        self.embeddings.load_state_dict({'weight': torch.tensor(weights_matrix)})
        if non_trainable:
            self.embeddings.weight.requires_grad = False
            
        self.weight_layer = nn.Linear(self.vocab_dim + self.vocab_dim, num_attention_heads*2)
        
        self.softmax = nn.LogSoftmax(dim=1)
        self.loss = nn.NLLLoss()
        
    def load_glove(self, glove_path):
        weights_matrix = load_pickle(glove_path)
        return (weights_matrix)
        
    def forward(self, tokens, labels, attns):
        n_words = len(tokens)
        tokens = self.embeddings(tokens)
        tokens_pairs = torch.cat((tokens.repeat(1,n_words,1).squeeze().view(n_words*n_words, self.vocab_dim), 
                          tokens.repeat(1,1,n_words).squeeze().view(n_words*n_words, self.vocab_dim)), 
                                 dim=-1).view(n_words,n_words,self.vocab_dim+self.vocab_dim).transpose(1,0)
        
        tokens_pairs = torch.cat((torch.zeros((n_words, 1, 200)).cuda(), tokens_pairs), 1) # dummy for ROOT
        #tokens_pairs = torch.cat((tokens_pairs, torch.zeros((n_words, 1, 200)).cuda()), 1) # dummy for ROOT

        tokens_h = self.weight_layer(tokens_pairs)
        alayers, aheads, awords, awords = attns.shape
        #attns = attns.view(alayers*aheads, awords, awords)
        attns = torch.cat((attns, attns.transpose(3, 2)), 0).view(alayers*aheads*2, awords, awords)
        #attns = attns[:,1:-1, 1:-1]
        attns = torch.cat(((attns[:,1:-1, 0] + attns[:,1:-1, -1]).unsqueeze(-1), attns[:,1:-1, 1:-1]),-1)
        attns = attns.transpose(2,1).transpose(0,2)
        attns_tokens = attns * tokens_h
        attns_sum = attns_tokens.sum(-1)
        loss = self.loss(self.softmax(attns_sum), labels)
        outputs = [loss, attns_sum]
        return outputs 

In [5]:
class GloveTokenizer():
    def __init__(self):
        self.vocab = load_pickle("./data/glove/vocab.pkl")
        
    def get_token_ids(self, words):
        token_ids = []
        for word in words:
            token_id = self.vocab.get(word.lower(), 0)
            token_ids.append(token_id)
        return (torch.tensor(token_ids))

tokenizer = GloveTokenizer()
model = AttentionProbe()

In [6]:
attention_data = load_pickle("./data/ud/ud_attention_data.pkl")
train_samples = round(0.8*len(attention_data))
train_data = attention_data[:train_samples]
dev_data = attention_data[train_samples:]

In [7]:
def train(example, optimizer):
    attns = example["attns"]
    labels = torch.tensor(example["heads"]).squeeze()
    input_tokens = tokenizer.get_token_ids(example["words"])
    input_tokens = input_tokens.cuda()
    labels = labels.cuda()
    attns = attns.cuda()
    outputs = model(input_tokens, labels, attns)
    loss = outputs[0]
    loss.backward()
    optimizer.step()
    model.zero_grad()
    oloss = float(loss.detach().cpu().numpy())
    return(oloss)

In [8]:
losses = []

learning_rate = 0.002
#optimizer = optim.SGD(model.parameters(), lr=learning_rate)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    
model.cuda()
model.train()
for epoch in range(2):
    for i, example in enumerate(train_data):
        if i % 100 == 0:
            print("{:}/{:}".format(i, len(train_data)))
        if(len(example["words"]) <= 1):
            continue
        #print(i, example["words"])
        loss = train(example, optimizer)
        losses.append(loss)

0/3369
100/3369
200/3369
300/3369
400/3369
500/3369
600/3369
700/3369
800/3369
900/3369
1000/3369
1100/3369
1200/3369
1300/3369
1400/3369
1500/3369
1600/3369
1700/3369
1800/3369
1900/3369
2000/3369
2100/3369
2200/3369
2300/3369
2400/3369
2500/3369
2600/3369
2700/3369
2800/3369
2900/3369
3000/3369
3100/3369
3200/3369
3300/3369
0/3369
100/3369
200/3369
300/3369
400/3369
500/3369
600/3369
700/3369
800/3369
900/3369
1000/3369
1100/3369
1200/3369
1300/3369
1400/3369
1500/3369
1600/3369
1700/3369
1800/3369
1900/3369
2000/3369
2100/3369
2200/3369
2300/3369
2400/3369
2500/3369
2600/3369
2700/3369
2800/3369
2900/3369
3000/3369
3100/3369
3200/3369
3300/3369


In [9]:
for i in range(round(len(losses)/1000)):
    print(np.nanmean(losses[i*1000:((i+1)*1000)]))

0.9759140166319048
0.7724338806722153
0.9837104285235737
0.5255733057877713
0.3910336213554157
0.562712003280832


In [10]:
def evaluate(example):
    model.eval()
    attns = example["attns"]
    labels = torch.tensor(example["heads"]).squeeze()
    input_tokens = tokenizer.get_token_ids(example["words"])
    input_tokens = input_tokens.cuda()
    labels = labels.cuda()
    attns = attns.cuda()
    outputs = model(input_tokens, labels, attns)
    attns_sum = outputs[1]
    preds = np.argmax(attns_sum.detach().cpu(), axis=1)
    return(preds)

In [11]:
example = train_data[6]
preds = evaluate(example)

In [12]:
model.cuda()
print("Evaluating...")
correct, total = 0, 0
for i, example in enumerate(dev_data):
    
    if i % 100 == 0:
        print("{:}/{:}".format(i, len(dev_data)))
    if len(example["words"]) <= 1:
        continue
    preds = evaluate(example)
    #print (example["words"])
    
    for j, (head, prediction, reln) in enumerate(zip(example["heads"], preds.numpy(), example["relns"])):
        # it is standard to ignore punct for Stanford Dependency evaluation
        #print(head, prediction, reln)
        if reln != "punct":
            if head == prediction:
                correct += 1
            total += 1

print("UAS: {:.1f}".format(100 * correct / total))

Evaluating...
0/842
100/842
200/842
300/842
400/842
500/842
600/842
700/842
800/842
UAS: 75.2
