In [2]:
from torch import nn
import torch

from einops import rearrange

In [None]:
nn.BatchNorm1d()

In [4]:
import torch

In [6]:
def abs(x):
    if x > 0:
        return x
    else:
        return -x

In [7]:
x = torch.tensor([1])

In [9]:
type(jitted_abs)

torch.jit.ScriptFunction

In [None]:
# from einops.layers.torch import 

In [11]:
jitted_abs(-x)

tensor([-1])

In [8]:
jitted_abs = torch.jit.trace(abs, x)

  if x > 0:


In [12]:
# import numpy as np
# np.einsum('b n d, b m d -> b n m', x, y)

In [21]:
from torch import nn
from einops import rearrange

In [22]:
class PreNorm(nn.Module):
    """Apply layer normalization to the input and pass it through the layer."""
    def __init__(self, dim: int, layer: nn.Module) -> None:
        super().__init__()
        self.layer = layer
        self.norm = nn.LayerNorm(dim)
    
    def forward(self, x: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
        x = self.norm(x)
        return self.layer(x, **kwargs)


class FeedForward(nn.Module):
    """ Implement Linear(d, h) -> GeLU() -> Linear(h, d) """
    def __init__(self, dim: int, hidden_dim: int) -> None:
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim)
        )
    
    def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
        return self.net(x)
# JAX

class Attention(nn.Module):
    """ Multi-Head Self Attention """
    def __init__(self, dim: int, heads: int = 8, dim_head: int = 64) -> None:
        super().__init__()
        
        inner_dim = dim_head * heads
        need_project_out = not (heads == 1 and dim_head == dim)
        
        self.heads = heads
        self.scale = dim_head ** -0.5
        
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
        self.attend = nn.Softmax(dim=-1)
        
        self.to_out = nn.Linear(inner_dim, dim) if need_project_out else nn.Identity()
        
    def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
        qkv = self.to_qkv(x).chunk(3, dim=-1) # [batch_size, seq_len, (dim_head * heads)]
        q, k , v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv) 
        
        attn = torch.matmul(q, k.transpose(-1, -2)) # [batch_size, seq_len, seq_len]
        attn = attn * self.scale
        attn_weights = self.attend(attn)
        
        out = torch.matmul(attn_weights, v)             # [batch_size, heads, seq_len, dim_head]
        out = rearrange(out, 'b h s d -> b s (h d)')    # [batch_size, seq_len, dim_head * heads == inner_dim]
        return self.to_out(out)
        
    # def forward_one_head(self, x: torch.FloatTensor) -> torch.FloatTensor:
    #     qkv = self.to_qkv(x) # [batch_size, seq_len, dim * 3]
    #     qkv = qkv.chunk(3, dim=-1) # [batch_size, seq_len, dim] * 3
    #     # qkv[0] # 0, ..., dim - 1 ||| dim, ..., 2 * dim - 1 ||| 2 * dim, ..., 3 * dim - 1
    #     q, k, v = qkv
        
    #     # q [batch_size, seq_len, dim]
    #     # k [batch_size, seq_len, dim] -> k.transpose(-1, -2) [batch_size, dim, seq_len]
        
    #     attn = torch.matmul(q, k.transpose(-1, -2)) # [batch_size, seq_len, seq_len]
    #     attn_weights = self.attend(attn) # [batch_size, seq_len, seq_len]
        
    #     out = torch.matmul(attn_weights, v) # [batch_size, seq_len, dim]
    #     return out
        

class TransformerEncoder(nn.Module):
    def __init__(self, dim: int, depth: int, heads: int, dim_head: int, mlp_dim: int) -> None:
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            layer = nn.ModuleList([
                PreNorm(dim, Attention(dim, heads, dim_head)),
                PreNorm(dim, FeedForward(dim, mlp_dim)),
            ])
            self.layers.append(layer)
        
        self.norm = nn.LayerNorm(dim)

    def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
            
        return self.norm(x)


In [23]:
encoder = TransformerEncoder(512, 6, 8, 64, 2048)

In [24]:
a = torch.randn(3, 5, 512)

In [26]:
encoder(a).shape

torch.Size([3, 5, 512])

In [18]:
sum(p.numel() for p in encoder.parameters())

18906112