In [13]:
import torch
import torchaudio
from nnAudio import features
from scipy.io import wavfile
from nnAudio import features
import torch

import openvino as ov

from torch_audioset.yamnet.model import yamnet as torch_yamnet
from torch_audioset.data.torch_input_processing import WaveformToInput as TorchTransform
import onnxruntime

In [2]:
wav_file = '/data/audio/ITW-Music/snr10-inthewild-16khz-inthewild_10002.wav'
waveform, sample_rate = torchaudio.load(wav_file, normalize=True)
# transform = torchaudio.transforms.Spectrogram(n_fft=512)
# spectrogram = transform(waveform)

In [3]:

tt_model = TorchTransform()

patches, spectrogram = tt_model.wavform_to_log_mel(waveform, 16000)

patches2 = tt_model(waveform, sample_rate)

spectrogram_direct = tt_model.mel_trans_ope.spectrogram(waveform)
spectrogram_direct = spectrogram_direct ** 0.5
mel_specgram_direct= tt_model.mel_trans_ope.mel_scale(spectrogram_direct)
mel_specgram_direct = torch.log(mel_specgram_direct + 0.001)
mel_specgram_direct = mel_specgram_direct.squeeze(dim=0).T 


In [None]:



class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.epsilon=1e-10
        # Getting Mel Spectrogram on the fly
        self.spec_layer = features.STFT(n_fft=512, win_length=400, freq_bins=None,
                                           hop_length=160, window='hann',
                                           freq_scale='no', center=True,
                                           pad_mode='reflect', fmin=125,
                                           fmax=7500, sr=16000, trainable=False,
                                           output_format='Magnitude')
        # self.mel_spec_layer = features.mel.MelSpectrogram(sr=16000, n_fft=512, win_length=400, n_mels=64,
        #                                                   hop_length=160, window='hann',
        #                                                   center=True, pad_mode='reflect',
        #                                                   power = 2.0, htk = True,
        #                                                   fmin = 125, fmax = 7500, norm = None)
        
        
        self.mel_scale = torchaudio.transforms.MelScale(
            n_mels = 64, sample_rate = 16000, f_min = 125, f_max = 7500, n_stft=512 // 2 + 1, norm = None, mel_scale = 'htk'
        )
        self.yamnet_model = torch_yamnet(pretrained=False)
        # Manually download the `yamnet.pth` file.
        self.yamnet_model.load_state_dict(torch.load('./yamnet.pth'))
        
    def _get_mel_specgram(self,x):
        specgram = self.spec_layer(x)
        mel_specgram = self.mel_scale(specgram)
        mel_specgram = torch.log(mel_specgram + 0.001)
        x = mel_specgram.squeeze(dim=0).T 
        spectrogram = x
        
        window_size_in_frames = int(round(
            0.96 / 0.010
        ))
        # num_chunks = x.shape[0] // window_size_in_frames

        # reshape into chunks of non-overlapping sliding window
        # num_frames_to_use = num_chunks * window_size_in_frames
        # x = x[:num_frames_to_use]
        # # [num_chunks, 1, window_size, num_freq]
        # x = x.reshape(num_chunks, 1, window_size_in_frames, x.shape[-1])

        patch_hop_in_frames = int(round(
            1.0 / 0.010
        ))
        # TODO performance optimization with zero copy
        patch_hop_num_chunks = (x.shape[0] - window_size_in_frames) // patch_hop_in_frames + 1
        num_frames_to_use = window_size_in_frames + (patch_hop_num_chunks - 1) * patch_hop_in_frames
        x = x[:num_frames_to_use]
        x_in_frames = x.reshape(-1, x.shape[-1])
        x_output = torch.empty((patch_hop_num_chunks, window_size_in_frames, x.shape[-1]))
        for i in range(patch_hop_num_chunks):
            start_frame = i * patch_hop_in_frames
            x_output[i] = x_in_frames[start_frame: start_frame + window_size_in_frames]
        x = x_output.reshape(patch_hop_num_chunks, 1, window_size_in_frames, x.shape[-1])
        # x = torch.tensor(x, dtype=torch.float32)

        # z = self.mel_spec_layer(x)
        return x, spectrogram

    def forward(self, x):
        mel_specgram, spectrogram = self._get_mel_specgram(x)
        return self.yamnet_model(mel_specgram), spectrogram


model = Model()    

with torch.no_grad():
    model.eval()
    y, spectrogram = model(waveform) # automatically convert waveforms into spectrograms

    torch.onnx.export(model,               # model being run
        waveform,                         # model input (or a tuple for multiple inputs)
        'my_model.onnx',            # where to save the model (can be a file or file-like object)
        verbose=True)

    # print(f'STFT all close: {torch.allclose(x_stft, x_stft_mel)}')

    # y = y.squeeze(dim=0).T 
    # y = torch.log(y + 0.001)
    print(f'min val: {torch.min(y)}, max val: {torch.max(y)}, mean val: {torch.mean(y)}, std val: {torch.std(y)}')
    torch.allclose(patches, y, atol=1e-5)


pt_model = torch_yamnet(pretrained=False)
# Manually download the `yamnet.pth` file.
pt_model.load_state_dict(torch.load('./yamnet.pth'))

patches, spectrogram = TorchTransform().wavform_to_log_mel(waveform, 16000)


with torch.no_grad():
    pt_model.eval()
    # x = torch.from_numpy(patches)
    # x = x.unsqueeze(1)  # [5, 96, 64] -> [5, 1, 96, 64]
    x = y # patches
    pt_pred = pt_model(x, to_prob=True)
    # pt_pred = pt_pred.numpy()
    
    pt_pred2 = pt_model(patches, to_prob=True)
    


In [None]:
model = TorchTransform()
x = model(waveform, sample_rate)
torch.onnx.export(model,               # model being run
    (waveform, x),                         # model input (or a tuple for multiple inputs)
    'torch_transform.onnx',            # where to save the model (can be a file or file-like object)
    verbose=True)

In [None]:
# using nnAudio
sr = 16000
waveforms = torch.randn(1,sr)

spec_layer = features.STFT(n_fft=512, win_length=0.025*sr, freq_bins=None, hop_length=0.010*sr,
                              window='hann', freq_scale='linear', center=True, pad_mode='reflect',
                              fmin=50,fmax=11025, sr=sr) # Initializing the model
spectrogram_nnaudio = spec_layer(waveforms)

In [None]:
spec_layer.eval()
torch.onnx.export(spec_layer,               # model being run
    waveform,                         # model input (or a tuple for multiple inputs)
    'specgram_nnaudio.onnx',            # where to save the model (can be a file or file-like object)
    verbose=True)

In [None]:
transform.eval()
torch.onnx.export(transform,               # model being run
    waveform,                         # model input (or a tuple for multiple inputs)
    'specgram.onnx',            # where to save the model (can be a file or file-like object)
    verbose=True)

## Inference not E2E (starting from spectrogram)

In [3]:
wav_file = '/data/audio/loccus-asv-datasets/QA/ASR_evaluation/wav.16kHz/common_voice_en_538718_en.wav'
waveform, sample_rate = torchaudio.load(wav_file, normalize=True)

In [6]:
patches, spectrogram = TorchTransform().wavform_to_log_mel(waveform, 16000)

In [None]:
# PyTorch
pt_model = torch_yamnet(pretrained=False)
pt_model.load_state_dict(torch.load('./yamnet.pth'))

with torch.no_grad():
    pt_model.eval()
    pt_pred = pt_model(patches, to_prob=True)

In [10]:
# ONNX
opts = onnxruntime.SessionOptions()
opts.inter_op_num_threads = 1
opts.intra_op_num_threads = 1      
session = onnxruntime.InferenceSession('yamnet.onnx', providers=['CPUExecutionProvider'], sess_options=opts)

ort_inputs = {session.get_inputs()[0].name: patches.numpy()}
ort_outs = session.run([session.get_outputs()[0].name], ort_inputs)

onnx_pred = ort_outs[0]

In [12]:
# OpenVino
core = ov.Core()
ov_model = core.read_model(model='yamnet.xml')        
compiled_model = ov.compile_model(ov_model, "CPU") #, config=config)
output_layer = compiled_model.output(0)

ov_pred = compiled_model(patches)[output_layer]