In [4]:
!pip install PyPDF2
!pip install datasets
!pip install rouge-score
!pip install bert-score
!pip install transformers

import nltk
# Download all necessary NLTK resources explicitly
nltk.download('punkt')
nltk.download('punkt_tab')  # Add this missing resource
import os
import numpy as np
import torch
import json
import re
import requests
import PyPDF2
import xml.etree.ElementTree as ET
from bs4 import BeautifulSoup
from io import BytesIO
from nltk.tokenize import sent_tokenize
from sklearn.metrics.pairwise import cosine_similarity
from transformers import PegasusForConditionalGeneration, PegasusTokenizer, AutoTokenizer, AutoModel
from transformers import AutoTokenizer as BertTokenizer, AutoModel as BertModel
from transformers import BartForConditionalGeneration, BartTokenizer
from transformers import GenerationConfig

# Ensure NLTK resources are downloaded
nltk.download('punkt', quiet=True)
nltk.download('punkt_tab', quiet=True)  # Add explicit download for missing resource

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# File paths
DRIVE_PATH = '/content/drive/MyDrive/nlp_project'
PAPERS_FILE = os.path.join(DRIVE_PATH, "cached_papers.json")
EMBEDDINGS_FILE = os.path.join(DRIVE_PATH, "cached_embeddings.npy")
PEGASUS_MODEL_PATH = os.path.join(DRIVE_PATH, "pegasus_finetuned_model")
BART_MODEL_PATH = os.path.join(DRIVE_PATH, "bart_finetuned_model")
PEGASUS_ARXIV_MODEL_PATH = os.path.join(DRIVE_PATH, "pegasus_arxiv_finetuned_model")
PEGASUS_ARXIV_FALLBACK_MODEL = "google/pegasus-arxiv"

def clean_extracted_text(text):
    if not text:
        return ""

    # Basic URL removal
    text = re.sub(r'https?://\S+|www\.\S+', '', text)

    # Remove math equations completely
    text = re.sub(r'\$[^$]*\$|\$\$[^$]*\$\$', '', text)
    text = re.sub(r'[a-zA-Z0-9]+[=><+\-*/^()[\]{}]+[a-zA-Z0-9]+', '', text)

    # Remove references section
    ref_patterns = [
        r'References\s*\n', r'REFERENCES\s*\n', r'Bibliography\s*\n',
        r'BIBLIOGRAPHY\s*\n', r'Works Cited\s*\n', r'REFERENCES CITED\s*\n'
    ]
    ref_start = len(text)
    for pattern in ref_patterns:
        matches = list(re.finditer(pattern, text))
        if matches:
            ref_start = min(ref_start, matches[-1].start())
    text = text[:ref_start]

    # Fix ligatures and special characters
    ligature_map = {
        '\ufb01': 'fi', '\ufb02': 'fl', '\ufb00': 'ff', '\ufb03': 'ffi', '\ufb04': 'ffl',
        '\u2019': "'", '\u2018': "'", '\u201c': '"', '\u201d': '"', '\u2014': '-', '\u2013': '-', '\u0003': ''
    }
    for ligature, replacement in ligature_map.items():
        text = text.replace(ligature, replacement)

    # Remove specific patterns like dates, numbers, citations
    text = re.sub(r'\b\d{1,2}[/-]\d{1,2}[/-]\d{2,4}\b|\b\d{2,4}[/-]\d{1,2}[/-]\d{1,2}\b', '', text)
    text = re.sub(r'\b\d+\.\d+%?\b|\b\d{2,}\b|\[\d+\]|\(\d+\)', '', text)

    # Fix hyphenated words across lines
    text = re.sub(r'(\w+)-\s*\n\s*(\w+)', r'\1\2', text)

    # Add spaces between joined different case words
    text = re.sub(r'([a-z])([A-Z])|([a-z])([0-9])|([0-9])([a-z])', r'\1 \2', text)

    # Normalize whitespace
    text = re.sub(r'\s+', ' ', text)

    # Fix spacing around punctuation
    text = re.sub(r'\s+([.,;:!?)])', r'\1', text)
    text = re.sub(r'([.,;:!?])([a-zA-Z])', r'\1 \2', text)

    # Remove any remaining special characters
    text = re.sub(r'[^\w\s.,;:!?()\-\'\"]+', ' ', text)

    # Final whitespace cleanup
    text = re.sub(r'\s+', ' ', text)

    return text.strip()

def download_arxiv_paper(arxiv_id):
    if "arxiv.org" in arxiv_id:
        arxiv_id = arxiv_id.split('/')[-1].split('v')[0]

    api_url = f"http://export.arxiv.org/api/query?id_list={arxiv_id}"
    response = requests.get(api_url)
    soup = BeautifulSoup(response.content, 'xml')
    abstract = soup.find('summary').text.strip() if soup.find('summary') else ""
    title = soup.find('title').text.strip() if soup.find('title') else ""

    pdf_url = f"http://arxiv.org/pdf/{arxiv_id}.pdf"
    response = requests.get(pdf_url)
    if response.status_code == 200:
        try:
            pdf_file = BytesIO(response.content)
            pdf_reader = PyPDF2.PdfReader(pdf_file)
            full_text = ""
            for page in pdf_reader.pages:
                full_text += page.extract_text()

            cleaned_text = clean_extracted_text(full_text)

            # Extract introduction
            introduction = ""
            intro_patterns = [
                r"(?i)(?:1\.?\s*|I\.?\s*)?Introduction(.*?)(?:\n\d\.|\n[A-Z]\.|\nII\.)",
                r"(?i)(?:1\.?\s*|I\.?\s*)?Introduction(.*?)(?=\n2\.|\nII\.)"
            ]
            for pattern in intro_patterns:
                intro_match = re.search(pattern, cleaned_text, re.DOTALL)
                if intro_match:
                    introduction = clean_extracted_text(intro_match.group(1).strip())
                    break

            # Extract conclusion
            conclusion = ""
            concl_patterns = [
                r"(?i)(?:\d\.?\s*|[IVX]+\.?\s*)?Conclusion[s]?(.*?)(?:\n\d\.|\n[A-Z]\.|\nReferences|\n[IVX]+\.)",
                r"(?i)(?:\d\.?\s*|[IVX]+\.?\s*)?Discussion(?:s)?(.*?)(?:\n\d\.|\n[A-Z]\.|\nReferences|\n[IVX]+\.)"
            ]
            for pattern in concl_patterns:
                concl_match = re.search(pattern, cleaned_text, re.DOTALL)
                if concl_match:
                    conclusion = clean_extracted_text(concl_match.group(1).strip())
                    break

            return {
                "title": title,
                "abstract": abstract,
                "introduction": introduction,
                "conclusion": conclusion,
                "full_text": cleaned_text
            }
        except Exception as e:
            print(f"Error extracting text from PDF: {e}")
            return {"title": title, "abstract": abstract, "introduction": "", "conclusion": "", "full_text": ""}

    return {"title": title, "abstract": abstract, "introduction": "", "conclusion": "", "full_text": ""}

def get_bert_embeddings(texts, model_name="bert-base-uncased"):
    tokenizer = BertTokenizer.from_pretrained(model_name)
    model = BertModel.from_pretrained(model_name).to(device)

    embeddings = []
    for text in texts:
        inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
        inputs = {key: val.to(device) for key, val in inputs.items()}

        with torch.no_grad():
            outputs = model(**inputs)

        embedding = outputs.last_hidden_state[:, 0, :].cpu().numpy()
        embeddings.append(embedding[0])

    return np.array(embeddings)

def highlight_query_relevant_sentences(text, query, top_n=10):
    text = clean_extracted_text(text)
    sentences = sent_tokenize(text)

    if not sentences:
        return []

    all_embeddings = get_bert_embeddings(sentences + [query])
    sentence_embeddings = all_embeddings[:-1]
    query_embedding = all_embeddings[-1]

    similarities = cosine_similarity(sentence_embeddings, query_embedding.reshape(1, -1)).flatten()
    top_indices = similarities.argsort()[-top_n:][::-1]

    return [(sentences[i], float(similarities[i]), i) for i in top_indices]

def get_paragraph_context(sentences, index, context_size=2):
    start_idx = max(0, index - context_size)
    end_idx = min(len(sentences) - 1, index + context_size)
    context_paragraph = " ".join(sentences[start_idx:end_idx + 1])
    return clean_extracted_text(context_paragraph)

def extract_relevant_content(paper_data, query, max_sentences_per_section=10, context_size=2):
    result = {
        "title": paper_data["title"],
        "abstract": paper_data["abstract"],
        "relevant_sections": []
    }

    sections = {
        "introduction": paper_data.get("introduction", ""),
        "conclusion": paper_data.get("conclusion", ""),
        "full_text": paper_data.get("full_text", "")
    }

    for section_name, section_text in sections.items():
        if not section_text:
            continue

        cleaned_text = clean_extracted_text(section_text)
        all_sentences = sent_tokenize(cleaned_text)
        relevant_sentences = highlight_query_relevant_sentences(cleaned_text, query, top_n=max_sentences_per_section)

        if relevant_sentences:
            processed_results = []
            relevant_sentences.sort(key=lambda x: x[2])  # Sort by index
            processed_indices = set()

            for sent, score, idx in relevant_sentences:
                if idx in processed_indices:
                    continue

                context_paragraph = get_paragraph_context(all_sentences, idx, context_size)

                # Mark processed indices
                for i in range(max(0, idx - context_size), min(len(all_sentences), idx + context_size + 1)):
                    processed_indices.add(i)

                processed_results.append({
                    "text": context_paragraph,
                    "relevance_score": score,
                    "core_sentence": sent
                })

            result["relevant_sections"].append({
                "section_name": section_name,
                "sentences": processed_results
            })

    return result

def process_top_papers(query, top_papers, max_papers=5, context_size=2):
    results = []

    for i, paper in enumerate(top_papers[:max_papers]):
        paper_link = paper['link']
        print(f"Processing paper {i+1}/{max_papers}: {paper['title']}")

        paper_data = download_arxiv_paper(paper_link)
        relevant_content = extract_relevant_content(paper_data, query, context_size=context_size)

        results.append({
            "paper_index": i+1,
            "title": paper['title'],
            "link": paper['link'],
            "extracted_content": relevant_content
        })

    return results

def get_embedding(text, tokenizer, model):
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
    inputs = {key: val.to(device) for key, val in inputs.items()}

    with torch.no_grad():
        output = model(**inputs)

    return output.last_hidden_state.mean(dim=1).cpu().numpy()

def postprocess_summary(summary):
    """
    Additional processing to clean up the summary from symbols and artifacts
    """
    # Remove any formula-like patterns
    summary = re.sub(r'[a-zA-Z0-9]+[=><+\-*/^()[\]{}]+[a-zA-Z0-9]+', '', summary)

    # Remove isolated special characters
    summary = re.sub(r'\s[/=+\-*:]+\s', ' ', summary)

    # Remove reference citations
    summary = re.sub(r'\[\d+\]|\(\d+\)', '', summary)

    # Fix spacing issues
    summary = re.sub(r'\s+', ' ', summary)

    # Remove any other non-alphanumeric characters except for basic punctuation
    summary = re.sub(r'[^\w\s.,;:!?()\-\'\"]+', '', summary)

    # Ensure proper spacing after punctuation
    summary = re.sub(r'([.,;:!?])([a-zA-Z])', r'\1 \2', summary)

    # Ensure proper capitalization of sentences
    sentences = sent_tokenize(summary)
    processed_sentences = []

    for sentence in sentences:
        sentence = sentence.strip()
        if sentence:
            if not sentence[0].isupper() and sentence[0].isalpha():
                sentence = sentence[0].upper() + sentence[1:]
            processed_sentences.append(sentence)

    summary = ' '.join(processed_sentences)

    return summary.strip()

def run_inference(query, model_choice):
    # Load cached papers and embeddings
    if os.path.exists(PAPERS_FILE) and os.path.exists(EMBEDDINGS_FILE):
        with open(PAPERS_FILE, "r") as f:
            all_papers = json.load(f)
        paper_embeddings = np.load(EMBEDDINGS_FILE)
        print(f"Loaded {len(all_papers)} papers and embeddings")
    else:
        print("Cached data not found. Please run the training script first.")
        return None

    # Load Specter model
    try:
        print("Loading Specter model...")
        specter_tokenizer = AutoTokenizer.from_pretrained("allenai/specter")
        specter_model = AutoModel.from_pretrained("allenai/specter").to(device)
    except Exception as e:
        print(f"Error loading Specter model: {e}")
        return None

    # Load selected model
    try:
        if model_choice.lower() == "pegasus":
            print("Loading fine-tuned Pegasus model...")
            tokenizer = PegasusTokenizer.from_pretrained(PEGASUS_MODEL_PATH)
            model = PegasusForConditionalGeneration.from_pretrained(PEGASUS_MODEL_PATH).to(device)
            max_input_length = 512
            max_output_length = 256
            generation_config = GenerationConfig(
                num_beams=5,
                length_penalty=1.2,
                early_stopping=True,
                no_repeat_ngram_size=3,
                do_sample=True,
                top_p=0.92,
                temperature=0.85,
                max_length=max_output_length
            )
            fallback_config = GenerationConfig(
                num_beams=8,
                length_penalty=1.5,
                early_stopping=True,
                no_repeat_ngram_size=2,
                temperature=0.7,
                max_length=max_output_length
            )
        elif model_choice.lower() == "bart":
            print("Loading fine-tuned BART model...")
            tokenizer = BartTokenizer.from_pretrained(BART_MODEL_PATH)
            model = BartForConditionalGeneration.from_pretrained(BART_MODEL_PATH).to(device)
            max_input_length = 1024
            max_output_length = 256
            generation_config = GenerationConfig(
                num_beams=4,
                early_stopping=True,
                max_length=max_output_length
            )
            fallback_config = GenerationConfig(
                num_beams=6,
                length_penalty=1.0,
                early_stopping=True,
                max_length=max_output_length
            )
        elif model_choice.lower() == "pegasus_arxiv":
            print(f"Attempting to load fine-tuned Pegasus-arxiv model from: {PEGASUS_ARXIV_MODEL_PATH}")
            if os.path.exists(PEGASUS_ARXIV_MODEL_PATH) and os.path.exists(os.path.join(PEGASUS_ARXIV_MODEL_PATH, "config.json")):
                try:
                    tokenizer = PegasusTokenizer.from_pretrained(PEGASUS_ARXIV_MODEL_PATH)
                    model = PegasusForConditionalGeneration.from_pretrained(PEGASUS_ARXIV_MODEL_PATH).to(device)
                    print("Fine-tuned Pegasus-arxiv model loaded successfully")
                except Exception as e:
                    print(f"Error loading fine-tuned model: {e}")
                    print(f"Falling back to {PEGASUS_ARXIV_FALLBACK_MODEL}")
                    tokenizer = PegasusTokenizer.from_pretrained(PEGASUS_ARXIV_FALLBACK_MODEL)
                    model = PegasusForConditionalGeneration.from_pretrained(PEGASUS_ARXIV_FALLBACK_MODEL).to(device)
            else:
                print(f"Model directory not found at {PEGASUS_ARXIV_MODEL_PATH}. Falling back to {PEGASUS_ARXIV_FALLBACK_MODEL}")
                tokenizer = PegasusTokenizer.from_pretrained(PEGASUS_ARXIV_FALLBACK_MODEL)
                model = PegasusForConditionalGeneration.from_pretrained(PEGASUS_ARXIV_FALLBACK_MODEL).to(device)
            max_input_length = 512
            max_output_length = 128
            generation_config = GenerationConfig(
                num_beams=4,
                early_stopping=True,
                max_length=max_output_length
            )
            fallback_config = GenerationConfig(
                num_beams=6,
                length_penalty=1.0,
                early_stopping=True,
                max_length=max_output_length
            )
        else:
            print("Invalid model choice. Please select 'pegasus', 'bart', or 'pegasus_arxiv'.")
            return None
    except Exception as e:
        print(f"Error loading model: {e}")
        return None

    # Generate query embedding and find similar papers
    print("Generating query embedding...")
    query_embedding = get_embedding(query, specter_tokenizer, specter_model)
    similarities = cosine_similarity(query_embedding, paper_embeddings)
    sorted_indices = np.argsort(similarities[0])[::-1]
    top_papers = [all_papers[i] for i in sorted_indices]

    print(f"Found {len(top_papers)} papers, processing top 5...")
    # Display top 5 paper titles and links
    print("\nTop 5 relevant papers:")
    for i, paper in enumerate(top_papers[:5]):
        print(f"{i+1}. {paper['title']}")
        print(f"   {paper['link']}\n")

    # Process top papers
    print("Extracting relevant content from papers...")
    extraction_results = process_top_papers(query, top_papers, max_papers=5, context_size=2)

    # Concatenate relevant sections
    combined_text = []
    for paper in extraction_results:
        for section in paper['extracted_content']['relevant_sections']:
            for sentence in section['sentences']:
                combined_text.append(sentence['text'])

    input_text = " ".join(combined_text).strip()
    if not input_text:
        print("No relevant content extracted for the query.")
        return None

    # Generate summary
    print("Generating summary...")
    try:
        print(f"Input text length: {len(input_text)} characters")
        inputs = tokenizer(input_text, max_length=max_input_length, truncation=True, padding="max_length", return_tensors="pt")
        inputs = {k: v.to(device) for k, v in inputs.items()}
        print("Tokenization successful")
    except Exception as e:
        print(f"Error during tokenization: {e}")
        return None

    try:
        # Generate summary with model-specific parameters
        summary_ids = model.generate(
            input_ids=inputs["input_ids"],
            generation_config=generation_config
        )
        summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
    except Exception as e:
        print(f"Error during initial generation: {e}")
        print("Attempting fallback generation...")
        try:
            summary_ids = model.generate(
                input_ids=inputs["input_ids"],
                generation_config=fallback_config
            )
            summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
        except Exception as e2:
            print(f"Error during fallback generation: {e2}")
            return None

    # Post-process the summary to clean up any remaining artifacts
    cleaned_summary = postprocess_summary(summary)

    # Check if the summary is too short or contains artifacts
    if len(cleaned_summary.split()) < 20 or any(char in cleaned_summary for char in '/=+-*:'):
        print("Initial summary is too short or contains artifacts. Attempting fallback generation...")
        try:
            summary_ids = model.generate(
                input_ids=inputs["input_ids"],
                generation_config=fallback_config
            )
            summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
            cleaned_summary = postprocess_summary(summary)
        except Exception as e:
            print(f"Error during fallback generation: {e}")
            return None

    return cleaned_summary

def main():
    print("Scientific Paper Query and Summarization System")
    print("==============================================")

    # Import needed libraries if not already imported
    try:
        from google.colab import drive
        drive.mount('/content/drive')
        print("Google Drive mounted successfully!")
    except ImportError:
        print("Not running in Google Colab or drive already mounted.")
    except Exception as e:
        print(f"Error mounting Google Drive: {e}")
        return

    # Get query from user
    query = input("\nEnter your research query: ")
    if not query.strip():
        print("Empty query. Please enter a valid query.")
        return

    # Get model choice from user
    print("\nAvailable models: Pegasus, BART, Pegasus_arxiv")
    model_choice = input("Enter the model to use (Pegasus, BART, or Pegasus_arxiv): ")
    if model_choice.lower() not in ['pegasus', 'bart', 'pegasus_arxiv']:
        print("Invalid model choice. Please select 'Pegasus', 'BART', or 'Pegasus_arxiv'.")
        return

    # Run inference
    print("\nProcessing your query. This may take a few minutes...\n")
    summary = run_inference(query, model_choice)

    if summary:
        print("\n=== SUMMARY ===")
        print(summary)
    else:
        print("\nFailed to generate summary. Please check error messages above.")

if __name__ == "__main__":
    main()



[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


Using device: cpu
Scientific Paper Query and Summarization System
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Google Drive mounted successfully!

Enter your research query: Transformer models for machine translation

Available models: Pegasus, BART, Pegasus_arxiv
Enter the model to use (Pegasus, BART, or Pegasus_arxiv): bart

Processing your query. This may take a few minutes...

Loaded 7068 papers and embeddings
Loading Specter model...
Loading fine-tuned BART model...
Generating query embedding...
Found 7068 papers, processing top 5...

Top 5 relevant papers:
1. Six Challenges for Neural Machine Translation
   http://arxiv.org/abs/1706.03872v1

2. OpenNMT: Open-source Toolkit for Neural Machine Translation
   http://arxiv.org/abs/1709.03815v1

3. Neural Machine Translation
   http://arxiv.org/abs/1709.07809v1

4. SentEval: An Evaluation Toolkit for Universal Sentence Representations
   http://arxiv.o

`generation_config` default values have been modified to match model-specific defaults: {'min_length': 56, 'length_penalty': 2.0, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2, 'pad_token_id': 1, 'bos_token_id': 0, 'eos_token_id': 2, 'decoder_start_token_id': 2}. If this is not desired, please set these values explicitly.


Generating summary...
Input text length: 27165 characters
Tokenization successful

=== SUMMARY ===
Research shows that neural machine translation (NMT) systems, like those trained on large data sets, struggle with poor performance on rare word translation. This is because the training data leads these systems to decide on specific word choices during decoding, buried in large matrices of values. NMT systems, however, outperform SMT systems on translation of very infrequent words, particularly those belonging to flected categories. These include Chinese English, German English, and Russian English, where NMT outperforms SMT by a wide margin, particularly on rare words. While NMT improves on SMT on these rare words, it still struggles on rare ones, particularly in domains like Subtitles and German English. This highlights the need to develop better analytics for NMT, particularly for rare words like Chinese English and Russian.
