In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from torchvision import transforms
import torch.optim as optim
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
import random 
from tqdm import tqdm
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import torch.nn.functional as F
from tqdm.auto import tqdm
from datetime import datetime
import wandb
import time

In [2]:
train_data = pd.read_csv("../dataset/train.csv")
validation_data = pd.read_csv("../dataset/validation.csv")
test_data = pd.read_csv("../dataset/test.csv")

train_data.rename(columns={"highlights": "summaries", "article":"articles"}, inplace=True)
validation_data.rename(columns={"highlights": "summaries","article":"articles"}, inplace=True)
test_data.rename(columns={"highlights": "summaries", "article":"articles"}, inplace=True)


In [3]:
train_data["article_word_count"] = train_data["articles"].astype(str).apply(lambda x: len(x.split()))
train_data["summary_word_count"] = train_data["summaries"].astype(str).apply(lambda x: len(x.split()))

validation_data["article_word_count"] = validation_data["articles"].astype(str).apply(lambda x: len(x.split()))
validation_data["summary_word_count"] = validation_data["summaries"].astype(str).apply(lambda x: len(x.split()))

test_data["article_word_count"] = test_data["articles"].astype(str).apply(lambda x: len(x.split()))
test_data["summary_word_count"] = test_data["summaries"].astype(str).apply(lambda x: len(x.split()))


In [4]:
train_data = train_data[train_data["article_word_count"] < 512]
validation_data = validation_data[validation_data["article_word_count"] < 512]
test_data = test_data[test_data["article_word_count"] < 1024]

In [5]:
train_sample = train_data.sample(frac=1, random_state=1)
validation_sample = validation_data.sample(frac=1, random_state=1)
test_sample = test_data.sample(frac=0.01, random_state=1)
train_sample.info()

<class 'pandas.core.frame.DataFrame'>
Index: 97959 entries, 282281 to 227701
Data columns (total 5 columns):
 #   Column              Non-Null Count  Dtype 
---  ------              --------------  ----- 
 0   id                  97959 non-null  object
 1   articles            97959 non-null  object
 2   summaries           97959 non-null  object
 3   article_word_count  97959 non-null  int64 
 4   summary_word_count  97959 non-null  int64 
dtypes: int64(2), object(3)
memory usage: 4.5+ MB


In [6]:
max_len_article = train_sample["article_word_count"].max()
print(max_len_article)
max_len_summary = train_sample["summary_word_count"].max()
print(max_len_summary)
# if max_len_article is None:
#     self.max_len_article = max(len(numericalize(art)) for art in articles) + 2  # +2 cho SOS/EOS
# else:
#     self.max_len_article = max_len_article
    
# if max_len_summary is None:
#     self.max_len_summary = max(len(numericalize(summ)) for summ in summaries) + 2
# else:
#     self.max_len_summary = max_len_summary


511
1185


In [7]:
EMBEDDING_FILE = "../Embedding/glove-wiki-gigaword-100.txt"
vocab, embeddings = [], []
with open(EMBEDDING_FILE, 'rt', encoding='utf-8') as ef:
    full_content = ef.read().strip().split('\n')
for i in range(len(full_content)):
    i_word = full_content[i].split(' ')[0]
    i_embeddings = [float(val) for val in full_content[i].split(' ')[1:]]
    i_embeddings.extend([0.0, 0.0, 0.0, 0.0])
    vocab.append(i_word)
    embeddings.append(i_embeddings)

embs_npa = np.array(embeddings)

unk_embedding = np.mean(embs_npa, axis=0).tolist()

dim = embs_npa.shape[1]
sos_embedding = [0.0] * dim
sos_embedding[-3] = 1.0
eos_embedding = [0.0] * dim
eos_embedding[-2] = 1.0
pad_embedding = [0.0] * dim
pad_embedding[-4] = 1.0
# unk_embedding = [0.0] * dim
# unk_embedding[-1] = 1.0

# Update vocab and embeddings
vocab = ["<PAD>", "<SOS>", "<EOS>", "<UNK>"] + vocab
embeddings = [pad_embedding, sos_embedding,
              eos_embedding, unk_embedding] + embeddings

vocab_npa = np.array(vocab)
embs_npa = np.array(embeddings)


def tokenize(text):
    return text.lower().strip().split()


stoi_dict = {word: idx for idx, word in enumerate(vocab_npa)}
_unk_idx = stoi_dict["<UNK>"]


def stoi(string, stoi_dict=stoi_dict):
    return stoi_dict.get(string, _unk_idx)


def numericalize(text):
    tokenized_text = tokenize(text)
    return [
        stoi(token)
        for token in tokenized_text
    ]

print(embs_npa.shape[0])
embedding_layer = torch.nn.Embedding.from_pretrained(torch.FloatTensor(embeddings),
                                                     freeze=False,
                                                     padding_idx=stoi("<PAD>"))
embedding_layer.to(device)
print("Embedding shape:", np.array(embeddings).shape) 
print("<PAD> embedding last 4 dims:", embeddings[stoi("<PAD>")][-4:])
print("<SOS> embedding last 4 dims:", embeddings[stoi("<SOS>")][-4:])
print("Word 'the' embedding last 4 dims:", embeddings[stoi("the")][-4:])

25004
Embedding shape: (25004, 104)
<PAD> embedding last 4 dims: [1.0, 0.0, 0.0, 0.0]
<SOS> embedding last 4 dims: [0.0, 1.0, 0.0, 0.0]
Word 'the' embedding last 4 dims: [0.0, 0.0, 0.0, 0.0]


In [8]:
# pad_idx = stoi("<PAD>")
# sos_idx = stoi("<SOS>")
# eos_idx = stoi("<EOS>")

In [9]:
class Seq2SeqDataset(Dataset):
    def __init__(self, articles, summaries, stoi, max_len_article=512, max_len_summary=256):
        self.articles = articles  # List of articles
        self.summaries = summaries  # List of summaries
        self.stoi = stoi  # String-to-index dictionary
        self.pad_idx = stoi("<PAD>")
        self.sos_idx = stoi("<SOS>")
        self.eos_idx = stoi("<EOS>")
        
        # Determine max lengths if not provided
        self.max_len_article = max_len_article or max(len(a.split()) for a in articles) + 2
        self.max_len_summary = max_len_summary or max(len(s.split()) for s in summaries) + 2

    def __len__(self):
        return len(self.articles)

    def __getitem__(self, idx):
        def process_text(text, max_len):
            tokens = [self.sos_idx] + [self.stoi(w) for w in text.split()] + [self.eos_idx]  # Tokenize and add SOS/EOS
            tokens = tokens[:max_len] + [self.pad_idx] * (max_len - len(tokens))  # Pad to max length
            return torch.tensor(tokens), len(tokens)

        article_tokens, article_len = process_text(self.articles[idx], self.max_len_article)
        summary_tokens, summary_len = process_text(self.summaries[idx], self.max_len_summary)
        
        return {
            'article': article_tokens,  # Encoded article
            'article_len': torch.tensor(article_len),
            'summary': summary_tokens,  # Encoded summary
            'summary_len': torch.tensor(summary_len)
        }

def collate_fn(batch):
    # Batch là list các dict {'article': ..., 'summary': ...}
    return {
        'article': torch.stack([item['article'] for item in batch]),
        'article_len': torch.tensor([item['article_len'] for item in batch]),
        'summary': torch.stack([item['summary'] for item in batch]),
        'summary_len': torch.tensor([item['summary_len'] for item in batch])
    }



In [10]:
# DataLoader setup
# torch.set_printoptions(profile="full")
torch.set_printoptions(profile="default")
train_dataset = Seq2SeqDataset(train_sample['articles'].tolist(), train_sample['summaries'].tolist(), stoi)
print(train_dataset[268]["article"])
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=collate_fn)

valid_dataset= Seq2SeqDataset(validation_sample['articles'].tolist(), validation_sample['summaries'].tolist(), stoi)
valid_loader = DataLoader(valid_dataset, batch_size=16, shuffle=True, collate_fn=collate_fn)

tensor([    1,     3,    69,     3,     3,     3,   460,  4355,     8,  8090,
            3,     9,    52,     7,    75,   997,     3,     3,    18,    88,
            3,     3,   974,     3, 10960,     3,     3,    11,  2871,    10,
            4,  1139,     7,    11,  1617,  3123,     7,     3,     3,     3,
            3,     3,     3,     3,    11,   304,    90,     3,    11,   791,
           35,     4,   252,     8,    92,     3,     3,     3,  2248,     4,
         3123,   139,    60,    77,     3,   204,     8,     3,   117,    24,
         5272,     4,     3,     3,     3,  2599,    11,  3999,    29,     3,
           16,   214,     4,  3123,     3,  1820,     4,     3,   505,     3,
           46,     3,     3,   242,  3565, 16584,    50,  7368,  2772,    47,
           34,  5423,     9,   396,  1126,     8,     3,   111,    34,     3,
            3,  6216,    16,     4,     3,  3986,     7,    51,   531,    18,
            3,     3,    11,   304,  4256,    11,  3123,     7, 

In [11]:
# def collate_fn(batch):
#     return {
#         'article': torch.stack([item['article'] for item in batch]),
#         'article_len': torch.stack([item['article_len'] for item in batch]),
#         'summary': torch.stack([item['summary'] for item in batch]),
#         'summary_len': torch.stack([item['summary_len'] for item in batch])
#     }

# train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)

## Simple Architecture Seq2Seq


In [12]:
class SimpleEncoder(nn.Module):
    def __init__(self, embedding_layer, hidden_dim):
        super().__init__()
        self.embedding = embedding_layer  # Sử dụng embedding có sẵn
        self.lstm = nn.LSTM(
            input_size=self.embedding.embedding_dim,
            hidden_size=hidden_dim,
            batch_first=True,
            bidirectional=False  # Unidirectional
        )
    
    def forward(self, x, seq_lens):  # x = batch input sequences
        # Bước 1: Embedding
        x = self.embedding(x)  # [batch_size, max_len, emb_dim]
        
        # Bước 2: Pack để bỏ qua padding tokens
        packed = pack_padded_sequence(
            input=x,
            lengths=seq_lens.cpu(),  # Chuyển sang CPU tensor
            batch_first=True,
            enforce_sorted=False  # Không cần sắp xếp theo độ dài
        )
        
        # Bước 3: LSTM (chỉ xử lý độ dài thực) 
        '''
        packed_output: dữ liệu thực sự được xử lý
        hidden/cell lưu trạng thái cuối cùng của mỗi sequence
        Input (padded):       Packed LSTM:         Output (unpacked):
        [1,2,3,0,0]    -->   [1,2,3,4,5,6,7,8]       --> [h1_t1,h1_t2,h1_t3,0,0]
        [4,5,6,7,8]     hidden_dim, hd,... hd(8cai)  --> [h2_t1,h2_t2,h2_t3,h2_t4,h2_t5]
        '''
        packed_output, (hidden, cell) = self.lstm(packed)
        # Output: (batch_size, seq_len, hidden_dim)
        # Bước 4: Unpack nếu cần dùng attention
        
        output, _ = pad_packed_sequence(packed_output, batch_first=True)
        
        return output, (hidden, cell)  # hidden shape: [1, batch_size, hidden_dim]

In [13]:
class SimpleAttention(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        # Kết hợp cả encoder outputs và decoder hidden state
        self.energy = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),  # Nhận input là [encoder_output + decoder_hidden]
            nn.Tanh(),                              # thêm chút phi tuyến
            nn.Linear(hidden_dim, 1, bias=False)    #  Trả về 1 chiều attention scores
        )
    
    def forward(self, decoder_hidden, encoder_outputs, mask=None):
        # Bước 1: Chuẩn bị decoder_hidden để cộng với encoder_outputs
        # decoder_hidden: [batch_size, hidden_dim]
        # encoder_outputs: [batch_size, seq_len, hidden_dim]
        # Copy dọc theo seq_len
        #########  decoder_hidden = decoder_hidden.unsqueeze(1).expand_as(encoder_outputs)  # [batch_size, seq_len, hidden_dim]
        decoder_hidden = decoder_hidden.unsqueeze(1)  # [batch_size, 1, hidden_dim]
        decoder_hidden = decoder_hidden.repeat(1, encoder_outputs.size(1), 1)
        # Bước 2: Tính energy từ sự kết hợp encoder-decoder
        combined = torch.cat([encoder_outputs, decoder_hidden], dim=2)  # [batch_size, seq_len, hidden_dim * 2]
        scores = self.energy(combined).squeeze(2)  # [batch_size, seq_len]
        
        # Bước 3: Áp dụng mask và softmax
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e10)
        attn_weights = F.softmax(scores, dim=1)    # [batch_size, seq_len]
        
        # Bước 4: Tính context vector: # [batch_size,1, hidden_dim] x [batch_size, seq_len, hidden_dim]
                #  context = [batch_size, 1, hidden_dim]squeeze(1) loại chiều 1 -> [batch_size, hidden_dim]
        context = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs).squeeze(1) 
     
        # debug  
        '''
        Encoder Outputs:  "<SOS>"   "cat"    "sat"   "<EOS>"  "<PAD>"
                        [0.1,0.2] [0.3,0.4] [0.5,0.6] [0.7,0.8] [0.0,0.0]
        Attention Weights:  0.12      0.53      0.29      0.06      0.0
                        ↓         ↓         ↓         ↓         ↓
        Context Vector:  = 0.12*[0.1,0.2] + 0.53*[0.3,0.4] + ... = [0.35, 0.45]
        '''
        '''
        attn_weights (unsqueezed):   encoder_outputs:       context:
        [ [ [0.2, 0.5, 0.3] ]    @  [[0.1,0.2],        =  [ [0.32, 0.42] ]
        [ [0.1, 0.7, 0.2] ]        [0.3,0.4],            [0.92, 1.02] ]
                                    [0.5,0.6] ]
        '''
        # print("Decoder hidden:", decoder_hidden.shape)
        # print("Encoder outputs:", encoder_outputs.shape)
        # context, attn_weights = self.attention(decoder_hidden, encoder_outputs)
        # print("Attention weights:", attn_weights)  # Xem model đang tập trung vào đâu
        return context, attn_weights

In [14]:
class SimpleDecoder(nn.Module):
    def __init__(self, embedding_layer, hidden_dim, vocab_size):
        super().__init__()
        self.embedding = embedding_layer
        self.lstm = nn.LSTMCell(
            input_size=self.embedding.embedding_dim + hidden_dim,  # Thay đổi 1: Thêm context_dim
            hidden_size=hidden_dim
        )
        self.attention = SimpleAttention(hidden_dim)
        self.fc_out = nn.Linear(hidden_dim * 2, vocab_size)  # Thay đổi 2: Kết hợp hidden + context

    def forward(self, x, prev_hidden, prev_cell, encoder_outputs, mask=None):
        # Bước 1: Embedding
        x = self.embedding(x)  # [batch_size] -> [batch_size, emb_dim]
        
        if prev_hidden.dim() == 3:  # If it's [num_layers, batch_size, hidden_dim]
            prev_hidden = prev_hidden[-1]  # Take last layer's hidden state
        
        context, attn_weights = self.attention(prev_hidden, encoder_outputs, mask)
        
        # Bước 3: Kết hợp embedding và context làm input cho LSTM
        lstm_input = torch.cat([x, context], dim=1)  # [batch_size, emb_dim + hidden_dim]
        
        # Bước 4: LSTM step
        hidden, cell = self.lstm(lstm_input, (prev_hidden, prev_cell))
        
        # Bước 5: Kết hợp hidden và context để dự đoán từ
        output_input = torch.cat([hidden, context], dim=1)
        output = self.fc_out(output_input)
        
        return output, hidden, cell, attn_weights


In [15]:
class Seq2SeqModel(nn.Module):
    def __init__(self, embedding_layer, hidden_dim, vocab_size):
        super().__init__()
        self.encoder = SimpleEncoder(embedding_layer, hidden_dim)
        self.decoder = SimpleDecoder(embedding_layer, hidden_dim, vocab_size)
        self.vocab_size = vocab_size
        self.start_id = stoi("<SOS>")  # Thêm start token ID

    def forward(self, src, src_lens, trg=None, max_len=256, teacher_forcing_ratio=0.5):
        # Encoder forward
        enc_outputs, (hidden, cell) = self.encoder(src, src_lens)
        
        # Chuẩn bị decoder
        batch_size = src.size(0)
        if trg is None:  # Inference mode
            max_len = max_len
            trg = torch.full((batch_size,), self.start_id, dtype=torch.long, device=src.device)
        else:  # Training mode
            max_len = trg.size(1)
        
        # Tensor lưu outputs
        outputs = torch.zeros(batch_size, max_len, self.vocab_size).to(src.device)
        
        # Khởi tạo input đầu tiên
        x = torch.full((batch_size,), self.start_id, dtype=torch.long, device=src.device)  # Đổi tên x_t -> x
        
        # Squeeze the encoder hidden states for the decoder
        hidden = hidden.squeeze(0)  # [1, batch_size, hidden_dim] -> [batch_size, hidden_dim]
        cell = cell.squeeze(0)
        
        for t in range(max_len):
            output, hidden, cell, _ = self.decoder(
                x=x,
                prev_hidden=hidden,
                prev_cell=cell,
                encoder_outputs=enc_outputs
            )
            outputs[:, t] = output
            
            if trg is not None and random.random() < teacher_forcing_ratio:
                x = trg[:, t]
            else:
                x = output.argmax(1)
        
        return outputs

In [16]:
print(f"GPU Memory allocated: {torch.cuda.memory_allocated()/1024**2:.2f} MB")
print(f"GPU Memory reserved: {torch.cuda.memory_reserved()/1024**2:.2f} MB")

GPU Memory allocated: 9.92 MB
GPU Memory reserved: 20.00 MB


In [17]:
# # 2. Tạo embedding layer (như bạn đã làm)
# embedding_layer = torch.nn.Embedding.from_pretrained(
#     torch.FloatTensor(embeddings),
#     freeze=False,
#     padding_idx=stoi("<PAD>")
# ).to(device)

# # 3. Khởi tạo model
# model = Seq2SeqModel(
#     embedding_layer=embedding_layer,
#     hidden_dim=128,
#     vocab_size=len(vocab)
# ).to(device)

# # 4. Train loop
# optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# criterion = nn.CrossEntropyLoss(ignore_index=stoi("<PAD>"))

# for epoch in range(1):
#     model.train()
#     total_loss = 0
    
#     for batch in train_loader:
#         src = batch['article'].to(device)
#         src_lens = batch['article_len'].to(device)
#         trg = batch['summary'].to(device)
        
#         outputs = model(src, src_lens, trg=trg, teacher_forcing_ratio=0.5)
#         loss = criterion(outputs[:, 1:].reshape(-1, outputs.size(-1)), 
#                          trg[:, 1:].reshape(-1))
        
#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()
        
#         total_loss += loss.item()
    
#     print(f"Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.4f}")

In [18]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
embedding_layer = torch.nn.Embedding.from_pretrained(
    torch.FloatTensor(embeddings),
    freeze=False,  # Cho phép fine-tune embedding
    padding_idx=stoi("<PAD>")  # Index của token padding
).to(device)

# Khởi tạo
hidden_dim = 128
vocab_size = len(vocab)
model = Seq2SeqModel(
    embedding_layer=embedding_layer,
    hidden_dim=hidden_dim,
    vocab_size=vocab_size
).to(device)
print("Embedding shape:", torch.FloatTensor(embeddings).shape)
print("Vocab size:", vocab_size)
print("Model architecture:")
print(model)


Embedding shape: torch.Size([25004, 104])
Vocab size: 25004
Model architecture:
Seq2SeqModel(
  (encoder): SimpleEncoder(
    (embedding): Embedding(25004, 104, padding_idx=0)
    (lstm): LSTM(104, 128, batch_first=True)
  )
  (decoder): SimpleDecoder(
    (embedding): Embedding(25004, 104, padding_idx=0)
    (lstm): LSTMCell(232, 128)
    (attention): SimpleAttention(
      (energy): Sequential(
        (0): Linear(in_features=256, out_features=128, bias=True)
        (1): Tanh()
        (2): Linear(in_features=128, out_features=1, bias=False)
      )
    )
    (fc_out): Linear(in_features=256, out_features=25004, bias=True)
  )
)


In [19]:
wandb.init(
    project="Seq2Seq-Summarization",
    name=f"seq2seq-{datetime.now().strftime('%Y%m%d-%H%M%S')}",
    config={
        "model": "Seq2Seq-LSTM",
        "hidden_dim": 128,
        "batch_size": 16,
        "learning_rate": 0.001,
        "teacher_forcing_ratio": 0.5,
        "vocab_size": len(vocab)
    }
)

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mvubkk67[0m ([33mvubkk67-hanoi-university-of-science-and-technology[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [20]:
def train_model(model, train_loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    progress_bar = tqdm(train_loader, desc="Training", leave=False)
    
    for batch in progress_bar:
        src = batch['article'].to(device)
        src_lens = batch['article_len'].to(device)
        trg = batch['summary'].to(device)
        
        outputs = model(src, src_lens, trg=trg, teacher_forcing_ratio=0.5)
        loss = criterion(
            outputs[:, 1:].reshape(-1, outputs.size(-1)),
            trg[:, 1:].reshape(-1)
        )
        
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        total_loss += loss.item()
        progress_bar.set_postfix(loss=loss.item())
    
    return total_loss / len(train_loader)

# 3. Hàm eval với progress bar
def evaluate(model, val_loader, criterion, device):
    model.eval()
    total_loss = 0
    progress_bar = tqdm(val_loader, desc="Evaluating", leave=False)
    
    with torch.no_grad():
        for batch in progress_bar:
            src = batch['article'].to(device)
            src_lens = batch['article_len'].to(device)
            trg = batch['summary'].to(device)
            
            outputs = model(src, src_lens, trg=trg, teacher_forcing_ratio=0)
            loss = criterion(
                outputs[:, 1:].reshape(-1, outputs.size(-1)),
                trg[:, 1:].reshape(-1)
            )
            total_loss += loss.item()
            progress_bar.set_postfix(loss=loss.item())
    
    return total_loss / len(val_loader)


In [21]:
import os

# Tạo thư mục nếu chưa tồn tại
model_dir = "../Model"
os.makedirs(model_dir, exist_ok=True)

# Lưu mô hình tốt nhất
best_model_path = os.path.join(model_dir, "best_model.pth")
embedding_layer = torch.nn.Embedding.from_pretrained(
    torch.FloatTensor(embeddings),
    freeze=False,
    padding_idx=stoi("<PAD>")
).to(device)

# 3. Khởi tạo model
model = Seq2SeqModel(
    embedding_layer=embedding_layer,
    hidden_dim=128,
    vocab_size=len(vocab)
).to(device)

# 4. Train loop
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss(ignore_index=stoi("<PAD>"))
best_val_loss = float('inf')
wandb.watch(model)
for epoch in range(3):
    start_time = time.time()
    
    # Train
    train_loss = train_model(model, train_loader, optimizer, criterion, device)
    
    # Eval
    val_loss = evaluate(model, valid_loader, criterion, device)
    
    # Log metrics
    wandb.log({
        "epoch": epoch+1,
        "train_loss": train_loss,
        "val_loss": val_loss,
        "lr": optimizer.param_groups[0]['lr']
    })
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }, best_model_path)
    
    print(f"Epoch {epoch+1:02d} | "
          f"Train Loss: {train_loss:.4f} | "
          f"Val Loss: {val_loss:.4f} | "
          f"Time: {time.time()-start_time:.2f}s")

# Kết thúc W&B
wandb.finish()

Training:   0%|          | 0/6123 [00:00<?, ?it/s]