In [1]:
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import torch.nn.functional as F
from tqdm.auto import tqdm
from datetime import datetime
import wandb
import time
import os

from rouge import Rouge
from bert_score import score as bert_score
from nltk.translate.meteor_score import single_meteor_score
# Import Hugging Face tokenizer
from transformers import BartTokenizer

# nltk.download('punkt') # No longer needed with HF tokenizer

NUM_EPOCHS = 1
BATCH_SIZE = 4
FRAC_SAMPLE = 0.2
MAX_LENGTH_ARTICLE = 512
MIN_LENGTH_ARTICLE = 50 # Keep for data filtering, but tokenizer handles max length
MAX_LENGTH_SUMMARY = 128
MIN_LENGTH_SUMMARY = 20 # Keep for data filtering
HIDDEN_DIM = 512
LEARNING_RATE = 0.00005 # Reduced learning rate
NUM_CYCLES = 3
MAX_PLATEAU_COUNT = 5
WEIGHT_DECAY = 1e-4
CLIP = 1
USE_PRETRAINED_EMB = False # Not using pre-trained GloVe embeddings anymore
USE_SCHEDULER = True
SCHEDULER_TYPE = "warmup_cosine_with_restarts"
TEACHER_FORCING_RATIO = 0.75
NUM_CYCLES = 3
MAX_PLATEAU_COUNT = 5


# model_dir = "../Model"
datafilter = "../dataft"
os.makedirs(datafilter, exist_ok=True)
# os.makedirs(model_dir, exist_ok=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
train_data = pd.read_csv("../dataset/train.csv")
validation_data = pd.read_csv("../dataset/validation.csv")
test_data = pd.read_csv("../dataset/test.csv")

# add col
train_data.rename(columns={"highlights": "summaries", "article":"articles"}, inplace=True)
validation_data.rename(columns={"highlights": "summaries","article":"articles"}, inplace=True)
test_data.rename(columns={"highlights": "summaries", "article":"articles"}, inplace=True)

train_data["article_word_count"] = train_data["articles"].astype(str).apply(lambda x: len(x.split()))
train_data["summary_word_count"] = train_data["summaries"].astype(str).apply(lambda x: len(x.split()))

validation_data["article_word_count"] = validation_data["articles"].astype(str).apply(lambda x: len(x.split()))
validation_data["summary_word_count"] = validation_data["summaries"].astype(str).apply(lambda x: len(x.split()))

test_data["article_word_count"] = test_data["articles"].astype(str).apply(lambda x: len(x.split()))
test_data["summary_word_count"] = test_data["summaries"].astype(str).apply(lambda x: len(x.split()))

# filter range
train_data = train_data[
    (train_data["article_word_count"] <= MAX_LENGTH_ARTICLE) &
    (train_data["article_word_count"] >= MIN_LENGTH_ARTICLE) &
    (train_data["summary_word_count"] <= MAX_LENGTH_SUMMARY) &
    (train_data["summary_word_count"] >= MIN_LENGTH_SUMMARY)
]

validation_data = validation_data[
    (validation_data["article_word_count"] <= MAX_LENGTH_ARTICLE) &
    (validation_data["article_word_count"] >= MIN_LENGTH_ARTICLE) &
    (validation_data["summary_word_count"] <= MAX_LENGTH_SUMMARY) &
    (validation_data["summary_word_count"] >= MIN_LENGTH_SUMMARY)
]
test_data = test_data[
    (test_data["article_word_count"] <= MAX_LENGTH_ARTICLE) &
    (test_data["article_word_count"] >= MIN_LENGTH_ARTICLE) &
    (test_data["summary_word_count"] <= MAX_LENGTH_SUMMARY) &
    (test_data["summary_word_count"] >= MIN_LENGTH_SUMMARY)
]

train_sample = train_data.sample(frac=FRAC_SAMPLE, random_state=1)
validation_sample = validation_data.sample(frac=FRAC_SAMPLE, random_state=1)
test_sample = test_data.sample(frac=1, random_state=1)

train_sample.to_csv(os.path.join(datafilter,"train_sample.csv"), index=False)
test_sample.to_csv(os.path.join(datafilter,"test_sample.csv"), index=False)
validation_sample.to_csv(os.path.join(datafilter,"validation_sample.csv"), index=False)

In [3]:
train_sample.info()
test_sample.info()

<class 'pandas.core.frame.DataFrame'>
Index: 19198 entries, 144417 to 201560
Data columns (total 5 columns):
 #   Column              Non-Null Count  Dtype 
---  ------              --------------  ----- 
 0   id                  19198 non-null  object
 1   articles            19198 non-null  object
 2   summaries           19198 non-null  object
 3   article_word_count  19198 non-null  int64 
 4   summary_word_count  19198 non-null  int64 
dtypes: int64(2), object(3)
memory usage: 899.9+ KB
<class 'pandas.core.frame.DataFrame'>
Index: 4224 entries, 9204 to 591
Data columns (total 5 columns):
 #   Column              Non-Null Count  Dtype 
---  ------              --------------  ----- 
 0   id                  4224 non-null   object
 1   articles            4224 non-null   object
 2   summaries           4224 non-null   object
 3   article_word_count  4224 non-null   int64 
 4   summary_word_count  4224 non-null   int64 
dtypes: int64(2), object(3)
memory usage: 198.0+ KB


In [4]:
# Load pre-trained tokenizer
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')

vocab_size = len(tokenizer)

In [5]:

pad_token_id = tokenizer.pad_token_id
unk_token_id = tokenizer.unk_token_id
sos_token_id = tokenizer.bos_token_id
eos_token_id = tokenizer.eos_token_id

print(f"Vocabulary size: {vocab_size}")
print(f"PAD token ID: {pad_token_id}")
print(f"UNK token ID: {unk_token_id}")
print(f"SOS token ID: {sos_token_id}")
print(f"EOS token ID: {eos_token_id}")

Vocabulary size: 50265
PAD token ID: 1
UNK token ID: 3
SOS token ID: 0
EOS token ID: 2


In [6]:
class Seq2SeqDataset(Dataset):
    def __init__(self, articles, summaries, tokenizer, max_len_article=MAX_LENGTH_ARTICLE, max_len_summary=MAX_LENGTH_SUMMARY):
        self.articles = articles
        self.summaries = summaries
        self.tokenizer = tokenizer
        self.max_len_article = max_len_article
        self.max_len_summary = max_len_summary

    def __len__(self):
        return len(self.articles)

    def __getitem__(self, idx):
        article = str(self.articles[idx])
        summary = str(self.summaries[idx])

        # Tokenize and encode article
        article_encoding = self.tokenizer(
            article,
            max_length=self.max_len_article,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )

        # Tokenize and encode summary
        summary_encoding = self.tokenizer(
            summary,
            max_length=self.max_len_summary,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )

        return {
            'article': article_encoding['input_ids'].squeeze(0), # Remove batch dimension
            'article_attention_mask': article_encoding['attention_mask'].squeeze(0), # Remove batch dimension
            'summary': summary_encoding['input_ids'].squeeze(0), # Remove batch dimension
            'summary_attention_mask': summary_encoding['attention_mask'].squeeze(0) # Remove batch dimension
        }

In [7]:
def collate_fn(batch):
    # Batch is a list of dictionaries from __getitem__
    return {
        'article': torch.stack([item['article'] for item in batch]),
        'article_attention_mask': torch.stack([item['article_attention_mask'] for item in batch]),
        'summary': torch.stack([item['summary'] for item in batch]),
        'summary_attention_mask': torch.stack([item['summary_attention_mask'] for item in batch])
    }

In [8]:
# DataLoader setup

train_dataset = Seq2SeqDataset(train_sample['articles'].tolist(), train_sample['summaries'].tolist(), tokenizer)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)

validation_dataset= Seq2SeqDataset(validation_sample['articles'].tolist(), validation_sample['summaries'].tolist(), tokenizer)
validation_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)

In [9]:

import math

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        # x = [seq_len, batch_size, d_model]
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

class Encoder(nn.Module):
    def __init__(self, vocab_size, hid_dim, n_layers, n_heads, pf_dim, dropout, device, max_length=512):
        super().__init__()

        self.device = device

        self.embedding = nn.Embedding(vocab_size, hid_dim)
        self.pos_encoder = PositionalEncoding(hid_dim, dropout, max_length)

        encoder_layers = nn.TransformerEncoderLayer(
            d_model=hid_dim, 
            nhead=n_heads, 
            dim_feedforward=pf_dim, 
            dropout=dropout,
            batch_first=False
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, n_layers)

        self.dropout = nn.Dropout(dropout)
        self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)

    def forward(self, src, src_mask=None):
        # src = [batch_size, src_len]
        # src_mask = [batch_size, src_len] (boolean mask, True for padding positions)
        
        # Convert to [src_len, batch_size]
        src = src.transpose(0, 1)  
        
        # Embedding and positional encoding
        # [src_len, batch_size, hid_dim]
        embedded = self.dropout(self.embedding(src) * self.scale)
        src = self.pos_encoder(embedded)
        
        # Create src_key_padding_mask for transformer
        # It should be a boolean mask where True indicates padding positions
        src_key_padding_mask = src_mask if src_mask is not None else None
        
        # Pass through transformer encoder
        # [src_len, batch_size, hid_dim]
        encoder_output = self.transformer_encoder(src, src_key_padding_mask=src_key_padding_mask)
        
        # Return to [batch_size, src_len, hid_dim]
        return encoder_output.transpose(0, 1)

class Decoder(nn.Module):
    def __init__(self, vocab_size, hid_dim, n_layers, n_heads, pf_dim, dropout, device, max_length=128):
        super().__init__()

        self.device = device

        self.embedding = nn.Embedding(vocab_size, hid_dim)
        self.pos_encoder = PositionalEncoding(hid_dim, dropout, max_length)

        decoder_layers = nn.TransformerDecoderLayer(
            d_model=hid_dim, 
            nhead=n_heads, 
            dim_feedforward=pf_dim, 
            dropout=dropout,
            batch_first=False
        )
        self.transformer_decoder = nn.TransformerDecoder(decoder_layers, n_layers)

        self.fc_out = nn.Linear(hid_dim, vocab_size)
        self.dropout = nn.Dropout(dropout)
        self.scale = torch.sqrt(torch.FloatTensor([hid_dim])).to(device)

    def forward(self, trg, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None):
        # trg = [batch_size, trg_len]
        # memory = [batch_size, src_len, hid_dim]
        
        # Convert to [trg_len, batch_size]
        trg = trg.transpose(0, 1)
        
        # Convert memory to [src_len, batch_size, hid_dim]
        memory = memory.transpose(0, 1)
        
        # Embedding and positional encoding
        # [trg_len, batch_size, hid_dim]
        embedded = self.dropout(self.embedding(trg) * self.scale)
        trg = self.pos_encoder(embedded)
        
        # Pass through transformer decoder
        # [trg_len, batch_size, hid_dim]
        decoder_output = self.transformer_decoder(
            trg, 
            memory, 
            tgt_mask=tgt_mask,
            memory_mask=memory_mask,
            tgt_key_padding_mask=tgt_key_padding_mask,
            memory_key_padding_mask=memory_key_padding_mask
        )
        
        # Output projection
        # [trg_len, batch_size, vocab_size]
        output = self.fc_out(decoder_output)
        
        # Return to [batch_size, trg_len, vocab_size]
        return output.transpose(0, 1)

class Seq2SeqTransformer(nn.Module):
    def __init__(self, encoder, decoder, pad_idx, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.pad_idx = pad_idx
        self.device = device
        self.sos_idx = tokenizer.bos_token_id
        self.eos_idx = tokenizer.eos_token_id

    def make_src_mask(self, src):
        src_mask = (src == self.pad_idx)
        return src_mask
    
    def make_trg_mask(self, trg):
        trg_key_padding_mask = (trg == self.pad_idx)
        trg_len = trg.shape[1]
        tgt_mask = torch.triu(torch.ones((trg_len, trg_len), device=self.device), diagonal=1).bool()
        return tgt_mask, trg_key_padding_mask

    def forward(self, src, trg):
        src_mask = self.make_src_mask(src)
        tgt_mask, tgt_key_padding_mask = self.make_trg_mask(trg)
        
        encoder_output = self.encoder(src, src_mask)
        output = self.decoder(
            trg=trg, 
            memory=encoder_output, 
            tgt_mask=tgt_mask, 
            tgt_key_padding_mask=tgt_key_padding_mask,
            memory_key_padding_mask=src_mask
        )
        
        return output
    
    def generate(self, input_ids, attention_mask=None, max_length=128, num_beams=4, length_penalty=2.0, early_stopping=True):
        batch_size = input_ids.shape[0]
        
        # Encoder pass
        encoder_output = self.encoder(input_ids, (input_ids == self.pad_idx))
        
        # Initialize decoder input with SOS tokens
        decoder_input = torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * self.sos_idx
        
        # Use beam search for generation
        if num_beams > 1:
            return self._generate_beam_search(
                encoder_output=encoder_output,
                encoder_mask=(input_ids == self.pad_idx),
                start_token_id=self.sos_idx,
                end_token_id=self.eos_idx,
                max_length=max_length,
                num_beams=num_beams,
                length_penalty=length_penalty,
                early_stopping=early_stopping
            )
        
        # Greedy search as fallback
        return self._generate_greedy(
            encoder_output=encoder_output,
            encoder_mask=(input_ids == self.pad_idx),
            start_token_id=self.sos_idx,
            end_token_id=self.eos_idx,
            max_length=max_length
        )
    
    def _generate_greedy(self, encoder_output, encoder_mask, start_token_id, end_token_id, max_length):
        batch_size = encoder_output.shape[0]
        
        # Initialize decoder input with start token
        decoder_input = torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * start_token_id
        
        # Track which sequences are completed
        completed_sequences = torch.zeros(batch_size, dtype=torch.bool, device=self.device)
        
        for _ in range(max_length - 1):
            # Make target mask for current decoder input
            tgt_mask, tgt_key_padding_mask = self.make_trg_mask(decoder_input)
            
            # Decode one step
            decoder_output = self.decoder(
                trg=decoder_input,
                memory=encoder_output,
                tgt_mask=tgt_mask,
                tgt_key_padding_mask=tgt_key_padding_mask,
                memory_key_padding_mask=encoder_mask
            )
            
            # Get next token prediction
            next_token_logits = decoder_output[:, -1, :]
            next_token = next_token_logits.argmax(dim=-1, keepdim=True)
            
            # Concatenate to decoder input
            decoder_input = torch.cat([decoder_input, next_token], dim=1)
            
            # Check for EOS
            completed_sequences = completed_sequences | (next_token.squeeze(-1) == end_token_id)
            if completed_sequences.all():
                break
        
        return decoder_input
    
    def _generate_beam_search(self, encoder_output, encoder_mask, start_token_id, end_token_id, max_length, num_beams, length_penalty, early_stopping):
        batch_size = encoder_output.shape[0]
        
        # Expand encoder output for beam search
        # [batch_size, seq_len, hidden] -> [batch_size * num_beams, seq_len, hidden]
        encoder_output = encoder_output.unsqueeze(1).expand(-1, num_beams, -1, -1).reshape(batch_size * num_beams, -1, encoder_output.shape[-1])
        
        # Expand encoder mask
        # [batch_size, seq_len] -> [batch_size * num_beams, seq_len]
        encoder_mask = encoder_mask.unsqueeze(1).expand(-1, num_beams, -1).reshape(batch_size * num_beams, -1)
        
        # Start with beams of SOS tokens
        current_tokens = torch.ones((batch_size * num_beams, 1), dtype=torch.long, device=self.device) * start_token_id
        
        # Track beam scores
        beam_scores = torch.zeros((batch_size, num_beams), device=self.device)
        beam_scores[:, 1:] = float('-inf')  # Only first beam is active initially
        beam_scores = beam_scores.view(-1)  # [batch_size * num_beams]
        
        # Track completed sequences and scores
        done_sequences = []
        done_scores = []
        done = [False for _ in range(batch_size)]
        
        for step in range(max_length - 1):
            # Make target mask for current decoder input
            tgt_mask, tgt_key_padding_mask = self.make_trg_mask(current_tokens)
            
            # Get decoder output
            decoder_output = self.decoder(
                trg=current_tokens,
                memory=encoder_output,
                tgt_mask=tgt_mask,
                tgt_key_padding_mask=tgt_key_padding_mask,
                memory_key_padding_mask=encoder_mask
            )
            
            # Get log probabilities for next token
            vocab_size = decoder_output.shape[-1]
            next_token_logits = decoder_output[:, -1, :]
            logprobs = torch.nn.functional.log_softmax(next_token_logits, dim=-1)  # [batch_size * num_beams, vocab_size]
            
            # Calculate new scores
            # [batch_size * num_beams, vocab_size]
            next_scores = beam_scores.unsqueeze(1) + logprobs
            next_scores = next_scores.view(batch_size, num_beams * vocab_size)
            
            # Get the best 2*num_beams candidates
            # [batch_size, 2*num_beams]
            topk_scores, topk_indices = next_scores.topk(2 * num_beams, dim=1, largest=True, sorted=True)
            
            # Process each batch item
            next_tokens = []
            next_scores = []
            
            for batch_idx in range(batch_size):
                # Skip if this batch item is already done
                if done[batch_idx]:
                    next_tokens.extend([current_tokens[batch_idx * num_beams] for _ in range(num_beams)])
                    next_scores.extend([beam_scores[batch_idx * num_beams] for _ in range(num_beams)])
                    continue
                
                # Find which beams and tokens to keep
                beam_indices = []
                token_indices = []
                
                for score_idx, (token_idx, score) in enumerate(zip(topk_indices[batch_idx], topk_scores[batch_idx])):
                    # Convert to beam index and token index
                    beam_idx = token_idx // vocab_size
                    token = token_idx % vocab_size
                    
                    # Add to candidates if not already filled
                    if len(beam_indices) < num_beams:
                        beam_indices.append(beam_idx)
                        token_indices.append(token)
                
                # Create new token sequences
                for beam_idx, token in zip(beam_indices, token_indices):
                    # Get current token sequence for this beam
                    token_seq = current_tokens[batch_idx * num_beams + beam_idx].clone()
                    # Add new token
                    new_seq = torch.cat([token_seq, token.unsqueeze(0)], dim=0)
                    next_tokens.append(new_seq)
                    
                    # Apply length penalty
                    lp = ((5 + len(new_seq)) / 6) ** length_penalty
                    next_scores.append(beam_scores[batch_idx * num_beams + beam_idx] + logprobs[batch_idx * num_beams + beam_idx, token] / lp)
                    
                    # Check if sequence is done
                    if token == end_token_id:
                        done_sequences.append(new_seq)
                        done_scores.append(next_scores[-1])
                        
                # Check if all sequences for this batch are done
                if len(done_sequences) >= num_beams and early_stopping:
                    done[batch_idx] = True
            
            # Update current tokens and beam scores
            current_tokens = torch.stack(next_tokens).view(batch_size * num_beams, -1)
            beam_scores = torch.tensor(next_scores, device=self.device).view(batch_size * num_beams)
            
            # Check if all batches are done
            if all(done):
                break
        
        # If no sequences completed, use the current ones
        if not done_sequences:
            done_sequences = [current_tokens[i] for i in range(batch_size * num_beams)]
            done_scores = beam_scores.tolist()
        
        # Get best sequence for each batch
        result = []
        for i in range(batch_size):
            # Find best sequence for this batch
            best_seq = None
            best_score = float('-inf')
            
            for seq, score in zip(done_sequences, done_scores):
                if score > best_score:
                    best_score = score
                    best_seq = seq
            
            result.append(best_seq)
        
        return torch.stack(result)


In [10]:
PAD_IDX = tokenizer.pad_token_id
UNK_IDX = tokenizer.unk_token_id
SOS_IDX = tokenizer.bos_token_id
EOS_IDX = tokenizer.eos_token_id

# Model Hyperparameters (can be adjusted)
ENC_LAYERS = 3
DEC_LAYERS = 3
ENC_HEADS = 8
DEC_HEADS = 8
ENC_PF_DIM = 512 
DEC_PF_DIM = 512
ENC_DROPOUT = 0.1
DEC_DROPOUT = 0.1

# Instantiate the Transformer model
encoder = Encoder(vocab_size, HIDDEN_DIM, ENC_LAYERS, ENC_HEADS, ENC_PF_DIM, ENC_DROPOUT, device, MAX_LENGTH_ARTICLE)
decoder = Decoder(vocab_size, HIDDEN_DIM, DEC_LAYERS, DEC_HEADS, DEC_PF_DIM, DEC_DROPOUT, device, MAX_LENGTH_SUMMARY)

model = Seq2SeqTransformer(encoder, decoder, PAD_IDX, device).to(device)

# Define optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
criterion = nn.CrossEntropyLoss(ignore_index = PAD_IDX)

# Define scheduler
# Calculate total training steps for the scheduler
total_steps = len(train_loader) * NUM_EPOCHS
warmup_steps = int(0.1 * total_steps) # 10% of total steps for warmup

def linear_warmup_decay(step, warmup_steps, total_steps):
    if step < warmup_steps:
        return (step + 1) / (warmup_steps + 1)
    else:
        return max(1e-7, (total_steps - step) / (total_steps - warmup_steps))

def warmup_cosine_with_restarts(step, warmup_steps, total_steps, num_cycles=1):
    if step < warmup_steps:
        return (step + 1) / (warmup_steps + 1)
    else:
        progress = (step - warmup_steps) / (total_steps - warmup_steps)
        cycle_progress = progress * num_cycles % 1
        return max(1e-7, 0.5 * (1 + math.cos(math.pi * cycle_progress)))

def get_scheduler(optimizer, total_steps, warmup_steps, num_cycles=None, types='warmup_cosine_with_restarts'):
    if types == 'warmup_cosine_with_restarts':
        assert num_cycles != None, 'must specify num_cycles when types="warmup_cosine_with_restarts"'
        return torch.optim.lr_scheduler.LambdaLR(
            optimizer,
            lr_lambda=lambda step: warmup_cosine_with_restarts(
                step, warmup_steps, total_steps, num_cycles=num_cycles)
        )
    elif types == 'linear_warmup_decay':
        return torch.optim.lr_scheduler.LambdaLR(
            optimizer,
            lr_lambda=lambda step: linear_warmup_decay(step, warmup_steps, total_steps)
        )
    else:
        raise Exception('not implemented')

scheduler = get_scheduler(
    optimizer,
    total_steps=total_steps,
    warmup_steps=warmup_steps,
    num_cycles=NUM_CYCLES, # Used for warmup_cosine_with_restarts
    types=SCHEDULER_TYPE
)

print(f"The model has {sum(p.numel() for p in model.parameters() if p.requires_grad):,} trainable parameters")
print(f"GPU Memory allocated: {torch.cuda.memory_allocated()/1024**2:.2f} MB")
print(f"GPU Memory reserved: {torch.cuda.memory_reserved()/1024**2:.2f} MB")
print("Model architecture:")
print(model)



The model has 89,880,153 trainable parameters
GPU Memory allocated: 344.12 MB
GPU Memory reserved: 364.00 MB
Model architecture:
Seq2SeqTransformer(
  (encoder): Encoder(
    (embedding): Embedding(50265, 512)
    (pos_encoder): PositionalEncoding(
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer_encoder): TransformerEncoder(
      (layers): ModuleList(
        (0-2): 3 x TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
          )
          (linear1): Linear(in_features=512, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=512, out_features=512, bias=True)
          (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dro

In [11]:
# # W&B setup
# wandb.init(
#     project="Seq2Seq-Summarization",
#     name=f"transformer-scratch-{datetime.now().strftime('%Y%m%d-%H%M%S')}",
#     config={
#         "model": "Seq2Seq-Transformer-Scratch",
#         "hidden_dim": HIDDEN_DIM,
#         "batch_size": BATCH_SIZE,
#         "learning_rate": LEARNING_RATE,
#         "num_epochs": NUM_EPOCHS,
#         "encoder_layers": ENC_LAYERS,
#         "decoder_layers": DEC_LAYERS,
#         "encoder_heads": ENC_HEADS,
#         "decoder_heads": DEC_HEADS,
#         "encoder_pf_dim": ENC_PF_DIM,
#         "decoder_pf_dim": DEC_PF_DIM,
#         "encoder_dropout": ENC_DROPOUT,
#         "decoder_dropout": DEC_DROPOUT,
#         "weight_decay": WEIGHT_DECAY,
#         "clip": CLIP,
#         "scheduler_type": SCHEDULER_TYPE,
#         "warmup_steps": warmup_steps,
#         "num_cycles": NUM_CYCLES, # Relevant for cosine scheduler
#         "vocab_size": vocab_size,
#         "max_length_article": MAX_LENGTH_ARTICLE,
#         "max_length_summary": MAX_LENGTH_SUMMARY,
#         "teacher_forcing_ratio": TEACHER_FORCING_RATIO # Added teacher forcing ratio
#     }
# )

# wandb.watch(model)

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mvubkk67[0m ([33mvubkk67-hanoi-university-of-science-and-technology[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [12]:
# # Save best model directory
# # save_dir = os.path.join(model_dir, 'transformer_scratch_best_model')
# save_dir = 'transformer_scratch_best_model'
# os.makedirs(save_dir, exist_ok=True)

# # Training loop
# best_val_loss = float("inf")

# for epoch in range(NUM_EPOCHS):
#     start_time = time.time()

#     model.train()
#     train_loss = 0.0

#     progress_bar_train = tqdm(train_loader, desc=f"Epoch {epoch+1} [Train]")
#     for batch in progress_bar_train:
#         src = batch['article'].to(device)
#         trg = batch['summary'].to(device)

#         optimizer.zero_grad()

#         # The target input for the decoder should exclude the <eos> token
#         # Pass the correct masks: trg_lookahead_mask, src_padding_mask, trg_padding_mask
#         output = model(src, trg[:, :-1])

#         # output = [batch_size, trg_len - 1, output_dim]
#         # trg = [batch_size, trg_len]

#         output_dim = output.shape[-1]

#         # Reshape for loss calculation
#         output = output.contiguous().view(-1, output_dim)
#         trg = trg[:, 1:].contiguous().view(-1) # The target output should exclude the <sos> token

#         loss = criterion(output, trg)

#         # Check for NaN loss
#         if torch.isnan(loss):
#             print("NaN loss detected! Stopping training.")
#             break # Exit the inner batch loop

#         loss.backward()

#         torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP)

#         optimizer.step()
#         scheduler.step()

#         train_loss += loss.item()
#         progress_bar_train.set_postfix(loss=loss.item(), lr=optimizer.param_groups[0]['lr'])

#     # If NaN loss was detected in the inner loop, break the outer epoch loop as well
#     if torch.isnan(torch.tensor(train_loss)):
#         break

#     train_loss /= len(train_loader)

#     # Validation
#     model.eval()
#     val_loss = 0.0
#     with torch.no_grad():
#         progress_bar_val = tqdm(validation_loader, desc=f"Epoch {epoch+1} [Val]")
#         for batch in progress_bar_val:
#             src = batch['article'].to(device)
#             trg = batch['summary'].to(device)

#             # The target input for the decoder should exclude the <eos> token
#             # Pass the correct masks: trg_lookahead_mask, src_padding_mask, trg_padding_mask
#             output = model(src, trg[:, :-1])

#             output_dim = output.shape[-1]

#             # Reshape for loss calculation
#             output = output.contiguous().view(-1, output_dim)
#             trg = trg[:, 1:].contiguous().view(-1) # The target output should exclude the <sos> token

#             loss = criterion(output, trg)

#             val_loss += loss.item()
#             progress_bar_val.set_postfix(loss=loss.item())

#         val_loss /= len(validation_loader)
#         current_lr = optimizer.param_groups[0]['lr'] # Get current LR after scheduler step

#     # W&B log
#     wandb.log({
#         "epoch": epoch + 1,
#         "train_loss": train_loss,
#         "val_loss": val_loss,
#         "lr": current_lr,
#         "best_val_loss": best_val_loss # Log best val loss seen so far
#     })

#     # Save best model and tokenizer
#     if val_loss < best_val_loss:
#         best_val_loss = val_loss
#         torch.save(model.state_dict(), os.path.join(save_dir, 'best_transformer_model.pt'))
#         tokenizer.save_pretrained(save_dir)
#         print(f"Saved best model and tokenizer to `{save_dir}` at epoch {epoch+1}")

#     print(
#         f"Epoch {epoch+1:02d} | "
#         f"Train Loss: {train_loss:.4f} | "
#         f"Val Loss: {val_loss:.4f} | "
#         f"LR: {current_lr:.6f} | "
#         f"Time: {time.time() - start_time:.2f}s"
#     )

# # W&B end
# wandb.finish()

Epoch 1 [Train]:   0%|          | 0/4800 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [13]:
model_dir = "../Model"

In [14]:
output_path = os.path.join(datafilter, "test_pred_6.csv") 
save_dir = os.path.join(model_dir, 'transformer_scratch_best_model')
print(save_dir)
tokenizer = BartTokenizer.from_pretrained(save_dir)

../Model/transformer_scratch_best_model


In [15]:
test_dataset = Seq2SeqDataset(test_sample['articles'].tolist(), test_sample['summaries'].tolist(), tokenizer)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=2, shuffle=False, collate_fn=collate_fn)

PAD_IDX = tokenizer.pad_token_id
UNK_IDX = tokenizer.unk_token_id
SOS_IDX = tokenizer.bos_token_id
EOS_IDX = tokenizer.eos_token_id
encoder = Encoder(vocab_size, HIDDEN_DIM, ENC_LAYERS, ENC_HEADS, ENC_PF_DIM, ENC_DROPOUT, device, MAX_LENGTH_ARTICLE)
decoder = Decoder(vocab_size, HIDDEN_DIM, DEC_LAYERS, DEC_HEADS, DEC_PF_DIM, DEC_DROPOUT, device, MAX_LENGTH_SUMMARY)

model = Seq2SeqTransformer(encoder, decoder, PAD_IDX, device)




In [16]:
# Load the model from saved state
model.load_state_dict(torch.load(os.path.join(save_dir, 'best_transformer_model.pt'), map_location=device))
model.to(device)
model.eval()

print("Generating summaries...")
predictions = []
with torch.no_grad():
    for batch in tqdm(test_loader, desc="Generating summaries"):
        input_ids = batch['article'].to(device)
        attention_mask = batch['article_attention_mask'].to(device)
        
        # Instead of using beam search which has issues, use greedy search
        # by setting num_beams=1
        output_ids = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_length=MAX_LENGTH_SUMMARY,
            num_beams=1,  # Use greedy search instead of beam search
            length_penalty=2.0,
            early_stopping=True
        )
        batch_preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
        predictions.extend(batch_preds)

# Save predictions
print(f"Generated {len(predictions)} summaries")
test_sample = test_sample.iloc[:len(predictions)].copy()
test_sample["predicted_summary"] = predictions
test_sample.to_csv(output_path, index=False)
print(f"✅ File has been saved at: {output_path}")

  model.load_state_dict(torch.load(os.path.join(save_dir, 'best_transformer_model.pt'), map_location=device))


Generating summaries...


Generating summaries:   0%|          | 0/2112 [00:00<?, ?it/s]

Generated 4224 summaries
✅ File has been saved at: ../dataft/test_pred_6.csv


In [17]:

# Calculate metrics
print("Calculating evaluation metrics...")
test_pred = pd.read_csv(output_path)
display(test_pred[["articles","summaries", "predicted_summary"]].head(2))

# Check required columns
if "summaries" in test_pred.columns and "predicted_summary" in test_pred.columns:
    references = test_pred["summaries"].fillna("<empty>").astype(str).tolist()
    predictions = test_pred["predicted_summary"].fillna("<empty>").astype(str).tolist()

    # Filter valid pairs
    valid_pairs = [
        (pred, ref) for pred, ref in zip(predictions, references)
        if pred.strip() and pred != "<empty>" and ref.strip()
    ]
    
    if not valid_pairs:
        print("No valid pairs found for metric calculation.")
    else:
        filtered_preds, filtered_refs = zip(*valid_pairs)

        # ROUGE
        rouge = Rouge()
        rouge_scores = rouge.get_scores(filtered_preds, filtered_refs, avg=True)
        print("ROUGE scores:")
        print(f"ROUGE-1: {rouge_scores['rouge-1']['f']:.4f}")
        print(f"ROUGE-2: {rouge_scores['rouge-2']['f']:.4f}")
        print(f"ROUGE-L: {rouge_scores['rouge-l']['f']:.4f}")

        # BERTScore
        P, R, F1 = bert_score(filtered_preds, filtered_refs, lang="en", verbose=False)
        print("BERTScore:")
        print(f"Precision: {P.mean().item():.4f}")
        print(f"Recall:    {R.mean().item():.4f}")
        print(f"F1:        {F1.mean().item():.4f}")

        # METEOR
        print("METEOR Score (average):")
        meteor_scores = [single_meteor_score(ref.split(), pred.split()) 
                        for pred, ref in zip(filtered_preds, filtered_refs)]
        print(f"METEOR: {sum(meteor_scores)/len(meteor_scores):.4f}")

else:
    print("Could not find required columns 'summaries' and 'predicted_summary' for metric calculation.")

Calculating evaluation metrics...


Unnamed: 0,articles,summaries,predicted_summary
0,A Florida bus passenger was arrested for throw...,"Joel Parker, 33, was riding the bus in St John...",The driver was arrested on suspicion of assaul...
1,Aston Villa may be able to sign Cordoba strike...,Aston Villa have held talks over Cordoba strik...,Manchester City midfielder has been linked wit...


ROUGE scores:
ROUGE-1: 0.1815
ROUGE-2: 0.0280
ROUGE-L: 0.1741


2025-05-19 03:53:04.319459: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-05-19 03:53:04.336711: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1747601584.355542   19516 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1747601584.361095   19516 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1747601584.374852   19516 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

BERTScore:
Precision: 0.8236
Recall:    0.8188
F1:        0.8209
METEOR Score (average):
METEOR: 0.1318
