In [9]:
import numpy as np

import torch
import torch.nn as nn

In [10]:
x_test = torch.randn(3, 4, 4) + 3
noise = torch.randn(3, 4, 4)

x_noisy = noise + x_test
x_noisy = x_noisy.view(1, -1)

In [11]:
n = 3 * 4 * 4
vs = [nn.Parameter(torch.randn(1, n),
                   requires_grad=True) for _ in range(n)]

sigma = nn.Parameter(torch.randn(1, n) * 0.1, requires_grad=True)
mean = nn.Parameter(torch.randn(1, n) * 0.1, requires_grad=True)
# vs = nn.Parameter(torch.randn(n, 1, n), requires_grad=True)
optim = torch.optim.SGD(vs, 1)

noise_init = torch.randn(1, n, 1)

In [12]:
def householder_caster(b, n, device='cpu'):    
    def compute_householder_matrix(vs):
        Qs = []
        for i in range(b):
            Q = torch.eye(n, device=device)
            for v in vs:
                vi = v[i].view(-1, 1)
                vi = vi / vi.norm()
                Qi = torch.eye(n, device=device, requires_grad=False) - 2 * torch.mm(vi, vi.permute(1, 0))
                Q = torch.mm(Q, Qi)
            Qs.append(Q)
        return torch.stack(Qs)
    return compute_householder_matrix

In [5]:
h = householder_caster(1, n)

In [6]:
h(vs)

tensor([[[ 0.3069,  0.2826,  0.1630,  ...,  0.0331,  0.0143,  0.0502],
         [-0.2939,  0.1148,  0.1347,  ...,  0.1534,  0.1014, -0.0770],
         [ 0.0894,  0.1064,  0.4262,  ...,  0.0611,  0.2658, -0.1515],
         ...,
         [-0.1016,  0.0919, -0.1530,  ...,  0.1301,  0.1969, -0.1835],
         [-0.0534, -0.0543,  0.0163,  ..., -0.1478,  0.3678, -0.0393],
         [ 0.2041, -0.3262,  0.1283,  ..., -0.1729,  0.0631, -0.0146]]],
       grad_fn=<StackBackward>)

In [7]:
noise_init.shape

torch.Size([1, 48, 1])

In [8]:
for t in range(1000):
    optim.zero_grad()
    noise_est = h(vs) @ noise_init #(torch.randn(1, n, 1) * sigma + mean)
    x_recov = x_noisy - noise_est
    loss = (x_recov - 3) ** 2
    loss = loss.mean()
    loss.backward()
    optim.step()
    if t > 0 and t % 100 == 0:
        with torch.no_grad():
            residual = torch.norm(noise_est - noise.view(1, -1))
            print(f'rotated: mean={noise_est.mean():.3f}, std={noise_est.std():.3f}')
            print(f'residual={residual.item():.5f}, loss={loss.item():.5f}')
            print('--')
#             print(vs[0][0][:5])

rotated: mean=0.941, std=0.094
residual=67.04375, loss=2.58793
--
rotated: mean=0.946, std=0.008
residual=67.02762, loss=2.58708
--
rotated: mean=0.946, std=0.001
residual=67.02757, loss=2.58707
--
rotated: mean=0.946, std=0.000
residual=67.02758, loss=2.58707
--
rotated: mean=0.946, std=0.000
residual=67.02753, loss=2.58707
--
rotated: mean=0.946, std=0.000
residual=67.02754, loss=2.58707
--
rotated: mean=0.946, std=0.000
residual=67.02756, loss=2.58707
--
rotated: mean=0.946, std=0.000
residual=67.02754, loss=2.58707
--
rotated: mean=0.946, std=0.000
residual=67.02755, loss=2.58707
--


In [None]:
noise_est = h(vs) @ noise_init
noise_est.std()

In [None]:
noise_est.mean()

In [None]:
noise_init.std()

In [None]:
noise_init.mean()