In [1]:
!pip install stadle_client
!pip install nest-asyncio
import nest_asyncio
nest_asyncio.apply()

!mkdir content
!cd content
!mkdir logs
!cd ../..

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting stadle_client
  Downloading stadle_client-0.0.6-py3-none-any.whl (28 kB)
Collecting websockets==8.1
  Downloading websockets-8.1-cp38-cp38-manylinux2010_x86_64.whl (78 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m78.4/78.4 KB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
Collecting getmac
  Downloading getmac-0.9.1-py2.py3-none-any.whl (35 kB)
Collecting pyblas
  Downloading pyblas-0.0.10.tar.gz (26 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting pyOpenSSL
  Downloading pyOpenSSL-23.0.0-py3-none-any.whl (57 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m57.3/57.3 KB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m
Collecting python-xz
  Downloading python_xz-0.4.0-py3-none-any.whl (20 kB)
Collecting tk
  Downloading tk-0.1.0-py3-none-any.whl (3.9 kB)
Collecting cryptography<40,>=38.0.0
  Downloading cryptography-39.0.0-cp36

In [2]:
import pandas as pd
import numpy as np

import torch
from torch.nn import Sequential, Linear, ReLU, Sigmoid, BCELoss
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader

import ipywidgets as widgets

In [3]:
trained_models = []

In [4]:
class FraudDataset(Dataset):
    def __init__(self, X_df, y_df):
        self.X = X_df.to_numpy().astype(np.float32)
        self.y = y_df.to_numpy().astype(np.float32)
      
    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return (self.X[idx], self.y[idx])


get_model = lambda : Sequential(
    Linear(4,8),
    ReLU(),
    Linear(8,1),
    Sigmoid()
)


def train_model(model, dataloader, epochs):
    criterion = BCELoss()
    optimizer = Adam(model.parameters(), lr=1e-3)

    cum_loss = 0.0

    for epoch in range(epochs):
        cum_loss = 0.0
        for i, (X,y) in enumerate(dataloader):
            pred_y = model(X)
            loss = criterion(pred_y, y)
            loss.backward()
            optimizer.step()
            cum_loss += loss.item()
        cum_loss /= len(dataloader)

    return model, cum_loss


def test_model(model, dataloader):
    correct_preds = 0.0
    total_preds = 0.0
    for i, (X,y) in enumerate(dataloader):
        y_pred = (model(X) > 0.5).type(torch.FloatTensor)
        correct_preds += torch.sum(y_pred == y).item()
        total_preds += len(y)

    return correct_preds/total_preds

In [5]:
dataloader = None

In [6]:
def process_df(csv_path):
    global dataloader

    df = pd.read_csv(csv_path)

    df['is_int'] = df['location'].map(lambda x: (False if x=='US' else True))
    df['is_withdrawal'] = df['transaction_type'].map(lambda x: (True if x=='withdrawal' else False))
    df['is_payment'] = df['transaction_type'].map(lambda x: (True if x=='payment' else False))
    df['transaction_amount'] /= 1000
    df = df.drop(['balance', 'transaction_id', 'account_id', 'transaction_time', 'transaction_type', 'location'], axis=1)
    
    X = df[['transaction_amount', 'is_int', 'is_withdrawal', 'is_payment']]
    y = df[['is_fraud']]

    ds = FraudDataset(X,y)
    dataloader = DataLoader(ds, batch_size=128, shuffle=True)

def on_train_click(b):
    process_df(dataset_select.value)
    model, _ = train_model(get_model(), dataloader, 15)
    trained_models.append(('local_model', model))
    inf_model_select.options = [t for t in trained_models]

def on_apply_click(b):
    t_data = np.array([float(amt_select.value)/1000.0, loc_select.value != 'US', t_type_select.value == 'withdrawal', t_type_select.value == 'payment']).astype(np.float32)
    pred = inf_model_select.value(torch.Tensor(t_data))
    is_fraud.value = (pred.item() > 0.5)

In [7]:
dataset_select = widgets.Combobox(
    description='Select dataset file to load:',
    options=['transaction_data_1.csv', 'transaction_data_2.csv'],
    ensure_option=True,
    disabled=False,
    style= {'description_width': 'initial'},
    layout={'width': 'initial'}
)

train_button = widgets.Button(description='Train Local Model')

train_button.on_click(on_train_click)

widgets.HBox([dataset_select, train_button])

HBox(children=(Combobox(value='', description='Select dataset file to load:', ensure_option=True, layout=Layou…

In [8]:
inf_model_select = widgets.Dropdown(
    options=trained_models,
    description='Select trained model:',
    style= {'description_width': 'initial'}
)

num_features = 5

t_type_select = widgets.Dropdown(
    options=['payment', 'deposit', 'withdrawal', 'transfer', 'interest'],
    placeholder='Type',
    description='Input transaction information:',
    style= {'description_width': 'initial'}
)

amt_select = widgets.Text(
    placeholder='Transaction Amount'
)

balance_select = widgets.Text(
    placeholder='Previous Balance'
)

loc_select = widgets.Dropdown(
    options=['US', 'CAN', 'MEX', 'ENG', 'ITA', 'RUS', 'CHN', 'IND', 'FRA', 'JPN', 'KOR'],
    placeholder='Type'
)

input_row = widgets.HBox([t_type_select, amt_select, balance_select, loc_select])

inf_button = widgets.Button(description='Apply Model')

inf_button.on_click(on_apply_click)

is_fraud = widgets.Valid(
    value=False,
    description='Fraud Detected:',
    style= {'description_width': 'initial'}
)

widgets.VBox([inf_model_select, input_row, inf_button, is_fraud])

VBox(children=(Dropdown(description='Select trained model:', options=(), style=DescriptionStyle(description_wi…

In [9]:
from stadle import AdminAgent, BaseModelConvFormat, BasicClient
from stadle.lib.entity.model import BaseModel

fd_bm = BaseModel("Fraud-Detection-Model", get_model(), BaseModelConvFormat.pytorch_format)

def on_start_fl_click(b):
    global dataloader

    # Upload base model
    fl_progress.description = 'Uploading model metadata'
    fl_progress.value=fl_progress.max
    admin_agent = AdminAgent(aggregator_ip_address=agg_ip_input.value, base_model=fd_bm)
    admin_agent.preload()
    admin_agent.initialize()

    # Start FL
    stadle_client = BasicClient(agent_name=f'agent_{dataset_select.value[-5]}')

    model = get_model()
    stadle_client.set_bm_obj(model)

    fl_progress.style={'bar_color': 'blue', 'description_width': 'initial'}

    for rnd in range(round_lim.value):
        fl_progress.description='Training'
        fl_progress.value=rnd+1
        model, loss = train_model(model, dataloader, 2)
        fl_progress.description='Aggregating'
        stadle_client.send_trained_model(model, perf_values={'loss_training':loss})
        fl_sd = stadle_client.wait_for_sg_model().state_dict()
        model.load_state_dict(fl_sd)

    fl_progress.description='FL complete'
    fl_progress.style={'bar_color': 'green', 'description_width': 'initial'}

    trained_models.append(('FL_model', model))
    inf_model_select.options = [t for t in trained_models]

In [10]:
agg_ip_input = widgets.Text(
    description='STADLE Aggregator Address:',
    style= {'description_width': 'initial', 'width': 'initial'}
)

round_lim = widgets.IntText(
    value=5,
    description='Number of FL rounds:',
    style= {'description_width': 'initial'}
)

fl_button = widgets.Button(description='Start FL')

fl_progress = widgets.IntProgress(
    description='Waiting for FL...',
    value=0,
    min=0,
    max=round_lim.value,
    style={'bar_color': 'yellow', 'description_width': 'initial'},
    orientation='horizontal'
)

fl_button.on_click(on_start_fl_click)

fl_row = widgets.HBox([agg_ip_input, round_lim])
widgets.VBox([fl_row, fl_button, fl_progress])

VBox(children=(HBox(children=(Text(value='', description='STADLE Aggregator Address:', style=DescriptionStyle(…

2023-02-02 09:21:09 INFO           AdminAgent|   Agent initialized at 1675329669.9557314
2023-02-02 09:21:10 INFO           AdminAgent|   Starting upload of base model
Base Model: 
	Name: Fraud-Detection-Model
	Type: PyTorch
	Model Object Serialized: True
2023-02-02 09:21:11 WARN          ImportCheck|   flask/gevent not found - only websocket servers can be spawned ('send' not affected)
2023-02-02 09:23:20 ERRO CommunicationHandler|   Cannot connect to k8s-aggregat-aggregat-5bae980385-0f3024fed1b076d6.elb.us-west-1.amazonaws.com:8765 - attempting to connect every 15 seconds (max attempts: 40)
