In [None]:
import os

import numpy as np
import pandas as pd
import pickle
import seaborn as sns
from matplotlib import pyplot as plt

from align_utils import idr2_jdsx
from fit_diagnostics import abs_residuals
from plot_utils import clipped_heatmap, df_heatmap, plot_res_grouped, plot_res_heatmap
from red_likelihood import makeCArray, split_rel_results
from xd_utils import XDgroup_data

In [None]:
%matplotlib inline

In [None]:
plot_figs = True
if plot_figs:
    import matplotlib as mpl
    mpl.rcParams['figure.dpi'] = 300

In [None]:
jd_time = 2458098.43869 # used to find LST that labels dataframe
pol = 'ee'
ndist = 'gaussian'
dir_path = 'xd_rel_dfs'

In [None]:
lst_df = pd.read_pickle('jd_lst_map_idr2.pkl')

In [None]:
lst_ref = lst_df[lst_df['JD_time'] == jd_time]['LASTs'].values[0][0]
xd_df_path = os.path.join(dir_path, 'xd_rel_df.{:.4f}.{}.{}.pkl'.format(lst_ref, pol, ndist))

In [None]:
with open(os.path.join(dir_path, 'xd_rel_df.{:.4f}.{}.md.pkl'.format(lst_ref, pol)), 'rb') as f:
    md = pickle.load(f)
    
xd_df = pd.read_pickle(xd_df_path)
xd_df.sample(5).sort_index()

## Performance

### Number of iterations

In [None]:
plot_res_grouped(xd_df, 'nit', logy=True)

In [None]:
plot_res_heatmap(xd_df, 'nit', clip=True)

### Log-likelihood

In [None]:
plot_res_grouped(xd_df, 'fun', logy=True, figsize=(10, 7))

In [None]:
plot_res_heatmap(xd_df, 'fun', clip=True, clip_pctile=98, figsize=(8, 6))

### Residuals

In [None]:
xd_df[['med_abs_norm_res_Re', 'med_abs_norm_res_Im']] = xd_df.apply(lambda row: \
    pd.Series(abs_residuals(row['norm_residual'])), axis=1)

xd_df['med_abs_norm_res_comb'] = np.sqrt(xd_df['med_abs_norm_res_Re']**2 + \
                                         xd_df['med_abs_norm_res_Im']**2)

In [None]:
plot_res_heatmap(xd_df, 'med_abs_norm_res_comb', vmin=0.16, vmax=0.22, figsize=(8, 6))

## Gains at sample frequency and time slice

In [None]:
# Check results for a given frequency & time integration
test_freq = 600
test_tint = 53

resx = xd_df.loc[(test_freq, test_tint)][5:-5].values.astype(float)
test_vis, test_gains = split_rel_results(resx, md['no_unq_bls'], coords='cartesian')
test_gains = test_gains.reshape((md['JDs'].size, -1))

print('Mean gain amplitude across JDs for test frequency {} and time integration {}: '\
      '\n{}\n'.format(test_freq, test_tint, np.mean(np.abs(test_gains), axis=0)))
print('Mean gain phase across JDs for test frequency {} and time integration {}: '\
      '\n{}\n'.format(test_freq, test_tint, np.mean(np.angle(test_gains), axis=0)))

In [None]:
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(9, 3.5), sharey=True)

sns.heatmap(np.abs(test_gains), cmap=sns.cm.rocket_r, center=1, ax=ax1)
sns.heatmap(np.angle(test_gains), cmap='bwr', center=0, ax=ax2, vmin=-np.pi, vmax=np.pi)
ax1.set_xlabel('Antenna number')
ax2.set_xlabel('Antenna number')

ax1.set_yticks(np.arange(md['JDs'].size)+0.5)
ax1.set_yticklabels(md['JDs'], rotation=0)
ax2.set_yticks(np.arange(md['JDs'].size), minor=True)
ax1.tick_params(axis='y', which='minor', color='white')
ax2.tick_params(axis='y', which='minor', color='white')

ax1.set_xticks(np.arange(md['no_ants'])[::5]+0.5, minor=False)
ax1.set_xticklabels(np.arange(md['no_ants'])[::5])
ax2.set_xticks(np.arange(md['no_ants'])[::5]+0.5, minor=False)
ax2.set_xticklabels(np.arange(md['no_ants'])[::5])

ax1.grid(which='minor', axis='y', linestyle='--', lw=0.5)
ax2.grid(which='minor', axis='y', linestyle='--', lw=0.5)

plt.tight_layout()
plt.show()

## Visibilities

In [None]:
no_unq_bls = md['no_unq_bls']
no_min_p = 5 # number of columns in df that are attributes of the SciPy OptimizeResult 
vis_df = xd_df.iloc[:, no_min_p:no_unq_bls*2+no_min_p]

visC_df = vis_df.apply(lambda row: makeCArray(row.values), axis=1)
visC_df = pd.DataFrame(visC_df.values.tolist(), index=visC_df.index)

### Visibilities at test time integration

In [None]:
df = visC_df.xs(53, level='time_int').abs().transpose()
vmax = np.nanpercentile(df.values, 98)
vmin = np.nanpercentile(df.values, 2)
df_heatmap(df, xbase=25, ybase=5, \
           xlabel='Channel', ylabel='Redundant Baseline Group', \
           vmin=vmin, vmax=vmax, figsize=(8, 6))

In [None]:
df = visC_df.xs(53, level='time_int').applymap(np.angle).transpose()
df_heatmap(df, xbase=25, ybase=5, cmap='bwr', vmin=-np.pi, vmax=np.pi, center=0, \
           xlabel='Channel', ylabel='Redundant Baseline Group', figsize=(8, 6))

In [None]:
chans = xd_df.index.get_level_values(level='freq').unique().values
tints = xd_df.index.get_level_values(level='time_int').unique().values

if tints.size == md['Ntimes']:
    tints = None
if (md['JDs'] == idr2_jdsx).all():
    jds = idr2_jdsx
else:
    jds = md['JDs']

hd, redg, cdata, cndata = XDgroup_data(jd_time, jds, pol, chans=chans, \
    tints=tints, bad_ants=True, use_flags='first', noise=True)

In [None]:
# TODO
# Gain stability over different days
# Histograms of NLL, NLL / Noise, R_man to find outliers
# Compare med visibilities to those from individual datasets - can do amps,
# but what about phases?