## Named entity recognition with federated learning 
Simple example in implementing a biLSTM model for NER, using PyTorch and PySyft. 

* Due to an ongoing issue with recurrent layers on PySyft (on the roadmap to be fixed!) we use a vanilla biLSTM implemented 'from scratch'. 
    * To replicate the issue, see: https://github.com/j-chim/nlp-examples/blob/pysyft-ner/pysyft/minimal-example.ipynb

In [1]:
from argparse import Namespace
import math
import numpy as np
import os
import pickle
import torch 
from tqdm.notebook import tqdm
import syft as sy

import torch.nn as nn
#from syft.frameworks.torch.nn import LSTM
import torch.optim as optim

### 1. Setup

#### a. Define filepaths and selected variables

In [2]:
args = Namespace(
    # Data and path
    train_path = "./data/eng.train",
    embeddings_path = "./data/glove.840B.300d.txt", 
    pad_token = "<PAD>",
    unk_token = "<UNK>",
    batch_size=32,
    
    # Model 
    embedding_dim = 300,
    lstm_dim = 56,
    
    # Training
    seed=42,
    num_epochs=30,
    lr=1e-5
)

In [3]:
torch.manual_seed(args.seed)

<torch._C.Generator at 0x1295e54b0>

#### b. Setup PySyft

In [4]:
hook = sy.TorchHook(torch)

In [5]:
alice = sy.VirtualWorker(hook, id="alice")
bob = sy.VirtualWorker(hook, id="bob")
carol = sy.VirtualWorker(hook, id="carol")

#### c. Prepare data
CoNLL 2003: https://www.clips.uantwerpen.be/conll2003/ner/
* This notebook only uses the English data.
* The raw data are a series of text files.
    * Each line looks like this: "EU NNP I-NP I-ORG"
    * Here we are only interested in the actual word ("EU") and the NER tag ("I-ORG")
* Process raw file into tensors, then into federated datasets

In [6]:
def extract_sentences(filepath, is_cased=True):
    """ Process data in the CoNLL-2003 format into words/labels per sentence. """
    sents, labels = [], []
    curr_sent_words, curr_sent_labels = [], []
    with open(filepath, 'r') as f:
        lines = f.readlines()
    
    for line in lines:
        if not is_cased:
            line = line.lower()
        if line == "\n":
            sents.append(curr_sent_words)
            labels.append(curr_sent_labels)
            curr_sent_words, curr_sent_labels = [], []
        else:
            elements = line.strip().split(" ")
            word = elements[0]
            label = elements[-1]
            curr_sent_words.append(word)
            curr_sent_labels.append(label)
    
    return sents, labels


def process_data(sents, labels, word2id, label2id, max_length, pad_token, unknown_token,
                 is_train=True):
    """ Processes data in the CoNLL 2003 format into torch tensors.
    
    Args:
      sents:         List of sentences (each a list of words)
      labels:        List of list of labels 
      word2id:       Dictionary mapping words to int indices
      label2id:      Dictionary mapping labels to int indices
      max_length:    Target length that we want to pad sequences to.
      pad_token:     String for padding.
      unknown_token: String to replace OOV words.      
      is_train:      Is processing for training data or not. 
                     If true, updates the dictionaries with each value.
                     If false, unseen words will be mapped to <UNK>.
    
    Returns:
      word2id:       Word-to-indices mapping. If is_train, it will be updatd.
      label2id:      Label-to-indices mapping. If is_train, it will be updated. 
      X:             Tensor of processed sentences.
      y:             Tensor of processed labels.
    """
    X, y = [], []
    if is_train:
        pad_id, unk_id = 0, 1
        word2id[pad_token] = pad_id
        label2id[pad_token] = pad_id
        word2id[unknown_token] = unk_id
    else:
        pad_id = word2id[pad_token]
        unk_id = word2id[unknown_token]
    
    for i, sent in enumerate(sents):
        X_i, y_i = [], []
        curr_labels = labels[i]
        sent = sent[:max_length]
        diff = max_length - len(sent)
        for word, label in zip(sent, curr_labels):
            if is_train:
                # If preparing training data, update mappings as we go
                word_id = word2id.get(word, None)
                if not word_id:
                    word_id = len(word2id)
                    word2id[word] = word_id
                
                label_id = label2id.get(label, None)
                if not label_id:
                    label_id = len(label2id)
                    label2id[label] = label_id
            else:
                # Otherwise, fetch id from existing mappings
                word_id = word2id.get(word, word2id[unknown_token])
                label_id = label2id.get(label)
                if not label_id:
                    raise LookupError(f'Unseen label {label}')
            X_i.append(word_id)
            y_i.append(label_id) 
            
        # Pad sequences to max length
        X_i.extend([pad_id] * diff)
        y_i.extend([pad_id] * diff)
        
        X.append(X_i)
        y.append(y_i)
            
    X = torch.tensor(X)   
    y = torch.tensor(y)
    
    return word2id, label2id, X, y

In [7]:
sents, labels = extract_sentences(args.train_path)

In [8]:
# Get maximum sequence length based on loaded sentences
max_train_sent_length = max(len(sent) for sent in sents)
max_length = 2**math.ceil(math.log2(max_train_sent_length))
max_length

128

In [9]:
word2id, label2id = dict(), dict()
word2id, label2id, X, y = process_data(
    sents=sents, labels=labels, word2id=word2id, label2id=label2id, 
    max_length=max_length, pad_token=args.pad_token, unknown_token=args.unk_token
)
vocab_size = len(word2id)
num_labels = len(label2id)

In [10]:
id2word = {v:k for k, v in word2id.items()}

In [11]:
X.shape, y.shape

(torch.Size([14986, 128]), torch.Size([14986, 128]))

We can now transform this into PySyft's federated dataset.

In [12]:
base = sy.BaseDataset(X[:1000], y[:1000]) # TODO: Run on full dataset.
dataset = base.federate((alice, bob))
train_loader = sy.FederatedDataLoader(
    federated_dataset=dataset, 
    batch_size=args.batch_size, 
    shuffle=True                     
)

Embed words
* Pre-trained word embeddings. Here we use common crawl glove cased, 840B/300d. 
    * Source: https://nlp.stanford.edu/projects/glove/

In [13]:
use_pretrained_embeddings = True
if use_pretrained_embeddings:
    word2vec_path = './word2vec.pkl'
    if os.path.exists(word2vec_path):
        with open(word2vec_path, 'rb') as f:
            word2vec = pickle.load(f)
    else:
        word2vec = dict()
        with open(args.embeddings_path, 'r') as f:
            lines = f.readlines()
        for line in tqdm(lines):        
            split_line = line.split(" ")
            word = split_line[0]
            vec = torch.tensor([float(v) for v in split_line[1:]])

            word2vec[word] = vec
        word2vec[args.unk_token] = torch.randn(300)
        #word2vec[args.unk_token] = torch.mean(, axis=0) # TODO: replace with mean
        with open(word2vec_path, 'wb') as f:
            pickle.dump(word2vec, f)

In [14]:
pretrained_embeddings = torch.zeros((vocab_size, args.embedding_dim))
for w, v in tqdm(word2vec.items()): # can prob change to id2vec
    idx = word2id.get(w, word2id[args.unk_token])
    pretrained_embeddings[idx,:] = v    

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=2196017.0), HTML(value='')))




### 2. Define model

In [15]:
class LSTM(nn.Module):
    """ Vanilla LSTM.
    Adapted from: https://pytorch.org/docs/stable/_modules/torch/nn/modules/rnn.html#LSTM"""
    def __init__(self, input_size, hidden_size, batch_first=True, 
                 bidirectional=True):
        """ 
        Args:
            input_size:    input vector size.
            hidden_size:   hidden state dimension.
            batch_first:   whether first dimension of inputs is batch_size.
            bidirectional: whether to run model in both directions.
        """
        super(LSTM, self).__init__()
        
        self.hidden_size = hidden_size
        self.batch_first = batch_first
        self.bidirectional = bidirectional
        
        # gate values computed using previous hidden state and current input
        self.concat_size = self.hidden_size + input_size
        
         # forget gate
        self.f_t = nn.Linear(self.concat_size, self.hidden_size)
        self.sigmoid_f = nn.Sigmoid()
        
        # input gate
        self.i_t = nn.Linear(self.concat_size, self.hidden_size)
        self.sigmoid_i = nn.Sigmoid()
        
        # output gate
        self.o_t = nn.Linear(self.concat_size, self.hidden_size)
        self.sigmoid_o = nn.Sigmoid()
        
        # candidate cell state
        self.c_tilde = nn.Linear(self.concat_size, self.hidden_size)
        self.tanh_cell = nn.Tanh()
        
        self.tanh_hidden = nn.Tanh()
        
    def init_states(self, x, batch_size):
        zeros = torch.zeros(batch_size,
                            self.hidden_size,
                            dtype=x.dtype, 
                            device=x.device)
        location = x.location
        if location is not None:
            return (zeros.send(location), zeros.send(location))
        return (zeros, zeros)
    
    def _pass(self, batch_size, x, prev_h, prev_c):
        concatenated = torch.cat((prev_h, x), -1)

        input_gate = self.sigmoid_i(self.i_t(concatenated))
        forget_gate = self.sigmoid_f(self.f_t(concatenated))
        output_gate = self.sigmoid_o(self.o_t(concatenated))

        c_tilde = self.tanh_cell(self.c_tilde(concatenated))
        cell_state = (forget_gate * prev_c) + (input_gate * c_tilde)
        
        hidden = output_gate * self.tanh_hidden(cell_state)
    
        return (hidden, cell_state)
        
        
    def forward(self, x, prev_hidden=None):
        """
        Args:
            x:              input tensor (batch_size, seq_size, feat_size) 
            prev_hidden:    tuple of initial/previous hidden state and cell state.
                            Defaults to zero. 
        Returns:
            hidden:         LSTM hidden state (batch_size, seq_size, hidden_size)
            cell_state:     LSTM cell state 
        """
        
        if self.batch_first:
            batch_size, seq_size, feat_size = x.shape
        else:
            seq_size, batch_size, feat_size = x.shape
            x = x.permute(1, 0, 2)             
        
        if prev_hidden is None:
            prev_h, prev_c = self.init_states(x, batch_size)
        else:
            prev_h, prev_c = prev_hidden
        
        h_forward, h_backward = [], []
        c_forward, c_backward = [], [] 
        h_0, c_0 = prev_h, prev_c
        
        for t in range(seq_size):
            x_t = x[:,t,:]
            prev_h, prev_c = self._pass(batch_size, x_t, prev_h, prev_c)
            h_forward.append(prev_h)
            c_forward.append(prev_c)
            
        if self.bidirectional:
            prev_h, prev_c = h_0, c_0
            for t in reversed(range(seq_size)):
                x_t = x[:,t,:]
                prev_h, prev_c = self._pass(batch_size, x_t, prev_h, prev_c)
                h_backward.append(prev_h)
                c_backward.append(prev_c)
            
            h = torch.cat((torch.stack(h_forward), torch.stack(h_backward)), -1)
            c = torch.cat((torch.stack(c_forward), torch.stack(c_backward)), -1)
        else:
            h = torch.stack(h_forward)
            c = torch.stack(c_forward)
            
        return (h, c)

Use LSTM implemented above and combine with embedding + linear layers

In [16]:
class Tagger(nn.Module):
    """ biLSTM for sequence tagging. """
    
    def __init__(self, args, vocab_size, num_labels, 
                 pretrained_embeddings=None, batch_first=True, padding_idx=0):
        """ 
        Args:
            args:                  namespace object containing configs.
            vocab_size:            number of unique words in training set. 
            num_labels:            number of output labels.
            pretrained_embeddings: pretrained word embedding weights.
                                   If None, weights will initialise randomly.
            batch_first:           whether first dimension of inputs is batch_size.
            padding_idx:           id corresponding to <PAD> tokens.
        """
        
        super(Tagger, self).__init__()
        self.batch_first = batch_first
        self.vocab_size = vocab_size
        self.embedding_size = args.embedding_dim
        self.lstm_hidden_size = args.lstm_dim
        self.num_labels = num_labels
        
        if pretrained_embeddings is not None:
            self.embedding = nn.Embedding.from_pretrained(
                pretrained_embeddings
            )
        else:
            self.embedding = nn.Embedding(
                num_embeddings=self.vocab_size,
                embedding_dim=self.embedding_size, 
                padding_idx=padding_idx            
            )
        
        self.lstm = LSTM(
            input_size=self.embedding_size,
            hidden_size=self.lstm_hidden_size,
            batch_first=self.batch_first,
            bidirectional=True
        )
        
        self.fc = nn.Linear(
            in_features=self.lstm_hidden_size*2,  # bidirectional 
            out_features=self.num_labels
        ) 
        
        self.log_softmax = nn.LogSoftmax(dim=2)
                    
    
    def forward(self, x):
        if self.batch_first:
            batch_size, seq_len = x.shape
        else:
            seq_len, batch_size = x.shape
            x = x.permute(1, 0)
            
        embedded = self.embedding(x)
        h, c = self.lstm(embedded)
        y_out = self.fc(h)
        
        new_feat_size = y_out.shape[-1]
        y_out = y_out.view(batch_size, new_feat_size, seq_len) 
        logits = self.log_softmax(y_out)
        
        return logits

In [17]:
model = Tagger(args=args, vocab_size=len(word2id), num_labels=len(label2id))
criterion = nn.NLLLoss()
optimizers = { 
    worker: optim.Adam(params=model.parameters(),lr=args.lr) 
    for worker in dataset.workers
}

In [None]:
for epoch in range(args.num_epochs):
    training_loss = 0
    
    for data, label in train_loader: 
        worker = data.location.id
        model.send(worker)

        optimizers[worker].zero_grad()
        log_probs = model(data)
        loss = criterion(log_probs, label.squeeze())
        loss.backward()
        optimizers[worker].step()

        model.get()
        
        curr_loss = loss.get().item()
        training_loss += curr_loss
    else:
        print(f"({worker}) training loss: {training_loss}")

(bob) training loss: 155.29152965545654
(bob) training loss: 155.29092264175415
(bob) training loss: 155.29072618484497
(bob) training loss: 155.28915214538574
(bob) training loss: 155.2884955406189
(bob) training loss: 155.28441667556763
(bob) training loss: 155.28483724594116
