Будем обучать посимвольную RNN на текстах Шекспира

In [1]:
! mkdir shake

In [2]:
! wget -P shake https://raw.githubusercontent.com/cedricdeboom/character-level-rnn-datasets/refs/heads/master/datasets/shakespeare.txt

--2025-03-18 12:21:23--  https://raw.githubusercontent.com/cedricdeboom/character-level-rnn-datasets/refs/heads/master/datasets/shakespeare.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.109.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 6347705 (6.1M) [text/plain]
Saving to: ‘shake/shakespeare.txt’


2025-03-18 12:21:24 (8.38 MB/s) - ‘shake/shakespeare.txt’ saved [6347705/6347705]



In [3]:
import re

def preprocess(text):
    return re.sub(r'[^a-zA-Z0-9\s.,\'?!]', '', text)

In [4]:
from pathlib import Path

def save_preprocessed(path: str) -> Path:
    """
    Preprocess text and save it to disk.

    Returns:
        path to saved text
    """
    path_ = Path(path)
    with path_.open() as file:
        text = file.read()
    processed = preprocess(text)
    res = path_.parent / "processed.txt"
    with res.open("w") as file:
        file.write(processed)
    return res

In [5]:
path = save_preprocessed("shake/shakespeare.txt")

In [6]:
from math import floor
def train_test_split(path: str, test_size: float = 0.1) -> tuple[Path, Path]:
    """
    Split dataset to train and test.
    
    Returns:
        train_path and test_path
    """
    path_ = Path(path)
    with path_.open() as file:
        text = file.read()
    test_num_chars = floor(len(text) * test_size)
    test_text = text[-test_num_chars:]
    train_text = text[:-test_num_chars]
    train_path = path_.parent / "train.txt"
    with train_path.open("w") as file:
        file.write(train_text)
    test_path = path_.parent / "test.txt"
    with test_path.open("w") as file:
        file.write(test_text)
    return train_path, test_path

In [7]:
train_path, test_path = train_test_split(path)

In [8]:
import string

class CharTokenizer:
    def __init__(self) -> None:
        self.id2token = list(string.ascii_letters + string.digits + ".,\'?! \n")
        self.token2id = {char: token_id for token_id, char in enumerate(self.id2token)}
    
    def encode(self, txt: str) -> list[int]:
        return [self.token2id[tok] for tok in txt]
    
    def decode(self, token_ids: list[int]) -> str:
        return [self.id2token[tok_id] for tok_id in token_ids]

In [9]:
import torch
from torch.utils.data import Dataset
from pathlib import Path


class ShakespeareTexts(Dataset):
    def __init__(self, path: str, seq_length: int) -> None:
        self.seq_length = seq_length
        self.sequences = self._load_sequences(path)

        self.tokenizer = CharTokenizer()
    
    def _load_sequences(self, path: str) -> list[str]:
        res = []
        with Path(path).open() as file:
            while True:
                sequence = file.read(self.seq_length)
                if not sequence:
                    break
                res.append(sequence)
        if len(res[-1]) != self.seq_length:
            res.pop()
        return res
    
    def __len__(self) -> int:
        return len(self.sequences)
    
    def __getitem__(self, index: int) -> torch.LongTensor:
        """
        Returns:
            tokenized sequence
        """
        seq = self.sequences[index]
        token_ids = self.tokenizer.encode(seq)
        return torch.LongTensor(token_ids)
        

In [None]:
train_dataset = ShakespeareTexts(train_path, seq_length=32)
test_dataset = ShakespeareTexts(test_path, seq_length=32)
len(train_dataset), len(test_dataset)

(177198, 19688)

In [11]:
print(train_dataset[0]),
print(train_dataset.tokenizer.decode(train_dataset[0]))

tensor([53, 58, 52, 55, 68, 68, 26, 37, 37, 44, 67, 48, 30, 37, 37, 67, 45, 33,
        26, 45, 67, 30, 39, 29, 44, 67, 48, 30, 37, 37, 68, 68])
['1', '6', '0', '3', '\n', '\n', 'A', 'L', 'L', 'S', ' ', 'W', 'E', 'L', 'L', ' ', 'T', 'H', 'A', 'T', ' ', 'E', 'N', 'D', 'S', ' ', 'W', 'E', 'L', 'L', '\n', '\n']


In [12]:
vocab_size = len(train_dataset.tokenizer.id2token)
vocab_size

69

In [13]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [14]:
from torch import nn

class LanguageModel(nn.Module):
    def __init__(self, vocab_size: int, hidden_dim: int, embedding_dim: int, dropout_rate: float) -> None:
        super().__init__()

        self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)
        self.rnn = nn.GRUCell(input_size=embedding_dim, hidden_size=hidden_dim)
        self.dropout = nn.Dropout(dropout_rate)
        self.lm_head = nn.Linear(hidden_dim, out_features=vocab_size, bias=False)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Obtain logits for next token.

        Args:
            x: tensor with token ids, shape (B,T)
        
        Returns:
            logits: tensor with shape (B,T,vocab_size)
        """

        B, T = x.shape
        outputs = []
        for t in range(T):
            # (B,C)
            embedded_input = self.embedding(x[:, t])
            cell_output = self.rnn(embedded_input)
            outputs.append(cell_output)
        
        # (B,T,C)
        outputs = self.dropout(torch.stack(outputs, dim=1))

        # (B,T,vocab_size)
        return self.lm_head(outputs)

In [15]:
import torch

def collate_sequences(batch: list[torch.LongTensor]) -> tuple[torch.LongTensor, torch.LongTensor]:
    """
    Collate function for language modeling.

    Args:
        batch: List of tensors, where each tensor is a sequence of token IDs.

    Returns:
        tuple: A tuple containing:
            - inputs: Batch of input sequences.
            - targets: Batch of target sequences.
    """
    # (B,T)
    batch_tensor = torch.stack(batch)

    # (B,T-1)
    inputs = batch_tensor[:, :-1]

    # (B,T-1)
    targets = batch_tensor[:, 1:]

    return inputs, targets

In [16]:
from torch.utils.data import DataLoader

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, drop_last=True, num_workers=2, collate_fn=collate_sequences)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, drop_last=False, num_workers=2, collate_fn=collate_sequences)

In [17]:
import torch.nn.functional as F
from tqdm import tqdm

def evaluate(model: nn.Module, device: torch.DeviceObjType, data_loader: DataLoader) -> tuple[float, float]:
    model.eval()
    loss = 0
    correct = 0
    length = 0
    
    with torch.no_grad():
        for batch_idx, batch in enumerate(data_loader):
            # (B,T) and (B,T)
            inputs, targets = batch
            targets = targets.to(device)
            
            # (B,T,vocab_size)
            outputs: torch.Tensor = model(inputs.to(device))
            
            # (B,vocab_size,T)
            outputs = outputs.permute(0, 2, 1)

            loss += F.cross_entropy(outputs, targets, reduction="sum").item()

            # (B,T)
            pred = outputs.argmax(dim=1) 

            correct += (pred == targets).sum().item()
            length += inputs.shape[0] * inputs.shape[1]

    return loss / length, correct / length

def train(
        model: nn.Module,
        device: torch.DeviceObjType,
        train_loader: DataLoader,
        test_loader: DataLoader,
        n_epoch: int, optimizer: torch.optim.Optimizer,
        max_norm: float | None = None,
        scheduler: torch.optim.lr_scheduler.StepLR | None = None
    ):
    train_history, test_history = {'loss':[], 'acc':[]}, {'loss':[], 'acc':[]}
    
    steps_in_epoch = len(train_loader)

    for epoch in range(1, n_epoch + 1):
        model.train()
        for batch_idx, batch in tqdm(enumerate(train_loader), total=steps_in_epoch):
            # (B,T) and (B,T)
            inputs, targets = batch
            
            # (B,T,vocab_size)
            outputs: torch.Tensor = model(inputs.to(device))
            
            # (B,vocab_size,T)
            outputs = outputs.permute(0, 2, 1)

            loss = F.cross_entropy(outputs, targets.to(device), reduction="mean")
            
            optimizer.zero_grad()
            loss.backward()
            
            if max_norm is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_norm)

            optimizer.step()
            if scheduler is not None:
                scheduler.step()
            
        loss, acc = evaluate(model, device, train_loader)
        train_history['loss'].append(loss), train_history['acc'].append(acc)
        print("Train Loss:", loss, "Train Accuracy:", acc)

        loss, acc = evaluate(model, device, test_loader)
        test_history['loss'].append(loss), test_history['acc'].append(acc)
        print("Val Loss:", loss, "Val Accuracy:", acc)

        

    return train_history, test_history

In [18]:
model = LanguageModel(vocab_size=vocab_size, hidden_dim=64, embedding_dim=64, dropout_rate=0.2)
model.to(device)

LanguageModel(
  (embedding): Embedding(69, 64)
  (rnn): GRUCell(64, 64)
  (dropout): Dropout(p=0.2, inplace=False)
  (lm_head): Linear(in_features=64, out_features=69, bias=False)
)

In [19]:
evaluate(model, device, test_loader)

(4.2285711471434215, 0.025032441572400414)

In [20]:
optimizer = torch.optim.AdamW(lr=5e-4, params=model.parameters())
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2000, gamma=0.9)

In [21]:
train(model, device, train_loader, test_loader, n_epoch=10, optimizer=optimizer, max_norm=2, scheduler=scheduler)

100%|██████████| 5537/5537 [00:58<00:00, 94.83it/s]


Train Loss: 2.476620366579765 Train Accuracy: 0.3031628866219625
Val Loss: 2.6422508851669715 Val Accuracy: 0.22948807854137449


100%|██████████| 5537/5537 [00:59<00:00, 93.36it/s]


Train Loss: 2.470106620108295 Train Accuracy: 0.30306384614936466
Val Loss: 2.6279302891232255 Val Accuracy: 0.2292799937083011


100%|██████████| 5537/5537 [01:00<00:00, 91.36it/s]


Train Loss: 2.467459218205784 Train Accuracy: 0.3031681663530385
Val Loss: 2.627080635522518 Val Accuracy: 0.23031058709415264


100%|██████████| 5537/5537 [00:59<00:00, 92.90it/s]


Train Loss: 2.465727644427769 Train Accuracy: 0.30325792178133026
Val Loss: 2.6325273085902428 Val Accuracy: 0.2296945249111953


100%|██████████| 5537/5537 [01:00<00:00, 92.26it/s]


Train Loss: 2.4647673067977425 Train Accuracy: 0.30316270456227024
Val Loss: 2.622917025187174 Val Accuracy: 0.23010577918758437


100%|██████████| 5537/5537 [01:01<00:00, 89.69it/s]


Train Loss: 2.4643696782067344 Train Accuracy: 0.30284428216048054
Val Loss: 2.625208511249535 Val Accuracy: 0.22905716270595483


 94%|█████████▍| 5194/5537 [00:57<00:03, 90.78it/s]


KeyboardInterrupt: 