In [None]:
import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from matplotlib import pyplot as plt
from torchinfo import summary

In [None]:
device = torch.device('cpu')
if torch.cuda.is_available():
    device = torch.device('cuda')
    print("Using CUDA")
if torch.backends.mps.is_available():
    device = torch.device('mps')
    print("Using MPS")
else:
    print("Using CPU")

In [None]:
data_dir = './train_data/wkz8.txt'
ctx_len = 64
batch_size = 8
d_model = 512
n_heads = 8
n_layers = 6

In [None]:
# Read the text file
with open(data_dir, 'r', encoding='utf-8') as f:
    text = f.read()

# Count unique characters
unique_chars = set(text)
num_unique_chars = len(unique_chars)

print(f'Length of text: {len(text)}')
print(f"Number of unique characters in the file: {num_unique_chars}")
print("Unique characters:", ''.join(sorted(unique_chars)))


In [None]:
character_to_index = {char: i for i, char in enumerate(unique_chars)}
index_to_character = {i: char for i, char in enumerate(unique_chars)}
encode = lambda x: [character_to_index[i] for i in x]
decode = lambda x: [index_to_character[i] for i in x]

print(encode('你'))
print(decode(encode('你')))

In [None]:
class RotaryPositionalEmbedding(nn.Module):
    def __init__(self, d_model, max_seq_len):
        super(RotaryPositionalEmbedding, self).__init__()
        self.d_model = d_model
        self.max_seq_len = max_seq_len

        # Precompute sinusoidal embeddings
        inv_freq = 1.0 / (10000 ** (torch.arange(0, d_model, 2).float() / d_model))
        self.register_buffer("sinusoidal", torch.einsum("i,j->ij", torch.arange(max_seq_len).float(), inv_freq))
        self.register_buffer("sin", torch.sin(self.sinusoidal))
        self.register_buffer("cos", torch.cos(self.sinusoidal))

    def forward(self, x):
        """
        Args:
            x: A tensor of shape (length, batch, d_model).

        Returns:
            A tensor of shape (length, batch, d_model) with rotary positional embeddings applied.
        """
        length, batch, d_model = x.shape
        assert d_model == self.d_model, "Input d_model must match initialized d_model"

        # Apply rotary embeddings
        x1, x2 = x[..., ::2], x[..., 1::2]  # Split into even and odd dimensions
        x_rotated = torch.cat([x1 * self.cos[:length, None, :] - x2 * self.sin[:length, None, :],
                               x1 * self.sin[:length, None, :] + x2 * self.cos[:length, None, :]], dim=-1)
        return x_rotated

In [None]:
class MHA(nn.Module):
    def __init__(self, d_model, n_heads, ctx_len):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        self.rope = RotaryPositionalEmbedding(d_model, max_seq_len = ctx_len)
        self.wq = nn.Linear(d_model, d_model)
        self.wk = nn.Linear(d_model, d_model)
        self.wv = nn.Linear(d_model, d_model)
        self.wo = nn.Linear(d_model, d_model)
        self.ff = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.GELU(),
            nn.Linear(4 * d_model, d_model)
        )

    def forward(self, x):
        # rope only applies to q and k, not v
        q = self.wq(x)
        q = self.rope(q)
        k = self.wk(x)
        k = self.rope(k)
        v = self.wv(x)

        q = q.view(q.shape[0], q.shape[1], self.n_heads, self.head_dim)
        k = k.view(k.shape[0], k.shape[1], self.n_heads, self.head_dim)
        v = v.view(v.shape[0], v.shape[1], self.n_heads, self.head_dim)

        # Assume the input is of shape (length, batch, d_model)
        # the Q, K, V tensors are now of shape (length, batch, n_heads, head_dim)
        q = q.permute(1, 2, 0, 3)
        k = k.permute(1, 2, 0, 3)
        v = v.permute(1, 2, 0, 3)
        # now they are of shape (batch, n_heads, length, head_dim)
        attn = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attn = F.softmax(attn, dim=-1)
        # the operations so far can be done with einsum in a much more succinct way i suppose
        out = attn @ v
        out = out.permute(2, 0, 1, 3).reshape(x.shape[0], x.shape[1], self.d_model)
        out = self.wo(out)
        out = self.ff(out)
        return out
        


In [None]:
class GPT(nn.Module):
    def __init__(self, vocab, d_model, ctx_len, n_heads, n_layers):
        super().__init__()
        self.vocab = vocab
        self.d_model = d_model
        self.ctx_len = ctx_len
        self.embedding = nn.Embedding(vocab, d_model)
        self.mha = nn.ModuleList([MHA(d_model, n_heads, ctx_len) for i in range(n_layers)])
        self.fc = nn.Linear(d_model, vocab)
 
    def forward(self, x):
        x = self.embedding(x)
        for layer in self.mha:
            x = layer(x)
        x = self.fc(x[-1])
        return x

In [None]:
model = GPT(vocab=num_unique_chars,
            d_model=d_model,
            ctx_len=ctx_len,
            n_heads=n_heads,
            n_layers=n_layers)
summary(model)

In [20]:
x = torch.randint(1, num_unique_chars, (ctx_len,1))
y = model(x)
print(y.shape)
print(y)
y = F.softmax(y, dim=-1)
print(torch.argmax(y))

torch.Size([1, 2149])
tensor([[-0.0296, -0.0054,  0.0118,  ...,  0.0462, -0.0051,  0.0271]],
       grad_fn=<AddmmBackward0>)
tensor(1188)
