In [1]:
from pulse2percept.models import AxonMapModel, BiphasicAxonMapModel, BiphasicAxonMapSpatial, Model, AxonMapSpatial
from pulse2percept.stimuli import Stimulus, BiphasicPulseTrain
from pulse2percept.implants import DiskElectrode, ProsthesisSystem, ArgusII, ElectrodeArray

import numpy as np
import os
os.environ["CUDA_VISIBLE_DEVICES"]=""
# import jax.numpy as np
import jax
from jax import jit
jax.devices()

[CpuDevice(id=0)]

## Where the magic happens (Jax predict_percept())

In [None]:
# NEW jax copy
def gpu_biphasic_axon_map(amps, freqs, pdurs, x, y, # Per ACTIVE electrode
                          axon_segments, rho, axlambda, thresh_percept, timesteps):
    deg2rad = 3.14159265358979323846 / 180.0

    n_space = len(axon_segments)

    min_size = 10**2 / rho**2
    min_streak = 10**2 / axlambda **2

    # First get contributions from F, G, H per electrode
    scaled_amps = amps / (0.8825 + 0.27*pdurs)
    brights = 1.84*scaled_amps + 0.2*freqs + 2.0986
    sizes = np.maximum(1.081*scaled_amps - 0.3533764, min_size)
    streaks = np.maximum(1.56 - 0.54 * pdurs ** 0.21, min_streak)

    # axon_segments is (n_space, axon_length, 3), x and y are (n_elec)
    d2_el = (axon_segments[:, :, 0, None] - x)**2 + (axon_segments[:, :, 1, None] - y)**2
    # (n_space, axon_length, n_elecs)
    
    #                       (n_elecs) (n_space, axon_length, n_elecs)            (n_space, axon_length, n_elecs) 
    electrode_intensities = brights * np.exp(-d2_el / ( 2. * rho**2. * sizes)) * (axon_segments[idx_space, :, 2, None] ** (1. / streaks))
    # (n_space, axon_length, n_elecs)
    
    axon_intensities = np.sum(electrode_intensities, axis=2)
    # (n_space, n_elecs)
    I = np.max(axon_intensities, axis=1)
    
#     I = (I > thresh_percept) * I
    return np.asarray(np.transpose(np.tile(I, (timesteps, 1))))

In [6]:
a = np.array([[[0, 0, 0], [1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4], [5, 5, 5]]])
b = np.array([10, 9, 8, 7, 6])
c = np.array([10, 9, 8, 7, 6])

In [34]:
d = (a[:, :, 0, None] - b)**2 + (a[:, :, 1, None] - c)**2
d

array([[[200, 162, ...,  98,  72],
        [162, 128, ...,  72,  50],
        [128,  98, ...,  50,  32]],

       [[ 98,  72, ...,  32,  18],
        [ 72,  50, ...,  18,   8],
        [ 50,  32, ...,   8,   2]]])

In [35]:
d.shape

(2, 3, 5)

In [36]:
s1 = np.array([1,2,3,4,5])

In [42]:
a[:, :, 2]

array([[0, 1, 2],
       [3, 4, 5]])

In [46]:
f = a[:, :, 2, None] ** (1/s1)

In [53]:
g = np.sum(f, axis=2)
g

array([[ 0.  ,  5.  ,  7.01],
       [ 8.74, 10.32, 11.82]])

In [54]:
np.max(g, axis=1)

array([ 7.01, 11.82])

#### Build model here in new spatial subclass, wrapped in BiphasicAxonMapGPU

In [3]:
class BiphasicAxonMapGPUSpatial(AxonMapSpatial):
  def __init__(self, **params):
    super(BiphasicAxonMapGPUSpatial, self).__init__(**params)


  def _predict_spatial(self, earray, stim):
    assert isinstance(earray, ElectrodeArray)
    assert isinstance(stim, Stimulus)

    # get relevant stimulus properties
    amps = np.array([stim.metadata['electrodes'][str(e)]['metadata']['amp'] for e in stim.electrodes], dtype="float32")
    freqs = np.array([stim.metadata['electrodes'][str(e)]['metadata']['freq'] for e in stim.electrodes], dtype="float32")
    pdurs = np.array([stim.metadata['electrodes'][str(e)]['metadata']['phase_dur'] for e in stim.electrodes], dtype="float32")


    return gpu_biphasic_axon_map(amps, freqs, pdurs,
                                 np.array([earray[e].x for e in stim.electrodes], dtype=np.float32),
                                 np.array([earray[e].y for e in stim.electrodes], dtype=np.float32),
                                 self.axon_contrib,
                                 self.rho, self.axlambda, self.thresh_percept, stim.shape[1])


class BiphasicAxonMapGPU(Model):
  def __init__(self, **params):
    super(BiphasicAxonMapGPU, self).__init__(spatial=BiphasicAxonMapGPUSpatial(), temporal=None, **params)

  def predict_percept(self, implant, t_percept=None):
    # Make sure stimulus is a BiphasicPulseTrain:
    if not isinstance(implant.stim, BiphasicPulseTrain):
      # Could still be a stimulus where each electrode has a biphasic pulse train
      for ele, params in implant.stim.metadata['electrodes'].items():
        if params['type'] != BiphasicPulseTrain or params['metadata']['delay_dur'] != 0: 
          raise TypeError("All stimuli must be BiphasicPulseTrains with no delay dur (Failing electrode: %s)" % (ele)) 
        
    return super(BiphasicAxonMapGPU, self).predict_percept(implant, t_percept=t_percept)


In [4]:
model = BiphasicAxonMapGPU()
model.build()
implant = ArgusII()
stim = Stimulus({"A2" : BiphasicPulseTrain(20, 1, 0.45)})
implant.stim = stim

model_orig = BiphasicAxonMapModel()
model_orig.build()
print()




In [5]:
%%timeit 
percept = model.predict_percept(implant)
percept.plot()

IndexError: too many indices for array: array is 2-dimensional, but 3 were indexed

In [6]:
%%timeit
p1 = model_orig.predict_percept(implant)
p1.plot()

NameError: name 'fast_biphasic_axon_map' is not defined

In [None]:
np.max(percept.data - p1.data)

NameError: ignored

## OLD

In [2]:
# numpy copy
"""
Function to be JIT'ed. 
Must be purely functional, and follow all rules at https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html

"""
def gpu_biphasic_axon_map(amps, freqs, pdurs, x, y, # Per ACTIVE electrode
                          axon_segments,
                          rho, axlambda, thresh_percept, timesteps):
  deg2rad = 3.14159265358979323846 / 180.0

  n_space = len(axon_segments)

  min_size = 10**2 / rho**2
  min_streak = 10**2 / axlambda **2

  # First get contributions from F, G, H per electrode
  scaled_amps = amps / (0.8825 + 0.27*pdurs)
  brights = 1.84*scaled_amps + 0.2*freqs + 2.0986
  sizes = np.maximum(1.081*scaled_amps - 0.3533764, min_size)
  streaks = np.maximum(1.56 - 0.54 * pdurs ** 0.21, min_streak)

  I = np.zeros(shape=(n_space), dtype=np.float32)

  for idx_space in range(n_space):
    # (n_segments, n_elecs)
    d2_el = ((axon_segments[idx_space, :, 0])[:, None] - x[None, :])**2. + ((axon_segments[idx_space, :, 1])[:, None] - y[None, :])**2. 

    # (n_segments, n_elecs)
    electrode_intensities = brights * np.exp(-d2_el / ( 2. * rho**2. * sizes)) * ((axon_segments[idx_space, :, 2])[:, None] ** (1. / streaks)[None, :])
    #                      (n_elec) * (n_segments, n_elecs)                  * (n_segments, n_elecs)

    # (n_segments)
    axon_intensities = np.sum(electrode_intensities, axis=1) # ith row is intensity at ith axon segment

    axon_intensities[axon_intensities < thresh_percept] = 0.0

    I[idx_space] = np.max(axon_intensities)

  return np.asarray(np.transpose(np.tile(I, (timesteps, 1))))

In [None]:
#OLD jax copy
"""
Function to be JIT'ed. 
Must be purely functional, and follow all rules at https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html

"""
def gpu_biphasic_axon_map(amps, freqs, pdurs, x, y, # Per ACTIVE electrode
                          axon_segments, rho, axlambda, thresh_percept, timesteps):
  deg2rad = 3.14159265358979323846 / 180.0

  n_space = len(axon_segments)

  min_size = 10**2 / rho**2
  min_streak = 10**2 / axlambda **2

  # First get contributions from F, G, H per electrode
  scaled_amps = amps / (0.8825 + 0.27*pdurs)
  brights = 1.84*scaled_amps + 0.2*freqs + 2.0986
  sizes = np.maximum(1.081*scaled_amps - 0.3533764, min_size)
  streaks = np.maximum(1.56 - 0.54 * pdurs ** 0.21, min_streak)

  I = np.zeros(shape=(n_space), dtype=np.float32)

  for idx_space in range(n_space):
    # (n_segments, n_elecs)
    d2_el = ((axon_segments[idx_space, :, 0])[:, None] - x[None, :])**2. + ((axon_segments[idx_space, :, 1])[:, None] - y[None, :])**2. 

    # (n_segments, n_elecs)
    electrode_intensities = brights * np.exp(-d2_el / ( 2. * rho**2. * sizes)) * ((axon_segments[idx_space, :, 2])[:, None] ** (1. / streaks)[None, :])
    #                      (n_elec) * (n_segments, n_elecs)                  * (n_segments, n_elecs)

    # (n_segments)
    axon_intensities = np.sum(electrode_intensities, axis=1) # ith row is intensity at ith axon segment

    # axon_intensities = jax.ops.index_update(axon_intensities, (axon_intensities < thresh_percept), 0.0)

    I = jax.ops.index_update(I, idx_space, np.max(axon_intensities))
#     I[idx_space] = np.max(axon_intensities)

  return np.asarray(np.transpose(np.tile(I, (timesteps, 1))))
