# Trainable GFRFT: Transform Learning

In [1]:
import torch as th
from torch_gfrft import EigvalSortStrategy
from torch_gfrft.gfrft import GFRFT
from torch_gfrft.gft import GFT
from transform_learning import experiment, generate_adjecency
from utils import seed_everything

In [2]:
SEED = 0
NUM_NODES = 100
TIME_LENGTH = 200
EIGVAL_SORT_STRATEGY = EigvalSortStrategy.TOTAL_VARIATION
SYMMETRIC = False
SELF_LOOPS = False
DEVICE = th.device("cuda" if th.cuda.is_available() else "cpu")

ORIGINAL_ORDER = 0.35
INITIAL_ORDERS = [0.0, 1.0]
LEARNING_RATE = 1e-3
EPOCHS = 2000
DISPLAY_EPOCHS = [1, 500, 1000, 1500, 2000]
SHOW_SUM_DURING_TRAINING = True

## Initialize Random Graph and JTV Signal with GFRFT

In [3]:
seed_everything(SEED)
jtv_signal = th.rand(NUM_NODES, TIME_LENGTH, device=DEVICE)
adjacency = generate_adjecency(NUM_NODES, DEVICE, SYMMETRIC, SELF_LOOPS)
gft = GFT(adjacency, EIGVAL_SORT_STRATEGY)
gfrft = GFRFT(gft.gft_mtx)

## Experiment

Transform the original signals using the GFRFT with fractional order `original_order`. By using the transformed signals as ground truth and MSE loss function, learn the multi-GFRFT layer network's fractional orders.

In [4]:
seed_everything(SEED)
experiment(
    gfrft=gfrft,
    original_signals=jtv_signal,
    original_order=ORIGINAL_ORDER,
    initial_orders=INITIAL_ORDERS,
    lr=LEARNING_RATE,
    epochs=EPOCHS,
    display_epochs=DISPLAY_EPOCHS,
    show_sum_during_training=SHOW_SUM_DURING_TRAINING,
    dim=0,
)

Sequential(
  (0): GFRFT(order=0.0, size=100, dim=0)
  (1): GFRFT(order=1.0, size=100, dim=0)
)
original order: 0.3500
learning rate: 0.001
initial orders: a1 =  0.0000 | a2 =  1.0000 | sum =  1.0000
-----------------------------------------------------------
Epoch    1 | Loss 7.80e+00 | a1 =  0.0000 | a2 =  1.0000 | sum =  1.0000
Epoch  500 | Loss 3.13e-01 | a1 = -0.2486 | a2 =  0.7514 | sum =  0.5028
Epoch 1000 | Loss 1.48e-03 | a1 = -0.3198 | a2 =  0.6802 | sum =  0.3605
Epoch 1500 | Loss 4.00e-07 | a1 = -0.3249 | a2 =  0.6751 | sum =  0.3502
Epoch 2000 | Loss 6.68e-10 | a1 = -0.3250 | a2 =  0.6750 | sum =  0.3500
-----------------------------------------------------------
final orders: a1 = -0.3250 | a2 =  0.6750
original order: 0.3500, final order sum: 0.3500
