# Example: Federated named entity recognition

This notebook goes through data preparation and distributed learning for NER using PyTorch and [PySyft](https://github.com/OpenMined/PySyft). 

Note that we implement a vanilla biLSTM for the time being since there's an outstanding library issue. ([Minimal example + link to issue](https://github.com/j-chim/nlp-examples/blob/pysyft-ner/pysyft/minimal-example.ipynb))

# Step 0: Setup

In [None]:
# Install required packages
!pip install tf-encrypted
!pip install 'syft[udacity]'

In [None]:
from argparse import Namespace
import math
import numpy as np
import os
import pickle
import torch 
from tqdm.notebook import tqdm
from sklearn.model_selection import train_test_split
import syft as sy

import torch.nn as nn
import torch.optim as optim

Define filepaths and selected variables

In [None]:
args = Namespace(
    train_path = "./data/eng.train",
    embeddings_path = "./data/glove.840B.300d.txt", 
    pad_token = "<PAD>",
    unk_token = "<UNK>",
    batch_size = 32,
    
    # Model 
    embedding_dim = 300,
    lstm_dim = 128,
    
    # Training
    seed = 42,
    num_epochs = 30,
    lr = 1e-5
)

In [None]:
assert(os.path.exists(args.train_path))

In [None]:
torch.manual_seed(args.seed) # set random seed with args

# Step 1: Create embeddings + federated datasets

The following cell downloads pretrained word embeddings. It may take a few minutes.
* For our example we used [common crawl glove cased](https://nlp.stanford.edu/projects/glove/), 840B/300d.

We don't provide the training file directly here, but they can be accessed through the [CoNLL 2003 official website](https://www.clips.uantwerpen.be/conll2003/ner/).
*   This notebook only uses the English data. The raw data are a series of text files. 
* Each line looks like this: "EU NNP I-NP I-ORG"
* Here we are only interested in the actual word ("EU") and the NER tag ("I-ORG").

When setting args later, make sure the paths to the training files and embeddings are set accordingly. 

a. Process embeddings

In [None]:
# Download pretrained word embeddings.
# We can also skip this and randomly initialise
use_pretrained_embeddings = True

if use_pretrained_embeddings:
  !mkdir -p ./data/
  !wget http://nlp.stanford.edu/data/glove.840B.300d.zip
  !unzip compressed_file_name.zip -d ./data/
else:
    pickled_embeddings_path = './word2vec.pkl'
    if os.path.exists(pickled_embeddings_path):
        with open(pickled_embeddings_path, 'rb') as f:
            vector_mapper = pickle.load(f)
    else:
        vector_mapper = dict()
        with open(args.embeddings_path, 'r') as f:
            lines = f.readlines()
        for line in tqdm(lines):        
            split_line = line.split(" ")
            word = split_line[0]
            vec = torch.tensor([float(v) for v in split_line[1:]])

            vector_mapper[word] = vec
        vector_mapper[args.unk_token] = torch.randn(300) # use mean instead of randn 
                                                         # for better performance
        with open(pickled_embeddings_path, 'wb') as f:
            pickle.dump(vector_mapper, f)
    
    pretrained_embeddings = torch.zeros((vocab_size, args.embedding_dim))
    for w, v in tqdm(vector_mapper.items()): 
        idx = word2id.get(w, word2id[args.unk_token])
        pretrained_embeddings[idx,:] = v   

b. Define virtual workers

In [None]:
hook = sy.TorchHook(torch)

alice = sy.VirtualWorker(hook, id="alice")
bob = sy.VirtualWorker(hook, id="bob")
carol = sy.VirtualWorker(hook, id="carol")

c. Prepare data

Process raw file into tensors, then into federated datasets.

In [None]:
def extract_sentences(filepath, is_cased=True):
    """ Process data in the CoNLL-2003 format into words/labels per sentence. """
    sents, labels = [], []
    curr_sent_words, curr_sent_labels = [], []
    with open(filepath, 'r') as f:
        lines = f.readlines()
    
    for line in lines:
        if not is_cased:
            line = line.lower()
        if line == "\n":
            sents.append(curr_sent_words)
            labels.append(curr_sent_labels)
            curr_sent_words, curr_sent_labels = [], []
        else:
            elements = line.strip().split(" ")
            word = elements[0]
            label = elements[-1]
            curr_sent_words.append(word)
            curr_sent_labels.append(label)
    
    return sents, labels


def process_data(sents, labels, word2id, label2id, max_length, pad_token, unknown_token,
                 is_train=True):
    """ Processes data in the CoNLL 2003 format into torch tensors.
    
    Args:
      sents:         List of sentences (each a list of words)
      labels:        List of list of labels 
      word2id:       Dictionary mapping words to int indices
      label2id:      Dictionary mapping labels to int indices
      max_length:    Target length that we want to pad sequences to.
      pad_token:     String for padding.
      unknown_token: String to replace OOV words.      
      is_train:      Is processing for training data or not. 
                     If true, updates the dictionaries with each value.
                     If false, unseen words will be mapped to <UNK>.
    
    Returns:
      word2id:       Word-to-indices mapping. If is_train, it will be updatd.
      label2id:      Label-to-indices mapping. If is_train, it will be updated. 
      X:             Tensor of processed sentences.
      y:             Tensor of processed labels.
    """
    X, y = [], []
    if is_train:
        pad_id, unk_id = 0, 1
        word2id[pad_token] = pad_id
        label2id[pad_token] = pad_id
        word2id[unknown_token] = unk_id
    else:
        pad_id = word2id[pad_token]
        unk_id = word2id[unknown_token]
    
    for i, sent in enumerate(sents):
        X_i, y_i = [], []
        curr_labels = labels[i]
        sent = sent[:max_length]
        diff = max_length - len(sent)
        for word, label in zip(sent, curr_labels):
            if is_train:
                # If preparing training data, update mappings as we go
                word_id = word2id.get(word, None)
                if not word_id:
                    word_id = len(word2id)
                    word2id[word] = word_id
                
                label_id = label2id.get(label, None)
                if not label_id:
                    label_id = len(label2id)
                    label2id[label] = label_id
            else:
                # Otherwise, fetch id from existing mappings
                word_id = word2id.get(word, word2id[unknown_token])
                label_id = label2id.get(label)
                if not label_id:
                    raise LookupError(f'Unseen label {label}')
            X_i.append(word_id)
            y_i.append(label_id) 
            
        # Pad sequences to max length
        X_i.extend([pad_id] * diff)
        y_i.extend([pad_id] * diff)
        
        X.append(X_i)
        y.append(y_i)
            
    X = torch.tensor(X)   
    y = torch.tensor(y)
    
    return word2id, label2id, X, y

In [None]:
sents, labels = extract_sentences(args.train_path)
_X_train, _X_val, _y_train, _y_val = train_test_split(
    sents, labels, test_size=0.1, random_state=args.seed)

In [None]:
# Get maximum sequence length based on loaded sentences
max_train_sent_length = max(len(train_sent) for train_sent in _X_train)
max_length = 2**math.ceil(math.log2(max_train_sent_length))
max_length

In [None]:
word2id, label2id = dict(), dict()
word2id, label2id, X_train, y_train = process_data(
    sents=_X_train, labels=_y_train, word2id=word2id, label2id=label2id, 
    max_length=max_length, pad_token=args.pad_token, unknown_token=args.unk_token
)

vocab_size = len(word2id)
num_labels = len(label2id)

id2word = {v:k for k, v in word2id.items()}

X_train.shape, y_train.shape

In [None]:
_, _, X_val, y_val = process_data(
    sents=_X_val, labels=_y_val, word2id=word2id, label2id=label2id, 
    max_length=max_length, pad_token=args.pad_token, unknown_token=args.unk_token,
    is_train=False
)
X_val.shape, y_val.shape

In [None]:
train_base = sy.BaseDataset(X_train, y_train) 
train_dataset = train_base.federate((alice, bob))
train_loader = sy.FederatedDataLoader(
    federated_dataset=train_dataset, 
    batch_size=args.batch_size, 
    shuffle=True                     
)

val_base = sy.BaseDataset(X_val, y_val) 
val_dataset = val_base.federate((alice, bob))
val_loader = sy.FederatedDataLoader(
    federated_dataset=val_dataset, 
    batch_size=args.batch_size, 
    shuffle=True                     
)

# Step 2: Define Model

Normally, this can be replaced with torch.nn.LSTM().

In [None]:
class LSTM(nn.Module):
    """ Vanilla LSTM.
    Adapted from: https://pytorch.org/docs/stable/_modules/torch/nn/modules/rnn.html#LSTM"""
    def __init__(self, input_size, hidden_size, 
                 batch_first=True,  bidirectional=True):
        """ 
        Args:
            input_size:    input vector size.
            hidden_size:   hidden state dimension.
            batch_first:   whether first dimension of inputs is batch_size.
            bidirectional: whether to run model in both directions.
        """
        super(LSTM, self).__init__()
        
        self.hidden_size = hidden_size
        self.batch_first = batch_first
        self.bidirectional = bidirectional
        
        # gate values computed using previous hidden state and current input
        self.concat_size = self.hidden_size + input_size
        
         # forget gate
        self.f_t = nn.Linear(self.concat_size, self.hidden_size)
        self.sigmoid_f = nn.Sigmoid()
        
        # input gate
        self.i_t = nn.Linear(self.concat_size, self.hidden_size)
        self.sigmoid_i = nn.Sigmoid()
        
        # output gate
        self.o_t = nn.Linear(self.concat_size, self.hidden_size)
        self.sigmoid_o = nn.Sigmoid()
        
        # candidate cell state
        self.c_tilde = nn.Linear(self.concat_size, self.hidden_size)
        self.tanh_cell = nn.Tanh()
        
        self.tanh_hidden = nn.Tanh()        
        
    def init_states(self, x, batch_size):
        zeros = torch.zeros(batch_size,
                            self.hidden_size,
                            dtype=x.dtype, 
                            device=x.device)
        location = x.location
        if location is not None:
            return (zeros.send(location), zeros.send(location))
        return (zeros, zeros)
    
    def _pass(self, batch_size, x, prev_h, prev_c):
        concatenated = torch.cat((prev_h, x), -1)
        
        input_gate = self.sigmoid_i(self.i_t(concatenated))
        forget_gate = self.sigmoid_f(self.f_t(concatenated))
        output_gate = self.sigmoid_o(self.o_t(concatenated))

        c_tilde = self.tanh_cell(self.c_tilde(concatenated))
        cell_state = (forget_gate * prev_c) + (input_gate * c_tilde)
        
        hidden = output_gate * self.tanh_hidden(cell_state)
    
        return (hidden, cell_state)
        
        
    def forward(self, x, prev_hidden=None):
        """
        Args:
            x:              input tensor (batch_size, seq_size, feat_size) 
            prev_hidden:    tuple of initial/previous hidden state and cell state.
                            Defaults to zero. 
        Returns:
            hidden:         LSTM hidden state (batch_size, seq_size, hidden_size)
            cell_state:     LSTM cell state 
        """
        
        if self.batch_first:
            batch_size, seq_size, feat_size = x.shape
        else:
            seq_size, batch_size, feat_size = x.shape
            x = x.permute(1, 0, 2)             
        
        if prev_hidden is None:
            prev_h, prev_c = self.init_states(x, batch_size)
        else:
            prev_h, prev_c = prev_hidden
        
        h_forward, h_backward = [], []
        c_forward, c_backward = [], [] 
        h_0, c_0 = prev_h, prev_c
            
            
        for t in range(seq_size):
            x_t = x[:,t,:]
            prev_h, prev_c = self._pass(batch_size, x_t, prev_h, prev_c)
            h_forward.append(prev_h)
            c_forward.append(prev_c)
            
        if self.bidirectional:
            prev_h, prev_c = h_0, c_0
            for t in reversed(range(seq_size)):
                x_t = x[:,t,:]
                prev_h, prev_c = self._pass(batch_size, x_t, prev_h, prev_c)
                h_backward.append(prev_h)
                c_backward.append(prev_c)
            
            h = torch.cat((torch.stack(h_forward), torch.stack(h_backward)), -1)
            c = torch.cat((torch.stack(c_forward), torch.stack(c_backward)), -1)
        else:
            h = torch.stack(h_forward)
            c = torch.stack(c_forward)
            
        return (h, c)

Use LSTM implemented above and combine with embedding + linear layers.

In [None]:
class Tagger(nn.Module):
    """ biLSTM for sequence tagging. """
    
    def __init__(self, args, vocab_size, num_labels, 
                 pretrained_embeddings=None, batch_first=True, padding_idx=0):
        """ 
        Args:
            args:                  namespace object containing configs.
            vocab_size:            number of unique words in training set. 
            num_labels:            number of output labels.
            pretrained_embeddings: pretrained word embedding weights.
                                   If None, weights will initialise randomly.
            batch_first:           whether first dimension of inputs is batch_size.
            padding_idx:           id corresponding to <PAD> tokens.
            
        """
        
        super(Tagger, self).__init__()
        self.batch_first = batch_first
        self.vocab_size = vocab_size
        self.embedding_size = args.embedding_dim
        self.lstm_hidden_size = args.lstm_dim
        self.num_labels = num_labels
        
        if pretrained_embeddings is not None:
            self.embedding = nn.Embedding.from_pretrained(
                pretrained_embeddings
            )
        else:
            self.embedding = nn.Embedding(
                num_embeddings=self.vocab_size,
                embedding_dim=self.embedding_size, 
                padding_idx=padding_idx            
            )
        
        self.lstm = LSTM(
            input_size=self.embedding_size,
            hidden_size=self.lstm_hidden_size,
            batch_first=self.batch_first,
            bidirectional=True
        )
        
        self.fc = nn.Linear(
            in_features=self.lstm_hidden_size*2,  # bidirectional 
            out_features=self.num_labels
        ) 
        
        self.log_softmax = nn.LogSoftmax(dim=2)
                    
    
    def forward(self, x):
        if self.batch_first:
            batch_size, seq_len = x.shape
        else:
            seq_len, batch_size = x.shape
            x = x.permute(1, 0)
            
        embedded = self.embedding(x)
        h, c = self.lstm(embedded)
        y_out = self.fc(h)
        
        new_feat_size = y_out.shape[-1]
        y_out = y_out.view(batch_size, new_feat_size, seq_len) 
        logits = self.log_softmax(y_out)
        
        return logits

# Step 3: Instantiate our model

In [None]:
model = Tagger(args=args, vocab_size=len(word2id), num_labels=len(label2id))
criterion = nn.NLLLoss()

compute_nodes = [ w for w in train_dataset.workers]
models = { n: model.copy() for n in compute_nodes }
optimizers = { 
    n: optim.Adam(params=models[n].parameters(),lr=args.lr) 
    for n in compute_nodes
}
params = { n: list(m.parameters()) for n, m in models.items() }

# Step 4: Train over our distributed dataset

This roughly follows the example outlined in [PySyft's official tutorial](https://github.com/OpenMined/PySyft/blob/master/examples/tutorials/Part%2010%20-%20Federated%20Learning%20with%20Secure%20Aggregation.ipynb).

* Use fixed precision encoding to encode parameters for SMPC
* Share encrypted parameters to workers
* Update central model by fetching and decrypting updates from remote workers 

In [None]:
def update(data, label, model, optimizer):
    model.send(data.location)
    optimizer.zero_grad()

    log_probs = model(data)
    loss = criterion(log_probs, label.squeeze())

    loss.backward()
    optimizer.step()

    return model.get()

In [None]:
for epoch in range(args.num_epochs):
    num_workers = len(compute_nodes)
    num_model_params = len(params[compute_nodes[0]])

    for data, label in train_loader:

        # update remote models
        worker = data.location.id
        models[worker] = update(data, label, model, optimizers[worker])

        # encrypted aggregation
        updated_params = []
        for param_idx in range(num_model_params):
            spdz_params = []
            for worker in compute_nodes:
                param_i = params[worker][param_idx]
                param_i_encrypted = param_i.fix_precision()
                encrypted_ptr = param_i_encrypted.share(
                    *compute_nodes, crypto_provider=carol)
                spdz_params.append(encrypted_ptr.get())
            updated_param = (sum(spdz_params).float_precision()) / num_workers
            updated_params.append(updated_param)
   
        # clean up
        with torch.no_grad():
            for worker, worker_params in params.items():
                for param_idx in range(num_model_params):
                    worker_params[param_idx].set_(updated_params[param_idx])
          
    else:
        with torch.no_grad():
            model.eval()
            val_loss = 0
            for val_data, val_label in val_loader:
                worker = data.location.id
                model.send(worker)
                log_probs = model(data)
                loss = criterion(log_probs, label.squeeze())

                curr_loss = loss.get().item()
                val_loss += curr_loss
                model.get()
            
        model.train()
        print(
            "Epoch {e}/{total_e}".format(e=epoch+1, total_e=args.num_epochs+1),
            "Training loss: {:.3f}.. ".format(training_loss/len(train_loader)),
            "Val loss: {:.3f}.. ".format(val_loss/len(val_loader)),
        )