In [94]:
from pathlib import Path
import urllib.request
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

In [95]:
def download_shakespeare_text():
    path = Path("datasets/shakespeare/shakespeare.txt")
    if not path.is_file():
        path.parent.mkdir(parents=True, exist_ok=True)
    url = "https://homl.info/shakespeare"
    urllib.request.urlretrieve(url, path)
    return path.read_text()

In [96]:
shakespeare_text = download_shakespeare_text()
print(shakespeare_text[:80])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.


In [97]:
vocab = sorted(set(shakespeare_text))
"".join(vocab)

char_to_id = {char: index for index, char in enumerate(vocab)}
id_to_char = {index:char for char, index  in char_to_id.items()}



In [98]:
def encode_text(text):
    return torch.tensor([char_to_id[char] for char in text.lower()])

def decode_text(char_ids):
    return "".join([id_to_char[char_id.item()] for char_id in char_ids])
encoded = encode_text("Hello, world!")

decode_text(encoded)

'hello, world!'

In [99]:
class CharDataset(Dataset):
    def __init__(self, text: str, window_length: int):
        self.encoded_text = encode_text(text)
        self.window_length = window_length
    def __len__(self):
        return len(self.encoded_text) - self.window_length
    def __getitem__(self, index):
        if index >= len(self):
            raise IndexError("dataset index out of range")
        end = index + self.window_length
        window = self.encoded_text[index: end]
        target = self.encoded_text[index+1: end+1]
        return window, target


In [100]:
len(shakespeare_text)

1115394

In [101]:
window_length = 50
batch_size = 512
vocab_size = len(vocab) #65
device = torch.device("cuda")

train_set = CharDataset(shakespeare_text[:1_000_000], window_length)
valid_set = CharDataset(shakespeare_text[1_000_000: 1_060_000], window_length)
test_set = CharDataset(shakespeare_text[1_060_000:], window_length)

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_set, batch_size=batch_size)
test_loader = DataLoader(test_set, batch_size=batch_size)


x,y = next(iter(train_loader))

In [102]:
class ShakespeareModel(nn.Module):
    def __init__(self, vocab_size: int, n_layers: int = 2, embed_dim: int = 10, hidden_dim: int = 128, dropout:float = 0.1):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.gru = nn.GRU(
                        input_size=embed_dim,
                        hidden_size=hidden_dim,
                        num_layers=n_layers,
                        batch_first=True,
                        dropout=dropout
                              )
        self.output = nn.Linear(hidden_dim, vocab_size)

    def forward(self, X):
        """
        X shape: (B, seq_len)
        embeddings shapes: (B, seq_len, D)
        rnn_out: (B, seq_len, H)
        out: (B, seq_len, vocab_size)
        out: (B, vocab_size, seq_len) =(B, C, d_1)

        """
        embeddings = self.embed(X)
        rnn_out, hidden_state = self.gru(embeddings)
        out = self.output(rnn_out)
        out = out.permute(0, 2, 1)
        return out

model = ShakespeareModel(vocab_size=vocab_size)
model(x).shape

torch.Size([512, 65, 50])

#### Inference example

In [103]:
model.eval()
text = "inference exampleswr"
encoded_text = encode_text(text).unsqueeze(dim=0) #(18) => (1, 18)
with torch.no_grad():
    y_logits = model(encoded_text) # (B, C, d_1)
    """
    第0个batch, 所有vocab_size, 最后一个seq_len的索引
    """
    prediced_char_id = y_logits[0,:, -1].argmax().item()
    print(prediced_char_id)
    predicted_char = id_to_char[prediced_char_id]
    print(predicted_char)




57
s


In [104]:
probs = torch.tensor([[0.5, 0.4, 0.1]])

samples = torch.multinomial(probs,replacement=True,num_samples=8)
samples

tensor([[0, 1, 1, 0, 0, 1, 1, 0]])

### Decoding
Temperature: <1 模型更贪心， 高概率的更高 低概率的更低

In [None]:
def next_char(model: nn.Module, text: str, temperature:float = 1):
    "predict next character based on model"
    encoded_text = encode_text(text).unsqueeze(dim=0) #(seq_len) => (1,seq_len)
    with torch.no_grad():
        y_logits = model(encoded_text)
        """y_logits (B, vocab_size, seq_len)
            y_logits[0,:, -1] (vocab_size)
        """

        y_prob = F.softmax(y_logits[0, :, -1] / temperature, dim=-1)
        predicted_char_id = torch.multinomial(y_prob, num_samples=1, replacement=False).item()
        return id_to_char[predicted_char_id]

def extend_text(model, text, n_chars=80, temperature=1):
    for _ in range(n_chars):
        text += next_char(model, text, temperature)
    return text

In [109]:
extend_text(model, "to be nor no to ba")

"to be nor no to ba3\n$sWgvp:'r\nPUJa:x.ZH3j.?vrlcbXsDcddC&'$ZcjucE3S AEcHOpQXiszKb'I-UrGBlw'gImmwL-j"