In [1]:
import torch as th
import torch.nn as nn
from torch_gfrft import EigvalSortStrategy
from torch_gfrft.gfrft import GFRFT
from torch_gfrft.gft import GFT
from torch_gfrft.layer import GFRFTLayer

NUM_NODES = 100
TIME_LENGTH = 200
if th.cuda.is_available():
    DEVICE = th.device('cuda')
else:
    DEVICE = th.device('cpu')

In [2]:
A = th.rand(NUM_NODES, NUM_NODES, device=DEVICE)
# A = A + A.T
A = A - th.diag(th.diag(A))

In [3]:
gft = GFT(A, EigvalSortStrategy.TOTAL_VARIATION)
gfrft = GFRFT(gft.gft_mtx)

In [4]:
original_order = 0.35
th.manual_seed(0)
X = th.randn(NUM_NODES, TIME_LENGTH, device=DEVICE)
Y = gfrft.gfrft(X, original_order, dim=0)

In [10]:
def mse_loss(predictions: th.Tensor, targets: th.Tensor) -> th.Tensor:
    return th.norm(predictions - targets, p='fro', dim=0).mean()

model = nn.Sequential(
    GFRFTLayer(gfrft, 0.75, dim=0),
    GFRFTLayer(gfrft, 0.25, dim=0),
)
print(model)
optim = th.optim.Adam(model.parameters(), lr=5e-4)
epochs = 2000

th.manual_seed(0)
for epoch in range(epochs + 1):
    optim.zero_grad()
    output = mse_loss(model(X), Y)
    if epoch % 100 == 0:
        print(f"Epoch {epoch:4d} | Loss {output.item():<4.4f} | a1 = {model[0].order.item():.4f} | a2 = {model[1].order.item():.4f}")
    output.backward()
    optim.step()
print(f"Original a: {original_order:.4f}, Final a1: {model[0].order.item():.4f} | Final a2: {model[1].order.item():.4f}")
print(f"Final sum: {model[0].order.item() + model[1].order.item():.4f}")

Sequential(
  (0): GFRFT(order=0.75, size=100, dim=0)
  (1): GFRFT(order=0.25, size=100, dim=0)
)
Epoch    0 | Loss 80.8083 | a1 = 0.7500 | a2 = 0.2500
Epoch  100 | Loss 61.5493 | a1 = 0.7024 | a2 = 0.2024
Epoch  200 | Loss 48.6033 | a1 = 0.6607 | a2 = 0.1607
Epoch  300 | Loss 39.2506 | a1 = 0.6232 | a2 = 0.1232
Epoch  400 | Loss 31.6635 | a1 = 0.5879 | a2 = 0.0879
Epoch  500 | Loss 24.7365 | a1 = 0.5530 | a2 = 0.0530
Epoch  600 | Loss 17.8074 | a1 = 0.5169 | a2 = 0.0169
Epoch  700 | Loss 10.5126 | a1 = 0.4790 | a2 = -0.0210
Epoch  800 | Loss 2.7667 | a1 = 0.4392 | a2 = -0.0608
Epoch  900 | Loss 0.0060 | a1 = 0.4250 | a2 = -0.0750
Epoch 1000 | Loss 0.0065 | a1 = 0.4250 | a2 = -0.0750
Epoch 1100 | Loss 0.0080 | a1 = 0.4250 | a2 = -0.0750
Epoch 1200 | Loss 0.0111 | a1 = 0.4249 | a2 = -0.0751
Epoch 1300 | Loss 0.0059 | a1 = 0.4250 | a2 = -0.0750
Epoch 1400 | Loss 0.0079 | a1 = 0.4250 | a2 = -0.0750
Epoch 1500 | Loss 0.0118 | a1 = 0.4249 | a2 = -0.0751
Epoch 1600 | Loss 0.0060 | a1 = 0.425