# Multi Train Gradient Update

## Importing Libraries

In [None]:
from typing import Dict
import threading

import gym
import numpy as np
import torch as th

from stable_baselines3 import A2C as ALGO
from stable_baselines3.common.evaluation import evaluate_policy

## Init. ENV and Model

In [None]:
env = gym.make('LunarLander-v2')
model = ALGO(
    "MlpPolicy",
    env
)

model_trained_1 = ALGO(
    "MlpPolicy",
    env
)

model_trained_2 = ALGO(
    "MlpPolicy",
    env
)

## Functions to Evaluate Model and Train Model within Thread

In [None]:
def evaluate(model, env, message = '', verbose = 0):
    fitnesses = []
    iterations = 10
    for i in range(iterations):
        fitness, _ = evaluate_policy(model, env)
        if verbose == 1:
            print(i, fitness, end=" ")
        fitnesses.append(fitness)

    mean_fitness = np.mean(sorted(fitnesses))
    print(f'Type {message} Mean reward: {mean_fitness}')
    return mean_fitness

In [None]:
def train(model, timesteps):
    print('Starting Training')
    model.learn(total_timesteps=timesteps)
    print('Completed Training')

## Initial Evaluation

In [None]:
evaluate(model, env)
evaluate(model_trained_1, env)
evaluate(model_trained_2, env)

## Train for 1K Steps and Evaluate

In [None]:
# Train MT Model 1
t1 = threading.Thread(target=train, args=(model_trained_1, 10_00))

# Train MT Model 2
t2 = threading.Thread(target=train, args=(model_trained_2, 10_00))

# starting thread
t1.start()
t2.start()

# wait until thread is completely executed
t1.join()
t2.join()


# model_trained.learn(total_timesteps=10_00)
evaluate(model_trained_1, env)
evaluate(model_trained_2, env)
evaluate(model, env)

## Apply Gradient and Evaluate

In [None]:
# For Trained Model 1
state_dict = model.policy.state_dict()
optim_dict = model_trained_1.policy.optimizer.param_groups[0]['params']
optim_alpha = model.policy.optimizer.param_groups[0]['alpha']

optim_index = 0
for key, value in state_dict.items():
    # print(key)
    state_dict[key].add_(optim_alpha, optim_dict[optim_index])
    optim_index += 1

model.policy.load_state_dict(state_dict)

# For Trained Model 2
state_dict = model.policy.state_dict()
optim_dict = model_trained_2.policy.optimizer.param_groups[0]['params']
optim_alpha = model.policy.optimizer.param_groups[0]['alpha']

optim_index = 0
for key, value in state_dict.items():
    # print(key)
    state_dict[key].add_(optim_alpha, optim_dict[optim_index])
    optim_index += 1

model.policy.load_state_dict(state_dict)


evaluate(model, env)

In [None]:
for i in range(10):
    print('Train Iter: ', i)

    # Train MT Model 1
    t1 = threading.Thread(target=train, args=(model_trained_1, 10_000))

    # Train MT Model 2
    t2 = threading.Thread(target=train, args=(model_trained_2, 10_000))

    # starting thread
    t1.start()
    t2.start()

    # wait until thread is completely executed
    t1.join()
    t2.join()

    f1 = evaluate(model_trained_1, env, 'Trained Model 1', verbose=1)
    f2 = evaluate(model_trained_2, env, 'Trained Model 2', verbose=1)
    fx = evaluate(model, env, 'Initial Model', verbose=1)

    if f1 > f2:
        # For Trained Model 1
        state_dict = model.policy.state_dict()
        optim_dict = model_trained_1.policy.optimizer.param_groups[0]['params']
        optim_alpha = model.policy.optimizer.param_groups[0]['alpha']

        optim_index = 0
        for key, value in state_dict.items():
            # print(key)
            state_dict[key].add_(optim_alpha, optim_dict[optim_index])
            optim_index += 1

        model.policy.load_state_dict(state_dict)
        model_trained_1.policy.load_state_dict(state_dict)
        model_trained_2.policy.load_state_dict(state_dict)

    if f2 > f1:
        # For Trained Model 2
        state_dict = model.policy.state_dict()
        optim_dict = model_trained_2.policy.optimizer.param_groups[0]['params']
        optim_alpha = model.policy.optimizer.param_groups[0]['alpha']

        optim_index = 0
        for key, value in state_dict.items():
            # print(key)
            state_dict[key].add_(optim_alpha, optim_dict[optim_index])
            optim_index += 1

        model.policy.load_state_dict(state_dict)
        model_trained_1.policy.load_state_dict(state_dict)
        model_trained_2.policy.load_state_dict(state_dict)

    # if fx > f1 and fx > f2:
    #     # For Trained Model 2
    #     state_dict = model.policy.state_dict()
    #     optim_dict = model.policy.optimizer.param_groups[0]['params']
    #     optim_alpha = model.policy.optimizer.param_groups[0]['alpha']

    #     optim_index = 0
    #     for key, value in state_dict.items():
    #         # print(key)
    #         state_dict[key].add_(optim_alpha, optim_dict[optim_index])
    #         optim_index += 1

    #     model_trained_1.policy.load_state_dict(state_dict)
    #     model_trained_2.policy.load_state_dict(state_dict)
    

    evaluate(model, env, 'Updated Model', verbose=1)

In [None]:
model.get_parameters()

In [None]:
model.save('a2c_lunar_multiproc')

In [None]:
# Exporting Params as JSON
## Function to Convert Params Dict to Flattened List
def flatten_list(params):
    """
    :param params: (dict)
    :return: (np.ndarray)
    """
    params_ = {}
    for key in params.keys():
        params_[key] = params[key].tolist()
    return params_
## Write Parameters to JSON File
import json

all_params = model.get_parameters()
pol_params = flatten_list(all_params['policy'])

all_params['policy'] = pol_params

with open('a2c_lunar_multiproc.json', 'w') as f:
    json.dump(all_params, f, indent='\t')

In [None]:
model_loaded = ALGO(
    "MlpPolicy",
    env
)

evaluate(model_loaded,env, verbose=1)

new_params = all_params
loaded_pol_params = new_params['policy']
for key in loaded_pol_params.keys():
    loaded_pol_params[key] = th.tensor(loaded_pol_params[key])

new_params['policy'] = loaded_pol_params

model_loaded.set_parameters(new_params)

In [None]:
env.reset()
evaluate(model_loaded,env, verbose=1)