In [1]:
import IPython.display as ipd
from IPython.display import display
import glob
from ipywidgets import widgets, interactive
import scipy.io.wavfile
import matplotlib.pyplot as plt
import numpy as np
import sys
import glob
import os
import pickle as pckl

from IPython.display import Image
from IPython.core.display import HTML 

import pickle


In [2]:
models = ['/om4/group/mcdermott/user/jfeather/projects/pitch_network_metamers/model_analysis_folders/arch_0302_PND_v08_TLAS_snr_neg10pos10_AN_RSB_noise0000_classification0'
         ]

# example sounds that we have right now. 
sounds_idx = range(15)
layers_idx = range(24) # 10 different layers

# Path inside of each model directory -- can be changed if the optimization was different
metamer_path = 'metamers/pitchnet_inst_RS541_IA5000_IL0/'

# Name of the original file in the metamer directory
orig_name_glob = '*_audio_time_varying.wav'

pitch_network_layers = ['visualization_input', 'conv_0', 'conv_0_jittered_relu', 
                        'pool_0', 'batch_norm_0',
                        'conv_1', 'conv_1_jittered_relu', 
                        'pool_1', 'batch_norm_1', 
                        'conv_2', 'conv_2_jittered_relu', 
                        'pool_2', 'batch_norm_2', 
                        'conv_3', 'conv_3_jittered_relu', 
                        'pool_3', 'batch_norm_3', 
                        'conv_4', 'conv_4_jittered_relu', 
                        'pool_4', 'batch_norm_4', 
                        'flatten_end_conv', 
                        'fc_intermediate', 'fc_intermediate_jittered_relu', 
                        'batch_norm_fc_intermediate', 'fc_top']


In [3]:
def display_examples(model, 
                     sound_idx, 
                     layer_name):
    
    sound_folder = glob.glob(os.path.join('model_analysis_folders', model, metamer_path,
                             '%d_SOUND'%sound_idx))
    assert len(sound_folder)==1
    sound_folder = sound_folder[0]
    
    orig_glob = glob.glob(os.path.join(sound_folder, orig_name_glob))
    orig_path = orig_glob[0]
    
    metamer_wav_path = glob.glob(
            os.path.join(sound_folder, '*full_%s*synth.wav'%layer_name))[0]
    metamer_pckl_path = glob.glob(
            os.path.join(sound_folder, '*full_%s*.pckl'%layer_name))[0]
    
    metamer_pckl = pickle.load(open(metamer_pckl_path, 'rb'))
    
    print('Orig')
    ipd.display(ipd.Audio(orig_path))
    print('Synth')
    ipd.display(ipd.Audio(metamer_wav_path))
     
    plt.figure(figsize=(10,4))
    plt.subplot(2,1,1)
    plt.imshow(np.squeeze(metamer_pckl['orig_coch_time_varying']), origin='lower', interpolation='none')
    plt.title('Subbands Orig')
    plt.subplot(2,1,2)
    plt.imshow(np.squeeze(metamer_pckl['synth_coch']), origin='lower', interpolation='none')
    plt.title('Subbands Synthetic')
    plt.show()
    
#     plt.figure(figsize=(6,3))
#     plt.subplot(1,2,1)
#     plt.scatter(metamer_pckl['features_out_time_varying_synth'][layer_name], 
#                 metamer_pckl['features_out_time_varying_orig'][layer_name])
#     plt.xlabel('Synth Activations')
#     plt.ylabel('Orig Activations')
#     plt.title('Activations at \n Layer %s'%layer_name)
#     plt.subplot(1,2,2)
#     plt.scatter(metamer_pckl['features_out_time_varying_synth']['fc_top'], 
#                 metamer_pckl['features_out_time_varying_orig']['fc_top'])
#     plt.xlabel('Synth Activations')
#     plt.ylabel('Orig Activations')
#     plt.title('Activations at \n Logits, %s'%'fc_top')
#     plt.tight_layout()
#     plt.show()
    
    
    

In [4]:
widgets.interact(display_examples, 
                 model=models,
                 sound_idx=sounds_idx, 
                 layer_name=pitch_network_layers)


interactive(children=(Dropdown(description='model', options=('/om4/group/mcdermott/user/jfeather/projects/pitc…

<function __main__.display_examples(model, sound_idx, layer_name)>

In [34]:
import h5py
import resampy
import numpy as np


def rms(x):
    '''
    Returns root mean square amplitude of x (raises ValueError if NaN).
    '''
    out = np.sqrt(np.mean(np.square(x)))
    if np.isnan(out):
        raise ValueError('rms calculation resulted in NaN')
    return out


def set_dBSPL(x, dBSPL, mean_subtract=True):
    '''
    Returns x re-scaled to specified SPL in dB re 20e-6 Pa.
    '''
    if mean_subtract:
        x = x - np.mean(x)
    rms_out = 20e-6 * np.power(10, dBSPL/20)
    return rms_out * x / rms(x)


def get_pure_tone_dataset(DUR=0.05,
                          SR=32000,
                          N=15,
                          f0_min=80,
                          f0_max=1000,
                          dBSPL=60.0,
                          random_seed=858):
    """
    Makes a small dataset dictionary containing `N` pure tones with frequencies
    spaced log-uniformly between `f0_min` and `f0_max`. Pure tones have random
    start phases and are scaled to specified `dBSPL`.
    """
    t = np.arange(0, DUR, 1/SR)
    list_f0 = np.exp(np.linspace(np.log(f0_min), np.log(f0_max), N))
    y_preprocessed = np.zeros([N, len(t)], dtype=np.float32)
    for itr_f0, f0 in enumerate(list_f0):
        y_preprocessed[itr_f0] = np.sin(2*np.pi*f0*t + 2*np.pi*np.random.rand())
        y_preprocessed[itr_f0] = set_dBSPL(y_preprocessed[itr_f0], dBSPL)
    dataset = {
        'sr': np.array([SR]),
        'f0': list_f0,
        'y_preprocessed': y_preprocessed,
    }
    return dataset


def pitchnet_datasets(WAV_IDX, data_type='inst', preproc_scaled=None, rms_normalize=None, SR=32000):
    """
    Loads an example from the pitch network datasets, either speech or instrument. Each hdf5 contains 15 clips.
    """
    sound_locations = {
        'inst': '/om/user/msaddler/data_pitchnet/PND_v08inst_examples_for_metamers.hdf5',
        'spch': '/om/user/msaddler/data_pitchnet/PND_v08spch_examples_for_metamers.hdf5',
        'pure': get_pure_tone_dataset(),
    }
    if isinstance(sound_locations[data_type], str):
        dataset = h5py.File(sound_locations[data_type], "r")
        filename = sound_locations[data_type]
    else:
        dataset = sound_locations[data_type]
        filename = data_type # <-- this probably isnt a great filename
    # Convert the SR if necessary
    if dataset['sr'][0]!=SR:
        wav_f = resampy.resample(dataset['y_preprocessed'][WAV_IDX], dataset['sr'][0], SR, axis=0)
    else:
        wav_f = dataset['y_preprocessed'][WAV_IDX]
    if rms_normalize is not None:
        raise ValueError('pitch net datasets are already rms preprocessed for the networks, set to 60dB SPL. do not change rms_normalize')
    if (preproc_scaled is not None) and (~np.isclose(preproc_scaled, 1)):
        print(~np.isclose(preproc_scaled, 1))
        raise ValueError('pitch net datasets are already scaled for the networks, set to 60dB SPL. Cannot use option preproc_scaled=%f.'%preproc_scaled)
    audio_dict={}
    audio_dict['wav'] = wav_f
    audio_dict['SR'] = SR
    audio_dict['filename'] = filename
    audio_dict['filename_short'] = '%s_%d'%(data_type, WAV_IDX)
    audio_dict['correct_response'] = dataset['f0'][WAV_IDX]
    return audio_dict


audio_dict = pitchnet_datasets(2, data_type='pure', preproc_scaled=None, rms_normalize=None, SR=32000)
import IPython.display as ipd
ipd.display(ipd.Audio(audio_dict['wav'], rate=audio_dict['SR']))


In [35]:
audio_dict

{'SR': 32000,
 'correct_response': 114.76023196680742,
 'filename': 'pure',
 'filename_short': 'pure_2',
 'wav': array([-0.01475273, -0.0152797 , -0.01579877, ...,  0.0257608 ,
         0.02549427,  0.02521495], dtype=float32)}