Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding the option to measure connectivity over time to spectral_connectivity() #17

Closed
Avoide opened this issue Jan 11, 2021 · 7 comments · Fixed by #104
Closed

Adding the option to measure connectivity over time to spectral_connectivity() #17

Avoide opened this issue Jan 11, 2021 · 7 comments · Fixed by #104

Comments

@Avoide
Copy link
Contributor

Avoide commented Jan 11, 2021

Describe the new feature or enhancement

Hi all,

mne-tools/mne-python#7937 indicated there was an error in the PLI implementation and I also had the same problem when working with resting-state epoched data. But after looking more into it I found out it is not really an error but instead two different ways to analyze connectivity: over time or over trials.

Mike X Cohen explains the difference quite nicely in his video here:
Youtube Link

I also commented in that issue, but are now creating a standalone feature request.
The current PLI and PLV implementation in spectral_connectivity() is connectivity over trials, which are suited for ERP data. But for resting-state data it does not make sense to analyze connectivity over trials but should instead be over time.

Describe your proposed implementation

I tried implementing my own version to calculate PLV and wPLI over time and you can see a simple working example here:

# Libraries
import numpy as np
import mne
from scipy import special
from mne.connectivity import spectral_connectivity

# Generate data
np.random.seed(42)
n_epochs = 5
n_channels = 3
n_times = 1000
data = np.random.rand(n_epochs, n_channels, n_times)

# Set sampling freq
sfreq = 250 # A reasonable random choice

# Choose what kind of data should be used for evaluation
# 0 = same trial repeated in all epochs
# 1 = 10Hz sinus waves with random phase differences in each channel and epoch
data_option = 1

if data_option == 0:
    # Make all 5 epochs the same trial to show difference between connectivity
    # over time and over trials. Here we expect con over trials = 1
    for i in range(n_epochs):
        data[i] = data[0]
elif data_option == 1:
    # Generate 10Hz sinus waves to show difference between connectivity
    # over time and over trials. Here we expect con over time = 1
    for i in range(n_epochs):
        for c in range(n_channels):
            wave_freq = 10
            epoch_len = n_times/sfreq
            # Introduce random phase for each channel
            phase = np.random.rand(1)*10
            # Generate sinus wave
            x = np.linspace(-wave_freq*epoch_len*np.pi+phase,
                            wave_freq*epoch_len*np.pi+phase,n_times)
            data[i,c] = np.squeeze(np.sin(x))
else:
    print("Data_option value chosen is invalid")

# Define freq bands
Freq_Bands = {"delta": [1.25, 4.0],
              "theta": [4.0, 8.0],
              "alpha": [8.0, 13.0],
              "beta": [13.0, 30.0],
              "gamma": [30.0, 49.0]}
n_freq_bands = len(Freq_Bands)
# Convert to tuples for the mne function
fmin=tuple([list(Freq_Bands.values())[f][0] for f in range(len(Freq_Bands))])
fmax=tuple([list(Freq_Bands.values())[f][1] for f in range(len(Freq_Bands))])

# Connectivity methods
connectivity_methods = ["plv","wpli"]
n_con_methods=len(connectivity_methods)

# Number of pairwise ch connections
n_ch_connections = scipy.special.comb(n_channels,2, exact=True, repetition=False)

# Pre-allocatate memory
con_data = np.zeros((n_con_methods,n_channels,n_channels,n_freq_bands))
con_data[con_data==0] = np.nan # nan matrix as 0 is meaningful

# Calculate PLV and wPLI - the MNE python implementation is over trials
con, freqs, times, n_epochs, n_tapers = spectral_connectivity(
    data, method=connectivity_methods,
    mode="multitaper", sfreq=sfreq, fmin=fmin, fmax=fmax,
    faverage=True, verbose=0)
# Save the results in array
con_data[0,:,:,:] = con[0] # PLV
con_data[1,:,:,:] = con[1] # WPLI

print("Alpha PLV over trials")
print(con_data[0,:,:,2]) # all 1, since it is the same trial
# so PLV across trials should be 1

# PLV and wPLI across time
# Make linspace array for morlet waves
freq_centers = np.arange(fmin[0],fmax[-1]+0.25,0.25)
# Prepare Morlets
morlets = mne.time_frequency.tfr.morlet(sfreq,freq_centers,n_cycles=3)

# Make freqs array for indexing
freqs0 = [0]*n_freq_bands
for f in range(n_freq_bands):
    freqs0[f] = freq_centers[(freq_centers>=fmin[f]) & (freq_centers<=fmax[f])]

def calculate_PLV_WPLI_across_time(cwt_data):
    """

    Parameters
    ----------
    cwt_data : array, shape(n_channels, n_freq, n_times)
        The continuous wavelet transform of the data in one epoch

    Returns
    -------
    con_array : array, shape(2, n_channels, n_channels, n_freq_bands)
        The connectivity matrix. Only the lower diagonal is calculated.
        First axis indicates whether it is PLV (0) or wPLI(1)

    """
    n_ch, n_freq0, n_time0 = cwt_data.shape
    # The wavelet transform coefficients are complex numbers
    # The real part correspond to the amplitude and the imaginary part can be used to calculate the phase
    angles_data = np.apply_along_axis(np.angle,2,cwt_data)
    # Prepare array with phase differences between all combinations
    phase_diff_arr = np.zeros((n_ch_connections,n_freq0,n_time0))
    phase_diff_arr_ch_idx = [0]*n_ch_connections
    counter = 0
    for ch_c in range(n_ch):
        for ch_r in range(ch_c+1,n_ch): # only calculate lower diagonal
            phase_diff_arr[counter] = angles_data[ch_r]-angles_data[ch_c]
            phase_diff_arr_ch_idx[counter] = [ch_r,ch_c]
            counter += 1
            
    del angles_data # free up some memory
    # =========================================================================
    # PLV over time correspond to mean across time of the absolute value of
    # the circular length of the relative phases. So PLV will be 1 if
    # the phases of 2 signals maintain a constant lag
    # In equational form: PLV = 1/N * |sum(e^i(phase1-phase2))|
    # =========================================================================
    # Convert phase difference to complex part i(phase1-phase2)
    phase_diff_arr_im = 0*phase_diff_arr+1j*phase_diff_arr
    # Take the exponential
    expPhase_arr = np.apply_along_axis(np.exp,2,phase_diff_arr_im)
    # Take mean and then the absolute value
    meanexpPhase_arr = np.apply_along_axis(np.mean,2,expPhase_arr)
    PLV_arr = np.apply_along_axis(np.abs,1,meanexpPhase_arr)
    
    del phase_diff_arr_im, expPhase_arr # free up some memory
    # =========================================================================
    # PLI over time correspond to the sign of the sine of relative phase
    # differences. So PLI will be 1 if one signal is always leading or
    # lagging behind the other signal. But it is insensitive to changes in
    # relative phase, as long as it is the same signal that leads.
    # If 2 signals are almost in phase, they might shift between lead/lag
    # due to small fluctuations from noise. This would lead to unstable
    # estimation of "phase" synchronisation.
    # The wPLI tries to correct for this by weighting the PLI with the
    # magnitude of the lag, to attenuate noise sources giving rise to
    # near zero phase lag "synchronization"
    # In equational form: WPLI = |E{|phase_diff|*sign(phase_diff)}| / E{|phase_diff|}
    # =========================================================================
    # Calculate the magnitude of phase differences
    phase_diff_mag_arr = np.abs(np.sin(phase_diff_arr))
    # Calculate the signed phase difference (PLI)
    sign_phase_diff_arr = np.sign(np.sin(phase_diff_arr))
    # Calculate the nominator (abs and average across time)
    WPLI_nominator = np.abs(np.mean(phase_diff_mag_arr*sign_phase_diff_arr,axis=2))
    # Calculate denominator for normalization
    WPLI_denom = np.mean(phase_diff_mag_arr, axis=2)
    # Calculate WPLI
    WPLI_arr = WPLI_nominator/WPLI_denom
    
    del phase_diff_mag_arr, sign_phase_diff_arr, phase_diff_arr # free up some memory
    # Calculate mean for each freq band
    con_array0 = np.zeros((2,n_ch_connections,n_freq_bands))
    for f in range(n_freq_bands):
        freq_of_interest = freqs0[f] # find freqs in the freq band of interest
        freq_idx = [i in freq_of_interest for i in freq_centers] # get idx
        con_array0[0,:,f] = np.apply_along_axis(np.mean,1,PLV_arr[:,freq_idx])
        con_array0[1,:,f] = np.apply_along_axis(np.mean,1,WPLI_arr[:,freq_idx])
    
    # Save to final array with ch-ch connectivity in matrix form
    con_array = np.zeros((2,n_ch,n_ch,n_freq_bands))
    for com in range(n_ch_connections):
        ch_r = phase_diff_arr_ch_idx[com][0] # get row idx
        ch_c = phase_diff_arr_ch_idx[com][1] # get col idx
        con_array[0,ch_r,ch_c,:] = con_array0[0,com]
        con_array[1,ch_r,ch_c,:] = con_array0[1,com]

    return con_array

# Pre-allocate memory
con_data_time = np.zeros((n_con_methods,n_channels,n_channels,n_freq_bands))
con_data_time[con_data_time==0] = np.nan # nan matrix as 0 is meaningful

con_data1 = np.zeros((n_con_methods,n_epochs,n_channels,n_channels,n_freq_bands))

for epoch in range(n_epochs):
    # First the data in each epoch is retrieved
    temp_data = data[i]
    # Then continuous wavelet transform is used to decompose in time frequencies
    temp_data_cwt = mne.time_frequency.tfr.cwt(temp_data,morlets)
    # PLV and WPLI value is calculated across timepoints in each freq band and averaged into the 5 defined freq bands
    PLV_WPLI_con = calculate_PLV_WPLI_across_time(temp_data_cwt)
    # Save results
    con_data1[0,epoch,:,:,:] = PLV_WPLI_con[0] # phase locking value
    con_data1[1,epoch,:,:,:] = PLV_WPLI_con[1] # weighted phase lag index

# Take average across epochs for PLV and wPLI
con_data2 = np.mean(con_data1,axis=1)
# Save to final array
con_data_time[0,:,:,:] = con_data2[0] # phase locking value
con_data_time[1,:,:,:] = con_data2[1] # weighted phase lag index

print("Alpha PLV over time")
print(con_data_time[0,:,:,2])

Notice I added the option to choose what kind of data should be used. If data_option = 0 then the expected output is that PLV across trials are 1 since I just repeated the first epoch in all other epochs, while if data_option = 1 the PLV across time should be 1 as I use 10Hz sine waves with different phase differences.

It is still quite rough and not optimized but should give an idea of the procedure.
I am not completely certain about my implementation but the wPLI is calculated by following equation 1 in [1] and weighing by the absolute magnitude of the phase differences as described by [2]. PLV was calculated according to the equation in [3]. According to mne-tools/mne-python#8305 this equation should be equivalent to what is being used in spectral_connectivity() already, but my implementation over time gives another PLV than spectral_connectivity().

It would be helpful if anyone could have a look. After optimizing it the feature should probably just be added as an option to spectral_connectivity() and does not need its own version.

Describe possible alternatives

I don't know any alternatives as this suggestion is just about a slight different way of calculating the connectivity over time instead of over trials.

[1] Hardmeier, Martin, Florian Hatz, Habib Bousleiman, Christian Schindler, Cornelis Jan Stam, and Peter Fuhr. 2014. “Reproducibility of Functional Connectivity and Graph Measures Based on the Phase Lag Index (PLI) and Weighted Phase Lag Index (WPLI) Derived from High Resolution EEG.” PLoS ONE 9 (10). https://doi.org/10.1371/journal.pone.0108648.
[2] Vinck, Martin, Robert Oostenveld, Marijn Van Wingerden, Franscesco Battaglia, and Cyriel M.A. Pennartz. 2011. “An Improved Index of Phase-Synchronization for Electrophysiological Data in the Presence of Volume-Conduction, Noise and Sample-Size Bias.” NeuroImage 55 (4): 1548–65. https://doi.org/10.1016/j.neuroimage.2011.01.055.
[3] Lachaux, Jean Philippe, Eugenio Rodriguez, Jacques Martinerie, and Francisco J. Varela. 1999. “Measuring Phase Synchrony in Brain Signals.” Human Brain Mapping 8 (4): 194–208. https://doi.org/10.1002/(SICI)1097-0193(1999)8:4<194::AID-HBM4>3.0.CO;2-C.

@mmagnuski
Copy link
Member

@Avoide I updated your code example to inclue syntax coloring. I'll try to add to the discussion on monday!

@balandongiv
Copy link

balandongiv commented Jul 7, 2021

Hi @larsoner / @mmagnuski , May I know whether this issue of calculating over time and suggestion by the OP has been incorporate in recent mne version?

@mmagnuski
Copy link
Member

No, but the connectivity measures are moved to a separate repository now - mne-connectivity. It would be good to discuss this issue there once the transition is over

@larsoner larsoner transferred this issue from mne-tools/mne-python Jul 7, 2021
@Avoide
Copy link
Contributor Author

Avoide commented Aug 19, 2021

I am still available for discussing this issue, although having someone more experienced in programming is probably needed for the actual incorporation. I'm still quite new to python and haven't used github before, so I didn't start at pull request.
The simple code I wrote seemed to work when I tested it for a few simple cases, although it is significantly slower than the current spectral_connectivity() function for calculating connectivity over trials.

@adam2392
Copy link
Member

Hi @Avoide thanks for making this issue in the first place.

I haven't had a chance to look at your code in depth yet, but for what it's worth, I am open to trying to get this to work over time.

Your proposal implementation if it works over time doesn't need to work on the existing spectral_connectivity function. We are thinking of separating out the two methods of doing so into one that operates over trials (what is there now) and one that operates over time (hypothetically what you posted).

Lmk if you have time to start making a PR. Otherwise, I may not be able to fully implement this until a few weeks at least.

@adam2392
Copy link
Member

Hi @Avoide the latest PR #67 adds the ability to compute spectral connectivity over time. Would you be able to check it out and see if this fits your needs? Thanks!

@Avoide
Copy link
Contributor Author

Avoide commented Jan 12, 2022

@adam2392 I just tried using the function spectral_connectivity_time, but I am a bit in doubt about the API.
Previously (using mne.connectivity.spectral_connectivity) I used fmin and fmax to specify the range, e.g. [8, 13] for alpha range. How should that be translated for the freqs parameter? Is it the center frequency I should insert?

After using the function on some dummy data with:
n_epochs = 5
n_channels = 3
n_times = 1000
I get a EpochSpectroTemporalConnectivity object with the size (n_epochs, n_channels, n_channels, n_freq, n_times)
But how should I interpret n_times in regards to for instance PLV? i.e. is this the PLV at that time point between the two channels in that specific epoch and freq? But PLV shouldn't be defined for one timepoint?

If I want to boil the information down to (n_channels, n_channels, n_freq), i.e. the same output as mne.connectivity.spectral_connectivity, can I just take the np.mean over the first and last axis? However when I do this, the diagonal for PLV is not one as I would expect.

I also tried using the PLV on some randomly generated data and got very high PLV values despite taking the average over 100 runs, but perhaps I am using it wrongly. I attached some code using random data and also my own implementation as comparison.

# Libraries
import numpy as np
import mne
import scipy
from mne_connectivity import spectral_connectivity_time

# Generate data and try spectral_connectivity_time
n_repetitions = 100
n_epochs = 5
n_channels = 3
n_times = 1000
PLV = np.zeros((n_repetitions,n_channels,n_channels))
for rep in range(n_repetitions):
    data = np.random.rand(n_epochs, n_channels, n_times)
    
    # Set sampling freq
    sfreq = 250 # A reasonable random choice
    
    # Trying MNE_connectivity implementation
    # Create epochs object for compatibility
    ch_names = ["T1","T2","T3"]
    info = mne.create_info(ch_names, sfreq, ch_types="eeg")
    data_epoch = mne.EpochsArray(data,info)
    
    # PLV
    con3 = spectral_connectivity_time(data_epoch, method="plv",
        mode="multitaper", sfreq=sfreq, freqs=10)
    
    con3 = con3.get_data()
    # Avg over time
    con3 = np.mean(con3,axis=-1)
    # Avg over epochs
    con3 = np.mean(con3,axis=0)
    
    PLV[rep] = con3[:,:,0]

# Average over runs
print("PLV using spectral_connectivity_time over 100 runs with randomly generated data")
print(np.mean(PLV,axis=0))

# Try my implementation
Freq_Bands = {"alpha": [8.0, 12.0]}
n_freq_bands = len(Freq_Bands)
# Convert to tuples for the mne function
fmin=tuple([list(Freq_Bands.values())[f][0] for f in range(len(Freq_Bands))])
fmax=tuple([list(Freq_Bands.values())[f][1] for f in range(len(Freq_Bands))])
def calculate_PLV_WPLI_across_time(data):
    n_ch, n_time0 = data.shape
    x = data.copy()
    # Filter the signal in the different freq bands
    con_array0 = np.zeros((2,n_ch,n_ch,n_freq_bands))
    # con_array0[con_array0==0] = np.nan
    for fname, frange in Freq_Bands.items():
        fmin, fmax = [float(interval) for interval in frange]
        signal_filtered = mne.filter.filter_data(x, sfreq, fmin, fmax,
                                          fir_design="firwin", verbose=0)
        # Filtering on finite signals will yield very low values for first
        # and last timepoint, which can create outliers. E.g. 1e-29 compared to 1e-14
        # This systematic error is removed by removing the first and last timepoint
        signal_filtered = signal_filtered[:,1:-1]
        # Hilbert transform to get complex signal
        analytic_signal = scipy.signal.hilbert(signal_filtered)
        # Calculate for the lower diagnonal only as it is symmetric
        for ch_r in range(n_ch):
            for ch_c in range(n_ch):
                if ch_r>ch_c:
                    # =========================================================================
                    # PLV over time correspond to mean across time of the absolute value of
                    # the circular length of the relative phases. So PLV will be 1 if
                    # the phases of 2 signals maintain a constant lag
                    # In equational form: PLV = 1/N * |sum(e^i(phase1-phase2))|
                    # In code: abs(mean(exp(1i*phase_diff)))
                    # =========================================================================
                    # The real part correspond to the amplitude and the imaginary part can be used to calculate the phase
                    phase_diff = np.angle(analytic_signal[ch_r])-np.angle(analytic_signal[ch_c])
                    # Convert phase difference to complex part i(phase1-phase2)
                    phase_diff_im = 0*phase_diff+1j*phase_diff
                    # Take the exponential, then the mean followed by absolute value
                    PLV = np.abs(np.mean(np.exp(phase_diff_im)))
                    # Save to array
                    con_array0[0,ch_r,ch_c,list(Freq_Bands.keys()).index(fname)] = PLV
                    # =========================================================================
                    # PLI over time correspond to the sign of the sine of relative phase
                    # differences. So PLI will be 1 if one signal is always leading or
                    # lagging behind the other signal. But it is insensitive to changes in
                    # relative phase, as long as it is the same signal that leads.
                    # If 2 signals are almost in phase, they might shift between lead/lag
                    # due to small fluctuations from noise. This would lead to unstable
                    # estimation of "phase" synchronisation.
                    # The wPLI tries to correct for this by weighting the PLI with the
                    # magnitude of the lag, to attenuate noise sources giving rise to
                    # near zero phase lag "synchronization"
                    # In equational form: WPLI = |E{|phase_diff|*sign(phase_diff)}| / E{|phase_diff|}
                    # =========================================================================
                    # Calculate the magnitude of phase differences
                    phase_diff_mag = np.abs(np.sin(phase_diff))
                    # Calculate the signed phase difference (PLI)
                    sign_phase_diff = np.sign(np.sin(phase_diff))
                    # Calculate the nominator (abs and average across time)
                    WPLI_nominator = np.abs(np.mean(phase_diff_mag*sign_phase_diff))
                    # Calculate denominator for normalization
                    WPLI_denom = np.mean(phase_diff_mag)
                    # Calculate WPLI
                    WPLI = WPLI_nominator/WPLI_denom
                    # Save to array
                    con_array0[1,ch_r,ch_c,list(Freq_Bands.keys()).index(fname)] = WPLI
    return con_array0

n_repetitions = 100
n_epochs = 5
n_channels = 3
n_times = 1000
PLV = np.zeros((n_repetitions,n_channels,n_channels))
for rep in range(n_repetitions):
    data = np.random.rand(n_epochs, n_channels, n_times)
    
    # Set sampling freq
    sfreq = 250 # A reasonable random choice
    
    con4 = np.zeros((n_epochs, 2, n_channels,n_channels,n_freq_bands))
    for e in range(n_epochs):
        con4[e] = calculate_PLV_WPLI_across_time(data[0])
    # Avg over epochs
    con4 = np.mean(con4,axis=0)
    
    PLV[rep] = con4[0,:,:,0]

print("PLV using my implementation over 100 runs with randomly generated data")
print(np.mean(PLV,axis=0))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
4 participants