# fitting GP (simplex, multiprocessed)

In [53]:
from gpFittingFunctions import *

def multiprocessor():
    # Create a multiprocessing pool
    pool = mp.Pool(processes = 20)
    
    # Create a list of arguments for each iteration of the loop
    session_args = ['Raltz', 'Jirachi']
    
    # Execute the loop iterations in parallel
    results_mp = pool.map(process_animal, session_args)
    
    # Close the multiprocessing pool
    pool.close()
    return results_mp
    
if __name__ == "__main__":
    import multiprocessing as mp
    from supplementaryFunctions import *

    results = []
    num_opts = 10
    tic()
    
    # iterator
    for num in range(num_opts):
        results_mp = multiprocessor()
        
        for an, x, fun in results_mp:
            results.append({'animal':an, 'num_opt': num, 'params':x, 'nll':fun})
            print(results[-1])
        
    print('--------------------------------------')
    toc()

{'animal': 'Raltz', 'num_opt': 0, 'params': array([1.54256878e+00, 9.58975824e-01, 1.00000000e-06, 1.00000000e-03,
       1.96270117e-01]), 'nll': 0.0}
{'animal': 'Jirachi', 'num_opt': 0, 'params': array([5.00000000e+00, 1.00000000e-04, 1.41677535e-01, 1.57080740e-01,
       1.00000000e-04]), 'nll': 4809.423059440209}
{'animal': 'Raltz', 'num_opt': 1, 'params': array([1.00000000e-04, 1.00000878e-04, 4.95314467e-01, 4.19296513e+00,
       3.21994228e-01]), 'nll': 5261.889765196647}
{'animal': 'Jirachi', 'num_opt': 1, 'params': array([5.00000000e+00, 1.00000000e-04, 1.41678270e-01, 1.03914622e-01,
       1.34998020e-01]), 'nll': 4809.423059029519}
{'animal': 'Raltz', 'num_opt': 2, 'params': array([2.59189014e+00, 3.00372423e-01, 1.00000000e-06, 1.00000000e-03,
       1.00000000e-04]), 'nll': 0.0}
{'animal': 'Jirachi', 'num_opt': 2, 'params': array([1.00000000e-04, 1.00000000e-04, 1.00000000e+01, 1.00000000e+01,
       8.26936251e-01]), 'nll': 6940.980568975154}
{'animal': 'Raltz', 'num_o

# fitting GP (differential evolution)

In [1]:
from opconNosepokeFunctions import *                              # load sessdf
from gpFittingFunctions import *

from supplementaryFunctions import *

# load data first
sessdf = pd.read_csv('L:/4portProb_processed/sessdf.csv')
sessdf.drop(columns = 'Unnamed: 0', inplace = True)
window = 7
trialsinsess = 100
exclude = ['[ 20  20  20 100]', '[0 0 0 0]', '[0]', '[0 0]',
       '[1000   80]', '[30]', '[40]', '[70]']
sessdf = sessdf[~sessdf.rewprobfull.isin(exclude)]
sessdf = sessdf[~sessdf.duplicated(subset = ['animal', 'session', 'trialstart', 'eptime'], keep = False)]
sessdf = sessdf.groupby(['animal','session']).filter(lambda x: x.reward.size >= trialsinsess)

In [2]:
l_an = sessdf.animal.unique()
tic()
from scipy.optimize import differential_evolution
lower_bnd, upper_bnd = np.exp(-5), np.exp(5)
bounds = ((lower_bnd, upper_bnd),       # alpha = observation noise variance
          (lower_bnd, 1),       # beta = exploration coefficient
          (lower_bnd, 5),       # tau = for ucb softmax
          (lower_bnd, upper_bnd),      # length scale of the RBF kernel; larger values = more spatial generalization
          (lower_bnd, 1))      # intial q value in a session

result = {}
for an in l_an:
    filtered = sessdf[(sessdf.animal == an) & (sessdf.task == 'unstr')]
    filtered = filtered.groupby('session').head(trialsinsess)
    args = (filtered, an)
    result[an] = differential_evolution(nLL_gp_ucb,
                  args=args,
                  bounds=bounds, disp = True,
                                        workers = 20)
toc()

  with DifferentialEvolutionSolver(func, bounds, args=args,


differential_evolution step 1: f(x)= 6646.88
Polishing solution with 'L-BFGS-B'


  with DifferentialEvolutionSolver(func, bounds, args=args,


differential_evolution step 1: f(x)= 6199.62
differential_evolution step 2: f(x)= 6016.07
differential_evolution step 3: f(x)= 6016.07
differential_evolution step 4: f(x)= 6016.07
differential_evolution step 5: f(x)= 6016.07
differential_evolution step 6: f(x)= 3397.12
differential_evolution step 7: f(x)= 3397.12
differential_evolution step 8: f(x)= 3397.12
differential_evolution step 9: f(x)= 3005.09
differential_evolution step 10: f(x)= 3005.09
differential_evolution step 11: f(x)= 3005.09
differential_evolution step 12: f(x)= 2812.19
differential_evolution step 13: f(x)= 2610.16
differential_evolution step 14: f(x)= 2610.16
differential_evolution step 15: f(x)= 2502.4
differential_evolution step 16: f(x)= 2366.45
differential_evolution step 17: f(x)= 2366.45
differential_evolution step 18: f(x)= 2335.96
differential_evolution step 19: f(x)= 2332.35
differential_evolution step 20: f(x)= 2332.35
differential_evolution step 21: f(x)= 2332.35
differential_evolution step 22: f(x)= 2332.3

  with DifferentialEvolutionSolver(func, bounds, args=args,


differential_evolution step 1: f(x)= 7000.79
Polishing solution with 'L-BFGS-B'


  with DifferentialEvolutionSolver(func, bounds, args=args,


differential_evolution step 1: f(x)= 6995.11
Polishing solution with 'L-BFGS-B'


  with DifferentialEvolutionSolver(func, bounds, args=args,


differential_evolution step 1: f(x)= 6610.18
Polishing solution with 'L-BFGS-B'


  with DifferentialEvolutionSolver(func, bounds, args=args,


differential_evolution step 1: f(x)= 6957.27
Polishing solution with 'L-BFGS-B'


  with DifferentialEvolutionSolver(func, bounds, args=args,


differential_evolution step 1: f(x)= 6919.3
Polishing solution with 'L-BFGS-B'


  with DifferentialEvolutionSolver(func, bounds, args=args,


differential_evolution step 1: f(x)= 6692.11
Polishing solution with 'L-BFGS-B'


  with DifferentialEvolutionSolver(func, bounds, args=args,


differential_evolution step 1: f(x)= 6991.76
Polishing solution with 'L-BFGS-B'


  with DifferentialEvolutionSolver(func, bounds, args=args,


differential_evolution step 1: f(x)= 6984.04
Polishing solution with 'L-BFGS-B'


  with DifferentialEvolutionSolver(func, bounds, args=args,


differential_evolution step 1: f(x)= 6980.16
Polishing solution with 'L-BFGS-B'


  with DifferentialEvolutionSolver(func, bounds, args=args,


differential_evolution step 1: f(x)= 6752.79
Polishing solution with 'L-BFGS-B'


  with DifferentialEvolutionSolver(func, bounds, args=args,


differential_evolution step 1: f(x)= 3441.53
differential_evolution step 2: f(x)= 3441.53
differential_evolution step 3: f(x)= 3441.53
differential_evolution step 4: f(x)= 3441.53
differential_evolution step 5: f(x)= 3441.53
differential_evolution step 6: f(x)= 2387.39
differential_evolution step 7: f(x)= 1156.52
differential_evolution step 8: f(x)= 1156.52
differential_evolution step 9: f(x)= 1156.52
differential_evolution step 10: f(x)= 1108.77
differential_evolution step 11: f(x)= 1108.77
differential_evolution step 12: f(x)= 1003.89
differential_evolution step 13: f(x)= 980.395
differential_evolution step 14: f(x)= 980.395
differential_evolution step 15: f(x)= 908.094
differential_evolution step 16: f(x)= 908.094
differential_evolution step 17: f(x)= 788.112
differential_evolution step 18: f(x)= 788.112
differential_evolution step 19: f(x)= 746.995
differential_evolution step 20: f(x)= 746.995
differential_evolution step 21: f(x)= 734.219
differential_evolution step 22: f(x)= 724.3

  with DifferentialEvolutionSolver(func, bounds, args=args,


differential_evolution step 1: f(x)= 6970.13
Polishing solution with 'L-BFGS-B'


  with DifferentialEvolutionSolver(func, bounds, args=args,


differential_evolution step 1: f(x)= 6888.94
Polishing solution with 'L-BFGS-B'


  P = np.exp(gamma*m[t] + beta_star*sd[t])
  P = P/ np.sum(P)
  P = np.exp(gamma*m[t] + beta_star*sd[t])
  P = P/ np.sum(P)
  P = np.exp(gamma*m[t] + beta_star*sd[t])
  P = P/ np.sum(P)
  P = np.exp(gamma*m[t] + beta_star*sd[t])
  P = P/ np.sum(P)
  P = np.exp(gamma*m[t] + beta_star*sd[t])
  P = P/ np.sum(P)
  P = np.exp(gamma*m[t] + beta_star*sd[t])
  P = P/ np.sum(P)
  P = np.exp(gamma*m[t] + beta_star*sd[t])
  P = P/ np.sum(P)
  P = np.exp(gamma*m[t] + beta_star*sd[t])
  P = P/ np.sum(P)
  P = np.exp(gamma*m[t] + beta_star*sd[t])
  P = P/ np.sum(P)
  P = np.exp(gamma*m[t] + beta_star*sd[t])
  P = P/ np.sum(P)
  P = np.exp(gamma*m[t] + beta_star*sd[t])
  P = P/ np.sum(P)
  P = np.exp(gamma*m[t] + beta_star*sd[t])
  P = P/ np.sum(P)
  P = np.exp(gamma*m[t] + beta_star*sd[t])
  P = P/ np.sum(P)
  P = np.exp(gamma*m[t] + beta_star*sd[t])
  P = P/ np.sum(P)
  P = np.exp(gamma*m[t] + beta_star*sd[t])
  P = P/ np.sum(P)
  P = np.exp(gamma*m[t] + beta_star*sd[t])
  P = P/ np.sum(P)
  P = np

differential_evolution step 1: f(x)= 6740.25
Polishing solution with 'L-BFGS-B'


  with DifferentialEvolutionSolver(func, bounds, args=args,


differential_evolution step 1: f(x)= 6983.36
Polishing solution with 'L-BFGS-B'


  with DifferentialEvolutionSolver(func, bounds, args=args,


differential_evolution step 1: f(x)= 6968.02
Polishing solution with 'L-BFGS-B'


  P = np.exp(gamma*m[t] + beta_star*sd[t])
  P = P/ np.sum(P)
  P = np.exp(gamma*m[t] + beta_star*sd[t])
  P = P/ np.sum(P)
  P = np.exp(gamma*m[t] + beta_star*sd[t])
  P = P/ np.sum(P)
  P = np.exp(gamma*m[t] + beta_star*sd[t])
  P = P/ np.sum(P)
  P = np.exp(gamma*m[t] + beta_star*sd[t])
  P = P/ np.sum(P)
  P = np.exp(gamma*m[t] + beta_star*sd[t])
  P = P/ np.sum(P)
  P = np.exp(gamma*m[t] + beta_star*sd[t])
  P = P/ np.sum(P)
  P = np.exp(gamma*m[t] + beta_star*sd[t])
  P = P/ np.sum(P)
  P = np.exp(gamma*m[t] + beta_star*sd[t])
  P = P/ np.sum(P)
  P = np.exp(gamma*m[t] + beta_star*sd[t])
  P = P/ np.sum(P)
  P = np.exp(gamma*m[t] + beta_star*sd[t])
  P = P/ np.sum(P)
  P = np.exp(gamma*m[t] + beta_star*sd[t])
  P = P/ np.sum(P)
  P = np.exp(gamma*m[t] + beta_star*sd[t])
  P = P/ np.sum(P)
  P = np.exp(gamma*m[t] + beta_star*sd[t])
  P = P/ np.sum(P)
  P = np.exp(gamma*m[t] + beta_star*sd[t])
  P = P/ np.sum(P)
  P = np.exp(gamma*m[t] + beta_star*sd[t])
  P = P/ np.sum(P)
  P = np

differential_evolution step 1: f(x)= 6956.72
Polishing solution with 'L-BFGS-B'


  with DifferentialEvolutionSolver(func, bounds, args=args,


differential_evolution step 1: f(x)= 6948.25
Polishing solution with 'L-BFGS-B'


  with DifferentialEvolutionSolver(func, bounds, args=args,


differential_evolution step 1: f(x)= 6574.42
differential_evolution step 2: f(x)= 6574.42
differential_evolution step 3: f(x)= 6524.93
differential_evolution step 4: f(x)= 5647.18
differential_evolution step 5: f(x)= 5647.18
differential_evolution step 6: f(x)= 5647.18
differential_evolution step 7: f(x)= 3194.03
differential_evolution step 8: f(x)= 1083.45
differential_evolution step 9: f(x)= 1083.45
differential_evolution step 10: f(x)= 1083.45
differential_evolution step 11: f(x)= 1083.45
differential_evolution step 12: f(x)= 1083.45
differential_evolution step 13: f(x)= 1083.45
differential_evolution step 14: f(x)= 983.269
differential_evolution step 15: f(x)= 908.846
differential_evolution step 16: f(x)= 908.846
differential_evolution step 17: f(x)= 908.846
differential_evolution step 18: f(x)= 908.846
differential_evolution step 19: f(x)= 873.156
differential_evolution step 20: f(x)= 858.677
differential_evolution step 21: f(x)= 829.286
differential_evolution step 22: f(x)= 829.2

  with DifferentialEvolutionSolver(func, bounds, args=args,


differential_evolution step 1: f(x)= 7000.78
Polishing solution with 'L-BFGS-B'


  with DifferentialEvolutionSolver(func, bounds, args=args,


differential_evolution step 1: f(x)= 6989.73
Polishing solution with 'L-BFGS-B'


  with DifferentialEvolutionSolver(func, bounds, args=args,


differential_evolution step 1: f(x)= 6757.38
Polishing solution with 'L-BFGS-B'


  with DifferentialEvolutionSolver(func, bounds, args=args,


differential_evolution step 1: f(x)= 6781.71
Polishing solution with 'L-BFGS-B'


  with DifferentialEvolutionSolver(func, bounds, args=args,


differential_evolution step 1: f(x)= 6953.8
Polishing solution with 'L-BFGS-B'


  with DifferentialEvolutionSolver(func, bounds, args=args,


differential_evolution step 1: f(x)= 6783.33
Polishing solution with 'L-BFGS-B'
Elapsed time: 5081.420522 seconds.



In [3]:
l = []
for an in l_an:
    l.append(result[an]['x'])

In [6]:
import xarray as xr

In [9]:
{for row in np.array(l).T}

array([[4.08313042e+01, 2.39602024e+00, 1.25750236e+02, 6.29992649e+01,
        1.39297667e+01, 7.21683614e+01, 5.52041005e+00, 6.73794700e-03,
        1.30946050e+02, 9.75984539e+01, 1.04358952e+02, 1.40839832e+01,
        1.19008362e+00, 1.48413159e+02, 1.29448530e+00, 1.48413159e+02,
        2.35775724e+00, 6.93090818e+01, 1.48413159e+02, 6.73794700e-03,
        1.91270830e+01, 1.48413159e+02, 4.98499893e+00, 1.48413159e+02,
        8.73970409e-01, 4.97598246e+00, 2.09184465e+00],
       [6.73794700e-03, 6.73794700e-03, 6.73794700e-03, 6.73794700e-03,
        6.73794700e-03, 6.73794700e-03, 6.33033458e-01, 6.73794700e-03,
        6.73794700e-03, 6.73794700e-03, 6.73794700e-03, 6.73794700e-03,
        6.73794700e-03, 6.73794700e-03, 6.73794700e-03, 6.73794700e-03,
        6.73794700e-03, 1.00000000e+00, 6.73794700e-03, 6.73794700e-03,
        6.73794700e-03, 6.73794700e-03, 6.73794700e-03, 6.73794700e-03,
        6.73794700e-03, 6.73794700e-03, 6.73794700e-03],
       [8.93305656e-02

In [50]:
zero = pd.DataFrame(result).T[pd.DataFrame(result).T.fun==0].x

In [51]:
zero.to_dict()

{'Raltz': array([5.21594406e-01, 6.51275757e+00, 1.39792601e-03, 6.60554981e+00,
        9.13674946e-01]),
 'Zacian': array([3.00250448e+00, 7.59844706e+00, 1.66918010e-03, 6.90923939e+00,
        5.39703559e-02]),
 'Blissey': array([3.56064904e+00, 8.72631496e+00, 2.60487255e-03, 2.79024056e+00,
        6.48000397e-01]),
 'Alakazam': array([3.85696879e+00, 3.76967190e+00, 1.16602007e-03, 4.49432579e+00,
        9.08391655e-01]),
 'Xatu': array([2.44045037e+00, 8.25596180e+00, 1.59785347e-03, 7.10753659e-01,
        4.68335032e-01]),
 'Hoppip': array([5.19398469e-01, 8.08451573e+00, 1.00319848e-03, 1.52821518e+00,
        3.03539620e-01]),
 'Togepi': array([1.71304395e+00, 1.00000000e+01, 1.00000000e-03, 1.66167634e+00,
        6.00188225e-01]),
 'Inkay': array([4.20714573e+00, 9.19225144e+00, 2.37402072e-03, 7.15397511e+00,
        5.65081686e-01]),
 'Bayleef': array([2.40692346e+00, 9.81869584e+00, 2.86581828e-03, 2.14779513e+00,
        4.25449811e-01])}

In [23]:
result

 message: Optimization terminated successfully.
 success: True
     fun: 5.505815070549008
       x: [ 5.000e+00  1.000e-04  6.214e-02  1.649e-01  2.412e-01]
     nit: 38
    nfev: 3129
     jac: [-2.918e-01  1.193e+00 -2.135e-03  1.421e-06  0.000e+00]