**Tensorflow implementation** of the paper [Steerable discovery of neural audio effects](https://arxiv.org/abs/2112.02926) by [Christian J. Steinmetz](https://www.christiansteinmetz.com/) and [Joshua D. Reiss](http://www.eecs.qmul.ac.uk/~josh/)

______________________________________________

<div align="center">

# Steerable discovery of neural audio effects

  [Christian J. Steinmetz](https://www.christiansteinmetz.com/)  and  [Joshua D. Reiss](http://www.eecs.qmul.ac.uk/~josh/)


[Code](https://github.com/csteinmetz1/steerable-nafx) • [Paper](https://arxiv.org/abs/2112.02926) • [Demo](https://csteinmetz1.github.io/steerable-nafx)	• [Slides]()

<img src="https://csteinmetz1.github.io/steerable-nafx/assets/steerable-headline.svg">

</div>

## Abtract
Applications of deep learning for audio effects often focus on modeling analog effects or learning to control effects to emulate a trained audio engineer. 
However, deep learning approaches also have the potential to expand creativity through neural audio effects that enable new sound transformations. 
While recent work demonstrated that neural networks with random weights produce compelling audio effects, control of these effects is limited and unintuitive.
To address this, we introduce a method for the steerable discovery of neural audio effects.
This method enables the design of effects using example recordings provided by the user. 
We demonstrate how this method produces an effect similar to the target effect, along with interesting inaccuracies, while also providing perceptually relevant controls.


\* *Accepted to NeurIPS 2021 Workshop on Machine Learning for Creativity and Design*



# Setup

In [None]:
import numpy as np
import scipy as sp

import tensorflow as tf
from tensorflow import keras
import tf2onnx
import onnx

import os
import IPython

import matplotlib.pyplot as plt
import librosa.display

### Choose computation device

In [None]:
physical_devices = tf.config.list_physical_devices()
print(f"These are the physical devices available:\n{physical_devices}")

try:
    # Disable all GPUS
    tf.config.set_visible_devices([], 'GPU')
    visible_devices = tf.config.get_visible_devices()
    print(f"These are the visible devices:\n{visible_devices}")
except:
    pass

In [None]:
name = 'model_0'


if not os.path.exists('models/'+name):
    os.makedirs('models/'+name)
else:
    print("A model with the same name already exists. Please choose a new name.")
    exit

## Define the model

In [None]:
keras.backend.clear_session()

class TCNBlock(keras.Model):
    def __init__(self, in_channels, out_channels, kernel_size, dilation, activation=True, **kwargs):
        super().__init__(**kwargs)
        self.conv = keras.layers.Conv1D(filters=out_channels, kernel_size=kernel_size, dilation_rate=dilation, padding="valid", use_bias=True, kernel_initializer="glorot_uniform", bias_initializer="zeros")
        # self.act = keras.layers.tanh
        self.act = keras.layers.PReLU(alpha_initializer=tf.initializers.constant(0.25), shared_axes=[0, 1, 2])
        self.res = keras.layers.Conv1D(out_channels, 1, use_bias=False, kernel_initializer="glorot_uniform")
        self.kernel_size = kernel_size
        self.dilation = dilation

    def call(self, x):
        x_in = x
        x = self.conv(x)
        x = self.act(x)
        x_res = self.res(x_in)
        x_res = x_res[:, (self.kernel_size-1)*self.dilation:, :]
        x = x + x_res
        return x
    
class TCN(keras.Model):
    def __init__(self, n_inputs=1, n_outputs=1, n_blocks=10, kernel_size=13, n_channels=64, dilation_growth=4):
        super().__init__()
        self.kernel_size = kernel_size
        self.n_channels = n_channels
        self.dilation_growth = dilation_growth
        self.n_blocks = n_blocks
        self.stack_size = n_blocks
        
        self.blocks = []
        for n in range(n_blocks):
            if n == 0:
                in_ch = n_inputs
                out_ch = n_channels
                act = True
            elif (n+1) == n_blocks:
                in_ch = n_channels
                out_ch = n_outputs
                act = True
            else:
                in_ch = n_channels
                out_ch = n_channels
                act = True

            dilation = dilation_growth ** n
            self.blocks.append(TCNBlock(in_ch, out_ch, kernel_size, dilation, activation=act))

    def call(self, x):
        for block in self.blocks:
            x = block(x)
        return x
    
    def compute_receptive_field(self):
        """Compute the receptive field in samples."""
        rf = self.kernel_size
        for n in range(1, self.n_blocks):
            dilation = self.dilation_growth ** (n % self.stack_size)
            rf = rf + ((self.kernel_size - 1) * dilation)
        return rf

# 1. Steering (training)
Use a pair of audio examples in order to construct neural audio effects.

There are two options. Either start with the pre-loaded audio examples, or upload your own clean/processed audio recordings for the steering process.

a.) Use some of our pre-loaded audio examples. Choose from the compressor or reverb effect.

In [None]:
#@title Use pre-loaded audio examples for steering
effect_type = "Amp" #@param ["Compressor", "Reverb", "UltraTab", "Amp"]

if effect_type == "Compressor":
  input_file = "audio/drum_kit_clean.wav"
  output_file = "audio/drum_kit_comp_agg.wav"
elif effect_type == "Reverb":
  input_file = "audio/acgtr_clean.wav"
  output_file = "audio/acgtr_reverb.wav"
elif effect_type == "UltraTab":
  input_file = "audio/acgtr_clean.wav"
  output_file = "audio/acgtr_ultratab.wav"
elif effect_type == "Amp":
  input_file = "audio/ts9_test1_in_FP32.wav"
  output_file = "audio/ts9_test1_out_FP32.wav"

# Load and Preprocess Data ###########################################
sample_rate, x = sp.io.wavfile.read(input_file)
sample_rate, y = sp.io.wavfile.read(output_file)

x = x.astype(np.float32)
y = y.astype(np.float32)

# x = x[..., :1]/32768.0 # when wav files are 16-bit integers
# y = y[..., :1]/32768.0

x = x.flatten()
y = y.flatten()

print(f"x shape: {x.shape}")
print(f"x = {x}")
print(f"y shape: {y.shape}")
print(f"y = {y}")

print("input file", x.shape)
IPython.display.display(IPython.display.Audio(data=x, rate=sample_rate))
print("output file", y.shape)
IPython.display.display(IPython.display.Audio(data=y, rate=sample_rate))

Now its time to generate the neural audio effect by training the TCN to emulate the input/output function from the target audio effect. Adjusting the parameters will enable you to tweak the optimization process. 

In [None]:
#@title TCN model training parameters
kernel_size = 13 #@param {type:"slider", min:3, max:32, step:1}
n_blocks = 4 #@param {type:"slider", min:2, max:30, step:1}
dilation_growth = 10 #@param {type:"slider", min:1, max:10, step:1}
n_channels = 32 #@param {type:"slider", min:1, max:128, step:1}
n_iters = 300 #@param {type:"slider", min:0, max:10000, step:1}
length = 508032 #@param {type:"slider", min:0, max:524288, step:1}
lr = 0.001 #@param {type:"number"}

# # reshape the audio
x_batch = x.reshape(1,-1,1)
y_batch = y.reshape(1,-1,1)

print(f"x_batch shape: {x_batch.shape}")
print(f"y_batch shape: {y_batch.shape}")

# build the model
model = TCN(
    n_inputs=1,
    n_outputs=1,
    kernel_size=kernel_size, 
    n_blocks=n_blocks, 
    dilation_growth=dilation_growth, 
    n_channels=n_channels)
rf = model.compute_receptive_field()

print(f"Receptive field: {rf} samples or {(rf/sample_rate)*1e3:0.1f} ms")

In [None]:
class MyLRSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, initial_learning_rate, n_iters):
        self.initial_learning_rate = initial_learning_rate
        self.n_iters = n_iters

    @tf.function
    def __call__(self, step):
        if step >= tf.cast((self.n_iters * 0.8), tf.int64):
            return self.initial_learning_rate * 0.1
        elif step >= tf.cast(self.n_iters * 0.95, tf.int64):
            return self.initial_learning_rate * 0.01
        else:
            return self.initial_learning_rate
        
    def get_config(self):
        config = {
        'initial_learning_rate': self.initial_learning_rate,
        'niters': self.n_iters,
        }
        return config

optimizer = keras.optimizers.Adam(learning_rate=MyLRSchedule(lr, n_iters), epsilon=1e-8)
model.compile(optimizer=optimizer, loss='mse')
model.build(input_shape=(None, length+rf-1, 1))
model.summary(expand_nested=True)

## Train the model

In [None]:
start_idx = rf 
stop_idx = start_idx + length

# the data is the same with every iteration
x_crop = x_batch[:,start_idx-rf+1:stop_idx,:]
y_crop = y_batch[:,start_idx:stop_idx,:]
print(f"x_crop = {x_crop.shape}")
print(f"y_crop = {y_crop.shape}")

history = model.fit(x=x_crop, y=y_crop, epochs=n_iters, batch_size=1, verbose=1)

In [None]:
model.save_weights('models/'+name+'/'+name)

## Run predictions
### 0. Load the model

In [None]:
model.load_weights('models/'+name+'/'+name)

### 1. On the test audio data

In [None]:
# Run Prediction #################################################
# Test the model on the testing data #############################

x_pad = np.pad(x_batch, ((0,0),(rf-1,0),(0,0)), mode='constant')

y_hat = model.predict(x_pad)

input = x_batch.flatten()
output = y_hat.flatten()
target = y_batch.flatten()

print(f"Input shape: {input.shape}")
print(f"Output shape: {output.shape}")
print(f"Target shape: {target.shape}")

# apply highpass to outpu to remove DC
sos = sp.signal.butter(8, 20.0, fs=sample_rate, output="sos", btype="highpass")
output = sp.signal.sosfilt(sos, output)

input /= np.max(np.abs(input))
output /= np.max(np.abs(output))
target /= np.max(np.abs(target))

fig, ax = plt.subplots(nrows=1, sharex=True)
librosa.display.waveshow(target, sr=sample_rate, color='b', alpha=0.5, ax=ax, label='Target')
librosa.display.waveshow(output, sr=sample_rate, color='r', alpha=0.5, ax=ax, label='Output')

print("Input (clean)")
IPython.display.display(IPython.display.Audio(data=input, rate=sample_rate))
print("Target")
IPython.display.display(IPython.display.Audio(data=target, rate=sample_rate))
print("Output")
IPython.display.display(IPython.display.Audio(data=output, rate=sample_rate))
plt.legend()
plt.show(fig)

In [None]:
# Load and Preprocess Data ###########################################
sample_rate, x_whole = sp.io.wavfile.read("audio/piano_clean.wav")
x_whole = x_whole.astype(np.float32)
x_whole = x_whole[..., :1]/32768.0 # because wav files are 16-bit integers
x_whole = x_whole.reshape(1,-1,1)

# Padding on both sides of the receptive field
x_whole = np.pad(x_whole, ((0,0),(rf-1,rf-1),(0,0)), mode='constant')

y_whole = model.predict(x_whole)

x_whole = x_whole[:, -y_whole.shape[1]:, :]

y_whole /= np.abs(y_whole).max()

# apply high pass filter to remove DC
sos = sp.signal.butter(8, 20.0, fs=sample_rate, output="sos", btype="highpass")
y_whole = sp.signal.sosfilt(sos, y_whole.flatten())

x_whole = x_whole.flatten()

y_whole = (y_whole * 0.8)
IPython.display.display(IPython.display.Audio(data=x_whole, rate=sample_rate))
IPython.display.display(IPython.display.Audio(data=y_whole, rate=sample_rate))

x_whole /= np.max(np.abs(x_whole))
y_whole /= np.max(np.abs(y_whole))

fig, ax = plt.subplots(nrows=1, sharex=True)
librosa.display.waveshow(y_whole, sr=sample_rate, color='r', alpha=0.5, ax=ax, label='Output')
librosa.display.waveshow(x_whole, sr=sample_rate, alpha=0.5, ax=ax, label='Input', color="blue")
plt.legend()
plt.show(fig)

### 2. On a number sequence (to control inference)

In [None]:
# Test the model simple number sequence to compare with inference #
X_testing_2 = np.array([])

for i in range(0, 2048+rf-1):
    X_testing_2 = np.append(X_testing_2, i*0.000001)

X_testing_2 = X_testing_2.reshape(1, -1, 1)

print("Running prediction..")
prediction_2 = model.predict(X_testing_2)
print(f"prediction {prediction_2}")

print("X_testing_2 shape: ", X_testing_2.shape)
print("prediction_2 shape: ", prediction_2.shape)

## Export as tflite and onnx model

In [None]:
# rf is the receptive field and hence the kernel size of the last layer because of the dilation
# hence the model input size needs to be at least rf
# still we export the model with dynamic input and batch size
input_shape = [1, None, 1]

func = tf.function(model).get_concrete_function(
    tf.TensorSpec(input_shape, dtype=tf.float32))
converter = tf.lite.TFLiteConverter.from_concrete_functions([func], model)
tflite_model = converter.convert()

# Save the model.
with open("models/"+name+"/"+"steerable-nafx-dynamic.tflite", 'wb') as f:
    f.write(tflite_model)

In [None]:
# export the model to onnx
input_shape = [1, None, 1]

# Define the input shape
input_signature = [tf.TensorSpec(input_shape, tf.float32, name='x')]

# Convert the model
onnx_model, _ = tf2onnx.convert.from_keras(model, input_signature, opset=18)
onnx.save(proto=onnx_model, f="models/"+name+"/"+"steerable-nafx-tflite-dynamic.onnx")

# 3 TCN blocks Model

In [None]:
#@title TCN model training parameters
kernel_size = 13 #@param {type:"slider", min:3, max:32, step:1}
n_blocks = 3 #@param {type:"slider", min:2, max:30, step:1}
dilation_growth = 10 #@param {type:"slider", min:1, max:10, step:1}
n_channels = 32 #@param {type:"slider", min:1, max:128, step:1}
n_iters = 10 #@param {type:"slider", min:0, max:10000, step:1}
length = 508032 #@param {type:"slider", min:0, max:524288, step:1}
lr = 0.001 #@param {type:"number"}

# # reshape the audio
x_batch = x.reshape(1,-1,1)
y_batch = y.reshape(1,-1,1)

print(f"x_batch shape: {x_batch.shape}")
print(f"y_batch shape: {y_batch.shape}")

# build the model
model_3_blocks = TCN(
    n_inputs=1,
    n_outputs=1,
    kernel_size=kernel_size, 
    n_blocks=n_blocks, 
    dilation_growth=dilation_growth, 
    n_channels=n_channels)
rf = model_3_blocks.compute_receptive_field()

print(f"Receptive field: {rf} samples or {(rf/sample_rate)*1e3:0.1f} ms")

In [None]:
class MyLRSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, initial_learning_rate, n_iters):
        self.initial_learning_rate = initial_learning_rate
        self.n_iters = n_iters

    @tf.function
    def __call__(self, step):
        if step >= tf.cast((self.n_iters * 0.8), tf.int64):
            return self.initial_learning_rate * 0.1
        elif step >= tf.cast(self.n_iters * 0.95, tf.int64):
            return self.initial_learning_rate * 0.01
        else:
            return self.initial_learning_rate
        
    def get_config(self):
        config = {
        'initial_learning_rate': self.initial_learning_rate,
        'niters': self.n_iters,
        }
        return config

optimizer = keras.optimizers.Adam(learning_rate=MyLRSchedule(lr, n_iters), epsilon=1e-8)
model_3_blocks.compile(optimizer=optimizer, loss='mse')
model_3_blocks.build(input_shape=(None, length+rf-1, 1))
model_3_blocks.summary(expand_nested=True)

## Train the model

In [None]:
start_idx = rf 
stop_idx = start_idx + length

# the data is the same with every iteration
x_crop = x_batch[:,start_idx-rf+1:stop_idx,:]
y_crop = y_batch[:,start_idx:stop_idx,:]
print(f"x_crop = {x_crop.shape}")
print(f"y_crop = {y_crop.shape}")

history = model_3_blocks.fit(x=x_crop, y=y_crop, epochs=n_iters, batch_size=1, verbose=1)

## Run predictions

### 1. On the test audio data

In [None]:
# Run Prediction #################################################
# Test the model on the testing data #############################

x_pad = np.pad(x_batch, ((0,0),(rf-1,0),(0,0)), mode='constant')

y_hat = model_3_blocks.predict(x_pad)

input = x_batch.flatten()
output = y_hat.flatten()
target = y_batch.flatten()

print(f"Input shape: {input.shape}")
print(f"Output shape: {output.shape}")
print(f"Target shape: {target.shape}")

# apply highpass to outpu to remove DC
sos = sp.signal.butter(8, 20.0, fs=sample_rate, output="sos", btype="highpass")
output = sp.signal.sosfilt(sos, output)

input /= np.max(np.abs(input))
output /= np.max(np.abs(output))
target /= np.max(np.abs(target))

fig, ax = plt.subplots(nrows=1, sharex=True)
librosa.display.waveshow(target, sr=sample_rate, color='b', alpha=0.5, ax=ax, label='Target')
librosa.display.waveshow(output, sr=sample_rate, color='r', alpha=0.5, ax=ax, label='Output')

print("Input (clean)")
IPython.display.display(IPython.display.Audio(data=input, rate=sample_rate))
print("Target")
IPython.display.display(IPython.display.Audio(data=target, rate=sample_rate))
print("Output")
IPython.display.display(IPython.display.Audio(data=output, rate=sample_rate))
plt.legend()
plt.show(fig)

In [None]:
# Load and Preprocess Data ###########################################
sample_rate, x_whole = sp.io.wavfile.read("audio/piano_clean.wav")
x_whole = x_whole.astype(np.float32)
x_whole = x_whole[..., :1]/32768.0 # because wav files are 16-bit integers
x_whole = x_whole.reshape(1,-1,1)

# Padding on both sides of the receptive field
x_whole = np.pad(x_whole, ((0,0),(rf-1,rf-1),(0,0)), mode='constant')

y_whole = model_3_blocks.predict(x_whole)

x_whole = x_whole[:, -y_whole.shape[1]:, :]

y_whole /= np.abs(y_whole).max()

# apply high pass filter to remove DC
sos = sp.signal.butter(8, 20.0, fs=sample_rate, output="sos", btype="highpass")
y_whole = sp.signal.sosfilt(sos, y_whole.flatten())

x_whole = x_whole.flatten()

y_whole = (y_whole * 0.8)
IPython.display.display(IPython.display.Audio(data=x_whole, rate=sample_rate))
IPython.display.display(IPython.display.Audio(data=y_whole, rate=sample_rate))

x_whole /= np.max(np.abs(x_whole))
y_whole /= np.max(np.abs(y_whole))

fig, ax = plt.subplots(nrows=1, sharex=True)
librosa.display.waveshow(y_whole, sr=sample_rate, color='r', alpha=0.5, ax=ax, label='Output')
librosa.display.waveshow(x_whole, sr=sample_rate, alpha=0.5, ax=ax, label='Input', color="blue")
plt.legend()
plt.show(fig)

## Export as tflite and onnx model

In [None]:
# rf is the receptive field and hence the kernel size of the last layer because of the dilation
# hence the model input size needs to be at least rf
# still we export the model with dynamic input and batch size
input_shape = [1, None, 1]

func = tf.function(model_3_blocks).get_concrete_function(
    tf.TensorSpec(input_shape, dtype=tf.float32))
converter = tf.lite.TFLiteConverter.from_concrete_functions([func], model_3_blocks)
tflite_model_3_blocks = converter.convert()

# Save the model.
with open("models/"+name+"/"+"steerable-nafx-3_blocks-dynamic.tflite", 'wb') as f:
    f.write(tflite_model_3_blocks)

In [None]:
# export the model to onnx
input_shape = [1, None, 1]

# Define the input shape
input_signature = [tf.TensorSpec(input_shape, tf.float32, name='x')]

# Convert the model
onnx_model_3_blocks, _ = tf2onnx.convert.from_keras(model_3_blocks, input_signature, opset=18)
onnx.save(proto=onnx_model_3_blocks, f="models/"+name+"/"+"steerable-nafx-3_blocks-tflite-dynamic.onnx")

# 2 TCN blocks Model

In [None]:
#@title TCN model training parameters
kernel_size = 13 #@param {type:"slider", min:3, max:32, step:1}
n_blocks = 2 #@param {type:"slider", min:2, max:30, step:1}
dilation_growth = 10 #@param {type:"slider", min:1, max:10, step:1}
n_channels = 32 #@param {type:"slider", min:1, max:128, step:1}
n_iters = 10 #@param {type:"slider", min:0, max:10000, step:1}
length = 508032 #@param {type:"slider", min:0, max:524288, step:1}
lr = 0.001 #@param {type:"number"}

# # reshape the audio
x_batch = x.reshape(1,-1,1)
y_batch = y.reshape(1,-1,1)

print(f"x_batch shape: {x_batch.shape}")
print(f"y_batch shape: {y_batch.shape}")

# build the model
model_2_blocks = TCN(
    n_inputs=1,
    n_outputs=1,
    kernel_size=kernel_size, 
    n_blocks=n_blocks, 
    dilation_growth=dilation_growth, 
    n_channels=n_channels)
rf = model_2_blocks.compute_receptive_field()

print(f"Receptive field: {rf} samples or {(rf/sample_rate)*1e3:0.1f} ms")

In [None]:
class MyLRSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, initial_learning_rate, n_iters):
        self.initial_learning_rate = initial_learning_rate
        self.n_iters = n_iters

    @tf.function
    def __call__(self, step):
        if step >= tf.cast((self.n_iters * 0.8), tf.int64):
            return self.initial_learning_rate * 0.1
        elif step >= tf.cast(self.n_iters * 0.95, tf.int64):
            return self.initial_learning_rate * 0.01
        else:
            return self.initial_learning_rate
        
    def get_config(self):
        config = {
        'initial_learning_rate': self.initial_learning_rate,
        'niters': self.n_iters,
        }
        return config

optimizer = keras.optimizers.Adam(learning_rate=MyLRSchedule(lr, n_iters), epsilon=1e-8)
model_2_blocks.compile(optimizer=optimizer, loss='mse')
model_2_blocks.build(input_shape=(None, length+rf-1, 1))
model_2_blocks.summary(expand_nested=True)

## Train the model

In [None]:
start_idx = rf 
stop_idx = start_idx + length

# the data is the same with every iteration
x_crop = x_batch[:,start_idx-rf+1:stop_idx,:]
y_crop = y_batch[:,start_idx:stop_idx,:]
print(f"x_crop = {x_crop.shape}")
print(f"y_crop = {y_crop.shape}")

history = model_2_blocks.fit(x=x_crop, y=y_crop, epochs=n_iters, batch_size=1, verbose=1)

## Run predictions

### 1. On the test audio data

In [None]:
### 1. On the test audio dat# Run Prediction #################################################
# Test the model on the testing data #############################

x_pad = np.pad(x_batch, ((0,0),(rf-1,0),(0,0)), mode='constant')

y_hat = model_2_blocks.predict(x_pad)

input = x_batch.flatten()
output = y_hat.flatten()
target = y_batch.flatten()

print(f"Input shape: {input.shape}")
print(f"Output shape: {output.shape}")
print(f"Target shape: {target.shape}")

# apply highpass to outpu to remove DC
sos = sp.signal.butter(8, 20.0, fs=sample_rate, output="sos", btype="highpass")
output = sp.signal.sosfilt(sos, output)

input /= np.max(np.abs(input))
output /= np.max(np.abs(output))
target /= np.max(np.abs(target))

fig, ax = plt.subplots(nrows=1, sharex=True)
librosa.display.waveshow(target, sr=sample_rate, color='b', alpha=0.5, ax=ax, label='Target')
librosa.display.waveshow(output, sr=sample_rate, color='r', alpha=0.5, ax=ax, label='Output')

print("Input (clean)")
IPython.display.display(IPython.display.Audio(data=input, rate=sample_rate))
print("Target")
IPython.display.display(IPython.display.Audio(data=target, rate=sample_rate))
print("Output")
IPython.display.display(IPython.display.Audio(data=output, rate=sample_rate))
plt.legend()
plt.show(fig)

In [None]:
# Load and Preprocess Data ###########################################
sample_rate, x_whole = sp.io.wavfile.read("audio/piano_clean.wav")
x_whole = x_whole.astype(np.float32)
x_whole = x_whole[..., :1]/32768.0 # because wav files are 16-bit integers
x_whole = x_whole.reshape(1,-1,1)

# Padding on both sides of the receptive field
x_whole = np.pad(x_whole, ((0,0),(rf-1,rf-1),(0,0)), mode='constant')

y_whole = model_2_blocks.predict(x_whole)

x_whole = x_whole[:, -y_whole.shape[1]:, :]

y_whole /= np.abs(y_whole).max()

# apply high pass filter to remove DC
sos = sp.signal.butter(8, 20.0, fs=sample_rate, output="sos", btype="highpass")
y_whole = sp.signal.sosfilt(sos, y_whole.flatten())

x_whole = x_whole.flatten()

y_whole = (y_whole * 0.8)
IPython.display.display(IPython.display.Audio(data=x_whole, rate=sample_rate))
IPython.display.display(IPython.display.Audio(data=y_whole, rate=sample_rate))

x_whole /= np.max(np.abs(x_whole))
y_whole /= np.max(np.abs(y_whole))

fig, ax = plt.subplots(nrows=1, sharex=True)
librosa.display.waveshow(y_whole, sr=sample_rate, color='r', alpha=0.5, ax=ax, label='Output')
librosa.display.waveshow(x_whole, sr=sample_rate, alpha=0.5, ax=ax, label='Input', color="blue")
plt.legend()
plt.show(fig)

## Export as tflite and onnx model

In [None]:
# rf is the receptive field and hence the kernel size of the last layer because of the dilation
# hence the model input size needs to be at least rf
# still we export the model with dynamic input and batch size
input_shape = [1, None, 1]

func = tf.function(model_2_blocks).get_concrete_function(
    tf.TensorSpec(input_shape, dtype=tf.float32))
converter = tf.lite.TFLiteConverter.from_concrete_functions([func], model_2_blocks)
tflite_model_2_blocks = converter.convert()

# Save the model.
with open("models/"+name+"/"+"steerable-nafx-2_blocks-dynamic.tflite", 'wb') as f:
    f.write(tflite_model_2_blocks)

In [None]:
# export the model to onnx
input_shape = [1, None, 1]

# Define the input shape
input_signature = [tf.TensorSpec(input_shape, tf.float32, name='x')]

# Convert the model
onnx_model_2_blocks, _ = tf2onnx.convert.from_keras(model_2_blocks, input_signature, opset=18)
onnx.save(proto=onnx_model_2_blocks, f="models/"+name+"/"+"steerable-nafx-2_blocks-tflite-dynamic.onnx")