<center><strong><font size=+3>Gaussian Process Interpolation of HERA data</font></center>
<br><br>
</center>
<center><strong><font size=+2>Matyas Molnar and Bojan Nikolic</font><br></strong></center>
<br><center><strong><font size=+1>Astrophysics Group, Cavendish Laboratory, University of Cambridge</font></strong></center>

In [None]:
import itertools
import os

import matplotlib as mpl
import numpy as np
from matplotlib import cm, colors, ticker
from matplotlib import pyplot as plt
from scipy import signal
from scipy.interpolate import griddata, interp1d
from sklearn import gaussian_process as gp
from sklearn import preprocessing

from robstat.ml import extrem_nans, nan_interp2d
from robstat.stdstat import rsc_mean
from robstat.utils import DATAPATH, decomposeCArray

import uvtools

In [None]:
%matplotlib inline

In [None]:
mpl.rcParams['figure.dpi'] = 175
mpl.rcParams['figure.figsize'] = (5, 3)

mpl.rc('font',**{'family':'serif','serif':['cm']})
mpl.rc('text', usetex=True)
mpl.rc('text.latex', preamble=r'\usepackage{amssymb} \usepackage{amsmath}')

save_fig_dir = '/lustre/aoc/projects/hera/mmolnar/figs'

### Load HERA dataset

In [None]:
xd_vis_file = os.path.join(DATAPATH, 'lstb_no_avg/idr2_lstb_14m_ee_1.40949.npz')
sample_xd_data = np.load(xd_vis_file)

xd_data = sample_xd_data['data']
xd_redg = sample_xd_data['redg']
xd_pol = sample_xd_data['pol'].item()

# data dimensions (2xdays, freqs, times, bls)
xd_flags = np.isnan(xd_data)
xd_data[xd_flags] *= np.nan  # multiply to also hit the imag part
chans = plt_chans = np.arange(xd_data.shape[1])
freqs = np.linspace(1e8, 2e8, 1025)[:-1]
JDs = sample_xd_data['JDs']
f_resolution = np.median(np.ediff1d(freqs))
lsts = sample_xd_data['lsts']

band_1 = [175, 334]
band_2 = [515, 694]

band_i = band_2  # select band here

xd_data = xd_data[:, band_i[0]:band_i[1]+1, ...]
xd_flags = xd_data[:, band_i[0]:band_i[1]+1, ...]
freqs = freqs[band_i[0]:band_i[1]+1]
chans = chans[band_i[0]:band_i[1]+1]
no_chans = xd_data.shape[1]

In [None]:
bl_grp = 0 # look at 14m 60deg baseline - has more structure, 14m EW too flat

slct_bl_idxs = np.where(xd_redg[:, 0] == bl_grp)[0]
data = xd_data[..., slct_bl_idxs]
flags = xd_flags[..., slct_bl_idxs]
slct_red_bl = xd_redg[slct_bl_idxs[0], :][1:]
xd_data_bls = xd_data[..., slct_bl_idxs]
no_bls = slct_bl_idxs.size
print('Looking at baselines redundant to ({}, {}, \'{}\')'.\
      format(*slct_red_bl, xd_pol))

### Format & clean dataset

In [None]:
# remove baselines with only nan entries
nan_bls = np.where(np.isnan(xd_data_bls).all(axis=(0, 1, 2)))[0]
no_bls = no_bls - nan_bls.size
xd_data_bls = np.delete(xd_data_bls, nan_bls, axis=3)

In [None]:
# percentage of data flagged
print(round(np.isnan(xd_data_bls).sum() / xd_data_bls.size * 100, 3))

In [None]:
# find data slice with few nan values
sum_nans = np.isnan(xd_data_bls).sum(axis=(1, 2))
ok = np.unravel_index(sum_nans.argmin(), sum_nans.shape)

test_data = xd_data_bls[ok[0], ..., ok[1]]

In [None]:
# remove frequencies at extremities with only nan entries
nan_chans = extrem_nans(np.isnan(test_data).all(axis=(1)))
if nan_chans.size != 0:
    flt_chans = np.delete(chans, nan_chans)
    flt_freqs = np.delete(freqs, nan_chans)
    test_data = np.delete(test_data, nan_chans, axis=0)
else:
    flt_chans = chans
    flt_freqs = freqs
    
# remove time integrations at extremities with only nan entries
nan_tints = extrem_nans(np.isnan(test_data).all(axis=(0)))
if nan_tints.size != 0:
    test_data = np.delete(test_data, nan_tints, axis=1)

### Gaussian Process regression

#### 1D real valued GP

In [None]:
sample_data = np.abs(test_data[:, 0])
print('Flagged data points at channels {}'.format(flt_chans[np.where(np.isnan(sample_data))[0]].tolist()))

nans = np.isnan(sample_data)
nan_loc = lambda z: z.nonzero()[0]

# interpolate nan values just for this plot and for noise estimate
f_vis = interp1d(flt_freqs[~nans], sample_data[~nans], kind='cubic', bounds_error=False)
sample_data_i = sample_data.copy()
sample_data_i[nans] = f_vis(flt_freqs[nans])

fig, ax = plt.subplots()

ax.scatter(flt_freqs/1e6, sample_data, s=3)
ax.scatter(flt_freqs[nans]/1e6, sample_data_i[nans], s=3, label='interp values')

ax.set_xlabel('Frequency [MHz]')
ax.set_ylabel(r'$\left| V \right|$')
ax.legend(loc='upper right', prop={'size': 6})

fig.tight_layout()
plt.show()

In [None]:
# quick & dirty noise estimate - can get better values from the data files
noise_dest = np.var(np.abs(np.ediff1d(sample_data_i)))
print('Dirty noise estimate: {}'.format(round(noise_dest, 3)))

In [None]:
const_val = 1 # np.nanmean(sample_data)**2
const_min = 1e-1  # 10**np.floor(np.log10(sample_data[~nans].min()))
const_max = 5e1   # 10**(np.ceil(np.log10(sample_data[~nans].max()))+1)

kernel = gp.kernels.ConstantKernel(constant_value=const_val, \
         constant_value_bounds=(const_min, const_max)) * \
         gp.kernels.RBF(length_scale=f_resolution*10, length_scale_bounds=(1e5, 1e7)) + \
         gp.kernels.WhiteKernel(noise_level=1e-1, noise_level_bounds=(5e-2, 2e+1))

model = gp.GaussianProcessRegressor(kernel=kernel, n_restarts_optimizer=100, normalize_y=True)

model.fit(flt_freqs[~nans].reshape(-1, 1), sample_data[~nans])
print(model.kernel_)

In [None]:
c_kern_params = model.kernel_.get_params()['k1'].get_params()['k1'].get_params()
n_kern_params = model.kernel_.get_params()['k2']
c_kern_params

In [None]:
freq_pred = np.linspace(flt_freqs[0], flt_freqs[-1], 1000)
y_pred, std = model.predict(freq_pred[:, np.newaxis], return_std=True)

In [None]:
fig, ax = plt.subplots()

ax.scatter(flt_freqs/1e6, sample_data, s=2, color='orange', zorder=3)
ax.plot(freq_pred/1e6, y_pred, zorder=2, label='GP')

# sigma regions
for sigma, c in zip([3, 2, 1], ['lightgray', 'darkgray', 'gray']):
    ax.fill_between(freq_pred/1e6, y_pred-sigma*std, y_pred+sigma*std, \
                    color=c, alpha=0.5, label=rf'$\pm {sigma} \sigma$')

ax.set_xlabel('Frequency [MHz]')
ax.set_ylabel(r'$\left| V \right|$')

ax.legend(loc='lower right', prop={'size': 8})

fig.tight_layout()
plt.show()

##### Explore hyperparameter space

In [None]:
# make a grid of different hyperparameter values to explore the marginal likelihood landscape
lsv_log_min = np.log10(f_resolution)
lsv_log_max = np.log10(flt_freqs[-1] - flt_freqs[0])

nlv_log_min = np.log10(n_kern_params.get_params()['noise_level']) + 1
nlv_log_max = np.log10(n_kern_params.get_params()['noise_level']) - 1

lsv = np.logspace(lsv_log_min, lsv_log_max, 100)
nlv = np.logspace(nlv_log_min, nlv_log_max, 100)
cv = c_kern_params['constant_value']  # fix this

margloglik = np.empty((lsv.size, nlv.size))

for i, l in enumerate(lsv):
    for j, n in enumerate(nlv):
        kernel = gp.kernels.ConstantKernel(constant_value=cv) * \
                 gp.kernels.RBF(length_scale=l) + \
                 gp.kernels.WhiteKernel(noise_level=n)

        model_ij = gp.GaussianProcessRegressor(kernel=kernel, optimizer=None, normalize_y=True)
        model_ij.fit(flt_freqs[~nans].reshape(-1, 1), sample_data[~nans])

        margloglik[i, j] = model_ij.log_marginal_likelihood()

In [None]:
L, N = np.meshgrid(lsv, nlv)

fig, ax = plt.subplots()

levels = np.logspace(np.log10(-margloglik.max()), np.log10(-margloglik.min()), 30)
cp = ax.contour(L, N, -margloglik.T, levels=levels, cmap='viridis')

idx_max = np.unravel_index(np.argmax(margloglik), margloglik.shape)
ax.scatter(lsv[idx_max[0]], nlv[idx_max[1]], color='orange')

norm = colors.Normalize(vmin=cp.cvalues.min(), vmax=cp.cvalues.max())
sm = plt.cm.ScalarMappable(norm=norm, cmap=cp.cmap)
cb = fig.colorbar(sm)#, label=r'$-\ln(\mathcal{L})$')
cb.ax.invert_yaxis()

ax.set_xscale('log')
ax.set_yscale('log')

ax.set_title('Log-Marginal Likelihood')
ax.set_xlabel('Length scale')
ax.set_ylabel('Noise level')

plt.tight_layout()
plt.show()

In [None]:
# alternative way of calculating - faster
length_scale_grid, noise_level_grid = np.meshgrid(lsv, nlv)

log_marginal_likelihood = [model.log_marginal_likelihood(theta=np.log([cv, l, n])) \
    for l, n in zip(length_scale_grid.ravel(), noise_level_grid.ravel())]

log_marginal_likelihood = np.reshape(log_marginal_likelihood, newshape=noise_level_grid.shape)

In [None]:
fig, ax = plt.subplots()

levels = np.logspace(np.log10(-log_marginal_likelihood.max()), np.log10(-log_marginal_likelihood.min()), 30)
cp = ax.contour(length_scale_grid, noise_level_grid, -log_marginal_likelihood, levels=levels, cmap='viridis')

idx_max = np.unravel_index(np.argmax(log_marginal_likelihood), log_marginal_likelihood.shape)
ax.scatter(length_scale_grid[idx_max], noise_level_grid[idx_max], color='orange')

norm = colors.Normalize(vmin=cp.cvalues.min(), vmax=cp.cvalues.max())
sm = plt.cm.ScalarMappable(norm=norm, cmap=cp.cmap)
cb = fig.colorbar(sm)#, label=r'$-\ln(\mathcal{L})$')
cb.ax.invert_yaxis()

ax.set_xscale('log')
ax.set_yscale('log')

ax.set_title('Log-Marginal Likelihood')
ax.set_xlabel('Length scale')
ax.set_ylabel('Noise level')

plt.tight_layout()
plt.show()

In [None]:
# cubic interpolation of MLL space to get finer detail
no_interp_points = 1000
l_log_i = np.logspace(lsv_log_min, lsv_log_max, no_interp_points)
n_log_i = np.logspace(nlv_log_min, nlv_log_max, no_interp_points)
L_i, N_i = np.meshgrid(l_log_i, n_log_i)
grid_z = griddata((L.ravel(), N.ravel()), margloglik.ravel(), (L_i, N_i), method='cubic', rescale=True)

In [None]:
fig, axes = plt.subplots(ncols=2, sharey=True, constrained_layout=True)

im1 = axes[0].imshow(-margloglik, cmap='viridis', interpolation=None, \
    extent=[lsv.min(), lsv.max(), nlv.min(), nlv.max()], aspect='auto')
im2 = axes[1].imshow(-grid_z, cmap='viridis', interpolation=None, \
    extent=[l_log_i.min(), l_log_i.max(), n_log_i.min(), n_log_i.max()], aspect='auto')

axes[0].set_title('Original points', size=9)
axes[1].set_title('Cubic interpolation', size=9)
axes[0].set_xlabel('Length scale', size=8)
axes[1].set_xlabel('Length scale', size=8)
axes[0].set_ylabel('Noise level', size=8)

# for ax in axes:
#     ax.set_xscale('log')
#     ax.set_yscale('log')

cb = fig.colorbar(im2, ax=axes.ravel(), label=r'$-\ln(\mathcal{L})$')
cb.ax.invert_yaxis()

plt.show()

In [None]:
# hyperparameters with highest MLLs
interp_max_idx = np.unravel_index(np.argmax(grid_z), grid_z.shape)
print('{:e} {:e}'.format(l_log_i[interp_max_idx[0]], n_log_i[interp_max_idx[1]]))

In [None]:
# compare to ones found from optimization
model.kernel_

In [None]:
fig = plt.figure(figsize=(6, 4))

ax = plt.axes(projection='3d')
# ax.plot_surface(L, N, -margloglik.T, cmap='plasma_r')
ax.plot_surface(np.log10(L), np.log10(N), -margloglik.T, cmap='plasma_r')

ax.set_xlabel('Log Length Scale')
ax.set_ylabel('Log Noise Level')
ax.set_zlabel(r'$-\ln(\mathcal{L}_{\mathrm{marg}})$', rotation=90)

ax.view_init(azim=45)

plt.tight_layout()
plt.show()

In [None]:
# replot with more accurate minimum location:
fig, ax = plt.subplots()

levels = np.logspace(np.log10(-log_marginal_likelihood.max()), np.log10(-log_marginal_likelihood.min()), 30)
cp = ax.contour(length_scale_grid, noise_level_grid, -log_marginal_likelihood, levels=levels, cmap='viridis')

ax.scatter(l_log_i[interp_max_idx[0]], n_log_i[interp_max_idx[1]], color='orange')

norm = colors.Normalize(vmin=cp.cvalues.min(), vmax=cp.cvalues.max())
sm = plt.cm.ScalarMappable(norm=norm, cmap=cp.cmap)
cb = fig.colorbar(sm)#, label=r'$-\ln(\mathcal{L})$')
cb.ax.invert_yaxis()

ax.set_xscale('log')
ax.set_yscale('log')

ax.set_title('Log-Marginal Likelihood')
ax.set_xlabel('Length scale')
ax.set_ylabel('Noise level')

plt.tight_layout()
plt.show()

#### Complex GP

In [None]:
c_tdata = decomposeCArray(test_data[:, 0])  # multiple target for complex numbers
c_nans = np.isnan(sample_data)

In [None]:
const_val = 100
const_min = 5e2 # np.max(10**np.floor(np.log10(np.max([np.nanmin(c_tdata), 1]))))
const_max = 5e3  # 10**(np.ceil(np.log10(np.nanmax(c_tdata)))+1)

kernel = gp.kernels.ConstantKernel(constant_value=const_val, \
         constant_value_bounds=(const_min, const_max)) * \
         gp.kernels.RBF(length_scale=f_resolution*10, length_scale_bounds=(1e5, 1e7)) + \
         gp.kernels.WhiteKernel(noise_level=1e1, noise_level_bounds=(1e1, 5e2))

# for some reason, normalize_y=True gives difference noises for Re and Im - avoid
model = gp.GaussianProcessRegressor(kernel=kernel, n_restarts_optimizer=100, normalize_y=False)

model.fit(flt_freqs[~nans].reshape(-1, 1), c_tdata[~nans])
print(model.kernel_)

In [None]:
y_c_pred, c_std = model.predict(freq_pred[:, np.newaxis], return_std=True)
y_c_pred = y_c_pred[:, 0] + 1j*y_c_pred[:, 1]
c_std = c_std[:, 0] + 1j*c_std[:, 1]

In [None]:
fig, ax = plt.subplots(figsize=(7, 5))

real_lab = r'$\mathfrak{Re}$'
imag_lab = r'$\mathfrak{Im}$'

# data
ax.plot(freq_pred/1e6, y_c_pred.real, zorder=2, color='orange', label=real_lab)
ax.scatter(flt_freqs[~nans]/1e6, c_tdata[:, 0][~nans], s=2, color='blue', zorder=3)

ax.plot(freq_pred/1e6, y_c_pred.imag, zorder=2, color='purple', label=imag_lab)
ax.scatter(flt_freqs[~nans]/1e6, c_tdata[:, 1][~nans], s=2, color='red', zorder=3)

# sigma regions
for sigma, c in zip([3, 2, 1], ['lightgray', 'darkgray', 'gray']):
    ax.fill_between(freq_pred/1e6, y_c_pred.real-sigma*c_std.real, y_c_pred.real+sigma*c_std.real, \
                    color=c, alpha=0.5, label=rf'$\pm {sigma} \sigma$')
    ax.fill_between(freq_pred/1e6, y_c_pred.imag-sigma*c_std.imag, y_c_pred.imag+sigma*c_std.imag, \
                    color=c, alpha=0.5)

ax.set_xlabel('Frequency [MHz]')
ax.set_ylabel(r'$V$')

ax.legend(loc='best')#prop={'size': 6})

fig.tight_layout()
plt.show()

In [None]:
c_kern_params = model.kernel_.get_params()['k1'].get_params()['k1'].get_params()
n_kern_params = model.kernel_.get_params()['k2']

lsv_log_min = np.log10(f_resolution)
lsv_log_max = np.log10(flt_freqs[-1] - flt_freqs[0])

nlv_log_min = np.log10(n_kern_params.get_params()['noise_level']) + 1
nlv_log_max = np.log10(n_kern_params.get_params()['noise_level']) - 1

lsv = np.logspace(lsv_log_min, lsv_log_max, 300)
nlv = np.logspace(nlv_log_min, nlv_log_max, 300)
cv = c_kern_params['constant_value']  # fix this


length_scale_grid, noise_level_grid = np.meshgrid(lsv, nlv)

log_marginal_likelihood = [model.log_marginal_likelihood(theta=np.log([cv, l, n])) \
    for l, n in zip(length_scale_grid.ravel(), noise_level_grid.ravel())]

log_marginal_likelihood = np.reshape(log_marginal_likelihood, newshape=noise_level_grid.shape)


fig, ax = plt.subplots(figsize=(7, 3.5))

levels = np.logspace(np.log10(-log_marginal_likelihood.max()), np.log10(-log_marginal_likelihood.min()), 30)
cp = ax.contour(length_scale_grid, noise_level_grid, -log_marginal_likelihood, levels=levels, cmap='viridis')

idx_max = np.unravel_index(np.argmax(log_marginal_likelihood), log_marginal_likelihood.shape)
ax.scatter(length_scale_grid[idx_max], noise_level_grid[idx_max], color='orange')

norm = colors.Normalize(vmin=cp.cvalues.min(), vmax=cp.cvalues.max())
sm = plt.cm.ScalarMappable(norm=norm, cmap=cp.cmap)
cb = fig.colorbar(sm)
cb.ax.invert_yaxis()

ax.set_xscale('log')
ax.set_yscale('log')

ax.set_title('Log-Marginal Likelihood')
ax.set_xlabel('Length Scale [MHz]')
ax.set_ylabel('Noise Level [Jy$^2$]')

plt.tight_layout()
# plt.savefig(os.path.join(save_fig_dir, 'GP_logl_f.pdf'), bbox_inches='tight')
plt.show()

In [None]:
length_scale_grid[idx_max]/1e6, noise_level_grid[idx_max]

In [None]:
# compare to what model found
model.kernel_

#### Add time as a feature

In [None]:
test_data_r = test_data.copy()

restrict = False
if restrict:  # test case
    ridx1, ridx2 = 35, 45
    test_data_r = test_data_r[:ridx1, :ridx2]

flt_freqs_r = flt_freqs[:test_data_r.shape[0]]

In [None]:
if test_data_r.shape[0] >= test_data_r.shape[1]:
    ncols = 1
    nrows = 2
    sharex = True
    sharey = False
else:
    ncols = 2
    nrows = 1
    sharex = False
    sharey = True

fig, axes = plt.subplots(ncols=ncols, nrows=nrows, sharex=sharex, sharey=sharey, \
                         constrained_layout=True)

extent_r = [flt_freqs_r.min()/1e6, flt_freqs_r.max()/1e6, test_data_r.shape[1], 0]

im1 = axes[0].imshow(test_data_r.real.T, extent=extent_r, aspect='auto', interpolation='none')
im2 = axes[1].imshow(test_data_r.imag.T, extent=extent_r, aspect='auto', interpolation='none')

axes[0].set_ylabel('Time Integration', size=8)
axes[1].set_xlabel('Frequency [MHz]', size=8)
if sharex:
    axes[1].set_ylabel('Time Integration', size=8)
else:
    axes[0].set_xlabel('Frequency [MHz]', size=8)
    
cb1 = fig.colorbar(im1, ax=axes[0], label=r'$\mathfrak{Re}(V)$', pad=0.025, aspect=15)
cb2 = fig.colorbar(im2, ax=axes[1], label=r'$\mathfrak{Im}(V)$', pad=0.025, aspect=15)

plt.show()

In [None]:
# preprocess data by standard scaling, i.e. remove the mean and scale to unit variance
X_us = np.array(np.meshgrid(flt_freqs[:test_data_r.shape[0]], \
                            np.arange(test_data_r.shape[1]))).T.reshape(-1, 2)
scaler = preprocessing.StandardScaler().fit(X_us)
X_s = scaler.transform(X_us)

Y = decomposeCArray(test_data_r.ravel(order='C'))

nans_a = np.isnan(test_data_r.ravel(order='C'))
X, Y = X_s[~nans_a, :], Y[~nans_a, :]

In [None]:
# different hyperparameters as X scales standardized
const_min = 2e2 # np.max(10**np.floor(np.log10(np.nanmax([np.nanmin(Y), 1]))))
const_max = 5e3 # 10**(np.ceil(np.log10(np.nanmax(Y)))+1)
noise_est = np.nanstd(test_data_r)

kernel = gp.kernels.ConstantKernel(constant_value=5e2, \
         constant_value_bounds=(const_min, const_max)) * \
         gp.kernels.RBF(length_scale=1e-1, length_scale_bounds=(1e-2, 2e1)) + \
         gp.kernels.WhiteKernel(noise_level=noise_est, noise_level_bounds=(1e1, 1e2))

model = gp.GaussianProcessRegressor(kernel=kernel, n_restarts_optimizer=25, normalize_y=False)

model.fit(X, Y)
print(model.kernel_)

In [None]:
f_min, f_max = np.min(np.min(X[:, 0])), np.max(np.max(X[:, 0]))
t_min, t_max = np.min(np.min(X[:, 1])), np.max(np.max(X[:, 1]))
dim_pred = np.array(np.meshgrid(np.linspace(f_min, f_max, 100), \
                                np.linspace(t_min, t_max, 100))).T.reshape(-1, 2)

In [None]:
y_ir_pred, c_std = model.predict(dim_pred, return_std=True)

In [None]:
dim_pred_itr = scaler.inverse_transform(dim_pred)
coords = dim_pred_itr.reshape(100, 100, -1)

pdata = (y_ir_pred[:, 0] + 1j*y_ir_pred[:, 1]).reshape(100, 100)
pstd = (c_std[:, 0] + 1j*c_std[:, 1]).reshape(100, 100)

In [None]:
# look at a time slice
fig, ax = plt.subplots(figsize=(7, 3.5))

st = 0

ax.plot(coords[:, 0, 0]/1e6, pdata.real[:, st], color='orange', label=real_lab)
ax.plot(coords[:, 0, 0]/1e6, pdata.imag[:, st], color='purple', label=imag_lab)
ax.scatter(flt_freqs[:test_data_r.shape[0]]/1e6, test_data_r[:, st].real, s=2, color='blue', zorder=3)
ax.scatter(flt_freqs[:test_data_r.shape[0]]/1e6, test_data_r[:, st].imag, s=2, color='red', zorder=3)

# sigma regions
for sigma, c in zip([3, 2, 1], ['lightgray', 'darkgray', 'gray']):
    ax.fill_between(coords[:, 0, 0]/1e6, pdata.real[:, st]-sigma*pstd.real[:, st], \
                    pdata.real[:, st]+sigma*pstd.real[:, st], color=c, alpha=0.5, \
                    label=rf'$\pm {sigma} \sigma$')
    ax.fill_between(coords[:, 0, 0]/1e6, pdata.imag[:, st]-sigma*pstd.imag[:, st], \
                    pdata.imag[:, st]+sigma*pstd.imag[:, st], color=c, alpha=0.5)

ax.set_xlabel('Frequency [MHz]')
ax.set_ylabel(r'$V$ [Jy]')
ax.legend(loc='lower right', prop={'size': 9})

fig.tight_layout()
# plt.savefig(os.path.join(save_fig_dir, 'GP_regression_2d_f.pdf'), bbox_inches='tight')
plt.show()

In [None]:
# look at a frequency slice
fig, ax = plt.subplots(figsize=(7, 3.5))

ax.plot(coords[0, :, 1], pdata.real[0, :], color='orange', label=real_lab)
ax.plot(coords[0, :, 1], pdata.imag[0, :], color='purple', label=imag_lab)
ax.scatter(np.arange(test_data_r.shape[1]), test_data_r[0, :].real, s=2, color='blue', zorder=3)
ax.scatter(np.arange(test_data_r.shape[1]), test_data_r[0, :].imag, s=2, color='red', zorder=3)

# sigma regions
for sigma, c in zip([3, 2, 1], ['lightgray', 'darkgray', 'gray']):
    ax.fill_between(coords[0, :, 1], pdata.real[0, :]-sigma*pstd.real[0, :], pdata.real[0, :]+sigma*pstd.real[0, :], \
                    color=c, alpha=0.5, label=rf'$\pm {sigma} \sigma$')
    ax.fill_between(coords[0, :, 1], pdata.imag[0, :]-sigma*pstd.imag[0, :], pdata.imag[0, :]+sigma*pstd.imag[0, :], \
                    color=c, alpha=0.5)

ax.set_xlabel('Time Integration')
ax.set_ylabel(r'$V$ [Jy]')
ax.legend(loc='best', prop={'size': 9})

fig.tight_layout()
# plt.savefig(os.path.join(save_fig_dir, 'GP_regression_2d_t.pdf'), bbox_inches='tight')
plt.show()

In [None]:
# GP interpolation of nan data
Y_nan = model.predict(X_s[nans_a, :])
Y_nan = Y_nan[:, 0] + 1j*Y_nan[:, 1]

test_data_r_filled = test_data_r.ravel(order='C')
test_data_r_filled[nans_a] = Y_nan
test_data_r_filled = test_data_r_filled.reshape(test_data_r.shape)

In [None]:
# compare to 2D cubic interpolation
test_data_r_ci = nan_interp2d(test_data_r)

In [None]:
from mpl_toolkits.axes_grid1 import AxesGrid

fig = plt.figure(figsize=(7, 5), dpi=600)

grid = AxesGrid(fig, 111, nrows_ncols=(2, 2), axes_pad=0.4, share_all=True, cbar_location='right', \
                cbar_mode='each', cbar_size=0.15, cbar_pad=0.15, direction='row', aspect=False)

# GP
im0 = grid[0].imshow(test_data_r_filled.real.T, extent=extent_r, aspect='auto')
im2 = grid[2].imshow(test_data_r_filled.imag.T, extent=extent_r, aspect='auto')

# resid
r1 = test_data_r_filled.real.T - test_data_r_ci.real.T
r2 = test_data_r_filled.imag.T - test_data_r_ci.imag.T
vmax1 = np.max(np.abs(r1))
vmin1 = -vmax1
vmax2 = np.max(np.abs(r2))
vmin2 = -vmax2
im1 = grid[1].imshow(r1, extent=extent_r, aspect='auto', cmap='bwr', vmin=vmin1, vmax=vmax1)
im3 = grid[3].imshow(r2, extent=extent_r, aspect='auto', cmap='bwr', vmin=vmin2, vmax=vmax2)

cb0 = grid.cbar_axes[0].colorbar(im0)
cb0.ax.set_title(r'$\mathfrak{Re}(V)$', size=8)
cb2 = grid.cbar_axes[2].colorbar(im2)
cb2.ax.set_title(r'$\mathfrak{Im}(V)$', size=8)

cb1 = grid.cbar_axes[1].colorbar(im1,)
cb1.ax.set_title(r'$\Delta \mathfrak{Re}(V)$', size=8)
cb3 = grid.cbar_axes[3].colorbar(im3)
cb3.ax.set_title(r'$\Delta \mathfrak{Im}(V)$', size=8)

grid[0].set_ylabel('Time Integration', size=8)
grid[2].set_ylabel('Time Integration', size=8)

grid[2].set_xlabel('Frequency [MHz]', size=8)
grid[3].set_xlabel('Frequency [MHz]', size=8)

grid[0].set_title('GP interpolation', size=8)
grid[1].set_title('GP - CI Residual', size=8)

plt.show()

In [None]:
fig, axes = plt.subplots(ncols=2, nrows=2, sharex=True, sharey=True, constrained_layout=True)

vmin1 = np.min([test_data_r_filled.real.min(), test_data_r_ci.real.min()])
vmax1 = np.max([test_data_r_filled.real.max(), test_data_r_ci.real.max()])
vmin2 = np.min([test_data_r_filled.imag.min(), test_data_r_ci.imag.min()])
vmax2 = np.max([test_data_r_filled.imag.max(), test_data_r_ci.imag.max()])

# GP
im1 = axes[0][0].imshow(test_data_r_filled.real.T, extent=extent_r, aspect='auto', vmin=vmin1, vmax=vmax1)
im2 = axes[1][0].imshow(test_data_r_filled.imag.T, extent=extent_r, aspect='auto', vmin=vmin2, vmax=vmax2)

# cubic interpolation
axes[0][1].imshow(test_data_r_ci.real.T, extent=extent_r, aspect='auto', vmin=vmin1, vmax=vmax1)
axes[1][1].imshow(test_data_r_ci.imag.T, extent=extent_r, aspect='auto', vmin=vmin2, vmax=vmax2)

for ax, col in zip(axes[0], ['GP Interpolation', 'Cubic Interpolation']):
    ax.set_title(col, size=8)
    
for ax in axes:
    ax[0].set_ylabel('Time Integration', size=8)
    
axes[1][0].set_xlabel('Frequency [MHz]', size=8)
axes[1][1].set_xlabel('Frequency [MHz]', size=8)

cb1 = fig.colorbar(im1, ax=axes[0].ravel(), label=r'$\mathfrak{Re}(V)$', pad=0.04, aspect=15)
cb2 = fig.colorbar(im2, ax=axes[1].ravel(), label=r'$\mathfrak{Im}(V)$', pad=0.04, aspect=15)

plt.show()

##### Performance against cubic interpolation

In [None]:
# sample non-nan data that has been median/HERA mean averaged across JDs
nn_chan_idxs = np.unique(np.where(np.isnan(xd_data_bls[..., 0]).all(axis=(0)))[0])
if nn_chan_idxs.size != 0:
    largest_gap = np.argmax(np.ediff1d(nn_chan_idxs))
    chan_se = nn_chan_idxs[[largest_gap, largest_gap+1]]
    
#     test_data_m = np.nanmedian(xd_data_bls[:, chan_se[0]+1:chan_se[1], \
#                                            :, 0], axis=0)
    test_data_m = rsc_mean(xd_data_bls[:, chan_se[0]+1:chan_se[1], :, 0], sigma=5, min_N=5, axis=0)
    flt_freqs_se = flt_freqs[chan_se[0]+1:chan_se[1]]
else:
#     test_data_m = np.nanmedian(xd_data_bls[..., 0], axis=0)
    test_data_m = rsc_mean(xd_data_bls[..., 0], sigma=5, min_N=5, axis=0)
    flt_freqs_se = flt_freqs

if restrict:
    test_data_m = test_data_m[:ridx1, :ridx2]

flt_freqs_m = flt_freqs[:test_data_m.shape[0]]

In [None]:
if test_data_m.shape[0] >= test_data_m.shape[1]:
    ncols = 1
    nrows = 2
    sharex = True
    sharey = False
else:
    ncols = 2
    nrows = 1
    sharex = False
    sharey = True

fig, axes = plt.subplots(ncols=ncols, nrows=nrows, sharex=sharex, sharey=sharey, \
                         constrained_layout=True)

extent_m = [flt_freqs_m.min()/1e6, flt_freqs_m.max()/1e6, test_data_m.shape[1]-1, 0]
axes[0].imshow(test_data_m.real.T, extent=extent_m, aspect='auto')
axes[1].imshow(test_data_m.imag.T, extent=extent_m, aspect='auto')

axes[0].set_title(real_lab, size=8)
axes[1].set_title(imag_lab, size=8)
axes[0].set_ylabel('Time Integration', size=8)
axes[1].set_xlabel('Frequency [MHz]', size=8)
if sharex:
    axes[1].set_ylabel('Time Integration', size=8)
else:
    axes[0].set_xlabel('Frequency [MHz]', size=8)
    
cb1 = fig.colorbar(im1, ax=axes[0], label=r'$\mathfrak{Re}(V)$', pad=0.04, aspect=15)
cb2 = fig.colorbar(im2, ax=axes[1], label=r'$\mathfrak{Im}(V)$', pad=0.04, aspect=15)

plt.show()

In [None]:
# artifically add holes in data
flg_no_point = test_data_m.size//10
flg_no_chans = test_data_m.shape[0]//10
flg_no_tints = test_data_m.shape[1]//10

test_data_m_n = test_data_m.copy()

np.random.seed(0)
# random flagged slices
rnd_idx = np.sort(np.random.choice(np.arange(test_data_m.size), flg_no_point, replace=False))
rnd_idxs = np.unravel_index(rnd_idx, test_data_m.shape)
test_data_m_n[rnd_idxs] *= np.nan

# random flagged freq rows and tint cols
flg_chans = np.sort(np.random.choice(np.arange(test_data_m.shape[0]), flg_no_chans, replace=False))
flg_tints = np.sort(np.random.choice(np.arange(test_data_m.shape[1]), flg_no_tints, replace=False))

if 0 in flg_chans:
    flg_chans = np.delete(flg_chans, 0)
if 0 in flg_tints:
    flg_tints = np.delete(flg_tints, 0)    
if test_data_m.shape[0] - 1 in flg_chans:
    flg_chans = np.delete(flg_chans, -1)
if test_data_m.shape[1] - 1 in flg_tints:
    flg_tints = np.delete(flg_tints, -1)

test_data_m_n[flg_chans, :] *= np.nan
test_data_m_n[:, flg_tints] *= np.nan

In [None]:
# randomly nan'd data points 
fig, axes = plt.subplots(ncols=ncols, nrows=nrows, constrained_layout=True)

axes[0].imshow(test_data_m_n.real.T, extent=extent_m, aspect='auto', interpolation='none')
axes[1].imshow(test_data_m_n.imag.T, extent=extent_m, aspect='auto', interpolation='none')

axes[0].set_title(real_lab, size=8)
axes[1].set_title(imag_lab, size=8)
axes[0].set_ylabel('Time Integration', size=8)
axes[1].set_xlabel('Frequency [MHz]', size=8)
if sharex:
    axes[1].set_ylabel('Time Integration', size=8)
else:
    axes[0].set_xlabel('Frequency [MHz]', size=8)
    
cb1 = fig.colorbar(im1, ax=axes[0], label=r'$\mathfrak{Re}(V)$', pad=0.04, aspect=15)
cb2 = fig.colorbar(im2, ax=axes[1], label=r'$\mathfrak{Im}(V)$', pad=0.04, aspect=15)

plt.show()

In [None]:
# what filter_half_widths window for time?

lsts_tr = lsts[:test_data_m.shape[1]]*12/np.pi

dlys, pspec = signal.periodogram(test_data_m[10, :], fs=1/np.median(np.ediff1d(lsts_tr)), \
    window='blackmanharris', scaling='spectrum', nfft=None, detrend=False, \
    return_onesided=False)

delay_sort = np.argsort(dlys)
dlys = dlys[delay_sort]
td_pspec = pspec[delay_sort]

fig, ax = plt.subplots(figsize=(7, 3))

ax.plot(dlys, td_pspec, alpha=0.8)

ax.set_ylabel(r'Power Spectrum [Jy$^2$ Hz$^2$]')
ax.set_yscale('log')
ax.set_xlabel(r'Fringe-rate [Hz]')

plt.tight_layout()
plt.show()

In [None]:
# cubic interpolation
# note that if a flagged data point is in a corner, this method cannot extrapolate to that point
test_data_m_n_ci = nan_interp2d(test_data_m_n, kind='cubic')

# CLEAN interpolation
data_2d = test_data_m_n.copy()
flags_2d = np.isnan(test_data_m_n)

# parameters
filter_centers = [[0.], [0.]] # center of rectangular fourier regions to filter
filter_half_widths = [[2e-6], [20]] # half-width of rectangular fourier regions to filter
mode = 'clean'

data_2d_tr = data_2d
freqs_tr = flt_freqs_m

data_2d_tr[flags_2d] = 0.
wgts = np.logical_not(flags_2d).astype(float) # real weights where flagged data has 0 weight

x = [freqs_tr, lsts_tr]

d_mdl_tr, _, info = uvtools.dspec.fourier_filter(x, data_2d_tr, wgts, \
    filter_centers, filter_half_widths, mode, filter_dims=(0, 1), skip_wgt=0., \
    zero_residual_flags=True)

In [None]:
# GP interpolation
X_us = np.array(np.meshgrid(flt_freqs_se[:test_data_m.shape[0]], \
                            np.arange(test_data_m_n.shape[1]))).T.reshape(-1, 2)
scaler = preprocessing.StandardScaler().fit(X_us)
X_s = scaler.transform(X_us)

Y = decomposeCArray(test_data_m_n.ravel(order='C'))

nans_a = np.isnan(test_data_m_n.ravel(order='C'))
X, Y = X_s[~nans_a, :], Y[~nans_a, :]

const_min = 2e2
const_max = 2e3
noise_est = np.nanstd(test_data_r)

kernel = gp.kernels.ConstantKernel(constant_value=5e2, \
         constant_value_bounds=(const_min, const_max)) * \
         gp.kernels.RBF(length_scale=1e1, length_scale_bounds=(1e-1, 2e1)) + \
         gp.kernels.WhiteKernel(noise_level=noise_est, noise_level_bounds=(1e-1, 2e1))

model_i = gp.GaussianProcessRegressor(kernel=kernel, n_restarts_optimizer=25, normalize_y=False)
model_i.fit(X, Y)
print(model_i.kernel_)

y_ir_pred, c_std = model_i.predict(X_s[nans_a, :], return_std=True)

pdata = (y_ir_pred[:, 0] + 1j*y_ir_pred[:, 1])

In [None]:
test_data_m_n_gpi = test_data_m_n.copy()
test_data_m_n_gpi[np.where(np.isnan(test_data_m_n))] = pdata

In [None]:
# results
fig, axes = plt.subplots(ncols=2, nrows=2, sharex=True, sharey=True, constrained_layout=True, \
                         figsize=(7, 5), dpi=600)

# actual data
axes[0][0].imshow(test_data_m_n.real.T, extent=extent_m, aspect='auto', interpolation='None', rasterized=True)
axes[1][0].imshow(test_data_m_n.imag.T, extent=extent_m, aspect='auto', interpolation='None', rasterized=True)

# GP
axes[0][1].imshow(test_data_m_n_gpi.real.T, extent=extent_m, aspect='auto', interpolation='None', rasterized=True)
axes[1][1].imshow(test_data_m_n_gpi.imag.T, extent=extent_m, aspect='auto', interpolation='None', rasterized=True)

for ax, col in zip(axes[0], ['Sample Data', 'GP Interpolation']):
    ax.set_title(col)
    
for ax, col in zip(axes, [real_lab, imag_lab]):
    ax[0].set_ylabel('Time Integration')
    
axes[1][0].set_xlabel('Frequency [MHz]')
axes[1][1].set_xlabel('Frequency [MHz]')

cb1 = fig.colorbar(im1, ax=axes[0].ravel(), label=r'$\mathfrak{Re}(V)$', pad=0.025, aspect=15)
cb2 = fig.colorbar(im2, ax=axes[1].ravel(), label=r'$\mathfrak{Im}(V)$', pad=0.025, aspect=15)

# save_fig_dir = '/lustre/aoc/projects/hera/mmolnar/figs'
# plt.savefig(os.path.join(save_fig_dir, 'gp_vis_grid.pdf'), bbox_inches='tight')

plt.show()

In [None]:
# residuals for the flagged data points
plt.figure()
ci_res = np.abs(test_data_m[rnd_idxs] - test_data_m_n_ci[rnd_idxs])
gpi_res = np.abs(test_data_m[rnd_idxs] - test_data_m_n_gpi[rnd_idxs])
plt.plot(ci_res, label='Cubic interp', lw=1.5, alpha=0.7)
plt.plot(gpi_res, label='GP interp', lw=1.5, alpha=0.7, color='orange')
plt.axhline(np.nanmean(ci_res), ls='--')
plt.axhline(gpi_res.mean(), ls='--', color='orange')
plt.ylabel(r'$\left| \Delta V \right|$')
plt.legend(loc='upper right', prop={'size': 8})
plt.tight_layout()
plt.show()

In [None]:
# CLEAN interpolation
data_2d = test_data_m_n.copy()
flags_2d = np.isnan(test_data_m_n)

# parameters
filter_centers = [0.]# [[0.], [0.]] # center of rectangular fourier regions to filter
filter_half_widths = [20e-6]# [[5e-6], [20]] # half-width of rectangular fourier regions to filter
mode = 'clean'

data_2d_tr = data_2d
freqs_tr = flt_freqs_m

data_2d_tr[flags_2d] = 0.
wgts = np.logical_not(flags_2d).astype(float)  # real weights where flagged data has 0 weight

x = freqs_tr#[freqs_tr, lsts_tr]

# can also try 2D clean
d_mdl_tr, _, info = uvtools.dspec.fourier_filter(x, data_2d_tr, wgts, \
    filter_centers, filter_half_widths, mode, filter_dims=(0), skip_wgt=0., \
    zero_residual_flags=True)

In [None]:
# cross-PS between all time integrations pairs

# don't want fully flagged time integrations
# ps_data = np.delete(test_data_m, flg_tints, axis=1)
ps_data = test_data_m_n.copy()
ps_data[np.isnan(test_data_m_n)] = d_mdl_tr[np.isnan(test_data_m_n)]
ps_data = np.delete(ps_data, flg_tints, axis=1)
gpi_ps_data = np.delete(test_data_m_n_gpi, flg_tints, axis=1)

tint_pairs = list(itertools.product(np.arange(ps_data.shape[1]), repeat=2))
tints1 = [i[0] for i in tint_pairs]
tints2 = [i[1] for i in tint_pairs]

delay, pspec = signal.csd(ps_data[..., tints1], ps_data[..., tints2], \
    fs=1/f_resolution, window='blackmanharris', scaling='spectrum', nfft=None, detrend=False, \
    return_onesided=False, nperseg=ps_data.shape[0], axis=0)

delay_sort = np.argsort(delay)
delay = delay[delay_sort]
pspec = pspec[delay_sort, :]

gpi_delay, gpi_pspec = signal.csd(gpi_ps_data[..., tints1], gpi_ps_data[..., tints2], \
    fs=1/f_resolution, window='blackmanharris', scaling='spectrum', nfft=None, detrend=False, \
    return_onesided=False, nperseg=gpi_ps_data.shape[0], axis=0)

delay_sort = np.argsort(gpi_delay)
gpi_delay = gpi_delay[delay_sort]
gpi_pspec = gpi_pspec[delay_sort, :]

pspec_mean = np.nanmean(pspec, axis=1)
gpi_pspec_mean = np.nanmean(gpi_pspec, axis=1)

In [None]:
fig, axes = plt.subplots(ncols=3, figsize=(7, 4), sharey=True, dpi=300)

axes[0].plot(delay*1e6, np.abs(pspec), alpha=0.3, rasterized=True)
axes[0].plot(delay*1e6, np.abs(pspec_mean), alpha=1, color='orange')
axes[0].set_ylabel('Power Spectrum [Jy$^2$ Hz$^2$]')

axes[1].plot(delay*1e6, np.abs(gpi_pspec), alpha=0.3, rasterized=True)
axes[1].plot(delay*1e6, np.abs(gpi_pspec_mean), alpha=1, color='purple')

axes[2].plot(delay*1e6, np.abs(pspec_mean), alpha=0.6, color='orange', label='CLEAN')
axes[2].plot(delay*1e6, np.abs(gpi_pspec_mean), alpha=0.6, color='purple', label='GP')

for ax in axes:
    ax.set_yscale('log')
    ax.set_xlabel(r'Delay [$\mu$s]')
    ax.set_xticks([-5, -2.5, 0, 2.5, 5])
    
axes[0].set_title('CLEAN')
axes[1].set_title('GP')
axes[2].set_title('Comparison')
axes[2].legend(loc='lower center', prop={'size': 8})

fig.tight_layout()
# save_fig_dir = '/lustre/aoc/projects/hera/mmolnar/figs'
# plt.savefig(os.path.join(save_fig_dir, 'gp_ps.pdf'), bbox_inches='tight')
plt.show()