In [1]:
import syft as sy
import torch
from tools import models
import numpy as np

In [None]:
duet = sy.join_duet(loopback=True)

In [None]:
train_data_ptr = duet.store[0]
train_labels_ptr = duet.store[1]

test_data_ptr = duet.store[2]
test_labels_ptr = duet.store[3]

In [None]:
local_model = models.Deep2DNet(torch)
remote_model = local_model.send(duet)
remote_torch = duet.torch

params = remote_model.parameters()
optim = remote_torch.optim.Adam(params=params, lr=0.001)
loss_function = remote_torch.nn.CrossEntropyLoss()

cuda_available = remote_torch.cuda.is_available().get(request_block=True, reason='Gimme!')

if cuda_available:
    device = remote_torch.device('cuda:0')
    remote_model.cuda(device)
else:
    device = remote_torch.device('cpu')
    remote_model.cpu()

In [None]:
from tools.utils import train

train(300, 1, remote_model, remote_torch, optim, loss_function,
      train_data_ptr, train_labels_ptr, test_data_ptr, test_labels_ptr, [1, 64, 64], device)

In [None]:
def train(batch_size, epochs, model, torch_ref, optim, loss_function, data, labels, input_shape):
    length = len(data)
    
    if length % batch_size != 0:
        cut_data = data[:length - length % batch_size]
        cut_labels = labels[:length - length % batch_size]
        
    shape = [-1, batch_size]
    shape.extend(input_shape)
    
    batch_data = cut_data.view(shape).to(device)
    batch_labels = cut_labels.view(-1, batch_size).to(device)
    
    for epoch in range(epochs):
        model.train()
        
        print(f'###### Epoch {epoch + 1} ######')
        for i in range(int(length / batch_size)):
            optim.zero_grad()
            
            output = model(batch_data[i])
            
            loss = loss_function(output, batch_labels[i])
            loss_item = loss.item()
            loss_value = loss_item.get_copy(reason="To evaluate training progress", request_block=True, timeout_secs=5)
            print(f'Training Loss: {loss_value}')
        
            loss.backward()
            optim.step()
        
        test(model, loss_function, torch_ref, test_data_ptr, test_labels_ptr)
                   
def test(model, loss_function, remote_torch, data, labels):
    model.eval()
    
    data = data.to(device)
    labels = labels.to(device)
    length = len(data)
    
    with remote_torch.no_grad():
        output = model(data)
        test_loss = loss_function(output, labels)
        prediction = output.argmax(dim=1)
        total = prediction.eq(labels).sum().item()
        
    acc_ptr = total / length
    acc = acc_ptr.get(request_block=True, reason='Gimme!')
    loss = test_loss.get(request_block=True, reason='Gimme!')
    
    print(f'Test Accuracy: {acc} --- Test Loss: {loss}')

train(300, 1, remote_model, remote_torch, optim, loss_function, train_data_ptr, train_labels_ptr, [1, 64, 64])

In [None]:
duet.python