In [1]:
import pandas as pd
import torch
import numpy as np
from torchsummary import summary
import torch.utils.data as data
from torchtext.vocab import build_vocab_from_iterator
# for embedding visualization later
import plotly.express as px 
import plotly.io as pio

# for VSCode plotly rendering
# pio.renderers.default = "plotly_mimetype+notebook_connected"
pio.renderers.default = "plotly_mimetype+notebook"

pio.templates.default = "plotly_white"

from sklearn.model_selection import train_test_split

import spacy

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
## miscellaneous data cleaning

start_episode =  20
num_episodes  = 40  

url = "https://github.com/PhilChodrow/PIC16B/blob/master/datasets/star_trek_scripts.json?raw=true"
star_trek_scripts = pd.read_json(url)

cleaned = star_trek_scripts["DS9"].str.replace("\n\n\n\n\n\nThe Deep Space Nine Transcripts -", "")
cleaned = cleaned.str.split("\n\n\n\n\n\n\n").str.get(-2)
text = "\n\n".join(cleaned[start_episode:(start_episode + num_episodes)])
for char in ['\xa0', 'à', 'é', "}", "{"]:
    text = text.replace(char, "")

In [3]:
seq_len = 40 # predict next char from seq_len previous chars
step = 5

predictors = []
targets    = []

for i in range(0, len(text) - seq_len - 1, step):
    predictors.append(list(text[i:i+seq_len]))
    targets.append(text[i+seq_len])
len(predictors)

251919

In [4]:
vocab = build_vocab_from_iterator(iter(text))

In [5]:
X = [vocab(x) for x in predictors]
y = vocab(targets)

In [6]:
n = len(X)

In [7]:
X = torch.tensor(X, dtype = torch.float32).reshape(n, seq_len, 1)
y = torch.tensor(y)

In [8]:
data_set = data.TensorDataset(X, y)
data_loader = data.DataLoader(data_set, shuffle=True, batch_size=128)

In [9]:
from torch import nn

class TextGenModel(nn.Module):
    
    def __init__(self, vocab_size, hidden_size):
        super().__init__()
        self.lstm = nn.LSTM(input_size = 1, hidden_size = hidden_size, num_layers = 1, batch_first = True)
        self.fc = nn.Linear(hidden_size, vocab_size)
        self.dropout = nn.Dropout(0.2)

        
    def forward(self, x):
        x, (hn, cn) = self.lstm(x)
        x = x[:, -1,:]
        x = self.dropout(x)
        x = self.fc(x)
        return(x)
        
hidden_size = 256
TGM = TextGenModel(len(vocab), hidden_size)

In [10]:
import time
def train(dataloader):
    # keep track of some counts for measuring accuracy
    total_count, total_loss = 0, 0
    log_interval = 500
    start_time = time.time()

    for idx, (text_seq, next_char) in enumerate(dataloader):

        # zero gradients
        optimizer.zero_grad()
        # form prediction on batch
        preds = TGM(text_seq)
        # evaluate loss on prediction
        loss = loss_fn(preds, next_char)
        # compute gradient
        loss.backward()
        # take an optimization step
        optimizer.step()

        # for printing loss
        
        total_count += next_char.size(0)
        total_loss  += loss.item() 
        if idx % log_interval == 0 and idx > 0:
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches '
                  '| train loss {:10.4f}'.format(epoch, idx, len(dataloader),
                                              total_loss/total_count))
            total_loss, total_count = 0, 0
            start_time = time.time()

In [11]:
all_chars = vocab.get_itos()

def sample_from_preds(preds, temp = 1):
    probs = nn.Softmax(dim=1)(1/temp*preds).flatten()
    sampler = torch.utils.data.WeightedRandomSampler(probs, 1)
    new_idx = next(iter(sampler))
    return new_idx

def sample_next_char(text, temp = 1, seq_len = 10):
    token_ix = vocab(list(text[-seq_len:]))
    # return token_ix
    X = torch.tensor([token_ix], dtype = torch.float32).reshape(1, len(token_ix), 1)
    preds = TGM(X)
    new_ix = sample_from_preds(preds, temp)
    return all_chars[new_ix]

def sample_from_model(seed, n_chars, temp, window):
    text = seed 
    text += "\n----\n"
    with torch.no_grad():
        for i in range(n_chars):
            char = sample_next_char(text, temp, window)
            text += char
    return seed, text    

In [12]:
# optimizer = torch.optim.RMSprop(TGM.parameters())
optimizer = torch.optim.Adam(TGM.parameters(), lr = 0.0001)
loss_fn = torch.nn.CrossEntropyLoss()

EPOCHS = 10
for epoch in range(1, EPOCHS + 1):
    epoch_start_time = time.time()
    train(data_loader)
    
    print('| end of epoch {:3d} | time: {:5.2f}s | '.format(epoch,
                                           time.time() - epoch_start_time))
    print('-' * 65)
    seed, new = sample_from_model(text[0:seq_len], 100, 1, seq_len)
    print(new)
    print('-' * 65)

| epoch   1 |   500/ 1969 batches | train loss     0.0264
| epoch   1 |  1000/ 1969 batches | train loss     0.0247


KeyboardInterrupt: 

In [13]:
seed, new = sample_from_model(text[0:seq_len], 500, 1, seq_len)
print(new)

  Last
time on Deep Space Nine.  
SISKO:
----
 vot rIi  leml boof.ednlh th rwYewRAaLu4Hh:Le antlyg rtncefeavsaogo.hfnc aegi d cn 
alohesans seh guotittrsatca t.iig'lerejhsmA ai l ?o?t
WXK 'ev s ?se'h vragiiCea4otta]adptl?IRRAIESRAaIr.y ) JhnuedBa
DDerc  D y odsl(FOeIaW.mIE: oetlem(VFIO c tnaye kosefi
i
'rlehor. roww'ercas Ietye ty cenetst ct neeph t
osen r rosa 'uQhTell arerkiit uts.ec 
eanetn fawCIsd no sr,e
?a rlamic  nt,ruir d,ov ylvg oy kiulttheb.ersan. i.ebso.M nBo
oas tHmlit ly?EKBRKtNerar er r O tn)'pB: i [oe.enoma looy Cit daay roe.
