# Using PyTorch to Generate Spongebob Transcripts

1. import libraries we will depend on.

In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import Counter
from argparse import Namespace

2. Set variables for this project

In [2]:
flags = Namespace(
    train_file='spongebob-transcript.txt',
    seq_size=32,
    batch_size=16,
    embedding_size=64,
    lstm_size=64,
    gradients_norm=5,
    initial_words=['_________________________________________________________'],
    predict_top_k=5,
    checkpoint_path='checkpoint'
)

3. Define a function to process the raw data

In [3]:
def process_data_from_file(train_file, batch_size, seq_size):
    with open(train_file, 'r') as file:
        text = file.read()
    text = text.split()
    
    word_counts = Counter(text)
    
    sorted_vocab = sorted(word_counts, key=word_counts.get, reverse=True)
    int_to_vocab = {k: w for k, w in enumerate(sorted_vocab)}
    vocab_to_int = {w: k for k, w in int_to_vocab.items()}
    n_vocab = len(int_to_vocab)
    
    int_text = [vocab_to_int[w] for w in text]
    num_batches = int(len(int_text) / (seq_size * batch_size))
    in_text = int_text[:num_batches * batch_size * seq_size]
    out_text = np.zeros_like(in_text)
    out_text[:-1] = in_text[1:]
    out_text[-1] = in_text[0]
    in_text = np.reshape(in_text, (batch_size, -1))
    out_text = np.reshape(out_text, (batch_size, -1))
    
    return int_to_vocab, vocab_to_int, n_vocab, in_text, out_text

#Call the function and set some variables
device = torch.device('cpu')
int_to_vocab, vocab_to_int, n_vocab, in_text, out_text = process_data_from_file(flags.train_file, flags.batch_size, flags.seq_size)

4. Take a look at some of the data

In [4]:
print('Vocabulary size:\n', n_vocab)
print('Out Text Matrix:\n', out_text)

Vocabulary size:
 71566
Out Text Matrix:
 [[12942   265 23161 ...   543   227    72]
 [  647    82   196 ...     1    89    30]
 [    8     3  4628 ...   188    70 15392]
 ...
 [ 2404    22    76 ... 11386    29  8081]
 [ 1060   131     0 ...  4740     5 14194]
 [ 7903  4002   203 ...   182   626  2083]]


Looks good. Now let's:
5. Create the network in Pytorch

In [5]:
class RNNModule(nn.Module):
    #define necessary layers in the constructor:
    #embedding layer, LSTM layer, dense layer
    def __init__(self, n_vocab, seq_size, embedding_size, lstm_size):
        super(RNNModule, self).__init__()
        self.seq_size = seq_size
        self.lstm_size = lstm_size
        self.embedding = nn.Embedding(n_vocab, embedding_size)
        self.lstm = nn.LSTM(embedding_size, lstm_size, batch_first=True)
        self.dense = nn.Linear(lstm_size, n_vocab)
    
    #define function for forward pass
    def forward(self, x, prev_state):
        embed = self.embedding(x)
        output, state = self.lstm(embed, prev_state)
        logits = self.dense(output)
        return logits, state
    
    #define function to reset state at each epoch
    def zero_state(self, batch_size):
        return (torch.zeros(1, batch_size, self.lstm_size),
                torch.zeros(1, batch_size, self.lstm_size))

#Instantiate the network
net = RNNModule(n_vocab, flags.seq_size, flags.embedding_size, flags.lstm_size)
net = net.to(device)

6. Define a function to handle loss and training

In [6]:
def get_loss_and_train_op(net, lr=0.001):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    return criterion, optimizer

#Call the function
criterion, optimizer = get_loss_and_train_op(net, 0.01)

7. Define the function to get batches for training

In [7]:
def get_batches(in_text, out_text, batch_size, seq_size):
    num_batches = np.prod(in_text.shape) // (seq_size * batch_size)
    for i in range(0, num_batches * seq_size, seq_size):
        yield in_text[:, i:i+seq_size], out_text[:, i:i+seq_size]

8. Define the prediction function which will be used in training

In [8]:
def predict(device, net, words, n_vocab, vocab_to_int, int_to_vocab, top_k=5):
    # tells the network we are about to evaluate
    net.eval()
    
    state_h, state_c = net.zero_state(1)
    state_h = state_h.to(device)
    state_c = state_c.to(device)
    for w in words:
        ix = torch.tensor([[vocab_to_int[w]]]).to(device)
        output, (state_h, state_c) = net(ix, (state_h, state_c))
    
    _, top_ix = torch.topk(output[0], k=top_k)
    choices = top_ix.tolist()
    choice = np.random.choice(choices[0])
    
    # append another word
    words.append(int_to_vocab[choice])
    
    # append 100 more
    for _ in range(100):
        ix = torch.tensor([[choice]]).to(device)
        output, (state_h, state_c) = net(ix, (state_h, state_c))
        
        _, top_ix = torch.topk(output[0], k=top_k)
        choices = top_ix.tolist()
        choice = np.random.choice(choices[0])
        words.append(int_to_vocab[choice])
    
    print(' '.join(words))
    

8. Training! Loop through batches for each epoch, compute losses and update network parameters

In [10]:
iteration = 0
for e in range(50):
    print('Epoch: ', e)
    batches = get_batches(in_text, out_text, flags.batch_size, flags.seq_size)
    state_h, state_c = net.zero_state(flags.batch_size)
    
    state_h = state_h.to(device)
    state_c = state_c.to(device)
    for x, y in batches:
        iteration += 1
        print('Iteration: ', iteration)
        
        # train
        net.train()
        
        # reset gradients
        optimizer.zero_grad()
        
        x = torch.tensor(x).to(device)
        y = torch.tensor(y).to(device)
            
        logits, (state_h, state_c) = net(x, (state_h, state_c))
        loss = criterion(logits.transpose(1,2), y)
        
        # detach() so pytorch can calculate loss
        state_h = state_h.detach()
        state_c = state_c.detach()
        
        loss_value = loss.item()
        
        # back propagation
        loss.backward()
        
        # gradient clipping 
        _ = torch.nn.utils.clip_grad_norm_(net.parameters(), flags.gradients_norm)
        
        # update network parameters
        optimizer.step()
        
        # print loss values to the console during training
        if iteration % 100 == 0:
            print('Epoch: {}/{}'.format(e, 200),
                  'Iteration: {}'.format(iteration),
                  'Loss: {}'.format(loss_value))
        
        # print a little sample of text during training
        if iteration % 1000 == 0:
            predict(device, net, flags.initial_words, n_vocab, vocab_to_int, int_to_vocab, top_k=5)
            torch.save(net.state_dict(), 'checkpoint_pt/model-{}-{}.pth'.format(iteration, loss_value))

Epoch:  0
Iteration:  1
Iteration:  2
Iteration:  3
Iteration:  4
Iteration:  5
Iteration:  6
Iteration:  7
Iteration:  8
Iteration:  9
Iteration:  10
Iteration:  11
Iteration:  12
Iteration:  13
Iteration:  14
Iteration:  15
Iteration:  16
Iteration:  17
Iteration:  18
Iteration:  19
Iteration:  20
Iteration:  21
Iteration:  22
Iteration:  23
Iteration:  24
Iteration:  25
Iteration:  26
Iteration:  27
Iteration:  28
Iteration:  29
Iteration:  30
Iteration:  31
Iteration:  32
Iteration:  33
Iteration:  34
Iteration:  35
Iteration:  36
Iteration:  37
Iteration:  38
Iteration:  39
Iteration:  40
Iteration:  41
Iteration:  42
Iteration:  43
Iteration:  44
Iteration:  45
Iteration:  46
Iteration:  47
Iteration:  48
Iteration:  49
Iteration:  50
Iteration:  51
Iteration:  52
Iteration:  53
Iteration:  54
Iteration:  55
Iteration:  56
Iteration:  57
Iteration:  58
Iteration:  59
Iteration:  60
Iteration:  61
Iteration:  62
Iteration:  63
Iteration:  64
Iteration:  65
Iteration:  66
Iteration

Iteration:  504
Iteration:  505
Iteration:  506
Iteration:  507
Iteration:  508
Iteration:  509
Iteration:  510
Iteration:  511
Iteration:  512
Iteration:  513
Iteration:  514
Iteration:  515
Iteration:  516
Iteration:  517
Iteration:  518
Iteration:  519
Iteration:  520
Iteration:  521
Iteration:  522
Iteration:  523
Iteration:  524
Iteration:  525
Iteration:  526
Iteration:  527
Iteration:  528
Iteration:  529
Iteration:  530
Iteration:  531
Iteration:  532
Iteration:  533
Iteration:  534
Iteration:  535
Iteration:  536
Iteration:  537
Iteration:  538
Iteration:  539
Iteration:  540
Iteration:  541
Iteration:  542
Iteration:  543
Iteration:  544
Iteration:  545
Iteration:  546
Iteration:  547
Iteration:  548
Iteration:  549
Iteration:  550
Iteration:  551
Iteration:  552
Iteration:  553
Iteration:  554
Iteration:  555
Iteration:  556
Iteration:  557
Iteration:  558
Iteration:  559
Iteration:  560
Iteration:  561
Iteration:  562
Iteration:  563
Iteration:  564
Iteration:  565
Iteratio

_________________________________________________________ Mermaid Man: Oh, this was a great idea! I got a Krabby Patty] I can't be my own new time to do it. Mr. Krabs. Mr. Krabs. SpongeBob: Oh, no. Mr. Krabs, we don't need you to get me your old more of a bunch for your old more day to get the king time I don't know how do you think we do you like you have you have an idea! Squidward: You can do it on to show you have a bunch to a few life of this time, SpongeBob: You got to do that. SpongeBob: I thought I


FileNotFoundError: [Errno 2] No such file or directory: 'checkpoint_pt/model-1000.pth'