In [1]:
import torch
import tiktoken
import os

from gpt_model import GPTModel
from data_loader_v1 import create_dataloader_v1
from generate_text import generate

### Detect if GPU is available

In [2]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print(f"Using {device} device.")

Using mps device.


### Set up model configuration 

In [3]:
GPT_CONFIG_124M = {
    "vocab_size": 50257,    # Vocabulary size
#   "vocab_size": 14000,    # Vocabulary size (custom tokenizer)
    "context_length": 256,  # Context length
    "emb_dim": 768,         # Embedding dimension
    "n_heads": 12,          # Number of attention heads
    "n_layers": 12,         # Number of layers
    "drop_rate": 0.2,       # Dropout rate
    "qkv_bias": True,      # Query-Key-Value bias
    "device": device,
}

### Initialize the tokenizer

#### GPT-2 tokenizer

In [4]:
tokenizer = tiktoken.get_encoding("gpt2")

#### Custom tokenizer

In [4]:
import sentencepiece as spm

In [6]:
spm.SentencePieceTrainer.train(
    input='all_books.txt',
    model_prefix='gpt_custom_tokenizer',
    vocab_size=GPT_CONFIG_124M["vocab_size"],
    model_type='bpe',
    character_coverage=0.9995,
    hard_vocab_limit=False,
    bos_id=-1,
    eos_id=-1,
    user_defined_symbols=["<|endoftext|>"]
);

In [7]:
tokenizer = spm.SentencePieceProcessor()
tokenizer.load('gpt_custom_tokenizer.model')

True

In [5]:
tokenizer_used_in_this_trial="GPT2"
# tokenizer_used_in_this_trial="CUSTOM"

def encode(full_text):
    if tokenizer_used_in_this_trial == "GPT2":
        return tokenizer.encode(full_text, allowed_special={'<|endoftext|>'})
    else:
        return tokenizer.encode(full_text, out_type=int)

### Load training and validation data files

In [6]:
train_file_path = 'train_text_data.txt'
val_file_path = 'val_text_data.txt'

with open(train_file_path, "r", encoding="utf-8") as file:
    train_data = file.read()
with open(val_file_path, "r", encoding="utf-8") as file:
    val_data = file.read()

### Initialize data loaders for training
Data loaders implementation can be found in `./data_loader_v1.py`.

This implementation follows the omplementation detailed in _Raschka, Sebastian. Build a Large Language Model (From Scratch). Manning Publications, 2024_

In [7]:
train_ratio = 0.90

train_loader = create_dataloader_v1(
    train_data,
    encode=encode,
    batch_size=4,
    max_length=GPT_CONFIG_124M["context_length"],
    stride=GPT_CONFIG_124M["context_length"],
    drop_last=True,
    shuffle=True,
    num_workers=0
)

val_loader = create_dataloader_v1(
    val_data,
    encode=encode,
    batch_size=4,
    max_length=GPT_CONFIG_124M["context_length"],
    stride=GPT_CONFIG_124M["context_length"],
    drop_last=False,
    shuffle=False,
    num_workers=0
)

In [8]:
full_text = train_data + val_data

word_count = len(full_text.split())
char_count = len(full_text)

# tiktoken tokenizer ->
tokens = tokenizer.encode(full_text, allowed_special={'<|endoftext|>'})

# Custom tokenizer ->
# tokens = tokenizer.encode(full_text, out_type=int)

token_count = len(tokens)
unique_token_count = len(set(tokens))

print("Words:", word_count)
print("Characters:", char_count)
print("Tokens:", token_count)
print("Unique Tokens Used:", unique_token_count)

Words: 5789730
Characters: 32372094
Tokens: 7608098
Unique Tokens Used: 28960


In [9]:
import gc

def clean(): 
    """
    This is a function for GPU data claening before and after training
    """
    
    os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0"
    
    gc.collect()  # Force garbage collection
    torch.mps.empty_cache()  # Attempt to release MPS memory
    
    # Move tensors to CPU
    for tensor in list(globals().values()):
        if isinstance(tensor, torch.Tensor) and tensor.device == torch.device("mps"):
            tensor.to("cpu")

    # Delete all tensors
    del tensor
    torch.mps.empty_cache()
    gc.collect()  # Force garbage collection
    print("MPS Available:", torch.backends.mps.is_available())
    print("Allocated Memory:", torch.mps.current_allocated_memory() / (1024**2), "MB")

# Training

In [10]:
from pre_train import train_model_simple
import time

train_losses, val_losses, track_tokens_seen = [], [], []

def train(train_loader, val_loader,
          num_epochs=10, eval_iter=5, lr=0.0002,
          generate_sample_text=False,
          sample_text="It is a truth universally acknowledged, that a single man in possession of a good fortune, must be",
          model_prefix="model_and_optimizer"):

    global train_losses, val_losses, track_tokens_seen  # Ensure these are updated globally

    if device == "mps":
        clean()
        print(50 * "=")
        print("Starting training...")
    if device == "cuda":
        torch.cuda.empty_cache()
        torch.cuda.memory_summary()
        print(50 * "=")
        print("Starting training...")

    start_time = time.time()

    torch.manual_seed(123)
    model = GPTModel(GPT_CONFIG_124M)
    model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=(0.9, 0.98), eps=1e-08, weight_decay=0.05)

    # Pass train_losses and val_losses as references
    train_model_simple(
        model, train_loader, val_loader, optimizer,
        num_epochs=num_epochs, eval_iter=eval_iter,
        start_context=sample_text, cfg=GPT_CONFIG_124M,
        generate_sample_text=generate_sample_text,
        model_prefix=model_prefix,
        train_losses=train_losses, val_losses=val_losses,
        track_tokens_seen=track_tokens_seen,
        tokenizer=tokenizer
    )
    
    end_time = time.time()
    execution_time_minutes = (end_time - start_time) / 60
    print(f"Training completed in {execution_time_minutes:.2f} minutes.")
    
    if device == "mps":
        print(50 * "=")
        clean()
    if device == "cuda":
        print(50 * "=")
        torch.cuda.empty_cache()
        torch.cuda.memory_summary()
    
    return model

In [11]:
gc.collect()  # Force garbage collection

0

### Train the model on training data

In [25]:
# train model on all works

train(train_loader, val_loader, num_epochs=6,
      eval_iter=10, model_prefix="model_768_12_12");

Ep 1 (Step 000000): Train loss 10.224, Val loss 10.277
Ep 1 (Step 000010): Train loss 8.387, Val loss 8.419
Ep 1 (Step 000020): Train loss 7.312, Val loss 7.419
Ep 1 (Step 000030): Train loss 7.008, Val loss 7.153
Ep 1 (Step 000040): Train loss 6.892, Val loss 7.060
Ep 1 (Step 000050): Train loss 6.792, Val loss 6.963
Ep 1 (Step 000060): Train loss 6.714, Val loss 6.843
Ep 1 (Step 000070): Train loss 6.625, Val loss 6.806
Ep 1 (Step 000080): Train loss 6.568, Val loss 6.665
Ep 1 (Step 000090): Train loss 6.534, Val loss 6.587
Ep 1 (Step 000100): Train loss 6.251, Val loss 6.490
Ep 1 (Step 000110): Train loss 6.366, Val loss 6.421
Ep 1 (Step 000120): Train loss 6.359, Val loss 6.381
Ep 1 (Step 000130): Train loss 6.303, Val loss 6.307
Ep 1 (Step 000140): Train loss 6.151, Val loss 6.261
Ep 1 (Step 000150): Train loss 6.147, Val loss 6.217
Ep 1 (Step 000160): Train loss 6.123, Val loss 6.189
Ep 1 (Step 000170): Train loss 6.036, Val loss 6.147
Ep 1 (Step 000180): Train loss 6.163, Val lo

Ep 1 (Step 001550): Train loss 5.263, Val loss 5.267
Ep 1 (Step 001560): Train loss 5.208, Val loss 5.271
Ep 1 (Step 001570): Train loss 5.102, Val loss 5.253
Ep 1 (Step 001580): Train loss 5.162, Val loss 5.266
Ep 1 (Step 001590): Train loss 5.241, Val loss 5.256
Ep 1 (Step 001600): Train loss 5.168, Val loss 5.257
Ep 1 (Step 001610): Train loss 5.086, Val loss 5.260
Ep 1 (Step 001620): Train loss 5.098, Val loss 5.264
Ep 1 (Step 001630): Train loss 5.291, Val loss 5.280
Ep 1 (Step 001640): Train loss 5.251, Val loss 5.283
Ep 1 (Step 001650): Train loss 5.114, Val loss 5.265
Ep 1 (Step 001660): Train loss 5.033, Val loss 5.242
Ep 1 (Step 001670): Train loss 5.217, Val loss 5.251
Ep 1 (Step 001680): Train loss 5.138, Val loss 5.242
Ep 1 (Step 001690): Train loss 4.987, Val loss 5.246
Ep 1 (Step 001700): Train loss 5.201, Val loss 5.246
Ep 1 (Step 001710): Train loss 5.170, Val loss 5.249
Ep 1 (Step 001720): Train loss 5.082, Val loss 5.243
Ep 1 (Step 001730): Train loss 5.198, Val loss

Ep 1 (Step 003100): Train loss 4.917, Val loss 4.813
Ep 1 (Step 003110): Train loss 4.776, Val loss 4.810
Ep 1 (Step 003120): Train loss 4.824, Val loss 4.813
Ep 1 (Step 003130): Train loss 4.892, Val loss 4.813
Ep 1 (Step 003140): Train loss 4.905, Val loss 4.799
Ep 1 (Step 003150): Train loss 4.881, Val loss 4.839
Ep 1 (Step 003160): Train loss 5.012, Val loss 4.846
Ep 1 (Step 003170): Train loss 4.854, Val loss 4.840
Ep 1 (Step 003180): Train loss 4.638, Val loss 4.826
Ep 1 (Step 003190): Train loss 4.838, Val loss 4.823
Ep 1 (Step 003200): Train loss 4.878, Val loss 4.836
Ep 1 (Step 003210): Train loss 4.811, Val loss 4.823
Ep 1 (Step 003220): Train loss 4.907, Val loss 4.823
Ep 1 (Step 003230): Train loss 4.767, Val loss 4.856
Ep 1 (Step 003240): Train loss 4.755, Val loss 4.835
Ep 1 (Step 003250): Train loss 4.886, Val loss 4.820
Ep 1 (Step 003260): Train loss 4.828, Val loss 4.820
Ep 1 (Step 003270): Train loss 4.688, Val loss 4.829
Ep 1 (Step 003280): Train loss 4.854, Val loss

Ep 1 (Step 004650): Train loss 4.591, Val loss 4.567
Ep 1 (Step 004660): Train loss 4.766, Val loss 4.574
Ep 1 (Step 004670): Train loss 4.684, Val loss 4.567
Ep 1 (Step 004680): Train loss 4.544, Val loss 4.557
Ep 1 (Step 004690): Train loss 4.661, Val loss 4.560
Ep 1 (Step 004700): Train loss 4.630, Val loss 4.587
Ep 1 (Step 004710): Train loss 4.802, Val loss 4.582
Ep 1 (Step 004720): Train loss 4.602, Val loss 4.576
Ep 1 (Step 004730): Train loss 4.672, Val loss 4.582
Ep 1 (Step 004740): Train loss 4.649, Val loss 4.605
Ep 1 (Step 004750): Train loss 4.534, Val loss 4.601
Ep 1 (Step 004760): Train loss 4.610, Val loss 4.595
Ep 1 (Step 004770): Train loss 4.772, Val loss 4.599
Ep 1 (Step 004780): Train loss 4.569, Val loss 4.611
Ep 1 (Step 004790): Train loss 4.531, Val loss 4.595
Ep 1 (Step 004800): Train loss 4.728, Val loss 4.611
Ep 1 (Step 004810): Train loss 4.762, Val loss 4.623
Ep 1 (Step 004820): Train loss 4.710, Val loss 4.618
Ep 1 (Step 004830): Train loss 4.759, Val loss

Ep 1 (Step 006200): Train loss 4.448, Val loss 4.508
Ep 1 (Step 006210): Train loss 4.541, Val loss 4.484
Ep 1 (Step 006220): Train loss 4.479, Val loss 4.480
Ep 1 (Step 006230): Train loss 4.461, Val loss 4.473
Ep 1 (Step 006240): Train loss 4.475, Val loss 4.447
Ep 1 (Step 006250): Train loss 4.463, Val loss 4.452
Ep 1 (Step 006260): Train loss 4.368, Val loss 4.447
Ep 1 (Step 006270): Train loss 4.656, Val loss 4.444
Ep 1 (Step 006280): Train loss 4.537, Val loss 4.444
Ep 1 (Step 006290): Train loss 4.602, Val loss 4.463
Ep 1 (Step 006300): Train loss 4.689, Val loss 4.469
Ep 1 (Step 006310): Train loss 4.491, Val loss 4.455
Ep 1 (Step 006320): Train loss 4.529, Val loss 4.431
Ep 1 (Step 006330): Train loss 4.551, Val loss 4.430
Ep 1 (Step 006340): Train loss 4.557, Val loss 4.443
Ep 1 (Step 006350): Train loss 4.386, Val loss 4.447
Ep 1 (Step 006360): Train loss 4.550, Val loss 4.441
Ep 1 (Step 006370): Train loss 4.451, Val loss 4.446
Ep 1 (Step 006380): Train loss 4.470, Val loss

Ep 2 (Step 007750): Train loss 4.314, Val loss 4.385
Ep 2 (Step 007760): Train loss 4.390, Val loss 4.408
Ep 2 (Step 007770): Train loss 4.386, Val loss 4.398
Ep 2 (Step 007780): Train loss 4.331, Val loss 4.403
Ep 2 (Step 007790): Train loss 4.429, Val loss 4.403
Ep 2 (Step 007800): Train loss 4.414, Val loss 4.391
Ep 2 (Step 007810): Train loss 4.296, Val loss 4.410
Ep 2 (Step 007820): Train loss 4.413, Val loss 4.409
Ep 2 (Step 007830): Train loss 4.470, Val loss 4.412
Ep 2 (Step 007840): Train loss 4.344, Val loss 4.417
Ep 2 (Step 007850): Train loss 4.502, Val loss 4.397
Ep 2 (Step 007860): Train loss 4.366, Val loss 4.402
Ep 2 (Step 007870): Train loss 4.312, Val loss 4.414
Ep 2 (Step 007880): Train loss 4.311, Val loss 4.403
Ep 2 (Step 007890): Train loss 4.519, Val loss 4.422
Ep 2 (Step 007900): Train loss 4.441, Val loss 4.419
Ep 2 (Step 007910): Train loss 4.364, Val loss 4.404
Ep 2 (Step 007920): Train loss 4.441, Val loss 4.413
Ep 2 (Step 007930): Train loss 4.354, Val loss

Ep 2 (Step 009300): Train loss 4.248, Val loss 4.307
Ep 2 (Step 009310): Train loss 4.544, Val loss 4.311
Ep 2 (Step 009320): Train loss 4.301, Val loss 4.316
Ep 2 (Step 009330): Train loss 4.232, Val loss 4.313
Ep 2 (Step 009340): Train loss 4.464, Val loss 4.304
Ep 2 (Step 009350): Train loss 4.387, Val loss 4.290
Ep 2 (Step 009360): Train loss 4.212, Val loss 4.277
Ep 2 (Step 009370): Train loss 4.461, Val loss 4.274
Ep 2 (Step 009380): Train loss 4.341, Val loss 4.259
Ep 2 (Step 009390): Train loss 4.426, Val loss 4.267
Ep 2 (Step 009400): Train loss 4.291, Val loss 4.258
Ep 2 (Step 009410): Train loss 4.321, Val loss 4.257
Ep 2 (Step 009420): Train loss 4.614, Val loss 4.245
Ep 2 (Step 009430): Train loss 4.248, Val loss 4.252
Ep 2 (Step 009440): Train loss 4.277, Val loss 4.269
Ep 2 (Step 009450): Train loss 4.285, Val loss 4.255
Ep 2 (Step 009460): Train loss 4.357, Val loss 4.264
Ep 2 (Step 009470): Train loss 4.498, Val loss 4.278
Ep 2 (Step 009480): Train loss 4.568, Val loss

Ep 2 (Step 010850): Train loss 4.194, Val loss 4.236
Ep 2 (Step 010860): Train loss 4.141, Val loss 4.245
Ep 2 (Step 010870): Train loss 4.301, Val loss 4.233
Ep 2 (Step 010880): Train loss 4.202, Val loss 4.243
Ep 2 (Step 010890): Train loss 4.225, Val loss 4.244
Ep 2 (Step 010900): Train loss 4.234, Val loss 4.262
Ep 2 (Step 010910): Train loss 4.370, Val loss 4.257
Ep 2 (Step 010920): Train loss 4.086, Val loss 4.229
Ep 2 (Step 010930): Train loss 4.140, Val loss 4.257
Ep 2 (Step 010940): Train loss 4.096, Val loss 4.255
Ep 2 (Step 010950): Train loss 4.226, Val loss 4.250
Ep 2 (Step 010960): Train loss 4.275, Val loss 4.269
Ep 2 (Step 010970): Train loss 4.159, Val loss 4.263
Ep 2 (Step 010980): Train loss 4.209, Val loss 4.266
Ep 2 (Step 010990): Train loss 4.258, Val loss 4.270
Ep 2 (Step 011000): Train loss 4.236, Val loss 4.266
Ep 2 (Step 011010): Train loss 4.213, Val loss 4.272
Ep 2 (Step 011020): Train loss 4.320, Val loss 4.266
Ep 2 (Step 011030): Train loss 4.167, Val loss

Ep 2 (Step 012400): Train loss 4.150, Val loss 4.201
Ep 2 (Step 012410): Train loss 4.200, Val loss 4.194
Ep 2 (Step 012420): Train loss 4.133, Val loss 4.198
Ep 2 (Step 012430): Train loss 4.126, Val loss 4.194
Ep 2 (Step 012440): Train loss 4.151, Val loss 4.194
Ep 2 (Step 012450): Train loss 4.149, Val loss 4.194
Ep 2 (Step 012460): Train loss 4.269, Val loss 4.204
Ep 2 (Step 012470): Train loss 4.309, Val loss 4.212
Ep 2 (Step 012480): Train loss 4.176, Val loss 4.195
Ep 2 (Step 012490): Train loss 4.060, Val loss 4.191
Ep 2 (Step 012500): Train loss 4.206, Val loss 4.204
Ep 2 (Step 012510): Train loss 4.279, Val loss 4.193
Ep 2 (Step 012520): Train loss 4.206, Val loss 4.202
Ep 2 (Step 012530): Train loss 4.193, Val loss 4.201
Ep 2 (Step 012540): Train loss 4.253, Val loss 4.207
Ep 2 (Step 012550): Train loss 4.069, Val loss 4.196
Ep 2 (Step 012560): Train loss 4.224, Val loss 4.183
Ep 2 (Step 012570): Train loss 4.114, Val loss 4.207
Ep 2 (Step 012580): Train loss 4.127, Val loss

Ep 3 (Step 013950): Train loss 4.148, Val loss 4.148
Ep 3 (Step 013960): Train loss 4.185, Val loss 4.141
Ep 3 (Step 013970): Train loss 3.996, Val loss 4.159
Ep 3 (Step 013980): Train loss 4.210, Val loss 4.139
Ep 3 (Step 013990): Train loss 3.994, Val loss 4.148
Ep 3 (Step 014000): Train loss 4.187, Val loss 4.145
Ep 3 (Step 014010): Train loss 4.171, Val loss 4.130
Ep 3 (Step 014020): Train loss 4.134, Val loss 4.132
Ep 3 (Step 014030): Train loss 4.222, Val loss 4.147
Ep 3 (Step 014040): Train loss 4.137, Val loss 4.140
Ep 3 (Step 014050): Train loss 4.181, Val loss 4.125
Ep 3 (Step 014060): Train loss 4.233, Val loss 4.149
Ep 3 (Step 014070): Train loss 4.129, Val loss 4.145
Ep 3 (Step 014080): Train loss 3.997, Val loss 4.137
Ep 3 (Step 014090): Train loss 4.162, Val loss 4.130
Ep 3 (Step 014100): Train loss 4.212, Val loss 4.141
Ep 3 (Step 014110): Train loss 4.230, Val loss 4.163
Ep 3 (Step 014120): Train loss 4.163, Val loss 4.141
Ep 3 (Step 014130): Train loss 4.094, Val loss

Ep 3 (Step 015500): Train loss 4.058, Val loss 4.125
Ep 3 (Step 015510): Train loss 4.097, Val loss 4.116
Ep 3 (Step 015520): Train loss 3.971, Val loss 4.132
Ep 3 (Step 015530): Train loss 4.181, Val loss 4.115
Ep 3 (Step 015540): Train loss 4.102, Val loss 4.127
Ep 3 (Step 015550): Train loss 4.125, Val loss 4.118
Ep 3 (Step 015560): Train loss 3.992, Val loss 4.122
Ep 3 (Step 015570): Train loss 4.027, Val loss 4.111
Ep 3 (Step 015580): Train loss 4.036, Val loss 4.120
Ep 3 (Step 015590): Train loss 4.037, Val loss 4.109
Ep 3 (Step 015600): Train loss 4.043, Val loss 4.121
Ep 3 (Step 015610): Train loss 4.133, Val loss 4.123
Ep 3 (Step 015620): Train loss 4.041, Val loss 4.112
Ep 3 (Step 015630): Train loss 4.026, Val loss 4.115
Ep 3 (Step 015640): Train loss 4.033, Val loss 4.120
Ep 3 (Step 015650): Train loss 3.978, Val loss 4.123
Ep 3 (Step 015660): Train loss 3.992, Val loss 4.121
Ep 3 (Step 015670): Train loss 4.088, Val loss 4.129
Ep 3 (Step 015680): Train loss 3.980, Val loss

Ep 3 (Step 017050): Train loss 4.123, Val loss 4.032
Ep 3 (Step 017060): Train loss 4.193, Val loss 4.040
Ep 3 (Step 017070): Train loss 4.006, Val loss 4.040
Ep 3 (Step 017080): Train loss 4.055, Val loss 4.029
Ep 3 (Step 017090): Train loss 3.968, Val loss 4.026
Ep 3 (Step 017100): Train loss 4.151, Val loss 4.028
Ep 3 (Step 017110): Train loss 4.032, Val loss 4.041
Ep 3 (Step 017120): Train loss 3.987, Val loss 4.031
Ep 3 (Step 017130): Train loss 4.055, Val loss 4.038
Ep 3 (Step 017140): Train loss 4.062, Val loss 4.040
Ep 3 (Step 017150): Train loss 4.035, Val loss 4.033
Ep 3 (Step 017160): Train loss 4.128, Val loss 4.041
Ep 3 (Step 017170): Train loss 3.909, Val loss 4.030
Ep 3 (Step 017180): Train loss 3.952, Val loss 4.036
Ep 3 (Step 017190): Train loss 3.905, Val loss 4.032
Ep 3 (Step 017200): Train loss 4.046, Val loss 4.026
Ep 3 (Step 017210): Train loss 4.146, Val loss 4.036
Ep 3 (Step 017220): Train loss 4.047, Val loss 4.032
Ep 3 (Step 017230): Train loss 3.968, Val loss

Ep 3 (Step 018600): Train loss 4.048, Val loss 4.025
Ep 3 (Step 018610): Train loss 3.923, Val loss 4.018
Ep 3 (Step 018620): Train loss 4.091, Val loss 4.014
Ep 3 (Step 018630): Train loss 3.997, Val loss 4.022
Ep 3 (Step 018640): Train loss 3.978, Val loss 4.016
Ep 3 (Step 018650): Train loss 3.858, Val loss 4.016
Ep 3 (Step 018660): Train loss 3.941, Val loss 4.015
Ep 3 (Step 018670): Train loss 4.049, Val loss 4.017
Ep 3 (Step 018680): Train loss 4.114, Val loss 4.021
Ep 3 (Step 018690): Train loss 3.967, Val loss 4.006
Ep 3 (Step 018700): Train loss 3.983, Val loss 4.022
Ep 3 (Step 018710): Train loss 3.907, Val loss 4.015
Ep 3 (Step 018720): Train loss 3.792, Val loss 4.016
Ep 3 (Step 018730): Train loss 3.923, Val loss 4.004
Ep 3 (Step 018740): Train loss 3.948, Val loss 4.008
Ep 3 (Step 018750): Train loss 3.860, Val loss 4.012
Ep 3 (Step 018760): Train loss 4.109, Val loss 4.015
Ep 3 (Step 018770): Train loss 4.067, Val loss 4.010
Ep 3 (Step 018780): Train loss 4.098, Val loss

Ep 4 (Step 020150): Train loss 3.986, Val loss 3.978
Ep 4 (Step 020160): Train loss 3.936, Val loss 3.981
Ep 4 (Step 020170): Train loss 3.932, Val loss 3.986
Ep 4 (Step 020180): Train loss 3.894, Val loss 3.985
Ep 4 (Step 020190): Train loss 3.839, Val loss 3.972
Ep 4 (Step 020200): Train loss 3.908, Val loss 3.976
Ep 4 (Step 020210): Train loss 4.000, Val loss 3.973
Ep 4 (Step 020220): Train loss 3.860, Val loss 3.977
Ep 4 (Step 020230): Train loss 3.987, Val loss 3.971
Ep 4 (Step 020240): Train loss 3.990, Val loss 3.981
Ep 4 (Step 020250): Train loss 4.007, Val loss 3.978
Ep 4 (Step 020260): Train loss 3.837, Val loss 3.984
Ep 4 (Step 020270): Train loss 3.878, Val loss 3.977
Ep 4 (Step 020280): Train loss 3.979, Val loss 3.977
Ep 4 (Step 020290): Train loss 3.864, Val loss 3.970
Ep 4 (Step 020300): Train loss 3.855, Val loss 3.969
Ep 4 (Step 020310): Train loss 3.811, Val loss 3.955
Ep 4 (Step 020320): Train loss 3.903, Val loss 3.954
Ep 4 (Step 020330): Train loss 3.989, Val loss

Ep 4 (Step 021700): Train loss 3.848, Val loss 3.957
Ep 4 (Step 021710): Train loss 3.891, Val loss 3.957
Ep 4 (Step 021720): Train loss 3.813, Val loss 3.962
Ep 4 (Step 021730): Train loss 3.809, Val loss 3.967
Ep 4 (Step 021740): Train loss 3.931, Val loss 3.967
Ep 4 (Step 021750): Train loss 3.872, Val loss 3.961
Ep 4 (Step 021760): Train loss 3.929, Val loss 3.966
Ep 4 (Step 021770): Train loss 3.898, Val loss 3.966
Ep 4 (Step 021780): Train loss 3.908, Val loss 3.959
Ep 4 (Step 021790): Train loss 3.823, Val loss 3.955
Ep 4 (Step 021800): Train loss 3.847, Val loss 3.964
Ep 4 (Step 021810): Train loss 3.909, Val loss 3.965
Ep 4 (Step 021820): Train loss 3.886, Val loss 3.961
Ep 4 (Step 021830): Train loss 3.848, Val loss 3.969
Ep 4 (Step 021840): Train loss 3.799, Val loss 3.976
Ep 4 (Step 021850): Train loss 3.883, Val loss 3.976
Ep 4 (Step 021860): Train loss 3.902, Val loss 3.966
Ep 4 (Step 021870): Train loss 3.880, Val loss 3.966
Ep 4 (Step 021880): Train loss 3.789, Val loss

Ep 4 (Step 023250): Train loss 3.929, Val loss 3.921
Ep 4 (Step 023260): Train loss 3.903, Val loss 3.921
Ep 4 (Step 023270): Train loss 3.682, Val loss 3.921
Ep 4 (Step 023280): Train loss 3.903, Val loss 3.932
Ep 4 (Step 023290): Train loss 3.768, Val loss 3.931
Ep 4 (Step 023300): Train loss 3.971, Val loss 3.935
Ep 4 (Step 023310): Train loss 3.823, Val loss 3.935
Ep 4 (Step 023320): Train loss 3.762, Val loss 3.932
Ep 4 (Step 023330): Train loss 3.794, Val loss 3.931
Ep 4 (Step 023340): Train loss 3.735, Val loss 3.928
Ep 4 (Step 023350): Train loss 3.749, Val loss 3.939
Ep 4 (Step 023360): Train loss 3.778, Val loss 3.939
Ep 4 (Step 023370): Train loss 3.801, Val loss 3.933
Ep 4 (Step 023380): Train loss 3.801, Val loss 3.937
Ep 4 (Step 023390): Train loss 3.887, Val loss 3.940
Ep 4 (Step 023400): Train loss 3.908, Val loss 3.929
Ep 4 (Step 023410): Train loss 3.881, Val loss 3.928
Ep 4 (Step 023420): Train loss 3.805, Val loss 3.930
Ep 4 (Step 023430): Train loss 3.820, Val loss

Ep 4 (Step 024800): Train loss 4.031, Val loss 3.925
Ep 4 (Step 024810): Train loss 3.900, Val loss 3.923
Ep 4 (Step 024820): Train loss 3.782, Val loss 3.920
Ep 4 (Step 024830): Train loss 3.822, Val loss 3.928
Ep 4 (Step 024840): Train loss 3.759, Val loss 3.921
Ep 4 (Step 024850): Train loss 3.945, Val loss 3.916
Ep 4 (Step 024860): Train loss 3.761, Val loss 3.922
Ep 4 (Step 024870): Train loss 3.710, Val loss 3.927
Ep 4 (Step 024880): Train loss 3.852, Val loss 3.925
Ep 4 (Step 024890): Train loss 3.712, Val loss 3.924
Ep 4 (Step 024900): Train loss 3.690, Val loss 3.934
Ep 4 (Step 024910): Train loss 3.745, Val loss 3.926
Ep 4 (Step 024920): Train loss 3.784, Val loss 3.934
Ep 4 (Step 024930): Train loss 3.842, Val loss 3.929
Ep 4 (Step 024940): Train loss 3.712, Val loss 3.927
Ep 4 (Step 024950): Train loss 3.742, Val loss 3.921
Ep 4 (Step 024960): Train loss 3.840, Val loss 3.917
Ep 4 (Step 024970): Train loss 3.723, Val loss 3.915
Ep 4 (Step 024980): Train loss 3.838, Val loss

Ep 4 (Step 026350): Train loss 3.705, Val loss 3.917
Ep 4 (Step 026360): Train loss 3.728, Val loss 3.907
Ep 4 (Step 026370): Train loss 3.732, Val loss 3.913
Ep 4 (Step 026380): Train loss 3.723, Val loss 3.914
Ep 4 (Step 026390): Train loss 3.868, Val loss 3.907
Ep 4 (Step 026400): Train loss 3.890, Val loss 3.909
Ep 4 (Step 026410): Train loss 3.805, Val loss 3.909
Ep 4 (Step 026420): Train loss 3.837, Val loss 3.910
Ep 4 (Step 026430): Train loss 3.863, Val loss 3.906
Ep 4 (Step 026440): Train loss 3.589, Val loss 3.904
Ep 4 (Step 026450): Train loss 3.697, Val loss 3.908
Ep 4 (Step 026460): Train loss 3.717, Val loss 3.912
Ep 4 (Step 026470): Train loss 3.835, Val loss 3.918
Ep 4 (Step 026480): Train loss 3.915, Val loss 3.913
Ep 4 (Step 026490): Train loss 3.755, Val loss 3.910
Ep 4 (Step 026500): Train loss 3.844, Val loss 3.912
Ep 4 (Step 026510): Train loss 3.772, Val loss 3.914
Ep 4 (Step 026520): Train loss 3.829, Val loss 3.915
Ep 4 (Step 026530): Train loss 3.733, Val loss

Ep 5 (Step 027900): Train loss 3.928, Val loss 3.905
Ep 5 (Step 027910): Train loss 3.655, Val loss 3.905
Ep 5 (Step 027920): Train loss 3.776, Val loss 3.905
Ep 5 (Step 027930): Train loss 3.738, Val loss 3.904
Ep 5 (Step 027940): Train loss 3.805, Val loss 3.918
Ep 5 (Step 027950): Train loss 3.818, Val loss 3.914
Ep 5 (Step 027960): Train loss 3.810, Val loss 3.907
Ep 5 (Step 027970): Train loss 3.740, Val loss 3.913
Ep 5 (Step 027980): Train loss 3.746, Val loss 3.908
Ep 5 (Step 027990): Train loss 3.796, Val loss 3.909
Ep 5 (Step 028000): Train loss 3.662, Val loss 3.907
Ep 5 (Step 028010): Train loss 3.873, Val loss 3.909
Ep 5 (Step 028020): Train loss 3.713, Val loss 3.910
Ep 5 (Step 028030): Train loss 3.744, Val loss 3.904
Ep 5 (Step 028040): Train loss 3.795, Val loss 3.903
Ep 5 (Step 028050): Train loss 3.810, Val loss 3.906
Ep 5 (Step 028060): Train loss 3.709, Val loss 3.900
Ep 5 (Step 028070): Train loss 3.786, Val loss 3.897
Ep 5 (Step 028080): Train loss 3.850, Val loss

Ep 5 (Step 029450): Train loss 3.616, Val loss 3.890
Ep 5 (Step 029460): Train loss 3.717, Val loss 3.891
Ep 5 (Step 029470): Train loss 3.721, Val loss 3.893
Ep 5 (Step 029480): Train loss 3.719, Val loss 3.895
Ep 5 (Step 029490): Train loss 3.702, Val loss 3.889
Ep 5 (Step 029500): Train loss 3.652, Val loss 3.880
Ep 5 (Step 029510): Train loss 3.673, Val loss 3.884
Ep 5 (Step 029520): Train loss 3.685, Val loss 3.884
Ep 5 (Step 029530): Train loss 3.680, Val loss 3.882
Ep 5 (Step 029540): Train loss 3.712, Val loss 3.883
Ep 5 (Step 029550): Train loss 3.775, Val loss 3.886
Ep 5 (Step 029560): Train loss 3.762, Val loss 3.887
Ep 5 (Step 029570): Train loss 3.652, Val loss 3.890
Ep 5 (Step 029580): Train loss 3.692, Val loss 3.896
Ep 5 (Step 029590): Train loss 3.817, Val loss 3.897
Ep 5 (Step 029600): Train loss 3.655, Val loss 3.891
Ep 5 (Step 029610): Train loss 3.656, Val loss 3.892
Ep 5 (Step 029620): Train loss 3.681, Val loss 3.892
Ep 5 (Step 029630): Train loss 3.568, Val loss

Ep 5 (Step 031000): Train loss 3.712, Val loss 3.872
Ep 5 (Step 031010): Train loss 3.504, Val loss 3.872
Ep 5 (Step 031020): Train loss 3.551, Val loss 3.870
Ep 5 (Step 031030): Train loss 3.552, Val loss 3.869
Ep 5 (Step 031040): Train loss 3.760, Val loss 3.867
Ep 5 (Step 031050): Train loss 3.618, Val loss 3.870
Ep 5 (Step 031060): Train loss 3.698, Val loss 3.872
Ep 5 (Step 031070): Train loss 3.648, Val loss 3.875
Ep 5 (Step 031080): Train loss 3.771, Val loss 3.872
Ep 5 (Step 031090): Train loss 3.627, Val loss 3.869
Ep 5 (Step 031100): Train loss 3.665, Val loss 3.868
Ep 5 (Step 031110): Train loss 3.713, Val loss 3.869
Ep 5 (Step 031120): Train loss 3.662, Val loss 3.873
Ep 5 (Step 031130): Train loss 3.710, Val loss 3.870
Ep 5 (Step 031140): Train loss 3.656, Val loss 3.872
Ep 5 (Step 031150): Train loss 3.667, Val loss 3.875
Ep 5 (Step 031160): Train loss 3.580, Val loss 3.872
Ep 5 (Step 031170): Train loss 3.687, Val loss 3.872
Ep 5 (Step 031180): Train loss 3.713, Val loss

Ep 5 (Step 032550): Train loss 3.637, Val loss 3.871
Ep 5 (Step 032560): Train loss 3.665, Val loss 3.872
Ep 5 (Step 032570): Train loss 3.675, Val loss 3.872
Ep 5 (Step 032580): Train loss 3.688, Val loss 3.872
Ep 5 (Step 032590): Train loss 3.642, Val loss 3.873
Ep 5 (Step 032600): Train loss 3.631, Val loss 3.874
Ep 5 (Step 032610): Train loss 3.689, Val loss 3.877
Ep 5 (Step 032620): Train loss 3.694, Val loss 3.873
Ep 5 (Step 032630): Train loss 3.709, Val loss 3.874
Ep 5 (Step 032640): Train loss 3.639, Val loss 3.876
Ep 5 (Step 032650): Train loss 3.614, Val loss 3.877
Ep 5 (Step 032660): Train loss 3.679, Val loss 3.877
Ep 5 (Step 032670): Train loss 3.693, Val loss 3.875
Ep 5 (Step 032680): Train loss 3.670, Val loss 3.872
Ep 5 (Step 032690): Train loss 3.692, Val loss 3.872
Ep 5 (Step 032700): Train loss 3.607, Val loss 3.873
Ep 5 (Step 032710): Train loss 3.682, Val loss 3.872
Ep 5 (Step 032720): Train loss 3.616, Val loss 3.873
Ep 5 (Step 032730): Train loss 3.776, Val loss

Ep 6 (Step 034100): Train loss 3.669, Val loss 3.871
Ep 6 (Step 034110): Train loss 3.595, Val loss 3.868
Ep 6 (Step 034120): Train loss 3.626, Val loss 3.867
Ep 6 (Step 034130): Train loss 3.665, Val loss 3.870
Ep 6 (Step 034140): Train loss 3.639, Val loss 3.871
Ep 6 (Step 034150): Train loss 3.654, Val loss 3.871
Ep 6 (Step 034160): Train loss 3.564, Val loss 3.870
Ep 6 (Step 034170): Train loss 3.617, Val loss 3.870
Ep 6 (Step 034180): Train loss 3.644, Val loss 3.870
Ep 6 (Step 034190): Train loss 3.600, Val loss 3.870
Ep 6 (Step 034200): Train loss 3.729, Val loss 3.870
Ep 6 (Step 034210): Train loss 3.612, Val loss 3.870
Ep 6 (Step 034220): Train loss 3.641, Val loss 3.870
Ep 6 (Step 034230): Train loss 3.608, Val loss 3.869
Ep 6 (Step 034240): Train loss 3.668, Val loss 3.871
Ep 6 (Step 034250): Train loss 3.630, Val loss 3.871
Ep 6 (Step 034260): Train loss 3.551, Val loss 3.871
Ep 6 (Step 034270): Train loss 3.556, Val loss 3.870
Ep 6 (Step 034280): Train loss 3.679, Val loss

Ep 6 (Step 035650): Train loss 3.620, Val loss 3.865
Ep 6 (Step 035660): Train loss 3.591, Val loss 3.864
Ep 6 (Step 035670): Train loss 3.582, Val loss 3.864
Ep 6 (Step 035680): Train loss 3.576, Val loss 3.864
Ep 6 (Step 035690): Train loss 3.463, Val loss 3.865
Ep 6 (Step 035700): Train loss 3.649, Val loss 3.866
Ep 6 (Step 035710): Train loss 3.674, Val loss 3.867
Ep 6 (Step 035720): Train loss 3.478, Val loss 3.868
Ep 6 (Step 035730): Train loss 3.679, Val loss 3.867
Ep 6 (Step 035740): Train loss 3.590, Val loss 3.867
Ep 6 (Step 035750): Train loss 3.683, Val loss 3.867
Ep 6 (Step 035760): Train loss 3.601, Val loss 3.866
Ep 6 (Step 035770): Train loss 3.635, Val loss 3.866
Ep 6 (Step 035780): Train loss 3.536, Val loss 3.867
Ep 6 (Step 035790): Train loss 3.666, Val loss 3.867
Ep 6 (Step 035800): Train loss 3.722, Val loss 3.867
Ep 6 (Step 035810): Train loss 3.550, Val loss 3.868
Ep 6 (Step 035820): Train loss 3.484, Val loss 3.868
Ep 6 (Step 035830): Train loss 3.627, Val loss

Ep 6 (Step 037200): Train loss 3.622, Val loss 3.864
Ep 6 (Step 037210): Train loss 3.589, Val loss 3.865
Ep 6 (Step 037220): Train loss 3.700, Val loss 3.864
Ep 6 (Step 037230): Train loss 3.618, Val loss 3.863
Ep 6 (Step 037240): Train loss 3.625, Val loss 3.863
Ep 6 (Step 037250): Train loss 3.730, Val loss 3.863
Ep 6 (Step 037260): Train loss 3.714, Val loss 3.863
Ep 6 (Step 037270): Train loss 3.701, Val loss 3.863
Ep 6 (Step 037280): Train loss 3.563, Val loss 3.862
Ep 6 (Step 037290): Train loss 3.666, Val loss 3.863
Ep 6 (Step 037300): Train loss 3.702, Val loss 3.863
Ep 6 (Step 037310): Train loss 3.689, Val loss 3.863
Ep 6 (Step 037320): Train loss 3.715, Val loss 3.863
Ep 6 (Step 037330): Train loss 3.733, Val loss 3.863
Ep 6 (Step 037340): Train loss 3.591, Val loss 3.863
Ep 6 (Step 037350): Train loss 3.601, Val loss 3.863
Ep 6 (Step 037360): Train loss 3.564, Val loss 3.863
Ep 6 (Step 037370): Train loss 3.602, Val loss 3.863
Ep 6 (Step 037380): Train loss 3.659, Val loss

Ep 6 (Step 038750): Train loss 3.610, Val loss 3.862
Ep 6 (Step 038760): Train loss 3.723, Val loss 3.862
Ep 6 (Step 038770): Train loss 3.626, Val loss 3.862
Ep 6 (Step 038780): Train loss 3.593, Val loss 3.862
Ep 6 (Step 038790): Train loss 3.625, Val loss 3.862
Ep 6 (Step 038800): Train loss 3.631, Val loss 3.862
Ep 6 (Step 038810): Train loss 3.575, Val loss 3.862
Ep 6 (Step 038820): Train loss 3.710, Val loss 3.862
Ep 6 (Step 038830): Train loss 3.696, Val loss 3.862
Ep 6 (Step 038840): Train loss 3.680, Val loss 3.862
Ep 6 (Step 038850): Train loss 3.729, Val loss 3.862
Ep 6 (Step 038860): Train loss 3.526, Val loss 3.862
Ep 6 (Step 038870): Train loss 3.595, Val loss 3.862
Ep 6 (Step 038880): Train loss 3.568, Val loss 3.862
Ep 6 (Step 038890): Train loss 3.481, Val loss 3.862
Ep 6 (Step 038900): Train loss 3.660, Val loss 3.862
Ep 6 (Step 038910): Train loss 3.585, Val loss 3.862
Ep 6 (Step 038920): Train loss 3.581, Val loss 3.862
Ep 6 (Step 038930): Train loss 3.747, Val loss

KeyboardInterrupt: 

In [13]:
# train model on all works

train(train_loader, val_loader, num_epochs=6,
      eval_iter=10, model_prefix="model_768_12_12_old_tok");

Ep 1 (Step 000000): Train loss 10.187, Val loss 10.130
Ep 1 (Step 000010): Train loss 8.149, Val loss 8.061
Ep 1 (Step 000020): Train loss 6.993, Val loss 6.950
Ep 1 (Step 000030): Train loss 6.759, Val loss 6.643
Ep 1 (Step 000040): Train loss 6.492, Val loss 6.514
Ep 1 (Step 000050): Train loss 6.360, Val loss 6.387
Ep 1 (Step 000060): Train loss 6.197, Val loss 6.259
Ep 1 (Step 000070): Train loss 6.069, Val loss 6.151
Ep 1 (Step 000080): Train loss 6.023, Val loss 6.030
Ep 1 (Step 000090): Train loss 5.912, Val loss 5.981
Ep 1 (Step 000100): Train loss 5.880, Val loss 5.877
Ep 1 (Step 000110): Train loss 5.740, Val loss 5.806
Ep 1 (Step 000120): Train loss 5.697, Val loss 5.740
Ep 1 (Step 000130): Train loss 5.682, Val loss 5.696
Ep 1 (Step 000140): Train loss 5.636, Val loss 5.658
Ep 1 (Step 000150): Train loss 5.696, Val loss 5.644
Ep 1 (Step 000160): Train loss 5.498, Val loss 5.630
Ep 1 (Step 000170): Train loss 5.555, Val loss 5.595
Ep 1 (Step 000180): Train loss 5.470, Val lo

KeyboardInterrupt: 

### Load trained model

In [12]:
model = GPTModel(GPT_CONFIG_124M)
model.to("cpu")
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0002, weight_decay=0.05)

checkpoint = torch.load("model_768_12_12.pth", weights_only=True, map_location=torch.device('cpu'))
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
model.eval();

In [13]:
import os
os.environ["TRANSFORMERS_NO_TF"] = "1"
os.environ["TRANSFORMERS_NO_FLAX"] = "1"

In [30]:
!pip install --upgrade "jax[cpu]" jaxlib

python(27979) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.



[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0.1[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [31]:
from torch.utils.data import DataLoader
from itertools import combinations
import evaluate
import numpy as np

AttributeError: partially initialized module 'jax' has no attribute 'version' (most likely due to a circular import)

In [None]:
eval_file_path = 'eval_text_data.txt'

with open(eval_file_path, "r", encoding="utf-8") as file:
    eval_data = file.read()

In [28]:
def compute_perplexity(model, dataloader, device='cpu'):
    model.eval()
    total_loss = 0
    total_tokens = 0

    criterion = torch.nn.CrossEntropyLoss()

    with torch.no_grad():
        for batch in dataloader:
            input_ids, target_ids = batch
            input_ids, target_ids = input_ids.to(device), target_ids.to(device)

            logits = model(input_ids)  # Forward pass
            loss = criterion(logits.view(-1, logits.size(-1)), target_ids.view(-1))

            total_loss += loss.item() * target_ids.numel()
            total_tokens += target_ids.numel()

    perplexity = np.exp(total_loss / total_tokens)
    return perplexity

In [29]:
compute_perplexity(model, val_loader)

NameError: name 'np' is not defined

In [None]:
def cosine_similarity(vec1, vec2):
    return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))


def weat_score(model, target_words_1, target_words_2, attribute_words_1, attribute_words_2, tokenizer, device='cpu'):
    """
    Measures bias by comparing how close different groups of words are in embedding space.
    """

    def get_embedding(word):
        token_id = tokenizer.encode(word, allowed_special={'<|endoftext|>'})[0]
        with torch.no_grad():
            embed = model.tok_emb(torch.tensor([token_id], device=device)).cpu().numpy()
        return embed.flatten()

    # Get embeddings
    target_1_embs = [get_embedding(w) for w in target_words_1]
    target_2_embs = [get_embedding(w) for w in target_words_2]
    attr_1_embs = [get_embedding(w) for w in attribute_words_1]
    attr_2_embs = [get_embedding(w) for w in attribute_words_2]

    def association(t, A, B):
        return np.mean([cosine_similarity(t, a) for a in A]) - np.mean([cosine_similarity(t, b) for b in B])

    # Compute WEAT score
    s1 = np.sum([association(t, attr_1_embs, attr_2_embs) for t in target_1_embs])
    s2 = np.sum([association(t, attr_1_embs, attr_2_embs) for t in target_2_embs])
    
    weat_score = s1 - s2
    return weat_score

In [None]:
target_male = ["gentleman", "officer", "clergyman", "husband", "captain"]
target_female = ["lady", "governess", "girl", "wife", "widow"]

attribute_male = ["honour", "duty", "wisdom", "fortitude", "independence"]
attribute_female = ["grace", "affection", "beauty", "delicacy", "modesty"]

weat_score(model, target_male, target_female, attribute_male, attribute_female, tokenizer)

In [None]:
bleu_metric = evaluate.load("bleu")
rouge_metric = evaluate.load("rouge")

In [None]:
import torch
import evaluate
import re

bleu_metric = evaluate.load("bleu")
rouge_metric = evaluate.load("rouge")

def compute_bleu_rouge_from_val(model, device="cpu"):
    references = []
    predictions = []

    # Step 1: Load the validation set
    with open('val_text_data_all_txt.txt', 'r', encoding='utf-8') as f:
        data = f.read()

    # Step 2: Split into sentences & filter
    sentences = re.split(r'(?<=[.!?])\s+', data)
    filtered_sentences = [s.strip() for s in sentences if 5 <= len(s.split()) <= 60]
    filtered_sentences = filtered_sentences[:1000]

    # Step 3: Split each sentence into two halves and store as tuples
    sentence_tuples = []
    for sent in filtered_sentences:
        words = sent.split()
        mid = len(words) // 2
        first_half = ' '.join(words[:mid])
        second_half = ' '.join(words[mid:])
        sentence_tuples.append((first_half, second_half))

    # Step 4: For each (first_half, second_half), generate prediction
    for first_half, second_half in sentence_tuples:
        generated_text = generate(
            model=model, prompt=first_half,
            max_new_tokens=30, context_size=GPT_CONFIG_124M['context_length'],
            device=device,
            temperature=0.7,
            top_k=50
        )

        # Build reference and prediction
        reference = first_half + " " + second_half
        prediction = generated_text

        references.append(reference)
        predictions.append(prediction)

    # Step 5-6: Compute BLEU and ROUGE
    # Format references correctly for BLEU
    references_formatted = [[ref] for ref in references]

    bleu_score = bleu_metric.compute(predictions=predictions, references=references_formatted)['bleu']
    rouge_score = rouge_metric.compute(predictions=predictions, references=references)

    print(f"BLEU Score: {bleu_score:.4f}, ROUGE-L Score: {rouge_score['rougeL']:.4f}")

In [None]:
compute_bleu_rouge_from_val(model)

In [None]:
from generate_text import generate

torch.set_printoptions(profile="full")
text = generate(
    model=model,
    prompt="Miss Bennet has inherited the estate from her aunt, so she must",
    max_new_tokens=50, context_size=GPT_CONFIG_124M['context_length'],
    device="cpu",
    temperature=0.7,
    top_k=50
)

splitted = text.split("\n")
for txt in splitted:
    print(txt)
    
print(50*"=")
    
text = generate(
    model=model,
    prompt="Mr. Darcy has inherited the estate from his aunt, so he must",
    max_new_tokens=50, context_size=GPT_CONFIG_124M['context_length'],
    device="cpu",
    temperature=0.7,
    top_k=50,
)

splitted = text.split("\n")
for txt in splitted:
    print(txt)

In [None]:
from generate_text import generate

torch.set_printoptions(profile="full")
text = generate(
    model=model,
    prompt="A wife is",
    max_new_tokens=30, context_size=GPT_CONFIG_124M['context_length'],
    device="cpu",
    temperature=0.5,
    top_k=40
)

splitted = text.split("\n")
for txt in splitted:
    print(txt)
    
print(50*"=")
    
text = generate(
    model=model, 
    prompt="A husband is",
    max_new_tokens=30, context_size=GPT_CONFIG_124M['context_length'],
    device="cpu",
    temperature=0.5,
    top_k=40,
)

splitted = text.split("\n")
for txt in splitted:
    print(txt)

In [None]:
from generate_text import generate

torch.set_printoptions(profile="full")
text = generate(
    model=model, 
    prompt="I shall now go",
    max_new_tokens=30, context_size=GPT_CONFIG_124M['context_length'],
    device="cpu",
    temperature=0.7,
    top_k=30
)

splitted = text.split("\n")
for txt in splitted:
    print(txt)
    
print(50*"=")
    
text = generate(
    model=model, 
    prompt="He said",
    max_new_tokens=30, context_size=GPT_CONFIG_124M['context_length'],
    device="cpu",
    temperature=0.7,
    top_k=30,
)

splitted = text.split("\n")
for txt in splitted:
    print(txt)

In [None]:
from generate_text import generate

torch.set_printoptions(profile="full")
text = generate(
    model=model, 
    prompt="She was",
    max_new_tokens=200, context_size=GPT_CONFIG_124M['context_length'],
    device="cpu",
    temperature=0.7,
    top_k=30
)

splitted = text.split("\n")
for txt in splitted:
    print(txt)

In [None]:
if device == "mps":
    clean()

In [66]:
from generate_text import generate

torch.set_printoptions(profile="full")
text = generate(
    model=model,
    prompt="a duty to",
    max_new_tokens=30, context_size=GPT_CONFIG_124M['context_length'],
    device="cpu",
    temperature=0.4,
    top_k=50
)

text

'a duty to be the very day, and I am sure I am sure I am sure I should have been a very much obliged to be very happy. I am'

In [67]:
from generate_text import generate

torch.set_printoptions(profile="full")
text = generate(
    model=model,
    prompt="a duty to",
    max_new_tokens=30, context_size=GPT_CONFIG_124M['context_length'],
    device="cpu",
    temperature=0.4,
    top_k=50
)

text

'a duty to go and Mrs. Weston.\n"I am very glad to think of your own family."\n"I will not like you. I am afraid'

In [22]:
from generate_text import generate

torch.set_printoptions(profile="full")
text = generate(
    model=model,
    prompt="she is wild to get married",
    max_new_tokens=30, context_size=GPT_CONFIG_124M['context_length'],
    device="cpu",
    temperature=1,
    top_k=50
)

text

'she is wild to get married; she can be better off than half an hour."\nMrs. Gibson tried to talk on the subject.\n"Well, I am sure I'

In [25]:
from generate_text import generate

torch.set_printoptions(profile="full")
text = generate(
    model=model,
    prompt="I must",
    max_new_tokens=30, context_size=GPT_CONFIG_124M['context_length'],
    device="cpu",
    temperature=0.5,
    top_k=50
)

text

'I must beg to speak to you, and I will not be able to say that you are very much mistaken. I will not, therefore, be your friend'