In [1]:
import torch
import os

if 'DISPLAY' in os.environ:
    del os.environ['DISPLAY']
    
user_env = os.environ['USER']
job_env = os.environ['SLURM_JOB_ID']

import numpy as np
import dill
import matplotlib.pyplot as plt
import joblib
from distributed import Client
from dask_jobqueue import SLURMCluster
import dask
from joblib import Parallel, delayed
import time
import random
import pandas as pd
import hnn_core
from hnn_core import simulate_dipole, Network, read_params, JoblibBackend, MPIBackend
import sbi.utils as utils
import datetime
import dask.bag as db
#time_stamp = datetime.datetime.now().strftime("%m%d%Y_%H%M%S")
#import logging
#logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.DEBUG)


cluster = SLURMCluster(cores = 50,
                       processes=50,
                       queue='compute',
                       memory="200GB",
                       walltime="48:00:00",
                       job_extra=['-A csd403', '--nodes=1']
                       
)
# local_directory = '/scratch/' + user_env +'/job_' + job_env
#'--cpus-per-task=1'
client = Client(cluster)
client

0,1
Client  Scheduler: tcp://198.202.102.90:34565  Dashboard: http://198.202.102.90:8787/status,Cluster  Workers: 0  Cores: 0  Memory: 0 B


In [2]:
num_simulations = 100000
time_stamp = datetime.datetime.now().strftime("%m%d%Y_%H%M%S")

params_fname = '../../data/beta/params/beta_param.param'
save_suffix = 'beta_event_expanse_t100000' + '_' + time_stamp
save_path = '../../data/beta/prerun_simulations/' + save_suffix  + '/'

prior_dict = {'dipole_scalefctr': (60000, 200000),
 't_evprox_1': (225, 255),
 'sigma_t_evprox_1': (10, 50),
 'numspikes_evprox_1': (1, 20),
 'gbar_evprox_1_L2Pyr_ampa': (1e-06, 0.0005),
 'gbar_evprox_1_L5Pyr_ampa': (1e-06, 0.0005),
 't_evdist_1': (235, 255),
 'sigma_t_evdist_1': (5, 30),
 'numspikes_evdist_1': (1, 20),
 'gbar_evdist_1_L2Pyr_ampa': (1e-06, 0.0005),
 'gbar_evdist_1_L5Pyr_ampa': (1e-06, 0.0005)}

param_low = [float(item[0]) for key, item in prior_dict.items()]
param_high = [float(item[1]) for key, item in prior_dict.items()]
prior = utils.BoxUniform(low=torch.tensor(param_low), high=torch.tensor(param_high))

theta_samples = prior.sample((num_simulations,))

def dill_save(save_object, save_prefix, save_suffix, save_path, extension='.pkl'):
    save_file = open(save_path + save_prefix + '_' + save_suffix + extension, 'wb')
    dill.dump(save_object, save_file)
    save_file.close()

os.mkdir(save_path)
os.mkdir(save_path + 'data/')
dill_save(params_fname, 'params_fname', save_suffix, save_path)
dill_save(prior, 'prior', save_suffix, save_path)
dill_save(prior_dict, 'prior_dict', save_suffix, save_path)

    

In [3]:
class HNNSimulator:
    def __init__(self, params_fname, prior_dict):
        if 'DISPLAY' in os.environ:
            del os.environ['DISPLAY']
            
        import hnn_core
        from hnn_core import simulate_dipole, Network, read_params, JoblibBackend, MPIBackend
        self.params = read_params(params_fname)
        #self.params['tstop'] = 30
        self.param_names = list(prior_dict.keys())

    def __call__(self, new_param_values):
        new_params = dict(zip(self.param_names, new_param_values.detach().cpu().numpy()))
        self.params.update(new_params)

        net = Network(self.params)
        with JoblibBackend(n_jobs=1):
            dpl = simulate_dipole(net, n_trials=1, add_drives_from_params=True)

        summstats = dpl[0].data['agg']
        spike_times = net.cell_response.spike_times
        spike_gids = net.cell_response.spike_gids
        spike_types = net.cell_response.spike_types
        return summstats, spike_times, spike_gids, spike_types

#sbi_simulator, sbi_prior = prepare_for_sbi(hnn_simulator, prior)
#params = read_params(params_fname)
def run_simulator(theta, params_fname, prior_dict, sim_idx):
    hnn_simulator = HNNSimulator(params_fname,prior_dict)
    dpl = hnn_simulator(theta)
    return dpl
    


In [4]:
client.cluster.scale(500)
def batch(seq, theta_samples, params_fname, prior_dict):
    res_list= []
    for sim_idx in seq:
        res = dask.delayed(run_simulator)(theta_samples[sim_idx,:], params_fname, prior_dict, sim_idx)
        res_list.append(res)
        
        

    final_res = dask.compute(*res_list)
    dpl_list = np.stack([final_res[idx][0] for idx in range(len(seq))])
    
    dpl_name = save_path + 'data/dpl_' + save_suffix + '_sim{}-{}'.format(seq[0],seq[-1]) + '.csv'
    param_name = save_path + 'data/theta_' + save_suffix + '_sim{}-{}'.format(seq[0],seq[-1]) + '.csv'

    np.savetxt(dpl_name, dpl_list, delimiter=',')
    np.savetxt(param_name, theta_samples[seq,:].detach().cpu().numpy(), delimiter=',')
    
    #dill_save([final_res[idx][1] for idx in range(len(seq))], 'data/spike_times', save_suffix + '_sim{}-{}'.format(seq[0],seq[-1]), save_path)
    #dill_save([final_res[idx][2] for idx in range(len(seq))], 'data/spike_gids', save_suffix + '_sim{}-{}'.format(seq[0],seq[-1]), save_path)

batches = []
step_size = 500
for i in range(0, num_simulations, step_size):
    print(i)
    batch(list(range(i, i + step_size)),theta_samples, params_fname, prior_dict)

0


distributed.core - ERROR - Exception while handling op heartbeat_worker
Traceback (most recent call last):
  File "/home/ntolley/anaconda3/envs/sbi/lib/python3.8/site-packages/distributed/core.py", line 493, in handle_comm
    result = handler(comm, **msg)
  File "/home/ntolley/anaconda3/envs/sbi/lib/python3.8/site-packages/distributed/scheduler.py", line 2196, in heartbeat_worker
    ws._executing = {
  File "/home/ntolley/anaconda3/envs/sbi/lib/python3.8/site-packages/distributed/scheduler.py", line 2197, in <dictcomp>
    self.tasks[key]: duration for key, duration in executing.items()
KeyError: 'run_simulator-ab1b82b3-40c1-4b51-bb45-f7c277424fbb'


KeyboardInterrupt: 

distributed.core - ERROR - Exception while handling op heartbeat_worker
Traceback (most recent call last):
  File "/home/ntolley/anaconda3/envs/sbi/lib/python3.8/site-packages/distributed/core.py", line 493, in handle_comm
    result = handler(comm, **msg)
  File "/home/ntolley/anaconda3/envs/sbi/lib/python3.8/site-packages/distributed/scheduler.py", line 2196, in heartbeat_worker
    ws._executing = {
  File "/home/ntolley/anaconda3/envs/sbi/lib/python3.8/site-packages/distributed/scheduler.py", line 2197, in <dictcomp>
    self.tasks[key]: duration for key, duration in executing.items()
KeyError: 'run_simulator-1b63ecab-3808-46eb-90fd-e51966d90ba2'
distributed.core - ERROR - Exception while handling op heartbeat_worker
Traceback (most recent call last):
  File "/home/ntolley/anaconda3/envs/sbi/lib/python3.8/site-packages/distributed/core.py", line 493, in handle_comm
    result = handler(comm, **msg)
  File "/home/ntolley/anaconda3/envs/sbi/lib/python3.8/site-packages/distributed/sch