In [None]:
import os
import sys

import numpy as np
import pandas as pd
import pickle
from matplotlib import pyplot as plt

sys.path.insert(0, os.path.dirname(os.getcwd()))
from align_utils import align_df, idr2_jdsx
from plot_utils import clipped_heatmap, df_heatmap
from red_likelihood import makeCArray

In [None]:
%matplotlib inline

In [None]:
plot_figs = False
if plot_figs:
    import matplotlib as mpl
    mpl.rcParams['figure.dpi'] = 300

## Statistics of redundantly calibration + degenerately transformed datasets across JDs

In [None]:
jd_time = 2458098.43869
pol = 'ee'
ndist = 'gaussian'

In [None]:
with open(os.path.join('../rel_dfs', 'rel_df.{}.{}.md.pkl'.format(jd_time, pol)), \
          'rb') as f:
    md = pickle.load(f)
    
vis_list = list(map(str, np.arange(md['no_unq_bls']*2).tolist()))

In [None]:
idr2_df_path = '../aligned_red_deg.1.3826.ee.{}.pkl'.format(ndist)
idr2_df = pd.read_pickle(idr2_df_path)

### Selected time

In [None]:
time_integration = 53
idr2_df_tint = idr2_df.xs(time_integration, level='time_int', drop_level=True).groupby(level=['freq']).median()
idr2_df_tint = idr2_df_tint.apply(lambda row: pd.Series(makeCArray(row[vis_list].values.astype(float))), \
                                  axis=1)

vis_abs_med = idr2_df_tint.abs().transpose()
vmax = np.nanpercentile(vis_abs_med.values, 97)
df_heatmap(vis_abs_med, xbase=50, ybase=5, vmax=vmax, vmin=0, figsize=(8, 5), \
           xlabel='Channel', ylabel='Redundant Baseline Group')

In [None]:
idr2_df_tint = idr2_df.xs(time_integration, level='time_int', drop_level=True).groupby(level=['freq']).var()
# turning into complex values
idr2_df_tint = idr2_df_tint.apply(lambda row: pd.Series(makeCArray(row[vis_list].values.astype(float))), \
                                  axis=1)

vis_var = np.sqrt(idr2_df_tint.abs()).transpose()
vmax = np.nanpercentile(vis_var.values, 98)
df_heatmap(vis_var, xbase=50, ybase=5, vmax=vmax, vmin=0, figsize=(8, 5), \
           xlabel='Channel', ylabel='Redundant Baseline Group')

### Selected baseline (14m EW)

In [None]:
ew_bl_id = 2 # 14 m EW baselines

Ntimes = idr2_df.index.get_level_values('time_int').unique().size
Nfreqs = idr2_df.index.get_level_values('freq').unique().size

grp = idr2_df.groupby(level=['freq', 'time_int'])

In [None]:
arr = np.sqrt(np.square(grp.median()[str(ew_bl_id)]) + \
              np.square(grp.median()[str(ew_bl_id+1)])).values.reshape(Nfreqs, Ntimes)

fig, ax = clipped_heatmap(arr.transpose(), 'Time Integration', 'Channel', \
                          vmin=0, figsize=(8, 5), sci_format=True)
plt.tight_layout()
plt.show()

In [None]:
arr = np.sqrt(np.square(grp.var()[str(ew_bl_id)]) + \
      np.square(grp.var()[str(ew_bl_id+1)])).values.reshape(Nfreqs, Ntimes)

fig, ax = clipped_heatmap(arr.transpose(), 'Time Integration', 'Channel', \
                          vmin=0, clip_pctile=98, clip_rnd=100000, figsize=(8, 5), sci_format=True)
plt.tight_layout()
plt.show()

### NLLs

In [None]:
grp = idr2_df[['fun']].groupby(level=['freq', 'time_int'])
logl_med = grp.median().values.reshape(Nfreqs, Ntimes)

fig, ax = clipped_heatmap(logl_med.transpose(), 'Time Integration', 'Channel', \
                          vmin=0, clip_pctile=95, clip_rnd=10000, figsize=(8, 5), sci_format=True)
plt.tight_layout()
plt.show()

In [None]:
logl_std = grp.std().values.reshape(Nfreqs, Ntimes)

fig, ax = clipped_heatmap(logl_std.transpose(), 'Time Integration', 'Channel', \
                          vmin=0, clip_pctile=95, clip_rnd=10000, figsize=(8, 5), sci_format=True)
plt.tight_layout()
plt.show()

#### Flags from hera_cal

In [None]:
idr2_flags = np.load('../idr2_flags.npz')['arr_0']

fig, ax = clipped_heatmap(idr2_flags.sum(axis=0).transpose(), 'Time Integration', 'Channel', \
                          vmin=0, clip_pctile=100, figsize=(8, 5))
plt.tight_layout()
plt.show()

## Outlier detection

### Z-score

In [None]:
z_thresh = 3.3

logl_mean = grp.mean().values.reshape(Nfreqs, Ntimes)
idr2_flags = idr2_flags.astype(bool)
logl_flags = np.empty_like(idr2_flags, dtype=bool)
logls = np.empty_like(logl_flags, dtype=float)

for i, jd in enumerate(idr2_jdsx):
    logl_jd = idr2_df[['fun']].xs(jd, level='JD', drop_level=True).values.reshape(Nfreqs, Ntimes)
    logl_flagi = np.logical_or(logl_jd > logl_mean + z_thresh*logl_std, logl_jd < logl_mean - z_thresh*logl_std)
    logl_flags[i, ...] = logl_flagi
    logls[i, ...] = logl_jd

new_flags = np.logical_and(logl_flags, ~idr2_flags)
print('{} potentially bad slices found that are not flagged through the '\
      'hera_cal pipeline, through std considerations'.format(np.sum(new_flags)))

In [None]:
bad_slices = np.where(new_flags)
print(np.array(idr2_jdsx)[bad_slices[0]]) # JDs
print(bad_slices[1] + 50) # Channels
print(bad_slices[2]) # Time integrations

In [None]:
print(logls[bad_slices], '\n') # NLLs for bad slices
print(logl_mean[bad_slices[1:]], '\n') # med NLLs across JDs for each bad slice slice
print(logl_std[bad_slices[1:]]*1e3, '\n')  # NLLs std across JDs for each bad slice

### Modified Z-scores & MAD

In [None]:
correction=1.4826

meds = np.repeat(np.squeeze(idr2_df[['fun']].groupby(level=['freq', 'time_int']).\
                 median().values), len(idr2_jdsx))
dev_from_med = idr2_df['fun'].values - meds
mad = np.median(np.abs(dev_from_med).reshape(-1, len(idr2_jdsx)), axis=1)
modz = dev_from_med/(correction*np.repeat(mad, len(idr2_jdsx)))
modz = np.swapaxes(modz.reshape((len(idr2_jdsx), Ntimes, Nfreqs), order='F'), 1, 2)

In [None]:
bad_slicesz = np.where(np.logical_and(np.abs(modz) > 11, ~idr2_flags))
print('{} potentially bad slices found that are not flagged through the '\
      'hera_cal pipeline, through z-score considerations'.format(bad_slicesz[0].size))

In [None]:
print(modz[bad_slicesz], '\n') # modified Z-score
print(logls[bad_slicesz], '\n') # NLLs for bad slices
print(logl_mean[bad_slicesz[1:]], '\n') # med NLLs across JDs for each bad slice slice
print(logl_std[bad_slicesz[1:]]*1e3, '\n')  # NLL std across JDs for each bad slice
print(mad.reshape((Nfreqs, Ntimes))[bad_slicesz[1:]]*1e3)  # MAD for each bad slice

In [None]:
bad_slices_t = tuple(np.append(bad_slices[i], bad_slicesz[i]) for i in range(len(bad_slices)))
sort_index = np.argsort(bad_slices_t[0])
bad_slices_t = tuple(b[sort_index] for b in bad_slices_t)

In [None]:
nice_print = np.empty((8, bad_slices_t[0].size))

nice_print[0, :] = np.array(idr2_jdsx)[bad_slices_t[0]] # JDs
nice_print[1, :] = bad_slices_t[1] + 50 # Channels
nice_print[2, :] = bad_slices_t[2] # Time integrations
nice_print[3, :] = logls[bad_slices_t] # NLLs
nice_print[4, :] = logl_med[bad_slices_t[1:]] # med NLLs
nice_print[5, :] = (logls[bad_slices_t] - logl_mean[bad_slices_t[1:]]) \
                   / logl_std[bad_slices_t[1:]] # Z-score
nice_print[6, :] = modz[bad_slices_t] # Modified Z-score
nice_print[7, :] = mad.reshape((Nfreqs, Ntimes))[bad_slices_t[1:]]*1e3 # MAD

In [None]:
# for LaTeX table formatting
pp = nice_print.transpose()
print('JD & Channel & Time & NLL & med NLL & Z-score & Z-score & MAD\n')
for i in range(bad_slices_t[0].size):
    p = pp[i, :]
    print('{} & {} & {} & {:.4f} & {:.4f} & {:.4f} & {:.4f} & {:.4f} \\\\'.\
          format(int(p[0]), int(p[1]), int(p[2]), p[3], p[4], p[5], p[6], p[7]))