
# Surface mapping

When we create surface based simulations or pseudospectral simulations we need to map data using index vectors.  This notebook explores what code in Jax to use for that.


In [2]:
import tvb.datatypes.cortex 
cortex = tvb.datatypes.cortex.Cortex.from_file()
cortex



Unnamed: 0,value
Type,Cortex
"coupling_strength [min, median, max]","[1, 1, 1]"
coupling_strength dtype,float64
coupling_strength shape,"(1,)"
gid,UUID('d0e5ce3d-be19-400e-b06a-15f4bfb73881')
local_connectivity,
region_mapping_data,RegionMapping gid: a7910f59-cf7b-4b1a-bea0-28a48bffb7e5
title,Cortex gid: d0e5ce3d-be19-400e-b06a-15f4bfb73881


In [3]:
region_map = cortex.region_mapping_data.array_data

In [4]:
region_map.shape, region_map.min(), region_map.max()

((16384,), 0, 75)

In [5]:
import tvb.datatypes.connectivity
conn = tvb.datatypes.connectivity.Connectivity.from_file()
conn.weights.shape



(76, 76)

Move those arrays to Jax

In [6]:
import jax.numpy as np
import jax

region_map = np.array(region_map)

## From region vector to surface

This is the easiest operation since it just uses the `region_map` as indices

In [7]:
x = np.ones((76, ))

x[region_map].shape

(16384,)

## From surface to region vector

The inverse operation is more interesting: TVB uses an average over vertices in a region to compute the corresponding regional value.  Computing the average is easier if we first count vertices in each region:

In [8]:
vtx_count = np.bincount(region_map)

Jax doesn't do the `np.add.at` thing, however the array update has the correct semantics with a surprisingly similar name.  We can check that against `vtx_count` like so:

In [9]:
x = x.at[:].set(0)

x = x.at[region_map].add(1)

x == vtx_count

Array([ True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True], dtype=bool)

now with a full random surface vector,

In [11]:
key = jax.random.PRNGKey(42)
surf_x = jax.random.normal(key, shape=(len(region_map), ))

surf_x_roi_mean = surf_x_roi_mean.at[region_map].add(surf_x)/vtx_count

"Yes, but did it work?" with a scan

In [13]:
def surf_roi_mean_check(surf_x):
    def op(c, args):
        j, sx = args
        c = c.at[j].add(sx)
        return c, None
    c = np.zeros_like(x)
    c, _ = jax.lax.scan(op, c, (region_map, surf_x))
    return c / vtx_count

assert (surf_roi_mean_check(surf_x) == surf_x_roi_mean).all()

obligatory micro benchmark

In [14]:
%timeit surf_roi_mean_check(surf_x).block_until_ready()
%timeit (np.zeros_like(x).at[region_map].add(surf_x)/vtx_count).block_until_ready()

15.8 ms ± 147 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
602 µs ± 4.79 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
