In [None]:
import torch
from torch import optim, nn
from torch.autograd import Variable
import syft as sy
hook = sy.TorchHook(torch)
import pixiedust


In [None]:
@property
def location(self):
    m = self.__getitem__(0)
    w = m.weight[0]
    return w.location

nn.Sequential.location = location

In [None]:
# A Toy Dataset
x = torch.tensor([[0,0,0,0],[1,0,0,0],[0,1,0,0],[0,0,1,0],[1,1,0,0],[1,0,1,0],[0,1,1,0],[1,1,1,0],[0,0,0,1],[1,0,0,1],[0,1,0,1],[0,0,1,1],[1,1,0,1],[1,0,1,1],[0,1,1,1],[1,1,1,1.]])
x.requires_grad_()
target = torch.tensor([[0],[0],[0],[0],[0],[0],[0],[0],[1],[1],[1],[1],[1],[1],[1],[1.]])


#   Variables for performance metrics
epochs = 20
lr = 0.2
counter = 0

# Define 2 chained models
models = [
    nn.Sequential(
        nn.Linear(4, 3),
        nn.Tanh()
    ),
    nn.Sequential(
        nn.Linear(3, 1),
        nn.Sigmoid()
    )
]

# Create optimisers for each segment and link to their segment
optimizers = [
    optim.SGD(params=model.parameters(),lr=lr)
    for model in models
]

# create some workers
alice = sy.VirtualWorker(hook, id="alice")
bob = sy.VirtualWorker(hook, id="bob")
workers = alice, bob

# Send Model Segments and Data to starting locations
model_locations = [alice, bob]

for model, location in zip(models, model_locations):
    model.send(location)
    
x = x.send(models[0].location)
target = target.send(models[1].location)

In [None]:
%%pixie_debugger

def train():
    # Training Logic
    for iter in range(epochs):

        # 1) erase previous gradients (if they exist)
        for opt in optimizers:
            opt.zero_grad()

        # 2) make a prediction
        a  = models[0](x)

        # 3) send the activation signal to the next model
        a_to_send = a.detach()
        remote_a = a_to_send.move(models[1].location)
        # re-enable autograd here
        remote_a.requires_grad_()

        pred =  models[1](remote_a)

        # 3) calculate how much we missed
        loss = ((pred - target)**2).sum()

        # 4) figure out which weights caused us to miss
        loss.backward()
        
        # 5) Backprop gradient to model behind
        grad_a = remote_a.grad.clone()
        grad_a.move(models[0].location)
        
        # 5) This is where it breaks, these are both in the same location
        # and of the same width and length. I think the issue could be that
        # it is looking at the pointer and not the tensor behind the pointer
        print("a: ",len(a), "location: ",a.location)
        print("grad_a: ",len(grad_a),"location: ",grad_a.location)
        a.backward(grad_a)


        # 5) change the weights
        for opt in optimizers:
            opt.step()

        # 6) print our progress
        # Do not use .data
        print(loss.detach())
        
train()