In [11]:
import torch.nn as nn
import torch
from collections import OrderedDict

In [12]:
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.featrues = nn.Sequential( # 内部实现了forward函数;各网络层顺序执行
            nn.Conv2d(1, 6, 3),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(6, 16, 3),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.classifier = nn.Sequential(
            nn.Linear(16*6*6, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, 10),
        )

    def forward(self, x):
        x = self.featrues(x)
        x = x.reshape(x.size()[0], -1)
        x = self.classifier(x)
        return x

In [13]:
net = LeNet()
net

LeNet(
  (featrues): Sequential(
    (0): Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (classifier): Sequential(
    (0): Linear(in_features=576, out_features=120, bias=True)
    (1): ReLU()
    (2): Linear(in_features=120, out_features=84, bias=True)
    (3): ReLU()
    (4): Linear(in_features=84, out_features=10, bias=True)
  )
)

In [14]:
print(net.featrues[0]) # 通过网络层顺序进行索引或切片
print(type(net.featrues[0]))
print(net.featrues[0:3])

Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))
<class 'torch.nn.modules.conv.Conv2d'>
Sequential(
  (0): Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))
  (1): ReLU()
  (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)


In [15]:
img = torch.randn((1, 1, 32, 32), dtype=torch.float32)
ret = net(img) # 调用forward方法(通过__call__方法),继承自nn.Module
ret

tensor([[ 0.0874, -0.0706,  0.0642,  0.2153, -0.0228,  0.0939,  0.0963,  0.1423,
          0.0593,  0.0394]], grad_fn=<AddmmBackward>)

In [16]:
class LeNet1(LeNet):
    def __init__(self):
        super(LeNet1, self).__init__()
        self.featrues = nn.Sequential(OrderedDict({ # 有序字典
            'conv1': nn.Conv2d(1, 6, 3), # 通过有序字典指定各模块的名称
            'relu1': nn.ReLU(),
            'pool1': nn.MaxPool2d(2, 2),
            'conv2': nn.Conv2d(6, 16, 3),
            'relu2': nn.ReLU(),
            'pool2': nn.MaxPool2d(2, 2)
        }))
        self.classifier = nn.Sequential(OrderedDict({
            'line1': nn.Linear(16*6*6, 120),
            'relu1': nn.ReLU(),
            'line2': nn.Linear(120, 84),
            'relu2': nn.ReLU(),
            'line3': nn.Linear(84, 10),
        }))

    def forward(self, x):
        x = self.featrues(x)
        x = x.reshape(x.size()[0], -1)
        x = self.classifier(x)
        return x

In [17]:
net1 = LeNet1()
net1

LeNet1(
  (featrues): Sequential(
    (conv1): Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))
    (relu1): ReLU()
    (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (conv2): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1))
    (relu2): ReLU()
    (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (classifier): Sequential(
    (line1): Linear(in_features=576, out_features=120, bias=True)
    (relu1): ReLU()
    (line2): Linear(in_features=120, out_features=84, bias=True)
    (relu2): ReLU()
    (line3): Linear(in_features=84, out_features=10, bias=True)
  )
)

In [18]:
print(net1.featrues[0]) # 通过网络层顺序进行索引或切片
print(net.featrues[0:4])
print(net1.featrues.conv2) # 还可以通过名称(即有序字典的key)进行索引

Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))
Sequential(
  (0): Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))
  (1): ReLU()
  (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (3): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1))
)
Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1))


In [19]:
img1 = torch.randn((1, 1, 32, 32), dtype=torch.float32)
ret1 = net1(img)
ret1

tensor([[-0.0172, -0.1413,  0.1419, -0.0052,  0.0612, -0.0540, -0.0895,  0.0453,
         -0.0634, -0.0467]], grad_fn=<AddmmBackward>)