Navigation Menu

Skip to content

Commit

Permalink
minor
Browse files Browse the repository at this point in the history
  • Loading branch information
nicktfranklin committed Aug 9, 2018
1 parent 489267c commit 6f3799d
Showing 1 changed file with 8 additions and 9 deletions.
17 changes: 8 additions & 9 deletions get_bayesian_gp_means_std.py
Expand Up @@ -3,7 +3,7 @@
from tqdm import tqdm
from scipy.special import logsumexp
from scipy.stats import norm
from get_noise_nmll import get_noise_nmll
from get_kalman import get_noise_nmll


def get_means_stdev(x_mu_rbf, x_sd_rbf, x_mu_lin, x_sd_lin, x_mu_kal, x_sd_kal,
Expand Down Expand Up @@ -95,8 +95,7 @@ def get_means_stdev(x_mu_rbf, x_sd_rbf, x_mu_lin, x_sd_lin, x_mu_kal, x_sd_kal,

return pd.concat(all_subj)


## N.B. each experiment needs a seperate function to prepare it's own data
# N.B. each experiment needs a seperate function to prepare it's own data

def exp_lin():
lin_gp_data = pd.read_csv('Data/exp_linear/linpred.csv')
Expand All @@ -106,7 +105,7 @@ def exp_lin():
rbf_gp_data.index = range(len(rbf_gp_data))

raw_data = pd.read_csv('Data/exp_linear/lindata.csv')
rewards = raw_data[ 'out'].values
rewards = raw_data['out'].values

noise_nmll = get_noise_nmll(raw_data_path='Data/exp_linear/lindata.csv')

Expand All @@ -120,7 +119,6 @@ def exp_lin():
lin_gp_data = lin_gp_data[lin_gp_data.id != s].copy()
noise_nmll = noise_nmll[noise_nmll.Subject != s].copy()


x_mu_rbf = np.array([rbf_gp_data.loc[:, 'mu_ %d' % ii].values for ii in range(8)]).T
x_sd_rbf = np.array([rbf_gp_data.loc[:, 'sig_ %d' % ii].values for ii in range(8)]).T

Expand Down Expand Up @@ -293,9 +291,10 @@ def exp_scrambled():
)
all_subjs.to_pickle('Data/exp_scrambled/bayes_gp_exp_scram.pkl')


if __name__ == "__main__":
# exp_lin()
exp_lin()
exp_shifted()
# exp_cp()
# exp_srs()
# exp_scrambled()
exp_cp()
exp_srs()
exp_scrambled()

0 comments on commit 6f3799d

Please sign in to comment.