In [None]:
import os
import sys

import numpy as np
import pandas as pd
import pickle
from matplotlib import pyplot as plt

sys.path.insert(0, os.path.dirname(os.getcwd()))
from align_utils import align_df, idr2_jdsx
from plot_utils import arr_pcmesh, clipped_heatmap, df_heatmap
from red_likelihood import makeCArray

In [None]:
%matplotlib inline

In [None]:
import matplotlib as mpl

plot_figs = False
if plot_figs:
    mpl.rcParams['figure.dpi'] = 300

mpl.rc('font',**{'family':'serif','serif':['cm']})
mpl.rc('text', usetex=True)
mpl.rc('text.latex', preamble=r'\usepackage{amssymb} \usepackage{amsmath}')

## Statistics of redundantly calibration + degenerately transformed datasets across JDs

In [None]:
jd_time = 2458098.43869
pol = 'ee'
ndist = 'gaussian'
aligned_dir = '../aligned_red_deg_dfs'

In [None]:
with open(os.path.join('../rel_dfs', 'rel_df.{}.{}.md.pkl'.format(jd_time, pol)), \
          'rb') as f:
    md = pickle.load(f)
    
vis_list = list(map(str, np.arange(md['no_unq_bls']*2).tolist()))

In [None]:
idr2_df_path = os.path.join(aligned_dir, 'aligned_red_deg.1.3826.ee.{}.pkl'.format(ndist))
idr2_df = pd.read_pickle(idr2_df_path)

### Selected time

In [None]:
time_integration = 53

# abs first then median
idr2_df_tint = idr2_df.xs(time_integration, level='time_int', drop_level=True)
idr2_df_tint = idr2_df_tint.apply(lambda row: pd.Series(makeCArray(row[vis_list].to_numpy().astype(float))), \
                                  axis=1)
vis_abs_med = idr2_df_tint.abs().groupby(level=['freq']).median().transpose()

# # med first then abs
# idr2_df_tint = idr2_df.xs(time_integration, level='time_int', drop_level=True).groupby(level=['freq']).median()
# idr2_df_tint = idr2_df_tint.apply(lambda row: pd.Series(makeCArray(row[vis_list].to_numpy().astype(float))), \
#                                   axis=1)
# vis_abs_med = idr2_df_tint.abs().transpose()

In [None]:
freqs_arr = np.arange(md['Nfreqs'])[50:-50]
blgrp_arr = np.arange(md['no_unq_bls'])
tints_arr = np.arange(md['Ntimes'])

In [None]:
vmax = np.nanpercentile(vis_abs_med.to_numpy(), 97)

arr_pcmesh(freqs_arr, blgrp_arr, vis_abs_med, vmin=0, vmax=vmax, extend='max', \
           xlabel='Frequency Channel', ylabel='Redundant Baseline Group', \
           clabel=r'$\mathop{\mathrm{med}}(|V|)$', xlim=(0, md['Nfreqs']-1), sci_fmt=True)

In [None]:
idr2_df_tint = idr2_df.xs(time_integration, level='time_int', drop_level=True).groupby(level=['freq']).var()
vis_std = np.sqrt(idr2_df_tint[vis_list[0::2]].to_numpy() + idr2_df_tint[vis_list[1::2]].to_numpy()).transpose()

arr_pcmesh(freqs_arr, blgrp_arr, vis_std, vmin=0, vmax=0.025, extend='max', \
           xlabel='Frequency Channel', ylabel='Redundant Baseline Group', clabel=r'$\mathop{\mathrm{std}}(V)$', \
           xlim=(0, md['Nfreqs']-1), sci_fmt=True)

### Selected baseline (14m EW)

In [None]:
ew_bl_id = 2  # 14 m EW baselines

Ntimes = idr2_df.index.get_level_values('time_int').unique().size
Nfreqs = idr2_df.index.get_level_values('freq').unique().size

In [None]:
arr = np.sqrt(idr2_df[str(ew_bl_id*2)]**2 + idr2_df[str(ew_bl_id*2+1)]**2).groupby(level=['freq', 'time_int']).\
    median().to_numpy().reshape(Nfreqs, Ntimes).transpose()

arr_pcmesh(freqs_arr, tints_arr, arr, vmin=0, vmax=0.06, extend='max', \
           xlabel='Frequency Channel', ylabel='Time Integration', clabel=r'$\mathop{\mathrm{med}}(|V|)$', \
           xlim=(0, md['Nfreqs']-1), sci_fmt=True)

In [None]:
arr = np.sqrt(idr2_df[[str(ew_bl_id*2), str(ew_bl_id*2+1)]].groupby(level=['freq', 'time_int']).var().sum(axis=1).\
    to_numpy().reshape(Nfreqs, Ntimes).transpose())

In [None]:
fig, ax = arr_pcmesh(freqs_arr, tints_arr, arr, vmin=0, vmax=0.02, extend='max', \
    xlabel='Frequency Channel', ylabel='Time Integration', clabel=r'$\mathop{\mathrm{std}}(V)$', \
    xlim=(0, md['Nfreqs']-1), sci_fmt=True, rtn_fig_ax=True)

colorbar = ax.collections[0].colorbar
colorbar.set_ticks(np.array([0., 0.5, 1, 1.5, 2])*1e-2)

fig.tight_layout()
plt.show()

### NLLs

In [None]:
grp = idr2_df[['fun']].groupby(level=['freq', 'time_int'])
logl_med = grp.median().to_numpy().reshape(Nfreqs, Ntimes).transpose()

arr_pcmesh(freqs_arr, tints_arr, logl_med, vmin=0, vmax=0.17, extend='max', \
           xlabel='Frequency Channel', ylabel='Time Integration', clabel=r'$\mathrm{med}(-\ln(\mathcal{L}))$', \
           xlim=(0, md['Nfreqs']-1), sci_fmt=True)

In [None]:
logl_std = grp.std().to_numpy().reshape(Nfreqs, Ntimes).transpose()

arr_pcmesh(freqs_arr, tints_arr, logl_std, vmin=0, vmax=0.1, extend='max', \
           xlabel='Frequency Channel', ylabel='Time Integration', clabel=r'$\mathrm{std}(-\ln(\mathcal{L}))$', \
           xlim=(0, md['Nfreqs']-1), sci_fmt=True)

#### Flags from hera_cal

In [None]:
idr2_flags = np.load('../idr2_flags.npz')['arr_0']
flags_arr = idr2_flags.sum(axis=0).transpose()

arr_pcmesh(freqs_arr, tints_arr, flags_arr, vmin=0, \
           xlabel='Frequency Channel', ylabel='Time Integration', clabel=r'\# Flags', \
           xlim=(0, md['Nfreqs']-1))

## Outlier detection

### Z-score

In [None]:
z_thresh = 3.3

logl_mean = grp.mean().to_numpy().reshape(Nfreqs, Ntimes)
idr2_flags = idr2_flags.astype(bool)
logl_flags = np.empty_like(idr2_flags, dtype=bool)
logls = np.empty_like(logl_flags, dtype=float)

for i, jd in enumerate(idr2_jdsx):
    logl_jd = idr2_df[['fun']].xs(jd, level='JD', drop_level=True).to_numpy().reshape(Nfreqs, Ntimes)
    logl_flagi = np.logical_or(logl_jd > logl_mean + z_thresh*logl_std.T, logl_jd < logl_mean - z_thresh*logl_std.T)
    logl_flags[i, ...] = logl_flagi
    logls[i, ...] = logl_jd

new_flags = np.logical_and(logl_flags, ~idr2_flags)
print('{} potentially bad slices found that are not flagged through the '\
      'hera_cal pipeline, through std considerations'.format(np.sum(new_flags)))

In [None]:
bad_slices = np.where(new_flags)
print(np.array(idr2_jdsx)[bad_slices[0]]) # JDs
print(bad_slices[1] + 50) # Channels
print(bad_slices[2]) # Time integrations

In [None]:
print(logls[bad_slices], '\n') # NLLs for bad slices
print(logl_mean[bad_slices[1:]], '\n') # med NLLs across JDs for each bad slice slice
print(logl_std.T[bad_slices[1:]]*1e3, '\n')  # NLLs std across JDs for each bad slice

### Modified Z-scores & MAD

In [None]:
correction=1.4826

meds = np.repeat(np.squeeze(idr2_df[['fun']].groupby(level=['freq', 'time_int']).\
                 median().to_numpy()), len(idr2_jdsx))
dev_from_med = idr2_df['fun'].to_numpy() - meds
mad = np.median(np.abs(dev_from_med).reshape(-1, len(idr2_jdsx)), axis=1)
modz = dev_from_med/(correction*np.repeat(mad, len(idr2_jdsx)))
modz = np.swapaxes(modz.reshape((len(idr2_jdsx), Ntimes, Nfreqs), order='F'), 1, 2)

In [None]:
bad_slicesz = np.where(np.logical_and(np.abs(modz) > 11, ~idr2_flags))
print('{} potentially bad slices found that are not flagged through the '\
      'hera_cal pipeline, through Z-score considerations'.format(bad_slicesz[0].size))

In [None]:
print(modz[bad_slicesz], '\n') # modified Z-score
print(logls[bad_slicesz], '\n') # NLLs for bad slices
print(logl_mean[bad_slicesz[1:]], '\n') # med NLLs across JDs for each bad slice slice
print(logl_std.T[bad_slicesz[1:]]*1e3, '\n')  # NLL std across JDs for each bad slice
print(mad.reshape((Nfreqs, Ntimes))[bad_slicesz[1:]]*1e3)  # MAD for each bad slice

In [None]:
bad_slices_t = tuple(np.append(bad_slices[i], bad_slicesz[i]) for i in range(len(bad_slices)))
sort_index = np.argsort(bad_slices_t[0])
bad_slices_t = tuple(b[sort_index] for b in bad_slices_t)

In [None]:
nice_print = np.empty((8, bad_slices_t[0].size))

nice_print[0, :] = np.array(idr2_jdsx)[bad_slices_t[0]] # JDs
nice_print[1, :] = bad_slices_t[1] + 50 # Channels
nice_print[2, :] = bad_slices_t[2] # Time integrations
nice_print[3, :] = logls[bad_slices_t] # NLLs
nice_print[4, :] = logl_med.T[bad_slices_t[1:]] # med NLLs
nice_print[5, :] = (logls[bad_slices_t] - logl_mean[bad_slices_t[1:]]) \
                   / logl_std.T[bad_slices_t[1:]] # Z-score
nice_print[6, :] = modz[bad_slices_t] # Modified Z-score
nice_print[7, :] = mad.reshape((Nfreqs, Ntimes))[bad_slices_t[1:]]*1e3 # MAD

In [None]:
# for LaTeX table formatting
pp = nice_print.transpose()
print('JD & Channel & Time & NLL & med NLL & Z-score & Z-score & MAD\n')
for i in range(bad_slices_t[0].size):
    p = pp[i, :]
    print('{} & {} & {} & {:.4f} & {:.4f} & {:.4f} & {:.4f} & {:.4f} \\\\'.\
          format(int(p[0]), int(p[1]), int(p[2]), p[3], p[4], p[5], p[6], p[7]))