In [2]:
!pip install timm

Collecting timm
[?25l  Downloading https://files.pythonhosted.org/packages/90/fc/606bc5cf46acac3aa9bd179b3954433c026aaf88ea98d6b19f5d14c336da/timm-0.4.12-py3-none-any.whl (376kB)
[K     |████████████████████████████████| 378kB 5.0MB/s 
Installing collected packages: timm
Successfully installed timm-0.4.12


In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from collections import OrderedDict

import timm

To build a model based on $n$ identical blocks, we use a wrapper for a generic block instance (a `torch.nn.Module`) and add a normalization layer at the bottom of the pipeline.  

In [None]:
class SimpleBlockModel(nn.Module):
  """
  """
  def __init__(
      self,
      block_layer:nn.Module,
      features_dim:int,
      batch_norm:nn.Module=None,
      p_drop:float=.1,
  ):
    super().__init__()
    self.batch_norm = batch_norm or nn.BatchNorm1d
    self.dropout = nn.Dropout(p_drop)
    self.act_fn = act_fn or nn.Relu

    self.batch_norm = batch_norm(features_dim)
    self.block_layer = block_layer
    self.proj = nn.utils.weight_norm(nn.Linear(features_dim, features_dim)) # weight norm from https://www.kaggle.com/shahules/pytorch-entity-embedding#Model

def forward(self, x):
  x = block_layer(x)
  x = self.batch_norm(x)
  x = self.dropout(x)
  return act_fn(self.proj(x))

In [None]:
class MultiBlockModel(nn.Module):
    """
    MultiBlockModel is a stack of n blocks with a final normalization layer.
    """
    def __init__(
        self, 
        block:nn.Module, 
        num_blocks:int=1
    ):
        super().__init__()
        self.blocks = nn.ModuleList([copy.deepcopy(block) for _ in range(num_blocks)])
        self.norm = LayerNorm(block.size)
        
    def forward(self, x):
      for block in self.blocks:
        x = block(x)
      return self.norm(x)

In [None]:
class EncoderDecoderModel(nn.Module):
  """
  An encoder/decoder architecture for classification (-ish).
  """
  def __init__(
      self, 
      encoder:nn.Module,
      decoder:nn.Module,
      classifier:nn.Module
  ):
    super().__init__()
    self.encoder = encoder
    self.decoder = decoder
    self.classifier = classifier

  def forward(self, 
              x, 
              return_representation=False, 
              return_features=False, 
              return_prob=True):
    representation = self.encoder(x)
    if return_representation:
      return representation
    features = self.decoder(repr)
    if return_features:
      return features
    y = self.classifier(features)
    if return_prob:
      return F.log_softmax(y) # obtain model confidence (for multiclass problems)
    return y

In [None]:
class CatModel(nn.Module):
  def __init__(
      self,
      left_model:nn.Module,
      right_model:nn.Module,
  ):
    super().__init__()
    self.left_model = left_model
    self.right_model = right_model

  def forward(self, left_x, right_x):
    left_x = left_model(left_x)
    right_x = right_model(right_x)
    return torch.cat((left_x, right_x))

In [None]:
import math

class EmdebbingModel(nn.Module):
  def __init__(
      self, 
      num_embeddings:int, 
      embedding_dim:int
  ):
    """
    num_embeddings: size of the dictionary of embeddings
    embedding_dim: the size of each embedding vector 
    """
    super().__init__()
    self.embedding = nn.Embedding(num_embeddings, embedding_dim)
    self.embedding_dim = embedding_dim

  def forward(self, x):
    return self.embedding(x) * math.sqrt(self.embedding_dim) # sqrt from "attention is all you need"

In [None]:
class ResidualBlockModel(nn.Module):
  def __init__(
      self,
      block_layer:nn.Module,
      norm_layer:nn.Module,
      in_channels:int,
      out_channels:int,
  ):
    super().__init__()
    self.block_layer1 = block_layer(in_channels, out_channels)
    self.batch_norm = norm_layer(out_channels)
    self.act_fn = act_fn or nn.ReLU(inplace=True)
    self.block_layer2 = block_layer(out_channels, out_channels)
    
  def forward(self, x):
    identity = x

    y = self.block_layer1(x)
    y = self.batch_norm(y)
    y = self.act_fn(y)

    y = self.block_layer2(y)
    y = self.batch_norm(y)

    y += identity
    return self.act_fn(y)

In [None]:
class ResidualSimpleBlockModel(nn.Module):
  def __init__(
      self,
      block_layer:nn.Module,
      in_channels:int,
      out_channels:int,
      norm_layer:nn.Module=None,
      act_fn:nn.Module=None,
  ):
    super().__init__()
    self.block_layer = block_layer(in_channels, out_channels)
    self.batch_norm = norm_layer or nn.BatchNorm1d
    self.batch_norm = batch_norm(out_channels)
    self.act_fn = act_fn or nn.ReLU(inplace=True)
    self.proj = nn.utils.weight_norm(nn.Linear(out_channels, out_channels)) # weight norm from https://www.kaggle.com/shahules/pytorch-entity-embedding#Model
    
  def forward(self, x):
    identity = x

    y = self.block_layer(x)
    y = self.batch_norm(y)
    y = self.act_fn(self.proj(y))

    y += identity
    return self.act_fn(y)

In [None]:
class IdentityModel(nn.Module):
  def __init__(
      self,
  ):
    super().__init__()
    self.identity = nn.Identity()
    
  def forward(self, x):
    return self.identity(x)

#  Test

In [5]:
from timm.models.layers.classifier import ClassifierHead

model = EncoderDecoderModel(
    encoder=CatModel(
        left_model=EmdebbingModel(
            num_embeddings=input_dim,
            embedding_dim=features_dim,
        ),
        right_model=IdentityModel()
    ),
    decoder=MultiBlockModel(
        block=ResidualBlockModel(
            block_layer=nn.Linear,
            norm_layer=nn.BatchNorm1d,
            in_channels=features_dim,
            out_channels=features_dim,
        ),
        num_blocks=1
    ),
    classifier=ClassifierHead(
        in_chs=features_dim, 
        num_classes=num_classes, 
        pool_type='avg', 
        drop_rate=0., 
        use_conv=False
    )
)

NameError: ignored