In [20]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [21]:
!wandb login 9676e3cc95066e4865586082971f2653245f09b4

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /Users/guydavidson/.netrc
[32mSuccessfully logged in to Weights & Biases![0m


In [22]:
import numpy as np
import pandas as pd
import scipy
from scipy import stats
from scipy.special import factorial

from mpl_toolkits.mplot3d import Axes3D
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import patches
from matplotlib import path as mpath
import matplotlib.gridspec as gridspec

import pickle
import tabulate
import wandb

In [23]:
api = wandb.Api()

In [6]:
run_seeds = np.arange(200, 210)

initial_runs = [run for run in api.runs('augmented-frostbite/initial-experiments')
                if run.config['seed'] in run_seeds]

In [None]:
q_value_means = []
q_value_stds = []

reward_means = []
reward_stds = []

for run in initial_runs:
    history = run.history(pandas=True)
    print(run.name)
    q_value_means.append(np.array(history['Q_value_mean'], dtype=np.float))
    q_value_stds.append(np.array(history['Q_value_std'], dtype=np.float))
    
    reward_means.append(np.array(history['reward_mean'], dtype=np.float))
    reward_stds.append(np.array(history['reward_std'], dtype=np.float))
    
for result_list in (q_value_means, q_value_stds, reward_means, reward_stds):
    max_len = max([arr.shape[0] for arr in result_list])
    
    for i, arr in enumerate(result_list):
        result_list[i] = np.pad(arr, (0, max_len - arr.shape[0]), 'constant', constant_values=np.nan)

# Basic plots

In [None]:
NROWS = 1
NCOLS = 2
COL_WIDTH = 6
ROW_HEIGHT = 5 
WIDTH_SPACING = 2
HEIGHT_SPACING = 0
COLORMAP = 'cool'

figure = plt.figure(figsize=(NCOLS * COL_WIDTH + WIDTH_SPACING, NROWS * ROW_HEIGHT + HEIGHT_SPACING))
plt.subplots_adjust(hspace=0.4, wspace=0.2)

x = np.arange(1, len(reward_means[0]) + 1) * 10000
cmap = plt.get_cmap(COLORMAP)

reward_ax = plt.subplot(NROWS, NCOLS, 1)

for i, (r_mean, r_std) in enumerate(zip(reward_means, reward_stds)):
    color = cmap(i / 9)
    reward_ax.plot(x, r_mean, lw=1, color=color)
    reward_ax.fill_between(x, r_mean - r_std, r_mean + r_std, color=color, alpha=0.10)
    
overall_reward_mean = np.nanmean(reward_means, axis=0)
reward_ax.plot(x, overall_reward_mean, lw=2, color='black')
    
# reward_ax.set_yscale('log')
reward_ax.set_title('Rewards')
reward_ax.set_xlabel('Steps (1 step = 4 frames, 200k frames ~ 1 hr @ 60 fps)')
reward_ax.set_ylabel('Reward')

@matplotlib.ticker.FuncFormatter
def million_formatter(x, pos):
    if x == 0:
        return 0
    
    return f'{x / 10 ** 6:.1f}M'

reward_ax.xaxis.set_major_formatter(million_formatter)


q_ax = plt.subplot(NROWS, NCOLS, 2)

for i, (q_mean, q_std) in enumerate(zip(q_value_means, q_value_stds)):
    color = cmap(i / 9)
    q_ax.plot(x, q_mean, color=color, lw=1)
    q_ax.fill_between(x, q_mean - q_std, q_mean + q_std, color=color, alpha=0.10)
    
overall_q_mean = np.nanmean(q_value_means, axis=0)
q_ax.plot(x, overall_q_mean, lw=2, color='black')
    
# reward_ax.set_yscale('log')
q_ax.set_title('Q-values')
q_ax.set_xlabel('Steps (1 step = 4 frames, 200k frames ~ 1 hr @ 60 fps)')
q_ax.set_ylabel('Average Q-value')

q_ax.xaxis.set_major_formatter(million_formatter)

plt.show()

In [None]:
initial_runs[0].config['seed']

In [8]:
initial_runs[0].name, initial_runs[0].id 

('data-efficient-5M-201', 'yslgd3ls')

In [None]:
h = initial_runs[0].history()

In [None]:
h['steps'].iat[-1]

In [None]:
for existing_run in api.runs('augmented-frostbite/initial-experiments'):
    if existing_run.config['seed'] == 123:
        print(existing_run.history()['steps'])

In [None]:
files = initial_runs[0].files()

In [None]:
for f in initial_runs[0].files('config2.yaml'):
    print(f.name)

In [None]:
initial_runs[0].file('config2.yaml')

In [None]:
r = initial_runs[0]

In [None]:
for r in api.runs('augmented-frostbite/initial-experiments'):
    print(r.name, r.storage_id)

In [None]:
dir(initial_runs[0])

In [None]:
tqdm.trange?

In [None]:
s = 'cabac'

s == s[::-1]

In [24]:
runs = api.runs('augmented-frostbite/initial-experiments', 
                {"$and": [{"config.id": "data-efficient-5M"}, {"config.seed": 200}]})
r = runs[0]

In [25]:
h = r.history(samples=1000)

In [26]:
h.tail(10)

Unnamed: 0,Q_value_mean,Q_value_std,Q_values,_runtime,_step,_timestamp,gradients/convs.0.bias,gradients/convs.0.weight,gradients/convs.2.bias,gradients/convs.2.weight,...,gradients/fc_z_a.weight_sigma,gradients/fc_z_v.bias_mu,gradients/fc_z_v.bias_sigma,gradients/fc_z_v.weight_mu,gradients/fc_z_v.weight_sigma,human_hours,reward_mean,reward_std,rewards,steps
182,3.660273,1.863285,"[5.610603332519531, 5.5790205001831055, 5.6074...",160610.887995,182,1568391000.0,"{'_type': 'histogram', 'values': [1, 0, 2, 0, ...","{'_type': 'histogram', 'values': [1, 2, 7, 4, ...","{'_type': 'histogram', 'values': [2, 0, 0, 1, ...","{'_type': 'histogram', 'values': [1, 1, 0, 2, ...",...,"{'_type': 'histogram', 'values': [1, 1, 0, 0, ...","{'_type': 'histogram', 'values': [1, 0, 0, 0, ...","{'_type': 'histogram', 'values': [1, 0, 0, 0, ...","{'_type': 'histogram', 'values': [1, 0, 0, 0, ...","{'_type': 'histogram', 'values': [1, 0, 0, 0, ...",33.888889,2755,514.125471,"[2240, 2400, 3170, 2400, 2400, 3530, 2400, 324...",1830000
183,3.578776,1.724485,"[5.558034420013428, 5.61371374130249, 5.645660...",161895.976216,183,1568393000.0,"{'_type': 'histogram', 'values': [1, 0, 2, 0, ...","{'_type': 'histogram', 'values': [2, 2, 0, 0, ...","{'_type': 'histogram', 'values': [1, 1, 0, 0, ...","{'_type': 'histogram', 'values': [1, 1, 5, 1, ...",...,"{'_type': 'histogram', 'values': [1, 0, 0, 0, ...","{'_type': 'histogram', 'values': [1, 0, 0, 0, ...","{'bins': [-0.0013222142588347197, -0.001174632...","{'_type': 'histogram', 'values': [1, 0, 0, 0, ...","{'bins': [-0.04831943288445473, -0.04730383306...",34.074074,3424,314.776111,"[2990, 3440, 3220, 3440, 3400, 3440, 3380, 344...",1840000
184,3.460703,1.698317,"[5.582767963409424, 5.556641101837158, 5.57945...",163168.639314,184,1568394000.0,"{'_type': 'histogram', 'values': [1, 0, 0, 0, ...","{'_type': 'histogram', 'values': [10, 12, 11, ...","{'bins': [-0.018167462199926376, -0.0176646467...","{'_type': 'histogram', 'values': [1, 0, 0, 0, ...",...,"{'_type': 'histogram', 'values': [1, 0, 0, 0, ...","{'_type': 'histogram', 'values': [1, 0, 0, 0, ...","{'_type': 'histogram', 'values': [1, 0, 0, 0, ...","{'values': [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...","{'values': [1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1...",34.259259,2323,1461.923733,"[150, 150, 3550, 3620, 3620, 2790, 3620, 2790,...",1850000
185,3.139777,1.884731,"[5.613146781921387, 5.606122016906738, 5.61173...",164450.582095,185,1568395000.0,"{'bins': [-0.055528827011585236, -0.0533091239...","{'_type': 'histogram', 'values': [4, 4, 2, 3, ...","{'bins': [-0.016371022909879684, -0.0157843567...","{'_type': 'histogram', 'values': [3, 1, 1, 0, ...",...,"{'_type': 'histogram', 'values': [1, 0, 0, 0, ...","{'values': [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...","{'_type': 'histogram', 'values': [1, 0, 0, 0, ...","{'bins': [-0.05430735647678375, -0.05319623276...","{'values': [2, 0, 0, 1, 2, 1, 0, 0, 0, 0, 3, 0...",34.444444,3407,502.136436,"[3240, 3640, 3590, 3580, 1940, 3640, 3660, 357...",1860000
186,3.571087,1.714585,"[5.578207969665527, 5.542852878570557, 5.58575...",165726.656753,186,1568396000.0,"{'_type': 'histogram', 'values': [1, 0, 0, 0, ...","{'_type': 'histogram', 'values': [4, 6, 5, 9, ...","{'bins': [-0.037655558437108994, -0.0364460423...","{'values': [1, 0, 0, 1, 1, 2, 0, 0, 0, 0, 0, 2...",...,"{'_type': 'histogram', 'values': [1, 0, 0, 1, ...","{'_type': 'histogram', 'values': [1, 0, 0, 0, ...","{'_type': 'histogram', 'values': [1, 0, 0, 0, ...","{'_type': 'histogram', 'values': [1, 0, 0, 1, ...","{'_type': 'histogram', 'values': [1, 0, 0, 0, ...",34.62963,3246,792.69414,"[2750, 3160, 3050, 3330, 3050, 3330, 5430, 305...",1870000
187,3.553452,1.826426,"[5.586501121520996, 5.5879340171813965, 5.6301...",167025.028883,187,1568398000.0,"{'_type': 'histogram', 'values': [1, 0, 0, 0, ...","{'_type': 'histogram', 'values': [2, 5, 7, 5, ...","{'_type': 'histogram', 'values': [1, 0, 0, 0, ...","{'_type': 'histogram', 'values': [4, 1, 1, 0, ...",...,"{'values': [1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 2...","{'_type': 'histogram', 'values': [1, 0, 0, 0, ...","{'_type': 'histogram', 'values': [1, 0, 0, 0, ...","{'_type': 'histogram', 'values': [1, 0, 0, 0, ...","{'_type': 'histogram', 'values': [1, 0, 0, 0, ...",34.814815,3193,737.33371,"[2870, 2650, 4090, 3780, 1940, 3450, 3450, 354...",1880000
188,3.775963,1.696744,"[5.5297698974609375, 5.554186820983887, 5.5416...",168320.408289,188,1568399000.0,"{'bins': [-0.014252981171011925, -0.0138447768...","{'_type': 'histogram', 'values': [1, 3, 4, 2, ...","{'_type': 'histogram', 'values': [1, 0, 0, 0, ...","{'_type': 'histogram', 'values': [2, 0, 0, 0, ...",...,"{'_type': 'histogram', 'values': [1, 0, 1, 1, ...","{'_type': 'histogram', 'values': [1, 0, 0, 0, ...","{'_type': 'histogram', 'values': [1, 0, 0, 0, ...","{'values': [1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 2...","{'_type': 'histogram', 'values': [1, 1, 0, 0, ...",35.0,3339,710.738348,"[3950, 2760, 3440, 3580, 1420, 3550, 3580, 358...",1890000
189,3.695615,1.827958,"[5.620333194732666, 5.608715057373047, 5.61887...",169622.187625,189,1568400000.0,"{'_type': 'histogram', 'values': [2, 0, 0, 0, ...","{'_type': 'histogram', 'values': [1, 5, 14, 16...","{'values': [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...","{'_type': 'histogram', 'values': [2, 0, 1, 2, ...",...,"{'values': [1, 0, 0, 1, 1, 2, 3, 1, 2, 1, 2, 2...","{'_type': 'histogram', 'values': [1, 0, 0, 0, ...","{'_type': 'histogram', 'values': [1, 0, 0, 0, ...","{'values': [1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0...","{'_type': 'histogram', 'values': [1, 0, 0, 0, ...",35.185185,3017,611.409028,"[3570, 2270, 2270, 3570, 3460, 2270, 2270, 346...",1900000
190,3.356843,1.699732,"[5.559397220611572, 5.579154014587402, 5.56833...",170921.009495,190,1568402000.0,"{'bins': [-0.08691029250621796, -0.08469506353...","{'_type': 'histogram', 'values': [3, 5, 11, 9,...","{'values': [1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0...","{'_type': 'histogram', 'values': [1, 0, 1, 0, ...",...,"{'_type': 'histogram', 'values': [1, 0, 0, 0, ...","{'values': [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...","{'_type': 'histogram', 'values': [1, 0, 0, 0, ...","{'bins': [-0.07764151692390442, -0.07564348727...","{'_type': 'histogram', 'values': [1, 0, 0, 0, ...",35.37037,2848,992.439419,"[3450, 4230, 3240, 1390, 3190, 3450, 3300, 139...",1910000
191,3.532052,1.707155,"[5.657956123352051, 5.688235759735107, 5.73611...",172247.875376,191,1568403000.0,"{'_type': 'histogram', 'values': [1, 1, 0, 0, ...","{'bins': [-0.02953970618546009, -0.02836538851...","{'bins': [-0.015759935602545738, -0.0151855004...","{'_type': 'histogram', 'values': [1, 0, 5, 4, ...",...,"{'_type': 'histogram', 'values': [1, 0, 0, 0, ...","{'_type': 'histogram', 'values': [1, 1, 1, 0, ...","{'values': [1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0...","{'_type': 'histogram', 'values': [1, 2, 0, 0, ...","{'_type': 'histogram', 'values': [1, 0, 0, 0, ...",35.555556,3493,133.794619,"[3360, 3600, 3570, 3300, 3400, 3500, 3300, 360...",1920000
