In [2]:
import sys

sys.path.append('/Users/antonbaleatolizancos/Projects/direct_SHT/direct_sht/sht')
from sht import DirectSHT

import numpy as np
import numba as nb
import matplotlib.pyplot as plt
import jax.numpy as jnp
from jax import jit, vmap, device_put
import time

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
# We can very quickly calculate Ylm values. As an example, compute them up to lmax=Nl
Nl  = 500
Nx  = 1024

sht = DirectSHT(Nl,Nx)
x_samples = sht.x

In [4]:
def get_points(Nrandoms):
    # A function to randomly generate data points
    thetas = np.random.uniform(np.pi/4, np.pi/2., Nrandoms)
    phis = np.random.uniform(0, np.pi, Nrandoms)
    wts = np.random.uniform(1, 1.5, Nrandoms)
    
    x=np.cos(thetas)

    sorted_idx = np.argsort(x)
    x_data_sorted = x[sorted_idx]; w_i_sorted = wts[:][sorted_idx]; phi_data_sorted = phis[:][sorted_idx]
    return x_data_sorted, w_i_sorted, phi_data_sorted

In [5]:
Nrandoms = int(1e7)
x_data_sorted, w_i_sorted, phi_data_sorted = get_points(Nrandoms)
spline_idx = np.digitize(x_data_sorted, x_samples) - 1
t = x_data_sorted - x_samples[spline_idx]

# Put things in the GPU
x_data_sorted_jax = device_put(x_data_sorted)
spline_idx_jax = device_put(spline_idx)

Metal device set to: Apple M2 Max




In [47]:
from scipy.stats import mode

def find_transitions(arr):
    # Find the differences between consecutive elements
    differences = np.diff(arr)
    # Find the indices where differences are non-zero
    transition_indices = np.nonzero(differences)[0] + 1
    # Prepend a zero for convenience
    transition_indices = np.insert(transition_indices, 0, 0, axis=0)
    return transition_indices

def reshape_array(data, transitions,spline_idx, bin_num, bin_len):
    # Reshape the data into a 2D array
    data_reshaped = np.zeros((bin_num, bin_len))
    for i in range(bin_num-1):
        fill_in = data[transitions[i]:transitions[i+1]]
        data_reshaped[i,:len(fill_in)] = fill_in
    return data_reshaped

# First, we find the number of different bins that are populated
bin_num = len(np.unique(spline_idx))
# Then, we find the maximum number of points in a bin
bin_len = mode(spline_idx).count
# Find the indices of transitions between bins
transitions = find_transitions(spline_idx)
# Reshape the inputs into a 2D array for fast binning
reshaped_inputs = [device_put(reshape_array(w_i_sorted*input_,transitions,spline_idx,bin_num,bin_len)) for input_ in [(2*t+1)*(1-t)**2, t*(1-t)**2, t**2*(3-2*t), t**2*(t-1)]]
reshaped_phi_data = device_put(reshape_array(phi_data_sorted,transitions,spline_idx,bin_num,bin_len))
# Make a mask to discard spurious zeros
mask = reshaped_inputs[0]!=0
reshaped_inputs = [input_*mask for input_ in reshaped_inputs]
reshaped_phi_data *= mask

In [50]:
from functools import partial
@jit
def collapse(arr):
    return jnp.sum(arr, axis=1)

@partial(jit, static_argnames=['m'])
def cosmphi(phi, m):
    return jnp.cos(m*phi)

@partial(jit, static_argnames=['m'])
def sinmphi(phi, m):
    return jnp.sin(m*phi)

#@partial(jit, static_argnames=['m'])
def get_vs(m, phi_data_sorted, reshaped_inputs):
    phi_dep_real,phi_dep_imag = [fn(m*phi_data_sorted) for fn in [jnp.cos, jnp.sin]]
    vs_real = [collapse(input_*phi_dep_real) for input_ in reshaped_inputs]
    vs_imag = [collapse(input_*phi_dep_imag) for input_ in reshaped_inputs]
    return vs_real, vs_imag

get_vs_mapped = vmap(get_vs, in_axes=(0,None,None))