In [1]:
from datasets import load_dataset
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import datasets
import sentencepiece as spm
import tqdm
import os
import re
import math

In [2]:
#Tiny Shakespeare là tập dữ liệu gồm khoảng 1 triệu ký tự (≈ 40 tác phẩm của Shakespeare), trích từ kịch, bi kịch, hài kịch của ông — 
#ví dụ như Hamlet, Macbeth, Julius Caesar, The Tempest, v.v.

dataset = load_dataset(
    "text",
    data_files="https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt",
)

Downloading data:   0%|          | 0.00/1.12M [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

In [3]:
full_text = "\n".join(dataset["train"][i]["text"] for i in range(len(dataset["train"])))

print(len(full_text))           # khoảng 1,118,394 ký tự
print(full_text[:500])

1115393
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor


In [4]:
text = "\n".join(dataset["train"][i]["text"] for i in range(len(dataset["train"])))
with open("tiny_shakespeare.txt", "w", encoding="utf-8") as f:
    f.write(text)

In [5]:
# tạo Tokenize bằng BPE
spm.SentencePieceTrainer.train(
    input="tiny_shakespeare.txt",
    model_prefix="tinyshakespeare",
    vocab_size=8000,        # số lượng token trong từ vựng (bạn có thể đổi)
    character_coverage=1.0, # 1.0 = bao phủ hết ký tự trong tiếng Anh
    model_type="bpe"        # dùng Byte Pair Encoding
)

sentencepiece_trainer.cc(78) LOG(INFO) Starts training with : 
trainer_spec {
  input: tiny_shakespeare.txt
  input_format: 
  model_prefix: tinyshakespeare
  model_type: BPE
  vocab_size: 8000
  self_test_sample_size: 0
  character_coverage: 1
  input_sentence_size: 0
  shuffle_input_sentence: 1
  seed_sentencepiece_size: 1000000
  shrinking_factor: 0.75
  max_sentence_length: 4192
  num_threads: 16
  num_sub_iterations: 2
  max_sentencepiece_length: 16
  split_by_unicode_script: 1
  split_by_number: 1
  split_by_whitespace: 1
  split_digits: 0
  pretokenization_delimiter: 
  treat_whitespace_as_suffix: 0
  allow_whitespace_only_pieces: 0
  required_chars: 
  byte_fallback: 0
  vocabulary_output_piece_score: 1
  train_extremely_large_corpus: 0
  seed_sentencepieces_file: 
  hard_vocab_limit: 1
  use_all_vocab: 0
  unk_id: 0
  bos_id: 1
  eos_id: 2
  pad_id: -1
  unk_piece: <unk>
  bos_piece: <s>
  eos_piece: </s>
  pad_piece: <pad>
  unk_surface:  ⁇ 
  enable_differential_privacy: 0
 

In [6]:
class TinyShakespeareData(torch.utils.data.Dataset):
    def __init__(self, max_seq_length=None):
        dataset = load_dataset(
            "text",
            data_files="https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
        )
        self.dataset = dataset["train"]
        self.sp = spm.SentencePieceProcessor(model_file="tinyshakespeare.model")
        self.max_seq_length = max_seq_length

    def __len__(self):
        return len(self.dataset)
#Thực hiện tokenize + chuyển sang ID cho từng mẫu dữ liệu.
    def __getitem__(self, idx):
        text = self.dataset[idx]["text"]
        encoded = self.sp.encode_as_ids(text, add_bos=True, add_eos=True)
        if self.max_seq_length:
            encoded = encoded[:self.max_seq_length]
        return torch.tensor(encoded, dtype=torch.long)
# Để gom dữ liệu thành batch, thêm padding cho các câu có độ dài khác nhau.
    def collate_function(self, batch):
        return torch.nn.utils.rnn.pad_sequence(
            batch, batch_first=True, padding_value=3
        )

arthest
bpe_model_trainer.cc(268) LOG(INFO) Added: freq=4 size=7580 all=21613 active=1017 piece=▁minstrel
bpe_model_trainer.cc(268) LOG(INFO) Added: freq=4 size=7600 all=21594 active=998 piece=▁stooping
bpe_model_trainer.cc(159) LOG(INFO) Updating active symbols. max_freq=4 min_freq=3
bpe_model_trainer.cc(268) LOG(INFO) Added: freq=4 size=7620 all=21574 active=1060 piece=▁caparison
bpe_model_trainer.cc(268) LOG(INFO) Added: freq=4 size=7640 all=21558 active=1044 piece=▁merriment
bpe_model_trainer.cc(268) LOG(INFO) Added: freq=4 size=7660 all=21539 active=1025 piece=▁utterance
bpe_model_trainer.cc(268) LOG(INFO) Added: freq=4 size=7680 all=21519 active=1005 piece=▁jealousies
bpe_model_trainer.cc(268) LOG(INFO) Added: freq=4 size=7700 all=21500 active=986 piece=▁indifferent
bpe_model_trainer.cc(159) LOG(INFO) Updating active symbols. max_freq=4 min_freq=3
bpe_model_trainer.cc(268) LOG(INFO) Added: freq=3 size=7720 all=21494 active=1069 piece=sy
bpe_model_trainer.cc(268) LOG(INFO) Added: 

In [7]:
data = TinyShakespeareData()
print(len(data))
print(data[0])         # Tensor token ID

40000
tensor([   1,  423,  807, 7959,    2])


# Create A Simple Decoder only Transformer Network

In [8]:
#implement positional encoder 
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length):
        super(PositionalEncoding, self).__init__()
        
        pe = torch.zeros(max_seq_length, d_model)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        self.register_buffer('pe', pe.unsqueeze(0))
        
    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

In [9]:
#define decoder feedforward layers 
class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super(PositionWiseFeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

In [10]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        #create queries, keys, values and output
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    #impelement attention formula     
    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
        attn_probs = torch.softmax(attn_scores, dim=-1)
        output = torch.matmul(attn_probs, V)
        return output

    #function to split input values into multi-head format    
    def split_heads(self, x):
        batch_size, seq_length, d_model = x.size()
        return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)

    #function to concatenate multihead attention values together     
    def combine_heads(self, x):
        batch_size, _, seq_length, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)
        
    def forward(self, Q, K, V, mask=None):
        Q = self.split_heads(self.W_q(Q))
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))
        
        attn_output = self.scaled_dot_product_attention(Q, K, V, mask)
        output = self.W_o(self.combine_heads(attn_output))
        return output

In [11]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(DecoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask):
        attn_output = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x

In [12]:
class Transformer(nn.Module):
    def __init__(self, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout):
        super(Transformer, self).__init__()
        self.decoder_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_seq_length)

        self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])

        self.fc = nn.Linear(d_model, tgt_vocab_size)
        self.dropout = nn.Dropout(dropout)

    #function to generate look-ahead mask to allow for autoregressive decoding 
    def generate_mask(self, tgt):
        tgt_mask = (tgt != 3).unsqueeze(1).unsqueeze(3) # sentencepiece pad_id = 3
        seq_length = tgt.size(1)
        nopeak_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length), diagonal=1)).bool().to(device)
        tgt_mask = tgt_mask & nopeak_mask
        return tgt_mask

    def forward(self, tgt):
        tgt_mask = self.generate_mask(tgt)
        tgt_embedded = self.dropout(self.positional_encoding(self.decoder_embedding(tgt)))

        dec_output = tgt_embedded
        for dec_layer in self.decoder_layers:
            dec_output = dec_layer(dec_output, tgt_mask)

        output = self.fc(dec_output)
        return output

# Train model


In [13]:
####define constants across project######

DIMENSIONS = 64
DATASET = "train"
DATASET_PERCENTAGE = 1
VOCAB_SIZE = 8000
NUM_OF_EPOCHS = 100
LEARNING_RATE = 0.005
BATCH_SIZE= 16
NUM_HEADS=4
NUM_LAYERS=1
D_FF =  2048
MAX_SEQ_LENGTH = 512
DROPOUT = 0.1
INPUT_SENTENCE_SIZE = 100000

In [14]:

#run on gpu 
def getDevice():
  is_cuda = torch.cuda.is_available()
  return "cuda:0" if is_cuda else "cpu"

#function to find latest epoch
def find_latest_epoch_file(path='./'):
    epoch_files = [f for f in os.listdir(path) if re.match(r'transformer_epoch_\d+\.pt', f)]
    if epoch_files:
        # Extracting epoch numbers from the files and finding the max
        latest_epoch = max([int(f.split('_')[2].split('.')[0]) for f in epoch_files])
        return latest_epoch, f"./transformer_epoch_{latest_epoch}.pt"
    else:
        return 0, None

#function to load the latest epoch file if it exists
def load_latest_checkpoint(model, path='./'):
    latest_epoch, latest_file = find_latest_epoch_file(path)
    if latest_file:
        print(f"Resuming training from epoch {latest_epoch+1}")
        model.load_state_dict(torch.load(latest_file, map_location=torch.device(getDevice())))
    else:
        print("No checkpoint found, starting from beginning")
    return latest_epoch

In [15]:

device = getDevice()
print(f"Device = {device}")

#load dataset
dataset = TinyShakespeareData(max_seq_length=1200)
# ds = dataset.TinyStoriesData("roneneldan/TinyStories", f"train[:{DATASET_PERCENTAGE}%]", MAX_SEQ_LENGTH)
dl = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=dataset.collate_function)

# instantiate transformer and define loss function
transformer = Transformer(VOCAB_SIZE, DIMENSIONS, NUM_HEADS, NUM_LAYERS, D_FF, MAX_SEQ_LENGTH, DROPOUT).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=3) # sentencepiece pad_id = 3
optimizer = optim.Adam(transformer.parameters(), lr=5e-5, betas=(0.9, 0.98), eps=1e-9)

transformer.train()

# start_epoch = load_latest_checkpoint(transformer)
start_epoch = 0

#define transformer training loop
for epoch in range(start_epoch, NUM_OF_EPOCHS):
  total_loss = 0
  for tgt_data in tqdm.tqdm(dl, desc=f"Epoch {epoch+1}/{NUM_OF_EPOCHS}", unit="batch"):
    optimizer.zero_grad()
    output = transformer(tgt_data[:, :-1].to(device))
    loss = criterion(output.to(device).contiguous().view(-1, VOCAB_SIZE), tgt_data[:, 1:].to(device).contiguous().view(-1))
    if torch.isnan(loss):
        print("NaN loss detected!")
        print(f"Output stats: min={output.min()}, max={output.max()}, mean={output.mean()}")
        print(f"Grad norm: {sum(p.grad.norm().item() for p in transformer.parameters() if p.grad is not None):.2f}")
        break
    loss.backward()
    optimizer.step()
    total_loss += loss.item()
  print(f"Epoch {epoch+1}/{NUM_OF_EPOCHS}, Loss: {total_loss/2500}")
  torch.save(transformer.state_dict(), f"./transformer_epoch_{epoch+1}.pt")

Device = cuda:0


Epoch 1/100: 100%|██████████| 2500/2500 [00:16<00:00, 154.70batch/s]


Epoch 1/100, Loss: 6.4907613386154175


Epoch 2/100: 100%|██████████| 2500/2500 [00:15<00:00, 165.38batch/s]


Epoch 2/100, Loss: 5.602468134307862


Epoch 3/100: 100%|██████████| 2500/2500 [00:15<00:00, 164.84batch/s]


Epoch 3/100, Loss: 5.464146460151673


Epoch 4/100: 100%|██████████| 2500/2500 [00:15<00:00, 165.15batch/s]


Epoch 4/100, Loss: 5.379345463562012


Epoch 5/100: 100%|██████████| 2500/2500 [00:14<00:00, 166.96batch/s]


Epoch 5/100, Loss: 5.313296541690827


Epoch 6/100: 100%|██████████| 2500/2500 [00:14<00:00, 167.26batch/s]


Epoch 6/100, Loss: 5.264464497470856


Epoch 7/100: 100%|██████████| 2500/2500 [00:14<00:00, 168.89batch/s]


Epoch 7/100, Loss: 5.226444433116913


Epoch 8/100: 100%|██████████| 2500/2500 [00:15<00:00, 165.61batch/s]


Epoch 8/100, Loss: 5.1922862080574035


Epoch 9/100: 100%|██████████| 2500/2500 [00:15<00:00, 165.18batch/s]


Epoch 9/100, Loss: 5.166414796161652


Epoch 10/100: 100%|██████████| 2500/2500 [00:14<00:00, 168.78batch/s]


Epoch 10/100, Loss: 5.145179809188843


Epoch 11/100: 100%|██████████| 2500/2500 [00:14<00:00, 167.40batch/s]


Epoch 11/100, Loss: 5.124935648632049


Epoch 12/100: 100%|██████████| 2500/2500 [00:14<00:00, 168.01batch/s]


Epoch 12/100, Loss: 5.108396948051452


Epoch 13/100: 100%|██████████| 2500/2500 [00:14<00:00, 167.05batch/s]


Epoch 13/100, Loss: 5.089454315567017


Epoch 14/100: 100%|██████████| 2500/2500 [00:14<00:00, 169.30batch/s]


Epoch 14/100, Loss: 5.076739964103699


Epoch 15/100: 100%|██████████| 2500/2500 [00:14<00:00, 168.78batch/s]


Epoch 15/100, Loss: 5.063729370689392


Epoch 16/100: 100%|██████████| 2500/2500 [00:15<00:00, 163.94batch/s]


Epoch 16/100, Loss: 5.053423874473572


Epoch 17/100: 100%|██████████| 2500/2500 [00:14<00:00, 167.53batch/s]


Epoch 17/100, Loss: 5.038590624904632


Epoch 18/100: 100%|██████████| 2500/2500 [00:15<00:00, 165.34batch/s]


Epoch 18/100, Loss: 5.029842113208771


Epoch 19/100: 100%|██████████| 2500/2500 [00:14<00:00, 167.13batch/s]


Epoch 19/100, Loss: 5.0249458683013914


Epoch 20/100: 100%|██████████| 2500/2500 [00:15<00:00, 166.43batch/s]


Epoch 20/100, Loss: 5.014156785774231


Epoch 21/100: 100%|██████████| 2500/2500 [00:15<00:00, 162.62batch/s]


Epoch 21/100, Loss: 5.006104911231994


Epoch 22/100: 100%|██████████| 2500/2500 [00:15<00:00, 165.89batch/s]


Epoch 22/100, Loss: 4.996459922981262


Epoch 23/100: 100%|██████████| 2500/2500 [00:14<00:00, 167.07batch/s]


Epoch 23/100, Loss: 4.990102435970306


Epoch 24/100: 100%|██████████| 2500/2500 [00:14<00:00, 166.78batch/s]


Epoch 24/100, Loss: 4.984401873588562


Epoch 25/100: 100%|██████████| 2500/2500 [00:14<00:00, 168.40batch/s]


Epoch 25/100, Loss: 4.97209411611557


Epoch 26/100: 100%|██████████| 2500/2500 [00:14<00:00, 167.49batch/s]


Epoch 26/100, Loss: 4.9658076934814455


Epoch 27/100: 100%|██████████| 2500/2500 [00:15<00:00, 165.89batch/s]


Epoch 27/100, Loss: 4.958009644317627


Epoch 28/100: 100%|██████████| 2500/2500 [00:14<00:00, 168.58batch/s]


Epoch 28/100, Loss: 4.950210235595703


Epoch 29/100: 100%|██████████| 2500/2500 [00:14<00:00, 166.91batch/s]


Epoch 29/100, Loss: 4.942711063575745


Epoch 30/100: 100%|██████████| 2500/2500 [00:15<00:00, 165.67batch/s]


Epoch 30/100, Loss: 4.936476554107666


Epoch 31/100: 100%|██████████| 2500/2500 [00:14<00:00, 168.53batch/s]


Epoch 31/100, Loss: 4.932857426071167


Epoch 32/100: 100%|██████████| 2500/2500 [00:14<00:00, 167.77batch/s]


Epoch 32/100, Loss: 4.925640077114105


Epoch 33/100: 100%|██████████| 2500/2500 [00:15<00:00, 165.38batch/s]


Epoch 33/100, Loss: 4.92064830789566


Epoch 34/100: 100%|██████████| 2500/2500 [00:14<00:00, 167.65batch/s]


Epoch 34/100, Loss: 4.913227581501007


Epoch 35/100: 100%|██████████| 2500/2500 [00:14<00:00, 168.42batch/s]


Epoch 35/100, Loss: 4.908433297538758


Epoch 36/100: 100%|██████████| 2500/2500 [00:14<00:00, 168.44batch/s]


Epoch 36/100, Loss: 4.903354203701019


Epoch 37/100: 100%|██████████| 2500/2500 [00:14<00:00, 169.39batch/s]


Epoch 37/100, Loss: 4.8992122128486635


Epoch 38/100: 100%|██████████| 2500/2500 [00:14<00:00, 166.90batch/s]


Epoch 38/100, Loss: 4.892876208305359


Epoch 39/100: 100%|██████████| 2500/2500 [00:14<00:00, 167.77batch/s]


Epoch 39/100, Loss: 4.889693330192566


Epoch 40/100: 100%|██████████| 2500/2500 [00:15<00:00, 166.35batch/s]


Epoch 40/100, Loss: 4.885829689598084


Epoch 41/100: 100%|██████████| 2500/2500 [00:14<00:00, 168.60batch/s]


Epoch 41/100, Loss: 4.882037014770508


Epoch 42/100: 100%|██████████| 2500/2500 [00:15<00:00, 165.45batch/s]


Epoch 42/100, Loss: 4.882475939941406


Epoch 43/100: 100%|██████████| 2500/2500 [00:15<00:00, 163.79batch/s]


Epoch 43/100, Loss: 4.87653163356781


Epoch 44/100: 100%|██████████| 2500/2500 [00:14<00:00, 167.40batch/s]


Epoch 44/100, Loss: 4.873552209377289


Epoch 45/100: 100%|██████████| 2500/2500 [00:15<00:00, 165.16batch/s]


Epoch 45/100, Loss: 4.8687593083381655


Epoch 46/100: 100%|██████████| 2500/2500 [00:14<00:00, 166.96batch/s]


Epoch 46/100, Loss: 4.86837989320755


Epoch 47/100: 100%|██████████| 2500/2500 [00:15<00:00, 166.10batch/s]


Epoch 47/100, Loss: 4.866215673446655


Epoch 48/100: 100%|██████████| 2500/2500 [00:14<00:00, 167.77batch/s]


Epoch 48/100, Loss: 4.86330034828186


Epoch 49/100: 100%|██████████| 2500/2500 [00:15<00:00, 166.54batch/s]


Epoch 49/100, Loss: 4.860872428417206


Epoch 50/100: 100%|██████████| 2500/2500 [00:14<00:00, 168.33batch/s]


Epoch 50/100, Loss: 4.860247629928589


Epoch 51/100: 100%|██████████| 2500/2500 [00:14<00:00, 168.45batch/s]


Epoch 51/100, Loss: 4.858152048206329


Epoch 52/100: 100%|██████████| 2500/2500 [00:14<00:00, 167.70batch/s]


Epoch 52/100, Loss: 4.85367345123291


Epoch 53/100: 100%|██████████| 2500/2500 [00:14<00:00, 166.81batch/s]


Epoch 53/100, Loss: 4.853688846874237


Epoch 54/100: 100%|██████████| 2500/2500 [00:14<00:00, 168.06batch/s]


Epoch 54/100, Loss: 4.851578966999054


Epoch 55/100: 100%|██████████| 2500/2500 [00:14<00:00, 168.22batch/s]


Epoch 55/100, Loss: 4.847841150760651


Epoch 56/100: 100%|██████████| 2500/2500 [00:14<00:00, 167.75batch/s]


Epoch 56/100, Loss: 4.845991917514801


Epoch 57/100: 100%|██████████| 2500/2500 [00:15<00:00, 164.79batch/s]


Epoch 57/100, Loss: 4.844288680839538


Epoch 58/100: 100%|██████████| 2500/2500 [00:15<00:00, 166.37batch/s]


Epoch 58/100, Loss: 4.839144359874726


Epoch 59/100: 100%|██████████| 2500/2500 [00:15<00:00, 162.33batch/s]


Epoch 59/100, Loss: 4.837457181739807


Epoch 60/100: 100%|██████████| 2500/2500 [00:14<00:00, 167.03batch/s]


Epoch 60/100, Loss: 4.832370100307465


Epoch 61/100: 100%|██████████| 2500/2500 [00:15<00:00, 162.91batch/s]


Epoch 61/100, Loss: 4.8298901969909664


Epoch 62/100: 100%|██████████| 2500/2500 [00:14<00:00, 167.23batch/s]


Epoch 62/100, Loss: 4.828798954677581


Epoch 63/100: 100%|██████████| 2500/2500 [00:14<00:00, 169.05batch/s]


Epoch 63/100, Loss: 4.82811969833374


Epoch 64/100: 100%|██████████| 2500/2500 [00:14<00:00, 167.72batch/s]


Epoch 64/100, Loss: 4.826485701942444


Epoch 65/100: 100%|██████████| 2500/2500 [00:15<00:00, 165.32batch/s]


Epoch 65/100, Loss: 4.822965472507477


Epoch 66/100: 100%|██████████| 2500/2500 [00:15<00:00, 164.75batch/s]


Epoch 66/100, Loss: 4.8195434619903565


Epoch 67/100: 100%|██████████| 2500/2500 [00:15<00:00, 165.88batch/s]


Epoch 67/100, Loss: 4.816717303848266


Epoch 68/100: 100%|██████████| 2500/2500 [00:14<00:00, 167.60batch/s]


Epoch 68/100, Loss: 4.81610318107605


Epoch 69/100: 100%|██████████| 2500/2500 [00:15<00:00, 163.05batch/s]


Epoch 69/100, Loss: 4.812074776363373


Epoch 70/100: 100%|██████████| 2500/2500 [00:14<00:00, 167.72batch/s]


Epoch 70/100, Loss: 4.813260789966583


Epoch 71/100: 100%|██████████| 2500/2500 [00:14<00:00, 168.20batch/s]


Epoch 71/100, Loss: 4.8060898976325985


Epoch 72/100: 100%|██████████| 2500/2500 [00:15<00:00, 165.87batch/s]


Epoch 72/100, Loss: 4.8054229727745055


Epoch 73/100: 100%|██████████| 2500/2500 [00:15<00:00, 164.71batch/s]


Epoch 73/100, Loss: 4.805174408149719


Epoch 74/100: 100%|██████████| 2500/2500 [00:14<00:00, 168.45batch/s]


Epoch 74/100, Loss: 4.800145466327667


Epoch 75/100: 100%|██████████| 2500/2500 [00:14<00:00, 168.24batch/s]


Epoch 75/100, Loss: 4.799695800876617


Epoch 76/100: 100%|██████████| 2500/2500 [00:15<00:00, 164.77batch/s]


Epoch 76/100, Loss: 4.799083662414551


Epoch 77/100: 100%|██████████| 2500/2500 [00:15<00:00, 166.50batch/s]


Epoch 77/100, Loss: 4.795402353477478


Epoch 78/100: 100%|██████████| 2500/2500 [00:14<00:00, 168.46batch/s]


Epoch 78/100, Loss: 4.796173832035064


Epoch 79/100: 100%|██████████| 2500/2500 [00:14<00:00, 166.68batch/s]


Epoch 79/100, Loss: 4.791892127037048


Epoch 80/100: 100%|██████████| 2500/2500 [00:15<00:00, 164.58batch/s]


Epoch 80/100, Loss: 4.791772122955322


Epoch 81/100: 100%|██████████| 2500/2500 [00:14<00:00, 167.36batch/s]


Epoch 81/100, Loss: 4.789695254421234


Epoch 82/100: 100%|██████████| 2500/2500 [00:14<00:00, 168.80batch/s]


Epoch 82/100, Loss: 4.7867449262619015


Epoch 83/100: 100%|██████████| 2500/2500 [00:15<00:00, 161.59batch/s]


Epoch 83/100, Loss: 4.788256359004975


Epoch 84/100: 100%|██████████| 2500/2500 [00:15<00:00, 164.46batch/s]


Epoch 84/100, Loss: 4.785131950950623


Epoch 85/100: 100%|██████████| 2500/2500 [00:15<00:00, 164.75batch/s]


Epoch 85/100, Loss: 4.78343588809967


Epoch 86/100: 100%|██████████| 2500/2500 [00:15<00:00, 164.85batch/s]


Epoch 86/100, Loss: 4.783589002704621


Epoch 87/100: 100%|██████████| 2500/2500 [00:15<00:00, 162.46batch/s]


Epoch 87/100, Loss: 4.781925857162475


Epoch 88/100: 100%|██████████| 2500/2500 [00:14<00:00, 167.49batch/s]


Epoch 88/100, Loss: 4.77997910270691


Epoch 89/100: 100%|██████████| 2500/2500 [00:14<00:00, 166.89batch/s]


Epoch 89/100, Loss: 4.780307880210876


Epoch 90/100: 100%|██████████| 2500/2500 [00:14<00:00, 169.60batch/s]


Epoch 90/100, Loss: 4.7809762424469


Epoch 91/100: 100%|██████████| 2500/2500 [00:15<00:00, 164.41batch/s]


Epoch 91/100, Loss: 4.778331746768951


Epoch 92/100: 100%|██████████| 2500/2500 [00:14<00:00, 171.03batch/s]


Epoch 92/100, Loss: 4.777836310005188


Epoch 93/100: 100%|██████████| 2500/2500 [00:15<00:00, 166.13batch/s]


Epoch 93/100, Loss: 4.777773439121247


Epoch 94/100: 100%|██████████| 2500/2500 [00:14<00:00, 167.79batch/s]


Epoch 94/100, Loss: 4.777365462398529


Epoch 95/100: 100%|██████████| 2500/2500 [00:14<00:00, 168.65batch/s]


Epoch 95/100, Loss: 4.779969942951203


Epoch 96/100: 100%|██████████| 2500/2500 [00:14<00:00, 167.41batch/s]


Epoch 96/100, Loss: 4.7779397121429446


Epoch 97/100: 100%|██████████| 2500/2500 [00:15<00:00, 165.33batch/s]


Epoch 97/100, Loss: 4.779223730754852


Epoch 98/100: 100%|██████████| 2500/2500 [00:14<00:00, 167.01batch/s]


Epoch 98/100, Loss: 4.780716458702088


Epoch 99/100: 100%|██████████| 2500/2500 [00:14<00:00, 168.11batch/s]


Epoch 99/100, Loss: 4.779374151706696


Epoch 100/100: 100%|██████████| 2500/2500 [00:14<00:00, 167.12batch/s]

Epoch 100/100, Loss: 4.77840571975708





In [16]:
import math
from torch.nn import functional as F

# --- Evaluation function ---
def evaluate_perplexity(model, dataloader, device, pad_id=3):
    model.eval()
    total_loss, total_tokens = 0, 0

    with torch.no_grad():
        for batch in tqdm.tqdm(dataloader, desc="Evaluating", unit="batch"):
            batch = batch.to(device)
            logits = model(batch[:, :-1])
            loss = F.cross_entropy(
                logits.reshape(-1, logits.size(-1)),
                batch[:, 1:].reshape(-1),
                ignore_index=pad_id,
                reduction='sum'  # để cộng dồn loss chính xác theo token
            )
            total_loss += loss.item()
            total_tokens += (batch[:, 1:] != pad_id).sum().item()

    avg_loss = total_loss / total_tokens
    ppl = math.exp(avg_loss)
    return avg_loss, ppl

# Load checkpoint tốt nhất hoặc epoch cuối
transformer.load_state_dict(torch.load("./transformer_epoch_100.pt", map_location=device))
transformer.to(device)

# Evaluate perplexity
avg_loss, ppl = evaluate_perplexity(transformer, dl, device, pad_id=3)
print(f"\nFinal evaluation:")
print(f"  Average loss: {avg_loss:.4f}")
print(f"  Perplexity: {ppl:.2f}")

# Perplexity (PP) = exp(loss).
#→ Tức là mỗi bước dự đoán, model “phân vân” giữa ~104 từ khả dĩ, khá yếu nhưng ổn với dữ liệu nhỏ.

Evaluating: 100%|██████████| 2500/2500 [00:07<00:00, 326.98batch/s]


Final evaluation:
  Average loss: 4.6469
  Perplexity: 104.26





In [17]:
import torch
import torch.nn as nn
import sentencepiece as spm


device = getDevice()
print(f"🟢 Device = {device}")

def generate(prompt_text, max_new_tokens=50, temperature=1.0):
    # Load tokenizer
    sp = spm.SentencePieceProcessor(model_file='tinyshakespeare.model')

    # Encode input
    input_ids = torch.tensor(sp.encode_as_ids(prompt_text, add_bos=True)).long().unsqueeze(0).to(device)

    # Load model
    transformer = Transformer(
        VOCAB_SIZE,
        DIMENSIONS,
        NUM_HEADS,
        NUM_LAYERS,
        D_FF,
        MAX_SEQ_LENGTH,
        DROPOUT
    ).to(device)

    transformer.eval()
    load_latest_checkpoint(transformer)

    print(f"\n=== GENERATION START ===\nPrompt: {prompt_text}\n")

    # Generate step-by-step
    for step in range(max_new_tokens):
        with torch.no_grad():
            logits = transformer(input_ids)
            logits = logits[:, -1, :] / temperature

            # Kiểm tra NaN/Inf
            if torch.any(torch.isnan(logits)) or torch.any(torch.isinf(logits)):
                print("NaN/Inf detected — dùng greedy decoding.")
                next_token = torch.argmax(logits, dim=-1, keepdim=True)
            else:
                probs = torch.softmax(logits, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1)

            token_id = next_token.item()

            # Nếu gặp token <eos>
            if token_id == 2:
                print("Encountered <eos> token — stopping.")
                break

            # Ghép token mới vào chuỗi
            input_ids = torch.cat([input_ids, next_token], dim=1)

            # In tạm thời đoạn text hiện tại
            partial_text = sp.decode(input_ids.squeeze(0).tolist())
            print(f"[{step+1:02d}] {partial_text}")

    # Decode toàn bộ output
    final_output = sp.decode(input_ids.squeeze(0).tolist())
    print("\n=== FINAL OUTPUT ===")
    print(final_output)
    print("====================\n")

    return {"story": final_output}


# Chạy trực tiếp file
if __name__ == "__main__":
    prompt = "Once upon a time"
    generate(prompt, max_new_tokens=50)
# vì là dư liệu là nhạc kịch nên xu hương genneration cũng theo phong cách đó

🟢 Device = cuda:0
Resuming training from epoch 101

=== GENERATION START ===
Prompt: Once upon a time

[01] Once upon a time,
[02] Once upon a time,l
[03] Once upon a time,l mercy
[04] Once upon a time,l mercy,
[05] Once upon a time,l mercy, so
[06] Once upon a time,l mercy, soath
[07] Once upon a time,l mercy, soath,
[08] Once upon a time,l mercy, soath, but
[09] Once upon a time,l mercy, soath, but so
[10] Once upon a time,l mercy, soath, but so could
[11] Once upon a time,l mercy, soath, but so could therefore
[12] Once upon a time,l mercy, soath, but so could therefore!
[13] Once upon a time,l mercy, soath, but so could therefore! Senator
[14] Once upon a time,l mercy, soath, but so could therefore! Senator.
Encountered <eos> token — stopping.

=== FINAL OUTPUT ===
Once upon a time,l mercy, soath, but so could therefore! Senator.

