## Training of a lossless FDN to improve colorlessness
Tune the parameters of an homogeneous and lossless FDN to reduce coloration


Start from importing all the neccessary packages and flamo modules 

In [None]:
import torch
import os
from collections import OrderedDict
import matplotlib.pyplot as plt

# from flamo 
from flamo.optimize.dataset import DatasetColorless, load_dataset
from flamo.optimize.trainer import Trainer
from flamo.processor import dsp, system
from flamo.optimize.loss import amse_loss, sparsity_loss

torch.manual_seed(130798)


### Construct the Feedback Delay Network
The FDN is created as an istance of the `flamo.system.Series`class which is used to cascade multiple DSP modules in series, similarly to `nn.Sequential`. This class serves as a container and ensures that all included modules share the same values for the `nfft` and `alias_decay_db` attributes. 

Note that the created FDN is lossless, so the `alias_decay_db` must be nonzero to reduce time-aliasing 

In [24]:
# FDN parameters
delay_lengths = torch.tensor([887, 911, 941, 1699, 1951, 2053])
N = len(delay_lengths)  # number of delays

# training parameters
nfft = 2**16    # number of FFT points
alias_decay_db = 30  # decay in dB of the anti time-aliasing envelope
device = 'cpu'  # 'cuda' or 'cpu'
fs = 48000  # sample rate


In [None]:
# Input gains 
input_gain = dsp.Gain(
    size=(N, 1), 
    nfft=nfft, 
    requires_grad=True, 
    alias_decay_db=alias_decay_db, 
    device=device
)

# Output gains
output_gain = dsp.Gain(
    size=(1, N), 
    nfft=nfft, 
    requires_grad=True, 
    alias_decay_db=alias_decay_db, 
    device=device
)

# FEEDBACK LOOP

# feedforward path with delays
delays = dsp.parallelDelay(
    size=(N,),
    max_len=delay_lengths.max(),
    nfft=nfft,
    isint=True,
    requires_grad=False,
    alias_decay_db=alias_decay_db,
    device=device,
)
delays.assign_value(delays.sample2s(delay_lengths))

# Feedback path with orthogonal matrix
feedback = dsp.Matrix(
    size=(N, N),
    nfft=nfft,
    matrix_type="orthogonal",
    requires_grad=True,
    alias_decay_db=alias_decay_db,
    device=device,
)
# Create recursion
feedback_loop = system.Recursion(fF=delays, fB=feedback)

# Contruct the FDN
FDN = system.Series(OrderedDict({
    'input_gain': input_gain,
    'feedback_loop': feedback_loop,
    'output_gain': output_gain
}))

flamo provides a `Shell` class where the differentiable system, in this case `FDN`, is connected to the input and output layers. 
- The input will be an impulse in time domain, thus the input layer needs to transform it to frequency domain 
- The target is the desired magnitude response, thus the input layers is the absolute value operation 

In [26]:
input_layer = dsp.FFT(nfft) 
output_layer = dsp.Transform(transform=lambda x : torch.abs(x))
# wrap the FDN in the Shell
model = system.Shell(
    core=FDN, 
    input_layer=input_layer, 
    output_layer=output_layer)

To speed up training is good practice to make sure that the energy of the system is comparable to that of the target. 

In [27]:
H = model.get_freq_response(identity=False)
energy_H = torch.mean(torch.pow(torch.abs(H),2))
target_energy = 1
# apply energy normalization on input and output gains only
with torch.no_grad():
    core = model.get_core()
    core.input_gain.assign_value(torch.div(core.input_gain.param, torch.pow( energy_H / target_energy, 1/4)))
    core.output_gain.assign_value(torch.div(core.output_gain.param, torch.pow( energy_H / target_energy, 1/4)))
    model.set_core(core)

Log impulse response and the magnitude response at initialization

In [28]:
with torch.no_grad():
    ir_init =  model.get_time_response(identity=False, fs=fs).squeeze() 
    mag_init = model.get_freq_response(identity=False, fs=fs).squeeze() 
    mag_init = 20 * torch.log10(mag_init)

#### Set up training
Set training parameters values and construct dataset and trainer. 

In [29]:
# training set up parameters 
batch_size = 1
num = 256 # number of samples
max_epochs = 20 # maximum number of epochs 
lr = 1e-3 # learning rate
step_size = 5 # step size for the learning rate scheduler
train_dir = 'output/ex_fdn'
# create the output directory
os.makedirs(train_dir, exist_ok=True)

# create the dataset and data loaders 
dataset = DatasetColorless(
    input_shape=(1, nfft // 2 + 1, 1),      # impulse 
    target_shape=(1, nfft // 2 + 1, 1),     # flat spectrum as target 
    expand=num,
    device=device,
)
train_loader, valid_loader = load_dataset(dataset, batch_size=batch_size)

# Initialize training process
trainer = Trainer(
    model, 
    max_epochs=max_epochs, 
    lr=lr, 
    train_dir=train_dir, 
    device=device)

# Register the loss functions with their relative weights
trainer.register_criterion(amse_loss(), 1)
trainer.register_criterion(sparsity_loss(), 1, requires_model=True)


#### Train the model! 
For each epoch the trainer launch both training and validation 

In [None]:
trainer.train(train_loader, valid_loader)

In [None]:

# Get optimized impulse response
with torch.no_grad():
    ir_optim =  model.get_time_response(identity=False, fs=fs).squeeze()
    mag_optim = model.get_freq_response(identity=False, fs=fs).squeeze() 
    mag_optim = 20 * torch.log10(mag_optim)

time_axis = torch.linspace(0, nfft/fs, nfft)
freq_axis = torch.linspace(0, fs/2, nfft//2+1)

# plot impulse response
plt.figure(figsize=(12, 6))
plt.subplot(2, 1, 1)
plt.plot(time_axis, ir_init.numpy(), label='Initial')
plt.plot(time_axis, ir_optim.numpy(), label='Optimized', alpha=0.7)
plt.xlim(0, 0.5)
plt.legend()
plt.title('Impulse Response')
plt.xlabel('Samples')
plt.ylabel('Amplitude')

# plot magnitude response
plt.subplot(2, 1, 2)
plt.plot(freq_axis, mag_init.numpy(), label='Initial')
plt.plot(freq_axis, mag_optim.numpy(), label='Optimized', alpha=0.7)
plt.xlim(100, 500)
plt.legend()
plt.title('Magnitude Response')
plt.xlabel('Frequency (Hz)')
plt.ylabel('Magnitude')

plt.tight_layout()
plt.show()

In [None]:
from IPython.display import Audio

# Play the initial impulse response
print("Initial Impulse Response:")
display(Audio(ir_init.numpy(), rate=fs))

# Play the optimized impulse response
print("Optimized Impulse Response:")
display(Audio(ir_optim.numpy(), rate=fs))