In [6]:
from transformers import OpenAIGPTTokenizer
import torch
from torch import nn, optim
import torch.nn.functional as F
import math

In [7]:
class Attention(nn.Module):

    def __init__(self, input_dim, output_dim, num_heads, p):
        super(Attention, self).__init__()
        self.num_heads = num_heads
        self.c_attn = nn.Linear(input_dim, output_dim * 3)
        self.c_proj = nn.Linear(output_dim, output_dim)
        self.attn_dropout = nn.Dropout(p=p, inplace=False)
        self.resid_dropout = nn.Dropout(p=p, inplace=False)

    def scaled_dot_product_attention(self, q, k, v, mask: bool):
        """
        q: [batch_size, num_heads, head_dim, seq1_len]
        k: [batch_size, num_heads, head_dim, seq2_len]
        v: [batch_size, num_heads, head_dim, seq2_len]
        mask: [batch_size, num_heads, seq1_len, seq1_len]
        (seq1_len = seq2_len for self attention)
        """
        qk = q.matmul(k.transpose(-1, -2)) / math.sqrt(q.shape[-1])
        if mask:
            mask = torch.tril(torch.ones(1, 1, qk.shape[-2], qk.shape[-1])).type(torch.bool).to(qk.device)
            qk = qk.masked_fill(~mask, -torch.inf)
        attn_weights = self.attn_dropout(qk.softmax(dim=-1))
        return attn_weights.matmul(v)
    
    def qkv_reshape(self, x):
        return x.view(x.shape[0], x.shape[1], self.num_heads, -1).permute(0, 2, 1, 3)
    
    def output_reshape(self, x):
        x = x.permute(0, 2, 1, 3)
        return x.reshape(x.shape[0], x.shape[1], -1)
    
    def forward(self, x, mask: bool):
        q, k, v = self.c_attn(x).chunk(3, dim=-1)
        q, k, v = self.qkv_reshape(q), self.qkv_reshape(k), self.qkv_reshape(v)
        attn_outputs = self.output_reshape(self.scaled_dot_product_attention(q, k, v, mask))
        return self.resid_dropout(self.c_proj(attn_outputs))
    

class MLP(nn.Module):

    def __init__(self, input_dim, p) -> None:
        super(MLP, self).__init__()
        self.c_fc = nn.Linear(input_dim, input_dim * 4)
        self.c_proj = nn.Linear(input_dim * 4, input_dim)
        self.act = nn.GELU()
        self.dropout = nn.Dropout(p=p, inplace=False)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.dropout(x)
        x = self.act(x)
        x = self.c_proj(x)
        return x


class Block(nn.Module):

    def __init__(self, d_model, num_heads, p):
        super(Block, self).__init__()
        self.attn = Attention(d_model, d_model, num_heads, p)
        self.ln_1 = nn.LayerNorm(d_model)
        self.mlp = MLP(d_model, p)
        self.ln_2 = nn.LayerNorm(d_model)

    def forward(self, x, mask: bool):
        skip_x = x
        x = self.attn(x, mask=mask)
        x = self.ln_1(x + skip_x)
        skip_x = x
        x = self.mlp(x)
        x = self.ln_2(x + skip_x)
        return x
    

class GPT(nn.Module):

    def __init__(self, vocab_size, max_seq_len, n_layers, d_model, num_heads, p):
        super(GPT, self).__init__()
        self.d_model = d_model
        self.tokens_embed = nn.Embedding(vocab_size, d_model)
        self.positions_embed = nn.Embedding(max_seq_len, d_model)
        self.drop = nn.Dropout(p=p, inplace=False)
        self.h = nn.ModuleList([Block(d_model, num_heads, p) for _ in range(n_layers)])

    def forward(self, x):
        """
        x: [batch_size, seq_len]
        """
        x = self.tokens_embed(x) * math.sqrt(self.d_model)
        position_tokens = torch.arange(x.shape[-2]).unsqueeze(0).repeat(x.shape[0], 1).to(x.device)
        x = self.drop(x + self.positions_embed(position_tokens))
        for layer in self.h:
            x = layer(x, mask=True)
        return x

In [8]:
def init_model_and_tokenizer(weights_path, device):
    tokenizer = OpenAIGPTTokenizer.from_pretrained("openai-gpt")
    gpt = GPT(
        vocab_size=tokenizer.vocab_size,
        max_seq_len=512,
        d_model=768,
        num_heads=12,
        n_layers=12,
        p=0.1
    ).to(device)
    gpt.load_state_dict(torch.load(weights_path, map_location=device))
    
    special_tokens_dict = {"bos_token":"<bos>", "eos_token":"<eos>", "sep_token":"<sep>"}
    num_added = tokenizer.add_special_tokens(special_tokens_dict)
    new_embedding_weights = torch.randn(num_added, 768).to(device)
    gpt.tokens_embed.weight.data = torch.cat([gpt.tokens_embed.weight.data, new_embedding_weights], dim=0)
    return tokenizer, gpt

In [9]:
DEV = torch.device("mps")
tokenizer, model = init_model_and_tokenizer("weights.pth", DEV)

In [12]:
model_input = tokenizer("<bos> hello, how are you? <sep> I am alright <eos>", return_tensors="pt")["input_ids"].to(DEV)
out = model(model_input)