In [1]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import tensorflow as tf

In [2]:
tf.__version__

'2.11.0'

## Create Binary Classification data with sklearn

In [3]:
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split


n = 100_000
d = 100


noise_factor = 0.01 # % of the labels are randomly flipped, DEFAULT=0.01
test_size = 0.1 # % of n
# The factor multiplying the hypercube size. Larger values spread out the 
# clusters/classes and make the classification task easier. DEFAULT=1
class_sep = -1
seed = 7

# Create (noisy) testing data for binary classification.
X, y = make_classification(
    n_samples=n, 
    n_features=d,
    n_informative=d,
    n_redundant=0, 
    n_classes=2,
    class_sep=class_sep,
    flip_y=noise_factor,
    random_state=seed
)

# We will work with label values -1, +1 and not 0, +1 (convert)
y[y == 0] = -1

# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=seed)

In [4]:
from sklearn.linear_model import PassiveAggressiveClassifier
from sklearn.metrics import accuracy_score

# PA-I regressor from sklearn
pa1 = PassiveAggressiveClassifier(C=0.01, loss="hinge", n_jobs=-1)
pa1.fit(X_train, y_train)

accuracy_score(y_test, pa1.predict(X_test))

0.7888

## Convert to Tensors

In [5]:
X_train_tensor = tf.constant(X_train, dtype=tf.float32)
y_train_tensor = tf.constant(y_train, dtype=tf.float32)
X_test_tensor = tf.constant(X_test, dtype=tf.float32)
y_test_tensor = tf.constant(y_test, dtype=tf.float32)

Delete sklearn type data 

In [6]:
del X, y, X_train, X_test, y_train, y_test

## Prepare data for Federated Learning

In [7]:
import collections

### Create centralized testing dataset

In [8]:
slices_test = collections.OrderedDict([('y', y_test_tensor), ('x', X_test_tensor)])

In [9]:
def create_tf_dataset_for_testing(batch_size):
    return tf.data.Dataset.from_tensor_slices(slices_test).batch(batch_size)

In [10]:
test_dataset = create_tf_dataset_for_testing(32)

### Slice the Tensors for each Client

We will cut the training data, i.e., (`X_train_tensor`, `y_train_tensor`) to equal parts, each part corresponding to one Client. We want to give the result back as a dictionary with key `client_id` and value the training tensor data. Note that the training Tensor data will be processed according to the standard, i.e., an `OrderedDict` with `y` and `x` attributes.

In [11]:
def create_data_for_clients(num_clients):
    
    client_slices_train = {}

    n_test = int(n - n*test_size)

    for i in range(num_clients):
        # Compute the indices for this client's slice
        start_idx = int(i * n_test / num_clients)
        end_idx = int((i + 1) * n_test / num_clients)

        # Get the slice for this client
        X_client_train = X_train_tensor[start_idx:end_idx]
        y_client_train = y_train_tensor[start_idx:end_idx]

        client_data_train = collections.OrderedDict([('y', y_client_train), ('x', X_client_train)])

        # Combine the slices into a single dataset
        client_slices_train[f'client_{i}'] = client_data_train
    
    return client_slices_train

### Create TF friendly data for each Client

Given a Tensor slice (i.e. value of `client_slices_train["client_id"]` we convert it to highly optimized `tf.data.Dataset` to prepare for training.

In [12]:
def create_tf_dataset_for_client(client_tensor_slices, batch_size, shuffle_buffer_size, num_steps_until_rtc_check, seed):
    
        return tf.data.Dataset.from_tensor_slices(client_tensor_slices) \
            .shuffle(buffer_size=shuffle_buffer_size, seed=seed).batch(batch_size) \
            .prefetch(tf.data.AUTOTUNE).take(num_steps_until_rtc_check)

### Create Federated Learning data

In [13]:
def create_federated_data(client_slices_train, batch_size, shuffle_buffer_size, num_steps_until_rtc_check, seed=None):
    
    federated_dataset = [ 
        create_tf_dataset_for_client(client_tensor_slices, batch_size, shuffle_buffer_size, num_steps_until_rtc_check, seed)
        for client, client_tensor_slices in client_slices_train.items()
    ]
    
    return federated_dataset

## PA-Classiers (binary classification)

![PA](images/PA_binary_classifiers.png)

In [14]:
@tf.function
def client_train(model, dataset, C):
    
    @tf.function
    def _train_on_batch(model, batch, C):

        x_batch, y_batch = batch['x'], tf.expand_dims(batch['y'], axis=1)

        # dot(w, x) for the batch (each instance of x in x_batch) with with shape=(batchsize, 1)
        weights_dot_x_batch = tf.matmul(x_batch, model)

        # Prediction batch with shape=(batchsize, 1)
        y_pred_batch = tf.sign(weights_dot_x_batch)

        # Suffer loss for each prediction (of instance) in the batch with shape=(batchsize,1)
        loss_batch = tf.maximum(0., 1. - tf.multiply(y_batch, weights_dot_x_batch))

        # shape=(batchsize,1) where each instance is ||x||^2, x in x_batch
        norm_batch = tf.expand_dims(tf.reduce_sum(tf.square(x_batch), axis=1), axis=1)
        
        # PA-1 : Learning rate t for each instance x, with shape=(batchsize,1)
        t_batch = tf.maximum(C, tf.divide(loss_batch, norm_batch))

        # each instance is y*t*x, where y,t scalars and x in x_batch. shape=(batchsize,d)
        t_y_x_batch = tf.multiply(t_batch, tf.multiply(y_batch, x_batch))

        # !!!! Update with mean t*y*x
        t_y_x_update = tf.expand_dims(tf.reduce_mean(t_y_x_batch, axis=0) ,axis=1)

        # Update
        model.assign_add(t_y_x_update)
    
    for batch in dataset:
        _train_on_batch(model, batch, C)
        
    return model

## Accuracy Testing

In [15]:
@tf.function
def accuracy(model, dataset):
    
    @tf.function
    def _batch_accuracy(model, batch):
        x_batch, y_batch = batch['x'], tf.expand_dims(batch['y'], axis=1)

        # dot(w, x) for the batch (each instance of x in x_batch) with with shape=(batchsize, 1)
        weights_dot_x_batch = tf.matmul(x_batch, model)

        # Prediction batch with shape=(batchsize, 1)
        y_pred_batch = tf.sign(weights_dot_x_batch)

        accuracy = tf.reduce_mean(tf.cast(tf.equal(y_pred_batch, y_batch), tf.float32))

        return accuracy
    
    # We take advantage of AutoGraph (convert Python code to TensorFlow-compatible graph code automatically)
    acc, num_batches = 0., 0.
    for batch in dataset:
        acc += _batch_accuracy(model, batch)
        num_batches += 1
        
    acc = acc / num_batches
    
    return acc

## Training Loop

In [20]:
import sys

In [17]:
NUM_CLIENTS = 20

model = tf.Variable(tf.zeros(shape=(d, 1)), trainable=True, name='weights', dtype=tf.float32)
    
client_models = [
    tf.Variable(tf.zeros(shape=(d, 1)), trainable=True, name='weights', dtype=tf.float32)
    for _ in range(NUM_CLIENTS)
]

In [21]:
@tf.function
def train_everything(model, client_models):
    NUM_CLIENTS = 20
    client_slices_train = create_data_for_clients(NUM_CLIENTS)
    BATCH_SIZE = 32
    NUM_STEPS_UNTIL_RTC_CHECK = 1

    federated_dataset = create_federated_data(
        client_slices_train=client_slices_train,
        batch_size=BATCH_SIZE,
        shuffle_buffer_size=int(n/20),
        num_steps_until_rtc_check=NUM_STEPS_UNTIL_RTC_CHECK,
        seed=seed
    )
    i = 1
    for r in range(1, 500):
        for client_model, client_dataset in zip(client_models, federated_dataset):
            tf.print("iter: ", i, output_stream=sys.stdout)
            client_train(client_model, client_dataset, 0.01)
            i += 1

In [None]:
train_everything(model, client_models)

Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
iter:  1
iter:  2
iter:  3
iter:  4
iter:  5
iter:  6
iter:  7
iter:  8
iter:  9
iter:  10
iter:  11
iter:  12
iter:  13
iter:  14
iter:  15
iter:  16
iter:  17
iter:  18
iter:  19
iter:  20
iter:  21
iter:  22
iter:  23
iter:  24
iter:  25
iter:  26
iter:  27
iter:  28
iter:  29
iter:  30
iter:  31
iter:  32
iter:  33
iter:  34
iter:  35
iter:  36
iter:  37
iter:  38
iter:  39
iter:  40
iter:  41
iter:  42
iter:  43
iter:  44
iter:  45
iter:  46
iter:  47
iter:  48
iter:  49
iter:  50
iter:  51
iter:  52
iter:  53
iter:  54
iter:  55
iter:  56
iter:  57
iter:  58
iter:  59
iter:  60
iter:  61
iter:  62
iter:  63
iter:  64
iter:  65
iter:  66
iter:  67
iter:  68
iter:  69
iter:  70
iter:  71
iter:  72
iter:  73
iter:  74
iter:  75
iter:  76
iter:  77
iter:  78
iter:  79
iter:  80
iter:  8

iter:  710
iter:  711
iter:  712
iter:  713
iter:  714
iter:  715
iter:  716
iter:  717
iter:  718
iter:  719
iter:  720
iter:  721
iter:  722
iter:  723
iter:  724
iter:  725
iter:  726
iter:  727
iter:  728
iter:  729
iter:  730
iter:  731
iter:  732
iter:  733
iter:  734
iter:  735
iter:  736
iter:  737
iter:  738
iter:  739
iter:  740
iter:  741
iter:  742
iter:  743
iter:  744
iter:  745
iter:  746
iter:  747
iter:  748
iter:  749
iter:  750
iter:  751
iter:  752
iter:  753
iter:  754
iter:  755
iter:  756
iter:  757
iter:  758
iter:  759
iter:  760
iter:  761
iter:  762
iter:  763
iter:  764
iter:  765
iter:  766
iter:  767
iter:  768
iter:  769
iter:  770
iter:  771
iter:  772
iter:  773
iter:  774
iter:  775
iter:  776
iter:  777
iter:  778
iter:  779
iter:  780
iter:  781
iter:  782
iter:  783
iter:  784
iter:  785
iter:  786
iter:  787
iter:  788
iter:  789
iter:  790
iter:  791
iter:  792
iter:  793
iter:  794
iter:  795
iter:  796
iter:  797
iter:  798
iter:  799
iter:  800

iter:  1417
iter:  1418
iter:  1419
iter:  1420
iter:  1421
iter:  1422
iter:  1423
iter:  1424
iter:  1425
iter:  1426
iter:  1427
iter:  1428
iter:  1429
iter:  1430
iter:  1431
iter:  1432
iter:  1433
iter:  1434
iter:  1435
iter:  1436
iter:  1437
iter:  1438
iter:  1439
iter:  1440
iter:  1441
iter:  1442
iter:  1443
iter:  1444
iter:  1445
iter:  1446
iter:  1447
iter:  1448
iter:  1449
iter:  1450
iter:  1451
iter:  1452
iter:  1453
iter:  1454
iter:  1455
iter:  1456
iter:  1457
iter:  1458
iter:  1459
iter:  1460
iter:  1461
iter:  1462
iter:  1463
iter:  1464
iter:  1465
iter:  1466
iter:  1467
iter:  1468
iter:  1469
iter:  1470
iter:  1471
iter:  1472
iter:  1473
iter:  1474
iter:  1475
iter:  1476
iter:  1477
iter:  1478
iter:  1479
iter:  1480
iter:  1481
iter:  1482
iter:  1483
iter:  1484
iter:  1485
iter:  1486
iter:  1487
iter:  1488
iter:  1489
iter:  1490
iter:  1491
iter:  1492
iter:  1493
iter:  1494
iter:  1495
iter:  1496
iter:  1497
iter:  1498
iter:  1499
iter

iter:  2100
iter:  2101
iter:  2102
iter:  2103
iter:  2104
iter:  2105
iter:  2106
iter:  2107
iter:  2108
iter:  2109
iter:  2110
iter:  2111
iter:  2112
iter:  2113
iter:  2114
iter:  2115
iter:  2116
iter:  2117
iter:  2118
iter:  2119
iter:  2120
iter:  2121
iter:  2122
iter:  2123
iter:  2124
iter:  2125
iter:  2126
iter:  2127
iter:  2128
iter:  2129
iter:  2130
iter:  2131
iter:  2132
iter:  2133
iter:  2134
iter:  2135
iter:  2136
iter:  2137
iter:  2138
iter:  2139
iter:  2140
iter:  2141
iter:  2142
iter:  2143
iter:  2144
iter:  2145
iter:  2146
iter:  2147
iter:  2148
iter:  2149
iter:  2150
iter:  2151
iter:  2152
iter:  2153
iter:  2154
iter:  2155
iter:  2156
iter:  2157
iter:  2158
iter:  2159
iter:  2160
iter:  2161
iter:  2162
iter:  2163
iter:  2164
iter:  2165
iter:  2166
iter:  2167
iter:  2168
iter:  2169
iter:  2170
iter:  2171
iter:  2172
iter:  2173
iter:  2174
iter:  2175
iter:  2176
iter:  2177
iter:  2178
iter:  2179
iter:  2180
iter:  2181
iter:  2182
iter

iter:  2783
iter:  2784
iter:  2785
iter:  2786
iter:  2787
iter:  2788
iter:  2789
iter:  2790
iter:  2791
iter:  2792
iter:  2793
iter:  2794
iter:  2795
iter:  2796
iter:  2797
iter:  2798
iter:  2799
iter:  2800
iter:  2801
iter:  2802
iter:  2803
iter:  2804
iter:  2805
iter:  2806
iter:  2807
iter:  2808
iter:  2809
iter:  2810
iter:  2811
iter:  2812
iter:  2813
iter:  2814
iter:  2815
iter:  2816
iter:  2817
iter:  2818
iter:  2819
iter:  2820
iter:  2821
iter:  2822
iter:  2823
iter:  2824
iter:  2825
iter:  2826
iter:  2827
iter:  2828
iter:  2829
iter:  2830
iter:  2831
iter:  2832
iter:  2833
iter:  2834
iter:  2835
iter:  2836
iter:  2837
iter:  2838
iter:  2839
iter:  2840
iter:  2841
iter:  2842
iter:  2843
iter:  2844
iter:  2845
iter:  2846
iter:  2847
iter:  2848
iter:  2849
iter:  2850
iter:  2851
iter:  2852
iter:  2853
iter:  2854
iter:  2855
iter:  2856
iter:  2857
iter:  2858
iter:  2859
iter:  2860
iter:  2861
iter:  2862
iter:  2863
iter:  2864
iter:  2865
iter

iter:  3466
iter:  3467
iter:  3468
iter:  3469
iter:  3470
iter:  3471
iter:  3472
iter:  3473
iter:  3474
iter:  3475
iter:  3476
iter:  3477
iter:  3478
iter:  3479
iter:  3480
iter:  3481
iter:  3482
iter:  3483
iter:  3484
iter:  3485
iter:  3486
iter:  3487
iter:  3488
iter:  3489
iter:  3490
iter:  3491
iter:  3492
iter:  3493
iter:  3494
iter:  3495
iter:  3496
iter:  3497
iter:  3498
iter:  3499
iter:  3500
iter:  3501
iter:  3502
iter:  3503
iter:  3504
iter:  3505
iter:  3506
iter:  3507
iter:  3508
iter:  3509
iter:  3510
iter:  3511
iter:  3512
iter:  3513
iter:  3514
iter:  3515
iter:  3516
iter:  3517
iter:  3518
iter:  3519
iter:  3520
iter:  3521
iter:  3522
iter:  3523
iter:  3524
iter:  3525
iter:  3526
iter:  3527
iter:  3528
iter:  3529
iter:  3530
iter:  3531
iter:  3532
iter:  3533
iter:  3534
iter:  3535
iter:  3536
iter:  3537
iter:  3538
iter:  3539
iter:  3540
iter:  3541
iter:  3542
iter:  3543
iter:  3544
iter:  3545
iter:  3546
iter:  3547
iter:  3548
iter

iter:  4149
iter:  4150
iter:  4151
iter:  4152
iter:  4153
iter:  4154
iter:  4155
iter:  4156
iter:  4157
iter:  4158
iter:  4159
iter:  4160
iter:  4161
iter:  4162
iter:  4163
iter:  4164
iter:  4165
iter:  4166
iter:  4167
iter:  4168
iter:  4169
iter:  4170
iter:  4171
iter:  4172
iter:  4173
iter:  4174
iter:  4175
iter:  4176
iter:  4177
iter:  4178
iter:  4179
iter:  4180
iter:  4181
iter:  4182
iter:  4183
iter:  4184
iter:  4185
iter:  4186
iter:  4187
iter:  4188
iter:  4189
iter:  4190
iter:  4191
iter:  4192
iter:  4193
iter:  4194
iter:  4195
iter:  4196
iter:  4197
iter:  4198
iter:  4199
iter:  4200
iter:  4201
iter:  4202
iter:  4203
iter:  4204
iter:  4205
iter:  4206
iter:  4207
iter:  4208
iter:  4209
iter:  4210
iter:  4211
iter:  4212
iter:  4213
iter:  4214
iter:  4215
iter:  4216
iter:  4217
iter:  4218
iter:  4219
iter:  4220
iter:  4221
iter:  4222
iter:  4223
iter:  4224
iter:  4225
iter:  4226
iter:  4227
iter:  4228
iter:  4229
iter:  4230
iter:  4231
iter

iter:  4832
iter:  4833
iter:  4834
iter:  4835
iter:  4836
iter:  4837
iter:  4838
iter:  4839
iter:  4840
iter:  4841
iter:  4842
iter:  4843
iter:  4844
iter:  4845
iter:  4846
iter:  4847
iter:  4848
iter:  4849
iter:  4850
iter:  4851
iter:  4852
iter:  4853
iter:  4854
iter:  4855
iter:  4856
iter:  4857
iter:  4858
iter:  4859
iter:  4860
iter:  4861
iter:  4862
iter:  4863
iter:  4864
iter:  4865
iter:  4866
iter:  4867
iter:  4868
iter:  4869
iter:  4870
iter:  4871
iter:  4872
iter:  4873
iter:  4874
iter:  4875
iter:  4876
iter:  4877
iter:  4878
iter:  4879
iter:  4880
iter:  4881
iter:  4882
iter:  4883
iter:  4884
iter:  4885
iter:  4886
iter:  4887
iter:  4888
iter:  4889
iter:  4890
iter:  4891
iter:  4892
iter:  4893
iter:  4894
iter:  4895
iter:  4896
iter:  4897
iter:  4898
iter:  4899
iter:  4900
iter:  4901
iter:  4902
iter:  4903
iter:  4904
iter:  4905
iter:  4906
iter:  4907
iter:  4908
iter:  4909
iter:  4910
iter:  4911
iter:  4912
iter:  4913
iter:  4914
iter

1. Add input_spec everywhere to avoid tracing [here](https://stackoverflow.com/questions/52774351/how-to-run-parallel-map-fn-when-eager-execution-enabled#:~:text=First%2C%20using%20tf.,once%2C%20so%2C%20the%20time.)