Enoncé et consignes
===================

Le propos de ce projet est de finaliser l'implantation d'un parser en dépendances par transitions à partir d'une implémentation partielle limitée au tagging qui est proposée ci-dessous. 

Ce parser s'appuie sur un modèle d'analyse neuronal où la phrase d'entrée est encodée par deux bi-lstm, ce qui servira de features pour prendre les décisions de transition. Le modèle statistique de prédiction structurée sera un classifieur local.

Le parser utilisera l'algorithme arc-standard et pourra utiliser une méthode de recherche de solutions gloutonne.

Le jeu de données utilisé est une verions du corpus sequoia : 
    https://gforge.inria.fr/frs/?view=shownotes&group_id=3597&release_id=9064.
qui a la propriété d'être projectif. 
Le fichier de données utilisé est `sequoia-corpus.np_conll`


Objectifs
--------
Le parser à réaliser sera un hybride entre celui proposé par Kipperwasser et Goldberg (2016)
et le système multitâche pour l'analyse en constituants décrit par Coavoux et Crabbé (2017).

Soit $\mathbf{a} = a_1\ldots a_T$ une séquence de dérivation extraite d'un treebank par un oracle, avec $a_i$ une action et $\mathbf{w}=w_1\ldots w_n$ la séquence de mots correspondante.

La log-probabilité conditionnelle d'un arbre de dépendances se décompose par :

$
\begin{align*}
\log P(\mathbf{a} | \mathbf{w};\boldsymbol\theta) =  \sum_{i=1}^T \log P(a_i| \mathbf{w};\boldsymbol\theta ) 
\end{align*}
$

Soit $\mathbf{t} = t_1\ldots t_n$ une séquence de tags correspondant à $w_1\ldots w_n$, la log-probabilité conditionnelle de cette séquence de tags se décompose par :

$
\begin{align*}
\log P(\mathbf{t} | \mathbf{w};\boldsymbol\theta) =  \sum_{i=1}^n \log P(t_i| \mathbf{w};\boldsymbol\theta ) 
\end{align*}
$

L'apprentissage des paramètres consistera à maximiser la vraisemblance du treebank dans un contexte à deux objectifs : 

$
\begin{align*}
{\boldsymbol\theta} = \mathop{argmax}_{\boldsymbol\theta} \sum_{j=1}^N \\

\log P(\mathbf{a}_j | \mathbf{w}_j;\boldsymbol\theta) + \log P(\mathbf{t}_j | \mathbf{w}_j;\boldsymbol\theta) 


\end{align*}
$


Consignes spécifiques:
---------------------

1. Lire attentivement le code proposé ci-dessous et en discuter entre vous. 
2. Créer le système de transition :
    * Modifier la méthode `index_symbols` pour qu'elle encode sur le modèle des actions shift les actions de type left arc right arc sous forme de couple à deux éléments associés à des entiers
    * Implémenter les méthodes `shift` `left_arc`, `right_arc` 
    * Modifier la méthode `oracle` pour qu'elle implémente l'oracle statique arc-standard 
    * Implémenter la méthode `forward_parser`.
3. Implémenter une méthode `predict_tree(...)` qui, à partir d'une configuration finale, renvoie l'arbre de dépendances prédit sous forme d'un objet `DependencyTree`.
4. Implémenter une fonction `eval_parser` qui renvoie le LAS et l'UAS de votre parser.    
5. Gestion des mots inconnus. Suite à vos premières évaluations, vous aurez peut-être remarqué que le tagger ou le parser a des performances relativement moyennes. Pour améliorer la situation, on propose d'incorporer un module de gestion des mots inconnus.
    * Implanter une méthode `word_dropout(...,alpha=1.0)` qui prend en paramètre une liste de mots et qui renvoie cette liste avec certains mots remplacés par le token du mot inconnu. En général on choisit de remplacer un mot de basse fréquence $w$ par le token inconnu avec la probabilité $p = \frac{\alpha}{count(w)} $ où $\alpha \in [0,1]$ mais vous êtes libre de choisir votre méthode.  
    * Ajputer un modèle de caractère au parser. Celui-ci prend en entrée une string et renvoie un vecteur calculé par un modèle RNN de caractères. Intégrer ce modèle à l'analyseur pour améliorer la gestion des mots inconnus (**Exercice plus compliqué**, mais vous pouvez vous inspirer de tutoriels existants sur le web)
    * Evaluer l'impact de vos modifications et comparer par rapport à la baseline. Reporter vos UAS/LAS à chaque fois    



Biblio
------

*  Kipperwasser et Goldberg, "Simple and Accurate Dependency Parsing Using Bidirectional LSTM" , _ACL_, 2016.
*  Coavoux et Crabbé, "Multilingual Lexicalized Constituency Parsing with Word-Level Auxiliary Tasks", _EACL_ , 2017. 

In [None]:
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import sys
import numpy as np
from collections import Counter
from random import random

from tqdm import tqdm #optional, for progress bars in notebook.

Structure de données annexe
====================

In [None]:
class DependencyTree:

    DUMMY_ROOT = '$ROOT$'
    
    def __init__(self,wordlist,edges=None):
        
            self.wordlist = wordlist                  #hyp: DUMMY ROOT is wordlist[0]  
            self.edges    = edges if edges else []    #edges are 4-tuples (govidx,label,deptag,depidx)
            self.edges.sort(key = lambda x:x[3])
                
    def tag_list(self):
        return [deptag for (gidx,deplabel,deptag,didx) in self.edges]
    def word_list(self):
        return self.wordlist
    def dlabel_list(self):
         return [deplabel for (gidx,deplabel,deptag,didx) in self.edges]
    
    #aux functions useful for oracle
    def get_edge(self,gov_idx,dep_idx):
       #gets the edge (with labels) from gov_idx to dep_idx 
       #returns either a singleton or an empty list
       return [ (gidx,x,y,didx) for (gidx,x,y,didx) in self.edges \
                                if gidx == gov_idx and didx == dep_idx]

    def dom_edges(self,gov_idx):
         #gets all edges couples from gov_idx to some dependant node
         return [(gidx,didx) for (gidx,_,_,didx) in self.edges if gidx == gov_idx]
        
    @staticmethod
    def read_tree(istream):
        """
        Reads a Conll-u tree and returns it as a DepTree object
        returns None if no tree has been read
        """
        words = []
        edges = []
        for line in istream:
            line = line.split("#")[0]
            if line and not line.isspace():
                values = line.split()
                words.append(values[1])
                edges.append( (int(values[6]),values[7], values[3], int(values[0])) )
            elif words:
                return DependencyTree([DependencyTree.DUMMY_ROOT] + words,edges)
        return DependencyTree([DependencyTree.DUMMY_ROOT] + words,edges) if words else None
    
    
    def __str__(self):
        """
        Pretty prints a tree for debug
        Returns: a string
        """
        lines =  [(didx,self.wordlist[didx],dtag,dlabel,gidx) for (gidx,dlabel,dtag,didx) in self.edges] 
        lines.sort(key=lambda x:x[0])
        return '\n'.join( [ "%d\t%s\t%s\t%s\t%d"%(L) for L in lines ] )

Lecture des données
====================

In [None]:

def is_projective(root_idx,edges):
    rspan = [root_idx]
    children = [ didx for (gidx,_,_,didx) in edges if gidx==root_idx ]
    for child in children:
        proj,span = is_projective(child,edges)
        if not proj:
            return False,[]
        rspan.extend(span)
    rspan.sort()
    return  ( all([ jdx == idx+1 for idx,jdx in zip(rspan,rspan[1:])]),rspan )
    
def read_trees(filename):

    istream = open(filename)
    
    ilist = []
    dtree = DependencyTree.read_tree(istream)
    while dtree is not None:
        proj,span = is_projective(0,dtree.edges)
        if proj:
            ilist.append(dtree)
        #else:
        #    print('[warning] non projective tree skipped')
        dtree = DependencyTree.read_tree(istream)
        
    istream.close()
    return ilist
        
train_treebank = read_trees("/Users/bcrabbe/parsing-at-diderot/data/train.French.gold.conll")
valid_treebank = read_trees("/Users/bcrabbe/parsing-at-diderot/data/dev.French.gold.conll")
test_treebank  = read_trees("/Users/bcrabbe/parsing-at-diderot/data/test.French.gold.conll")

print('#training trees',len(train_treebank))
print('#validation trees',len(valid_treebank))
print('#test trees',len(test_treebank))


Codage
======

In [None]:
def make_sym2idx(wordlist,unk_symbol=None):
    """
    Creates a dictionary mapping symbols to integers
    Args:
       wordlist  (list): a list of strings
       unk_symbol (str): a special default string for symbols unknown to this dictionary
    Returns:
        a dict string => idx mapping words to integer indexes
    """
    if unk_symbol:
        wordlist.append(unk_symbol)
    wordset = set(wordlist)
    return dict(zip(wordset,range(len(wordset)))) 

def code_sequence(symlist,sym2idx,unk_symbol=None):
    """
    Maps a list of string to a list of int (encodes the sentence on integers) 
    Args:
        symlist  (list): a list of strings
        sym2idx  (dict): a dictionary mapping strings to int
        unk_symbol(str): a special default string for the dictionary sym2idx
    Returns a list of int as torch.tensor where words are mapped to their integer indexes.
    """
    def normal_form(symbol):
        return symbol if symbol in sym2idx else unk_symbol
    
    code_list = [sym2idx[normal_form(symbol)] for symbol in symlist]
    return torch.tensor(code_list, dtype=torch.long)

Modèle
=======

In [None]:
class ArcStandardParser(nn.Module):
    
    
    #ACTIONS
    SHIFT = "S"
    LARC  = "L"
    RARC  = "R"
    
    #DUMMY TAG
    DTAG  = 'ROOT'
    
    def __init__(self,embedding_dim,lstm_memory_dim,config_embedding_dim):
        """
        Args:
            embedding_dim        (int): size of word embeddings
            lstm_memory_dim      (int): size of the sentence encoder lstm memory 
            config_embedding_dim (int): size of the hidden layer of the output MLP.
        """
        super(ArcStandardParser, self).__init__()
        self.embedding_dim        = embedding_dim
        self.lstm_memory_dim      = lstm_memory_dim
        self.config_embedding_dim = config_embedding_dim
        
    ### INDEXING ###
    def index_symbols(self,treebank,unk_word,unk_char):
        """
        Indexes the x and y symbols on integers
        Args:
           treebank   (list): a list of DependencyTree objects
           unk_word (string): a symbol to use for unk words
           unk_char (string): a symbol to use for unk char
        Returns:
           (...)
        """
        wlist      = []   
        taglist    = [ArcStandardParser.DTAG]
        for dtree in treebank:
            wlist.extend(dtree.word_list())
            taglist.extend(dtree.tag_list())
            
        w2idx         = make_sym2idx(wlist,unk_symbol=unk_word)
        tag2idx       = make_sym2idx(taglist)
        wcounts       = Counter(wlist)
        return w2idx,tag2idx,  _  ,wcounts  #also return actions indexes here
    
    ### MODEL STRUCTURE AND INFERENCES ###
    def allocate_structure(self,nWords,nTags,nActions):
        """
        Allocates memory for the parser network params.
        Args:
           nWords  (int):number of words in the sym2idx dict
           nTags   (int):number of tags  in the parser
           nActions(int):number of actions in the parser
        """
        self.word_embeds           = nn.Embedding(nWords, self.embedding_dim)
        self.tagger_bilstm         = nn.LSTM(self.embedding_dim+self.char_lstm_memory_dim, self.lstm_memory_dim,num_layers=1,bidirectional=True)
        self.parser_bilstm         = nn.LSTM(self.lstm_memory_dim*2, self.lstm_memory_dim,num_layers=1,bidirectional=True)
        self.tagger_lstm2hidden    = nn.Linear(self.lstm_memory_dim*2,self.config_embedding_dim)
        self.hidden2tags           = nn.Linear(self.config_embedding_dim,nTags)
        self.parser_lstm2hidden    = nn.Linear(self.lstm_memory_dim*8,self.config_embedding_dim)
        self.hidden2actions        = nn.Linear(self.config_embedding_dim,nActions)
       
    def word_dropout(self,toklist,alpha=1.0):
        """
        Replaces each word in toklist with probability Lambda / counts(w)
        Args:
            toklist (list): a list of strings
            alpha  (float): a real in [0,1]. a word w is replaced with prob = alpha/counts(w)
        Returns:
            a list of strings with some strings replaced by the unk token
        """
        return toklist #Todo
         
    def train_model(self,train_set,validation_set,epochs=10,learning_rate=0.1,unk_word='##UNK##',unk_char='@@',alpha=0.5):
                    
        self.w2idx, self.tag2idx,self.a2idx,self.wcounts = self.index_symbols(train_set,unk_word,unk_char)
        
        self.reva2idx   = [action for action,idx in sorted(self.a2idx.items(),key = lambda x : x[1])]
        self.revtag2idx = [tag for tag,idx in sorted(self.tag2idx.items(),key = lambda x : x[1])]
        
        self.unk_word,self.unk_char = unk_word,unk_char
        self.allocate_structure(len(self.w2idx),len(self.tag2idx),len(self.a2idx))
        
        loss_function = nn.NLLLoss()
        optimizer     = optim.Adam(self.parameters(), lr=learning_rate)
        
        min_loss      = np.iinfo(np.int32).max
        for epoch in range(epochs):
            
            NLL = 0.0
            for deptree in tqdm(train_set): 
                
                self.zero_grad()                     
                
                X               = deptree.word_list()
                xtagger,xparser = self.forward_encoding(X,alpha)

                ypred,yref      = self.forward_tagger(xtagger,oracle_tree=deptree) # ypred is a sequence of y-softmaxed-probs for each step in the inference                      
                loss            = loss_function(ypred, yref)
                NLL            += loss.item()
                loss.backward(retain_graph=True) #for multiobjective (allows to call loss.backward() again later on)
                    
                ypred,yref      = self.forward_parser(xparser,oracle_tree=deptree)
                loss            = loss_function(ypred, yref)
                NLL            += loss.item()
                loss.backward()
                optimizer.step()
                
            print("\n[train]      Epoch %d, NLL = %f"%(epoch,NLL/len(train_set)),file=sys.stderr)
    
            with torch.no_grad():
                NLL = 0.0
                for deptree in validation_set:
                     X               = deptree.word_list()
                     xtagger,xparser = self.forward_encoding(X)
                        
                     ypred,yref      = self.forward_tagger(xtagger,oracle_tree=deptree)      # ypred is a sequence of y-softmaxed-probs for each step in the inference
                     loss            = loss_function(ypred, yref) 
                     NLL            += loss.item()
                    
                     ypred,yref      = self.forward_parser(xparser,oracle_tree=deptree)
                     loss            = loss_function(ypred, yref)
                     NLL            += loss.item()
                if NLL < min_loss:
                    torch.save(self.state_dict(), 'parsing_model.wt')        
                print("[validation] Epoch %d, NLL = %f\n"%(epoch,NLL/len(validation_set)),file=sys.stderr)  
            
        self.load_state_dict(torch.load('parsing_model.wt'))

        
    def forward_encoding(self,wordlist,alpha=0.0):
        """
        Performs the forward pass for encoding the input only.
        Args:
           wordlist   (list): a list of strings
           alpha     (float): 0 <= alpha <= 1 float for word dropout
        Returns:
           two lists of vectors encoding the sentence to be used for tagging and parsing the input
           as torch.tensor
        """        
        xwords                = code_sequence(self.word_dropout(wordlist,alpha),self.w2idx,unk_symbol=self.unk_word) 
        xword_embedded        = self.word_embeds(xwords)
        
        ### insert char model here ###
        # ...
        # ...
        # xchar_embedded = ...
        
        xcodes                = xword_embedded  #torch.cat([xword_embedded,xchar_embedded],1)

        xtagger, hidden_tag   = self.tagger_bilstm(xcodes.view(len(xcodes), 1, -1), None)
        xparser, hidden_par   = self.parser_bilstm(xtagger, None)
        return (xtagger,xparser)
        
        
    def forward_tagger(self,xtagger,oracle_tree=None):
        """
        Performs the forward pass for tagging only.
        Args:
           xtagger               (list): a list of vectors,one per word,encoding the input sentence
           oracle_tree (DependencyTree): an optional Dependency tree used as oracle
        Returns:
           a list of tags (one per word, including the dummy word)
           or
           a list of softmax distributions as torch.tensor and the list of reference tags if an oracle_tree is provided
        """
        
        hidden        = F.relu(self.tagger_lstm2hidden(xtagger.squeeze()))
        softmaxin     = self.hidden2tags(hidden)
        log_softmaxes = F.log_softmax(softmaxin, dim=1)  

        if oracle_tree:
            ref_tags = [ArcStandardParser.DTAG] + oracle_tree.tag_list()
            ref_tags = code_sequence(ref_tags,self.tag2idx)
            return (log_softmaxes,ref_tags)
        else:
            max_probs,  argmaxlist  = torch.max(log_softmaxes,1)
            pred_tags               = [self.revtag2idx[argmax] for argmax in argmaxlist]
            return pred_tags
        
    def forward_parser(self,xparser,oracle_tree=None):
        """
        Performs the forward pass for parsing only.
        Args:
           xparser               (list): a list of vectors,one per word,encoding the input sentence
           oracle_tree (DependencyTree): an optional Dependency tree used to drive the oracle
        Returns:
           the final configuration
           or
           a list of softmax distributions, and the list of reference actions if an oracle_tree is provided
        """        
        #TODO
        pass
       
        
    def predict_tree(self,wordlist):
        """
        Actually performs the tagging and parsing of a sentence.
        Args:
            wordlist (list): a list of strings
        Returns:
            a tree  (DependencyTree): the dep tree and its tags
        """
        with torch.no_grad():
            xtagger,xparser  = self.forward_encoding(wordlist)
            taglist          = self.forward_tagger(xtagger)
            pass #TODO
        
    ### TRANSITION SYSTEM ###
    def init_configuration(self,N):
        """
        Creates an init configuration for an arc standard parser.
        A configuration is a quadruple (stack_idxes,buffer_idxes,dep_arcs)
        Args:
            N  (int): the length of the input
        Returns:
            the init configuration
        """
        return ([0],list(range(1,N)),[]) 
        # configs stack and buffer are filled with integer indexes of word positions
        # stack starts with [0] the dummy root node
        
    def shift(self,configuration):
        """
        Performs shift on a configuration and returns the result
        Args:
           configuration (tuple): a configuration to be shifted
        Returns: 
           the shifted configuration
        """
        pass #TODO
                
    def left_arc(self,configuration,deplabel):
        """
        Performs a left arc action on a configuration and returns the result
        Args:
           configuration (tuple): a configuration
           deplabel     (string): a string labelling the arc
        Returns: 
           the resulting configuration
        """
        pass #TODO
    
    def right_arc(self,configuration,deplabel):
        """
        Performs a right arc action on a configuration and returns the result
        Args:
           configuration (tuple): a configuration
           deplabel     (string): a string labelling the arc
        Returns: 
           the resulting configuration
        """
        pass #TODO
    
    def run_time_mask(self,configuration,log_mask=True):
        """
        The run time mask returns a mask where impossible actions have 0 score.
        The mask is a vector of mask pseudo scores with dimensions comparable to the output action score vector.
        Impossible actions are filled with a -inf value, and possible actions with a 0 value.
        
        The mask can be added to action scores prior to softmax.
        Using the mask ensures that we have no fatal error while performing oracle-less runtime search.
        Args:
           configuration (tuple): a configuration
        Kwargs:
            log_mask      (bool): bool stating if the mask is in logit_space or in probabilistic space.
        Returns:
            torch.tensor of size num Actions filled with masking values
        """
        m0,m1 = (-float('inf'), 0.0) if log_mask else (0.0,1.0)
        
        def valmask(atype,stack_size,buffer_size,configuration):
            if atype in [ArcStandardParser.LARC,ArcStandardParser.RARC] and stack_size < 2:
                return m0
            elif atype == ArcStandardParser.LARC and stack_size >= 2:
                S,B,A = configuration
                if S[-2] == 0:
                    return m0
            elif atype == ArcStandardParser.SHIFT and buffer_size == 0:
                return m0
            return m1
            
        S,B,A       = configuration
        stack_size  = len(S)
        buffer_size = len(B)
        return torch.tensor([ valmask(atype,stack_size,buffer_size,configuration)   for (atype,alabel) in self.reva2idx])  

    def oracle(self,config,deptree):
        """
        Deterministic function translating a dependency tree into a derivation sequence step by step.
        Arc standard static oracle.
        Args:
            config           (tuple): a configuration
            deptree (DependencyTree): the tree used as ground truth
        Returns:
            the action as a tuple (ActionType,ActionLabel)
        """
        #TODO
        pass 
    
    

Evaluation
==========

In [None]:
def eval_parser(model,test_treebank):
    
    las_correct = 0
    uas_correct = 0
    tag_correct = 0
    
    for dtree in test_treebank:
        
        pass #TODO
    
    
    


In [None]:
m = ArcStandardParser(32,64,100)
m.train_model(train_treebank,valid_treebank,epochs=3,learning_rate=0.001,alpha=0.5)

In [None]:
eval_parser(m,test_treebank)
