# Multi Train Gradient Update

## Importing Libraries

In [1]:
from typing import Dict
import threading

import gym
import numpy as np
import torch

from stable_baselines3 import PPO as ALGO
from stable_baselines3.common.evaluation import evaluate_policy

  for external in metadata.entry_points().get(self.group, []):


In [2]:
# Hyper-Parameters
NUM_CLIENT_MODELS = 4
NUM_TRAINING_STEPS = 10
NUM_ITERATIONS = 10
ENV_NAME = 'CartPole-v1'


## Init. ENV and Model

In [3]:
env = gym.make(ENV_NAME)
global_model = ALGO(
    "MlpPolicy",
    env
)

client_models = [ALGO("MlpPolicy", gym.make(ENV_NAME)) for i in range(NUM_CLIENT_MODELS)]

## Functions to Evaluate Model and Train Model within Thread

In [4]:
def evaluate(model, env, message = '', verbose = 0):
    fitnesses = []
    iterations = 20
    for i in range(iterations):
        fitness, _ = evaluate_policy(model, env)
        if verbose == 1:
            print(i, fitness, end=" ")
        fitnesses.append(fitness)

    mean_fitness = np.mean(sorted(fitnesses))
    print(f'Type {message} Mean reward: {mean_fitness}')

In [5]:
def train(model, timesteps):
    # print('Starting Training')
    model.learn(total_timesteps=timesteps)
    # print('Completed Training')

In [6]:
def multithread_eval(client_models):
    # Create Threads
    client_threads = [] 
    for ci in range(NUM_CLIENT_MODELS):
        thread = threading.Thread(target=evaluate, args=(client_models[ci], gym.make(ENV_NAME), f'Trained Model {ci}'))
        client_threads.append(thread)

    # Start Threads
    for thread in client_threads:
        thread.start()

    # Join Threads (wait until thread is completely executed)
    for thread in client_threads:
        thread.join()

## Initial Evaluation

In [7]:
for model in client_models:
    model.set_parameters(global_model.get_parameters())

global_model.save('initial')

evaluate(global_model, env)

multithread_eval(client_models)



Type  Mean reward: 9.100000000000003
Type Trained Model 1 Mean reward: 9.040000000000001
Type Trained Model 2 Mean reward: 9.08
Type Trained Model 3 Mean reward: 9.205000000000002
Type Trained Model 0 Mean reward: 9.105


## Train for 1K Steps and Evaluate

In [8]:
# Create Threads
client_threads = [] 
for i in range(NUM_CLIENT_MODELS):
    thread = threading.Thread(target=train, args=(client_models[i], NUM_TRAINING_STEPS))
    client_threads.append(thread)


# Start Threads
for thread in client_threads:
    thread.start()

# Join Threads (wait until thread is completely executed)
for thread in client_threads:
    thread.join()

evaluate(global_model, env)

multithread_eval(client_models)

Type  Mean reward: 9.009999999999996
Type Trained Model 3 Mean reward: 239.28000000000006
Type Trained Model 2 Mean reward: 257.29999999999995
Type Trained Model 0 Mean reward: 471.57
Type Trained Model 1 Mean reward: 500.0


## Apply Gradient and Evaluate

In [9]:
global_dict = global_model.policy.state_dict()

# Accumulate Client Parameters / Weights
for k in global_dict.keys():
    global_dict[k] = torch.stack([client_models[i].policy.state_dict()[k].float() for i in range(len(client_models))], 0).mean(0)

# Load New Parameters to Global Model
global_model.policy.load_state_dict(global_dict)

# Load New Parameters to clients
for model in client_models:
    model.policy.load_state_dict(global_model.policy.state_dict())

evaluate(global_model, env)

Type  Mean reward: 366.015


In [10]:
# Evaluation Before Iterated Training
evaluate(global_model, env, "Global Initial Model")

for i in range(NUM_ITERATIONS):
    print('Train Iter: ', i)

    # Create Threads
    client_threads = [] 
    for ci in range(NUM_CLIENT_MODELS):
        thread = threading.Thread(target=train, args=(client_models[ci], NUM_TRAINING_STEPS))
        client_threads.append(thread)


    # Start Threads
    for thread in client_threads:
        thread.start()

    # Join Threads (wait until thread is completely executed)
    for thread in client_threads:
        thread.join()

    # Evaluation after Training
    multithread_eval(client_models)

    # Accumulate Client Parameters / Weights
    global_dict = global_model.policy.state_dict()
    for k in global_dict.keys():
        global_dict[k] = torch.stack([client_models[i].policy.state_dict()[k].float() for i in range(len(client_models))], 0).mean(0)

    # Load New Parameters to Global Model
    global_model.policy.load_state_dict(global_dict)

    # Load New Parameters to clients
    for model in client_models:
        model.policy.load_state_dict(global_model.policy.state_dict())

    # Evaluate Updated Global Model
    evaluate(model, env, 'Global Updated Model', verbose=0)

Type Global Initial Model Mean reward: 375.01499999999993
Train Iter:  0
Type Trained Model 0 Mean reward: 241.565
Type Trained Model 2 Mean reward: 362.35
Type Trained Model 3 Mean reward: 378.09
Type Trained Model 1 Mean reward: 417.055
Type Global Updated Model Mean reward: 370.045
Train Iter:  1
Type Trained Model 1 Mean reward: 349.29
Type Trained Model 3 Mean reward: 390.34000000000003
Type Trained Model 0 Mean reward: 400.995
Type Trained Model 2 Mean reward: 420.48999999999995
Type Global Updated Model Mean reward: 397.07000000000005
Train Iter:  2
Type Trained Model 2 Mean reward: 342.17999999999995
Type Trained Model 1 Mean reward: 395.54499999999996
Type Trained Model 0 Mean reward: 403.505
Type Trained Model 3 Mean reward: 417.78999999999996
Type Global Updated Model Mean reward: 395.4
Train Iter:  3
Type Trained Model 3 Mean reward: 345.83500000000004
Type Trained Model 2 Mean reward: 403.56
Type Trained Model 1 Mean reward: 423.8
Type Trained Model 0 Mean reward: 427.95
T

In [None]:
global_model.get_parameters()

In [None]:
global_model.save('a2c_lunar_multiproc')

In [None]:
# Exporting Params as JSON
## Function to Convert Params Dict to Flattened List
def flatten_list(params):
    """
    :param params: (dict)
    :return: (np.ndarray)
    """
    params_ = {}
    for key in params.keys():
        params_[key] = params[key].tolist()
    return params_
## Write Parameters to JSON File
import json

all_params = global_model.get_parameters()
pol_params = flatten_list(all_params['policy'])

all_params['policy'] = pol_params

with open('a2c_lunar_multiproc.json', 'w') as f:
    json.dump(all_params, f, indent='\t')

In [None]:
# model_loaded = ALGO(
#     "MlpPolicy",
#     env
# )

# evaluate(model_loaded,env, verbose=1)

# import json
# with open('a2c_lunar_multiproc.json', 'w') as f:
#     new_params = json.load(f)

# loaded_pol_params = new_params['policy']
# for key in loaded_pol_params.keys():
#     loaded_pol_params[key] = th.tensor(loaded_pol_params[key])

# new_params['policy'] = loaded_pol_params

# model_loaded.set_parameters(new_params)

model_loaded = ALGO.load('a2c_lunar_multiproc', env)

In [None]:
env.reset()
evaluate(model_loaded,env, verbose=1)