Skip to content

Commit

Permalink
Merge pull request #109 from arokem/master
Browse files Browse the repository at this point in the history
Spectra for multi-dimensional time-series
  • Loading branch information
arokem committed Jan 17, 2013
2 parents b0bd96a + fed644c commit ef67b88
Showing 1 changed file with 19 additions and 9 deletions.
28 changes: 19 additions & 9 deletions nitime/analysis/spectral.py
Expand Up @@ -94,23 +94,26 @@ def psd(self):
else:
psd_len = NFFT / 2.0 + 1
dt = float
psd = np.empty((self.input.shape[0],
psd_len), dtype=dt)


#If multi-channel data:
if len(self.input.data.shape) > 1:
for i in xrange(self.input.data.shape[0]):
psd_shape = (self.input.shape[:-1] + (psd_len,))
flat_data = np.reshape(self.input.data, (-1,
self.input.data.shape[-1]))
flat_psd = np.empty((flat_data.shape[0], psd_len), dtype=dt)
for i in xrange(flat_data.shape[0]):
#'f' are the center frequencies of the frequency bands
#represented in the psd. These are identical in each iteration
#of the loop, so they get reassigned into the same variable in
#each iteration:
temp, f = tsa.mlab.psd(self.input.data[i],
temp, f = tsa.mlab.psd(flat_data[i],
NFFT=NFFT,
Fs=Fs,
detrend=detrend,
window=window,
noverlap=n_overlap)
psd[i] = temp.squeeze()
flat_psd[i] = temp.squeeze()
psd = np.reshape(flat_psd, psd_shape).squeeze()

else:
psd, f = tsa.mlab.psd(self.input.data,
Expand Down Expand Up @@ -187,9 +190,16 @@ def spectrum_multi_taper(self):
:func:`multi_taper_csd'
"""
if np.iscomplexobj(self.input.data):
psd_len = self.input.shape[-1]
dt = complex
else:
psd_len = self.input.shape[-1] / 2 + 1
dt = float

#Initialize the output
spectrum_multi_taper = np.empty((self.input.shape[0],
self.input.shape[-1] / 2 + 1))
spectrum_multi_taper = np.empty((self.input.shape[:-1] + (psd_len,)),
dtype=dt)

#If multi-channel data:
if len(self.input.data.shape) > 1:
Expand Down Expand Up @@ -304,7 +314,7 @@ def filtfilt(self, b, a, in_ts=None):
#filtfilt only operates channel-by-channel, so we need to loop over the
#channels, if the data is multi-channel data:
if len(data.shape) > 1:
out_data = np.empty(data.shape)
out_data = np.empty(data.shape, dtype=data.dtype)
for i in xrange(data.shape[0]):
out_data[i] = signal.filtfilt(b, a, data[i])
#Make sure to preserve the DC:
Expand Down

0 comments on commit ef67b88

Please sign in to comment.