In [1]:
import numpy as np
import time
import gym
import torch
from network import ContinuousPolicyNetwork
from spsa import get_alpha, get_delta, update_weights, perturb_policy, Module, revert_weights


In [2]:
def simulate(policy: Module, env: gym.Env, gamma: float) -> float:
    state, _ = env.reset()
    tdx = 0
    G_new = 0
    while True:
        action, log_prob = policy.sample(state)
        act = action.item()
        state, reward, term, trunc, _ = env.step([act])
        done = term or trunc
        G_new += gamma**tdx * reward
        tdx += 1
        if done:
            break
    return G_new

def spsa(env, policy, seed, num_episodes=20000, gamma=0.99, num_trials=10):
    start = time.time()
    results = []
    for episode in range(num_episodes):
        with torch.no_grad():
            # sample perturbations
            perturbed_policy, old_params, perts = perturb_policy(
                policy, delta=get_delta(episode)
            )

            # simulate for num_trials
            rewards = []
            for _ in range(num_trials):
                rewards.append(simulate(perturbed_policy, env, gamma))

            # revert weights of the policy
            policy = revert_weights(perturbed_policy, old_params)

            # update weights according to the paper
            avg_reward = sum(rewards) / len(rewards)
            policy = update_weights(policy, avg_reward, perts, episode)

        results.append(avg_reward)

        if episode % 100 == 0:
            avg = sum(results[-1000:]) / min(len(results), 1000)
            print(
                f"Seed: {seed}, time: {time.time() - start}, Episode {episode}, Average Reward: {avg}"
            )

    return results


In [3]:
seed = 0
iterations = 5000
policy = ContinuousPolicyNetwork()
env = gym.make("Pendulum-v1")
results = spsa(env, policy, seed, iterations)


  if not isinstance(terminated, (bool, np.bool8)):


Seed: 0, time: 0.14887499809265137, Episode 0, Average Reward: -572.6721596333452
Seed: 0, time: 14.804646968841553, Episode 100, Average Reward: -557.8625315164915
Seed: 0, time: 29.41745686531067, Episode 200, Average Reward: -567.9334121502192
Seed: 0, time: 44.33139681816101, Episode 300, Average Reward: -573.975645558429
Seed: 0, time: 58.94672393798828, Episode 400, Average Reward: -576.8748722377385
Seed: 0, time: 73.55594301223755, Episode 500, Average Reward: -577.4157815431045
Seed: 0, time: 87.99592995643616, Episode 600, Average Reward: -577.3072614912123
Seed: 0, time: 102.22212290763855, Episode 700, Average Reward: -578.0279180520799
Seed: 0, time: 117.15559411048889, Episode 800, Average Reward: -578.608784570868
Seed: 0, time: 131.30128288269043, Episode 900, Average Reward: -578.4083520649137
Seed: 0, time: 145.79238390922546, Episode 1000, Average Reward: -577.859921330252
Seed: 0, time: 160.3004801273346, Episode 1100, Average Reward: -580.1247029298187
Seed: 0, tim