<a href="https://colab.research.google.com/github/leileqiTHU/Data-Analysis/blob/master/workshop/char_rnn/Char-RNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2025-11-17 03:25:49--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt.2’


2025-11-17 03:25:49 (30.2 MB/s) - ‘input.txt.2’ saved [1115394/1115394]



In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import random
import math

# =========================
# 1. Config
# =========================

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

TEXT_PATH = "input.txt"  # path to your corpus
SEQ_LEN = 100                  # length of input sequence (characters)
BATCH_SIZE = 2048
EMBED_DIM = 256
HIDDEN_DIM = 512
NUM_LAYERS = 2
LR = 3e-3
EPOCHS = 10                    # increase for better results


# =========================
# 2. Data Loading & Preprocessing
# =========================

with open(TEXT_PATH, "r", encoding="utf-8") as f:
    text = f.read()

print(f"Corpus length (chars): {len(text)}")

# Build vocabulary (character-level)
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(f"Vocab size: {vocab_size}")

stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for ch, i in stoi.items()}

def encode(s: str):
    return [stoi[c] for c in s]

def decode(indices):
    return "".join(itos[i] for i in indices)

data = torch.tensor(encode(text), dtype=torch.long)


# =========================
# 3. Dataset & Dataloader
# =========================

class ShakespeareDataset(Dataset):
    def __init__(self, data_tensor, seq_len):
        self.data = data_tensor
        self.seq_len = seq_len

    def __len__(self):
        # last sequence can't start past len - seq_len - 1
        return len(self.data) - self.seq_len

    def __getitem__(self, idx):
        x = self.data[idx : idx + self.seq_len]
        y = self.data[idx + 1 : idx + self.seq_len + 1]
        return x, y

dataset = ShakespeareDataset(data, SEQ_LEN)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

print(f"Number of sequences: {len(dataset)}")


# =========================
# 4. Model Definition
# =========================

class CharLSTM(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(
            input_size=embed_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
        )
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x, hidden=None):
        # x: (batch, seq_len)
        x = self.embed(x)  # (batch, seq_len, embed_dim)
        out, hidden = self.lstm(x, hidden)  # out: (batch, seq_len, hidden_dim)
        logits = self.fc(out)               # (batch, seq_len, vocab_size)
        return logits, hidden

model = CharLSTM(vocab_size, EMBED_DIM, HIDDEN_DIM, NUM_LAYERS).to(DEVICE)
print(model)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)


# =========================
# 5. Training Loop
# =========================
from tqdm import tqdm
def train():
    print("start training...")
    model.train()
    for epoch in range(1, EPOCHS + 1):
        total_loss = 0.0
        num_batches = 0

        for x, y in tqdm(dataloader, total=len(dataloader)):
            x = x.to(DEVICE)
            y = y.to(DEVICE)

            optimizer.zero_grad()
            logits, _ = model(x)

            # reshape for cross-entropy:
            # logits: (batch*seq_len, vocab), y: (batch*seq_len)
            loss = criterion(
                logits.reshape(-1, vocab_size),
                y.reshape(-1)
            )

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            total_loss += loss.item()
            num_batches += 1

        avg_loss = total_loss / num_batches
        print(f"Epoch [{epoch}/{EPOCHS}] - loss: {avg_loss:.4f}")

        # sample a bit of text each epoch
        print("\nSample generation:")
        print(generate_text("ROMEO:", max_new_tokens=400))
        print("=" * 80)


# =========================
# 6. Text Generation
# =========================

@torch.no_grad()
def generate_text(start_text="ROMEO:", max_new_tokens=500, temperature=1.0):
    model.eval()

    # encode the prompt
    input_ids = torch.tensor(encode(start_text), dtype=torch.long, device=DEVICE).unsqueeze(0)  # (1, seq)
    hidden = None

    generated = input_ids.tolist()[0]

    for _ in range(max_new_tokens):
        logits, hidden = model(input_ids, hidden)
        # select last time step
        logits = logits[:, -1, :] / temperature  # (1, vocab_size)
        probs = torch.softmax(logits, dim=-1)

        # sample next char
        next_id = torch.multinomial(probs, num_samples=1).item()
        generated.append(next_id)

        # next input is the sample we just produced
        input_ids = torch.tensor([[next_id]], dtype=torch.long, device=DEVICE)

    return decode(generated)


if __name__ == "__main__":
    train()

    # save the model
    torch.save(model.state_dict(), "char_lstm_shakespeare.pt")
    print("Model saved to char_lstm_shakespeare.pt")

    # final longer sample
    print("\nFinal sample:")
    print(generate_text("To be, or not to be", max_new_tokens=600, temperature=0.8))


Corpus length (chars): 1115394
Vocab size: 65
Number of sequences: 1115294
CharLSTM(
  (embed): Embedding(65, 256)
  (lstm): LSTM(256, 512, num_layers=2, batch_first=True)
  (fc): Linear(in_features=512, out_features=65, bias=True)
)
start training...


 21%|██        | 115/544 [02:52<10:45,  1.51s/it]