In [1]:
import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from matplotlib import pyplot as plt
from torchinfo import summary

In [2]:
device = torch.device('cpu')
if torch.cuda.is_available():
    device = torch.device('cuda')
    print("Using CUDA")
if torch.backends.mps.is_available():
    device = torch.device('mps')
    print("Using MPS")
else:
    print("Using CPU")

Using MPS


In [3]:
data_dir = './train_data/wkz8.txt'
ctx_len = 64
batch_size = 8
d_model = 512
n_heads = 8
n_layers = 6

In [4]:
# Read the text file
with open(data_dir, 'r', encoding='utf-8') as f:
    text = f.read()

# Count unique characters
unique_chars = set(text)
num_unique_chars = len(unique_chars)

print(f'Length of text: {len(text)}')
print(f"Number of unique characters in the file: {num_unique_chars}")
print("Unique characters:", ''.join(sorted(unique_chars)))


Length of text: 87542
Number of unique characters in the file: 2149
Unique characters: 
 *.08altwx—‘’“”…、。一丁七万丈三上下不与丑专且世丘业东丝丢两严个丫中丸丹为主丽举乃久么义之乌乍乎乐乔乖乘九也书买乱了予争事二于亏云互五井些亡交亮亲人亿什仇今仍从仑仔他仗仙仞代令以仨仪们仰件价任份仿伍伏伐休众优伙会伟传伤伦伶伸似但位低住佑体何余佛作你佩佼使依侠侧便促俊俏俗保信俩修俯俱俺倒候倚借倦值倾假偈偎偏做停偷偿傲傻像僧儿兀元兄充先光克免兔入全八公六兮兰共关兴兵其具养兽内冉再冒冕写冠冤冥冬冰冲决况冷冻净凄准凉凌减凑凛凝几凡凤凭凰凳凶出击刀分切刑划刚初判利别刮到制刹刺刻前剑剥剧剩副割劈力劝办功加务劣动助努劫劲劳势勇勒勾包匆化北匹区医十千升午半华卑卒单卖南博占卫卯印危即却卷历厉压厌厚原厮去参又及友双反发取受变叛叠口古句另只叫叭叮可台史右叶号叹叽吁吃各吆合同名后吏吐向吓吗君吞吟否吧吩含听吭启吵吸吹吻吼吽呀呆告呜呢呦周味呵呸呼命咆咋和咏咐咒咔咕咚咦咧咪咬咱咳咽哀品哄哇哈哉响哎哑哗哟哥哦哧哩哪哭哮哲哼哽唉唐唠唤唬唯唱唳唵啃啄啊啥啦啪啷啸啼喂喃善喊喘喜喝喧喳喷喽嗔嗝嗡嗨嗬嗯嗵嗷嘎嘘嘛嘟嘤嘲嘴嘶嘻嘿噙噢器嚎嚓嚣嚼囚四回因团园困围固国图圆圈土圣在地场均坏坐坑块坚坠坡垂型垒垫埋城堂堆堵塌塔塞填境墓墙增墟壁壑壤士声壳壶处备复夏夕外多夜够大天太夫央失头夹夺奇奈奋奔套奘奥女奶她好如妃妄妇妈妖妙妨妹始姐姑姓姿威娘娲婉婢婶媳嫁嫩子孔字存孙季孤学孩孱宁它宇守安完宗官宙定宛宜宝实客宫宰害宴家容宽宿寂密寒寞察寸对寺寻导封射将尊小少尖尘尚尝尤就尸尺尽尾局屁层居屈屋屏屑展山岁岂岐岔岚岛岩岭岸峰峻崖崩崽嶽巉巡左巧巨差己已巴巾市布帅师希帘帝带席帮常帽幅幕干平年并幸幻幽广庄庆床序应底庙府庞废度座庭廊廖延建开异弃弄式引弟张弥弧弯弱弹强弼归当形彩影彼往径待很律徐徒得御微德心必忆忌忍忒志忘忙忧快念忽怀态怎怏怒怔怕怖怜思急性怨怪怯总恋恍恐恒恙恨恩恬恭息恳恶恼悄悉悔悟悠悦您悬悲情惊惑惜惧惨惬惮想惶惹愁愉意愕感愣愤愿慌慕慢慧憧憬憾懂懒戏成我戒或战戟戳戴所扁扇手才扎扑打扔托执扫扬扮扯扰扶找承把抓投抖抗折抚抛抢护报披抬抱抵抹押抽拂拄担拆拉拍拎拐拒拖拘招拜拣拥拦拨择拯拱拳拼拽拾拿持挂指按挑挖挚挠挡挣挤挥挨振挺挽捂捅捉捏捕捞捡换捣捧据捶捷掀

In [5]:
character_to_index = {char: i for i, char in enumerate(unique_chars)}
index_to_character = {i: char for i, char in enumerate(unique_chars)}
encode = lambda x: [character_to_index[i] for i in x]
decode = lambda x: [index_to_character[i] for i in x]

print(encode('你'))
print(decode(encode('你')))

[756]
['你']


In [6]:
class RotaryPositionalEmbedding(nn.Module):
    def __init__(self, d_model, max_seq_len):
        super(RotaryPositionalEmbedding, self).__init__()
        self.d_model = d_model
        self.max_seq_len = max_seq_len

        # Precompute sinusoidal embeddings
        inv_freq = 1.0 / (10000 ** (torch.arange(0, d_model, 2).float() / d_model))
        self.register_buffer("sinusoidal", torch.einsum("i,j->ij", torch.arange(max_seq_len).float(), inv_freq))
        self.register_buffer("sin", torch.sin(self.sinusoidal))
        self.register_buffer("cos", torch.cos(self.sinusoidal))

    def forward(self, x):
        """
        Args:
            x: A tensor of shape (length, batch, d_model).

        Returns:
            A tensor of shape (length, batch, d_model) with rotary positional embeddings applied.
        """
        length, batch, d_model = x.shape
        assert d_model == self.d_model, "Input d_model must match initialized d_model"

        # Apply rotary embeddings
        x1, x2 = x[..., ::2], x[..., 1::2]  # Split into even and odd dimensions
        x_rotated = torch.cat([x1 * self.cos[:length, None, :] - x2 * self.sin[:length, None, :],
                               x1 * self.sin[:length, None, :] + x2 * self.cos[:length, None, :]], dim=-1)
        return x_rotated

In [7]:
class MHA(nn.Module):
    def __init__(self, d_model, n_heads, ctx_len):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        self.rope = RotaryPositionalEmbedding(d_model, max_seq_len = ctx_len)
        self.wq = nn.Linear(d_model, d_model)
        self.wk = nn.Linear(d_model, d_model)
        self.wv = nn.Linear(d_model, d_model)
        self.wo = nn.Linear(d_model, d_model)
        self.ff = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.GELU(),
            nn.Linear(4 * d_model, d_model)
        )

    def forward(self, x):
        # rope only applies to q and k, not v
        q = self.wq(x)
        q = self.rope(q)
        k = self.wk(x)
        k = self.rope(k)
        v = self.wv(x)

        q = q.view(q.shape[0], q.shape[1], self.n_heads, self.head_dim)
        k = k.view(k.shape[0], k.shape[1], self.n_heads, self.head_dim)
        v = v.view(v.shape[0], v.shape[1], self.n_heads, self.head_dim)

        # Assume the input is of shape (length, batch, d_model)
        # the Q, K, V tensors are now of shape (length, batch, n_heads, head_dim)
        q = q.permute(1, 2, 0, 3)
        k = k.permute(1, 2, 0, 3)
        v = v.permute(1, 2, 0, 3)
        # now they are of shape (batch, n_heads, length, head_dim)
        attn = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attn = F.softmax(attn, dim=-1)
        # the operations so far can be done with einsum in a much more succinct way i suppose
        out = attn @ v
        out = out.permute(2, 0, 1, 3).reshape(x.shape[0], x.shape[1], self.d_model)
        out = self.wo(out)
        out = self.ff(out)
        return out
        


In [8]:
class GPT(nn.Module):
    def __init__(self, vocab, d_model, ctx_len, n_heads, n_layers):
        super().__init__()
        self.vocab = vocab
        self.d_model = d_model
        self.ctx_len = ctx_len
        self.embedding = nn.Embedding(vocab, d_model)
        self.mha = nn.ModuleList([MHA(d_model, n_heads, ctx_len) for i in range(n_layers)])
        self.fc = nn.Linear(d_model, vocab)
 
    def forward(self, x):
        x = self.embedding(x)
        for layer in self.mha:
            x = layer(x)
        x = self.fc(x[-1])
        return x

In [9]:
model = GPT(vocab=num_unique_chars,
            d_model=d_model,
            ctx_len=ctx_len,
            n_heads=n_heads,
            n_layers=n_layers)
summary(model)

Layer (type:depth-idx)                             Param #
GPT                                                --
├─Embedding: 1-1                                   1,100,288
├─ModuleList: 1-2                                  --
│    └─MHA: 2-1                                    --
│    │    └─RotaryPositionalEmbedding: 3-1         --
│    │    └─Linear: 3-2                            262,656
│    │    └─Linear: 3-3                            262,656
│    │    └─Linear: 3-4                            262,656
│    │    └─Linear: 3-5                            262,656
│    │    └─Sequential: 3-6                        2,099,712
│    └─MHA: 2-2                                    --
│    │    └─RotaryPositionalEmbedding: 3-7         --
│    │    └─Linear: 3-8                            262,656
│    │    └─Linear: 3-9                            262,656
│    │    └─Linear: 3-10                           262,656
│    │    └─Linear: 3-11                           262,656
│    │    └─Sequential:

In [10]:
x = torch.randint(1, num_unique_chars, (ctx_len,1))
y = model(x)
print(y.shape)

torch.Size([1, 2149])
