In [1]:
import torch
import torch.nn as nn

## Chain Rule and Computational Graph

In [2]:
# z = (x + y)^2

x = torch.randn(32, requires_grad=True)
y = torch.randn(32)

t = (x + y)
z = t ** 2

In [3]:
x.grad is None

True

In [4]:
z.backward(torch.ones_like(z))

In [5]:
x.grad

tensor([ 5.4802, -0.4582, -0.5747,  1.7369,  1.0143, -1.6004,  3.4750, -0.4550,
        -5.6589, -2.8882, -0.5215,  1.5525, -2.6667,  2.2073,  2.6927, -4.0240,
         1.9924, -1.3897, -0.2752,  0.2054, -1.2876, -1.4065,  1.3723, -0.5634,
        -0.3737,  3.7107, -4.1007, -3.1944,  2.8045,  1.8362,  2.0028, -1.2334])

In [6]:
2 * (x + y) # partial derivative z over x is same as 2(x+y)*1

tensor([ 5.4802, -0.4582, -0.5747,  1.7369,  1.0143, -1.6004,  3.4750, -0.4550,
        -5.6589, -2.8882, -0.5215,  1.5525, -2.6667,  2.2073,  2.6927, -4.0240,
         1.9924, -1.3897, -0.2752,  0.2054, -1.2876, -1.4065,  1.3723, -0.5634,
        -0.3737,  3.7107, -4.1007, -3.1944,  2.8045,  1.8362,  2.0028, -1.2334],
       grad_fn=<MulBackward0>)

In [7]:
t.grad_fn, z.grad_fn

(<AddBackward0 at 0x16ab8c19e48>, <PowBackward0 at 0x16ab8c19ef0>)

## Pytorch Implementation

- Add

In [8]:
class Add(torch.autograd.Function):
    
    @staticmethod
    def forward(ctx, i, j):
        result = i + j
        ctx.save_for_backward(result)
        return result
    
    @staticmethod
    def backward(ctx, grad_output):
        return grad_output, grad_output

In [15]:
x.grad = None
y2 = y.data
y2.requires_grad = True

In [16]:
t = Add.apply(x, y2)

In [12]:
t.backward(torch.ones_like(t))

In [13]:
x.grad # 덧셈 노드 역전파는 gradient를 그대로 전파

tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

In [14]:
y2.grad

tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

- Multiplication

In [59]:
class Mul(torch.autograd.Function):
    
    @staticmethod
    def forward(ctx, i, j):
        result = i * j
        ctx.save_for_backward(i, j)
        return result
    
    @staticmethod
    def backward(ctx, grad_output):
        x, y = ctx.saved_tensors
        return y, x

In [60]:
x.grad = None
y2.grad = None

In [61]:
t = Mul.apply(x, y2)

In [62]:
t.backward(torch.ones_like(t))

In [63]:
x.grad # is same as y2

tensor([ 1.2498,  1.2571,  1.2734, -0.0856, -0.8524, -2.4533, -0.0090,  0.1798,
        -1.8876, -1.1345,  0.2635, -1.0527, -1.1891,  1.1100, -0.5364, -1.9171,
        -0.1662, -0.7579, -0.2441,  0.0794,  0.0216, -1.3117, -0.8964,  0.9103,
        -0.6450,  0.7024, -1.8953, -1.1638, -0.2627,  0.9256,  0.6960, -0.6845])

In [64]:
y2.grad # is same as x

tensor([ 1.4903, -1.4862, -1.5608,  0.9541,  1.3596,  1.6531,  1.7464, -0.4073,
        -0.9419, -0.3096, -0.5243,  1.8290, -0.1442, -0.0064,  1.8828, -0.0949,
         1.1624,  0.0631,  0.1065,  0.0233, -0.6654,  0.6085,  1.5825, -1.1920,
         0.4582,  1.1529, -0.1550, -0.4334,  1.6650, -0.0075,  0.3054,  0.0678])

## 사과 쇼핑의 예

In [91]:
apple = torch.tensor([100.,], requires_grad=True)
num = torch.tensor([2.,], requires_grad=True)
ctax = torch.tensor([1.1], requires_grad=True)

In [92]:
price = apple * num
price.retain_grad()
result = price * ctax
result.retain_grad()

In [93]:
result.backward(torch.ones(1,))

In [96]:
apple.grad, num.grad, ctax.grad, price.grad, result.grad

(tensor([2.2000]),
 tensor([110.]),
 tensor([200.]),
 tensor([1.1000]),
 tensor([1.]))

## 사과와 귤 쇼핑의 역전파

In [98]:
apple = torch.tensor([100.,], requires_grad=True)
tangerine = torch.tensor([150.,], requires_grad=True)
num_apple = torch.tensor([2.,], requires_grad=True)
num_tangerine = torch.tensor([3.,], requires_grad=True)
ctax = torch.tensor([1.1], requires_grad=True)

In [100]:
apple_price = apple * num_apple
apple_price.retain_grad()

tangerine_price = tangerine * num_tangerine
tangerine_price.retain_grad()

price = apple_price + tangerine_price
price.retain_grad()

result = price * ctax
result.retain_grad()

In [101]:
result.backward(torch.ones(1))

In [115]:
items = ['apple', 'tangerine', 'num_apple', 'num_tangerine', 'ctax']
items += ['apple_price', 'tangerine_price', 'price', 'result']

for item in items:
    print(f"{item:>15s}.grad = {eval(item).grad.item():.2f}")

          apple.grad = 2.20
      tangerine.grad = 3.30
      num_apple.grad = 110.00
  num_tangerine.grad = 165.00
           ctax.grad = 650.00
    apple_price.grad = 1.10
tangerine_price.grad = 1.10
          price.grad = 1.10
         result.grad = 1.00
