## Federated Experiment - Kaapana PySyft Integration

- Data: Pneunomia (chest x-rays)

- Computing plan: Sequential Training

Each node is running on an individual Kaapana instance and hosts the locally available data. This notebook and the PySyft-Grid are running on the central Kaapana instance.

Please note, this notebook is a simplified version - things like logging where removed for improved readability.

In [1]:
import json
import requests
import logging
import time
from datetime import datetime

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
#from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms, models

from PIL import Image

In [3]:
# PySyft Imports
import syft as sy
from syft.grid.clients.data_centric_fl_client import DataCentricFLClient
from syft.grid.public_grid import PublicGridNetwork

hook = sy.TorchHook(torch)

from utils.datasets import OpenminedDataset

### Training parameter & model architecture

In [None]:
# Arguments for training

class Arguments():
    def __init__(self):
        self.epochs = 40 
        self.batch_size = 8
        self.lr = 10e-4
        self.optimizer = 'SGD'
        self.log_interval = 30

args = Arguments()

In [None]:
# Model architecture (ResNet18) # TODO!!!

class ResNet18(nn.Module):
    def __init__(self, channels=3, num_classes=2):
        super(ResNet18, self).__init__()
        self.channels = channels
        self.num_classes = num_classes

        self.feature_extractor = models.resnet18(pretrained=False)
        self.feature_extractor.load_state_dict(torch.load('resnet18_pretrained_t140/resnet18_pretrained.pt'))
        
        num_ftrs = self.feature_extractor.fc.in_features
        self.feature_extractor.fc = nn.Linear(num_ftrs, self.num_classes)

    def forward(self, x):
        x = self.feature_extractor(x)
        return F.log_softmax(x, dim=1)


### Grid Network

Make sure you started a Grid on central instance (same instance your notebook is running on). Furhtermore, it might be necessary to clean the PySyft-Grid's database before continuing. You can use the Adminer extension to do that.

In [None]:
# Network
GRID_ADDRESS, GRID_PORT = '10.128.129.76', '7000' 
grid = PublicGridNetwork(hook,"http://" + GRID_ADDRESS + ":" + GRID_PORT)

### Start DAGs on remote machines

Call the APIs of the remote instances to start their Data-Providing-Dags. You have to wait a short time until they provide their data - then you should be able so find it by searching for the given experiment-tag.

In [None]:
# set values
DATASET = 'xray'
EXP_TAG = '#xray-exp'
DATA_DIR= 'XRAY-split'
GRID_NETWORK_URL = 'http://10.128.129.76:7000'


participants = [
    ('10.128.129.41', 'hd'),
    ('10.128.129.6', 'co'),
    ('10.128.130.197', 'mu')
]

# trigger data-provider dags on remote machines
for machine, identifier in participants:
    json_data = {
        'rest_call': {
            'global': {
                'hostname': machine,
                'action_operator_dirs':[DATA_DIR],
                'release_name': f'openmined-node-{identifier}'
            },
            'operators': {
                'unzip-file':{
                    'operator_in_dir': DATA_DIR
                },
                'openmined-node': {
                    'global.id': identifier,
                    'port': 5000,
                    'grid_network_url': GRID_NETWORK_URL
                },
                'data-provider': {
                    'dataset': DATASET,
                    'lifespan': 60 * 23, #60 * n_hours
                    'exp_tag': EXP_TAG
                }
                
            }
        }
    }   
    url = f'https://{machine}/flow/kaapana/api/trigger/openmined-provide-data'
    print(url)
    
    r = requests.post(url, json=json_data, verify=False)
    print(r.json())

# Timestamp    
ts_trigger = time.time()
ts_trigger_date = datetime.fromtimestamp(ts_trigger).strftime('%Y-%b-%d-%H-%M-%S')
timestamps.append({
    'description': 'trigger_dags',
    'epoch': '',
    'worker': '',
    'ts': ts_trigger,
    'ts_date': ts_trigger_date
})
logging.info('Triggered Openmined DAGs on workers:\t{}'.format(ts_trigger_date))

When the nodes are setup and the data-providing operators prepared the data, you can find it by searching the gri network. If the results are empty, check the remote machines before proceeding (still, it might need some time until it's up and running):

In [None]:
# search for targets in grid
grid.search('#Y')

In [None]:
# search for images in grid
grid.search('#X')

### Dataset implemented to work with the pointers to the remote data

In [None]:
from torch.utils.data import Dataset

class OpenminedDataset(Dataset):
    '''Openmined Dataset using pointers to remote data instances'''
    
    def __init__(self, img_ptr, label_ptr):
        self.img_ptr = img_ptr
        self.label_ptr = label_ptr
        self.transform = None #transform
    
    def __len__(self):
        return len(self.img_ptr)

    def __getitem__(self, idx):
        '''Return image and corresponding label'''
        img_ptr = self.img_ptr[idx]
        label_ptr = self.label_ptr[idx]
        
        return img_ptr, label_ptr

### Get data references and prepare data loaders

In [None]:
# search data
print('Data:')
print(grid.search("#X", "#xray", "#dataset", "#xray-exp"))
data = grid.search("#X", "#xray", "#dataset", "#xray-exp")

print('\nLabel:')
print(grid.search("#Y", "#xray", "#dataset", "#xray-exp"))
labels = grid.search("#Y", "#xray", "#dataset", "#xray-exp")

# get workers and their locations
workers = {worker : data[worker][0].location for worker in data.keys()}
print('\nWorkers:')
print(workers)

In [None]:
# create dataloaders using the pointers-datasets
dataloaders = dict()

for worker in workers.items():
    name = worker[0]
    dataset = OpenminedDataset(data[name][0],labels[name][0])
    
    dataloaders[name] = DataLoader(
        dataset,
        batch_size= args.batch_size,
        shuffle=True,
        num_workers=0
    )

print(dataloaders)

### Training

In [None]:
# Initialization
model = ResNet18()
optimizer = optim.SGD(model.parameters(), lr=args.lr)

# initil timstamp
ts_initial = time.time()
timestamps.append({
    'description': 'initial',
    'epoch': '',
    'worker': '',
    'ts': ts_initial,
    'ts_date': datetime.fromtimestamp(ts_initial).strftime('%Y-%b-%d-%H-%M-%S')
})

In [None]:
# Run training

print('\n##### RUN MODEL TRAINING #####')

for epoch in range(args.epochs):
    print('# EPOCH: {}'.format(epoch))
    
    # iterate over the remote workers - send model to its location
    for identifier, worker in workers.items():
        model.train()
        model.send(worker)
        loss_acc = 0
    
        # iterate over batches of remote data on current worker
        for batch_idx, (imgs, labels) in enumerate(dataloaders[worker.id]):
            pred = model(imgs)
            loss = F.nll_loss(pred, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_value = loss.get()
            loss_acc += loss_value
            
        # get model and calc avg loss from worker
        model.get()
        loss_avg = loss_acc.item() / len(dataloaders[worker.id].dataset)
        print('Train epoch: {} | Worker: {} | Loss: {:.6}'.format(epoch, worker.id, loss_avg))

In [None]:
# save trained model
torch.save(model.state_dict(), './model_checkpoint.pt')

### Testing model performance on test data

After training is finished, the models performance is tested on the Pneunoia test data (made available in .\data)

In [None]:
# transformation
img_transforms = transforms.Compose([
    transforms.Resize((256, 256), interpolation=Image.NEAREST),
    #transforms.RandomHorizontalFlip(),
    #transforms.RandomVerticalFlip(),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))  # ImageNet values
    ])

In [None]:
# load test dataset
test_data_dir = './data/test'

test_loader = DataLoader(
    dataset=datasets.ImageFolder(root=test_data_dir, transform=img_transforms),
    batch_size=8,
    shuffle=False
)
print('Images:', len(test_loader.dataset))

In [None]:
def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    targets, predictions = [], []
    
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

            targets.extend(target.tolist())
            predictions.extend([item[0] for item in pred.tolist()])


    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.3f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    
    return predictions, targets

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

In [None]:
# load model
model = ResNet18()
model.load_state_dict(torch.load('./model_checkpoint.pt'))

In [None]:
# Run testing
test(model, device, test_loader)