In [2]:
import numpy as np
from scipy import stats

In [3]:
from tqdm import tqdm_notebook

In [4]:
import pprint

In [5]:
%matplotlib inline
import matplotlib.pyplot as plt
import pymongo
import os, sys
root_dir = "/home/tabz/Coding/pymarl/src"
sys.path = [root_dir] + sys.path

In [6]:
from bokeh.io import push_notebook, show, output_notebook, export_svgs, export_png
from bokeh.layouts import column, row, gridplot
from bokeh.plotting import figure
from bokeh.models import Band, Span, Range1d
from bokeh.models.sources import ColumnDataSource
from bokeh.models.formatters import NumeralTickFormatter
from bokeh.palettes import all_palettes, magma, Set3
from bokeh.models import Range1d
from bokeh.models import FixedTicker
output_notebook()

In [7]:
# from config.mongodb import REGISTRY as mongo_REGISTRY
def get_mongo_db_client(conf_name, maxSevSelDelay=5000):    
#     mongo_conf = mongo_REGISTRY[conf_name](None, None)
    # Hardcode for now
    db_url = "mongodb://pymarlOwner:EMC7Jp98c8rE7FxxN7g82DT5spGsVr9A@gandalf.cs.ox.ac.uk:27017/pymarl"
    db_name = "pymarl"
    client = pymongo.MongoClient(db_url, ssl=True, serverSelectionTimeoutMS=maxSevSelDelay)
    return client, client[db_name]

In [8]:
class MongoCentral():
    
    def __init__(self, *args, **kwargs):
        self.conf_names = kwargs["conf_names"]
        self.db = {}
        self._connect(self.conf_names)
        
    def _connect(self, conf_names):
        self.clients = {}
        for _name in conf_names:
            self.clients[_name], self.db[_name] = get_mongo_db_client(_name)
            print("Connected to {}".format(_name))
            
    def get_exp_names(self):
        #print("Loading keys...")
        names = []
        for key, db in self.db.items():
            query = db["runs"].distinct("config.name") # .find({"config":None})
            names.extend(query)
            print("Done Loading...")
        return names
    
    def get_config_and_info(self, label, keys_to_return):
        queries = []
        things_to_return = {"config"}
        for k in keys_to_return:
            things_to_return.add("info.{}".format(k))
        print("Things to return:", things_to_return)
        for key, db in self.db.items():
            print("Retreiving info from {}".format(key))
            query = db.runs.find({"config.label": label}, things_to_return)
            queries.extend(query)
        return queries
    
    def get_config_and_info_all(self, label):
        queries = []
        for key, db in self.db.items():
            print("Retreiving info from {}".format(key))
            query = db.runs.find({"config.label": label}, {"config", "info"})
            queries.extend(query)
        return queries
    
    def get_tag_names(self, tag, bundle=True):
        import re
        names = []
        for key, db in self.db.items():            
            query = db.runs.find({"config.name":{'$regex':r'^{}(.*)'.format(tag)}}, {"config.name":1}) # .find({"config":None})
            names.extend([_q["config"]["name"] for _q in query])
            print("Done Loading...")
            
        if bundle: # bundle by experiment name
            bundle_dic = {}
            for name in names:
#                 print(name)
#                 tag, exp_name_time_stamp, repeat = name.split("__")
#                 exp_name = "_".join(exp_name_time_stamp.split("_")[:-1])
                exp_name = name
                if exp_name not in bundle_dic:
                    bundle_dic[exp_name] = []
                bundle_dic[exp_name].append(name) 
            return bundle_dic
        return names

    def get_name_prop(self, name, prop):
        res = []
        for key, db in self.db.items():
            query = db.runs.find({"config.name":name}, {prop:1})
            for _q in query:
                res.append(_q)
        return res

In [9]:
def ewma(x, alpha):
    '''
    Returns the exponentially weighted moving average of x.

    Parameters:
    -----------
    x : array-like
    alpha : float {0 <= alpha <= 1}

    Returns:
    --------
    ewma : numpy array
        the exponentially weighted moving average
    '''
    # coerce x to an array
    x = np.array(x)
    n = x.size
    # create an initial weight matrix of (1-alpha), and a matrix of powers
    # to raise the weights by
    w0 = np.ones(shape=(n,n)) * (1-alpha) 
    p = np.vstack([np.arange(i,i-n,-1) for i in range(n)])
    # create the weight matrix
    w = np.tril(w0**p,0)
    # calculate the ewma
    return np.dot(w, x[::np.newaxis]) / w.sum(axis=1)

In [10]:
def mean_confidence_interval(data, confidence=0.95):
    a = np.array(data)
    n = a.shape[0]
    m, se = a.mean(axis=0), stats.sem(a, axis=0)
    h = se * stats.t._ppf((1+confidence)/2., n-1)
    return m, h

In [11]:
def get_stats(y_key):
    means = {}
    stds = {}
    nums = {}
    for key in tqdm_notebook(keys, desc="params", leave=False):
        runs = data[key][:limit]
#         print(runs)
        interps = []

        for run in tqdm_notebook(runs, desc="runs", leave=False):
            
            if "{}_T".format(y_key) not in run:
                print("Field: {} not in run for {}".format(y_key, key))
                continue
            
            xs = run["{}_T".format(y_key)]
            ys = run[y_key]
            
#             xs = [a["py/tuple"][0] for a in run]
#             ys = [a["py/tuple"][1] for a in run]
            
            # Check they are long enough
            if xs[-1] < t_needed:
                if print_not_long_enough:
                    print("Run not long enough: {} for key: {}".format(xs[-1], key))
                continue

            # Align them
#             diffs = len(xs) - len(ys)
#             xs = xs[diffs:]

            y_interp = np.interp(x_interp, xs, ys)
#             smoother = 0.25
            y_interp_smoothed = ewma(y_interp, smoother)
    #         y_interp_smoothed = y_interp
            interps.append(y_interp_smoothed)
        
        if interps == []:
            continue

        joined_array = np.array(interps)
        if confidence_interval:
            mean, std = mean_confidence_interval(joined_array, confidence=confidence_interval)
        else:
            mean = np.mean(joined_array, axis=0)
            std = np.std(joined_array, axis=0)
        means[key] = mean
        stds[key] = std
        nums[key] = len(interps)
    return means, stds, nums

In [12]:
def plot(means, stds, nums, x_label, y_label, t_max, indivs=False):
    
    x_vals_interp = x_interp
          
    keys =  means.keys()
    p = figure(plot_width=1400, plot_height=800, x_range=[0, t_max])

    y_min = 0
    y_max = 1

    num_lines = len(keys)
    if num_lines < 3:
        magma_cols = ["red", "green"][:num_lines]
    else:
        magma_cols = Set3[num_lines]
    colors = {}
    for key, col in zip(keys, magma_cols):
        colors[key] = col

    for key in keys:

        color = colors[key]
        name = key

        p.line(x_vals_interp, means[key], color=color, legend=name + " [" + str(nums[key]) + "]", line_width=5)

        xs = list(x_vals_interp) + list(reversed(x_vals_interp))

        mm = means[key]
        ss = stds[key]
        ys = np.concatenate([mm - ss, np.flip(mm + ss, axis=0)])
        lls = p.patch(xs, ys, color=color, alpha=0.2)

        y_max = max(y_max, max(means[key]))
        y_min = min(y_min, min(means[key]))
        
        if indivs:
            for run in data[key][:limit]:
                if "{}_T".format(y_label) not in run:
                    continue

                xs = run["{}_T".format(y_label)]
                ys = run[y_label]

                p.line(xs, ys, color=color, line_width=1, alpha=0.4)

                y_max = max(y_max, max(ys))
                y_min = min(y_min, min(ys))
        
    p.y_range = Range1d(y_min * 1.1, y_max * 1.1)

    p.grid.grid_line_width=2
    p.grid.minor_grid_line_width=2
    p.grid.grid_line_color=(1,1,1,0.1)

    p.grid.minor_grid_line_color=(1,1,1,0.1)
    p.xgrid.minor_grid_line_color=None

    p.yaxis.formatter = NumeralTickFormatter(format="0.0a")
    p.yaxis[0].ticker.desired_num_ticks = 5
    p.yaxis[0].ticker.num_minor_ticks = 2

    p.xaxis.formatter = NumeralTickFormatter(format="0.0a")

    p.legend.location = (0,0)

    p.xaxis.axis_label= x_label
    p.yaxis.axis_label= y_label

#     p.xaxis.axis_label_standoff=-20
#     p.yaxis.axis_label_standoff=-30

    p.xaxis.axis_label_text_font_style="normal"
    p.yaxis.axis_label_text_font_style="normal"

    p.xaxis.axis_label_text_font_size="26pt"
    p.yaxis.axis_label_text_font_size="26pt"

    p.xaxis.major_label_text_font_size="24pt"
    p.yaxis.major_label_text_font_size="24pt"

    p.legend.label_text_font_size="16pt"
    p.legend.visible=True
    p.legend.location="bottom_right"
#     p.legend.orientation="horizontal"
#     p.legend.glyph_width=60
#     p.legend.glyph_height=100

    p.title.text_font_size="24pt"

    p.h_symmetry = False
    p.min_border_right = 70
    p.min_border_left = 0
    p.min_border_bottom = 0
#     p.yaxis.axis_label_text_line_height = 1
#     p.yaxis.bounds = (400,400)
#     p.yaxis.bounds
#     p.yaxis.axis_label_text_baseline="top"
#     p.yaxis.axis_label_text_align="center"
    p.min_border_top = 30
#     p.border_fill_color = (1,1,1,0.1)

    p.axis.axis_line_width = 3
    p.axis.major_tick_line_width = 5
    p.axis.major_tick_in = 10
    p.axis.major_tick_out = 5

    show(p, notebook_handle=True)

In [13]:
mongo_central = MongoCentral(conf_names=["gandalf_pymarl"])

Connected to gandalf_pymarl


In [14]:
label = "QMIX_Refactor_Test_2"

In [15]:
exps = mongo_central.get_config_and_info_all(label)

Retreiving info from gandalf_pymarl


In [16]:
# Keys
info_keys = set()
for exp in exps:
    exp_keys = exp["info"].keys()
    for key in exp_keys:
        info_keys.add(key)
print("Keys in info:")
pprint.pprint(list(filter(lambda x: not x.endswith("_T"), sorted(info_keys))))

Keys in info:
['battles_draw',
 'battles_game',
 'battles_won',
 'ep_length',
 'episode',
 'epsilon',
 'grad_norm',
 'loss',
 'mean_q_value',
 'mean_target',
 'mean_test_return',
 'restarts',
 'std_test_return',
 'td_error',
 'test_mean_battles_draw',
 'test_mean_battles_game',
 'test_mean_battles_won',
 'test_mean_restarts',
 'test_mean_timeouts',
 'test_mean_win_rate',
 'test_return',
 'timeouts',
 'train_return',
 'win_rate']


In [17]:
params = ["name"]

t_max = 2 * 1000 * 1000
t_needed = t_max - 50000

limit = 50
x_key = "T env" # x-axis label
x_interp = np.linspace(0, t_max, 2000)

confidence_interval = False
smoother = 0.25

In [18]:
keys = set()
data = {}
for exp in exps:
#     print(exp)
    exp_config = exp["config"]
    if exp["info"] == {}:
        continue
    exp_info = exp["info"]
#     if exp_info == {}:
#         continue
    params_str = "__".join(["{}-{}".format(param, exp_config[param]) for param in params])
    keys.add(params_str)
    if params_str not in data:
        data[params_str] = [exp_info]
    else:
        data[params_str].append(exp_info)
        
print("Keys: ", keys)

Keys:  {'name-vdn_sc2_3m', 'name-iql_sc2_5m', 'name-qmix_sc2_5m', 'name-iql_sc2_3m', 'name-qmix_sc2_3m', 'name-vdn_sc2_5m'}


In [19]:
data.keys()

dict_keys(['name-iql_sc2_3m', 'name-iql_sc2_5m', 'name-vdn_sc2_3m', 'name-vdn_sc2_5m', 'name-qmix_sc2_3m', 'name-qmix_sc2_5m'])

In [20]:
print_not_long_enough = False

In [21]:
keys_to_plot = ["test_mean_win_rate", "loss", "win_rate", "ep_length", "train_return", "mean_test_return"]

In [22]:
for y_key in tqdm_notebook(keys_to_plot, desc="keys", leave=False):
    x_key = "T env"
    m, s, n = get_stats(y_key)
    plot(m, s, n, x_key, y_key, t_max, indivs=True)

HBox(children=(IntProgress(value=0, description='keys', max=6), HTML(value='')))

HBox(children=(IntProgress(value=0, description='params', max=6), HTML(value='')))

HBox(children=(IntProgress(value=0, description='runs', max=20), HTML(value='')))

HBox(children=(IntProgress(value=0, description='runs', max=20), HTML(value='')))

HBox(children=(IntProgress(value=0, description='runs', max=20), HTML(value='')))

HBox(children=(IntProgress(value=0, description='runs', max=20), HTML(value='')))

HBox(children=(IntProgress(value=0, description='runs', max=20), HTML(value='')))

HBox(children=(IntProgress(value=0, description='runs', max=19), HTML(value='')))

HBox(children=(IntProgress(value=0, description='params', max=6), HTML(value='')))

HBox(children=(IntProgress(value=0, description='runs', max=20), HTML(value='')))

HBox(children=(IntProgress(value=0, description='runs', max=20), HTML(value='')))

HBox(children=(IntProgress(value=0, description='runs', max=20), HTML(value='')))

HBox(children=(IntProgress(value=0, description='runs', max=20), HTML(value='')))

HBox(children=(IntProgress(value=0, description='runs', max=20), HTML(value='')))

HBox(children=(IntProgress(value=0, description='runs', max=19), HTML(value='')))

HBox(children=(IntProgress(value=0, description='params', max=6), HTML(value='')))

HBox(children=(IntProgress(value=0, description='runs', max=20), HTML(value='')))

HBox(children=(IntProgress(value=0, description='runs', max=20), HTML(value='')))

HBox(children=(IntProgress(value=0, description='runs', max=20), HTML(value='')))

HBox(children=(IntProgress(value=0, description='runs', max=20), HTML(value='')))

HBox(children=(IntProgress(value=0, description='runs', max=20), HTML(value='')))

HBox(children=(IntProgress(value=0, description='runs', max=19), HTML(value='')))

HBox(children=(IntProgress(value=0, description='params', max=6), HTML(value='')))

HBox(children=(IntProgress(value=0, description='runs', max=20), HTML(value='')))

HBox(children=(IntProgress(value=0, description='runs', max=20), HTML(value='')))

HBox(children=(IntProgress(value=0, description='runs', max=20), HTML(value='')))

HBox(children=(IntProgress(value=0, description='runs', max=20), HTML(value='')))

HBox(children=(IntProgress(value=0, description='runs', max=20), HTML(value='')))

HBox(children=(IntProgress(value=0, description='runs', max=19), HTML(value='')))

HBox(children=(IntProgress(value=0, description='params', max=6), HTML(value='')))

HBox(children=(IntProgress(value=0, description='runs', max=20), HTML(value='')))

HBox(children=(IntProgress(value=0, description='runs', max=20), HTML(value='')))

HBox(children=(IntProgress(value=0, description='runs', max=20), HTML(value='')))

HBox(children=(IntProgress(value=0, description='runs', max=20), HTML(value='')))

HBox(children=(IntProgress(value=0, description='runs', max=20), HTML(value='')))

HBox(children=(IntProgress(value=0, description='runs', max=19), HTML(value='')))

HBox(children=(IntProgress(value=0, description='params', max=6), HTML(value='')))

HBox(children=(IntProgress(value=0, description='runs', max=20), HTML(value='')))

HBox(children=(IntProgress(value=0, description='runs', max=20), HTML(value='')))

HBox(children=(IntProgress(value=0, description='runs', max=20), HTML(value='')))

HBox(children=(IntProgress(value=0, description='runs', max=20), HTML(value='')))

HBox(children=(IntProgress(value=0, description='runs', max=20), HTML(value='')))

HBox(children=(IntProgress(value=0, description='runs', max=19), HTML(value='')))

