In [2]:
import json
import logging
from pathlib import Path
from getpass import getpass
from typing import List, Dict, Any
from tqdm.notebook import tqdm
from dataclasses import dataclass, field

# Import custom modules
# from graph_gen import KnowledgeGraphBuilder, GraphConfig
# from llm_tools import OpenAIClient, OpenAIConfig, create_generate_fn
import sys
sys.path.append('')
# from graph_reasoning import *
from graph_reasoning.graph_gen import KnowledgeGraphBuilder, GraphConfig
from graph_reasoning.llm_tools import GeminiClient, GeminiConfig, create_generate_fn

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Set up output directory
OUTPUT_DIR = Path("")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

In [None]:
@dataclass
class ArxivConfig:
    """Configuration for ArXiv data processing."""
    input_file: Path = Path("data/arxiv-metadata-oai-snapshot.json")
    filtered_file: Path = Path("data/filtered_arxiv.json")
    categories: List[str] = field(default_factory=lambda: ["physics.optics"])
    max_papers: int = 100  # Limit for testing

# LLM configuration
api_key = getpass("Please enter your API key: ")
if not api_key.strip():
    raise ValueError("API key cannot be empty.")

gemini_config = GeminiConfig(
    api_key=api_key,
    max_tokens=8192,
    temperature=0,
    model_name = "gemini-2.0-flash-exp",
    top_k  = 40,
    top_p = 0.95,
)

# Graph generation configuration
graph_config = GraphConfig(
    chunk_size=2000,
    chunk_overlap=200,
    system_prompt="""
    Analyze the scientific abstract and extract key concepts and their relationships.
    Return a JSON object containing edges between concepts. Each edge should have:
    - source: The source concept/entity
    - target: The target concept/entity
    - attributes: Including relationship type and confidence score
    Focus on:
    - Technical terms and concepts
    - Experimental methods and results
    - Theoretical frameworks
    - Cause-effect relationships
    """
)

In [3]:
def load_and_filter_arxiv(config: ArxivConfig) -> List[Dict[str, Any]]:
    """
    Load and filter ArXiv papers based on categories.
    """
    filtered_papers = []
    
    logger.info(f"Loading papers from {config.input_file}")
    with open(config.input_file, 'r') as f:
        for i, line in enumerate(tqdm(f)):
            if i >= config.max_papers:
                break
                
            try:
                paper = json.loads(line)
                if any(cat in paper.get('categories', '') 
                      for cat in config.categories):
                    filtered_papers.append({
                        'id': paper.get('id'),
                        'title': paper.get('title'),
                        'abstract': paper.get('abstract'),
                        'categories': paper.get('categories')
                    })
            except json.JSONDecodeError:
                logger.warning(f"Could not parse line {i}")
                continue
    
    logger.info(f"Found {len(filtered_papers)} papers in specified categories")
    return filtered_papers

In [4]:
def process_papers(papers: List[Dict[str, Any]], output_dir: Path):
    """
    Process papers and generate knowledge graph.
    """
    # Initialize AI client and create generate_fn
    ai_client = GeminiClient(gemini_config)
    generate_fn = create_generate_fn(ai_client)
    
    # Initialize graph builder
    graph_builder = KnowledgeGraphBuilder(
        config=graph_config,
        output_dir=output_dir
    )
    
    # Process each paper
    for paper in tqdm(papers, desc="Processing papers"):
        try:
            # Combine title and abstract for context
            text = f"Title: {paper['title']}\n\nAbstract: {paper['abstract']}"
            
            # Generate graph for this paper
            graph, embeddings = graph_builder.build_graph_from_text(
                text=text,
                generate_fn=generate_fn,
                graph_root=f"paper_{paper['id']}"
            )
            
            logger.info(f"Successfully processed paper {paper['id']}")
            
        except Exception as e:
            logger.error(f"Error processing paper {paper['id']}: {e}")
            continue

In [None]:
def main():
    # Initialize config
    config = ArxivConfig(    
        input_file="",
        filtered_file = "",
        categories = "",
        max_papers = 3
    )
    
    # Load and filter papers
    papers = load_and_filter_arxiv(config)
    
    # Process papers and generate knowledge graphs
    process_papers(papers, OUTPUT_DIR)
    
    logger.info("Processing complete. Check the output directory for results.")

if __name__ == "__main__":
    main()