# Modifying and Loading Parameters of Policies (Gradient Update)

## Importing Libraries

In [1]:
from typing import Dict

import gym
import numpy as np
import torch as th

from stable_baselines3 import A2C as ALGO
from stable_baselines3.common.evaluation import evaluate_policy

## Init. ENV and Model

In [2]:
env = gym.make('CartPole-v1')
model = ALGO(
    "MlpPolicy",
    env
)

model_trained = ALGO(
    "MlpPolicy",
    env
)

## Function to Evaluate Model 

In [3]:
def evaluate(model, env, message = ''):
    fitnesses = []
    iterations = 20
    for i in range(iterations):
        fitness, _ = evaluate_policy(model, env)
        print(i, fitness, end="\r")
        fitnesses.append(fitness)

    mean_fitness = np.mean(sorted(fitnesses))
    print(f'Type {message} Mean reward: {mean_fitness}')

## Initial Evaluation

In [4]:
evaluate(model, env)
evaluate(model_trained, env)



Type  Mean reward: 57.325
Type  Mean reward: 83.95000000000002


## Train for 1K Steps and Evaluate

In [5]:
# Train MT Model
model_trained.learn(total_timesteps=10_00)
evaluate(model_trained, env)
evaluate(model, env)

Type  Mean reward: 134.14000000000001
Type  Mean reward: 53.510000000000005


## Apply Gradient and Evaluate

In [6]:
state_dict = model.policy.state_dict()
optim_dict = model_trained.policy.optimizer.param_groups[0]['params']
optim_alpha = model.policy.optimizer.param_groups[0]['alpha']

# print(state_dict['q_net.q_net.0.weight'])

optim_index = 0
for key, value in state_dict.items():
    # print(key)
    state_dict[key].add_(optim_alpha, optim_dict[optim_index])
    optim_index += 1

model.policy.load_state_dict(state_dict)

evaluate(model, env)

	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha) (Triggered internally at  ../torch/csrc/utils/python_arg_parser.cpp:1025.)
  # Remove the CWD from sys.path while we load stuff.


Type  Mean reward: 80.41


In [7]:
for i in range(10):
    print('Train Iter: ', i)

    model_trained.learn(total_timesteps=10_00)

    evaluate(model_trained, env, 'Trained Model')
    evaluate(model, env, 'Initial Model')

    state_dict = model.policy.state_dict()
    optim_dict = model_trained.policy.optimizer.param_groups[0]['params']
    optim_alpha = model.policy.optimizer.param_groups[0]['alpha']

    # print(state_dict['q_net.q_net.0.weight'])

    optim_index = 0
    for key, value in state_dict.items():
        # print(key)
        state_dict[key].add_(optim_dict[optim_index], alpha=optim_alpha)
        optim_index += 1

    model.policy.load_state_dict(state_dict)

    evaluate(model, env, 'Updated Model')

Train Iter:  0
Type Trained Model Mean reward: 38.964999999999996
Type Initial Model Mean reward: 86.07000000000001
Type Updated Model Mean reward: 59.519999999999996
Train Iter:  1
Type Trained Model Mean reward: 67.56
Type Initial Model Mean reward: 61.779999999999994
Type Updated Model Mean reward: 68.46000000000001
Train Iter:  2
Type Trained Model Mean reward: 47.769999999999996
Type Initial Model Mean reward: 68.525
Type Updated Model Mean reward: 63.165
Train Iter:  3
Type Trained Model Mean reward: 186.17999999999998
Type Initial Model Mean reward: 66.215
Type Updated Model Mean reward: 90.53999999999999
Train Iter:  4
Type Trained Model Mean reward: 500.0
Type Initial Model Mean reward: 88.22999999999999
Type Updated Model Mean reward: 151.54
Train Iter:  5
Type Trained Model Mean reward: 159.76
Type Initial Model Mean reward: 151.89000000000001
Type Updated Model Mean reward: 284.32000000000005
Train Iter:  6
Type Trained Model Mean reward: 500.0
Type Initial Model Mean rewar

In [8]:
model.policy.optimizer.param_groups

[{'params': [Parameter containing:
   tensor([[ 1.0806e+00,  1.0122e+00, -1.6773e+00,  3.2338e+00],
           [ 4.0006e+00,  6.7071e-01, -2.7807e+00,  1.7551e+00],
           [-2.1527e+00, -1.4294e+00, -1.7218e+00, -2.7186e+00],
           [-7.7418e-01, -2.7780e+00, -1.5532e+00, -3.3371e+00],
           [ 1.3478e-01,  1.1380e+00, -2.5663e+00, -3.1758e+00],
           [-5.0913e-01,  3.4034e-01,  2.2941e-01,  3.0163e+00],
           [ 2.4074e-01, -1.6346e+00,  7.1230e-01,  2.1018e-01],
           [-1.3023e-01, -8.3644e-02,  2.8941e+00,  1.7126e+00],
           [ 4.7333e+00, -1.9284e+00,  2.0422e+00,  1.3843e+00],
           [ 2.5615e-01,  1.1294e+00,  2.4390e+00,  2.1649e+00],
           [-9.0549e-01, -2.9252e+00, -1.6719e+00, -1.2645e+00],
           [-1.4026e+00,  2.3101e+00, -3.3497e+00, -2.3350e+00],
           [ 3.2149e+00,  7.7405e-01, -5.2920e-01, -3.0006e+00],
           [-9.4652e-01, -1.7432e+00,  5.0262e-01,  1.2102e+00],
           [-1.3078e+00,  3.5630e-01, -1.6209e+00, -1.2