In [59]:
import torch
from torch import nn
from torch.nn import functional as F

In [60]:
#手动实现多层感知机
class MLP(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.hidden=nn.Linear(20,256)
        self.out=nn.Linear(256,10)
        
    def forward(self,X):
        return self.out(F.relu(self.hidden(X)))
    
net=MLP()
X=torch.randn(3,20)
net.forward(X)

tensor([[ 0.5456, -0.0139,  0.3604, -0.0753,  0.3465, -0.2623,  0.2258,  0.4416,
         -0.0819,  0.0041],
        [ 0.4631, -0.2744, -0.0177, -0.3775,  0.3395, -0.1012,  0.2001, -0.2340,
         -0.1355, -0.2906],
        [-0.4599, -0.0992,  0.1617,  0.3489,  0.5954,  0.1378,  0.0285,  0.1254,
         -0.1612, -0.0072]], grad_fn=<AddmmBackward0>)

In [61]:
#重写Sequential
class mySequential(nn.Module):
    def __init__(self,*args) -> None:
        super().__init__()
        self.layers=args
        
    def forward(self,X):
        res=X
        for layer in self.layers:
            res=layer(res)
        return res

net0=mySequential(nn.Linear(20,256),nn.ReLU(),nn.Linear(256,10))
X0=net0(X)

In [62]:
#三个层
net=nn.Sequential(nn.Linear(20,256),nn.ReLU(),nn.Linear(256,10))
net[2].state_dict()#参数字典

OrderedDict([('weight',
              tensor([[ 5.2678e-02,  4.0138e-02,  3.1163e-02,  ...,  9.3917e-03,
                        4.7029e-02,  5.7082e-02],
                      [ 1.0350e-02,  8.6519e-03,  5.5149e-02,  ..., -1.7446e-02,
                        4.4096e-02, -7.7985e-05],
                      [ 3.7303e-02, -4.6987e-02,  8.0692e-03,  ...,  2.0141e-02,
                       -4.4910e-02, -3.8173e-03],
                      ...,
                      [ 1.1738e-02,  3.4330e-03, -5.2295e-02,  ...,  2.7930e-02,
                       -6.1875e-02, -6.2694e-03],
                      [-4.4126e-02, -1.7636e-02,  5.9123e-02,  ...,  3.1298e-02,
                       -1.1796e-02, -1.8344e-02],
                      [ 6.2318e-02,  1.9420e-02, -3.8788e-02,  ...,  3.8656e-03,
                        2.5948e-02,  5.1625e-02]])),
             ('bias',
              tensor([ 0.0185,  0.0065,  0.0190, -0.0044, -0.0163, -0.0228,  0.0410, -0.0200,
                       0.0470, -0.0151]))])

In [63]:
net[2].bias.data

tensor([ 0.0185,  0.0065,  0.0190, -0.0044, -0.0163, -0.0228,  0.0410, -0.0200,
         0.0470, -0.0151])

In [64]:
[(name,param.shape) for name,param in net.named_parameters()]

[('0.weight', torch.Size([256, 20])),
 ('0.bias', torch.Size([256])),
 ('2.weight', torch.Size([10, 256])),
 ('2.bias', torch.Size([10]))]

In [65]:
net.add_module('softmax 1',nn.Softmax())
net#增加名字和模块

Sequential(
  (0): Linear(in_features=20, out_features=256, bias=True)
  (1): ReLU()
  (2): Linear(in_features=256, out_features=10, bias=True)
  (softmax 1): Softmax(dim=None)
)

In [66]:
#自定义层
class Relu(nn.Module):
    def forward(self,X):
        a=torch.zeros_like(X)
        return torch.max(a,X)
#forward会自动调用
net=nn.Sequential(nn.Linear(20,256),Relu())
net(X)

tensor([[0.0000e+00, 1.4499e-01, 0.0000e+00, 2.9267e-01, 0.0000e+00, 0.0000e+00,
         2.1341e-01, 0.0000e+00, 0.0000e+00, 4.5029e-01, 0.0000e+00, 0.0000e+00,
         5.0784e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 6.4271e-01, 0.0000e+00,
         4.5113e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.5431e+00, 4.5986e-01,
         2.5002e-01, 2.3068e-01, 2.7206e-01, 0.0000e+00, 0.0000e+00, 4.7483e-01,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 5.0842e-01,
         0.0000e+00, 2.3710e-01, 0.0000e+00, 9.2859e-01, 9.2492e-01, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 7.8932e-01, 7.7232e-01, 0.0000e+00, 0.0000e+00, 5.0438e-01,
         0.0000e+00, 0.0000e+00, 1.3850e-01, 3.8440e-01, 0.0000e+00, 0.0000e+00,
         3.4158e-01, 0.0000e+00, 0.0000e+00, 1.0876e+00, 2.2254e-02, 0.0000e+00,
         8.5145e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 9.8831e-01, 0.0000e+00,
         6.5835e-02, 1.0590e

In [67]:
torch.save(net0.state_dict(),r'data\mlp.params')
clone=net0
clone.load_state_dict(torch.load(r'data\mlp.params'))


<All keys matched successfully>

In [68]:
X0==clone(X)

tensor([[True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True]])