In [None]:
%matplotlib inline

In [None]:
import yaml
import sys
import traceback
import logging
import contextlib
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn

from tqdm import tqdm_notebook
from typing import List, Optional, Union
from torch import multiprocessing as mp
from multiprocessing.pool import Pool
from multiprocessing import Queue, Manager

In [None]:
%load_ext autoreload
%autoreload 2

from dqnroute import event_series, NetworkRunner, ConveyorsRunner,\
                     MultiEventSeries, DQNROUTE_LOGGER, TF_MODELS_DIR, LOG_DATA_DIR

In [None]:
logger = logging.getLogger(DQNROUTE_LOGGER)
TORCH_MODELS_DIR = '../torch_models'
LOG_DATA_DIR = '../logs/runs'

np.set_printoptions(linewidth=500)

In [None]:
_legend_txt_replace = {
    'link_state': 'Shortest paths',
    'simple_q': 'Q-routing',
    'pred_q': 'PQ-routing',
    'glob_dyn': 'Global-dynamic',
    'dqn': 'DQN',
    'dqn_oneout': 'DQN (1-out)',
    'dqn_emb': 'DQN-LE',
}

def mk_job_id(router_type, seed):
    return '{}-{}'.format(router_type, seed)

def un_job_id(job_id):
    [router_type, s_seed] = job_id.split('-')
    return router_type, int(s_seed)

def add_avg(df: pd.DataFrame):
    df['avg'] = df['sum'] / df['count']
    return df

def plot_data(data, figsize=(15,5), xlim=None, ylim=None, target='avg',
              xlabel='Время симулятора', ylabel='Среднее время пакета в пути',
              font_size=14, title=None, save_path=None):
    fig = plt.figure(figsize=figsize)
    ax = sns.lineplot(x='time', y=target, hue='router_type', data=data,
                      err_kws={'alpha': 0.1})
    
    handles, labels = ax.get_legend_handles_labels()
    new_labels = list(map(lambda l: _legend_txt_replace.get(l, l), labels[1:]))
    ax.legend(handles=handles[1:], labels=new_labels, fontsize=font_size)
        
    if xlim is not None:
        ax.set_xlim(xlim)
    if ylim is not None:
        ax.set_ylim(ylim)
    if title is not None:
        ax.set_title(title)
    
    ax.set_xlabel(xlabel, fontsize=font_size)
    ax.set_ylabel(ylabel, fontsize=font_size)
    
    plt.show(fig)
    
    if save_path is not None:
        fig.savefig('../img/' + save_path, bbox_inches='tight')

def split_data(dct):
    results = []
    
    def add_res(i, key, val):
        while len(results) <= i:
            results.append({})
        results[i][key] = val
    
    for (key, vals) in dct.items():
        for (i, val) in enumerate(vals):
            add_res(i, key, val)
    return tuple(results)
    
def add_cols(df, **cols):
    for (col, val) in cols.items():
        df.loc[:, col] = val
    
def combine_launch_data(launch_data):
    dfs = []
    for (job_id, data) in launch_data.items():
        router_type, seed = un_job_id(job_id)
        df = data.copy()
        add_cols(df, router_type=router_type, seed=seed)
        dfs.append(df)
    return pd.concat(dfs, axis=0)

In [None]:
class DummyTqdmFile(object):
    """Dummy file-like that will write to tqdm"""
    file = None
    def __init__(self, file):
        self.file = file

    def write(self, x):
        # Avoid print() second call (useless \n)
        if len(x.rstrip()) > 0:
            tqdm.write(x, file=self.file)

    def flush(self):
        return getattr(self.file, "flush", lambda: None)()

@contextlib.contextmanager
def std_out_err_redirect_tqdm():
    orig_out_err = sys.stdout, sys.stderr
    try:
        sys.stdout, sys.stderr = map(DummyTqdmFile, orig_out_err)
        yield orig_out_err[0]
    # Relay exceptions
    except Exception as exc:
        raise exc
    # Always restore sys.stdout/err if necessary
    finally:
        sys.stdout, sys.stderr = orig_out_err

In [None]:
class DummyProgressbarQueue:
    def __init__(self, bar):
        self.bar = bar
        
    def put(self, val):
        _, delta = val
        if delta is not None:
            self.bar.update(delta)

In [None]:
def run_network_scenario_file(file: str, router_type: str, random_seed: int = None,
                              progress_step: Optional[int] = None, progress_queue: Optional[Queue] = None,
                              ignore_saved=False, series_period: int = 500,
                              series_funcs: List[str] = ['count', 'sum', 'min', 'max']):
    """
    Helper wrapper around `NetworkEnvironment` which should run in a separate thread.
    """    
    with open(file) as f:
        run_params = yaml.safe_load(f)
    
    series = event_series(series_period, series_funcs)
    runner = NetworkRunner(run_params=run_params, router_type=router_type, data_series=series)
    
    series = runner.run(random_seed=random_seed, ignore_saved=ignore_saved,
                        progress_step=progress_step, progress_queue=progress_queue)
    
    return add_avg(series.getSeries())

def run_conveyor_scenario_file(file: str, router_type: str, random_seed: int = None,
                               progress_step: Optional[int] = None, progress_queue: Optional[Queue] = None,
                               ignore_saved=False, series_period: int = 500,
                               series_funcs: List[str] = ['count', 'sum', 'min', 'max']):
    """
    Helper wrapper around `ConveyorsEnvironment` which should run in a separate thread.
    """
    with open(file) as f:
        run_params = yaml.safe_load(f)
    
    time_series = event_series(series_period, series_funcs)
    energy_series = event_series(series_period, series_funcs)
    series = MultiEventSeries(time=time_series, energy=energy_series)
    
    runner = ConveyorsRunner(run_params=run_params, router_type=router_type, data_series=series)
    series = runner.run(random_seed=random_seed, ignore_saved=ignore_saved,
                        progress_step=progress_step, progress_queue=progress_queue)
    
    return add_avg(time_series.getSeries()), add_avg(energy_series.getSeries())

def run_single(func, router_type: str, random_seed: int, **kwargs):
    job_id = mk_job_id(router_type, random_seed)
    with tqdm_notebook(desc=job_id) as bar:
        queue = DummyProgressbarQueue(bar)
        results = func(router_type=router_type, random_seed=random_seed,
                       progress_queue=queue, **kwargs) 
     
    if type(results) is tuple:
        for df in results:
            add_cols(df, router_type=router_type, seed=random_seed)
    else:
        add_cols(results, router_type=router_type, seed=random_seed)
    return results

def exc_print(e):
    print(''.join(traceback.format_exception(etype=type(e), value=e, tb=e.__traceback__)))

def run_threaded(func, router_types: List[str], random_seeds: List[int],
                 ignore_saved: Union[bool, List[str]] = [], *args, **kwargs):
    """
    Runs several scenario runners in multiple threads and displays progress bars for them
    """

    pool = Pool()
    m = Manager()
    queue = m.Queue()
    jobs = {}
    bars = {}
    if ignore_saved == True:
        ignore_saved = router_types
    
    for router_type in router_types:
        for seed in random_seeds:
            job_id = mk_job_id(router_type, seed)
            job_args = dict(kwargs, router_type=router_type, random_seed=seed,
                            ignore_saved=router_type in ignore_saved, progress_queue=queue)
            jobs[job_id] = pool.apply_async(func, args=args, kwds=job_args,
                                            error_callback=exc_print)
            bars[job_id] = tqdm_notebook(desc=job_id)

    while len(bars) > 0:
        (job_id, val) = queue.get()
        if val is None:
            bars.pop(job_id).close()
        else:
            bars[job_id].update(val)
        
    results = {job_id: job.get() for (job_id, job) in jobs.items()}
    
    if type(next(iter(results.values()))) is tuple:
        return split_data(results)
    return results

In [None]:
results_no_pretrain = pd.read_csv('../logs/results5_no_pretrain_dqn.csv', index_col=0)
results5_ls = pd.read_csv('../logs/results5_link_state.csv', index_col=0)
res5_comb = combine_launch_data({'link_state-42': results5_ls, 'dqn-42': results_no_pretrain})
plot_data(res5_comb, ylim=(0, 700), xlim=(0, 40000), ylabel='Среднее время в пути')

In [None]:
launch6_data_mult = run_threaded(run_network_scenario_file, random_seeds=[42, 43, 44],
                                 file='../launches/launch6.yaml', router_types=['link_state', 'simple_q', 'dqn', 'dqn_emb'],
                                 ignore_saved=['link_state'], progress_step=500)

In [None]:
launch6_data_comb = combine_launch_data(launch6_data_mult)

In [None]:
plot_data(launch6_data_comb, figsize=(10,6), ylim=(0, 150))

In [None]:
launch8_data = run_threaded(run_network_scenario_file, file='../launches/launch8.yaml',
                            router_types=['link_state', 'simple_q', 'dqn', 'dqn_emb'], progress_step=500,
                            ignore_saved=['link_state', 'simple_q'], random_seeds=[42, 43, 44])

launch8_data_comb = combine_launch_data(launch8_data)

In [None]:
plot_data(launch8_data_comb, figsize=(15, 10), ylim=(35, 140), xlim=(-300, 35000))

In [None]:
launch_calm_data = run_threaded(run_network_scenario_file, file='../launches/launch_long_calm.yaml',
                                router_types=['link_state', 'dqn', 'dqn_emb'], progress_step=500,
                                ignore_saved=[], random_seeds=[42, 43, 44])

In [None]:
plot_data(combine_launch_data(launch_calm_data), xlim=(0, 40000), ylim=(0, 300))

In [None]:
launch_rand_data = run_threaded(run_network_scenario_file, file='../launches/launch_dqn_transfer.yaml',
                                router_types=['link_state', 'dqn', 'dqn_emb'], progress_step=500,
                                ignore_saved=[], random_seeds=[42, 43, 44])

In [None]:
plot_data(combine_launch_data(launch_rand_data), figsize=(10, 6),
          ylim=(20, 150), xlim=(-500, 20000), save_path='learning-transfer-small.pdf')

In [None]:
launch_rand_data_big = run_threaded(run_network_scenario_file, file='../launches/launch_rand_big.yaml',
                                    router_types=['link_state', 'simple_q', 'dqn_emb'], progress_step=500,
                                    ignore_saved=[], random_seeds=[42, 43, 44])

In [None]:
launch_rand_data_comb = combine_launch_data(launch_rand_data_big)
plot_data(launch_rand_data_comb, figsize=(6, 6), ylim=(30, 300), xlim=(-500, 50000),
          save_path='learning-transfer-big-low-load.pdf')

In [None]:
debug_data = run_single(run_conveyor_scenario_file, file='../launches/conveyor_energy_test.yaml',
                        router_type='simple_q', ignore_saved=True, random_seed=42, progress_step=20)

In [None]:
time_data, eng_data = debug_data
plot_data(time_data)
plot_data(eng_data, target='sum', ylabel='Энергия')

In [None]:
conveyor_data_full = run_threaded(run_conveyor_scenario_file, file='../launches/conveyor_energy_test.yaml',
                                  router_types=['link_state', 'simple_q', 'dqn_emb'], progress_step=500,
                                  ignore_saved=[], random_seeds=[42, 43, 44])

In [None]:
conveyor_data_time, conveyor_data_nrg = conveyor_data_full

In [None]:
conveyor_data_time_comb = combine_launch_data(conveyor_data_time)
conveyor_data_nrg_comb = combine_launch_data(conveyor_data_nrg)

plot_data(conveyor_data_time_comb, figsize=(15, 10), font_size=18)
plot_data(conveyor_data_nrg_comb, figsize=(15, 10), font_size=18,
          target='sum', ylabel='Суммарные энергозатраты')

In [None]:
conveyor2_data_full = run_threaded(run_conveyor_scenario_file, file='../launches/conveyor_energy_test_2.yaml',
                                   router_types=['link_state', 'simple_q', 'dqn', 'dqn_emb'], progress_step=500,
                                   ignore_saved=[], random_seeds=[42, 43, 44])

In [None]:
conveyor2_data_time, conveyor2_data_nrg = conveyor2_data_full
conveyor2_data_time_comb = combine_launch_data(conveyor2_data_time)
conveyor2_data_nrg_comb = combine_launch_data(conveyor2_data_nrg)

plot_data(conveyor2_data_time_comb, figsize=(10, 10), ylim=(40, 100), font_size=18,
          save_path='conveyors2-late-time.pdf')
plot_data(conveyor2_data_nrg_comb, figsize=(10, 10), font_size=18, ylim=(1000, 2700),
          target='sum', ylabel='Суммарные энергозатраты', save_path='conveyors2-late-energy.pdf')