In [None]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
import os
import pickle
import pandas as pd
import glob
import joblib as jl
import json

In [None]:
rootdir = "/nfs/nhome/live/npatel/projects/Collective-learning-/ceph_logs/sweep_test/r_2_sigma1_sweep/2025-05-07_09-16-49/T6lr1.0r_11.0rho0delta0sigma21.0epsilon0.0baselineTrueepochs1000savepoints100"
df=pd.read_pickle(os.path.join(rootdir, 'ode_results.pkl'))
configs = json.load(open(os.path.join(rootdir, 'config.json')))

epochs = configs['fixed']['epochs']
savepoints = configs['fixed']['savepoints']
steps = np.unique(np.logspace(0, np.log10(epochs - 1), num=savepoints, dtype=int))
true_savepoints = np.shape(steps)[0]

configs['fixed'].pop('epochs', None)
configs['fixed'].pop('savepoints', None)
configs['fixed'].pop('logdir', None)
configs['fixed'].pop('epochs', None)
configs['fixed'].pop('D', None)

In [None]:


if 'r_1' in configs['sweep']:
    r_1_s = np.array(configs['sweep']['r_1'])
else:
    r_1 = np.array(configs['fixed']['r_1'])

if 'r_2' in configs['sweep']:
    r_2_s = np.array(configs['sweep']['r_2'])
else:
    r_2 = np.array(configs['fixed']['r_2'])

if 'r_12' in configs['sweep']:
    r_12_s = np.array(configs['sweep']['r_12'])
else:
    r_12 = np.array(configs['fixed']['r_12'])

if 'tau_1' in configs['sweep']:
    tau_1_s = np.array(configs['sweep']['tau_1'])
else:
    tau_1 = np.array(configs['fixed']['tau_1'])

if 'tau_2' in configs['sweep']:
    tau_2_s = np.array(configs['sweep']['tau_2'])
else:
    tau_2 = np.array(configs['fixed']['tau_2'])


In [None]:
"""
set the fixed parameters
"""
fixed_param_names = list(configs['fixed'].keys())
fixed_params = {}
for key in fixed_param_names:
    fixed_params[key[:-1]] = configs['fixed'][key][0]
print(fixed_params)
df_filtered = df.xs(tuple(fixed_params.values()), level=list(fixed_params.keys()))

In [None]:
"""
print swept parameters and values
"""
sweep_param_names = list(configs['sweep'].keys())
print({sweep_param_names[0]},  configs['sweep'][sweep_param_names[0]])
print({sweep_param_names[1]},  configs['sweep'][sweep_param_names[1]])

In [None]:
"""
The following is for a brief preliminary visisualization of results, and is absolutely not necessary.
The gist is to plot sweeps of trajectories for one of the parameters, while holding the other parameter fixed.
But this is done for a few values of the second parameter, with an individual subplot for each value of the second parameter.

The input would be be two lists of values for the two parameters being swept over
(this would typically be the full list of values for one parameter, 
and a smaller list of values for the second parameter. N.B. the smaller list should be a subset of the full list)

given df_filtered (filtered dataframe by fixed parameters),
sweep_params (a list of values of the parameters being swept and plotted over)
second_params (a (smaller) list of values of the secondary parameters being swept over and plotted in an outer loop)
sweep_param_name (name of the parameter being swept over)
second_param_name (name of the secondary parameter being 'swept' over)
"""
def plot_variable(df_filtered, variable, sweep_params, second_params, sweep_param_name, second_param_name, log_epochs):
    sweep_length = len(sweep_params)
    second_sweep_length = len(second_params)

    if 'rhos' in configs['sweep']:
        if 'rho' == sweep_param_name:
            rhos = np.array(sweep_params)
            Sr = S_r(rhos)
            Sr = np.tile(Sr, (second_sweep_length, 1))
            Si = S_i(rhos)
            Si = np.tile(Si, (second_sweep_length, 1))

        else:
            rhos = np.array(second_params)
            Sr = S_r(rhos)
            Sr = np.tile(Sr, (1, sweep_length))
            Si = S_i(rhos)
            Si = np.tile(Si, (1, sweep_length))
     
    else:
        rho = np.array(configs['fixed']['rhos'] * sweep_length)
        Sr = S_r(rho)
        Sr = np.tile(Sr, (second_sweep_length, 1))
        Si = S_i(rho)
        Si = np.tile(Si, (second_sweep_length, 1))

    if 'deltas' in configs['sweep']:
        if 'delta' == sweep_param_name:
            deltas = np.array(sweep_params)
            deltas = np.tile(deltas, (second_sweep_length, 1))
        else:
            deltas = np.array(second_params)
            deltas = np.tile(deltas, (1, sweep_length))
        deltas = np.array(configs['sweep']['deltas'])
    else:
        deltas = configs['fixed']['deltas'] * sweep_length

    if 'epsilons' in configs['sweep']:
        epsilons = np.array(configs['sweep']['epsilons'])
    else:
        epsilons = configs['fixed']['epsilons'] * sweep_length

    fig, axes = plt.subplots(len(second_params), 2, figsize=(10, 3*len(second_params)), sharex=True)
    for i, second_param in enumerate(second_params):
        df_filtered_new = df_filtered.xs(second_param, level=second_param_name)
        for param, row in df_filtered_new.iterrows():
            R = row['R']
            Qr = row['Qr']
            Qi = row['Qi']
            

        for var, ax in zip(['R', 'Qr', 'Qi'], axes[i,:]):
            for rho_val, row in df_filtered_new.iterrows():
                ax.plot(log_epochs, row[var], label=f"delta={rho_val:.3g}")
            ax.set_title(f"{var} vs time (sweep over delta)")
            ax.set_ylabel(var)
            ax.grid(True)
            #ax.legend()
            ax.set_xscale('log')

        for param in sweep_params:
            try:
                y_vals = df_filtered.loc[(rho, delta)][variable]
                plt.plot(range(len(y_vals)), y_vals, label=f"rho={rho}, delta={delta}")
            except KeyError:
                print(f"Skipping rho={rho}, delta={delta} (not found in data)")
                continue
    plt.title(f"{variable} over time for varying rho and delta")
    plt.xlabel("Time Step")
    plt.ylabel(variable)
    plt.legend()
    plt.grid(True)
    plt.show()

# Example: Plot R, Qr, and Qi
plot_variable(df_filtered, 'R', rho_values_to_plot, delta_values_to_plot)
plot_variable(df_filtered, 'Qr', rho_values_to_plot, delta_values_to_plot)
plot_variable(df_filtered, 'Qi', rho_values_to_plot, delta_values_to_plot)
