In [61]:
import numpy as np
from scipy.optimize import curve_fit
from scipy.stats import chisquare, norm
from sklearn.utils import resample
import pandas as pd

import matplotlib.pyplot as plt
from matplotlib import gridspec
from matplotlib import rc
from IPython.display import display, Latex
rc('font', size=20)

## Generate data

In [62]:
N = int(1e6)
mu_true, sig_true = 0, 1
mu_gen, sig_gen = 0.5, 1
smearing = 0.0
epsilon = 1e-10
ibu_iterations = 20
n_bootstraps = 100

streams = ['truth', 'data', 'gen', 'sim']
bins = np.load("bins.npy")
n_bins = len(bins)


# Generate data
truth = np.random.normal(mu_true, sig_true, N)
data = np.random.normal(truth, smearing)
gen = np.random.normal(mu_gen, sig_gen, N)
sim = np.random.normal(gen, smearing)

In [44]:
def digitize_data(data):
    digitized = np.digitize(data, bins)  # Digitize the data
    clipped = np.clip(digitized, 1, n_bins) - 1  # Clip and adjust indices
    return clipped

def IBU(prior, data_marginal, alt_response_matrix, n_iterations):
    posterior = [prior]
    for i in range(n_iterations):
        m = alt_response_matrix * posterior[-1]
        m /= (m.sum(axis=1)[:,np.newaxis] + epsilon)
        posterior.append(m.T @ data_marginal)
    return posterior[-1]

def response_matrix(x_list, y_list):
    H, _, _ = np.histogram2d(x_list, y_list, bins=[range(n_bins), range(n_bins)])
    return H

In [45]:
df = pd.DataFrame({'indices': [resample(np.arange(N), replace=True) for _ in range(n_bootstraps)]})

for stream in streams:
    df[stream + '_bs'] = [globals()[stream][indices] for indices in df['indices']]
    df[stream + '_digitized'] = df[stream + '_bs'].apply(digitize_data)

In [46]:
df['H_truth_data'] = df.apply(lambda row: response_matrix(row['truth_digitized'], row['data_digitized']), axis=1)
df['H_gen_sim'] = df.apply(lambda row: response_matrix(row['gen_digitized'], row['sim_digitized']), axis=1)
df['H_truth_data_normalized'] = df['H_truth_data'].apply(
    lambda H: H / (H.sum(axis=1, keepdims=True) + epsilon)
)
df['H_gen_sim_normalized'] = df['H_gen_sim'].apply(
    lambda H: H / (H.sum(axis=1, keepdims=True) + epsilon)
)

In [47]:
df['truth_marginal'] = df['H_truth_data'].apply(lambda H: np.sum(H, axis=1))
df['data_marginal'] = df['H_truth_data'].apply(lambda H: np.sum(H, axis=0))
df['gen_marginal'] = df['H_gen_sim'].apply(lambda H: np.sum(H, axis=1))
df['sim_marginal'] = df['H_gen_sim'].apply(lambda H: np.sum(H, axis=0))

In [48]:
df['ibu'] = df.apply(lambda row: IBU(row['gen_marginal'], row['data_marginal'],
                                     row['H_gen_sim_normalized'].T, ibu_iterations), axis=1)

In [26]:
df

Unnamed: 0,indices,truth_bs,truth_digitized,data_bs,data_digitized,gen_bs,gen_digitized,sim_bs,sim_digitized,H_truth_data,H_gen_sim,H_truth_data_normalized,H_gen_sim_normalized,truth_marginal,data_marginal,gen_marginal,sim_marginal,ibu
0,"[976312, 951225, 852004, 943600, 424110, 25003...","[-0.05861208604927202, -0.5718349949410215, 2....","[2, 2, 4, 3, 3, 2, 3, 2, 2, 2, 2, 2, 2, 2, 1, ...","[-0.05861208604927202, -0.5718349949410215, 2....","[2, 2, 4, 3, 3, 2, 3, 2, 2, 2, 2, 2, 2, 2, 1, ...","[1.6209940571623176, 1.4675018823773627, 0.479...","[3, 3, 3, 3, 1, 3, 3, 1, 3, 1, 3, 3, 4, 4, 2, ...","[1.6209940571623176, 1.4675018823773627, 0.479...","[3, 3, 3, 3, 1, 3, 3, 1, 3, 1, 3, 3, 4, 4, 2, ...","[[6753.0, 0.0, 0.0, 0.0, 0.0], [0.0, 146966.0,...","[[1487.0, 0.0, 0.0, 0.0, 0.0], [0.0, 62514.0, ...","[[0.9999999999999852, 0.0, 0.0, 0.0, 0.0], [0....","[[0.9999999999999327, 0.0, 0.0, 0.0, 0.0], [0....","[6753.0, 146966.0, 477948.0, 324358.0, 43975.0]","[6753.0, 146966.0, 477948.0, 324358.0, 43975.0]","[1487.0, 62514.0, 371970.0, 449295.0, 114734.0]","[1487.0, 62514.0, 371970.0, 449295.0, 114734.0]","[6752.9999999999, 146965.9999999999, 477947.99..."
1,"[73784, 246732, 538139, 577503, 868271, 245673...","[0.9278326041669247, 0.2700689266589992, 0.094...","[3, 2, 2, 1, 2, 3, 2, 3, 2, 2, 4, 2, 2, 2, 3, ...","[0.9278326041669247, 0.2700689266589992, 0.094...","[3, 2, 2, 1, 2, 3, 2, 3, 2, 2, 4, 2, 2, 2, 3, ...","[0.021194968001291004, 1.3246320086312537, 2.4...","[2, 3, 4, 4, 1, 4, 3, 3, 3, 1, 4, 2, 3, 3, 4, ...","[0.021194968001291004, 1.3246320086312537, 2.4...","[2, 3, 4, 4, 1, 4, 3, 3, 3, 1, 4, 2, 3, 3, 4, ...","[[6797.0, 0.0, 0.0, 0.0, 0.0], [0.0, 147138.0,...","[[1464.0, 0.0, 0.0, 0.0, 0.0], [0.0, 62980.0, ...","[[0.9999999999999852, 0.0, 0.0, 0.0, 0.0], [0....","[[0.9999999999999316, 0.0, 0.0, 0.0, 0.0], [0....","[6797.0, 147138.0, 478364.0, 323974.0, 43727.0]","[6797.0, 147138.0, 478364.0, 323974.0, 43727.0]","[1464.0, 62980.0, 371888.0, 449167.0, 114501.0]","[1464.0, 62980.0, 371888.0, 449167.0, 114501.0]","[6796.9999999999, 147137.9999999999, 478363.99..."
2,"[646599, 677225, 8887, 195303, 141941, 968149,...","[-0.21661300904738584, -0.5212281420781996, -0...","[2, 2, 2, 2, 2, 3, 2, 2, 3, 3, 1, 1, 3, 3, 3, ...","[-0.21661300904738584, -0.5212281420781996, -0...","[2, 2, 2, 2, 2, 3, 2, 2, 3, 3, 1, 1, 3, 3, 3, ...","[1.6463222805342295, 0.01031506832307122, -1.6...","[3, 2, 1, 2, 3, 3, 2, 3, 1, 3, 2, 1, 1, 1, 3, ...","[1.6463222805342295, 0.01031506832307122, -1.6...","[3, 2, 1, 2, 3, 3, 2, 3, 1, 3, 2, 1, 1, 1, 3, ...","[[6812.0, 0.0, 0.0, 0.0, 0.0], [0.0, 146329.0,...","[[1484.0, 0.0, 0.0, 0.0, 0.0], [0.0, 63159.0, ...","[[0.9999999999999853, 0.0, 0.0, 0.0, 0.0], [0....","[[0.9999999999999326, 0.0, 0.0, 0.0, 0.0], [0....","[6812.0, 146329.0, 478735.0, 324117.0, 44007.0]","[6812.0, 146329.0, 478735.0, 324117.0, 44007.0]","[1484.0, 63159.0, 370832.0, 449988.0, 114537.0]","[1484.0, 63159.0, 370832.0, 449988.0, 114537.0]","[6811.9999999999, 146328.9999999999, 478734.99..."
3,"[390312, 922523, 655439, 463953, 724824, 50661...","[0.20029947483247537, 1.1422377019511862, -0.0...","[2, 3, 2, 2, 4, 2, 2, 2, 3, 3, 3, 1, 1, 2, 3, ...","[0.20029947483247537, 1.1422377019511862, -0.0...","[2, 3, 2, 2, 4, 2, 2, 2, 3, 3, 3, 1, 1, 2, 3, ...","[-0.38001910431043817, 1.02871751550596, -0.45...","[2, 3, 2, 2, 2, 3, 2, 4, 4, 3, 4, 2, 3, 4, 3, ...","[-0.38001910431043817, 1.02871751550596, -0.45...","[2, 3, 2, 2, 2, 3, 2, 4, 4, 3, 4, 2, 3, 4, 3, ...","[[6589.0, 0.0, 0.0, 0.0, 0.0], [0.0, 146430.0,...","[[1486.0, 0.0, 0.0, 0.0, 0.0], [0.0, 62970.0, ...","[[0.9999999999999848, 0.0, 0.0, 0.0, 0.0], [0....","[[0.9999999999999327, 0.0, 0.0, 0.0, 0.0], [0....","[6589.0, 146430.0, 479211.0, 324237.0, 43533.0]","[6589.0, 146430.0, 479211.0, 324237.0, 43533.0]","[1486.0, 62970.0, 371888.0, 449430.0, 114226.0]","[1486.0, 62970.0, 371888.0, 449430.0, 114226.0]","[6588.9999999999, 146429.9999999999, 479210.99..."
4,"[29581, 279755, 611066, 771592, 743077, 242226...","[-0.5519163502279943, 1.0072642309567115, 1.04...","[2, 3, 3, 2, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 1, ...","[-0.5519163502279943, 1.0072642309567115, 1.04...","[2, 3, 3, 2, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 1, ...","[0.8761838572618561, 2.6091374481283545, 1.143...","[3, 4, 3, 1, 3, 4, 3, 3, 4, 3, 4, 1, 2, 4, 2, ...","[0.8761838572618561, 2.6091374481283545, 1.143...","[3, 4, 3, 1, 3, 4, 3, 3, 4, 3, 4, 1, 2, 4, 2, ...","[[6683.0, 0.0, 0.0, 0.0, 0.0], [0.0, 146589.0,...","[[1540.0, 0.0, 0.0, 0.0, 0.0], [0.0, 62728.0, ...","[[0.999999999999985, 0.0, 0.0, 0.0, 0.0], [0.0...","[[0.999999999999935, 0.0, 0.0, 0.0, 0.0], [0.0...","[6683.0, 146589.0, 478965.0, 323535.0, 44228.0]","[6683.0, 146589.0, 478965.0, 323535.0, 44228.0]","[1540.0, 62728.0, 372027.0, 448839.0, 114866.0]","[1540.0, 62728.0, 372027.0, 448839.0, 114866.0]","[6682.9999999999, 146588.9999999999, 478964.99..."
5,"[362938, 627342, 619741, 5063, 94190, 363575, ...","[-1.003515553217308, -1.3108948036529815, 1.06...","[2, 1, 3, 2, 2, 4, 2, 1, 2, 2, 3, 2, 2, 2, 2, ...","[-1.003515553217308, -1.3108948036529815, 1.06...","[2, 1, 3, 2, 2, 4, 2, 1, 2, 2, 3, 2, 2, 2, 2, ...","[0.8184566550343402, 0.5040750090106555, 0.856...","[3, 3, 3, 2, 1, 3, 3, 2, 2, 2, 3, 2, 2, 2, 3, ...","[0.8184566550343402, 0.5040750090106555, 0.856...","[3, 3, 3, 2, 1, 3, 3, 2, 2, 2, 3, 2, 2, 2, 3, ...","[[6775.0, 0.0, 0.0, 0.0, 0.0], [0.0, 146389.0,...","[[1501.0, 0.0, 0.0, 0.0, 0.0], [0.0, 62863.0, ...","[[0.9999999999999852, 0.0, 0.0, 0.0, 0.0], [0....","[[0.9999999999999334, 0.0, 0.0, 0.0, 0.0], [0....","[6775.0, 146389.0, 478536.0, 324174.0, 44126.0]","[6775.0, 146389.0, 478536.0, 324174.0, 44126.0]","[1501.0, 62863.0, 372154.0, 449273.0, 114209.0]","[1501.0, 62863.0, 372154.0, 449273.0, 114209.0]","[6774.9999999999, 146388.9999999999, 478535.99..."
6,"[975830, 711468, 327039, 605626, 855517, 77804...","[-0.08599778777590482, 0.6382596922654075, 0.2...","[2, 3, 2, 1, 3, 2, 1, 3, 2, 1, 3, 2, 2, 3, 3, ...","[-0.08599778777590482, 0.6382596922654075, 0.2...","[2, 3, 2, 1, 3, 2, 1, 3, 2, 1, 3, 2, 2, 3, 3, ...","[0.9131249138166577, -1.049910980644239, -0.02...","[3, 1, 2, 3, 2, 3, 3, 3, 4, 3, 2, 3, 3, 4, 3, ...","[0.9131249138166577, -1.049910980644239, -0.02...","[3, 1, 2, 3, 2, 3, 3, 3, 4, 3, 2, 3, 3, 4, 3, ...","[[6754.0, 0.0, 0.0, 0.0, 0.0], [0.0, 146721.0,...","[[1502.0, 0.0, 0.0, 0.0, 0.0], [0.0, 62865.0, ...","[[0.9999999999999852, 0.0, 0.0, 0.0, 0.0], [0....","[[0.9999999999999334, 0.0, 0.0, 0.0, 0.0], [0....","[6754.0, 146721.0, 478821.0, 324065.0, 43639.0]","[6754.0, 146721.0, 478821.0, 324065.0, 43639.0]","[1502.0, 62865.0, 370992.0, 450378.0, 114263.0]","[1502.0, 62865.0, 370992.0, 450378.0, 114263.0]","[6753.9999999999, 146720.9999999999, 478820.99..."
7,"[656138, 208283, 309859, 529246, 208139, 35244...","[-0.8138112598796872, 0.8354768870348188, -0.2...","[2, 3, 2, 2, 2, 3, 2, 3, 2, 1, 2, 4, 3, 1, 2, ...","[-0.8138112598796872, 0.8354768870348188, -0.2...","[2, 3, 2, 2, 2, 3, 2, 3, 2, 1, 2, 4, 3, 1, 2, ...","[0.8291987842847963, 1.8458272503021158, 1.191...","[3, 4, 3, 3, 3, 3, 3, 3, 2, 4, 4, 2, 3, 3, 3, ...","[0.8291987842847963, 1.8458272503021158, 1.191...","[3, 4, 3, 3, 3, 3, 3, 3, 2, 4, 4, 2, 3, 3, 3, ...","[[6619.0, 0.0, 0.0, 0.0, 0.0], [0.0, 147227.0,...","[[1512.0, 0.0, 0.0, 0.0, 0.0], [0.0, 62919.0, ...","[[0.9999999999999849, 0.0, 0.0, 0.0, 0.0], [0....","[[0.9999999999999338, 0.0, 0.0, 0.0, 0.0], [0....","[6619.0, 147227.0, 477953.0, 324497.0, 43704.0]","[6619.0, 147227.0, 477953.0, 324497.0, 43704.0]","[1512.0, 62919.0, 371543.0, 449734.0, 114292.0]","[1512.0, 62919.0, 371543.0, 449734.0, 114292.0]","[6618.9999999999, 147226.9999999999, 477952.99..."
8,"[189533, 370101, 430791, 741900, 210244, 11148...","[1.246499287613725, -0.49688383784090284, -1.8...","[3, 2, 1, 3, 2, 3, 2, 2, 3, 2, 3, 2, 1, 3, 2, ...","[1.246499287613725, -0.49688383784090284, -1.8...","[3, 2, 1, 3, 2, 3, 2, 2, 3, 2, 3, 2, 1, 3, 2, ...","[-1.0691052105238137, -0.5109490838240505, 0.8...","[1, 2, 3, 2, 3, 3, 2, 3, 4, 3, 3, 3, 3, 2, 4, ...","[-1.0691052105238137, -0.5109490838240505, 0.8...","[1, 2, 3, 2, 3, 3, 2, 3, 4, 3, 3, 3, 3, 2, 4, ...","[[6630.0, 0.0, 0.0, 0.0, 0.0], [0.0, 146551.0,...","[[1474.0, 0.0, 0.0, 0.0, 0.0], [0.0, 62696.0, ...","[[0.9999999999999849, 0.0, 0.0, 0.0, 0.0], [0....","[[0.9999999999999322, 0.0, 0.0, 0.0, 0.0], [0....","[6630.0, 146551.0, 478835.0, 323996.0, 43988.0]","[6630.0, 146551.0, 478835.0, 323996.0, 43988.0]","[1474.0, 62696.0, 371790.0, 449678.0, 114362.0]","[1474.0, 62696.0, 371790.0, 449678.0, 114362.0]","[6629.9999999999, 146550.9999999999, 478834.99..."
9,"[813240, 951334, 446938, 880796, 539709, 36590...","[0.3941804309743496, 0.7530748409637149, 0.330...","[3, 3, 2, 3, 2, 2, 3, 2, 1, 3, 2, 2, 3, 2, 2, ...","[0.3941804309743496, 0.7530748409637149, 0.330...","[3, 3, 2, 3, 2, 2, 3, 2, 1, 3, 2, 2, 3, 2, 2, ...","[-1.6033072281117176, 1.0258136112177754, -1.3...","[1, 3, 1, 3, 3, 3, 3, 3, 2, 4, 1, 2, 3, 2, 2, ...","[-1.6033072281117176, 1.0258136112177754, -1.3...","[1, 3, 1, 3, 3, 3, 3, 3, 2, 4, 1, 2, 3, 2, 2, ...","[[6700.0, 0.0, 0.0, 0.0, 0.0], [0.0, 146719.0,...","[[1512.0, 0.0, 0.0, 0.0, 0.0], [0.0, 62962.0, ...","[[0.9999999999999851, 0.0, 0.0, 0.0, 0.0], [0....","[[0.9999999999999338, 0.0, 0.0, 0.0, 0.0], [0....","[6700.0, 146719.0, 478883.0, 324246.0, 43452.0]","[6700.0, 146719.0, 478883.0, 324246.0, 43452.0]","[1512.0, 62962.0, 371896.0, 449422.0, 114208.0]","[1512.0, 62962.0, 371896.0, 449422.0, 114208.0]","[6699.9999999999, 146718.9999999999, 478882.99..."


In [52]:
plt.figure(figsize=(15, 50))  # Adjust the figure size as needed
n_plots = 10
for i in range(n_plots, 19):
    # Directly access the DataFrame for the required data
    truth_marginal = df['truth_marginal'].iloc[i]
    ibu = df['ibu'].iloc[i]
    ratio = ibu / (truth_marginal + epsilon)

    # Create a subplot for each sample
    ax1 = plt.subplot(n_plots, 2, 2*i + 1)  # 2 columns, n_samples rows, 1st column for i-th sample
    ax1.plot(truth_marginal, label='Truth Marginal', marker='o')
    ax1.plot(ibu, label='IBU Posterior', marker='x')
    ax1.set_title(f'Sample {i+1}')
    ax1.legend()
    ax1.grid(True)

    ax2 = plt.subplot(n_plots, 2, 2*i + 2)  # 2nd column for i-th sample
    ax2.plot(ratio, label='IBU/Truth Ratio', marker='o', color='red')
    ax2.axhline(y=1, color='gray', linestyle='--')  # Reference line at y=1
    ax2.legend()
    ax2.grid(True)

plt.tight_layout()
plt.show()

ValueError: num must be an integer with 1 <= num <= 20, not 21

<Figure size 1500x5000 with 0 Axes>

In [37]:
ibu_matrix = np.stack(df['ibu'])
average_ibu = np.mean(ibu_matrix, axis=0)
print("Averaged IBU array:", average_ibu)

covariance_matrix = np.cov(ibu_matrix, rowvar=False)
print("Covariance matrix of IBU values:\n", np.round(covariance_matrix, 2))

Averaged IBU array: [  6711.2 146705.9 478625.1 324119.9  43837.9]
Covariance matrix of IBU values:
 [[  6235.51   -913.31  -8452.69  -2854.64   5985.13]
 [  -913.31  98062.99 -97768.1   26158.21 -25539.79]
 [ -8452.69 -97768.1  178458.99 -57349.99 -14888.21]
 [ -2854.64  26158.21 -57349.99  68133.88 -34087.46]
 [  5985.13 -25539.79 -14888.21 -34087.46  68530.32]]


In [34]:
print(ibu_matrix)

[[  6753. 146966. 477948. 324358.  43975.]
 [  6797. 147138. 478364. 323974.  43727.]
 [  6812. 146329. 478735. 324117.  44007.]
 [  6589. 146430. 479211. 324237.  43533.]
 [  6683. 146589. 478965. 323535.  44228.]
 [  6775. 146389. 478536. 324174.  44126.]
 [  6754. 146721. 478821. 324065.  43639.]
 [  6619. 147227. 477953. 324497.  43704.]
 [  6630. 146551. 478835. 323996.  43988.]
 [  6700. 146719. 478883. 324246.  43452.]]


In [39]:
df['H_gen_sim_normalized'][0]

array([[1., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0.],
       [0., 0., 1., 0., 0.],
       [0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 1.]])