In [None]:
%load_ext autoreload
%autoreload 2
import run_pg

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from tqdm import tqdm

In [None]:
rng = np.random.RandomState(1234)

In [None]:
from collections import defaultdict

In [None]:
all_episode_returns = defaultdict(list)

In [None]:
# all_results = []

In [None]:
param_ranges = {
    'learning_rate': 10 ** np.linspace(-1.0, -0., num=1000),
    'gamma': np.linspace(0.7, 0.95, num=1000)
}

In [None]:
def sample_params(param_ranges):
    return {key: rng.choice(range) 
            for key, range in param_ranges.items()}

In [None]:
obs_horizons = [1, 2, 3]
for i in tqdm(range(5*len(obs_horizons))):
    obs_horizon = obs_horizons[i % len(obs_horizons)]
    
    params = sample_params(param_ranges)
    
    results = run_pg.run(10000, 
                         obs_horizon=obs_horizon, 
                         learning_rate=params['learning_rate'], # 0.03
                         gamma=params['gamma'])                 # 0.7
    all_episode_returns[obs_horizon].append(
        {**results,
          **params})
#     all_results.append(results)

In [None]:
def jsonify_lists_dicts_nparrays(xs):
    if isinstance(xs, list):
        return [jsonify_lists_dicts_nparrays(x) for x in xs]
    elif isinstance(xs, dict):
        return {key: jsonify_lists_dicts_nparrays(x) for key, x in xs.items()}
    elif isinstance(xs, np.ndarray):
        return xs.tolist()
    else:
        return xs

In [None]:
import json
from datetime import datetime
with open('results.json', 'w') as f:
    json.dump(jsonify_lists_dicts_nparrays(results), f)

In [None]:
fig, axs = plt.subplots(ncols=4, figsize=(12, 4), sharex=True, sharey=True)
for i in range(4):
    axs[i].imshow(results['model_weights'][0][:, i].reshape(-1, 5))
plt.show()

In [None]:
# np.mean(all_episode_returns[], axis=1)

In [None]:
import pandas as pd

In [None]:
%matplotlib inline

In [None]:
def plot_episode_returns(ax, episode_returns, **kwargs):
    pd.Series(episode_returns).rolling(1000, center=True).mean().plot(ax=ax, **kwargs)
    ax.set_xlabel('Episode')
    ax.set_ylabel('Return')
    
#     pd.Series(episode_returns).rolling(500, center=True).quantile(.05).plot(ax=ax)
#     pd.Series(episode_returns).rolling(500, center=True).quantile(.95).plot(ax=ax)

In [None]:
def subselect_results(results, fraction, ):
    n = int(round(fraction * len(results)))
    return sorted(results, key=lambda result: -np.sum(result[-10000:]))[:n]

In [None]:
obs_horizons = [1, 3]
colors_by_obsho = {
    1: 'orange',
    2: 'red',
    3: 'darkred'
}

# fig, ax = plt.subplots(figsize=(7,5))
# for obsho in obs_horizons:
#     for results in all_episode_returns[obsho]:
#         plot_episode_returns(ax, results['episode_returns'], color=colors_by_obsho[obsho], alpha=0.2)
# plt.show()    
    
fig, ax = plt.subplots(figsize=(7,5))
for obsho in obs_horizons:
    plot_episode_returns(ax, np.mean([results['episode_returns'] 
                                      for results in all_episode_returns[obsho]], axis=0),
#                             color=colors_by_obsho[obsho],
                            label=f'horizon {obsho}')
ax.legend(loc=0)
plt.show()


fig, ax = plt.subplots(figsize=(7,5))
for obsho in obs_horizons:
    plot_episode_returns(ax, np.mean(subselect_results([results['episode_returns'] 
                                                         for results in all_episode_returns[obsho]], 0.1), 
                                      axis=0),
                        color=colors_by_obsho[obsho]
                        ) 
plt.show()

                         
#                          color=colors_by_obsho[obsho])
#     plot_episode_returns(ax, np.max(all_episode_returns[obsho], axis=0), 
#                          color=colors_by_obsho[obsho])

# plot_episode_returns(episode_returns_old)

In [None]:
for obsho in obs_horizons:
    results = all_episode_returns[obsho]
    fig, axs = plt.subplots(ncols=2)
    axs[0].scatter(
        [result['learning_rate'] for result in results],
        [np.mean(result['episode_returns']) for result in results]
    )
    axs[1].scatter(
        [result['gamma'] for result in results],
        [np.mean(result['episode_returns']) for result in results]
    )

    plt.show()
    

In [None]:
os.system('say We are done now. Congratulations')

In [None]:
import pg

In [None]:
import food_search_env

In [None]:
for obsho in obs_horizons:
    print('obsho', obsho)
    env = food_search_env.FoodSearch(obs_horizon=obsho, n_noise_channels=2)
    state_size = np.product(env.observation_space.shape)
    action_size = env.action_space.n
    agent = pg.PGAgent(state_size, action_size, 0.0, 0.9)
    
    best_weights = max(all_episode_returns[obsho], key=lambda result: np.sum(result['episode_returns']))['model_weights']
    agent.model.set_weights(best_weights)
    for _ in range(5):
        food_search_env.render_video(env, agent.get_policy(), name_prefix=f'obsho{obsho}')
    