In [1]:
import numpy as np
import matplotlib.pyplot as plt
from stiefel_optimizer import AdamG, qr_retraction, SGDG
from torch.optim import AdamW
import torch

In [2]:
# Simple Linear Network
lin = torch.nn.Linear(8, 4, bias=False)

In [3]:
# Initialize parameters to be orthogonal
q_param = qr_retraction(lin.weight.data.view(lin.weight.size(0), -1))
lin.weight.data.copy_(q_param.view(lin.weight.size()))

print("Checking if code is orthogonal:")
print(lin.weight.data @ lin.weight.data.T)

Checking if code is orthogonal:
tensor([[ 1.0000e+00,  2.2352e-08,  8.0094e-08, -5.9605e-08],
        [ 2.2352e-08,  1.0000e+00,  0.0000e+00,  3.7253e-08],
        [ 8.0094e-08,  0.0000e+00,  1.0000e+00,  4.4703e-08],
        [-5.9605e-08,  3.7253e-08,  4.4703e-08,  1.0000e+00]])


In [4]:
x = torch.randn(3, 8)
y = torch.randn(3, 4)

print(torch.nn.functional.mse_loss(lin(x), y))

tensor(3.1353, grad_fn=<MseLossBackward>)


In [5]:
# Initialize Optimizer
lr=1.e-1
params = {'params': lin.parameters(), 'lr':lr, 'stiefel':True}
opt = AdamG([params])

In [6]:
# Find orthogonal transform minimizing L2 error
for i in range(200):
    opt.zero_grad()
    loss = torch.nn.functional.mse_loss(lin(x), y)
    # print(loss)
    loss.backward()
    opt.step()

    is_ortho = torch.nn.functional.mse_loss(lin.weight.data @lin.weight.data.T, torch.eye(4))
    if i % 10 == 0:
        print("loss:", loss, "ortho cost:", is_ortho)

loss: tensor(3.1353, grad_fn=<MseLossBackward>) ortho cost: tensor(5.2858e-15)
loss: tensor(1.9553, grad_fn=<MseLossBackward>) ortho cost: tensor(6.9181e-15)
loss: tensor(0.8842, grad_fn=<MseLossBackward>) ortho cost: tensor(2.2917e-15)
loss: tensor(0.3609, grad_fn=<MseLossBackward>) ortho cost: tensor(6.2034e-15)
loss: tensor(0.1926, grad_fn=<MseLossBackward>) ortho cost: tensor(7.3344e-15)
loss: tensor(0.1359, grad_fn=<MseLossBackward>) ortho cost: tensor(8.8957e-15)
loss: tensor(0.1128, grad_fn=<MseLossBackward>) ortho cost: tensor(8.5279e-15)
loss: tensor(0.1019, grad_fn=<MseLossBackward>) ortho cost: tensor(1.6077e-14)
loss: tensor(0.0963, grad_fn=<MseLossBackward>) ortho cost: tensor(9.0292e-15)
loss: tensor(0.0932, grad_fn=<MseLossBackward>) ortho cost: tensor(2.4217e-14)
loss: tensor(0.0913, grad_fn=<MseLossBackward>) ortho cost: tensor(2.6737e-14)
loss: tensor(0.0901, grad_fn=<MseLossBackward>) ortho cost: tensor(3.1248e-14)
loss: tensor(0.0893, grad_fn=<MseLossBackward>) orth

In [7]:
# Check if parameters are still orthogonal:
print("Orthogonality Check:", lin.weight.data @lin.weight.data.T)

Orthogonality Check: tensor([[ 1.0000e+00,  0.0000e+00, -7.4506e-08,  4.8429e-08],
        [ 0.0000e+00,  1.0000e+00, -7.3574e-08,  1.3411e-07],
        [-7.4506e-08, -7.3574e-08,  1.0000e+00,  1.4901e-07],
        [ 4.8429e-08,  1.3411e-07,  1.4901e-07,  1.0000e+00]])
