In [49]:
from typing import Literal

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers import AutoTokenizer

from tqdm import tqdm


In [50]:
ctx = {}

In [51]:
tokenizer = AutoTokenizer.from_pretrained('bert-base-chinese')
tokenizer.pad_token = tokenizer.eos_token

def decode(tsn):
    return tokenizer.decode(
        tsn,
        skip_special_tokens=True
    )

def encode(s, ctx_len=None):
    return tokenizer(
        s,
        truncation=False,
        max_length=ctx_len,
        padding='do_not_pad' if ctx_len is None else 'max_length',
        return_tensors='pt'
    )['input_ids'].squeeze(dim=0)

In [52]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, nhead, pdrop=0.2):
        super().__init__()
        self.qkv = nn.Linear(embed_dim, 3 * embed_dim)
        self.attend = nn.Softmax(dim=-1)
        self.drop = nn.Dropout(pdrop)
        self.c_proj = nn.Linear(embed_dim, embed_dim)
        self.nhead = nhead

    def forward(self, x, attn_mask=None):
        b, t, c = x.shape
        qkv = self.qkv(x).chunk(3, dim=2)
        q, k, v = map(lambda tsn: rearrange(tsn, 'b t (nh hd) -> b nh t hd', nh=self.nhead), qkv)

        wei = q @ k.transpose(-2, -1) * k.size(dim=-1) ** -0.5
        if attn_mask is not None:
            wei = wei.masked_fill(attn_mask[:t, :t] == 0, -float('inf'))
        attn = self.attend(wei)

        y = self.c_proj(rearrange(attn @ v, 'b nh t hd -> b t (nh hd)', nh=self.nhead))

        return self.drop(y)

class FFN(nn.Module):
    def __init__(self, embed_dim, pdrop=0.2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 4),
            nn.GELU(),
            nn.Linear(embed_dim * 4, embed_dim),
            nn.Dropout(pdrop),
        )

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    def __init__(self, embed_dim, nhead, pdrop=0.2):
        super().__init__()
        self.ln_1 = nn.LayerNorm(embed_dim)
        self.ln_2 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadAttention(embed_dim, nhead, pdrop)
        self.ffn = FFN(embed_dim, pdrop)

    def forward(self, x, attn_mask=None):
        x = x + self.attn(self.ln_1(x), attn_mask)
        x = x + self.ffn(self.ln_2(x))
        return x

class GPT(nn.Module):
    def __init__(self,
                vocab_size,
                embed_dim,
                context_length,
                nhead,
                n_layer,
                pdrop=0.2,
                ):
        super().__init__()
        self.tok_emb = nn.Embedding(vocab_size, embed_dim)
        self.pos_emb = nn.Embedding(context_length, embed_dim)
        self.blocks = nn.ModuleList(
            [Block(embed_dim, nhead, pdrop) for _ in range(n_layer)]
        )
        self.drop = nn.Dropout(pdrop)
        self.ln_f = nn.LayerNorm(embed_dim)
        self.lm_head = nn.Linear(embed_dim, vocab_size)

        self.context_length = context_length

    def forward(self, idx, attn_mask=None, target=None):
        b, t = idx.shape

        tok_emb = self.tok_emb(idx)
        pos_emb = self.pos_emb(torch.arange(0, t).long().to(idx.device))

        x = tok_emb + pos_emb
        x = self.drop(x)

        for block in self.blocks:
            x = block(x, attn_mask)
        x = self.ln_f(x)

        logits = self.lm_head(x)

        loss = None
        if target is not None:
            loss = F.cross_entropy(rearrange(logits, 'b t c -> (b t) c'), rearrange(target, 'b t -> (b t)'))

        return logits, loss

    def generate(self, idx, max_tok=100, attn_mask=None, temperature=1.0):
        for _ in range(max_tok):
            idx_cond = idx[:, -self.context_length:]
            logits, _ = self(idx_cond, attn_mask=attn_mask)
            probs = (logits[:, -1, :] / temperature).softmax(dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=-1)
        return idx

In [53]:
ctx['context_length'] = 128

In [None]:
with open('input.txt', 'r', encoding='utf-8') as f:
    data = encode(f.read().replace(' ', '').replace('\n', '').replace('\t', ''))

train_ratio = 0.9
train_len = int(train_ratio * len(data))

train_data = data[:train_len]
val_data = data[train_len:]


def get_xy(split: Literal['train', 'val'], batch_size):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - ctx['context_length'], size=(batch_size, ))
    x = torch.stack([data[i:i+ctx['context_length']] for i in ix])
    y = torch.stack([data[i+1:i+ctx['context_length']+1] for i in ix])
    return x, y

In [None]:
model = GPT(
    vocab_size=tokenizer.vocab_size,
    embed_dim=768,
    context_length=ctx['context_length'],
    nhead=12,
    n_layer=24,
    pdrop=0.,
)

In [None]:
ctx['device'] = 'cuda' if torch.cuda.is_available() else 'cpu'
ctx['epochs'] = 100
ctx['batch_size'] = 64
ctx['eval_interval'] = 50
ctx['checkpoint_interval'] = 50
ctx['lr'] = 1e-4
ctx['state_dict'] = 'gpt.pt'

pbar = tqdm(range(1, ctx['epochs'] + 1))

model = model.to(ctx['device'])

if os.path.exists(ctx['state_dict']):
    model.load_state_dict(torch.load(ctx['state_dict'], weights_only=True))

attn_mask = torch.tril(torch.ones(ctx['context_length'], ctx['context_length'])).to(ctx['device'])

optimizer = torch.optim.AdamW(model.parameters(), lr=ctx['lr'])
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 10)

for epoch in pbar:
    model.train()
    x, y = get_xy('train', ctx['batch_size'])
    x = x.to(ctx['device'])
    y = y.to(ctx['device'])
    _, loss = model(x, attn_mask, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    scheduler.step()
    pbar.set_description(f'loss: {loss.item():.4f}')
    if epoch % ctx['checkpoint_interval'] == 0:
        torch.save(model.state_dict(), ctx['state_dict'])
    if epoch % ctx['eval_interval'] == 0:
        model.eval()
        generated = model.generate(
            encode('').unsqueeze(0).to(ctx['device']), 200, attn_mask=attn_mask
        )
        pbar.set_postfix_str(f'generated: {decode(generated[0])}')

In [None]:
generated = model.generate(
            encode('羊脂球').unsqueeze(0).to(ctx['device']), 500, attn_mask=attn_mask
        )
decode(generated[0]).replace(' ', '')