In [1]:
import torch
import os

if 'DISPLAY' in os.environ:
    del os.environ['DISPLAY']

import numpy as np
import dill
import matplot.pyplot as plt
import joblib
from distributed import Client
from dask_jobqueue import SLURMCluster
import dask
from joblib import Parallel, delayed
import time
import random
import pandas as pd
import hnn_core
from hnn_core import simulate_dipole, Network, read_params, JoblibBackend
import sbi.utils as utils
#import logging
#logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.DEBUG)


cluster = SLURMCluster(
                       cores=8,
                       processes=1,
                       memory="20GB",
                       walltime="01:00:00",
                    #    queue="carney-sjones-condo"
)

client = Client(cluster)



In [4]:
num_simulations = 10

params_fname = '/users/ntolley/Jones_Lab/sbi_hnn_github/data/beta/params/beta_param.param'

prior_dict = {'dipole_scalefctr': (60000, 200000),
 't_evprox_1': (225, 255),
 'sigma_t_evprox_1': (10, 50),
 'numspikes_evprox_1': (1, 20),
 'gbar_evprox_1_L2Pyr_ampa': (1e-06, 0.0005),
 'gbar_evprox_1_L5Pyr_ampa': (1e-06, 0.0005),
 't_evdist_1': (235, 255),
 'sigma_t_evdist_1': (5, 30),
 'numspikes_evdist_1': (1, 20),
 'gbar_evdist_1_L2Pyr_ampa': (1e-06, 0.0005),
 'gbar_evdist_1_L5Pyr_ampa': (1e-06, 0.0005)}

param_low = [float(item[0]) for key, item in prior_dict.items()]
param_high = [float(item[1]) for key, item in prior_dict.items()]
prior = utils.BoxUniform(low=torch.tensor(param_low), high=torch.tensor(param_high))

theta_samples = prior.sample((num_simulations,))

In [16]:
class HNNSimulator:
    def __init__(self, params_fname, prior_dict):
        self.params = read_params(params_fname)
        self.params['tstop'] = 30
        self.param_names = list(prior_dict.keys())

    def __call__(self, new_param_values):
        new_params = dict(zip(self.param_names, new_param_values.detach().cpu().numpy()))
        self.params.update(new_params)

        net = Network(self.params)
        with JoblibBackend(n_jobs=1):
            dpl = simulate_dipole(net, n_trials=1)

        summstats = torch.as_tensor(dpl[0].data['agg'])
        return summstats


#sbi_simulator, sbi_prior = prepare_for_sbi(hnn_simulator, prior)

def run_simulator(theta, params_fname, prior_dict, sim_idx):
    hnn_simulator = HNNSimulator(params_fname,prior_dict)
    dpl = hnn_simulator(theta)
    return dpl

In [17]:

results = []
for sim_idx in range(num_simulations):    
    result = dask.delayed(run_simulator)(theta_samples[sim_idx,:], params_fname, prior_dict, sim_idx)
    results.append(result)



In [21]:
client.cluster.scale(10)
final = dask.compute(*results)

In [22]:
plt.plot(final[0])

NameError: name 'plt' is not defined

In [20]:
results[0]

Delayed('run_simulator-a2202163-7570-476a-a65e-ae4c08035638')