In [None]:
#@title Hyperparameters

RL_ALGORITHM_NAME = 'TD3' #@param ["TD3", "DDPG", "SAC", "PPO"]
FL_MODEL_NAME = 'EMNIST' #@param ["EMNIST", "Shakespeare"]
OBS_TYPE = 'divergence' #@param ["accuracy", "loss", "divergence"]
REWARD_TYPE = 'accuracy' #@param ["accuracy", "loss"]
OBS_INCLUDE_CLIENT_SIZE = 'no' #@param ["no", "yes"]
AGG_STRATEGY = 'scale_sum' #@param ["scale_sum", "scale_raw"]
N_RL_TRAINING_ROUNDS = 5 #@param {type:"integer"}
N_BEST_RL_TRAINING_ROUNDS = 5 #@param {type:"integer"}
N_MODEL_FL_EVAL_ROUNDS = 5 #@param {type:"integer"}
N_FL_MODEL_EVAL_EPISODES = 2 #@param {type:"integer"}
N_FL_LOCAL_TRAINING_ROUNDS = 1 #@param {type:"integer"}
N_FL_CLIENTS = 5 #@param {type:"integer"}
N_OPTUNA_TRIALS = 1 #@param {type:"integer"}

In [None]:
#@title Dependencies + Imports

# FL
!pip install -q fedjax
!pip install -q --upgrade -q jax jaxlib

# RL
!pip install -q gym==0.25.2
!apt install -q swig cmake
!pip install -q stable-baselines3[extra] box2d box2d-kengz
!pip install -q optuna


# Imports 
import jax
import jax.numpy as jnp
import numpy as np
import itertools
import fedjax

import time
import matplotlib.pyplot as plt
import torch
from scipy.special import softmax

import gym
import gym.spaces as spaces
from gym.spaces import Discrete, MultiDiscrete
import os
from stable_baselines3 import TD3, DDPG, SAC, PPO
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.results_plotter import load_results, ts2xy
from stable_baselines3.common.noise import NormalActionNoise
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common import results_plotter
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecEnv, VecNormalize

import optuna
from optuna.pruners import MedianPruner
from optuna.samplers import TPESampler, RandomSampler


from google.colab import drive
drive.mount('/content/gdrive/', force_remount=True)

# RL & FL Models

In [None]:
class RL_Model:

  def __init__(self, name):

    if name == 'TD3':
      self.model = TD3
      self.name = 'TD3'
      self.optuna_params = self.optuna_params_td3

    elif name == 'DDPG':
      self.model = DDPG
      self.name = 'DDPG'
      self.optuna_params = self.optuna_params_ddpg

    elif name == 'SAC':
      self.model = SAC
      self.name = 'SAC'
      self.optuna_params = self.optuna_params_sac

    elif name == 'PPO':
      self.model = PPO
      self.name = 'PPO'
      self.optuna_params = self.optuna_params_ppo


  def optuna_params_td3(self, trial):
    return  {
      'gamma': trial.suggest_float('gamma', 0.1, 0.99),
      'learning_rate': trial.suggest_float('learning_rate', 1e-3, 0.9),
      'learning_starts': trial.suggest_int('learning_starts', 1e2, 1e4),
      'tau': trial.suggest_float('tau', 5e-3, 0.9),
      'target_policy_noise': trial.suggest_float('target_policy_noise', 1e-1, 5e-1),
  }

  def optuna_params_sac(self, trial):
        return  {
          'gamma': trial.suggest_float('gamma', 0.1, 0.99),
          'learning_rate': trial.suggest_float('learning_rate', 1e-4, 0.9),
          'learning_starts': trial.suggest_int('learning_starts', 1e2, 1e4),
          'tau': trial.suggest_float('tau', 5e-3, 0.9),
      }


  def optuna_params_ddpg(self, trial):
        return  {
          'gamma': trial.suggest_float('gamma', 0.1, 0.99),
          'learning_rate': trial.suggest_float('learning_rate', 1e-3, 0.9),
          'learning_starts': trial.suggest_int('learning_starts', 1e2, 1e4),
          'tau': trial.suggest_float('tau', 5e-3, 0.9)
      }


  def optuna_params_ppo(self, trial):
        return  {
          'gamma': trial.suggest_float('gamma', 0.1, 0.99),
          'learning_rate': trial.suggest_float('learning_rate', 3e-4, 0.9),
          'gae_lambda': trial.suggest_float('gae_lambda', 0.1, 0.99),
          'clip_range': trial.suggest_float('clip_range', 0.1, 0.99),
      }

In [None]:
class FL_Model:
  def __init__(self, name):

    if name == 'EMNIST':
      self.train, self.test = fedjax.datasets.emnist.load_data(only_digits=False)
      self.fl_model = fedjax.models.emnist.create_conv_model(only_digits=False)
      self.accuracy_string = 'accuracy'
      self.loss_string = 'loss'

    elif name == 'Shakespeare':
      self.train, self.test = fedjax.datasets.shakespeare.load_data()
      self.fl_model = fedjax.models.shakespeare.create_lstm_model()
      self.accuracy_string = 'accuracy_in_vocab'
      self.loss_string = 'sequence_loss'

    if OBS_TYPE == 'loss':
          self.obs_metric_string = self.loss_string
    else:
        self.obs_metric_string = self.accuracy_string

    if REWARD_TYPE == 'loss':
          self.reward_metric_string = self.loss_string
    else:
        self.reward_metric_string = self.accuracy_string

    self.name = name
    self.rng = jax.random.PRNGKey(0)
    self.init_params = self.fl_model.init(self.rng)
    self.grad_fn = fedjax.model_grad(self.fl_model)
    self.client_optimizer = fedjax.optimizers.sgd(0.1)

    self.client_sampler = fedjax.client_samplers.UniformGetClientSampler(
        self.train, num_clients=N_FL_CLIENTS, seed=1)
    self.client_test_sampler = fedjax.client_samplers.UniformGetClientSampler(
        self.test, num_clients=N_FL_CLIENTS, seed=1)
    
    self.batched_train_data = list(itertools.islice(
        fedjax.padded_batch_federated_data(self.train, batch_size=128), 16))
    self.batched_test_data = list(itertools.islice(
         fedjax.padded_batch_federated_data(self.test, batch_size=128), 8))

In [None]:
RL_MODEL = RL_Model(RL_ALGORITHM_NAME)
FL_MODEL = FL_Model(FL_MODEL_NAME)

# Directories

In [None]:
RL_MODELS_DIR = f"/content/gdrive/MyDrive/FLRL/models/tests/{RL_MODEL.name}/{N_FL_CLIENTS}/{FL_MODEL.name}/{OBS_INCLUDE_CLIENT_SIZE}/{OBS_TYPE}/{REWARD_TYPE}/{AGG_STRATEGY}"
RL_LOG_DIR = f"/content/gdrive/MyDrive/FLRL/logs/tests/{RL_MODEL.name}/{N_FL_CLIENTS}/{FL_MODEL.name}/{OBS_INCLUDE_CLIENT_SIZE}/{OBS_TYPE}/{REWARD_TYPE}/{AGG_STRATEGY}"

if not os.path.exists(RL_MODELS_DIR):
    os.makedirs(RL_MODELS_DIR)

if not os.path.exists(RL_LOG_DIR):
    os.makedirs(RL_LOG_DIR)

# RL/FL Environmnent

In [None]:
class FLRLEnv(gym.Env):

    metadata = {'render_modes': ['human']}
  
    def __init__(self, FL_MODEL, N_MODEL_FL_EVAL_ROUNDS, N_FL_LOCAL_TRAINING_ROUNDS, OBS_TYPE, REWARD_TYPE, OBS_INCLUDE_CLIENT_SIZE):
        super().__init__()

        self.action_space = spaces.Box(low=-1, high=1, shape=(N_FL_CLIENTS,), dtype=np.float32)

        if OBS_INCLUDE_CLIENT_SIZE == 'yes':
          self.observation_space = spaces.Box(low=0, high=np.inf, shape=(N_FL_CLIENTS, 2), dtype=np.float32)
        else:
          self.observation_space = spaces.Box(low=0, high=np.inf, shape=(N_FL_CLIENTS,), dtype=np.float32)

    def client_update(self, init_params, client_dataset, client_rng, grad_fn):
      opt_state = FL_MODEL.client_optimizer.init(init_params)
      params = init_params
      for batch in client_dataset.shuffle_repeat_batch(batch_size=10):
        client_rng, use_rng = jax.random.split(client_rng)
        grads = grad_fn(params, batch, use_rng)
        opt_state, params = FL_MODEL.client_optimizer.apply(grads, opt_state, params)
      delta_params = jax.tree_util.tree_map(lambda a, b: a - b,
                                             init_params, params)
      return len(client_dataset), params, delta_params


    def run_one_fl_round(self, global_params):
      sampled_clients_with_data = FL_MODEL.client_sampler.sample()
      parameter_list = {}
      metric_list = {}

      original_params = global_params

      for num in range(N_FL_LOCAL_TRAINING_ROUNDS):
        for client_id, client_data, client_rng in sampled_clients_with_data:
          num_samples, global_params, gradients = self.client_update(global_params, client_data, client_rng, FL_MODEL.grad_fn)

          if (num == (N_FL_LOCAL_TRAINING_ROUNDS-1)):

            if OBS_TYPE == 'divergence':
              neg = fedjax.tree_util.tree_weight(original_params, -1)
              sum = fedjax.tree_util.tree_add(neg, global_params)
              metric = fedjax.tree_util.tree_l2_norm(sum)

            else:
              metric = (fedjax.evaluate_model(FL_MODEL.fl_model, global_params, FL_MODEL.batched_train_data))[FL_MODEL.obs_metric_string]
            if OBS_INCLUDE_CLIENT_SIZE == 'yes':
              metric_list[client_id] = (metric, num_samples)
            else:
              metric_list[client_id] = metric
            parameter_list[client_id] = gradients
          
      metric_list = [metric_list[key] for key in sorted(metric_list.keys())]
      parameter_list = [parameter_list[key] for key in sorted(parameter_list.keys())]

      return metric_list, parameter_list

    def aggregate_and_evaluate_parameters(self, old_global, parameter_list, scale_list):

      if AGG_STRATEGY == 'scale_raw':
        weighted_trees = [fedjax.tree_util.tree_weight(parameter_list[x], scale_list[x]) for x in range(len(scale_list))]
        new_params = fedjax.tree_util.tree_sum(weighted_trees)

      else:
        scaled_list = [(parameter_list[x], scale_list[x]) for x in range(len(scale_list))]
        new_params = fedjax.tree_util.tree_mean(scaled_list)

      optimizer = fedjax.optimizers.sgd(learning_rate=0.9)
      opt_state = optimizer.init(old_global)
      _, new_params = optimizer.apply(new_params, opt_state, old_global)

      metric = (fedjax.evaluate_model(FL_MODEL.fl_model, new_params, FL_MODEL.batched_train_data))[FL_MODEL.reward_metric_string]
      return new_params, metric


    def get_action(self, action):
      action = ((np.array(action)+1.00001)/2)
      if AGG_STRATEGY != 'scale_raw':
        action = action*100
      return action


    def step(self, action):
  
      self.steps += 1

      action = self.get_action(action)
      
      self.global_params, metric = self.aggregate_and_evaluate_parameters(self.global_params, self.param_list, action)

      if REWARD_TYPE == 'loss':
        reward = - metric
      else:
        reward = metric
        
      metric_list, self.param_list = self.run_one_fl_round(self.global_params)
      obs = metric_list

      if (self.steps >= N_MODEL_FL_EVAL_ROUNDS):
        done = True
      else:
        done = False
      
      info = {}
      return obs, reward, done, info     


    def reset(self):
      self.global_params = FL_MODEL.fl_model.init(FL_MODEL.rng)
      metric_list, self.param_list = self.run_one_fl_round(self.global_params)
      self.steps = 0         
      return metric_list

          
    def render(self, action, reward, mode='human'):
      if mode == 'human':
        print(f"action: {action}, reward = {reward}") 
      else:
        super().render(mode=mode) 

# Evalutation Flows

In [None]:
def client_update(init_params, client_dataset, client_rng, grad_fn):
  opt_state = FL_MODEL.client_optimizer.init(init_params)
  params = init_params
  for batch in client_dataset.shuffle_repeat_batch(batch_size=10):
    client_rng, use_rng = jax.random.split(client_rng)
    grads = grad_fn(params, batch, use_rng)
    opt_state, params = FL_MODEL.client_optimizer.apply(grads, opt_state, params)
  return len(client_dataset), params

In [None]:
def get_action(action):

    action = ((np.array(action)+1.00001)/2)

    if AGG_STRATEGY != 'scale_raw':
      action = action*100

    return action

In [None]:
def run_fl_evaluation(epochs, init_params, rl_model):


  global_params = init_params
  original_params = init_params

  accuracy_log_list = []
  loss_log_list = []

  for i in range(epochs):

    sampled_clients_with_data = FL_MODEL.client_test_sampler.sample()
    parameter_list = {}
    metric_list = {}

    for num in range(N_FL_LOCAL_TRAINING_ROUNDS):

      for client_id, client_data, client_rng in sampled_clients_with_data:
        num_samples, global_params = client_update(global_params, client_data, client_rng, FL_MODEL.grad_fn)

        if (num == (N_FL_LOCAL_TRAINING_ROUNDS-1)):
          
          if OBS_TYPE == 'divergence':
              neg = fedjax.tree_util.tree_weight(original_params, -1)
              sum = fedjax.tree_util.tree_add(neg, global_params)
              metric = fedjax.tree_util.tree_l2_norm(sum)

          else:
              metric = (fedjax.evaluate_model(FL_MODEL.fl_model, global_params, FL_MODEL.batched_test_data))[FL_MODEL.obs_metric_string]
         
          if OBS_INCLUDE_CLIENT_SIZE == 'yes':
            metric_list[client_id] = (metric, num_samples)
          else:
            metric_list[client_id] = metric
          parameter_list[client_id] = global_params

    metric_list = [metric_list[key] for key in sorted(metric_list.keys())]
    parameter_list = [parameter_list[key] for key in sorted(parameter_list.keys())]

    scale_list, _ = rl_model.predict(metric_list)
    scale_list = get_action(scale_list)

    if AGG_STRATEGY == 'scale_raw':
        weighted_trees = [fedjax.tree_util.tree_weight(parameter_list[x], scale_list[x]) for x in range(len(scale_list))]
        global_params = fedjax.tree_util.tree_sum(weighted_trees)
    else:
      scaled_list = [(parameter_list[x], scale_list[x]) for x in range(len(scale_list))]
      global_params = fedjax.tree_util.tree_mean(scaled_list)

    eval_metrics = (fedjax.evaluate_model(FL_MODEL.fl_model, global_params, FL_MODEL.batched_test_data))
    accuracy, loss = eval_metrics[FL_MODEL.accuracy_string], eval_metrics[FL_MODEL.loss_string]
    accuracy_log_list.append(accuracy)
    loss_log_list.append(loss)

  return  accuracy_log_list, loss_log_list

In [None]:
def run_fedvag(epochs, init_params):


  global_params = init_params
  accuracy_log_list = []
  loss_log_list = []


  for i in range(epochs):

    sampled_clients_with_data = FL_MODEL.client_test_sampler.sample()
    client_updates = []

    for num in range(N_FL_LOCAL_TRAINING_ROUNDS):

      for client_id, client_data, client_rng in sampled_clients_with_data:
        num_samples, global_params = client_update(global_params, client_data, client_rng, FL_MODEL.grad_fn)

        if (num == (N_FL_LOCAL_TRAINING_ROUNDS-1)):
           client_updates.append((global_params, num_samples))


    global_params = fedjax.tree_util.tree_mean(client_updates)

    eval_metrics = (fedjax.evaluate_model(FL_MODEL.fl_model, global_params, FL_MODEL.batched_test_data))
    accuracy, loss = eval_metrics[FL_MODEL.accuracy_string], eval_metrics[FL_MODEL.loss_string]
    accuracy_log_list.append(accuracy)
    loss_log_list.append(loss)

  return  accuracy_log_list,loss_log_list

# Helpers: Optuna, Callback, Initalise

In [None]:
def create_rl_model_and_environment(train_test, params):

  if train_test == 'training':
    n_envs = 1
    environment = [lambda: Monitor(FLRLEnv(FL_MODEL, N_MODEL_FL_EVAL_ROUNDS, N_FL_LOCAL_TRAINING_ROUNDS, OBS_TYPE, REWARD_TYPE, OBS_INCLUDE_CLIENT_SIZE), f"{RL_LOG_DIR}")  for i in range(n_envs)]
    environment = DummyVecEnv(environment)
    environment= VecNormalize(environment)
 
    if params != None:
      if (RL_MODEL.name != 'TD3' and  RL_MODEL.name != 'DDPG'):
        rl_model = (RL_MODEL.model)("MlpPolicy", environment, verbose=1, **params)
      else:
        action_noise = NormalActionNoise(mean=np.zeros(N_FL_CLIENTS), sigma=0.1 * np.ones(N_FL_CLIENTS))
        rl_model = (RL_MODEL.model)("MlpPolicy", environment, action_noise=action_noise, verbose=1, **params)
    else:
      if (RL_MODEL.name != 'TD3' and  RL_MODEL.name != 'DDPG'):
        rl_model = (RL_MODEL.model)("MlpPolicy", environment, verbose=1)
      else:
        action_noise = NormalActionNoise(mean=np.zeros(N_FL_CLIENTS), sigma=0.1 * np.ones(N_FL_CLIENTS))
        rl_model = (RL_MODEL.model)("MlpPolicy", environment, action_noise=action_noise, verbose=1)

  else:
    n_envs = 1
    environment = [lambda: Monitor(FLRLEnv(FL_MODEL, N_MODEL_FL_EVAL_ROUNDS, N_FL_LOCAL_TRAINING_ROUNDS, REWARD_TYPE, OBS_TYPE, OBS_INCLUDE_CLIENT_SIZE), f"{RL_LOG_DIR}")  for i in range(n_envs)]
    environment = DummyVecEnv(environment)
    environment= VecNormalize(environment, norm_reward=False)

    rl_model = None
    
  return rl_model, environment

In [None]:
class Objective:

  def __init__(self):
    self.best_model = None
    self.best_env = None
    self.best_reward = None
    self.model = None
    self.env = None
    self.reward = None


  def __call__(self, trial):

    model, env = create_rl_model_and_environment('training', RL_MODEL.optuna_params(trial))
    model.learn(total_timesteps=N_RL_TRAINING_ROUNDS)
    self.model = model
    self.env = env
    mean_reward, _ = evaluate_policy(model, env, n_eval_episodes=N_FL_MODEL_EVAL_EPISODES)
    self.reward = mean_reward

    return 1 * mean_reward

  def callback(self, study, trial):
        if study.best_trial == trial:
            print("Saving new best model and env")
            self.best_model = self.model
            self.best_env = self.env
            self.best_reward = self.reward

In [None]:
class SaveBestModel(BaseCallback):

    def __init__(self):
        super(SaveBestModel, self).__init__()
        self.check_freq = N_MODEL_FL_EVAL_ROUNDS
        self.log_dir = RL_LOG_DIR
        self.save_path = os.path.join(RL_MODELS_DIR, 'best_model')
        self.best_mean_reward = -np.inf

    def _on_step(self) -> bool:
      
        if self.n_calls % self.check_freq == 0:

          x, y = ts2xy(load_results(self.log_dir), 'timesteps')

          if len(x) > 0:
              mean_reward = np.mean(y[-N_MODEL_FL_EVAL_ROUNDS:])
              print(f"Best mean episode reward: {self.best_mean_reward:.2f} - Latest mean episode reward: {mean_reward:.2f}")

              if mean_reward > self.best_mean_reward:
                  self.best_mean_reward = mean_reward
                  print(f"Saving new best model  {mean_reward} to {self.save_path}.zip")
                  self.model.save(self.save_path)
                  
        return True

In [None]:
def fl_final_evaluation_plot(title, model_list, fed_avg_list):
 
  fig = plt.figure(title,figsize=(16,10), dpi= 200)
  plt.plot(range(1,N_MODEL_FL_EVAL_ROUNDS+1), model_list, 'g', label='RL Model')
  plt.plot(range(1,N_MODEL_FL_EVAL_ROUNDS+1), fed_avg_list, 'b', label='FedAvg')
  plt.xlabel('FL Round')
  plt.ylabel(title)
  plt.legend()
  plt.show()

  return plt

In [None]:
def train_main():
  # Optimse RL Algo
  print("---------- \RL Algo Train: Tuning RL Model with Optuna")
  objective = Objective()
  sampler = RandomSampler()
  pruner = MedianPruner()
  study = optuna.create_study(sampler=sampler, pruner=pruner, direction="maximize")
  study.optimize(objective, n_trials=N_OPTUNA_TRIALS, n_jobs=1, callbacks=[objective.callback])
  best_reward = objective.best_reward

  # Check Default Param Version
  print("---------- \RL Algo Train: Checking Default Param option")
  rl_model, environment = create_rl_model_and_environment('training', None)
  rl_model.learn(total_timesteps=N_RL_TRAINING_ROUNDS)
  mean_reward, _ = evaluate_policy(rl_model, environment, n_eval_episodes=N_FL_MODEL_EVAL_EPISODES)

  # Choose Default v Optuna:
  if (mean_reward) < best_reward:
    rl_model, environment = create_rl_model_and_environment('training', study.best_params)
    print(f"Using the Optuna Model with params {study.best_params}")
  else:
    print("Using the Default Params Model")

  # Extended Train best RL Model
  print("---------- \RL Algo Train: Extended Training Best RL Model")
  callback = SaveBestModel()
  rl_model.learn(total_timesteps=N_BEST_RL_TRAINING_ROUNDS, callback=callback)

In [None]:
def evaluate_main():

  # Load best RL Model from training:
  model_path = f"{RL_MODELS_DIR}/best_model.zip"
  __ , test_environment = create_rl_model_and_environment('test', None)
  best_rl_model = (RL_MODEL.model).load(model_path, env=test_environment)

  # Evaluate Best Model in FL Test.
  print("---------- \nRL Algo Evaluation: Evaluating Best RL Model in FL Setting")
  fl_accuracy_list, fl_loss_list = np.mean([run_fl_evaluation(N_MODEL_FL_EVAL_ROUNDS, FL_MODEL.init_params, best_rl_model) for i in range(N_FL_MODEL_EVAL_EPISODES)], 0)

  # Get FedAVG Comparison Results.
  print("---------- \nRL Algo Evaluation: Evaluating FEDAVG for FL Setting")
  fedavg_accuracy_list, fedavg_loss_list = np.mean([run_fedvag(N_MODEL_FL_EVAL_ROUNDS, FL_MODEL.init_params) for i in range(N_FL_MODEL_EVAL_EPISODES)], 0)

  outcome_string = f"Best RL Model in FL evaluation is: {fl_accuracy_list[-1]}, FedAVG acuracy is: {fedavg_accuracy_list[-1]}"

  return fl_accuracy_list, fedavg_accuracy_list, fl_loss_list, fedavg_loss_list, outcome_string

# Run: Training

In [None]:
# Train
train_main()

In [None]:
# View Training Performance:
results_plotter.plot_results([RL_LOG_DIR], N_BEST_RL_TRAINING_ROUNDS, results_plotter.X_TIMESTEPS, "Results")
plt.show()

# Run: Evaluation

In [None]:
fl_accuracy_list, fedavg_accuracy_list, fl_loss_list, fedavg_loss_list, outcome_string = evaluate_main()

In [None]:
# Show Results:
acc_plot = fl_final_evaluation_plot ('Accuracy', fl_accuracy_list,  fedavg_accuracy_list)
acc_plot.show()
loss_plot = fl_final_evaluation_plot ('Loss', fl_loss_list,  fedavg_loss_list)
loss_plot.show()

print(outcome_string)
print(RL_MODELS_DIR)