In [3]:
from bokeh.models import HoverTool

In [11]:
import panel as pn
import param as pm
import pandas as pd
import numpy as np
from glob import glob
from bokeh.palettes import Category20, Turbo256, Dark2
import hvplot.pandas
import millify
import holoviews as hv
hvplot.extension('bokeh')
pn.extension()

all_experiments = sorted(list(set([f.split('/')[-1].split('-')[0] for f in glob('../data/simulations/*')])))

relevant_experiments = [
    'sanity_check_run',
    'standard_stochastic_run',
    'initial_conditions',
    'reference_subsidy_sweep',
    'sweep_over_saturation',
]

def snake_to_title(s):
    """Utility function used for printing chart titles and labels as Title Case.
    Example:
    snake_to_caps('snake_case')
    >>> 'Snake Case'
    """
    
    return ' '.join(word.capitalize() for word in s.split('_'))



class Simulation(pm.Parameterized):
    experiment = pm.ObjectSelector(default='reference_subsidy_sweep', objects=relevant_experiments)
    dataset = pm.Selector(default=None, precedence=-1)  # This will be dynamically populated
    drop_list = pm.List(precedence=-1, default=['timestep', 'simulation', 'subset', 'timestep_in_days', 'block_time_in_seconds', 'delta_days', 'delta_blocks'])
    sim_df = pm.DataFrame(precedence=-1)
    color_palette = pm.Selector(default=Category20, objects=[Category20, Turbo256], precedence=-1)
    column_colors = pm.Dict(precedence=-1)
    value_format = pm.Selector(default='Millify', objects=['Scientific', 'Millify', 'Decimal'], precedence=1)
    value_color_log_scale = pm.Boolean(False, precedence=-1)
    kpi_subset = pm.Selector()
    sort_days_passed = pm.Boolean(False, precedence=-1)
    max_rows = pm.Integer(6, bounds=(1, None), step=2)

    def fan_chart_quantile_median(self, df, column='circulating_supply', median_only=False):
        """Combine an area chart of min-max and a line chart of median for a series."""
    
        # min, max, median
        fan_df = df.groupby('days_passed')[column].agg(['min', 'max', 'median'])
    
        opts = dict(width=1200, height=400, title=f'{snake_to_title(column)} Fan Chart', ylabel=f'{column}_min_max_median')
    
        # Median curve
        hover = HoverTool(tooltips=[(f'{snake_to_title(column)} Median', '@median{0,0.00}')])
        median_chart = fan_df.hvplot(x='days_passed', y='median', alpha=1, line_width=4, label=f'{snake_to_title(column)} Median', tools=[hover], color=self.column_colors[column]).opts(**opts)
        if median_only:
            return median_chart
    
        # min-max band
        hover = HoverTool(tooltips=[(f'Day:', '$x{0,0}')])
        bands_chart = fan_df.hvplot.area(x='days_passed', y='min', y2='max', legend='top_left', alpha=0.4, tools=[hover], ylim=(0,None), color=self.column_colors[column]).opts(**opts)
    
        # Composition
        chart = bands_chart * median_chart
        return chart

    def __init__(self, **params):
        super(Simulation, self).__init__(**params)
        self._update_dataset_options()
        self._load_simulation_data()
        self._sort_df()
        self._set_kpi_subsets()

    def kpi_subset_name(self):
        inverted_dict = {tuple(value) if type(value)==list else value: key for key, value in self.param.kpi_subset.names.items()}
        key = self.kpi_subset
        if type(key) == list:
            key = tuple(key)
        return inverted_dict[key]
        
    @pm.depends('sort_days_passed', watch=True)
    def _sort_df(self):
        if self.sort_days_passed:
            self.sim_df = self.sim_df.sort_values('days_passed').reset_index(drop=True)
        else:
             self.sim_df = self.sim_df.sort_values(['label', 'environmental_label', 'days_passed']).reset_index(drop=True)

    @pm.depends('experiment', watch=True)
    def _update_dataset_options(self, event=None):
        datasets = sorted(glob(f"../data/simulations/{self.experiment}*"))
        self.param.dataset.objects = datasets
        self.dataset = datasets[-1] if datasets else None

    @pm.depends('dataset', 'drop_list', watch=True)
    def _load_simulation_data(self):
        """Load the simulation data when dataset or drop_list are changed."""
        # Read pickle and drop uneccessary columns and reset index. Backfill data incase of nans in first block
        if self.dataset:
            self.sim_df = pd.read_pickle(self.dataset).drop(self.drop_list, axis=1).reset_index(drop=True).bfill()

            # Columns to move to the far left
            columns_to_move = ['label', 'environmental_label']
            
            # Create a new column order
            new_column_order = columns_to_move + [col for col in self.sim_df.columns if col not in columns_to_move]

            self.sim_df = self.sim_df[new_column_order]
            
        # If no dataset is set, initialize empty dataframe
        else:
            self.sim_df = pd.DataFrame()

    @pm.depends('sim_df', watch=True)
    def _add_kpis(self):
        self.sim_df['issuance'] = self.sim_df['block_reward'] + self.sim_df['reference_subsidy']
        self.sim_df['fees'] = self.sim_df['compute_fee_volume'] + self.sim_df['storage_fee_volume']
        self._update_column_colors()

    def _set_kpi_subsets(self):
        all = 'All'
        fees_and_issuance = ['compute_fee_volume','storage_fee_volume', 'fees', 'block_reward', 'reference_subsidy', 'issuance']
        system_balances = ['other_issuance_balance', 'reward_issuance_balance']
        agent_balances = [
            'farmers_balance',
            'operators_balance',
            'nominators_balance',
            'holders_balance',
        ]
        agent_pool_balances = ['staking_pool_balance']
        protocol_treasury_balances = ['fund_balance']
        other_balances = list(set([c for c in self.sim_df.columns if 'balance' in c]) - set(system_balances + agent_balances + agent_pool_balances + protocol_treasury_balances) )
        supply_columns = list({c for c in self.sim_df.columns if 'supply' in c} - {'max_credit_supply', 'issued_supply', 'total_supply'})
        balance_columns = list(set([c for c in self.sim_df.columns if 'balance' in c]) - set(system_balances))
        
        KPI_SUBSETS = dict(
            all = all,
            fees_and_issuance = fees_and_issuance,
            # system_balances = system_balances,
            agent_balances = agent_balances,
            agent_pool_balances = agent_pool_balances,
            protocol_treasury_balances = protocol_treasury_balances,
            # other_balances = other_balances,
            supply_columns = supply_columns,
            # balance_columns = balance_columns,
        )

        self.param.kpi_subset.objects = KPI_SUBSETS
        self.kpi_subset = all


    def _discrete_colorization(self):
        column_colors = {col: self.color_palette[20][i%20] for i, col in enumerate(self.sim_df.columns)}
        return column_colors

    def _continuous_colorization(self):
        column_colors = dict(zip(self.sim_df.columns, [self.color_palette[int(i)] for i in np.linspace(0,len(self.color_palette)-1, len(self.sim_df.columns))]))
        return column_colors
        
    @pm.depends('sim_df', 'color_palette', watch=True)
    def _update_column_colors(self):
        """Set column colors based on selected color palette and sim df"""
        if self.color_palette == Turbo256:
            self.column_colors = self._continuous_colorization()
        
        if self.color_palette == Category20:
            self.column_colors = self._discrete_colorization()

    def _truncate_dataframe(self, df):
        if self.max_rows >= len(self.sim_df):
            return self.sim_df
        else:
            return pd.concat([df.head(self.max_rows//2), df.tail(self.max_rows//2)])

    def kpi_subset_dataframe(self):
        if self.kpi_subset == 'All':
            return self.sim_df
        else:
            return self.sim_df[['label', 'environmental_label', 'days_passed'] + self.kpi_subset ]

    def pivot_labels_kpi_dataframe(self):
        return self.kpi_subset_dataframe().pivot_table(index='days_passed', columns=['label', 'environmental_label'])

    def format_value(self, x):
        try:
            numeric_x = float(x)
            if self.value_format == 'Scientific':
                return f"{numeric_x:.2e}"
            elif self.value_format == 'Decimal':
                return f"{numeric_x:.f}"
            elif self.value_format == 'Millify':
                return millify.millify(numeric_x, precision=2)
        except (ValueError, TypeError):
            return x

    def styled_results_dataframe(self):
        def luminance(hex_color):
            # Convert hex color to RGB
            hex_color = hex_color.lstrip('#')
            r, g, b = tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4))
            # Calculate luminance
            return (0.299 * r + 0.587 * g + 0.114 * b) / 255
    
        def color_scale(val, min_val, max_val, log_scale=False, is_categorical=False, categories=None):
            if is_categorical:
                category_colors = {cat: Dark2[8][i % 8] for i, cat in enumerate(categories)}
                bg_color = category_colors[val]
            else:
                if pd.isnull(val) or (log_scale and val <= 0):
                    return 'background-color: #ffffff; color: black'
                if log_scale:
                    val, min_val, max_val = np.log(val), np.log(min_val), np.log(max_val)
                
                normalized = (val - min_val) / (max_val - min_val) if max_val > min_val else 0
                color_idx = int(normalized * (len(Turbo256) - 1))
                bg_color = Turbo256[color_idx]
    
            text_color = 'white' if luminance(bg_color) < 0.6 else 'black'
            return f'background-color: {bg_color}; color: {text_color}'
    
        def apply_color_styler(element, log_scale=False):
            for col in element.columns:
                if element[col].dtype == 'object' or element[col].dtype.name == 'category':
                    unique_categories = element[col].dropna().unique()
                    element[col] = element[col].apply(color_scale, args=(None, None, log_scale, True, unique_categories))
                else:
                    min_val, max_val = element[col].min(), element[col].max()
                    if log_scale:
                        # Adjust min_val to positive if necessary
                        min_val = min_val if min_val > 0 else 0.1
                    element[col] = element[col].apply(color_scale, args=(min_val, max_val, log_scale))
            return element
    
        formatter = {col: self.format_value for col in self.sim_df.columns}
    
        if self.kpi_subset == 'All':
            truncated_df = self._truncate_dataframe(self.sim_df)
        else:
            truncated_df = self._truncate_dataframe(self.kpi_subset_dataframe())

        header_styles = [{
            'selector': f'th.col_heading.level0.col{i}',
            'props': [('background-color', self.column_colors.get(col, '#ffffff')), ('color', 'black')]
        } for i, col in enumerate(truncated_df.columns)]
        
        styled_df = truncated_df.style.apply(apply_color_styler, log_scale=self.value_color_log_scale, axis=None).set_table_styles(header_styles).format(formatter)
        return styled_df

    def run_descriptor_df(self):
        df_kpi = self.kpi_subset_dataframe()
        box_df = df_kpi.set_index(['days_passed', 'label', 'environmental_label'])
        
        describe_df = box_df.describe().drop('count')
        
        describe_df.index.name = 'trajectory_metrics'
        
        # Create a MultiIndex with two new levels
        # Note: We create a list for each new level with a single repeated value
        multi_index = pd.MultiIndex.from_product(
            [['all_runs'], ['all_runs'], describe_df.index], 
            names=['label', 'environmental_label', 'trajectory_metrics']
        )
        
        # Assign this MultiIndex to your DataFrame
        describe_df.index = multi_index
        
        describe_labels_df = box_df.groupby(['label', 'environmental_label']).apply(lambda label: label.describe().drop('count'))

        describe_labels_df = pd.concat([describe_df, describe_labels_df])
        

        return describe_labels_df

    def run_comparison_df(self):
        describe_labels_df = self.run_descriptor_df()

        describe_df = describe_labels_df.loc[('all_runs', 'all_runs')]
        
        describe_difference_df = pd.DataFrame(describe_labels_df.values - pd.concat([describe_df for i in range(self.kpi_subset_dataframe()[['label', 'environmental_label']].nunique().prod()+1)]).values, columns=describe_labels_df.columns, index=describe_labels_df.index)
        
        return describe_difference_df

    def styled_run_descriptor_comparison(self):
        describe_labels_df = self.run_descriptor_df()
        describe_difference_df = self.run_comparison_df()
    
        def color_scale(s):
            # Compute max_abs_val for this column for normalization
            max_abs_val = np.max(np.abs(describe_difference_df[s.name]))
    
            # Normalize the differences to [-1, 1] for color mapping within this column
            normalized_diff = s / max_abs_val
    
            # Define colors: Negative differences in shades of red, positive in shades of green, and near-zero as white
            def get_color(value):
                if value < 0:
                    # More negative, more red (1, 1+value, 1+value), value is negative
                    return f'background-color: rgb(255, {int(255 * (1 + value))}, {int(255 * (1 + value))})'
                elif value > 0:
                    # More positive, more green (1-value, 1, 1-value), value is positive
                    return f'background-color: rgb({int(255 * (1 - value))}, 255, {int(255 * (1 - value))})'
                else:
                    # Zero difference, white
                    return 'background-color: rgb(255, 255, 255)'
    
            # Apply color mapping to each cell
            colors = normalized_diff.map(get_color)
            return colors


        header_styles = [{
            'selector': f'th.col_heading.level0.col{i}',
            'props': [('background-color', self.column_colors.get(col, '#ffffff')), ('color', 'black')]
        } for i, col in enumerate(describe_labels_df.columns)]
    
        # Apply the color_scale function to each column in describe_labels_df
        styled_df = describe_labels_df.style.apply(lambda s: color_scale(describe_difference_df[s.name]), axis=0)

        formatter = {col: self.format_value for col in describe_labels_df.columns}
    
        return styled_df.set_table_styles(header_styles).format(formatter)
        

    def view_results_dataframe(self):
        styled_df = self.styled_results_dataframe()
        return pn.panel(styled_df, max_rows=self.max_rows)


    def runs_overview(self):
        return self.sim_df.groupby(['run', 'label', 'environmental_label']).size().reset_index(name='days').groupby(['label','environmental_label']).agg({'run': 'count', 'days': 'first'}).reset_index().rename({'run':'runs'},axis=1)

    def view_runs_overview(self):
        return pn.panel(self.runs_overview(), max_rows=self.max_rows)


    def view_color_columns_bar(self):
        """ View the colormap """

        # For some odd reason, hvplot reverses bar ordering when there are greater that 10 columns. So we apply a reverse to negate that hvplot bug. See here: https://github.com/holoviz/hvplot/issues/1277
        columns_reversed = self.sim_df.columns[::-1]
        
        return self.sim_df.count().to_frame().T.hvplot.bar(y=columns_reversed, color=[self.column_colors[c] for c in columns_reversed], rot=90, width=1400, height=500, title='Column Color Map', fontscale=1.4, yaxis=None)

    def view_violin_kpis(self):
        df_kpi = self.kpi_subset_dataframe()
        box_df = df_kpi.set_index(['days_passed', 'label', 'environmental_label'])
        box_df_melted = box_df.reset_index().drop('days_passed',axis=1).melt(id_vars=['label', 'environmental_label'])
        violin_list = [label.hvplot.violin(y='value', by='variable', c='variable', legend=False, width=1200, height=200, title=f'{self.kpi_subset_name()} : {name}', cmap=self.column_colors, ylim=(0,box_df.max().max()*0.75)) for name, label in box_df_melted.groupby(['label', 'environmental_label'])]
        layout = hv.Layout(violin_list).cols(1)
        return layout

    def view_fan_kpis(self):

        df_kpi = self.kpi_subset_dataframe()
        box_df = df_kpi.set_index(['days_passed', 'label', 'environmental_label'])
        line_list = [hv.Overlay([self.fan_chart_quantile_median(label, column) for column in label.columns if column not in ['label', 'environmental_label', 'days_passed']]).opts(title=f'{self.kpi_subset_name()} : {name}', show_legend=False, ylabel='value') for name, label in box_df.reset_index().groupby(['label', 'environmental_label'])]
        layout = hv.Layout(line_list).cols(1)
        return layout

    def view(self):
        """View the selected simulation results."""
        view = pn.Column(
            """
            ## Simulation Analysis Dashboard
            """,
            pn.Accordion(
                ('Select Parameters', pn.Row(self, pn.Column('## Simulation Results DataFrame', self.view_results_dataframe))), 
                ('Runs Overview', self.view_runs_overview),
                ('Run Comparisons',  self.styled_run_descriptor_comparison),
                ('Run Comparisons Violin', self.view_violin_kpis),
                ('Run Comparisons Fan Charts', self.view_fan_kpis),
            ),
        )
        # view[1].active = list(range(len(view[1])))
        # view[1].active = []
        view[1].active = [0]
        # view[1].active = [0, 2, 3, 4]
        return view


# Usage
s = Simulation()
s.experiment = 'sanity_check_run'
# s.experiment = 'standard_stochastic_run'
# s.experiment = 'reference_subsidy_sweep'
s.max_rows = 20
s.kpi_subset = s.param.kpi_subset.objects['fees_and_issuance']
df = s.sim_df
s.view()