In [None]:
import numpy as np
from plotly import graph_objs as go

import torch
import torch.nn as nn
import torch.optim as optim

import dsplib
import plotlib

from models import LutWithMemory
from training import Trainer

In [None]:
import importlib
importlib.reload(dsplib)
importlib.reload(plotlib)

In [None]:
fs_hz = 48e3
signal_duration_sec = 1
smp_num = dsplib.calc_smp_num(signal_duration_sec, fs_hz)

In [None]:
min_x = -5
max_x = 5

@dsplib.resample(factor=32)
def tanh_resample(x):
    return np.tanh(x)

In [None]:
S = dsplib.generate_delayed_sin_matrix(smp_num=smp_num, tone_freq_n=9100/fs_hz, mag=0.5, history_smp_num=3, noise_level=1e-4)

ref_sig = tanh_resample(S[0,:])

In [None]:
# Initialize model, loss, and optimizer
memory_depth = 3
model = LutWithMemory(input_range=(min_x, max_x), bins_num=64, memory_depth=memory_depth)
optimizer = optim.Adam(model.parameters(), lr=0.01)

# Prepare data
X = torch.from_numpy(S.copy().astype(np.float32))
y = torch.from_numpy(ref_sig.astype(np.float32))
# Train
trainer = Trainer(model, optimizer=optimizer, criterion=nn.MSELoss())
trainer.train(X, y, num_epochs=int(2e4))

fig = go.Figure()
fig.add_trace(go.Scatter(y=np.log10(trainer.loss_history)))
fig.show()


In [None]:
sig = S[0,:]
f, S_ref = dsplib.calc_spectrum(ref_sig, fs=fs_hz)
_, S_out = dsplib.calc_spectrum(np.tanh(sig), fs=fs_hz)
_, S_model = dsplib.calc_spectrum(model(X).detach().squeeze().numpy(), fs=fs_hz)

fig = go.Figure()
fig.add_trace(go.Scatter(x=f, y=S_ref[:-1], name="ref"))
fig.add_trace(go.Scatter(x=f, y=S_out, line=dict(width=2, dash='dash'), name="vanilla"))
fig.add_trace(go.Scatter(x=f, y=S_model, line=dict(width=2, dash='dashdot'), name="model"))

fig.update_layout(
    title=f"Signal spectrum before and after tanh",
    xaxis_title='Frequency, Hz',
    yaxis_title='Spectral density'
)

fig.show()