In [None]:
import glob
import os
import yaml
from collections import OrderedDict as odict

import numpy as np
import pandas as pd
import pickle
import seaborn as sns
from matplotlib import pyplot as plt

from hera_cal.io import HERAData

from simpleredcal.align_utils import idr2_jdsx
from simpleredcal.fit_diagnostics import abs_residuals, norm_residuals
from simpleredcal.plot_utils import clipped_heatmap, df_heatmap, flagged_hist, \
plot_res_grouped, plot_res_heatmap
from simpleredcal.red_likelihood import condenseMap, group_data, makeCArray, \
relabelAnts, split_rel_results
from simpleredcal.red_utils import DATAPATH, find_nearest, find_zen_file, \
JD2LSTPATH, RESPATH
from simpleredcal.xd_utils import union_bad_ants, XDgroup_data

In [None]:
%matplotlib inline

In [None]:
import matplotlib as mpl

plot_figs = False
if plot_figs:
    mpl.rcParams['figure.dpi'] = 300

mpl.rc('font',**{'family':'serif','serif':['cm']})
mpl.rc('text', usetex=True)
mpl.rc('text.latex', preamble=r'\usepackage{amssymb} \usepackage{amsmath}')

In [None]:
jd_time = 2458098.43869 # used to find LST that labels dataframe
pol = 'ee'
ndist = 'cauchy'

dir_path = os.path.join(RESPATH, 'xd_rel_dfs_nn')

In [None]:
lst_df = pd.read_pickle(JD2LSTPATH)

In [None]:
lst_ref = lst_df[lst_df['JD_time'] == jd_time]['LASTs'].values[0][0]
lst_stop = lst_df[lst_df['JD_time'] == jd_time]['LASTs'].values[0][-1]
xd_df_path = os.path.join(dir_path, 'xd_rel_df.{:.4f}.{}.{}.pkl'.format(lst_ref, pol, ndist))

In [None]:
with open(os.path.join(dir_path, 'xd_rel_df.{:.4f}.{}.md.pkl'.format(lst_ref, pol)), 'rb') as f:
    md = pickle.load(f)

xd_df = pd.read_pickle(xd_df_path)

chans = xd_df.index.get_level_values(level='freq').unique().values
tints = xd_df.index.get_level_values(level='time_int').unique().values

Nfreqs = chans.size
Ntints = tints.size

xd_df.sample(5).sort_index()

## Performance

### Number of iterations

In [None]:
plot_res_grouped(xd_df, 'nit', logy=True)

In [None]:
plot_res_heatmap(xd_df, 'nit', clip=True)

### Log-likelihood

In [None]:
plot_res_grouped(xd_df, 'fun', logy=True, figsize=(10, 7))

In [None]:
plot_res_heatmap(xd_df, 'fun', clip=True, clip_pctile=98, figsize=(8, 6))

### Residuals

In [None]:
xd_df[['med_abs_norm_res_Re', 'med_abs_norm_res_Im']] = xd_df.apply(lambda row: \
    pd.Series(abs_residuals(row['norm_residual'])), axis=1)

xd_df['med_abs_norm_res_comb'] = np.sqrt(xd_df['med_abs_norm_res_Re']**2 + \
                                         xd_df['med_abs_norm_res_Im']**2)

In [None]:
plot_res_heatmap(xd_df, 'med_abs_norm_res_comb', vmin=0.16, vmax=0.22, \
                 figsize=(8, 6))

## Gains at sample frequency and time slice

In [None]:
# Check results for a given frequency & time integration
test_freq = 600
test_tint = 53

resx = xd_df.loc[(test_freq, test_tint)][5:-5].values.astype(float)
test_vis, test_gains = split_rel_results(resx, md['no_unq_bls'], coords='cartesian')
test_gains = test_gains.reshape((md['JDs'].size, -1))

print('Mean gain amplitude across JDs for test frequency {} and time integration {}: '\
      '\n{}\n'.format(test_freq, test_tint, np.mean(np.abs(test_gains), axis=0)))
print('Mean gain phase across JDs for test frequency {} and time integration {}: '\
      '\n{}\n'.format(test_freq, test_tint, np.mean(np.angle(test_gains), axis=0)))

In [None]:
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(9, 3.5), sharey=True)

sns.heatmap(np.abs(test_gains), cmap=sns.cm.rocket_r, center=1, ax=ax1)
sns.heatmap(np.angle(test_gains), cmap='bwr', center=0, ax=ax2, vmin=-np.pi, vmax=np.pi)
ax1.set_xlabel('Antenna number')
ax2.set_xlabel('Antenna number')

ax1.set_yticks(np.arange(md['JDs'].size)+0.5)
ax1.set_yticklabels(md['JDs'], rotation=0)
ax2.set_yticks(np.arange(md['JDs'].size), minor=True)
ax1.tick_params(axis='y', which='minor', color='white')
ax2.tick_params(axis='y', which='minor', color='white')

ax1.set_xticks(np.arange(md['no_ants'])[::5]+0.5, minor=False)
ax1.set_xticklabels(np.arange(md['no_ants'])[::5])
ax2.set_xticks(np.arange(md['no_ants'])[::5]+0.5, minor=False)
ax2.set_xticklabels(np.arange(md['no_ants'])[::5])

ax1.grid(which='minor', axis='y', linestyle='--', lw=0.5)
ax2.grid(which='minor', axis='y', linestyle='--', lw=0.5)

plt.tight_layout()
plt.show()

## Visibilities

In [None]:
no_unq_bls = md['no_unq_bls']
no_min_p = 5 # number of columns in df that are attributes of the SciPy OptimizeResult 
vis_df = xd_df.iloc[:, no_min_p:no_unq_bls*2+no_min_p]

visC_df = vis_df.apply(lambda row: makeCArray(row.values), axis=1)
visC_df = pd.DataFrame(visC_df.values.tolist(), index=visC_df.index)

### Visibilities at test time integration

In [None]:
df = visC_df.xs(53, level='time_int').abs().transpose()
vmax = np.nanpercentile(df.values, 98)
vmin = np.nanpercentile(df.values, 2)
df_heatmap(df, xbase=25, ybase=5, \
           xlabel='Channel', ylabel='Redundant Baseline Group', \
           vmin=vmin, vmax=vmax, figsize=(8, 6))

In [None]:
df = visC_df.xs(53, level='time_int').applymap(np.angle).transpose()
df_heatmap(df, xbase=25, ybase=5, cmap='bwr', vmin=-np.pi, vmax=np.pi, center=0, \
           xlabel='Channel', ylabel='Redundant Baseline Group', figsize=(8, 6))

## Final flags from the HERA pipeline

Final flags are the individual final calibration flags + the manual flags applied by Nick Kern + the MAD-clipping flags from LST-binning

#### Get flags from .smooth_abs.calfits files

In [None]:
idr2_flags = np.load(f'{DATAPATH}/idr2_flags.npz')['arr_0'].astype(bool)
idr2_flags = idr2_flags[:, chans - 50, :] # Channels 500-700, which include Band 2

#### Additional flags by Nick Kern

In [None]:
if os.path.exists('/lustre/aoc/projects/hera/H1C_IDR2/'):
    nkern_flg_dir = '/lustre/aoc/projects/hera/H1C_IDR2/IDR2_2_pspec/v2/one_group/'
    local_work = False
else:
    nkern_flg_dir = '/Users/matyasmolnar/Downloads/HERA_Data/sample_data/'
    local_work = True

nkern_flg_file = os.path.join(nkern_flg_dir, 'preprocess_params.yaml')

# Read YAML file
with open(nkern_flg_file, 'r') as stream:
    data_loaded = yaml.safe_load(stream)

In [None]:
man_flags = np.concatenate([np.arange(i[0], i[1]+1) for i in \
                            data_loaded['algorithm']['fg_filt']['flag_chans']]).ravel()
rel_nflags = man_flags[np.where(np.logical_and(man_flags >= chans[0], man_flags <= chans[-1]))] - chans[0]
idr2_flags[:, rel_nflags, :] = True

##### Last individual dataset flags

In [None]:
fig, ax = clipped_heatmap(idr2_flags.sum(axis=0).transpose(), 'Time Integration', 'Channel', \
                          vmin=0, clip_pctile=100, figsize=(8, 5), xoffset=-chans[0], \
                          cbar_lab='\# Flagged Days')
fig.tight_layout()
plt.show()

#### MAD-clipping flags from LST-Binning

In [None]:
mad_clip_dir = '/lustre/aoc/projects/hera/mmolnar/LST_bin/binned_files/lstb_mad_flags'
mad_flag_files = sorted(glob.glob(os.path.join(mad_clip_dir, 'zen.grp1.of1.LST.*.bad_jds.pkl')))
mad_flag_lsts = np.array(['.'.join(os.path.basename(fn).split('.')[4:6]) for fn in mad_flag_files])

In [None]:
clip_f_idx1 = find_nearest(mad_flag_lsts.astype(float), lst_ref, condition='leq')[1]
clip_f_idx2 = find_nearest(mad_flag_lsts.astype(float), lst_stop, condition='leq')[1]

In [None]:
with open(mad_flag_files[clip_f_idx1], 'rb') as f:
    clip_flags1 = pickle.load(f)

with open(mad_flag_files[clip_f_idx2], 'rb') as f:
    clip_flags2 = pickle.load(f)

In [None]:
bad_ants_idr2 = union_bad_ants(idr2_jdsx)

clip_flags1 = {k: v for k, v in clip_flags1.items() if k[0] != k[1] and k[2] == 'ee'} # flt autos and pol
clip_flags1 = {k: v for k, v in clip_flags1.items() if not any(i in bad_ants_idr2 for i in k[:2])} # flt bad ants

clip_flags2 = {k: v for k, v in clip_flags2.items() if k[0] != k[1] and k[2] == 'ee'} # flt autos and pol
clip_flags2 = {k: v for k, v in clip_flags2.items() if not any(i in bad_ants_idr2 for i in k[:2])} # flt bad ants

In [None]:
if local_work:
    lst_binned_dir = '/Users/matyasmolnar/Downloads/HERA_Data/sample_data/'
else:
    lst_binned_dir = '/lustre/aoc/projects/hera/H1C_IDR2/IDR2_2/LSTBIN/one_group/grp1'

lst_binned_file1 = os.path.join(lst_binned_dir, 'zen.grp1.of1.LST.{}.HH.OCRSL.uvh5'.\
                               format(mad_flag_lsts[clip_f_idx1]))
lst_binned_file2 = os.path.join(lst_binned_dir, 'zen.grp1.of1.LST.{}.HH.OCRSL.uvh5'.\
                               format(mad_flag_lsts[clip_f_idx2]))

hd_lstb1 = HERAData(lst_binned_file1)
hd_lstb2 = HERAData(lst_binned_file2)
hd = HERAData(find_zen_file(jd_time))

In [None]:
bin_lsts1 = np.sort(np.append(hd_lstb1.lsts, hd_lstb1.lsts + np.median(np.ediff1d(hd_lstb1.lsts))/2))
bin_lsts2 = np.sort(np.append(hd_lstb2.lsts, hd_lstb2.lsts + np.median(np.ediff1d(hd_lstb2.lsts))/2))

bin_idx1 = np.digitize(hd.lsts, bin_lsts1, right=False)
bin_idx1 = bin_idx1[bin_idx1 < hd.Ntimes*2]

bin_idx2 = np.digitize(hd.lsts, bin_lsts2, right=False)
bin_idx2 = bin_idx2[np.where(bin_idx2 == 0)[0][-1]:]

In [None]:
relab_dict1 = condenseMap(bin_idx1)
relab_dict2 = {k: v+bin_idx1.size for k, v in condenseMap(bin_idx2).items()}

mad_flags_dict = odict()
for bl in clip_flags1.keys():
    mad_flags_dict[bl] = odict()
    # Iterate over 1st MAD-clipped dataset
    for t, v in clip_flags1[bl].items():
        if 2*t in bin_idx1:
            mad_flags_dict[bl][relab_dict1[2*t]] = v[::2]
        if 2*t+1 in bin_idx1:
            mad_flags_dict[bl][relab_dict1[2*t+1]] = v[1::2]
    # Iterate over 2nd MAD-clipped dataset
    for t, v in clip_flags2[bl].items():
        if 2*t in bin_idx2:
            mad_flags_dict[bl][relab_dict2[2*t]] = v[::2]
        if 2*t+1 in bin_idx2:
            mad_flags_dict[bl][relab_dict2[2*t+1]] = v[1::2]

In [None]:
# Turn flags from MAD-clipping to ndarray
RedG = md['redg']
mad_flags = np.empty((len(idr2_jdsx), hd.Nfreqs, hd.Ntimes, RedG.shape[0]), dtype=bool)

for i, bl_row in enumerate(RedG):
    mad_flags[:, :, :, i] = np.moveaxis(np.array(list(mad_flags_dict[(*RedG[0][1:], pol)].values())), \
                                        [1, 2, 0], [0, 1, 2])
mad_flags = mad_flags[:, chans, ...] # Channels 500-700, which include Band 2

In [None]:
tot_flags = mad_flags + np.tile(idr2_flags[..., np.newaxis], RedG.shape[0])

In [None]:
print('For HERA data over JDs {}, channels {}-{}, LASTs {:.3f}-{:.3f}, '\
      'excluding all those with bad antennas, there are {:,} visibilities, '\
      'of which {:,} are flagged from the calibration pipeline and manual flagging '\
      'and {:,} are flagged through MAD-clippnig.'\
      .format(idr2_jdsx, chans[0], chans[-1], lst_ref, lst_stop, tot_flags.size, \
              tot_flags.sum(), mad_flags.sum()))

In [None]:
tot_flags.shape

In [None]:
# With MAD-clipping
tot_flags_d = np.all(tot_flags, axis=3)
fig, ax = clipped_heatmap(tot_flags_d.sum(axis=0).transpose(), 'Time Integration', 'Channel', \
                          vmin=0, clip_pctile=100, figsize=(8, 5), xoffset=-chans[0], \
                          cbar_lab='No Flagged Days')
fig.tight_layout()
plt.show()

## Finding additional flags through xd_rel_cal

### Negative log-likelihood histograms

We look at the mininmum negative log-likelihood from across days redundant calibration $-\ln(\mathcal{L}^C_\mathrm{xd\_rel})$ solved using **xd_rel_cal**.

In [None]:
nll_upper_cut = 2
xd_rel_values = xd_df['fun'].values.astype(float)

flgs_all = np.zeros_like(tot_flags_d[0, ...])
flg_pct = 50 / 100
flgs_all[np.where(tot_flags_d.mean(axis=0) > flg_pct)] = True # if 50% of days flagged
flgs_all = flgs_all.ravel(order='F')

flagged_hist(xd_rel_values, flgs_all, \
             xlabel=r'$-\ln(\mathcal{L}^C_\mathrm{xd\_rel})$', \
             lower_cut=0.1, upper_cut=nll_upper_cut, bin_width=0.02, hist_start=0, ylim=(0, 1500))

In [None]:
sus_slices = np.where((xd_rel_values > 10) & ~flgs_all)[0]
xd_df.index[sus_slices]

In [None]:
xd_rel_values[sus_slices]

This method can broadly tell us which slices (over days and baselines) have corrupted data, but forfeits the granularity of being able to flag specific day/channel/time/baseline slices, since it only explores data along the channel/time dimensions.

### Calculating the NLLs had the minimization been done with Gaussian distribution

In [None]:
if ndist == 'gaussian':
    nll_dist = 'cauchy'
if ndist == 'cauchy':
    nll_dist = 'gaussian'

In [None]:
res_cols = [col for col in xd_df.columns.values if col.isdigit()]

# Retrieve solved gains in array format
xd_gains = xd_df[res_cols[no_unq_bls*2:]].values.reshape((Nfreqs, Ntints, md['JDs'].size, -1))
xd_gains = np.moveaxis(xd_gains, [2, 0, 1, 3], [0, 1, 2, 3])
y = xd_gains.reshape(xd_gains.shape[:3] + (md['no_ants'], -1, 2))
xd_gains = np.squeeze(y[..., 0] + 1j*y[..., 1])

# Retrieve solved visibilities in array format
xd_vis = xd_df[res_cols[:no_unq_bls*2]].values.reshape((Nfreqs, Ntints, -1, 2))
xd_vis = xd_vis[..., 0] + 1j*xd_vis[..., 1]
xd_vis = np.tile(np.expand_dims(xd_vis, axis=0), (md['JDs'].size, 1, 1, 1))

In [None]:
if Ntints == md['Ntimes']:
    tints = None
if (md['JDs'] == idr2_jdsx).all():
    jds = idr2_jdsx
else:
    jds = md['JDs']

if os.path.exists(f'{DATAPATH}/test_idr2_cdata.npz'):
    local_work = True
    local_jd = 2458098.43869
    if jd_time == local_jd:
    # retrieve data locally
        cdata = np.load(f'{DATAPATH}/test_idr2_cdata.npz')['arr_0']
        cndata = np.load(f'{DATAPATH}/test_idr2_cndata.npz')['arr_0']
    else:
        raise Exception('Only H1C_IDR2 visibility data across JDs aligned with {} '
                        'is available locally.'.format(local_jd))
else:
    _, _, cdata, cndata = XDgroup_data(jd_time, jds, pol, chans=chans, \
        tints=tints, bad_ants=True, use_flags='first', noise=True)
    cdata = cdata.data

cRedG = relabelAnts(RedG)

In [None]:
NLLFN = {'cauchy': lambda delta: np.log(1 + np.square(np.abs(delta))).sum(axis=(0, -1)),
         'gaussian': lambda delta: np.square(np.abs(delta)).sum(axis=(0, -1))}

In [None]:
gvis = xd_vis[..., cRedG[:, 0]]*xd_gains[..., cRedG[:, 1]]*np.conj(xd_gains[..., cRedG[:, 2]])
delta = cdata - gvis
nlog_likelihood = NLLFN[nll_dist](delta)

In [None]:
fig, ax = clipped_heatmap(nlog_likelihood.transpose(), ylabel='Time integration', 
                          clip_pctile=98, figsize=(8, 6), clip_rnd=100, xoffset=-chans[0])
fig.tight_layout()
plt.show()

### Residuals between solved visibilities and gain transformed observed visibilities

With across days redundant calibration, we obtain a single set of visibilities. We wish to compare these solved visibilities to the observed visibilities on different days, to find potential outliers. The amplitudes of these visibilities could be compared; their phases, however, cannot, since there are degenerate offsets between them.

We do not wish to calculate these degenerate offsets, as this is computationally expensive - this would require doing pairs of comparison between the solved xd_rel_cal solutions and each day. What we can do, however, is take the observed visibilities and divide them by the solved gains to get a quantity that is comparable as it is degenerately consistent. The residual between this quantity and the true visibilities is what we use for outlier detection.

In [None]:
tr_vis = cdata / xd_gains[..., cRedG[:, 1]] / np.conj(xd_gains[..., cRedG[:, 2]])
tr_res = xd_vis[..., cRedG[:, 0]] - tr_vis

#### Modified Z-score

In [None]:
correction = 1.4826
mad = np.median(np.abs(tr_res), axis=0) # Median Absolute Deviation
modz = np.abs(tr_res)/(correction*np.tile(np.expand_dims(mad, axis=0), \
                                          (md['JDs'].size, 1, 1, 1))) # Modified Z-score
# Note that these quantities are about the solved visibility values, and not
# about their medians

In [None]:
mean_modz = np.mean(modz, axis=(0, -1)) # mean over days and baselines
fig, ax = clipped_heatmap(mean_modz.transpose(), 'Time Integration', 'Channel', \
                          clip_pctile=99, figsize=(8, 5), xoffset=-chans[0])
plt.tight_layout()
plt.show()

In [None]:
# Mean modified Z-score, if looking at the mean mod-Z across baselines
bad_slicesz = np.where(np.logical_and(modz.mean(axis=-1) > 0.8, ~tot_flags_d))
print('{} potentially bad day/chan/time slices found that are not flagged through the '\
      'hera_cal pipeline, through modified Z-score considerations'.format(bad_slicesz[0].size))

In [None]:
# Look at individual baselines
bad_slicesz_bl = np.where(np.logical_and(modz > 5, ~tot_flags))
print('{} potentially bad day/chan/time/baseline slices found that are not flagged through the '\
      'hera_cal pipeline, through modified Z-score considerations'.format(bad_slicesz_bl[0].size))

$\mathcal{R}_{\mathrm{man}}$ already calculated above, but run on last baseline dimension too - should be similar results to modified Z-score

In [None]:
nrm_resid = norm_residuals(xd_vis[..., cRedG[:, 0]], tr_vis)
abs_resid = np.median(np.abs(nrm_resid), axis=-1) # median over the baseline axis

In [None]:
mean_abs_resid = np.mean(abs_resid, axis=0) # mean over days
vmin = np.nanpercentile(mean_abs_resid, 1)
fig, ax = clipped_heatmap(mean_abs_resid.transpose(), 'Time Integration', 'Channel', \
                          clip_pctile=98, vmin=vmin, figsize=(8, 5), xoffset=-chans[0])
fig.tight_layout()
plt.show()

In [None]:
# TODO
# Histograms of NLL/Noise, R_man to find outliers