In [1]:
import os
import sys
import scipy
from numba import njit, prange
import numpy as np
import scipy.stats as stats
import bayesflow as bf
import matplotlib.pyplot as plt
import seaborn as sns
from multiprocessing import Pool
import pickle
import tensorflow as tf
import pandas as pd 
from sklearn.covariance import EmpiricalCovariance
import priors_and_simulators as ps

# Suppress scientific notation for floats
np.set_printoptions(suppress=True)
RNG = np.random.default_rng(2023)

#Settings
# Path to data
PATH = '/home/mischa/Documents/bayesflow/prj_real_life_ddm/data/prepared_data/'

# Where to save files
PATH_TO_SAVE = '/home/mischa/Documents/bayesflow/prj_real_life_ddm/data/pickle_st0_sv/'

2024-08-12 11:01:59.624563: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-08-12 11:01:59.646091: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
  from tqdm.autonotebook import tqdm


# Load neural networks from checkpoint

In [2]:
# Get network settings

PARAM_NAMES = [
    r"$v_{congruent}$",
    r"$v_{incongruent}$",
    r"$a_{congruent}$",
    r"$a_{incongruent}$",
    r"$\tau_{correct}$",
    r"$\tau_{error}$",
    r"$s_{tau}$",
    r"$s_{v}$"
]

prior = bf.simulation.Prior(prior_fun=ps.sv_st0_ddm_prior_fun, param_names=PARAM_NAMES)

prior_means, prior_stds = prior.estimate_means_and_stds(n_draws=10000)
prior_means = np.round(prior_means, decimals=1)
prior_stds = np.round(prior_stds, decimals=1)

simulator = bf.simulation.Simulator(simulator_fun=ps.sv_st0_ddm_simulator_fun)

model = bf.simulation.GenerativeModel(prior=prior, simulator=simulator, name="DDM")

def configurator(forward_dict):
    """Configure the output of the GenerativeModel for a BayesFlow setup."""

    out_dict = {}
    out_dict["summary_conditions"] = forward_dict["sim_data"]
    params = forward_dict["prior_draws"].astype(np.float32)
    # Standardize parameters
    out_dict["parameters"] = (params - prior_means) / prior_stds
    
    return out_dict

summary_net = bf.networks.SetTransformer(input_dim=4, summary_dim=30, name="ddm_summary")

inference_net = bf.networks.InvertibleNetwork(
    num_params=len(prior.param_names),
    coupling_settings={"dense_args": dict(kernel_regularizer=None), "dropout": False},
    name="ddm_inference")

amortizer = bf.amortizers.AmortizedPosterior(inference_net, summary_net, name="ddm_amortizer",
                                            summary_loss_fun='MMD')

trainer = bf.trainers.Trainer(
        generative_model=model, amortizer=amortizer, configurator= configurator,
    checkpoint_path="ddm_model_st0_sv_net2")

INFO:root:Performing 2 pilot runs with the DDM model...
INFO:root:Shape of parameter batch after 2 pilot simulations: (batch_size = 2, 8)
INFO:root:Shape of simulation batch after 2 pilot simulations: (batch_size = 2, 120, 4)
INFO:root:No optional prior non-batchable context provided.
INFO:root:No optional prior batchable context provided.
INFO:root:No optional simulation non-batchable context provided.
INFO:root:No optional simulation batchable context provided.
2024-08-12 11:02:15.197235: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:995] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-08-12 11:02:15.225834: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:995] successful NUMA node read from SysFS had negative value (-1), but there must be at least on

# Run amortized inference

In [3]:
#This is where the magic happens

# 1. Store all data-set chunk names in a list
datasets = sorted(os.listdir(PATH))[1:3]

# 2. For each chunk
for dataset_name in datasets:
    
        # 2.1 Load chunk
        loaded_pickle = pickle.load(open(PATH +str(dataset_name), "rb" ))
        
        X_test = loaded_pickle['data_array']
        y_test = loaded_pickle['outcome_array']
        rt_summaries = loaded_pickle['rt_summaries']
        
        print(str(dataset_name)+ " loaded")


        # 2.2 Estimate chunk

        samples_dm = np.concatenate([amortizer.sample(input_dict = {"summary_conditions": x}, n_samples=3000,
                                                  to_numpy=True) for x in np.array_split(X_test, 50)], axis=0)

        samples_dm = samples_dm * prior_stds + prior_means
        
        # Discard negative samples for positively bounded parameters
        samples_dm[:,:,2:8][samples_dm[:,:,2:8]<0] = np.nan  
        
        print(str(np.round(np.sum(np.isnan(samples_dm))/
                           (np.sum(np.isnan(samples_dm))+np.sum(~np.isnan(samples_dm))),
                           5)*100) +"% improper samples rejected")
        
        print(str(dataset_name)+ " inference done")

        # 2.3 Compute summaries of parameter posteriors: means, medians, stds, Q0.025, Q0.0975
        
        estimates = ps.compute_summaries(samples_dm)
        
        # Exclude people with less than 1000 proper posterior samples for at least one parameter
        estimates[np.sum((np.sum(np.isnan(samples_dm), axis=1)>2000), axis=1)>0,:] = np.nan 
        
        print(str(np.sum((np.sum(np.isnan(samples_dm), axis=1))>2000)) +
              " people with <1000 proper samples (for at least one parameter) excluded")

        # 2.4 Get empirical Mahalanobis distances for summary statistics provided by network
            
        summary_statistics_empirical = np.concatenate([trainer.amortizer.summary_net (x)
                                                       for x in np.array_split(X_test, 50)], axis=0)
        
        cov = EmpiricalCovariance().fit(summary_statistics_empirical)

        mahalanobis_empirical = cov.mahalanobis(summary_statistics_empirical)
        
        print(str(dataset_name)+ " Mahalanobis check done")

        
        # 2.5 Store everything together (serialized, pickle.dump) as a dict with keys 
        dict_to_store = {'data_array': X_test, 'est_array': estimates, "outcome_array": y_test,
                        'mahalanobis': mahalanobis_empirical, 'rt_summaries': rt_summaries}
        pickle.dump(dict_to_store,
                    open(PATH_TO_SAVE +"estimates_" +str(dataset_name),"wb"))
        print(str(dataset_name)+ " saving done")

# 3. Celebrate

prepared_False2004iat.p loaded
0.04% improper samples rejected
prepared_False2004iat.p inference done
0 people with <1000 proper samples (for at least one parameter) excluded
prepared_False2004iat.p Mahalanobis check done
prepared_False2004iat.p saving done
prepared_False2005iat.p loaded
0.045% improper samples rejected
prepared_False2005iat.p inference done
0 people with <1000 proper samples (for at least one parameter) excluded
prepared_False2005iat.p Mahalanobis check done
prepared_False2005iat.p saving done


# Create csv

In [6]:
df = pd.DataFrame()
datasets = os.listdir(PATH_TO_SAVE)

for dataset in datasets:
        pickles = pickle.load(open(PATH_TO_SAVE +str(dataset), "rb" ))
        df_oneset = np.concatenate((pickles['est_array'],pickles['outcome_array'],pickles['rt_summaries'],
                                   np.expand_dims(pickles['mahalanobis'], axis=1)), axis=1)
        df = pd.concat([df, pd.DataFrame(data = df_oneset)])
        df["dataset"] = str(dataset)      
        
        print(str(dataset)+" done")

df.columns = ["v_congruent", "v_incongruent", "a_congruent", "a_incongruent",
                 "tplus", "tminus", "st0", "sv",
                "v_congruent_median", "v_incongruent_median", "a_congruent_median", "a_incongruent_median",
                 "tplus_median","tminus_median", "st0_median", "sv_median",
                "v_congruent_std", "v_incongruent_std", "a_congruent_std", "a_incongruent_std",
                 "tplus_std","tminus_std", "st0_std", "sv_std",
                "v_congruent_q025", "v_incongruent_q025", "a_congruent_q025", "a_incongruent_q025",
                 "tplus_q025","tminus_q025", "st0_q025", "sv_q025",
                "v_congruent_q975", "v_incongruent_q975", "a_congruent_q975", "a_incongruent_q975",
                 "tplus_q975","tminus_q975", "st0_q975", "sv_q975",
              
                  "session_id", "age",
              
                 "congruent_rt_correct", "congruent_rt_correct_sd",
                "congruent_rt_error", "congruent_rt_error_sd", "congruent_accuracy",
                "incongruent_rt_correct", "incongruent_rt_correct_sd",
                  "incongruent_rt_error", "incongruent_rt_error_sd", "incongruent_accuracy",
                "word_rt_correct", "word_rt_correct_sd",
                "word_rt_error","word_rt_error_sd", "word_accuracy",
                "picture_rt_correct", "picture_rt_correct_sd", 
               "picture_rt_error", "picture_rt_error_sd", "picture_accuracy",
              
                 "mahalanobis_distance", "dataset"]

df.to_csv("df_sv_st0_ddm.csv", index=False)

estimates_prepared_False2005iat.p done
estimates_prepared_False2004iat.p done
