In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import pymongo
import os, sys, yaml
import numpy as np
from scipy import stats

root_dir = "/home/marilena/PycharmProjects/3rdyearproject/pymarl-dev/src"

In [2]:
def get_default_config():
    with open(root_dir + "/config/default.yaml", 'r') as stream:
        try:
            return yaml.safe_load(stream)
        except yaml.YAMLError as exc:
            print(exc)

In [3]:
def get_mongo_db_client(db_url, db_name):   
    maxSevSelDelay = 10000  # Assume 10s maximum server selection delay
    
    client = pymongo.MongoClient(db_url, ssl=True, serverSelectionTimeoutMS=maxSevSelDelay)
    
    print(client)
    
    return client, client[db_name]

In [4]:
config = get_default_config()

mongo_client, mongo_db = get_mongo_db_client(config["db_url"], config["db_name"])

MongoClient(host=['gandalf.cs.ox.ac.uk:27017'], document_class=dict, tz_aware=False, connect=True, ssl=True, serverselectiontimeoutms=10000)


In [None]:
possible_experiments = ["coma_", "coma_vdn_", "centralV_", "coma_qmix_", "coma_nb_", \
                        "coma_qmix_ns_", "coma_vdn_nb_", "coma_vdn_vb_", "coma_fql_", \
                       "coma_vdn_vmb_", "centralV_vdn_"]
lambdas = ["03", "05", "07", "09"] #"01", "03", "05", "07", "08", "09"

game_step_counts = {
    "sh": 1500000,
    "sc_2s3z": 2000000,
    "sc_3m": 2000000,
    "sc_2m_vs_1z": 2000000,
    "sc_25m": 10000000,
    "sc_8m": 2000000,
    "sc_smb": 2000000,
    "sc_bvb": 2000000,
    "sc_3s5z": 2000000
}

game_titles = {
    "sh": "Stag Hunt",
    "sc_2s3z": "Starcraft, map 2s_3z",
    "sc_3m": "Starcraft, map 3m",
    "sc_2m_vs_1z": "Starcraft, map 2m_vs_1z",
    "sc_25m": "Starcraft, map 25m",
    "sc_8m": "Starcraft, map 8m",
    "sc_smb": "Starcraft, so many banelings",
    "sc_bvb": "Starcraft, bane vs bane",
    "sc_3s5z": "Starcraft, map 3s5z"
}

axes_labels = {
    "test_return_mean": "Test Return Mean",
    "test_battle_won_mean": "Test Battle Won Mean %"
}

DATAPOINTS_COUNT = 5001

all_colors = ["red", "blue", "green", "orange", "purple", "brown", "pink", "yellow"]
exp_color_map = {
    "coma_vdn_": "blue",
    "coma_": "red",
    "centralV_": "green",
    "coma_qmix_": "orange",
    "coma_nb_": "brown",
    "coma_qmix_ns_": "pink",
    "coma_vdn_nb_": "orange",
    "coma_vdn_vb_": "purple",
    "coma_fql_": "magenta",
    "coma_vdn_vmb_": "magenta",
    "centralV_vdn_": "magenta"
}

In [5]:
def get_exp_ids_from_db(config_db):    
    results = []
    
    query = mongo_db["runs"].find(config_db)
    results.extend([q["config"] for q in query])
    
    eliminate_known_bad_exp_names()
    print(results)
    return results

In [9]:
get_exp_ids_from_db(get_config_for_db(None,None, None))

[{'critic_baseline_fn': 'coma', 'critic_train_mode': 'seq', 'save_model': False, 'lr': 0.0005, 'mixer_lr': 0.0005, 'save_replay': False, 'learner': 'coma_learner', 'optim_eps': 1e-05, 'optim_alpha': 0.99, 'env_args': {'p_hare_rest': 0.0, 'truncate_episodes': True, 'n_agents': 4, 'world_shape': [6, 6], 'print_caught_prey': False, 'p_stags_rest': 0.0, 'intersection_unknown': False, 'observe_one_hot': False, 'agent_obs': [2, 2], 'state_last_action': False, 'reward_stag': 10, 'observe_ids': False, 'intersection_global_view': False, 'reward_collision': 0.0, 'capture_terminal': True, 'reward_hare': 5, 'n_hare': 1, 'toroidal': False, 'observe_walls': False, 'mountain_slope': 0.0, 'capture_conditions': [0, 1], 'episode_limit': 100, 'reward_time': -0.1, 'n_stags': 1}, 'action_selector': 'multinomial', 'mixing_embed_dim': 32, 't_max': 1510000, 'load_step': 0, 'use_cuda': True, 'batch_size_run': 8, 'critic_lr': 0.0005, 'critic_train_reps': 1, 'critic_q_fn': 'coma', 'epsilon_finish': 0.01, 'epsilo

[{'critic_baseline_fn': 'coma',
  'critic_train_mode': 'seq',
  'save_model': False,
  'lr': 0.0005,
  'mixer_lr': 0.0005,
  'save_replay': False,
  'learner': 'coma_learner',
  'optim_eps': 1e-05,
  'optim_alpha': 0.99,
  'env_args': {'p_hare_rest': 0.0,
   'truncate_episodes': True,
   'n_agents': 4,
   'world_shape': [6, 6],
   'print_caught_prey': False,
   'p_stags_rest': 0.0,
   'intersection_unknown': False,
   'observe_one_hot': False,
   'agent_obs': [2, 2],
   'state_last_action': False,
   'reward_stag': 10,
   'observe_ids': False,
   'intersection_global_view': False,
   'reward_collision': 0.0,
   'capture_terminal': True,
   'reward_hare': 5,
   'n_hare': 1,
   'toroidal': False,
   'observe_walls': False,
   'mountain_slope': 0.0,
   'capture_conditions': [0, 1],
   'episode_limit': 100,
   'reward_time': -0.1,
   'n_stags': 1},
  'action_selector': 'multinomial',
  'mixing_embed_dim': 32,
  't_max': 1510000,
  'load_step': 0,
  'use_cuda': True,
  'batch_size_run': 8,


In [8]:
def get_config_for_db(game, exp, special_config):
    base_config = {
        "config.name": {'$regex': '^exp_marilena'}
    }
    return base_config

In [None]:
def plot_means(game, type_of_data, experiments, percentile=False, conf_inter=False, \
               figure_file_name=None, special_config=None, labels):
    for exp in experiments:
        assert exp in possible_experiments
    
    step_count = game_step_counts[game]
    timesteps = np.linspace(0, step_count, DATAPOINTS_COUNT)
    
    fig, axes = plt.subplots(1, 1, figsize=(20, 10))
    fig.suptitle(game_titles[game], fontsize=20)
    
    assert len(labels) == len(experiments)
    
    if special_config is not None:
        assert len(experiments) == len(special_config)
    
    for idx, exp in enumerate(experiments):
        color = exp_color_map[exp]
        
        config_db = get_config_for_db(game, exp, special_config[idx])
        
        exp_ids = get_exp_ids_from_db(config_db)
        
        if len(exp_ids) == 0:
            print("No experiments with " + exp + game)
            continue
            
        label = labels[idx].format(len(exp_ids))
        print("Plotting experiments with " + exp + game)
        
        means, curr_timesteps = get_means_exp_data(game, exp_ids, type_of_data)
        
        if conf_inter:
            sems, curr_timesteps = get_sems_exp_data(game, exp_ids, type_of_data)
            confidence_intervals_low, confidence_intervals_up = get_confidence_intervals(means, sems)
            axes.fill_between(timesteps, confidence_intervals_low, confidence_intervals_up, facecolor=color, alpha=0.3)
        
        if percentile:
            p25, p75 = get_percentile_threshold_exp_data(game, exp_ids, type_of_data)
            axes.fill_between(timesteps, p25, p75, facecolor=color, alpha=0.3)
        
        axes.plot(timesteps, means, color=color, label=label)
        axes.set_ylabel(axes_labels[type_of_data], fontsize=14)
    
        axes.legend(loc="upper left")
        if figure_file_name is not None:
            plt.savefig(figure_file_name)