In [39]:
import sys

sys.path.append('../sht')
from sht import DirectSHT
import interp_funcs as interp

import numpy as np
from scipy.stats import mode
import jax.numpy as jnp
from jax import jit, vmap, device_put
import matplotlib.pyplot as plt
import utils

from functools import partial

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [40]:
# We can very quickly calculate Ylm values. As an example, compute them up to lmax=Nl
Nl  = 500
Nx  = 1024

sht = DirectSHT(Nl,Nx)
x_samples = sht.x

In [41]:
def get_points(Nrandoms):
    # A function to randomly generate data points
    #thetas = np.random.uniform(np.pi/8, np.pi/2., Nrandoms)
    #phis = np.random.uniform(0, np.pi, Nrandoms)
    #wts = np.random.uniform(1, 1.5, Nrandoms)
    
    rng = np.random.default_rng()
    cmin,cmax = np.sin(np.radians(-30)),np.sin(np.radians(30.))
    thetas     = np.arccos(rng.uniform(low=cmin,high=cmax,size=Nrandoms))
    phis       = rng.uniform(low=0.0,high=2*np.pi,size=Nrandoms)
    wts        = np.ones(Nrandoms)
    
    x=np.cos(thetas)

    sorted_idx = np.argsort(x)
    x_data_sorted = x[sorted_idx]; w_i_sorted = wts[:][sorted_idx]; phi_data_sorted = phis[:][sorted_idx]
    return x_data_sorted, w_i_sorted, phi_data_sorted

In [42]:
Nrandoms = int(2e7)
x_data_sorted, w_i_sorted, phi_data_sorted = get_points(Nrandoms)

# Keep positive only for now
x_data_sorted = x_data_sorted[np.where(x_data_sorted>=0)[0]]
w_i_sorted = w_i_sorted[np.where(x_data_sorted>=0)[0]]
phi_data_sorted = phi_data_sorted[np.where(x_data_sorted>=0)[0]]

spline_idx = np.digitize(x_data_sorted, x_samples) - 1
t = x_data_sorted - x_samples[spline_idx]

# Put things in the GPU
x_data_sorted_jax = device_put(x_data_sorted)
spline_idx_jax = device_put(spline_idx)

In [94]:
def find_transitions(arr):
    '''
    Find the indices of transitions between different values in an array
    :param arr: 1D numpy array indicating what bin each element belongs to (must be sorted)
    :return: 1D numpy array of indices where the value in arr changes (includes 0)
    '''
    # Find the differences between consecutive elements
    differences = np.diff(arr)
    # Find the indices where differences are non-zero
    transition_indices = np.nonzero(differences)[0] + 1
    # Prepend a zero for convenience
    transition_indices = np.insert(transition_indices, 0, 0, axis=0)
    return transition_indices

def reshape_array(data, transitions, bin_num, bin_len):
    '''
    Reshape a 1D array into a 2D array to facilitate binning
    :param data: 1D numpy array of data to be binned
    :param transitions: 1D numpy array of indices where the value in data changes (includes 0)
    :param bin_num: int. Number of bins where there is data 
    :param bin_len: int. Maximum number of points in a bin
    :return: 2D numpy array of reshaped data, zero padded in bins with fewer points
    '''
    data_reshaped = np.zeros((bin_num, bin_len))
    for i in range(bin_num-1):
        fill_in = data[transitions[i]:transitions[i+1]]
        data_reshaped[i,:len(fill_in)] = fill_in
    return data_reshaped

# First, we find the number of different bins that are populated
occupied_bins = np.unique(spline_idx)
bin_num = len(occupied_bins)
# Then, we find the maximum number of points in a bin
bin_len = mode(spline_idx).count
# Find the indices of transitions between bins
transitions = find_transitions(spline_idx)
# Reshape the inputs into a 2D array for fast binning
reshaped_inputs = utils.reshape_vs_array([w_i_sorted * input_ for input_ in
                   [(2 * t + 1) * (1 - t) ** 2, t * (1 - t) ** 2, t ** 2 * (3 - 2 * t), t ** 2 * (t - 1)]], transitions, bin_num, bin_len)
# Make a mask to discard spurious zeros
mask = reshape_array(np.ones_like(phi_data_sorted),transitions,bin_num,bin_len)
# Mask and put in GPU memory
reshaped_inputs = device_put(mask*reshaped_inputs)
reshaped_phi_data = device_put(mask*reshape_array(phi_data_sorted,transitions,bin_num,bin_len))

In [98]:
def get_vs(mmax, phi_data_reshaped, reshaped_inputs):
    vs_r = np.zeros((mmax+1, 4, phi_data_reshaped.shape[0])); vs_i=vs_r.copy()
    for m in range(mmax+1):
        vs_r[m,:,:], vs_i[m,:,:] = interp.get_vs_at_m(m, phi_data_reshaped, reshaped_inputs)
        #vs_r_at_m, vs_i_at_m = get_vs_at_m(m, phi_data_reshaped, reshaped_inputs)
        #vs_r.at[m,:,:].set(vs_r_at_m); vs_i.at[m,:,:].set(vs_i_at_m)
    return vs_r, vs_i

get_vs(500, reshaped_phi_data, reshaped_inputs)

(array([[[ 1.75389883e+04,  1.77559883e+04,  1.76559883e+04, ...,
           1.77189883e+04,  1.76319883e+04,  0.00000000e+00],
         [ 7.72772408e+00,  7.80593872e+00,  7.79531145e+00, ...,
           7.77301884e+00,  7.76347637e+00,  0.00000000e+00],
         [ 1.35965124e-02,  1.37403114e-02,  1.37607269e-02, ...,
           1.36844758e-02,  1.36700673e-02,  0.00000000e+00],
         [-4.53117490e-03, -4.57909610e-03, -4.58589708e-03, ...,
          -4.56048874e-03, -4.55568731e-03,  0.00000000e+00]],
 
        [[ 6.72482300e+00,  4.71902237e+01,  1.37736755e+02, ...,
           1.14091255e+02,  6.23192635e+01,  0.00000000e+00],
         [ 1.42665450e-02,  2.61438042e-02,  8.24133381e-02, ...,
           7.31289759e-02,  5.66772223e-02,  0.00000000e+00],
         [ 2.41667585e-05,  7.83657451e-05,  1.85082943e-04, ...,
           1.63100092e-04,  1.13528455e-04,  0.00000000e+00],
         [-8.05437230e-06, -2.61137538e-05, -6.16783436e-05, ...,
          -5.43536808e-05, -3.78344

In [100]:

vsr, vsi = get_vs(500, reshaped_phi_data, reshaped_inputs)


In [89]:
%%timeit
interp.get_vs(500, reshaped_phi_data, reshaped_inputs) 

1.35 s ± 26.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [95]:
import jax
print(jax.make_jaxpr(get_vs_mapped)(ms, reshaped_phi_data, reshaped_inputs))

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:i32[10][39m b[35m:f32[569,18012][39m c[35m:f32[4,569,18012][39m. [34m[22m[1mlet
    [39m[22m[22md[35m:f32[10][39m = convert_element_type[new_dtype=float32 weak_type=False] a
    e[35m:f32[1,569,18012][39m = broadcast_in_dim[
      broadcast_dimensions=(1, 2)
      shape=(1, 569, 18012)
    ] b
    f[35m:f32[10,1,1][39m = broadcast_in_dim[broadcast_dimensions=(0,) shape=(10, 1, 1)] d
    g[35m:f32[10,569,18012][39m = mul f e
    h[35m:f32[10,569,18012][39m = cos g
    i[35m:f32[10][39m = convert_element_type[new_dtype=float32 weak_type=False] a
    j[35m:f32[1,569,18012][39m = broadcast_in_dim[
      broadcast_dimensions=(1, 2)
      shape=(1, 569, 18012)
    ] b
    k[35m:f32[10,1,1][39m = broadcast_in_dim[broadcast_dimensions=(0,) shape=(10, 1, 1)] i
    l[35m:f32[10,569,18012][39m = mul k j
    m[35m:f32[10,569,18012][39m = sin l
    n[35m:f32[10,1,569,18012][39m = broadcast_in_dim[
      broadcast_dimens