In [1]:
import os
import mne
from mne.preprocessing import (ICA, create_eog_epochs, create_ecg_epochs,
                               corrmap)

eeglab_file = os.path.join('eeglab2021.1', 'sample_data', 'eeglab_data_epochs_ica.set')

epochs = mne.io.read_epochs_eeglab(eeglab_file)

Extracting parameters from /Users/jacobfeitelberg/Desktop/sarma/iclabel-python/eeglab2021.1/sample_data/eeglab_data_epochs_ica.set...
Not setting metadata
Not setting metadata
80 matching events found
No baseline correction applied
0 projection items activated
Ready.


  epochs = mne.io.read_epochs_eeglab(eeglab_file)
['EOG1', 'EOG2']
Consider setting the channel types to be of EEG/sEEG/ECoG/DBS/fNIRS using inst.set_channel_types before calling inst.set_montage, or omit these channels when creating your montage.
  epochs = mne.io.read_epochs_eeglab(eeglab_file)


In [2]:
epochs.get_data().shape

(80, 32, 384)

In [3]:
ica = ICA(n_components=None, max_iter='auto', random_state=97, method='infomax')
ica.fit(epochs)
ica

Fitting ICA to data using 30 channels (please be patient, this may take a while)


  ica.fit(epochs)


Selecting by non-zero PCA components: 30 components
 
Fitting ICA took 6.2s.


0,1
Method,infomax
Fit,500 iterations on epochs (30720 samples)
ICA components,30
Explained variance,100.0 %
Available PCA components,30
Channel types,eeg
ICA components marked for exclusion,—


In [4]:
import numpy as np

icaact = ica.get_sources(epochs).get_data()
icaact = np.transpose(icaact, [1,2,0])

# weights (unmixing matrix)
icaweights = ica.unmixing_matrix_

icawinv = np.linalg.pinv(ica.unmixing_matrix_ @ ica.pca_components_.T)

srate = 128
pnts = 384
trials = 80

In [5]:
from eeg_features import eeg_autocorr_fftw, eeg_features, eeg_rpsd, eeg_topoplot

In [10]:
autocorr = eeg_autocorr_fftw(icaact, trials, srate, pnts=pnts)
psds = eeg_rpsd(icaact=icaact, icaweights=icaweights, pnts=pnts, srate=srate, trials=trials)

In [6]:
def mne_to_eeglab_locs(raw):
    def sph2topo(theta,phi):
        az = phi
        horiz = theta
        angle  = -1*horiz
        radius = (np.pi/2 - az)/np.pi
        return angle,radius

    def cart2sph(x,y,z):
        azimuth = np.arctan2(y,x)
        elevation = np.arctan2(z,np.sqrt(x**2 + y**2))
        r = np.sqrt(x**2 + y**2 + z**2)
        # theta,phi,r
        return azimuth, elevation, r
    
    locs = raw._get_channel_positions()

    #%% Obtain carthesian coordinates
    X = locs[:,1]
    Y = -1*locs[:,0] # be mindful of the nose orientation in eeglab and mne 
    # see https://github.com/mne-tools/mne-python/blob/24377ad3200b6099ed47576e9cf8b27578d571ef/mne/io/eeglab/eeglab.py#L105
    Z = locs[:,2]

    #%% Obtain Spherical Coordinates 
    sph = np.array([cart2sph(X[i],Y[i],Z[i]) for i in range(len(X))])
    theta = sph[:,0]
    phi = sph[:,1]

    #%% Obtain Polar coordinates (as in eeglab)
    topo = np.array([sph2topo(theta[i],phi[i]) for i in range(len(theta))])
    Rd = topo[:,1]
    Th = topo[:,0]
    
    return Rd.reshape([1,-1]), np.degrees(Th).reshape([1,-1])

In [7]:
Rd, Th = mne_to_eeglab_locs(epochs)

In [15]:
import scipy.io as sio

In [48]:
topoplot_data = sio.loadmat('test_data/topoplot_data.mat')

# Inputs
# icawinv = topoplot_data['icawinv']
rd = topoplot_data['Rd']
th = topoplot_data['Th']
plotchans = topoplot_data['plotchans']

In [72]:
import warnings

def pol2cart(theta: np.array, rho: np.array) -> tuple[np.array, np.array]:
    """
    Converts polar coordinates to cartesian coordinates.

    Args:
        theta (np.array): angle
        rho (np.array): magnitude
    """
    x = rho * np.cos(theta)
    y = rho * np.sin(theta)
    return x, y


def mergesimpts(data: np.array, tols: list[np.array, np.array, np.array], mode: str = 'average') -> np.array:
    """

    Args:
        data (np.array): [description]
        tols (list[np.array, np.array, np.array]): [description]
        mode (str, optional): [description]. Defaults to 'average'.

    Returns:
        np.array: [description]
    """
    data_ = data.copy()[np.argsort(data[:, 0])]
    newdata = []
    tols_ = np.array(tols)
    idxs_ready = []
    point = 0
    for point in range(data_.shape[0]):
        if point in idxs_ready:
            continue
        else:
            similar_pts = np.where(np.prod(np.abs(data_ - data_[point]) < tols_, axis=-1))
            similar_pts = np.array(list(set(similar_pts[0].tolist()) - set(idxs_ready)))
            idxs_ready += similar_pts.tolist()
            if mode == 'average':
                exemplar = np.mean(data_[similar_pts], axis=0)
            else:
                exemplar = data_[similar_pts].copy()[0]  # first
            newdata.append(exemplar)
    return np.array(newdata)


def mergepoints2D(x: np.array, y: np.array, v: np.array) -> tuple[np.array, np.array, np.array]:
    """
    Averages values for points that are close to each other.

    Args:
        x (np.array): x-coordinates
        y (np.array): y-coordinates
        v (np.array): values

    Returns:
        tuple[np.array, np.array, np.array]: [description]
    """
    # Sort x and y so duplicate points can be averaged
    # Need x,y and z to be column vectors
    sz = x.size
    x = x.copy()
    y = y.copy()
    v = v.copy()
    x = np.reshape(x, sz, order='F')
    y = np.reshape(y, sz, order='F')
    v = np.reshape(v, sz, order='F')

    myepsx = np.spacing(0.5 * (np.max(x) - np.min(x))) ** (1 / 3)
    myepsy = np.spacing(0.5 * (np.max(y) - np.min(y))) ** (1 / 3)
    # Look for x, y points that are indentical (within a tolerance)
    # Average out the values for these points
    if np.all(np.isreal(v)):
        data = np.stack((y, x, v), axis=-1)
        yxv = mergesimpts(data, [myepsy, myepsx, np.inf], 'average')
        x = yxv[:, 1]
        y = yxv[:, 0]
        v = yxv[:, 2]
    else:
        # If z is imaginary split out the real and imaginary parts
        data = np.stack((y, x, np.real(v), np.imag(v)), axis=-1)
        yxv = mergesimpts(data, [myepsy, myepsx, np.inf, np.inf], 'average')
        x = yxv[:, 1]
        y = yxv[:, 0]
        # Re-combine the real and imaginary parts
        v = yxv[:, 2] + 1j * yxv[:, 3]
    # Give a warning if some of the points were duplicates (and averaged out)
    # if sz > x.shape[0]:
    #     print('MATLAB:griddata:DuplicateDataPoints')
    return x, y, v


def gdatav4(x: np.array, y: np.array, v: np.array, xq: np.array, yq: np.array) -> tuple[np.array, np.array, np.array]:
    """
    GDATAV4 MATLAB 4 GRIDDATA interpolation
    Reference:  David T. Sandwell, Biharmonic spline
    interpolation of GEOS-3 and SEASAT altimeter
    data, Geophysical Research Letters, 2, 139-142,
    1987.  Describes interpolation using value or
    gradient of value in any dimension.
    
    Args:
        x (np.array): x-coordinates
        y (np.array): y-coordinates
        v (np.array): values
        xq (np.array): x-grid
        yq (np.array): y-grid

    Returns:
        tuple[np.array, np.array, np.array]: tuple of Xi, Yi, Zi 
    """

    x, y, v = mergepoints2D(x, y, v)

    xy = x + 1j * y
    xy = np.squeeze(xy)

    # Determine distances between points
    d = np.abs(np.subtract.outer(xy, xy))
    # % Determine weights for interpolation
    g = np.square(d) * (np.log(d) - 1)  # % Green's function.
    # Fixup value of Green's function along diagonal
    np.fill_diagonal(g, 0)
    weights = np.linalg.lstsq(g, v)[0]

    m, n = xq.shape
    vq = np.zeros(xq.shape)

    # Evaluate at requested points (xq,yq).  Loop to save memory.
    for i in range(m):
        for j in range(n):
            d = np.abs(xq[i, j] + 1j * yq[i, j] - xy)
            g = np.square(d) * (np.log(d) - 1)
            # Value of Green's function at zero
            g[np.where(np.isclose(d, 0))] = 0
            vq[i, j] = (np.expand_dims(g, axis=0) @ np.expand_dims(weights, axis=1))[0][0]
    return xq, yq, vq


def eeg_topoplot(icawinv: np.array, Th: np.array, Rd: np.array, plotchans: np.array = None) -> np.array_equal:
    """
    Generates topoplot image for ICLabel

    Args:
        icawinv (np.array): pinv(EEG.icaweights*EEG.icasphere);
        Th (np.array): Theta coordinates of electrodes (polar)
        Rd (np.array): Rho coordinates of electrodes (polar)
        plotchans (np.array): plot channels

    Returns:
        np.array_equal: Heatmap values (32 x 32 image)
    """
    GRID_SCALE = 32
    RMAX = 0.5

    Th = Th * np.pi / 180
    allchansind = np.array(list(range(Th.shape[1])))
    intchans = np.arange(0, len(plotchans))
    x, y = pol2cart(Th, Rd)
    # allchansind = allchansind[plotchans]
    
    Rd = Rd[:, plotchans]
    x = x[:, plotchans]
    y = y[:, plotchans]

    intx = x[:, intchans]
    inty = y[:, intchans]
    icawinv = icawinv[plotchans]
    intValues = icawinv[intchans]

    plotrad = min(1.0, np.max(Rd) * 1.02)

    # Squeeze channel locations to <= RMAX
    squeezefac = RMAX / plotrad
    inty *= squeezefac
    intx *= squeezefac

    xi = np.linspace(-0.5, 0.5, GRID_SCALE)
    yi = np.linspace(-0.5, 0.5, GRID_SCALE)

    XQ, YQ = np.meshgrid(xi, yi)
    
    # Do interpolation with v4 scheme from MATLAB
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        Xi, Yi, Zi = gdatav4(inty, intx, intValues, YQ, XQ)

    mask = np.sqrt(np.power(Xi, 2) + np.power(Yi, 2)) > RMAX

    Zi[mask] = np.nan

    return Zi.T

In [98]:
i = 10
plotchans = np.squeeze(np.argwhere(~np.isnan(np.squeeze(Th))))
topo = eeg_topoplot(icawinv = icawinv[:, i:i+1], Th=Th, Rd=Rd, plotchans = plotchans)

In [8]:
trials

80

In [9]:
features = eeg_features(icaact = icaact, 
                        trials = trials, 
                        srate = srate, 
                        pnts=pnts, 
                        subset = None,
                        icaweights = icaweights,
                        icawinv = icawinv,
                        Th = Th,
                        Rd = Rd)

In [17]:
topo = features[0].astype(np.float32)
psds = features[1].astype(np.float32)
autocorr = features[2].astype(np.float32)

  autocorr = features[2].astype(np.float32)


In [11]:
from ICLabel import run_iclabel

In [18]:
print(topo.shape, psds.shape, autocorr.shape)
print(topo.dtype, psds.dtype, autocorr.dtype)

(32, 32, 1, 30) (1, 100, 1, 30) (1, 100, 1, 30)
float32 float32 float32


In [20]:
labels = run_iclabel(topo, psds, autocorr)

In [22]:
# Print out
np.set_printoptions(precision=4)
np.set_printoptions(suppress=True)

In [23]:
print(labels)

[[0.5302 0.001  0.032  0.0015 0.0273 0.0551 0.353 ]
 [0.425  0.0009 0.0089 0.0006 0.0691 0.0273 0.4683]
 [0.2554 0.0001 0.0077 0.003  0.2128 0.0096 0.5113]
 [0.3174 0.002  0.0597 0.0033 0.2016 0.0884 0.3275]
 [0.1188 0.0014 0.0022 0.0071 0.5116 0.0388 0.32  ]
 [0.1053 0.0006 0.0016 0.0015 0.4571 0.0065 0.4273]
 [0.2314 0.0004 0.1386 0.0152 0.0548 0.043  0.5165]
 [0.0536 0.0033 0.0844 0.0277 0.5242 0.0086 0.2982]
 [0.341  0.0001 0.023  0.0053 0.089  0.0181 0.5236]
 [0.1197 0.0002 0.0768 0.0038 0.1132 0.0152 0.6712]
 [0.6001 0.001  0.0084 0.0045 0.0487 0.1488 0.1885]
 [0.3207 0.0006 0.0004 0.002  0.3334 0.0035 0.3395]
 [0.049  0.0003 0.0021 0.0019 0.2206 0.0037 0.7224]
 [0.1791 0.0007 0.021  0.0009 0.06   0.0099 0.7284]
 [0.2751 0.0003 0.0142 0.0183 0.2983 0.0146 0.3792]
 [0.5767 0.0001 0.0021 0.0004 0.4122 0.0006 0.008 ]
 [0.2947 0.0004 0.0034 0.0068 0.2957 0.0007 0.3984]
 [0.7006 0.0011 0.007  0.0021 0.1637 0.0034 0.1222]
 [0.8109 0.001  0.004  0.0011 0.04   0.007  0.1361]
 [0.4687 0.0