In [2]:
import os
import shutil
import importlib
import copy
import glob
import json
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import time
import torch
import torchaudio

import phaselocknet_model
import util

importlib.reload(phaselocknet_model)
importlib.reload(util)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device


device(type='cuda', index=0)

In [3]:
"""
Copy tensorflow model directories to new torch model directories for evaluation routine.
"""
# regex_dir_model_src = "../phaselocknet/models/sound_localization/simplified_IHC3000_delayed_integration/arch*"
# regex_dir_model_src = "../phaselocknet/models/sound_localization/simplified_IHC3000/arch*"
regex_dir_model_src = "../phaselocknet/models/spkr_word_recognition/simplified_IHC3000/arch*"
args_replace = ("../phaselocknet/", "../phaselocknet_torch/") # String replacement to map src to dst directory
for dir_model_src in glob.glob(regex_dir_model_src):
    # Build torch model object
    model, _ = phaselocknet_model.get_model(dir_model_src)
    # Load model weights from tensorflow checkpoint
    util.load_tf_model_checkpoint(
        model=model.perceptual_model,
        filename=os.path.join(dir_model_src, "ckpt_BEST"),
    )
    # Prepare destination directory for torch model
    assert args_replace[0] in dir_model_src
    dir_model_dst = dir_model_src.replace(*args_replace)
    if not os.path.exists(dir_model_dst):
        os.makedirs(dir_model_dst)
    # Save weights and configuration to torch model directory
    util.save_model_checkpoint(
        model=model.perceptual_model,
        dir_model=dir_model_dst,
        step=None,
        fn_ckpt="ckpt_BEST.pt",
    )
    # Copy `config.json` and `arch.json` to destination directory
    for basename in ["config.json", "arch.json"]:
        shutil.copyfile(
            os.path.join(dir_model_src, basename),
            os.path.join(dir_model_dst, basename),
        )
    print(f"[COMPLETE] {dir_model_dst=}\n")


[get_model] dir_model='../phaselocknet/models/spkr_word_recognition/simplified_IHC3000/arch0_0000'
[get_model] |__ input_shape=[2, 40000]
[get_model] |__ config_random_slice={'size': [50, 20000], 'buffer': [0, 0]}


2024-11-20 03:54:52.850820: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


[load_tf_model_checkpoint] missing_keys (../phaselocknet/models/spkr_word_recognition/simplified_IHC3000/arch0_0000/ckpt_BEST)
|__ body.block0_pool.weight
|__ body.block1_pool.weight
|__ body.block2_pool.weight
|__ body.block3_pool.weight
|__ body.block4_pool.weight
|__ body.block5_pool.weight
|__ body.block6_pool.weight
[load_tf_model_checkpoint] ../phaselocknet/models/spkr_word_recognition/simplified_IHC3000/arch0_0000/ckpt_BEST
[save_model_checkpoint] ../phaselocknet_torch/models/spkr_word_recognition/simplified_IHC3000/arch0_0000/ckpt_BEST.pt
[COMPLETE] dir_model_dst='../phaselocknet_torch/models/spkr_word_recognition/simplified_IHC3000/arch0_0000'

[get_model] dir_model='../phaselocknet/models/spkr_word_recognition/simplified_IHC3000/arch0_0001'
[get_model] |__ input_shape=[2, 40000]
[get_model] |__ config_random_slice={'size': [50, 20000], 'buffer': [0, 0]}
[load_tf_model_checkpoint] missing_keys (../phaselocknet/models/spkr_word_recognition/simplified_IHC3000/arch0_0001/ckpt_BES

In [4]:
"""
Minimal example of how to load a `phaselocknet` model directory into a torch model.
Please be aware torch and tensorflow model outputs will not exactly match due to
the stochastic spike sampling and small differences in numerical precision.
"""

# # Sound localization network with simplified auditory nerve model (operates on audio)
# dir_model = "../phaselocknet/models/sound_localization/simplified_IHC3000_delayed_integration/arch01"

# # Sound localization network with detailed auditory nerve model (operates on pre-computed auditory nerve representations)
# dir_model = "../phaselocknet/models/sound_localization/IHC3000_delayed_integration/arch01"

# Word + voice recognition network with simplified auditory nerve model (operates on audio)
dir_model = "../phaselocknet/models/spkr_word_recognition/simplified_IHC3000/arch0_0000"

# # Word + voice recognition network with detailed auditory nerve model (operates on pre-computed auditory nerve representations)
# dir_model = "../phaselocknet/models/spkr_word_recognition/IHC3000/arch0_0000"

model, _ = phaselocknet_model.get_model(dir_model)

# Load model weights from tensorflow checkpoint
util.load_tf_model_checkpoint(
    model=model.perceptual_model,
    filename=os.path.join(dir_model, "ckpt_BEST"),
)
model.train(mode=False)
model.to(device)
assert not model.training


[get_model] dir_model='../phaselocknet/models/spkr_word_recognition/simplified_IHC3000/arch0_0000'
[get_model] |__ input_shape=[2, 40000]
[get_model] |__ config_random_slice={'size': [50, 20000], 'buffer': [0, 0]}
[load_tf_model_checkpoint] missing_keys (../phaselocknet/models/spkr_word_recognition/simplified_IHC3000/arch0_0000/ckpt_BEST)
|__ body.block0_pool.weight
|__ body.block1_pool.weight
|__ body.block2_pool.weight
|__ body.block3_pool.weight
|__ body.block4_pool.weight
|__ body.block5_pool.weight
|__ body.block6_pool.weight
[load_tf_model_checkpoint] ../phaselocknet/models/spkr_word_recognition/simplified_IHC3000/arch0_0000/ckpt_BEST


In [5]:
regex_filenames = "../phaselocknet/stimuli/spkr_word_recognition/evaluation/pitch_altered_v00/stim*hdf5"

sr = 50000 if "sound_localization" in regex_filenames else 20000
num_steps_per_display = 10
dataset = util.HDF5Dataset(regex_filenames)


In [8]:
example = dataset[10]

print("Example structure:")
for k, v in example.items():
    print("|__", k, v.shape if v.ndim > 0 else v, v.dtype)

x = torch.tensor(example["signal"])[None, ...]
sr_src = example["sr"]
resampler = torchaudio.transforms.Resample(
    orig_freq=sr_src,
    new_freq=sr,
)
print(f"[evaluate] resampling audio from {sr_src} to {sr} Hz")
if x.ndim > 2:
    x = torch.stack(
        [resampler(x[..., channel]) for channel in range(x.shape[-1])],
        axis=-1,
    )
else:
    x = resampler(x)

x = util.pad_or_trim_to_len(x, n=model.input_shape[1], dim=1)
out = model(x.to(device))

print("Model output:")
for k, v in out.items():
    print("|__", k, v.shape if v.ndim > 0 else v, v.dtype, np.argmax(v[0].cpu().detach().numpy()))


Example structure:
|__ f0_shift_in_semitones nan float64
|__ foreground_condition 0 int64
|__ foreground_dbspl 60.0 float64
|__ foreground_index 10 int64
|__ index 10 int64
|__ inharmonic 0 int64
|__ label_speaker_int 277 int64
|__ label_word_int 19 int64
|__ signal (40000,) float32
|__ snr inf float64
|__ sr 20000 int64
|__ whispered 0 int64
[evaluate] resampling audio from 20000 to 20000 Hz
Model output:
|__ label_speaker_int torch.Size([1, 433]) torch.float32 277
|__ label_word_int torch.Size([1, 794]) torch.float32 19
