Text generation example with char-rnn

First to load data:

In [4]:
import requests

def download_text(url):
    # Send a HTTP request to the URL
    response = requests.get(url)
    
    # Check if the request was successful
    if response.status_code == 200:
        return response.text
    else:
        return None

# Example usage
url = 'https://homl.info/shakespeare'
text_data = download_text(url)
if text_data:
    print("Data downloaded successfully!")
else:
    print("Failed to download data.")


Data downloaded successfully!


Build Vocab and tokenize

In [68]:
vocab = set()
for char in text_data:
    if char != '\n':
        vocab.add(char.lower())

vocab_mapping = {char:i+2 for i, char in enumerate(vocab)} # 0 for padding and 1 for unknown
vocab_mapping['<UNK>'] = 1
vocab_mapping['<PAD>'] = 0
print(vocab_mapping)

{'o': 2, 'm': 3, 'v': 4, 'u': 5, 'h': 6, "'": 7, 't': 8, 'w': 9, 'b': 10, 'p': 11, 'c': 12, '&': 13, '-': 14, ' ': 15, '!': 16, 'j': 17, 'r': 18, 'q': 19, '$': 20, ';': 21, ':': 22, 'e': 23, '?': 24, 's': 25, 'n': 26, 'z': 27, 'i': 28, 'l': 29, 'g': 30, ',': 31, 'y': 32, '3': 33, 'x': 34, 'a': 35, 'd': 36, 'k': 37, '.': 38, 'f': 39, '<UNK>': 1, '<PAD>': 0}


Data preparation

In [69]:
from torch.utils.data import Dataset, DataLoader
import torch

class TextTransform:
    def __init__(self, vocabulary):
        self.vocab = vocabulary
        self.n_tokens = len(vocabulary)
        reverse_map = [None] * len(vocabulary)
        for k, v in vocabulary.items():
            reverse_map[v] = k
        self.reverse_map = reverse_map

    def __call__(self, text):
        # Numericalize tokens
        return [self.vocab.get(char.lower(), 1) for char in text]

    
class MyDataset(Dataset):
    def __init__(self, data: str, transform=None, window_len = 100) -> None:
        self.data = data.replace('\n', '').lower()
        self.transform = transform
        self.window_len = window_len
    
    def __len__(self):
        return len(self.data) - self.window_len -1 # need next few tokens as label
    
    def __getitem__(self, idx):
        text = self.data[idx:idx+self.window_len]
        label = self.data[idx+1:idx+self.window_len+1]
        if self.transform:
            return self.transform(text), self.transform(label)
        else:
            return text, label


window_len = 100      
def collate_self_supervision(batch):
    texts, labels = zip(*batch)  # Unzip the tuples into separate lists
    texts_tensor = torch.tensor(texts, dtype=torch.long)  # Ensure dtype is long for indices
    labels_tensor = torch.tensor(labels, dtype=torch.long)  # Same for labels
    return texts_tensor, labels_tensor



transform = TextTransform(vocab_mapping)
dataset = MyDataset(text_data, transform, window_len)
data_loader = DataLoader(dataset, batch_size=2, collate_fn=collate_self_supervision)

example_data = data_loader._get_iterator()._next_data()
print(f'shapes of example data: {[t.shape for t in example_data]}')
print(example_data)


shapes of example data: [torch.Size([2, 100]), torch.Size([2, 100])]
(tensor([[39, 28, 18, 25,  8, 15, 12, 28,  8, 28, 27, 23, 26, 22, 10, 23, 39,  2,
         18, 23, 15,  9, 23, 15, 11, 18,  2, 12, 23, 23, 36, 15, 35, 26, 32, 15,
         39,  5, 18,  8,  6, 23, 18, 31, 15,  6, 23, 35, 18, 15,  3, 23, 15, 25,
         11, 23, 35, 37, 38, 35, 29, 29, 22, 25, 11, 23, 35, 37, 31, 15, 25, 11,
         23, 35, 37, 38, 39, 28, 18, 25,  8, 15, 12, 28,  8, 28, 27, 23, 26, 22,
         32,  2,  5, 15, 35, 18, 23, 15, 35, 29],
        [28, 18, 25,  8, 15, 12, 28,  8, 28, 27, 23, 26, 22, 10, 23, 39,  2, 18,
         23, 15,  9, 23, 15, 11, 18,  2, 12, 23, 23, 36, 15, 35, 26, 32, 15, 39,
          5, 18,  8,  6, 23, 18, 31, 15,  6, 23, 35, 18, 15,  3, 23, 15, 25, 11,
         23, 35, 37, 38, 35, 29, 29, 22, 25, 11, 23, 35, 37, 31, 15, 25, 11, 23,
         35, 37, 38, 39, 28, 18, 25,  8, 15, 12, 28,  8, 28, 27, 23, 26, 22, 32,
          2,  5, 15, 35, 18, 23, 15, 35, 29, 29]]), tensor([[28, 18, 2

Define model arch

In [108]:
import torch.nn as nn

# Define your device
device = 'cuda' if torch.cuda.is_available() else 'cpu'


class CharRNN(nn.Module):
    def __init__(self, n_tokens, emb_dim=16, GRU_dim=128):
        super(CharRNN, self).__init__() 
        self.embedding = nn.Embedding(n_tokens, emb_dim, padding_idx=0)
        self.gru = nn.GRU(emb_dim, GRU_dim, batch_first=True)
        self.output = nn.Sequential(
            nn.Linear(GRU_dim, 2*n_tokens),
            nn.ReLU(),
            nn.Linear(2*n_tokens, n_tokens),
        )
        self.GRU_dim = GRU_dim
    
    def forward(self, input):
        emb = self.embedding(input)  # [batch, window_len, emb_dim]
        seq, _h_n = self.gru(emb)    # [batch, window_len, GRU_dim]
        return self.output(seq)       # [batch, window_len, n_tokens]
    
    def generate(self, input_tensor, length=5, temperature=0.8):
        N = len(input_tensor)
        h = torch.zeros(size=(1, N, self.GRU_dim)).to(device)

        output = []

        for _ in range(length):
            if len(output) == 0:
                input = input_tensor
            else:
                input = output[-1]
            emb = self.embedding(input)  # [batch, window_len, emb_dim]
            seq, h = self.gru(emb, h)    # [batch, window_len, GRU_dim]
            logits_last = self.output(seq)[:,-1,:]      # [batch, 1, n_tokens]
            prob = torch.softmax(logits_last.squeeze(1),dim=1) # [batch, n_tokens]
            next_token = torch.multinomial(prob, num_samples=1) #[batch, 1]
            output.append(next_token) # [batch, 1]

        return output #[batch, length]




def generate(prompt: str, model, length=5, temperature=0.8) -> str:
    input_tensor = torch.tensor(transform(prompt)).unsqueeze(0).to(device)
    output_tokens = model.generate(input_tensor, length, temperature)  # Ensure this is on the right device
    # print(output_tokens)
    return ''.join([transform.reverse_map[idx] for idx in output_tokens])

In [109]:

# Move model to device
model = CharRNN(transform.n_tokens).to(device)
prompt = 'To be or not to be'
print(
    f"""
    Before training:
    prompt: {prompt}
    next seq: {generate(prompt, model, temperature=0.3)}
    """
)


    Before training:
    prompt: To be or not to be
    next seq: b chc
    


Now train it!

In [110]:
from torch.optim import SGD, Adam

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# optimizer = SGD(model.parameters(), lr=0.001, weight_decay=0.01)
optimizer = Adam(model.parameters(), lr=0.001)
# optimizer = SGD(model.parameters(), lr=0.01, weight_decay=0.01)
criterion = nn.NLLLoss().to(device)
num_epoch = 1

data_loader = DataLoader(dataset, batch_size=2, collate_fn=collate_self_supervision, shuffle=True)
for k in range(1, num_epoch+1):
    print(f'--- epoch {k} ---')
    try:
        for i, data in enumerate(data_loader):
            optimizer.zero_grad()

            texts, label = data     # label shape: [batch, seq_len]
            pred = model(texts.to(device)) # batch, seq_len, n_tokens
            loss = criterion(pred.transpose(1, 2), label.to(device))
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0)
            optimizer.step()

            if i % 1000 == 0:
                print(f'{i} batches processed({i/len(data_loader)}%), ce loss {loss}')

    except Exception as e:
        print(f"An error occurred: {e}, batch idx {i}, epoch idx {k}")

    print(f"After epoch {k}, ce loss is {loss}")   


--- epoch 1 ---
0 batches processed(0.0%), ce loss -0.002575966063886881
1000 batches processed(0.0018599564398201795%), ce loss -6388.337890625
2000 batches processed(0.003719912879640359%), ce loss -25402.939453125
3000 batches processed(0.005579869319460538%), ce loss -57269.44921875
4000 batches processed(0.007439825759280718%), ce loss -100038.9375
5000 batches processed(0.009299782199100897%), ce loss -157385.125
6000 batches processed(0.011159738638921076%), ce loss -226237.875
7000 batches processed(0.013019695078741256%), ce loss -308572.40625
8000 batches processed(0.014879651518561436%), ce loss -403917.1875
9000 batches processed(0.016739607958381614%), ce loss -506824.40625
10000 batches processed(0.018599564398201793%), ce loss -636673.75
11000 batches processed(0.020459520838021973%), ce loss -757458.9375
12000 batches processed(0.022319477277842153%), ce loss -909946.0625
13000 batches processed(0.024179433717662333%), ce loss -1075204.625
14000 batches processed(0.0260

In [95]:
prompt = 'Barack Obama was born in Honolulu, Hawaii. He was born in'
prompt = 'to be or not to be'
print(
    f"""
    After training:
    prompt: {prompt}
    next seq: {generate(prompt, model, length=200)}
    """
)

TypeError: CharRNN.generate() takes from 2 to 3 positional arguments but 4 were given

TODO:
- sampling and temperature