# The LSTM-CRF tutorial for POS task

In [1]:
from fastNLP.modules.encoder.lstm import LSTM
from fastNLP.modules.encoder.linear import Linear
from fastNLP.modules.encoder.embedding import Embedding
from fastNLP.modules.decoder.CRF import ConditionalRandomField
from fastNLP.io.dataset_loader import Conll2003Loader
from fastNLP.models.base_model import BaseModel
from fastNLP import Vocabulary
from fastNLP.modules import decoder, encoder
from fastNLP.modules.utils import seq_mask
from fastNLP.core.metrics import MetricBase
from fastNLP import Trainer, Tester
from fastNLP import AccuracyMetric
from fastNLP.core.optimizer import SGD
import numpy as np
import torch
import torch.nn as nn



## The BiLSTM CRF model
The model consists of two main layers, the first is the LSTM layers, which provide a softmax ouput for CRF layer to calculate the score loss and infer the right route

In [2]:
class BiLSTMCRF(nn.Module):
    
    def __init__(self, config):
        super(BiLSTMCRF, self).__init__()
        vocab_size = config["vocab_size"]
        word_emb_dim = config["word_emb_dim"]
        hidden_dim = config["rnn_hidden_units"]
        num_classes = config["num_classes"]
        bi_direciton = config["bi_direction"]
        self.Embedding = Embedding(vocab_size, word_emb_dim)
        self.Lstm = LSTM(word_emb_dim, hidden_dim, bidirectional=bi_direciton)
        self.Linear = Linear(2*hidden_dim if bi_direciton else hidden_dim, num_classes)
        self.Crf = ConditionalRandomField(num_classes)
        self.mask = None
        

    def forward(self, token_index_list, origin_len, speech_index_list=None):
        """
        
            param: token_index_list [batch_size, padded_len]: The word2index list
            param: origin_len [batch_size]: The origin length of the sentence in each batch
            param: speech_index_list [batch_size, padded_len]: The expected speech tagging list for each sentences
            ret: json-like result will be utilized by Trainer or Tester, {"loss": loss, "pred": tag_seq}
        
        """
        max_len = len(token_index_list[0])
        self.mask = self.make_mask(token_index_list, origin_len)
        
        x = self.Embedding(token_index_list) # [batch_size, max_len, word_emb_dim]
        x = self.Lstm(x) # [batch_size, max_len, hidden_size]
        x = self.Linear(x) # [batch_size, max_len, num_classes]
        
        loss = None
        ## Calculate the loss value if in the training mode(the speech_index_list is given)
        if speech_index_list is not None:
            total_loss = self.Crf(x, speech_index_list, self.mask) ## [batch_size, 1]
            loss = torch.mean(total_loss)
            
        
        ## Get the POS sequence(padding the sequence to equal length) 
        tag_seq = self.Crf.viterbi_decode(x, self.mask)
        for index in range(len(tag_seq)):
            bias = max_len - origin_len[index]
            for i in range(origin_len[index], max_len):
                tag_seq[index][i] = 0
        
        return {
            "loss": loss,
            "pred": tag_seq
        }
        
    
    def make_mask(self, x, seq_len):
        ## make the mask for batch-load datasets 
        batch_size, max_len = x.size(0), x.size(1)
        mask = seq_mask(seq_len, max_len)
        mask = mask.view(batch_size, max_len)
        mask = mask.to(x).float()
        return mask

## The util function to load the data & format to Dataset that fastNLP will use

In [None]:
def prepare_data():
    ## load the data from the textfile
    datasets = load_data(Conll2003Loader(), [\
                    "./data/conll2003/train.txt",
                    "./data/conll2003/valid.txt",
                    "./data/conll2003/test.txt"
                  ])
    train_data = datasets[0]
    valid_data = datasets[1]
    test_data = datasets[2]
    
    #Lower case the words in the sentences
    lower_case([train_data, valid_data, test_data], "token_list")
    
    ## Build vocab
    vocab = build_vocab([train_data, valid_data, test_data], "token_list")
    speech_vocab = build_vocab([train_data, valid_data, test_data], "label0_list")
    
    ## Build index
    build_index([train_data, valid_data, test_data], "token_list", 'token_index_list', vocab)
    build_index([train_data, valid_data, test_data], "label0_list", 'speech_index_list', speech_vocab)
    
    
    ## Build origin length for each sentence, for mask in the following procedure
    build_origin_len([train_data, valid_data, test_data], "token_list", 'origin_len')
    
    return train_data, valid_data, test_data, vocab, speech_vocab

## Preparing the data

In [4]:
train_data, valid_data, test_data, vocab, speech_vocab = prepare_data()

## Set the corresponding tags for each dataset, which will be used in the Trainer
train_data.set_input("token_index_list", "origin_len", "speech_index_list")
test_data.set_input("token_index_list", "origin_len", "speech_index_list")
valid_data.set_input("token_index_list", "origin_len", "speech_index_list")

train_data.set_target("speech_index_list")
test_data.set_target("speech_index_list")
valid_data.set_target("speech_index_list")
    
    


## Test the model
tester = Tester(data=test_data, 
          model=model, 
          metrics=PosMetric(pred='pred', target='speech_index_list'),
   )
tester.test()

## Build the model

In [5]:
config = {
    "vocab_size": len(vocab),
    "word_emb_dim": 200, 
    "rnn_hidden_units": 600,
    "num_classes": len(speech_vocab),
    "bi_direction": True
}
model = BiLSTMCRF(config)

## Introduce the Metric for model

In [None]:
class PosMetric(MetricBase):
    """
        The PosMetric use the accuracy of each word to 
        evaluate the performance of POS task, suggested by the 
        original paper on https://arxiv.org/abs/1508.01991
    """
    def __init__(self, pred=None, target=None):
        super().__init__()
        self._init_param_map(pred=pred, target=target)
        self.total = 0
        self.acc_count = 0
        

    def evaluate1(self, pred, target):
        """
            Each time when loading a batch of data in the Trainer&Tester, 
            this function would be called for one time. So we can use some 
            class member to memorize the state in the training process.
        
        """
        self.acc_count += torch.sum(torch.eq(pred, target).float()).item()
        self.total += np.prod(list(pred.size()))

    def evaluate(self, pred, target):  
        
        for i in range(len(pred)):
            for j in range(len(pred[0])):
                if target[i][j] != 0:
                    self.acc_count += 1 if target[i][j] == pred[i][j] else 0
                    self.total += 1
    
    def get_metric(self):
        """
            As suggested in the tutorial, this function would be called once 
            the Trainer finished 1 epoch of training on the whole dataset.
            
            :return {"acc": float}
        """
        
        return {
            'acc': round(self.acc_count / self.total, 6)
        }

## Build the Trainer

In [None]:
optimizer = Adam(lr=0.01) 


## Train the model
trainer = Trainer(
    model=model, 
    train_data=train_data, 
    dev_data=valid_data,
    use_cuda=True,
    metrics=PosMetric(pred='pred', target='speech_index_list'),
    optimizer=optimizer,
    n_epochs=5, 
    batch_size=100,
    save_path="./save"
)
trainer.train()

## Build the Tester

In [84]:
test = Tester(data=test_data, 
              model=model, 
              metrics=PosMetric(pred='pred', target='speech_index_list')
       )
test.test()