
# Surface mapping

When we create surface based simulations or pseudospectral simulations we need to map data using index vectors.  This notebook explores what code in Jax to use for that.


In [5]:
import tvb.datatypes.cortex 
cortex = tvb.datatypes.cortex.Cortex.from_file()
cortex

Unnamed: 0,value
Type,Cortex
"coupling_strength [min, median, max]","[1, 1, 1]"
coupling_strength dtype,float64
coupling_strength shape,"(1,)"
gid,UUID('417dc1ae-573a-4954-a6ba-4eaa1680e83b')
local_connectivity,
region_mapping_data,RegionMapping gid: 861818b5-e88c-453e-a1ba-38c4223ce126
title,Cortex gid: 417dc1ae-573a-4954-a6ba-4eaa1680e83b


In [9]:
region_map = cortex.region_mapping_data.array_data

In [10]:
region_map.shape, region_map.min(), region_map.max()

((16384,), 0, 75)

In [11]:
import tvb.datatypes.connectivity
conn = tvb.datatypes.connectivity.Connectivity.from_file()
conn.weights.shape



(76, 76)

Move those arrays to Jax

In [12]:
import jax.numpy as np
import jax

region_map = np.array(region_map)



## From region vector to surface

This is the easiest operation since it just uses the `region_map` as indices

In [19]:
x = np.ones((76, ))

x[region_map].shape

(16384,)

## From surface to region vector

The inverse operation is more interesting: TVB uses an average over vertices in a region to compute the corresponding regional value.  Computing the average is easier if we first count vertices in each region:

In [23]:
vtx_count = np.bincount(region_map)

Jax doesn't do the `np.add.at` thing, however the array update has the correct semantics with a surprisingly similar name.  We can check that against `vtx_count` like so:

In [25]:
x = x.at[:].set(0)

x = x.at[region_map].add(1)

x == vtx_count

Array([ True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True,  True,  True], dtype=bool)

now with a full random surface vector,

In [27]:
key = jax.random.PRNGKey(42)
surf_x = jax.random.normal(key, shape=(len(region_map), ))

surf_x_roi_mean = np.zeros_like(x)
surf_x_roi_mean = surf_x_roi_mean.at[region_map].add(surf_x)

"Yes, but did it work?" with a scan

In [33]:
def surf_roi_mean_check(surf_x):
    def op(c, args):
        j, sx = args
        c = c.at[j].add(sx)
        return c, None
    c = np.zeros_like(x)
    c, _ = jax.lax.scan(op, c, (region_map, surf_x))
    return c

assert (surf_roi_mean_check(surf_x) == surf_x_roi_mean2).all()

obligatory micro benchmark

In [34]:
%timeit surf_roi_mean_check(surf_x)
%timeit np.zeros_like(x).at[region_map].add(surf_x)

32.7 ms ± 2.38 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
1.1 ms ± 32.9 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


Regardless of performance I like Jax syntax/semantics here, it's easier to reason about.