Char generation using rnn

In [8]:
import torch
from torch import nn
import torch.optim as optim

import time
from tqdm import tqdm

from text_helpers import load_corpus, Vocab, batch_generator, encode_text, decode_text, generate_seed, sample_from_probs

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

Using device: cpu


In [10]:
class CharRNN(nn.Module):
    def __init__(self, vocab_size, embedding_size, hidden_size, num_layers):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.embedding = nn.Embedding(vocab_size, embedding_size)
        self.rnn = nn.RNN(embedding_size, hidden_size, num_layers, batch_first=False) # RNN expects (seq_len, batch_size, input_size)
        self.linear = nn.Linear(hidden_size, vocab_size) # to project RNN output to vocab size

        self.init_weights()

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    m.bias.data.fill_(0.01)
            elif isinstance(m, nn.RNN):
                for name, param in m.named_parameters():
                    if 'weight' in name:
                        nn.init.orthogonal_(param)
                    elif 'bias' in name:
                        nn.init.constant_(param, 0)

    def forward(self, input, hidden=None):
        # input shape: (seq_len, batch_size)
        embedded = self.embedding(input) # shape (seq_len, batch_size, embedding_size)

        if hidden is None:
            hidden = torch.zeros(self.num_layers, input.size(1), self.hidden_size).to(input.device) # input.size(1) = batch_size

        output, hidden = self.rnn(embedded, hidden) # output: (seq_len, batch_size, hidden_size), hidden: (num_layers, batch_size, hidden_size)
        output = self.linear(output) # output: (seq_len, batch_size, vocab_size)
        return output, hidden

def build_rnn(vocab_size, embedding_size, hidden_size, num_layers):
    return CharRNN(vocab_size, embedding_size, hidden_size, num_layers)

def train(net, train_loader, device, num_epochs, learning_rate, num_batches):
    optimizer = optim.Adam(net.parameters(), lr=learning_rate)
    loss_function = torch.nn.CrossEntropyLoss()
    loss_history = []

    net.train()

    with tqdm(total=num_batches*num_epochs, position=0, leave=True) as pbar:
        for epoch in range(num_epochs):
            running_loss = 0.0
            hidden = None # Initialize hidden state for each epoch

            for _ in range(num_batches):
                inputs, labels, *_ = next(train_loader)
                inputs = torch.from_numpy(inputs).to(device) # shape: [seq_len, batch_size]
                labels = torch.from_numpy(labels).to(device) # shape: [seq_len, batch_size]

                optimizer.zero_grad()

                output, hidden = net(inputs, hidden)
                hidden = hidden.detach() # Detach hidden state for next batch

                loss = loss_function(output.reshape(-1, VOCAB_SIZE), labels.reshape(-1))
                loss.backward()
                # Clip gradients to prevent exploding gradients (common in RNNs)
                torch.nn.utils.clip_grad_norm_(net.parameters(), 5)
                optimizer.step()
                running_loss += loss.item()
                pbar.update(1) # Update for each batch

            pbar.set_description("Epoch: %d, Loss: %.2f" % (epoch + 1, running_loss))
            loss_history.append(running_loss)

        pbar.close()

    return loss_history

# read text
corpus, vocab = load_corpus('/content/tinyshakespeare.txt', token_type = 'char')
VOCAB_SIZE = len(vocab)

print(vocab._token_to_idx) # print vocab index

EMMBEDDING_SIZE = 32
HIDDEN_SIZE = 128
NUM_LAYERS = 2

model = build_rnn(VOCAB_SIZE, EMMBEDDING_SIZE, HIDDEN_SIZE, NUM_LAYERS).to(device)
print(model)

BATCH_SIZE = 64
SEQ_LENGTH = 64
EPOCHS = 256
LR = 0.001

num_batches = (len(corpus) - 1) // (BATCH_SIZE * SEQ_LENGTH)
data_iter = batch_generator(encode_text(corpus, char2id=vocab), batch_size = BATCH_SIZE, seq_len=SEQ_LENGTH, vocab = vocab)

train(model, data_iter, device, EPOCHS, LR, num_batches)

seed = generate_seed(corpus)
encoded = encode_text(seed, char2id=vocab) # encode in 1D numpy array

print(encoded)

tensor = torch.tensor(encoded, dtype=torch.long).unsqueeze(1).to(device) #unsqueeze to make it (seq_len, 1)
print(tensor)

# Pass initial hidden state for inference example
output, hidden = model(tensor, hidden=None)
print(output.shape)

# output[-1, 0, :] selects the logits for the last character of the sequence, for the first (and only) batch item.
# This gives a 1D tensor of shape [vocab_size].
probs = torch.softmax(output[-1, 0, :], dim=0)

sampled_index = sample_from_probs(probs.detach().cpu().numpy(), top_n=4)
print(f"Sampled index: {sampled_index}")

def generate_text(model, seed_chars, length=256, top_n=4, vocab=None, device=None):
    model.eval()
    generated_chars_list = list(seed_chars)

    encoded_seed = encode_text(seed_chars, char2id=vocab) #numpy array
    input_tensor = torch.tensor(encoded_seed, dtype=torch.long).unsqueeze(1).to(device)
    hidden = None

    with torch.no_grad():
        if input_tensor.numel() > 0: # If seed is not empty
            output, hidden = model(input_tensor, hidden)
            last_logits = output[-1, 0, :] # shape: (vocab_size)
        else:
            random_idx = torch.randint(0, len(vocab._idx_to_token), (1,)).item()
            input_tensor = torch.tensor([[random_idx]], dtype=torch.long).to(device)
            output, hidden = model(input_tensor, hidden) # Process this single char to get initial hidden state and logits
            last_logits = output[0, 0, :]
            generated_chars_list.append(vocab._idx_to_token[random_idx])

    probs = torch.softmax(last_logits, dim=0)
    current_input_idx = sample_from_probs(probs.detach().cpu().numpy(), top_n=top_n)
    generated_chars_list.append(vocab._idx_to_token[current_input_idx])

    while len(generated_chars_list) < length:
        input_tensor_one_char = torch.tensor([[current_input_idx]], dtype=torch.long).to(device) # (1, 1)
        with torch.no_grad():
            output, hidden = model(input_tensor_one_char, hidden)
            # Output for a single character input will be (1, 1, vocab_size)
            logits = output[0, 0, :]

            probs = torch.softmax(logits, dim=0)
            current_input_idx = sample_from_probs(probs.detach().cpu().numpy(), top_n=top_n)
            generated_chars_list.append(vocab._idx_to_token[current_input_idx])

    model.train()
    return ''.join(generated_chars_list)

print("\nGenerating text after training:")
generated_text_output = generate_text(model, "KING RICHARD: ", length=512, top_n=4, vocab=vocab, device=device)
print(generated_text_output)


{'<unk>': 0, ' ': 1, 'e': 2, 't': 3, 'o': 4, 'a': 5, 'h': 6, 's': 7, 'r': 8, 'n': 9, 'i': 10, '\n': 11, 'l': 12, 'd': 13, 'u': 14, 'm': 15, 'y': 16, ',': 17, 'w': 18, 'f': 19, 'c': 20, 'g': 21, 'I': 22, 'b': 23, 'p': 24, ':': 25, '.': 26, 'A': 27, 'v': 28, 'k': 29, 'T': 30, "'": 31, 'E': 32, 'O': 33, 'N': 34, 'R': 35, 'S': 36, 'L': 37, 'C': 38, ';': 39, 'W': 40, 'U': 41, 'H': 42, 'M': 43, 'B': 44, '?': 45, 'G': 46, '!': 47, 'D': 48, '-': 49, 'F': 50, 'Y': 51, 'P': 52, 'K': 53, 'V': 54, 'j': 55, 'q': 56, 'x': 57, 'z': 58, 'J': 59, 'Q': 60, 'Z': 61, 'X': 62, '3': 63, '&': 64, '$': 65}
CharRNN(
  (embedding): Embedding(66, 32)
  (rnn): RNN(32, 128, num_layers=2)
  (linear): Linear(in_features=128, out_features=66, bias=True)
)


  0%|          | 2/69632 [00:00<1:23:07, 13.96it/s]

number of batches: 272.
effective text length: 1114112.
x shape:  (64, 17408)
y shape:  (64, 17408)


Epoch: 256, Loss: 656.33: 100%|██████████| 69632/69632 [1:14:33<00:00, 15.57it/s]


[5 1 7 2]
tensor([[5],
        [1],
        [7],
        [2]])
torch.Size([4, 1, 66])
Sampled index: 1

Generating text after training:
KING RICHARD: hinde h ar t at s areronore tharendeno at terese what wit se t s wnd,

INIUSAREONGLAne s anond wind this allor shenonder athie areshan to t athite herer tho athes t t thangheroutherd,'s alinou tore ther sere tore wicathouth at then t alllle thino an t and tharouse serer wed al t sher t tou thour theste tharenge hindere thin he s at se t s we t wond wing s shithand tonou al ar ander ande atouse sthes the s the herit angerono sther thint wore t te arond sthat t astour matho whis se whes whon the
