In [1]:
!pip install -q python-dotenv evaluate torch transformers sentence-transformers faiss-cpu nltk gradio

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m7.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m1.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m115.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m87.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m57.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m5.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m8.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
!pip install rouge_score bert_score

Collecting rouge_score
  Downloading rouge_score-0.1.2.tar.gz (17 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting bert_score
  Downloading bert_score-0.3.13-py3-none-any.whl.metadata (15 kB)
Downloading bert_score-0.3.13-py3-none-any.whl (61 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.1/61.1 kB[0m [31m6.6 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: rouge_score
  Building wheel for rouge_score (setup.py) ... [?25l[?25hdone
  Created wheel for rouge_score: filename=rouge_score-0.1.2-py3-none-any.whl size=24934 sha256=c7aaf4ae16df94f97d173168122d00b918025684c00792e81cbea93ad82774d8
  Stored in directory: /root/.cache/pip/wheels/1e/19/43/8a442dc83660ca25e163e1bd1f89919284ab0d0c1475475148
Successfully built rouge_score
Installing collected packages: rouge_score, bert_score
Successfully installed bert_score-0.3.13 rouge_score-0.1.2


In [None]:
import os
import torch
import faiss
import numpy as np
import pandas as pd
import nltk
import requests
import evaluate
from torch import nn
from transformers import AutoModelForCausalLM, AutoTokenizer
from sentence_transformers import SentenceTransformer
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
from bs4 import BeautifulSoup
from IPython.display import display, Markdown
from dotenv import load_dotenv
from huggingface_hub import login

# Initialize environment
load_dotenv()
nltk.download('punkt')
nltk.download('punkt_tab')

# Hugging Face login if token exists
if os.getenv("HF_TOKEN"):
    login(token=os.getenv("HF_TOKEN"))

class LCMConfig:
    def __init__(self):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.embd_dim = 768
        self.dim = 768
        self.layers = 6
        self.heads = 12
        self.dropout = 0.05
        self.chunk_size = 768
        self.summary_length = 250
        self.num_key_concepts = 7
        self.segmenter_model = "bert-large-uncased"
        self.embedding_model = "all-mpnet-base-v2"
        self.temperature = 0.5
        self.num_beams = 4
        self.arxiv_api_url = "http://export.arxiv.org/api/query?"
        self.max_retrieved_papers = int(os.getenv("ARXIV_MAX_RESULTS", 5))
        self.lora_rank = 8
        self.lora_alpha = 16
        self.lora_dropout = 0.1
        self.model_cache_dir = os.getenv("MODEL_CACHE_DIR", "./models")

# [Rest of your classes and functions remain exactly the same...]
# DocumentProcessor, ConceptSelector, FAISSRetriever, ArXivSearch, RAGSummarizer
# evaluate_summaries, display_results, run_search_pipeline



class DocumentProcessor:
    def __init__(self, config):
        self.config = config
        self.tokenizer = AutoTokenizer.from_pretrained(config.segmenter_model)

    def chunk_document(self, text):
        sentences = nltk.sent_tokenize(text)
        chunks, current_chunk = [], []
        current_length = 0

        for sent in sentences:
            sent_tokens = self.tokenizer.tokenize(sent)
            if current_length + len(sent_tokens) > self.config.chunk_size and current_chunk:
                chunks.append(" ".join(current_chunk))
                current_chunk = []
                current_length = 0

            current_chunk.append(sent)
            current_length += len(sent_tokens)

            if current_length > self.config.chunk_size * 0.75:
                chunks.append(" ".join(current_chunk))
                current_chunk = []
                current_length = 0

        if current_chunk:
            chunks.append(" ".join(current_chunk))

        return chunks

class ConceptSelector(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.attention = nn.MultiheadAttention(dim, num_heads=4, batch_first=True)
        self.scorer = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, dim * 2),
            nn.GELU(),
            nn.Linear(dim * 2, 1)
        )

    def forward(self, embeddings):
        attn_out, _ = self.attention(embeddings, embeddings, embeddings)
        return self.scorer(attn_out).squeeze(-1)

class FAISSRetriever:
    def __init__(self, dim):
        self.index = faiss.IndexFlatIP(dim)
        self.metadata = []

    def add(self, embedding, data):
        if not isinstance(embedding, np.ndarray):
            embedding = np.array(embedding, dtype=np.float32)
        self.index.add(embedding.reshape(1, -1))
        self.metadata.append(data)

    def search(self, query_embedding, k=5):
        if not isinstance(query_embedding, np.ndarray):
            query_embedding = np.array(query_embedding, dtype=np.float32)
        distances, indices = self.index.search(query_embedding.reshape(1, -1), k)
        return [self.metadata[i] for i in indices[0]], distances[0]

class ArXivSearch:
    def __init__(self, config):
        self.config = config
        self.embedding_model = SentenceTransformer(config.embedding_model, device=config.device)

    def search_papers(self, query, max_results=5):
        """Search ArXiv papers using their API"""
        params = {
            "search_query": f"ti:{query}",
            "start": 0,
            "max_results": max_results,
            "sortBy": "relevance",
            "sortOrder": "descending"
        }

        response = requests.get(self.config.arxiv_api_url, params=params)
        if response.status_code != 200:
            raise Exception(f"ArXiv API request failed with status {response.status_code}")

        return self._parse_arxiv_response(response.text)

    def _parse_arxiv_response(self, xml_response):
        """Parse ArXiv API XML response"""
        from bs4 import BeautifulSoup

        soup = BeautifulSoup(xml_response, 'xml')
        entries = soup.find_all('entry')

        papers = []
        for entry in entries:
            paper = {
                "title": entry.title.text.strip(),
                "authors": [author.find('name').text for author in entry.find_all('author')],
                "abstract": entry.summary.text.strip(),
                "published": entry.published.text,
                "updated": entry.updated.text,
                "arxiv_id": entry.id.text.split('/')[-1],
                "pdf_url": None
            }

            # Find PDF link
            for link in entry.find_all('link'):
                if link.get('title') == 'pdf':
                    paper['pdf_url'] = link.get('href')
                    break

            papers.append(paper)

        return papers

    def fetch_paper_text(self, arxiv_id):
        """Fetch full paper text from Kaggle dataset"""
        try:
            # Load the Kaggle dataset
            df = pd.read_csv('/kaggle/input/arxiv/arxiv-metadata-oai-snapshot.json',
                           lines=True, nrows=100000)  # Load subset for demo

            # Find the paper by arXiv ID
            paper = df[df['id'] == arxiv_id].iloc[0]
            return paper['abstract'] + " " + paper.get('title', '')  # Using abstract as proxy for full text
        except:
            return None

class RAGSummarizer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.processor = DocumentProcessor(config)
        self.embedding_model = SentenceTransformer(config.embedding_model, device=config.device)
        self.arxiv_search = ArXivSearch(config)

        # Initialize the base model with LoRA
        self.base_model = AutoModelForCausalLM.from_pretrained(
            "meta-llama/Llama-3.2-1B",
            torch_dtype=torch.float16,
            device_map="auto"
        )

        # Configure LoRA
        lora_config = LoraConfig(
            r=config.lora_rank,
            lora_alpha=config.lora_alpha,
            lora_dropout=config.lora_dropout,
            target_modules=["q_proj", "v_proj"],
            bias="none",
            task_type="CAUSAL_LM"
        )

        # Apply LoRA to the base model
        self.model = get_peft_model(self.base_model, lora_config)
        self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")

        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=config.dim,
                nhead=config.heads,
                dropout=config.dropout,
                batch_first=True,
                norm_first=True
            ),
            num_layers=config.layers
        )

        self.concept_selector = ConceptSelector(config.dim)
        self.retriever = FAISSRetriever(config.embd_dim)

    def encode_chunks(self, chunks):
        return self.embedding_model.encode(
            chunks,
            convert_to_tensor=True,
            device=self.config.device,
            normalize_embeddings=True
        )

    def forward(self, document_text):
        chunks = self.processor.chunk_document(document_text)
        if not chunks:
            return {"chunks": [], "key_concepts": [], "document_embedding": None}

        chunk_embeds = self.encode_chunks(chunks)
        encoded = self.transformer(chunk_embeds.unsqueeze(0))
        concept_scores = self.concept_selector(encoded)

        score_dist = torch.sigmoid(concept_scores)
        top_k = min(max(int((score_dist > 0.65).sum().item()), 3),
                   self.config.num_key_concepts, concept_scores.shape[-1])
        top_indices = torch.topk(concept_scores, k=top_k, dim=-1).indices

        return {
            "chunks": chunks,
            "key_concepts": [chunks[i] for i in top_indices[0]],
            "document_embedding": encoded.mean(dim=1).squeeze().cpu().detach().numpy()
        }

    def generate_summary(self, text):
        processed = self.forward(text)
        if not processed["key_concepts"]:
            return {"summary": "", "key_concepts": []}

        summary = self.generate_structured_summary(
            processed["key_concepts"],
            max_tokens=self.config.summary_length,
            temperature=self.config.temperature,
            num_beams=self.config.num_beams
        )

        self.retriever.add(processed["document_embedding"], {
            "summary": summary,
            "concepts": processed["key_concepts"],
            "original_text": text[:1000] + "..." if len(text) > 1000 else text
        })

        return {
            "summary": summary,
            "key_concepts": processed["key_concepts"],
            "embedding": processed["document_embedding"],
            "scores": self.evaluate_summary(summary, text)
        }

    def generate_structured_summary(self, concepts, max_tokens=250, temperature=0.5, num_beams=4):
        concepts = [c[:512] for c in concepts]
        prompt = f"""Generate a technical paper summary using these key points:

Concepts:
{chr(10).join(f'- {c}' for c in concepts)}

Structure your summary with:
1. Research objective and problem statement
2. Methodology and technical approach
3. Key findings and results
4. Implications and future work

Ensure:
- Technical accuracy
- Clear logical flow between paragraphs
- Proper technical terminology
- Concise but comprehensive coverage

Summary:"""

        inputs = self.tokenizer(prompt, return_tensors="pt", max_length=4096, truncation=True).to(self.model.device)

        outputs = self.model.generate(
            **inputs,
            max_new_tokens=max_tokens,
            temperature=temperature,
            top_p=0.95,
            num_beams=num_beams,
            repetition_penalty=1.15,
            do_sample=True,
            early_stopping=True
        )

        full_summary = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        return full_summary.split("Summary:")[-1].strip()

    def search_and_summarize(self, query):
        """Search ArXiv for papers and generate summaries"""
        papers = self.arxiv_search.search_papers(query, self.config.max_retrieved_papers)

        results = []
        for paper in papers:
            # Try to get full text from Kaggle dataset
            full_text = self.arxiv_search.fetch_paper_text(paper['arxiv_id'])
            if full_text is None:
                full_text = paper['abstract']  # Fall back to abstract

            result = self.generate_summary(full_text)
            result['paper_info'] = paper
            results.append(result)

        return results

    def evaluate_summary(self, summary, reference):
        """Evaluate a single summary against reference text"""
        return evaluate_summaries([summary], [reference])

    def print_trainable_parameters(self):
        """Prints the number of trainable parameters in the model."""
        trainable_params = 0
        all_param = 0
        for _, param in self.model.named_parameters():
            all_param += param.numel()
            if param.requires_grad:
                trainable_params += param.numel()
        print(
            f"Trainable params: {trainable_params:,} || "
            f"All params: {all_param:,} || "
            f"Trainable%: {100 * trainable_params / all_param:.2f}%"
        )

def evaluate_summaries(predictions, references):
    rouge = evaluate.load("rouge")
    bertscore = evaluate.load("bertscore")

    rouge_scores = rouge.compute(predictions=predictions, references=references, use_stemmer=True)
    bert_scores = bertscore.compute(predictions=predictions, references=references, lang="en", model_type="roberta-large")

    content_sim = []
    for pred, ref in zip(predictions, references):
        pred_ngrams = set(nltk.ngrams(pred.split(), 3))
        ref_ngrams = set(nltk.ngrams(ref.split(), 3))
        overlap = len(pred_ngrams & ref_ngrams) / len(ref_ngrams) if ref_ngrams else 0
        content_sim.append(overlap)

    return {
        "rouge1": rouge_scores["rouge1"],
        "rouge2": rouge_scores["rouge2"],
        "rougeL": rouge_scores["rougeL"],
        "bert_precision": np.mean(bert_scores["precision"]),
        "bert_recall": np.mean(bert_scores["recall"]),
        "bert_f1": np.mean(bert_scores["f1"]),
        "content_preservation": np.mean(content_sim)
    }

def display_results(results):
    """Display search and summary results in notebook-friendly format"""
    for i, result in enumerate(results):
        display(Markdown(f"### Paper {i+1}: {result['paper_info']['title']}"))
        display(Markdown(f"**Authors**: {', '.join(result['paper_info']['authors'])}"))
        display(Markdown(f"**Published**: {result['paper_info']['published']}"))
        display(Markdown(f"**[PDF Link]({result['paper_info']['pdf_url']})**"))

        display(Markdown("#### Summary"))
        display(Markdown(result['summary']))

        display(Markdown("#### Evaluation Scores"))
        scores = result['scores']
        display(Markdown(f"""
- ROUGE-1: {scores['rouge1']:.3f}
- ROUGE-2: {scores['rouge2']:.3f}
- ROUGE-L: {scores['rougeL']:.3f}
- BERT F1: {scores['bert_f1']:.3f}
- Content Preservation: {scores['content_preservation']:.3f}
        """))

        display(Markdown("---"))



def run_search_pipeline(query="AI", num_papers=None):
    config = LCMConfig()
    if num_papers:
        config.max_retrieved_papers = num_papers

    summarizer = RAGSummarizer(config).to(config.device)
    summarizer.eval()

    results = summarizer.search_and_summarize(query)
    return results

if __name__ == "__main__":
    test_results = run_search_pipeline("quantum computing")
    for i, result in enumerate(test_results):
        print(f"\nPaper {i+1}: {result['paper_info']['title']}")
        print(f"Summary: {result['summary'][:200]}...")




In [None]:
from rag_pipeline import run_search_pipeline
import gradio as gr
import traceback

def gradio_search(query, num_papers=5):
    try:
        # Limit num_papers to 5 maximum
        num_papers = min(num_papers, 5)

        results = run_search_pipeline(query, num_papers=num_papers)

        output = []
        for i, result in enumerate(results):
            # Generate a random pastel color for each card header
            hue = (i * 60) % 360  # Different hue for each card
            bg_color = f"hsla({hue}, 70%, 95%, 1)"
            accent_color = f"hsla({hue}, 70%, 40%, 1)"

            output.append(f"""
            <div class="paper-card" style="
                border: 1px solid #ddd;
                border-radius: 12px;
                padding: 0;
                margin-bottom: 25px;
                box-shadow: 0 4px 12px rgba(0,0,0,0.1);
                font-family: system-ui, -apple-system, 'Segoe UI', Roboto, sans-serif;
                max-width: 100%;
                background-color: #ffffff;
                overflow: hidden;
                transition: transform 0.2s, box-shadow 0.2s;
            "
            onmouseover="this.style.transform='translateY(-5px)';this.style.boxShadow='0 8px 24px rgba(0,0,0,0.15)';"
            onmouseout="this.style.transform='translateY(0)';this.style.boxShadow='0 4px 12px rgba(0,0,0,0.1)';"
            >
                <div style="
                    background-color: {bg_color};
                    padding: 16px 20px;
                    border-bottom: 1px solid #eee;
                ">
                    <h2 style="
                        font-size: 1.4em;
                        color: #333;
                        margin: 0 0 10px 0;
                        line-height: 1.3;
                        font-weight: 600;
                    ">{result['paper_info']['title']}</h2>
                    <p style="
                        color: #555;
                        margin: 0;
                        font-size: 0.95em;
                    "><strong>Authors:</strong> {', '.join(result['paper_info']['authors'][:3])}{', et al.' if len(result['paper_info']['authors']) > 3 else ''}</p>
                </div>

                <div style="padding: 20px;">
                    <div style="display: flex; flex-wrap: wrap; gap: 10px; margin-bottom: 20px;">
                        <a href="{result['paper_info']['pdf_url']}" target="_blank" style="
                            display: inline-flex;
                            align-items: center;
                            gap: 6px;
                            padding: 8px 16px;
                            background-color: {accent_color};
                            color: white;
                            text-decoration: none;
                            border-radius: 6px;
                            font-size: 14px;
                            font-weight: 500;
                            transition: opacity 0.2s;
                        " onmouseover="this.style.opacity='0.9';" onmouseout="this.style.opacity='1';">
                            <svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
                                <path d="M14 2H6a2 2 0 0 0-2 2v16a2 2 0 0 0 2 2h12a2 2 0 0 0 2-2V8z"></path>
                                <polyline points="14 2 14 8 20 8"></polyline>
                                <line x1="16" y1="13" x2="8" y2="13"></line>
                                <line x1="16" y1="17" x2="8" y2="17"></line>
                                <polyline points="10 9 9 9 8 9"></polyline>
                            </svg>
                            View PDF
                        </a>
                        <div style="
                            display: inline-flex;
                            align-items: center;
                            gap: 6px;
                            padding: 8px 16px;
                            background-color: #f0f0f0;
                            color: #666;
                            border-radius: 6px;
                            font-size: 14px;
                        ">
                            <svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
                                <rect x="3" y="4" width="18" height="18" rx="2" ry="2"></rect>
                                <line x1="16" y1="2" x2="16" y2="6"></line>
                                <line x1="8" y1="2" x2="8" y2="6"></line>
                                <line x1="3" y1="10" x2="21" y2="10"></line>
                            </svg>
                            {result['paper_info']['published'][:10]}
                        </div>
                    </div>

                    <div style="margin-top: 20px;">
                        <h3 style="
                            font-size: 1.1em;
                            color: #444;
                            margin: 0 0 10px 0;
                            padding-bottom: 8px;
                            border-bottom: 2px solid {bg_color};
                            display: inline-block;
                        ">Summary</h3>
                        <div style="
                            background-color: #f9f9f9;
                            padding: 16px;
                            border-radius: 8px;
                            color: #333;
                            line-height: 1.5;
                            font-size: 0.95em;
                            margin-bottom: 20px;
                        ">
                            {result['summary']}
                        </div>
                    </div>

                    <div class="key-concepts">
                        <h3 style="
                            font-size: 1.1em;
                            color: #444;
                            margin: 0 0 10px 0;
                            padding-bottom: 8px;
                            border-bottom: 2px solid {bg_color};
                            display: inline-block;
                        ">Key Concepts</h3>
                        <div class="concepts-grid" style="
                            display: grid;
                            grid-template-columns: repeat(auto-fill, minmax(250px, 1fr));
                            gap: 10px;
                        ">
                            {"".join(f'''
                            <div style="
                                background-color: #f5f5f5;
                                border-left: 3px solid {accent_color};
                                padding: 10px;
                                border-radius: 6px;
                                font-size: 0.9em;
                                color: #555;
                            ">
                                {concept[:150]}{'...' if len(concept)>150 else ''}
                            </div>
                            ''' for concept in result['key_concepts'][:5])}
                        </div>
                    </div>
                </div>
            </div>
            """)

        # Add CSS for responsiveness
        responsive_css = """
        <style>
            @media (max-width: 768px) {
                .concepts-grid {
                    grid-template-columns: 1fr !important;
                }

                .paper-card h2 {
                    font-size: 1.2em !important;
                }

                .paper-card {
                    padding: 0 !important;
                }
            }
        </style>
        """

        return responsive_css + "".join(output)

    except Exception as e:
        error_message = f"""
        <div style="
            border: 1px solid #f8d7da;
            background-color: #fff5f5;
            color: #721c24;
            padding: 20px;
            border-radius: 8px;
            margin: 20px 0;
            font-family: system-ui, -apple-system, 'Segoe UI', Roboto, sans-serif;
        ">
            <h3 style="margin-top: 0;">⚠️ Error Occurred</h3>
            <p>{str(e)}</p>
            <details>
                <summary style="cursor: pointer; color: #721c24; font-weight: bold; margin: 10px 0;">
                    Show technical details
                </summary>
                <pre style="
                    background-color: #f8f8f8;
                    padding: 15px;
                    border-radius: 5px;
                    overflow-x: auto;
                    color: #333;
                    font-size: 0.9em;
                ">{traceback.format_exc()}</pre>
            </details>
        </div>
        """
        return error_message

# Custom CSS for Gradio interface
custom_css = """
.gradio-container {
    max-width: 900px !important;
    margin-left: auto !important;
    margin-right: auto !important;
}

.main-container {
    padding: 0 !important;
}

/* Loading animation */
@keyframes pulse {
    0% { opacity: 0.6; }
    50% { opacity: 1; }
    100% { opacity: 0.6; }
}

.loading {
    animation: pulse 1.5s infinite;
}

/* Responsive design */
@media (max-width: 768px) {
    .gradio-container {
        padding: 10px !important;
    }
}
"""

if __name__ == "__main__":
    with gr.Blocks(css=custom_css) as iface:
        gr.Markdown(
            """
            # 📚 ArXiv IntelliSearch 📚

            Get AI-generated summaries of the latest research papers from ArXiv.
            Enter a topic and let our model find and summarize relevant papers.
            """
        )

        with gr.Row():
            query = gr.Textbox(
                label="Research Topic",
                placeholder="Enter a topic (e.g., quantum mechanics, AI, climate change)",
                lines=1
            )
            num_papers = gr.Slider(
                minimum=1,
                maximum=5,
                value=3,
                step=1,
                label="Number of Papers (max 5)"
            )

        search_btn = gr.Button("Search Papers", variant="primary")

        # Output area
        output = gr.HTML(
            label="Results",
            value="""<div style="text-align: center; color: #666; padding: 30px;">
                Enter a research topic above and click "Search Papers" to get started
            </div>"""
        )

        # Examples - reduced to just two key examples
        gr.Examples(
            examples=[
                ["large language models", 3],
                ["quantum computing", 2]
            ],
            inputs=[query, num_papers]
        )

        # Loading state
        search_btn.click(
            fn=lambda: """<div style="text-align: center; padding: 40px; color: #666;" class="loading">
                    <svg xmlns="http://www.w3.org/2000/svg" width="40" height="40" viewBox="0 0 24 24" fill="none"
                         stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"
                         style="margin: 0 auto 20px auto; display: block;">
                        <path d="M21 12a9 9 0 1 1-6.219-8.56"></path>
                    </svg>
                    <p>Searching and analyzing research papers...</p>
                    <p style="font-size: 0.85em;">This may take a minute</p>
                </div>""",
            outputs=output
        ).then(
            fn=gradio_search,
            inputs=[query, num_papers],
            outputs=output
        )

        gr.Markdown(
            """
            ### How it works

            This tool uses:
            1. A specialized RAG pipeline to find relevant papers
            2. LLM-based key concept extraction
            3. Advanced summarization techniques to create concise research summaries

            *Note: Summaries are AI-generated and should be used as a starting point for further research.*
            """
        )

    # Launch the interface
    iface.launch(server_port=5000, share=True)

In [None]:
!python gradio_app.py

2025-06-30 05:33:05.505269: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1751261585.538644    1333 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1751261585.557435    1333 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-06-30 05:33:05.615292: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
[nltk_data]