Skip to content

Commit

Permalink
Add a wrapper around tfrview and tfrqview
Browse files Browse the repository at this point in the history
  • Loading branch information
jaidevd committed Feb 25, 2016
1 parent 3ddc560 commit a7c07bd
Showing 1 changed file with 181 additions and 2 deletions.
183 changes: 181 additions & 2 deletions tftb/processing/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,17 @@

import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d

TYPE1 = ["pseudo margenau-hill", "spectrogram", "pseudo page", "margenau-hill",
"reassigned morlet scalogram", "gabor", "morlet scalogram",
"reassigned pseudo margenau-hill", "reassigned spectrogram",
"reassinged pseudo page", "reassigned gabor", "page", "rihaczek"]
TYPE2 = ["wigner-ville", "smoothed pseudo wigner-ville",
"reassigned smoothed psedo wigner-ville", "d-flandrin", "bertrand",
"unterberger", "pseudo wigner-ville", "reassigned psuedo wigner-ville",
"scalogram"]
AFFINE = ["d-flandrin", "unterberger", "bertrand", "scalogram"]


class BaseTFRepresentation(object):
Expand Down Expand Up @@ -51,6 +62,14 @@ def __init__(self, signal, **kwargs):
self.freqs = freqs.astype(float) / self.n_fbins
self.tfr = np.zeros((self.n_fbins, self.ts.shape[0]), dtype=complex)

@property
def has_negative_frequencies(self):
return self.name.lower() in TYPE1

@property
def _isaffine(self):
return self.name in AFFINE

def _get_spectrum(self):
if not self.isaffine:
return np.fft.fftshift(np.abs(np.fft.fft(self.signal)) ** 2)
Expand Down Expand Up @@ -89,7 +108,7 @@ def _plot_tfr(self, ax, kind, extent, contour_x=None, contour_y=None,
contour_x = self.ts
if contour_y is None:
if show_tf:
if self.isaffine:
if self.isaffine or self.name == "scalogram":
contour_y = np.linspace(self.fmin, self.fmax, self.n_voices)
else:
contour_y = np.linspace(0, 0.5, self.signal.shape[0])
Expand Down Expand Up @@ -215,7 +234,7 @@ def plot(self, ax=None, kind='cmap', show=True, default_annotation=True,
if default_annotation:
ax.set_zlabel("Amplitude")
elif kind == "wireframe":
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d import Axes3D # noqa
ax = fig.gca(projection="3d")
x = np.arange(self.signal.shape[0])
y = np.linspace(0, 0.5, self.signal.shape[0])
Expand All @@ -233,3 +252,163 @@ def plot(self, ax=None, kind='cmap', show=True, default_annotation=True,
ax.set_title(self.name.upper())
if show:
plt.show()

def __tfrqview__(self, tfr=None, sig=None, t=None, method=None, **kwargs):
if tfr is None:
tfr = self.tfr
tfrrow, tfrcol = tfr.shape
if t is None:
t = self.ts
if method is None:
method = "type1"
if sig is None:
sig = self.signal
if self._isaffine:
try:
freq = kwargs['freq']
except KeyError:
raise KeyError("freq must be supplied to __tfrqview__ for affine reps.")
else:
freq = 0.5 * np.arange(nf2) / nf2 # noqa

# Test of analyticity
# FIXME: This test could be used for the unit testing too.
lt_fog = t.max() - t.min() + 1
nb_tranches_fog = np.floor(lt_fog / tfrrow)
spec = np.zeros((tfrrow,))
for i in range(nb_tranches_fog):
_add = np.abs(np.fft.fft(sig[t.min() + tfrrow * i + np.arange(tfrrow)]))
spec += _add ** 2
if lt_fog > (nb_tranches_fog * tfrrow):
sig_slice = sig[(t.min() + tfrrow * nb_tranches_fog):t.max()]
spectre_fog = np.fft.fft(sig_slice, tfrrow)
spec += np.abs(spectre_fog) ** 2
spec1 = np.sum(spec[:(tfrrow / 2)])
spec2 = np.sum(spec[(tfrrow / 2):])
if spec2 > spec1 / 10.0:
import warnings
warnings.warn("The signal is not analytic", UserWarning)

tfr = np.real(tfr)

def __tfrview__(self, tfr=None, sig=None, t=None, method=None, kind="cmap",
scale="linear", threshold=0.05, n_levels=64, nf2=None, fs=1.0,
fmin=0.0, fmax=None, show_tf=True, **kwargs):
if tfr is None:
tfr = self.tfr
tfrrow, tfrcol = tfr.shape
if t is None:
t = self.ts
if method is None:
method = "type1"
if sig is None:
sig = self.signal
maxi = np.amax(tfr)

# default params
if nf2 is None:
if self.has_negative_frequencies:
nf2 = tfrrow / 2
else:
nf2 = tfrrow
if fmax is None:
fmax = fs * fmin

# computation of isaffine and freq
if self._isaffine:
if "freq" not in kwargs:
raise ValueError("Freq required for affine methods.")
freq = kwargs.pop('freq')
else:
freq = 0.5 * np.arange(nf2) / nf2
freqr = freq * fs
ts = t / fs

# update mini, levels, linlogstr, etc
if scale == "linear":
if kind in ("surf", "mesh"):
mini = np.amin(tfr)
else:
mini = np.max([np.amin(tfr), maxi * threshold])
levels = np.linspace(mini, maxi, n_levels + 1)
elif scale == "log":
mini = np.max([np.amin(tfr), maxi * threshold])
levels = np.logspace(np.log10(mini), np.log10(maxi), n_levels + 1)

# test of analyticity and computation of spec
alpha = 2
lt = t.max() - t.min() + 1
if 2 * nf2 >= lt:
spec = np.abs(np.fft.fft(sig[t.min():t.max()], 2 * nf2)) ** 2
else:
nb_tranches_fog = np.floot(lt / (2 * nf2))
spec = np.zeros((2 * nf2,))
for i in range(nb_tranches_fog):
_spec = np.fft.fft(sig[t.min() + 2 * nf2 * i + np.arange(2 * nf2)])
spec += np.abs(_spec) ** 2
if lt > nb_tranches_fog * 2 * nf2:
start = t.min() + 2 * tfrrow * nb_tranches_fog
spectre_fog = np.fft.fft(sig[start:t.max()], 2 * nf2)
spec += np.abs(spectre_fog) ** 2
spec1 = np.sum(spec[:nf2])
spec2 = np.sum(spec[nf2:])
if spec2 > 0.1 * spec1:
if not np.isreal(sig):
alpha = 1

if show_tf:
if self._isaffine:
f1 = freqr[0]
f2 = freqr[nf2 - 1]
d = f2 - f1
nf4 = np.round((nf2 - 1) * fs / (2 * d)) + 1
spec[:alpha * nf4] = np.abs(np.fft.fft(sig[t.min():t.max()],
alpha * nf4)) ** 2
start = np.round(f1 * 2 * (nf4 - 1) / fs + 1)
stop = np.round(f1 * 2 * (nf4 - 1) / fs + nf2)
spec = spec[start:stop]
freqs = np.linspace(f1, f2, nf2)
else:
freqs = freqr
spec = spec[:nf2]
maxsp = np.amax(spec)

# the axis here is the spectrum axis
if scale == "linear":
plt.plot(freqs, spec)
plt.ylim(maxsp * threshold * 0.01, maxsp * 1.2)
elif scale == "log":
plt.plot(freqs, 10 * np.log10(spec / maxsp))
plt.ylim(10 * np.log10(threshold), 0)
plt.xlim(fmin, fmax)
if self._isaffine:
freqs = np.linspace(freqr[0], freqr[nf2 - 1], nf2)
spec = interp1d(0.5 * fs * np.arange(nf2) / nf2,
spec[:nf2])(freqs)
else:
freqs = freqr
spec = spec[:nf2]
maxsp = np.amax(spec)
if scale == "linear":
plt.ylim(maxsp * threshold, maxsp * 1.2)
elif scale == "log":
plt.ylim(10 * np.log10(threshold), 0)
plt.xlim(fmin * fs, fmax * fs)

# the axis here is the signal axis
plt.plot(np.arange(t.min(), t.max()) / fs,
np.real(sig[t.min():t.max()]))
plt.axis([t.min(), t.max(), np.amin(np.real(sig)),
np.amax(np.real(sig))])

# The axis here is the TFR axis
# for contour
plt.contour(ts, freqr, tfr, levels)
plt.ylim(fmin * fs, fmax * fs)

# for images
if scale == "linear":
plt.imshow(ts, freqr, tfr)
else:
plt.imshow(ts, freqr, np.log10(tfr))
plt.ylim(fmin * fs, fmax * fs)

0 comments on commit a7c07bd

Please sign in to comment.