In [1]:
import jax.numpy as jnp
from jax import grad, jit, random
from jax.example_libraries import optimizers

## Data preprocessing

In [2]:
# get data
from ler.utils import append_json, load_json
import numpy as np
unlensed_params = load_json("jointL1.json")
snr = np.array(unlensed_params['L1'])

In [3]:
# let's generate IMRPhenomD (spinless) interpolartor
from gwsnr import GWSNR
gwsnr = GWSNR(gwsnr_verbose=False)

psds not given. Choosing bilby's default psds
Interpolator will be loaded for L1 detector from ./interpolator_pickle/L1/halfSNR_dict_0.pickle
Interpolator will be loaded for H1 detector from ./interpolator_pickle/H1/halfSNR_dict_0.pickle
Interpolator will be loaded for V1 detector from ./interpolator_pickle/V1/halfSNR_dict_0.pickle


In [4]:
# get half_snr 
# open pickle file
import pickle
with open('./interpolator_pickle/L1/halfSNR_dict_0.pickle', 'rb') as f:
    half_snr = pickle.load(f)

In [5]:
from gwsnr import antenna_response_array, cubic_spline_interpolator2d

def input_output(idx, params):

    det_idx = 0 # L1

    mass_1 = np.array(params['mass_1'])[idx]
    mass_2 = np.array(params['mass_2'])[idx]
    luminosity_distance = np.array(params['luminosity_distance'])[idx]
    theta_jn = np.array(params['theta_jn'])[idx]
    psi = np.array(params['psi'])[idx]
    geocent_time = np.array(params['geocent_time'])[idx]
    ra = np.array(params['ra'])[idx]
    dec = np.array(params['dec'])[idx]
    
    detector_tensor = gwsnr.detector_tensor_list
    snr_halfscaled = np.array(gwsnr.snr_halfsacaled_list)
    ratio_arr = gwsnr.ratio_arr
    mtot_arr = gwsnr.mtot_arr
    
    size = len(mass_1)
    len_ = len(detector_tensor)
    mtot = mass_1 + mass_2
    ratio = mass_2 / mass_1
    # get array of antenna response
    Fp, Fc = antenna_response_array(ra, dec, geocent_time, psi, detector_tensor)

    Mc = ((mass_1 * mass_2) ** (3 / 5)) / ((mass_1 + mass_2) ** (1 / 5))
    eta = mass_1 * mass_2/(mass_1 + mass_2)**2.
    A1 = Mc ** (5.0 / 6.0)
    ci_2 = np.cos(theta_jn) ** 2
    ci_param = ((1 + np.cos(theta_jn) ** 2) / 2) ** 2
    
    size = len(mass_1)
    snr_half_ = np.zeros((len_,size))
    d_eff = np.zeros((len_,size))

    # loop over the detectors
    for j in range(len_):
        # loop over the parameter points
        for i in range(size):
            snr_half_coeff = snr_halfscaled[j]
            snr_half_[j,i] = cubic_spline_interpolator2d(mtot[i], ratio[i], snr_half_coeff, mtot_arr, ratio_arr)
            d_eff[j,i] =luminosity_distance[i] / np.sqrt(
                    Fp[j,i]**2 * ci_param[i] + Fc[j,i]**2 * ci_2[i]
                )

    #amp0
    amp0 =  A1 / d_eff

    # get spin parameters
    a_1 = np.array(params['a_1'])[idx]
    a_2 = np.array(params['a_2'])[idx]
    tilt_1 = np.array(params['tilt_1'])[idx]
    tilt_2 = np.array(params['tilt_2'])[idx]
    phi_12 = np.array(params['phi_12'])[idx]
    phi_jl = np.array(params['phi_jl'])[idx]

    # input data
    # X = np.vstack([L1, amp0, Mc, eta, theta_jn, a_1, a_2, tilt_1, tilt_2, phi_12, phi_jl]).T
    X = np.vstack([snr_half_[det_idx], amp0[det_idx], eta, a_1, a_2, tilt_1, tilt_2, phi_12, phi_jl]).T

    # output data
    # get L1 snr for y train 
    y = snr[idx]

    return(X,y)

In [6]:
# training set
unlensed_params.keys()

dict_keys(['zs', 'geocent_time', 'ra', 'dec', 'phase', 'psi', 'theta_jn', 'a_1', 'a_2', 'tilt_1', 'tilt_2', 'phi_12', 'phi_jl', 'luminosity_distance', 'mass_1_source', 'mass_2_source', 'mass_1', 'mass_2', 'L1', 'H1', 'V1', 'optimal_snr_net'])

In [7]:
len(snr)

114180

In [8]:
snr_min = 6.
snr_max = 10.
bool_ = (snr>snr_min) & (snr<snr_max) 

len_ = len(snr)
idx_train_ = np.arange(len_)
# randomize the train set
idx_train_ = np.random.choice(idx_train_, len(idx_train_), replace=False)
idx_train = idx_train_[:-10000]
idx_test = idx_train_[-10000:]

In [9]:
X_train, y_train = input_output(idx_train, unlensed_params)
X_test, y_test = input_output(idx_test, unlensed_params)

In [10]:
# feature scaling
from sklearn.preprocessing import StandardScaler
sc = StandardScaler()
X_train = sc.fit_transform(X_train)
X_test = sc.transform(X_test)

In [11]:
np.shape(X_train)

(104180, 9)

## Now back to ANN

In [22]:
def init_network_params(rng, input_dim, hidden_dim1, hidden_dim2, hidden_dim3, output_dim):
    """Initialize parameters of a fully-connected neural network."""
    rng, *layer_rngs = random.split(rng, 5)
    params = []  # List to store parameters for each layer

    # Loop through each layer
    for d_in, d_out in zip([input_dim, hidden_dim1, hidden_dim2, hidden_dim3], [hidden_dim1, hidden_dim2, hidden_dim3, output_dim]):
        # Split the random key for generating weights and biases
        w_key, b_key = random.split(rng)
        
        # Initialize weights with normal distribution
        w = random.normal(w_key, (d_in, d_out))
        
        # Initialize biases with normal distribution
        b = random.normal(b_key, (d_out,))
        
        # Append the weights and biases as a tuple to the params list
        params.append((w, b))
        
        # Update the RNG for the next layer
        rng = layer_rngs.pop(0)
        
    return params


* params: This is a list containing the parameters (weights and biases) for each layer of the neural network. Each element of the list is a tuple (w, b) where w is the weight matrix and b is the bias vector for a particular layer.

* input_data: This is the input data batch to the neural network.

In [23]:
@jit
def neural_network(params, input_data):
    """Feedforward function for a fully-connected neural network."""
    activations = input_data
    for w, b in params[:-1]:
        activations = jnp.tanh(jnp.dot(activations, w) + b)
    w, b = params[-1]
    output = jnp.dot(activations, w) + b
    return output

In [24]:
# Define Loss Function
@jit
def loss(params, inputs, targets):
    predictions = neural_network(params, inputs)
    return jnp.mean((predictions - targets) ** 2)


In [26]:
# Initialize Parameters and Optimize
rng = random.PRNGKey(0)
input_dim = 9
hidden_dim1 = 64
hidden_dim2 = 64
hidden_dim3 = 64
output_dim = 1

# Initialize parameters
params = init_network_params(rng, input_dim, hidden_dim1, hidden_dim2, hidden_dim3, output_dim)

# Initialize optimizer
step_size = 1e-3
opt_init, opt_update, get_params = optimizers.adam(step_size)
opt_state = opt_init(params)

In [27]:
@jit
def step(params, opt_state, inputs, targets):
    """Perform a single optimization step."""
    grads = grad(loss)(params, inputs, targets)
    new_opt_state = opt_update(0, grads, opt_state)
    new_params = get_params(new_opt_state)
    return new_params, new_opt_state

In [28]:
inputs = X_train[0:10000]
targets = y_train[0:10000]

In [29]:
num_epochs = 1000
batch_size = 32
num_batches = len(inputs) // batch_size

for epoch in range(num_epochs):
    rng, *batch_rngs = random.split(rng, num_batches + 1)
    for batch_idx in range(num_batches):
        batch_rng = batch_rngs[batch_idx]
        batch_indices = random.choice(batch_rng, len(inputs), (batch_size,), replace=False)
        batch_inputs = inputs[batch_indices]
        batch_targets = targets[batch_indices]
        params, opt_state = step(params, opt_state, batch_inputs, batch_targets)
    if epoch % 100 == 0:
        print("Epoch {} - Loss: {}".format(epoch, loss(params, inputs, targets)))

Epoch 0 - Loss: 44.53794860839844


KeyboardInterrupt: 