diff --git a/nitime/analysis/spectral.py b/nitime/analysis/spectral.py index 058be7fd..36b4bd4d 100644 --- a/nitime/analysis/spectral.py +++ b/nitime/analysis/spectral.py @@ -94,23 +94,26 @@ def psd(self): else: psd_len = NFFT / 2.0 + 1 dt = float - psd = np.empty((self.input.shape[0], - psd_len), dtype=dt) - + #If multi-channel data: if len(self.input.data.shape) > 1: - for i in xrange(self.input.data.shape[0]): + psd_shape = (self.input.shape[:-1] + (psd_len,)) + flat_data = np.reshape(self.input.data, (-1, + self.input.data.shape[-1])) + flat_psd = np.empty((flat_data.shape[0], psd_len), dtype=dt) + for i in xrange(flat_data.shape[0]): #'f' are the center frequencies of the frequency bands #represented in the psd. These are identical in each iteration #of the loop, so they get reassigned into the same variable in #each iteration: - temp, f = tsa.mlab.psd(self.input.data[i], + temp, f = tsa.mlab.psd(flat_data[i], NFFT=NFFT, Fs=Fs, detrend=detrend, window=window, noverlap=n_overlap) - psd[i] = temp.squeeze() + flat_psd[i] = temp.squeeze() + psd = np.reshape(flat_psd, psd_shape).squeeze() else: psd, f = tsa.mlab.psd(self.input.data, @@ -187,9 +190,16 @@ def spectrum_multi_taper(self): :func:`multi_taper_csd' """ + if np.iscomplexobj(self.input.data): + psd_len = self.input.shape[-1] + dt = complex + else: + psd_len = self.input.shape[-1] / 2 + 1 + dt = float + #Initialize the output - spectrum_multi_taper = np.empty((self.input.shape[0], - self.input.shape[-1] / 2 + 1)) + spectrum_multi_taper = np.empty((self.input.shape[:-1] + (psd_len,)), + dtype=dt) #If multi-channel data: if len(self.input.data.shape) > 1: @@ -304,7 +314,7 @@ def filtfilt(self, b, a, in_ts=None): #filtfilt only operates channel-by-channel, so we need to loop over the #channels, if the data is multi-channel data: if len(data.shape) > 1: - out_data = np.empty(data.shape) + out_data = np.empty(data.shape, dtype=data.dtype) for i in xrange(data.shape[0]): out_data[i] = signal.filtfilt(b, a, data[i]) #Make sure to preserve the DC: