In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
import numpy as np
import matplotlib.pyplot as plt

In [None]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self, scale_factor=1, dropout=0.0, fill=float('-inf')):
        super(ScaledDotProductAttention).__init__()
        self.scale_factor = scale_factor
        self.dropout = nn.Dropout(dropout)
        self.softmax = nn.Softmax(dim=-1)
        self.fill = fill

    def forward(self, q, k, v, attn_mask=None):
        attn = torch.matmul(q, k.transpose(-2, -1)) / self.scale_factor

        if attn_mask is not None:
            attn = attn.masked_fill(attn_mask == 0, self.fill)

        attn = self.softmax(attn)
        attn = self.dropout(attn)
        y = torch.matmul(attn, v)

        return y

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, block_size=1024, dropout=0.1, eps=1e-6):
        super(MultiHeadAttention).__init__()
        assert embed_dim % num_heads == 0

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.block_size = block_size

        self.qkv_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=False)
        self.output_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.resid_dropout = nn.Dropout(dropout)
        scale_factor = (self.embed_dim // self.num_heads) ** 0.5
        self.attn = ScaledDotProductAttention(scale_factor=scale_factor, dropout=dropout)

        self.register_buffer("bias", torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size))

    def forward(self, x):
        B, T, C = x.size()

        q, k ,v  = self.qkv_proj(x).split(self.embed_dim, dim=2)
        k = k.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2)
        q = q.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2)
        v = v.view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2)

        y = self.attn(q, k, v, self.bias[:, :, :T, :T])
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.resid_dropout(self.output_proj(y))

        return y

In [None]:
class NewGELU(nn.Module):
    def __init__(self):
        super(NewGELU).__init__()

    def forward(self, x):
        return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))

class MLP(nn.Module):
    def __init__(self, embed_dim, dropout=0.1):
        super(MLP).__init__()
        self.model = nn.Sequential(
            nn.Linear(embed_dim, 4 * embed_dim),
            NewGELU(),
            nn.Linear(4 * embed_dim, embed_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.model(x)

In [None]:
class Block(nn.Module):
    def __init__(self, embed_dim, num_heads, block_size=1024, dropout=0.1, eps=1e-6):
        super(Block).__init__()
        self.sub_block_1 = nn.Sequential(
            nn.LayerNorm(embed_dim),
            MultiHeadAttention(embed_dim, num_heads, block_size, dropout, eps)
        )

        self.sub_block_2 = nn.Sequential(
            nn.LayerNorm(embed_dim),
            MLP(embed_dim, dropout)
        )

    def forward(self, x):
        x = x + self.sub_block_1(x)
        x = x + self.sub_block_2(x)
        return x