In [1]:
import torch
import torch.nn as nn
import numpy as np
import math

# this implementation follows minGPT: https://github.com/karpathy/minGPT

# this transformer take in token ids, e.g., "She has a cat." => ["She", "has", "a", "cat"] => [0, 1, 2, 3]

class MLP(nn.Module):
    def __init__(self, in_dim=2, out_dim=2, w=2, depth=2, shp=None):
        super(MLP, self).__init__()
        if shp == None:
            shp = [in_dim] + [w]*(depth-1) + [out_dim]
            self.in_dim = in_dim
            self.out_dim = out_dim
            self.depth = depth
                 
        else:
            self.in_dim = shp[0]
            self.out_dim = shp[-1]
            self.depth = len(shp) - 1
        linear_list = []
        for i in range(self.depth):
            linear_list.append(nn.Linear(shp[i], shp[i+1]))
        self.linears = nn.ModuleList(linear_list)
        self.shp = shp
    
    def forward(self, x):
        f = lambda x: 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
        #f = torch.nn.SiLU()
        for i in range(self.depth-1):
            x = f(self.linears[i](x))
        x = self.linears[-1](x)
        return x
    

class Attention(nn.Module):
    def __init__(self, n_head=2, n_embed=6):
        super().__init__()
        assert n_embed % n_head == 0
        # key, query, value projections for all heads, but in a batch
        self.l_attn = nn.Linear(n_embed, 3*n_embed)
        # output projection
        self.l_proj = nn.Linear(n_embed, n_embed)
        self.n_head = n_head
        self.n_embed = n_embed

    def forward(self, x):
        # B: batch size; T: sequence length; C: embedding dimensionality (n_embd)
        B, T, C = x.size()

        # query, key, value
        x = self.l_attn(x)
        q, k, v = x[:,:,:C], x[:,:,C:2*C], x[:,:,2*C:3*C]
        n_head = self.n_head
        assert C % n_head == 0
        q = q.reshape(B, T, n_head, int(C/n_head))
        k = k.reshape(B, T, n_head, int(C/n_head))
        v = v.reshape(B, T, n_head, int(C/n_head))

        # (causal) self-attention
        attn = torch.einsum('ijhl,ikhl->ijkh', q, k)/np.sqrt(int(C/n_head))
        mask = torch.ones(T,T)*float('-inf')
        mask = torch.tril(mask, diagonal=-1).permute(1,0).unsqueeze(dim=0).unsqueeze(dim=3)
        attn = attn + mask
        attn = nn.Softmax(dim=2)(attn)
        attn = torch.einsum('ijkl,iklh->ijlh', attn, v)
        attn = attn.reshape(B, T, C)

        # output projection
        y = self.l_proj(attn)
        return y
    

class Block(nn.Module):
    # A transformer block
    def __init__(self, n_head=2, n_embed=6):
        super().__init__()
        self.n_head = n_head
        self.n_embed = n_embed
        self.ln_1 = nn.LayerNorm(n_embed)
        self.attn = Attention(n_head=n_head, n_embed=n_embed)
        self.ln_2 = nn.LayerNorm(n_embed)
        self.mlp = MLP(shp=[n_embed, 4*n_embed, n_embed])

    def forward(self, x):
        #If you want to use layer norm, use this
        #x = x + self.attn(self.ln_1(x))
        #x = x + self.mlp(self.ln_2(x))
        #x = self.attn(self.ln_1(x))
        #x = self.mlp(self.ln_2(x))
        
        #If you don't want to use layer norm, use this
        x = x + self.attn(x)
        x = x + self.mlp(x)
        return x
    

class Transformer(nn.Module):
    # Transformer: since our goal is to deal with linear regression, not language, 
    # we ignore token embeddings and positioanl embeddings. 
    def __init__(self, out_dim=19, n_head=2, n_embed=20, n_layer=2):
        super().__init__()
        self.n_head = n_head
        self.n_embed = n_embed
        self.n_layer = n_layer
        self.l_i = nn.Linear(n_embed, n_embed)
        self.blocks = nn.ModuleList([Block(n_head=n_head, n_embed=n_embed) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embed)
        self.l_f = nn.Linear(n_embed, out_dim)
        self.in_dim = n_embed
        self.out_dim = out_dim
        self.embedding = nn.Parameter(torch.normal(0,1,size=(out_dim,n_embed)))
        
    def forward(self, token_ids):
        # token_ids shape: (batch_size, sequence length)
        x = self.embedding[token_ids]
        x = self.l_i(x)
        for i in range(self.n_layer):
            x = self.blocks[i](x)
        y = self.l_f(x)
        # y shape: (batch_size, sequence length, out_dim)
        # here out_dim is the number of tokens in your library
        # y is the logits used to predict the next token
        return y

# initialize a transformer
out_dim = 19
n_head = 4
n_embed = 32
n_layer = 2
model = Transformer(out_dim=out_dim, n_head=n_head, n_embed=n_embed, n_layer=n_layer)

# feed data into transformer
batch_size = 128
seq_len = 16
x = np.random.choice(out_dim,size=(batch_size, seq_len))
model(x).shape # (batch_size, seq_len, out_dim)

torch.Size([128, 16, 19])