In [None]:
import sys, importlib, time
sys.path.append('../../')

from src.Modules.Utils.Imports import *
device = torch.device(GetLowestGPU(pick_from=[0]))

import src.Modules.Loaders.DataFormatter as DF
from src.DE_simulation import DE_sim
from scipy.interpolate import RBFInterpolator
import numpy.matlib as matlib

from src.get_params import get_heterog_LHC_params
from src.custom_functions import to_torch, to_numpy, load_model, recover_binn_params, unique_inputs, MSE

### Params for plotting
fontsize=24
prop_cycle = plt.rcParams['axes.prop_cycle']
colors = prop_cycle.by_key()['color']

### Data specifications
path = '../../data/'
filename_header = "adhesion_pulling_mean_25"

### BINN model information
save_folder = "../../results/weights/"
model_name = 'DMLP'
weight = '_best_val'
pde_weight = 1e4

def BINN_diffusion_mesh(T_mesh,params):
    
    """
    Compute diffusion coefficients for a density mesh and parameter values.

    Parameters:
        T_mesh (np.ndarray): A mesh of density values.
        params (tuple): Model parameters, including Adhesion agentmigration `rmh`
                        and adhesion strength `Padh`.

    Returns:
        np.ndarray: Diffusion coefficients computed for the T_mesh and the given params
    """
    
    rmh, rmp, Padh, Ppull, alpha = params
    file_name = f'{filename_header}_PmH_{rmh}_PmP_{rmp}_Padh_{Padh}_Ppull_{Ppull}_alpha_{alpha}'

    binn_name  = f"{model_name}"
    save_name =  f"BINN_training_{binn_name}_{file_name}_pde_weight_{pde_weight}"
    model,binn = load_model(binn_name=binn_name,save_name=save_folder + save_name,x=1.0,t=1.0)

    D_mesh = to_numpy(binn.diffusion(to_torch(T_mesh)[:, None]))
    
    return D_mesh

def simulate_interpolant_PDE(params,f):

    """
    Simulate a PDE with interpolated diffusion coefficients based on 
    given parameters and and interpolant, f.

    Parameters:
        params (tuple or float): Model parameters, including:
            - Adhesion agent migration `rmh`
            - Pulling agent migration `rmp`            
            - Adhesion strength `Padh`
            - Adhesion strength `Ppull`
            - Proprtion of adhesion agents `alpha`            
            
        f (callable): A function that provides interpolated diffusion coefficients based on time and parameters.

    Returns:
        np.ndarray: Spatial grid points `x`.
        np.ndarray: Time points `t`.
        np.ndarray: The solution of the PDE system with interpolated diffusion.
        np.ndarray: The original data for comparison.
    """
    
    rmh, rmp, Padh, Ppull, alpha = params
    
    def interpolated_diffusion(T):
        
        ### Create [T;params]
        param_mesh = matlib.repmat(params,len(T),1)
        sampled_points_mesh = np.hstack([T[:,None], param_mesh])
        ### Compute D(T;params)
        D_mesh = f(sampled_points_mesh)
        ### Set any negative values to zero
        D_mesh[D_mesh < 0] = 0
        
        return D_mesh

    ### Load in ABM data
    file_name = f'{filename_header}_PmH_{rmh}_PmP_{rmp}_Padh_{Padh}_Ppull_{Ppull}_alpha_{alpha}'
    inputs, outputs, shape  = DF.load_ABM_data(path+file_name+".npy",plot=False)
    x,t = unique_inputs(inputs)
    U = outputs.reshape((len(x),-1))
    #Use first timeshot for initial condition
    IC = U[:,0]

    sol = DE_sim(x, 
                 t, 
                 [], 
                 IC, 
                 Diffusion_function = interpolated_diffusion)
    
    
    return x, t, sol, U

def MSE_computation(params, f):
    
    """
    Compute Mean Squared Error (MSE) for training and testing data in 
    predicting ABM data with interpolated PDE simulations
    
    Parameters:
        params (tuple or float): Model parameters.
        f (callable): A function that provides interpolated diffusion coefficients based on time and parameters.

    Returns:
        MSE_train (float): MSE for the training data.
        MSE_test  (float): MSE for the testing data.
    """
    
    x, t, sol, U = simulate_interpolant_PDE(params,f)

    t_max = np.max(t)
    t_perc = 0.75
    
    U_train   = U[:,t<=t_perc*t_max]
    U_test    = U[:,t >t_perc*t_max]
    
    sol_train = sol[:,t<=t_perc*t_max]
    sol_test  = sol[:,t >t_perc*t_max]
    
    MSE_train = MSE(sol_train, U_train)
    MSE_test  = MSE(sol_test, U_test)
    
    return MSE_train, MSE_test

In [None]:
'''
Predicting data from the Pulling & Adhesion ABM.

In this notebook, we predict new data from the Pulling & Adhesion
ABM while varying Padh, Ppull, and alpha and fixing rmp = 1.0 and rmh = 0.25.
'''

#get old parameters
params_old = np.round(get_heterog_LHC_params("Training"),3)
#get new parameters
params_new = np.round(get_heterog_LHC_params("Testing"),3)

#dependent variable
T_mesh = np.linspace(0, 1, 101)[:,None]

### Create inputs (sampled_points = [T;params]) and outputs 
### (D(T;params)) for prior dataset
count = 0
for params in params_old:
    
    #D(T;params)
    D_mesh_tmp = BINN_diffusion_mesh(T_mesh,params)
    #params
    param_mesh_tmp = matlib.repmat(params,len(T_mesh),1)
    #[T; params]
    sampled_points_mesh_tmp = np.hstack([T_mesh, param_mesh_tmp])
    
    #aggregate
    if count == 0:
        D_mesh = D_mesh_tmp
        sampled_points = sampled_points_mesh_tmp 
    else:
        D_mesh = np.vstack([D_mesh,D_mesh_tmp])
        sampled_points = np.vstack([sampled_points,sampled_points_mesh_tmp])
    count+=1 
    
D_mesh = np.squeeze(D_mesh)

### interpolator matches [T;params] -> D(T;params)    
f = RBFInterpolator(sampled_points, D_mesh, kernel="linear")

### save to file
data = {}
data['f'] = f
np.save("../../results/PDE_sims/adhesion_pulling_interpolant.npy",data)

In [None]:
### Plot old and new datasets

fig = plt.figure(figsize=(8,6))#, layout = "constrained")
ax = fig.add_subplot(projection='3d')
ax.scatter(params_old[:,2],params_old[:,3],params_old[:,4],s=250, color = "orange", label = "$\mathcal{P}^{ \ prior}$")
ax.scatter(params_new[:,2],params_new[:,3],params_new[:,4],s=250, color = "green", marker = "x", label = "$\mathcal{P}^{ \  new}$")

fontsize=22
plt.legend(fontsize=20)
plt.xlabel("$p_{adh}$",fontsize=fontsize)
plt.ylabel("$p_{pull}$",fontsize=fontsize)
ax.set_zlabel(r"$\alpha$",fontsize=fontsize)
plt.xticks(np.arange(0,.61,0.2),fontsize=16)
plt.yticks(np.arange(0,.61,0.2),fontsize=16)
ax.set_zticks(np.arange(0,1.1,0.2))
ax.set_zticklabels(np.round(np.arange(0,1.1,0.2),2),fontsize=16)
plt.title("Pulling & Adhesion ABM parameter collections",fontsize=fontsize)
plt.subplots_adjust(left=0)
plt.savefig("../../results/figures/Heterog_parameter_samplings.pdf",format="pdf")

In [None]:
### Initialize MSEs
MSE_binn_trains = []
MSE_binn_tests = []

for params in params_new:
    
    #compute MSEs
    MSE_train, MSE_test = MSE_computation(params,f)
    
    ### save
    MSE_binn_trains.append(MSE_train)
    MSE_binn_tests.append(MSE_test)
    
### convert to ndarray    
MSE_binn_trains = np.array(MSE_binn_trains)
MSE_binn_tests = np.array(MSE_binn_tests)

## Plot MSE values

In [None]:
MSE_sort_index = np.argsort(MSE_binn_trains)

plt.figure(figsize=(8,6),layout="constrained")

plt.bar(np.arange(len(MSE_binn_trains)), MSE_binn_trains[MSE_sort_index],color = colors[2], width = 0.33, log = True, edgecolor="k", label = "Interpolated training MSE")
plt.bar(np.arange(len(MSE_binn_trains))+.33, MSE_binn_tests[MSE_sort_index],color = colors[2], width = 0.33, log = True, hatch=".", edgecolor="k", label = "Interpolated training MSE")

xlabels = [f"{i}" for i in np.arange(1,len(MSE_binn_trains)+1)]
plt.xticks(np.arange(len(MSE_binn_trains))+.167, labels = xlabels, rotation = 0, fontsize=16)
plt.yticks([1e-4, 2e-4, 3e-4, 4e-4], rotation=50, fontsize=20)

plt.grid(axis="y", linewidth=0.25)

plt.xlabel("LHC Samples", fontsize=fontsize)
plt.ylabel("Mean-squared error (MSE)", fontsize=fontsize)
plt.title("Error in predicting the \nPulling & Adhesion ABM", fontsize=fontsize)
plt.legend(fontsize=20)

plt.savefig("../../results/figures/heterog_interpolation.pdf",format="pdf")

## Plot example diffusion rates

In [None]:
fontsize = 20

for i in [0,10,-1]:

    plt.figure()
    params = params_new[MSE_sort_index[i]]
    rmh, rmp, Padh, Ppull, alpha = params 
    
    ### Create [T;params]
    param_mesh = matlib.repmat(params,len(T_mesh),1)
    sampled_points_mesh = np.hstack([T_mesh, param_mesh])
    ### Compute D(T;params)
    D_mesh = f(sampled_points_mesh)
    ### Set negative values to zero
    D_mesh[D_mesh < 0] = 0
    
    plt.plot(T_mesh, D_mesh, c=colors[2], linewidth = 4)
    plt.grid()
    plt.xlim([0,0.75])
    plt.ylim([0,1.0])
    plt.tight_layout(pad=2)

    plt.xlabel("Agent density (T)",fontsize=fontsize)
    plt.ylabel("Diffusion rate, $D^{interp}(T)$",fontsize=fontsize)
    plt.title(r"$(r_m^H, r_m^P, p_{adh}, p_{pull}, \alpha)$ = "+f"\n{(rmh, rmp, Padh, Ppull, alpha)}",
              fontsize=fontsize)



### Make table of LHC samples

In [None]:
print("    \centering")
print(r"    \begin{tabular}{|c|c|}")
print("    \hline")
print("    Sample &   \\Pm = $\ (r_m^{pull},\ r_m^{adh},\ p_{pull},\ p_{adh},\ " + r"\alpha" + ")^T$ \\\\ \\hline")
for i in np.arange(len(MSE_sort_index)):
    index = MSE_sort_index[i]
    rmh, rmp, Padh, Ppull, alpha = params_new[MSE_sort_index[i]]
    print(f"    {i+1} &   ({rmp}, {rmh}, {Ppull}, {Padh}, {alpha})$^T$ \\\\ \\hline")
print("    \end{tabular}")    

In [None]:
### Compute MSE values on old dataset
MSE_vals = []
for params in params_old:
    
    rmh, rmp, Padh, Ppull, alpha = params
    mat = np.load(f"../../results/PDE_sims/PDE_sim_DMLP_adhesion_pulling_mean_25_PmH_{rmh}_PmP_{rmp}_Padh_{Padh}_Ppull_{Ppull}_alpha_{alpha}_pde_weight_10000.0.npy",allow_pickle=True).item()
    
    data = mat['U_data']
    sim  = mat['U_sim']
    t    = mat['t']
    
    MSE_vals.append(MSE(data[:,t<=750],sim[:,t<=750]))
    
### Sort by increasing training MSE order    
MSE_old_train_sort_index = np.argsort(MSE_vals)

print("    \centering")
print(r"    \begin{tabular}{|c|c|}")
print("    \hline")
print("    Sample &   \\Pm = $\ (r_m^{pull},\ r_m^{adh},\ p_{pull},\ p_{adh},\ " + r"\alpha" + ")^T$ \\\\ \\hline")
for i in np.arange(len(MSE_old_train_sort_index)):
    index = MSE_old_train_sort_index[i]
    rmh, rmp, Padh, Ppull, alpha = params_old[MSE_old_train_sort_index[i]]
    print(f"    {i+1} &   ({rmp}, {rmh}, {Ppull}, {Padh}, {alpha})$^T$ \\\\ \\hline")
print("    \end{tabular}")    