In [1]:
from IPython.display import Image

In [20]:
import torch
from torch import nn
from torch.autograd import Variable

## multi head (output/branch) architecture

- https://www.bilibili.com/video/BV1o24y1b7tk

In [7]:
Image(url='../imgs/multi_loss.PNG', width=100)

In [8]:
a = Variable(torch.rand(1, 4), requires_grad=True)
b = a**2
c = b*2

d = c.mean()
e = c.sum()


d.backward()

# RuntimeError: Trying to backward through the graph a second time
e.backward()

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

In [15]:
a = Variable(torch.rand(1, 4), requires_grad=True)
b = a**2
c = b*2

d = c.mean()
e = c.sum()


d.backward(retain_graph=True)

e.backward()

$$
\begin{split}
&b_i=a_i^2\\
&c_i=2b_i=2a_i^2\\
&d=\frac{\sum_ic_i}4=\frac{\sum_i 2a_i^2}4\\
&e=\sum_i c_i=\sum_i 2a_i^2
\end{split}
$$

$$
\begin{split}
&\frac{\partial d}{\partial a_i}=a_i\\
&\frac{\partial e}{\partial a_i}=4a_i
\end{split}
$$

In [11]:
a

tensor([[0.0688, 0.2060, 0.3086, 0.4836]], requires_grad=True)

In [10]:
a.grad

tensor([[0.3442, 1.0302, 1.5431, 2.4179]])

In [14]:
5*a

tensor([[0.3442, 1.0302, 1.5431, 2.4179]], grad_fn=<MulBackward0>)

- suppose you first back-propagate loss1, then loss2 (you can also do the reverse)

```
l1.backward(retain_graph=True)
l2.backward() # now the graph is freed, and next process of batch gradient descent is ready

optimizer.step() # update the network parameters

```

## non-leaf node

In [16]:
a = Variable(torch.rand(1, 4), requires_grad=True)
b = a**2
c = b*2

d = c.mean()

In [17]:
d.backward()

In [18]:
b.grad

  b.grad


In [19]:
a = Variable(torch.rand(1, 4), requires_grad=True)
b = a**2
b.retain_grad()
c = b*2

d = c.mean()
d.backward()
b.grad

tensor([[0.5000, 0.5000, 0.5000, 0.5000]])

$$
\begin{split}
&d = \frac{\sum_i c_i}{4}=\frac{\sum_i 2b_i}{4}=\frac{\sum_i b_i}2\\
&\frac{\partial d}{\partial b_i}=\frac12
\end{split}
$$

### nn 中间层的weights 其实也是 leaf node

In [21]:
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(28*28, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc3 = nn.Linear(512, 10)

    def forward(self, x):
        x = nn.Flatten(x)
        x = self.fc3(nn.ReLU(self.fc2(nn.ReLU(self.fc1(x)))))
        return x

In [22]:
mlp = MLP()

In [24]:
mlp.fc1.weight.is_leaf

True

In [25]:
mlp.fc2.weight.is_leaf

True

In [26]:
mlp.fc3.weight.is_leaf

True