In [89]:
import numpy as np
import scipy as sp
import pylab as plt
import os
import time
from scipy.stats import mode
from scipy.signal.windows import blackmanharris as BH
import scipy as sp
from multiprocess import Pool
from astropy.timeseries import LombScargle as LSc
from scipy.signal import lombscargle as lsc
conjgrad = sp.sparse.linalg.cg # conjugate gradient solver
nsample = np.random.multivariate_normal
from scipy.optimize import minimize, Bounds

In [88]:
def flip(arr):
    return np.fft.fftshift(arr)

In [106]:
def FOp(s):
    
    # basic Fourier operator for matrix side length s
    
    FO = np.zeros((s,s), dtype=complex)
    for i in range(s):
        y = np.zeros(s)
        y[i] = 1
        FO[i] = np.fft.fft(y)
        
    return FO

In [3]:
def m(tau, s):

    y = np.zeros(s)
    y[tau] = 1
    return np.fft.fft(y)

################
    
def Q(tau, s):
    
    filename = 'Qs/Q' + str(s) + '_' + str(tau)+ '.npy'
    if os.path.isfile(filename):
        Q = np.load(filename)
    else:
        Q = np.outer( m(tau,s).conj(), m(tau,s) )
        np.save(filename, Q)
    return Q

################

def bias(tau, s, R, C_noise_total):
    
    return 0.5 * np.trace( C_noise_total @ R.conj() @ Q(tau,s) @ R )

# redundant in HERA setup, but useful for comparison

################

def qhat(x, tau, s, R, bias):

    E = R.conj() @ Q(tau, s) @ R
    
    return 0.5 *  ( x.conj().T @ E @ x ) - bias

def qhat_h(x1, x2, tau, s, R):
    # HERA-like cross corr
    
    # exact Pspec code:
#     Rx1, Rx2 = np.dot(R, x1) , np.dot(R, x2)
    
    
#     QRx2 = np.dot(Q(tau,s), Rx2)
#     return 0.5 * np.einsum('i...,i...->...', Rx1.conj(), QRx2) # can test these vs each other....
    
    Rx1, Rx2 = R@x1 , R@x2

    return 0.5 * Rx1.conj().T @ Q(tau,s) @ Rx2  
                   

################

def F(s, R):
    
    t = np.arange(s)
    
    F = np.zeros((s,s), dtype=complex)
    
    for a in range(s):
        for b in range(s):
            F[a,b] = 0.5 * np.trace(  R.conj() @ Q(t[a], s) @ R @ Q(t[b], s)  )
    
    return F

################

def Ft(s, R):
    
    t = np.arange(s)
    
    F = np.zeros((s,s), dtype=complex)
    
    iR1Q1, iR2Q2 = {}, {}
    
    for i in range(s):
        iR1Q1[i] = np.dot(np.conj(R).T, Q(i,s)) # R_1 Q_alt
        iR2Q2[i] = np.dot(R, Q(i,s)) # R_2 Q
    
    
    for a in range(s):
        for b in range(s):
            F[a,b] = 0.5*np.einsum('ab,ba', iR1Q1[a], iR2Q2[b])
    
    return F

################

def M_Fhalf(F):
    
    return np.linalg.inv(sp.linalg.sqrtm(F))

################

def M_Finv(F):
    
    return np.linalg.inv(F)

################

def M_opt(F):
    
    M = np.diag(np.divide(1, np.diag(F)))
    W = M @ F
    for row in range(0, np.shape(M)[0]):
        M[row] = np.divide(M[row], np.sum(W[row])) # Wll normalisation - does it make sense? perhaps not
    
    return M 

################

def q(V, s, R, bias):
    
    # calculates qhat across the tau range
    # need to calculate bias beforehand - needs to be an array (see return statement below)
    
    N = len(V)
    
    t_ = np.arange(s)
    
    qs = np.zeros((N,s))
    
    for i in range(N):
        qs[i] = np.array([qhat(V[i], tau, s, R, bias[tau]) for tau in t_])
    
    return qs

################

def q_h(V, s, R, taper=None):
    
    N = len(V)//2 # Should be even if creating pairs of visibilities
    
    t_ = np.arange(s)
    
    qs = np.zeros((N,s), dtype=complex)

    for i in range(N):
        qs[i] = np.array([qhat_h(V[2*i], V[2*i+1], t, s, R) for t in t_])
    
    return qs

################

def p(q, M):
    
    return M @ q 

# def q_h(x1, x2, s, R):
#     # old function - keeping in case
#     t_ = np.arange(s)
#     return np.array([qhat_h(x1, x2, t, s, R) for t in t_])
# question - how do HERA do this?

In [4]:
def pstats(qs, FM, plot=None, ylims=[-2,3.5], xlims=[None, None], scale=None):
    
    pall_ = flip(np.array([p(q, FM) for q in qs],dtype=complex))

    p_ = np.array([np.mean(pall_[:,i]) for i in range(s)])

    # std_ = flip([np.std(np.array([p(inprange, q, MB) for q in qs])[:,i]) for i in range(s)])

    sp1 = np.array([np.percentile(pall_[:,i], 84.2) for i in range(s)])
    sm1 = np.array([np.percentile(pall_[:,i], 15.8) for i in range(s)])
    
    if (plot=='plot' or plot=='plotonly'):
        plt.plot(tau_f, p_, label=r'mean')
        plt.fill_between(tau_f, sm1,sp1, alpha=0.25, color='orange')
        plt.plot(tau_f, alys, label=r'analytic P($\tau$)')
        plt.title('%d sims + noise + foregrounds'%(N), fontsize=20)
        plt.legend(fontsize=14)
        plt.xlabel(r'Delay $\tau$ [$\mu$s]', fontsize=20)
        
        if scale=='log':
            plt.yscale('log')
        plt.xlim(xlims[0],xlims[1])
        plt.ylim(ylims[0],ylims[1])
        plt.show()
        
    if plot=='plotonly': return None
    else: return pall_, p_, sp1, sm1

In [60]:
def pstats2(qs, FM):
    
    pall_ = flip(np.array([p(q, FM) for q in qs],dtype=complex))

    return pall_

In [5]:
def getqs(Vis, R):
    """
    creates required matrices and runs the skeleton OQE over the given set of data using weighting R,
    returning unnormalized qs to be normalized by the pstats function
    
    """
    st = time.time()
    s = len(Vis[0])
    matc(R)
    Fm = F(s, R) # Fisher matrix
    MB = M_opt(Fm)
    MA = M_Finv(Fm)
    qs = q_h(Vis, s, R)
    print('%.3fs'%(time.time()-st))
    return qs, Fm, MB, MA

In [37]:
def q_hp(V, s, R, ncpu):
    st=time.time()
    N = len(V)//2 
    t_ = np.arange(s)
    if np.iscomplexobj(V):  qs = np.zeros((N,s), dtype=complex)
    else: qs = np.zeros((N,s))
        
    Vidxs = np.arange(N)
    with Pool(ncpu) as pool:
        qs = pool.map(lambda idx: np.array([qhat_h(V[2*idx], V[2*idx+1], t, s, R) for t in t_]), Vidxs)
    print('%.3fs'%(time.time()-st))
    return qs

In [31]:
def addnoise_c(Vis_sfg, C_noise):
    # assumes same covariance for Re and Im noise
    N = len(Vis_sfg)
    s = len(Vis_sfg[0])
    V = np.zeros((2*N,s),dtype=complex)
    noise = nsample(mv, C_noise, 4*N)
    for i,x in enumerate(Vis_sfg):
        V[2*i] , V[2*i+1]  =  x + noise[4*i+0] + 1j*noise[4*i+1], x + noise[4*i+2] + 1j*noise[4*i+3]
    return V

def addnoise(Vis_sfg, C_noise):
    N = len(Vis_sfg)
    s = len(Vis_sfg[0])
    V = np.zeros((2*N,s))
    noise = nsample(mv, C_noise, 2*N)
    for i,x in enumerate(Vis_sfg):
        V[2*i] , V[2*i+1]  =  x + noise[2*i+0], x + noise[2*i+1]
    return V

def addnonoise(V):
    N = len(V)
    Vout = np.zeros((2*N,s),dtype=complex)
    for i,x in enumerate(V):
        Vout[2*i] , Vout[2*i+1]  =  x , x 
    return Vout

In [45]:
def Sig_QEN(R, C_noise, norm):
    
    # In jianrong's paper, E is normalized. So, need to divide by the sum of the (row) of the window function
    
    s = len(R)
    
    Sig = np.zeros(s, dtype=complex)
    
    for i in range(s):
        E = R @ Q(i, s) @ R * norm

        Sig[i] = 0.5 * np.trace( E @ C_noise @ E @ C_noise )
    
    return Sig



def Sig_QESN(R, C_noise, C_S, norm):
    
    s = len(R)
    
    Sig = np.zeros(s, dtype=complex)
    
    for i in range(s):
        E = R @ Q(i, s) @ R * norm
        Sig[i] = 0.5 * np.trace( (E @ C_noise @ E @ C_noise) + (E @ C_S @ E @ C_noise) + (E @ C_noise @ E @ C_S))
        
    
    return Sig

In [8]:
def matc(M):
    evs = np.linalg.eigvals(M).real
    Minv = np.linalg.inv(M)
    print(np.all(evs > 0),' - positive definite')
    print(np.format_float_scientific( max(evs)/min(evs)  ),' - eigval ratio')
    print('%f'%(np.linalg.norm(M)*np.linalg.norm(Minv)),' - condition (norm C x norm Cinv)')
    print('')

In [92]:
def GCR(dat, w, S, N, nrzn=1, inpaint='inpaint', quiet='quiet', dat2=None, bla=None, poolmap=False):
    """
    
    Returns a number of constrained realizations for a flagged data vector with signal prior S and noise prior N,
    following the Gaussian constrained realization equation. 
    
    Important note: any given realization will not match the data in the unflagged region. If the desire is to
    in-paint flagged regions, you will need to select only this region of the output vector.
    
    ---Variables:
    d - Data vector. 
    w - Flagging/mask vector (1 for unflagged data, 0 for flagged data)
    S - Signal prior covariance matrix. Has the same dimension as the data vector. May only be real-valued.   
    N - Noise prior covariance matrix. Has the same dimension as the data vector. May only be real-valued.
    nrzn - Number of realizations to return.
    inpaint - return inpainted mean and realizations. 'inpaint'=on
    noisy - Inclusion of cosmetic white noise consistent with N in returned realizations. 'addnoise'=on 
    ufN - use an unflagged N matrix to find GCR solutions. 'Nunflagged'=on
    
    """
    s = len(dat)
    d = dat.reshape((1,max(s,len(dat.T))))
    nbaselines = 1
    if bla is not None:
        nbaselines = bla.shape[0] # n_rows (each redundant baseline is one row)
        d = np.sum(bla, axis=0).reshape((1,max(len(dat),len(dat.T))))
        
    if dat2 is not None:
        d2 = dat2.reshape((1,max(len(dat2),len(dat2.T))))
        d = (d+d2)/2
    if np.iscomplexobj(d) or np.iscomplexobj(S) or np.iscomplexobj(N):
        dat_complex=True
    else: dat_complex=False
            
    Sh = sp.linalg.sqrtm( S )
    Nh = sp.linalg.sqrtm( N )
    Si = np.linalg.inv( S )
    Ni = w.T*np.linalg.inv( N )*w
    Sih = sp.linalg.sqrtm( Si )
    Nih = sp.linalg.sqrtm( Ni )
           
    A = nbaselines*Sh @ Ni @ Sh  + np.eye(s)
    Ai = np.linalg.inv(A)
    b = Sh @ Ni @ (w*d).T
        
    wiener, ___ = conjgrad(A, b, maxiter=1e5, M=Ai) # Wiener / max-likelihood solution
    Wnr = Sh@wiener
    
    if dat_complex: solns = np.zeros((nrzn,s), dtype=complex) # array for solutions to GCR equation
    else: solns = np.zeros((nrzn,s))
    
    if not dat_complex:
        for i in range(nrzn):
            omi = np.random.randn(s,1)
            cri = omi + Sh @ Nih @ np.sum(np.random.randn(s,nbaselines),axis=0)
   
            bcri = b + cri
            xboth, info2 = conjgrad(A, bcri, maxiter=1e5, M=Ai)
            solns[i] = Sh@xboth
                
    if dat_complex:
        for i in range(nrzn):
            omi, omj = np.random.randn(s,1),np.random.randn(s,1)
            
            cri = (omi+1j*omj)/2**0.5 + Sh @ Nih @ (   np.sum(np.random.randn(s,nbaselines),axis=1)+\
                                                    1j*np.sum(np.random.randn(s,nbaselines),axis=1)   ).reshape((s,1))/2**0.5
            
            bcri = b + cri
            xboth, info2 = conjgrad(A, bcri, maxiter=1e5, M=Ai)
            solns[i] = Sh@xboth
    
    # in-painting if required
    
    unflagged_indices = np.where(w==1)
    if inpaint=='inpaint':
        Wnr[unflagged_indices] = dat[unflagged_indices]
        for i,sol in enumerate(solns):
            solns[i][unflagged_indices] = dat[unflagged_indices]
            
    if inpaint=='subtract':
        Z = np.zeros(s, dtype=complex)
        Z[unflagged_indices] = dat[unflagged_indices] - Wnr[unflagged_indices]
        Wnr = Z
        for i,sol in enumerate(solns):
            Z = np.zeros(s, dtype=complex)
            Z[unflagged_indices] = dat[unflagged_indices] - sol[unflagged_indices]
            solns[i] = Z
        
    if quiet!='quiet':   print('complex data: ',dat_complex)

    if poolmap==True: return solns
    else: return Wnr, solns

In [None]:
# # problem cell - do not run

# if not dat_complex:
#     for i in range(nrzn):
#         omi,omj = np.random.randn(120,1), np.random.randn(120,1)
#         cri = omi + Sh @ Nih @ omj
#         bcri = b + cri
#         xboth, info2 = conjgrad(A, bcri, maxiter=1e5, M=Ai)
#         solns[i] = Sh@xboth

# if dat_complex:
#     for i in range(nrzn):
#         omi, omj, omk, oml = np.random.randn(120,1),np.random.randn(120,1),np.random.randn(120,1),\
#                                             np.random.randn(120,1)
#         cri = (omi+1j*omj) + Sh @ Nih @ (omk+1j*oml)
#         bcri = b + cri
#         xboth, info2 = conjgrad(A, bcri, maxiter=1e5, M=Ai)
#         solns[i] = Sh@xboth 
        
# '''
# look at real and imag parts of the power spectrum, rather than abs

# seems like...
#     cri = (omi+1j*omj) + Sh @ Nih @ (omk+1j*oml) should this be NORM 1? so divide by sqrt 2?????
#      or
#     conjgrad() might be the culprit?

# '''

In [109]:
def naivePS(data, meansub=1, taper=1):
        
    if meansub:
#         d = d - np.mean(d, axis=0)
        d = data - np.mean(data, axis=1)[:,np.newaxis]
    
    if taper:
        d *= BH(s)
        
    return flip(abs(np.fft.fft(d))**2)

In [107]:
def nPS(data):
    
    sk_ = np.fft.fft(data, axis=-1)
    
    return np.sum(sk_ * sk_.conj(), axis=0).real
    

In [55]:
def GCR_OQEarray(V, w, S, N, inpaint='inpaint'):
    
    VW = np.zeros(V.shape, dtype=complex)
    VC = np.zeros(V.shape, dtype=complex)
    
    for i,rzn in enumerate(V):
        if not i%2: 
            id2=i+1 
            wnr, cr = GCR(rzn, w, S, N, nrzn=1, inpaint=inpaint, dat2=V[id2])
            VW[i] = wnr
            VC[i] = cr
            if i==0: print('complex: data',np.iscomplexobj(rzn),'C_s',np.iscomplexobj(S),'C_n',np.iscomplexobj(N))

            fi = np.where(w==0)
            VW[i+1] = V[i+1]
            VC[i+1] = V[i+1]
            VW[i+1][fi] = wnr[fi]
            VC[i+1][fi] = cr[:,fi]

        if not i%100: print(i, end=' ')
    return VW, VC

In [103]:
def GCR_array(V, w, S, N, inpaint='inpaint', bla=None, ncpu=2):
    
    """
    bla set to nbaselines (not None) - take noiseless sims and generate nbaselines \times noisy sims to hand to the GCR solver
    """
    
    VW = np.zeros(V.shape, dtype=complex)
    VC = np.zeros(V.shape, dtype=complex)
    
       
    Vidxs = np.arange(V.shape[0])
    
    st=time.time()
    
    if bla is None:
        
        with Pool(ncpu) as pool:
            VC = pool.map(lambda idx: GCR(V[idx], w, S, N, nrzn=1, inpaint=inpaint, poolmap=True), Vidxs)

#         for i,rzn in enumerate(V):

#             wnr, cr = GCR(rzn, w, S, N, nrzn=1, inpaint=inpaint)
#             VW[i] = wnr
#             VC[i] = cr
#             if i==0: print('complex: data',np.iscomplexobj(rzn),'C_s',np.iscomplexobj(S),'C_n',np.iscomplexobj(N))

#             if not i%100: print(i, end=' ')
                
    else:
        nbaselines = bla
        for i,rzn in enumerate(V): # assuming now that these V are noiseless, we're going to create redundant baseline data here
            
            noises = nsample(mv, C_noise, nbaselines) + 1j*nsample(mv, C_noise, nbaselines)
            redundantbls = rzn + noises # broadcasting single V to nbaselines * noise
            
            wnr, cr = GCR(rzn, w, S, N, nrzn=1, inpaint=inpaint, bla=redundantbls)
            VW[i] = wnr
            VC[i] = cr
            if i==0: print('complex: data',np.iscomplexobj(rzn),'C_s',np.iscomplexobj(S),'C_n',np.iscomplexobj(N))

            if not i%100: print(i, end=' ')
    print('%.1fs'%(time.time()-st), end=' ')
    return VW, np.array(VC).reshape(V.shape)

In [104]:
def GCR_eigarray(V, w, S, F_evecs, N, ncpu=2):
    
    VC = np.zeros(V.shape, dtype=complex)

    Vidxs = np.arange(V.shape[0])
    
    st=time.time()
    
    with Pool(ncpu) as pool:
        VC = pool.map(lambda idx: GCR_eig(V[idx], w, S, F_evecs, N), Vidxs)

    print('%.1fs'%(time.time()-st), end=' ')
    return np.array(VC).reshape(V.shape)

In [58]:
def wfcorrection(S,N):
    # does this need an fftshift? adding one in the results script
    T = np.zeros((s,s), dtype=complex)

    for i in range(s):
        T[i] = m(i,s)
        
    return np.diag(T.conj().T @ ( S @ np.linalg.inv(S+N) @ N    ) @ T)

In [59]:
def tt_eb(dat):
    # two tailed errorbar
    
    sp1 = np.array([np.percentile(dat[:,i], 84.2) for i in range(s)])
    sm1 = np.array([np.percentile(dat[:,i], 15.8) for i in range(s)])
    
    return sp1+sm1/2

In [62]:
def decorr_matrix(w, tau, freqs):
    """
    Calculate rotation matrix from Eq. 8 of Bryna's note, 
    needed to decorrelate the real and imaginary amplitudes of 
    the least squares-fitted cosine/sine modes.
    
    To use this matrix to decorrelate the amplitudes, do:
    `np.dot(rot, [A_real, A_imag])`
    
    Parameters
    ----------
    w : array_like
        Mask vector, 1 for unmasked, 0 for masked.
    
    tau : float
        Delay wavenumber.
    
    freqs : array_like
        Frequency array.
    
    Returns
    -------
    rot : array_like
        Rotation matrix to be applied to the amplitude vector.
        
    eigvals : array_like
        Eigenvalues of mode correlation matrix. Multiply the 
        variance of the mode, sigma^2, with these eigenvalues 
        to get the new variances (sigma1^2, sigma2^2); see 
        Eq. 9 of Bryna's note.
    """
    # Sine and cosine terms with mask
    cos = w*np.cos(2.*np.pi*tau*freqs)
    sin = w*np.sin(2.*np.pi*tau*freqs)
    
    # Covariance (overlap) matrix
    cov = np.zeros((2, 2))
    cov[0,0] = np.sum(cos*cos)
    cov[0,1] = cov[1,0] = np.sum(cos*sin)
    cov[1,1] = np.sum(sin*sin)
    
    # Calculate rotation angle directly
    theta = 0.5 * np.arctan2(2.*np.sum(cos*sin), 
                             np.sum(cos*cos) - np.sum(sin*sin))
    rot = np.array([[np.cos(theta), np.sin(theta)], 
                     [-np.sin(theta), np.cos(theta)]])
    rinv = np.array([[np.cos(theta), -np.sin(theta)], 
                     [np.sin(theta), np.cos(theta)]])
    eigvals = np.diag(np.dot(rot, np.dot(cov, rinv)))
    
    # Eigendecomposition
    #eigvals, eigvec = np.linalg.eig(cov)
    
    # Rotation operator is inverse of the eigenvector matrix
    #rot = np.linalg.pinv(eigvec)
    return rot, eigvals

In [63]:
def decorr_pspec(A_re, A_im, w, tau, freqs):
    """
    Calculate the LSSA power spectrum, by using Bryna's decorrelation 
    scheme to re-weight the real and imaginary amplitudes.
    """
    ps = np.zeros(tau.size)
    
    # Loop over tau modes
    for i, t in enumerate(tau):
        # Get decorrelation matrix and eigenvalues
        rot, eigvals = decorr_matrix(w=w, tau=t, freqs=freqs)
        
        # Apply decorrelation rotation
        A1, A2 = np.matmul(rot, np.array([A_re[i], A_im[i]]))
        
        # Construct power spectrum (c.f. Eq. 12 of Bryna's note)
        # Multiplied num. and denom. by each eigval squared to avoid 1/0
        ps[i] = ((A1 * eigvals[1])**2. + (A2 * eigvals[0])**2.) \
              / (eigvals[0]**2. + eigvals[1]**2.)
    return ps

In [85]:
def model_ap(amp, phase, tau, freqs):
    return amp * np.exp(-2.*np.pi*1.j*tau*freqs + 1.j*phase)

def model_aa(A_re, A_im, tau, freqs):
    return (A_re + 1.j*A_im) * np.exp(-2.*np.pi*1.j*tau*freqs)

def lssa_fit_modes(d, freqs, invcov=None, fit_amp_phase=True, tau=None, 
                   minimize_method='L-BFGS-B', taper=None):
    r"""
    Perform a weighted LSSA fit to masked complex 1D data.

    NOTE: The input data/covariance should have already had the flagged 
    channels removed. Use the `trim_flagged_channels()` function to do 
    this.
    
    The log-likelihood for each sinusoid takes the assumed form:
    
    $\log L_n = \tilde{x}^\dagger \tilde{C}^{-1} \tilde{x}$
    
    where $\tau_n = n / \Delta \nu$, $\Delta \nu$ is the bandwidth, and 
    
    $x = [d - A \exp(2 \pi i \nu \tau_n + i\phi)]$.

    The tilde denotes vectors/matrices from which the masked channels 
    (rows/columns) have been removed entirely.
    
    Parameters:
        d (array_like):
            Complex data array that has already had flagged channels removed.
        
        freqs (array_like):
            Array of frequency values, in MHz. Used to get tau values in 
            the right units only. Flagged channels must have already been 
            removed.
        
        invcov (array_like):
            Inverse of the covariance matrix (flagged channels must have been 
            removed before inverting).

        fit_amp_phase (bool, optional):
            If True, fits the (real) amplitude and (real) phase parameters 
            for each sinusoid. If False, fits the real and imaginary amplitudes.
        
        tau (array_like, optional):
            Array of tau modes to fit. If `None`, will use `fftfreq()` to 
            calculate the tau values. Units: nanosec.
        
        taper (array_like, optional):
            If specified, multiplies the data and sinusoid model by a taper 
            function to enforce periodicity. The taper should be evaluated 
            at the locations specified in `freqs`
        
        minimize_method (str, optional):
            Which SciPy minimisation method to use. Default: `'L-BFGS-B'`.
    
    Returns:
        tau (array_like):
            Wavenumbers, calculated as tau_n = n / L, in nanoseconds.
            
        param1, param2 (array_like):
            If `fit_amp_phase` is True, these are the best-fit amplitude and 
            phase of the sinusoids. Otherwise, they are the real and imaginary 
            amplitudes of the sinusoids.
    """
    # Get shape of data etc.
    bandwidth = (freqs[-1] - freqs[0]) / 1e3 # assumed MHz, convert to GHz
    assert d.size == invcov.shape[0] == invcov.shape[1] == freqs.size, \
        "Data, inv. covariance, and freqs array must have same number of channels"
    
    # Calculate tau values
    if tau is None:
        tau = np.fft.fftfreq(n=freqs.size, d=freqs[1]-freqs[0]) * 1e3 # nanosec
    
    # Taper
    if taper is None:
        taper = 1.
    else:
        assert taper.size == freqs.size, \
            "'taper' must be evaluated at locations given in 'freqs'"
    
    # Log-likelihood (or log-posterior) function
    def loglike(p, n):
        if fit_amp_phase:
            m = model_ap(amp=p[0], phase=p[1], tau=tau[n], freqs=freqs)
        else:
            m = model_aa(A_re=p[0], A_im=p[1], tau=tau[n], freqs=freqs)
        
        # Calculate residual and log-likelihood
        x = taper * (d - m)
        logl = 0.5 * np.dot(x.conj(), np.dot(invcov, x))
        return logl.real # Result should be real
    
    # Set appropriate bounds for fits
    max_abs = np.max(np.abs(d))
    if fit_amp_phase:
        bounds = [(-100.*max_abs, 100.*max_abs), (0., 2.*np.pi)]
    else:
        bounds = [(-1000.*max_abs, 1000.*max_abs), (-1000.*max_abs, 1000.*max_abs)]
    
    # Do least-squares fit for each tau
    param1 = np.zeros(tau.size)
    param2 = np.zeros(tau.size)
    
    for n in range(tau.size):
        p0 = np.zeros(2)

        # Rough initial guess
        if fit_amp_phase:
            p0[0] = 0.2 * np.max(np.abs(d))
            p0[1] = 0.5 * np.pi
        else:
            p0[0] = 0.2 * np.max(d.real) # rough guess at amplitude
            p0[1] = 0.2 * np.max(d.imag)
        
        # Least-squares fit for mode n
        result = minimize(loglike, p0, args=(n,), 
                          method=minimize_method, 
                          bounds=bounds)
        param1[n], param2[n] = result.x
    
    return tau, param1, param2

In [86]:
def trim_flagged_channels(w, x):
    """
    Remove flagged channels from a 1D or 2D (square) array. This is 
    a necessary pre-processing step for LSSA.

    Parameters:
        w (array_like):
            1D array of mask values, where 1 means unmasked and 0 means 
            masked.
        
        x (array_like):
            1D or square 2D array to remove the masked channels from.

    Returns:
        xtilde (array_like):
            Input array with the flagged channels removed.
    """
    # Check inputs
    assert np.shape(x) == (w.size,) or np.shape(x) == (w.size, w.size), \
        "Input array must have shape (w.size) or (w.size, w.size)"

    # 1D case
    if len(x.shape) == 1:
        return x[w == 1.]
    else:
        return x[:,w == 1.][w == 1.,:]

In [87]:
from scipy.signal.windows import dpss, kaiser

def dpss_fit_modes(d, w, freqs, cov, nmodes=10, alpha=1.,
                   minimize_method='L-BFGS-B', taper=None):
    r"""
    Perform a weighted DPSS fit to masked complex 1D data.
    
    The log-likelihood for each DPSS mode takes the assumed form:
    
    $\log L_n = \tilde{x}^\dagger \tilde{C}^{-1} \tilde{x}$
    
    where $\tau_n = n / \Delta \nu$, $\Delta \nu$ is the bandwidth, and 
    
    $x = [d - A f_dpss(n, \nu))]$.

    The tilde denotes vectors/matrices from which the masked channels 
    (rows/columns) have been removed entirely.
    
    Parameters:
        d (array_like):
            Complex data array that has already had flagged channels removed.
        
        w (array_like):
            Flag array, where 
        
        freqs (array_like):
            Array of frequency values, in MHz.
        
        cov (array_like):
            Covariance matrix model.
        
        nmodes (int, optional):
            Number of DPSS modes to fit.
        
        alpha (float, optional):
            Bandwidth factor used in the DPSS functions. Higher values are more 
            concentrated towards the centre of the band.
        
        taper (array_like, optional):
            If specified, multiplies the data and sinusoid model by a taper 
            function to enforce periodicity. The taper should be evaluated 
            at the locations specified in `freqs`.
        
        minimize_method (str, optional):
            Which SciPy minimisation method to use. Default: `'L-BFGS-B'`.
        
    Returns:
        param1, param2 (array_like):
            If `fit_amp_phase` is True, these are the best-fit amplitude and 
            phase of the sinusoids. Otherwise, they are the real and imaginary 
            amplitudes of the sinusoids.
    """
    # Get shape of data etc.
    assert d.size == cov.shape[0] == cov.shape[1] == freqs.size == w.size, \
        "Data, flags, covariance, and freqs arrays must have same number of channels"
    
    # Taper
    if taper is None:
        taper = 1.
    else:
        assert taper.size == freqs.size, \
            "'taper' must be evaluated at locations given in 'freqs'"
    
    # Precompute DPSS basis functions, shape: (nmodes, nfreqs)
    dpss_modes = dpss(freqs.size, 
                      NW=alpha, 
                      Kmax=nmodes, 
                      sym=False)
    
    # Invert covariance matrix
    invcov = np.linalg.inv(cov)
    
    # Log-likelihood (or log-posterior) function
    def loglike(p):
        # Real and imaginary coeffs are interleaved
        m = p[0::2,np.newaxis]*dpss_modes[:,:] + 1.j*p[1::2,np.newaxis]*dpss_modes[:,:]
        m = np.sum(m, axis=0)
        
        # Calculate residual and log-likelihood
        x = taper * w * (d - m)
        logl = 0.5 * np.dot(x.conj(), np.dot(invcov, x))
        return logl.real # Result should be real
            
    # Least-squares fit for all modes
    p0 = np.zeros(2*nmodes)
    result = minimize(loglike, p0, 
                      method=minimize_method, 
                      bounds=None)
    amps = result.x
    return dpss_modes, amps