In [1]:
import word_representations
import metrics, config

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.utils.data import DataLoader
import torch.utils.data as utils


if config.mode == 'bert':
    word_repr = word_representations.Bert()
elif config.mode == 'elmo':
    word_repr = word_representations.Elmo()
elif config.mode == 'glove':
    word_repr = word_representations.GloVe()

    
def read_conllu(repr_mode, max_lines=50000):
    if repr_mode == 'train':
        path = 'en_ewt-ud-train.conllu'
    elif repr_mode == 'test':
        path = 'en_ewt-ud-dev.conllu'
        
    with open(path, 'r') as file:
        sentences_states = []
        sentences_tokens = []
        tokens = []
        for i, line in enumerate(file):
            if i > max_lines:
                break
            if line == '\n':
                text = '[CLS]' + ' '.join(tokens) + '[SEP]'
                if config.mode == 'bert':
                    if len(word_repr.get_bert(text)[:-1]) == len(tokens):
                        sentences_states.extend(word_repr.get_bert(text)[:-1])
                        sentences_tokens.extend(tokens)
                elif config.mode == 'elmo':
                    sentences_states.extend(word_repr.get_elmo(' '.join(tokens)))
                    sentences_tokens.extend(tokens)
                elif config.mode == 'glove':
                    sentences_states.extend(word_repr.get_glove(' '.join(tokens)))
                    sentences_tokens.extend(tokens)
                tokens = []
                continue
            if line[0] == '#':
                continue
            line = line.rstrip('\n')
            line = line.split('\t')
            symbols = ['.', ',', '<', '>', ':', ';', '\'', '/', '-', '_', '%', '@', '#', '$', '^', '*', '?', '!', "‘",
                       "’", "'", "+", '=', '|', '\’']
            if len(line[1]) > 1:
                for sym in symbols:
                    line[1] = line[1].replace(sym, '')
            if line[1] == '':
                line[1] = 'unk'
            tokens.append(line[1].lower())

        return sentences_states, sentences_tokens
    
states, tokens = read_conllu('train')
states_dev, tokens_dev = read_conllu('test')

In [2]:
vocab = {}
for t in tokens + tokens_dev: 
    if t not in vocab:
        vocab[t] = len(vocab)

inv_vocab = {v: k for k, v in vocab.items()}

In [3]:
def get_tensors(states, tokens):
    X = torch.stack(states)
    y_lst = []
    for t in tokens:
        y_lst.append(vocab[t])
    y = torch.LongTensor(y_lst)
    return(X, y)

X_train, y_train = get_tensors(states, tokens)
X_dev, y_dev = get_tensors(states_dev, tokens_dev)

In [4]:
dset = utils.TensorDataset(X_train, y_train)
data_loader = DataLoader(dset, batch_size=64, shuffle=True)

In [5]:
class ProbingModule(nn.Module):
    def __init__(
            self,
            input_dim,
            output_dim,
            hidden_dim=64,
    ):
        super(ProbingModule, self).__init__()
        self.hidden = nn.Linear(input_dim, hidden_dim)
        self.output = nn.Linear(hidden_dim, output_dim)

    def forward(self, X, **kwargs):
        X = F.relu(self.hidden(X))
        X = self.output(X)
        return X

In [6]:
def train_probe(model, data_loader, X_dev, y_dev, seed):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0)

    n_steps = 0
    evaluator = metrics.Evaluator(model)
    torch.manual_seed(seed)

    for epoch in range(10):
        for _, (state, target) in enumerate(data_loader):
            optimizer.zero_grad()
            tag_score = model(state)
            loss = criterion(tag_score, target)
            loss.backward()
            optimizer.step()

            if n_steps in evaluator.steps:
                with torch.no_grad():
                    acc = evaluator.accuracy_for_ws(X_dev, y_dev)
                    # print("n_steps: {0}, accuracy: {1} ".format(n_steps, acc))
            n_steps += 64

        with torch.no_grad():
            evaluator.accuracy(X_dev, y_dev, model)
            
    return(evaluator)

In [7]:
evaluator_saver = []

for s in config.seeds:
    model = ProbingModule(768, len(vocab))
    evaluator = train_probe(model, data_loader, X_dev, y_dev, s)
    evaluator_saver.append(evaluator)

mean_sd = metrics.MeanSD(evaluator_saver)
mean_sd.print_accs()

	nonzero()
Consider using one of the following signatures instead:
	nonzero(*, bool as_tuple) (Triggered internally at  ../torch/csrc/utils/python_arg_parser.cpp:766.)
  pred = (outputs.data == max_value).nonzero()


Best epoch:    Mean: 80.4586717630196    Standard Deviation: 0.08351718810082304
