In [211]:
import os 
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# 获取训练设备
## 查看有无可用GPU

In [212]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cuda


# 定义神经网络类
## super(NeuralNetwork,self) ：查找NeuralNetwork的父类，对self实施父类的方法
## nn.Flatten(x,[start=1,end=-1]) ：对输入张量进行指定维数降维，此处将(1,28,28)降成(1,28*28)
## nn.Sequential() ：序列容器，将神经网络模块按顺序添加到容器中

In [213]:
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork,self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28,512),
            nn.ReLU(),
            nn.Linear(512,512),
            nn.ReLU(),
            nn.Linear(512,10),
        )
        
    def forward(self,x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

In [214]:
##将模型移入GPU并打印其网络结构
model = NeuralNetwork().to(device)
print(model)

NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
)


In [215]:
##输入数据到模型模块进行推理，不要直接调用model.forward()!!!
X = torch.rand(1,28,28,device=device)
logits = model(X)
pred_probab = nn.Softmax(dim=1)(logits)
y_pred = pred_probab.argmax(dim=1)
print(logits)
print(pred_probab)
print(y_pred)

tensor([[ 0.1336,  0.0304,  0.0931,  0.0339, -0.0094, -0.0008,  0.0166, -0.0240,
          0.0925,  0.0196]], device='cuda:0', grad_fn=<AddmmBackward0>)
tensor([[0.1098, 0.0991, 0.1055, 0.0994, 0.0952, 0.0960, 0.0977, 0.0938, 0.1054,
         0.0980]], device='cuda:0', grad_fn=<SoftmaxBackward0>)
tensor([0], device='cuda:0')


# 模型层解构

In [216]:
input_image = torch.rand(3,28,28)
print(input_image.size())

torch.Size([3, 28, 28])


## nn.Flatten

In [217]:
flatten = nn.Flatten()
flat_image = flatten(input_image)
print(flat_image.size())

torch.Size([3, 784])


## nn.Linear

In [218]:
layer1 = nn.Linear(in_features=28*28,out_features=20)
hidden1 = layer1(flat_image)
print(hidden1)

tensor([[ 0.2000,  0.2776, -0.7355, -0.3797, -0.1512,  0.0185, -0.5153,  0.1008,
         -0.0465, -0.0639,  0.3081, -0.1949, -0.3344, -0.2686, -0.4252,  0.3563,
         -0.4522, -0.4960,  0.1055,  0.2331],
        [ 0.0975,  0.2940, -0.9692, -0.3690,  0.0631, -0.3411, -0.0673,  0.3240,
         -0.1576, -0.1297,  0.3874,  0.1448, -0.0382, -0.0201, -0.3787,  0.0979,
         -0.1431, -0.3668,  0.0498,  0.0185],
        [ 0.1308,  0.4705, -1.0834, -0.4550, -0.5799, -0.1748, -0.0261,  0.1923,
         -0.1883, -0.1687,  0.2970,  0.0997, -0.3435, -0.0668, -0.1730,  0.3838,
         -0.5578, -0.2858, -0.3256,  0.2613]], grad_fn=<AddmmBackward0>)


## nn.ReLu

In [219]:
print("Before ReLU: "+str(hidden1))
hidden1 = nn.ReLU()(hidden1)
print("After ReLU: "+str(hidden1))

Before ReLU: tensor([[ 0.2000,  0.2776, -0.7355, -0.3797, -0.1512,  0.0185, -0.5153,  0.1008,
         -0.0465, -0.0639,  0.3081, -0.1949, -0.3344, -0.2686, -0.4252,  0.3563,
         -0.4522, -0.4960,  0.1055,  0.2331],
        [ 0.0975,  0.2940, -0.9692, -0.3690,  0.0631, -0.3411, -0.0673,  0.3240,
         -0.1576, -0.1297,  0.3874,  0.1448, -0.0382, -0.0201, -0.3787,  0.0979,
         -0.1431, -0.3668,  0.0498,  0.0185],
        [ 0.1308,  0.4705, -1.0834, -0.4550, -0.5799, -0.1748, -0.0261,  0.1923,
         -0.1883, -0.1687,  0.2970,  0.0997, -0.3435, -0.0668, -0.1730,  0.3838,
         -0.5578, -0.2858, -0.3256,  0.2613]], grad_fn=<AddmmBackward0>)
After ReLU: tensor([[0.2000, 0.2776, 0.0000, 0.0000, 0.0000, 0.0185, 0.0000, 0.1008, 0.0000,
         0.0000, 0.3081, 0.0000, 0.0000, 0.0000, 0.0000, 0.3563, 0.0000, 0.0000,
         0.1055, 0.2331],
        [0.0975, 0.2940, 0.0000, 0.0000, 0.0631, 0.0000, 0.0000, 0.3240, 0.0000,
         0.0000, 0.3874, 0.1448, 0.0000, 0.0000, 0.0000

## nn.Sequential

In [223]:
seq_modules = nn.Sequential(
    flatten,
    layer1,
    nn.ReLU(),
    nn.Linear(20,10)
)
input_image = torch.rand(3,28,28)
logits = seq_modules(input_image)
print(logits)

tensor([[-0.1867,  0.2772, -0.2026,  0.0006,  0.0781, -0.0579, -0.0299,  0.2266,
         -0.2537,  0.1294],
        [-0.1681,  0.3171, -0.2369,  0.0297,  0.0926, -0.0527, -0.0191,  0.2763,
         -0.3408,  0.1500],
        [-0.3777,  0.2138, -0.0585,  0.2183,  0.0368, -0.0401,  0.0495, -0.0046,
         -0.2697,  0.3572]], grad_fn=<AddmmBackward0>)


## nn.Softmax

In [227]:
softmax = nn.Softmax(dim=1)
pred_probab = softmax(logits)
print(pred_probab)

tensor([[0.0819, 0.1302, 0.0806, 0.0988, 0.1067, 0.0932, 0.0958, 0.1238, 0.0766,
         0.1124],
        [0.0824, 0.1339, 0.0769, 0.1004, 0.1070, 0.0925, 0.0957, 0.1285, 0.0693,
         0.1133],
        [0.0662, 0.1197, 0.0911, 0.1202, 0.1003, 0.0928, 0.1015, 0.0962, 0.0738,
         0.1381]], grad_fn=<SoftmaxBackward0>)


# 模型参数
## nn.Moudle会自动跟踪保存模型参数，使用parameters()或named_parameters()获取

In [228]:
print(model)
for name,param in model.named_parameters():
    print(name)
    print(param)

NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
)
linear_relu_stack.0.weight
Parameter containing:
tensor([[ 0.0136,  0.0086,  0.0201,  ..., -0.0083,  0.0114,  0.0195],
        [-0.0292,  0.0223,  0.0234,  ..., -0.0216,  0.0027,  0.0272],
        [ 0.0288, -0.0090, -0.0113,  ..., -0.0138,  0.0296,  0.0332],
        ...,
        [-0.0310,  0.0033,  0.0168,  ...,  0.0356, -0.0144, -0.0318],
        [ 0.0150, -0.0069,  0.0163,  ..., -0.0115,  0.0056,  0.0016],
        [ 0.0151, -0.0308,  0.0306,  ..., -0.0280, -0.0270, -0.0092]],
       device='cuda:0', requires_grad=True)
linear_relu_stack.0.bias
Parameter containing:
tensor([-5.1606e-03,  2.0612e-02, -1.1231e-02, -1.7049e-02,  3.4463e-03,
         3.5099e-02,  3.3571e-02, 