In [99]:
import numpy as np
import matplotlib.pyplot as plt
from stiefel_optimizer import AdamG, qr_retraction, SGDG
from torch.optim import AdamW
import torch

In [100]:
lin = torch.nn.Linear(8, 4, bias=False)

# Initialize to be orthogonal
q_param = qr_retraction(lin.weight.data.view(lin.weight.size(0), -1))
lin.weight.data.copy_(q_param.view(lin.weight.size()))

print(lin.weight.data)

tensor([[ 0.0824,  0.6819,  0.0800,  0.1486, -0.4823, -0.2930, -0.4207, -0.0656],
        [-0.2262, -0.2704, -0.1981,  0.1736, -0.2894, -0.6350,  0.2350,  0.5139],
        [-0.6645,  0.2620, -0.6268,  0.0740,  0.0267,  0.1738,  0.0857, -0.2307],
        [-0.4636,  0.1128,  0.6340,  0.1715,  0.2613, -0.3236,  0.2495, -0.3251]])


In [101]:
x = torch.randn(3, 8)
y = torch.randn(3, 4)

print(torch.nn.functional.mse_loss(lin(x), y))

tensor(2.1709, grad_fn=<MseLossBackward>)


In [140]:
lr=1.e-1
params = {'params': lin.parameters(), 'lr':lr, 'stiefel':True}
opt = SGDG([params])

In [146]:
for i in range(1000):
    opt.zero_grad()
    loss = torch.nn.functional.mse_loss(lin(x), y)
    # print(loss)
    loss.backward()
    opt.step()

    is_ortho = torch.nn.functional.mse_loss(lin.weight.data @lin.weight.data.T, torch.eye(4))
    if i % 100 == 0:
        print("loss:", loss, "ortho cost:", is_ortho)

loss: tensor(0.0538, grad_fn=<MseLossBackward>) ortho cost: tensor(1.1258e-15)
loss: tensor(0.0538, grad_fn=<MseLossBackward>) ortho cost: tensor(5.6899e-16)
loss: tensor(0.0538, grad_fn=<MseLossBackward>) ortho cost: tensor(6.8001e-16)
loss: tensor(0.0538, grad_fn=<MseLossBackward>) ortho cost: tensor(7.9103e-16)
loss: tensor(0.0538, grad_fn=<MseLossBackward>) ortho cost: tensor(1.3618e-15)
loss: tensor(0.0538, grad_fn=<MseLossBackward>) ortho cost: tensor(1.6445e-15)
loss: tensor(0.0538, grad_fn=<MseLossBackward>) ortho cost: tensor(6.4532e-16)
loss: tensor(0.0538, grad_fn=<MseLossBackward>) ortho cost: tensor(1.9377e-15)
loss: tensor(0.0538, grad_fn=<MseLossBackward>) ortho cost: tensor(1.9654e-15)
loss: tensor(0.0538, grad_fn=<MseLossBackward>) ortho cost: tensor(6.4705e-16)


In [147]:
print(lin.weight.data @lin.weight.data.T)

tensor([[ 1.0000e+00,  5.9605e-08,  0.0000e+00, -4.4703e-08],
        [ 5.9605e-08,  1.0000e+00,  1.4901e-08, -5.9605e-08],
        [ 0.0000e+00,  1.4901e-08,  1.0000e+00, -2.2352e-08],
        [-4.4703e-08, -5.9605e-08, -2.2352e-08,  1.0000e+00]])
