In [1]:
from collections import OrderedDict
from typing import List

from sklearn.preprocessing import StandardScaler
import flwr as fl
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import ray

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


In [2]:
class FraudDataset(Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __getitem__(self, index):
        return self.x[index], self.y[index]

    def __len__(self):
        return len(self.x)

In [3]:
client1_args = {
    'train_split': 2000, 'initial_split': 1, 'test_split': 3000, 'batch_size': 32, 'label': 30, 'columns': [*range(0,25)]
}
client2_args = {
    'train_split': 5000, 'initial_split': 3000, 'test_split': 6000, 'batch_size': 32, 'label': 30, 'columns': [*range(5,30)]
}

def load_data(path, initial_split, train_split, test_split, columns, batch_size=32, label=30): # "data/creditcard.csv", 2000, 3000, 1:30
  df = pd.read_csv(path)
  x_train = df.iloc[initial_split:train_split, columns].values
  y_train = df.iloc[initial_split:train_split, label].values
  sc = StandardScaler()
  x_train = sc.fit_transform(x_train)
  x_test = df.iloc[train_split:test_split, columns].values
  x_test = sc.transform(x_test)
  y_test = df.iloc[train_split:test_split, label].values
  trainset = FraudDataset(x_train, y_train)
  testset = FraudDataset(x_test, y_test)
  trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
  valloader = DataLoader(testset, batch_size=batch_size)
  return trainloader, valloader

train1, test1 = load_data("data/creditcard.csv", **client1_args)
train2, test2 = load_data("data/creditcard.csv", **client2_args)

trainloaders = [train1, train2]
valloaders = [test1, test2]
shared_columns = [col for col in client1_args['columns'] if col in client2_args['columns']]
client1_args['ind_columns'] = [col for col in client1_args['columns'] if col not in shared_columns]
client2_args['ind_columns'] = [col for col in client2_args['columns'] if col not in shared_columns]
args = [client1_args, client2_args]

In [4]:
def train(shared_model, ind_model, agg_model, shared_opt, ind_opt, agg_opt, trainloader, epochs, ind_columns):
    criterion = nn.BCELoss()
    for _ in range(epochs):
        for x, y in trainloader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            print(x[shared_columns].size())
            shared_opt.zero_grad()
            ind_opt.zero_grad()
            agg_opt.zero_grad()
            shared_outputs = shared_model(x[shared_columns])
            ind_outputs = ind_model(x[ind_columns])
            outputs = agg_model(torch.cat((shared_outputs, ind_outputs), dim=1))
            loss = criterion(outputs, y)
            loss.backward()
            shared_outputs.sum().backward()
            ind_outputs.sum().backward()
            shared_opt.step()
            ind_opt.step()
            agg_opt.step()

def test(shared_model, ind_model, agg_model, valloader, ind_columns):
    """Validate the modelwork on the entire test set."""
    criterion = nn.BCELoss()
    loss = 0.0
    tp, fp, tn, fn = 0, 0, 0, 0
    with torch.no_grad():
        for x, y in valloader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            print(x[shared_columns].size())
            shared_outputs = shared_model(x[shared_columns])
            ind_outputs = ind_model(x[ind_columns])
            outputs = agg_model(torch.cat((shared_outputs, ind_outputs), dim=1))
            loss += criterion(outputs, y).item()
            pred = round(float(outputs.get()[0]))
            lab = float(y.get()[0])
            # Collect statistics
            tp += (pred and lab)
            fp += (pred and not lab)
            tn += (not pred and not lab)
            fn += (not pred and lab)
    f1_score = tp / (tp + (fp + fn)/2)
    return loss, f1_score

In [5]:
def get_parameters(net) -> List[np.ndarray]:
    return [val.cpu().numpy() for _, val in net.state_dict().items()]

def set_parameters(net, parameters: List[np.ndarray]):
    params_dict = zip(net.state_dict().keys(), parameters)
    state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
    net.load_state_dict(state_dict, strict=True)

In [6]:
class FlowerClient(fl.client.NumPyClient):
    def __init__(self, shared_model, trainloader, valloader, ind_model, agg_model, shared_opt, ind_opt, agg_opt, columns):
        self.shared_model = shared_model
        self.ind_model = ind_model
        self.agg_model = agg_model
        self.shared_opt = shared_opt
        self.ind_opt = ind_opt
        self.agg_opt = agg_opt
        self.trainloader = trainloader
        self.columns = columns
        self.valloader = valloader

    def get_parameters(self):
        return get_parameters(self.net)

    def fit(self, parameters, config):
        set_parameters(self.shared_model, parameters)
        train(self.shared_model, self.ind_model, self.agg_model, self.shared_opt, self.ind_opt, self.agg_opt, self.trainloader, 1, self.columns)
        return get_parameters(self.shared_model), len(self.trainloader), {}

    def evaluate(self, parameters, config):
        set_parameters(self.shared_model, parameters)
        loss, f1_score = test(self.shared_model, self.ind_model, self.agg_model, self.valloader, self.columns)
        return float(loss), len(self.valloader), {"f1_score": float(f1_score)}

In [7]:
class SplitNN(nn.Module):
    def __init__(self, sizes) -> None:
        super(SplitNN, self).__init__()
        self.input_size = sizes[0]
        self.output_size = sizes[-1]
        self.lin1 = nn.Linear(sizes[0], sizes[1])
        self.lin2 = nn.Linear(sizes[1], sizes[2])

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = F.relu(self.lin1(x))
        x = F.relu(self.lin2(x))
        x = torch.sigmoid(x) if self.output_size == 1 else x
        return x

In [8]:
def client_fn(cid: str) -> FlowerClient:
    """Create a Flower client representing a single organization."""
    print(f'Creating client {cid}')
    # Load model
    shared_model = SplitNN([len(shared_columns), 64, 32])
    ind_model = SplitNN([len(args[int(cid)]['ind_columns']), 32, 32])
    agg_model = SplitNN([64, 16, 1])
    shared_opt = torch.optim.SGD(shared_model.parameters(), lr=0.001, momentum=0.9)
    ind_opt = torch.optim.SGD(ind_model.parameters(), lr=0.001, momentum=0.9)
    agg_opt = torch.optim.SGD(agg_model.parameters(), lr=0.001, momentum=0.9)

    trainloader = trainloaders[int(cid)]
    valloader = valloaders[int(cid)]

    # Create a  single Flower client representing a single organization
    return FlowerClient(shared_model, trainloader, valloader, ind_model, agg_model, shared_opt, ind_opt, agg_opt, args[int(cid)]['ind_columns'])

In [9]:
# Create FedAvg strategy
strategy = fl.server.strategy.FedAvg(
        fraction_fit=1.0,  # Sample 100% of available clients for training
        fraction_eval=0.5,  # Sample 50% of available clients for evaluation
        min_fit_clients=2,  # Never sample less than 10 clients for training
        min_eval_clients=1,  # Never sample less than 5 clients for evaluation
        min_available_clients=2,  # Wait until all 10 clients are available
)

# Start simulation
fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=2,
    num_rounds=5,
    strategy=strategy,
)