In [1]:
import gym
import numpy as np
from tqdm.notebook import tqdm
from scipy.stats import ttest_ind 
import matplotlib.pyplot as plt

from optimization import Success_Matching, Parallel_Environment

In [2]:
def polynomial_features(x):
    angle = np.arcsin(x[0])
    angle_velocity = x[2]
    x = np.array([angle, angle_velocity])
    return np.array([
        *x, *(x.reshape(-1, 1)@x.reshape(1, -1))[np.triu_indices(x.shape[0])].flatten(), *x**3
    ])

In [3]:
def pendulum_v1_trial(env, params: np.array, seed=None, T=False, tolerance=50, viz=True):
    env.action_space.seed(seed)
    state = env.reset(seed=seed)
    sum_reward = 0
    for t in range(env._max_episode_steps):
        action = np.clip(polynomial_features(state) @ params, -2, 2)
        state, reward, done, info = env.step([action])
        sum_reward += reward
        if done:
            if tolerance > 0 and viz:
                tolerance -= 1
            else:
                break
        if viz: 
            env.render()
    return sum_reward

In [4]:
# rollout functions
reward_function = lambda env, params, seed: pendulum_v1_trial(env, params, seed, viz=False)
viz_function = lambda env, params: pendulum_v1_trial(env, params, viz=True)

# parallel env
p_env = Parallel_Environment('Pendulum-v1', reward_function, viz_function)

# optimizer
n_state = polynomial_features(np.ones(p_env.observation_space_shape[0])).shape[0]
print("n_state", n_state)
sm = Success_Matching(n_state)

n_state 7


In [6]:
p_env.vizualize(sm.mean)

-1603.495317608435

In [7]:
def scheduler(t):
    """
    returns 
    [population_size, lower_bound_variance, lower_bound_importance] 
    for every timestep t
    """
    if t < 1:
        return [1000, 5, 1e-3]
    if t < 3:
        return [500, 1, 1e-4]
    if t < 5:
        return [300, 1e-4, 1e-4]
    if t < 6:
        return [300, 1e-5, 1e-5]
    if t < 7:
        return [300, 1e-6, 1e-5]
    return [300, 1e-7, 0]

In [8]:
# training
e = 10
for i in range(e):
    population_size, lower_bound_variance, lower_bound_importance = scheduler(i)
    print(f"""epoch: {i}
    population_size: {population_size}
    lower_bound_variance: {lower_bound_variance}
    lower_bound_importance: {lower_bound_importance}
    """)
    best_params = sm.train(
        p_env,
        population_size=population_size, 
        epochs=20, 
        variance_optimization=True, 
        variance_decay=False,
        lower_bound_variance=lower_bound_variance,
        lower_bound_importance=lower_bound_importance
    )
    p_env.vizualize(sm.mean)

epoch: 0
    population_size: 1000
    lower_bound_variance: 5
    lower_bound_importance: 0.01
    


  0%|          | 0/20 [00:00<?, ?it/s]

epoch: 1
    population_size: 500
    lower_bound_variance: 1
    lower_bound_importance: 0.001
    


  0%|          | 0/20 [00:00<?, ?it/s]

epoch: 2
    population_size: 500
    lower_bound_variance: 1
    lower_bound_importance: 0.001
    


  0%|          | 0/20 [00:00<?, ?it/s]

epoch: 3
    population_size: 300
    lower_bound_variance: 0.0001
    lower_bound_importance: 0.0001
    


  0%|          | 0/20 [00:00<?, ?it/s]

epoch: 4
    population_size: 300
    lower_bound_variance: 0.0001
    lower_bound_importance: 0.0001
    


  0%|          | 0/20 [00:00<?, ?it/s]

epoch: 5
    population_size: 300
    lower_bound_variance: 1e-05
    lower_bound_importance: 1e-05
    


  0%|          | 0/20 [00:00<?, ?it/s]

epoch: 6
    population_size: 300
    lower_bound_variance: 1e-06
    lower_bound_importance: 1e-05
    


  0%|          | 0/20 [00:00<?, ?it/s]

epoch: 7
    population_size: 300
    lower_bound_variance: 1e-07
    lower_bound_importance: 0
    


  0%|          | 0/20 [00:00<?, ?it/s]

epoch: 8
    population_size: 300
    lower_bound_variance: 1e-07
    lower_bound_importance: 0
    


  0%|          | 0/20 [00:00<?, ?it/s]

epoch: 9
    population_size: 300
    lower_bound_variance: 1e-07
    lower_bound_importance: 0
    


  0%|          | 0/20 [00:00<?, ?it/s]

In [9]:
for _ in range(10):
    p_env.vizualize(sm.mean)

In [10]:
sm.mean

array([-9.51131325,  4.82630526, -4.57864117, -9.28505739, -2.38660498,
        6.81543525, -0.5806123 ])

In [15]:
sm.cov

array([[ 3.26163740e-06,  1.71943688e-08, -1.67979197e-06,
         1.41016586e-06,  9.22003622e-07, -3.62515169e-07,
         2.73179868e-07],
       [ 1.71943688e-08,  4.86506532e-07,  2.57681342e-07,
         1.53402603e-07, -5.99467051e-08, -1.78788135e-07,
        -3.65323828e-08],
       [-1.67979197e-06,  2.57681342e-07,  1.25629688e-06,
        -1.01241885e-06, -6.98042424e-07, -9.07567285e-08,
        -1.76394710e-07],
       [ 1.41016586e-06,  1.53402603e-07, -1.01241885e-06,
         2.68306293e-06,  1.32478809e-06, -1.80575742e-10,
         2.24703051e-07],
       [ 9.22003622e-07, -5.99467051e-08, -6.98042424e-07,
         1.32478809e-06,  7.11828987e-07,  3.22809991e-08,
         1.36760235e-07],
       [-3.62515169e-07, -1.78788135e-07, -9.07567285e-08,
        -1.80575742e-10,  3.22809991e-08,  2.40084025e-07,
        -9.48358835e-09],
       [ 2.73179868e-07, -3.65323828e-08, -1.76394710e-07,
         2.24703051e-07,  1.36760235e-07, -9.48358835e-09,
         1.0000000