In [1]:
import word_representations
import metrics, config

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.utils.data import DataLoader
import torch.utils.data as utils

import time


if config.mode == 'bert':
    word_repr = word_representations.Bert()
elif config.mode == 'elmo':
    word_repr = word_representations.Elmo()
elif config.mode == 'glove':
    word_repr = word_representations.GloVe()

    
def read_conllu(repr_mode, max_lines=50000):
    if repr_mode == 'train':
        path = 'en_ewt-ud-train.conllu'
    elif repr_mode == 'test':
        path = 'en_ewt-ud-dev.conllu'
        
    with open(path, 'r') as file:
        sentences_states = []
        sentences_tokens = []
        tokens = []
        for i, line in enumerate(file):
            if i > max_lines:
                break
            if line == '\n':
                text = '[CLS]' + ' '.join(tokens) + '[SEP]'
                if config.mode == 'bert':
                    if len(word_repr.get_bert(text)[:-1]) == len(tokens):
                        sentences_states.extend(word_repr.get_bert(text)[:-1])
                        sentences_tokens.extend(tokens)
                elif config.mode == 'elmo':
                    sentences_states.extend(word_repr.get_elmo(' '.join(tokens)))
                    sentences_tokens.extend(tokens)
                elif config.mode == 'glove':
                    sentences_states.extend(word_repr.get_glove(' '.join(tokens)))
                    sentences_tokens.extend(tokens)
                tokens = []
                continue
            if line[0] == '#':
                continue
            line = line.rstrip('\n')
            line = line.split('\t')
            symbols = ['.', ',', '<', '>', ':', ';', '\'', '/', '-', '_', '%', '@', '#', '$', '^', '*', '?', '!', "‘",
                       "’", "'", "+", '=', '|', '\’']
            if len(line[1]) > 1:
                for sym in symbols:
                    line[1] = line[1].replace(sym, '')
            if line[1] == '':
                line[1] = 'unk'
            tokens.append(line[1].lower())

        return sentences_states, sentences_tokens
    
states, tokens = read_conllu('train')
states_dev, tokens_dev = read_conllu('test')

In [2]:
vocab = {}
for t in tokens + tokens_dev: 
    if t not in vocab:
        vocab[t] = len(vocab)

inv_vocab = {v: k for k, v in vocab.items()}

In [3]:
print(len(vocab))

8991


In [4]:
def get_tensors(states, tokens):
    X = torch.stack(states)
    y_lst = []
    for t in tokens:
        y_lst.append(vocab[t])
    y = torch.LongTensor(y_lst)
    return(X, y)

X_train, y_train = get_tensors(states, tokens)
X_dev, y_dev = get_tensors(states_dev, tokens_dev)

In [5]:
dset = utils.TensorDataset(X_train, y_train)
data_loader = DataLoader(dset, batch_size=64, shuffle=True)

In [6]:
def accuracy(X, y, model):
    correct = 0
    total = 0
    for state, target in zip(X, y):
        outputs = model(state)
        max_value = torch.max(outputs.data)
        pred = (outputs.data == max_value).nonzero()
        total += 1
        correct += (pred == target).sum().item()
        #if pred != target:
            #print(inv_vocab[pred.item()], inv_vocab[target.item()])
        #else:
            #print(inv_vocab[pred.item()])
    acc = 100 * correct / total
    return acc

In [7]:
class ProbingModule(nn.Module):
    def __init__(
            self,
            input_dim,
            output_dim,
            hidden_dim=64,
    ):
        super(ProbingModule, self).__init__()
        self.hidden = nn.Linear(input_dim, hidden_dim)
        self.output = nn.Linear(hidden_dim, output_dim)

    def forward(self, X, **kwargs):
        X = F.relu(self.hidden(X))
        X = self.output(X)
        return X

In [8]:
for seed in range(1):
    model = ProbingModule(768, len(vocab))
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0)

    # n_steps = 0
    # evaluator = metrics.Evaluator()

    torch.manual_seed(seed)
    epoch_accuracies = []

    for epoch in range(10):
        start_time = time.time()
        for _, (state, target) in enumerate(data_loader):
            optimizer.zero_grad()
            tag_score = model(state)
            loss = criterion(tag_score, target)
            loss.backward()
            optimizer.step()

            # if n_steps in evaluator.steps:
                # with torch.no_grad():
                    # print("n_steps: {0}, accuracy: {1} ".format(n_steps, evaluator.accuracy_for_ws(X_dev, y_dev)))
            # n_steps += 64

        with torch.no_grad():
            epoch_acc = accuracy(X_dev, y_dev, model)
            print("Epoch: {0}, Seconds: {1}, Loss {2}, Acc Dev: {3}".format(epoch, (time.time() - start_time), loss,
                                                                                    epoch_acc))
            epoch_accuracies.append(epoch_acc)

    print('Best accuracy: {}'.format(max(epoch_accuracies)))

	nonzero()
Consider using one of the following signatures instead:
	nonzero(*, bool as_tuple) (Triggered internally at  ../torch/csrc/utils/python_arg_parser.cpp:766.)
  pred = (outputs.data == max_value).nonzero()


Epoch: 0, Seconds: 8.099352836608887, Loss 3.3974673748016357, Acc Dev: 60.77799012581621
Epoch: 1, Seconds: 8.67821192741394, Loss 1.58973228931427, Acc Dev: 71.60774008600096
Epoch: 2, Seconds: 8.362237930297852, Loss 0.5087944269180298, Acc Dev: 76.15464245899028
Epoch: 3, Seconds: 9.208853006362915, Loss 0.10409748554229736, Acc Dev: 77.5879917184265
Epoch: 4, Seconds: 8.867652893066406, Loss 0.036656249314546585, Acc Dev: 78.44003822264692
Epoch: 5, Seconds: 8.88765287399292, Loss 0.011511667631566525, Acc Dev: 78.79439401178531
Epoch: 6, Seconds: 9.1597580909729, Loss 0.0018127151997759938, Acc Dev: 78.97356266921484
Epoch: 7, Seconds: 8.903685808181763, Loss 0.020809367299079895, Acc Dev: 79.08106386367255
Epoch: 8, Seconds: 9.040055751800537, Loss 0.04157098010182381, Acc Dev: 79.2044911610129
Epoch: 9, Seconds: 8.731001853942871, Loss 0.0007580933161079884, Acc Dev: 79.00143334925944
Best accuracy: 79.2044911610129
