<center><strong><font size=+3>Outlier Detection of HERA Data with Robust Mahalanobis Distances</font></center>
<br><br>
</center>
<center><strong><font size=+2>Matyas Molnar and Bojan Nikolic</font><br></strong></center>
<br><center><strong><font size=+1>Astrophysics Group, Cavendish Laboratory, University of Cambridge</font></strong></center>

In [None]:
import os

import matplotlib as mpl
import numpy as np
import seaborn as sns
from matplotlib import pyplot as plt
from scipy import special, stats
from sklearn.covariance import MinCovDet

from robstat.utils import DATAPATH, decomposeCArray, flt_nan

In [None]:
%matplotlib inline

In [None]:
plot_figs = True
if plot_figs:
    mpl.rcParams['figure.figsize'] = (12, 8)
    mpl.rcParams['figure.dpi'] = 300
else:
    mpl.rcParams['figure.figsize'] = (5, 3)
    mpl.rcParams['figure.dpi'] = 125

### Load HERA visibility data

In [None]:
vis_file = os.path.join(DATAPATH, 'lstb_no_avg/idr2_lstb_14m_ee_1.40949.npz')
vis_data = np.load(vis_file)

In [None]:
data = vis_data['data']
redg = vis_data['redg']
pol = vis_data['pol'].item()

flags = np.isnan(data)
no_chans = data.shape[1]
chans = np.arange(no_chans)
freqs = np.linspace(1e8, 2e8, 1025)[:-1]
new_no_tints = data.shape[2]

### Visualize outlier detection

In [None]:
# example data - combine 2 freq channels
eg_data = data[:, 650:652, 0, 0]
points = decomposeCArray(flt_nan(eg_data.flatten()))

In [None]:
robust_cov = MinCovDet().fit(points)
chi2_quantile = 0.975
chi2_thresh = stats.chi2.ppf(chi2_quantile, points.shape[1])
outliers = np.where(robust_cov.mahalanobis(points) > chi2_thresh)[0]

In [None]:
# to normal
stats.norm.cdf(np.sqrt(2*chi2_thresh) - np.sqrt(2*points.shape[1] - 1))

In [None]:
real_lab = r'$\mathfrak{Re} \; (V)$'
imag_lab = r'$\mathfrak{Im} \; (V)$'

In [None]:
fig, ax = plt.subplots(figsize=(7, 5))

inliers = np.delete(points, outliers, axis=0)
sns.scatterplot(x=inliers[:, 0], y=inliers[:, 1], ax=ax, label='Inliers')
sns.scatterplot(x=points[outliers, 0], y=points[outliers, 1], color='red', ax=ax, \
                label='Outliers')
sns.scatterplot(x=[robust_cov.location_[0]], y=[robust_cov.location_[1]], color='brown', \
                ax=ax, label='MCD location', marker='+')

# Create meshgrid of feature values
xx, yy = np.meshgrid(np.linspace(plt.xlim()[0], plt.xlim()[1], 1001),
                     np.linspace(plt.ylim()[0], plt.ylim()[1], 1001))
zz = np.c_[xx.ravel(), yy.ravel()]

# Calculate the MCD based Mahalanobis distances
mahal_robust_cov = robust_cov.mahalanobis(zz)
mahal_robust_cov = mahal_robust_cov.reshape(xx.shape)
robust_contour = ax.contour(xx, yy, np.sqrt(mahal_robust_cov), cmap=plt.cm.YlOrBr_r, \
                            linestyles='--')

ax.clabel(robust_contour, robust_contour.levels, inline=True, fontsize=10)
thresh_contour = ax.contour(xx, yy, np.sqrt(mahal_robust_cov), [np.sqrt(chi2_thresh)], colors='red')

ax.annotate('Robust Mahalanobis Distance', xy=(0.64, 0.10), xycoords='axes fraction', \
            bbox=dict(boxstyle='round', facecolor='white'), size=10, color='orange')

ax.annotate(r'$\chi_{\mathrm{thresh}}$ = '+'{0:.3f}'.format(np.sqrt(chi2_thresh)), xy=(0.64, 0.03), \
            xycoords='axes fraction', bbox=dict(boxstyle='round', facecolor='white'), size=10, color='red')

ax.set_xlabel(real_lab)
ax.set_ylabel(imag_lab)

rmd_lims = [ax.get_xlim(), ax.get_ylim()]

plt.legend()
plt.tight_layout()

# plt.savefig('/Users/matyasmolnar/Dropbox/PhD/Papers/memo_mvo/Figures/contour_plot.pdf')
plt.show()

#### MAD-clipping 

In [None]:
from astropy.stats import mad_std
import matplotlib.patches as patches

from robstat.stdstat import mad_clip

In [None]:
sigma = 4.0

_, f_r = mad_clip(points[:, 0], sigma=4)
_, f_i = mad_clip(points[:, 0], sigma=4)

outliers = np.where(f_r + f_i)[0]

In [None]:
gap = 5

fig, ax = plt.subplots(figsize=(7, 5))

inliers = np.delete(points, outliers, axis=0)
sns.scatterplot(x=inliers[:, 0], y=inliers[:, 1], ax=ax, label='Inliers')
sns.scatterplot(x=points[outliers, 0], y=points[outliers, 1], color='red', ax=ax, \
                label='Outliers')
sns.scatterplot(x=[np.median(points[:, 0])], y=[np.median(points[:, 1])], color='brown', \
                ax=ax, label='Marginal Median', marker='+')

mads = mad_std(points, axis=0)
meds = np.median(points, axis=0)

ax.axvspan(meds[0]-(sigma+gap)*mads[0], meds[0]-sigma*mads[0], alpha=0.25, color='red', lw=0)
ax.axvspan(meds[0]+sigma*mads[0], meds[0]+(sigma+gap)*mads[0], alpha=0.25, color='red', lw=0)

rect = patches.Rectangle((meds[0]-sigma*mads[0], meds[1]-(sigma+gap)*mads[1]), 2*sigma*mads[0], \
                         gap*mads[1], linewidth=1, edgecolor='r', facecolor='red', \
                         alpha=0.25, lw=0)
ax.add_patch(rect)

rect = patches.Rectangle((meds[0]-sigma*mads[0], meds[1]+sigma*mads[1]), 2*sigma*mads[0], \
                         gap*mads[1], linewidth=1, edgecolor='r', facecolor='red', \
                         alpha=0.25, lw=0)
ax.add_patch(rect)

ax.annotate(r'$\mathrm{MAD}_{\mathfrak{Re}} \;$ = '+'{0:.3f}'.format(mads[0]), xy=(0.02, 0.10), \
            xycoords='axes fraction', bbox=dict(boxstyle='round', facecolor='white'), size=10, \
            color='orange')

ax.annotate(r'$\mathrm{MAD}_{\mathfrak{Im}}$ = '+'{0:.3f}'.format(mads[1]), xy=(0.02, 0.03), \
            xycoords='axes fraction', bbox=dict(boxstyle='round', facecolor='white'), size=10, \
            color='orange')

ax.annotate(r'$ \mathrm{med} \pm 4 \times \mathrm{MAD} $', xy=(0.8, 0.065), \
            xycoords='axes fraction', bbox=dict(boxstyle='round', facecolor='white'), size=10, \
            color='red')

ax.set_xlabel(real_lab)
ax.set_ylabel(imag_lab)

# pad = 1.5
# ax.set_xlim(left=meds[0]-pad*sigma*mads[0], right=meds[0]+pad*sigma*mads[0])
# ax.set_ylim(bottom=meds[1]-pad*sigma*mads[1], top=meds[1]+pad*sigma*mads[1])
ax.set_xlim(left=rmd_lims[0][0], right=rmd_lims[0][1])
ax.set_ylim(bottom=rmd_lims[1][0], top=rmd_lims[1][1])

ax.legend(loc='upper right')
plt.tight_layout()

plt.show()

In [None]:
_, f_r = mad_clip(data.real, axis=0, sigma=4)
_, f_i = mad_clip(data.imag, axis=0, sigma=4)

mad_clip_f = f_r + f_i
mad_clip_f = mad_clip_f ^ flags

In [None]:
mad_clip_f.sum()

#### rMD-clipping 

In [None]:
# the probability that a normal deviate lies in the range between mu - n*sigma and mu + n*sigma
chi2_quantile = special.erf(sigma/np.sqrt(2))
chi2_thresh = stats.chi2.ppf(chi2_quantile, 2)
print('chi^2 quantile corresponding to {}*sigma is {:.7f}'.format(sigma, chi2_quantile))

In [None]:
import multiprocess as multiprocessing

# require a shared ctype array in order to fill in a numpy array in parallel

def create_mp_array(arr):
    shared_arr = multiprocessing.RawArray(np.ctypeslib.as_ctypes_type(arr.dtype), int(np.prod(arr.shape)))
    new_arr = np.frombuffer(shared_arr, arr.dtype).reshape(arr.shape)  # shared_arr and new_arr the same memory
    new_arr[...] = arr
    return shared_arr, new_arr

def mp_init(shared_arr_, sharred_arr_shape_, sharred_arr_dtype_):
    global shared_arr, sharred_arr_shape, sharred_arr_dtype
    shared_arr = shared_arr_
    sharred_arr_shape = sharred_arr_shape_
    sharred_arr_dtype = sharred_arr_dtype_

def mp_iter(s):
    d = data[:, s[0], s[1], s[2]]
    if not np.isnan(d).all():
        
        isfinite = np.isfinite(d).nonzero()[0]
        d = decomposeCArray(flt_nan(d))
        robust_cov = MinCovDet(random_state=0).fit(d)
        outliers = robust_cov.mahalanobis(d) > chi2_thresh

        rmd_clip_f = np.frombuffer(shared_arr, dtype).reshape(shape)
        rmd_clip_f[isfinite, s[0], s[1], s[2]] = outliers

In [None]:
rmd_clip_f = np.ones_like(data, dtype=bool)
d_shared, rmd_clip_f = create_mp_array(rmd_clip_f)
dtype = rmd_clip_f.dtype
shape = rmd_clip_f.shape

m_pool = multiprocessing.Pool(multiprocessing.cpu_count(), initializer=mp_init, \
                              initargs=(d_shared, dtype, shape))
_ = m_pool.map(mp_iter, np.ndindex(data.shape[1:]))
m_pool.close()
m_pool.join()

rmd_clip_f = rmd_clip_f ^ flags

In [None]:
fig, ax = plt.subplots(figsize=(5, 5), nrows=2, sharex=True)

mad_im = mad_clip_f.sum(axis=(0, 3)).T
rmd_im = rmd_clip_f.sum(axis=(0, 3)).T
tot_arr = np.concatenate((mad_im, rmd_im))
vmin = None#np.min(tot_arr)
vmax = None#np.max(tot_arr)

ax[0].imshow(mad_im, aspect='auto', interpolation='none', vmin=vmin, vmax=vmax)
ax[1].imshow(rmd_im, aspect='auto', interpolation='none', vmin=vmin, vmax=vmax)

plt.tight_layout()
plt.show()