# Trainable GFRFT: Transform Learning

In [1]:
import os
import random
from typing import Iterable

import numpy as np
import torch as th
import torch.nn as nn
from torch_gfrft import EigvalSortStrategy
from torch_gfrft.gfrft import GFRFT
from torch_gfrft.gft import GFT
from torch_gfrft.layer import GFRFTLayer

SEED = 0
NODE_DIM = 0
NUM_NODES = 100
TIME_LENGTH = 200
EIGVAL_SORT_STRATEGY = EigvalSortStrategy.TOTAL_VARIATION
SYMMETRIC = False
SELF_LOOPS = False

if th.cuda.is_available():
    DEVICE = th.device("cuda")
else:
    DEVICE = th.device("cpu")

## Seed for Reproducibility

In [2]:
def seed_everything(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    th.use_deterministic_algorithms(True)
    th.manual_seed(seed)
    if th.cuda.is_available():
        th.backends.cudnn.benchmark = False
        th.cuda.manual_seed_all(seed)
        os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"


seed_everything(SEED)

## Initialize Random Graph and Corresponding GFRFT

In [3]:
def generate_adjecency(
    num_nodes: int,
    device: th.device,
    symmetric: bool = False,
    self_loops: bool = False,
):
    A = th.rand(num_nodes, num_nodes, device=device)
    if symmetric:
        A = 0.5 * (A + A.T)
    if not self_loops:
        A = A - th.diag(th.diag(A))
    return A


adjacency = generate_adjecency(NUM_NODES, DEVICE, SYMMETRIC, SELF_LOOPS)
gft = GFT(adjacency, EIGVAL_SORT_STRATEGY)
gfrft = GFRFT(gft.gft_mtx)

## Generate Random Signals

In [4]:
if NODE_DIM == 0:
    X = th.rand(NUM_NODES, TIME_LENGTH, device=DEVICE)
elif NODE_DIM == 1:
    X = th.rand(TIME_LENGTH, NUM_NODES, device=DEVICE)
else:
    raise ValueError("NODE_DIM must be 0 or 1")

## Experiment

Transform the original signals using the GFRFT with fractional order `original_order`. By using the transformed signals as ground truth and MSE loss function, learn the multi-GFRFT layer network's fractional orders.

In [5]:
def mse_loss(predictions: th.Tensor, targets: th.Tensor) -> th.Tensor:
    return th.norm(predictions - targets, p="fro", dim=0).mean()


def get_order_info(
    model: nn.Sequential, show_sum: bool = False, sep: str = " | "
) -> str:
    orders = [layer.order.item() for layer in model]
    info_str = sep.join(f"a{i + 1} = {order: >7.4f}" for i, order in enumerate(orders))
    if show_sum:
        info_str += f"{sep}sum = {sum(orders): >7.4f}"
    return info_str


def experiment(
    gfrft: GFRFT,
    original_signals: th.Tensor,
    original_order: float,
    initial_orders: Iterable[float],
    *,
    dim: int = -1,
    lr: float = 5e-4,
    epochs: int = 1000,
    display_epochs: Iterable[int] | None = None,
    show_sum_during_training: bool = False,
):
    if display_epochs is None:
        display_epochs = (e for e in range(0, epochs, 100))
    display_epochs = set(display_epochs)
    transformed_signals = gfrft.gfrft(original_signals, original_order, dim=dim)
    model = nn.Sequential(
        *[GFRFTLayer(gfrft, order, dim=dim) for order in initial_orders]
    )
    print(model)
    print(f"original order: {original_order:.4f}")
    print(f"learning rate: {lr}")
    optim = th.optim.Adam(model.parameters(), lr=lr)

    start_str = (
        f"initial orders: {get_order_info(model, show_sum=show_sum_during_training)}"
    )
    print(start_str)
    print("-" * len(start_str))
    for epoch in range(1, 1 + epochs):
        optim.zero_grad()
        output = mse_loss(model(original_signals), transformed_signals)
        if epoch in display_epochs:
            info = get_order_info(model, show_sum=show_sum_during_training)
            print(f"Epoch {epoch:4d} | Loss {output.item(): >8.4f} | {info}")
        output.backward()
        optim.step()
    print("-" * len(start_str))
    print(f"final orders: {get_order_info(model)}")
    final_sum = sum(layer.order.item() for layer in model)
    print(f"original order: {original_order:.4f}, final order sum: {final_sum:.4f}")

experiment(
    gfrft=gfrft,
    original_signals=X,
    original_order=0.35,
    initial_orders=[0.0, 1.0],
    lr=5e-4,
    dim=NODE_DIM,
    epochs=2000,
    display_epochs=[0, 500, 1000, 1500, 2000],
    show_sum_during_training=True,
)

Sequential(
  (0): GFRFT(order=0.0, size=100, dim=0)
  (1): GFRFT(order=1.0, size=100, dim=0)
)
original order: 0.3500
learning rate: 0.0005
initial orders: a1 =  0.0000 | a2 =  1.0000 | sum =  1.0000
-----------------------------------------------------------
Epoch  500 | Loss   9.8519 | a1 = -0.1549 | a2 =  0.8451 | sum =  0.6901
Epoch 1000 | Loss   0.9936 | a1 = -0.3138 | a2 =  0.6862 | sum =  0.3724
Epoch 1500 | Loss   0.0093 | a1 = -0.3251 | a2 =  0.6749 | sum =  0.3498
Epoch 2000 | Loss   0.0098 | a1 = -0.3249 | a2 =  0.6751 | sum =  0.3502
-----------------------------------------------------------
final orders: a1 = -0.3250 | a2 =  0.6750
original order: 0.3500, final order sum: 0.3501
