In [1]:
import torch
import os

if 'DISPLAY' in os.environ:
    del os.environ['DISPLAY']

import numpy as np
import dill
import matplotlib.pyplot as plt
import joblib
from distributed import Client
from dask_jobqueue import SLURMCluster
import dask
from joblib import Parallel, delayed
import time
import random
import pandas as pd
import hnn_core
from hnn_core import simulate_dipole, Network, read_params, JoblibBackend, MPIBackend
import sbi.utils as utils
import datetime
import dask.bag as db
time_stamp = datetime.datetime.now().strftime("%m%d%Y_%H%M%S")
#import logging
#logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.DEBUG)


cluster = SLURMCluster(
                       cores=46,
                       processes=46,
                       memory="100GB",
                       walltime="24:00:00",
                       job_extra=['-A carney-sjones-condo']
)

client = Client(cluster)
client

0,1
Client  Scheduler: tcp://172.20.207.31:46096  Dashboard: http://172.20.207.31:8787/status,Cluster  Workers: 0  Cores: 0  Memory: 0 B


In [8]:
num_simulations = 10000

params_fname = '/users/ntolley/Jones_Lab/sbi_hnn_github/data/beta/params/beta_param.param'
save_suffix = 'beta_event_t10000' + '_' + time_stamp
save_path = '/users/ntolley/Jones_Lab/sbi_hnn_github/data/beta/prerun_simulations/' + save_suffix  + '/'

prior_dict = {'dipole_scalefctr': (60000, 200000),
 't_evprox_1': (225, 255),
 'sigma_t_evprox_1': (10, 50),
 'numspikes_evprox_1': (1, 20),
 'gbar_evprox_1_L2Pyr_ampa': (1e-06, 0.0005),
 'gbar_evprox_1_L5Pyr_ampa': (1e-06, 0.0005),
 't_evdist_1': (235, 255),
 'sigma_t_evdist_1': (5, 30),
 'numspikes_evdist_1': (1, 20),
 'gbar_evdist_1_L2Pyr_ampa': (1e-06, 0.0005),
 'gbar_evdist_1_L5Pyr_ampa': (1e-06, 0.0005)}

param_low = [float(item[0]) for key, item in prior_dict.items()]
param_high = [float(item[1]) for key, item in prior_dict.items()]
prior = utils.BoxUniform(low=torch.tensor(param_low), high=torch.tensor(param_high))

theta_samples = prior.sample((num_simulations,))

def dill_save(save_object, save_prefix, save_suffix, save_path, extension='.pkl'):
    save_file = open(save_path + save_prefix + '_' + save_suffix + extension, 'wb')
    dill.dump(save_object, save_file)
    save_file.close()

os.mkdir(save_path)
os.mkdir(save_path + 'data/')
dill_save(params_fname, 'params_fname', save_suffix, save_path)
dill_save(prior, 'prior', save_suffix, save_path)
dill_save(prior_dict, 'prior_dict', save_suffix, save_path)

    

In [9]:
class HNNSimulator:
    def __init__(self, params_fname, prior_dict):
        if 'DISPLAY' in os.environ:
            del os.environ['DISPLAY']
            
        import hnn_core
        from hnn_core import simulate_dipole, Network, read_params, JoblibBackend, MPIBackend
        self.params = read_params(params_fname)
        #self.params['tstop'] = 30
        self.param_names = list(prior_dict.keys())

    def __call__(self, new_param_values):
        new_params = dict(zip(self.param_names, new_param_values.detach().cpu().numpy()))
        self.params.update(new_params)

        net = Network(self.params)
        with JoblibBackend(n_jobs=1):
            dpl = simulate_dipole(net, n_trials=1)

        summstats = dpl[0].data['agg']
        return summstats


#sbi_simulator, sbi_prior = prepare_for_sbi(hnn_simulator, prior)
#params = read_params(params_fname)
def run_simulator(theta, params_fname, prior_dict, sim_idx):
    hnn_simulator = HNNSimulator(params_fname,prior_dict)
    dpl = hnn_simulator(theta)
    return dpl
    


In [10]:
client.cluster.scale(100)
def batch(seq, theta_samples, params_fname, prior_dict):
    sub_results = []
    for sim_idx in seq:
        result = dask.delayed(run_simulator)(theta_samples[sim_idx,:], params_fname, prior_dict, sim_idx)
        sub_results.append(result)
    dpl_list = dask.compute(*sub_results)
    dpl_list = np.stack(dpl_list)
    
    
    dpl_name = save_path + 'data/' + save_suffix + '_dpl_sim{}-{}'.format(seq[0],seq[-1]) + '.csv'
    param_name = save_path + 'data/' + save_suffix + '_theta_sim{}-{}'.format(seq[0],seq[-1]) + '.csv'

    np.savetxt(dpl_name, dpl_list, delimiter=',')
    np.savetxt(param_name, theta_samples[seq,:].detach().cpu().numpy(), delimiter=',')

batches = []
step_size = 1000
for i in range(0, num_simulations, step_size):
    print(i)
    batch(list(range(i, i + step_size)),theta_samples, params_fname, prior_dict)
    
    


0
1000
2000
3000
4000
5000
6000
7000
8000
9000


In [None]:
del client
del cluster