### 1. 计算图的正向传播是立即执行的 

In [1]:
import torch

In [2]:
w = torch.tensor([[3.0, 1.0]], requires_grad = True)
print(w, w.dim())

tensor([[3., 1.]], requires_grad=True) 2


In [4]:
b = torch.tensor([[3.0]], requires_grad = True)
x = torch.randn(10, 2)
print(x, x.dim())

tensor([[-1.0225,  0.8705],
        [-1.0744,  0.3900],
        [-1.6546, -0.6132],
        [-0.1766,  0.7379],
        [ 1.6862, -1.5031],
        [ 0.3034,  0.2811],
        [-0.8615, -1.0708],
        [-0.8512,  0.7998],
        [-0.3452, -0.7666],
        [ 0.9913,  0.5325]]) 2


In [5]:
y = torch.randn(10, 1)
print(y, y.dim())

tensor([[ 0.6094],
        [-1.0475],
        [-0.7576],
        [ 0.8444],
        [ 0.8226],
        [-0.9716],
        [ 1.0400],
        [-0.7278],
        [ 1.7263],
        [-1.3268]]) 2


In [6]:
y_hat = x @ w.t() + b
print(y_hat)

tensor([[ 0.8030],
        [ 0.1666],
        [-2.5769],
        [ 3.2081],
        [ 6.5555],
        [ 4.1915],
        [-0.6554],
        [ 1.2462],
        [ 1.1977],
        [ 6.5063]], grad_fn=<AddBackward0>)


In [7]:
print(y_hat.data)

tensor([[ 0.8030],
        [ 0.1666],
        [-2.5769],
        [ 3.2081],
        [ 6.5555],
        [ 4.1915],
        [-0.6554],
        [ 1.2462],
        [ 1.1977],
        [ 6.5063]])


In [9]:
loss = torch.mean(torch.pow((y_hat - y), 2))
print(loss.data)

tensor(13.8340)


### 2. 计算图在反向传播后立即销毁

In [10]:
import torch

In [11]:
w = torch.tensor([[3.0, 1.0]], requires_grad = True)
print(w.data)

tensor([[3., 1.]])


In [16]:
b = torch.tensor([[3.0]], requires_grad = True)
x = torch.randn(10, 2)
y = torch.randn(10, 1)
y_hat = x @ w.t() + b
loss = torch.mean(torch.pow((y_hat - y), 2))
print('\n')
print('计算图在反向传播之后会立即的注销，如果要保留计算图，需要设置retain_graph = True')
loss.backward() # loss.backward(retain_graph = True)
#print('导数为:', x_grad)
# loss.backward() # 如果需要再次执行反向传播将报错



计算图在反向传播之后会立即的注销，如果要保留计算图，需要设置retain_graph = True


 ### 3. 计算图中的Function

In [1]:
import torch

In [2]:
class MyRelu(torch.autograd.Function):
    
    # 正向传播 
    def forward(ctx, input):
        # ctx.save_for_backward方法用于存储在forward()此期间生成的值，稍后将在执行时需要此值backward()。可以backward()在ctx.saved_tensors属性期间访问保存的值。
        ctx.save_for_backward(input)
        return input.clamp(min = 0)
    
    
    # 反向传播
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_output = grad_output.clone()
        grad_output[input < 0] = 0
        return grad_output

In [4]:
w = torch.tensor([[3.0, 1.0]], requires_grad = True)
print(w.data, w.dim(), w.size())

tensor([[3., 1.]]) 2 torch.Size([1, 2])


In [13]:
b = torch.tensor([[3.0]],  requires_grad = True)
x = torch.tensor([[-1.0, -1.0], [1.0, 1.0]])
y = torch.tensor([[2.0, 3.0]])

In [14]:
relu = MyRelu.apply # relu现在也具有正向传播和反向传播功能

In [17]:
relu

<function MyRelu.apply>

In [18]:
y_hat = relu(x @ w.t()) + b
y_hat

tensor([[3.],
        [7.]], grad_fn=<AddBackward0>)

In [19]:
loss = torch.mean(torch.pow(y_hat - y, 2))
loss

tensor(10.5000, grad_fn=<MeanBackward0>)

In [20]:
loss.backward()

In [21]:
print('对w求导数：', w.grad)
print('对b求导数：', b.grad)

对w求导数： tensor([[4.5000, 4.5000]])
对b求导数： tensor([[5.]])


In [23]:
print(y_hat.grad_fn) # y_hat的梯度函数即是我们自己所定义的MyRelu.backward

<AddBackward0 object at 0x00000228FEF709D0>


### 4. 计算图与反向传播

In [24]:
import torch

In [26]:
x = torch.tensor(3.0, requires_grad = True)
print(x)

tensor(3., requires_grad=True)


In [27]:
y1 = x + 1
print(y1)

tensor(4., grad_fn=<AddBackward0>)


In [28]:
y2 = 2 * x
print(y2)

tensor(6., grad_fn=<MulBackward0>)


In [29]:
loss = (y1 - y2) ** 2
print(loss)

tensor(4., grad_fn=<PowBackward0>)


In [30]:
loss.backward()

In [33]:
print(y1.grad_fn)

<AddBackward0 object at 0x00000228FF39BA90>


 ### 5. 叶子节点和非叶子节点

In [35]:
print('叶子节点需要满足两个条件')
print('1. 叶子节点张量是由用户直接创建的张量，而非由某个Function通过计算得到的张量')
print('2. 叶子节点张量的requires_grad必须为True')
print('这样设计的好处：节约内存或者小村空间，因为几乎所有的时候，用户只关心他自己直接创建的张量的梯度')

叶子节点需要满足两个条件
1. 叶子节点张量是由用户直接创建的张量，而非由某个Function通过计算得到的张量
2. 叶子节点张量的requires_grad必须为True
这样设计的好处：节约内存或者小村空间，因为几乎所有的时候，用户只关心他自己直接创建的张量的梯度


In [36]:
import torch

In [37]:
x = torch.tensor(3.0, requires_grad = True)
y1 = x + 1
y2 = 2 * x
loss = (y1 - y2) ** 2

In [38]:
loss.backward()
print('loss.grad:', loss.grad)
print('y1.grad:', y1.grad)
print('y2.grad:', y2.grad)
print('x.grad:', x.grad)

loss.grad: None
y1.grad: None
y2.grad: None
x.grad: tensor(4.)


  return self._grad


In [40]:
print('打印看是否为叶子节点:')
print('x是否为叶子节点：', x.is_leaf)
print('y1是否为叶子节点：', y1.is_leaf)
print('y2是否为叶子节点：', y2.is_leaf)
print('loss是否为叶子节点', loss.is_leaf)

打印看是否为叶子节点:
x是否为叶子节点： True
y1是否为叶子节点： False
y2是否为叶子节点： False
loss是否为叶子节点 False


In [41]:
print('利用retain_grad可以保留非叶子节点的梯度，利用register_hook可以查看非叶子节点的梯度值')

利用retain_grad可以保留非叶子节点的梯度，利用register_hook可以查看非叶子节点的梯度值


In [42]:
import torch
x = torch.tensor(3.0, requires_grad = True)
y1 = x + 1
y2 = 2 * x
loss = (y1 - y2) ** 2

In [43]:
print('非叶子节点梯度显示控制')
y1.register_hook(lambda grad: print('y1 grad:', grad))
y2.register_hook(lambda grad: print('y2 grad:', grad))
loss.retain_grad()

非叶子节点梯度显示控制


In [44]:
loss.backward()
print('loss.grad:', loss.grad)
print('x.grad:', x.grad)

y2 grad: tensor(4.)
y1 grad: tensor(-4.)
loss.grad: tensor(1.)
x.grad: tensor(4.)


### 6. 计算图在TensoBoard中的可视化

In [45]:
from torch import nn

In [46]:
torch.randn(2, 1)

tensor([[-0.6050],
        [ 0.9749]])

In [47]:
torch.zeros(1, 1)

tensor([[0.]])

In [48]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.w = nn.Parameter(torch.randn(2, 1))
        self.b = nn.Parameter(torch.zeros(1, 1))

    def forward(self, x):
        y = x @ self.w + self.b
        return y
    
net = Net()

In [49]:
net

Net()

In [50]:
from torch.utils.tensorboard import SummaryWriter

In [51]:
writer = SummaryWriter('./data/tensorboard')
writer.add_graph(net, input_to_model = torch.randn(10, 2))
writer.close()

In [52]:
from tensorboard import notebook

In [53]:
notebook.list()

No known TensorBoard instances running.


In [55]:
print('在tensorboard中查看模型')
notebook.start('--logdir ./data/tensorboard')

在tensorboard中查看模型


Reusing TensorBoard on port 6006 (pid 20000), started 0:01:42 ago. (Use '!kill 20000' to kill it.)