In [5]:
import torch
import torch.nn.functional as F
import torchvision

from einops.layers.torch import Rearrange, Reduce
from random import randrange
from torch import nn, einsum

In [6]:
def exists(val):
    return val is not None

def pair(val):
    return (val, val) if not isinstance(val, tuple) else val

def dropout_layers(layers, prob_survival):
    if prob_survival == 1:
        return layers
    
    num_layers = len(layers)
    to_drop = torch.zeros(num_layers).uniform_(0., 1.) > prob_survival
    
    if all(to_drop):
        rand_index = randrange(num_layers)
        to_drop[rand_index] = False
        
    # x = torch.zeros(2).uniform_(0, 1)  -> tensor([0.7952, 0.0958])
    # print(x > 0.5) -> tensor([ True, False]) 
    # print(all(x > 0.5)) -> False
        
    layers = [layer for (layer, drop) in zip(layers, to_drop) if not drop]
    
    return layers

class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
        
    def forward(self, x):
        return self.fn(x) + x
    
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.LayerNorm(dim)
        
    def forward(self, x, **kwargs):
        x = self.norm(x)
        return self.fn(x, **kwargs)
    
class Attention(nn.Module):
    def __init__(self, dim_in, dim_oout, dim_inner, casual=False):
        super().__init__()
        self.scale = dim_inner ** -0.5
        self.casual = casual
        
        self.to_qkv = nn.Linear(dim_in, dim_inner * 3, bias=False)
        self.to_out = nn.Linear(dim_inner, dim_out)
        
    def forward(self, x):
        device = x.device
        # chunk: テンソルを特定の次元に沿ってchunk個に分割する
        q, k, v = self.to_qkv(x).chunk(3, dim = -1)
        # einsum:アインシュタインの縮約でテンソル積の計算の構文糖
        sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
        
        if self.casual:
            # すべて1埋めの行列を作成し上三角部分を取得しTrueで埋める
            mask = torch.ones(sim.shape[-2:], device=device).triu(1).bool()
            # torch.Tensor.masked_fill_: Fills elements of self tensor with value where mask is True.
            # torch.finfo: A torch.finfo is an object that represents the numerical properties of a floating point torch.dtype
            sim.masked_fill_(mask[None, ...], -torch.finfo(q.dtype).max)
            
        attn = sim.softmax(dim=-1)
        out = einsum('b i j, b j d -> b i d', attn, v)
        return self.to_out(out)
    
class SpatialGatingUnit(nn.Module):
    def __init__(self, dim, dim_seq, casual = False, act=nn.Identity(), init_eps=1e-3):
        super().__init__()
        dim_out = dim // 2
        self.casual = casual
        
        self.norm = nn.LayerNorm(dim_out)
        self.proj = nn.Conv1d(dim_seq, dim_seq, 1)
        
        self.act = act
        
        init_eps /= dim_seq
        nn.init.uniform_(self.proj.weight, -init_eps, init_eps)
        nn.init.constant_(self.proj.bias, 1.)
        
    def forward(self, x, gate_res=None):
        device, n = x.device, x.shape[1]
        
        res, gate = x.chunk(2, dim=-1)
        gate = self.norm(gate)
        
        weight, bias = self.proj.weight, self.proj.bias
        if self.casual:
            weight, bias = weight[:n, :n], bias[:n]
            mask = torch.ones(weight.shape[:2], device=device).triu_(1).bool()
            weight = weight.masked_fill(mask[..., None], 0.)
            
        gate = F.conv1d(gate, weight, bias)
        
        if exists(gate_res):
            gate = gate + gate_res
            
        return self.act(gate) * res
    
class gMLPBlock(nn.Module):
    def __init__(self, *, dim, dim_ff, seq_len, attn_dim=None, casual=False, act=nn.Identity()):
        
        super().__init__()
        self.proj_in = nn.Sequential(nn.Linear(dim, dim_ff), nn.GELU())
        
        self.attn = Attention(dim, dim_ff // 2, attn_dim, casual) if exists(attn_dim) else None
        
        self.sgu = SpatialGatingUnit(dim_ff, seq_len, casual, act)
        self.proj_out = nn.Linear(dim_ff // 2, dim)
        
    def forward(self, x):
        gate_res = self.attn(x) if exists(self.attn) else None
        
        x = self.proj_in(x)
        x = self.sgu(x, gate_res = gate_res)
        x = self.proj_out(x)
        
        return x

class gMLPVision(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, ff_mult=4, channels=3, attn_dim =None, prob_survival = 1.):
        
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)
                
        assert (image_height % patch_height ) == 0 and (image_width % patch_width) == 0, 'image height and width must be divisible by patch size'
        num_patches = (image_height // patch_height) * (image_width // patch_width)
        
        dim_ff = dim * ff_mult

        self.to_patch_embed = nn.Sequential(
            Rearrange('b c (h  p1) (w p2 ) -> b (h w) (c p1 p2)',  p1 = patch_height, p2 = patch_width), 
            nn.Linear(channels * patch_width * patch_height, dim)
        )
        
        self.prob_survival = prob_survival

        self.layers = nn.ModuleList([Residual(PreNorm(dim, gMLPBlock(dim = dim, dim_ff = dim_ff, seq_len = num_patches, attn_dim = attn_dim))) for i in range(depth)])
       
        self.to_logits = nn.Sequential(
            nn.LayerNorm(dim),
            Reduce('b n d -> b d', 'mean'),
            nn.Linear(dim, num_classes)
        )
        
        
    def forward(self, x):
        x = self.to_patch_embed(x)
        layers = self.layers if not self.training else dropout_layers(self.layers, self.prob_survival)
        x = nn.Sequential(*layers)(x)
        
        return self.to_logits(x)
    

In [4]:
model = gMLPVision(
    image_size = 256,
    patch_size = 16,
    num_classes = 100,
    dim = 512,
    depth = 6
)

img = torch.randn(1, 3, 256, 256)
logits = model(img) # (1, 1000)

In [7]:
batch_size_train = 64 # We use a small batch size here for training
batch_size_test = 1024 #

# define how image transformed
image_transform = torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])
#image datasets
train_dataset = torchvision.datasets.MNIST('dataset/', 
                                           train=True, 
                                           download=True,
                                           transform=image_transform)
test_dataset = torchvision.datasets.MNIST('dataset/', 
                                          train=False, 
                                          download=True,
                                          transform=image_transform)
#data loaders
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size_train, 
                                           shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset,
                                          batch_size=batch_size_test, 
                                          shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to dataset/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HTTPError: HTTP Error 503: Service Unavailable

In [8]:
# LayerNormalization:
# ミニバッチの個別のテンソル毎に独立、チャネルは区別しない
# ミニバッチの単位とチャネルの単位で正規化の対象が異なる
# https://qiita.com/amateur2020/items/f2c829677d9764af0b50

In [None]:
# tiruの動作確認
a = torch.randn(3, 3)
print(a)
print(torch.triu(a))
print(a.dtype)
print(torch.finfo(a.dtype))
print(torch.finfo(a.dtype).max)

In [4]:
# 
mask = torch.ones((3, 3), device='cpu').triu(1).bool()
print(mask)
print(mask[None, ...])
print(mask[..., None])

tensor([[False,  True,  True],
        [False, False,  True],
        [False, False, False]])
tensor([[[False,  True,  True],
         [False, False,  True],
         [False, False, False]]])
tensor([[[False],
         [ True],
         [ True]],

        [[False],
         [False],
         [ True]],

        [[False],
         [False],
         [False]]])


In [5]:
assert 3 ==4, 'not equal'

AssertionError: not equal