In [None]:
%connect_info

In [None]:
%matplotlib inline

In [None]:
import sys
sys.argv = sys.argv[:1]

In [None]:
import argparse
import os
import pickle
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from sklearn.metrics import f1_score, classification_report, confusion_matrix
from sklearn.utils.class_weight import compute_class_weight
from sklearn.model_selection import train_test_split

import skorch
from skorch import NeuralNetClassifier
from skorch.callbacks import ProgressBar, EarlyStopping, Checkpoint
from skorch.helper import predefined_split
from torchtext.data import Dataset, Field, Example, BucketIterator

from tensorflow.keras.utils import to_categorical

from transformers import AutoModel, AutoTokenizer

# from evaluation import evaluate_model

In [None]:
#

In [None]:
id_labels = {
    0: 'none',
    1: 'IAV',
    2: 'IRV',
    3: 'LVC.cause',
    4: 'LVC.full',
    5: 'MVC',
    6: 'VID',
    7: 'VPC.full',
    8: 'VPC.semi',
    9: '<unlabeled>'
}


def load_tokenized_data(datafile, language_codes, percent=0.15, seed=42):

    with open(datafile, 'rb') as f:
        data = pickle.load(f)
    x_train, y_train = [], []
    x_val, y_val = [], []
    x_dev, y_dev = {}, {}
    for code in language_codes:

        true_x, true_y = [], []
        false_x, false_y = [], []
        for i, (xsample, ysample) in enumerate(zip(data[code]['x_train'], data[code]['y_train'])):
            
            if sum(ysample) > 0:
                true_x.append(xsample)
                true_y.append(ysample)

        max_len = max([len(y) for y in true_y])
        for xsample, ysample in zip(data[code]['x_train'], data[code]['y_train']):
            if 1 not in ysample and len(ysample) < max_len:
                false_x.append(xsample)
                false_y.append(ysample)

        false_x = np.array(false_x)
        false_y = np.array(false_y)

        np.random.seed(seed)
        idx = np.random.randint(len(false_y), size=int(percent * len(true_y)))
        false_x = false_x[idx].tolist()
        false_y = false_y[idx].tolist()

        x_train += true_x + false_x
        y_train += true_y + false_y
        x_val += data[code]["x_dev"]
        y_val += data[code]["y_dev"]

        x_dev[code] = data[code]["x_dev"]
        y_dev[code] = data[code]["y_dev"]

    del data

    return (x_train, y_train),( x_val, y_val), (x_dev, y_dev)

In [None]:
def build_model_name(args, model='rnn-cnn'):
    name = ''
    if model == 'rnn':
        name = ("{0}.{1}.{2}layers.{3}lstm.{4}dropout.{5}init.{6}activation"
                "{7}clipnorm.{8}batch.{9}epochs".format(
                    args.bert_type, args.metric, args.nlayers, args.lstm_size,
                    args.dropout, args.initrange, args.output_activation,
                    args.clipnorm, args.batch_size, args.max_epochs
                ))
    elif model == 'cnn':
        name = ("{0}.{1}.{2}filters.{3}kernels.{4}poolstride.{5}dropout."
                "{6}activation.{7}batch.{8}epochs".format(
                    args.bert_type, args.metric, args.nfilters, args.kernels,
                    args.pool_stride, args.dropout, args.output_activation,
                    args.batch_size, args.max_epochs
                ))
    elif model == 'cnn-rnn':
         name = ("{0}.{1}.{2}filters.{3}kernels.{4}poolstride.{5}layers."
                "{6}lstm.{7}dropout.{8}init.{9}activation.{10}batch."
                "{11}epochs".format(
                    args.bert_type, args.metric, args.nfilters, args.kernels,
                    args.pool_stride, args.nlayers, args.lstm_size, args.dropout,
                    args.initrange, args.output_activation, args.batch_size,
                    args.max_epochs
                ))
    return name

In [None]:
class SkorchBucketIterator(BucketIterator):

    def __init__(self,
                 dataset,
                 batch_size,
                 sort_key=None,
                 device=None,
                 batch_size_fn=None,
                 train=True,
                 repeat=False,
                 shuffle=None,
                 sort=None,
                 sort_within_batch=None,
                 one_hot=True,
                 num_classes=2):
        self.one_hot = one_hot
        self.num_classes = num_classes
        super(SkorchBucketIterator,
              self).__init__(dataset, batch_size, sort_key, device,
                             batch_size_fn, train, repeat, shuffle, sort,
                             sort_within_batch)

    def __iter__(self):
        for batch in super().__iter__():
            # We make a small modification: Instead of just returning batch
            # we return batch.text and batch.label, corresponding to X and y
            # if self.train:
            if self.one_hot:
                y = batch.labels.to('cpu')
                y = to_categorical(y, num_classes=self.num_classes)
                y = torch.tensor(y).to(self.device)
                batch.labels = y
            else:
                batch.labels = batch.labels.float()
            yield batch.sentence, batch.labels


class SentenceDataset(Dataset):

    def __init__(self, data, min_len=5, **kwargs):
        self.min_len = min_len
        text_field = Field(use_vocab=False, pad_token=0, batch_first=True)
        label_field = Field(use_vocab=False, pad_token=-1, batch_first=True)
        fields = [("sentence", text_field), ("labels", label_field)]
        examples = []
        for (x, y) in zip(data[0], data[1]):
            if len(x) < self.min_len:     # pad all sequences shorter than this
                x += [0] * (5 - len(x))
                y += [-1] * (5 - len(y))
            examples.append(Example.fromlist([x, y], fields))
        super().__init__(examples, fields, **kwargs)


class IdiomClassifier(skorch.NeuralNetClassifier):

    def __init__(self, print_report=True, class_weights=None, score_average='binary', *args, **kwargs):
        self.print_report = print_report
        self.class_weights = class_weights
        self.score_average = score_average
        if class_weights is None:
            self.class_weights = [1.0, 1.0]
        super(IdiomClassifier, self).__init__(*args, **kwargs)
        self.set_params(callbacks__valid_acc=None)
        self.set_params(criterion__reduction='none')

    def get_loss(self, y_pred, y_true, X, *args, **kwargs):
        if isinstance(self.criterion_, torch.nn.NLLLoss):
            loss = super().get_loss(y_pred.view(-1, 10), y_true.long().view(-1), X, *args, **kwargs)
        else:
            loss = super().get_loss(y_pred.view(-1), y_true.view(-1), X, *args,
                                        **kwargs)
        weights = torch.ones_like(y_true) * y_true
        for w, weight in enumerate(self.class_weights):
            weights = torch.where(
                y_true == w,
                torch.tensor(weight).float().to(self.device),
                weights)
        loss = (loss * weights.view(-1))
        mask = (y_true >= 0).int()
        loss = (loss * mask.view(-1)).mean()
        return loss

    def predict(self, X):
        self.module.eval()
        y_pred = self.module(X)
        if len(y_pred.shape) > 2:
            y_pred = torch.argmax(y_pred, dim=2)
        else:
            y_pred = (y_pred > 0.5).int()
        return y_pred

    def score(self, X, y=None):
        self.module.eval()
        ds = self.get_dataset(X)
        target_iterator = self.get_iterator(ds, training=False)

        y_true = []
        y_pred = []
        for x, y in target_iterator:
            preds = self.predict(x)
            y_pred.append(preds.view(-1))
            if len(y.shape) > 2:
                y = torch.argmax(y, dim=2)
            y_true.append(y.view(-1))
        y_true = torch.cat(y_true).cpu().view(-1).detach().numpy().tolist()
        y_pred = torch.cat(y_pred).cpu().view(-1).detach().numpy().tolist()

        tt, tp = [], []
        for t, p in zip(y_true, y_pred):
            if t >= 0:
                tt.append(t)
                tp.append(p)

        y_true = tt
        y_pred = tp

        if self.print_report:
            print('Confusion matrix')
            print(confusion_matrix(y_true, y_pred))
            print(classification_report(y_true, y_pred))
        return f1_score(y_true, y_pred, average=self.score_average)


class CustomScorer(skorch.callbacks.EpochScoring):

    def on_epoch_end(self, net, dataset_train, dataset_valid, **kwargs):
        current_score = net.score(dataset_valid)
        self._record_score(net.history, current_score)


In [None]:
class CNNClassifier(nn.Module):
    def __init__(self, config, transformer, transformer_device):
        super(CNNClassifier, self).__init__()

        self.transformer_device = transformer_device
        self.model_device = transformer_device
        self.transformer = transformer
        self.convolutions = nn.ModuleList([
            nn.Conv1d(
                in_channels=transformer.embeddings.word_embeddings.embedding_dim,
                out_channels=config.nfilters,
                kernel_size=kernel_size,
                stride=1) for kernel_size in config.kernels])

        self.pool_stride = config.pool_stride
        self.dropout = nn.Dropout(config.dropout)

        ninputs = (config.nfilters // config.pool_stride) * len(config.kernels)
        noutputs = 1

        if config.labels == 'multilabel':
            noutputs = 10
        else:
            if config.output_activation == 'softmax':
                noutputs = 2

        self.fully_connected = nn.Linear(ninputs, noutputs)

        self.output_activation = (torch.sigmoid  # pylint: disable=no-member
                                  if noutputs == 1
                                  else F.softmax)

    def to(self, *args, **kwargs):
        self = super().to(*args, **kwargs)
        self.transformer = self.transformer.to(
            torch.device(self.transformer_device))
        self.model_device = next(self.fully_connected.parameters()).device.type
        return self

    def freeze_transformer(self):
        for param in self.transformer.parameters():
            param.requires_grad = False

    def unfreeze_transformer(self):
        for param in self.transformer.parameters():
            param.requires_grad = True

    def forward(self, x):
        x = x.to(self.transformer_device)
        m = (x > 0).int()
        x = self.transformer(x, attention_mask=m)[0].transpose(1, 2)
        #
        seq_len = x.shape[-1]
        #
        if self.transformer_device != self.model_device:
            x = x.to(self.model_device)
        #
        x = [F.relu(conv(x)).transpose(1, 2) for conv in self.convolutions]
        x = [nn.functional.pad(i, (0, 0, 0, seq_len - i.shape[1])) for i in x]
        x = [F.max_pool1d(c, self.pool_stride) for c in x]
        x = torch.cat(x, dim=2)  # pylint: disable=no-member
        x = self.fully_connected(x)
        x = self.dropout(x)

        return self.output_activation(x, dim=2)

class RNNClassifier(nn.Module):
    def __init__(self, config, transformer, transformer_device):
        super(RNNClassifier, self).__init__()

        self.transformer_device = transformer_device
        self.model_device = transformer_device
        self.transformer = transformer

        self.lstm = nn.LSTM(
            input_size=transformer.embeddings.word_embeddings.embedding_dim,
            hidden_size=config.lstm_size,
            num_layers=config.nlayers,
            batch_first=True,
            dropout=config.dropout)

        self.dropout = nn.Dropout(config.dropout)
        noutputs = (1 if config.output_activation == 'sigmoid' else 2)

        self.fully_connected = nn.Linear(config.lstm_size, noutputs)

        self.output_activation = (torch.sigmoid  # pylint: disable=no-member
                                  if noutputs == 1
                                  else F.softmax)
        self.init_weights(config.initrange)

    def to(self, *args, **kwargs):
        self = super().to(*args, **kwargs)
        self.transformer = self.transformer.to(
            torch.device(self.transformer_device))
        self.model_device = next(self.fully_connected.parameters()).device.type
        return self

    def freeze_transformer(self):
        for param in self.transformer.parameters():
            param.requires_grad = False

    def unfreeze_transformer(self):
        for param in self.transformer.parameters():
            param.requires_grad = True

    def forward(self, x):
        x = x.to(self.transformer_device)
        m = (x > 0).int()
        x = self.transformer(x, attention_mask=m)[0]
        #
        seq_len = x.shape[-1]
        #
        if self.transformer_device != self.model_device:
            x = x.to(self.model_device)
        #
        x, _ = self.lstm(x)
        x = self.dropout(x)
        x = self.fully_connected(x)

        return self.output_activation(x)

    def init_weights(self, initrange):
        for names in self.lstm._all_weights:
            for name in filter(lambda n: "bias" in n, names):
                bias = getattr(self.lstm, name)
                n = bias.size(0)
                start, end = n//4, n//2
                bias.data[start:end].fill_(1.)
            for name in filter(lambda n: "weight" in n,  names):
                weight = getattr(self.lstm, name)
                weight.data.uniform_(-initrange, initrange)

        self.fully_connected.bias.data.fill_(0)
        self.fully_connected.weight.data.uniform_(-initrange, initrange)


In [None]:
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')     # pylint: disable=no-member
LANGUAGE_CODES = ['DE', 'GA', 'HI', 'PT', 'ZH']
CWD = os.getcwd()
BASE_DIR = ''     # this will point to the user's home
TRAIN_DIR = "transformer/cnn"


In [None]:
parser = argparse.ArgumentParser(description='Classifier using CNNs')
parser.add_argument(
    '--bert_type',
    type=str,
    default='distilbert-base-multilingual-cased',
    help='transormer model [should be a miltilingual model]')
parser.add_argument(
    '--bert_device',
    type=str,
    default='gpu',
    help='device to run the transformer model')
parser.add_argument(
    '--labels',
    type=str,
    default='multilabel',
    help='multilabel or binary classification')
parser.add_argument(
    '--metric',
    type=str,
    default='f1',
    help='sklearn metric to evaluate the model while training')
parser.add_argument(
    '--nfilters',
    type=int,
    default=32,
    help='number of convolution filters')
parser.add_argument(
    '--kernels',
    type=list,
    default=[1, 3, 5],
    help='number of convolution filters')
parser.add_argument(
    '--pool_stride',
    type=int,
    default=2,
    help='size of the stride for the pooling operation')
parser.add_argument(
    '--nlayers',
    type=int,
    default=2,
    help='number of convolution filters')
parser.add_argument(
    '--lstm_size',
    type=int,
    default=50,
    help='number of convolution filters')
parser.add_argument(
    '--dropout',
    type=float,
    default=0.2,
    help='dropout probability for the dense layer')
parser.add_argument(
    '--initrange',
    type=float,
    default=0.1,
    help='range to initialize the lstm layers')
parser.add_argument(
    '--clipnorm',
    type=float,
    default=5.0,
    help='limit to clip the l2 norm of gradients')
parser.add_argument(
    '--output_activation',
    type=str,
    default='sigmoid',
    help='output activation')
parser.add_argument(
    '--batch_size',
    type=int,
    default=32,
    help='training batch size')
parser.add_argument(
    '--eval_batch_size',
    type=int,
    default=32,
    help='validation/evaluation batch size')
parser.add_argument(
    '--max_epochs',
    type=int,
    default=100,
    help='max number of epochs to train the model')
parser.add_argument(
    "--train_dir",
    type=str,
    default=os.path.join(BASE_DIR, TRAIN_DIR) + "/",
    help="Train dir")
parser.add_argument(
    "--eval",
    action="store_true",
    help="eval at the end of the training process")


In [None]:
args = parser.parse_args()
args.kernels = [int(i) for i in args.kernels if ',' not in str(i)]
transformer_device = torch.device(
    'cuda' if torch.cuda.is_available() and args.bert_device == 'gpu'
    else 'cpu')

ONE_HOT_OUTPUT = args.output_activation == 'softmax' or args.labels == 'binary'

In [None]:
(x_train, y_train), (x_val, y_val), (x_dev, y_dev) = load_tokenized_data(
    datafile='data/{}{}.tokenized.pkl'.format(args.bert_type, '' if args.labels == 'binary' else '.multilabel'),
    language_codes=LANGUAGE_CODES,
    seed=SEED)

targets = np.concatenate(y_train).reshape(-1)
class_weights = compute_class_weight(class_weight='balanced',
                                     classes=np.unique(targets),
                                     y=targets)

In [None]:
class_weights

In [None]:
tokenizer = AutoTokenizer.from_pretrained(args.bert_type)
transformer = AutoModel.from_pretrained(args.bert_type)

In [None]:
model = CNNClassifier(args, transformer, transformer_device)
model_name = build_model_name(args, model='cnn')

model.to(DEVICE)     # pylint: disable=no-member
model.freeze_transformer()

In [None]:
progress_bar = ProgressBar(batches_per_epoch=len(x_train) // args.batch_size + 1)
scorer = CustomScorer(scoring=None, name="F1", lower_is_better=False, use_caching=False)
early_stopping =  EarlyStopping(monitor='F1', patience=20, lower_is_better=False)
checkpoint = Checkpoint(
    monitor='F1_best',
    dirname=args.train_dir,
    f_params='{}.params.pt'.format(model_name),
    f_optimizer='{}.optimizer.pt'.format(model_name),
    f_history='{}.history.json'.format(model_name))

In [None]:
net = IdiomClassifier(
    module=model,
    class_weights=class_weights,
    print_report=False,
    score_average='micro',
     #
    iterator_train=SkorchBucketIterator,
    iterator_train__batch_size=args.batch_size,
    iterator_train__sort_key=lambda x: len(x.sentence),
    iterator_train__shuffle=True,
    iterator_train__device=DEVICE,
    iterator_train__one_hot=ONE_HOT_OUTPUT,
     #
    iterator_valid=SkorchBucketIterator,
    iterator_valid__batch_size=32,
    iterator_valid__sort_key=lambda x: len(x.sentence),
    iterator_valid__shuffle=True,
    iterator_valid__device=DEVICE,
    iterator_valid__one_hot=ONE_HOT_OUTPUT,

    train_split=predefined_split(SentenceDataset(data=(x_val[0:5], y_val[0:5]))),
    optimizer=torch.optim.Adam,
    criterion=nn.BCELoss if args.labels == 'binary' else nn.NLLLoss,
    criterion__ignore_index=-1,
    callbacks=[progress_bar, scorer, early_stopping, checkpoint],
    device=DEVICE,
)


In [None]:
net.fit(SentenceDataset(data=(x_train[0:32], y_train[0:32])), y=None, epochs=1)
print()

In [None]:
# net.initialize()
# net.load_params(checkpoint=checkpoint)

In [24]:
print(model_name)
args.eval = True
LANGUAGE_CODES = ['GA']
if args.eval:
    for code in LANGUAGE_CODES:
        print('#' * 20)
        print('# Evaluating Language: {}'.format(code))
        print('#' * 20)
        test_iterator = SkorchBucketIterator(
            dataset=SentenceDataset(data=(x_dev[code], y_dev[code])),
            batch_size=args.eval_batch_size,
            sort=False,
            sort_key=lambda x: len(x.sentence),
            shuffle=False,
            train=False,
            one_hot=ONE_HOT_OUTPUT,
            device=DEVICE)
        args.dev_file = 'data/{}/dev.cupt'.format(code)
        evaluate_model(net, test_iterator, tokenizer, args)

distilbert-base-multilingual-cased.f1.32filters.[1, 3, 5]kernels.2poolstride.0.2dropout.sigmoidactivation.32batch.100epochs
####################
# Evaluating Language: GA
####################
0


Exception: Line has 162 columns, but header specifies 11

In [26]:
from eval_scripts.evaluate import Main


ID_LABELS = {
    0: 'none',
    1: 'IAV',
    2: 'IRV',
    3: 'LVC.cause',
    4: 'LVC.full',
    5: 'MVC',
    6: 'VID',
    7: 'VPC.full',
    8: 'VPC.semi',
    9: '<unlabeled>'
}

def evaluate_model(net, test_iterator, tokenizer, args):
    preds = []
    sents = []
    i = 0
    for x, y in test_iterator:
        y_pred = net.predict(x)
    #     i += 1
        if i % 40 == 0:
            print(i)
        i += 1
        sub_tokens = []
        sub_preds = []
        text = []
        predictions = []
        tokens = tokenizer.convert_ids_to_tokens(x.detach().cpu().numpy().reshape(-1))
        # tokens = tokens
        y_pred = y_pred.cpu().detach().reshape(-1).tolist()
        for t, p in zip(tokens, y_pred):
            if '#' in t:
                sub_tokens.append(t.replace('#', ''))
                sub_preds.append(p)
            else:
                if sub_tokens:
                    old_token = ''.join([text[-1]] + sub_tokens)
                    old_pred = sum(sub_preds)
                    text = text[0:-1]
                    text.append(old_token)
                    predictions = predictions[0:-1]
                    predictions.append(old_pred if old_pred == 0 else 1)
                    old_token = t
                    old_pred = p
                    sub_tokens = []
                    sub_preds = []
                else:
                    old_token = t
                    old_pred = p
                text.append(old_token)
                predictions.append(old_pred)
                assert len(text[1:-1]) == len(predictions[1:-1])
        sents.append(text[1:-1])
        preds += predictions[1:-1]

    output_count = 0
    with open(args.dev_file, 'r') as dev:
        with open(args.dev_file.replace('dev.cupt', 'temp.cupt'), 'w') as test:
            for line in dev:
                feats = line.split()
                if not line.startswith('#') and line != '\n' and '-' not in feats[0]:
                    prediction = preds[output_count]
                    if prediction == 0:
                        label = '*'
                    else:
                        label = ID_LABELS.get(prediction, '*')
                    new_line = '\t'.join(
                        [str(f) for f in feats[0:-1]] + [str(label)] + ['\n'])
                    test.write(new_line)
                    output_count += 1
                else:
                    test.write(line)

    # post-process the file to get the predictions into cupt format
    with open(args.dev_file.replace('dev.cupt', 'temp.cupt'), 'r') as temp:
        with open(args.dev_file.replace('dev.cupt', 'system.cupt'), 'w') as test:
            current_prediction = [1, None]
            verb_found = False
            for line in temp:
                feats = line.split('\t')
                if not line.startswith('#') and line != '\n' and '-' not in feats[0]:

                    if feats[10] == '*':
                        test.write(line)
                    else:

                        if current_prediction[1] is None:

                            label = '{}:{}'.format(current_prediction[0], feats[10])
                            verb_found = True if feats[3] == 'VERB' else False
                            current_prediction[1] = feats[10]

                        else:

                            if feats[10] == current_prediction[1]:

                                if verb_found and feats[3] != 'VERB':
                                    label = current_prediction[0]

                                elif verb_found and feats[3] == 'VERB':
                                    current_prediction[0] = current_prediction[0] + 1
                                    current_prediction[1] = feats[10]
                                    label = '{}:{}'.format(current_prediction[0], feats[10])

                                elif not verb_found:
                                    label = current_prediction[0]
                                    verb_found = True if feats[3] == 'VERB' else False

                            else:
                                current_prediction[0] = current_prediction[0] + 1
                                current_prediction[1] = feats[10]
                                label = '{}:{}'.format(current_prediction[0], feats[10])
                                verb_found = True if feats[3] == 'VERB' else False
                        test.write('\t'.join(feats[0:-2] + [str(label)]))
                else:
                    if line == '\n':
                        current_prediction = [1, None]
                        verb_found = False
                    test.write(line)
    if args.eval:
        _run_sript(args)




def _run_sript(args):

    args.debug = False
    args.combinatorial = True
    args.gold_file = open(args.dev_file, 'r')
    args.prediction_file = open(args.dev_file.replace('dev.cupt', 'system.cupt'), 'r')
    args.train_file = open(args.dev_file.replace('dev.cupt', 'train.cupt'), 'r')

    Main(args).run()

In [None]:
print("#" * 20)
print("\nTraining finished!!!")

In [None]:
net.module

In [None]:
code = 'GA'
test_iterator = SkorchBucketIterator(
            dataset=SentenceDataset(data=(x_dev[code], y_dev[code])),
            batch_size=args.eval_batch_size,
            sort=False,
            sort_key=lambda x: len(x.sentence),
            shuffle=False,
            train=False,
            one_hot=ONE_HOT_OUTPUT,
            device=DEVICE)

In [None]:

args.dev_file = 'data/{}/dev.cupt'.format(code)
evaluate_model(net, test_iterator, tokenizer, args)

In [None]:
F.softmax(torch.tensor([[0.25, 0.75], [0.5, 0.5]]))

In [None]:
x, y = next(iter(test_iterator))

In [None]:
# x = x.to(self.transformer_device)
m = (x > 0).int()
x = transformer(x, attention_mask=m)[0].transpose(1, 2)
#
seq_len = x.shape[-1]
#
# if self.transformer_device != self.model_device:
#     x = x.to(self.model_device)
#
x = [F.relu(conv(x)).transpose(1, 2) for conv in net.module.convolutions]
x = [nn.functional.pad(i, (0, 0, 0, seq_len - i.shape[1])) for i in x]
x = [F.max_pool1d(c, net.module.pool_stride) for c in x]
x = torch.cat(x, dim=2)  # pylint: disable=no-member
x = net.module.fully_connected(x)
x = net.module.dropout(x)

In [None]:
x.shape

In [None]:
x = F.softmax(x, dim=2)

In [None]:
x[0, 0, :].sum()

In [None]:
print('\t'.join(['1', '2', '3']))

In [None]:
d = {1: 'one', 2: 'two'}

In [None]:
'one' in d.values()

In [None]:
sorted([k for k in d])

In [None]:
sorted([k for k in d.keys()])[-1]

In [None]:
[i for i in range(11)][10]

In [None]:
len(d.keys())

In [None]:
print(
    """# source_sent_id = http://hdl.handle.net/11234/1-3105 UD_Irish-IDT/ga_idt-ud-dev 547
# text = Nuair a chuaigh Éamonn i mbun oibre bhí sé i gceist aige an oifig a riar go héifeachtach ach chuir na riaráistí oibre alltacht air.
1	Nuair	nuair	SCONJ	Subord	_	3	mark	_	_	*
2	a	a	PART	Vb	PartType=Vb|PronType=Rel	3	mark:prt	_	_	*
3	chuaigh	téigh	VERB	VTI	Form=Len|Mood=Ind|Tense=Past	8	advcl	_	_	1:VID
4	Éamonn	Éamonn	NOUN	Noun	Gender=Masc|Number=Sing	3	nsubj	_	_	*
5	i	i	ADP	Cmpd	PrepForm=Cmpd	7	case	_	_	1
6	mbun	mbun	ADP	Cmpd	PrepForm=Cmpd	5	fixed	_	_	1
7	oibre	obair	NOUN	Noun	Case=Gen|Gender=Fem|Number=Sing	3	obl	_	_	*
8	bhí	bí	VERB	PastInd	Form=Len|Mood=Ind|Tense=Past	0	root	_	_	*
9	sé	sé	PRON	Pers	Gender=Masc|Number=Sing|Person=3	8	nsubj	_	_	*
10	i	i	ADP	Simp	_	11	case	_	_	*
11	gceist	ceist	NOUN	Noun	Form=Ecl|Gender=Fem|Number=Sing	8	xcomp:pred	_	_	*
12	aige	ag	ADP	Prep	Gender=Masc|Number=Sing|Person=3	8	obl:prep	_	_	*
13	an	an	DET	Art	Definite=Def|Number=Sing|PronType=Art	14	det	_	_	*
14	oifig	oifig	NOUN	Noun	Definite=Def|Gender=Fem|Number=Sing	16	obj	_	_	*
15	a	a	PART	Inf	PartType=Inf	16	mark	_	_	*
16	riar	riar	NOUN	Noun	VerbForm=Inf	8	xcomp	_	_	*
17	go	go	PART	Ad	PartType=Ad	18	mark:prt	_	_	*
18	héifeachtach	éifeachtach	ADJ	Adj	Degree=Pos|Form=HPref	16	advmod	_	_	*
19	ach	ach	SCONJ	Subord	_	20	mark	_	_	*
20	chuir	cuir	VERB	VTI	Form=Len|Mood=Ind|Tense=Past	8	advcl	_	_	2:LVC.cause
21	na	na	DET	Art	Definite=Def|Number=Plur|PronType=Art	22	det	_	_	*
22	riaráistí	riaráiste	NOUN	Noun	Definite=Def|Gender=Masc|Number=Plur	20	nsubj	_	_	*
23	oibre	obair	NOUN	Noun	Case=Gen|Gender=Fem|Number=Sing	22	nmod	_	_	*
24	alltacht	alltacht	NOUN	Noun	Gender=Fem|Number=Sing	20	obj	_	_	2
25	air	ar	ADP	Prep	Gender=Masc|Number=Sing|Person=3	20	obl:prep	_	SpaceAfter=No	*
26	.	.	PUNCT	.	_	8	punct	_	_	*"""
)

In [None]:
current_prediction = [1, None]

In [None]:
current_prediction

In [None]:
current_prediction[1] = 'VERB'

In [None]:
current_prediction

In [None]:
dev =  """# source_sent_id = http://hdl.handle.net/11234/1-3105 UD_Irish-IDT/ga_idt-ud-dev 547\n
# text = Nuair a chuaigh Éamonn i mbun oibre bhí sé i gceist aige an oifig a riar go héifeachtach ach chuir na riaráistí oibre alltacht air.\n
1	Nuair	nuair	SCONJ	Subord	_	3	mark	_	_	*\n
2	a	a	PART	Vb	PartType=Vb|PronType=Rel	3	mark:prt	_	_	*\n
3	chuaigh	téigh	VERB	VTI	Form=Len|Mood=Ind|Tense=Past	8	advcl	_	_	VID\n
4	Éamonn	Éamonn	NOUN	Noun	Gender=Masc|Number=Sing	3	nsubj	_	_	*\n
5	i	i	ADP	Cmpd	PrepForm=Cmpd	7	case	_	_	VID\n
6	mbun	mbun	ADP	Cmpd	PrepForm=Cmpd	5	fixed	_	_	VID\n
7	oibre	obair	NOUN	Noun	Case=Gen|Gender=Fem|Number=Sing	3	obl	_	_	*\n
8	bhí	bí	VERB	PastInd	Form=Len|Mood=Ind|Tense=Past	0	root	_	_	*\n
9	sé	sé	PRON	Pers	Gender=Masc|Number=Sing|Person=3	8	nsubj	_	_	*\n
10	i	i	ADP	Simp	_	11	case	_	_	*\n
11	gceist	ceist	NOUN	Noun	Form=Ecl|Gender=Fem|Number=Sing	8	xcomp:pred	_	_	*\n
12	aige	ag	ADP	Prep	Gender=Masc|Number=Sing|Person=3	8	obl:prep	_	_	*\n
13	an	an	DET	Art	Definite=Def|Number=Sing|PronType=Art	14	det	_	_	*\n
14	oifig	oifig	NOUN	Noun	Definite=Def|Gender=Fem|Number=Sing	16	obj	_	_	*\n
15	a	a	PART	Inf	PartType=Inf	16	mark	_	_	*\n
16	riar	riar	NOUN	Noun	VerbForm=Inf	8	xcomp	_	_	*\n
17	go	go	PART	Ad	PartType=Ad	18	mark:prt	_	_	*\n
18	héifeachtach	éifeachtach	ADJ	Adj	Degree=Pos|Form=HPref	16	advmod	_	_	*\n
19	ach	ach	SCONJ	Subord	_	20	mark	_	_	*\n
20	chuir	cuir	VERB	VTI	Form=Len|Mood=Ind|Tense=Past	8	advcl	_	_	VID\n
21	na	na	DET	Art	Definite=Def|Number=Plur|PronType=Art	22	det	_	_	*\n
22	riaráistí	riaráiste	NOUN	Noun	Definite=Def|Gender=Masc|Number=Plur	20	nsubj	_	_	*\n
23	oibre	obair	NOUN	Noun	Case=Gen|Gender=Fem|Number=Sing	22	nmod	_	_	*\n
24	alltacht	alltacht	NOUN	Noun	Gender=Fem|Number=Sing	20	obj	_	_	VID\n
25	air	ar	ADP	Prep	Gender=Masc|Number=Sing|Person=3	20	obl:prep	_	SpaceAfter=No	*\n
26	.	.	PUNCT	.	_	8	punct	_	_	*"""
print(dev)

In [None]:
dev = dev.split('\n')
dev = [d for d in dev if d != '']
dev

In [None]:
current_prediction = [1, None]
verb_found = False
for line in dev:
    feats = line.split('\t')
    if not line.startswith('#') and line != '\n' and '-' not in feats[0]:
        
        if feats[10] == '*':
            print(line)
        else:

            if current_prediction[1] is None:

                label = '{}:{}'.format(current_prediction[0], feats[10])
                verb_found = True if feats[3] == 'VERB' else False
                current_prediction[1] = feats[10]

            else:

                if feats[10] == current_prediction[1]:

                    if verb_found and feats[3] != 'VERB':
                        label = current_prediction[0]

                    elif verb_found and feats[3] == 'VERB':
                        current_prediction[0] = current_prediction[0] + 1
                        current_prediction[1] = feats[10]
                        label = '{}:{}'.format(current_prediction[0], feats[10])
                    
                    elif not verb_found:
                        label = current_prediction[0]
                        verb_found = True if feats[3] == 'VERB' else False

                else:
                    current_prediction[0] = current_prediction[0] + 1
                    current_prediction[1] = feats[10]
                    label = '{}:{}'.format(current_prediction[0], feats[10])
                    verb_found = True if feats[3] == 'VERB' else False
            print('\t'.join(feats[0:-1] + [str(label)] + ['pred']))
    else:
        if line == '\n':
            current_prediction = [1, None]
            verb_found = False

In [None]:
feats