In [1]:
%cd ../src

C:\Users\nozoe-tatsuya\dev\ai-ocr-ensemble\src


In [2]:
import torch
import torch.nn as nn

from functools import partial

from models.gmlp import GMLP, GatedMLPBlock, GatedMLPCore, SpatialGatingUnit

In [3]:
N = 32
in_channels = 3
num_classes = 17
image_size = 32
patch_size = 16
num_patches = (image_size // patch_size) ** 2
W = H = image_size
d_model = 128
mlp_ratio = 6
d_ffn = d_model * mlp_ratio

## Spatial Gating Unit
```
    in - split - norm - proj - hadamard - out
           └─────────────┘
```

In [4]:
x = torch.randn(N, num_patches, d_ffn)
x.shape

torch.Size([32, 4, 768])

In [5]:
sgu = SpatialGatingUnit(dim=d_ffn, num_patches=num_patches, norm_layer=nn.LayerNorm)

In [6]:
y = sgu(x)

In [7]:
y.shape  # => (N, num_patches, d_ffn // 2)

torch.Size([32, 4, 384])

## gMLP core
```
    in - proj - gelu - dropout - SGU - proj - dropout - out
    
(gMLP block)
    in - norm - gMLP - dropout - add - out
       └───────────────┘
```

In [8]:
x = torch.randn(N, num_patches, d_model)
x.shape

torch.Size([32, 4, 128])

In [9]:
gmlp_core = GatedMLPCore(d_model, d_ffn,
                         activation=nn.GELU,
                         spatial_gating_unit=partial(SpatialGatingUnit, num_patches=num_patches),
                         p_dropout=0.2)

In [10]:
y = gmlp_core(x)

In [11]:
y.shape

torch.Size([32, 4, 128])

## gMLP block
```
gMLP block
    in - norm - gMLP - dropout - add - out
       └───────────────┘
```

In [12]:
x = torch.randn(N, num_patches, d_model)
x.shape

torch.Size([32, 4, 128])

In [13]:
gmlp_block = GatedMLPBlock(d_model, num_patches, mlp_ratio=mlp_ratio,
                           norm_layer=nn.LayerNorm,
                           activation=nn.GELU,
                           p_dropout=0.2,
                           p_droppath=0.1)

In [14]:
y = gmlp_block(x)

In [15]:
y.shape

torch.Size([32, 4, 128])

## gMLP
```
    in - patch_embed - gMLP_block * num_blocks - norm - GAP - head - out
```

In [16]:
x = torch.randn(N, in_channels, H, W)
x.shape

torch.Size([32, 3, 32, 32])

In [17]:
gmlp = GMLP(img_size=image_size, in_channels=in_channels, num_classes=num_classes,
            patch_size=patch_size, num_blocks=30, d_model=d_model, mlp_ratio=mlp_ratio,
            p_dropout=0., p_droppath=0.)

In [18]:
y = gmlp(x)

In [19]:
y.shape

torch.Size([32, 17])