# Load Data

In [None]:
# import 
import autodisc as ad
import ipywidgets
import plotly
import numpy as np
plotly.offline.init_notebook_mode(connected=True)

In [None]:
# Define and load data
org_experiment_definitions = [
    # two dimensions
    dict(id = '1',
         directory = '../experiments/experiment_000001',
         name = 'Test 1',
         is_default = True),
    dict(id = '2',
         directory = '../experiments/experiment_000002',
         name = 'Test 2',
         is_default = True),
]

repetition_ids = list(range(10))

# define names and load the data
experiment_name_format = '<id> - <name>' # <id>, <name>

experiment_definitions = []
for org_exp_def in org_experiment_definitions:
    new_exp_def = dict()
    new_exp_def['directory'] = org_exp_def['directory']
    if 'is_default' in org_exp_def:
        new_exp_def['is_default'] = org_exp_def['is_default']
    
    if 'name' in org_exp_def:
        new_exp_def['id'] = ad.gui.jupyter.misc.replace_str_from_dict(experiment_name_format, {'id': org_exp_def['id'], 'name': org_exp_def['name']})
    else:
        new_exp_def['id'] = ad.gui.jupyter.misc.replace_str_from_dict(experiment_name_format, {'id': org_exp_def['id']})

    experiment_definitions.append(new_exp_def)

experiment_statistics = dict()
for experiment_definition in experiment_definitions:
    experiment_statistics[experiment_definition['id']] = ad.gui.jupyter.misc.load_statistics(experiment_definition['directory'])   

# Error Curves

## Generate Data

In [None]:
def compute_running_average(data=None, data_sources=None, nsteps=50):
    
    averaged_data = dict()
        
    if not isinstance(data_sources, list):
        data_sources = [data_sources]
    
    # compute for each experiment
    for experiment_id, experiment_data in data.items():
        
        cur_average_experiment_data = dict()
        
        # compute for each given datasource
        for datasource in data_sources:

            # go through the sub elements of the current datasource to get to the final data
            cur_data = experiment_data
            
            if not isinstance(datasource, tuple):
                datasource = (datasource, )

            # go though sub datasources to reach final data
            for sub_ds in datasource:
                cur_data = cur_data[sub_ds]

            cur_average_data = np.zeros(cur_data.shape) * np.nan
            
            for end_idx in range(cur_data.shape[1]):
                start_idx = max(0, end_idx - nsteps)
                cur_average_data[:, end_idx] = np.nansum(cur_data[:, start_idx:end_idx+1], axis=1) / (end_idx-start_idx+1)
 
            for sub_ds in reversed(datasource[1:]):
                cur_average_data = {sub_ds: cur_average_data}
            
            cur_average_experiment_data[datasource[0]] = cur_average_data
            
        averaged_data[experiment_id] = cur_average_experiment_data
            
    return averaged_data

# compute the running average for some statistics
experiment_running_average_statisitics = compute_running_average(data=experiment_statistics, 
                                                                 data_sources=[('error_in_goalspace_between_goal_bestpolicy', 'data'),('error_in_goalspace_between_goal_usedpolicy', 'data')],
                                                                 nsteps=100)

## Error between Goal and Optimal Policy

In [None]:
# PLOTTING
config = dict(
    layout = dict(
        title = 'Error between Goal and Optimal Policy',
        xaxis = dict(
            title = 'explorations'
        ),
        yaxis = dict(
            title = 'error'
        )
    )
)

elems = ad.gui.jupyter.interact_selection_multiple_experiments_repetitions(func=ad.gui.jupyter.plot_scatter_per_datasource, 
                                                            experiment_definitions=experiment_definitions,
                                                            repetition_ids=repetition_ids, 
                                                            data=experiment_running_average_statisitics, 
                                                            data_source=('error_in_goalspace_between_goal_bestpolicy', 'data'), 
                                                            config=config)

## Error between Goal and Used Policy

In [None]:
# PLOTTING
config = dict(
    layout = dict(
        title = 'Error between Goal and Used Policy',
        xaxis = dict(
            title = 'explorations'
        ),
        yaxis = dict(
            title = 'error'
        )
    )
)

elems = ad.gui.jupyter.interact_selection_multiple_experiments_repetitions(func=ad.gui.jupyter.plot_scatter_per_datasource, 
                                                            experiment_definitions=experiment_definitions,
                                                            repetition_ids=repetition_ids, 
                                                            data=experiment_running_average_statisitics, 
                                                            data_source=('error_in_goalspace_between_goal_usedpolicy', 'data'), 
                                                            config=config)

# Classification

## Generate Data

In [None]:
## Collect Data for categories

# function to calcuate ratios
def calc_binary_ratio(array):
    return np.sum(array) / np.sum(~np.isnan(array))

classifier_data = dict()
for experiment_definition in experiment_definitions: 

    experiment_id = experiment_definition['id']
    
    dead_data = experiment_statistics[experiment_id]['classifier_dead']['data']
    dead_ratio = experiment_statistics[experiment_id]['classifier_dead']['ratio']

    animal_data = experiment_statistics[experiment_id]['classifier_animal']['data']
    animal_ratio = experiment_statistics[experiment_id]['classifier_animal']['ratio']

    non_animal_data = np.full(dead_data.shape, True)
    non_animal_data[dead_data] = False
    non_animal_data[animal_data] = False
    non_animal_ratio = np.apply_along_axis(calc_binary_ratio, 1, non_animal_data)

    diverging_data = np.full(experiment_statistics[experiment_id]['classifier_diverging']['data'].shape, True)
    diverging_data[experiment_statistics[experiment_id]['classifier_diverging']['data'] == 0] = False
    diverging_ratio = np.apply_along_axis(calc_binary_ratio, 1, diverging_data)
    
    diverging_animal_data = animal_data & diverging_data
    diverging_animal_ratio = np.apply_along_axis(calc_binary_ratio, 1, diverging_animal_data)

    non_diverging_animal_data = animal_data & ~diverging_data
    non_diverging_animal_ratio = np.apply_along_axis(calc_binary_ratio, 1, non_diverging_animal_data)

    diverging_non_animal_data = non_animal_data & diverging_data
    diverging_non_animal_ratio = np.apply_along_axis(calc_binary_ratio, 1, diverging_non_animal_data)

    non_diverging_non_animal_data = non_animal_data & ~diverging_data
    non_diverging_non_animal_ratio = np.apply_along_axis(calc_binary_ratio, 1, non_diverging_non_animal_data)

    non_diverging_animal_stable_fixpoint_data = non_diverging_animal_data & experiment_statistics[experiment_id]['classifier_stable_fixpoint_solution']['data']
    non_diverging_animal_stable_fixpoint_ratio = np.apply_along_axis(calc_binary_ratio, 1, non_diverging_animal_stable_fixpoint_data) / non_diverging_animal_ratio
    non_diverging_animal_stable_fixpoint_ratio[np.isnan(non_diverging_animal_stable_fixpoint_ratio)] = 0
    
    # animal & moving & not fixpoint
    non_diverging_animal_moving_data = non_diverging_animal_data & experiment_statistics[experiment_id]['classifier_moving']['data'] & ~non_diverging_animal_stable_fixpoint_data
    non_diverging_animal_moving_ratio = np.apply_along_axis(calc_binary_ratio, 1, non_diverging_animal_moving_data) / non_diverging_animal_ratio
    non_diverging_animal_moving_ratio[np.isnan(non_diverging_animal_moving_ratio)] = 0
    
    # animal & not moving & not fixpoint
    non_diverging_animal_non_moving_data = non_diverging_animal_data & ~experiment_statistics[experiment_id]['classifier_moving']['data'] & ~non_diverging_animal_stable_fixpoint_data
    non_diverging_animal_non_moving_ratio = np.apply_along_axis(calc_binary_ratio, 1, non_diverging_animal_non_moving_data) / non_diverging_animal_ratio
    non_diverging_animal_non_moving_ratio[np.isnan(non_diverging_animal_non_moving_ratio)] = 0
    
    non_diverging_non_animal_stable_fixpoint_data = non_diverging_non_animal_data & experiment_statistics[experiment_id]['classifier_stable_fixpoint_solution']['data']
    non_diverging_non_animal_stable_fixpoint_ratio = np.apply_along_axis(calc_binary_ratio, 1, non_diverging_non_animal_stable_fixpoint_data) / non_diverging_non_animal_ratio
    non_diverging_non_animal_stable_fixpoint_ratio[np.isnan(non_diverging_non_animal_stable_fixpoint_ratio)] = 0
    
    non_diverging_non_animal_non_stable_fixpoint_data = non_diverging_non_animal_data & ~experiment_statistics[experiment_id]['classifier_stable_fixpoint_solution']['data']
    non_diverging_non_animal_non_stable_fixpoint_ratio = np.apply_along_axis(calc_binary_ratio, 1, non_diverging_non_animal_non_stable_fixpoint_data)  / non_diverging_non_animal_ratio
    non_diverging_non_animal_non_stable_fixpoint_ratio[np.isnan(non_diverging_non_animal_non_stable_fixpoint_ratio)] = 0
    
    
    classifier_data[experiment_id] = dict(     
                            dead_data = dead_data,
                            dead_ratio = dead_ratio,

                            animal_data = animal_data,
                            animal_ratio = animal_ratio,

                            non_animal_data = non_animal_data,
                            non_animal_ratio = non_animal_ratio,

                            diverging_data = diverging_data,
                            diverging_ratio = diverging_ratio,

                            diverging_animal_data = diverging_animal_data,
                            diverging_animal_ratio = diverging_animal_ratio,

                            non_diverging_animal_data = non_diverging_animal_data,
                            non_diverging_animal_ratio = non_diverging_animal_ratio,

                            diverging_non_animal_data = diverging_non_animal_data,
                            diverging_non_animal_ratio = diverging_non_animal_ratio,

                            non_diverging_non_animal_data = non_diverging_non_animal_data,
                            non_diverging_non_animal_ratio = non_diverging_non_animal_ratio,

                            non_diverging_animal_stable_fixpoint_data = non_diverging_animal_stable_fixpoint_data,
                            non_diverging_animal_stable_fixpoint_ratio = non_diverging_animal_stable_fixpoint_ratio,

                            non_diverging_animal_moving_data = non_diverging_animal_moving_data,
                            non_diverging_animal_moving_ratio = non_diverging_animal_moving_ratio,

                            non_diverging_animal_non_moving_data = non_diverging_animal_non_moving_data,
                            non_diverging_animal_non_moving_ratio = non_diverging_animal_non_moving_ratio,

                            non_diverging_non_animal_stable_fixpoint_data = non_diverging_non_animal_stable_fixpoint_data,
                            non_diverging_non_animal_stable_fixpoint_ratio = non_diverging_non_animal_stable_fixpoint_ratio,

                            non_diverging_non_animal_non_stable_fixpoint_data = non_diverging_non_animal_non_stable_fixpoint_data,
                            non_diverging_non_animal_non_stable_fixpoint_ratio = non_diverging_non_animal_non_stable_fixpoint_ratio
)

## Major Classification

In [None]:
# Plotting
config = dict(
    plot_type = 'plotly_box',
    layout = dict(
        title = 'Major Classification'
        ),
    trace_labels = ['dead', 'div non animal', 'div animal', 'non animal', 'animal'],
)

elems = ad.gui.jupyter.interact_selection_multiple_experiments_repetitions(func=ad.gui.jupyter.plot_barbox_per_datasource, 
                                                            experiment_definitions=experiment_definitions,
                                                            repetition_ids=repetition_ids, 
                                                            data=classifier_data, 
                                                            data_source=['dead_ratio', 'diverging_non_animal_ratio', 'diverging_animal_ratio', 'non_diverging_non_animal_ratio', 'non_diverging_animal_ratio'],
                                                            config=config)

## Animal Classification

In [None]:
# Plotting
config = dict(
    plot_type = 'plotly_box',
    layout = dict(
        title = 'Animal Classification'
        ),
    trace_labels = ['stable fixpoint', 'moving', 'non moving'],
)

elems = ad.gui.jupyter.interact_selection_multiple_experiments_repetitions(func=ad.gui.jupyter.plot_barbox_per_datasource, 
                                                            experiment_definitions=experiment_definitions,
                                                            repetition_ids=repetition_ids, 
                                                            data=classifier_data, 
                                                            data_source=['non_diverging_animal_stable_fixpoint_ratio', 'non_diverging_animal_moving_ratio', 'non_diverging_animal_non_moving_ratio'],
                                                            config=config)

## Non-Animal Classification

In [None]:
# Plotting
config = dict(
    plot_type = 'plotly_box',
    layout = dict(
        title = 'Non-Animal Classification'
        ),
    trace_labels = ['stable fixpoint', 'non stable'],
)

elems = ad.gui.jupyter.interact_selection_multiple_experiments_repetitions(func=ad.gui.jupyter.plot_barbox_per_datasource, 
                                                            experiment_definitions=experiment_definitions,
                                                            repetition_ids=repetition_ids, 
                                                            data=classifier_data, 
                                                            data_source=['non_diverging_non_animal_stable_fixpoint_ratio', 'non_diverging_non_animal_non_stable_fixpoint_ratio'],
                                                            config=config)

# Run Parameters

In [None]:
# Plotting
def plot_run_parameters_for_experiment(experiment_id):

    config = dict(
        subplots = dict(subplot_titles = ['R', 'T', 'm', 's', 'b:1', 'b:2', 'b:3', 'b:4'],
                        rows=4,
                        cols=2),
        init_mode = 'elements',
        layout = dict(
            height = 1000,
            default_xaxis = dict(
                title = 'explorations'
                ),
            default_yaxis = dict(
                title = 'error'
                ),
            ),
        default_element_label = '<mean_label> - <subelem_idx>',
        default_trace = dict(
            mode = 'markers'
            ),
        default_mean_trace = dict(
            legendgroup = '<data_idx>', # subplot_idx, data_idx
            showlegend = False,
            ),
        default_subplot_mean_traces = [dict(
            showlegend = True,
            )],
        default_element_trace = dict(
            visible='legendonly',
            legendgroup = 'elem <data_idx>-<subelem_idx>',
            showlegend = False,
            ),
        default_subplot_element_traces = [dict(
            showlegend = True
            )],
        default_data_element_traces = [dict(
            visible = True
            )]
    )

    ad.gui.jupyter.plot_scatter_per_datasource(experiment_ids=[experiment_id],
                            repetition_ids=['all'], 
                            data=experiment_statistics, 
                            data_source=[('run_parameters','R'),
                                         ('run_parameters','T'),
                                         ('run_parameters','m'),
                                         ('run_parameters','s'),
                                         ('run_parameters','b', 0),
                                         ('run_parameters','b', 1),
                                         ('run_parameters','b', 2),
                                         ('run_parameters','b', 3)], 
                            config=config)
    
experiment_ids = [exp_def['id'] for exp_def in experiment_definitions]    
retval = ipywidgets.interact(plot_run_parameters_for_experiment, experiment_id = experiment_ids)