## Example GSLS Histograms

In [None]:
import numpy as np
import pandas as pd
import plotly.express as px
import scipy.stats

In [None]:
def hist_df(label, hist):
    return pd.DataFrame({
        'index': range(hist.shape[0]),
        'hist': hist,
        'label': label,
    })

def merge_dfs(*dfs):
    return pd.concat(dfs)

def plot_histogram(hist, colours=None, **kwargs):
    fig = px.bar(hist, x='index', y='hist', color='label',
                 color_discrete_sequence=colours, **kwargs)
    fig.update_xaxes(showgrid=False, zeroline=False, visible=False)
    fig.update_yaxes(showgrid=False, zeroline=False, visible=False)
    fig.update_layout(
        showlegend=False,
        plot_bgcolor='rgba(0, 0, 0, 0)',
        paper_bgcolor='rgba(0, 0, 0, 0)',
        margin=dict(b=10, t=10, l=10, r=10, pad=0),
    )
    fig.update_traces(
        marker=dict(line=dict(width=0)),
    )
    return fig

In [None]:
bins = 10

remain_hist = np.random.RandomState(100).dirichlet(np.ones(bins))

loss_hist = scipy.stats.norm.pdf(np.arange(bins), loc=bins*0.3, scale=bins/10)
loss_hist = loss_hist / loss_hist.sum()
loss_weight = 0.2

gain_hist = scipy.stats.norm.pdf(np.arange(bins), loc=bins*0.7, scale=bins/10)
gain_hist = gain_hist / gain_hist.sum()
gain_weight = 0.4

In [None]:
source_df = merge_dfs(
    hist_df('$P^R$', (1 - loss_weight) * remain_hist),
    hist_df('$P^-$', loss_weight * loss_hist),
)
loss_df = hist_df('$P^-$', loss_hist)
remain_df = hist_df('$P^R$', remain_hist)
gain_df = hist_df('$P^+$', gain_hist)
target_df = merge_dfs(
    hist_df('$P^R$', (1 - gain_weight) * remain_hist),
    hist_df('$P^+$', gain_weight * gain_hist),
)

loss_color = '#731717'
gain_color = '#61F261'
remain_color = '#6161F2'

fig = plot_histogram(source_df, colours=[remain_color, loss_color])
fig.write_image('plots/source-dist.svg')
fig.show()
fig = plot_histogram(loss_df, colours=[loss_color])
fig.write_image('plots/loss-dist.svg')
fig.show()
fig = plot_histogram(remain_df, colours=[remain_color])
fig.write_image('plots/remain-dist.svg')
fig.show()
fig = plot_histogram(gain_df, colours=[gain_color])
fig.write_image('plots/gain-dist.svg')
fig.show()
fig = plot_histogram(target_df, colours=[remain_color, gain_color])
fig.write_image('plots/target-dist.svg')
fig.show()