In [4]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import pickle

from jax import random
import jax
import jax.numpy as jnp
from jax.scipy.signal import convolve2d

### Bookkeeping
* All time units are in microseconds

### First, generate continuous arrival time spectrum and do binning

In [2]:
def gen_binned_waveform_jax(key: jax.random.PRNGKey, nphotons: int = 1, npmts: int = 128, arrival_window_width: float = 0.05, res: float = 0.001) -> jnp.ndarray:
    
    nbins = int(arrival_window_width / res)
    total_bins = nbins + int(arrival_window_width) # pad time axis

    key_photons, key_pmts = jax.random.split(key)

    # Generate random "continuous" arrival times
    arrival_times = jax.random.uniform(key_photons, shape=(nphotons,), minval=0.0, maxval=arrival_window_width)
    bin_indices = jnp.floor(arrival_times / res).astype(jnp.int32)
    bin_indices = jnp.clip(bin_indices, 0, nbins - 1)

    # Generate random PMT assignments
    pmt_indices = jax.random.randint(key_pmts, shape=(nphotons,), minval=0, maxval=npmts)

    # Convert 2D indices to flat indices
    flat_indices = pmt_indices * total_bins + bin_indices

    # Use `jax.numpy.bincount` to accumulate
    counts = segment_sum(data=jnp.ones(nphotons), segment_ids=flat_indices, num_segments=total_bins)
    waveform = counts.reshape(npmts, total_bins)

    return waveform


In [6]:
data = gen_binned_waveform_jax(nphotons=10000)
print(data.shape)

plt.imshow(data)

TypeError: gen_binned_waveform_jax() missing 1 required positional argument: 'key'

In [62]:
# making a convolution
'''
x = jnp.linspace(0, 10, 500)
y = jnp.sin(x) + 0.2 * random.normal(key, shape=(500,))

window = jnp.ones(10) / 10
y_smooth = jnp.convolve(y, window, mode='same')
'''

# smoothing convolution example
window = jnp.ones((10, 10)) / 10
data_smooth = jnp.convolve(data, window, mode='same')
plt.imshow(data_smooth)

TypeError: Error interpreting argument to <function convolve at 0x7f19969e3130> as an abstract array. The problematic value is of type <class 'torch.Tensor'> and was passed to the function at path a.
This typically means that a jit-wrapped function was called with a non-array argument, and this argument was not marked as static using the static_argnums or static_argnames parameters of jax.jit.