# FIR low-pass filter design on quaternion-weighted graph

This example performs a task similar to the `lms.ipynb` notebook. Here
we reconstruct a quaternion-valued graph signal from its noisy version,
by means of a low-pass quaternion filter design via Quaternion Least Mean Squares (QLMS).

In [None]:
# If gspx is not installed, we add it to the path
import os, sys
gdir = os.path.dirname(os.getcwd())  # parent folder
sys.path.insert(0, gdir)

In [None]:
import numpy as np

from gspx.utils.display import plot_graph
from gspx.datasets import WeatherGraphData, uk_weather
from gspx.signals import QuaternionSignal
from gspx.qgsp import create_quaternion_weights, QGFT, QMatrix

## Quaternion-weighted graph

In [None]:
uk_data = WeatherGraphData()
Ar, coords = uk_data.graph
s = uk_data.signal

In [None]:
df = uk_weather()

Aq = create_quaternion_weights(
    Ar, df, icols=['humidity'], jcols=['temp'],
    kcols=['wind_speed'], gauss_den=0.5)

In [None]:
plot_graph(
    Aq.abs(), coords=coords,
    figsize=(4, 8), colormap='viridis',
    node_size=40)

In [None]:
qgft = QGFT()
qgft.fit(Aq)

## Creating a quaternionic heat kernel (smooth signal)

In [None]:
# Heat kernel in all 4 quaternion dimensions
k = 0.2
ss = np.zeros(len(qgft.idx_freq))
ss[qgft.idx_freq] = np.exp(-k * np.arange(len(qgft.idx_freq)))

ss = QuaternionSignal.from_rectangular(
    np.hstack([ss[:, np.newaxis]] * 4)
)

In [None]:
rnd = np.random.default_rng(seed=42)
err_amplitude = 0.15

nn = QuaternionSignal.from_equal_dimensions(
    rnd.uniform(low=-err_amplitude, high=err_amplitude, size=len(ss))
)

Spectrum of the original smooth signal:

In [None]:
QuaternionSignal.show(ss, ordering=qgft.idx_freq)

Noisy version of the grash signal signal (in the frequency domain):

In [None]:
QuaternionSignal.show(ss + nn, ordering=qgft.idx_freq)

Original signal in the vertex domain:

In [None]:
s = qgft.inverse_transform(ss)

obj = QuaternionSignal.from_samples(s.matrix.ravel())
node_color = [tuple(rgba) for rgba in obj.to_rgba()]

plot_graph(
    Aq.abs(), coords=coords, colors=node_color,
    figsize=(4, 8), colormap='viridis',
    node_size=40)

### Total variation of eigenvectors for each eigenvalue

In [None]:
import matplotlib.pyplot as plt
import numpy as np

plt.scatter(np.real(qgft.eigc), np.imag(qgft.eigc), c=qgft.tv_)
plt.colorbar()
plt.title("Total Variation of eigenvectors for each eigenvalue")
plt.xlabel("Real(eigvals)")
plt.ylabel("Imag(eigvals)")
plt.show()

### Ideal low-pass filter frequency response

In [None]:
h_ideal = np.zeros(len(qgft.idx_freq))

# Bandwith of 20% the frequency support
bandwidth = int(len(qgft.idx_freq) / 5)
h_ideal[qgft.idx_freq[:bandwidth]] = 1

h_idealq = QuaternionSignal.from_rectangular(np.hstack((
    h_ideal[:, np.newaxis],
    np.zeros(len(qgft.idx_freq))[:, np.newaxis],
    np.zeros(len(qgft.idx_freq))[:, np.newaxis],
    np.zeros(len(qgft.idx_freq))[:, np.newaxis]
)))
QuaternionSignal.show(h_idealq, ordering=qgft.idx_freq)

## Low-pass filter design via QLMS

In [None]:
from gspx.qgsp import QMatrix
deg = 7

X = QMatrix.vander(qgft.eigq, deg, increasing=True)
y = h_idealq
print(X.shape, y.shape)

In [None]:
from gspx.adaptive import QLMS

qlms = QLMS(alpha=[0.2, 0.3, 0.35, 0.4])
qlms.fit(X, y)
qlms.plot(nsamples=100)

### Quaternion-valued filter taps

In [None]:
qlms.res_[qlms.best_lr_]['result']

### FIR filter response

In [None]:
h_opt = qlms.predict(X)
h_opt = QuaternionSignal.from_samples(h_opt.matrix.ravel())
QuaternionSignal.show(h_opt, ordering=qgft.idx_freq)

### Signal reconstruction using both the ideal and the FIR low-pass filters

In [None]:
print("Ideal LPF filter.")
sn = qgft.inverse_transform(ss + nn)

print("MSE prior (* 1000):", 1000 * np.mean((s - sn).abs()**2))

ssn_lpf = (ss + nn).hadamard(h_idealq)
s_lpf = qgft.inverse_transform(ssn_lpf)

print("MSE post (* 1000):", 1000 * np.mean((s - s_lpf).abs()**2))

QuaternionSignal.show(ssn_lpf, ordering=qgft.idx_freq)

In [None]:
print("FIR LPF filter.")
sn = qgft.inverse_transform(ss + nn)

print("MSE prior (* 1000):", 1000 * np.mean((s - sn).abs()**2))

ssn_lpf = (ss + nn).hadamard(h_opt)
s_lpf = qgft.inverse_transform(ssn_lpf)

print("MSE post (* 1000):", 1000 * np.mean((s - s_lpf).abs()**2))

QuaternionSignal.show(ssn_lpf, ordering=qgft.idx_freq)