**ConvMixer模型结构**
![image.png](../add_pic/ConvMixer模型结构.png)

![image.png](../add_pic/ConvMixer实现.png)

In [1]:
import torch.nn as nn

In [2]:
# 定义残差块
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x):
        return self.fn(x) + x

In [3]:
# 定义convmixer模型
def ConvMixer(dim, depth, kernel_size=9, patch_size=7, n_classes=1000):
    return nn.Sequential(
        nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size),
        nn.GELU(),
        nn.BatchNorm2d(dim),
        *[nn.Sequential(
            Residual(nn.Sequential(
                nn.Conv2d(dim, dim, kernel_size, groups=dim, padding='same'),
                nn.GELU(),
                nn.BatchNorm2d(dim)
            )),
            nn.Conv2d(dim, dim, kernel_size=1),
            nn.GELU(),
            nn.BatchNorm2d(dim)
    ) for _ in range(depth)],
    nn.AdaptiveAvgPool2d((1, 1)),
    nn.Flatten(),
    nn.Linear(dim, n_classes)
    )

In [4]:
# 测试模型
model = ConvMixer(dim=512, depth=8)
print(model)

Sequential(
  (0): Conv2d(3, 512, kernel_size=(7, 7), stride=(7, 7))
  (1): GELU(approximate='none')
  (2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (3): Sequential(
    (0): Residual(
      (fn): Sequential(
        (0): Conv2d(512, 512, kernel_size=(9, 9), stride=(1, 1), padding=same, groups=512)
        (1): GELU(approximate='none')
        (2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
    (2): GELU(approximate='none')
    (3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (4): Sequential(
    (0): Residual(
      (fn): Sequential(
        (0): Conv2d(512, 512, kernel_size=(9, 9), stride=(1, 1), padding=same, groups=512)
        (1): GELU(approximate='none')
        (2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): Conv2d(512, 512