Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: single class to represent multidimensional data #352

Closed
wants to merge 12 commits into from
11 changes: 11 additions & 0 deletions examples/ndvar/load_stc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from mne import ndvar
reload(mne.dimensions)
reload(ndvar)

fname = '/Users/christian/Documents/Eclipse/projects/mne-python/examples/MNE-sample-data/MEG/sample/sample_audvis-meg'
stc = mne.read_source_estimate(fname)

stcs = ndvar.from_stc([stc for _ in xrange(10)])
stc = ndvar.from_stc(stc)

stc.subdata(time=(0.1, 0.2))
79 changes: 79 additions & 0 deletions examples/ndvar/plot_label_stcs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import numpy as np
import pylab as pl
import mne
from mne.datasets import sample
from mne.fiff import Raw, pick_types
from mne.minimum_norm import apply_inverse_epochs, read_inverse_operator


data_path = sample.data_path('..')
fname_inv = data_path + '/MEG/sample/sample_audvis-meg-oct-6-meg-inv.fif'
fname_raw = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw.fif'
fname_event = data_path + '/MEG/sample/sample_audvis_filt-0-40_raw-eve.fif'
label_name = 'Aud-lh'
fname_label = data_path + '/MEG/sample/labels/%s.label' % label_name

event_id, tmin, tmax = 1, -0.2, 0.5
snr = 1.0 # use smaller SNR for raw data
lambda2 = 1.0 / snr ** 2
method = "dSPM" # use dSPM method (could also be MNE or sLORETA)

# Load data
inverse_operator = read_inverse_operator(fname_inv)
label = mne.read_label(fname_label)
raw = Raw(fname_raw)
events = mne.read_events(fname_event)

# Set up pick list
include = []
exclude = raw.info['bads'] + ['EEG 053'] # bads + 1 more

# pick MEG channels
picks = pick_types(raw.info, meg=True, eeg=False, stim=False, eog=True,
include=include, exclude=exclude)
# Read epochs
epochs = mne.Epochs(raw, events, event_id, tmin, tmax, picks=picks,
baseline=(None, 0), reject=dict(mag=4e-12, grad=4000e-13,
eog=150e-6))

# Compute inverse solution and stcs for each epoch
stcs = apply_inverse_epochs(epochs, inverse_operator, lambda2, method, label,
pick_normal=True)


##############################################################################
from mne import ndvar
reload(mne.dimensions)
reload(ndvar)
Y = ndvar.from_stc(stcs)
Y = ndvar.resample(Y, 50)

pl.figure()
# extract the mean in a label
Ylbl = Y.summary(source=label)

# plot all cases; don't worry about the axes in the object, just be explicit:
pl.plot(Ylbl.time.times, Ylbl.get_data(('time', 'case')), color=(.5, .5, .5))

# plot the mean across cases
Ylblm = Ylbl.summary('case')
pl.plot(Ylblm.time.times, Ylblm.get_data('time'), 'r-', linewidth=2)

# plot a spcific case
pl.plot(Ylblm.time.times, Ylbl[1].get_data('time'), color=(1, .5, 0))

# or write a plot-function
def plot_uts(Y, **kwargs):
if Y.has_case:
x = Y.get_data(('time', 'case'))
else:
x = Y.get_data(('time',))
pl.plot(Y.time.times, x, **kwargs)
pl.xlabel('Time (s)')

pl.figure()
plot_uts(Ylbl, color='gray')
plot_uts(Ylbl[1], color='orange')
plot_uts(Ylbl.summary('case'), color='red')

pl.show()
248 changes: 248 additions & 0 deletions mne/dimensions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
'''
Created on Oct 25, 2012

@author: christian
'''
import numpy as np

from label import Label, BiHemiLabel



class DimensionMismatchError(Exception):
pass



class Dimension():
def _dimrepr_(self):
return repr(self.name)




class UTS(Dimension):
"""Dimension object for representing uniform time series

Special Indexing
----------------

(tstart, tstop) : tuple
Restrict the time to the indicated window (either end-point can be
None).

"""
name = 'time'
def __init__(self, tmin, tstep, nsteps):
self.nsteps = nsteps = int(nsteps)
self.times = np.arange(tmin, tmin + tstep * (nsteps + 1), tstep)
self.tmin = tmin
self.tstep = tstep

def __repr__(self):
return "UTS(%s, %s, %s)" % (self.tmin, self.tstep, self.nsteps)

def _dimrepr_(self):
tmax = self.times[-1]
sfreq = 1. / self.tstep
r = '%r: %.3f - %.3f s, %s Hz' % (self.name, self.tmin, tmax, sfreq)
return r

def __len__(self):
return len(self.times)

def __getitem__(self, index):
if isinstance(index, int):
return self.times[index]
elif isinstance(index, slice):
if index.start is None:
start = 0
else:
start = index.start

if index.stop is None:
stop = len(self)
else:
stop = index.stop

tmin = self.times[start]
nsteps = stop - start - 1

if index.step is None:
tstep = self.tstep
else:
tstep = self.tstep * index.step
else:
times = self.times[index]
tmin = times[0]
nsteps = len(times)
steps = np.unique(np.diff(times))
if len(steps) > 1:
raise NotImplementedError("non-uniform time series")
tstep = steps[0]

return UTS(tmin, tstep, nsteps)

def dimindex(self, arg):
if np.isscalar(arg):
i, _ = find_time_point(self.times, arg)
return i
if isinstance(arg, tuple) and len(arg) == 2:
tstart, tstop = arg
if tstart is None:
start = None
else:
start, _ = find_time_point(self.times, tstart)

if tstop is None:
stop = None
else:
stop, _ = find_time_point(self.times, tstop)

s = slice(start, stop)
return s
else:
return arg



class SourceSpace(Dimension):
name = 'source'
"""
Indexing
--------

besides numpy indexing, the following indexes are possible:

- mne Label objects
- 'lh' or 'rh' to select an entire hemisphere

"""
def __init__(self, vertno, subject='fsaverage'):
"""
vertno : list of array
The indices of the dipoles in the different source spaces.
Each array has shape [n_dipoles] for in each source space]
subject : str
The mri-subject (used to load brain).

"""
self.vertno = vertno
self.lh_vertno = vertno[0]
self.rh_vertno = vertno[1]
self.lh_n = len(self.lh_vertno)
self.rh_n = len(self.rh_vertno)
self.subject = subject

def __repr__(self):
return "<dim source_space: %i (lh), %i (rh)>" % (self.lh_n, self.rh_n)

def __len__(self):
return self.lh_n + self.rh_n

def __getitem__(self, index):
vert = np.hstack(self.vertno)
hemi = np.zeros(len(vert))
hemi[self.lh_n:] = 1

vert = vert[index]
hemi = hemi[index]

new_vert = (vert[hemi == 0], vert[hemi == 1])
dim = SourceSpace(new_vert, subject=self.subject)
return dim

def dimindex(self, obj):
if isinstance(obj, (Label, BiHemiLabel)):
return self.label_index(obj)
elif isinstance(obj, str):
if obj == 'lh':
if self.lh_n:
return slice(None, self.lh_n)
else:
raise IndexError("lh is empty")
if obj == 'rh':
if self.rh_n:
return slice(self.lh_n, None)
else:
raise IndexError("rh is empty")
else:
raise IndexError('%r' % obj)
else:
return obj

def _hemilabel_index(self, label):
if label.hemi == 'lh':
stc_vertices = self.vertno[0]
base = 0
else:
stc_vertices = self.vertno[1]
base = len(self.vertno[0])

idx = np.nonzero(map(label.vertices.__contains__, stc_vertices))[0]
return idx + base

def label_index(self, label):
"""Returns the index for a label

Parameters
----------
label : Label | BiHemiLabel
The label (as created for example by mne.read_label). If the label
does not match any sources in the SourceEstimate, a ValueError is
raised.
"""
if label.hemi == 'both':
lh_idx = self._hemilabel_index(label.lh)
rh_idx = self._hemilabel_index(label.rh)
idx = np.hstack((lh_idx, rh_idx))
else:
idx = self._hemilabel_index(label)

if len(idx) == 0:
raise ValueError('No vertices match the label in the stc file')

return idx



def find_time_point(times, time):
"""
Returns (index, time) for the closest point to ``time`` in ``times``

times : array, 1d
Monotonically increasing time values.
time : scalar
Time point for which to find a match.

"""
if time in times:
i = np.where(times == time)[0][0]
else:
gr = (times > time)
if np.all(gr):
if times[1] - times[0] > times[0] - time:
return 0, times[0]
else:
name = repr(times.name) if hasattr(times, 'name') else ''
raise ValueError("time=%s lies outside array %r" % (time, name))
elif np.any(gr):
i_next = np.where(gr)[0][0]
elif times[-1] - times[-2] > time - times[-1]:
return len(times) - 1, times[-1]
else:
name = repr(times.name) if hasattr(times, 'name') else ''
raise ValueError("time=%s lies outside array %r" % (time, name))
t_next = times[i_next]

sm = times < time
i_prev = np.where(sm)[0][-1]
t_prev = times[i_prev]

if (t_next - time) < (time - t_prev):
i = i_next
time = t_next
else:
i = i_prev
time = t_prev
return i, time
Loading