In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [23]:
torch.manual_seed(42)
U1 = nn.Linear(2,4, bias=False)
U2 = nn.Linear(4,6, bias=False)

def step(x, e1, r1, e2, r2):
    with torch.no_grad():
        e1 = x - torch.matmul(U1.weight.T, r1)
    r1 = r1 + U1(e1) 
    with torch.no_grad():
        r1 -= (0.01*e2)
        e2 = r1 - torch.matmul(U2.weight.T, r2)
    r2 = r2 + U2(e2)
    return e1, r1, e2, r2

In [24]:
x = torch.rand(2)
e1 = torch.rand(2)
r1 = torch.rand(4)
e2 = torch.rand(4)
r2 = torch.rand(6)

In [32]:
e1 = e1.detach()
r1 = r1.detach()
e2 = e2.detach()
r2 = r2.detach()
e1, r1, e2, r2 = step(x, e1, r1, e2, r2)
print(f" x: {x}")
print(f"e1: {e1}")
print(f"r1: {r1}")
print(f"e2: {e2}")
print(f"r2: {r2}")

 x: tensor([0.8860, 0.5832])
e1: tensor([ 0.1565, -0.0101])
r1: tensor([ 1.2042,  0.1495,  0.1594, -0.5727], grad_fn=<AddBackward0>)
e2: tensor([ 0.2313,  0.3016, -0.1143, -0.4969])
r2: tensor([ 0.4809,  1.0673,  1.1476, -0.4755,  0.2951,  0.0814],
       grad_fn=<AddBackward0>)


In [33]:
with torch.no_grad():
    U1.weight.grad = torch.zeros_like(U1.weight)
    U2.weight.grad = torch.zeros_like(U2.weight)
loss = 0.5*((r1**2).sum() + (r2**2).sum())
loss.sum().backward()
print(f'U1.weight.grad:\n{U1.weight.grad}')
print(f'U2.weight.grad:\n{U2.weight.grad}')

U1.weight.grad:
tensor([[ 0.1884, -0.0122],
        [ 0.0234, -0.0015],
        [ 0.0249, -0.0016],
        [-0.0896,  0.0058]])
U2.weight.grad:
tensor([[ 0.1112,  0.1451, -0.0550, -0.2390],
        [ 0.2468,  0.3219, -0.1220, -0.5303],
        [ 0.2654,  0.3462, -0.1312, -0.5702],
        [-0.1100, -0.1434,  0.0544,  0.2363],
        [ 0.0682,  0.0890, -0.0337, -0.1466],
        [ 0.0188,  0.0246, -0.0093, -0.0405]])


In [34]:
print(f'1:\n{torch.matmul(e1.unsqueeze(1), r1.unsqueeze(1).T).T}')
print(f'2:\n{torch.matmul(e2.unsqueeze(1), r2.unsqueeze(1).T).T}')

1:
tensor([[ 0.1884, -0.0122],
        [ 0.0234, -0.0015],
        [ 0.0249, -0.0016],
        [-0.0896,  0.0058]], grad_fn=<PermuteBackward0>)
2:
tensor([[ 0.1112,  0.1451, -0.0550, -0.2390],
        [ 0.2468,  0.3219, -0.1220, -0.5303],
        [ 0.2654,  0.3462, -0.1312, -0.5702],
        [-0.1100, -0.1434,  0.0544,  0.2363],
        [ 0.0682,  0.0890, -0.0337, -0.1466],
        [ 0.0188,  0.0246, -0.0093, -0.0405]], grad_fn=<PermuteBackward0>)


In [274]:
U = nn.Conv2d(1, 3, (3,3), bias=False)
optimiser = torch.optim.SGD(U.parameters(), lr=0.001)
Ut = nn.ConvTranspose2d(3, 1, (3,3))
with torch.no_grad():
    Ut.weight = U.weight
grad = torch.zeros_like(U.weight)
def step(x, e, r):
    with torch.no_grad():
        e = x - Ut(r)
    r = r + U(e)
    return e, r

In [275]:
x = torch.rand((1,1,4,4))
e = torch.zeros_like(x)
r = torch.zeros((1,3,2,2))

In [294]:
e = e.detach()
r = r.detach()
e, r = step(x, e, r)
print(e)
print(r)

tensor([[[[ 0.5686, -0.0906,  0.0353,  0.2347],
          [ 0.3378,  0.1243, -0.1070, -0.0324],
          [ 0.2056,  0.1212, -0.0081, -0.0222],
          [ 0.0375,  0.0403,  0.0899,  0.0199]]]])
tensor([[[[ 0.3041,  0.3235],
          [-0.1773, -0.0444]],

         [[-0.0907, -1.0017],
          [ 0.0183, -0.6439]],

         [[ 0.6487,  1.1892],
          [-0.4736, -0.6325]]]], grad_fn=<AddBackward0>)


In [295]:
optimiser.zero_grad()
loss = (0.5*(r**2)).sum()
print(loss)
loss.backward()
optimiser.step()

tensor(2.0582, grad_fn=<SumBackward0>)


In [296]:
U.weight.grad

tensor([[[[ 0.0782, -0.0334,  0.1071],
          [ 0.1011, -0.0180, -0.0406],
          [ 0.0933,  0.0231, -0.0265]]],


        [[[-0.0347,  0.0440, -0.2194],
          [-0.2294,  0.1033,  0.0563],
          [-0.1653, -0.0601,  0.0118]]],


        [[[ 0.0226, -0.0079,  0.3732],
          [ 0.1929, -0.0989, -0.0901],
          [ 0.2343, -0.0069, -0.0868]]]])