Minimal example for some issues I'm having with the shape requirements.
* Seems to be known issue with size, e.g., https://github.com/OpenMined/PySyft/issues/3010

Problem - Can't meet the input_size requirements in recurrent layers (RNN, GRU, LSTM, etc). 
* With out-of-the-box PyTorch of nn.LSTM, I get:
    * RuntimeError: input.size(-1) must be equal to input_size. Expected 50, got 0
* With the Syft implementation of nn.LSTM, I get:
    * RuntimeError: The expanded size of the tensor (40) must match the existing size (0) at non-singleton dimension 1.  Target sizes: [4, 40].  Tensor sizes: [0]

In [1]:
from argparse import Namespace
import torch 
import syft as sy

import torch.nn as nn
from syft.frameworks.torch.nn import LSTM as syftLSTM
import torch.nn.functional as F
import torch.optim as optim

Set up args + toy datasets

In [2]:
args = Namespace(
    # Data and path
    batch_size = 8,
    max_len = 10, 
    
    # Model 
    embedding_dim = 50,  
    lstm_dim = 10,        
    num_lstm_layers = 1,
    
    # Training
    num_epochs=1,   
)

In [3]:
word2id = {x:i for i,x in enumerate("abcdefghij")}
labels2id = {x:i for i,x in enumerate("XYZ")}

vocab_size = len(word2id)
num_labels = len(labels2id)

In [4]:
X = torch.stack([
    torch.randint(low=0, high=vocab_size, size=(args.max_len,)) 
    for _ in range(args.batch_size)]
)
y = torch.stack([
        torch.randint(low=0, high=num_labels, size=(args.max_len,)) 
        for _w in range(args.batch_size)
    ])
X.shape, y.shape

(torch.Size([8, 10]), torch.Size([8, 10]))

In [5]:
class NN(nn.Module):
    def __init__(self, args, vocab_size, num_labels):
        super(NN, self).__init__()

        self.embedding = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=args.embedding_dim
        )
        
        self.lstm = nn.LSTM(              
            input_size=args.embedding_dim,
            hidden_size=args.lstm_dim,
            num_layers=args.num_lstm_layers,
            batch_first=True,
            bidirectional=True
        )
        
        self.fc = nn.Linear(
            in_features=args.lstm_dim*2,  # bidirectional 
            out_features=num_labels
        ) 
        
        self.softmax = nn.LogSoftmax(dim=1)
        
    def forward(self, x):
        x = x.long()
        x_embedded = self.embedding(x)
        assert(x_embedded.shape[-1] == args.embedding_dim) # <- can't run this in FL fashion
        y_out, _ = self.lstm(x_embedded)
        
        batch_size, max_len, feat_size = y_out.shape
        y_out = y_out.contiguous().view(batch_size * max_len, feat_size)
        
        y_out = self.fc(y_out)
        new_feat_size = y_out.shape[-1]
        y_out = y_out.view(batch_size, max_len, new_feat_size)
        
        y_out = F.log_softmax(y_out.view(max_len, -1), 1)
        
        return y_out

#### a. Regular PyTorch

In [6]:
dataset = torch.utils.data.TensorDataset(X, y)
train_loader = torch.utils.data.DataLoader(dataset)

In [7]:
model = nn.Linear(vocab_size, 1)                                    # runs normally
for epoch in range(args.num_epochs):  
    for data, label in train_loader:
        yhat = model(data.float())

In [8]:
model = NN(args=args, vocab_size=vocab_size, num_labels=num_labels) # runs normally
for epoch in range(args.num_epochs):  
    for data, label in train_loader:
        log_probs = model(data)

#### b. Setup PySyft

In [9]:
hook = sy.TorchHook(torch)

In [10]:
alice = sy.VirtualWorker(hook, id="alice")
bob = sy.VirtualWorker(hook, id="bob")

alice.clear_objects()
bob.clear_objects()

<VirtualWorker id:bob #objects:0>

In [11]:
# Linear layer without using FederatedDataLoader: runs normally
model = nn.Linear(vocab_size, 1)

alices_data = X[:4,:].float().send(alice)
bobs_data = X[4:,:].float().send(bob)

alices_model = model.copy().send(alice)
bobs_model = model.copy().send(bob)

alices_pred = alices_model(alices_data)
bobs_pred = bobs_model(bobs_data)

In [12]:
# Recurrent model without using FederatedDataLoader: fails size check at recurrent layer
model = NN(args=args, vocab_size=vocab_size, num_labels=num_labels)

alice.clear_objects()
bob.clear_objects()

alices_data = X[:4,:].send(alice)
bobs_data = X[4:,:].send(bob)

alices_model = model.copy().send(alice)
bobs_model = model.copy().send(bob)

alices_pred = alices_model(alices_data)
bobs_pred = bobs_model(bobs_data)

RuntimeError: input.size(-1) must be equal to input_size. Expected 50, got 0

In [13]:
alice.clear_objects()
bob.clear_objects()

base = sy.BaseDataset(X, y)
federated_loader = sy.FederatedDataLoader(
    federated_dataset=base.federate((alice, bob)), 
    batch_size=args.batch_size, 
    shuffle=False                                          
)

model = nn.Linear(vocab_size, 1)                                      # fails in linear layer 
for data, label in federated_loader:
    model.send(data.location.id)
    assert(model.location.id==data.location.id)
    assert(data.shape==(args.batch_size//2, vocab_size))
    
    model(data)

RuntimeError: Expected object of scalar type Float but got scalar type Long for argument #2 'mat1' in call to _th_addmm

In [14]:
alice.clear_objects()
bob.clear_objects()

base = sy.BaseDataset(X, y)
federated_loader = sy.FederatedDataLoader(
    federated_dataset=base.federate((alice, bob)), 
    batch_size=args.batch_size, 
    shuffle=False                                          
)

model = NN(args=args, vocab_size=vocab_size, num_labels=num_labels)   # fails to get correct size again
for data, label in federated_loader:
    model.send(data.location.id)
    assert(model.location.id==data.location.id)
    assert(data.shape==(args.batch_size//2, vocab_size))
    
    model(data)

RuntimeError: input.size(-1) must be equal to input_size. Expected 50, got 0