Text generation example with char-rnn

First to load data:

In [2]:
import requests

def download_text(url):
    # Send a HTTP request to the URL
    response = requests.get(url)
    
    # Check if the request was successful
    if response.status_code == 200:
        return response.text
    else:
        return None

# Example usage
url = 'https://homl.info/shakespeare'
text_data = download_text(url)
if text_data:
    print("Data downloaded successfully!")
else:
    print("Failed to download data.")


Data downloaded successfully!


Build Vocab and tokenize

In [3]:
vocab = set()
for char in text_data:
    vocab.add(char.lower())

vocab_mapping = {char:i+2 for i, char in enumerate(vocab)} # 0 for padding and 1 for unknown
vocab_mapping['<UNK>'] = 1
vocab_mapping['<PAD>'] = 0
print(vocab_mapping)

{'3': 2, 's': 3, '-': 4, 'o': 5, 'q': 6, 'f': 7, 'z': 8, 'd': 9, 'k': 10, '&': 11, 'l': 12, 't': 13, 'c': 14, 'n': 15, ':': 16, 'm': 17, '.': 18, ' ': 19, 'h': 20, 'b': 21, 'j': 22, 'r': 23, 'u': 24, 'v': 25, ';': 26, 'p': 27, '\n': 28, 'g': 29, 'a': 30, 'e': 31, '?': 32, 'w': 33, ',': 34, 'x': 35, '!': 36, '$': 37, "'": 38, 'i': 39, 'y': 40, '<UNK>': 1, '<PAD>': 0}


Data preparation

In [10]:
from torch.utils.data import Dataset, DataLoader
import torch

class TextTransform:
    def __init__(self, vocabulary):
        self.vocab = vocabulary

    def __call__(self, text):
        # Numericalize tokens
        return [self.vocab.get(char, 1) for char in text]
    
class MyDataset(Dataset):
    def __init__(self, data: str, transform=None, window_len = 100) -> None:
        self.data = data
        self.transform = transform
        self.window_len = window_len
    
    def __len__(self):
        return len(self.data) - 2 * self.window_len # need next few tokens as label
    
    def __getitem__(self, idx):
        text = self.data[idx:idx+self.window_len]
        label = self.data[idx+self.window_len:idx+2*self.window_len]
        if self.transform:
            return self.transform(text), self.transform(label)
        else:
            return text, label


window_len = 100      
def collate_self_supervision(batch):
    texts, labels = zip(batch)
    return torch.tensor(texts), torch.tensor(labels)



transform = TextTransform(vocab_mapping)
dataset = MyDataset(text_data, transform, window_len)
data_loader = DataLoader(dataset, batch_size=2, collate_fn=collate_self_supervision)

print(data_loader._get_iterator()._next_data())


(tensor([[[ 1, 39, 23,  3, 13, 19,  1, 39, 13, 39,  8, 31, 15, 16, 28,  1, 31,
           7,  5, 23, 31, 19, 33, 31, 19, 27, 23,  5, 14, 31, 31,  9, 19, 30,
          15, 40, 19,  7, 24, 23, 13, 20, 31, 23, 34, 19, 20, 31, 30, 23, 19,
          17, 31, 19,  3, 27, 31, 30, 10, 18, 28, 28,  1, 12, 12, 16, 28,  1,
          27, 31, 30, 10, 34, 19,  3, 27, 31, 30, 10, 18, 28, 28,  1, 39, 23,
           3, 13, 19,  1, 39, 13, 39,  8, 31, 15, 16, 28,  1,  5, 24],
         [19, 30, 23, 31, 19, 30, 12, 12, 19, 23, 31,  3,  5, 12, 25, 31,  9,
          19, 23, 30, 13, 20, 31, 23, 19, 13,  5, 19,  9, 39, 31, 19, 13, 20,
          30, 15, 19, 13,  5, 19,  7, 30, 17, 39,  3, 20, 32, 28, 28,  1, 12,
          12, 16, 28,  1, 31,  3,  5, 12, 25, 31,  9, 18, 19, 23, 31,  3,  5,
          12, 25, 31,  9, 18, 28, 28,  1, 39, 23,  3, 13, 19,  1, 39, 13, 39,
           8, 31, 15, 16, 28,  1, 39, 23,  3, 13, 34, 19, 40,  5, 24]]]), tensor([[[39, 23,  3, 13, 19,  1, 39, 13, 39,  8, 31, 15, 16, 28,  1, 31, 

Define model arch

In [11]:
import torch.nn as nn

class CharRNN(nn.Module):
    def __init__(self,n_tokens, window_len, emb_dim=16, GRU_dim=128):
        super(CharRNN, self).__init__() 
        self.embedding = nn.Embedding(n_tokens, emb_dim, padding_idx=0)
        self.gru = nn.GRU(emb_dim, GRU_dim, batch_first=True)
        self.output = nn.Sequential(
            nn.Linear(GRU_dim, n_tokens),
            nn.Softmax(dim=2),
        )
    
    def forward(self, input):
        emb = self.embedding(input) #[batch, window_len, emb_dim]
        seq, _h_n = self.gru(emb) #[batch, window_len, GRU_dim]
        return self.output(seq) # [batch, window_len, n_tokens]


TODO: 
- prediction utils
- training model
- test