Text generation example with char-rnn

First to load data:

In [3]:
import requests

def download_text(url):
    # Send a HTTP request to the URL
    response = requests.get(url)
    
    # Check if the request was successful
    if response.status_code == 200:
        return response.text
    else:
        return None

# Example usage
url = 'https://homl.info/shakespeare'
text_data = download_text(url)
if text_data:
    print("Data downloaded successfully!")
else:
    print("Failed to download data.")


Data downloaded successfully!


Build Vocab and tokenize

In [4]:
vocab = set()
for char in text_data:
    vocab.add(char.lower())

vocab_mapping = {char:i+2 for i, char in enumerate(vocab)} # 0 for padding and 1 for unknown
vocab_mapping['<UNK>'] = 1
vocab_mapping['<PAD>'] = 0
print(vocab_mapping)

{'\n': 2, 'i': 3, '3': 4, 'r': 5, 'd': 6, 'a': 7, 'm': 8, ':': 9, 'l': 10, 'k': 11, 'g': 12, 'q': 13, '.': 14, 'b': 15, 'v': 16, 'x': 17, ' ': 18, 'o': 19, 'h': 20, 'y': 21, 'w': 22, '-': 23, 'z': 24, 'e': 25, '!': 26, 'u': 27, ',': 28, '&': 29, 's': 30, 'j': 31, 'p': 32, 'n': 33, "'": 34, 'c': 35, 'f': 36, ';': 37, '$': 38, '?': 39, 't': 40, '<UNK>': 1, '<PAD>': 0}


Data preparation

In [32]:
from torch.utils.data import Dataset, DataLoader
import torch

class TextTransform:
    def __init__(self, vocabulary):
        self.vocab = vocabulary
        self.n_tokens = len(vocabulary)
        reverse_map = [None] * len(vocabulary)
        for k, v in vocabulary.items():
            reverse_map[v] = k
        self.reverse_map = reverse_map

    def __call__(self, text):
        # Numericalize tokens
        return [self.vocab.get(char, 1) for char in text]

    
class MyDataset(Dataset):
    def __init__(self, data: str, transform=None, window_len = 100) -> None:
        self.data = data
        self.transform = transform
        self.window_len = window_len
    
    def __len__(self):
        return len(self.data) - self.window_len -1 # need next few tokens as label
    
    def __getitem__(self, idx):
        text = self.data[idx:idx+self.window_len]
        label = self.data[idx+1:idx+self.window_len+1]
        if self.transform:
            return self.transform(text), self.transform(label)
        else:
            return text, label


window_len = 100      
def collate_self_supervision(batch):
    texts, labels = zip(*batch)  # Unzip the tuples into separate lists
    texts_tensor = torch.tensor(texts, dtype=torch.long)  # Ensure dtype is long for indices
    labels_tensor = torch.tensor(labels, dtype=torch.long)  # Same for labels
    return texts_tensor, labels_tensor



transform = TextTransform(vocab_mapping)
dataset = MyDataset(text_data, transform, window_len)
data_loader = DataLoader(dataset, batch_size=2, collate_fn=collate_self_supervision)

example_data = data_loader._get_iterator()._next_data()
print(f'shapes of example data: {[t.shape for t in example_data]}')
print(example_data)


shapes of example data: [torch.Size([2, 100]), torch.Size([2, 100])]
(tensor([[ 1,  3,  5, 30, 40, 18,  1,  3, 40,  3, 24, 25, 33,  9,  2,  1, 25, 36,
         19,  5, 25, 18, 22, 25, 18, 32,  5, 19, 35, 25, 25,  6, 18,  7, 33, 21,
         18, 36, 27,  5, 40, 20, 25,  5, 28, 18, 20, 25,  7,  5, 18,  8, 25, 18,
         30, 32, 25,  7, 11, 14,  2,  2,  1, 10, 10,  9,  2,  1, 32, 25,  7, 11,
         28, 18, 30, 32, 25,  7, 11, 14,  2,  2,  1,  3,  5, 30, 40, 18,  1,  3,
         40,  3, 24, 25, 33,  9,  2,  1, 19, 27],
        [ 3,  5, 30, 40, 18,  1,  3, 40,  3, 24, 25, 33,  9,  2,  1, 25, 36, 19,
          5, 25, 18, 22, 25, 18, 32,  5, 19, 35, 25, 25,  6, 18,  7, 33, 21, 18,
         36, 27,  5, 40, 20, 25,  5, 28, 18, 20, 25,  7,  5, 18,  8, 25, 18, 30,
         32, 25,  7, 11, 14,  2,  2,  1, 10, 10,  9,  2,  1, 32, 25,  7, 11, 28,
         18, 30, 32, 25,  7, 11, 14,  2,  2,  1,  3,  5, 30, 40, 18,  1,  3, 40,
          3, 24, 25, 33,  9,  2,  1, 19, 27, 18]]), tensor([[ 3,  5, 3

Define model arch

In [33]:
import torch.nn as nn

class CharRNN(nn.Module):
    def __init__(self, n_tokens, emb_dim=16, GRU_dim=128):
        super(CharRNN, self).__init__() 
        self.embedding = nn.Embedding(n_tokens, emb_dim, padding_idx=0)
        self.gru = nn.GRU(emb_dim, GRU_dim, batch_first=True)
        self.output = nn.Sequential(
            nn.Linear(GRU_dim, n_tokens),
            nn.Softmax(dim=2),
        )
    
    def forward(self, input):
        emb = self.embedding(input)  # [batch, window_len, emb_dim]
        seq, _h_n = self.gru(emb)    # [batch, window_len, GRU_dim]
        return self.output(seq)       # [batch, window_len, n_tokens]



def generate(prompt: str, device='cuda') -> str:
    input_tensor = torch.tensor(transform(prompt)).unsqueeze(0).to(device)
    output_tensors = model(input_tensor)  # Ensure this is on the right device
    output_tokens = torch.argmax(output_tensors, dim=2).squeeze()
    # print(output_tokens)
    return ''.join([transform.reverse_map[idx] for idx in output_tokens])

In [34]:
# Define your device
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Move model to device
model = CharRNN(transform.n_tokens).to(device)
prompt = 'To be or not to be'
print(
    f"""
    Before training:
    prompt: {prompt}
    next seq: {generate(prompt, device)}
    """
)


    Before training:
    prompt: To be or not to be
    next seq: rrrrqcuwaruaaaarrq
    


Now train it!

In [35]:
from torch.optim import SGD, Adam

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# optimizer = SGD(model.parameters(), lr=0.001, weight_decay=0.01)
optimizer = Adam(model.parameters(), lr=0.1)
criterion = nn.CrossEntropyLoss().to(device)
num_epoch = 2

data_loader = DataLoader(dataset, batch_size=2, collate_fn=collate_self_supervision, shuffle=True)
for k in range(1, num_epoch+1):
    print(f'--- epoch {k} ---')
    try:
        for i, data in enumerate(data_loader):
            optimizer.zero_grad()

            texts, label = data
            pred = model(texts.to(device)) # batch, seq_len, n_tokens
            loss = criterion(pred.transpose(1, 2), label.to(device))
            loss.backward()
            optimizer.step()

            if i % 1000 == 0:
                print(f'{i} batches processed({i/len(data_loader)}%), ce loss {loss}')

    except Exception as e:
        print(f"An error occurred: {e}, batch idx {i}, epoch idx {k}")

    print(f"After epoch {k}, ce loss is {loss}")   


--- epoch 1 ---
0 batches processed(0.0%), ce loss 3.714027166366577
1000 batches processed(0.0017932491343089804%), ce loss 3.554555892944336
2000 batches processed(0.003586498268617961%), ce loss 3.559589147567749
3000 batches processed(0.005379747402926942%), ce loss 3.5545849800109863
4000 batches processed(0.007172996537235922%), ce loss 3.5093657970428467
5000 batches processed(0.008966245671544902%), ce loss 3.509620428085327
6000 batches processed(0.010759494805853883%), ce loss 3.489236354827881
7000 batches processed(0.012552743940162862%), ce loss 3.4746060371398926
8000 batches processed(0.014345993074471843%), ce loss 3.484626531600952
9000 batches processed(0.016139242208780824%), ce loss 3.519594669342041
10000 batches processed(0.017932491343089805%), ce loss 3.509589433670044
11000 batches processed(0.019725740477398786%), ce loss 3.479557514190674
12000 batches processed(0.021518989611707767%), ce loss 3.579606294631958
13000 batches processed(0.023312238746016744%), 

In [37]:
prompt = 'To be or not to be'
print(
    f"""
    Before training:
    prompt: {prompt}
    next seq: {generate('To be or not to be')}
    """
)


    Before training:
    prompt: To be or not to be
    next seq:   ao aoaao  ao ao 
    


TODO: 
- training model
- test