In [None]:
!pip install datasets  # huggingface library with dataset
!pip install conllu    # aux library for processing CoNLL-U format
!pip install transformers
!pip install evaluate
!pip install accelerate

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting datasets
  Downloading datasets-2.13.1-py3-none-any.whl (486 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m486.2/486.2 kB[0m [31m7.3 MB/s[0m eta [36m0:00:00[0m
Collecting dill<0.3.7,>=0.3.0 (from datasets)
  Downloading dill-0.3.6-py3-none-any.whl (110 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m110.5/110.5 kB[0m [31m8.8 MB/s[0m eta [36m0:00:00[0m
Collecting xxhash (from datasets)
  Downloading xxhash-3.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (212 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m212.5/212.5 kB[0m [31m8.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting multiprocess (from datasets)
  Downloading multiprocess-0.70.14-py310-none-any.whl (134 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.3/134.3 kB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m
Collec

In [None]:
import torch
import torch.nn as nn
from functools import partial
from datasets import load_dataset

## Arc-eager Parsing

In [None]:
class ArcEager:
  def __init__(self, sentence):
    self.sentence = sentence
    self.buffer = [i for i in range(len(self.sentence))]
    self.stack = []
    self.arcs = [-1 for _ in range(len(self.sentence))]

    # shift to initialize the stack
    self.shift()

  def shift(self):
    b1 = self.buffer[0]
    self.buffer = self.buffer[1:]
    self.stack.append(b1)

  def left_arc(self):
    s = self.stack.pop()
    b = self.buffer[0]
    self.arcs[s] = b

  def right_arc(self):
    b = self.buffer[0]
    s = self.stack.pop()
    self.arcs[b] = s
    self.stack.append(s)
    self.shift()

  def reduce_arc(self):
    o1 = self.stack.pop()

  def is_tree_final(self):  # return boolean
    return (not self.buffer) and (len(self.stack) == 1)

  def print_configuration(self):
    s = [self.sentence[i] for i in self.stack]
    b = [self.sentence[i] for i in self.buffer]
    print(s, b)
    print(self.arcs)

In [None]:
sentence = ["<ROOT>", "He", "began", "to", "write", "again", "."]
gold = [-1, 2, 0, 4, 2, 4, 2 ]

parser = ArcEager(sentence)
parser.print_configuration()

['<ROOT>'] ['He', 'began', 'to', 'write', 'again', '.']
[-1, -1, -1, -1, -1, -1, -1]


In [None]:
parser.shift()
parser.print_configuration()

['<ROOT>', 'He'] ['began', 'to', 'write', 'again', '.']
[-1, -1, -1, -1, -1, -1, -1]


In [None]:

print(parser.stack)
parser.left_arc()
parser.print_configuration()
print(parser.stack)


[0, 1]
['<ROOT>'] ['began', 'to', 'write', 'again', '.']
[-1, 2, -1, -1, -1, -1, -1]
[0]


In [None]:
parser.right_arc()
parser.print_configuration()

['<ROOT>', 'began'] ['to', 'write', 'again', '.']
[-1, 2, 0, -1, -1, -1, -1]


In [None]:
parser.reduce_arc()
parser.print_configuration()

['<ROOT>'] ['to', 'write', 'again', '.']
[-1, 2, 0, -1, -1, -1, -1]


In [None]:
parser.is_tree_final()

False

## Oracle

In [None]:
class Oracle:
  def __init__(self, parser, gold_tree):
    self.parser = parser
    self.gold = gold_tree

  def is_left_arc_gold(self):
    if len(self.parser.buffer) == 0:
       return False # the buffer can not be empty

    s = self.parser.stack[len(self.parser.stack)-1]
    b = self.parser.buffer[0]

## if it has head or if it is the root node then false
    if self.parser.arcs[s] != -1 or s == self.parser.stack[0]:
      return False

    if self.gold[s] == b:  # if it is in the gold tree
      return True

    return False

  def is_right_arc_gold(self):

    if len(self.parser.buffer) == 0:
       return False #if buffer is empty then false

    o1 = self.parser.stack[len(self.parser.stack) - 1]
    o2 = self.parser.buffer[0]

    if self.gold[o2] != o1:
          return False # if it is in the gold tree

    return True

  def is_reduce_gold(self):

    if (self.is_left_arc_gold() or self.is_right_arc_gold()):
      return False # In order to dictate transition precedence of the parser

    #If it has head and stack is not empty
  # RE if there is a word v on the left of i(topmost stack element) in stack
# directly connected(head or dependent) with j(first element of buffer) in gold tree
    if self.parser.arcs[self.parser.stack[len(self.parser.stack)-1]] != -1 and len(self.parser.stack) > 1:
      if len(self.parser.buffer) != 0: #If buffer is not empty, a word v should be left of i(topmost stack element)
         for v in self.parser.stack:# but as i already does not have arc with j, it can not be problem here (left and right arc gold was false)
            if self.gold[v] == self.parser.buffer[0] or self.gold[self.parser.buffer[0]] == v:
               return True
      else: # Buffer is empty and we do not have left or right arc..
         return True  # in order to get the final tree we empty stack
                        # (we already have final configuration)
    return False

  def is_shift_gold(self):

    if len(self.parser.buffer) == 0:
      return False #Buffer is empty

    #This dictates transition precedence of the parser
    if (self.is_left_arc_gold() or self.is_right_arc_gold() or self.is_reduce_gold() ):
      return False

    return True


In [None]:
sentence1 = ["<ROOT>", "He", "began", "to", "write", "again", "."]
sentence = ["<ROOT>", "He", "wrote", "her", "a" ,"letter", "."]
gold1 = [-1, 2, 0, 4, 2, 4, 2 ]
gold = [-1, 2, 0, 2, 5, 2, 2]
parser = ArcEager(sentence)
oracle = Oracle(parser, gold)

parser.print_configuration()

['<ROOT>'] ['He', 'wrote', 'her', 'a', 'letter', '.']
[-1, -1, -1, -1, -1, -1, -1]


In [None]:
print("Left Arc: ", oracle.is_left_arc_gold())
print("Right Arc: ", oracle.is_right_arc_gold())
print("Reduce: ", oracle.is_reduce_gold())
print("Shift: ", oracle.is_shift_gold())

Left Arc:  False
Right Arc:  False
Reduce:  False
Shift:  True


In [None]:
print('Initial configuration:')
parser.print_configuration()
while not parser.is_tree_final():
  if oracle.is_left_arc_gold():
    parser.left_arc()
    print('left arc:')
    parser.print_configuration()
  elif oracle.is_right_arc_gold():
    parser.right_arc()
    print('right arc:')
    parser.print_configuration()
  elif oracle.is_reduce_gold():
    parser.reduce_arc()
    print('reduce arc:')
    parser.print_configuration()
  elif oracle.is_shift_gold():
    parser.shift()
    print('shift arc:')
    parser.print_configuration()
# print(parser.arcs)
print("GOLD TREE:\n", oracle.gold)

Initial configuration:
['<ROOT>'] ['He', 'wrote', 'her', 'a', 'letter', '.']
[-1, -1, -1, -1, -1, -1, -1]
shift arc:
['<ROOT>', 'He'] ['wrote', 'her', 'a', 'letter', '.']
[-1, -1, -1, -1, -1, -1, -1]
left arc:
['<ROOT>'] ['wrote', 'her', 'a', 'letter', '.']
[-1, 2, -1, -1, -1, -1, -1]
right arc:
['<ROOT>', 'wrote'] ['her', 'a', 'letter', '.']
[-1, 2, 0, -1, -1, -1, -1]
right arc:
['<ROOT>', 'wrote', 'her'] ['a', 'letter', '.']
[-1, 2, 0, 2, -1, -1, -1]
shift arc:
['<ROOT>', 'wrote', 'her', 'a'] ['letter', '.']
[-1, 2, 0, 2, -1, -1, -1]
left arc:
['<ROOT>', 'wrote', 'her'] ['letter', '.']
[-1, 2, 0, 2, 5, -1, -1]
reduce arc:
['<ROOT>', 'wrote'] ['letter', '.']
[-1, 2, 0, 2, 5, -1, -1]
right arc:
['<ROOT>', 'wrote', 'letter'] ['.']
[-1, 2, 0, 2, 5, 2, -1]
reduce arc:
['<ROOT>', 'wrote'] ['.']
[-1, 2, 0, 2, 5, 2, -1]
right arc:
['<ROOT>', 'wrote', '.'] []
[-1, 2, 0, 2, 5, 2, 2]
reduce arc:
['<ROOT>', 'wrote'] []
[-1, 2, 0, 2, 5, 2, 2]
reduce arc:
['<ROOT>'] []
[-1, 2, 0, 2, 5, 2, 2]
GOLD 

## Dataset

In [None]:
dataset = load_dataset('universal_dependencies', 'en_lines', split="train")

Downloading builder script: 0.00B [00:00, ?B/s]

Downloading metadata: 0.00B [00:00, ?B/s]

Downloading readme: 0.00B [00:00, ?B/s]

Downloading and preparing dataset universal_dependencies/en_lines to /root/.cache/huggingface/datasets/universal_dependencies/en_lines/2.7.0/1ac001f0e8a0021f19388e810c94599f3ac13cc45d6b5b8c69f7847b2188bdf7...


Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/580k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/199k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/181k [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/3 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/3176 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1032 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1035 [00:00<?, ? examples/s]

Dataset universal_dependencies downloaded and prepared to /root/.cache/huggingface/datasets/universal_dependencies/en_lines/2.7.0/1ac001f0e8a0021f19388e810c94599f3ac13cc45d6b5b8c69f7847b2188bdf7. Subsequent calls will reuse this data.


In [None]:
# info about dataset
print(len(dataset))
print(dataset[1].keys())

# we look into the second sentence in the dataset and print its tokens and (gold) dependency tree
print(dataset[1]["tokens"])
print(dataset[1]["head"])

3176
dict_keys(['idx', 'text', 'tokens', 'lemmas', 'upos', 'xpos', 'feats', 'head', 'deprel', 'deps', 'misc'])
['About', 'ANSI', 'SQL', 'query', 'mode']
['5', '5', '2', '5', '0']


## Data setup (Create training data and iterable dataloaders)

In [None]:
# the function returns whether a tree is projective or not. It is currently
# implemented inefficiently by brute checking every pair of arcs.
def is_projective(tree):
  for i in range(len(tree)):
    if tree[i] == -1:
      continue
    left = min(i, tree[i])
    right = max(i, tree[i])

    for j in range(0, left):
      if tree[j] > left and tree[j] < right:
        return False
    for j in range(left+1, right):
      if tree[j] < left or tree[j] > right:
        return False
    for j in range(right+1, len(tree)):
      if tree[j] > left and tree[j] < right:
        return False

  return True

# the function creates a dictionary of word/index pairs: our embeddings vocabulary
# threshold is the minimum number of appearance for a token to be included in the embedding list
def create_dict(dataset, threshold=3):
  dic = {}  # dictionary of word counts
  for sample in dataset:
    for word in sample['tokens']:
      if word in dic:
        dic[word] += 1
      else:
        dic[word] = 1

  map = {}  # dictionary of word/index pairs. This is our embedding list
  map["<pad>"] = 0
  map["<ROOT>"] = 1
  map["<unk>"] = 2 #used for words that do not appear in our list

  next_indx = 3
  for word in dic.keys():
    if dic[word] >= threshold:
      map[word] = next_indx
      next_indx += 1

  return map

In [None]:
train_dataset = load_dataset('universal_dependencies', 'en_lines', split="train")
dev_dataset = load_dataset('universal_dependencies', 'en_lines', split="validation")
test_dataset = load_dataset('universal_dependencies', 'en_lines', split="test")

# remove non-projective sentences: heads in the gold tree are strings, we convert them to int
train_dataset = [sample for sample in train_dataset if is_projective([-1] + [int(head) for head in sample["head"]])]
# create the embedding dictionary
emb_dictionary = create_dict(train_dataset)

print("Number of samples:")
print("Train:\t", len(train_dataset)) #(train is the number of samples without the non-projective)
print("Dev:\t", len(dev_dataset))
print("Test:\t", len(test_dataset))



Number of samples:
Train:	 2922
Dev:	 1032
Test:	 1035


In [None]:
def process_sample(sample, get_gold_path = False):

  # put sentence and gold tree in our format
  sentence = ["<ROOT>"] + sample["tokens"]
  gold = [-1] + [int(i) for i in sample["head"]]  #heads in the gold tree are strings, we convert them to int

  # put sentence in a format for bert

  sentence_bert = "<ROOT>" + sample["text"]
  # embedding ids of sentence words
  enc_sentence = [emb_dictionary[word] if word in emb_dictionary else emb_dictionary["<unk>"] for word in sentence]

  # gold_path and gold_moves are parallel arrays whose elements refer to parsing steps
  gold_path = []   # record two topmost stack tokens and first buffer token for current step
  gold_moves = []  # contains oracle (canonical) move for current step: 0 is left, 1 right, 2 shift

  if get_gold_path:  # only for training
    parser = ArcEager(sentence)
    oracle = Oracle(parser, gold)

    while not parser.is_tree_final():

      # save configuration
      configuration = [parser.stack[len(parser.stack)-2], parser.stack[len(parser.stack)-1]]
      if len(parser.buffer) == 0:
        configuration.append(-1)
      else:
        configuration.append(parser.buffer[0])
      gold_path.append(configuration)

      # save gold move
      if oracle.is_left_arc_gold():
        gold_moves.append(0)
        parser.left_arc()
      elif oracle.is_right_arc_gold():
        parser.right_arc()
        gold_moves.append(1)
      elif oracle.is_reduce_gold():
        parser.reduce_arc()
        gold_moves.append(2)
      elif oracle.is_shift_gold():
        parser.shift()
        gold_moves.append(3)

  return sentence_bert, enc_sentence, gold_path, gold_moves, gold

In [None]:
def prepare_batch(batch_data, get_gold_path=False):
  data = [process_sample(s, get_gold_path=get_gold_path) for s in batch_data]
  # sentences, paths, moves, trees are parallel arrays, each element refers to a sentence
  sentence_bert = [s[0] for s in data]
  sentences = [s[1] for s in data]
  paths = [s[2] for s in data]
  moves = [s[3] for s in data]
  trees = [s[4] for s in data]
  return sentence_bert, sentences, paths, moves, trees

In [None]:
BATCH_SIZE = 32 ## it was 32 but because of input in bert model has limited capacity of 512 we had to change it

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=partial(prepare_batch, get_gold_path=True))
dev_dataloader = torch.utils.data.DataLoader(dev_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=partial(prepare_batch))
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=partial(prepare_batch))

## About Bert model and Arc-Eager dependency parsing.

BERT (Bidirectional Encoder Representations from Transformers) is a pre-trained deep learning model that utilizes the transformer architecture (encoder part only) to generate word embeddings. It considers the contextual information of words in both directions, allowing it to capture comprehensive semantic understanding. BERT uses what is called a WordPiece tokenizer. It works by splitting words either into the full forms (e.g., one word becomes one token) or into word pieces — where one word can be broken into multiple tokens.

The Arc-Eager algorithm is a transition-based parsing algorithm used for dependency parsing. It incrementally constructs a dependency tree for a sentence by applying a set of predefined actions. These actions, such as shift, reduce, left arc, and right arc, manipulate a stack and a buffer to determine the next step in building the tree. The algorithm follows specific rules and conditions to ensure the correct formation of dependency relations between words in the sentence.

## Bert model for Neural Dependency Parsing

The NetBert class represents a neural network model that combines BERT (Bidirectional Encoder Representations from Transformers) for generating contextualized word embeddings and the Arc-Eager algorithm for dependency parsing. This model follows the same methods of our baseline model with Bi-lstm. Instead of using the hidden representations from Bi-lstm, we use contextual embeddings of the corresponding words in the configuration steps. Here's an overview of the modified model:

The model consists of the following components:

BERT Base: The bert_base represents a pre-trained BERT model used for obtaining contextualized word embeddings.
Tokenizer: The tokenizer is responsible for tokenizing input sentences and converting them into token IDs.
Feedforward Neural Network: The model incorporates a feedforward neural network with two linear layers (w1 and w2), an activation function (activation), and a softmax function (softmax). The input size of the network is three times the BERT embedding size (3*BERT_EMBEDDING_SIZE), and the output size is 4, representing the scores for each possible action in the Arc-Eager algorithm.
Dropout: The model includes a dropout layer (dropout) to prevent overfitting during training.

The forward method takes sentence_bert (a batch of sentences to feed BERT) and paths (parser configurations) as inputs. First, BERT embeddings are obtained using the bert_embeddings method. Then, the get_mlp_input method generates the MLP input based on the parser configurations and the obtained BERT embeddings. The feedforward neural network is applied to the MLP input, and the output scores for each possible action are obtained. The softmax function is applied to obtain a probability distribution over the actions, and the resulting probabilities are returned.

The infere method is used for inference. It takes a batch of sentences and sentence_bert as inputs. For each sentence in the batch, a parser is initialized using the Arc-Eager algorithm. While parsing is not complete for any of the parsers, the current configurations are retrieved, and the MLP input is generated using the get_mlp_input method. The feedforward neural network is used to obtain the output scores for each possible action, and the action with the highest score is selected. The parsers are updated accordingly. Finally, the method returns the predicted dependency arcs for each parser.

The bert_embeddings method tokenizes the input sentences using the tokenizer, obtains BERT embeddings for each token, and aggregates them to obtain word-level embeddings. The resulting embeddings are padded and returned as batch embeddings.

The get_mlp_input method generates the MLP input for the feedforward neural network. It iterates over the configurations for each sentence in the batch, concatenates the corresponding BERT embeddings based on the configuration indices, and returns the resulting input tensor.

The get_configurations method generates the current configurations for each parser based on the stack, buffer, and arcs.

The parsed_all method checks if parsing is complete for all parsers.

The parse_step method selects and performs the next move based on the output scores obtained from the feedforward neural network. It follows the rules of the Arc-Eager algorithm, considering conditions such as left arc, right arc, reduce arc, and shift. The parsers are updated accordingly.

Overall, the modified NetBert model combines BERT for generating contextualized word embeddings and a feedforward neural network for predicting the next parsing action based on the Arc-Eager algorithm.

In [None]:
BERT_EMBEDDING_SIZE = 768
MLP_SIZE = 200
DROPOUT = 0.5
EPOCHS = 15
LR = 0.00001   # learning rate


In [None]:
from transformers import AutoTokenizer, AutoModel
from torch.nn.utils.rnn import pad_sequence
# load BERT
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
bert_base = AutoModel.from_pretrained("bert-base-uncased").to(device)

Downloading (…)okenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt: 0.00B [00:00, ?B/s]

Downloading (…)/main/tokenizer.json: 0.00B [00:00, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
class NetBert(nn.Module):

  def __init__(self, device):
    super(NetBert, self).__init__()
    self.device = device
    self.bert_base = bert_base.to(self.device)
    self.tokenizer = tokenizer
    # initialize feedforward
    self.w1 = torch.nn.Linear(3*BERT_EMBEDDING_SIZE, MLP_SIZE, bias=True).to(self.device)
    self.activation = torch.nn.Tanh()
    self.w2 = torch.nn.Linear(MLP_SIZE, 4, bias=True).to(self.device)
    self.softmax = torch.nn.Softmax(dim=-1)

    self.dropout = torch.nn.Dropout(DROPOUT)


  def forward(self, sentence_bert, paths):
    # get embeddings
    embeddings_bert = self.bert_embeddings(sentence_bert)


    mlp_input = self.get_mlp_input(paths, embeddings_bert)

    # run the feedforward and get the scores for each possible action
    out = self.mlp(mlp_input)

    return out

  def get_mlp_input(self, configurations, h):
    mlp_input = []
    zero_tensor = torch.zeros(BERT_EMBEDDING_SIZE, requires_grad=False).to(self.device)
    for i in range(len(configurations)): # for every sentence in the batch
      for j in configurations[i]: # for each configuration of a sentence
        mlp_input.append(torch.cat([zero_tensor if j[0]==-1 else h[j[0]][i], zero_tensor if j[1]==-1 else h[j[1]][i], zero_tensor if j[2]==-1 else h[j[2]][i]]))
    mlp_input = torch.stack(mlp_input).to(self.device)
    return mlp_input

  def mlp(self, x):
    return self.softmax(self.w2(self.dropout(self.activation(self.w1(self.dropout(x))))))

  # we use this function at inference time. We run the parser and at each step
  # we pick as next move the one with the highest score assigned by the model
  def infere(self, x, sentence_bert):

    parsers = [ArcEager(i) for i in x]
    embeddings_bert = self.bert_embeddings(sentence_bert)

    while not self.parsed_all(parsers):
      # get the current configuration and score next moves
      configurations = self.get_configurations(parsers)
      mlp_input = self.get_mlp_input(configurations, embeddings_bert)
      mlp_out = self.mlp(mlp_input)
      moves_argm = mlp_out.argmax(-1)
      self.parse_step(parsers, mlp_out)
    return [parser.arcs for parser in parsers]

  def bert_embeddings(self, text):
        batch_embeddings = []
        for sentence in text:

            input_ids = self.tokenizer.encode(sentence, add_special_tokens= False,return_tensors="pt")
            outputs = self.bert_base(input_ids.to(self.device))
            emb = outputs.last_hidden_state[0].to(self.device) # N x H

            # remove special tokens and <ROOT>. we add <ROOT> manually with averaging 3 tokens <,ROOT,>
            input_ids = input_ids[0]

            # aggregate emb for each word
            word_embs = []
            cur_word_emb = [emb[0],emb[1],emb[2]]
            emb = emb[3:]
            input_ids = input_ids[3:]

            for i, token_id in enumerate(input_ids):
                token = self.tokenizer.decode(token_id)
                if token.startswith("##"):
                    cur_word_emb.append(emb[i])
                else:
                    # start a new word
                    if len(cur_word_emb) > 0:
                        word_embs.append(torch.mean(torch.stack(cur_word_emb).to(self.device), dim=0))
                        cur_word_emb = []
                    cur_word_emb.append(emb[i])

            if len(cur_word_emb) > 0:
                word_embs.append(torch.mean(torch.stack(cur_word_emb).to(self.device), dim=0))

            batch_embeddings.append(torch.stack(word_embs).to(self.device))

        return pad_sequence(batch_embeddings).to(self.device)


  def get_configurations(self, parsers):
    configurations = []

    for parser in parsers:
      if parser.is_tree_final():
        conf = [-1, -1, -1]
      else:
        conf = [parser.stack[len(parser.stack)-2], parser.stack[len(parser.stack)-1]]
        if len(parser.buffer) == 0:
          conf.append(-1)
        else:
          conf.append(parser.buffer[0])
      configurations.append([conf])

    return configurations

  def parsed_all(self, parsers):
    for parser in parsers:
      if not parser.is_tree_final():
        return False
      # print(parser.print_configuration())
    return True

  # In this function we select and perform the next move according to the scores obtained.
  # We need to be careful and select correct moves, e.g. don't do a shift if the buffer
  # is empty or a left arc if σ2 is the ROOT. For clarity sake we didn't implement
  # these checks in the parser so we must do them here. This renders the function quite ugly
  def parse_step(self, parsers, moves):
    moves_argm = moves.argmax(-1)
    for i in range(len(parsers)):
      if parsers[i].is_tree_final():
        continue
      else: #consider also the case which toplast element has a head
        if moves_argm[i] == 0:
          if parsers[i].stack[len(parsers[i].stack)-1] != 0:
            if len(parsers[i].buffer) != 0:
              parsers[i].left_arc() #if it is not root in the stack and buffer is not empty
            else: #there are elements in stack also other than root and buffer is empty so we do reduce
              parsers[i].reduce_arc()
          else: #there is only root in stack and buffer is not empty otherwise tree would be final
            if moves[i][1] > moves[i][3]:
              parsers[i].right_arc()  # if score of right_arc is more than shift then right_arc
            else:
              parsers[i].shift()

        elif moves_argm[i] == 1:
          if len(parsers[i].buffer) == 0:
            parsers[i].reduce_arc()
          else:
            parsers[i].right_arc()

        elif moves_argm[i] == 2:
          if (len(parsers[i].stack) > 1 and parsers[i].arcs[parsers[i].stack[len(parsers[i].stack)-1]] != -1) or len(parsers[i].buffer) == 0:
            parsers[i].reduce_arc()
          else:
            if len(parsers[i].stack) == 1:
              if moves[i][1] > moves[i][3]: #being right_arc is more probable than shift do right_arc
                parsers[i].right_arc()
              else:
                parsers[i].shift()
            else:
                if moves[i][0] >= moves[i][1] and moves[i][0] >= moves[i][3]:
                      parsers[i].left_arc()
                elif moves[i][1] >= moves[i][0] and moves[i][1] >= moves[i][3]:
                  parsers[i].right_arc()
                else:
                  parsers[i].shift()
        elif moves_argm[i] ==3:
          if len(parsers[i].buffer) != 0:
            parsers[i].shift()
          else:
            parsers[i].reduce_arc()

In [None]:
def evaluate(gold, preds):
  total = 0
  correct = 0

  for g, p in zip(gold, preds):
    for i in range(1,len(g)):
      total += 1
      if g[i] == p[i]:
        correct += 1

  return correct/total

In [None]:
def train(model, dataloader, criterion, optimizer):
  model.train()
  total_loss = 0
  count = 0

  for batch in dataloader:
    optimizer.zero_grad()
    sentence_bert, sentences, paths, moves, trees = batch

    out = model(sentence_bert, paths).to(device)
    labels = torch.tensor(sum(moves, [])).to(device) #sum(moves, []) flatten the array
    loss = criterion(out, labels)

    count +=1
    total_loss += loss.item()

    loss.backward()
    optimizer.step()

  return total_loss/count

def test(model, dataloader):
  model.eval()

  gold = []
  preds = []

  for batch in dataloader:
    sentence_bert, sentences, paths, moves, trees = batch
    with torch.no_grad():
      pred = model.infere(sentences, sentence_bert)

    gold += trees
    preds += pred

  return evaluate(gold, preds)

## Evaluation:  unlabeled attachment score (UAS)

In [None]:
# device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
# device = torch.device("cpu")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Device:", device)

model = NetBert(device)
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

for epoch in range(EPOCHS):
  avg_train_loss = train(model, train_dataloader, criterion, optimizer)
  val_uas = test(model, dev_dataloader)

  print("Epoch: {:3d} | avg_train_loss: {:5.3f} | dev_uas: {:5.3f} |".format( epoch, avg_train_loss, val_uas))

Device: cuda:0
Epoch:   0 | avg_train_loss: 1.271 | dev_uas: 0.495 |
Epoch:   1 | avg_train_loss: 1.072 | dev_uas: 0.573 |
Epoch:   2 | avg_train_loss: 1.024 | dev_uas: 0.607 |
Epoch:   3 | avg_train_loss: 0.996 | dev_uas: 0.623 |
Epoch:   4 | avg_train_loss: 0.974 | dev_uas: 0.635 |
Epoch:   5 | avg_train_loss: 0.957 | dev_uas: 0.645 |
Epoch:   6 | avg_train_loss: 0.942 | dev_uas: 0.658 |
Epoch:   7 | avg_train_loss: 0.929 | dev_uas: 0.663 |
Epoch:   8 | avg_train_loss: 0.917 | dev_uas: 0.682 |
Epoch:   9 | avg_train_loss: 0.905 | dev_uas: 0.698 |
Epoch:  10 | avg_train_loss: 0.895 | dev_uas: 0.708 |
Epoch:  11 | avg_train_loss: 0.885 | dev_uas: 0.712 |
Epoch:  12 | avg_train_loss: 0.875 | dev_uas: 0.727 |
Epoch:  13 | avg_train_loss: 0.867 | dev_uas: 0.736 |
Epoch:  14 | avg_train_loss: 0.861 | dev_uas: 0.741 |


In [None]:
test_uas = test(model, test_dataloader)
print("test_uas: {:5.3f}".format( test_uas))

test_uas: 0.754


## Comparison between the two models and error analysis for the BERT-based model

Comparing the results of the two models, the Bi-LSTM model achieves a test UAS (Unlabeled Attachment Score) of 0.743, while the BERT model achieves a higher test UAS of 0.754.

Throughout the training process, the Bi-LSTM model starts with an average training loss of 1.055 and gradually decreases to 0.804. On the other hand, the BERT model starts with an average training loss of 1.271 and converges to 0.861. The lower average training loss of the BERT model indicates that it learns the data more effectively during training.

In terms of performance on the development set, the Bi-LSTM model achieves a maximum UAS of 0.736, while the BERT model achieves a higher UAS of 0.741. This suggests that the BERT model performs slightly better in capturing the syntactic dependencies between words.

Finally, on the test set, the BERT model outperforms the Bi-LSTM model with a higher UAS of 0.754 compared to 0.743 for the Bi-LSTM model. This indicates that the BERT model generalizes better to unseen data and demonstrates improved performance in dependency parsing tasks.

Overall, based on the provided results, the BERT model shows better performance compared to the Bi-LSTM model in terms of both training loss and UAS scores on the development and test sets.

## Bert-based in closer look
The Bert model is tend to overfitting a lot. With smaller learning rate, we can see that UAS can be constant and even 0 easily and also there are some methods which can be further applied and help more to find other possible improvements. Such as:

False Positives and False Negatives: Look at cases where the model produces incorrect dependencies (false positives) or fails to predict correct dependencies (false negatives). Identifying the types of errors the model makes can guide further analysis and potential improvements.

Out-of-Domain Performance: Assess whether the model's performance varies when applied to data from different domains or genres. Dependency parsing can be sensitive to domain-specific language use, so it's crucial to understand how the model generalizes across different contexts

Plateau Effect: After reaching a dev_uas score of 0.741, the model's performance seems not to stabilize, as the consequent epochs show significant improvements. So, increasing epochs until to see the Plateau effect would increase performance and it would be more insightful to do hyperparameter tuning or model adjustments after that.

...