Text generation example with char-rnn

First to load data:

In [1]:
import requests

def download_text(url):
    # Send a HTTP request to the URL
    response = requests.get(url)
    
    # Check if the request was successful
    if response.status_code == 200:
        return response.text
    else:
        return None

# Example usage
url = 'https://homl.info/shakespeare'
text_data = download_text(url)
if text_data:
    print("Data downloaded successfully!")
else:
    print("Failed to download data.")


Data downloaded successfully!


Build Vocab and tokenize

In [2]:
vocab = set()
for char in text_data:
    vocab.add(char.lower())

vocab_mapping = {char:i+2 for i, char in enumerate(vocab)} # 0 for padding and 1 for unknown
vocab_mapping['<UNK>'] = 1
vocab_mapping['<PAD>'] = 0
print(vocab_mapping)

{',': 2, 'y': 3, '&': 4, ' ': 5, '.': 6, 'j': 7, 'd': 8, 'v': 9, 'o': 10, 'a': 11, 'u': 12, 'x': 13, 'e': 14, 'w': 15, ':': 16, '-': 17, 'b': 18, '\n': 19, 'm': 20, 'i': 21, 'z': 22, '$': 23, 't': 24, '3': 25, 'p': 26, '!': 27, 'k': 28, 'f': 29, "'": 30, '?': 31, 'n': 32, 'g': 33, 's': 34, 'q': 35, 'h': 36, 'r': 37, ';': 38, 'c': 39, 'l': 40, '<UNK>': 1, '<PAD>': 0}


Data preparation

In [3]:
from torch.utils.data import Dataset, DataLoader
import torch

class TextTransform:
    def __init__(self, vocabulary):
        self.vocab = vocabulary
        self.n_tokens = len(vocabulary)
        reverse_map = [None] * len(vocabulary)
        for k, v in vocabulary.items():
            reverse_map[v] = k
        self.reverse_map = reverse_map

    def __call__(self, text):
        # Numericalize tokens
        return [self.vocab.get(char, 1) for char in text]

    
class MyDataset(Dataset):
    def __init__(self, data: str, transform=None, window_len = 100) -> None:
        self.data = data
        self.transform = transform
        self.window_len = window_len
    
    def __len__(self):
        return len(self.data) - self.window_len -1 # need next few tokens as label
    
    def __getitem__(self, idx):
        text = self.data[idx:idx+self.window_len]
        label = self.data[idx+1:idx+self.window_len+1]
        if self.transform:
            return self.transform(text), self.transform(label)
        else:
            return text, label


window_len = 100      
def collate_self_supervision(batch):
    texts, labels = zip(batch)
    return torch.tensor(texts).squeeze(), torch.tensor(labels).squeeze()


transform = TextTransform(vocab_mapping)
dataset = MyDataset(text_data, transform, window_len)
data_loader = DataLoader(dataset, batch_size=2, collate_fn=collate_self_supervision)

example_data = data_loader._get_iterator()._next_data()
print(f'shapes of example data: {[t.shape for t in example_data]}')
print(example_data)


shapes of example data: [torch.Size([2, 100]), torch.Size([2, 100])]
(tensor([[ 1, 21, 37, 34, 24,  5,  1, 21, 24, 21, 22, 14, 32, 16, 19,  1, 14, 29,
         10, 37, 14,  5, 15, 14,  5, 26, 37, 10, 39, 14, 14,  8,  5, 11, 32,  3,
          5, 29, 12, 37, 24, 36, 14, 37,  2,  5, 36, 14, 11, 37,  5, 20, 14,  5,
         34, 26, 14, 11, 28,  6, 19, 19,  1, 40, 40, 16, 19,  1, 26, 14, 11, 28,
          2,  5, 34, 26, 14, 11, 28,  6, 19, 19,  1, 21, 37, 34, 24,  5,  1, 21,
         24, 21, 22, 14, 32, 16, 19,  1, 10, 12],
        [21, 37, 34, 24,  5,  1, 21, 24, 21, 22, 14, 32, 16, 19,  1, 14, 29, 10,
         37, 14,  5, 15, 14,  5, 26, 37, 10, 39, 14, 14,  8,  5, 11, 32,  3,  5,
         29, 12, 37, 24, 36, 14, 37,  2,  5, 36, 14, 11, 37,  5, 20, 14,  5, 34,
         26, 14, 11, 28,  6, 19, 19,  1, 40, 40, 16, 19,  1, 26, 14, 11, 28,  2,
          5, 34, 26, 14, 11, 28,  6, 19, 19,  1, 21, 37, 34, 24,  5,  1, 21, 24,
         21, 22, 14, 32, 16, 19,  1, 10, 12,  5]]), tensor([[21, 37, 3

Define model arch

In [7]:
import torch.nn as nn

class CharRNN(nn.Module):
    def __init__(self,n_tokens, emb_dim=16, GRU_dim=128):
        super(CharRNN, self).__init__() 
        self.embedding = nn.Embedding(n_tokens, emb_dim, padding_idx=0)
        self.gru = nn.GRU(emb_dim, GRU_dim, batch_first=True)
        self.output = nn.Sequential(
            nn.Linear(GRU_dim, n_tokens),
            nn.Softmax(dim=2),
        )
    
    def forward(self, input):
        emb = self.embedding(input) #[batch, window_len, emb_dim]
        seq, _h_n = self.gru(emb) #[batch, window_len, GRU_dim]
        return self.output(seq) # [batch, window_len, n_tokens]

# transform defined earlier
model = CharRNN(transform.n_tokens)

def generate(prompt:str) -> str:

    output_tensors = model(torch.tensor(transform(prompt)).unsqueeze(0))    
    output_tokens = torch.argmax(output_tensors, dim=2).squeeze()
    return ''.join([transform.reverse_map[idx] for idx in output_tokens])

prompt = 'To be or not to be'
print(
    f"""
    Before training:
    prompt: {prompt}
    next seq: {generate('To be or not to be')}
    """
)


    Before training:
    prompt: To be or not to be
    next seq: ft&ttdtttq&v&$&&tt
    


Now train it!

In [9]:
from torch.optim import SGD, Adam

# optimizer = SGD(model.parameters(), lr=0.001, weight_decay=0.01)
optimizer = Adam(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
num_epoch = 2

data_loader = DataLoader(dataset, batch_size=2, collate_fn=collate_self_supervision, shuffle=True)
for k in range(1, num_epoch+1):
    print(f'--- epoch {k} ---')
    try:
        for i, data in enumerate(data_loader):
            optimizer.zero_grad()

            texts, label = data
            pred = model(texts) # batch, seq_len, n_tokens
            loss = criterion(pred.transpose(1, 2), label)
            loss.backward()
            optimizer.step()

            if i % 1000 == 0:
                print(f'{i} batches processed({i/len(data_loader)}%), ce loss {loss}')

    except Exception as e:
        print(f"An error occurred: {e}, batch idx {i}, epoch idx {k}")

    print(f"After epoch {k}, ce loss is {loss}")   


--- epoch 1 ---
0 batches processed(0.0%), ce loss 3.7131435871124268
1000 batches processed(0.0017932491343089804%), ce loss 3.584627628326416
2000 batches processed(0.003586498268617961%), ce loss 3.6146273612976074
3000 batches processed(0.005379747402926942%), ce loss 3.5646276473999023
4000 batches processed(0.007172996537235922%), ce loss 3.5946285724639893
5000 batches processed(0.008966245671544902%), ce loss 3.6296277046203613
6000 batches processed(0.010759494805853883%), ce loss 3.6146275997161865
7000 batches processed(0.012552743940162862%), ce loss 3.6146275997161865
8000 batches processed(0.014345993074471843%), ce loss 3.594627618789673
9000 batches processed(0.016139242208780824%), ce loss 3.639627695083618
10000 batches processed(0.017932491343089805%), ce loss 3.6346275806427
11000 batches processed(0.019725740477398786%), ce loss 3.6246275901794434
12000 batches processed(0.021518989611707767%), ce loss 3.6046276092529297
13000 batches processed(0.023312238746016744

In [14]:
prompt = 'To be or not to be'
print(
    f"""
    Before training:
    prompt: {prompt}
    next seq: {generate('To be or not to be')}
    """
)


    Before training:
    prompt: To be or not to be
    next seq:                   
    


TODO: 
- training model
- test