In [None]:
# get the runs from wandb:

import pandas as pd
import os
import wandb
import numpy as np
import matplotlib.pyplot as plt
# plotting:
from tueplots import bundles
from tueplots import figsizes

import utils

bundle = bundles.icml2024()
bundle['text.usetex'] = False
bundle.pop('text.latex.preamble')
plt.rcParams.update(bundle)


FIGURE_SAVEDIR = 'Experiment_figures/'
if not os.path.exists(FIGURE_SAVEDIR):
    os.makedirs(FIGURE_SAVEDIR)


In [None]:
# load two datapoints from scenarios a and b and plot them:

path = '../data/Exp_0_jaxcpm'

datas = []
for scenario in ['scenario_a', 'scenario_b', 'scenario_d', 'scenario_f']:
    for i in range(1,2):
        data = utils.load_data_from_file(os.path.join(path, scenario, f'all_cpms_{i}.npz'))[-1]
        datas.append(data)



In [None]:
# plot the data:
colors = [
            np.array([[0.,0.,0.]]),# black
            np.array([[0.,0.,0.25]]),# dark blue
            np.array([[1.,0.,0.]]), #  red
            np.array([[204.,255.,11.]]) / 255. #  light green
        ]


types = ['Type A', 'Type B', 'Type D', 'Type F']
fig, axs = plt.subplots(1, 4, figsize=(4, 1.), gridspec_kw={'wspace': 0.05, 'hspace': 0.05})

for i, data in enumerate(datas):
    ax = axs[i]
    utils.plot_cell_image(data, ax, colors=colors)
    ax.axis('on')
    ax.set_xticks([])
    ax.set_yticks([])
    # ax.set_xlabel(types[i])

# # Add the text "Type A" and "Type B" and "Type D" and "Type F" under each subplot using fig.text:
fig.text(0.225, 0.05, 'Type A', ha='center', va='center', fontsize=8)
fig.text(0.42, 0.05, 'Type B', ha='center', va='center', fontsize=8)
fig.text(0.615, 0.05, 'Type D', ha='center', va='center', fontsize=8)
fig.text(0.81, 0.05, 'Type F', ha='center', va='center', fontsize=8)

axs[0].set_ylabel(
            '\nCell sorting', fontsize=8, labelpad=10,
            rotation=90, va="center", ha="center"
        )

plt.tight_layout()
plt.savefig(FIGURE_SAVEDIR + 'exp0_data.png', transparent=True, dpi=400)
plt.show()

In [None]:
datas[0].shape

In [None]:
datas[0][0].max()

In [None]:
wandb.login()

In [None]:

api = wandb.Api()
entity, project = 'neuralcpm', 'NeuralCPM'
runs = api.runs(entity + "/" + project)


In [None]:
# summary_list, config_list, name_list = [], [], []
# for run in runs:
#     # .summary contains output keys/values for
#     # metrics such as accuracy.
#     #  We call ._json_dict to omit large files
#     summary_list.append(run.summary._json_dict)
#
#     # .config contains the hyperparameters.
#     #  We remove special values that start with _.
#     config_list.append({k: v for k, v in run.config.items() if not k.startswith("_")})
#
#     # .name is the human-readable name of the run.
#     name_list.append(run.name)

In [None]:
dict_with_dfs = {}
for run in runs:
    if 'Exp0/' in run.name:
        dict_with_dfs[run.name] = run.history()

In [None]:
true_parameters = {
    'Sce_a': pd.DataFrame(dict(
        lamb_vol=[0.1],
        J_0_0=[0.],
        J_0_1=[0.5],
        J_0_2=[0.5],
        J_1_1=[0.333333],
        J_1_2=[0.2],
        J_2_2=[0.266667]
    )),
    'Sce_b': pd.DataFrame(dict(
        lamb_vol=[0.5],
        J_0_0=[0.],
        J_0_1=[2.5],
        J_0_2=[1.],
        J_1_1=[1.],
        J_1_2=[4.5],
        J_2_2=[1.]
    )),
    'Sce_d':pd.DataFrame({
            'lamb_vol': [0.1],
            'J_0_0': [0.],
            'J_0_1': [15],
            'J_0_2': [7.5],
            'J_1_1': [4.],
            'J_1_2': [7.5],
            'J_2_2': [4.],
        }
    ),
    'Sce_f':pd.DataFrame({
            'lamb_vol': [0.05],
            'J_0_0': [0.],
            'J_0_1': [2],
            'J_0_2': [8],
            'J_1_1': [7.],
            'J_1_2': [5.5],
            'J_2_2': [3.],
        }
    )
}

In [None]:
true_parameters['Sce_d']

In [None]:
def fit_T_and_calc_error(scenario, sampler, dict_with_dfs, true_parameters, key=None):
    """
    As the temperature parameter might be poorly identifiable from static data, and also depend on the sampler, we fit the optimal
    multiply the parameters with a scalar s.t. the error is minimized (i.e., we fit the (inverse) optimal temperature).
    :param scenario:
    :param sampler:
    :param dict_with_dfs:
    :param true_parameters:
    :return:
    """
    if key is None:
        key = f'Exp0/{scenario}/{sampler}'

    vals_learned = dict_with_dfs[key][true_parameters[scenario[:5]].columns].values.T
    vals_true = true_parameters[scenario[:5]].values.T
    T_opt, SSE, rank, singval = np.linalg.lstsq(vals_true, vals_learned, rcond=None)
    SSE_nofit = np.linalg.norm(vals_true - vals_learned, axis=0)**2
    MSE = SSE / vals_true.shape[1]
    MSE_nofit = SSE_nofit / vals_true.shape[1]
    return T_opt[0], MSE, MSE_nofit

In [None]:
dict_with_dfs.keys()


In [None]:
sampler_names = ['cpm',
                 ]
scenarios_samplers = []
for key in dict_with_dfs.keys():
    for n in sampler_names:
        if n in key:
            scenarios_samplers.append("/".join(key.split('/')[1:]))


scenarios = ['Sce_a', 'Sce_b', 'Sce_d_seed0', 'Sce_f_seed0']



In [None]:

for scenario in scenarios:
    plt.figure()
    for scenario_sampler in scenarios_samplers:
        if scenario in scenario_sampler and 'mcs' not in scenario_sampler:
            sampler = scenario_sampler.split('/')[-1]
            print(scenario, sampler)
            T, MSE, MSE_nofit = fit_T_and_calc_error(scenario, sampler, dict_with_dfs, true_parameters)
            print('logRMSE - T=T*', scenario, sampler, np.log10(np.sqrt(MSE[-1])))
            print('logRMSE - T=1', scenario, sampler, np.log10(np.sqrt(MSE_nofit[-1])))
            plt.plot(MSE,label=f'{sampler}')
    plt.yscale('log')
    # plt.ylim(1e-3, 1e1)
    plt.title(scenario)
    plt.legend(frameon=False)
    plt.show()

In [None]:
import pandas as pd

results = []

for scenario in scenarios:
    plt.figure()
    for scenario_sampler in scenarios_samplers:
        if scenario in scenario_sampler and 'mcs' not in scenario_sampler:
            sampler = scenario_sampler.split('/')[-1]
            print(scenario, sampler)
            T, MSE, MSE_nofit = fit_T_and_calc_error(scenario, sampler, dict_with_dfs, true_parameters)
            print('MSE - T=T*', scenario, sampler, np.log10(np.sqrt(MSE[-1])))
            print('MSE - T=1', scenario, sampler, np.log10(np.sqrt(MSE_nofit[-1])))
            plt.plot(MSE, label=f'{sampler}')
            results.append({
                'scenario': scenario,
                'sampler': sampler,
                'log-RMSE ($T=1$)': np.log10(np.sqrt(MSE_nofit[-1])),
                'log-RMSE ($T=T^*$)': np.log10(np.sqrt(MSE[-1]))
            })
    plt.yscale('log')
    # plt.ylim(1e-3, 1e1)
    plt.title(scenario)
    plt.legend(frameon=False)
    plt.show()

results_df = pd.DataFrame(results)

In [None]:
print(results_df.set_index(['scenario', 'sampler']).unstack(0).swaplevel(axis=1).sort_index(level=0, axis=1).to_latex(column_format='ccccc'))

In [None]:
dict_with_dfs.keys()

In [None]:
scenario = 'Sce_b'
sampler = 'gwg_0.5mcs'
sampler_base = sampler.split('_')[0]
k = f'Exp0/{scenario}/{sampler}'


In [None]:

plt.figure()
for sampler in ['cpm']: # sampler_names
    mcs_sampler = []
    mse_sampler = []
    for k in dict_with_dfs.keys():
        if sampler in k and scenario in k:
            print(k)
            df = dict_with_dfs[k]
            mcs = df['num steps'][0] * 100 / 200**2
            T, MSE, MSE_nofit = fit_T_and_calc_error(scenario, sampler, dict_with_dfs, true_parameters, key=k)
            MSE_final = MSE[-1]
            print('MSE - T=T*', scenario, sampler, mcs, MSE_final)
            print('MSE - T=1)', scenario, sampler, mcs, MSE_nofit[-1])
            mcs_sampler.append(mcs)
            mse_sampler.append(MSE_final)
    idx_sort = np.argsort(np.array(mcs_sampler))
    plt.plot(np.array(mcs_sampler)[idx_sort], np.array(mse_sampler)[idx_sort], '-o', label=sampler)
plt.legend()
plt.yscale('log')
plt.xscale('log')
plt.ylabel('MSE')
plt.xlabel('MCS per training step')
plt.show()

In [None]:
dict_with_dfs.keys()

In [None]:
# for all J_*_* values, plot their convergence to the true value for the cpm sampler:

scenarios = [
            # 'Sce_b',
            #  'Sce_a',
            *[f'Sce_d_seed{i}' for i in range(4,5)],
    # *[f'Sce_b_seed{i}' for i in range(1,5)]
             ]

fitted_vals = {}


fsize = figsizes.icml2024_half(ncols=1, nrows=len(scenarios), )
fsize['figure.figsize'] = (fsize['figure.figsize'][0] * 0.8, fsize['figure.figsize'][1] * 0.75)
with plt.rc_context(fsize):
    fig, axs = plt.subplots(len(scenarios), 1, sharex=True, squeeze=False)

    for ax, scenario in zip(axs.flatten(), scenarios):
        for sampler in ['cpm']:

            for i, (key, df) in enumerate(dict_with_dfs.items()):
                if scenario == key.split('/')[1] and sampler in key and 'mcs' not in key:
                    print(key)
                    T, MSE, MSE_nofit = fit_T_and_calc_error(scenario, sampler, dict_with_dfs, true_parameters, key=key)
                    print(T[-1])
                    for col in true_parameters[scenario[:5]].columns:
                        if 'J_0_0' in col:
                            continue
                        p = ax.plot((df[col] / T).values, label=col)
                        color = p[-1].get_color()
                        ax.hlines(true_parameters[scenario[:5]][col], xmin=0, xmax=100, colors=color, linestyle='--')
                        ax.set_ylabel('Parmameter value')
                        fitted_val = (df[col] / T).iloc[-1]
                        if col not in fitted_vals:
                            fitted_vals[col] = []
                        fitted_vals[col].append(fitted_val)
    # plt.title(f'{scenario} {sampler}')
    handles, labels = plt.gca().get_legend_handles_labels()
    labels = ['$\lambda$', '$J(0,1)$', '$J(0,2)$', '$J(1,1)$', '$J(1,2)$', '$J(2,2)$']
    fig.legend(handles, labels, loc='lower center', bbox_to_anchor=(0.5, 1), ncol=3, frameon=False)
    fig.supxlabel('Training iteration (x100)', fontsize=8)
    plt.savefig(FIGURE_SAVEDIR + f'exp0_param_convergence_{scenario}.pdf', bbox_inches='tight')
    plt.show()


In [None]:
fitted_vals = pd.DataFrame(fitted_vals)
print(scenarios)

In [None]:
fitted_vals.mean()

In [None]:
fitted_vals.std()

In [None]:
true_parameters['Sce_f']