# 1. Import Libraries

In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from dataclasses import dataclass

import math

# 2. Define Parameters

In [17]:
@dataclass
class GPTConfig:
    # Text length
    block_size: int = 512

    batch_size: int = 12
    n_layer: int = 12
    n_head: int = 12

    # hidden_dim / hidden_size
    # -> tie embedding_weight
    n_embd: int = 768
    hidden_dim: int = n_embd

    dropout: float = 0.1
    head_size: int = n_embd // n_head

    # Official gpt2 tokenizer
    vocab_size: int = 50257

# 3. Define GPT Structure

### 3.1 Single-Head Attention

In [None]:
class SingleHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.key = nn.Linear(config.n_embd, config.head_size)
        self.value = nn.Linear(config.n_embd, config.head_size)
        self.query = nn.Linear(config.n_embd, config.head_size)
        self.head_size = config.head_size

        # Register attention_mask through register_buffer
        # No calc grad -> less ram & faster
        # Decoder
        self.register_buffer(
            "attention_mask",
            # tril: Lower Triangle
            # block_size: 512
            torch.tril(
                torch.ones(config.block_size, config.block_size)
            )
        )

        self.dropout = nn.Dropout(config.dropout)


    def forward(self, x):
        # Batch: Group of data
        # seq_len: Max tokens
        batch_size, seq_len, hidden_dim = x.size()

        k = self.key(x)
        q = self.query(x)
        v = self.value(x)

        # The last two rows
        # @ -> torch.matmul
        # q 点积 k
        weight = q @ k.transpose(-2, -1)

        weight = weight.masked_fill(
            self.attention_mask[:seq_len, :seq_len] == 0,
            float("-inf")  # After softmax -> 0
        )

        # Divided by d_k(dimension of the key vector) when calculating weight
        # Avoid high score, low grad after softmax
        weight = weight / math.sqrt(self.head_size)
        weight = F.softmax(weight, dim=-1)

        # Dropout attention weight
        weight = self.dropout(weight)

        # Dropout after weight
        out = weight @ v

        return out

### 3.2 Multi-head Attention
Four weights tensor (q, k, v, proj)

In [19]:
class MultiHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.heads = nn.ModuleList(
            [SingleHeadAttention(config) for _ in range(config.n_head)]
        )
        self.proj = nn.Linear(config.n_embd, config.n_embd)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        output = torch.cat(
            [h(x) for h in self.heads],
            dim=-1
        )
        output = self.proj(output)
        output = self.dropout(output)
        return output

### 3.3 Feed Forward

In [None]:
class FeedForward(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(config.n_embd, 4 * config.n_embd),  # swiglu -> 8/3
            nn.GELU(),  # 激活
            nn.Linear(4 * config.n_embd, config.n_embd),  
            nn.Dropout(config.dropout)
        )

    def forward(self, x):
        return self.net(x)

### 3.4 Block

In [21]:
class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.att = MultiHeadAttention(config)
        self.ffn = FeedForward(config)
        self.ln1 = nn.LayerNorm(config.n_embd)
        self.ln2 = nn.LayerNorm(config.n_embd)

    def forward(self, x):
        x = x + self.att(self.ln1(x))
        x = x + self.ffn(self.ln2(x))
        return x

### 3.5 GPT (embedding, position, norm, mlp, block)

In [41]:
class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()

        # (embedding, position, norm, mlp, block)
        # position embedding: 0, 1, xxx embedding -> rope
        # norm: layer norm -> rms norm
        # mlp -> swiglu
        # mha -> gqa
        self.block_size = config.block_size
        self.token_embedding_table = nn.Embedding(config.vocab_size, config.n_embd)
        self.position_embedding_table = nn.Embedding(config.block_size, config.n_embd)

        self.blocks = nn.Sequential(
            *[Block(config) for _ in range(config.n_layer)]
        )
        self.ln_final = nn.LayerNorm(config.n_embd)

        # Softmax -> no bias
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        # SLM model will use tie weight to decrease parameters
        # Linear: 4d -> 8d ( 8 x 4 )
        self.apply(self._init_weights)


    def _init_weights(self, module):
        if isinstance(module, nn.Linear):

            # normal distribution
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        # idx: token ids
        # target: target token idx
        # Same shape
        batch, seq_len = idx.size()

        # (batch seq_len, n_embed)
        token_emb = self.token_embedding_table(idx)

        pos_emb = self.position_embedding_table(
            # pos_emb and idx on same device
            torch.arange(seq_len, device=idx.device)
        )

        # ? token_embedding + position+embedding
        x = token_emb + pos_emb
        x = self.blocks(x)
        x = self.ln_final(x)

        # Shape: (batch, seq_len, vocab_size)
        logits = self.lm_head(x)

        if targets is None:
            loss = None
        else:
            batch, seq_len, vocab_size = logits.size()
            logits = logits.view(batch * seq_len, vocab_size)
            targets = targets.view(batch * seq_len)
            loss = F.cross_entropy(logits, targets)
        return logits, loss


    def generate(self, idx, max_new_tokens):
        # idx shape (batch, seq_len)
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # 如果序列太长，只取最后 block_size 个token
            idx_cond = idx if idx.size(1) <= self.block_size else idx[:, -self.block_size:]
            # 获取预测
            logits, _ = self(idx_cond)
            # 只关注最后一个时间步的预测
            logits = logits[:, -1, :]  # becomes (B, vocab_size)
            # 应用softmax获取概率
            probs = F.softmax(logits, dim=-1)
            # 采样下一个token
            idx_next = torch.multinomial(probs, num_samples=1)  # (B, 1)
            # 附加到序列上
            idx = torch.cat((idx, idx_next), dim=1)  # (B, T+1)
        return idx

# 4. Build Input Dataset

In [23]:
class MyDataset(Dataset):
    def __init__(self, path, block_size=512):
        import tiktoken
        self.enc = tiktoken.get_encoding("gpt2")
        self.block_size = block_size  # pos max length

        # Special symbol to divide different text
        # <|endoftext|> -> 50526
        self.eos_token = self.enc.encode(
            "<|endoftext|>",
            allowed_special={"<|endoftext|>"}
        )[0]

        # Encode -> train
        import json

        self.encoded_data = []
        self.max_lines = 1000
        raw_data = []
        with open(path, 'r') as f:
            for i, line in enumerate(f):
                if i >= self.max_lines:
                    break
                try:
                    text = json.loads(line.strip())['text']
                    raw_data.append(text)
                except json.JSONDecodeError:
                    continue
                except Exception as e:
                    continue

        full_encoded = []
        for text in raw_data:
            encoded_text = self.enc.encode(text)
            full_encoded.extend(encoded_text + [self.eos_token])

        # 将长文本分割成训练样本
        for i in range(0, len(full_encoded), self.block_size):
            # 多取一个 Token 作为目标
            chunk = full_encoded[i:i+self.block_size+1]
            # 如果长度不够，用 eos_token 填充
            if len(chunk) < self.block_size + 1:
                chunk = chunk + [self.eos_token] * (self.block_size + 1 - len(chunk))
            self.encoded_data.append(chunk)


    def __len__(self):
        return len(self.encoded_data)

    def __getitem__(self, idx):
        chunk = self.encoded_data[idx]
        x = torch.tensor(chunk[:-1], dtype=torch.long)
        y = torch.tensor(chunk[1:], dtype=torch.long)
        return x, y

    def encode(self, text):
        return self.enc.encode(text)

    def decode(self, ids):
        return self.enc.decode(ids)

In [24]:
train_dataset = MyDataset('data.jsonl')

train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [0.9, 0.1])

train_loader = DataLoader(train_dataset, batch_size=12, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=12, shuffle=False)

# 5. Run Related Functions

In [25]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print("Using {} device".format(device))

Using mps device


In [29]:
model = GPT(GPTConfig())
model = model.to(device)

# Print parameter
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params / 1e6} M")

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

# Cos LR
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=1000)

Total parameters: 162.643968 M


### Training Function

In [None]:
import os

os.makedirs('checkpoints', exist_ok=True)

In [None]:
# Train loop
def train(model, optimizer, scheduler, train_loader, val_loader, device):
    model.train()
    total_loss = 0
    for batch_idx, (x, y) in enumerate(train_loader):
        x, y = x.to(device), y.to(device)

        # Forward
        logits, loss = model(x, targets=y)

        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # LR
        scheduler.step()

        total_loss += loss.item()

        if batch_idx % 100 == 0:
            print(f'Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item():.4f}')
    return total_loss

def eval(model, val_loader, device):
    # Validation
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for x, y in val_loader:
            x, y = x.to(device), y.to(device)
            logits, loss = model(x, targets=y)
            val_loss += loss.item()
    return val_loss


for epoch in range(100):
    train_loss = train(model, optimizer, scheduler, train_loader, val_loader, device)
    val_loss = eval(model, val_loader, device)
    print(f'Epoch: {epoch}, Train Loss: {train_loss/len(train_loader):.4f}, Val Loss: {val_loss/len(val_loader):.4f}')

    # Save Model
    avg_val_loss = val_loss / len(val_loader)
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'val_loss': avg_val_loss,
    }
    
    os.makedirs('checkpoints', exist_ok=True)
    torch.save(checkpoint, f'checkpoints/model_epoch_{epoch}.pt')


Epoch: 0, Batch: 0, Loss: 3.1839
Epoch: 0, Batch: 100, Loss: 3.3189
Epoch: 0, Batch: 200, Loss: 3.3341
Epoch: 0, Train Loss: 3.2212, Val Loss: 3.2327
Epoch: 1, Batch: 0, Loss: 3.1749
Epoch: 1, Batch: 100, Loss: 3.2260
Epoch: 1, Batch: 200, Loss: 3.1955
Epoch: 1, Train Loss: 3.2211, Val Loss: 3.2327
Epoch: 2, Batch: 0, Loss: 3.1914
Epoch: 2, Batch: 100, Loss: 3.1446
Epoch: 2, Batch: 200, Loss: 3.2590
Epoch: 2, Train Loss: 3.2215, Val Loss: 3.2327
Epoch: 3, Batch: 0, Loss: 3.1502


KeyboardInterrupt: 

# 6. Run model

In [43]:
checkpoint_path = "checkpoints/model_epoch_0.pt"
if not os.path.exists(checkpoint_path):
    raise FileNotFoundError(f"Checkpoint {checkpoint_path} not found")

In [44]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device} device")

Using mps device


In [45]:
model = GPT(GPTConfig()).to(device)
checkpoint_path = "checkpoints/model_epoch_1.pt"
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

GPT(
  (token_embedding_table): Embedding(50257, 768)
  (position_embedding_table): Embedding(512, 768)
  (blocks): Sequential(
    (0): Block(
      (att): MultiHeadAttention(
        (heads): ModuleList(
          (0-11): 12 x SingleHeadAttention(
            (key): Linear(in_features=768, out_features=64, bias=True)
            (value): Linear(in_features=768, out_features=64, bias=True)
            (query): Linear(in_features=768, out_features=64, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (ffn): FeedForward(
        (net): Sequential(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=3072, out_features=768, bias=True)
          (3): Dropout(p=0.1, inplace=False)
        )
      )
      (ln1): LayerNorm((768,), eps=1e-

In [46]:
dataset = MyDataset('data.jsonl')

In [None]:
prompt = "你好"
encoded_prompt = dataset.encode(prompt)
input_ids = torch.tensor([encoded_prompt], dtype=torch.long).to(device)
generated_ids = model.generate(input_ids, max_new_tokens=50)
generated_text = dataset.decode(generated_ids[0].cpu().numpy())
print("Generated text:\n", generated_text)

Generated text:
 你灶增枷压用向绿跡的。其试的激大者�各，在制欢罗
