Skip to content

Commit

Permalink
Updated wav_to_pulses_bands
Browse files Browse the repository at this point in the history
This version of code uses random frequency changes and was used for the pilot data in the methods section of the paper.
  • Loading branch information
mpolonenko committed Jan 11, 2021
1 parent 8b27e02 commit ca20e72
Showing 1 changed file with 49 additions and 38 deletions.
87 changes: 49 additions & 38 deletions code_stimulus/wav_to_pulses_bands.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,15 @@
story = 'alchemyst' # alchemyst (male narrator) or wrinkle (female narrator)
main_path = '/mnt/data/abr_peakyspeech/'
wav_in_path = main_path + 'audio_books/{}/'.format(story)
out_path = '/mnt/data/abr_peakyspeech/{}/'.format(story)
out_path = '/mnt/data/abr_peakyspeech/stimuli/{}/'.format(story)

start_file = 0
n_trials = 120

# %% Set up the bands and f_shifts

fs_filt = 44100
# fs_filt = 48000 # used a different fs for the pilot study
n_filt = int(fs_filt * 5e-3)
n_filt += ((n_filt + 1) % 2) # must be odd
freq = np.arange(n_filt) / float(n_filt) * fs_filt
Expand All @@ -68,27 +69,6 @@
n_f0 = n_ears * n_band # would just be n_ears for only broadband
n_f0_fake = 6 # number of fake pulse trains to calculate common component

# create a list of prime numbers
rmin = 2
rmax = 200
check = range(rmin, rmax)
primes = []
for i in check:
for j in range(rmin, i):
if not i % j:
break
else:
primes.append(i)
primes = primes[1:] # didn't originally use 2
del primes[8] # note: when making original files, 29 wasn't included
primes = np.array(primes)

f_shifts = 0.5 * (primes ** 0.5 - 1) # in Hz
f_shift_band = np.append(f_shifts[:n_f0 - 1 + int(n_f0_fake / 2)],
-f_shifts[:int(n_f0_fake / 2)])

assert(len(f_shift_band) == n_f0 - 1 + n_f0_fake) # left the first unshifted

# %% Make band filters


Expand Down Expand Up @@ -284,30 +264,52 @@ def amp_fun(f, top_width, trans_width):
xsg, ysg = np.meshgrid(t_sg, f_sg)
interp = RegularGridInterpolator([f_sg, t_sg], sg, method='linear',
bounds_error=False, fill_value=0)
phase_orig = np.copy(phase[0])

# %% make the new band phases and pulse_inds
print('make new band phases and pulse_inds')
for f0_ind in np.arange(1, n_f0 + n_f0_fake):
phase[f0_ind] = phase[0] + f_shift_band[f0_ind - 1] * (
2 * np.pi * np.arange(phase.shape[-1]) / float(fs))
# + np.random.rand() * 2 * np.pi) # random start phase for future
f0 = np.diff(phase, axis=-1) * fs / 2 / np.pi
f_shift_max = 1
f_shift_f_min = 0.0 # if 0 will use next highest frequency bin
f_shift_f_max = 0.05
f_shift_ind_min = int(np.maximum(1, np.round(f_shift_f_min * len(x) / fs)))
f_shift_ind_max = int(np.round(f_shift_f_max * len(x) / fs))
n_comp = f_shift_ind_max - f_shift_ind_min

f_shift_fft = np.zeros((n_f0 + n_f0_fake, len(x)), dtype=np.complex)
f_shift_fft[:, f_shift_ind_min:f_shift_ind_max] = np.exp(
1j * 2 * np.pi * np.random.rand(n_f0 + n_f0_fake, n_comp))
f_shift_fft[:, :-f_shift_ind_max:-1] = f_shift_fft[
:, 1:f_shift_ind_max].conj()
f_shift = ifft(f_shift_fft).real
f_shift *= f_shift_max / np.abs(f_shift).max(axis=-1, keepdims=True)

for f0_ind in np.arange(n_f0 + n_f0_fake):
phase[f0_ind] = phase_orig + (
2 * np.pi * np.cumsum(f_shift[f0_ind]) / float(fs)
+ np.random.rand() * 2 * np.pi)
f0 = np.diff(phase_orig, axis=-1) * fs / 2 / np.pi
t0 = np.arange(phase.shape[-1] - 1) / float(fs)

# split into true (for stimuli) & fake f0 (used to derive common component)
phase_fake_pulses = np.copy(phase)[n_f0:]
phase = phase[:n_f0]

print('Computing harmonic amplitude envelopes')
f_harm_max = np.minimum(fc_band[-1] * 2, fs / 2. - 1)
n_harm = int(np.floor(f_harm_max / f0_min))
amp0 = np.zeros((n_harm, len(x)))

for hi in range(n_harm):
# rather than interp at every point, take the pulse points and then
# cubic spline interpolate each of those
amp0[hi, :-1] = interp(np.transpose([f0[0] * (hi + 1), t0])) ** 0.5
# cubic spline interpolate each of those. would be faster.
f = f0 * (hi + 1)
fi = np.where(f < fs)[0]
amp0[hi, fi] = interp(np.transpose([f[fi], t0[fi]])) ** 0.5
amp0[:, np.isnan(amp0[0])] = 0
bplp, aplp = sig.butter(1, 70 / (fs / 2.))
amp0 = np.array([sig.filtfilt(bplp, aplp, amp) for amp in amp0])

print('Generating the harmonics')
x_harm = np.zeros((n_f0, x.shape[-1])) # only for real pulse trains
x_harm = np.zeros(phase.shape)
for f0_ind in range(n_f0):
for hi in range(n_harm):
x_harm[f0_ind] += np.nan_to_num(np.cos(phase[f0_ind] * (hi + 1))
Expand All @@ -330,9 +332,8 @@ def amp_fun(f, top_width, trans_width):
f0_ind = 0
for band_ind in range(n_band):
for ear_ind in range(n_ears):
x_harm_band[band_ind, ear_ind] = sig.fftconvolve(x_harm[f0_ind],
h_band[band_ind],
'same')
x_harm_band[band_ind, ear_ind] = sig.fftconvolve(
x_harm[f0_ind], h_band[band_ind], 'same')
f0_ind += 1
if n_ears == 1:
x_harm_band = x_harm_band.mean(1)
Expand All @@ -353,7 +354,9 @@ def amp_fun(f, top_width, trans_width):
sig.fftconvolve(x, h_single[1], 'same'))

pulse_inds = [np.where(np.diff(np.mod(phase[bi], 2 * np.pi)) < 0)[0]
for bi in range(n_f0 + n_f0_fake)]
for bi in range(n_f0)]
fake_pulse_inds = [np.where(np.diff(np.mod(
phase_fake_pulses[bi], 2 * np.pi)) < 0)[0] for bi in range(n_f0_fake)]

x_play_band *= flip_sign
x_play_single *= flip_sign
Expand Down Expand Up @@ -390,6 +393,7 @@ def amp_fun(f, top_width, trans_width):
fn.split('/')[-1][:-4] + '_bands' + name_band + '.hdf5',
dict(
pulse_inds=pulse_inds,
fake_pulse_inds=fake_pulse_inds,
pulse_inds_unfix=pulse_inds_unfix,
mixer=mixer,
x_harm_band=x_harm_band,
Expand All @@ -409,18 +413,25 @@ def amp_fun(f, top_width, trans_width):
h_single=h_single,
f_harm_max=f_harm_max,
n_harm=n_harm,
f_shift_band=f_shift_band,
f_shift=f_shift,
f_shift_f_min=f_shift_f_min,
f_shift_f_max=f_shift_f_max,
f_shift_ind_min=f_shift_ind_min,
f_shift_ind_max=f_shift_ind_max,
phase_orig=phase_orig,
phase=phase,
phase_fake_pulses=phase_fake_pulses,
n_f0=n_f0,
n_f0_fake=n_f0_fake,
slop=slop,
top_width=top_width,
trans_width=trans_width,
n_filt=n_filt,
phase=phase), overwrite=overwrite_file)
n_filt=n_filt), overwrite=overwrite_file)
write_hdf5(out_path + '{}_hdf5/hdf5_reduced/'.format(story) +
fn.split('/')[-1][:-4] + '_bands' + name_band +
'_reduced.hdf5', dict(
pulse_inds=pulse_inds,
fake_pulse_inds=fake_pulse_inds,
mixer=mixer,
x_play_band=x_play_band,
x_play_single=x_play_single,
Expand Down

0 comments on commit ca20e72

Please sign in to comment.