# LESON 31 卷积神经网络

### RES_NET

In [1]:
from torch import nn
import torch

In [2]:
class ResBlock(nn.Module):
    def __init__(self, n_chans):
        super(ResBlock, self).__init__()
        self.conv = nn.Conv2d(n_chans, n_chans, kernel_size=3, padding=1, bias=False)
        self.batch_normal = nn.BatchNorm2d(num_features=n_chans)
        nn.init.kaiming_normal_(self.conv.weight, nonlinearity='relu')
        nn.init.constant_(self.batch_normal.weight, 0.5)
        nn.init.zeros_(self.batch_normal.bias)

    def forward(self,x):
        out = self.conv(x)
        out = self.batch_normal(out)
        out = torch.relu(out)

        return out + x

In [3]:
import torch.functional as F

class NetResDeep(nn.Module):
    def __init__(self, n_chans=32, n_blocks=10):
        super().__init__()

        self.n_chans = n_chans
        self.conv = nn.Conv2d(3, n_chans, kernel_size=3, padding=1)

        self.resblocks = nn.Sequential(
            *[ResBlock(n_chans=n_chans) for _ in range(n_blocks)]
        )

        self.fc1 = nn.Linear(8 * 8 * n_chans, 32)
        self.fc2 = nn.Linear(32, 10)

    def forward(self, x):
        out = F.max_pool2d(torch.relu(self.conv(x)), 2)
        out = self.resblocks(out)

        out = F.max_pool2d(out, 2)

        out = out.view(-1, 8 * 8 * self.n_chans)
        out = torch.relu(self.fc1(out))

        out = self.fc2(out)

        return out