ML4NLP3 -- Graph parser
======

In this exercise you are given a fully working natural language parser for English although minimalistic.
As it stands it does not work very well for many reasons:

* The parser structure is incomplete
* The parser is unbatched
* The parser does not contain evaluation code

You task is to improve it until you get something decent.
While doing so, you will learn how to build deeper neural networks


In [None]:
!pip install torchmetrics

Collecting torchmetrics
  Downloading torchmetrics-1.2.1-py3-none-any.whl (806 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m806.1/806.1 kB[0m [31m5.3 MB/s[0m eta [36m0:00:00[0m
Collecting lightning-utilities>=0.8.0 (from torchmetrics)
  Downloading lightning_utilities-0.10.0-py3-none-any.whl (24 kB)
Installing collected packages: lightning-utilities, torchmetrics
Successfully installed lightning-utilities-0.10.0 torchmetrics-1.2.1


In [None]:
from torchmetrics.functional.classification.f_beta import multiclass_fbeta_score, binary_fbeta_score
import torch

In [None]:
preds  = torch.tensor([1, 0, 0, 1])
target = torch.tensor([1, 1, 0, 0])
ic(multiclass_fbeta_score(preds, target, beta=1.0, average='micro', num_classes=2))

preds  = torch.tensor([1, 0, 0, 1])
target = torch.tensor([1, 1, 0, 0])
ic(binary_fbeta_score(preds, target, beta=1.0, threshold=0.5))



d2preds  = torch.tensor([[0, 0, 0, 1],[0, 0, 0, 1]])
d2target = torch.tensor([[1, 1, 0, 0],[1, 1, 0, 0]])
ic(multiclass_fbeta_score(d2preds, d2target, beta=1.0, average='micro', num_classes=2))

preds  = torch.tensor([0.0, 0.0, 0.0, 0.9])
target = torch.tensor([1, 1, 0, 0])
ic(binary_fbeta_score(preds, target, beta=1.0, threshold=0.5))



ic| multiclass_fbeta_score(preds, target, beta=1.0, average='micro', num_classes=2): tensor(0.5000)
ic| binary_fbeta_score(preds, target, beta=1.0, threshold=0.5): tensor(0.5000)
ic| multiclass_fbeta_score(d2preds, d2target, beta=1.0, average='micro', num_classes=2): tensor(0.2500)
ic| binary_fbeta_score(preds, target, beta=1.0, threshold=0.5): tensor(0.)


tensor(0.)

In [None]:
!pip install icecream
from icecream import ic

Collecting icecream
  Downloading icecream-2.1.3-py2.py3-none-any.whl (8.4 kB)
Collecting colorama>=0.3.9 (from icecream)
  Downloading colorama-0.4.6-py2.py3-none-any.whl (25 kB)
Collecting executing>=0.3.1 (from icecream)
  Downloading executing-2.0.1-py2.py3-none-any.whl (24 kB)
Collecting asttokens>=2.0.1 (from icecream)
  Downloading asttokens-2.4.1-py2.py3-none-any.whl (27 kB)
Installing collected packages: executing, colorama, asttokens, icecream
Successfully installed asttokens-2.4.1 colorama-0.4.6 executing-2.0.1 icecream-2.1.3


The conll parsing data
---------------------

You can download the parsing data by running the following block

In [None]:
from urllib.request import urlretrieve

urlretrieve('https://raw.githubusercontent.com/UniversalDependencies/UD_English-EWT/master/en_ewt-ud-train.conllu','train.conllu')
urlretrieve('https://raw.githubusercontent.com/UniversalDependencies/UD_English-EWT/master/en_ewt-ud-dev.conllu','dev.conllu')
urlretrieve('https://raw.githubusercontent.com/UniversalDependencies/UD_English-EWT/master/en_ewt-ud-dev.conllu','test.conllu')



('test.conllu', <http.client.HTTPMessage at 0x7e70c1604880>)

*you* can observe the data to figure how it looks like, by running the following block:

In [None]:
N  = 25 #prints the 25 first lines of the dev file
idata = open('dev.conllu')
for idx,line in enumerate(idata):
    print(line.strip())
    if idx > N:
        break
idata.close()

# newdoc id = weblog-blogspot.com_nominations_20041117172713_ENG_20041117_172713
# sent_id = weblog-blogspot.com_nominations_20041117172713_ENG_20041117_172713-0001
# newpar id = weblog-blogspot.com_nominations_20041117172713_ENG_20041117_172713-p0001
# text = From the AP comes this story :
1	From	from	ADP	IN	_	3	case	3:case	_
2	the	the	DET	DT	Definite=Def|PronType=Art	3	det	3:det	_
3	AP	AP	PROPN	NNP	Number=Sing	4	obl	4:obl:from	_
4	comes	come	VERB	VBZ	Mood=Ind|Number=Sing|Person=3|Tense=Pres|VerbForm=Fin	0	root	0:root	_
5	this	this	DET	DT	Number=Sing|PronType=Dem	6	det	6:det	_
6	story	story	NOUN	NN	Number=Sing	4	nsubj	4:nsubj	_
7	:	:	PUNCT	:	_	4	punct	4:punct	_

# sent_id = weblog-blogspot.com_nominations_20041117172713_ENG_20041117_172713-0002
# newpar id = weblog-blogspot.com_nominations_20041117172713_ENG_20041117_172713-p0002
# text = President Bush on Tuesday nominated two individuals to replace retiring jurists on federal courts in the Washington area.
1	President	President	PROP

Reading data and encoding it on tensors
----------------------------


The vocabulary class implements a vocabumary mapping strings to integers and vice-versa

In [None]:
class Vocabulary:
    """
    This is a class mapping symbols to integers and vice-versa
    """
    def __init__(self,symbols=None):

        self.symb2idx = {}
        self.idx2sym  = []
        if symbols:
            self.update(symbols)

    def __len__(self):
        return len(self.idx2sym)


    def update(self,symbol_list):
        """
        Adds new symbols to the vocabulary if not already in
        """
        for S in symbol_list:
            if S not in self.symb2idx:
                self.symb2idx[S] = len(self.idx2sym)
                self.idx2sym.append(S)

    def rev_lookup(self, idx):
        """
        This is a reverse lookup. Given an integer returns the string
        """
        return self.idx2sym[idx]


    def __call__(self,symbol,fallback=None):
        """
        This is an alias for the lookup method

        Symbol lookup in a vocabulary. Given a symbol returns the int code.
        fallback is a string used to return an ID when the symbol is unknown to the vocabulary
        """
        return self.lookup(symbol,fallback)


    def lookup(self,symbol,fallback=None):
        """
        Symbol lookup in a vocabulary. Given a symbol returns the int code.
        fallback is a string used to return an ID when the symbol is unknown to the vocabulary

        """
        if fallback:
            return self.symb2idx.get(symbol,self.symb2idx[fallback])
        else:
            return self.symb2idx[symbol]


The Conll Reader
--------------------

The ConllReader reads the data files and generates an enumeration of graph representations

In [None]:
class ConllReader:

    CONLL_FIELDS = ("tokidx", "token", "low_token","upos","pos","features","head","deplabel")

    def __init__(self,node_attr=("token","upos"),store_vocab=("token","upos","deplabel")):
            self.node_attr    = list(node_attr)
            if "tokidx" not in node_attr:
                self.node_attr.append("tokidx")

            self.stored_vocab = {}
            if store_vocab is not None:
                self.stored_vocab = {elt: [] for elt in store_vocab}


    def get_vocabulary(self,conll_field):
        """
        Returns the set of symbols for the conll field with name in one of:

        * token
        * low_token
        * pos
        * upos
        * deplabel

        """
        return list(set(self.stored_vocab[conll_field]))

    def __call__(self,filename):
        return self.readfile(filename)


    def readfile(self,filename):

        istream = open(filename)

        sent_struct = {}
        for line in istream:
            line         = line.strip()
            if line and line[0] != "#":
                #print(line)
                tokidx, token,low_token,upos,pos,features,headidx,deplabel,extended, _ =  line.split()
                features = dict(zip(ConllReader.CONLL_FIELDS,(tokidx, token,low_token,upos,pos,features,headidx,deplabel)))

                if not any(c in ["-","."] for c in tokidx): #skips multi word annotation

                    #extract_edges and nodes
                    edges = sent_struct.get("edges",[])
                    if extended != "_":  # conll extended case (creating graph structure)
                        govlist = extended.replace(" ","").split("|")
                        for gov_chunk in govlist:
                            headidx, deplabel = gov_chunk.split(":")[:2]
                            try:
                                edges.append( {"src":int(headidx),"dst":int(tokidx),"elbl":deplabel})
                            except ValueError:
                                pass
                    else:
                        tokidx, headidx = int(tokidx), int(headidx)
                        edges.append( {"src":headidx,"dst":tokidx,"elbl":deplabel} )

                    sent_struct["edges"] = edges
                    nodes = sent_struct.get("nodes",[])
                    nodes.append( {F:features[F] if F != "tokidx" else int(features[F]) for F in self.node_attr})
                    sent_struct["nodes"] = nodes

                    #update vocabulary if needed
                    for key, value in self.stored_vocab.items():
                        self.stored_vocab[key].append(features[key])

            elif sent_struct:
                yield sent_struct
                sent_struct = {}

        if sent_struct:
            yield sent_struct
        istream.close()

This cell exemplifies the usage of the ConllReader together with the Vocabulary class. It outputs a string representation of a graph

In [None]:
conll_reader = ConllReader()
corpus = list(conll_reader('dev.conllu'))
print(corpus[0])                                        #prints out the first graph of the dev set


{'edges': [{'src': 3, 'dst': 1, 'elbl': 'case'}, {'src': 3, 'dst': 2, 'elbl': 'det'}, {'src': 4, 'dst': 3, 'elbl': 'obl'}, {'src': 0, 'dst': 4, 'elbl': 'root'}, {'src': 6, 'dst': 5, 'elbl': 'det'}, {'src': 4, 'dst': 6, 'elbl': 'nsubj'}, {'src': 4, 'dst': 7, 'elbl': 'punct'}], 'nodes': [{'token': 'From', 'upos': 'ADP', 'tokidx': 1}, {'token': 'the', 'upos': 'DET', 'tokidx': 2}, {'token': 'AP', 'upos': 'PROPN', 'tokidx': 3}, {'token': 'comes', 'upos': 'VERB', 'tokidx': 4}, {'token': 'this', 'upos': 'DET', 'tokidx': 5}, {'token': 'story', 'upos': 'NOUN', 'tokidx': 6}, {'token': ':', 'upos': 'PUNCT', 'tokidx': 7}]}


Dataset and DataLoader
-----------------------

This class implements a pytorch Dataset for the parsing problem.
The Dataset stores a full list of graphs and the DataLoader is the class used to provide tensors to the neural network.
[See also the pytorch documentation](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html)

Here the DataLoader is the main class of interest.
To get a dataloader, one first creates a dataset and then calls the ``get_loader`` method of the data set


In [None]:
from torch.utils.data import Dataset,DataLoader

class ConllDataset(Dataset):
        """
        This class turns a conll into a dataset.
        It also provides a dataloader for training and predicting from this dataset
        """
        def __init__(self, vocabulary, filename):

            super(ConllDataset, self).__init__()
            self.data = []
            self.vocabulary = vocabulary
            conll_reader = ConllReader()
            self.data = list(conll_reader(filename))[:100]

        def __len__(self):
            return len(self.data)

        def __getitem__(self, idx):
            src_idxes  = torch.tensor([edge["src"] for edge in self.data[idx]["edges"]])
            tgt_idxes  = torch.tensor([edge["dst"] for edge in self.data[idx]["edges"]])
            tokens_ids = torch.tensor([self.vocabulary.lookup("[ROOT]")]+[self.vocabulary.lookup(node["token"],"[UNK]") for node in self.data[idx]["nodes"]])
            return (src_idxes,tgt_idxes,tokens_ids)

        def get_loader(self, batch_size=1, num_workers=4):
            return DataLoader(self, batch_size=batch_size, num_workers=num_workers, collate_fn=None,shuffle=True)


The Parsing model
-----------

The parsing model uses an auxiliary class that computes bilinear scores.
Given a words embedding matrix it computes a weighted adjacency matrix (the parsing graph).
In this exercise, we typically provide the same matrix to the forward function of this module


In [None]:
import torch
import torch.nn as nn


class Bilinear(nn.Module):

    def __init__(self,emb_size):
        super(Bilinear, self).__init__()
        self.W = nn.Parameter(torch.empty(emb_size,emb_size))
        nn.init.xavier_normal_(self.W)

    def forward(self,src_embedding,tgt_embedding):
        return src_embedding @ self.W @ tgt_embedding.transpose(-1,-2)



The parsing model will also include a FeedForwardNetwork auxiliary module to include below:

In [None]:
#Feed Forward Network

#<HERE>

In [None]:
import torch.optim as optim

from tqdm.notebook import tqdm #progress bar

class GraphParser(nn.Module):


    def __init__(self,vocabulary,emb_size):

        super(GraphParser, self).__init__()
        self.E          = nn.Embedding(len(vocabulary),emb_size)
        self.bilinear   = Bilinear(emb_size)


    def forward(self,tok_IDs):
        """
        Given token IDs, returns the predicted adjacency matrix
        """

        #Basic parsing
        X = self.E(tok_IDs)
        adjacency = self.bilinear(X, X)
        return adjacency


    def predict(self,predloader):
        """
        Generates graphs in a prediction context
        """
        with torch.no_grad():

            for (src_idx,tgt_idx,tok_IDs) in tqdm(predloader):
                src_idx      = src_idx.squeeze()
                tgt_idx      = tgt_idx.squeeze()
                tok_IDs      = tok_IDs.squeeze()
                predicted    = self.forward(tok_IDs)
                predicted    = predicted  > 0.
                yield (torch.nonzero(predicted),tok_IDs)



    def train(self,trainloader,epochs):

        optimizer = optim.Adam(self.parameters(),lr=0.001)
        loss_fnc   = nn.BCEWithLogitsLoss()

        #computes the gold adjacency matrix from sparse idxes
        def adjacency_fnc(src_idx,tgt_idx,seq_len):
            A = torch.zeros(seq_len,seq_len)
            A[src_idx,tgt_idx] = 1.
            return A


        for e in range(epochs):

            loss_lst = []

            for (src_idx,tgt_idx,tok_IDs) in tqdm(trainloader):
                optimizer.zero_grad()
                src_idx      = src_idx.squeeze() #we currently work in unbatched mode
                tgt_idx      = tgt_idx.squeeze()
                tok_IDs      = tok_IDs.squeeze()
                predicted    = self.forward(tok_IDs)
                loss = loss_fnc(predicted,adjacency_fnc(src_idx,tgt_idx,len(tok_IDs)))
                loss.backward()
                loss_lst.append(loss.item())
                optimizer.step()

            print("Epoch",e," Loss",sum(loss_lst))


####
#Turns a graph back to a Conll string.
def graph2conll(token_vocabulary,src,tgt,tok_IDs):
        """
        Outputs a conll from the parser output
        """
        tokens  = [token_vocabulary.rev_lookup(tokid) for tokid in tok_IDs]
        src,tgt = src.tolist(),tgt.tolist()
        edges   = [(src,tgt) for src,tgt in zip(src,tgt)]

        print("src",src)
        print("tgt",tgt)

        num_nodes = len(tok_IDs)
        result = [ [] for _ in range(num_nodes) ]
        for govid,depid in edges:
            result[depid].append( (depid,'none',govid) )

        #make_str
        for idx,elt in enumerate(result):
            if len(elt) == 0: #creates a dummy root link for unconnected nodes
                result[idx] = "\t".join([str(idx),str(tokens[idx]),"_","_","_","_","0",'root',"_","_"])
            else:
               dep,lbl,gov = elt[0]
               enhanced    = '|'.join(["%d:%s"%(gov,lbl)  for (dep,lbl,gov) in elt])
               result[idx] = ("\t".join([str(dep), str(tokens[idx]), "_", "_", "_", "_", str(gov), lbl, enhanced, "_"]))
        result[0] = "# text = "+' '.join(tokens[1:])
        return "\n".join(result)

In [None]:
#make vocabulary
conll_reader = ConllReader()
corpus = list(conll_reader('train.conllu'))
vocab = Vocabulary(symbols = ["[ROOT]","[UNK]"])
vocab.update(list(conll_reader.get_vocabulary('token')))

#make train & dev loader
trainset = ConllDataset(vocab,'train.conllu').get_loader()
devset   = ConllDataset(vocab,'dev.conllu').get_loader()
#testset  = ConllDataset(vocab,'mini.conll').get_loader()

#train model
gp = GraphParser(vocab,512)
# gp.train(devset,10)

#predict and display
for edges,tokens in gp.predict(devset):
    src,tgt = edges.T[0],edges.T[1]

    ic()
    print(graph2conll(vocab,src,tgt,tokens))
    print()
    break




  0%|          | 0/100 [00:00<?, ?it/s]



Epoch 0  Loss 678.7554614543915


  0%|          | 0/100 [00:00<?, ?it/s]

KeyboardInterrupt: ignored

In [None]:
print(list(devset)[0])



[tensor([[ 4,  4,  4,  0,  6,  4, 10,  9, 10,  4, 13, 13, 10, 17, 17, 17, 10, 23,
         23, 23, 23, 23, 10, 17, 25, 23, 29, 29, 29,  4, 29,  4]]), tensor([[ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
         19, 20, 21, 22, 23, 23, 24, 25, 26, 27, 28, 29, 30, 31]]), tensor([[    0,  2286, 19582,  5652,  1352, 15281, 12374,  6444, 17592,     1,
           571, 16298, 17592,  7870,   324, 19153, 19403, 11359,  6444, 19403,
          9251, 17592, 10136,  1106, 10502,   621,   324,  9259, 18881,  5688,
         11598, 10731]])]


Questions :
-----------

**For each question you have to write explictly where you added something in the code.
Answers to questions 1 to 3 and 8 will get you 10/20. Answering to more questions increases the note.**

**Besides code, adding comments and explanations to your answers  is required**



1. Add the key missing components to this parser : an LSTM and two feed forward networks to specialize word embeddings for governors and dependant. To do that you will implement a FeedForward Module by yourself. See also poly or [(Dozat and Manning 2018)](https://aclanthology.org/P18-2077.pdf)
2. Add a validation function to the parser. During training you should be able to report the loss on the training set and on the validation set.
3. Add an evaluation metric to your validation function able to measure the F-score of your parser:
$$F = \frac{2 P R}{P+R}$$

$$P = \frac{numPredictedCorrect}{totalPredicted}$$

$$R = \frac{numPredictedCorrect}{totalCorrect}$$

4. Explain how you manage unknown words. You may modify existing code to improve.
5. Add code to predict edge labels
6. Add code to perform batching. This requires to modify the `collate_fn` of the [DataLoader](https://pytorch.org/docs/stable/data.html) for padding
7. Add and test anything you find useful to this parser (including materials from other exercises or other classes)
8. Train and test your parser in order to get some decent results

*Note. The training process run without batching will be slow. I advise that you set up your code by using a small data set such as the dev set (or even smaller). Use the full training set once everything is working well. The edge label prediction and the batching question are harder than the others*

    