<a href="https://colab.research.google.com/github/ganesh3/pytorch-work/blob/master/MLPMixer_pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [32]:
import torch
import torch.nn as nn
import einops

In [33]:
class MlpBlock(nn.Module):
  """Multilayer Perceptron

  Parameters
  -----------
  dim : int
      Input and output dimension of the entire block. Inside of the mixer it will either be equal to 'n_patches' or 'hidden_dim'

  mlp_dim : int
          Dimension of the hidden layer

  Attributes
  -----------

  linear_1, linear_2 : nn.Linear
          Linear Layers.

  
  activation : nn.GELU
            activation.

  """

  def __init__(self, dim, mlp_dim):
    super().__init__()

    mlp_dim = dim if mlp_dim is None else mlp_dim
    self.linear_1 = nn.Linear(dim, mlp_dim)
    self.activation = nn.GELU()
    self.linear_2 = nn.Linear(mlp_dim, dim)

  def forward(self, x):
    """Run the forward pass

    Parameters
    -----------
    x : torch.Tensor
      Input tensor of shape '(n_samples, n_channels, n_patches)' or '(n_samples, n_patches, n_channels)'

    Returns
    --------
    torch.Tensor
      Output tensor will have the same shape as the input tensor 'x'.
    """

    x = self.linear_1(x) #(n_samples, *, mlp_dim)
    x = self.activation(x) #(n_samples, *, mlp_dim)
    x = self.linear_2(x) #(n_samples, *, dim)

    return x

In [34]:
class MixerBlock(nn.Module):
  """Mixer block that contains 2 'MLPBlock's' and 2 'LayerNorm's'

  Parameters
  -----------
  n_patches : int
        Number of patches the image is split up into.

  hidden_dim : int
        Dimensionality of patch embeddings.

  tokens_mlp_dim : int
        Hidden dimension for the 'MLPBlock' when doing the token mixing.

  channels_mlp_dim : int
        Hidden dimension for the 'MLPBlock' when doing the channel mixing

  Attributes
  -----------
  norm_1, norm_2 : LayerNorm
        Layer Normalization

  token_mlp_block : MlpBlock
        Token mixing NLP.

  channel_mlp_block : MlpBlock
        Channel mixing NLP.

  """

  def __init__(self, *, n_patches, hidden_dim, tokens_mlp_dim, channels_mlp_dim):
    super().__init__()

    self.norm_1 = nn.LayerNorm(hidden_dim)
    self.norm_2 = nn.LayerNorm(hidden_dim)

    self.token_mlp_block = MlpBlock(n_patches, tokens_mlp_dim)
    self.channel_mlp_block = MlpBlock(hidden_dim, channels_mlp_dim)

  def forward(self, x):
    """Run the forward pass

    Parameters
    -----------
    x : torch.Tensor
      Input tensor of shape '(n_samples, n_patches, hidden_dim)'

    Returns
    --------
    torch.Tensor
      Output tensor will have the same shape as the input tensor 'x' i.e. '(n_samples, n_patches, hidden_dim)'.
    """
    y = self.norm_1(x) # (n_samples, n_patches, hidden_dim)
    y = y.permute(0, 2, 1) # swap the hidden_dim and n_patches (n_samples, hidden_dim, n_patches)
    y = self.token_mlp_block(y) # (n_samples, hidden_dim, n_patches)

    y = y.permute(0, 2, 1) # swap the n_patches and hidden_dim (n_samples, n_patches, hidden_dim)
    x = x + y # add y as a residual to the input (n_samples, n_patches, hidden_dim)
    y = self.norm_2(x) #(n_samples, n_patches, hidden_dim)
    res = x + self.channel_mlp_block(y) #(n_samples, n_patches, hidden_dim)
    return res

In [35]:
class MlpMixer(nn.Module):
  """Entire network.

  Parameters
  -----------
  image_size : int
    Height and width (assuming it is a square) of the image

  patch_size : int
    Height and width (assuming it is a square) of the patches. Note that we assume that 'image_size % patch_size == 0'

  tokens_mlp_dim : int
    Hidden dimension for the 'MlpBlock' when doing the token mixing.

  channel_mlp_dim : 
    Hidden dimension for the 'MlpBlock' when doing the channel mixing.

  n_classes : int
    Number of classes for classification
  
  hidden_dim : int
    Dimensionality of patch embeddings

  n_blocks : int
    Number of 'MixerBlock' in the architecture.

  Attributes
  -----------

  patch_embedder : nn.Conv2d
    Splits the image up into multiple patches and then embeds each of them (using shared weights).

  blocks : nn.ModuleList
    List of MixerBlock instances.

  pre_head_norm : nn.LayerNorm
    LayerNormalization applied just before the classification head.

  head_classifier : nn.Linear
    The classification head.

  """
  def __init__(self, *, image_size, patch_size, tokens_mlp_dim, channels_mlp_dim, n_classes, hidden_dim, n_blocks, ):
    super().__init__()

    n_patches = (image_size // patch_size) ** 2 # assumes divisibility
    self.patch_embedder = nn.Conv2d(3, hidden_dim, kernel_size=patch_size, stride = patch_size, )
    self.blocks = nn.ModuleList([MixerBlock(n_patches=n_patches, hidden_dim=hidden_dim, tokens_mlp_dim=tokens_mlp_dim, channels_mlp_dim=channels_mlp_dim) 
                                for _ in range(n_blocks)])
    self.pre_head_norm = nn.LayerNorm(hidden_dim)
    self.head_classifier = nn.Linear(hidden_dim, n_classes)


  def forward(self, x):
    """Run the forward pass

    Parameters
    -----------
    x : torch.Tensor
      Input batch of square images of shape '(n_samples, n_channels, image_size, image_size)'

    Returns
    --------
    torch.Tensor
      Class logits of the shape '(n_samples, n_classes)'.
    """
    x = self.patch_embedder(x) # (n_samples, hidden_dim, n_patches ** (1/2), n_patches ** (1/2))
    x = einops.rearrange(x, "n c h w -> n (h w) c") # (n_samples, n_patches, hidden_dim)

    for mixer_block in self.blocks:
      x = mixer_block(x) # (n_samples, n_patches, hidden_dim)
    
    x = self.pre_head_norm(x) # (n_samples, n_patches, hidden_dim)
    x = x.mean(dim = 1) # (n_samples, hidden_dim) - averaging across the token dimension
    y = self.head_classifier(x) # (n_samples, n_classes)

    return y




### Token Mixing as a Convolution using Conv1d

In [18]:
###Token Mixing as a Convolution using Conv1d
class conv1dDepthWiseShared(nn.Module):
  def __init__(self, hidden_dim, kernel_size, k):
    super().__init__()
    #same as input channels of our tensor
    self.hidden_dim = hidden_dim
    # K is any number representing the output features
    self.weight_shared = nn.Parameter(torch.rand(k, 1, kernel_size,))
    self.bias_shared = nn.Parameter(torch.rand(k))

  def forward(self, x):
    weight = self.weight_shared.repeat(self.hidden_dim, 1, 1)
    bias = self.bias_shared.repeat(self.hidden_dim)
    res = torch.nn.functional.conv1d(x, weight=weight, bias=bias, groups=self.hidden_dim)
    return res

In [19]:
n_samples, hidden_dim, n_patches = 2, 16, 25
k = 7
x = torch.rand(n_samples, hidden_dim, n_patches)
module_conv = conv1dDepthWiseShared(hidden_dim, n_patches, k)
module_linear = nn.Linear(n_patches, k)

In [23]:
sum(p.numel() for p in module_conv.parameters() if p.requires_grad), sum(p.numel() for p in module_linear.parameters() if p.requires_grad)

(182, 182)

In [24]:
module_conv.weight_shared.data[:, 0, :] = module_linear.weight.data
module_conv.bias_shared.data[:] = module_linear.bias.data

In [25]:
out_conv = module_conv(x).reshape(n_samples, hidden_dim, k)

In [26]:
out_linear = module_linear(x)

In [27]:
out_conv.shape, out_linear.shape

(torch.Size([2, 16, 7]), torch.Size([2, 16, 7]))

In [28]:
torch.allclose(out_conv, out_linear, atol=1e-6, rtol=0) 

True

182