In [1]:
import numpy as np
import tensorflow as tf

In [2]:
def snr(signal, recon):
    """Returns signal-noise ratio in dB."""
    ratio = np.var(signal)/np.var(signal-recon)
    return 10*np.log10(ratio)
    
# dynamic compressive gammachirp
def dcGC(t,f):
    """Dynamic compressive gammachirp filter as defined by Irino,
    with parameters from Park as used in Charles, Kressner, & Rozell.
    The log term is regularized to log(t + 0.00001).
    t : time in seconds, greater than 0
    f : characteristic frequency in Hz
    One but not both arguments may be numpy arrays.
    """
    ERB = 0.1039*f + 24.7
    return t**3 * np.exp(-2*np.pi*1.14*ERB*t) * np.cos(2*np.pi*f*t + 0.979*np.log(t+0.000001))

# adapted from scipy cookbook
lowcut = 100
highcut = 6000
def butter_bandpass(lowcut, highcut, fs, order=5):
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    b, a = scisig.butter(order, [low, high], btype='band')
    return b, a

def butter_bandpass_filter(data, lowcut, highcut, fs, order=5):
    b, a = butter_bandpass(lowcut, highcut, fs, order=order)
    y = scisig.lfilter(b, a, data)
    return y
    
def plot_spikegram( spikes, sample_rate, markerSize = .0001 ):
    """adapted from https://github.com/craffel/spikegram-coding/blob/master/plotSpikeGram.py"""
    nkernels = spikes.shape[0]
    indices = np.transpose(np.nonzero(spikes))
    scalesKernelsAndOffsets = [(spikes[idx[0],idx[1]], idx[0], idx[1]) for idx in indices]
    
    for scale, kernel, offset in scalesKernelsAndOffsets:
        # Put a dot at each spike location.  Kernels on y axis.  Dot size corresponds to scale
        plt.plot( offset/sample_rate, nkernels-kernel, 'k.', 
                 markersize=markerSize*np.abs( scale ) )
    plt.title( "Spikegram" )
    plt.xlabel( "Time (s)" )
    plt.ylabel( "Kernel" )
    plt.axis( [0.0, spikes.shape[1]/sample_rate, 0.0, nkernels] )
    plt.show()

In [4]:
class SignalSet:
    
    def __init__(self, sample_rate = 16000, data = '../Data/TIMIT/'):
        self.sample_rate = sample_rate
        if isinstance(data, str):
            self.load_from_folder(data)
        else:
            self.data = data
            self.ndata = len(data)            
            
    def load_from_folder(self, folder = '../Data/TIMIT/'):
        min_length = 800 # TODO: should not be hard-coded
        files = os.listdir(folder)
        file = None
        self.data = []
        for ff in files:
            if ff.endswith('.wav'):
                file = os.path.join(folder,ff)
                rate, signal = wavfile.read(file)
                if rate != self.sample_rate:
                    raise NotImplementedError('The signal in ' + ff +
                    ' does not match the given sample rate.')
                if signal.shape[0] > min_length:
                    # bandpass
                    signal = signal/signal.std()
                    signal = butter_bandpass_filter(signal, lowcut, highcut,
                                                    self.sample_rate, order=5)
                    self.data.append(signal)
        self.ndata = len(self.data)
        print("Found ", self.ndata, " files")
        
    def rand_stim(self):
        """Get one random signal."""
        which = np.random.randint(low=0, high=self.ndata)
        signal = self.data[which]
        signal /= np.max(signal) # as in Smith & Lewicki
        return signal
        
    def write_sound(self, filename, signal):
        signal /= np.max(signal)
        wavfile.write(filename, self.sample_rate, signal)
        
    def tiled_plot(self, stims):
        """Tiled plots of the given signals. Zeroth index is which signal.
        Kind of slow, expect about 10s for 100 plots."""
        nstim = stims.shape[0]
        plotrows = int(np.sqrt(nstim))
        plotcols = int(np.ceil(nstim/plotrows))
        f, axes = plt.subplots(plotrows, plotcols, sharex=True, sharey=True)
        for ii in range(nstim):
            axes.flatten()[ii].plot(stims[ii])
        f.subplots_adjust(hspace=0, wspace=0)
        plt.setp([a.get_xticklabels() for a in f.axes[:-1]], visible=False)
        plt.setp([a.get_yticklabels() for a in f.axes[:-1]], visible=False)

In [13]:
class MatchingPursuer:
    
    def __init__(self,
                 data = '../Data/TIMIT/',
                 data_dim=1,
                 nunits = 32,
                 filter_time = 0.05,
                 learn_rate = 0.01,
                 thresh = 0.5,
                 normed_thresh = None,
                 max_iter = 100,
                 min_spike = 0.01,
                 mask_epsilon = None,
                 sample_rate = 16000,
                 paramfile= 'dummy'):    
        
        self.thresh = thresh
        self.min_spike = min_spike
        self.sample_rate = sample_rate
        self.nunits = nunits
        self.lfilter = int(filter_time * self.sample_rate)
        self.normed_thresh = normed_thresh or 2/np.sqrt(self.lfilter)
        self.mask_epsilon = mask_epsilon or 0.01*np.sqrt(1/self.lfilter)
        self.max_iter = max_iter
        self.data_dim = data_dim
        
    def initial_filters(self, gammachirp=False):
        """If 1D, Return either a set of gammachirp filters or random (normal) filters,
        not normalized. Otherwise return Gaussian noise."""
        if self.data_dim==1:
            if gammachirp:
                gammachirps = np.zeros([self.nunits, self.lfilter])
                freqs = np.logspace(np.log10(100), np.log10(6000), self.nunits)
                times = np.linspace(0,self.lfilter/self.sample_rate,self.lfilter)
                for ii in range(self.nunits):
                    gammachirps[ii] = dcGC(times, freqs[ii])
                filters= gammachirps        
            else:
                filters = tf.random_normal([self.nunits, self.lfilter])
            return tf.expand_dims(filters,2)
        elif self.data_dim>2:
            normal = tf.random_normal([self.nunits, self.lfilter, self.nfreqs])
            return normal

In [22]:
g = tf.Graph()

self=MatchingPursuer()

with g.as_default():
    
    x = tf.placeholder(tf.float32, shape = [1,None,self.data_dim,1])
    
    phi = tf.Variable(self.initial_filters())
    phi_for_conv = tf.transpose(phi, [1,2,0])
    phi_for_conv = tf.expand_dims(phi_for_conv,2)
    rev_phi = tf.reverse(phi, dims=[False, True, False])
    
    with tf.variable_scope('inference'):
        def while_body(kk, winning_val, coeffs, resid, error):
            convs = tf.nn.convolution(resid,
                                      phi_for_conv,
                                      padding="SAME", name='convolutions')
            winning_val = tf.reduce_max(convs)
            winner = tf.argmax(convs)
            #coeffs = tf.select(convs == winning_val,
             #                  convs,
            #                 coeffs)
            update = tf.scatter_nd([winner], [winning_val], tf.shape(coeffs))
            coeffs = coeffs + update
            xhat = tf.convolution(coeffs,
                                  rev_phi,
                                  padding="SAME", name='reconstruction')
            resid = x - xhat
            error = tf.mean(tf.square(resid))
            return tf.add(kk, 1), winning_val, resid, error
        def while_cond(kk, winning_val, coeffs, resid, error):
            maxitercheck = kk < self.max_iter
            spikecheck = winning_val > self.min_spike
            return tf.logical_and(maxitercheck, spikecheck)
        kk = tf.constant(0)
        onespike = tf.constant(0.0)
        error = tf.constant(1.0)
        coeffs = tf.zeros_like(x) # WRONG
        resid = tf.identity(x)
        inf_loop = tf.while_loop(while_cond, while_body, [kk, onespike, coeffs, resid, error], back_prop=False)

ValueError: None values not supported.