# Goal

Compute:

$$
\nabla_w L(w + \alpha)
$$

요런 걸 계산하고 싶다. 즉, forward 한 네트워크의 파라메터는 w' = w + @ 인데, 이걸 w 에 대해서 그라디언트를 계산하고 싶은 것.

제일 쉬운 방법은 forward 를 할 때 이 alpha 를 정말 상수로 더해줘서 계산하고 w 에 대해 그라디언트를 구하면 되겠지만, 문제는 이러면 ready-made network 를 사용할 때 매우 피곤해지는 문제가 있음. 그렇게 안하고 간단히 할 수 있을까? 가 최종 목표.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

## Setup

Linear MLP.

In [101]:
X = torch.randn(4, 6)
W1 = torch.randn(4, 6)
B1 = torch.randn(4)
W2 = torch.randn(1, 4)
B2 = torch.randn(1)
Y = torch.randn(4)

A1 = torch.randn_like(W1) * 0.1
A2 = torch.randn_like(W2) * 0.1

## Naive method

In [102]:
x = X.clone()
w1 = W1.clone().requires_grad_()
b1 = B1.clone().requires_grad_()
w2 = W2.clone().requires_grad_()
b2 = B2.clone().requires_grad_()
y = Y.clone()

In [103]:
out1 = F.linear(x, w1 + A1, b1)
out2 = F.linear(out1, w2 + A2, b2)

In [104]:
L = (out2 - y).abs().sum()

In [105]:
torch.autograd.grad(L, [w1, b1, w2, b2], retain_graph=True)

(tensor([[ 1.7616,  0.8231, -1.6028,  0.0847,  0.3678,  2.0537],
         [ 6.9855,  3.2641, -6.3559,  0.3361,  1.4587,  8.1441],
         [ 9.7947,  4.5768, -8.9118,  0.4712,  2.0453, 11.4193],
         [-4.5118, -2.1082,  4.1051, -0.2171, -0.9421, -5.2601]]),
 tensor([ -2.8140, -11.1589, -15.6464,   7.2073]),
 tensor([[  8.7595,  19.6231,  -7.4477, -36.4192]]),
 tensor([-16.]))

In [106]:
L.backward(retain_graph=True)

In [107]:
w1.grad, b1.grad, w2.grad, b2.grad

(tensor([[ 1.7616,  0.8231, -1.6028,  0.0847,  0.3678,  2.0537],
         [ 6.9855,  3.2641, -6.3559,  0.3361,  1.4587,  8.1441],
         [ 9.7947,  4.5768, -8.9118,  0.4712,  2.0453, 11.4193],
         [-4.5118, -2.1082,  4.1051, -0.2171, -0.9421, -5.2601]]),
 tensor([ -2.8140, -11.1589, -15.6464,   7.2073]),
 tensor([[  8.7595,  19.6231,  -7.4477, -36.4192]]),
 tensor([-16.]))

## Second

In [108]:
x = X.clone()
w1 = (W1 + A1).clone().requires_grad_()
b1 = B1.clone().requires_grad_()
w2 = (W2 + A2).clone().requires_grad_()
b2 = B2.clone().requires_grad_()
y = Y.clone()

In [109]:
out1 = F.linear(x, w1, b1)
out2 = F.linear(out1, w2, b2)

In [110]:
L = (out2 - y).abs().sum()

In [112]:
r = torch.autograd.grad(L, [w1, b1, w2, b2], retain_graph=True)

# Test

In [17]:
w = torch.rand(10, 10)
b = torch.rand(10)
x = torch.rand(4, 10)

In [9]:
out = F.linear(x, w, b)

In [10]:
out

tensor([[3.4362, 3.6962, 2.4504, 2.5618, 3.2899, 1.9352, 3.0735, 2.5174, 2.8351,
         3.2384],
        [4.1157, 4.5152, 3.2124, 3.6898, 4.4032, 2.2202, 4.2657, 3.6527, 3.9886,
         3.8373],
        [2.0405, 2.6753, 2.0907, 1.9293, 2.2609, 1.4715, 2.1983, 1.6173, 1.8051,
         2.0078],
        [3.5920, 4.1670, 2.6857, 3.1477, 3.0942, 2.3879, 3.3677, 2.9636, 3.3715,
         3.2556]], grad_fn=<AddmmBackward>)

In [12]:
out.sum().backward()

In [113]:
r

(tensor([[ 1.7616,  0.8231, -1.6028,  0.0847,  0.3678,  2.0537],
         [ 6.9855,  3.2641, -6.3559,  0.3361,  1.4587,  8.1441],
         [ 9.7947,  4.5768, -8.9118,  0.4712,  2.0453, 11.4193],
         [-4.5118, -2.1082,  4.1051, -0.2171, -0.9421, -5.2601]]),
 tensor([ -2.8140, -11.1589, -15.6464,   7.2073]),
 tensor([[  8.7595,  19.6231,  -7.4477, -36.4192]]),
 tensor([-16.]))

In [117]:
torch.stack([t.norm(2) for t in r]).norm(2)

tensor(55.7611)

In [122]:
torch.cat([t.flatten() for t in r]).norm(2)

tensor(55.7611)

In [124]:
r[0]

tensor([[ 1.7616,  0.8231, -1.6028,  0.0847,  0.3678,  2.0537],
        [ 6.9855,  3.2641, -6.3559,  0.3361,  1.4587,  8.1441],
        [ 9.7947,  4.5768, -8.9118,  0.4712,  2.0453, 11.4193],
        [-4.5118, -2.1082,  4.1051, -0.2171, -0.9421, -5.2601]])

In [125]:
r[0].data

tensor([[ 1.7616,  0.8231, -1.6028,  0.0847,  0.3678,  2.0537],
        [ 6.9855,  3.2641, -6.3559,  0.3361,  1.4587,  8.1441],
        [ 9.7947,  4.5768, -8.9118,  0.4712,  2.0453, 11.4193],
        [-4.5118, -2.1082,  4.1051, -0.2171, -0.9421, -5.2601]])

In [137]:
np.random.uniform(0.01, 0.1)

0.02554466825101471