In [95]:
from skimage.transform import resize
from functools import partial
import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from keras.utils import Sequence
from keras.utils.np_utils import to_categorical
from keras.layers import Conv2D, MaxPooling2D, Dense, Input, Flatten
from keras.models import Model
from keras.optimizers import SGD
from keras.losses import categorical_crossentropy

import keras.backend as K

import tensorflow as tf

In [96]:
def simple_cnn(img_size, n_classes):
        
    inputs = Input(img_size)
    conv1 = Conv2D(kernel_size=5, filters=32)(inputs)
    pool1 = MaxPooling2D(pool_size=2, strides=2)(conv1)
    conv2 = Conv2D(kernel_size=5, filters=64)(pool1)
    pool2 = MaxPooling2D(pool_size=2, strides=2)(conv2)
    flat = Flatten()(pool2)
    dense = Dense(units=512, activation='relu')(flat)
    out = Dense(units=n_classes, activation='softmax')(dense)
    model = Model(inputs=inputs, outputs=out)
    
    return model
    
def get_model(config):
    model = simple_cnn(config.data.img_size, config.data.n_classes)
    model.compile(optimizer=SGD(lr=config.train.learning_rate),
                  loss='categorical_crossentropy',
                  metrics=['acc'])
    return model

class DataGenerator(Sequence):
    'Generates data for Keras'
    def __init__(self, df_path, 
                 batch_size, 
                 img_size, 
                 n_classes,
                 client_colm=None,
                 num=None, 
                 shuffle=True):
        'Initialization'
        self.num = num
        self.shuffle = shuffle
        df = pd.read_csv(df_path)
        if num is not None:
            rows = df[df[client_colm]==self.num]
        else:
            rows = df
        self.batch_size = batch_size
        self.img_size = img_size
        self.n_classes = n_classes
        self.filenames = rows.filename.values
        self.labels = rows.label.values
        
        self.on_epoch_end()

    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.floor(len(self.filenames) / self.batch_size))

    def __getitem__(self, index):
        'Generate one batch of data'
        # Generate indexes of the batch
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
        # Generate data
        X, y = self.__data_generation(indexes)

        return X, y

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.indexes = np.arange(len(self.filenames))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)

    def __data_generation(self, batch_inds):
        'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
        # Initialization
        X = np.empty((self.batch_size, *self.img_size))
        y = np.empty((self.batch_size), dtype=int)

        # Generate data
        for i, idx in enumerate(batch_inds):
            # Store sample
            X[i,] =  plt.imread(self.filenames[idx])[...,:3]

            # Store class
            y[i] = self.labels[idx]

        return X, to_categorical(y, num_classes=self.n_classes)        

In [97]:
from easydict import EasyDict
config = EasyDict(
{
 'data': {
     'train_df_path' : 'train_resized.csv',
     'val_df_path': 'val_resized.csv',
     'img_size': (256, 256, 3),
     'batch_size': 10,
     'n_classes': 12,
     'client_column': 'shard_iid',
     
 },
 'train' : {
     'learning_rate': 1e-3,
     'epochs': 5,
      'client_fraction':0.2,
     'num_clients': 10,
     'num_rounds': 10000
     
 },
    'log': {
        'path': './results/01-fed-avg-iid'
    },
    
#     'resume': {
#         'path': './results/03-fed-avg-non_iid'
#     }
}
)


In [98]:
def client_update(config, num, model, weights):
    print(num)
    print(pd.DataFrame(pd.read_csv(config.data.train_df_path).query('{}=={}'.format(
        config.data.client_column, num)).label.value_counts()).T)
    model.set_weights(weights)
    dataset = DataGenerator(df_path=config.data.train_df_path, 
                          batch_size=config.data.batch_size, 
                          img_size=config.data.img_size, 
                          n_classes=config.data.n_classes,
                          client_colm=config.data.client_column,
                          num=num)
    history = model.fit_generator(dataset, 
                        epochs=config.train.epochs, verbose=True,
                        workers=4
                                  , use_multiprocessing=False)
    weights = model.get_weights()
    return (weights,
            len(dataset.filenames),
            history.history['loss'][-1], 
            history.history['acc'][-1])
    

In [99]:
def average_weights(weights, n_examples):
    weight_lists = map(list, zip(*weights))
    total_examples = np.sum(n_examples)
    return [np.sum(np.stack(w, axis=-1) * n_examples, axis=-1) / total_examples for w in weight_lists]

def fed_averaging(config):
    if not os.path.exists(config.log.path):
        os.makedirs(config.log.path)
    
    logpath = os.path.join(config.log.path, 'csvlogs')
    if not os.path.exists(logpath):
        os.makedirs(logpath)
    model = get_model(config)
    client_model = get_model(config)
    valid_data = DataGenerator(df_path=config.data.val_df_path, 
                          batch_size=config.data.batch_size, 
                          img_size=config.data.img_size, 
                          n_classes=config.data.n_classes,
                          shuffle=False)
    valid_log = pd.DataFrame({'round': [], 
                        'loss': [],
                        'acc': []})
    train_log = pd.DataFrame({'round': [], 
                    'loss': [],
                    'acc': []})
    
    best_score = 0
    
    if 'resume' in config.keys():
        resume_ckpt = os.path.join(config.resume.path, 'ckpt')
        print('Resuming from {}'.format(resume_ckpt))
        model.load_weights(resume_ckpt)
        best_score = pd.read_csv(os.path.join(config.resume.path, 'csvlogs', 'valid')).acc.max()
    print('Best valid acc so far: {}'.format(best_score))
    for t in range(1, config.train.num_rounds + 1):
        print('Round {}'.format(t))
        print('-' * 10)
        print('Training')
        global_weights = model.get_weights()
        m = int(np.ceil(max(config.train.client_fraction * config.train.num_clients, 1)))
        clients = np.random.permutation(config.train.num_clients)[:m]
        local_results = []
        
        for client in clients:
            local_results.append(client_update(config, client, client_model, global_weights))
        
        
        local_weights, n_examples, _tloss, _tacc = zip(*local_results)
        tloss = np.mean(_tloss)
        tacc = np.mean(_tacc)
        model.set_weights(average_weights(local_weights, n_examples))
        print('train_loss {:.4f}, train_acc {:.4f}'.format(tloss, tacc))
        print('Validation')
        vloss, vacc = model.evaluate_generator(valid_data,
                                               verbose=True,
                                               workers=4, use_multiprocessing=True)
        
        valid_log = valid_log.append(pd.DataFrame({'round': [t], 
                                 'loss': vloss,
                                 'acc': vacc}), ignore_index=True)
        train_log = train_log.append(pd.DataFrame({'round': [t], 
                         'loss': tloss,
                         'acc': tacc}), ignore_index=True)
        
        if vacc > best_score:
            model.save_weights(os.path.join(config.log.path, 'ckpt'))
            best_score = vacc
            
        valid_log[['round', 'loss', 'acc']].to_csv(os.path.join(logpath, 'valid'), index=False)
        train_log[['round', 'loss', 'acc']].to_csv(os.path.join(logpath, 'train'), index=False)
        
        print('val_loss {:.4f}, val_acc {:.4f}'.format(vloss, vacc))
        print()
        print()

    pool.close()
    pool.join()
        

In [None]:
K.clear_session()
fed_averaging(config)

Best valid acc so far: 0
Round 1
----------
Training
0
       6   3   10  11  8   5   0   9   2   1   4   7 
label  61  52  42  41  37  35  26  24  24  23  19  16
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
7
       6   3   5   8   10  1   2   11  7   4   0   9 
label  67  60  44  42  38  31  25  22  22  20  16  13
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
train_loss 2.2740, train_acc 0.2075
Validation
val_loss 2.4532, val_acc 0.1187


Round 2
----------
Training
7
       6   3   5   8   10  1   2   11  7   4   0   9 
label  67  60  44  42  38  31  25  22  22  20  16  13
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
5
       3   5   10  6   1   8   11  2   4   0   9   7 
label  57  53  45  44  40  39  26  24  19  19  18  16
Epoch 1/5
Epoch 2/5
Epoch 3/5