<a href="https://colab.research.google.com/github/hanguyenai/sudo-code-nlp/blob/main/05_attention_text_summarization/Daily_Papers_Text_Summarization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 1 Setup & Install dependencies

In [86]:
!pip install rouge_score

Collecting rouge_score
  Downloading rouge_score-0.1.2.tar.gz (17 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: rouge_score
  Building wheel for rouge_score (setup.py) ... [?25l[?25hdone
  Created wheel for rouge_score: filename=rouge_score-0.1.2-py3-none-any.whl size=24934 sha256=59cbeb21801ad7f9519dc555eff1203503c48683a67006d39b2784c2854d1f53
  Stored in directory: /root/.cache/pip/wheels/85/9d/af/01feefbe7d55ef5468796f0c68225b6788e85d9d0a281e7a70
Successfully built rouge_score
Installing collected packages: rouge_score
Successfully installed rouge_score-0.1.2


In [87]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import random
import re
import json
import time
from typing import List, Dict
from datetime import datetime, timedelta
import requests
from bs4 import BeautifulSoup
from IPython.display import HTML, display
import warnings
from tqdm import tqdm
import os
from transformers import BertTokenizer, BertModel
from sklearn.model_selection import train_test_split
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rouge_score import rouge_scorer
from collections import Counter
import math

warnings.filterwarnings('ignore')

# Set style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette('husl')

# Set random seeds
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# Check GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


# 2 Crawl papers from Hugging Face

In [5]:
class HuggingFaceScraper:
    """
    Advanced HuggingFace Papers Scraper with Full Abstract Support
    Scrapes papers from https://huggingface.co/papers with complete abstracts
    """

    def __init__(self):
        self.base_url = "https://huggingface.co"
        self.session = requests.Session()
        self.user_agents = [
            'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36',
            'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36',
            'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36'
        ]
        self.session.headers.update({
            'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8',
            'Accept-Language': 'en-US,en;q=0.5',
            'Accept-Encoding': 'gzip, deflate, br',
            'Connection': 'keep-alive',
            'Upgrade-Insecure-Requests': '1'
        })

    def scrape_date_range(self, start_date, end_date, delay=2.5, fetch_full_abstract=True):
        """
        Scrape papers from date range

        Args:
            start_date (str): Start date in YYYY-MM-DD format
            end_date (str): End date in YYYY-MM-DD format
            delay (float): Delay between requests in seconds
            fetch_full_abstract (bool): If True, visit each paper page to get full abstract

        Returns:
            list: List of paper dictionaries
        """
        current = datetime.strptime(start_date, "%Y-%m-%d")
        end = datetime.strptime(end_date, "%Y-%m-%d")

        all_papers = []
        days = 0

        print(f"\n{'='*80}")
        print(f"🚀 SCRAPING HUGGINGFACE PAPERS")
        print(f"{'='*80}")
        print(f"📅 Date range: {start_date} to {end_date}")
        print(f"⚙️  Full abstract mode: {'ENABLED ✓' if fetch_full_abstract else 'DISABLED ✗'}")
        if fetch_full_abstract:
            print(f"⚠️  Note: Full abstract mode is SLOWER but gets complete abstracts")
        print(f"{'='*80}\n")

        while current <= end:
            date_str = current.strftime("%Y-%m-%d")
            papers = self._scrape_date(date_str, fetch_full_abstract)

            if papers:
                all_papers.extend(papers)
                print(f"✓ {date_str}: {len(papers)} papers")
            else:
                print(f"○ {date_str}: no papers")

            days += 1

            # Progress update every 10 days
            if days % 10 == 0:
                print(f"\n{'─'*80}")
                print(f"📊 Progress: {days} days, {len(all_papers)} papers total")
                print(f"{'─'*80}\n")
                time.sleep(random.uniform(5, 10))  # Longer break

            current += timedelta(days=1)
            time.sleep(delay + random.uniform(0, 2))

        # Deduplicate
        unique = self._deduplicate(all_papers)

        print(f"\n{'='*80}")
        print(f"✅ SCRAPING COMPLETE!")
        print(f"{'='*80}")
        print(f"📊 Total unique papers: {len(unique)}")
        print(f"{'='*80}\n")

        return unique

    def _scrape_date(self, date_str, fetch_full_abstract=True):
        """
        Scrape papers from a specific date

        Args:
            date_str (str): Date in YYYY-MM-DD format
            fetch_full_abstract (bool): Whether to fetch full abstracts

        Returns:
            list: List of papers for that date
        """
        url = f"{self.base_url}/papers/date/{date_str}"

        try:
            headers = {'User-Agent': random.choice(self.user_agents)}
            response = self.session.get(url, headers=headers, timeout=15)

            if response.status_code == 404:
                return []

            response.raise_for_status()
            soup = BeautifulSoup(response.text, 'html.parser')

            # Find all paper articles (limit to 50 per day)
            articles = soup.find_all('article')[:50]

            papers = []
            for i, article in enumerate(articles, 1):
                paper = self._parse_article(article, date_str, fetch_full_abstract)
                if paper:
                    papers.append(paper)
                    if fetch_full_abstract and i % 5 == 0:
                        print(f"      Progress: {i}/{len(articles)} papers processed")

            return papers

        except requests.exceptions.RequestException as e:
            print(f"   ❌ Request error for {date_str}: {e}")
            return []
        except Exception as e:
            print(f"   ❌ Unexpected error for {date_str}: {e}")
            return []

    def _parse_article(self, article, date_str, fetch_full_abstract=True):
        """
        Parse article element to extract paper information

        Args:
            article: BeautifulSoup element
            date_str (str): Date string
            fetch_full_abstract (bool): Whether to fetch full abstract

        Returns:
            dict: Paper information or None if parsing fails
        """
        paper = {
            'date': date_str,
            'scraped_at': datetime.now().isoformat()
        }

        # Extract title
        title_elem = article.find('h3') or article.find('h2')
        if not title_elem:
            return None
        paper['title'] = title_elem.get_text(strip=True)

        # Extract link and paper ID
        link_elem = article.find('a', href=re.compile(r'/papers/\d+\.\d+'))
        if link_elem:
            paper['link'] = self.base_url + link_elem['href']

            # Extract arXiv ID (format: YYMM.NNNNN)
            match = re.search(r'(\d{4})\.(\d{4,5})', link_elem['href'])
            if match:
                paper['paper_id'] = f"{match.group(1)}.{match.group(2)}"
                # Extract year from arXiv ID
                year_code = match.group(1)[:2]
                paper['year'] = 2000 + int(year_code)

        # Extract metadata from listing page
        text_content = article.get_text()

        # Extract number of authors
        author_match = re.search(r'(\d+)\s+authors?', text_content, re.IGNORECASE)
        if author_match:
            paper['authors_count'] = int(author_match.group(1))

        # Extract upvotes - Multiple strategies
        # Strategy 1: Look for div with class "leading-none" (common for upvote count)
        upvote_div = article.find('div', class_='leading-none')
        if upvote_div:
            upvote_text = upvote_div.get_text(strip=True)
            try:
                paper['upvotes'] = int(upvote_text)
            except ValueError:
                pass

        # Strategy 2: Text pattern matching (fallback)
        if 'upvotes' not in paper:
            upvote_match = re.search(r'(\d+)\s+(?:upvotes?|likes?)', text_content, re.IGNORECASE)
            if upvote_match:
                paper['upvotes'] = int(upvote_match.group(1))

        # Extract abstract
        if fetch_full_abstract and paper.get('link'):
            # Fetch FULL abstract AND additional metadata from detail page
            print(f"      📖 {paper['title'][:60]}...")
            detail_data = self._fetch_detail_page(paper['link'])

            if detail_data:
                # Update with full abstract
                if detail_data.get('abstract'):
                    paper['abstract'] = detail_data['abstract']
                    paper['abstract_type'] = 'full'
                    word_count = len(detail_data['abstract'].split())
                    print(f"         ✓ Full abstract: {word_count} words")

                # Update upvotes from detail page if not found earlier
                if detail_data.get('upvotes') and not paper.get('upvotes'):
                    paper['upvotes'] = detail_data['upvotes']

                # Add any other metadata from detail page
                if detail_data.get('authors'):
                    paper['authors'] = detail_data['authors']
            else:
                # Fallback to preview
                print(f"         ⚠️ Detail page failed, using preview")
                preview = self._extract_preview_abstract(article)
                if preview:
                    paper['abstract'] = preview
                    paper['abstract_type'] = 'preview'

            # Add delay between detail page requests
            time.sleep(random.uniform(1.5, 3))
        else:
            # Get preview abstract from listing page
            preview = self._extract_preview_abstract(article)
            if preview:
                paper['abstract'] = preview
                paper['abstract_type'] = 'preview'

        # Only return paper if it has an abstract
        return paper if paper.get('abstract') else None

    def _extract_preview_abstract(self, article):
        """
        Extract preview abstract from listing page

        Args:
            article: BeautifulSoup element

        Returns:
            str: Preview abstract or None
        """
        paragraphs = article.find_all('p')
        abstracts = [p.get_text(strip=True) for p in paragraphs
                    if len(p.get_text(strip=True)) > 50]

        if abstracts:
            return ' '.join(abstracts)

        return None

    def _fetch_detail_page(self, paper_url):
        """
        Fetch complete information from paper detail page
        Including: full abstract, upvotes, authors, etc.

        Args:
            paper_url (str): URL of the paper detail page

        Returns:
            dict: Dictionary with abstract, upvotes, authors, etc. or None
        """
        try:
            headers = {'User-Agent': random.choice(self.user_agents)}
            response = self.session.get(paper_url, headers=headers, timeout=15)
            response.raise_for_status()

            soup = BeautifulSoup(response.text, 'html.parser')
            detail_data = {}

            # ===== EXTRACT UPVOTES =====
            # Look for div with class "leading-none" (common pattern for upvote display)
            upvote_div = soup.find('div', class_='leading-none')
            if upvote_div:
                upvote_text = upvote_div.get_text(strip=True)
                try:
                    detail_data['upvotes'] = int(upvote_text)
                except ValueError:
                    pass

            # ===== EXTRACT AUTHORS =====
            # Look for author information (may need adjustment based on actual HTML)
            author_links = soup.find_all('a', href=re.compile(r'/papers\?author='))
            if author_links:
                detail_data['authors'] = [a.get_text(strip=True) for a in author_links]

            # ===== EXTRACT FULL ABSTRACT =====
            abstract = self._extract_abstract_from_soup(soup)
            if abstract:
                detail_data['abstract'] = abstract

            return detail_data if detail_data else None

        except requests.exceptions.RequestException as e:
            print(f"         ❌ Request error: {e}")
            return None
        except Exception as e:
            print(f"         ❌ Parse error: {e}")
            return None

    def _extract_abstract_from_soup(self, soup):
        """
        Extract abstract from BeautifulSoup object using multiple strategies

        Args:
            soup: BeautifulSoup object of the detail page

        Returns:
            str: Full abstract or None
        """
        # ===== STRATEGY 1: Find Abstract heading and extract following content =====
        headings = soup.find_all(['h2', 'h3'])
        for heading in headings:
            heading_text = heading.get_text(strip=True).lower()
            if heading_text == 'abstract':
                # Look for next sibling div or container
                next_container = heading.find_next_sibling()
                if next_container:
                    # Extract all paragraphs with class text-gray
                    paragraphs = next_container.find_all('p', class_=re.compile(r'text-gray'))
                    if paragraphs:
                        abstract_parts = []
                        for p in paragraphs:
                            text = p.get_text(separator=' ', strip=True)
                            if len(text) > 20:
                                abstract_parts.append(text)
                        if abstract_parts:
                            return ' '.join(abstract_parts)

        # ===== STRATEGY 2: Direct search for text-gray-600 paragraphs =====
        gray_paragraphs = soup.find_all('p', class_=re.compile(r'text-gray-600'))
        if gray_paragraphs:
            abstract_parts = []
            for p in gray_paragraphs:
                text = p.get_text(separator=' ', strip=True)
                if len(text) > 50:
                    abstract_parts.append(text)

            if abstract_parts:
                full_text = ' '.join(abstract_parts)
                return full_text[:5000] if len(full_text) > 5000 else full_text

        # ===== STRATEGY 3: Look in prose/content containers =====
        content_containers = soup.find_all('div', class_=re.compile(r'prose|content|article'))
        for container in content_containers:
            paragraphs = container.find_all('p')
            long_paras = []
            for p in paragraphs:
                text = p.get_text(separator=' ', strip=True)
                if len(text) > 100:
                    long_paras.append(text)

            if long_paras:
                return ' '.join(long_paras[:5])

        # ===== STRATEGY 4: Get all substantial paragraphs (fallback) =====
        all_paragraphs = soup.find_all('p')
        substantial = []
        for p in all_paragraphs:
            text = p.get_text(separator=' ', strip=True)
            if len(text) > 100:
                substantial.append(text)

        if substantial:
            return ' '.join(substantial[:3])

        return None

    def _deduplicate(self, papers):
        """
        Remove duplicate papers based on paper_id or title

        Args:
            papers (list): List of paper dictionaries

        Returns:
            list: Deduplicated list of papers
        """
        seen = set()
        unique = []

        for paper in papers:
            # Use paper_id as primary key, fall back to title
            key = paper.get('paper_id') or paper.get('title')
            if key and key not in seen:
                seen.add(key)
                unique.append(paper)

        if len(papers) != len(unique):
            print(f"🔄 Deduplication: {len(papers)} → {len(unique)} papers")

        return unique

    def save(self, papers, filename="papers.json"):
        """
        Save papers to JSON file with statistics

        Args:
            papers (list): List of paper dictionaries
            filename (str): Output filename

        Returns:
            str: Filename of saved file
        """
        # Calculate statistics
        stats = {
            'total': len(papers),
            'with_abstract': sum(1 for p in papers if p.get('abstract')),
            'with_full_abstract': sum(1 for p in papers
                                     if p.get('abstract_type') == 'full'),
            'with_preview_abstract': sum(1 for p in papers
                                        if p.get('abstract_type') == 'preview'),
            'with_paper_id': sum(1 for p in papers if p.get('paper_id')),
            'with_authors': sum(1 for p in papers if p.get('authors_count')),
            'with_author_names': sum(1 for p in papers if p.get('authors')),
            'with_upvotes': sum(1 for p in papers if p.get('upvotes')),
            'avg_abstract_length': sum(len(p.get('abstract', '').split())
                                      for p in papers) / len(papers) if papers else 0,
            'avg_upvotes': sum(p.get('upvotes', 0) for p in papers) / sum(1 for p in papers if p.get('upvotes')) if any(p.get('upvotes') for p in papers) else 0
        }

        # Prepare output
        output = {
            'scraped_at': datetime.now().isoformat(),
            'statistics': stats,
            'papers': papers
        }

        # Save to file
        with open(filename, 'w', encoding='utf-8') as f:
            json.dump(output, f, indent=2, ensure_ascii=False)

        # Print summary
        print(f"\n{'='*80}")
        print(f"💾 SAVED TO: {filename}")
        print(f"{'='*80}")
        print(f"📊 Statistics:")
        print(f"   Total papers: {stats['total']}")
        print(f"   With abstracts: {stats['with_abstract']}")
        print(f"   - Full abstracts: {stats['with_full_abstract']}")
        print(f"   - Preview abstracts: {stats['with_preview_abstract']}")
        print(f"   With upvotes: {stats['with_upvotes']} (avg: {stats['avg_upvotes']:.1f})")
        print(f"   With author names: {stats['with_author_names']}")
        print(f"   Average abstract length: {stats['avg_abstract_length']:.1f} words")
        print(f"{'='*80}\n")

        return filename

    def scrape_recent_days(self, num_days=30, fetch_full_abstract=True):
        """
        Scrape papers from recent N days

        Args:
            num_days (int): Number of recent days to scrape
            fetch_full_abstract (bool): Whether to fetch full abstracts

        Returns:
            list: List of papers
        """
        end_date = datetime.now()
        start_date = end_date - timedelta(days=num_days)

        return self.scrape_date_range(
            start_date=start_date.strftime("%Y-%m-%d"),
            end_date=end_date.strftime("%Y-%m-%d"),
            fetch_full_abstract=fetch_full_abstract
        )

    @staticmethod
    def load(filename="papers.json"):
        """Load papers from JSON file (works for both FastHFScraper and HuggingFaceScraper formats)"""
        try:
            with open(filename, 'r', encoding='utf-8') as f:
                data = json.load(f)

            papers = data.get('papers', [])
            stats = data.get('stats') or data.get('statistics', {})  # Support both formats
            scraped_at = data.get('scraped_at', 'Unknown')

            print(f"\n📂 Loaded from: {filename}")
            print(f"   Scraped at: {scraped_at}")
            print(f"   Total papers: {stats.get('total', len(papers))}")

            # Handle both date_range formats
            date_range = stats.get('date_range')
            if date_range:
                if isinstance(date_range, list):
                    print(f"   Date range: {date_range[0]} to {date_range[1]}")
                else:
                    print(f"   Date range: {date_range}")

            return papers
        except FileNotFoundError:
            print(f"❌ File not found: {filename}")
            return []
        except json.JSONDecodeError:
            print(f"❌ Invalid JSON file: {filename}")
            return []
        except Exception as e:
            print(f"❌ Error loading file: {e}")
            return []

In [None]:
print("=" * 80)
print("🔥 SCRAPING WITH ANTI-BLOCKING PROTECTION")
print("=" * 80)

scraper = HuggingFaceScraper()

print("\n🎯 Full scraping mode: Every day from Jan to Oct 2025")
papers = scraper.scrape_recent_days(num_days=365, fetch_full_abstract=True)

🔥 SCRAPING WITH ANTI-BLOCKING PROTECTION

🎯 Full scraping mode: Every day from Jan to Oct 2025
⚠️  This will take 15-20 minutes but gives maximum papers!


🚀 SCRAPING HUGGINGFACE PAPERS
📅 Date range: 2024-10-15 to 2025-10-15
⚙️  Full abstract mode: ENABLED ✓
⚠️  Note: Full abstract mode is SLOWER but gets complete abstracts

      📖 Animate-X: Universal Character Image Animation with Enhanced...
         ✓ Full abstract: 214 words
      📖 LOKI: A Comprehensive Synthetic Data Detection Benchmark usi...
         ✓ Full abstract: 198 words
      📖 MMIE: Massive Multimodal Interleaved Comprehension Benchmark...
         ✓ Full abstract: 226 words
      📖 Toward General Instruction-Following Alignment for Retrieval...
         ✓ Full abstract: 237 words
      📖 MEGA-Bench: Scaling Multimodal Evaluation to over 500 Real-W...
         ✓ Full abstract: 171 words
      Progress: 5/22 papers processed
      📖 Omni-MATH: A Universal Olympiad Level Mathematic Benchmark F...
         ✓ Full abstrac

# 3 Loading dataset after crawling

In [7]:
!gdown --id 1wS4NdhyQG76eHTUR5-BNq0pO6QGqxcCJ

Downloading...
From: https://drive.google.com/uc?id=1wS4NdhyQG76eHTUR5-BNq0pO6QGqxcCJ
To: /content/papers_365_days.json
100% 7.55M/7.55M [00:00<00:00, 137MB/s]


In [8]:
scraper = HuggingFaceScraper()
papers = scraper.load("papers_365_days.json")
print(f"Loaded {len(papers)} papers")


📂 Loaded from: papers_365_days.json
   Scraped at: 2025-10-13T18:36:22.412767
   Total papers: 4251
Loaded 4251 papers


In [9]:
print(f"\n📊 Scraping Results:")
print(f"   Total papers: {len(papers)}")

# Count papers with complete info
papers_with_abstract = sum(1 for p in papers if p.get('abstract'))
papers_with_id = sum(1 for p in papers if p.get('paper_id'))
papers_with_authors = sum(1 for p in papers if p.get('authors_count'))
papers_with_upvotes = sum(1 for p in papers if p.get('upvotes'))

print(f"   With abstracts: {papers_with_abstract}")
print(f"   With paper IDs: {papers_with_id}")
print(f"   With author count: {papers_with_authors}")
print(f"   With upvotes: {papers_with_upvotes}")


📊 Scraping Results:
   Total papers: 4251
   With abstracts: 4251
   With paper IDs: 4251
   With author count: 3995
   With upvotes: 1


In [10]:
# Deduplicate
seen = set()
unique_papers = []
for p in papers:
    key = p.get('paper_id') or p.get('title')
    if key and key not in seen:
        seen.add(key)
        unique_papers.append(p)

print(f"\n✅ Total papers after merge: {len(unique_papers)}")

# Save combined dataset
scraper.save(papers, "papers_jan_preview.json")

print("\n" + "=" * 80)
print("✅ RESUME COMPLETE")
print("=" * 80)
print(f"Total unique: {len(unique_papers)} papers")
print("=" * 80)


✅ Total papers after merge: 4251

💾 SAVED TO: papers_jan_preview.json
📊 Statistics:
   Total papers: 4251
   With abstracts: 4251
   - Full abstracts: 4251
   - Preview abstracts: 0
   With upvotes: 1 (avg: 2.0)
   With author names: 0
   Average abstract length: 188.4 words


✅ RESUME COMPLETE
Total unique: 4251 papers


In [11]:
# Display sample papers with all available info
print("\n" + "=" * 80)
print("📄 SAMPLE PAPERS")
print("=" * 80)

for i, paper in enumerate(papers[:5], 1):
    print(f"\n{i}. {paper['title']}")
    print(f"   Paper ID: {paper.get('paper_id', 'N/A')}")
    print(f"   Date: {paper.get('date', 'N/A')}")
    print(f"   Year: {paper.get('year', 'N/A')}")
    print(f"   Number of Authors: {paper.get('authors_count', 'N/A')}")
    print(f"   Abstract: {paper.get('abstract', 'N/A')[:120]}...")
    if paper.get('link'):
        print(f"   Link: {paper['link']}")


📄 SAMPLE PAPERS

1. WALL-E: World Alignment by Rule Learning Improves World Model-based LLM
  Agents
   Paper ID: 2410.07484
   Date: 2024-10-13
   Year: 2024
   Number of Authors: 7
   Abstract: Can large language models (LLMs) directly serve as powerful world models for model-based agents ? While the gaps between...
   Link: https://huggingface.co/papers/2410.07484

2. MathCoder2: Better Math Reasoning from Continued Pretraining on
  Model-translated Mathematical Code
   Paper ID: 2410.08196
   Date: 2024-10-13
   Year: 2024
   Number of Authors: 8
   Abstract: Code has been shown to be effective in enhancing the mathematical reasoning abilities of large language models due to it...
   Link: https://huggingface.co/papers/2410.08196

3. MLLM as Retriever: Interactively Learning Multimodal Retrieval for
  Embodied Agents
   Paper ID: 2410.03450
   Date: 2024-10-13
   Year: 2024
   Number of Authors: 4
   Abstract: MLLM agents demonstrate potential for complex embodied tasks by retriev

In [24]:
def preprocess_text(text):
    """Clean and normalize text"""
    text = text.lower()
    text = re.sub(r'[^a-z0-9\s]', '', text)
    text = re.sub(r'\s+', ' ', text).strip()
    return text

In [61]:
def load_data(json_path):
    with open(json_path, 'r', encoding='utf-8') as f:
        data = json.load(f)

    papers = []
    for paper in data['papers']:
        # Chỉ lấy papers có abstract và title
        if paper.get('abstract') and paper.get('title'):
            abstract = paper['abstract'].strip()
            title = paper['title'].strip()

            # Filter too short/long
            if len(abstract) > 50 and len(title) > 10:
                papers.append({
                    'abstract': abstract,
                    'title': title,
                    'paper_id': paper.get('paper_id', '')
                })

    print(f"✓ Loaded {len(papers)} valid papers")
    return papers

In [67]:
class SummarizationDataset(Dataset):
    def __init__(self, papers, tokenizer, max_input_len, max_target_len):
        self.papers = papers
        self.tokenizer = tokenizer
        self.max_input_len = max_input_len
        self.max_target_len = max_target_len

        # Build vocabulary from titles
        self.word2idx = {'<PAD>': 0, '<SOS>': 1, '<EOS>': 2, '<UNK>': 3}
        self.idx2word = {0: '<PAD>', 1: '<SOS>', 2: '<EOS>', 3: '<UNK>'}
        self._build_vocab()

    def _build_vocab(self):
        idx = 4
        for paper in self.papers:
            words = paper['title'].lower().split()
            for word in words:
                if word not in self.word2idx:
                    self.word2idx[word] = idx
                    self.idx2word[idx] = word
                    idx += 1

    def encode_title(self, title):
        words = title.lower().split()
        indices = [self.word2idx.get(w, self.word2idx['<UNK>']) for w in words]
        indices = [self.word2idx['<SOS>']] + indices + [self.word2idx['<EOS>']]

        # Padding
        if len(indices) < self.max_target_len:
            indices += [self.word2idx['<PAD>']] * (self.max_target_len - len(indices))
        else:
            indices = indices[:self.max_target_len-1] + [self.word2idx['<EOS>']]

        return torch.tensor(indices)

    def decode_title(self, indices):
        words = []
        for idx in indices:
            if idx == self.word2idx['<EOS>']:
                break
            if idx not in [self.word2idx['<PAD>'], self.word2idx['<SOS>']]:
                words.append(self.idx2word.get(idx, '<UNK>'))
        return ' '.join(words)

    def __len__(self):
        return len(self.papers)

    def __getitem__(self, idx):
        paper = self.papers[idx]

        # Encode abstract with BERT tokenizer
        encoding = self.tokenizer(
            paper['abstract'],
            max_length=self.max_input_len,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        # Encode title
        title_indices = self.encode_title(paper['title'])

        return {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'target': title_indices,
            'title_text': paper['title']
        }

# 3 Attention Mechanism Architecture

In [75]:
class BahdanauAttention(nn.Module):
    """
    Bahdanau (Additive) Attention
    score(h_t, h_s) = v^T * tanh(W1*h_t + W2*h_s)
    """
    def __init__(self, hidden_size, decoder_hidden_size):
        super(BahdanauAttention, self).__init__()
        self.W1 = nn.Linear(hidden_size, decoder_hidden_size, bias=False)
        self.W2 = nn.Linear(decoder_hidden_size, decoder_hidden_size, bias=False)
        self.v = nn.Linear(decoder_hidden_size, 1, bias=False)

    def forward(self, encoder_outputs, decoder_hidden):
        """
        encoder_outputs: (batch, seq_len, hidden_size)
        decoder_hidden: (batch, decoder_hidden_size)
        """
        batch_size = encoder_outputs.size(0)
        seq_len = encoder_outputs.size(1)

        # Expand decoder hidden to match encoder outputs
        decoder_hidden = decoder_hidden.unsqueeze(1).repeat(1, seq_len, 1)

        # Calculate attention scores
        energy = torch.tanh(self.W1(encoder_outputs) + self.W2(decoder_hidden))
        attention_scores = self.v(energy).squeeze(2)  # (batch, seq_len)

        # Apply softmax
        attention_weights = F.softmax(attention_scores, dim=1)

        # Calculate context vector
        context = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs)
        context = context.squeeze(1)  # (batch, hidden_size)

        return context, attention_weights

In [93]:
class BERTEncoder(nn.Module):
    def __init__(self, model_name, freeze_bert=False):
        super(BERTEncoder, self).__init__()
        self.bert = BertModel.from_pretrained(model_name)

        # Có thể freeze BERT để train nhanh hơn
        if freeze_bert:
            for param in self.bert.parameters():
                param.requires_grad = False

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        return outputs.last_hidden_state

In [95]:
class AttentionDecoder(nn.Module):
    def __init__(self, vocab_size, embed_size, encoder_hidden_size,
                 decoder_hidden_size, dropout=0.3):
        super(AttentionDecoder, self).__init__()
        self.vocab_size = vocab_size
        self.decoder_hidden_size = decoder_hidden_size

        # Layers
        self.embedding = nn.Embedding(vocab_size, embed_size, padding_idx=0)
        self.attention = BahdanauAttention(encoder_hidden_size, decoder_hidden_size)
        self.lstm = nn.LSTM(embed_size + encoder_hidden_size, decoder_hidden_size,
                           batch_first=True, num_layers=1, dropout=0)  # Không dùng LSTM dropout vì chỉ 1 layer
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(decoder_hidden_size, vocab_size)

    def forward(self, input_token, hidden, cell, encoder_outputs):
        """
        input_token: (batch, 1)
        hidden: (1, batch, decoder_hidden_size)
        cell: (1, batch, decoder_hidden_size)
        encoder_outputs: (batch, seq_len, encoder_hidden_size)
        """
        # Embedding với dropout
        embedded = self.dropout(self.embedding(input_token))  # (batch, 1, embed_size)

        # Attention
        decoder_hidden = hidden.squeeze(0)  # (batch, decoder_hidden_size)
        context, attention_weights = self.attention(encoder_outputs, decoder_hidden)
        context = context.unsqueeze(1)  # (batch, 1, encoder_hidden_size)

        # Concatenate embedding and context
        lstm_input = torch.cat([embedded, context], dim=2)

        # LSTM
        output, (hidden, cell) = self.lstm(lstm_input, (hidden, cell))

        # Dropout trước FC layer
        output = self.dropout(output.squeeze(1))

        # Prediction
        prediction = self.fc(output)  # (batch, vocab_size)

        return prediction, hidden, cell, attention_weights

In [96]:
class Seq2SeqWithAttention(nn.Module):
    def __init__(self, encoder, decoder, device):
        super(Seq2SeqWithAttention, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def forward(self, src, src_mask, trg, teacher_forcing_ratio=0.5):
        """
        src: (batch, src_len)
        trg: (batch, trg_len)
        """
        batch_size = src.size(0)
        trg_len = trg.size(1)
        vocab_size = self.decoder.vocab_size

        # Encoder
        encoder_outputs = self.encoder(src, src_mask)

        # Initialize decoder hidden state
        hidden = torch.zeros(1, batch_size, self.decoder.decoder_hidden_size).to(self.device)
        cell = torch.zeros(1, batch_size, self.decoder.decoder_hidden_size).to(self.device)

        # Store outputs
        outputs = torch.zeros(batch_size, trg_len, vocab_size).to(self.device)

        # First input is <SOS> token
        input_token = trg[:, 0].unsqueeze(1)

        for t in range(1, trg_len):
            output, hidden, cell, _ = self.decoder(input_token, hidden, cell, encoder_outputs)
            outputs[:, t, :] = output

            # Teacher forcing
            teacher_force = torch.rand(1).item() < teacher_forcing_ratio
            top1 = output.argmax(1).unsqueeze(1)
            input_token = trg[:, t].unsqueeze(1) if teacher_force else top1

        return outputs

In [97]:
class LabelSmoothingLoss(nn.Module):
    def __init__(self, classes, smoothing=0.0, ignore_index=-100):
        super(LabelSmoothingLoss, self).__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.cls = classes
        self.ignore_index = ignore_index

    def forward(self, pred, target):
        pred = pred.log_softmax(dim=-1)

        with torch.no_grad():
            true_dist = torch.zeros_like(pred)
            true_dist.fill_(self.smoothing / (self.cls - 1))
            true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
            true_dist[:, self.ignore_index] = 0
            mask = torch.nonzero(target == self.ignore_index, as_tuple=False)
            if mask.dim() > 0 and mask.size(0) > 0:
                true_dist.index_fill_(0, mask.squeeze(), 0.0)

        return torch.mean(torch.sum(-true_dist * pred, dim=-1))

In [98]:
def calculate_rouge(predictions, references):
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
    scores = {'rouge1': [], 'rouge2': [], 'rougeL': []}

    for pred, ref in zip(predictions, references):
        score = scorer.score(ref, pred)
        scores['rouge1'].append(score['rouge1'].fmeasure)
        scores['rouge2'].append(score['rouge2'].fmeasure)
        scores['rougeL'].append(score['rougeL'].fmeasure)

    return {k: np.mean(v) for k, v in scores.items()}

def calculate_bleu(predictions, references):
    smoothie = SmoothingFunction().method4
    scores = []

    for pred, ref in zip(predictions, references):
        pred_tokens = pred.lower().split()
        ref_tokens = [ref.lower().split()]
        score = sentence_bleu(ref_tokens, pred_tokens, smoothing_function=smoothie)
        scores.append(score)

    return np.mean(scores)

In [99]:
def train_epoch(model, loader, optimizer, criterion, device, teacher_forcing_ratio):
    model.train()
    epoch_loss = 0

    for batch in tqdm(loader, desc='Training'):
        src = batch['input_ids'].to(device)
        src_mask = batch['attention_mask'].to(device)
        trg = batch['target'].to(device)

        optimizer.zero_grad()
        output = model(src, src_mask, trg, teacher_forcing_ratio)

        # Calculate loss (ignore padding)
        output_dim = output.shape[-1]
        output = output[:, 1:].reshape(-1, output_dim)
        trg = trg[:, 1:].reshape(-1)

        loss = criterion(output, trg)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        epoch_loss += loss.item()

    return epoch_loss / len(loader)

def evaluate(model, loader, criterion, dataset, device):
    model.eval()
    epoch_loss = 0
    predictions = []
    references = []

    with torch.no_grad():
        for batch in tqdm(loader, desc='Evaluating'):
            src = batch['input_ids'].to(device)
            src_mask = batch['attention_mask'].to(device)
            trg = batch['target'].to(device)

            output = model(src, src_mask, trg, 0)  # No teacher forcing

            # Loss
            output_dim = output.shape[-1]
            output_flat = output[:, 1:].reshape(-1, output_dim)
            trg_flat = trg[:, 1:].reshape(-1)
            loss = criterion(output_flat, trg_flat)
            epoch_loss += loss.item()

            # Generate predictions
            pred_indices = output.argmax(2)
            for i in range(len(pred_indices)):
                pred_text = dataset.decode_title(pred_indices[i].cpu().numpy())
                ref_text = batch['title_text'][i]
                predictions.append(pred_text)
                references.append(ref_text)

    avg_loss = epoch_loss / len(loader)
    rouge_scores = calculate_rouge(predictions, references)
    bleu_score = calculate_bleu(predictions, references)

    return avg_loss, rouge_scores, bleu_score

# 5 Full pipeline training

In [100]:
CONFIG = {
    'encoder_model': 'bert-base-uncased',  # Pre-trained encoder
    'max_input_length': 256,
    'max_target_length': 32,
    'hidden_size': 768,  # BERT hidden size
    'decoder_hidden_size': 256,  # Giảm từ 512 → 256
    'embed_size': 256,  # Giảm từ 300 → 256
    'dropout': 0.5,  # Tăng từ 0.3 → 0.5
    'batch_size': 16,
    'learning_rate': 5e-4,  # Giảm từ 1e-3 → 5e-4
    'encoder_lr': 1e-5,
    'num_epochs': 30,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'teacher_forcing_ratio': 0.7,
    'weight_decay': 1e-4,  # L2 regularization
    'label_smoothing': 0.1,  # Label smoothing
    'patience': 5,  # Early stopping patience
    'freeze_bert': True  # Freeze BERT để giảm overfitting
}

In [101]:
# 1. Load data
papers = load_data('papers_365_days.json')
train_papers, temp = train_test_split(papers, test_size=0.3, random_state=42)
val_papers, test_papers = train_test_split(temp, test_size=0.5, random_state=42)
print(f"Train: {len(train_papers)}, Val: {len(val_papers)}, Test: {len(test_papers)}\n")

# 2. Tokenizer
tokenizer = BertTokenizer.from_pretrained(CONFIG['encoder_model'])

# 3. Datasets
train_dataset = SummarizationDataset(train_papers, tokenizer,
                                    CONFIG['max_input_length'], CONFIG['max_target_length'])
val_dataset = SummarizationDataset(val_papers, tokenizer,
                                  CONFIG['max_input_length'], CONFIG['max_target_length'])
test_dataset = SummarizationDataset(test_papers, tokenizer,
                                   CONFIG['max_input_length'], CONFIG['max_target_length'])

train_dataset.word2idx = val_dataset.word2idx = test_dataset.word2idx = train_dataset.word2idx
train_dataset.idx2word = val_dataset.idx2word = test_dataset.idx2word = train_dataset.idx2word

vocab_size = len(train_dataset.word2idx)
print(f"Vocabulary size: {vocab_size}\n")

✓ Loaded 4249 valid papers
Train: 2974, Val: 637, Test: 638

Vocabulary size: 6759



In [102]:
# 4. DataLoaders
train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'], shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'])
test_loader = DataLoader(test_dataset, batch_size=CONFIG['batch_size'])

# 5. Model
encoder = BERTEncoder(CONFIG['encoder_model'], freeze_bert=CONFIG['freeze_bert'])
decoder = AttentionDecoder(
    vocab_size, CONFIG['embed_size'],
    CONFIG['hidden_size'], CONFIG['decoder_hidden_size'],
    CONFIG['dropout']
)
model = Seq2SeqWithAttention(encoder, decoder, CONFIG['device']).to(CONFIG['device'])

print(f"Model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

Model parameters: 5,042,535


In [104]:
# 6. Training setup với Label Smoothing và Weight Decay
criterion = LabelSmoothingLoss(
    classes=vocab_size,
    smoothing=CONFIG['label_smoothing'],
    ignore_index=train_dataset.word2idx['<PAD>']
)

optimizer = torch.optim.Adam([
    {'params': encoder.parameters(), 'lr': CONFIG['encoder_lr'], 'weight_decay': CONFIG['weight_decay']},
    {'params': decoder.parameters(), 'lr': CONFIG['learning_rate'], 'weight_decay': CONFIG['weight_decay']}
])

# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=2
)

In [105]:
# 7. Training loop
best_val_loss = float('inf')
history = {'train_loss': [], 'val_loss': [], 'rouge1': [], 'rouge2': [], 'rougeL': [], 'bleu': []}

print("Starting training...\n")
for epoch in range(CONFIG['num_epochs']):
    train_loss = train_epoch(model, train_loader, optimizer, criterion,
                            CONFIG['device'], CONFIG['teacher_forcing_ratio'])
    val_loss, rouge, bleu = evaluate(model, val_loader, criterion, val_dataset, CONFIG['device'])

    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['rouge1'].append(rouge['rouge1'])
    history['rouge2'].append(rouge['rouge2'])
    history['rougeL'].append(rouge['rougeL'])
    history['bleu'].append(bleu)

    print(f"Epoch {epoch+1}/{CONFIG['num_epochs']}")
    print(f"  Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
    print(f"  ROUGE-1: {rouge['rouge1']:.4f} | ROUGE-2: {rouge['rouge2']:.4f} | ROUGE-L: {rouge['rougeL']:.4f}")
    print(f"  BLEU: {bleu:.4f}\n")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), 'best_attention_model.pt')

Starting training...



Training: 100%|██████████| 186/186 [01:07<00:00,  2.78it/s]
Evaluating: 100%|██████████| 40/40 [00:11<00:00,  3.36it/s]


Epoch 1/30
  Train Loss: 2.3855 | Val Loss: 2.3259
  ROUGE-1: 0.0675 | ROUGE-2: 0.0000 | ROUGE-L: 0.0662
  BLEU: 0.0082



Training: 100%|██████████| 186/186 [01:06<00:00,  2.78it/s]
Evaluating: 100%|██████████| 40/40 [00:12<00:00,  3.29it/s]


Epoch 2/30
  Train Loss: 2.2539 | Val Loss: 2.3040
  ROUGE-1: 0.0741 | ROUGE-2: 0.0000 | ROUGE-L: 0.0718
  BLEU: 0.0104



Training: 100%|██████████| 186/186 [01:07<00:00,  2.76it/s]
Evaluating: 100%|██████████| 40/40 [00:12<00:00,  3.31it/s]


Epoch 3/30
  Train Loss: 2.2146 | Val Loss: 2.2937
  ROUGE-1: 0.0946 | ROUGE-2: 0.0062 | ROUGE-L: 0.0903
  BLEU: 0.0126



Training: 100%|██████████| 186/186 [01:07<00:00,  2.76it/s]
Evaluating: 100%|██████████| 40/40 [00:12<00:00,  3.30it/s]


Epoch 4/30
  Train Loss: 2.1848 | Val Loss: 2.2838
  ROUGE-1: 0.1194 | ROUGE-2: 0.0171 | ROUGE-L: 0.1128
  BLEU: 0.0177



Training: 100%|██████████| 186/186 [01:07<00:00,  2.77it/s]
Evaluating: 100%|██████████| 40/40 [00:12<00:00,  3.33it/s]


Epoch 5/30
  Train Loss: 2.1560 | Val Loss: 2.2788
  ROUGE-1: 0.1351 | ROUGE-2: 0.0254 | ROUGE-L: 0.1281
  BLEU: 0.0224



Training: 100%|██████████| 186/186 [01:07<00:00,  2.76it/s]
Evaluating: 100%|██████████| 40/40 [00:11<00:00,  3.35it/s]


Epoch 6/30
  Train Loss: 2.1325 | Val Loss: 2.2726
  ROUGE-1: 0.1409 | ROUGE-2: 0.0268 | ROUGE-L: 0.1332
  BLEU: 0.0221



Training: 100%|██████████| 186/186 [01:07<00:00,  2.77it/s]
Evaluating: 100%|██████████| 40/40 [00:12<00:00,  3.31it/s]


Epoch 7/30
  Train Loss: 2.1053 | Val Loss: 2.2672
  ROUGE-1: 0.1494 | ROUGE-2: 0.0324 | ROUGE-L: 0.1409
  BLEU: 0.0255



Training: 100%|██████████| 186/186 [01:07<00:00,  2.77it/s]
Evaluating: 100%|██████████| 40/40 [00:11<00:00,  3.33it/s]


Epoch 8/30
  Train Loss: 2.0816 | Val Loss: 2.2671
  ROUGE-1: 0.1527 | ROUGE-2: 0.0336 | ROUGE-L: 0.1436
  BLEU: 0.0251



Training: 100%|██████████| 186/186 [01:06<00:00,  2.78it/s]
Evaluating: 100%|██████████| 40/40 [00:11<00:00,  3.34it/s]


Epoch 9/30
  Train Loss: 2.0634 | Val Loss: 2.2588
  ROUGE-1: 0.1603 | ROUGE-2: 0.0373 | ROUGE-L: 0.1516
  BLEU: 0.0265



Training: 100%|██████████| 186/186 [01:07<00:00,  2.77it/s]
Evaluating: 100%|██████████| 40/40 [00:11<00:00,  3.35it/s]


Epoch 10/30
  Train Loss: 2.0404 | Val Loss: 2.2562
  ROUGE-1: 0.1598 | ROUGE-2: 0.0351 | ROUGE-L: 0.1496
  BLEU: 0.0264



Training: 100%|██████████| 186/186 [01:06<00:00,  2.78it/s]
Evaluating: 100%|██████████| 40/40 [00:11<00:00,  3.34it/s]


Epoch 11/30
  Train Loss: 2.0211 | Val Loss: 2.2481
  ROUGE-1: 0.1587 | ROUGE-2: 0.0347 | ROUGE-L: 0.1502
  BLEU: 0.0261



Training: 100%|██████████| 186/186 [01:07<00:00,  2.77it/s]
Evaluating: 100%|██████████| 40/40 [00:11<00:00,  3.35it/s]


Epoch 12/30
  Train Loss: 2.0021 | Val Loss: 2.2484
  ROUGE-1: 0.1609 | ROUGE-2: 0.0336 | ROUGE-L: 0.1512
  BLEU: 0.0255



Training: 100%|██████████| 186/186 [01:06<00:00,  2.78it/s]
Evaluating: 100%|██████████| 40/40 [00:12<00:00,  3.33it/s]


Epoch 13/30
  Train Loss: 1.9801 | Val Loss: 2.2489
  ROUGE-1: 0.1712 | ROUGE-2: 0.0410 | ROUGE-L: 0.1600
  BLEU: 0.0278



Training: 100%|██████████| 186/186 [01:06<00:00,  2.80it/s]
Evaluating: 100%|██████████| 40/40 [00:12<00:00,  3.32it/s]


Epoch 14/30
  Train Loss: 1.9659 | Val Loss: 2.2460
  ROUGE-1: 0.1702 | ROUGE-2: 0.0377 | ROUGE-L: 0.1594
  BLEU: 0.0279



Training: 100%|██████████| 186/186 [01:06<00:00,  2.80it/s]
Evaluating: 100%|██████████| 40/40 [00:11<00:00,  3.34it/s]


Epoch 15/30
  Train Loss: 1.9476 | Val Loss: 2.2424
  ROUGE-1: 0.1709 | ROUGE-2: 0.0384 | ROUGE-L: 0.1600
  BLEU: 0.0284



Training: 100%|██████████| 186/186 [01:06<00:00,  2.79it/s]
Evaluating: 100%|██████████| 40/40 [00:12<00:00,  3.33it/s]


Epoch 16/30
  Train Loss: 1.9324 | Val Loss: 2.2484
  ROUGE-1: 0.1739 | ROUGE-2: 0.0440 | ROUGE-L: 0.1628
  BLEU: 0.0301



Training: 100%|██████████| 186/186 [01:06<00:00,  2.78it/s]
Evaluating: 100%|██████████| 40/40 [00:11<00:00,  3.34it/s]


Epoch 17/30
  Train Loss: 1.9149 | Val Loss: 2.2425
  ROUGE-1: 0.1765 | ROUGE-2: 0.0423 | ROUGE-L: 0.1653
  BLEU: 0.0293



Training: 100%|██████████| 186/186 [01:06<00:00,  2.80it/s]
Evaluating: 100%|██████████| 40/40 [00:11<00:00,  3.34it/s]


Epoch 18/30
  Train Loss: 1.8985 | Val Loss: 2.2395
  ROUGE-1: 0.1764 | ROUGE-2: 0.0416 | ROUGE-L: 0.1642
  BLEU: 0.0284



Training: 100%|██████████| 186/186 [01:06<00:00,  2.78it/s]
Evaluating: 100%|██████████| 40/40 [00:11<00:00,  3.33it/s]


Epoch 19/30
  Train Loss: 1.8882 | Val Loss: 2.2405
  ROUGE-1: 0.1729 | ROUGE-2: 0.0407 | ROUGE-L: 0.1604
  BLEU: 0.0270



Training: 100%|██████████| 186/186 [01:06<00:00,  2.80it/s]
Evaluating: 100%|██████████| 40/40 [00:11<00:00,  3.34it/s]


Epoch 20/30
  Train Loss: 1.8715 | Val Loss: 2.2419
  ROUGE-1: 0.1783 | ROUGE-2: 0.0421 | ROUGE-L: 0.1654
  BLEU: 0.0292



Training: 100%|██████████| 186/186 [01:07<00:00,  2.77it/s]
Evaluating: 100%|██████████| 40/40 [00:11<00:00,  3.33it/s]


Epoch 21/30
  Train Loss: 1.8615 | Val Loss: 2.2409
  ROUGE-1: 0.1836 | ROUGE-2: 0.0454 | ROUGE-L: 0.1705
  BLEU: 0.0301



Training: 100%|██████████| 186/186 [01:06<00:00,  2.80it/s]
Evaluating: 100%|██████████| 40/40 [00:12<00:00,  3.32it/s]


Epoch 22/30
  Train Loss: 1.8448 | Val Loss: 2.2413
  ROUGE-1: 0.1832 | ROUGE-2: 0.0443 | ROUGE-L: 0.1690
  BLEU: 0.0314



Training: 100%|██████████| 186/186 [01:06<00:00,  2.79it/s]
Evaluating: 100%|██████████| 40/40 [00:11<00:00,  3.33it/s]


Epoch 23/30
  Train Loss: 1.8277 | Val Loss: 2.2485
  ROUGE-1: 0.1868 | ROUGE-2: 0.0458 | ROUGE-L: 0.1731
  BLEU: 0.0311



Training: 100%|██████████| 186/186 [01:06<00:00,  2.80it/s]
Evaluating: 100%|██████████| 40/40 [00:12<00:00,  3.32it/s]


Epoch 24/30
  Train Loss: 1.8212 | Val Loss: 2.2366
  ROUGE-1: 0.1792 | ROUGE-2: 0.0430 | ROUGE-L: 0.1672
  BLEU: 0.0289



Training: 100%|██████████| 186/186 [01:07<00:00,  2.76it/s]
Evaluating: 100%|██████████| 40/40 [00:12<00:00,  3.33it/s]


Epoch 25/30
  Train Loss: 1.8082 | Val Loss: 2.2402
  ROUGE-1: 0.1803 | ROUGE-2: 0.0450 | ROUGE-L: 0.1674
  BLEU: 0.0292



Training: 100%|██████████| 186/186 [01:06<00:00,  2.80it/s]
Evaluating: 100%|██████████| 40/40 [00:12<00:00,  3.32it/s]


Epoch 26/30
  Train Loss: 1.7948 | Val Loss: 2.2391
  ROUGE-1: 0.1771 | ROUGE-2: 0.0421 | ROUGE-L: 0.1641
  BLEU: 0.0296



Training: 100%|██████████| 186/186 [01:06<00:00,  2.79it/s]
Evaluating: 100%|██████████| 40/40 [00:12<00:00,  3.32it/s]


Epoch 27/30
  Train Loss: 1.7840 | Val Loss: 2.2329
  ROUGE-1: 0.1857 | ROUGE-2: 0.0429 | ROUGE-L: 0.1728
  BLEU: 0.0305



Training: 100%|██████████| 186/186 [01:06<00:00,  2.78it/s]
Evaluating: 100%|██████████| 40/40 [00:11<00:00,  3.34it/s]


Epoch 28/30
  Train Loss: 1.7753 | Val Loss: 2.2387
  ROUGE-1: 0.1809 | ROUGE-2: 0.0415 | ROUGE-L: 0.1660
  BLEU: 0.0293



Training: 100%|██████████| 186/186 [01:07<00:00,  2.76it/s]
Evaluating: 100%|██████████| 40/40 [00:12<00:00,  3.30it/s]


Epoch 29/30
  Train Loss: 1.7669 | Val Loss: 2.2376
  ROUGE-1: 0.1864 | ROUGE-2: 0.0449 | ROUGE-L: 0.1738
  BLEU: 0.0312



Training: 100%|██████████| 186/186 [01:06<00:00,  2.78it/s]
Evaluating: 100%|██████████| 40/40 [00:12<00:00,  3.32it/s]


Epoch 30/30
  Train Loss: 1.7561 | Val Loss: 2.2356
  ROUGE-1: 0.1872 | ROUGE-2: 0.0397 | ROUGE-L: 0.1705
  BLEU: 0.0289



In [106]:
# 8. Test evaluation
print("Testing on test set...")
model.load_state_dict(torch.load('best_attention_model.pt'))
test_loss, test_rouge, test_bleu = evaluate(model, test_loader, criterion, test_dataset, CONFIG['device'])

print(f"\n{'='*60}")
print("FINAL TEST RESULTS")
print(f"{'='*60}")
print(f"Test Loss: {test_loss:.4f}")
print(f"ROUGE-1: {test_rouge['rouge1']:.4f}")
print(f"ROUGE-2: {test_rouge['rouge2']:.4f}")
print(f"ROUGE-L: {test_rouge['rougeL']:.4f}")
print(f"BLEU: {test_bleu:.4f}")

Testing on test set...


Evaluating: 100%|██████████| 40/40 [00:12<00:00,  3.31it/s]



FINAL TEST RESULTS
Test Loss: 2.1919
ROUGE-1: 0.1967
ROUGE-2: 0.0457
ROUGE-L: 0.1835
BLEU: 0.0314
