In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from torchvision import transforms
import torch.optim as optim
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
import random 
from tqdm import tqdm
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import torch.nn.functional as F
from tqdm.auto import tqdm
from datetime import datetime
import wandb
import time
import os
import math

In [2]:
NUM_EPOCHS = 30
BATCH_SIZE = 16
FRAC_SAMPLE = 0.03
MAX_LENGTH_ARTICLE = 512
MIN_LENGTH_ARTICLE = 50
MAX_LENGTH_SUMMARY = 128
MIN_LENGTH_SUMMARY = 20
HIDDEN_DIM = 128
LEARNING_RATE = 0.005
NUM_CYCLES = 3
MAX_PLATEAU_COUNT = 5
WEIGHT_DECAY = 1e-4

In [3]:
def linear_warmup_decay(step, warmup_steps, total_steps):
    if step < warmup_steps:
        return (step + 1) / (warmup_steps + 1)
    else:
        return max(1e-7, (total_steps - step) / (total_steps - warmup_steps))


def warmup_cosine_with_restarts(step, warmup_steps, total_steps, num_cycles=1):
    if step < warmup_steps:
        return (step + 1) / (warmup_steps + 1)
    else:
        progress = (step - warmup_steps) / (total_steps - warmup_steps)
        cycle_progress = progress * num_cycles % 1
        return max(1e-7, 0.5 * (1 + math.cos(math.pi * cycle_progress)))



def get_scheduler(
    optimizer, total_steps, warmup_steps, num_cycles=None, types='warmup_cosine_with_restarts'
):
    if types == 'warmup_cosine_with_restarts':
        assert num_cycles != None, 'must specify num_cycles when types="warmup_cosine_with_restarts"'
        return torch.optim.lr_scheduler.LambdaLR(
            optimizer,
            lr_lambda=lambda step: warmup_cosine_with_restarts(
                step, warmup_steps, total_steps, num_cycles=num_cycles)
        )
    elif types == 'linear_warmup_decay':
        return torch.optim.lr_scheduler.LambdaLR(
            optimizer,
            lr_lambda=lambda step: linear_warmup_decay(step, warmup_steps, total_steps)
        )
    else:
        raise Exception('not implemented')

In [4]:
model_dir = "../Model"
os.makedirs(model_dir, exist_ok=True)

In [5]:
train_data = pd.read_csv("../dataset/train.csv")
validation_data = pd.read_csv("../dataset/validation.csv")
test_data = pd.read_csv("../dataset/test.csv")

train_data.rename(columns={"highlights": "summaries", "article":"articles"}, inplace=True)
validation_data.rename(columns={"highlights": "summaries","article":"articles"}, inplace=True)
test_data.rename(columns={"highlights": "summaries", "article":"articles"}, inplace=True)


In [6]:
train_data["article_word_count"] = train_data["articles"].astype(str).apply(lambda x: len(x.split()))
train_data["summary_word_count"] = train_data["summaries"].astype(str).apply(lambda x: len(x.split()))

validation_data["article_word_count"] = validation_data["articles"].astype(str).apply(lambda x: len(x.split()))
validation_data["summary_word_count"] = validation_data["summaries"].astype(str).apply(lambda x: len(x.split()))

test_data["article_word_count"] = test_data["articles"].astype(str).apply(lambda x: len(x.split()))
test_data["summary_word_count"] = test_data["summaries"].astype(str).apply(lambda x: len(x.split()))


In [7]:
# Lọc train_data
train_data = train_data[
    (train_data["article_word_count"] <= MAX_LENGTH_ARTICLE) & 
    (train_data["article_word_count"] >= MIN_LENGTH_ARTICLE) &
    (train_data["summary_word_count"] <= MAX_LENGTH_SUMMARY) &
    (train_data["summary_word_count"] >= MIN_LENGTH_SUMMARY)
]

# Lọc validation_data
validation_data = validation_data[
    (validation_data["article_word_count"] <= MAX_LENGTH_ARTICLE) & 
    (validation_data["article_word_count"] >= MIN_LENGTH_ARTICLE) &
    (validation_data["summary_word_count"] <= MAX_LENGTH_SUMMARY) &
    (validation_data["summary_word_count"] >= MIN_LENGTH_SUMMARY)
]

# Lọc test_data
test_data = test_data[
    (test_data["article_word_count"] <= MAX_LENGTH_ARTICLE) & 
    (test_data["article_word_count"] >= MIN_LENGTH_ARTICLE) &
    (test_data["summary_word_count"] <= MAX_LENGTH_SUMMARY) &
    (test_data["summary_word_count"] >= MIN_LENGTH_SUMMARY)
]

In [8]:
train_sample = train_data.sample(frac=FRAC_SAMPLE, random_state=1)
validation_sample = validation_data.sample(frac=FRAC_SAMPLE, random_state=1)
test_sample = test_data.sample(frac=FRAC_SAMPLE, random_state=1)
train_sample.info()
print("\n")
validation_sample.info()

<class 'pandas.core.frame.DataFrame'>
Index: 2880 entries, 144417 to 2426
Data columns (total 5 columns):
 #   Column              Non-Null Count  Dtype 
---  ------              --------------  ----- 
 0   id                  2880 non-null   object
 1   articles            2880 non-null   object
 2   summaries           2880 non-null   object
 3   article_word_count  2880 non-null   int64 
 4   summary_word_count  2880 non-null   int64 
dtypes: int64(2), object(3)
memory usage: 135.0+ KB


<class 'pandas.core.frame.DataFrame'>
Index: 149 entries, 8901 to 5720
Data columns (total 5 columns):
 #   Column              Non-Null Count  Dtype 
---  ------              --------------  ----- 
 0   id                  149 non-null    object
 1   articles            149 non-null    object
 2   summaries           149 non-null    object
 3   article_word_count  149 non-null    int64 
 4   summary_word_count  149 non-null    int64 
dtypes: int64(2), object(3)
memory usage: 7.0+ KB


In [9]:
EMBEDDING_FILE = "../Embedding/glove-wiki-gigaword-100.txt"
vocab, embeddings = [], []
with open(EMBEDDING_FILE, 'rt', encoding='utf-8') as ef:
    full_content = ef.read().strip().split('\n')
for i in range(len(full_content)):
    i_word = full_content[i].split(' ')[0]
    i_embeddings = [float(val) for val in full_content[i].split(' ')[1:]]
    i_embeddings.extend([0.0, 0.0, 0.0, 0.0])
    vocab.append(i_word)
    embeddings.append(i_embeddings)

embs_npa = np.array(embeddings)

unk_embedding = np.mean(embs_npa, axis=0).tolist()

dim = embs_npa.shape[1]
sos_embedding = [0.0] * dim
sos_embedding[-3] = 1.0
eos_embedding = [0.0] * dim
eos_embedding[-2] = 1.0
pad_embedding = [0.0] * dim
pad_embedding[-4] = 1.0
# unk_embedding = [0.0] * dim
# unk_embedding[-1] = 1.0

# Update vocab and embeddings
vocab = ["<PAD>", "<SOS>", "<EOS>", "<UNK>"] + vocab
embeddings = [pad_embedding, sos_embedding,
              eos_embedding, unk_embedding] + embeddings

vocab_npa = np.array(vocab)
embs_npa = np.array(embeddings)


def tokenize(text):
    return text.lower().strip().split()


stoi_dict = {word: idx for idx, word in enumerate(vocab_npa)}
_unk_idx = stoi_dict["<UNK>"]
itos = {idx: word for word, idx in stoi_dict.items()}

def stoi(string, stoi_dict=stoi_dict):
    return stoi_dict.get(string, _unk_idx)


def numericalize(text):
    tokenized_text = tokenize(text)
    return [
        stoi(token)
        for token in tokenized_text
    ]

print(embs_npa.shape[0])
embedding_layer = torch.nn.Embedding.from_pretrained(torch.FloatTensor(embeddings),
                                                     freeze=False,
                                                     padding_idx=stoi("<PAD>"))
embedding_layer.to(device)
vocab_size = len(vocab_npa)
print("Embedding shape:", np.array(embeddings).shape) 
print("<PAD> embedding last 4 dims:", embeddings[stoi("<PAD>")][-4:])
print("<SOS> embedding last 4 dims:", embeddings[stoi("<SOS>")][-4:])
print("Word 'the' embedding last 4 dims:", embeddings[stoi("the")][-4:])

25004
Embedding shape: (25004, 104)
<PAD> embedding last 4 dims: [1.0, 0.0, 0.0, 0.0]
<SOS> embedding last 4 dims: [0.0, 1.0, 0.0, 0.0]
Word 'the' embedding last 4 dims: [0.0, 0.0, 0.0, 0.0]


In [10]:
class Seq2SeqDataset(Dataset):
    def __init__(self, articles, summaries, stoi, max_len_article=MAX_LENGTH_ARTICLE, max_len_summary=MAX_LENGTH_SUMMARY):
        self.articles = articles  # List of articles
        self.summaries = summaries  # List of summaries
        self.stoi = stoi  # String-to-index dictionary
        self.pad_idx = stoi("<PAD>")
        self.sos_idx = stoi("<SOS>")
        self.eos_idx = stoi("<EOS>")
        
        # Determine max lengths if not provided
        self.max_len_article = max_len_article or max(len(a.split()) for a in articles) + 2
        self.max_len_summary = max_len_summary or max(len(s.split()) for s in summaries) + 2

    def __len__(self):
        return len(self.articles)

    def __getitem__(self, idx):
        def process_text(text, max_len):
            tokens = [self.sos_idx] + [self.stoi(w) for w in text.split()] + [self.eos_idx]  # Tokenize and add SOS/EOS
            tokens = tokens[:max_len] + [self.pad_idx] * (max_len - len(tokens))  # Pad to max length
            return torch.tensor(tokens), len(tokens)

        article_tokens, article_len = process_text(self.articles[idx], self.max_len_article)
        summary_tokens, summary_len = process_text(self.summaries[idx], self.max_len_summary)
        
        return {
            'article': article_tokens,  # Encoded article
            'article_len': torch.tensor(article_len),
            'summary': summary_tokens,  # Encoded summary
            'summary_len': torch.tensor(summary_len)
        }

def collate_fn(batch):
    # Batch is list os the dict {'article': ..., 'summary': ...}
    return {
        'article': torch.stack([item['article'] for item in batch]),
        'article_len': torch.tensor([item['article_len'] for item in batch]),
        'summary': torch.stack([item['summary'] for item in batch]),
        'summary_len': torch.tensor([item['summary_len'] for item in batch])
    }

# DataLoader setup
# torch.set_printoptions(profile="full")
torch.set_printoptions(profile="default")
train_dataset = Seq2SeqDataset(train_sample['articles'].tolist(), train_sample['summaries'].tolist(), stoi)
# print(train_dataset[268]["article"])
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)

valid_dataset= Seq2SeqDataset(validation_sample['articles'].tolist(), validation_sample['summaries'].tolist(), stoi)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)


In [11]:

class SimpleEncoder(nn.Module):
    def __init__(self, embedding_layer, hidden_dim):
        super().__init__()
        self.embedding = embedding_layer
        self.lstm = nn.LSTM(
            input_size=self.embedding.embedding_dim,
            hidden_size=hidden_dim,
            batch_first=True,
            bidirectional=False
        )
        self.dropout = nn.Dropout(0.1)
    
    def forward(self, x, seq_lens):
        x = self.dropout(self.embedding(x))
        packed = pack_padded_sequence(
            input=x,
            lengths=seq_lens.cpu(),
            batch_first=True, 
            enforce_sorted=False
        )
        packed_output, (hidden, cell) = self.lstm(packed)
        output, _ = pad_packed_sequence(packed_output, batch_first=True)
        return output, (hidden, cell)

class SimpleAttention(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.energy = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1, bias=False))
        self.dropout = nn.Dropout(0.1)

    def forward(self, decoder_hidden, encoder_outputs, mask=None):
        # decoder_hidden: [batch, hidden]
        # encoder_outputs: [batch, seq_len, hidden]
        
        decoder_hidden = decoder_hidden.unsqueeze(1)  # [batch, 1, hidden]
        
        # Repeat decoder hidden state across sequence length
        decoder_hidden = decoder_hidden.expand(-1, encoder_outputs.size(1), -1)
        
        # Calculate attention scores
        energy_input = torch.cat([encoder_outputs, decoder_hidden], dim=2)
        scores = self.energy(self.dropout(energy_input)).squeeze(2)  # [batch, seq_len]
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e10)
            
        attn_weights = F.softmax(scores, dim=1)
        context = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs).squeeze(1)
        
        return context, attn_weights

class SimpleDecoder(nn.Module):
    def __init__(self, embedding_layer, hidden_dim, vocab_size):
        super().__init__()
        self.embedding = embedding_layer
        self.attention = SimpleAttention(hidden_dim)
        self.lstm = nn.LSTMCell(
            input_size=self.embedding.embedding_dim + hidden_dim,
            hidden_size=hidden_dim
        )
        self.fc_out = nn.Linear(hidden_dim * 2, vocab_size)
        self.dropout = nn.Dropout(0.1)
        
        # Initialize weights
        nn.init.xavier_uniform_(self.fc_out.weight)
        if self.fc_out.bias is not None:
            self.fc_out.bias.data.fill_(0.01)

    def forward(self, x, prev_hidden, prev_cell, encoder_outputs, mask=None):
        x = self.dropout(self.embedding(x))  # [batch] -> [batch, emb_dim]
        
        if prev_hidden.dim() == 3:
            prev_hidden = prev_hidden[-1]  # Take last layer if multi-layer
        
        context, attn_weights = self.attention(prev_hidden, encoder_outputs, mask)
        
        # LSTM update
        lstm_input = torch.cat([x, context], dim=1)
        hidden, cell = self.lstm(lstm_input, (prev_hidden, prev_cell))
        
        # Output prediction
        output = self.fc_out(torch.cat([hidden, context], dim=1))
        return output, hidden, cell, attn_weights

class Seq2SeqModel(nn.Module):
    def __init__(self, embedding_layer, hidden_dim, vocab_size):
        super().__init__()
        self.encoder = SimpleEncoder(embedding_layer, hidden_dim)
        self.decoder = SimpleDecoder(embedding_layer, hidden_dim, vocab_size)
        self.vocab_size = vocab_size
        self.start_id = 2  # <SOS> token id
        self.end_id = 3    # <EOS> token id
        
        # Projection layers for encoder to decoder states
        self.hidden_proj = nn.Linear(hidden_dim, hidden_dim)
        self.cell_proj = nn.Linear(hidden_dim, hidden_dim)
        
        # Initialize projections
        for proj in [self.hidden_proj, self.cell_proj]:
            nn.init.xavier_uniform_(proj.weight)
            proj.bias.data.fill_(0.01)

    def forward(self, src, src_lens, trg=None, max_len=None, teacher_forcing_ratio=0.5):
        # Encoder forward
        enc_outputs, (hidden, cell) = self.encoder(src, src_lens)
        
        # Project encoder states to decoder space
        hidden = self.hidden_proj(hidden.squeeze(0))
        cell = self.cell_proj(cell.squeeze(0))
        
        # Determine max length
        batch_size = src.size(0)
        if trg is not None:
            max_len = trg.size(1)
        else:
            max_len = max_len if max_len is not None else 100
            
        # Initialize outputs tensor
        outputs = torch.zeros(batch_size, max_len, self.vocab_size).to(src.device)
        
        # First input is SOS token
        x = torch.full((batch_size,), self.start_id, dtype=torch.long, device=src.device)
        
        # Create mask from padding
        mask = (src != 0)  # Assuming 0 is pad_idx
        
        # Decoding loop
        for t in range(max_len):
            output, hidden, cell, _ = self.decoder(
                x=x,
                prev_hidden=hidden,
                prev_cell=cell,
                encoder_outputs=enc_outputs,
                mask=mask
            )
            outputs[:, t] = output
            
            # Decide next input
            if trg is not None and random.random() < teacher_forcing_ratio:
                x = trg[:, t]
            else:
                x = output.argmax(1)
                
            # Early stopping if all sequences predicted EOS
            if (x == self.end_id).all() and trg is None:
                outputs = outputs[:, :t+1]  # Chỉ trim khi inference
                break
        return outputs

In [12]:
wandb.init(
    project="Seq2Seq-Summarization",
    name=f"seq2seq-{datetime.now().strftime('%Y%m%d-%H%M%S')}",
    config={
        "model": "Seq2Seq-LSTM",
        "hidden_dim": HIDDEN_DIM,
        "batch_size": BATCH_SIZE,
        "learning_rate": LEARNING_RATE,
        "teacher_forcing_ratio": 1.0,
        "vocab_size": len(vocab)
    }
)

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.


[34m[1mwandb[0m: Currently logged in as: [33mvubkk67[0m ([33mvubkk67-hanoi-university-of-science-and-technology[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [13]:
def train_model(model, train_loader, optimizer, criterion, device, scheduler=None):
    model.train()
    total_loss = 0
    progress_bar = tqdm(train_loader, desc="Training", leave=False)
    
    for batch in progress_bar:
        src = batch['article'].to(device)
        src_lens = batch['article_len'].to(device)
        trg = batch['summary'].to(device)
        
        optimizer.zero_grad()
        
        # Forward pass with teacher forcing
        outputs = model(src, src_lens, trg=trg, teacher_forcing_ratio=0.5)
        
        # Calculate loss (ignore padding)
        min_len = min(outputs.size(1), trg.size(1))
        loss = criterion(
            outputs[:, 1:min_len].reshape(-1, outputs.size(-1)),
            trg[:, 1:min_len].reshape(-1)
)
        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        if scheduler is not None:
            scheduler.step()
        
        total_loss += loss.item()
        progress_bar.set_postfix(loss=loss.item())
    
    return total_loss / len(train_loader)

def evaluate(model, val_loader, criterion, device):
    model.eval()
    total_loss = 0
    progress_bar = tqdm(val_loader, desc="Evaluating", leave=False)
    
    with torch.no_grad():
        for batch in progress_bar:
            src = batch['article'].to(device)
            src_lens = batch['article_len'].to(device)
            trg = batch['summary'].to(device)
            
            outputs = model(src, src_lens, trg=trg, teacher_forcing_ratio=0)
            loss = criterion(
                outputs[:, 1:].reshape(-1, outputs.size(-1)),
                trg[:, 1:].reshape(-1)
            )
            total_loss += loss.item()
            progress_bar.set_postfix(loss=loss.item())
    
    return total_loss / len(val_loader)
# Save best Model
best_model_path = os.path.join(model_dir, "best_model_128.pth")
embedding_layer = torch.nn.Embedding.from_pretrained(
    torch.FloatTensor(embeddings),
    freeze=False,
    padding_idx=stoi("<PAD>")
).to(device)

# 3. Khởi tạo model
model = Seq2SeqModel(
    embedding_layer=embedding_layer,
    hidden_dim=HIDDEN_DIM,
    vocab_size=vocab_size
).to(device)

# 4. Train loop
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
criterion = nn.CrossEntropyLoss(ignore_index=stoi("<PAD>"))
best_val_loss = float('inf')
# Initialize learning rate scheduler
total_steps = NUM_EPOCHS * len(train_loader)
warmup_steps = int(0.1 * total_steps)  # 10% of total steps for warmup
lr_scheduler = get_scheduler(
    optimizer, 
    total_steps=total_steps,
    warmup_steps=warmup_steps,
    num_cycles=NUM_CYCLES
)
plateau_count = 0
wandb.watch(model)
for epoch in range(NUM_EPOCHS):
    start_time = time.time()
    
    # Train
    train_loss = train_model(model, train_loader, optimizer, criterion, device,lr_scheduler)
    
    # Eval
    val_loss = evaluate(model, valid_loader, criterion, device)
    
    current_lr = lr_scheduler.get_last_lr()[0] if lr_scheduler else LEARNING_RATE
    # Log metrics
    wandb.log({
        "epoch": epoch+1,
        "train_loss": train_loss,
        "val_loss": val_loss,
        "best_val_loss": best_val_loss, 
        "lr": current_lr
    })
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        plateau_count = 0
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'lr_scheduler_state_dict': lr_scheduler.state_dict() if lr_scheduler else None,
        }, best_model_path)
    else:
        plateau_count += 1
    
    print(f"Epoch {epoch+1:02d} | "
          f"Train Loss: {train_loss:.4f} | "
          f"Val Loss: {val_loss:.4f} | "
          f"LR: {current_lr:.6f} | "
          f"Time: {time.time()-start_time:.2f}s")
    
    # Early stopping if validation loss doesn't improve
    if plateau_count >= MAX_PLATEAU_COUNT:
        print(f"Validation loss hasn't improved for {MAX_PLATEAU_COUNT} epochs. Stopping training.")
        break

# Kết thúc W&B
wandb.finish()

Training:   0%|          | 0/180 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 01 | Train Loss: 6.6745 | Val Loss: 5.4151 | LR: 0.001673 | Time: 160.84s


Training:   0%|          | 0/180 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 02 | Train Loss: 5.4335 | Val Loss: 5.3730 | LR: 0.003336 | Time: 162.59s


Training:   0%|          | 0/180 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 03 | Train Loss: 5.3354 | Val Loss: 5.4149 | LR: 0.005000 | Time: 162.04s


Training:   0%|          | 0/180 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 04 | Train Loss: 5.2417 | Val Loss: 5.3806 | LR: 0.004849 | Time: 160.41s


Training:   0%|          | 0/180 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 05 | Train Loss: 5.1803 | Val Loss: 5.3223 | LR: 0.004415 | Time: 160.01s


Training:   0%|          | 0/180 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 06 | Train Loss: 5.1361 | Val Loss: 5.2705 | LR: 0.003750 | Time: 159.68s


Training:   0%|          | 0/180 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 07 | Train Loss: 5.0911 | Val Loss: 5.2885 | LR: 0.002934 | Time: 159.59s


Training:   0%|          | 0/180 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 08 | Train Loss: 5.0549 | Val Loss: 5.2958 | LR: 0.002066 | Time: 159.91s


Training:   0%|          | 0/180 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 09 | Train Loss: 5.0203 | Val Loss: 5.2875 | LR: 0.001250 | Time: 159.32s


Training:   0%|          | 0/180 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 10 | Train Loss: 4.9838 | Val Loss: 5.2464 | LR: 0.000585 | Time: 159.78s


Training:   0%|          | 0/180 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 11 | Train Loss: 4.9438 | Val Loss: 5.2137 | LR: 0.000151 | Time: 159.51s


Training:   0%|          | 0/180 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 12 | Train Loss: 4.9098 | Val Loss: 5.2571 | LR: 0.005000 | Time: 162.13s


Training:   0%|          | 0/180 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 13 | Train Loss: 5.0851 | Val Loss: 5.3083 | LR: 0.004849 | Time: 159.25s


Training:   0%|          | 0/180 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 14 | Train Loss: 5.0386 | Val Loss: 5.2852 | LR: 0.004415 | Time: 158.96s


Training:   0%|          | 0/180 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 15 | Train Loss: 5.0008 | Val Loss: 5.2401 | LR: 0.003750 | Time: 158.71s


Training:   0%|          | 0/180 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 16 | Train Loss: 4.9586 | Val Loss: 5.2612 | LR: 0.002934 | Time: 158.88s
Validation loss hasn't improved for 5 epochs. Stopping training.


VBox(children=(Label(value='0.037 MB of 0.037 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
best_val_loss,█▇▇▇▅▃▃▃▃▂▁▁▁▁▁
epoch,▁▁▂▂▃▃▄▄▅▅▆▆▇▇██
lr,▃▆██▇▆▅▄▃▂▁██▇▆▅
train_loss,█▃▃▂▂▂▂▂▁▁▁▁▂▂▁▁
val_loss,█▇█▇▅▃▄▄▄▂▁▃▄▃▂▃

0,1
best_val_loss,5.21369
epoch,16.0
lr,0.00293
train_loss,4.95862
val_loss,5.26122


In [15]:
def decode_indices(indices, itos):
    tokens = []
    for idx in indices:
        token = itos.get(idx.item(), "<UNK>")
        if token == "<EOS>":
            break
        if token not in {"<SOS>", "<PAD>"}:
            tokens.append(token)
    return " ".join(tokens)


In [16]:
# Load lại model
checkpoint = torch.load(best_model_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()


  checkpoint = torch.load(best_model_path, map_location=device)


Seq2SeqModel(
  (encoder): SimpleEncoder(
    (embedding): Embedding(25004, 104, padding_idx=0)
    (lstm): LSTM(104, 128, batch_first=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (decoder): SimpleDecoder(
    (embedding): Embedding(25004, 104, padding_idx=0)
    (attention): SimpleAttention(
      (energy): Sequential(
        (0): Linear(in_features=256, out_features=128, bias=True)
        (1): Tanh()
        (2): Linear(in_features=128, out_features=1, bias=False)
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (lstm): LSTMCell(232, 128)
    (fc_out): Linear(in_features=256, out_features=25004, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (hidden_proj): Linear(in_features=128, out_features=128, bias=True)
  (cell_proj): Linear(in_features=128, out_features=128, bias=True)
)

In [17]:
# Lấy một sample
test_article = test_sample.iloc[0]['articles']
true_summary = test_sample.iloc[0]['summaries']

# Tiền xử lý như trong Dataset
tokens = [stoi("<SOS>")] + [stoi(w) for w in test_article.split()] + [stoi("<EOS>")]
tokens = tokens[:MAX_LENGTH_ARTICLE] + [stoi("<PAD>")] * (MAX_LENGTH_ARTICLE - len(tokens))
src_tensor = torch.tensor(tokens).unsqueeze(0).to(device)
src_len = torch.tensor([len(tokens)]).to(device)

# Dự đoán
with torch.no_grad():
    output = model(src_tensor, src_len, trg=None, teacher_forcing_ratio=0.0)

# Lấy chuỗi dự đoán
pred_indices = output.argmax(dim=-1).squeeze(0)
pred_summary = decode_indices(pred_indices, itos)

# In kết quả
print("\n📰 Input Article:\n", test_article)
print("\n✅ True Summary:\n", true_summary)
print("\n🤖 Predicted Summary:\n", pred_summary)



📰 Input Article:

✅ True Summary:
 Joel Parker, 33, was riding the bus in St Johns County, Florida .
Police said he threatened the driver and was disruptive during the ride .
As he got off the bus he offered the candy bar to the driver, who declined .
He was arrested for battery and is never allowed to ride the bus again .

🤖 Predicted Summary:
 <UNK>
