In [None]:
import findspark
findspark.init()
import pyspark
if 'sc' not in locals():
    sc = pyspark.SparkContext()

In [None]:
import numpy as np

from pulse2percept import electrode2currentmap as e2cm
from pulse2percept import effectivecurrent2brightness as ec2b

In [None]:
def setup_model(exp_params):
    model = dict()
    
    # set up temporal model
    model['tm'] = ec2b.TemporalModel(model=exp_params['model'],
                                     tsample=exp_params['tsample'])
    
    # Create a Retina object that can hold the entire Argus I array
    r_sampling = 250  # spatial sampling of retina
    r_file = '../retina_argus_s' + str(r_sampling) + '.npz'
    model['retina'] = e2cm.Retina(axon_map=r_file, sampling=r_sampling,
                                  xlo=-2500, xhi=2500, ylo=-2500, yhi=2500)
    
    # Place electrode array
    # Argus I: 4x4 checkerboard, 880um center-to-center,
    # (almost) alternating electrode sizes
    e_spacing = 880  # um
    x_coord = np.arange(0, 4) * e_spacing - 1.5 * e_spacing
    x_coord, y_coord = np.meshgrid(x_coord, x_coord, sparse=False)

    # spatial arrangement of Argus I creates checkerboard with alternating electrode sizes
    #   .  o  .  o
    #   o  .  o  .
    #   .  o  .  o
    #   o  .  o  .
    r_arr = np.array([260, 520, 260, 520])
    r_arr = np.concatenate((r_arr, r_arr[::-1], r_arr, r_arr[::-1]), axis=0)
    h_arr = np.ones(16)*100

    model['implant'] = e2cm.ElectrodeArray(r_arr.tolist(), x_coord.flatten().tolist(),
                                           y_coord.flatten().tolist(), h_arr.tolist())
    
    # We derive the effective current stimulation (ecs; passed through the effect of
    # the retinal layout, axons, etc.) in addition to the current (cs):
    model['ecs'], model['cs'] = model['retina'].electrode_ecs(model['implant'])
    
    return model

In [None]:
params = [dict(model='Nanduri', tsample=5.6e-5) for _ in range(10)]

In [None]:
paramsRDD = sc.parallelize(params)

In [None]:
modelsRDD = paramsRDD.map(setup_model)

In [None]:
modelsRDD.collect()