## Named entity recognition with federated learning 
Simple example in implementing a biLSTM model for NER, using PyTorch and PySyft. 

NOTE: This is a work in progress. Currently having some trouble getting federated LSTM to work.

In [1]:
from argparse import Namespace
import math
import numpy as np
import os
import pickle
import torch 
from tqdm.notebook import tqdm
import syft as sy

import torch.nn as nn
#from syft.frameworks.torch.nn import LSTM
import torch.nn.functional as F
import torch.optim as optim

### 1. Setup

#### a. Define filepaths and selected variables

In [2]:
args = Namespace(
    # Data and path
    train_path = "./data/eng.train",
    embeddings_path = "./data/glove.840B.300d.txt", 
    pad_token = "<PAD>",
    unk_token = "<UNK>",
    batch_size=32,
    
    # Model 
    embedding_dim = 300,
    lstm_dim = 10, #todo
    num_lstm_layers = 1,
    
    # Training
    seed=42,
    num_epochs=30,
    lr=0.001
)

In [3]:
torch.manual_seed(args.seed)

<torch._C.Generator at 0x10c1894d0>

#### b. Setup PySyft

In [4]:
hook = sy.TorchHook(torch)

In [5]:
alice = sy.VirtualWorker(hook, id="alice")
bob = sy.VirtualWorker(hook, id="bob")
carol = sy.VirtualWorker(hook, id="carol")

#### c. Prepare data
CoNLL 2003: https://www.clips.uantwerpen.be/conll2003/ner/
* This notebook only uses the English data.
* The raw data are a series of text files.
    * Each line looks like this: "EU NNP I-NP I-ORG"
    * Here we are only interested in the actual word ("EU") and the NER tag ("I-ORG")
* Process raw file into tensors, then into federated datasets

In [6]:
def extract_sentences(filepath, is_cased=True):
    """ Process data in the CoNLL-2003 format into words/labels per sentence. """
    sents, labels = [], []
    curr_sent_words, curr_sent_labels = [], []
    with open(filepath, 'r') as f:
        lines = f.readlines()
    
    for line in lines:
        if not is_cased:
            line = line.lower()
        if line == "\n":
            sents.append(curr_sent_words)
            labels.append(curr_sent_labels)
            curr_sent_words, curr_sent_labels = [], []
        else:
            elements = line.strip().split(" ")
            word = elements[0]
            label = elements[-1]
            curr_sent_words.append(word)
            curr_sent_labels.append(label)
    
    return sents, labels


def process_data(sents, labels, word2id, label2id, max_length, pad_token, unknown_token,
                 is_train=True):
    """ Processes data in the CoNLL 2003 format into torch tensors.
    
    Args:
      sents:         List of sentences (each a list of words)
      labels:        List of list of labels 
      word2id:       Dictionary mapping words to int indices
      label2id:      Dictionary mapping labels to int indices
      max_length:    Target length that we want to pad sequences to.
      pad_token:     String for padding.
      unknown_token: String to replace OOV words.      
      is_train:      Is processing for training data or not. 
                     If true, updates the dictionaries with each value.
                     If false, unseen words will be mapped to <UNK>.
    
    Returns:
      word2id:       Word-to-indices mapping. If is_train, it will be updatd.
      label2id:      Label-to-indices mapping. If is_train, it will be updated. 
      X:             Tensor of processed sentences.
      y:             Tensor of processed labels.
    """
    X, y = [], []
    if is_train:
        pad_id, unk_id = 0, 1
        word2id[pad_token] = pad_id
        label2id[pad_token] = pad_id
        word2id[unknown_token] = unk_id
    else:
        pad_id = word2id[pad_token]
        unk_id = word2id[unknown_token]
    
    for i, sent in enumerate(sents):
        X_i, y_i = [], []
        curr_labels = labels[i]
        sent = sent[:max_length]
        diff = max_length - len(sent)
        for word, label in zip(sent, curr_labels):
            if is_train:
                # If preparing training data, update mappings as we go
                word_id = word2id.get(word, None)
                if not word_id:
                    word_id = len(word2id)
                    word2id[word] = word_id
                
                label_id = label2id.get(label, None)
                if not label_id:
                    label_id = len(label2id)
                    label2id[label] = label_id
            else:
                # Otherwise, fetch id from existing mappings
                word_id = word2id.get(word, word2id[unknown_token])
                label_id = label2id.get(label)
                if not label_id:
                    raise LookupError(f'Unseen label {label}')
            X_i.append(word_id)
            y_i.append(label_id) 
            
        # Pad sequences to max length
        X_i.extend([pad_id] * diff)
        y_i.extend([pad_id] * diff)
        
        X.append(X_i)
        y.append(y_i)
            
    X = torch.tensor(X)   
    y = torch.tensor(y)
    
    return word2id, label2id, X, y

In [7]:
sents, labels = extract_sentences(args.train_path)

In [8]:
# Get maximum sequence length based on loaded sentences
max_train_sent_length = max(len(sent) for sent in sents)
max_length = 2**math.ceil(math.log2(max_train_sent_length))
max_length

128

In [9]:
word2id, label2id = dict(), dict()
word2id, label2id, X, y = process_data(
    sents=sents, labels=labels, word2id=word2id, label2id=label2id, 
    max_length=max_length, pad_token=args.pad_token, unknown_token=args.unk_token
)
vocab_size = len(word2id)
num_labels = len(label2id)

In [10]:
X.shape, y.shape

(torch.Size([14986, 128]), torch.Size([14986, 128]))

We can now transform this into PySyft's federated dataset.

Embed words
* Pre-trained word embeddings. Here we use common crawl glove cased, 840B/300d. 
    * Source: https://nlp.stanford.edu/projects/glove/

In [11]:
# TODO: load pretrained embeddings as frozen weights
if False: 
    word2vec_path = './word2vec.pkl'
    if os.path.exists(word2vec_path):
        with open(word2vec_path, 'rb') as f:
            word2vec = pickle.load(f)
    else:
        word2vec = dict()
        with open(args.embeddings_path, 'r') as f:
            lines = f.readlines()
        for line in tqdm(lines):        
            split_line = line.split(" ")
            word = split_line[0]
            vec = torch.tensor([float(v) for v in split_line[1:]])

            word2vec[word] = vec
        word2vec[args.unk_token] = torch.randn(300)
        #word2vec[args.unk_token] = torch.mean(, axis=0) # TODO: replace with mean
        word2vec_path = './word2vec.pkl'
        with open(word2vec_path, 'wb') as f:
            pickle.dump(word2vec, f)
    #pretrained_embeddings = torch.zeros((vocab_size, args.embedding_dim))

### 2. Define model

In [12]:
class biLSTM(nn.Module):
    def __init__(self, args, word2id, num_labels):
        super(biLSTM, self).__init__()
        
        self.embedding = nn.Embedding(
            num_embeddings=len(word2id),
            embedding_dim=args.embedding_dim,
            padding_idx=word2id[args.pad_token]
        )
        
        self.lstm = nn.LSTM(              
            input_size=args.embedding_dim,
            hidden_size=args.lstm_dim,
            num_layers=args.num_lstm_layers,
            batch_first=True,
            bidirectional=True
        )
        
        self.fc = nn.Linear(
            in_features=args.lstm_dim*2,  # bidirectional 
            out_features=num_labels
        ) 
        
        self.softmax = nn.LogSoftmax(dim=1)
        
    def forward(self, x):
        x_embedded = self.embedding(x)
        y_out, _ = self.lstm(x_embedded)
        
        batch_size, max_len, feat_size = y_out.shape
        y_out = y_out.contiguous().view(batch_size * max_len, feat_size)
        
        y_out = self.fc(y_out)
        new_feat_size = y_out.shape[-1]
        y_out = y_out.view(batch_size, max_len, new_feat_size)
        
        y_out = F.log_softmax(y_out.view(max_len, -1), 1)
        
        return y_out

Status: 
* Stuck with size issue at LSTM layer.
    * Seems to be known issue: e.g., https://github.com/OpenMined/PySyft/issues/3010


Details:
* I first thought it was the embedding layer's problem, but it turns out even if I 'manually' embedded vectors as input the problem stays the same. 
* I've also tried (with no luck) with:
    * Toy dataset vs actual data
    * Federated dataset vs loading data 'manually' and sending components to the workers
    * Switching the layer implementations:
        * With out-of-the-box PyTorch, I get:
            * RuntimeError: input.size(-1) must be equal to input_size. Expected 300, got 0
        * With the Syft implementation, I get:
            * The expanded size of the tensor (40) must match the existing size (0) at non-singleton dimension 1.  Target sizes: [16, 40].  Tensor sizes: [0]
* I haven't tried:
    * Implementing LSTM from scratch :/
    * Different machine/os
    * TF instead of PyTorch


In [13]:
model = biLSTM(args=args, word2id=word2id, num_labels=len(label2id))
criterion = nn.NLLLoss()

# Toy dataset
_X = torch.stack([
    torch.randint(low=0, high=vocab_size, size=(max_length,)) 
    for _ in range(args.batch_size)]
)
_y = torch.stack([
        torch.randint(low=0, high=len(label2id), size=(max_length,)) 
        for _w in range(args.batch_size)
    ])

is_federated = False
if is_federated:
    
    alice.clear_objects()
    bob.clear_objects()
    
    base = sy.BaseDataset(_X, _y)
    dataset = base.federate((alice, bob))
    train_loader = sy.FederatedDataLoader(
        federated_dataset=dataset, 
        batch_size=args.batch_size, 
        shuffle=False                                   # <- TODO
    )
    optimizers = { 
        worker: optim.Adam(params=model.parameters(),lr=args.lr) 
        for worker in dataset.workers
    }
else:
    #dataset = torch.utils.data.TensorDataset(_X, _y)
    dataset = torch.utils.data.TensorDataset(X[:100], y[:100])
    train_loader = torch.utils.data.DataLoader(dataset)
    optimizer = optim.Adam(params=model.parameters(),lr=args.lr)

In [14]:
# Attempts at getting syft to work. 
# Snippet is far from complete, but it doesn't matter as 
# the attempts at using Syft generally fails in the forward pass, 
# at self.lstm(x_embedded) stage.

for epoch in range(args.num_epochs)[:5]:  # <- TODO

    for data, label in train_loader:
        
        if is_federated:                   # <- size problem with federated dataset
            worker = data.location.id
            model.send(worker)
        
            optimizers[worker].zero_grad()
        elif False:                        # <- attempt 2; size problem with manual/'vanilla' syft
            model.send(bob)
            data_ptr = data.send(bob)
            label_ptr = label.send(bob)

            optimizer.zero_grad()

            log_probs = model(data_ptr)
            loss = criterion(log_probs, label_ptr.squeeze())
            loss.backward()
            optimizer.step()

            model.get() 
        else:                               # regular pytorch runs with no problems
            optimizer.zero_grad()
            log_probs = model(data)
            loss = criterion(log_probs, label.squeeze())
            loss.backward()
            optimizer.step()
         
        print(loss.item())

2.401567220687866
2.3967275619506836
2.3915979862213135
2.3219316005706787
2.3378689289093018
2.311460018157959
2.328676223754883
2.281559467315674
2.3150763511657715
2.2749524116516113
2.292695999145508
2.320054292678833
2.2991244792938232
2.2656478881835938
2.255303144454956
2.2499170303344727
2.224112033843994
2.2396445274353027
2.2546756267547607
2.250839948654175
2.255213737487793
2.238948345184326
2.230776309967041
2.207616090774536
2.2005298137664795
2.1686594486236572
2.1926522254943848
2.1606662273406982
2.174736499786377
2.1863174438476562
2.172788381576538
2.1629626750946045
2.1268396377563477
2.1052889823913574
2.12015700340271
2.11755108833313
2.0891902446746826
2.0863113403320312
2.0802483558654785
2.1063225269317627
2.0888566970825195
2.082275152206421
2.041229724884033
2.036936044692993
2.062511682510376
2.052417516708374
2.038029193878174
2.032695770263672
2.011843681335449
2.011183023452759
2.004842519760132
1.9568209648132324
1.9900898933410645
1.95694100856781
1.968

0.12757723033428192
0.20822715759277344
0.09565652906894684
0.1372678577899933
0.1433940827846527
0.036021482199430466
0.015365582890808582
0.05733690783381462
0.03325024992227554
0.2138832062482834
0.20989006757736206
0.14087052643299103
0.03496502339839935
0.1119387224316597
0.18437440693378448
0.1529606282711029
0.014853940345346928
0.053554367274045944
0.03234286606311798
0.06867146492004395
0.2584233582019806
0.01460261084139347
0.03834372013807297
0.03379980847239494
0.14894051849842072
0.03523169830441475
0.0343535915017128
0.028618214651942253
0.08161397278308868
0.05013946443796158
0.05824936926364899
0.1089237853884697
0.044493384659290314
0.06872252374887466
0.040611691772937775
0.03271442651748657
0.013867395929992199
0.047226738184690475
0.03199184685945511
0.13978931307792664
0.08212792873382568
0.10048885643482208
0.013571717776358128
0.05665452778339386
0.026988845318555832
0.06389883905649185
0.053573984652757645
0.04428631067276001
0.04052726924419403
0.04342299327254