# task05 PyTorch模型定义的方式
---

本节学习：

- 熟悉PyTorch中模型定义的三种方式：Sequential，ModuleList和ModuleDict
- 读懂GitHub上千奇百怪的写法
- 自己根据需要灵活选取模型定义方式



## 5.1  `Module `知识回顾

- `Module`类是`torch.nn`模块里提供的一个模型构造类(`nn.Module`),是所有神经网络模块的基类，我们可以继承它来定义我们想要的模型；
-  PyTorch模型定义应包括两个主要部分：各个部分的初始化(__init__)和数据流向定义(forward)


## 5.2 `nn.moudule`之Squential 

`nn.Sequential`用于**按顺序**包装一组网络层。

In [6]:
import torch
import torchvision
import torch.nn as nn
from collections import OrderedDict


class LeNetSequential(nn.Module):
    def __init__(self, classes):
        super(LeNetSequential, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 6, 5),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(6, 16, 5),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2), )
        self.classifier = nn.Sequential(
            nn.Linear(16 * 5 * 5, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, classes), )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size()[0], -1)
        x = self.classifier(x)
        return x


class LeNetSequentialOrderDict(nn.Module):
    def __init__(self, classes):
        super(LeNetSequentialOrderDict, self).__init__()
        self.features = nn.Sequential(OrderedDict({
            'conv1': nn.Conv2d(3, 6, 5),
            'relu1': nn.ReLU(inplace=True),
            'pool1': nn.MaxPool2d(kernel_size=2, stride=2),
            'conv2': nn.Conv2d(6, 16, 5),
            'relu2': nn.ReLU(inplace=True),
            'pool2': nn.MaxPool2d(kernel_size=2, stride=2),
        }))

        self.classifier = nn.Sequential(OrderedDict({
        'fc1': nn.Linear(16 * 5 * 5, 120),
        'relu3': nn.ReLU(),
        'fc2': nn.Linear(120, 84),
        'relu4': nn.ReLU(inplace=True),
        'fc3': nn.Linear(84, classes),
    }))


    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size()[0], -1)
        x = self.classifier(x)
        return x



In [8]:
net = LeNetSequential(classes=2)
net = LeNetSequentialOrderDict(classes=2)



In [10]:
fake_img = torch.randn((4, 3, 32, 32), dtype=torch.float32)



In [None]:
output = net(fake_img)

print(net)
print(output)

## 5.3 `nn.moudule`之ModuleList

`nn.ModuleList`用于包装一组网络层，以**迭代**方式调用网络层

主要方法：
- `append()`:在ModuleList后面添加网络层
- `extend()`:拼接两个ModuleList
- `insert()`:指定在ModuleList中位置**插入**网络层

In [11]:
class ModuleList(nn.Module):
    def __init__(self):
        super(ModuleList, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(20)])

    def forward(self, x):
        for i, linear in enumerate(self.linears):
            x = linear(x)
        return x

## 5.4 `nn.moudule`之ModuleDict

`nn.ModuleDict`用于包装一组网络层，以**索引**方式调用网络层

主要方法：

- `clear()`:清空ModuleDict
- `items()`:返回可迭代的键值对(key-value pairs)
- `keys()`:返回字典的键(key)
- `values()`:返回字典的值（value）
- `pop()`:返回一对键值，并从字典中删除


In [13]:
class ModuleDict(nn.Module):
    def __init__(self):
        super(ModuleDict, self).__init__()
        self.choices = nn.ModuleDict({
            'conv': nn.Conv2d(10, 10, 3),
            'pool': nn.MaxPool2d(3)
        })
        self.activations = nn.ModuleDict({
            'relu': nn.ReLU(),
            'prelu': nn.PReLU()
        })

    def forward(self, x, choice, act):
        x = self.choices[choice](x)
        x = self.activations[act](x)
        return x


net = ModuleDict()
fake_img = torch.randn((4, 10, 32, 32))
output = net(fake_img, 'conv', 'relu')
print(output)


tensor([[[[0.0000, 0.0000, 0.6661,  ..., 0.0000, 0.4233, 0.0715],
          [0.0000, 0.0000, 0.1219,  ..., 0.6663, 0.0000, 0.0419],
          [0.0000, 0.0000, 0.2731,  ..., 0.5060, 0.0000, 0.1619],
          ...,
          [0.0000, 0.6161, 0.0000,  ..., 0.0000, 0.1367, 0.5291],
          [0.0000, 0.8069, 0.0000,  ..., 0.0000, 0.1007, 0.5369],
          [0.0143, 0.5083, 0.2260,  ..., 0.0000, 0.0000, 0.0000]],

         [[0.6459, 0.0000, 0.0000,  ..., 0.0000, 0.5090, 0.0000],
          [0.0548, 0.2861, 0.5799,  ..., 1.3627, 0.0000, 0.0000],
          [0.7538, 0.0000, 0.3612,  ..., 0.5384, 0.0000, 0.0000],
          ...,
          [0.0000, 0.0000, 0.2894,  ..., 0.0000, 0.0000, 0.2972],
          [0.0000, 0.3672, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 1.4488, 0.4965,  ..., 0.0000, 0.0000, 0.6022]],

         [[0.1347, 0.1291, 0.0000,  ..., 0.2189, 0.0000, 0.7473],
          [0.6860, 0.5748, 0.0000,  ..., 0.0000, 0.1640, 0.2840],
          [0.0000, 0.3390, 0.0000,  ..., 0

## 5.5 三模块总结：

- `nn.Sequential`：**顺序性**，各网络之间严格按顺序执行，常用于block构建
- `nn.ModuleList`: **迭代性**，常用于大量重复网构建，通过for循环实现重复构建
- `nn.ModuleDict`: **索引性**，常用于可选择的网络层