In [None]:
%matplotlib inline

In [None]:
import yaml
import sys
import traceback
import logging
import contextlib
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn

from tqdm import tqdm_notebook
from typing import List, Optional, Union
from torch import multiprocessing as mp
from multiprocessing.pool import Pool
from multiprocessing import Queue, Manager

In [None]:
%load_ext autoreload
%autoreload 2

from dqnroute import *

In [None]:
logger = logging.getLogger(DQNROUTE_LOGGER)
TORCH_MODELS_DIR = '../torch_models'
LOG_DATA_DIR = '../logs/runs'

np.set_printoptions(linewidth=500)

In [None]:
df = pd.DataFrame(columns=['kek', 'mda', 'time_foo'])
df.columns.str.startswith('time')

In [None]:
_legend_txt_replace = {
    'networks': {
    'link_state': 'Shortest paths',
    'simple_q': 'Q-routing',
    'pred_q': 'PQ-routing',
    'glob_dyn': 'Global-dynamic',
    'dqn': 'DQN',
    'dqn_oneout': 'DQN (1-out)',
    'dqn_emb': 'DQN-LE',
    'centralized_simple': 'Centralized control'
    },
    'conveyors': {
    'link_state': 'Vyatkin-Black',
    'simple_q': 'Q-routing',
    'pred_q': 'PQ-routing',
    'glob_dyn': 'Global-dynamic',
    'dqn': 'DQN',
    'dqn_oneout': 'DQN (1-out)',
    'dqn_emb': 'DQN-LE',
    'centralized_simple': 'BSR'
    }
}

_targets = {'time': 'avg','energy': 'sum', 'collisions': 'sum'}

_ylabels = {
    'time': 'Среднее время в пути',
    'energy': 'Суммарные энергозатраты',
    'collisions': 'Столкновения сумок'
}

def print_sums(df):
    types = set(df['router_type'])
    for tp in types:
        x = df.loc[df['router_type']==tp, 'count'].sum()
        txt = _legend_txt_replace.get(tp, tp)
        print('  {}: {}'.format(txt, x))

def plot_data(data, meaning='time', figsize=(15,5), xlim=None, ylim=None,
              xlabel='Время симулятора', ylabel=None,
              font_size=14, title=None, save_path=None,
              draw_collisions=False, context='networks', **kwargs):
    if 'time' not in data.columns:
        datas = split_dataframe(data, preserved_cols=['router_type', 'seed'])
        good_time = None
        for tag, df in datas:
            if tag == 'collisions' and not draw_collisions:
                print('Количество столкновений:')
                print_sums(df)
                continue
                
            xlim = kwargs.get(tag+'_xlim', xlim)
            ylim = kwargs.get(tag+'_ylim', ylim)
            save_path = kwargs.get(tag+'_save_path', save_path)
            plot_data(df, meaning=tag, figsize=figsize, xlim=xlim, ylim=ylim,
                      xlabel=xlabel, ylabel=ylabel, font_size=font_size,
                      title=title, save_path=save_path, context='conveyors')
        return 
    
    target = _targets[meaning]
    if ylabel is None:
        ylabel = _ylabels[meaning]
        
    fig = plt.figure(figsize=figsize)
    ax = sns.lineplot(x='time', y=target, hue='router_type', data=data,
                      err_kws={'alpha': 0.1})
    
    handles, labels = ax.get_legend_handles_labels()
    new_labels = list(map(lambda l: _legend_txt_replace[context].get(l, l), labels[1:]))
    ax.legend(handles=handles[1:], labels=new_labels, fontsize=font_size)
        
    if xlim is not None:
        ax.set_xlim(xlim)
    if ylim is not None:
        ax.set_ylim(ylim)
    if title is not None:
        ax.set_title(title)
    
    ax.set_xlabel(xlabel, fontsize=font_size)
    ax.set_ylabel(ylabel, fontsize=font_size)
    
    plt.show(fig)
    
    if save_path is not None:
        fig.savefig('../img/' + save_path, bbox_inches='tight')

def split_data(dct):
    results = []
    
    def add_res(i, key, val):
        while len(results) <= i:
            results.append({})
        results[i][key] = val
    
    for (key, vals) in dct.items():
        for (i, val) in enumerate(vals):
            add_res(i, key, val)
    return tuple(results)
    
def combine_launch_data(launch_data):
    dfs = []
    for (job_id, data) in launch_data.items():
        router_type, seed = un_job_id(job_id)
        df = data.copy()
        add_cols(df, router_type=router_type, seed=seed)
        dfs.append(df)
    return pd.concat(dfs, axis=0)

In [None]:
class DummyTqdmFile(object):
    """Dummy file-like that will write to tqdm"""
    file = None
    def __init__(self, file):
        self.file = file

    def write(self, x):
        # Avoid print() second call (useless \n)
        if len(x.rstrip()) > 0:
            tqdm.write(x, file=self.file)

    def flush(self):
        return getattr(self.file, "flush", lambda: None)()

@contextlib.contextmanager
def std_out_err_redirect_tqdm():
    orig_out_err = sys.stdout, sys.stderr
    try:
        sys.stdout, sys.stderr = map(DummyTqdmFile, orig_out_err)
        yield orig_out_err[0]
    # Relay exceptions
    except Exception as exc:
        raise exc
    # Always restore sys.stdout/err if necessary
    finally:
        sys.stdout, sys.stderr = orig_out_err

In [None]:
def run_network_scenario_file(file: str, **kwargs):
    """
    Helper wrapper around `NetworkEnvironment` which should run in a separate thread.
    """        
    return run_simulation(NetworkRunner, run_params=file, **kwargs)

def run_conveyor_scenario_file(file: str, **kwargs):
    """
    Helper wrapper around `ConveyorsEnvironment` which should run in a separate thread.
    """
    return run_simulation(ConveyorsRunner, run_params=file, **kwargs)

def run_single(func, router_type: str, random_seed: int, **kwargs):
    job_id = mk_job_id(router_type, random_seed)
    with tqdm_notebook(desc=job_id) as bar:
        queue = DummyProgressbarQueue(bar)
        results = func(router_type=router_type, random_seed=random_seed,
                       progress_queue=queue, **kwargs) 
     
    if type(results) is tuple:
        for df in results:
            add_cols(df, router_type=router_type, seed=random_seed)
    else:
        add_cols(results, router_type=router_type, seed=random_seed)
    return results

def exc_print(e):
    print(''.join(traceback.format_exception(etype=type(e), value=e, tb=e.__traceback__)))

def run_threaded(func, router_types: List[str], random_seeds: List[int],
                 ignore_saved: Union[bool, List[str]] = [], *args, **kwargs):
    """
    Runs several scenario runners in multiple threads and displays progress bars for them
    """

    pool = Pool()
    m = Manager()
    queue = m.Queue()
    jobs = {}
    bars = {}
    if ignore_saved == True:
        ignore_saved = router_types
    
    for router_type in router_types:
        for seed in random_seeds:
            job_id = mk_job_id(router_type, seed)
            job_args = dict(kwargs, router_type=router_type, random_seed=seed,
                            ignore_saved=router_type in ignore_saved, progress_queue=queue)
            jobs[job_id] = pool.apply_async(func, args=args, kwds=job_args,
                                            error_callback=exc_print)
            bars[job_id] = tqdm_notebook(desc=job_id)

    while len(bars) > 0:
        (job_id, val) = queue.get()
        if val is None:
            bars.pop(job_id).close()
        else:
            bars[job_id].update(val)
        
    results = {job_id: job.get() for (job_id, job) in jobs.items()}
    return combine_launch_data(results)

In [None]:
results_no_pretrain = pd.read_csv('../logs/results5_no_pretrain_dqn.csv', index_col=0)
results5_ls = pd.read_csv('../logs/results5_link_state.csv', index_col=0)
res5_comb = combine_launch_data({'link_state-42': results5_ls, 'dqn-42': results_no_pretrain})
plot_data(res5_comb, ylim=(0, 700), xlim=(0, 40000), ylabel='Среднее время в пути')

In [None]:
launch6_data = run_threaded(run_network_scenario_file, random_seeds=[42, 43, 44],
                            file='../launches/launch6.yaml', router_types=['link_state', 'simple_q', 'dqn', 'dqn_emb'],
                            ignore_saved=[], progress_step=500)

In [None]:
plot_data(launch6_data, figsize=(10,6), ylim=(0, 150))написать такой отзыв?

In [None]:
launch8_data = run_threaded(run_network_scenario_file, file='../launches/launch8.yaml',
                            router_types=['link_state', 'simple_q', 'dqn', 'dqn_emb'], progress_step=500,
                            ignore_saved=['link_state', 'simple_q'], random_seeds=[42, 43, 44])

In [None]:
plot_data(launch8_data, figsize=(15, 10), ylim=(35, 140), xlim=(-300, 35000))

In [None]:
написать такой отзыв?launch_calm_data = run_threaded(run_network_scenario_file, file='../launches/launch_long_calm.yaml',
                                router_types=['link_state', 'dqn', 'dqn_emb'], progress_step=500,
                                ignore_saved=[], random_seeds=[42, 43, 44])

In [None]:
plot_data(combine_launch_data(launch_calm_data), xlim=(0, 40000), ylim=(0, 300))

In [None]:
launch_rand_data = run_threaded(run_network_scenario_file, file='../launches/launch_dqn_transfer.yaml',
                                router_types=['link_state', 'dqn', 'dqn_emb'], progress_step=500,
                                ignore_saved=[], random_seeds=[42, 43, 44])

In [None]:
plot_data(combine_launch_data(launch_rand_data), figsize=(10, 6),
          ylim=(20, 150), xlim=(-500, 20000), save_path='learning-transfer-small.pdf')

In [None]:
launch_rand_data_big = run_threaded(run_network_scenario_file, file='../launches/launch_rand_big.yaml',
                                    router_types=['link_state', 'simple_q', 'dqn_emb'], progress_step=500,
                                    ignore_saved=[], random_seeds=[42, 43, 44])

In [None]:
launch_rand_data_comb = combine_launch_data(launch_rand_data_big)
plot_data(launch_rand_data_comb, figsize=(6, 6), ylim=(30, 300), xlim=(-500, 50000),
          save_path='learning-transfer-big-low-load.pdf')

In [None]:
debug_data = run_single(run_conveyor_scenario_file, file='../launches/conveyor_energy_test_2.yaml',
                        router_type='simple_q', ignore_saved=True,
                        random_seed=42, progress_step=100)

In [None]:
plot_data(debug_data)

In [None]:
conveyor_data_full = run_threaded(run_conveyor_scenario_file, file='../launches/conveyor_energy_test.yaml',
                                  router_types=['link_state', 'simple_q', 'dqn_emb', 'centralized_simple'], progress_step=500,
                                  ignore_saved=['centralized_simple'], random_seeds=[42, 43, 44])

In [None]:
plot_data(conveyor_data_full, figsize=(10, 10), font_size=16, time_ylim=(40, 65), energy_ylim=(7e6, 2e7),
          time_save_path='conveyors-new-1-time.pdf', energy_save_path='conveyors-new-1-energy.pdf')

In [None]:
conveyor2_data_full = run_threaded(run_conveyor_scenario_file, file='../launches/conveyor_energy_test_2.yaml',
                                   router_types=['link_state', 'simple_q', 'dqn_emb', 'centralized_simple'], progress_step=500,
                                   ignore_saved=[], random_seeds=[42, 43, 44])

In [None]:
plot_data(conveyor2_data_full, figsize=(10, 10), font_size=16, time_ylim=(40, 65), energy_ylim=(7e6, 2.8e7),
          time_save_path='conveyors-new-2-time.pdf', energy_save_path='conveyors-new-2-energy.pdf')

In [None]:
conveyor2_data_time, conveyor2_data_nrg = conveyor2_data_full
conveyor2_data_time_comb = combine_launch_data(conveyor2_data_time)
conveyor2_data_nrg_comb = combine_launch_data(conveyor2_data_nrg)

plot_data(conveyor2_data_time_comb, figsize=(10, 10), ylim=(40, 100), font_size=18,
          save_path='conveyors2-late-time.pdf')
plot_data(conveyor2_data_nrg_comb, figsize=(10, 10), font_size=18, ylim=(1000, 2700),
          target='sum', ylabel='Суммарные энергозатраты', save_path='conveyors2-late-energy.pdf')