In [None]:
from RFIx.observation import Observation

from RFIx.utils.coordinates import orbit, orbit_fisher
from RFIx.utils.interferometry import rfi_vis, ants_to_bl
from RFIx.utils.mcmc import *
from RFIx.utils.optimize import quasi_Newton, quasi_Newton_v
from RFIx.utils.dict import *
from RFIx.utils.tools import saveObs, saveOptimization, saveSamples, saveTrue, saveTimes, savePrior

from jax.scipy.optimize import minimize
from jax import jit, vmap, pmap, random, grad, hessian, pmap, jacrev
from functools import partial
from jax.lax import cond
import jax.numpy as jnp
import jax

from tqdm import tqdm
from time import time
from datetime import datetime

from numpy import loadtxt
import subprocess
import sys
import os

from jax.config import config
config.update('jax_enable_x64', True)
# config.update('jax_log_compiles', True)

# os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=4'
# config.update('jax_platform_name', 'cpu')

In [None]:
import argparse

def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

parser = argparse.ArgumentParser(description='Simulate RFI contaminated visibilities and recover the gains and RFI parameters.')
# Output File Arguments
parser.add_argument('--fname', default='')
parser.add_argument('--output_path', default='./output_data/')
parser.add_argument('--Rkey', default=0, type=int)
# Problem Arguemnts
parser.add_argument('--Ncaltime', default=4, type=int)
parser.add_argument('--start_time', default=0.0, type=float)
parser.add_argument('--Ntartime', default=4, type=int)
parser.add_argument('--Nint', default=4, type=int)
parser.add_argument('--int_time', default=2., type=float)
parser.add_argument('--Nant', default=8, type=int)
parser.add_argument('--noise', default=0.65, type=float)
parser.add_argument('--GampDev', default=0.05, type=float)
parser.add_argument('--GphaseDev', default=5., type=float)
parser.add_argument('--Gainsl', default=600., type=float)
parser.add_argument('--RFIamp', default=10., type=float)
parser.add_argument('--CALamp', default=1., type=float)
parser.add_argument('--orbit_dev', default=1., type=float)
parser.add_argument('--Orbit', default='MEO', type=str)

# MCMC Arguments
parser.add_argument('--MCMC', default=False, type=str2bool)
parser.add_argument('--deltaT', default=3.0e-1, type=jnp.float64)
parser.add_argument('--T', default=100, type=jnp.int32)
parser.add_argument('--Nsamples', default=4000, type=int)
# Optimization Arguments
parser.add_argument('--optimize', default=True, type=str2bool)
parser.add_argument('--update_H_inv', default=True, type=str2bool)
parser.add_argument('--alpha', default=7.0e-1, type=jnp.float64)
parser.add_argument('--max_iter', default=1000, type=jnp.int32)
parser.add_argument('--threshold', default=1.0e-8, type=jnp.float64)
parser.add_argument('--batch_size_opt', default=500, type=jnp.int32)
parser.add_argument('--min_chi2', default=0.9, type=jnp.float64)

args = parser.parse_args()

f_name = args.fname
output_dir = args.output_path
key = random.PRNGKey(args.Rkey)

N_cal_time = args.Ncaltime
cal_start_time = args.start_time
N_tar_time = args.Ntartime
N_int_samples = args.Nint
N_ant = args.Nant
int_time = args.int_time
noise = args.noise
G_amp_dev = args.GampDev
G_phase_dev = args.GphaseDev
gains_l = args.Gainsl
RFI_amp = args.RFIamp
Cal_amp = args.CALamp
orbit_dev = args.orbit_dev

mcmc = args.MCMC
T = args.T
delta_T = args.deltaT
N_samples = args.Nsamples
batch_size = 100
N_batch = int(N_samples/batch_size)

optimization = args.optimize
update_H_inv = args.update_H_inv
alpha = args.alpha
max_iter = args.max_iter
threshold = args.threshold
batch_size_opt = args.batch_size_opt
min_chi2 = args.min_chi2

print(optimization, mcmc)

N_time = N_cal_time + N_tar_time

G0_mean = 1.0
G0_std = 0.05
Gt_std_mag = 1e-5
Gt_std_phase = jnp.deg2rad(1e-3)

start_time = time()

# N_freqs = 3


if args.Orbit=='GEO':
    # Geostationary Orbit HEO
    RFI_amp_cal = RFI_amp*5.10e-6
    el, inclination, lon_asc_node, periapsis = 35786e3, 3.2, 81., 300.0
    RIC_std = jnp.array([359.0, 432.0, 86.0])     # GEO orbits
elif args.Orbit=='MEO':
    # GPS Orbit MEO
    RFI_amp_cal = RFI_amp*0.29e-6
    el, inclination, lon_asc_node, periapsis = 20200e3, 55.0, 21., 5.0
    RIC_std = jnp.array([73.0, 131.0, 54.0])      # MEO orbits
    # RIC_std = jnp.array([150.0, 131.0, 54.0])      # MEO orbits
elif args.Orbit=='LEO':
    # Iridium Orbit LEO
    RFI_amp_cal = RFI_amp*0.0032e-6
    el, inclination, lon_asc_node, periapsis = 781e3, 86.4, 200.0, 204.0
    RIC_std = jnp.array([102.0, 471.0, 126.0])    # LEO orbits
else:
    print('Invalid orbit option chosen. Choose from "GEO", "MEO", or "LEO". ')
    sys.exit(0)

ENU = random.permutation(random.PRNGKey(0), jnp.array(loadtxt('data/Meerkat.enu.txt')))
# Create Calibration Observation Data
cal = Observation(latitude=-30., longitude=21., elevation=1050., ra=21., dec=10.,
                  times=cal_start_time+int_time*jnp.arange(N_cal_time), freqs=jnp.array([1.227e9]),
                  ENU_array=ENU[:N_ant], n_int_samples=int(16*int_time))
cal.addAstro(I=Cal_amp*jnp.ones((1,1)), ra=[21.0,], dec=[10.0,])
cal.addSat(Pv=RFI_amp_cal*jnp.ones(1), elevation=el, inclination=inclination,
           lon_asc_node=lon_asc_node, periapsis=periapsis)
cal.addGains(G0_mean=G0_mean, G0_std=G0_std, Gt_std_mag=Gt_std_mag, Gt_std_phase=Gt_std_phase)
cal.calculate_vis()
cal.addNoise(noise=noise, key=random.PRNGKey(int(cal_start_time)))

max_vis = jnp.abs(cal.vis_obs-1).max()
mean_vis = jnp.abs(cal.vis_obs-1).max(axis=0).mean()
print(max_vis, mean_vis)
N_int_samples = int(jnp.ceil(int_time*(mean_vis/(3*noise))**(1./2.)))
# N_int_samples = int(jnp.ceil(int_time*(mean_vis/(1.5*noise))**0.5))
# N_int_samples = int(jnp.ceil(int_time*(mean_vis/(3*noise))**(1./1.7)))
# N_int_samples = int(jnp.ceil(int_time*(mean_vis/(3*noise))**(1./1.5)))
# # N_int_samples = int(jnp.ceil(int_time*(mean_vis/(2.5*noise))**0.5))
# N_int_samples = int(jnp.ceil(int_time*(mean_vis/(noise))**0.5))
print(f'Using {N_int_samples} samples for RFI estimation')

# Create Calibration Observation Data
calN = Observation(latitude=-30., longitude=21., elevation=1050., ra=21., dec=10.,
                  times=cal_start_time+int_time*jnp.arange(N_cal_time), freqs=cal.freqs,
                  ENU_array=ENU[:N_ant], n_int_samples=N_int_samples)
calN.addAstro(I=Cal_amp*jnp.ones((1,1)), ra=[21.0,], dec=[10.0,])
calN.addSat(Pv=RFI_amp_cal*jnp.ones(1), elevation=el, inclination=inclination,
            lon_asc_node=lon_asc_node, periapsis=periapsis)
calN.addGains(G0_mean=G0_mean, G0_std=G0_std, Gt_std_mag=Gt_std_mag, Gt_std_phase=Gt_std_phase)
calN.calculate_vis()
calN.addNoise(noise=noise, key=random.PRNGKey(int(cal_start_time+1)))

n_src = 100
max_rad = 0.5
mean_I = 0.1
slew_time = 30.

trgt_I = mean_I*random.exponential(random.PRNGKey(101), (2*n_src,1))
_r = max_rad*jnp.sqrt(random.uniform(random.PRNGKey(102), (2*n_src,)))
_theta = 2.*jnp.pi*random.uniform(random.PRNGKey(103), (2*n_src,))
trgt_ra = _r*jnp.cos(_theta)
trgt_dec = _r*jnp.sin(_theta)

trgt_pos = jnp.stack([trgt_ra, trgt_dec])
source_d = jnp.linalg.norm(trgt_pos[:,:,None]-trgt_pos[:,None,:], axis=0)*3600
source_d = source_d + jnp.triu(4000*jnp.ones(source_d.shape))

idx = list(jnp.arange(trgt_ra.size))
# print(list(jnp.where(source_d<80)[0]))
for i in list(set(jnp.where(source_d<80)[0])):
    idx.remove(i)
idx = jnp.array(idx)

trgt_I = trgt_I[idx][:n_src]
trgt_ra = trgt_ra[idx][:n_src]
trgt_dec = trgt_dec[idx][:n_src]

# # Create Target Observation Data
# target = Observation(latitude=-30., longitude=21., elevation=1050., ra=27., dec=15.,
#                   times=cal_start_time + \
#                         int_time*N_cal_time + \
#                         slew_time + \
#                         int_time*jnp.arange(N_tar_time),
#                   freqs=cal.freqs, ENU_array=ENU[:N_ant], n_int_samples=cal.n_int_samples)
# target.addAstro(I=trgt_I, ra=target.ra+trgt_ra, dec=target.dec+trgt_dec)
# target.addSat(Pv=RFI_amp_cal*jnp.ones(1), elevation=el, inclination=inclination,
#               lon_asc_node=lon_asc_node, periapsis=periapsis)
# target.addGains(G0_mean=G0_mean, G0_std=G0_std, Gt_std_mag=Gt_std_mag, Gt_std_phase=Gt_std_phase)
# target.calculate_vis()
# target.addNoise(noise=noise, key=random.PRNGKey(104))

tar_start_time = cal_start_time + int_time*N_cal_time + slew_time
# Create Target Observation Data
target = Observation(latitude=-30., longitude=21., elevation=1050., ra=27., dec=15.,
                  times=tar_start_time+int_time*jnp.arange(N_tar_time),
                  freqs=cal.freqs, ENU_array=ENU[:N_ant], n_int_samples=cal.n_int_samples)
target.addAstro(I=trgt_I*jnp.ones((1,1)), ra=target.ra+trgt_ra, dec=target.dec+trgt_dec)
target.addSat(Pv=RFI_amp_cal*jnp.array([1]), elevation=el, inclination=inclination,
              lon_asc_node=lon_asc_node, periapsis=periapsis)
target.addGains(G0_mean=G0_mean, G0_std=G0_std, Gt_std_mag=Gt_std_mag, Gt_std_phase=Gt_std_phase)
target.calculate_vis()
target.addNoise(noise=noise, key=random.PRNGKey(104+int(tar_start_time)))

f_name += f'CalRFI_TargetSim{int(N_ant)}A_{int(N_time)}T_{int(N_int_samples)}I_{int(int_time)}IT_{int(cal_start_time)}ST_{noise:.1E}N_{args.Orbit}_{RFI_amp:.1E}RFI_{Cal_amp:.1E}CAL_{N_samples:.0E}S_K{args.Rkey}'

data_path = os.path.join(output_dir, f_name)

try:
    bashCommand = f'rm -rf {data_path}'
    process = subprocess.Popen(bashCommand.split(), stdout=subprocess.PIPE)
    output, error = process.communicate()
except:
    pass
bashCommand = f'mkdir {data_path}'
process = subprocess.Popen(bashCommand.split(), stdout=subprocess.PIPE)
output, error = process.communicate()

obs_path = os.path.join(data_path, '_ObservationData')
saveObs(obs_path, [cal, target])

# Clear Space in Memory
del target

# sys.exit(0)

In [None]:
##### Create Functions needed in MCMC
@jit
def linear_extrapolate(x, xp, fp):

    N_ext = int(N_int_samples/2)

    xi = x[:N_ext]
    xf = x[-N_ext:]

    x_int = x[N_ext:-N_ext]
    y_int = jnp.interp(x_int, xp, fp)

    dyi = (y_int[1]-y_int[0])/(x_int[1]-x_int[0])
    dyf = (y_int[-1]-y_int[-2])/(x_int[-1]-x_int[-2])

    yi = y_int[0] + dyi*(xi - xp.min())
    yf = y_int[-1] + dyf*(xf - xp.max())

    return jnp.concatenate([jnp.array(yi), jnp.array(y_int), jnp.array(yf)])

@jit
def nlog_prior(q, params):

    rfi_amp = q['rfi_amp'].reshape(N_cal_time, N_ant)
    G_amp = q['g_amp'].reshape(N_cal_time, N_ant)
    G_phase = q['g_phase'].reshape(N_cal_time, N_ant-1)

    # lp = jnp.sum(log_normal(G_amp.flatten(), params['mu_G_amp'].flatten(), params['G_amp_std'].flatten())) + \
    #      jnp.sum(log_normal(G_phase.flatten(), params['mu_G_phase'].flatten(), params['G_phase_std'])) + \
    #      log_multinorm(q['rfi_orbit'], params['mu_RFI_orbit'], params['inv_cov_RFI_orbit']) + \
    #      log_multinorm_sum(rfi_amp, jnp.zeros((N_cal_time,N_ant)), params['inv_cov_RFI_amp'])

    lp = jnp.sum(log_normal(G_amp.flatten(), params['mu_G_amp'].flatten(), params['G_amp_std'].flatten())) + \
         jnp.sum(log_normal(G_phase.flatten(), params['mu_G_phase'].flatten(), params['G_phase_std'])) + \
         log_multinorm(q['rfi_orbit'], params['mu_RFI_orbit'], params['inv_cov_RFI_orbit']) + \
         jnp.sum(log_normal(rfi_amp.flatten(), 0.0, 100.0))

    return -1.0*jnp.sum(lp)

@jit
def nlog_likelihood(q, params):

#     DATA
    V_obs = params['D']

#     THETA
    rfi_amp = q['rfi_amp'].reshape(N_cal_time,N_ant)
    rfi_amp = vmap(linear_extrapolate, in_axes=(None,None,1))(params['times_fine'], params['times'], rfi_amp)
    rfi_amp = jnp.transpose(rfi_amp).reshape(N_cal_time,N_int_samples,N_ant,1,1)

    G_amp = q['g_amp'].reshape(N_cal_time, N_ant)
    G_phase = q['g_phase'].reshape(N_cal_time, N_ant-1)
    G_phase = jnp.concatenate([G_phase, jnp.zeros((N_cal_time, 1))], axis=1)

    G = G_amp*jnp.exp(1.j*G_phase)
    G_bl = ants_to_bl(G)

#     Calculate the visibility contribution from the RFI
    rfi_xyz = orbit(params['times_fine'], *q['rfi_orbit'])
    distances = jnp.linalg.norm(params['ants_xyz']-rfi_xyz[:,None,:], axis=2)
    c_distances = (distances-params['phase_corrections'])[...,None].reshape(N_cal_time,N_int_samples,N_ant,1)
    V_rfi = vmap(rfi_vis, in_axes=(1,1,None))(rfi_amp, c_distances, params['freqs'])[...,0].mean(axis=0)

#     Log Likelihood
    model_vis = (G_bl*(params['vis_cal']+V_rfi))

    ll = jnp.sum(log_normal(jnp.abs(model_vis-V_obs), 0.0, params['noise']/jnp.sqrt(2.)))

    return -1.0*ll

@jit
def U(q, params):
    return nlog_prior(q, params) + nlog_likelihood(q, params)

# @jit
# def U(q, params):
#     return nlog_likelihood(q, params)

def flat_U(q, params):
    return U(unflatten(q), params)

flat_delU = jit(grad(flat_U, 0))
flat_hess = jit(hessian(flat_U, 0))

delU = jit(grad(U, 0))
hess = jit(hessian(U, 0))

@jit
def softabs(x, a=1e-2):
    return x/jnp.tanh(x/a)

@jit
def hess_inv(q, params):
    H = flat_hess(q, params)
    eigval, eigvec = jnp.linalg.eigh(H)
    H_inv = eigvec@jnp.diag(1./softabs(eigval))@eigvec.T
    return H_inv

@jit
def M_inv_L_custom(q, params):
    M = hess(q, params)
    M_block_diag = get_block_diag(M)
    M_diag = jax.tree_map(lambda x: jnp.diag(jnp.abs(jnp.diag(x))), M_block_diag)
    M_diag['rfi_amp'] = M_block_diag['rfi_amp']
    M_diag['rfi_orbit'] = M_block_diag['rfi_orbit']
    M_diag['g_amp'] = M_block_diag['g_amp']
    M_diag['g_phase'] = M_block_diag['g_phase']
    M_inv_diag = block_diag_inv(M_diag)
    L_diag = block_diag_cholesky(M_diag)
    return M_inv_diag, L_diag

@jit
def H_inv(q, params):
    H = hess(q, params)
    H_block_diag = get_block_diag(H)
    H_block_diag_inv = block_diag_inv(H_block_diag)
    # H_diag = jax.tree_map(lambda x: jnp.diag(jnp.abs(jnp.diag(x))), H_block_diag)
    # H_diag['rfi_amp'] = H_block_diag['rfi_amp']
    # H_diag['rfi_orbit'] = H_block_diag['rfi_orbit']
    # H_diag['g_amp'] = H_block_diag['g_amp']
    # H_diag['g_phase'] = H_block_diag['g_phase']
    # H_block_diag_inv = block_diag_inv(H_diag)
    return H_block_diag_inv

@partial(jit, static_argnums=(4,5))
def run_batch(sample_0, M_inv, L, delta_T, T, N_samples, key, updateDT):
    samples, acpt, delta_E, key = HMC_block_diag(U, delU, sample_0, params, M_inv, L,
                                                 delta_T, T, N_samples, key)
    key, subkey = random.split(key)
    delta_T = cond(updateDT, update_dt, lambda x: x[1], [jnp.mean(acpt), delta_T])
    return samples, acpt, delta_E, delta_T, key

lhhp_vmap = lambda x, y: jnp.zeros(batch_size)
lhp_vmap = lambda x, y: jnp.zeros(batch_size)
lp_vmap = jit(vmap(nlog_prior, in_axes=(1, None)))
ll_vmap = jit(vmap(nlog_likelihood, in_axes=(1, None)))
U_vmap = jit(vmap(U, in_axes=(1, None)))
extra_data_funcs = [lhhp_vmap, lhp_vmap, lp_vmap, ll_vmap]

# Set Constant Parameters
params = {'freqs': cal.freqs,
          'times': cal.times,
          'times_fine': calN.times_fine,
          'noise': cal.noise if cal.noise>0 else 0.2,
          'ants_xyz': calN.ants_xyz,
          'phase_corrections': calN.ants_uvw[...,-1],
          'n_ants': cal.n_ants,
          'n_bl': cal.n_bl,
          'vis_cal': calN.vis_ast.reshape(N_cal_time,N_int_samples,
                                          calN.n_bl).mean(axis=1),
          'D': cal.vis_obs[:,:,0],
          }

# Calculate Prior Parameter Values
max_sigma = 3
G_amp_mean = jnp.mean(jnp.abs(cal.gains_ants[:,:,0]), axis=0)[None,:]
G_amp = ( G_amp_mean + \
          G_amp_dev*G_amp_mean * \
          random.truncated_normal(random.PRNGKey(105+int(cal.times[0])), lower=-max_sigma, upper=max_sigma,
                                  shape=(1,N_ant)) ) * \
        jnp.ones((N_cal_time,N_ant))

G_phase_mean = jnp.mean(jnp.angle(cal.gains_ants[:,:,0]), axis=0)[None,:]
G_phase = ( G_phase_mean + \
            jnp.deg2rad(G_phase_dev) * \
            random.truncated_normal(random.PRNGKey(106+int(cal.times[0])), lower=-max_sigma, upper=max_sigma,
                                    shape=(1,N_ant)) ) * \
          jnp.ones((N_cal_time,N_ant))

G = G_amp*jnp.exp(1.j*G_phase)

inv_cov_RFI_orbit = orbit_fisher(cal.times, cal.rfi_orbit, orbit_dev*RIC_std)/cal.n_time
rfi_orbit = random.multivariate_normal(random.PRNGKey(107), cal.rfi_orbit,
                                       jnp.linalg.inv(inv_cov_RFI_orbit))

prior_params = {'mu_RFI_orbit': rfi_orbit,
               'inv_cov_RFI_orbit': inv_cov_RFI_orbit/4,
               'mu_G_amp': G_amp,
               'G_amp_std': G_amp_dev*G_amp_mean*jnp.ones((N_cal_time,N_ant))*2,
               'mu_G_phase': G_phase[:,:-1],
               'G_phase_std': jnp.deg2rad(G_phase_dev)*2,
               'inv_cov_RFI_amp': inv_kernel_vmap(params['ants_xyz'][0],
                                                  1e4*jnp.ones(N_cal_time),
                                                  1e3*jnp.ones(N_cal_time))
               }

params.update(prior_params)
file_name = os.path.join(data_path, 'prior')
savePrior(file_name, prior_params)

N_qi = 10000

qi_dev = 0.5

# G_amp_mean = jnp.mean(jnp.abs(cal.gains_ants[:,:,0]), axis=0)
# G_amp = ( params['mu_G_amp'][0,None,None] + \
#           qi_dev*params['G_amp_std']*random.normal(key, (N_qi,1,N_ant)) ) * \
#         jnp.ones((1,N_cal_time,N_ant))
# key, subkey = random.split(key)
# G_phase = ( jnp.concatenate([params['mu_G_phase'][0,None,None], jnp.zeros((1,1,1))], axis=-1) + \
#             qi_dev*params['G_phase_std']*random.normal(key, (N_qi,1,N_ant)) ) * \
#           jnp.ones((1,N_cal_time,N_ant))
# key, subkey = random.split(key)

G_amp_mean = jnp.mean(jnp.abs(cal.gains_ants[:,:,0]), axis=0)[None,None,:]
G_amp = ( G_amp_mean + \
          qi_dev*G_amp_dev*G_amp_mean*random.normal(key, (N_qi,1,N_ant)) ) * \
        jnp.ones((1,N_cal_time,N_ant))
key, subkey = random.split(key)
G_phase_mean = jnp.mean(jnp.angle(cal.gains_ants[:,:,0]), axis=0)[None,None,:-1]
G_phase = ( G_phase_mean + \
            qi_dev*jnp.deg2rad(G_phase_dev)*random.normal(key, (N_qi,1,N_ant-1)) ) * \
          jnp.ones((1,N_cal_time,N_ant-1))
key, subkey = random.split(key)

G = G_amp*jnp.exp(1.j*jnp.concatenate([G_phase, jnp.zeros((N_qi,N_cal_time,1))], axis=-1))

rfi_orbit = random.multivariate_normal(key, cal.rfi_orbit,
                                       qi_dev*jnp.linalg.inv(inv_cov_RFI_orbit),
                                       shape=(N_qi,))
key, subkey = random.split(key)

def calc_rfi_amp(G):
    return jnp.sqrt(jnp.max(jnp.abs(params['D']/ants_to_bl(G) -
                                    params['vis_cal']), axis=1, keepdims=True))* \
           jnp.ones((1,N_ant))#jnp.ones((N_cal_time,N_ant))

rfi_A = vmap(calc_rfi_amp, in_axes=(0,))(G)

qi = [{
      'g_amp': G_amp[i].flatten()/0.98, # (Nt,Na,Nf)
      'g_phase': G_phase[i].flatten(), # (Nt,Na-1,Nf)
      'rfi_amp': rfi_A[i].flatten()*0.98,
      'rfi_orbit': rfi_orbit[i]
      } for i in range(N_qi)]

nlpost = jnp.array([U(qi[i], params) for i in range(N_qi)])
qi = [qi[idx] for idx in jnp.argsort(nlpost)[:10]]

param_tree_def(qi[0])

# Calculate approximate True values
G = cal.gains_ants.reshape(N_cal_time,cal.n_int_samples,N_ant).mean(axis=1)

rfi_A = vmap(jnp.interp, in_axes=(None,None,1))(cal.times, cal.times_fine, cal.rfi_A_app[:,:,0]).T

true_values = {
              'g_amp': jnp.abs(G).flatten(), # (Nt,Na,Nf)
              'g_phase': jnp.angle(G)[:,:-1].flatten(), # (Nt,Na-1,Nf)
              'rfi_amp': rfi_A.flatten(),
              'rfi_orbit': cal.rfi_orbit
              }

# Save True values to H5
file_name = os.path.join(data_path, 'true_values')
saveTrue(file_name, true_values, N_ant, N_cal_time)

sys.exit(0)

# Clear Space in Memory
del cal
del calN

# Print Run details
print()
print(f'Start Time : {datetime.now()}')
print(f_name)
print(f'Number of samples : {N_samples}')
print(f'\nNumber of Data Points: {2*params["D"].size}')
print(f'\nNumber of Parameters: {len(flatten(true_values))}')

# Put Parameter Names into lists for saving to H5
hp_names = []
hp_keys = []

rfi_amp_names = [f'$A_{j}(t={int(t)})$'for t in params['times'] for j in range(N_ant)]
rfi_orbit_names = ['Elevation [m]', 'Inclination [deg]',
                   'Longitude of\nAscending Node [deg]',
                   'Argument of\nPerapsis [deg]']

g_amp_names = [f'$|G_{i}|(t={int(t)})$' for t in params['times'] for i in range(N_ant)]
g_phase_names = ['$\Phi_{G_{'+f'{i}'+'}}'+f'(t={int(t)})$' for t in params['times'] for i in range(N_ant-1)]

hp_names = [name.encode('utf-8') for name in hp_names]
rfi_amp_names = [name.encode('utf-8') for name in rfi_amp_names]
rfi_orbit_names = [name.encode('utf-8') for name in rfi_orbit_names]
g_amp_names = [name.encode('utf-8') for name in g_amp_names]
g_phase_names = [name.encode('utf-8') for name in g_phase_names]

param_names = [hp_names, rfi_amp_names, rfi_orbit_names,
               g_amp_names, g_phase_names]

init_time = time() - start_time
print(f'\nInitialization Time: {init_time} s')

print('\nTrue values')
nlpost = U(true_values, params)
nll = nlog_likelihood(true_values, params)
nlp = nlog_prior(true_values, params)
print(f'NL Prior : \t{nlp}\nNL Likelihood : {nll}\nNL Posterior : \t{nlpost}')

def qN(q):
    return quasi_Newton(U, delU, H_inv, q, params, alpha,
                        max_iter, threshold, batch_size_opt,
                        min_chi2, update_H_inv=update_H_inv)

def flat_qN(q):
    return quasi_Newton_v(flat_U, flat_delU, hess_inv, q, params, alpha/10.,
                        max_iter, threshold, batch_size_opt,
                        min_chi2, update_H_inv=update_H_inv)


#########################################################
# qi[0] = true_values
#########################################################


# Run Optimization
if optimization:
    print('\nRunning Quasi-Newton Optimization to find MAP')
    t = time()
    for i in range(10):
        print('\nInitial values')
        nlpost = U(qi[i], params)
        nll = nlog_likelihood(qi[i], params)
        nlp = nlog_prior(qi[i], params)
        print(f'NL Prior : \t{nlp}\nNL Likelihood : {nll}\nNL Posterior : \t{nlpost}')
        print('Running Block Diagonal Quasi-Newton ...')
        i_f, q0, qf = qN(qi[i])
        print(f'chi^2/N_d : {U(qf, params)/params["D"].size}')
        if U(qf, params)/params["D"].size<1.5:
            print('Running Full Quasi-Newton ...')
            i_f, q0, q = flat_qN(flatten(qf))
            q = unflatten(q)
            q0 = unflatten(q0)
            print(f'chi^2/N_d : {U(q, params)/params["D"].size}')
            if U(q, params)<U(qf, params):
                qf = q
            if U(qf, params)/params["D"].size<1.02:
                break
            else:
                alpha /= 2
    print(f'Final iter: {i_f} with chi^2/N_d = {round(U(qf, params)/params["D"].size, 5)}')

    print('\nMAP values')
    nlpost = U(qf, params)
    nll = nlog_likelihood(qf, params)
    nlp = nlog_prior(qf, params)
    print(f'NL Prior : \t{nlp}\nNL Likelihood : {nll}\nNL Posterior : \t{nlpost}')
    print(f'Time taken: {round(time()-t, 2)} s')

    print('\nRunning Laplace Approximation for Covariance')
    t = time()
    inv_cov_dict = hess(qf, params)
    inv_cov, keys, sizes = mat_from_dict(inv_cov_dict, qf)
    cov = jnp.linalg.inv(inv_cov)
    # L_cov = jnp.linalg.cholesky(cov)
    # eigval, eigvec = jnp.linalg.eigh(inv_cov)
    # L_cov = eigvec@jnp.diag(1./jnp.abs(jnp.sqrt(eigval)))
    # cov = L_cov@L_cov.T
    print(f'Time taken: {round(time()-t, 2)} s')

    file_name = os.path.join(data_path, '_Optimization')
    extra_params = [nlp, nll, nlpost]
    saveOptimization(file_name, qf, cov, keys, sizes, extra_params, N_ant, N_cal_time)
    # q_flat = flatten(qf)
    # qi = unflatten(q_flat + 0.5*L_cov@random.normal(key, (len(q_flat),)))
    # key, subkey = random.split(key)
    qi = qf
    # if nll/params["D"].size>min_chi2:
    #     sys.exit(0)

    from RFIx.utils.analyze import calc_bias, calc_coverage

    print('\nParameter Estimate Coverage')
    idxs = jnp.cumsum(jnp.concatenate([jnp.zeros(1), jnp.array(sizes)])).astype(int)
    std = jnp.sqrt((jnp.diag(cov)))
    for i, param_key in enumerate(keys):
        coverage = calc_coverage(calc_bias(qf[param_key], true_values[param_key],
                                           std[idxs[i]:idxs[i+1]]))
        print(param_key, [round(x,1) for x in coverage])

print('\nInitial values')
nlpost = U(qi, params)
nll = nlog_likelihood(qi, params)
nlp = nlog_prior(qi, params)
print(f'NL Prior : \t{nlp}\nNL Likelihood : {nll}\nNL Posterior : \t{nlpost}')

update = True
updateDT = True
sample_state = 'burn'

# Run MCMC
if mcmc:
    print('\nCompiling HMC ...')
    compilation_time_start = time()

    mass_time_start = time()
    M_inv, L = M_inv_L_custom(qi, params)
    Msum = float(vec_sum(M_inv))
    mass_comp_time = time() - mass_time_start
    print(f'\nMass Matrix Compilation Time: {mass_comp_time} s\n')

    samples, acpt, delta_E, delta_T, key = run_batch(qi, M_inv, L, delta_T, T, batch_size, key, updateDT)
    sample_0 = get_sample(samples, int(-1))
    file_name = os.path.join(data_path, 'mcmc_burn_0000')
    extra_samples = [acpt, delta_E, delta_T*jnp.ones(batch_size)] + \
                    [x(samples, params) for x in extra_data_funcs] + \
                    [U_vmap(samples, params)]
    saveSamples(file_name, samples, extra_samples,
                param_names, hp_keys, N_ant, N_cal_time)

    print('Acceptance Ratio on First Pass: {} %\n'.format(round(100*jnp.mean(acpt), 2)))
    compilation_time = time() - compilation_time_start
    print(f'Compilation Time: {compilation_time} s')

    mass_time_start = time()
    M_inv, L = M_inv_L_custom(true_values, params)
    # print(vec_sum(M_inv))
    M_inv, L = M_inv_L_custom(qi, params)
    # print(vec_sum(M_inv))
    mass_run_time = (time() - mass_time_start)/2
    print(f'\nMass Matrix Calculation Time: {1000*mass_run_time} ms')

    print('\nRunning HMC ...')
    run_start = time()

    iter = tqdm(range(1, N_batch))

    for i in iter:

        running_acpt = jnp.mean(acpt)
        iter.set_description(f'Acpt rate: {round(100*running_acpt,1)} % | {batch_size} samples/iter ')
        map_shift = 2*jnp.min(extra_samples[-2]) - 1.05*2*params['D'].size

        if not update and updateDT and running_acpt>0.59 and running_acpt<0.71:
            updateDT = False
            print(f'\nDelta T = {delta_T:.1E} fixed at iter : {i}')
            sample_state = 'sample'
            sample_start = time()
            burn_time = sample_start - run_start
        if update and map_shift<0:
            update = False
            map_idx = int(jnp.argmin(extra_samples[-1]))
            M_inv, L = M_inv_L_custom(get_sample(samples, map_idx), params)
            map_val = extra_samples[-1][map_idx]
            print(f'\nMass Updated at iter : {i} \nMAP NLP : {map_val}')

        samples, acpt, delta_E, delta_T, key = run_batch(sample_0, M_inv, L, delta_T, T, batch_size, key, updateDT)
        sample_0 = get_sample(samples, int(-1))

        file_name = os.path.join(data_path, f'mcmc_{sample_state}_{str(i).zfill(4)}')
        extra_samples = [acpt, delta_E, delta_T*jnp.ones(batch_size)] + \
                        [x(samples, params) for x in extra_data_funcs] + \
                        [U_vmap(samples, params)]
        saveSamples(file_name, samples, extra_samples,
                    param_names, hp_keys, N_ant, N_cal_time)


    try:
        sample_time = time() - sample_start
        hh = int(burn_time/3600)
        mm = int((burn_time-hh*3600)/60)
        ss = round(burn_time-hh*3600-mm*60, 2)
        print(f'\nHMC Burn In Time - {hh}:{mm}:{ss}')

        hh = int(sample_time/3600)
        mm = int((sample_time-hh*3600)/60)
        ss = round(sample_time-hh*3600-mm*60, 2)
        print(f'\nHMC Sample Time - {hh}:{mm}:{ss}')
    except:
        sample_time = 0.0
        burn_time = 0.0
        print('\nHMC sampling did not begin! \nStill in Warm up phase.')

    file_name = os.path.join(data_path, 'times')
    saveTimes(file_name, [init_time, compilation_time, mass_comp_time,
                          mass_run_time, burn_time, sample_time])

    print('\nTrue values')
    nlpost = U(true_values, params)
    nll = nlog_likelihood(true_values, params)
    nlp = nlog_prior(true_values, params)
    print(f'NL Prior : \t{nlp}\nNL Likelihood : {nll}\nNL Posterior : \t{nlpost}')

    print('\nInitial values')
    nlpost = U(qi, params)
    nll = nlog_likelihood(qi, params)
    nlp = nlog_prior(qi, params)
    print(f'NL Prior : \t{nlp}\nNL Likelihood : {nll}\nNL Posterior : \t{nlpost}')

    post_mean = mean_vec(samples)
    print(f'\nPosterior Mean values (last {int(batch_size)} samples)')
    nlpost = U(post_mean, params)
    nll = nlog_likelihood(post_mean, params)
    nlp = nlog_prior(post_mean, params)
    print(f'NL Prior : \t{nlp}\nNL Likelihood : {nll}\nNL Posterior : \t{nlpost}')

    for i, param_key in enumerate(post_mean.keys()):
        coverage = calc_coverage(calc_bias(post_mean[param_key].mean(axis=-1),
                                           true_values[param_key],
                                           post_mean[param_key].std(axis=-1)))
        print(param_key, [round(x,1) for x in coverage])