# Setup
This notebook is basically a mix of make_dataset.py from ../snn with encoding using delta modulation and adding waveform data 

In [20]:
import argparse
import speech2spikes
import torchaudio
import torch
import random
import numpy as np
import os
import pywt
import noisereduce

import snntorch as snn
from snntorch import spikeplot as splt
from snntorch import spikegen

import torch
import torch.nn as nn
import matplotlib.pyplot as plt

from torch.utils.data import Dataset, DataLoader

import numpy as np

import matplotlib.pyplot as plt

%matplotlib inline

In [185]:
# define variables for encoding
MODE = 'spec' # spec, dwt
DATASET_CAP = 40

SPEC_FREQ_BIN_COUNT = 25 # for spec mode

PATH_GUNSHOT_SOUNDS = '/home/joao/dev/MLAudio/shotspotter/data/gunshotsNew'
PATH_GUNSHOT_INDEX = '/home/joao/dev/MLAudio/shotspotter/data/gunshotsNewIndex.csv'
PATH_NOGUNSHOT_SOUNDS = '/home/joao/dev/MLAudio/shotspotter/data/genBackgrounds'

In [192]:
gunshot_file_paths = [PATH_GUNSHOT_SOUNDS+'/'+fn for fn in os.listdir(PATH_GUNSHOT_SOUNDS)][:DATASET_CAP//2]
print(f'We have {len(gunshot_file_paths)} gunshot audio files')
nogunshot_file_paths = [PATH_NOGUNSHOT_SOUNDS+'/'+fn for fn in os.listdir(PATH_NOGUNSHOT_SOUNDS)][:DATASET_CAP//2]
print(f'We have {len(nogunshot_file_paths)} background only audio files')

p1 = [(i, 1) for i in gunshot_file_paths]
p2 = [(i, 0) for i in nogunshot_file_paths]

pairs = p1+p2 # path to sound - label tuples
random.shuffle(pairs)

We have 20 gunshot audio files
We have 20 background only audio files


# Encoding

In [210]:
# input is shape time x batch x channels (from spectrogram or dwt spec)
def posneg_delta(raw_spec_data):
    delta = spikegen.delta(raw_spec_data, threshold=0.001, off_spike=True)
    
    new_data = torch.zeros(delta.shape[0], delta.shape[1], delta.shape[2]*2) # 2 channels for each freq channel
    for i in range(delta.shape[1]): # for each sample
        for timestep in range(delta.shape[0]):
            for channel in range(delta.shape[2]):
                if delta[timestep, i, channel] == 1:
                    new_data[timestep, i, channel] = 1
                elif delta[timestep, i, channel] == -1:
                    new_data[timestep, i, channel+delta.shape[2]] = 1;

    return new_data

def to_spikes(paths_list, labels):
    if MODE == 'spec':
        all_spikes = []
        targets = np.array(labels)

        for p in paths_list:
            samples, rate = torchaudio.load(p, normalize=True)
            
            if samples.shape[0] == 2: samples = samples[0, :]
            else: samples = samples[0]
            if(len(samples) < 24000):
                samples = torch.cat((samples, torch.tensor([0])))

            #plt.plot(np.linspace(0, len(samples), len(samples)), samples)
            
            samples = torch.tensor(noisereduce.reduce_noise(y=samples, sr=rate)) # testing this because I had it on in the ResNet version dataset

            #plt.plot(np.linspace(0, len(samples), len(samples)), samples)

            # freq bin count is nfft//2 + 1
            spec_transform = torchaudio.transforms.Spectrogram(n_fft=(2*SPEC_FREQ_BIN_COUNT-2))

            samples = samples.to(torch.float64)
            spec = spec_transform(samples)

            # convert waveform to spikes
            timesteps = spec.shape[1]
            waveform_timestep_len = (24000//timesteps) + 1

            ts_waveform = [] # timesteps but compressed to time resolution of spectrogram output
            current_t = waveform_timestep_len
            while current_t <= 24000:
                ts_waveform.append(samples[current_t-waveform_timestep_len: current_t].mean())
                current_t += waveform_timestep_len

            if len(ts_waveform) < timesteps: # pad so dimensions match
                ts_waveform.append(0) 

            # normalize because amplitude scales vary a lot (and we did this in resnet accidentally)
            ts_waveform = torch.tensor(ts_waveform)
            ts_waveform = (ts_waveform - ts_waveform.min()) / (ts_waveform.max() - ts_waveform.min())

            # debug
            #plt.plot(np.linspace(0, len(samples), len(samples)), samples)
            #plt.plot(np.linspace(0, len(ts_waveform), len(ts_waveform)), ts_waveform)

            # convert compressed waveform into spikes using bins
            waveform_spikes = []
            num_bins = 20
            bin_size = 1/num_bins
            for w in ts_waveform:
                waveform_spikes.append([0 for i in range(num_bins)])
                waveform_spikes[-1][int(w//bin_size)] = 1

            waveform_spikes = torch.tensor(waveform_spikes)
            splt.raster()

            ## CONITNUE HERE
            break
            all_spikes.append(spec)

        # same as dwt
        global_min = 0
        global_max = 0
        for s in all_spikes:
            if s.min() < global_min: global_min = s.min()
            if s.max() > global_max: global_max = s.max()

        # now normalize everything
        for i in range(len(all_spikes)):
            all_spikes[i] = (all_spikes[i]-global_min) / (global_max - global_min)

In [211]:
to_spikes([i[0] for i in pairs], [i[1] for i in pairs])

torch.Size([1001, 20])
