In [2]:
import os
import time
import pkg_resources
import math
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.optim as optim
from scipy import constants
import torch.nn as nn
import torch.nn.functional as F

def coalescence_time(f, chirp_mass):
    'Coalescence time of BBH merger'
    return 5 * (8 * np.pi * f)**(-8/3) * chirp_mass**(-5/3)

def freq(t, tc, chirp_mass):
    'finding frequency as function of time'
    return (((tc - t)/5)**(-3/8)) / (8*np.pi*(chirp_mass**(5/8)))

def phase(tc, t, ch_mass, phi_c):   #what is this called?????
    return phi_c - 2 * ((tc - t)/(5 * ch_mass))**(5/8)

def phi_t_Lbar(time, phi_0):
    'LISA orbital phase'
    return phi_0 + (2*np.pi*time/constants.year)

def theta_s_t(theta_sbar, phi_t_Lbar, phi_sbar):
    'Source location'
    cos_theta = (0.5 * np.cos(theta_sbar)) - (np.sqrt(3)/2)*np.sin(theta_sbar)*np.cos(phi_t_Lbar-phi_sbar)
    return np.arccos(cos_theta)

def alpha_i_t(i, t, alpha_0):
    'LISA Arm orientation'
    T = constants.year
    return 2*np.pi*t/T - np.pi/12 - (i-1)*np.pi/3 + alpha_0

def phi_s_t(theta_sbar, phi_t_bar, phi_sbar, alpha1):
    'Source location in (unbarred) detector frame'
    return alpha1 + np.pi/12 + np.arctan((np.sqrt(3)*np.cos(theta_sbar) + np.sin(theta_sbar)*np.cos(phi_t_bar - phi_sbar))                                   /(2*np.sin(theta_sbar)*np.sin(phi_t_bar - phi_sbar)) )

def psi_s_t(theta_Lbar, phi_Lbar, theta_sbar, phi_sbar, phi_t_Lbar, theta_s_t_):
    'Polarisation angle'
    L_dot_z = 0.5 * np.cos(theta_Lbar) - ( (np.sqrt(3)/2) * np.sin(theta_Lbar) * np.cos(phi_t_Lbar - phi_Lbar) )
    L_dot_n = np.cos(theta_Lbar)*np.cos(theta_sbar) + np.sin(theta_Lbar)*np.sin(theta_sbar)*np.cos(phi_Lbar - phi_sbar)
    global cos_i
    cos_i = L_dot_n
    
    cross = (0.5*np.sin(theta_Lbar)*np.sin(theta_sbar)*np.sin(phi_Lbar - phi_sbar)) - (np.sqrt(3)/2)*np.cos(phi_t_Lbar)*( (np.cos(theta_Lbar)*np.sin(theta_sbar)*np.sin(phi_sbar) -                                       np.cos(theta_sbar)*np.sin(theta_Lbar)*np.sin(phi_Lbar)) )    - (np.sqrt(3)/2)*np.sin(phi_t_Lbar)*(np.cos(theta_sbar)*np.sin(theta_Lbar)*np.cos(phi_Lbar) -                                       np.cos(theta_Lbar)*np.sin(theta_sbar)*np.cos(phi_sbar))
    
    tan_psi = (L_dot_z - L_dot_n * np.cos(theta_s_t_)) / cross
    
    return np.arctan(tan_psi)

def doppler_phase(f, theta_sbar, phi, phi_sbar):
    'doppler phase due to LISA motion'
    R = constants.astronomical_unit/constants.c
    return 2 * np.pi * f * R * np.sin(theta_sbar) * np.cos(phi - phi_sbar)

def F_plus(theta_s, phi_s, psi_s):
    'Detector Beam Pattern Coefficient'
    return (0.5 * (1 + np.cos(theta_s)**2) * np.cos(2*phi_s) * np.cos(2*psi_s)) - (np.cos(theta_s) * np.sin(2*phi_s) * np.sin(2*psi_s) )

def F_cross(theta_s, phi_s, psi_s):
    'Detector Beam Pattern Coefficient'
    return (0.5 * (1+np.cos(theta_s)**2) * np.cos(2*phi_s) * np.sin(2*psi_s)) + (np.cos(theta_s) * np.sin(2*phi_s) * np.cos(2*psi_s))

def phi_P_I_t(cos_i, F_plus, F_cross): 
    'Polarisation Phase'
    return np.arctan2( (2*cos_i*F_cross), ((1 + (cos_i**2))*F_plus) )

def A_t(M_c, f, D_L):
    'Waveform Amplitude'
    return 2 * M_c**(5/3) * (np.pi*f)**(2/3) / D_L

def A_p_t(F_plus, F_cross, cos_i):
    'Polarization Amplitude'
    return np.sqrt(3)/2 * (((1+cos_i**2)**2 * F_plus**2) + (4 * cos_i**2 * F_cross**2))**(1/2) 

def h_t(A_t, A_p_t, phase, phi_P_I_t, doppler_phase):
    'Strain signal'
    return A_t * A_p_t * np.cos(phase + phi_P_I_t + doppler_phase)

def phi_f(f, phi_c, chirp_mass, tc):
    'phase as func of freq'
    return 2*np.pi*f*tc - phi_c - np.pi/4 + (3/4)*(8*np.pi*chirp_mass*f)**(-5/3)

def fft(A_p_t, f, chirp_mass, D_L, phi_f, phi_p_t, phi_d_t):
    return np.sqrt(5/96) * np.pi**(-2/3) * (1/D_L) * A_p_t * chirp_mass**(5/6) * f**(-7/6) * np.exp(1j * (phi_f - phi_p_t - phi_d_t))

def P_oms(f):
    '''Single link optical metrology noise'''
    return (1.5e-11)**2 * (1 + (2e-3/f)**4)

def P_acc(f):
    '''Single test mass acceleration noise'''
    return (3e-15)**2 * (1 + (0.4e-3/f)**2) * (1 + (f/8e-3)**4)

def P_n(f, fstar=19.09e-3, L=2.5e9):
    '''Total noise'''
    return P_oms(f)/L**2 + 2*(1+(np.cos(f/fstar))**2) * P_acc(f)/((2*np.pi*f)**4 * L**2)

def S_c(f):
    '''Confusion Noise'''
    A = 9e-45
    alpha = 0.171
    beta = 292
    kappa = 1020
    gamma = 1680
    f_k = 0.00215
    
    return A * f ** (-7/3) * np.exp(-(f**alpha) + beta * f * np.sin(kappa*f)) * (1 + np.tanh(gamma*(f_k-f)))

def S_n(f, fstar=19.09e-3):
    '''Lisa Sensitivity'''
    return S_c(f) + 10/(3*2.5e9**2) * (P_oms(f) + (4*P_acc(f)/(2*np.pi*f)**4))         * (1+0.6*(f / fstar)**2)
        
def noise_f(f):
    '''Random gaussian noise with random phase between 0 2pi'''
    phase = np.exp(1j*np.random.uniform(0, 2*np.pi, (1,200)))
    noise = np.random.normal(0, 1) * phase
    return noise

def h_func_f2(angles, noise_ind, Nsamples=200):
    '''Input params are 2 angles; cos_theta and phi.
    ------------
     Returns whitened strain signal with noise based on LISA sensitivity'''
    theta_Lbar = np.pi/5
    phi_Lbar = np.pi/11
    theta_sbar = np.arccos(angles[0])
    phi_sbar = angles[1]
    alpha_0 = 0
    phi_0 = 0
    phi_c = 0
    det_no = 1
    fmin = 1e-4
    D_L = constants.parsec * 1e9 / constants.c #luminosity distance in seconds (1Gpc).

    #calculating coalescence time, f_isco and t_isco
    Nmax = 200
    delta_t = 500
    tc = Nmax*delta_t
    M_c = (5*(8*np.pi*fmin)**(-8/3) * tc**(-1))**(3/5)
    M = 4**(3/5) * M_c
    f_isco = 1 / (np.pi * 6**(3/2) * 2**(6/5) *M_c)

    #time and frequency arrays
    f = np.linspace(fmin, f_isco, Nsamples)
    t = tc - 5 * (8*np.pi*f)**(-8/3) * M_c**(-5/3)
    
    #params
    phi = phase(tc, t, M_c, phi_c)
    phi_f_t = phi_f(f, phi_c, M_c, tc)
    phi_t_Lbar_ = phi_t_Lbar(t, phi_0)
    theta_s_t_ = theta_s_t(theta_sbar, phi_t_Lbar_, phi_sbar)
    alpha_t_ = alpha_i_t(det_no, t, alpha_0)
    phi_s_t_ = phi_s_t(theta_sbar, phi_t_Lbar_, phi_sbar, alpha_t_)
    psi_s_t_ = psi_s_t(theta_Lbar, phi_Lbar, theta_sbar, phi_sbar, phi_t_Lbar_, theta_s_t_)
    doppler_phase_ = doppler_phase(f, theta_sbar, phi_t_Lbar_, phi_sbar)
    F_plus_ = F_plus(theta_s_t_, phi_s_t_, psi_s_t_)
    F_cross_ = F_cross(theta_s_t_, phi_s_t_, psi_s_t_)
    phi_P_I_t_ = phi_P_I_t(cos_i, F_plus_, F_cross_)
    A_t_ = A_t(M_c, f, D_L)
    A_p_t_ = A_p_t(F_plus_, F_cross_, cos_i)

    fourier_signal = fft(A_p_t_, f, M_c, D_L, phi_f_t, phi_P_I_t_, doppler_phase_)
    #noise equation compared to joels is slightly different but negligible. 
    return fourier_signal/np.sqrt(S_n(f)) + noise_ind * noise_f(f)

def setgeometry(q):
    global qdim, xmin, xmax, xstops, xmid, xwid

    # bins
    qdim = q

    # prior range for x (will be uniform)
    xmin, xmax = 0, 1

    # definition of quantization bins
    xstops = np.linspace(xmin, xmax, qdim + 1)

    # to plot histograms
    xmid = 0.5 * (xstops[:-1] + xstops[1:])
    xwid = xstops[1] - xstops[0]

setgeometry(64)

def numpy2cuda(array, single=True):
  array = torch.from_numpy(array)
  
  if single:
    array = array.float()
    
  if torch.cuda.is_available():
    array = array.cuda()
    
  return array


def cuda2numpy(tensor):
  return tensor.detach().cpu().numpy()


def makenet(dims, softmax=True, single=True):
  """Make a fully connected DNN with layer widths described by `dims`.
  CUDA is always enabled, and double precision is set with `single=False`.
  The output layer applies a softmax transformation,
  disabled by setting `softmax=False`."""

  ndims = len(dims)

  class Net(nn.Module):
    def __init__(self):
      super(Net, self).__init__()

      # the weights must be set explicitly as attributes in the class
      # (i.e., we can't collect them in a single list)
      for l in range(ndims - 1):
        layer = nn.Linear(dims[l], dims[l+1])
        
        if not single:
          layer = layer.double()
        
        if torch.cuda.is_available():
          layer = layer.cuda()
        
        setattr(self, f'fc{l}', layer)
                
    def forward(self, x):
      # per Alvin's recipe, apply relu everywhere but last layer
      for l in range(ndims - 2):
        x = F.leaky_relu(getattr(self, f'fc{l}')(x), negative_slope=0.2)

      x = getattr(self, f'fc{ndims - 2}')(x)

      if softmax:
        return F.softmax(x, dim=1)
      else:
        return x
  
  return Net


def makenetbn(dims, softmax=True, single=True):
  """A batch-normalizing version of makenet. Experimental."""

  ndims = len(dims)

  class Net(nn.Module):
    def __init__(self):
      super(Net, self).__init__()

      # the weights must be set explicitly as attributes in the class
      # (i.e., we can't collect them in a single list)
      for l in range(ndims - 1):
        layer = nn.Linear(dims[l], dims[l+1])
        bn = nn.BatchNorm1d(num_features=dims[l+1])
        
        if not single:
          layer = layer.double()
          bn = bn.double()
        
        if torch.cuda.is_available():
          layer = layer.cuda()
          bn = bn.cuda()
        
        setattr(self, f'fc{l}', layer)
        setattr(self, f'bn{l}', bn)
                
    def forward(self, x):
      # per Alvin's recipe, apply relu everywhere but last layer
      for l in range(ndims - 2):
        x = getattr(self, f'bn{l}')(F.leaky_relu(getattr(self, f'fc{l}')(x), negative_slope=0.2))

      x = getattr(self, f'fc{ndims - 2}')(x)

      if softmax:
        return F.softmax(x, dim=1)
      else:
        return x
  
  return Net

def kllossGn2(o, l: 'xtrue'):
  """KL loss for Gaussian-mixture output, 2D, precision-matrix parameters."""

  dx  = o[:,0::6] - l[:,0,np.newaxis]
  dy  = o[:,2::6] - l[:,1,np.newaxis]
  
  # precision matrix is positive definite, so has positive diagonal terms
  Fxx = o[:,1::6]**2
  Fyy = o[:,3::6]**2
  
  # precision matrix is positive definite, so has positive 
  Fxy = torch.atan(o[:,4::6]) / (0.5*math.pi) * o[:,1::6] * o[:,3::6]
  
  weight = torch.softmax(o[:,5::6], dim=1)
   
  # omitting the sqrt(4*math*pi) since it's common to all templates
  return -torch.mean(torch.logsumexp(torch.log(weight) - 0.5*(Fxx*dx*dx + Fyy*dy*dy + 2*Fxy*dx*dy) + 0.5*torch.log(Fxx*Fyy - Fxy*Fxy), dim=1))

cos_theta_min = 0.4
cos_theta_max = 0.6
phi_min = 7*np.pi/4
phi_max = 3*np.pi/2

def syntrain(size,  region=[[cos_theta_min, cos_theta_max], [phi_min, phi_max]], varx='theta_N', 
             varall=True, seed=None, single=True, noise=1):
    """Makes a training set using the ROMAN NN. It returns labels (for `varx`,
        or for all if `varall=True`), indicator vectors, and ROM coefficients
        (with `snr` and `noise`). Note that the coefficients are kept on the GPU.
        Parameters are sampled randomly within `region`."""
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu:0'
    
    if seed is not None:
        np.random.seed(seed)
        torch.manual_seed(seed)
    
    with torch.no_grad():
        xs = torch.zeros((size,2), dtype=torch.float, device=device)

        for i, r in enumerate(region):
            xs[:,i] = r[0] + (r[1] - r[0]) * torch.rand((size,), dtype=torch.float, device=device)
        
        xs_1 = xs.detach().cpu().double().numpy()
        
        #generating signals
        signal = np.apply_along_axis(h_func_f2, 1, xs_1, noise)[:,:,0]
            
        signal_r, signal_i = numpy2cuda(signal.real), numpy2cuda(signal.imag)
        
        #setting up real and imag alphas
        alphas = torch.zeros((size, 200*2), dtype=torch.float if single else torch.double, device=device)
        
        alphas[:,0::2] = signal_r 
        alphas[:,1::2] = signal_i 

    xr = xs.detach().cpu().double().numpy()

    del xs, signal_r, signal_i

      # normalize (for provided regions)
    for i, r in enumerate(region):
        xr[:,i] = (xr[:,i] - r[0]) / (r[1] - r[0])

    if isinstance(varx, list):
        ix = ['theta_N','phi'].index(varx[0])
        jx = ['theta_N','phi'].index(varx[1])    

        i = np.digitize(xr[:,ix], xstops, False) - 1
        i[i == -1] = 0; i[i == qdim] = qdim - 1
        px = np.zeros((size, qdim), 'd'); px[range(size), i] = 1

        j = np.digitize(xr[:,jx], xstops, False) - 1
        j[j == -1] = 0; j[j == qdim] = qdim - 1
        py = np.zeros((size, qdim), 'd'); py[range(size), j] = 1

        if varall:
            return xr, np.einsum('ij,ik->ijk', px, py), alphas
        else:
            return xr[:,[ix,jx]], np.einsum('ij,ik->ijk', px, py), alphas    
    else:
        ix = ['theta_N','phi'].index(varx)

        i = np.digitize(xr[:,ix], xstops, False) - 1
        i[i == -1] = 0; i[i == qdim] = qdim - 1
        px = np.zeros((size, qdim), 'd'); px[range(size), i] = 1

        if varall:
            return xr, px, alphas
        else:
            return xr[:,ix], px, alphas
        
def syntrainer(net, syntrain, lossfunction=None, iterations=300, 
               batchsize=None, initstep=1e-3, finalv=1e-5, clipgradient=None, validation=None,
               seed=None, single=True):
  """Trains network NN against training sets obtained from `syntrain`,
  iterating at most `iterations`; stops if the derivative of loss
  (averaged over 20 epochs) becomes less than `finalv`."""

  if seed is not None:
    np.random.seed(seed)
    torch.manual_seed(seed)

  indicatorloss = 'l' in lossfunction.__annotations__ and lossfunction.__annotations__['l'] == 'indicator'  
  
  if validation is not None:
    raise NotImplementedError
    
    vlabels = numpy2cuda(validation[1] if indicatorloss else validation[0], single)
    vinputs = numpy2cuda(validation[2], single)
  
  optimizer = optim.Adam(net.parameters(), lr=initstep)

  training_loss, validation_loss = [], []
  
  for epoch in range(iterations):
    t0 = time.time()

    xtrue, indicator, inputs = syntrain()
    labels = numpy2cuda(indicator if indicatorloss else xtrue, single)

    if batchsize is None:
      batchsize = inputs.shape[0]
    batches = inputs.shape[0] // batchsize

    averaged_loss = 0.0    
    
    for i in range(batches):
      # zero the parameter gradients
      optimizer.zero_grad()

      # forward + backward + optimize
      outputs = net(inputs[i*batchsize:(i+1)*batchsize])
      loss = lossfunction(outputs, labels[i*batchsize:(i+1)*batchsize])
      loss.backward()
      
      if clipgradient is not None:
        torch.nn.utils.clip_grad_norm_(net.parameters(), clipgradient)
      
      optimizer.step()

      # print statistics
      averaged_loss += loss.item()

    training_loss.append(averaged_loss/batches)

    if validation is not None:
      loss = lossfunction(net(vinputs), vlabels)
      validation_loss.append(loss.detach().cpu().item())

    if epoch == 1:
      print("One epoch = {:.1f} seconds.".format(time.time() - t0))

    if epoch % 50 == 0:
      print(epoch,training_loss[-1],validation_loss[-1] if validation is not None else '')

    try:
      if len(training_loss) > iterations/10:
        training_rate = np.polyfit(range(20), training_loss[-20:], deg=1)[0]
        if training_rate < 0 and training_rate > -finalv:
          print(f"Terminating at epoch {epoch} because training loss stopped improving sufficiently: rate = {training_rate}")
          break

      if len(validation_loss) > iterations/10:
        validation_rate = np.polyfit(range(20), validation_loss[-20:], deg=1)[0]        
        if validation_rate > 0:
          print(f"Terminating at epoch {epoch} because validation loss started worsening: rate = {validation_rate}")
          break
    except:
      pass
          
  print("Final",training_loss[-1],validation_loss[-1] if validation is not None else '')
      
  if hasattr(net,'steps'):
    net.steps += iterations
  else:
    net.steps = iterations
    
# dimensions = [200*2] + [1024]*8 + [1*6]
# percival_network = makenet(dimensions, softmax=False)

# network_to_use = percival_network()

# ##Training data to pass through Percival network
# training_data = lambda: syntrain(size=100000, varx='theta_N')

# ##Train Percival network on above data
# ##training the network
# syntrainer(network_to_use, training_data, lossfunction=kllossGn2, iterations=5000,
#            initstep=1e-4, finalv=1e-8)

# PATH = ""
# torch.save(network_to_use.state_dict(), PATH + '\\Trained-Networks\\theta-phi_l200-1024x8_2d_5000it.pt')

In [12]:
h_func_f2([0.45, np.pi/2], noise_ind=0) + h_func_f2([-0.45, np.pi/2], noise_ind=0)

array([[ 11262.64857069-12448.35641839j,  11461.45279295-12562.20064194j,
         11776.8794704 -12568.74907949j,  12205.95495933-12462.26933849j,
         12742.61025514-12234.21823802j,  13377.44845881-11873.64042362j,
         14097.43274732-11367.69956961j,  14885.53172448-10702.3510048j ,
         15720.36496632 -9863.16497884j,  16575.89748906 -8836.30702802j,
         17421.23741454 -7609.67588746j,  18220.59571682 -6174.19009054j,
         18933.46983096 -4525.20187608j,  19515.11316159 -2664.00155655j,
         19917.34911581  -599.35760322j,  20089.78015208 +1650.98179547j,
         19981.42851504 +4058.9198609j ,  19542.82504693 +6584.88795729j,
         18728.53530874 +9177.09651868j,  17500.0782811 +11771.2990707j ,
         15829.15286977+14291.21073489j,  13701.04282195+16649.71474547j,
         11118.02385607+18750.96996549j,   8102.55113448+20493.49728117j,
          4699.96485595+21774.2719773j ,    980.42163852+22493.78277547j,
         -2960.25511004+22561.93762246