In [1]:
import pandas as pd
import numpy as np

from collections import namedtuple
from enum import Enum
import tqdm
import torch
import torch.nn as nn
import torch.utils.data 
import torch.optim

In [2]:
Example = namedtuple("Example", "word pos head label")

In [3]:
def load_examples(fn):
    examples = []
    ex = Example([],[],[],[])
    lines = map(lambda x: x.strip().split(), open(fn).readlines())
    for line in lines:
        if len(line) == 0:
            examples.append(ex)
            ex = Example([],[],[],[])
            continue
        ex.word.append(line[1])
        ex.pos.append(line[3])
        ex.head.append(int(line[6])-1)
        ex.label.append(line[7])

    return examples

In [4]:
train_set = load_examples("data/train.conll")
dev_set = load_examples("data/dev.conll")

In [5]:

def collect_unique(examples, idx):
    result = []
    for ex in examples:
        result.extend(ex[idx])
    return sorted(list(set(result)))

In [6]:
class Embedding(object):
    def __init__(self, tokens, dim, filename=None):
        self.tokens = sorted(tokens)
        self.token2id = {token.lower() : n for (n, token) in enumerate(tokens)}
        self.embedding_dim = dim
        self.NULL = "<NULL>"
        self.UNK = "<UNK>"
        for t in [self.NULL, self.UNK]:
            self.token2id[t] = len(self.tokens)
            self.tokens.append(t)
        self.NULL_ID = self.get_id(self.NULL) 
        self.embed = nn.Embedding(len(self.tokens), self.embedding_dim)
        self.embed.weight.data.copy_(self.get_initial(filename))
        
    def get_initial(self, filename):
        matrix = (2*np.random.rand(len(self.tokens), self.embedding_dim)-1)/100 # -0.01<->0.01
        if filename:
            loaded = 0
            lines = open(filename, "rt").readlines()
            for line in lines:
                data = line.split()
                word = data[0].lower()
                if word in self.token2id:
                    loaded += 1
                    idx = self.token2id[word]
                    vec = list(map(float, data[1:]))
                    matrix[idx, :] = vec
            print("Loaded ", loaded, " pre-trained word vectors")
        return torch.from_numpy(matrix)
    
    
    def get_id(self, token):
        if isinstance(token, int):
            return token
        if token in self.token2id:
            return self.token2id[token]
        else:
            return self.token2id[self.UNK]

In [50]:
word_embedding = Embedding(collect_unique(train_set, 0), 50, "data/en-cw.txt")
pos_embedding = Embedding(collect_unique(train_set, 1), 50)
label_embedding = Embedding(collect_unique(train_set, 3), 50 )

Loaded  27518  pre-trained word vectors


In [51]:
class TransitionType(Enum):
    SHIFT = 1
    LEFT = 2
    RIGHT = 3

Transition = namedtuple("Transition", "kind label")

In [52]:
isinstance(3, int)

True

In [53]:
class Predictor(nn.Module):
    def __init__(self, word_embeddings, label_embeddings, pos_embeddings, D_in, hidden, D_out):
        super().__init__()
        self.word_embeddings = word_embeddings.embed  
        self.label_embeddings = label_embeddings.embed
        self.pos_embeddings = pos_embeddings.embed
        self.fc1 = nn.Linear(D_in, hidden)
        nn.init.xavier_uniform_(self.fc1.weight.data)
        self.fc1_dropout = nn.Dropout()
        self.fc2 = nn.Linear(hidden, D_out)
        nn.init.xavier_uniform_(self.fc2.weight.data)
    
    def forward(self, word_ids, pos_ids, label_ids):
        W = self.word_embeddings(word_ids)
        P = self.pos_embeddings(pos_ids)
        L = self.label_embeddings(label_ids)
        #print(W.shape, P.shape, L.shape)
        W = torch.reshape(W, (W.shape[0], np.prod(W.shape[1:])))
        #W = W.view((1, -1))
        P = torch.reshape(P, (P.shape[0], np.prod(P.shape[1:])))
        L = torch.reshape(L, (L.shape[0], np.prod(L.shape[1:])))
        #print(W.shape, P.shape, L.shape)
        combined = torch.cat((W, P, L), 1)
        relu = self.fc1(combined).clamp(min=0)
        dropped = self.fc1_dropout(relu)
        return self.fc2(dropped)
    
    

In [54]:
Arc = namedtuple("Arc", "head word label")

In [55]:
class Parse(object):
    def __init__(self, sentence, W_NULL="<NULL>", P_NULL="<NULL>", L_NULL="<NULL>"):
        self.sentence = sentence
        N = len(sentence.word)
        self.buf = [N - i - 1 for i in range(N)]
        self.stack = []
        self.arcs = []
        self.W_NULL = W_NULL
        self.P_NULL = P_NULL
        self.L_NULL = L_NULL
        
    
    def transition(self, transition, set_label=False):
        if transition.kind == TransitionType.SHIFT:
            self.stack.append(self.buf.pop())
        elif transition.kind == TransitionType.LEFT:
            if set_label:
                self.sentence.label[self.stack[-2]] = transition.label
                self.sentence.head[self.stack[-2]] = self.stack[-1]
            self.arcs.append((self.stack[-1], self.stack[-2], transition.label))
            self.stack.pop(-2)
        else: # RIGHT
            if set_label:
                self.sentence.label[self.stack[-1]] = transition.label
                self.sentence.head[self.stack[-1]] = self.stack[-2]
            self.arcs.append((self.stack[-2], self.stack[-1], transition.label))
            self.stack.pop()
    
    def get_oracle(self):
        if len(self.stack) < 2:
            if len(self.buf) > 0:
                return Transition(TransitionType.SHIFT, 0)
            else:
                return None
        
        i0 = self.stack[-1]
        i1 = self.stack[-2]
        head0 = self.sentence.head[i0]
        head1 = self.sentence.head[i1]
        label0 = self.sentence.label[i0]
        label1 = self.sentence.label[i1]

        if (i1 >= 0) and (head1 == i0):
            return Transition(TransitionType.LEFT, label1)
        elif (i1 >= 0) and (head0 == i1) and \
             (not any([x for x in self.buf if self.sentence.head[x] == i0])): # don't remove i0 if we still need it
            return Transition(TransitionType.RIGHT, label0)
        elif len(self.buf) > 0:
            return Transition(TransitionType.SHIFT, 0)
        else:
            return None

    def get_cur_features(self):
        def get_lc(k):
            return sorted([arc[1] for arc in self.arcs if arc[0] == k and arc[1] < k])

        def get_rc(k):
            return sorted([arc[1] for arc in self.arcs if arc[0] == k and arc[1] > k],
                          reverse=True)
        
        
        w_features = []
        p_features = []
        l_features = []
        
        def add(source_idx, source, target, i, null):
            if i >= len(source_idx):
                target.append(null)
            else:
                target.append(source[source_idx[-(i+1)]])
        for i in range(3):
            add(self.buf, self.sentence.word, w_features, i, self.W_NULL)
            add(self.stack, self.sentence.word, w_features, i, self.W_NULL)
            add(self.buf, self.sentence.pos, p_features, i, self.P_NULL)
            add(self.stack, self.sentence.pos, p_features, i, self.P_NULL)
        
        def add2(target, source, arr, idx, null):
            if len(arr) > idx:
                target.append(source[arr[idx]])
            else:
                target.append(null)
        
        for i in range(2):
            if i < len(self.stack):
                k = self.stack[-i-1]
                lc = get_lc(k)
                rc = get_rc(k)
                llc = get_lc(lc[0]) if len(lc) > 0 else []
                rrc = get_rc(rc[0]) if len(rc) > 0 else []

                for target, source, null in [(w_features, self.sentence.word, self.W_NULL),
                                       (p_features, self.sentence.pos, self.P_NULL),
                                       (l_features, self.sentence.label, self.L_NULL)]:
                    for arr, idx in [(lc, 0), (rc, 0), (lc, 1), (rc, 1), (llc, 0), (rrc, 0)]:
                        add2(target, source, arr, idx, null)
            else:
                w_features += [self.W_NULL] * 6
                p_features += [self.P_NULL] * 6
                l_features += [self.L_NULL] * 6
        
        # n_features = 48
        return w_features, p_features, l_features
    
    def get_instances(self):
        result = []
        transition = self.get_oracle()
        while transition:
            result.append((self.get_cur_features(), transition))
            self.transition(transition)
            transition = self.get_oracle()
        return result

In [56]:
class ParseDataset(torch.utils.data.Dataset):
    def __init__(self, samples):
        self.samples = samples
    def __len__(self):
        return len(self.samples)
    def __getitem__(self, i):
        return self.samples[i]

class Parser(object):
    def __init__(self, word_embeddings, label_embeddings, pos_embeddings):
        self.word_embeddings = word_embeddings
        self.label_embeddings = label_embeddings
        self.pos_embeddings = pos_embeddings
        self.transitions = [Transition(TransitionType.SHIFT, 0)]
        for i in range(len(label_embeddings.tokens) - 2): # to cut off <UNK> and <NULL>:
            self.transitions.append(Transition(TransitionType.LEFT, i))
            self.transitions.append(Transition(TransitionType.RIGHT, i))
        
        self.transition2id = {tran:i for (i, tran) in enumerate(self.transitions)} 
        self.model = Predictor(word_embeddings, label_embeddings, pos_embeddings, 48*50, 200, len(self.transitions))
    
    def vectorize(self, sentence):
        return Example([self.word_embeddings.get_id(t) for t in sentence.word],
                       [self.pos_embeddings.get_id(t) for t in sentence.pos],
                       sentence.head,
                       [self.label_embeddings.get_id(t) for t in sentence.label])
    
    
    def create_dataset(self, examples):
        samples = []
        for ex in tqdm.tqdm(examples):
            parse = Parse(self.vectorize(ex), self.word_embeddings.NULL_ID, 
                                              self.pos_embeddings.NULL_ID, 
                                              self.label_embeddings.NULL_ID)
            for x, y in parse.get_instances():
                samples.append({"w": x[0], "p" : x[1], "l" : x[2], "label" : self.transition2id[y]})
        
        dataset = ParseDataset(samples)
        return dataset
    
    def train(self, dataset, batch_size=10000, epochs=1, lr=1e-2):
        dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
        loss = nn.CrossEntropyLoss()
        optim = torch.optim.Adam(self.model.parameters(), lr=lr, weight_decay=1e-8)
        i = 0
        #print(list(map(lambda x: x.shape, self.model.parameters())))
        for _ in range(epochs):
            for samples in tqdm.tqdm(dataloader):
                i += 1
                #print(samples)
                optim.zero_grad()
                w = torch.stack(samples["w"], 1)
                p = torch.stack(samples["p"], 1)
                l = torch.stack(samples["l"], 1)
                #print(w.shape, p.shape, l.shape)
                pred = self.model(w, p, l)
                output = loss(pred, samples["label"])
                #if i % 50 == 1:
                #    print("loss: ", output.data)
                output.backward()
                optim.step()

                
    def parse_sentence(self, sentence, score=True):
        vectorized = Example([self.word_embeddings.get_id(t) for t in sentence.word],
                           [self.pos_embeddings.get_id(t) for t in sentence.pos],
                           [-1 for t in sentence.word],
                           [self.label_embeddings.NULL_ID for t in sentence.word])
        parse = Parse(vectorized, self.word_embeddings.NULL_ID, 
                                          self.pos_embeddings.NULL_ID, 
                                          self.label_embeddings.NULL_ID)
        SHIFT = Transition(TransitionType.SHIFT, 0)
        SHIFT_ID = self.transition2id[SHIFT]
        while len(parse.stack) >= 2 or len(parse.buf) > 0: 
            w, p, l = parse.get_cur_features()
            w = torch.unsqueeze(torch.LongTensor(w), 0)
            p = torch.unsqueeze(torch.LongTensor(p), 0)
            l = torch.unsqueeze(torch.LongTensor(l), 0)
            if len(parse.stack) < 2 and len(parse.buf) > 0:
                transition = SHIFT
            else:
                pred = self.model(w, p, l).clone().detach().numpy()
                if len(parse.buf) == 0:
                    pred[SHIFT_ID] = -np.inf
                transition = self.transitions[np.argmax(pred)]
                if len(parse.buf) == 0 and transition == SHIFT:
                    transition = Transition(TransitionType.RIGHT, 0)

            #print("Transition: ", transition)
            parse.transition(transition, True)
            
        if score:
            total = 0
            correct = 0
            for p_h, c_h in zip(parse.sentence.head, sentence.head):
                if c_h == -1:
                    continue
                else:
                    total += 1
                    correct += int(p_h == c_h)
            return total, correct
        
    def compute_UAS(self, dataset):
        all_tokens = 0
        correct_tokens = 0
        for ex in tqdm.tqdm(dataset):
            a,c = self.parse_sentence(ex)
            all_tokens += a
            correct_tokens += c
        return correct_tokens/all_tokens

In [57]:
parser = Parser(word_embedding, label_embedding, pos_embedding)

In [58]:
len(train_set)

39832

In [59]:
dev_UAS = []
train_dataset = parser.create_dataset(train_set)
for _ in range(10):
    parser.train(train_dataset, lr=5e-4)
    UAS = parser.compute_UAS(dev_set)
    print("UAS=", UAS)
    dev_UAS.append(UAS)

100%|██████████| 39832/39832 [00:55<00:00, 718.91it/s]
100%|██████████| 186/186 [05:01<00:00,  1.62s/it]
100%|██████████| 1700/1700 [00:43<00:00, 39.25it/s]
  0%|          | 0/186 [00:00<?, ?it/s]

UAS= 0.6103808209907072


100%|██████████| 186/186 [04:51<00:00,  1.57s/it]
100%|██████████| 1700/1700 [00:28<00:00, 59.21it/s]
  0%|          | 0/186 [00:00<?, ?it/s]

UAS= 0.7254080224900434


100%|██████████| 186/186 [05:07<00:00,  1.66s/it]
100%|██████████| 1700/1700 [00:28<00:00, 58.98it/s]
  0%|          | 0/186 [00:00<?, ?it/s]

UAS= 0.7554468074029727


100%|██████████| 186/186 [05:14<00:00,  1.69s/it]
100%|██████████| 1700/1700 [00:28<00:00, 60.25it/s]
  0%|          | 0/186 [00:00<?, ?it/s]

UAS= 0.7748392638675586


100%|██████████| 186/186 [05:22<00:00,  1.74s/it]
100%|██████████| 1700/1700 [00:28<00:00, 59.44it/s]
  0%|          | 0/186 [00:00<?, ?it/s]

UAS= 0.7881666970351667


100%|██████████| 186/186 [05:18<00:00,  1.71s/it]
100%|██████████| 1700/1700 [00:28<00:00, 59.85it/s]
  0%|          | 0/186 [00:00<?, ?it/s]

UAS= 0.792695941900721


100%|██████████| 186/186 [04:59<00:00,  1.61s/it]
100%|██████████| 1700/1700 [00:28<00:00, 59.27it/s]
  0%|          | 0/186 [00:00<?, ?it/s]

UAS= 0.8006351354868938


100%|██████████| 186/186 [05:14<00:00,  1.69s/it]
100%|██████████| 1700/1700 [00:27<00:00, 60.91it/s]
  0%|          | 0/186 [00:00<?, ?it/s]

UAS= 0.8036806622068355


100%|██████████| 186/186 [05:20<00:00,  1.73s/it]
100%|██████████| 1700/1700 [00:27<00:00, 61.93it/s]
  0%|          | 0/186 [00:00<?, ?it/s]

UAS= 0.8075851836426582


100%|██████████| 186/186 [05:15<00:00,  1.70s/it]
100%|██████████| 1700/1700 [00:28<00:00, 60.14it/s]

UAS= 0.8127651820808496





In [222]:
#dev_UAS 
# no dropout - [0.5962792315048249, 0.673041815178649]
# with dropout - [0.5324697904894375, 0.5870642441102321]
# On full data: 0.69, 0.73
# updated: 0.66, 0.73, 0.75, 0.76, 0.77

[0.5962792315048249, 0.673041815178649]

In [229]:
dev_UAS

[0.5324697904894375, 0.5870642441102321]

In [18]:
parser.parse_sentence(dev_set[25])

(21, 15)

In [28]:
def parse_sentence2(self, sentence, score=True):
    vectorized = Example([self.word_embeddings.get_id(t) for t in sentence.word],
                       [self.pos_embeddings.get_id(t) for t in sentence.pos],
                       [-1 for t in sentence.word],
                       [self.label_embeddings.NULL_ID for t in sentence.word])
    parse = Parse(vectorized, self.word_embeddings.NULL_ID, 
                                      self.pos_embeddings.NULL_ID, 
                                      self.label_embeddings.NULL_ID)
    SHIFT = Transition(TransitionType.SHIFT, 0)
    SHIFT_ID = self.transition2id[SHIFT]
    while len(parse.stack) >= 2 or len(parse.buf) > 0: 
        w, p, l = parse.get_cur_features()
        w = torch.unsqueeze(torch.LongTensor(w), 0)
        p = torch.unsqueeze(torch.LongTensor(p), 0)
        l = torch.unsqueeze(torch.LongTensor(l), 0)
        if len(parse.stack) < 2 and len(parse.buf) > 0:
            transition = SHIFT
        else:
            pred = self.model(w, p, l).clone().detach().numpy()
            if len(parse.buf) == 0:
                pred[SHIFT_ID] = -np.inf
            transition = self.transitions[np.argmax(pred)]
            if len(parse.buf) == 0 and transition == SHIFT:
                transition = Transition(TransitionType.RIGHT, 0)

        #print("Transition: ", transition)
        parse.transition(transition, True)
    print(parse.sentence.head)
    if score:
        total = 0
        correct = 0
        for p_h, c_h in zip(parse.sentence.head, sentence.head):
            if c_h == -1:
                continue
            else:
                total += 1
                correct += int(p_h == c_h)
        return total, correct

In [29]:
parse_sentence2(parser, dev_set[18])

[4, 2, 4, 4, -1, 6, 4, 8, 6, 10, 8, 13, 13, 10, 16, 16, 13, 19, 19, 4, 21, 19, 23, 21, 4, 4]


(25, 25)

In [30]:
dev_set[25]

Example(word=['PaineWebber', 'Inc.', 'filmed', 'a', 'new', 'television', 'commercial', 'at', '4', 'p.m.', 'EDT', 'yesterday', 'and', 'had', 'it', 'on', 'the', 'air', 'by', 'last', 'night', '.'], pos=['PROPN', 'PROPN', 'VERB', 'DET', 'ADJ', 'NOUN', 'NOUN', 'ADP', 'NUM', 'NOUN', 'PROPN', 'NOUN', 'CONJ', 'VERB', 'PRON', 'ADP', 'DET', 'NOUN', 'ADP', 'ADJ', 'NOUN', 'PUNCT'], head=[1, 2, -1, 6, 6, 6, 2, 10, 10, 10, 2, 2, 2, 2, 13, 17, 17, 13, 20, 20, 13, 2], label=['compound', 'nsubj', 'root', 'det', 'amod', 'compound', 'dobj', 'case', 'nummod', 'compound', 'nmod', 'nmod:tmod', 'cc', 'conj', 'dobj', 'case', 'det', 'nmod', 'case', 'amod', 'nmod', 'punct'])