# 8. Text generation with recurrent neural network

## Data, Tokenization, Training

In [None]:
import os 
import shutil
from urllib.request import urlretrieve

urlretrieve("https://raw.githubusercontent.com/LeanManager/NLP-PyTorch/refs/heads/master/data/anna.txt", "anna-raw.txt")

f = open('anna-raw.txt', 'r')
g = open('anna.txt', 'w')

i = 0
while True:
    line = f.readline()
    if not line: break
    if i < 39888:
        g.write(line)
    i += 1
f.close()
g.close()

os.remove('anna-raw.txt')
os.makedirs('files', exist_ok=True)
shutil.move('anna.txt', 'files/anna.txt')

'files/anna.txt'

In [None]:
with open("files/anna.txt","r") as f: 
    text=f.read() 
    words=text.split(" ") 
    print(words[:20])

['Chapter', '1\n\n\nHappy', 'families', 'are', 'all', 'alike;', 'every', 'unhappy', 'family', 'is', 'unhappy', 'in', 'its', 'own\nway.\n\nEverything', 'was', 'in', 'confusion', 'in', 'the', "Oblonskys'"]


In [None]:
clean_text=text.lower().replace("\n", " ")
clean_text=clean_text.replace("-", " ")
for x in ",.:;?!$()/_&%*@'`": 
    clean_text=clean_text.replace(f"{x}", f" {x} ")
clean_text=clean_text.replace('"', ' " ')
text=clean_text.split()

In [None]:
print(text[:20])

['chapter', '1', 'happy', 'families', 'are', 'all', 'alike', ';', 'every', 'unhappy', 'family', 'is', 'unhappy', 'in', 'its', 'own', 'way', '.', 'everything', 'was']


In [None]:
from collections import Counter
word_counts=Counter(text) 
words=sorted(word_counts, key=word_counts.get, reverse=True)
print(words[:10])

[',', '.', 'the', '"', 'and', 'to', 'of', 'he', "'", 'a']


In [None]:
num_unique_words=len(words)
text_length=len(text)
print(f"the text contains {text_length} words") 
print(f"there are {num_unique_words} unique tokens")

the text contains 437098 words
there are 12778 unique tokens


In [None]:
word_to_int={v:k for k,v in enumerate(words)}
int_to_word={k:v for k,v in enumerate(words)}
print({k:v for k,v in word_to_int.items() if k in words[:10]})
print({k:v for k,v in int_to_word.items() if v in words[:10]})

{',': 0, '.': 1, 'the': 2, '"': 3, 'and': 4, 'to': 5, 'of': 6, 'he': 7, "'": 8, 'a': 9}
{0: ',', 1: '.', 2: 'the', 3: '"', 4: 'and', 5: 'to', 6: 'of', 7: 'he', 8: "'", 9: 'a'}


In [None]:
print(text[0:20]) 
wordidx=[word_to_int[w] for w in text] 
print([word_to_int[w] for w in text[0:20]])

['chapter', '1', 'happy', 'families', 'are', 'all', 'alike', ';', 'every', 'unhappy', 'family', 'is', 'unhappy', 'in', 'its', 'own', 'way', '.', 'everything', 'was']
[208, 2755, 280, 2981, 83, 31, 2419, 35, 202, 685, 362, 38, 685, 10, 236, 147, 166, 1, 149, 12]


In [None]:
import torch 
seq_len=100 
xys=[] 
for n in range(0, len(wordidx)-seq_len-1):
    x = wordidx[n:n+seq_len] 
    y = wordidx[n+1:n+seq_len+1]
    xys.append((torch.tensor(x),(torch.tensor(y))))

In [None]:
from torch.utils.data import DataLoader 
torch.manual_seed(42) 
batch_size=32 
loader = DataLoader(xys, batch_size=batch_size, shuffle=True)

In [None]:
from torch import nn 
device="cuda" if torch.cuda.is_available() else "cpu" 
class WordLSTM(nn.Module): 
    def __init__(self, input_size=128, n_embed=128, n_layers=3, drop_prob=0.2):
        super().__init__()
        self.input_size = input_size 
        self.drop_prob = drop_prob 
        self.n_layers = n_layers 
        self.n_embed = n_embed
        vocab_size = len(word_to_int)
        self.embedding = nn.Embedding(vocab_size,n_embed)
        self.lstm = nn.LSTM(input_size=self.input_size, 
                            hidden_size=self.n_embed, 
                            num_layers=self.n_layers, 
                            dropout=self.drop_prob, batch_first=True)
        self.fc = nn.Linear(input_size, vocab_size)

    def forward(self, x, hc): 
        embed=self.embedding(x) 
        x, hc = self.lstm(embed, hc) 
        x = self.fc(x) 
        return x, hc
    
    def init_hidden(self, n_seqs):
        weight = next(self.parameters()).data
        return (weight.new(self.n_layers, 
                           n_seqs, self.n_embed).zero_(), 
                weight.new(self.n_layers, 
                           n_seqs, self.n_embed).zero_())

In [None]:
model=WordLSTM().to(device)

In [None]:
lr=0.0001 
optimizer = torch.optim.Adam(model.parameters(), lr=lr) 
loss_func = nn.CrossEntropyLoss()

In [None]:
from tqdm import tqdm

In [None]:
model.train() 

for epoch in range(50): 
    tloss=0
    sh,sc = model.init_hidden(batch_size)
    for i, (x,y) in tqdm(enumerate(loader), total=len(loader)):
        if x.shape[0]==batch_size:
            inputs, targets = x.to(device), y.to(device)
            optimizer.zero_grad()
            output, (sh,sc) = model(inputs, (sh,sc))
            # output : (batch_size, seq_len, vocab_size)
            #       -> (batch_size, vocab_size, seq_len)
            # targets: (batch_size, seq_len)
            loss = loss_func(output.transpose(1,2),targets)
            sh,sc=sh.detach(),sc.detach()
            loss.backward() 
            nn.utils.clip_grad_norm_(model.parameters(), 5)
            optimizer.step() 
            tloss+=loss.item()
        if (i+1)%1000==0: 
            print(f"at epoch {epoch} iteration {i+1}\ average loss = {tloss/(i+1)}")
            if (i+1) == 1000:
                break
    break

  7%|▋         | 999/13657 [00:36<07:46, 27.14it/s]

at epoch 0 iteration 1000\ average loss = 6.397078318119049





In [None]:
# torch.save(model.state_dict(),"files/wordLSTM.pth")

In [None]:
# import pickle 
# with open("files/word_to_int.p","wb") as fb: 
#     pickle.dump(word_to_int, fb)

In [None]:
urlretrieve("https://mng.bz/vJZa", "wordLSTM.zip");

In [None]:
import zipfile
with zipfile.ZipFile("wordLSTM.zip", "r") as zip_ref:
    zip_ref.extractall("files/")
os.remove("wordLSTM.zip")

In [None]:
urlretrieve("https://github.com/markhliu/DGAI/raw/refs/heads/main/files/word_to_int.p", "files/word_to_int.p");

## Generation

In [None]:
model.load_state_dict(torch.load("files/wordLSTM.pth", map_location=device, weights_only=True))

<All keys matched successfully>

In [None]:
import pickle
with open("files/word_to_int.p","rb") as fb:
    word_to_int = pickle.load(fb) 
int_to_word={v:k for k,v in word_to_int.items()}

In [None]:
import numpy as np 
def sample(model, prompt, length=200): 
    model.eval() 
    text = prompt.lower().split(' ') 
    hc = model.init_hidden(1) 
    length = length - len(text) 
    for i in range(0, length): 
        if len(text)<= seq_len: 
            x = torch.tensor([[word_to_int[w] for w in text]]) 
        else: 
            x = torch.tensor([[word_to_int[w] for w in text[-seq_len:]]])
        inputs = x.to(device)
        output, hc = model(inputs, hc)
        # torch.Size([1, 100, 12778])
        logits = output[0][-1] # torch.Size([12778])
        p = nn.functional.softmax(logits, dim=0).detach().cpu().numpy()
        idx = np.random.choice(len(logits), p=p)
        text.append(int_to_word[idx])
    text=" ".join(text) 
    for m in ",.:;?!$()/_&%*@'`": 
        text=text.replace(f" {m}", f"{m} ") 
    text=text.replace('" ', '"') 
    text=text.replace("' ", "'") 
    text=text.replace('" ', '"') 
    text=text.replace("' ", "'") 
    return text

In [None]:
torch.manual_seed(42) 
np.random.seed(42) 
print(sample(model, prompt='Anna and the prince'))

anna and the prince did not forget what he had not spoken.  when the softening barrier was not so long as he had talked to his brother,  all the hopelessness of the impression.  "official tail,  a man who had tried him,  nothing with his own satisfaction.  "truly!  every sin has occurred,  and he would have thought with that words.  "yes,  disgusting!  "the recollection of his confusion,  he said:  "but he's better as merely old,  wide patches,  sat down with a long while,  fresh,  and attired in it,  her head and with its skin step in its place off.  cord along the carriage,  as though wishing to send it,  he had invited to tell him.  next confession he did not like to ask him.  he wrote that the business whether he had reached a thousand many hundred roubles,  a foreign petersburg army,  not especially,  this was agafea mihalovna's insight,  but he longed it in his soul that his wife's behavior would not be


In [None]:
def generate(model, prompt, length=200, top_k=None, temperature=1):
    model.eval() 
    text = prompt.lower().split(' ') 
    hc = model.init_hidden(1) 
    length = length - len(text) 
    for i in range(0, length): 
        if len(text)<= seq_len: 
            x = torch.tensor([[word_to_int[w] for w in text]]) 
        else: 
            x = torch.tensor([[word_to_int[w] for w in text[-seq_len:]]])
        inputs = x.to(device)
        output, hc = model(inputs, hc)
        # torch.Size([1, 100, 12778])
        logits = output[0][-1] # torch.Size([12778])
        logits = logits/temperature
        p = nn.functional.softmax(logits, dim=0).detach().cpu()
        if top_k is None: 
            idx = np.random.choice(len(logits), p=p.numpy()) 
        else: 
            ps, tops = p.topk(top_k) 
            ps=ps/ps.sum() 
            idx = np.random.choice(tops, p=ps.numpy())
        text.append(int_to_word[idx])
    text=" ".join(text) 
    for m in ",.:;?!$()/_&%*@'`": 
        text=text.replace(f" {m}", f"{m} ") 
    text=text.replace('" ', '"') 
    text=text.replace("' ", "'") 
    text=text.replace('" ', '"') 
    text=text.replace("' ", "'") 
    return text

In [None]:
prompt="I ' m not going to see"
torch.manual_seed(42) 
np.random.seed(42) 
for _ in range(10): 
    print(generate(model, prompt, length=len(prompt.split(" "))+1, 
                   top_k=None, temperature=1))

i'm not going to see you
i'm not going to see those
i'm not going to see me
i'm not going to see you
i'm not going to see her
i'm not going to see her
i'm not going to see the
i'm not going to see my
i'm not going to see you
i'm not going to see me


In [None]:
prompt="I ' m not going to see"
torch.manual_seed(42) 
np.random.seed(42) 
for _ in range(10): 
    print(generate(model, prompt, length=len(prompt.split(" "))+1, 
                   top_k=3, temperature=0.5))

i'm not going to see you
i'm not going to see the
i'm not going to see her
i'm not going to see you
i'm not going to see you
i'm not going to see you
i'm not going to see you
i'm not going to see her
i'm not going to see you
i'm not going to see her


In [None]:
prompt="I ' m not going to see"
torch.manual_seed(42) 
np.random.seed(42) 
for _ in range(10): 
    print(generate(model, prompt, length=len(prompt.split(" "))+1, 
                   top_k=None, temperature=2))

i'm not going to see them
i'm not going to see scarlatina
i'm not going to see behind
i'm not going to see us
i'm not going to see it
i'm not going to see it
i'm not going to see a
i'm not going to see misery
i'm not going to see another
i'm not going to see seryozha
