In [None]:
import glob
import os
import sys
import yaml
from collections import OrderedDict as odict

import numpy as np
import pandas as pd
import pickle
import seaborn as sns
from matplotlib import pyplot as plt
from matplotlib import ticker

from hera_cal.io import HERAData
from hera_cal.utils import LST2JD

sys.path.insert(0, os.path.dirname(os.getcwd()))
from align_utils import idr2_jdsx
from plot_utils import clipped_heatmap
from red_likelihood import relabelAnts
from red_utils import calfits_to_flags, find_nearest, find_zen_file, match_lst
from xd_utils import union_bad_ants, XDgroup_data

In [None]:
xd_dir_path = '../xd_rel_dfs_nn'
jd_time = 2458098.43869 # reference JD
pol = 'ee'
ndist = 'cauchy'

In [None]:
hd = HERAData(find_zen_file(jd_time))

### Getting channels, LASTs and datasets for Band 2 Field 2

In [None]:
# Band 2 Frequencies
b2_freq_start = 150.3*1e6 # MHz
b2_freq_stop = 167.8*1e6 # MHz
band2_chans = np.where(np.logical_and(hd.freqs >= b2_freq_start, hd.freqs <= b2_freq_stop))[0]

# Field 2 LASTs
b2_lst_start = 4.5 # hours
b2_lst_stop = 6.5 # hours

# Convert to radians
lst_start_rad = b2_lst_start * np.pi / 12
lst_end_rad = b2_lst_stop * np.pi / 12

# Match with dataset labels
last_df = pd.read_pickle('../jd_lst_map_idr2.pkl')
jd_start_match = find_nearest(last_df['JD_time'].values, LST2JD(lst_start_rad, int(jd_time)), \
                              condition='leq')[0]
jd_end_match = find_nearest(last_df['JD_time'].values, LST2JD(lst_end_rad, int(jd_time)), \
                            condition='geq')[0]

# Field 2 Datasets
tocal = np.where(np.logical_and(last_df['JD_time'].loc[int(jd_time)].values >= jd_start_match, \
                                last_df['JD_time'].loc[int(jd_time)].values <= jd_end_match))
field2_refs = last_df['JD_time'].loc[int(jd_time)].values[tocal]

print('Band 2 channels are from {}-{} and Field 2 spans from LASTs {}-{}'\
      .format(band2_chans[0], band2_chans[-1], b2_lst_start, b2_lst_stop))

In [None]:
lst_ref = last_df[last_df['JD_time'] == jd_time]['LASTs'].values[0][0]

with open(os.path.join(xd_dir_path, 'xd_rel_df.{:.4f}.{}.md.pkl'.format(lst_ref, pol)), 'rb') as f:
    md = pickle.load(f)
    
RedG = md['redg']

In [None]:
jd_label_dict_fn = '../b2f2_jd_dict.npz'

# Getting datasets for H1C_IDR2 JDs that are in Field 2
jd_label_dict = {}
for jd_ref in field2_refs:
    jds = []
    for jd in idr2_jdsx[1:]:
        jda = str(match_lst(jd_ref, jd, tint=0))
        jdb = str(match_lst(jd_ref, jd, tint=-1))
        if len(jda) < 13:
            jda = jda + '0'
        if len(jdb) < 13:
            jdb = jdb + '0'
        jds.append([jda, jdb])
    jd_label_dict[jd_ref] = jds
    
if not os.path.exists(jd_label_dict_fn):
    np.savez(jd_label_dict_fn, jd_dict=jd_label_dict)

### Building final flags array

Final flags are the individual final calibration flags + the manual flags applied by Nick Kern + the MAD-clipping flags from LST-binning

#### Calibration flags

In [None]:
cal_flags_fn = '../b2f2_cal_flags.npz'

if os.path.exists(cal_flags_fn):
    cal_flags = np.load(cal_flags_fn)['flags']

else:
    cal_file = 'smooth_abs'
    for count, jd_ref in enumerate(field2_refs):

        cal_flags_jd = np.zeros((len(idr2_jdsx), hd.Nfreqs, hd.Ntimes, RedG.shape[0]), dtype=bool)
        cal_flags_jd[0, ...] = calfits_to_flags(jd_ref, cal_file, pol=pol, add_bad_ants=None)

        jds = jd_label_dict[jd_ref]
        lst_ref = last_df[last_df['JD_time'] == jd_ref]['LASTs'].values[0][0]
        
        for i, (JDa, JDb) in enumerate(jds):
            flagsa = calfits_to_flags(JDa, cal_file, pol=pol, add_bad_ants=None)
            flagsb = calfits_to_flags(JDb, cal_file, pol=pol, add_bad_ants=None)

            last2 = last_df[last_df['JD_time'] == float(JDa)]['LASTs'].values[0]
            _, offset = find_nearest(last2, lst_ref)

            flagsc = np.concatenate((flagsa[:, offset:], flagsb[:, :offset]), axis=1)
            cal_flags_jd[i+1, ...] = flagsc

        if count == 0:
            cal_flags = cal_flags_jd
        else:
            cal_flags = np.concatenate((cal_flags, cal_flags_jd), axis=2)
            
    np.savez_compressed(cal_flags_fn, flags=cal_flags, jds_refs=field2_refs)

#### Nick Kern's manual flags

In [None]:
if os.path.exists('/lustre/aoc/projects/hera/H1C_IDR2/'):
    nkern_flg_dir = '/lustre/aoc/projects/hera/H1C_IDR2/IDR2_2_pspec/v2/one_group/'
    local_work = False
else:
    nkern_flg_dir = '/Users/matyasmolnar/Downloads/HERA_Data/robust_cal'
    local_work = True

nkern_flg_file = os.path.join(nkern_flg_dir, 'preprocess_params.yaml')

# Read YAML file
with open(nkern_flg_file, 'r') as stream:
    data_loaded = yaml.safe_load(stream)
    
man_flags = np.concatenate([np.arange(i[0], i[1]+1) for i in \
                            data_loaded['algorithm']['fg_filt']['flag_chans']]).ravel()
cal_flags[:, man_flags, :, :] = True

In [None]:
fig, ax = clipped_heatmap(np.all(cal_flags, axis=3).sum(axis=0).transpose(), 'Time Integration', 'Channel', \
                          vmin=0, clip_pctile=100, figsize=(8, 5), xoffset=None, ybase=60,\
                          cbar_lab='# Flagged Days')
plt.tight_layout()
plt.show()

In [None]:
# Restricting to Band 2
cal_flags = cal_flags[:, band2_chans, :, :]

#### Getting MAD-clipping flags from LST-Binning

In [None]:
last_span = []
for jd_ref in field2_refs:
    lst_ref = last_df[last_df['JD_time'] == jd_ref]['LASTs'].values[0][0]
    last_span.append(lst_ref)
last_span.append(last_df[last_df['JD_time'] == jd_ref]['LASTs'].values[0][-1])

In [None]:
mad_flags_fn = '../b2f2_mad_flags.npz'

if os.path.exists(mad_flags_fn):
    mad_flags = np.load(mad_flags_fn)['flags']
else:
    mad_clip_dir = '/lustre/aoc/projects/hera/mmolnar/LST_bin/binned_files'
    mad_flag_files = sorted(glob.glob(os.path.join(mad_clip_dir, 'zen.grp1.of1.LST.*.bad_jds.pkl')))
    mad_flag_lsts = np.array(['.'.join(os.path.basename(fn).split('.')[4:6]) for fn in mad_flag_files])

    to_open = []
    for last in last_span:
        clip_f_idx = find_nearest(mad_flag_lsts.astype(float), last, condition='leq')[1]
        to_open.append(mad_flag_files[clip_f_idx])
    to_open = sorted(list(set(to_open)))

    bad_ants_idr2 = union_bad_ants(idr2_jdsx)
    lst_binned_dir = '/lustre/aoc/projects/hera/H1C_IDR2/IDR2_2/LSTBIN/one_group/grp1'

    for count, mad_clip_fn in enumerate(to_open):

        with open(mad_clip_fn, 'rb') as f:
            mad_flags_i = pickle.load(f)

        mad_flags_i = {k: v for k, v in mad_flags_i.items() if \
                       k[0] != k[1] and k[2] == 'ee'} # flt autos and pol
        mad_flags_i = {k: v for k, v in mad_flags_i.items() if \
                       not any(i in bad_ants_idr2 for i in k[:2])} # flt bad ants

        mad_flags_dict_i = odict()
        for bl in mad_flags_i.keys():
            mad_flags_dict_i[bl] = odict()
            for t, v in mad_flags_i[bl].items():
                mad_flags_dict_i[bl][2*t] = v[::2]
                mad_flags_dict_i[bl][2*t+1] = v[1::2]

        # Turn flags from MAD-clipping to ndarray
        mad_flags_arr_i = np.empty((len(idr2_jdsx), hd.Nfreqs, hd.Ntimes*2, RedG.shape[0]), dtype=bool)

        for i, bl_row in enumerate(RedG):
            mad_flags_arr_i[:, :, :, i] = np.moveaxis(np.array(list(mad_flags_dict_i[(*RedG[0][1:], pol)]\
                                              .values())), [1, 2, 0], [0, 1, 2])
        mad_flags_arr_i = mad_flags_arr_i[:, band2_chans, ...]


        if count == 0 or count == len(to_open) - 1:
            
            mad_flag_lst = '.'.join(os.path.basename(mad_clip_fn).split('.')[4:6])
            lst_binned_file = os.path.join(lst_binned_dir, 'zen.grp1.of1.LST.{}.HH.OCRSL.uvh5'.\
                                   format(mad_flag_lst))
            hd_lstb_i = HERAData(lst_binned_file)

            if count == 0:
                # Selecting first LAST
                lst_ref_i = HERAData(find_zen_file(field2_refs[0])).lsts[0]
            else:
                # Selecting last LAST
                lst_ref_i = HERAData(find_zen_file(field2_refs[-1])).lsts[-1]

            bin_lsts = np.sort(np.append(hd_lstb_i.lsts, hd_lstb_i.lsts + \
                                         np.median(np.ediff1d(hd_lstb_i.lsts))/2))
            adj_idx = find_nearest(bin_lsts, lst_ref_i, condition=None)[1]

            if count == 0:
                # Slicing s.t. times are aligned with those from field2_refs[0]
                mad_flags = mad_flags_arr_i[:, :, adj_idx:, :]
            else:
                # Slicing s.t. times do not go beyond those in field2_refs[-1]
                mad_flags_arr_i = mad_flags_arr_i[:, :, :adj_idx+1, :]

        if count != 0:
            mad_flags = np.concatenate((mad_flags, mad_flags_arr_i), axis=2)
            
    np.savez_compressed(mad_flags_fn, flags=mad_flags)

In [None]:
tot_flags = cal_flags + mad_flags

### Getting raw visibility data and xd_rel_cal solutions

In [None]:
lasts_b2 = []
for jd_ref in field2_refs:
    lasts_b2.append(last_df[last_df['JD_time'] == jd_ref]['LASTs'].values[0])
lasts_b2 = np.array(lasts_b2).flatten()

In [None]:
vis_data_fn = '../b2f2_vis_data.npz'

if os.path.exists(vis_data_fn):
    cdata = np.load(vis_data_fn)['data']
else:
    # this will take a while... run on cluster
    for count, jd_ref in enumerate(field2_refs):
        _, _, cdata_i, _ = XDgroup_data(jd_ref, idr2_jdsx, pol, chans=band2_chans, \
            tints=None, bad_ants=True, use_flags='first', noise=True)
        cdata_i = cdata_i.data
        
        if count == 0:
            cdata = cdata_i
        else:
            cdata = np.concatenate((cdata, cdata_i), axis=2)
            
    np.savez_compressed(vis_data_fn, data=cdata, jds_refs=field2_refs)

In [None]:
for count, jd_ref in enumerate(field2_refs):

    lst_ref = last_df[last_df['JD_time'] == jd_ref]['LASTs'].values[0][0]
    xd_df_path = os.path.join(xd_dir_path, 'xd_rel_df.{:.4f}.{}.{}.pkl'.format(lst_ref, pol, ndist))

    xd_df_i = pd.read_pickle(xd_df_path)
    
    if count == 0:
        xd_df = xd_df_i
    else:
        xd_df_i.reset_index(level='time_int', inplace=True)
        xd_df_i['time_int'] += 60*count
        xd_df_i.set_index('time_int', append=True, inplace=True)
        
        xd_df = pd.concat([xd_df, xd_df_i])
        
xd_df.sort_index(inplace=True)
xd_df = xd_df[xd_df.index.get_level_values(level='freq').isin(band2_chans)]

chans = xd_df.index.get_level_values(level='freq').unique().values
tints = xd_df.index.get_level_values(level='time_int').unique().values

Nfreqs = chans.size
Ntints = tints.size

In [None]:
res_cols = [col for col in xd_df.columns.values if col.isdigit()]

# Retrieve solved gains in array format
xd_gains = xd_df[res_cols[md['no_unq_bls']*2:]].values.reshape((Nfreqs, Ntints, md['JDs'].size, -1))
xd_gains = np.moveaxis(xd_gains, [2, 0, 1, 3], [0, 1, 2, 3])
y = xd_gains.reshape(xd_gains.shape[:3] + (md['no_ants'], -1, 2))
xd_gains = np.squeeze(y[..., 0] + 1j*y[..., 1])

# Retrieve solved visibilities in array format
xd_vis = xd_df[res_cols[:md['no_unq_bls']*2]].values.reshape((Nfreqs, Ntints, -1, 2))
xd_vis = xd_vis[..., 0] + 1j*xd_vis[..., 1]
xd_vis = np.tile(np.expand_dims(xd_vis, axis=0), (md['JDs'].size, 1, 1, 1))

In [None]:
# Checking data consistency:
print(tot_flags.shape)
print(cdata.shape)
print((md['JDs'].size, Nfreqs, Ntints))
print(xd_gains.shape)
print(xd_vis.shape)

In [None]:
# Transforming visibilities to comparable quantities, s.t. statistics can be done on them
cRedG = relabelAnts(RedG)
tr_vis = cdata / xd_gains[..., cRedG[:, 1]] / np.conj(xd_gains[..., cRedG[:, 2]])
tr_res = xd_vis[..., cRedG[:, 0]] - tr_vis

### Modifiez Z-score clipping

In [None]:
correction = 1.4826
mad = np.median(np.abs(tr_res), axis=0) # Median Absolute Deviation
modz = np.abs(tr_res)/(correction*np.tile(np.expand_dims(mad, axis=0), \
                                          (md['JDs'].size, 1, 1, 1))) # Modified Z-score
# Note that these quantities are about the solved visibility values, and not
# about their medians

In [None]:
# Look at individual baselines
bad_slicesz_bl = np.where(np.logical_and(modz > 5, ~tot_flags))
print('{} potentially bad day/chan/time/baseline slices found that are not flagged through the '\
      'hera_cal pipeline, through modified Z-score considerations'.format(bad_slicesz_bl[0].size))

In [None]:
modz_fn = '../b2f2_mz_flags.npz'
if not os.path.exists(modz_fn):
    np.savez_compressed(modz_fn, modz=modz, simga4=modz > 4, simga5=modz > 5, simga6=modz > 6)

In [None]:
tot_flags = tot_flags + (modz > 5)

In [None]:
# Final flags to be carried through to LST-binning
tot_flags_fn = '../b2f2_tot_flags.npz'
if not os.path.exists(tot_flags_fn):
    np.savez_compressed(tot_flags_fn, flags=tot_flags)