In [None]:
from pathlib import Path
from typing import Iterable, List

import torch as th
import torch.nn as nn
from filtering import GFRFTFilterLayer, Real
from torch_gfrft import EigvalSortStrategy
from torch_gfrft.gfrft import GFRFT
from torch_gfrft.gft import GFT
from utils import (
    add_gaussian_noise,
    complex_round,
    init_knn_from_mat,
    mse_loss,
    seed_everything,
    snr,
)

In [None]:
SEED = 0
NODE_DIM = 0
NUM_NODES = 100
TIME_LENGTH = 200
ORIGINAL_ORDER = 0.35
LEARNING_RATE = 5e-4
EPOCHS = 2000
EIGVAL_SORT_STRATEGY = EigvalSortStrategy.TOTAL_VARIATION
SYMMETRIC = False
SELF_LOOPS = False
DEVICE = th.device("cuda" if th.cuda.is_available() else "cpu")

KNN_COUNT = 10
KNN_SIGMA = None
MAX_NODE_COUNT = 100
MAX_TIME_LENGTH = 120
GRAPH_VERBOSE = True

In [None]:
seed_everything(SEED)
datasets_path = Path.cwd().parent.joinpath("data", "tv-graph-datasets").absolute()
curr_dataset_path = datasets_path.joinpath("sea-surface-temperature.mat")
graph, adjacency, jtv_signal = init_knn_from_mat(
    curr_dataset_path,
    knn_count=KNN_COUNT,
    knn_sigma=KNN_SIGMA,
    max_node_count=MAX_NODE_COUNT,
    max_time_length=MAX_TIME_LENGTH,
    device=DEVICE,
    verbose=GRAPH_VERBOSE,
)
gft = GFT(adjacency, EIGVAL_SORT_STRATEGY)
gfrft = GFRFT(gft.gft_mtx)

In [None]:
limits = [jtv_signal.min().floor().item(), jtv_signal.max().ceil().item()]
for t in [0, 5, 10, 15]:
    gsignal = jtv_signal[..., t]
    graph.plot_signal(gsignal.cpu().numpy(), limits=limits)

In [None]:
def generate_bandlimited_jtv_signal(
    gfrft: GFRFT,
    signal: th.Tensor,
    order: float,
    stopband: slice,
) -> th.Tensor:
    if not (0 <= stopband.start and stopband.stop <= signal.size(0)):
        raise ValueError("Count must be between 0 and the signal size")
    transformed = gfrft.gfrft(signal, order, dim=0)
    transformed[stopband, ...] = 0
    return gfrft.igfrft(transformed, order, dim=0)


def generate_bandlimited_noise(
    gfrft: GFRFT,
    signal: th.Tensor,
    order: float,
    stopband: slice,
    mean: float = 0.0,
    sigma: float = 1.0,
) -> th.Tensor:
    noise = th.zeros_like(signal)
    th.manual_seed(0)
    noise[stopband, ...] = mean + sigma * th.randn_like(noise[stopband, ...])
    return gfrft.igfrft(noise, order, dim=0)


def generate_bandlimited_experiment_data(
    gfrft: GFRFT,
    signal: th.Tensor,
    order: float,
    stopband_count: int,
    mean: float = 0.0,
    sigma: float = 1.0,
) -> tuple[th.Tensor, th.Tensor]:
    size = signal.size(0)
    stopband = slice(size - stopband_count, size)
    bl_signal = generate_bandlimited_jtv_signal(gfrft, signal, order, stopband)
    bl_noise = generate_bandlimited_noise(gfrft, signal, order, stopband, mean, sigma)
    return bl_signal, bl_noise

In [None]:
# def get_passband_from_stopband(size: int, stopband: slice, slice_stop: int | None = None) -> List[int]:
#     start = stopband.start if stopband.start is not None else 0
#     stop = stopband.stop if stopband.stop is not None else slice_stop
#     if stop is None:
#         raise ValueError("slice stop must be specified")
#     step = stopband.step if stopband.step is not None else 1
#     stop_values = set(range(start, stop, step))
#     complement_idx = list(set(range(size)).difference(stop_values))
#     return complement_idx


# arr = th.randn(20)
# stopband = slice(0, 10)
# passband = get_passband_from_stopband(arr.size(0), stopband)
# arr[passband], arr[stopband] = 1, 0
# arr

In [None]:
size = jtv_signal.size(0)
stopband = slice(size - 10, size)
bandlimited_jtv_signal = generate_bandlimited_jtv_signal(
    gfrft, jtv_signal[..., 0:2], ORIGINAL_ORDER, stopband
)
bandlimited_noise = generate_bandlimited_noise(
    gfrft, jtv_signal[..., 0:2], ORIGINAL_ORDER, stopband
)

# complex_round(gfrft.gfrft(bandlimited_jtv_signal, ORIGINAL_ORDER, dim=0), decimals=7)
complex_round(gfrft.gfrft(bandlimited_noise, ORIGINAL_ORDER, dim=0), decimals=7)

In [None]:
STOPBAND_COUNT = 10
bl_signal, bl_noise = generate_bandlimited_experiment_data(
    gfrft=gfrft,
    signal=jtv_signal,
    order=ORIGINAL_ORDER,
    stopband_count=STOPBAND_COUNT,
    mean=0.0,
    sigma=10.0,
)
bl_noisy_signal = bl_signal + bl_noise
print(f"SNR: {snr(bl_signal, bl_noise):.2f}")
print(f"MSE: {mse_loss(bl_signal, bl_noisy_signal):.2f}")

In [None]:
index = 1
graph.plot_signal(bl_signal[..., index].real.cpu().numpy(), limits=limits)
# graph.plot_signal(bl_noise[..., index].real.cpu().numpy(), limits=limits)
graph.plot_signal(bl_noisy_signal[..., index].real.cpu().numpy(), limits=limits)

In [None]:
def experiment(
    gfrft: GFRFT,
    jtv_signal: th.Tensor,
    jtv_noise: th.Tensor,
    initial_orders: List[float],
    cutoff_counts: List[int],
    *,
    lr: float = 5e-4,
    epochs: int = 1000,
    display_epochs: Iterable[int] | None = None,
    seed: int = 0,
    trainable_transform: bool = True,
    trainable_filter: bool = False,
) -> nn.Module:
    if len(initial_orders) != len(cutoff_counts):
        raise ValueError("initial_orders and cutoff_counts must have the same length")
    if display_epochs is None:
        display_epochs = (e for e in range(0, epochs, 100))
    display_epochs = set(display_epochs)

    seed_everything(seed)
    filters = [
        GFRFTFilterLayer(
            gfrft,
            cutoff,
            order,
            trainable_transform=trainable_transform,
            trainable_filter=trainable_filter,
        )
        for order, cutoff in zip(initial_orders, cutoff_counts)
    ]
    layers = [elem for pair in zip(filters, [Real()] * len(filters)) for elem in pair]
    model = nn.Sequential(*layers)
    print(model)
    print(f"learning rate: {lr}")
    optim = th.optim.Adam(model.parameters(), lr=lr)
    noisy_signal = jtv_signal + jtv_noise

    initial_loss = mse_loss(noisy_signal, jtv_signal)
    print(f"Epoch {0:4d} | Loss {initial_loss.item(): >8.4f}")
    for epoch in range(1, 1 + epochs):
        optim.zero_grad()
        output = mse_loss(model(noisy_signal), jtv_signal)
        if epoch in display_epochs:
            print(f"Epoch {epoch:4d} | Loss {output.item(): >8.4f}")
        if not (trainable_transform or trainable_filter):
            break
        output.backward()
        optim.step()
    return model


model = experiment(
    gfrft=gfrft,
    jtv_signal=bl_signal,
    jtv_noise=bl_noise,
    initial_orders=[1.0],
    cutoff_counts=[STOPBAND_COUNT],
    lr=5e-2,
    epochs=2000,
    display_epochs=list(range(0, 2000 + 1, 200)),
    trainable_transform=True,
    trainable_filter=True,
    seed=42,
)
model

In [None]:
with th.no_grad():
    estimated_signal = model(bl_noisy_signal)

initial_mse = mse_loss(bl_signal, bl_noisy_signal)
estimated_mse = mse_loss(bl_signal, estimated_signal)

initial_snr = snr(bl_signal, bl_noisy_signal - bl_signal)
estimated_snr = snr(bl_signal, estimated_signal - bl_signal)

print(f"Initial MSE: {initial_mse:.2f}, Estimated MSE: {estimated_mse:.2f}")
print(f"Initial SNR: {initial_snr:.2f}, Estimated SNR: {estimated_snr:.2f}")