<center><strong><font size=+3>Gaussian Process Interpolation of HERA data</font></center>
<br><br>
</center>
<center><strong><font size=+2>Matyas Molnar and Bojan Nikolic</font><br></strong></center>
<br><center><strong><font size=+1>Astrophysics Group, Cavendish Laboratory, University of Cambridge</font></strong></center>

In [None]:
import itertools
import os

import matplotlib as mpl
import numpy as np
from matplotlib import cm, colors, ticker
from matplotlib import pyplot as plt
from scipy import signal
from scipy.interpolate import griddata, interp1d
from sklearn import gaussian_process as gp
from sklearn import preprocessing

from robstat.ml import extrem_nans, nan_interp2d
from robstat.utils import DATAPATH, decomposeCArray

In [None]:
%matplotlib inline

In [None]:
mpl.rcParams['figure.dpi'] = 175
mpl.rcParams['figure.figsize'] = (5, 3)

### Load HERA dataset

In [None]:
xd_vis_file = os.path.join(DATAPATH, 'xd_vis_extd_rph.npz')
sample_xd_data = np.load(xd_vis_file)

In [None]:
xd_data = sample_xd_data['data'] # dimensions (days, freqs, times, bls)
xd_flags = sample_xd_data['flags']
xd_data[xd_flags] *= np.nan # multiply to also hit the imag part

xd_redg = sample_xd_data['redg']
xd_times = sample_xd_data['times']
xd_pol = sample_xd_data['pol'].item()
JDs = sample_xd_data['JDs']

freqs = sample_xd_data['freqs']
chans = sample_xd_data['chans']

f_resolution = np.median(np.ediff1d(freqs))

In [None]:
bl_grp = 0 # only look at 0th baseline group

slct_bl_idxs = np.where(xd_redg[:, 0] == bl_grp)[0]
data = xd_data[..., slct_bl_idxs]
flags = xd_flags[..., slct_bl_idxs]
slct_red_bl = xd_redg[slct_bl_idxs[0], :][1:]
xd_data_bls = xd_data[..., slct_bl_idxs]
no_bls = slct_bl_idxs.size
print('Looking at baselines redundant to ({}, {}, \'{}\')'.\
      format(*slct_red_bl, xd_pol))

### Format & clean dataset

In [None]:
# remove baselines with only nan entries
nan_bls = np.where(np.isnan(xd_data_bls).all(axis=(0, 1, 2)))[0]
no_bls = no_bls - nan_bls.size
xd_data_bls = np.delete(xd_data_bls, nan_bls, axis=3)

In [None]:
# percentage of data flagged
print(round(np.isnan(xd_data_bls).sum() / xd_data_bls.size * 100, 3))

In [None]:
# find data slice with few nan values
sum_nans = np.isnan(xd_data_bls).sum(axis=(1, 2))
ok = np.unravel_index(sum_nans.argmin(), sum_nans.shape)

test_data = xd_data_bls[ok[0], ..., ok[1]]

In [None]:
# remove frequencies at extremities with only nan entries
nan_chans = extrem_nans(np.isnan(test_data).all(axis=(1)))
if nan_chans.size != 0:
    flt_chans = np.delete(chans, nan_chans)
    flt_freqs = np.delete(freqs, nan_chans)
    test_data = np.delete(test_data, nan_chans, axis=0)
    
# remove time integrations at extremities with only nan entries
nan_tints = extrem_nans(np.isnan(test_data).all(axis=(0)))
if nan_tints.size != 0:
    flt_times = np.delete(xd_times, nan_tints)
    test_data = np.delete(test_data, nan_tints, axis=1)

### Gaussian Process regression

#### 1D real valued GP

In [None]:
sample_data = np.abs(test_data[:, 0])
print('Flagged data points at channels {}'.format(flt_chans[np.where(np.isnan(sample_data))[0]].tolist()))

nans = np.isnan(sample_data)
nan_loc = lambda z: z.nonzero()[0]

# interpolate nan values just for this plot and for noise estimate
f_vis = interp1d(flt_freqs[~nans], sample_data[~nans], kind='cubic')
sample_data_i = sample_data.copy()
sample_data_i[nans] = f_vis(flt_freqs[nans])

plt.figure()
plt.scatter(flt_freqs, sample_data, s=3)
plt.scatter(flt_freqs[nans], sample_data_i[nans], s=3, label='interp values')
plt.xlabel('Frequency')
plt.ylabel(r'$\left| V \right|$')
plt.legend(loc='upper right', prop={'size': 6})
plt.tight_layout()
plt.show()

In [None]:
# quick & dirty noise estimate - can get better values from the data files
noise_dest = np.var(np.abs(np.ediff1d(sample_data_i)))
print('Dirty noise estimate: {}'.format(round(noise_dest, 3)))

In [None]:
const_min = 10**np.floor(np.log10(sample_data[~nans].min()))
const_max = 10**(np.ceil(np.log10(sample_data[~nans].max()))+1)

kernel = gp.kernels.RBF(length_scale=f_resolution*20, length_scale_bounds=(1e5, 1e8)) * \
         gp.kernels.ConstantKernel(constant_value=np.nanmean(sample_data)**2, \
                                   constant_value_bounds=(const_min, const_max)) + \
         gp.kernels.WhiteKernel(noise_level=5e0, noise_level_bounds=(1e-1, 2e+1))

model = gp.GaussianProcessRegressor(kernel=kernel, n_restarts_optimizer=100, normalize_y=False)

model.fit(flt_freqs[~nans].reshape(-1, 1), sample_data[~nans])
print(model.kernel_)

In [None]:
model.kernel_.get_params()['k1'].get_params()['k1'].get_params()

In [None]:
freq_pred = np.linspace(flt_freqs[0], flt_freqs[-1], 1000)
y_pred, std = model.predict(freq_pred[:, np.newaxis], return_std=True)

In [None]:
plt.figure()
plt.fill_between(freq_pred, y_pred-3*std, y_pred+3*std, color='lightgray')
plt.fill_between(freq_pred, y_pred-2*std, y_pred+2*std, color='darkgray')
plt.fill_between(freq_pred, y_pred-1*std, y_pred+1*std, color='gray')
plt.scatter(flt_freqs, sample_data, s=2, color='orange', zorder=3)
plt.plot(freq_pred, y_pred, zorder=2)
plt.xlabel('Frequency')
plt.ylabel(r'$\left| V \right|$')
plt.tight_layout()
plt.show()

##### Explore hyperparameter space

In [None]:
# make a grid of different hyperparameter values to explore the marginal likelihood landscape
lsv_log_min = np.log10(f_resolution)
lsv_log_max = np.log10(flt_freqs[-1] - flt_freqs[0])

nlv_log_min = np.log10(noise_dest) + 1
nlv_log_max = np.log10(noise_dest) - 1

lsv = np.logspace(lsv_log_min, lsv_log_max, 100)
nlv = np.logspace(nlv_log_min, nlv_log_max, 100)
cv = np.nanmean(sample_data)**2 # fix this

margloglik = np.empty((lsv.size, nlv.size))

for i, l in enumerate(lsv):
    for j, n in enumerate(nlv):
        kernel = gp.kernels.RBF(length_scale=l) * \
                 gp.kernels.ConstantKernel(constant_value=cv) + \
                 gp.kernels.WhiteKernel(noise_level=n)

        model_ij = gp.GaussianProcessRegressor(kernel=kernel, optimizer=None)
        model_ij.fit(flt_freqs[~nans].reshape(-1, 1), sample_data[~nans])

        margloglik[i, j] = model_ij.log_marginal_likelihood()

In [None]:
L, N = np.meshgrid(lsv, nlv)

fig, ax = plt.subplots()

levels = np.logspace(np.log10(-margloglik.max()), np.log10(-margloglik.min()), 30)
cp = ax.contour(L, N, -margloglik.T, levels=levels, cmap='viridis')

idx_max = np.unravel_index(np.argmax(margloglik), margloglik.shape)
ax.scatter(lsv[idx_max[0]], nlv[idx_max[1]], color='orange')

norm = colors.Normalize(vmin=cp.cvalues.min(), vmax=cp.cvalues.max())
sm = plt.cm.ScalarMappable(norm=norm, cmap=cp.cmap)
# sm.set_array([])
cb = fig.colorbar(sm)
cb.ax.invert_yaxis()

ax.set_xscale('log')
ax.set_yscale('log')

ax.set_title('NMLL')
ax.set_xlabel('Length scale')
ax.set_ylabel('Noise level')

plt.tight_layout()
plt.show()

In [None]:
# cubic interpolation of MLL space to get finer detail
no_interp_points = 1000
l_log_i = np.logspace(lsv_log_min, lsv_log_max, no_interp_points)
n_log_i = np.logspace(nlv_log_min, nlv_log_max, no_interp_points)
L_i, N_i = np.meshgrid(l_log_i, n_log_i)
grid_z = griddata((L.ravel(), N.ravel()), margloglik.ravel(), (L_i, N_i), method='cubic', rescale=True)

In [None]:
fig, ax = plt.subplots(ncols=2, sharey=True)
ax[0].imshow(-margloglik, cmap='viridis', interpolation=None, \
             extent=[lsv.min(), lsv.max(), nlv.min(), nlv.max()], aspect='auto')
ax[1].imshow(-grid_z, cmap='viridis', interpolation=None, \
             extent=[l_log_i.min(), l_log_i.max(), n_log_i.min(), n_log_i.max()], aspect='auto')
ax[0].set_title('Original points', size=9)
ax[1].set_title('Cubic interpolation', size=9)
ax[0].set_xlabel('Length scale', size=8)
ax[1].set_xlabel('Length scale', size=8)
ax[0].set_ylabel('Noise level', size=8)
plt.tight_layout()
plt.show()

In [None]:
# hyperparameters with highest MLLs
interp_max_idx = np.unravel_index(np.argmax(grid_z), grid_z.shape)
print('{:e} {:e}'.format(l_log_i[interp_max_idx[0]], n_log_i[interp_max_idx[1]]))

In [None]:
# compare to ones found from optimization
model.kernel_

In [None]:
%matplotlib notebook
fig = plt.figure(figsize=(8, 6))
ax = plt.axes(projection='3d')
ax.plot_surface(L, N, -margloglik.T, cmap='plasma_r')
ax.set_xlabel('Length scale')
ax.set_ylabel('Noise level')
ax.set_zlabel('NMLL', rotation=90)
plt.tight_layout()
plt.show()

In [None]:
# back to original settings
%matplotlib inline
mpl.rcParams['figure.dpi'] = 175
mpl.rcParams['figure.figsize'] = (5, 3)

#### Complex GP

In [None]:
c_tdata = decomposeCArray(test_data[:, 0]) # multiple target for complex numbers
c_nans = np.isnan(sample_data)

In [None]:
const_min = np.max(10**np.floor(np.log10(np.max([np.nanmin(c_tdata), 1]))))
const_max = 10**(np.ceil(np.log10(np.nanmax(c_tdata)))+1)

kernel = gp.kernels.RBF(length_scale=f_resolution*20, length_scale_bounds=(1e5, 1e8)) * \
         gp.kernels.ConstantKernel(constant_value=np.nanmean(c_tdata)**2, \
                                   constant_value_bounds=(const_min, const_max)) + \
         gp.kernels.WhiteKernel(noise_level=5e0, noise_level_bounds=(1e-1, 2e+1))

model = gp.GaussianProcessRegressor(kernel=kernel, n_restarts_optimizer=100, normalize_y=False)

model.fit(flt_freqs[~nans].reshape(-1, 1), c_tdata[~nans])
print(model.kernel_)

In [None]:
y_c_pred, c_std = model.predict(freq_pred[:, np.newaxis], return_std=True)
y_c_pred = y_c_pred[:, 0] + 1j*y_c_pred[:, 1]

In [None]:
plt.figure()

real_lab = r'$\mathfrak{Re} \; (V)$'
imag_lab = r'$\mathfrak{Im} \; (V)$'

# real
plt.fill_between(freq_pred, y_c_pred.real-3*std, y_c_pred.real+3*std, color='lightgray', alpha=0.5)
plt.fill_between(freq_pred, y_c_pred.real-2*std, y_c_pred.real+2*std, color='darkgray', alpha=0.5)
plt.fill_between(freq_pred, y_c_pred.real-1*std, y_c_pred.real+1*std, color='gray', alpha=0.5)
plt.plot(freq_pred, y_c_pred.real, zorder=2, color='orange', label=real_lab)
plt.scatter(flt_freqs[~nans], c_tdata[:, 0][~nans], s=2, color='blue', zorder=3)

# imag
plt.fill_between(freq_pred, y_c_pred.imag-3*std, y_c_pred.imag+3*std, color='lightgray', alpha=0.5)
plt.fill_between(freq_pred, y_c_pred.imag-2*std, y_c_pred.imag+2*std, color='darkgray', alpha=0.5)
plt.fill_between(freq_pred, y_c_pred.imag-1*std, y_c_pred.imag+1*std, color='gray', alpha=0.5)
plt.plot(freq_pred, y_c_pred.imag, zorder=2, color='purple', label=imag_lab)
plt.scatter(flt_freqs[~nans], c_tdata[:, 1][~nans], s=2, color='red', zorder=3)

plt.xlabel('Frequency')
plt.legend(prop={'size': 6})
plt.tight_layout()
plt.show()

#### Add time as a feature

In [None]:
test_data_r = test_data[:30, :20] # reduce size of dataset f
flt_freqs_r = flt_freqs[:test_data_r.shape[0]]

In [None]:
if test_data_r.shape[0] >= test_data_r.shape[1]:
    ncols = 1
    nrows = 2
    sharex = True
    sharey = False
else:
    ncols = 2
    nrows = 1
    sharex = False
    sharey = True

fig, ax = plt.subplots(ncols=ncols, nrows=nrows, sharex=sharex, sharey=sharey)
extent_r = [flt_freqs_r.min()/1e6, flt_freqs_r.max()/1e6, test_data_r.shape[1], 0]
ax[0].imshow(test_data_r.real.T, extent=extent_r, aspect='auto', interpolation='none')
ax[1].imshow(test_data_r.imag.T, extent=extent_r, aspect='auto', interpolation='none')

ax[0].set_title(real_lab, size=8)
ax[1].set_title(imag_lab, size=8)
ax[0].set_ylabel('Time integration', size=8)
ax[1].set_xlabel('Frequency (MHz)', size=8)
if sharex:
    ax[1].set_ylabel('Time integration', size=8)
else:
    ax[0].set_xlabel('Frequency (MHz)', size=8)

plt.tight_layout()
plt.show()

In [None]:
X_us = np.array(np.meshgrid(flt_freqs[:test_data_r.shape[0]], \
                            np.arange(test_data_r.shape[1]))).T.reshape(-1, 2)
scaler = preprocessing.StandardScaler().fit(X_us)
X_s = scaler.transform(X_us)

Y = decomposeCArray(test_data_r.ravel(order='C'))

nans_a = np.isnan(test_data_r.ravel(order='C'))
X, Y = X_s[~nans_a, :], Y[~nans_a, :]

In [None]:
const_min = np.max(10**np.floor(np.log10(np.nanmax([np.nanmin(Y), 1]))))
const_max = 10**(np.ceil(np.log10(np.nanmax(Y)))+1)
ls_init = np.mean(np.ediff1d(X.ravel()))
noise_est = np.nanstd(test_data_r)

kernel = gp.kernels.RBF(length_scale=ls_init, length_scale_bounds=(1e-2, 2e1)) * \
         gp.kernels.ConstantKernel(constant_value=np.nanmean(Y)**2, \
                                   constant_value_bounds=(const_min, const_max)) + \
         gp.kernels.WhiteKernel(noise_level=noise_est, noise_level_bounds=(1e-1, 1e+2))

model = gp.GaussianProcessRegressor(kernel=kernel, n_restarts_optimizer=5, normalize_y=False)

model.fit(X, Y)
print(model.kernel_)

In [None]:
f_min, f_max = np.min(np.min(X[:, 0])), np.max(np.max(X[:, 0]))
t_min, t_max = np.min(np.min(X[:, 1])), np.max(np.max(X[:, 1]))
dim_pred = np.array(np.meshgrid(np.linspace(f_min, f_max, 100), \
                                np.linspace(t_min, t_max, 100))).T.reshape(-1, 2)

In [None]:
y_ir_pred, c_std = model.predict(dim_pred, return_std=True)

In [None]:
dim_pred_itr = scaler.inverse_transform(dim_pred)
coords = dim_pred_itr.reshape(100, 100, -1)

pdata = (y_ir_pred[:, 0] + 1j*y_ir_pred[:, 1]).reshape(100, 100)

In [None]:
# look at a time slice
plt.figure()
plt.plot(coords[:, 0, 0], pdata.real[:, 0], color='orange', label=real_lab)
plt.plot(coords[:, 0, 0], pdata.imag[:, 0], color='purple', label=imag_lab)
plt.scatter(flt_freqs[:test_data_r.shape[0]], test_data_r[:, 0].real, s=2, color='blue')
plt.scatter(flt_freqs[:test_data_r.shape[0]], test_data_r[:, 0].imag, s=2, color='red')
plt.xlabel('Frequency')
plt.legend(prop={'size': 6})
plt.tight_layout()
plt.show()

In [None]:
# look at a frequency slice
plt.figure()
plt.plot(coords[0, :, 1], pdata.real[0, :], color='orange', label=real_lab)
plt.plot(coords[0, :, 1], pdata.imag[0, :], color='purple', label=imag_lab)
plt.scatter(np.arange(test_data_r.shape[1]), test_data_r[0, :].real, s=2, color='blue')
plt.scatter(np.arange(test_data_r.shape[1]), test_data_r[0, :].imag, s=2, color='red')
plt.xlabel('Time integration')
plt.legend(prop={'size': 6})
plt.tight_layout()
plt.show()

In [None]:
# GP interpolation of nan data
Y_nan = model.predict(X_s[nans_a, :])
Y_nan = Y_nan[:, 0] + 1j*Y_nan[:, 1]

test_data_r_filled = test_data_r.ravel(order='C')
test_data_r_filled[nans_a] = Y_nan
test_data_r_filled = test_data_r_filled.reshape(test_data_r.shape)

In [None]:
# compare to 2D cubic interpolation
test_data_r_ci = nan_interp2d(test_data_r)

In [None]:
fig, axes = plt.subplots(ncols=2, nrows=2, sharex=True, sharey=True)

# cubic interpolation
axes[0][0].imshow(test_data_r_ci.real.T, extent=extent_r, aspect='auto')
axes[1][0].imshow(test_data_r_ci.imag.T, extent=extent_r, aspect='auto')

# GP
axes[0][1].imshow(test_data_r_filled.real.T, extent=extent_r, aspect='auto')
axes[1][1].imshow(test_data_r_filled.imag.T, extent=extent_r, aspect='auto')

for ax, col in zip(axes[0], ['Cubic interpolation', 'GP interpolation']):
    ax.set_title(col, size=8)
    
for ax, col in zip(axes, [real_lab, imag_lab]):
    ax[0].set_ylabel(col+'\n\nTime integration', size=8)
    
axes[1][0].set_xlabel('Frequency (MHz)', size=8)
axes[1][1].set_xlabel('Frequency (MHz)', size=8)

plt.tight_layout()
plt.show()

##### Performance against cubic interpolation

In [None]:
# sample non-nan data that has been median averaged across JDs
nn_chan_idxs = np.unique(np.where(np.isnan(xd_data_bls[..., 0]).all(axis=(0)))[0])
largest_gap = np.argmax(np.ediff1d(nn_chan_idxs))
chan_se = nn_chan_idxs[[largest_gap, largest_gap+1]]

chans_m = np.arange(chan_se[0]+1, chan_se[1])
test_data_m = np.nanmedian(xd_data_bls[:, chan_se[0]+1:chan_se[1], \
                                       :, 0], axis=0)

test_data_m = test_data_m[:30, :20]
flt_freqs_m = flt_freqs[:test_data_m.shape[0]]

In [None]:
if test_data_m.shape[0] >= test_data_m.shape[1]:
    ncols = 1
    nrows = 2
    sharex = True
    sharey = False
else:
    ncols = 2
    nrows = 1
    sharex = False
    sharey = True

fig, ax = plt.subplots(ncols=ncols, nrows=nrows, sharex=sharex, sharey=sharey)
extent_m = [flt_freqs_m.min()/1e6, flt_freqs_m.max()/1e6, test_data_m.shape[1], 0]
ax[0].imshow(test_data_m.real.T, extent=extent_m, aspect='auto')
ax[1].imshow(test_data_m.imag.T, extent=extent_m, aspect='auto')

ax[0].set_title(real_lab, size=8)
ax[1].set_title(imag_lab, size=8)
ax[0].set_ylabel('Time integration', size=8)
ax[1].set_xlabel('Frequency (MHz)', size=8)
if sharex:
    ax[1].set_ylabel('Time integration', size=8)
else:
    ax[0].set_xlabel('Frequency (MHz)', size=8)

plt.tight_layout()
plt.show()

In [None]:
# artifically add holes in data
flg_no_point = test_data_m.size//10
flg_no_chans = test_data_m.shape[0]//10
flg_no_tints = test_data_m.shape[1]//10

test_data_m_n = test_data_m.copy()

np.random.seed(0)
# random flagged slices
rnd_idx = np.sort(np.random.choice(np.arange(test_data_m.size), flg_no_point, replace=False))
rnd_idxs = np.unravel_index(rnd_idx, test_data_m.shape)
test_data_m_n[rnd_idxs] *= np.nan

# random flagged freq rows and tint cols
flg_chans = np.random.choice(np.arange(test_data_m.shape[0]), flg_no_chans, replace=False)
flg_tints = np.random.choice(np.arange(test_data_m.shape[1]), flg_no_tints, replace=False)
test_data_m_n[flg_chans, :] *= np.nan
test_data_m_n[:, flg_tints] *= np.nan

In [None]:
# randomly nan'd data points 
fig, ax = plt.subplots(ncols=ncols, nrows=nrows)
ax[0].imshow(test_data_m_n.real.T, extent=extent_m, aspect='auto', interpolation='none')
ax[1].imshow(test_data_m_n.imag.T, extent=extent_m, aspect='auto', interpolation='none')

ax[0].set_title(real_lab, size=8)
ax[1].set_title(imag_lab, size=8)
ax[0].set_ylabel('Time integration', size=8)
ax[1].set_xlabel('Frequency (MHz)', size=8)
if sharex:
    ax[1].set_ylabel('Time integration', size=8)
else:
    ax[0].set_xlabel('Frequency (MHz)', size=8)

plt.tight_layout()
plt.show()

In [None]:
# cubic interpolation
# note that if a flagged data point is in a corner, this method cannot extrapolate to that point
test_data_m_n_ci = nan_interp2d(test_data_m_n, kind='cubic')

In [None]:
# GP interpolation
X_us = np.array(np.meshgrid(flt_freqs[chan_se[0]:chan_se[1]][:test_data_m.shape[0]], \
                            np.arange(test_data_m_n.shape[1]))).T.reshape(-1, 2)
scaler = preprocessing.StandardScaler().fit(X_us)
X_s = scaler.transform(X_us)

Y = decomposeCArray(test_data_m_n.ravel(order='C'))

nans_a = np.isnan(test_data_m_n.ravel(order='C'))
X, Y = X_s[~nans_a, :], Y[~nans_a, :]


const_min = np.max(10**np.floor(np.log10(np.nanmax([np.nanmin(Y), 1]))))
const_max = 10**(np.ceil(np.log10(np.nanmax(Y)))+1)
ls_init = np.mean(np.ediff1d(X.ravel()))
noise_est = np.nanstd(test_data_m_n)

kernel = gp.kernels.RBF(length_scale=ls_init, length_scale_bounds=(1e-2, 1e1)) * \
         gp.kernels.ConstantKernel(constant_value=np.nanmean(Y)**2, \
                                   constant_value_bounds=(const_min, const_max)) + \
         gp.kernels.WhiteKernel(noise_level=noise_est, noise_level_bounds=(1e-1, 1e+2))

# normalize target values by removing mean and scaling to unit-variance with normalize_y=True 
model_i = gp.GaussianProcessRegressor(kernel=kernel, n_restarts_optimizer=5, normalize_y=True)
model_i.fit(X, Y)
print(model_i.kernel_)

y_ir_pred, c_std = model_i.predict(X_s[nans_a, :], return_std=True)

pdata = (y_ir_pred[:, 0] + 1j*y_ir_pred[:, 1])

In [None]:
test_data_m_n_gpi = test_data_m_n.copy()
test_data_m_n_gpi[np.where(np.isnan(test_data_m_n))] = pdata

In [None]:
# results
fig, axes = plt.subplots(ncols=2, nrows=2, sharex=True, sharey=True)

# cubic interpolation
axes[0][0].imshow(test_data_m_n_ci.real.T, extent=extent_m, aspect='auto')
axes[1][0].imshow(test_data_m_n_ci.imag.T, extent=extent_m, aspect='auto')

# GP
axes[0][1].imshow(test_data_m_n_gpi.real.T, extent=extent_m, aspect='auto')
axes[1][1].imshow(test_data_m_n_gpi.imag.T, extent=extent_m, aspect='auto')

for ax, col in zip(axes[0], ['Cubic interpolation', 'GP interpolation']):
    ax.set_title(col, size=8)
    
for ax, col in zip(axes, [real_lab, imag_lab]):
    ax[0].set_ylabel(col+'\n\nTime integration', size=8)
    
axes[1][0].set_xlabel('Frequency (MHz)', size=8)
axes[1][1].set_xlabel('Frequency (MHz)', size=8)

plt.tight_layout()
plt.show()

In [None]:
# residuals for the flagged data points
plt.figure()
ci_res = np.abs(test_data_m[rnd_idxs] - test_data_m_n_ci[rnd_idxs])
gpi_res = np.abs(test_data_m[rnd_idxs] - test_data_m_n_gpi[rnd_idxs])
plt.plot(ci_res, label='Cubic interp', lw=1.5, alpha=0.7)
plt.plot(gpi_res, label='GP interp', lw=1.5, alpha=0.7, color='orange')
plt.axhline(np.nanmean(ci_res), ls='--')
plt.axhline(gpi_res.mean(), ls='--', color='orange')
plt.ylabel(r'$\left| \Delta V \right|$')
plt.legend(loc='upper right', prop={'size': 8})
plt.tight_layout()
plt.show()

In [None]:
# cross-PS between all time integrations pairs

# don't want fully flagged time integrations
ps_data = np.delete(test_data_m, flg_tints, axis=1)
gpi_ps_data = np.delete(test_data_m_n_gpi, flg_tints, axis=1)

tint_pairs = list(itertools.product(np.arange(ps_data.shape[1]), repeat=2))
tints1 = [i[0] for i in tint_pairs]
tints2 = [i[1] for i in tint_pairs]

delay, pspec = signal.csd(ps_data[..., tints1], ps_data[..., tints2], \
    fs=1/f_resolution, window='hann', scaling='spectrum', nfft=None, detrend=False, \
    return_onesided=False, nperseg=ps_data.shape[0], axis=0)

delay_sort = np.argsort(delay)
delay = delay[delay_sort]
pspec = pspec[delay_sort, :]

gpi_delay, gpi_pspec = signal.csd(gpi_ps_data[..., tints1], gpi_ps_data[..., tints2], \
    fs=1/f_resolution, window='hann', scaling='spectrum', nfft=None, detrend=False, \
    return_onesided=False, nperseg=gpi_ps_data.shape[0], axis=0)

delay_sort = np.argsort(gpi_delay)
gpi_delay = gpi_delay[delay_sort]
gpi_pspec = gpi_pspec[delay_sort, :]

pspec_mean = np.nanmean(pspec, axis=1)
gpi_pspec_mean = np.nanmean(gpi_pspec, axis=1)

In [None]:
fig, axes = plt.subplots(ncols=3, figsize=(10, 6), sharey=True)

axes[0].plot(delay, np.abs(pspec), alpha=0.3)
axes[0].plot(delay, np.abs(pspec_mean), alpha=1, color='orange')
axes[0].set_ylabel('Power spectrum')

axes[1].plot(delay, np.abs(gpi_pspec), alpha=0.3)
axes[1].plot(delay, np.abs(gpi_pspec_mean), alpha=1, color='purple')

axes[2].plot(delay, np.abs(pspec_mean), alpha=0.6, color='orange', label='Actual')
axes[2].plot(delay, np.abs(gpi_pspec_mean), alpha=0.6, color='purple', label='Interpolated')

for ax in axes:
    ax.set_yscale('log')
    ax.set_xlabel('Delay')
    
axes[0].set_title('Actual data')
axes[1].set_title('Interpolated data')
axes[2].set_title('Comparison')
axes[2].legend(loc='best')

plt.show()