In [None]:
import json
import math
import numpy as np
import os
import time

from sklearn import metrics
from tqdm.notebook import tqdm

import torch
from torch.nn import BCELoss
from torch.optim import RMSprop

import syft as sy
from syft.federated.floptimizer import Optims

from utils import build_model, Config, config_to_dict, EarlyStopping, Metric, Standardizer

In [None]:
hook = sy.TorchHook(torch)

In [None]:
experiment_configuration = {
    'Task': 'Mortality prediction',
    'Approach': 'Federated ML',
    'Classifier': 'Feed-forward network'
}

In [None]:
data_folder = './data/mimic3_17f_24h/'
data_filename = os.path.join(data_folder, 'imputed-normed-ep_1_24.npz')
folds_filename = os.path.join(data_folder, '5-folds.npz')
features_filename = os.path.join(data_folder, 'input.csv')
results_folder = './results/mimic3_17f_24h/'
results_id = 'federated'

In [None]:
if not os.path.exists(data_folder):
    print(f'Wrong data_folder specified. This folder must exist')
    exit(1)

if not os.path.exists(results_folder):
    os.makedirs(results_folder)

In [None]:
config = Config()
config

In [None]:
folds_file = np.load(folds_filename, allow_pickle=True)
folds = folds_file['folds_ep_mor'][config.label_type][0]

data_file = np.load(data_filename, allow_pickle=True)
y = data_file['adm_labels_all'][:, config.label_type]
y = (y > 0).astype(float)

X = np.genfromtxt(features_filename, delimiter=',')

In [None]:
TASK_NAME = 'Mortality'
CLF_NAME = 'FederatedFeedForwardNetwork'

In [None]:
def create_workers(i):
    bob = sy.VirtualWorker(hook, id=f"bob{i}")
    alice = sy.VirtualWorker(hook, id=f"alice{i}")
    james = sy.VirtualWorker(hook, id=f"james{i}")
    return [bob, alice, james]

In [None]:
def distribute_dataset(X, y, train_idx, test_idx, workers):
    tensor_X, tensor_y = torch.Tensor(X), torch.Tensor(y).view(-1, 1)

    num_train = len(train_idx)
    split = int(np.floor(config.validation_split * num_train))
    train_idx, valid_idx = train_idx[split:], train_idx[:split]
    indices = [train_idx, valid_idx, test_idx]
    tags = ['train', 'valid', 'test']
    
    for idx, tag in zip(indices, tags):
        split_per_worker = math.ceil(len(tensor_X[idx]) / len(workers))
        split_X = torch.split(tensor_X[idx], split_per_worker, dim=0)
        split_y = torch.split(tensor_y[idx], split_per_worker, dim=0)

        for i in range(len(workers)):
            tag_X = split_X[i].tag("#X", f"#{tag}").describe("")
            tag_y = split_y[i].tag("#Y", f"#{tag}").describe("")
        
            tag_X.send(workers[i], garbage_collect_data=False)
            tag_y.send(workers[i], garbage_collect_data=False)
    
    return sy.PrivateGridNetwork(*workers)

In [None]:
def collect_datasets(grid):
    loaders = []
    tags = ['train', 'valid', 'test']
    for tag in tags:
        found_X = grid.search("#X", f"#{tag}")
        found_y = grid.search("#Y", f"#{tag}")
    
        datasets = []
        for worker in found_X.keys():
            datasets.append(sy.BaseDataset(found_X[worker][0], found_y[worker][0]))

        dataset = sy.FederatedDataset(datasets)
        loaders.append(sy.FederatedDataLoader(dataset, batch_size=config.batch_size))
    
    return loaders

In [None]:
def train(model, train_loader, valid_loader, workers):
    criterion = BCELoss() # binary cross-entropy
    # for RMSprop in PySyft each worker needs its own optimizer
    worker_ids = [worker.id for worker in workers]
    optims = Optims(worker_ids, optim=RMSprop(model.parameters(), lr=config.learning_rate))
    early_stopping = EarlyStopping(patience=config.early_stopping_patience)
    
    for epoch in tqdm(range(config.epochs)):
        
        model.train()
        for data, target in train_loader:
            model.send(data.location)
            
            opt = optims.get_optim(data.location.id)
            opt.zero_grad()
            
            output = model(data)
            
            loss = criterion(output, target)
            loss.backward()

            opt.step()
            model.get()

        model.eval() 
        valid_losses = []
        for data, target in valid_loader:
            model.send(data.location)
            
            output = model(data)
            loss = criterion(output, target)
            valid_losses.append(loss.get().item())
            
            model.get()
        valid_loss = np.average(valid_losses)
        
        if early_stopping.should_early_stop(valid_loss, model):
            break
    
    model.load_state_dict(early_stopping.best_model_state)
    
    return model, epoch + 1

In [None]:
def predict(model, data_loader):
    model.eval()
    
    num_elements = sum([len(data) for data, _ in data_loader])
    
    predictions = torch.zeros(num_elements)
    targets = torch.zeros(num_elements)
    
    start = 0
    for data, target in data_loader:
        
        target = target.get().view(-1)
        end = start + len(target)
        
        targets[start:end] = target
        
        model.send(data.location)
        with torch.no_grad():
            output = model(data)
            predictions[start:end] = output.get().view(-1)
        model.get()
        start = end
    return predictions, targets

In [None]:
metric_list = [
    Metric('Accuracy', metrics.accuracy_score, use_soft=False),
    Metric('Precision', metrics.precision_score, use_soft=False),
    Metric('Recall', metrics.recall_score, use_soft=False),
    Metric('F1 score', metrics.f1_score, use_soft=False),
    Metric('ROC AUC', metrics.roc_auc_score, use_soft=True),
    Metric('Average precision', metrics.average_precision_score, use_soft=True),
]

In [None]:
time_measurements = {t: [] for t in ['collecting_datasets', 'training', 'training_per_epoch', 'prediction']}

In [None]:
for i, (train_idx, valid_idx, test_idx) in enumerate(folds):
    train_idx = np.concatenate((train_idx, valid_idx))
    
    standardizer = Standardizer()
    standardizer.fit(X[train_idx])
    X_transformed = standardizer.transform(X)
    
    workers = create_workers(i)
    
    grid = distribute_dataset(X_transformed, y, train_idx, test_idx, workers)
    
    start = time.time()
    train_loader, valid_loader, test_loader = collect_datasets(grid)
    time_measurements['collecting_datasets'].append(time.time() - start)

    model = build_model(config, n_features=X_transformed.shape[1])
    
    start = time.time()
    model, finished_epochs = train(model, train_loader, valid_loader, workers)
    training_time = time.time() - start
    time_measurements['training'].append(training_time)
    time_measurements['training_per_epoch'].append(training_time / finished_epochs)
    
    start = time.time()
    y_soft, y_true = predict(model, test_loader)
    y_pred = (y_soft > 0.5).type(torch.int)
    time_measurements['prediction'].append(time.time() - start)

    for metric in metric_list:
        if metric.use_soft:
            score = metric.function(y_true, y_soft)
        else:
            score = metric.function(y_true, y_pred)
        metric.scores.append(score)
    
    del train_loader, valid_loader, test_loader, workers, grid

In [None]:
def create_summary(configuration, metric_list, time_measurements):
    summary = ''
    for label, value in configuration.items():
        summary += f'{label + ":": <20} {value}\n'
    summary += '\nMETRICS\n'
    for metric in metric_list:
        mean, std = np.mean(metric.scores), np.std(metric.scores)
        summary += f'{metric.name + ":": <20} {mean:.5f} ± {std:.5f}\n'
    summary += '\nTIME MEASUREMENTS\n'
    for label, times in time_measurements.items():
        mean, std = np.mean(times), np.std(times)
        summary += f'{label+":": <20} {mean:.5f} ± {std:.5f}\n'
    return summary

In [None]:
summary = create_summary(experiment_configuration, metric_list, time_measurements)
print(summary)

In [None]:
summary_filename = os.path.join(results_folder, f'{results_id}_summary.txt')
with open(summary_filename, 'w') as f:
    f.write(summary)

In [None]:
results = {
    'experiment_configuration': experiment_configuration,
    'metrics': {m.name: m.scores for m in metric_list},
    'time_measurements': time_measurements,
    'model_configuration': config_to_dict(config)
}

In [None]:
results_filename = os.path.join(results_folder, f'{results_id}_results.json')
with open(results_filename, 'w') as f:
    json.dump(results, f, indent=4)