In [16]:
import os
import json
import importlib
import numpy as np
import torch
import tensorflow as tf

import peripheral_model
import perceptual_model


In [59]:
importlib.reload(peripheral_model)
importlib.reload(perceptual_model)


class Model(torch.nn.Module):
    def __init__(
        self,
        config_model={},
        architecture=[],
        input_shape=[2, 65000, 2],
        config_random_slice={"size": [50, 10000], "buffer": [0, 1000]},
        device=None,
    ):
        """ """
        super().__init__()
        self.input_shape = input_shape
        kwargs_peripheral_model = {
            "sr_input": config_model["kwargs_cochlea"].get("sr_input", None),
            "sr_output": config_model["kwargs_cochlea"].get("sr_output", None),
            "config_cochlear_filterbank": config_model["kwargs_cochlea"].get(
                "config_filterbank", {}
            ),
            "config_ihc_transduction": config_model["kwargs_cochlea"].get(
                "config_subband_processing", {}
            ),
            "config_ihc_lowpass_filter": config_model["kwargs_cochlea"].get(
                "kwargs_fir_lowpass_filter_output", {}
            ),
            "config_anf_rate_level": config_model["kwargs_cochlea"].get(
                "kwargs_sigmoid_rate_level_function", {}
            ),
            "config_anf_spike_generator": config_model["kwargs_cochlea"].get(
                "kwargs_spike_generator_binomial", {}
            ),
            "config_random_slice": config_random_slice,
        }
        assert kwargs_peripheral_model["config_ihc_lowpass_filter"].pop(
            "ihc_filter", True
        )
        # print(self.input_shape, json.dumps(kwargs_peripheral_model, indent=4))
        self.peripheral_model = peripheral_model.PeripheralModel(
            **kwargs_peripheral_model,
        )
        self.perceptual_model = perceptual_model.PerceptualModel(
            architecture=architecture,
            input_shape=self.peripheral_model(torch.zeros(self.input_shape)).shape,
            heads=config_model["n_classes_dict"],
            device=device,
        )

    def forward(self, x):
        """ """
        return self.perceptual_model(self.peripheral_model(x))


dir_model = "../phaselocknet/models/sound_localization/simplified_IHC3000_delayed_integration/arch01"
input_shape = [2, 65000, 2]
config_random_slice = {"size": [50, 10000], "buffer": [0, 1000]}

dir_model = "../phaselocknet/models/spkr_word_recognition/simplified_IHC3000/arch0_0000"
input_shape = [2, 40000]
config_random_slice = {"size": [50, 20000], "buffer": [0, 0]}

# dir_model = "../phaselocknet/models/spkr_word_recognition/IHC3000/arch0_0000"
# input_shape = [2, 3, 50, 20000]
# config_random_slice = {}

# dir_model = "../phaselocknet/models/sound_localization/IHC3000_delayed_integration/arch01"
# input_shape = [2, 3, 50, 13000, 2]
# config_random_slice = {"size": [50, 10000], "buffer": [0, 1000]}

with open(os.path.join(dir_model, "config.json")) as f:
    config_model = json.load(f)
with open(os.path.join(dir_model, "arch.json")) as f:
    architecture = json.load(f)

model = Model(
    config_model=config_model,
    architecture=architecture,
    input_shape=input_shape,
    config_random_slice=config_random_slice,
)

for k, v in model(torch.zeros(model.input_shape)).items():
    print(k, v.shape, v.dtype)


label_speaker_int torch.Size([2, 433]) torch.float32
label_word_int torch.Size([2, 794]) torch.float32


In [61]:
list_layer_name = []
for n, p in model.perceptual_model.named_parameters():
    print(n, np.array(p.shape))
    name = n.replace(".bias", "").replace(".weight", "")
    if name not in list_layer_name:
        list_layer_name.append(name)


body.block0_conv.weight [32  3  2 42]
body.block0_conv.bias [32]
body.block0_norm.weight [32]
body.block0_norm.bias [32]
body.block1_conv.weight [64 32  2 18]
body.block1_conv.bias [64]
body.block1_norm.weight [64]
body.block1_norm.bias [64]
body.block2_conv.weight [128  64   6   6]
body.block2_conv.bias [128]
body.block2_norm.weight [128]
body.block2_norm.bias [128]
body.block3_conv.weight [256 128   6   6]
body.block3_conv.bias [256]
body.block3_norm.weight [256]
body.block3_norm.bias [256]
body.block4_conv.weight [512 256   8   8]
body.block4_conv.bias [512]
body.block4_norm.weight [512]
body.block4_norm.bias [512]
body.block5_conv.weight [512 512   6   6]
body.block5_conv.bias [512]
body.block5_norm.weight [512]
body.block5_norm.bias [512]
body.block6_conv.weight [512 512   8   8]
body.block6_conv.bias [512]
body.block6_norm.weight [512]
body.block6_norm.bias [512]
body.fc_intermediate_dense.weight [  512 35840]
body.fc_intermediate_dense.bias [512]
body.fc_intermediate_norm.weight

In [62]:
def load_tensorflow_checkpoint(filename):
    reader = tf.train.load_checkpoint(filename)
    shapes = reader.get_variable_to_shape_map()
    dtypes = reader.get_variable_to_dtype_map()
    tf_state_dict = {}
    for k in shapes:
        if ("layer_with_weights" in k) and ("OPTIMIZER_SLOT" not in k):
            tf_state_dict[k.replace("/.ATTRIBUTES/VARIABLE_VALUE", "")] = reader.get_tensor(k)
    return tf_state_dict

filename = os.path.join(dir_model, "ckpt_BEST")
tf_state_dict = load_tensorflow_checkpoint(filename)


In [64]:
torch_state_dict = {}
for k, v in sorted(tf_state_dict.items()):
    layer_index = int(k[k.find("-") + 1 : k.find("/")])
    layer_name = list_layer_name[layer_index]
    name = "{}.{}".format(layer_name, "bias" if "/b" in k else "weight")
    if v.ndim == 2:
        torch_state_dict[name] = torch.tensor(np.transpose(v, [1, 0]))
    elif v.ndim == 4:
        torch_state_dict[name] = torch.tensor(np.transpose(v, [3, 2, 0, 1]))
    else:
        torch_state_dict[name] = torch.tensor(v)
for n, p in model.perceptual_model.named_parameters():
    assert n in torch_state_dict
    if not torch_state_dict[n].shape == p.shape:
        print(n, p.shape, torch_state_dict[n].shape)


In [69]:
model.perceptual_model.load_state_dict(torch_state_dict, strict=False, assign=False)


_IncompatibleKeys(missing_keys=['body.block0_pool.weight', 'body.block1_pool.weight', 'body.block2_pool.weight', 'body.block3_pool.weight', 'body.block4_pool.weight', 'body.block5_pool.weight', 'body.block6_pool.weight'], unexpected_keys=[])

In [71]:
# model(torch.zeros(model.input_shape))
