In [1]:
import numpy

def spectral_interpolation_fast(timepoint_defined, signal, TR):
    
    #This function is to do spectral interpolation based on the lomb-scarge
    #periodogram. This function is based on matlab code from Anish Mitra/Jonathan Power.
    #This is normally ran on parcellated timeseries, so if you are running this code on
    #the whole brain in a vertex/voxel wise way, it may be pretty slow. If you want to
    #make a speed improvement with only having minimal degredations to interpolation accuracy,
    #you can change the ofac value from 8 to 4 (at least with multiband sequences this has been
    #tested on)
    #
    #
    #timepoint_defined should be a numpy boolean array or array of zeros/ones saying whether or
    #not a given timepoint is defined. signal should be shape <num_rois, num_timepoints> and TR
    #should be the amount of time between samples. Returns a numpy array with the same shape as
    #signal, with instances where timepoint_defined == False or 0 replaced by interpolated values
    
    good_timepoint_inds = np.where(timepoint_defined == True)[0]
    bad_timepoint_inds = np.where(timepoint_defined == False)[0]
    num_timepoints = timepoint_defined.shape[0]
    signal_copy = signal.copy()
    
    t = float(TR)*good_timepoint_inds
    h = signal[:,good_timepoint_inds]
    TH = np.linspace(0,(num_timepoints - 1)*TR,num=num_timepoints)
    ofac = float(8) #Higher than this is slow without good quality improvements
    hifac = float(1)

    N = timepoint_defined.shape[0] #Number of timepoints
    T = np.max(t) - np.min(t) #Total observed timespan

    #Calculate sampling frequencies
    f = np.linspace(1/(T*ofac), hifac*N/(2*T), num = int(((hifac*N/(2*T))/((1/(T*ofac))) + 1)))

    #angular frequencies and constant offsets
    w = 2*np.pi*f

    t1 = np.reshape(t,((1,t.shape[0])))
    w1 = np.reshape(w,((w.shape[0],1)))

    tan_a = np.sum(np.sin(np.matmul(w1,t1*2)), axis=1)
    tan_b = np.sum(np.cos(np.matmul(w1,t1*2)), axis=1)
    tau = np.divide(np.arctan2(tan_a,tan_b),2*w)

    a1 = np.matmul(w1,t1)
    b1 = np.asarray([np.multiply(w,tau)]*t.shape[0]).transpose()
    cs_input = a1 - b1

    #Calculate the spectral power sine and cosine terms
    cterm = np.cos(cs_input)
    sterm = np.sin(cs_input)
    
    cos_denominator = np.sum(np.power(cterm,2),axis=1)
    sin_denominator = np.sum(np.power(sterm,2),axis=1)
    
    #The inverse function to re-construct the original time series pt. 1
    Time = TH
    T_rep = np.asarray([Time]*w.shape[0])
    #already have w defined
    prod = np.multiply(T_rep, w1)
    sin_t = np.sin(prod)
    cos_t = np.cos(prod)
    
    for i in range(h.shape[0]):

        ##C_final = (sum(Cmult,2).^2)./sum(Cterm.^2,2)
        #This calculation is done speerately for the numerator, denominator, and the division
        Cmult = np.multiply(cterm, h[i,:])
        numerator = np.sum(Cmult,axis=1)

        c = np.divide(numerator, cos_denominator)

        #Repeat the above for sine term
        Smult = np.multiply(sterm,h[i,:])
        numerator = np.sum(Smult, axis=1)
        s = np.divide(numerator,sin_denominator)

        #The inverse function to re-construct the original time series pt. 2
        sw_p = np.multiply(sin_t,np.reshape(s,(s.shape[0],1)))
        cw_p = np.multiply(cos_t,np.reshape(c,(c.shape[0],1)))

        S = np.sum(sw_p,axis=0)
        C = np.sum(cw_p,axis=0)
        H = C + S

        #Normalize the reconstructed spectrum, needed when ofac > 1
        Std_H = np.std(H)
        Std_h = np.std(h)
        norm_fac = np.divide(Std_H,Std_h)
        H = np.divide(H,norm_fac)

        signal_copy[i,bad_timepoint_inds] = H[bad_timepoint_inds]

    
    return signal_copy