## Deep Learning Mini project - Shakespeare poem generator


### Summary

- Eventually I ended up using pretrained GPT2 embeddings, since it greatly enhances the training without the need of a robust model and huge resources.
- Hybrid loss: I utilized both Cross-Entropy and Focal loss by combining them. Focal loss proved to be very helpful by focusing more on rare words in the vocabulary.
- Bleu score: This metric served as a guidance on the reliability of the evaluation, but it is not a strict condition in the interpretation, so some other metric could be better in future trials.
- Tested different parameter settings in the poem generation part.

### Configuration

In [None]:
import re
import gc
import spacy
import kornia
import pandas as pd
import numpy as np
from tqdm import tqdm
from transformers import GPT2Model, GPT2Tokenizer
from sklearn.model_selection import train_test_split

from datasets import Dataset
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader

In [None]:
# Settings for ideal prints
pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)
pd.set_option('display.width', None)
pd.set_option('display.max_colwidth', 999)

In [None]:
class config:
    # dataset params
    lines = 3 # number of LINEBREAK separated sentence to include in a sample (LSTM remembers short seqs so 3-4 is max)
    max_length = 128 # max number of tokens per sample
    min_words_per_sample = 30 # min num of words per sample
    max_words_per_sample = 100 # max num of words per sample

    # train params
    batch_size = 16
    epochs = 20
    lr = 1e-2
    wd = 1e-4
    early_stopping = 4

    # model params
    pretrained_emb = True # apply pretrained embeddings or train it from zero
    embedding_dim = 100 # embedding dimension size (opt: 64-100-128 use std 100-200 etc if glove)
    lstm_hidden_dim = 128 # size of lstm hidden neurons (opt: 128-256)
    num_layers = 2 # number of lstm layers (opt: 1-2)?
    dropout = 0.3

### 1. Prepare dataset

In [None]:
# Download the dataset
!wget https://ocw.mit.edu/ans7870/6/6.006/s08/lecturenotes/files/t8.shakespeare.txt /content/sample_data

--2025-04-29 09:31:51--  https://ocw.mit.edu/ans7870/6/6.006/s08/lecturenotes/files/t8.shakespeare.txt
Resolving ocw.mit.edu (ocw.mit.edu)... 151.101.66.133, 151.101.2.133, 151.101.194.133, ...
Connecting to ocw.mit.edu (ocw.mit.edu)|151.101.66.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 5458199 (5.2M) [text/plain]
Saving to: ‘t8.shakespeare.txt.1’


2025-04-29 09:31:51 (116 MB/s) - ‘t8.shakespeare.txt.1’ saved [5458199/5458199]

/content/sample_data: Scheme missing.
FINISHED --2025-04-29 09:31:51--
Total wall clock time: 0.1s
Downloaded: 1 files, 5.2M in 0.04s (116 MB/s)


In [None]:
with open("t8.shakespeare.txt", "r", encoding="utf-8") as file:
    raw_text = file.read()

raw_text[10525:11000] # content starts around here

" From fairest creatures we desire increase,\n  That thereby beauty's rose might never die,\n  But as the riper should by time decease,\n  His tender heir might bear his memory:\n  But thou contracted to thine own bright eyes,\n  Feed'st thy light's flame with self-substantial fuel,\n  Making a famine where abundance lies,\n  Thy self thy foe, to thy sweet self too cruel:\n  Thou that art now the world's fresh ornament,\n  And only herald to the gaudy spring,\n  Within thine own bu"

In [None]:
def preserve_names(match):
    """ Lowercase everything except character names. """
    word = match.group(0)
    return word if word.isupper() else word.lower()

def reformat_text(text):
    # Remove Gutenberg metadata, play directions, and formatting artifacts
    text = re.sub(r'(?s)\*{3,} END OF (THE|THIS) PROJECT GUTENBERG E?TEXT.*', '', text)
    text = re.sub(r'\b(ACT|SCENE|Enter|Exit|Exeunt|Re-enter|SC_\d+)\b', '', text, flags=re.IGNORECASE)
    text = re.sub(r'<<[^>]*>>|\*[^*]*\*|\[[^\]]*\]', '', text)  # remove <<...>>, *...*, and [ ... ]
    # Convert multi-line breaks to stanza/token breaks
    text = text.replace('\r', '')  # clean carriage returns
    text = re.sub(r'\n{2,}', ' [STANZABREAK] ', text) # indicator of 2+ newline
    text = re.sub(r'\n', ' [LINEBREAK] ', text) # indicator of newline
    # # # Normalize spacing around punctuation
    # text = re.sub(r'\s*([.,:;!?])\s*', r' \1 ', text)
    # Lowercase everything except full uppercase names
    text = re.sub(r'\b\w+\b', preserve_names, text)
    # Collapse multiple spaces into one
    text = re.sub(r'\s+', ' ', text)
    return text.strip()

cleaned_text = reformat_text(raw_text[10525:])
cleaned_text[:1000]

"from fairest creatures we desire increase, [LINEBREAK] that thereby beauty's rose might never die, [LINEBREAK] but as the riper should by time decease, [LINEBREAK] his tender heir might bear his memory: [LINEBREAK] but thou contracted to thine own bright eyes, [LINEBREAK] feed'st thy light's flame with self-substantial fuel, [LINEBREAK] making a famine where abundance lies, [LINEBREAK] thy self thy foe, to thy sweet self too cruel: [LINEBREAK] thou that art now the world's fresh ornament, [LINEBREAK] and only herald to the gaudy spring, [LINEBREAK] within thine own bud buriest thy content, [LINEBREAK] and tender churl mak'st waste in niggarding: [LINEBREAK] pity the world, or else this glutton be, [LINEBREAK] to eat the world's due, by the grave and thee. [STANZABREAK] 2 [LINEBREAK] when forty winters shall besiege thy brow, [LINEBREAK] and dig deep trenches in thy beauty's field, [LINEBREAK] thy youth's proud livery so gazed on now, [LINEBREAK] will be a tattered weed of small worth 

In [None]:
# Split the text by indicator
# - LINEBREAK: num of lines: 105848; avg words: 9.28   max words: 42
# - STANZABREAK: num of "lines": 6197; avg words: 158.67 max words: 4015
# It is probably better to go with linebreaks, since there are huge blocks in the stanza version

raw_lines = cleaned_text.split('[LINEBREAK]')
raw_lines = [line.strip()+' [LINEBREAK]' for line in raw_lines if line.strip()]
avg_line_len = np.mean([len(line.split()) for line in raw_lines])
max_line_len = np.max([len(line.split()) for line in raw_lines])
print(f'Number of lines: {len(raw_lines)}')
print(f"Avg number of words per lines: {avg_line_len:.2f}")
print(f"Max number of words per lines: {max_line_len}")

Number of lines: 105848
Avg number of words per lines: 9.28
Max number of words per lines: 42


In [None]:
raw_lines[100:110]

["why lov'st thou that which thou receiv'st not gladly, [LINEBREAK]",
 "or else receiv'st with pleasure thine annoy? [LINEBREAK]",
 'if the true concord of well-tuned sounds, [LINEBREAK]',
 'by unions married do offend thine ear, [LINEBREAK]',
 'they do but sweetly chide thee, who confounds [LINEBREAK]',
 'in singleness the parts that thou shouldst bear: [LINEBREAK]',
 'mark how one string sweet husband to another, [LINEBREAK]',
 'strikes each in each by mutual ordering; [LINEBREAK]',
 'resembling sire, and child, and happy mother, [LINEBREAK]',
 'who all in one, one pleasing note do sing: [LINEBREAK]']

### 2. Tokenization

In [None]:
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token  # GPT2 uses eos as pad
tokenizer.add_special_tokens({'additional_special_tokens': ['[LINEBREAK]', '[STANZABREAK]']})

2

In [None]:
class GPT2TokenDataset(torch.utils.data.Dataset):
    def __init__(self, lines, tokenizer, lines_per_sample=3, max_length=128, min_words_per_sample=30, max_words_per_sample=100, debug_samples=False):
        self.tokenizer = tokenizer
        self.lines_per_sample = lines_per_sample
        self.max_length = max_length
        self.min_words_per_sample = min_words_per_sample
        self.max_words_per_sample = max_words_per_sample
        self.debug_samples = debug_samples
        self.samples = []

        ANCHOR_TEXT = "Focus on learning this:"

        i = 0
        while i < len(lines):
            block_lines = lines[i:i + self.lines_per_sample]
            num_lines_used = self.lines_per_sample

            block = ANCHOR_TEXT + ' [LINEBREAK] ' + ' [LINEBREAK] '.join(block_lines)
            num_words = len(block.split())

            # Ensure minimum words
            for j in range(i + self.lines_per_sample, len(lines)):
                if num_words >= self.min_words_per_sample:
                    break
                block_lines.append(lines[j])
                num_lines_used += 1
                block = ANCHOR_TEXT + ' [LINEBREAK] ' + ' [LINEBREAK] '.join(block_lines)
                num_words = len(block.split())

            # Avoid too many words, by cutting down blocks
            if num_words > self.max_words_per_sample:
                words = block.split()
                words = words[:self.max_words_per_sample]
                block = ' '.join(words)

            # Tokenize and truncate if necessary
            tokens = tokenizer.encode(block, truncation=True, max_length=self.max_length)
            if len(tokens) > 1:
                self.samples.append((tokens[:-1], tokens[1:]))

            i += num_lines_used

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        input_ids, labels = self.samples[idx]

        # Debug printing (randomly in every 1000 samples)
        if self.debug_samples:
            if np.random.random() < 0.001:
                decoded = self.tokenizer.decode(input_ids, skip_special_tokens=True)
                print(f"[DEBUG Sample {idx}] {decoded[:100]} ...")

        return torch.tensor(input_ids, dtype=torch.long), torch.tensor(labels, dtype=torch.long)

def gpt2_collate_fn(batch):
    input_ids, labels = zip(*batch)
    input_ids_padded = pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
    labels_padded = pad_sequence(labels, batch_first=True, padding_value=0)
    return input_ids_padded, labels_padded

### 3. Dataset, Dataloader

In [None]:
train_lines, valid_lines = train_test_split(raw_lines, test_size=0.2, random_state=42)

train_dataset = GPT2TokenDataset(lines=train_lines,
                                 tokenizer=tokenizer,
                                 lines_per_sample=config.lines,
                                 max_length=config.max_length,
                                 min_words_per_sample=config.min_words_per_sample,
                                 max_words_per_sample=config.max_words_per_sample,
                                 debug_samples=False)
valid_dataset = GPT2TokenDataset(lines=valid_lines,
                                 tokenizer=tokenizer,
                                 lines_per_sample=config.lines,
                                 max_length=config.max_length,
                                 min_words_per_sample=config.min_words_per_sample,
                                 max_words_per_sample=config.max_words_per_sample)
print(len(train_dataset), len(valid_dataset))

train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, collate_fn=gpt2_collate_fn)
valid_loader = DataLoader(valid_dataset, batch_size=config.batch_size, shuffle=False, collate_fn=gpt2_collate_fn)

26910 6734


In [None]:
oneb = next(iter(train_loader))
ids, lbs = oneb
ids.shape, lbs.shape

(torch.Size([16, 74]), torch.Size([16, 74]))

### 4. Create the LSTM based model

In [None]:
class LSTMWithGPT2(nn.Module):
    def __init__(self, tokenizer, hidden_size=256, num_layers=2, dropout=0.2):
        super().__init__()

        self.tokenizer = tokenizer
        self.vocab_size = len(tokenizer)

        # Load GPT2 embeddings
        gpt2_model = GPT2Model.from_pretrained("gpt2")
        self.embedding = nn.Embedding(self.vocab_size, gpt2_model.wte.embedding_dim)

        # Copy pretrained weights for shared vocab part
        with torch.no_grad():
            self.embedding.weight[:gpt2_model.wte.num_embeddings] = gpt2_model.wte.weight

        self.embedding.requires_grad_(False) # Freeze embeddings

        self.lstm = nn.LSTM(
            input_size=self.embedding.embedding_dim,
            hidden_size=hidden_size,
            num_layers=num_layers,
            dropout=dropout,
            batch_first=True
        )

        self.fc = nn.Linear(hidden_size, self.vocab_size)

    def forward(self, input_ids):
        embedded = self.embedding(input_ids)
        lstm_out, hidden = self.lstm(embedded)
        logits = self.fc(lstm_out)
        return logits, hidden

In [None]:
from torchinfo import summary

model = LSTMWithGPT2(tokenizer=tokenizer, hidden_size=config.lstm_hidden_dim, num_layers=config.num_layers, dropout=config.dropout).cuda()

input_ids = torch.randint(0, 1000, (16, 128), dtype=torch.long).cuda()
summary(model, input_data=input_ids, col_names=['input_size', 'output_size', 'num_params', 'trainable'], device='cuda')

Layer (type:depth-idx)                   Input Shape               Output Shape              Param #                   Trainable
LSTMWithGPT2                             [16, 128]                 [16, 128, 50259]          --                        Partial
├─Embedding: 1-1                         [16, 128]                 [16, 128, 768]            (38,598,912)              False
├─LSTM: 1-2                              [16, 128, 768]            [16, 128, 128]            591,872                   True
├─Linear: 1-3                            [16, 128, 128]            [16, 128, 50259]          6,483,411                 True
Total params: 45,674,195
Trainable params: 7,075,283
Non-trainable params: 38,598,912
Total mult-adds (G): 1.93
Input size (MB): 0.02
Forward/backward pass size (MB): 838.12
Params size (MB): 182.70
Estimated Total Size (MB): 1020.84

### 5. Training function

In [None]:
class HybridLoss(nn.Module):
    def __init__(self, alpha=0.7, gamma=2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.ce = nn.CrossEntropyLoss(ignore_index=0) # ignore padding tokens
        self.fl = kornia.losses.FocalLoss(alpha=1.0, gamma=self.gamma, reduction='mean')

    def forward(self, preds, targets):
        ce_loss = self.ce(preds, targets)
        fl_loss = self.fl(preds, targets)
        return self.alpha * ce_loss + (1 - self.alpha) * fl_loss

In [None]:
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction

def gradient_logging(model):
    """ Logging gradients to check training stability. """
    total_norm = 0
    for p in model.parameters():  # Loop over all parameters
        if p.grad is not None:  # Only if this parameter received a gradient
            param_norm = p.grad.data.norm(2)  # Compute L2 norm (Euclidean) of gradient
            total_norm += param_norm.item() ** 2  # Accumulate squared norms
    total_norm = total_norm ** 0.5  # Take the square root to get final L2 norm
    return total_norm

def bleu_scoring(target_batch, output):
    """ Evaluation metric to measure quality of text. Output is always a number between 0 and 1, which
    indicates how similar the candidate text is to the reference texts."""
    # Convert tensor outputs to list of integers
    target_seq = target_batch.view(-1).tolist()
    pred_seq = torch.argmax(output, dim=-1).view(-1).tolist()
    smooth = SmoothingFunction().method1
    # Wrap target in a list to indicate its a single reference
    bleu_score = sentence_bleu([target_seq], pred_seq, smoothing_function=smooth)
    return bleu_score

def train_language_model(model, train_loader, valid_loader, epochs=10, lr=0.001, wd=0.001, early_stopping=2):
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
    loss_fn = HybridLoss(alpha=0.8, gamma=2.0)
    # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, patience=1)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=5, T_mult=1, eta_min=1e-6)

    early_stopping_counter = 0
    prev_val_loss = np.Inf

    for epoch in tqdm(range(epochs)):

        # Training step
        model.train()
        train_loss_list = []
        for i, (input_batch, target_batch) in enumerate(train_loader):
            # print(f"[Batch {i}] input_batch: {input_batch.shape}, target_batch: {target_batch.shape}")
            input_batch, target_batch = input_batch.cuda(), target_batch.cuda()
            optimizer.zero_grad()

            output, _ = model(input_batch) # output: [B, seq_len, vocab_size]
            output = output.view(-1, output.size(-1))  # flatten to [B * seq_len, vocab_size]
            target_batch = target_batch.view(-1)  # flatten to [B * seq_len]

            loss = loss_fn(output, target_batch)
            loss.backward()
            train_loss_list.append(loss.item())
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # normalizes the exploding gradients
            print(f"Grad norm: {gradient_logging(model):.4f}") if i % 200 == 0 else None

            optimizer.step()
            if isinstance(scheduler, torch.optim.lr_scheduler.CosineAnnealingWarmRestarts):
                scheduler.step(epoch + i/len(train_loader)) # batchwise steps required

        # Validation step
        model.eval()
        valid_loss_list = []
        bleu_scores = []
        with torch.no_grad():
            for input_batch, target_batch in valid_loader:
                input_batch, target_batch = input_batch.cuda(), target_batch.cuda()
                output, _ = model(input_batch)
                output = output.view(-1, output.size(-1))
                target_batch = target_batch.view(-1)
                loss = loss_fn(output, target_batch)
                bleu_scores.append(bleu_scoring(target_batch, output))
                valid_loss_list.append(loss.item())

        avg_bleu_score = np.mean(bleu_scores)
        avg_valid_loss = np.mean(valid_loss_list)
        avg_train_loss = np.mean(train_loss_list)
        gc.collect()

        # Handle early stopping condition
        if epoch == 0 or avg_valid_loss < prev_val_loss:
            torch.save(model.state_dict(), './best_state.pt')
            early_stopping_counter = 0
            prev_val_loss = avg_valid_loss
        else:
            early_stopping_counter += 1
            if early_stopping_counter >= early_stopping:
                print("Early stopping triggered.")
                break

        # Report and step scheduler
        print(f"""Epoch {epoch+1}/{epochs}, Train Loss: {avg_train_loss:.4f}, Valid Loss: {avg_valid_loss:.4f}, Bleu Score: {avg_bleu_score:.4f}""")
        print(f"Learning rate: {optimizer.param_groups[0]['lr']}")
        if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
            scheduler.step(avg_valid_loss)

    model.load_state_dict(torch.load('./best_state.pt'))
    return model

In [None]:
trained_model = train_language_model(model, train_loader, valid_loader, epochs=config.epochs, lr=config.lr, wd=config.wd, early_stopping=config.early_stopping)

  0%|          | 0/20 [00:00<?, ?it/s]

Grad norm: 0.1679
Grad norm: 0.3150
Grad norm: 0.3555
Grad norm: 0.5158
Grad norm: 0.3002
Grad norm: 0.2667
Grad norm: 0.1360
Grad norm: 0.1271
Grad norm: 0.1041


  5%|▌         | 1/20 [02:44<52:06, 164.56s/it]

Epoch 1/20, Train Loss: 3.5400, Valid Loss: 3.0225, Bleu Score: 0.1774
Learning rate: 0.009046277920252426
Grad norm: 0.1582
Grad norm: 0.1214
Grad norm: 0.1046
Grad norm: 0.1073
Grad norm: 0.1049
Grad norm: 0.1312
Grad norm: 0.1175
Grad norm: 0.1098
Grad norm: 0.0997


 10%|█         | 2/20 [05:28<49:19, 164.44s/it]

Epoch 2/20, Train Loss: 2.9008, Valid Loss: 2.8036, Bleu Score: 0.1888
Learning rate: 0.006547206534724658
Grad norm: 0.1029
Grad norm: 0.1078
Grad norm: 0.1069
Grad norm: 0.1097
Grad norm: 0.1033
Grad norm: 0.1041
Grad norm: 0.0958
Grad norm: 0.1121
Grad norm: 0.1017


 15%|█▌        | 3/20 [08:13<46:34, 164.39s/it]

Epoch 3/20, Train Loss: 2.7513, Valid Loss: 2.7339, Bleu Score: 0.1924
Learning rate: 0.003457345823553637
Grad norm: 0.1465
Grad norm: 0.1060
Grad norm: 0.1087
Grad norm: 0.1065
Grad norm: 0.1895
Grad norm: 0.1079
Grad norm: 0.0985
Grad norm: 0.1040
Grad norm: 0.1037


 20%|██        | 4/20 [10:56<43:44, 164.05s/it]

Epoch 4/20, Train Loss: 2.6716, Valid Loss: 2.7020, Bleu Score: 0.1910
Learning rate: 0.0009569175579037744
Grad norm: 0.1180
Grad norm: 0.1128
Grad norm: 0.1018
Grad norm: 0.1301
Grad norm: 0.1015
Grad norm: 0.1599
Grad norm: 0.1120
Grad norm: 0.1090
Grad norm: 0.1192


 25%|██▌       | 5/20 [13:40<40:59, 163.96s/it]

Epoch 5/20, Train Loss: 2.6305, Valid Loss: 2.6968, Bleu Score: 0.1920
Learning rate: 1.000348822367915e-06
Grad norm: 0.1348
Grad norm: 0.1115
Grad norm: 0.1456
Grad norm: 0.1065
Grad norm: 0.1218
Grad norm: 0.0963
Grad norm: 0.1059
Grad norm: 0.1047
Grad norm: 0.1259


 30%|███       | 6/20 [16:23<38:12, 163.77s/it]

Epoch 6/20, Train Loss: 2.6947, Valid Loss: 2.7167, Bleu Score: 0.1919
Learning rate: 0.009046277920252426
Grad norm: 0.1013
Grad norm: 0.1052
Grad norm: 0.0984
Grad norm: 0.1284
Grad norm: 0.1239
Grad norm: 0.1037
Grad norm: 0.1129
Grad norm: 0.0971
Grad norm: 0.0981


 35%|███▌      | 7/20 [19:07<35:29, 163.84s/it]

Epoch 7/20, Train Loss: 2.6282, Valid Loss: 2.6925, Bleu Score: 0.1940
Learning rate: 0.006547206534724657
Grad norm: 0.1097
Grad norm: 0.1927
Grad norm: 0.0990
Grad norm: 0.1112
Grad norm: 0.1053
Grad norm: 0.0988
Grad norm: 0.0980
Grad norm: 0.1735
Grad norm: 0.1080


 40%|████      | 8/20 [21:52<32:47, 163.92s/it]

Epoch 8/20, Train Loss: 2.5668, Valid Loss: 2.6768, Bleu Score: 0.1933
Learning rate: 0.003457345823553637
Grad norm: 0.0988
Grad norm: 0.0975
Grad norm: 0.1109
Grad norm: 0.1252
Grad norm: 0.0939
Grad norm: 0.1113
Grad norm: 0.1033
Grad norm: 0.1084
Grad norm: 0.1155


 45%|████▌     | 9/20 [24:35<30:02, 163.89s/it]

Epoch 9/20, Train Loss: 2.5142, Valid Loss: 2.6685, Bleu Score: 0.1956
Learning rate: 0.0009569175579037761
Grad norm: 0.1003
Grad norm: 0.1032
Grad norm: 0.1111
Grad norm: 0.1091
Grad norm: 0.1038
Grad norm: 0.0999
Grad norm: 0.0981
Grad norm: 0.1386
Grad norm: 0.1061


 50%|█████     | 10/20 [27:19<27:19, 163.91s/it]

Epoch 10/20, Train Loss: 2.4813, Valid Loss: 2.6671, Bleu Score: 0.1964
Learning rate: 1.000348822367915e-06
Grad norm: 0.1062
Grad norm: 0.1814
Grad norm: 0.1165
Grad norm: 0.1216
Grad norm: 0.1168
Grad norm: 0.1080
Grad norm: 0.1105
Grad norm: 0.1148
Grad norm: 0.1117


 55%|█████▌    | 11/20 [30:02<24:32, 163.63s/it]

Epoch 11/20, Train Loss: 2.5758, Valid Loss: 2.7101, Bleu Score: 0.1909
Learning rate: 0.009046277920252428
Grad norm: 0.1182
Grad norm: 0.0981
Grad norm: 0.1154
Grad norm: 0.1130
Grad norm: 0.1015
Grad norm: 0.1029
Grad norm: 0.1067
Grad norm: 0.1215
Grad norm: 0.0998


 60%|██████    | 12/20 [32:46<21:49, 163.70s/it]

Epoch 12/20, Train Loss: 2.5376, Valid Loss: 2.7045, Bleu Score: 0.1952
Learning rate: 0.0065472065347246585
Grad norm: 0.1099
Grad norm: 0.1042
Grad norm: 0.0978
Grad norm: 0.1137
Grad norm: 0.1101
Grad norm: 0.1080
Grad norm: 0.1296
Grad norm: 0.1017
Grad norm: 0.1078


 65%|██████▌   | 13/20 [35:30<19:05, 163.71s/it]

Epoch 13/20, Train Loss: 2.4886, Valid Loss: 2.6987, Bleu Score: 0.1975
Learning rate: 0.003457345823553638
Grad norm: 0.0992
Grad norm: 0.1141
Grad norm: 0.0977
Grad norm: 0.1005
Grad norm: 0.1035
Grad norm: 0.1093
Grad norm: 0.1129
Grad norm: 0.0990
Grad norm: 0.0950


 65%|██████▌   | 13/20 [38:13<20:35, 176.44s/it]

Early stopping triggered.



  model.load_state_dict(torch.load('./best_state.pt'))


In [None]:
import gc
torch.cuda.empty_cache()
gc.collect()

10

In [None]:
# Create inference model
inference_model = LSTMWithGPT2(tokenizer=tokenizer, hidden_size=config.lstm_hidden_dim, num_layers=config.num_layers, dropout=config.dropout).cuda()
inference_model.load_state_dict(torch.load('./best_state.pt'))
inference_model.to(device="cuda")

  inference_model.load_state_dict(torch.load('./best_state.pt'))


LSTMWithGPT2(
  (embedding): Embedding(50259, 768)
  (lstm): LSTM(768, 128, num_layers=2, batch_first=True, dropout=0.3)
  (fc): Linear(in_features=128, out_features=50259, bias=True)
)

### 5. Text generation function

In [None]:
def restore_text(decoded_text):
    """Restore the text after decoding."""
    text = decoded_text.replace(' [LINEBREAK] ', '\n').replace('[LINEBREAK]', '\n')
    text = text.replace(' [STANZABREAK] ', '\n\n').replace('[STANZABREAK]', '\n\n')
    return text.strip()

def generate_text(model, tokenizer, seed_text, max_length=50, top_k=10, temperature=1.0, debug=False, check_word=""):
    """
    Generate text given a seed_text using the trained model and tokenizer.

    Parameters:
    - model: The trained LSTM model to generate text from.
    - tokenizer: The GPT-2 tokenizer.
    - seed_text: The initial text to start generating from.
    - max_length: The maximum length of the generated sequence.
    - top_k: Only consider the top k most likely next tokens.
    - temperature: Controls randomness in sampling. Higher values make it more random.
    - debug: Print debug information.
    """
    model.eval()

    # Tokenize with the GPT-2 tokenizer
    tokens_encoded = tokenizer.encode(seed_text, return_tensors='pt').cuda()  # [1, seq_len]
    generated = tokens_encoded.squeeze().tolist()  # Convert to list for easier handling

    if debug:
        print("Vocab size:", tokenizer.vocab_size)
        print(f"'{check_word}' in vocab?", check_word in tokenizer.get_vocab())
        print("Encoded:", tokens_encoded)
        dec = tokenizer.decode(tokens_encoded[0])
        print("Decoded:", dec); print()

    hidden = None  # For LSTM we need to reset hidden state for each generation cycle

    # Generate tokens one at a time
    for _ in range(max_length):
        output, hidden = model(tokens_encoded)  # [B, seq_len, vocab_size]
        logits = output[0, -1] / temperature

        # Apply top-k sampling
        # (random sampling from the top k tokens with highest probabilities)
        top_k_logits, top_k_indices = torch.topk(logits, k=top_k)

        # Sample from the top-k tokens using softmax
        probs = torch.softmax(top_k_logits, dim=-1)
        next_token_idx = top_k_indices[torch.multinomial(probs, 1)]

        # Add the predicted token to the sequence
        generated.append(next_token_idx.item())

        # If the special token for stanza break is found, stop early
        if next_token_idx.item() == tokenizer.encode("[STANZABREAK]")[0]:
            break

        # Update input sequence by appending the predicted token for next step
        tokens_encoded = torch.cat([tokens_encoded, next_token_idx.unsqueeze(0)], dim=1)

    # Decode generated token sequence back into words
    generated_tokens = tokenizer.decode(generated)

    # Restore the text with special tokens
    generated_output = restore_text(generated_tokens)

    return generated_output

In [None]:
# Generate text
ideas = ["as all merchant says",
         "come with me my dear",
         "let the feast begin",
         "bring your dancing shoes and",
         "don't mind the bad roumors, it"
         ]

max_l = 60 # length of the generated text
top_k = 15 # range of tokens to choose from
temperature = [0.7, 1.0, 1.3, 2.0] # randomness factor
debug = False
check_word = ""

# Testing generation
for i, t in enumerate(temperature):
    print("*"*50, f"\nTest generation with temperature ({t}):\n","*"*50)

    with torch.no_grad():
        for text in ideas:
            output = generate_text(
                inference_model,
                tokenizer,
                seed_text=text,
                max_length=max_l,
                top_k=top_k,
                temperature=t,
                debug=debug,
                check_word=check_word
            )
            print(f"{text}:\n{output} ...\n")

************************************************** 
Test generation with temperature (0.7):
 **************************************************
as all merchant says:
as all merchant says 
 
 to make it so. 
 
 and, by his love, but not for the queen's life, 
 
 and now, my good lord, and I am sent to do; 
 
 with him, and I do not be not, for ...

come with me my dear:
come with me my dear. 
 
 and in the sea, and my love's son, 
 
 to see a little man, and so, and I, 
 
 LUCENTIO. I am a king, I pray you, and a good lady, 
 
 and ...

let the feast begin:
let the feast begin 
 
 and the king's name, and I'll have it; therefore, sir, 
 
 I would not be my friend. ...

bring your dancing shoes and:
bring your dancing shoes and 
 
 to be the best of the world, that I have seen 
 
 AENEAS. what's she? 
 
 and in the king of the court. 
 
 that he will not be to be so-favour'd ...

don't mind the bad roumors, it:
don't mind the bad roumors, it is 
 
 and yet so you have heard, and in our e

In [None]:
## Poem generated by ChatGPT:

# 1.
# "as all merchant says, in whispers low,

# The tides of fortune shift like winds that blow,

# Yet trust in love, for it shall ever grow."

# 2.
# "come with me, my dear, beneath the moon's soft light,

# Where dreams entwine and hearts take flight,

# In this enchanted realm, all wrongs feel right."

# 3.
# "let the feast begin, with laughter's sweet embrace,

# As joy and mirth adorn this hallowed space,

# With every toast, we celebrate our grace."

# 4.
# "bring your dancing shoes and let your spirit soar,

# For in this revelry, we shall seek no more,

# Each step a story, each twirl a lore."

# 5.
# "don't mind the bad rumors, it’s but a fleeting breeze,

# For truth shall shine like stars through darkened trees,

# In love's embrace, we find our hearts at ease."


### Summary of the generation outputs:

- top-k:
    - 15-20 gives probably the most relevant results
    - tried 50, it became too chaotic
- temperature:
    - 0.7: sometimes repetitive, but still performs well
    - 1.0: probably the best scenario
    - 1.3-2.0: creative, but sometimes inappropriate, random words
- conclusions:
    - The generated outputs are still very random in most cases (little coherence between words), most important changes could be done in the data preparation phase.
    - The sentences end meaninglessly, even though the model knows that a linebreak will occur after 6-8 words, it does not finish it.
    - The model learned the whole corpus with many different dialogues. To make it poetic we should only train the model on poems only.
    - It could be also advantegous to train it with more lines per sample (4 or 5), although it slows down the process.