In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
import time

class PokemonDataset(Dataset):
    def __init__(self, file_path, max_seq_len):
        with open(file_path, 'r', encoding='utf-8') as f:
            self.text = f.read().lower().split()
        vocab = sorted(set(self.text))
        self.vocab = {word: idx for idx, word in enumerate(vocab)}
        self.itos = {idx: word for word, idx in self.vocab.items()}
        self.vocab_size = len(self.vocab)
        self.max_seq_len = max_seq_len

        self.tokens = [self.vocab[word] for word in self.text]
        self.data = [self.tokens[i : i + max_seq_len] for i in range(len(self.tokens) - max_seq_len)]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        seq = self.data[idx]
        x = seq[:-1]
        y = seq[1:]
        return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)


class GPTLanguageModel(nn.Module):
    def __init__(self, vocab_size, block_size, n_embd=128, n_layer=2, n_head=4):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)
        self.block_size = block_size

    def forward(self, idx, targets=None):
        # idx: (B, T)
        B, T = idx.size()
        token_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(T, device=idx.device))
        x = token_emb + pos_emb  # (B, T, n_embd)
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.lm_head(x)  # (B, T, vocab_size)
        if targets is None:
            loss = None
        else:
            # compute the loss
            logits = logits.view(-1, logits.size(-1))
            targets = targets.view(-1)
            loss = nn.functional.cross_entropy(logits, targets)
        return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.block_size:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :]
            probs = torch.softmax(logits, dim=-1)
            next_idx = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, next_idx), dim=1)
        return idx
    
class Block(nn.Module):
    def __init__(self, n_embd, n_head):
        super().__init__()
        self.sa = MultiHeadAttention(n_embd, n_head)
        self.ffwd = FeedForward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

class MultiHeadAttention(nn.Module):
    def __init__(self, n_embd, n_head):
        super().__init__()
        assert n_embd % n_head == 0
        self.n_head = n_head
        self.head_dim = n_embd // n_head
        self.key = nn.Linear(n_embd, n_embd)
        self.query = nn.Linear(n_embd, n_embd)
        self.value = nn.Linear(n_embd, n_embd)
        self.proj = nn.Linear(n_embd, n_embd)

    def forward(self, x):
        B, T, C = x.size()
        k = self.key(x).view(B, T, self.n_head, self.head_dim).transpose(1,2)
        q = self.query(x).view(B, T, self.n_head, self.head_dim).transpose(1,2)
        v = self.value(x).view(B, T, self.n_head, self.head_dim).transpose(1,2)
        att = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5)
        att = torch.softmax(att, dim=-1)
        y = att @ v
        y = y.transpose(1,2).contiguous().view(B, T, C)
        y = self.proj(y)
        return y

class FeedForward(nn.Module):
    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd)
        )
    def forward(self, x):
        return self.net(x)


# -------------------------------
# training loop
# -------------------------------

def estimate_perplexity(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    total_tokens = 0

    with torch.no_grad():
        for x_batch, y_batch in dataloader:
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            logits, _ = model(x_batch)
            # logits shape: (batch_size, seq_len, vocab_size)
            logits = logits.view(-1, logits.size(-1))
            y_batch = y_batch.view(-1)
            loss = criterion(logits, y_batch)
            total_loss += loss.item() * y_batch.size(0)
            total_tokens += y_batch.size(0)
    avg_loss = total_loss / total_tokens
    perplexity = torch.exp(torch.tensor(avg_loss))
    return avg_loss, perplexity.item()


def train(use_metrics=False):
    import time
    import torch
    from torch import nn, optim
    from torch.utils.data import DataLoader

    if torch.backends.mps.is_available():
        device = torch.device("mps")  # use M1 GPU
    else:
        device = torch.device("cpu")  # use CPU if no M1 GPU
   
    max_seq_len = 32  # the maximum sequence length
    batch_size = 32
    num_epochs = 10
    learning_rate = 0.001
    n_embd = 32
    n_layer = 2
    n_head = 4

    # 
    dataset = PokemonDataset(file_path='pokemon.txt', max_seq_len=max_seq_len)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    # initialize the model
    model = GPTLanguageModel(vocab_size=dataset.vocab_size,
                             block_size=max_seq_len - 1,
                             n_embd=n_embd,
                             n_layer=n_layer,
                             n_head=n_head)
    model = model.to(device)

    print("Model parameters:", sum(p.numel() for p in model.parameters() if p.requires_grad))
    print("Vocabulary size:", dataset.vocab_size)
    total_tokens = len(dataset) * (max_seq_len - 1) * num_epochs
    print(f"Total training tokens: {total_tokens/1e6:.2f}M")

    # define the loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    print("Debut d'entrainement")
    total_training_time = 0.0  # record the total training time
    overall_start_time = time.time()  # record the start time of the whole training process

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0.0

        # record the start time of this epoch
        epoch_train_start = time.time()

        for x_batch, y_batch in dataloader:
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            # forward pass
            _, loss = model(x_batch, y_batch)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        # compute the time taken for this epoch
        epoch_train_time = time.time() - epoch_train_start
        total_training_time += epoch_train_time

        avg_train_loss = total_loss / len(dataloader)
        
        if use_metrics:
            # compute perplexity on the validation set
            eval_start = time.time()
            avg_loss, perplexity = estimate_perplexity(model, dataloader, criterion, device)
            eval_time = time.time() - eval_start
            print(f"Epoch {epoch+1}/{num_epochs}, Cumul average Loss: {avg_train_loss:.4f}, "
                  f"Train average Loss: {avg_loss:.4f}, Perplexity: {perplexity:.4f}, "
                  f"Train Time: {epoch_train_time:.2f}s, Eval Time: {eval_time:.2f}s")
        else:
            print(f"Epoch {epoch+1}/{num_epochs}, Train average Loss: {avg_train_loss:.4f}, "
                  f"Train Time: {epoch_train_time:.2f}s")

    overall_end_time = time.time()
    print(f"total training time (without evaluation): {total_training_time:.2f}s")
    print(f"total time (with eval): {overall_end_time - overall_start_time:.2f}s")

if __name__ == '__main__':
    train(use_metrics=True)

Model parameters: 719039
Vocabulary size: 10655
Total training tokens: 27.65M
开始训练
Epoch 1/10, Train Loss: 3.1741, Eval Loss: 0.5180, Perplexity: 1.6787, Train Time: 106.33s, Eval Time: 28.22s
Epoch 2/10, Train Loss: 0.2924, Eval Loss: 0.2090, Perplexity: 1.2325, Train Time: 89.34s, Eval Time: 46.21s
Epoch 3/10, Train Loss: 0.2021, Eval Loss: 0.1890, Perplexity: 1.2081, Train Time: 110.98s, Eval Time: 43.50s
Epoch 4/10, Train Loss: 0.1905, Eval Loss: 0.1807, Perplexity: 1.1981, Train Time: 110.59s, Eval Time: 43.74s
Epoch 5/10, Train Loss: 0.1840, Eval Loss: 0.1742, Perplexity: 1.1903, Train Time: 112.92s, Eval Time: 42.81s
Epoch 6/10, Train Loss: 0.1787, Eval Loss: 0.1701, Perplexity: 1.1854, Train Time: 110.73s, Eval Time: 41.87s
Epoch 7/10, Train Loss: 0.1744, Eval Loss: 0.1650, Perplexity: 1.1794, Train Time: 108.98s, Eval Time: 43.56s
Epoch 8/10, Train Loss: 0.1706, Eval Loss: 0.1616, Perplexity: 1.1754, Train Time: 114.55s, Eval Time: 43.05s
Epoch 9/10, Train Loss: 0.1672, Eval L