In [1]:
import pandas as pd

In [ ]:
data_path = "../data/raw/arxiv_sample.csv"

In [3]:
import kagglehub

import os
import shutil
from pathlib import Path

def path2correct_loc(source_path, destination_path, copy_instead_of_move=False):
    """
    Move or copy all files from source path to destination path.
    
    Args:
        source_path (str): Path to the source files
        destination_path (str): Relative path from current working directory where to move the files
        copy_instead_of_move (bool): If True, copy files instead of moving them
    
    Returns:
        str: Absolute path to the destination directory
    """
    
    # Convert to Path objects for easier handling
    source = Path(source_path)
    # Make destination relative to current working directory
    dest = Path.cwd() / destination_path
    
    # Check if source exists
    if not source.exists():
        raise FileNotFoundError(f"Source path does not exist: {source_path}")
    
    # Create destination directory if it doesn't exist
    dest.mkdir(parents=True, exist_ok=True)
    
    # Get all files in source directory (including subdirectories)
    files_moved = 0
    total_size = 0
    
    print(f"{'Copying' if copy_instead_of_move else 'Moving'} files from {source} to {dest}")
    
    # Walk through all files and subdirectories
    for item in source.rglob('*'):
        if item.is_file():
            # Calculate relative path to preserve directory structure
            relative_path = item.relative_to(source)
            dest_file = dest / relative_path
            
            # Create subdirectories if needed
            dest_file.parent.mkdir(parents=True, exist_ok=True)
            
            # Move or copy the file
            try:
                if copy_instead_of_move:
                    shutil.copy2(item, dest_file)
                    action = "Copied"
                else:
                    shutil.move(str(item), str(dest_file))
                    action = "Moved"
                
                file_size = dest_file.stat().st_size
                total_size += file_size
                files_moved += 1
                
                print(f"{action}: {relative_path} ({file_size / (1024*1024):.2f} MB)")
                
            except Exception as e:
                print(f"Error processing {item}: {e}")
    
    print(f"\nCompleted! {files_moved} files {'copied' if copy_instead_of_move else 'moved'}")
    print(f"Total size: {total_size / (1024*1024*1024):.2f} GB")
    print(f"Files are now in: {dest.absolute()}")
    
    return str(dest.absolute())



path = kagglehub.dataset_download("Cornell-University/arxiv", )

new_location = path2correct_loc(path, "")


Downloading from https://www.kaggle.com/api/v1/datasets/download/Cornell-University/arxiv?dataset_version_number=234...


100%|██████████| 1.44G/1.44G [00:56<00:00, 27.2MB/s]

Extracting files...





Path to dataset files: /Users/joshash/.cache/kagglehub/datasets/Cornell-University/arxiv/versions/234


In [18]:
# jsons_fn = "arxiv-metadata-oai-snapshot.json"
# 
# import json
# import random
# import json
# import random
# from tqdm import tqdm
# 
# def sample_jsons(filename, n, method='first', seed=42):
#     """
#     Sample N entries from a JSON lines file.
#     
#     Args:
#         filename (str): Path to the JSON lines file
#         n (int): Number of entries to sample
#         method (str): Sampling method - 'first', 'random', or 'last'
#         seed (int): Random seed for reproducible random sampling
#     
#     Returns:
#         list: List of parsed JSON objects
#     """
#     
#     if method == 'first':
#         # Simple: just read first n lines
#         samples = []
#         with open(filename, 'r', encoding='utf-8') as f:
#             pbar = tqdm(desc=f"Reading first {n} entries", total=n)
#             for i, line in enumerate(f):
#                 if i >= n:
#                     break
#                 if line.strip():
#                     samples.append(json.loads(line))
#                     pbar.update(1)
#             pbar.close()
#         return samples
#     
#     elif method == 'last':
#         # Use deque to keep last n entries
#         from collections import deque
#         samples = deque(maxlen=n)
#         with open(filename, 'r', encoding='utf-8') as f:
#             # Count total lines first for progress bar
#             total_lines = sum(1 for line in f if line.strip())
#             f.seek(0)  # Reset file pointer
#             
#             with tqdm(desc=f"Reading for last {n} entries", total=total_lines) as pbar:
#                 for line in f:
#                     if line.strip():
#                         samples.append(json.loads(line))
#                         pbar.update(1)
#         return list(samples)
#     
#     elif method == 'random':
#         # Set seed for reproducible results
#         random.seed(seed)
#         
#         # Count total lines first for progress bar
#         with open(filename, 'r', encoding='utf-8') as f:
#             total_lines = sum(1 for line in f if line.strip())
#         
#         # Reservoir sampling for true random sample
#         samples = []
#         with open(filename, 'r', encoding='utf-8') as f:
#             with tqdm(desc=f"Random sampling {n} entries", total=total_lines) as pbar:
#                 line_count = 0
#                 for line in f:
#                     if not line.strip():
#                         continue
#                         
#                     entry = json.loads(line)
#                     line_count += 1
#                     
#                     if len(samples) < n:
#                         samples.append(entry)
#                     else:
#                         j = random.randint(0, line_count - 1)
#                         if j < n:
#                             samples[j] = entry
#                     
#                     pbar.update(1)
#         
#         return samples
#     
#     else:
#         raise ValueError("method must be 'first', 'last', or 'random'")
# 
# # Usage examples:
# # sample_jsons(jsons_fn, 100)  # First 100 entries
# # sample_jsons(jsons_fn, 100, 'random')  # Random 100 entries with default seed
# # sample_jsons(jsons_fn, 100, 'random', seed=123)  # Random with custom seed
# # sample_jsons(jsons_fn, 100, 'last')  # Last 100 entries
# 
# 
# 
# 
# # Usage examples:
# # sample_jsons(jsons_fn, 100)  # First 100 entries
# # sample_jsons(jsons_fn, 100, 'random')  # Random 100 entries with default seed
# sample = sample_jsons(jsons_fn, 100, 'random', seed=123)  # Random with custom seed
# # sample_jsons(jsons_fn, 100, 'last')  # Last 100 entries


Random sampling 100 entries: 100%|██████████| 2735264/2735264 [00:38<00:00, 70170.77it/s]


In [29]:
# # ArXiv JSON Processing Script
# # Run this in a Jupyter notebook cell
# 
# 
# import json
# from datetime import datetime
# from pathlib import Path
# 
# def parse_arxiv_entry(entry):
#     """
#     Parse a single ArXiv JSON entry and extract required fields.
#     
#     Args:
#         entry (dict): Single ArXiv paper dictionary
#     
#     Returns:
#         dict: Dictionary with extracted fields
#     """
#     # Get the latest version number from versions list
#     version_no = None
#     if entry.get('versions') and len(entry['versions']) > 0:
#         # Versions are typically in chronological order, so last one is latest
#         latest_version = entry['versions'][-1]
#         version_no = latest_version.get('version', '').replace('v', '')
#     
#     return {
#         'id': entry.get('id', ''),
#         'title': entry.get('title', '').strip(),
#         'authors_parsed': entry.get('authors_parsed', []),
#         'submitter': entry.get('submitter', ''),
#         'categories': entry.get('categories', ''),
#         'abstract': entry.get('abstract', '').strip(),
#         'doi': entry.get('doi', ''),
#         'update_date': entry.get('update_date', ''),
#         'version_no': version_no
#     }
# 
# def parse_json(entries):
#     """
#     Parse multiple ArXiv JSON entries.
#     
#     Args:
#         entries (list): List of ArXiv paper dictionaries
#     
#     Returns:
#         list: List of parsed entries
#     """
#     return [parse_arxiv_entry(entry) for entry in entries]
# 
# def is_valid(entry, allowed_categories=None, start_date=None, end_date=None):
#     """
#     Check if an ArXiv entry is valid based on categories and date range.
#     Highly efficient - operates on raw dict without full parsing.
#     
#     Args:
#         entry (dict): Raw ArXiv paper dictionary
#         allowed_categories (list): List of allowed categories (None to skip check)
#         start_date (str): Start date in YYYY-MM-DD format (None to skip check)
#         end_date (str): End date in YYYY-MM-DD format (None to skip check)
#     
#     Returns:
#         bool: True if entry passes all filters, False otherwise
#     """
#     # Category check
#     if allowed_categories is not None:
#         entry_categories = entry.get('categories', '')
#         if not entry_categories:
#             return False
#         
#         entry_cats = [cat.strip() for cat in entry_categories.split()]
#         if not any(cat in allowed_categories for cat in entry_cats):
#             return False
#     
#     if start_date is not None or end_date is not None:
#         update_date = entry.get('update_date', '')
#         if not update_date:
#             return False
#         
#         try:
#             entry_date = datetime.strptime(update_date, '%Y-%m-%d')
#             
#             if start_date is not None:
#                 start_dt = datetime.strptime(start_date, '%Y-%m-%d')
#                 if entry_date < start_dt:
#                     return False
#             
#             if end_date is not None:
#                 end_dt = datetime.strptime(end_date, '%Y-%m-%d')
#                 if entry_date > end_dt:
#                     return False
#                     
#         except ValueError:
#             return False
#     
#     return True
# 
# # Sample the data
# sample = sample_jsons(jsons_fn, 100, 'random', seed=123)  # Random with custom seed
# 
# print(f"Sampled {len(sample)} entries")
# 
# # Parse the sample
# parsed_jsons = parse_json([i for i in sample])
# 
# print(f"Parsed {len(parsed_jsons)} entries")
# 
# # Show a sample parsed entry
# if parsed_jsons:
#     print("\nSample parsed entry:")
#     entry = parsed_jsons[0]
#     print(f"ID: {entry['id']}")
#     print(f"Title: {entry['title'][:80]}...")
#     print(f"Categories: {entry['categories']}")
#     print(f"Authors: {len(entry['authors_parsed'])} authors")
#     print(f"DOI: {entry['doi']}")
#     print(f"Update Date: {entry['update_date']}")
#     print(f"Version: {entry['version_no']}")
#     print(f"Abstract: {entry['abstract'][:150]}...")
# 
# # Optional: Show some statistics
# print(f"\nStatistics:")
# print(f"Total entries: {len(parsed_jsons)}")
# print(f"Entries with DOI: {sum(1 for e in parsed_jsons if e['doi'])}")
# print(f"Unique categories: {len(set(e['categories'] for e in parsed_jsons))}")
# 
# # Op

TypeError: sample_jsons() got an unexpected keyword argument 'seed'

In [22]:
df = pd.DataFrame(parsed_jsons)

In [25]:
df['categories'].unique()

array(['astro-ph.HE', 'math.AP', 'hep-th',
       'cond-mat.mes-hall cond-mat.mtrl-sci cond-mat.str-el',
       'cs.NI cs.ET', 'quant-ph', 'cs.CV cs.RO', 'math.DG math.AP',
       'stat.AP stat.ME', 'math.AG', 'hep-th gr-qc', 'math.DS',
       'physics.plasm-ph', 'quant-ph physics.atom-ph',
       'cs.RO cs.AI cs.CV cs.LG cs.SY eess.SY', 'stat.ML cs.LG math.DG',
       'physics.optics', 'cond-mat.str-el cond-mat.mtrl-sci',
       'physics.data-an cond-mat.stat-mech nlin.PS physics.soc-ph',
       'math.GR', 'cs.CV',
       'cond-mat.mtrl-sci cond-mat.soft physics.optics', 'math.PR',
       'cond-mat.str-el cond-mat.supr-con',
       'cond-mat.mes-hall physics.atom-ph quant-ph', 'nucl-th hep-lat',
       'astro-ph.HE astro-ph.CO', 'cs.LG cs.AI', 'q-bio.TO',
       'cs.CL cs.AI cs.LG', 'math.HO', 'q-fin.ST cs.LG', 'hep-lat',
       'physics.atom-ph', 'nucl-ex', 'cs.IT math.IT', 'cond-mat.supr-con',
       'cs.NI', 'cs.RO', 'physics.comp-ph cs.LG hep-ph', 'astro-ph',
       'astro-ph.HE a

In [26]:
cs_cats = {
    "cs.AI": "Artificial Intelligence",
    "cs.AR": "Hardware Architecture", 
    "cs.CC": "Computational Complexity",
    "cs.CE": "Computational Engineering, Finance, and Science",
    "cs.CG": "Computational Geometry",
    "cs.CL": "Computation and Language",
    "cs.CR": "Cryptography and Security",
    "cs.CV": "Computer Vision and Pattern Recognition",
    "cs.CY": "Computers and Society",
    "cs.DB": "Databases",
    "cs.DC": "Distributed, Parallel, and Cluster Computing",
    "cs.DL": "Digital Libraries",
    "cs.DM": "Discrete Mathematics",
    "cs.DS": "Data Structures and Algorithms",
    "cs.ET": "Emerging Technologies",
    "cs.FL": "Formal Languages and Automata Theory",
    "cs.GL": "General Literature",
    "cs.GR": "Graphics",
    "cs.GT": "Computer Science and Game Theory",
    "cs.HC": "Human-Computer Interaction",
    "cs.IR": "Information Retrieval",
    "cs.IT": "Information Theory",
    "cs.LG": "Machine Learning",
    "cs.LO": "Logic in Computer Science",
    "cs.MA": "Multiagent Systems",
    "cs.MM": "Multimedia",
    "cs.MS": "Mathematical Software",
    "cs.NA": "Numerical Analysis",
    "cs.NE": "Neural and Evolutionary Computing",
    "cs.NI": "Networking and Internet Architecture",
    "cs.OH": "Other Computer Science",
    "cs.OS": "Operating Systems",
    "cs.PF": "Performance",
    "cs.PL": "Programming Languages",
    "cs.RO": "Robotics",
    "cs.SC": "Symbolic Computation",
    "cs.SD": "Sound",
    "cs.SE": "Software Engineering",
    "cs.SI": "Social and Information Networks",
    "cs.SY": "Systems and Control"
}

In [30]:
import json
from datetime import datetime
from pathlib import Path


def sample_jsons(filename, n, method='first', seed=42):
    """
    Sample N entries from a JSON lines file.
    
    Args:
        filename (str): Path to the JSON lines file
        n (int): Number of entries to sample
        method (str): Sampling method - 'first', 'random', or 'last'
        seed (int): Random seed for reproducible random sampling
        validation_function (function or None): If you pass this in, it only returns things that pass a validation function. This function must take in an arxiv dict and output True or False (default: None)
    
    Returns:
        list: List of parsed JSON objects
    """
    
    if method == 'first':
        # Simple: just read first n lines
        samples = []
        with open(filename, 'r', encoding='utf-8') as f:
            pbar = tqdm(desc=f"Reading first {n} entries", total=n)
            for i, line in enumerate(f):
                if i >= n:
                    break
                if line.strip():
                    samples.append(json.loads(line))
                    pbar.update(1)
            pbar.close()
        return samples
    
    elif method == 'last':
        # Use deque to keep last n entries
        from collections import deque
        samples = deque(maxlen=n)
        with open(filename, 'r', encoding='utf-8') as f:
            # Count total lines first for progress bar
            total_lines = sum(1 for line in f if line.strip())
            f.seek(0)  # Reset file pointer
            
            with tqdm(desc=f"Reading for last {n} entries", total=total_lines) as pbar:
                for line in f:
                    if line.strip():
                        samples.append(json.loads(line))
                        pbar.update(1)
        return list(samples)
    
    elif method == 'random':
        # Set seed for reproducible results
        random.seed(seed)
        
        # Count total lines first for progress bar
        with open(filename, 'r', encoding='utf-8') as f:
            total_lines = sum(1 for line in f if line.strip())
        
        # Reservoir sampling for true random sample
        samples = []
        with open(filename, 'r', encoding='utf-8') as f:
            with tqdm(desc=f"Random sampling {n} entries", total=total_lines) as pbar:
                line_count = 0
                for line in f:
                    if not line.strip():
                        continue
                        
                    entry = json.loads(line)
                    line_count += 1
                    
                    if len(samples) < n:
                        samples.append(entry)
                    else:
                        j = random.randint(0, line_count - 1)
                        if j < n:
                            samples[j] = entry
                    
                    pbar.update(1)
        
        return samples
    
    else:
        raise ValueError("method must be 'first', 'last', or 'random'")


def parse_arxiv_entry(entry):
    """
    Parse a single ArXiv JSON entry and extract required fields.
    
    Args:
        entry (dict): Single ArXiv paper dictionary
    
    Returns:
        dict: Dictionary with extracted fields
    """
    # Get the latest version number from versions list
    version_no = None
    if entry.get('versions') and len(entry['versions']) > 0:
        # Versions are typically in chronological order, so last one is latest
        latest_version = entry['versions'][-1]
        version_no = latest_version.get('version', '').replace('v', '')
    
    return {
        'id': entry.get('id', ''),
        'title': entry.get('title', '').strip(),
        'authors_parsed': entry.get('authors_parsed', []),
        'submitter': entry.get('submitter', ''),
        'categories': entry.get('categories', ''),
        'abstract': entry.get('abstract', '').strip(),
        'doi': entry.get('doi', ''),
        'update_date': entry.get('update_date', ''),
        'version_no': version_no
    }

def parse_json(entries):
    """
    Parse multiple ArXiv JSON entries.
    
    Args:
        entries (list): List of ArXiv paper dictionaries
    
    Returns:
        list: List of parsed entries
    """
    return [parse_arxiv_entry(entry) for entry in entries]

def is_valid(entry, allowed_categories=None, allowed_major_categories=None, 
             allowed_minor_categories=None, start_date=None, end_date=None):
    """
    Check if an ArXiv entry is valid based on categories and date range.
    Highly efficient - operates on raw dict without full parsing.
    
    Args:
        entry (dict): Raw ArXiv paper dictionary
        allowed_categories (list): List of allowed full categories like ['cs.AI', 'math.ST'] (None to skip check)
        allowed_major_categories (list): List of allowed major categories like ['cs', 'math'] (None to skip check)
        allowed_minor_categories (list): List of allowed minor categories like ['AI', 'ST'] (None to skip check)
        start_date (str): Start date in YYYY-MM-DD format (None to skip check)
        end_date (str): End date in YYYY-MM-DD format (None to skip check)
    
    Returns:
        bool: True if entry passes all filters, False otherwise
    """
    # Category check
    if any(x is not None for x in [allowed_categories, allowed_major_categories, allowed_minor_categories]):
        entry_categories = entry.get('categories', '')
        if not entry_categories:
            return False
        
        # Split categories and check if any match allowed categories
        entry_cats = [cat.strip() for cat in entry_categories.split()]
        
        # Check full categories (exact match)
        if allowed_categories is not None:
            if not any(cat in allowed_categories for cat in entry_cats):
                return False
        
        # Check major categories (before the dot)
        if allowed_major_categories is not None:
            entry_majors = [cat.split('.')[0] if '.' in cat else cat for cat in entry_cats]
            if not any(major in allowed_major_categories for major in entry_majors):
                return False
        
        # Check minor categories (after the dot)
        if allowed_minor_categories is not None:
            entry_minors = [cat.split('.')[1] if '.' in cat and len(cat.split('.')) > 1 else '' 
                           for cat in entry_cats]
            entry_minors = [minor for minor in entry_minors if minor]  # Remove empty strings
            if not any(minor in allowed_minor_categories for minor in entry_minors):
                return False
    
    # Date check
    if start_date is not None or end_date is not None:
        update_date = entry.get('update_date', '')
        if not update_date:
            return False
        
        try:
            # Convert update_date to datetime for comparison
            entry_date = datetime.strptime(update_date, '%Y-%m-%d')
            
            if start_date is not None:
                start_dt = datetime.strptime(start_date, '%Y-%m-%d')
                if entry_date < start_dt:
                    return False
            
            if end_date is not None:
                end_dt = datetime.strptime(end_date, '%Y-%m-%d')
                if entry_date > end_dt:
                    return False
                    
        except ValueError:
            # Invalid date format
            return False
    
    return True


    
# 1. Get jsons
sample = sample_jsons(jsons_fn, 100, 'random', seed=123)  # Random with custom seed

# 2. Parse jsons
parsed_jsons = parse_json([i for i in sample])

print(f"Parsed {len(parsed_jsons)} entries")
print("Sample parsed entry:")
if parsed_jsons:
    entry = parsed_jsons[0]
    for key, value in entry.items():
        print(f"  {key}: {value}")

# Example of filtering with is_valid
print("\nFiltering examples:")

# Filter by categories
valid_entries = [entry for entry in jsons if is_valid(entry, allowed_categories=['astro-ph.HE', 'hep-ph'])]
print(f"Entries with allowed categories: {len(valid_entries)}")

# Filter by date range
valid_entries = [entry for entry in jsons if is_valid(entry, start_date='2023-01-01', end_date='2023-12-31')]
print(f"Entries in 2023: {len(valid_entries)}")

# Combined filter
valid_entries = [entry for entry in jsons if is_valid(
    entry, 
    allowed_categories=['astro-ph.HE'], 
    start_date='2023-01-01'
)]
print(f"Entries with astro-ph.HE category from 2023: {len(valid_entries)}")

Random sampling 100 entries: 100%|██████████| 2735264/2735264 [00:38<00:00, 70782.58it/s]

Parsed 100 entries
Sample parsed entry:
  id: 2308.09518
  title: Efficient Modeling of Heavy Cosmic Rays Propagation in Evolving
  Astrophysical Environments
  authors_parsed: [['Merten', 'Lukas', ''], ['Da Vela', 'Paolo', ''], ['Reimer', 'Anita', ''], ['Boughelilba', 'Margot', ''], ['Lundquist', 'Jon Paul', ''], ['Vorobiov', 'Serguei', ''], ['Tjus', 'Julia Becker', '']]
  submitter: Lukas Merten
  categories: astro-ph.HE
  abstract: We present a new energy transport code that models the time dependent and
non-linear evolution of spectra of cosmic-ray nuclei, their secondaries, and
photon target fields. The software can inject an arbitrary chemical composition
including heavy elements up to iron nuclei. Energy losses and secondary
production due to interactions of cosmic ray nuclei, secondary mesons, leptons,
or gamma-rays with a target photon field are available for all relevant
processes, e.g., photo-meson production, photo disintegration, synchrotron
radiation, Inverse Compton scat




AttributeError: 'PosixPath' object has no attribute 'get'

In [31]:
parsed_jsons

[{'id': '2308.09518',
  'title': 'Efficient Modeling of Heavy Cosmic Rays Propagation in Evolving\n  Astrophysical Environments',
  'authors_parsed': [['Merten', 'Lukas', ''],
   ['Da Vela', 'Paolo', ''],
   ['Reimer', 'Anita', ''],
   ['Boughelilba', 'Margot', ''],
   ['Lundquist', 'Jon Paul', ''],
   ['Vorobiov', 'Serguei', ''],
   ['Tjus', 'Julia Becker', '']],
  'submitter': 'Lukas Merten',
  'categories': 'astro-ph.HE',
  'abstract': 'We present a new energy transport code that models the time dependent and\nnon-linear evolution of spectra of cosmic-ray nuclei, their secondaries, and\nphoton target fields. The software can inject an arbitrary chemical composition\nincluding heavy elements up to iron nuclei. Energy losses and secondary\nproduction due to interactions of cosmic ray nuclei, secondary mesons, leptons,\nor gamma-rays with a target photon field are available for all relevant\nprocesses, e.g., photo-meson production, photo disintegration, synchrotron\nradiation, Inverse 

In [38]:
def is_valid(entry, allowed_categories=None, allowed_major_categories=None, 
             allowed_minor_categories=None, start_date=None, end_date=None):
    """
    Check if an ArXiv entry is valid based on categories and date range.
    Highly efficient - operates on raw dict without full parsing.
    
    Args:
        entry (dict): Raw ArXiv paper dictionary
        allowed_categories (list): List of allowed full categories like ['cs.AI', 'math.ST'] (None to skip check)
        allowed_major_categories (list): List of allowed major categories like ['cs', 'math'] (None to skip check)
        allowed_minor_categories (list): List of allowed minor categories like ['AI', 'ST'] (None to skip check)
        start_date (str): Start date in YYYY-MM-DD format (None to skip check)
        end_date (str): End date in YYYY-MM-DD format (None to skip check)
    
    Returns:
        bool: True if entry passes all filters, False otherwise
    """
    # Category check
    if any(x is not None for x in [allowed_categories, allowed_major_categories, allowed_minor_categories]):
        entry_categories = entry.get('categories', '')
        if not entry_categories:
            return False
        
        # Split categories and check if any match allowed categories
        entry_cats = [cat.strip() for cat in entry_categories.split()]
        
        # Check full categories (exact match)
        if allowed_categories is not None:
            if not any(cat in allowed_categories for cat in entry_cats):
                return False
        
        # Check major categories (before the dot)
        if allowed_major_categories is not None:
            entry_majors = [cat.split('.')[0] if '.' in cat else cat for cat in entry_cats]
            if not any(major in allowed_major_categories for major in entry_majors):
                return False
        
        # Check minor categories (after the dot)
        if allowed_minor_categories is not None:
            entry_minors = [cat.split('.')[1] if '.' in cat and len(cat.split('.')) > 1 else '' 
                           for cat in entry_cats]
            entry_minors = [minor for minor in entry_minors if minor]  # Remove empty strings
            if not any(minor in allowed_minor_categories for minor in entry_minors):
                return False
    
    # Date check
    if start_date is not None or end_date is not None:
        update_date = entry.get('update_date', '')
        if not update_date:
            return False
        
        try:
            # Convert update_date to datetime for comparison
            entry_date = datetime.strptime(update_date, '%Y-%m-%d')
            
            if start_date is not None:
                start_dt = datetime.strptime(start_date, '%Y-%m-%d')
                if entry_date < start_dt:
                    return False
            
            if end_date is not None:
                end_dt = datetime.strptime(end_date, '%Y-%m-%d')
                if entry_date > end_dt:
                    return False
                    
        except ValueError:
            # Invalid date format
            return False
    
    return True


valid_entries = [entry for entry in sample if is_valid(entry, allowed_major_categories=['CS'])]


In [39]:
valid_entries

[]

# NEW

In [2]:
import json
import random
from collections import deque
from pathlib import Path # Assuming this might be used by your is_valid or other parts
from datetime import datetime # Assuming this might be used by your is_valid

from tqdm.auto import tqdm # Using tqdm.auto for flexible environment
from concurrent.futures import ProcessPoolExecutor
import os

def is_valid(entry, allowed_categories=None, allowed_major_categories=None, 
             allowed_minor_categories=None, start_date=None, end_date=None):
    """
    Check if an ArXiv entry is valid based on categories and date range.
    Highly efficient - operates on raw dict without full parsing.
    
    Args:
        entry (dict): Raw ArXiv paper dictionary
        allowed_categories (list): List of allowed full categories like ['cs.AI', 'math.ST'] (None to skip check)
        allowed_major_categories (list): List of allowed major categories like ['cs', 'math'] (None to skip check)
        allowed_minor_categories (list): List of allowed minor categories like ['AI', 'ST'] (None to skip check)
        start_date (str): Start date in YYYY-MM-DD format (None to skip check)
        end_date (str): End date in YYYY-MM-DD format (None to skip check)
    
    Returns:
        bool: True if entry passes all filters, False otherwise
    """
    # Category check
    if any(x is not None for x in [allowed_categories, allowed_major_categories, allowed_minor_categories]):
        entry_categories = entry.get('categories', '')
        if not entry_categories:
            return False
        
        # Split categories and check if any match allowed categories
        entry_cats = [cat.strip() for cat in entry_categories.split()]
        
        # Check full categories (exact match)
        if allowed_categories is not None:
            if not any(cat in allowed_categories for cat in entry_cats):
                return False
        
        # Check major categories (before the dot)
        if allowed_major_categories is not None:
            entry_majors = [cat.split('.')[0] if '.' in cat else cat for cat in entry_cats]
            if not any(major in allowed_major_categories for major in entry_majors):
                return False
        
        # Check minor categories (after the dot)
        if allowed_minor_categories is not None:
            entry_minors = [cat.split('.')[1] if '.' in cat and len(cat.split('.')) > 1 else '' 
                           for cat in entry_cats]
            entry_minors = [minor for minor in entry_minors if minor]  # Remove empty strings
            if not any(minor in allowed_minor_categories for minor in entry_minors):
                return False
    
    # Date check
    if start_date is not None or end_date is not None:
        update_date = entry.get('update_date', '')
        if not update_date:
            return False
        
        try:
            # Convert update_date to datetime for comparison
            entry_date = datetime.strptime(update_date, '%Y-%m-%d')
            
            if start_date is not None:
                start_dt = datetime.strptime(start_date, '%Y-%m-%d')
                if entry_date < start_dt:
                    return False
            
            if end_date is not None:
                end_dt = datetime.strptime(end_date, '%Y-%m-%d')
                if entry_date > end_dt:
                    return False
                    
        except ValueError:
            # Invalid date format
            return False
    
    return True


def _mp_process_line_for_filtering(line_validator_tuple):
    """
    Parses a JSON line and applies a validator. Designed for use with multiprocessing.
    Args:
        line_validator_tuple (tuple): A tuple containing (line_content_string, validator_function).
                                      validator_function can be None.
    Returns:
        object: The parsed and validated JSON object (dict) if valid, otherwise None.
    """
    line_content, validator_func = line_validator_tuple
    if not line_content.strip(): # Skip empty or whitespace-only lines
        return None
    try:
        entry = json.loads(line_content)
        if validator_func is None or validator_func(entry):
            return entry  # Parsed and validated (or no validator)
        return None  # Failed validation
    except json.JSONDecodeError:
        # Optionally, log this error or count parsing failures
        # print(f"Warning: Could not parse line: {line_content[:100]}")
        return None  # Failed parsing
    
def sample_jsons(filename, n, method='first', seed=42, validator=None):
    """
    Sample N entries from a JSON lines file. (Non-parallel version as provided)
    
    Args:
        filename (str): Path to the JSON lines file.
        n (int): Number of entries to sample.
        method (str): Sampling method - 'first', 'random', or 'last'.
        seed (int): Random seed for reproducible random sampling.
        validator (function, optional): If provided, only returns entries that pass this
                                        validation function. This function must take in a
                                        parsed JSON dict and output True or False. Defaults to None.
    
    Returns:
        list: List of parsed JSON objects (dictionaries).
    """
    if n == 0:
        return []

    samples = []
    
    if method == 'first':
        with open(filename, 'r', encoding='utf-8') as f:
            pbar = tqdm(desc=f"Finding first {n} valid entries", total=n, unit="entry")
            for line in f:
                if not line.strip():
                    continue
                try:
                    entry = json.loads(line)
                    if validator is None or validator(entry):
                        samples.append(entry)
                        pbar.update(1)
                        if len(samples) >= n:
                            break
                except json.JSONDecodeError:
                    tqdm.write(f"Skipping unparseable line in 'first': {line[:100]}")
            pbar.close()
            if len(samples) < n:
                tqdm.write(f"Warning: Found only {len(samples)} valid entries out of {n} requested from the beginning of the file.")
        return samples
    
    elif method == 'last':
        samples_deque = deque(maxlen=n)
        total_lines = None
        try:
            with open(filename, 'r', encoding='utf-8') as f_count:
                total_lines = sum(1 for _ in f_count)
        except (IOError, OSError):
            pass # total_lines will remain None

        with open(filename, 'r', encoding='utf-8') as f:
            with tqdm(desc=f"Scanning for last {n} valid entries", total=total_lines, unit="line") as pbar:
                for line in f:
                    pbar.update(1)
                    if not line.strip():
                        continue
                    try:
                        entry = json.loads(line)
                        if validator is None or validator(entry):
                            samples_deque.append(entry)
                    except json.JSONDecodeError:
                        tqdm.write(f"Skipping unparseable line in 'last': {line[:100]}")
        return list(samples_deque)
    
    elif method == 'random':
        random.seed(seed)
        samples = []
        valid_items_seen = 0
        total_lines = None
        try:
            with open(filename, 'r', encoding='utf-8') as f_count:
                total_lines = sum(1 for _ in f_count)
        except (IOError, OSError):
            pass

        with open(filename, 'r', encoding='utf-8') as f:
            with tqdm(desc=f"Random sampling for {n} valid entries", total=total_lines, unit="line") as pbar:
                for line in f:
                    pbar.update(1)
                    if not line.strip():
                        continue
                    try:
                        entry = json.loads(line)
                        if validator is None or validator(entry):
                            valid_items_seen += 1
                            if len(samples) < n:
                                samples.append(entry)
                            else:
                                j = random.randint(0, valid_items_seen - 1)
                                if j < n:
                                    samples[j] = entry
                    except json.JSONDecodeError:
                        tqdm.write(f"Skipping unparseable line in 'random': {line[:100]}")
            if len(samples) < n:
                tqdm.write(f"Warning: Found only {len(samples)} valid random entries out of {n} requested.")
        return samples
    
    else:
        raise ValueError("method must be 'first', 'last', or 'random'")

def filter_jsons(json_list, validator_function, max_workers=None):
    """
    Filter a list of JSON objects (already in memory) in parallel using a validator function.
    """
    if not json_list:
        return []
    
    if validator_function is None: # If no validator, return all
        return list(json_list)

    if max_workers is None:
        max_workers = os.cpu_count()

    validated_items = []
    
    with ProcessPoolExecutor(max_workers=max_workers) as executor:
        results = list(executor.map(validator_function, tqdm(json_list, desc="Filtering in-memory JSONs", unit="entry")))
        
        for item, is_valid_flag in zip(json_list, results):
            if is_valid_flag:
                validated_items.append(item)
                
    return validated_items


def read_and_filter_jsons_parallel(filename, validator_function, max_workers=None, batch_size=2048):
    """
    Reads a JSON lines file, filters its entries in parallel, and returns a list of valid JSON objects.
    This is memory-efficient as it processes the file in streaming batches.

    Args:
        filename (str): Path to the JSON lines file.
        validator_function (function): A function that takes a parsed JSON dict and returns True if valid.
        max_workers (int, optional): Max worker processes. Defaults to os.cpu_count().
        batch_size (int, optional): Number of lines to process in each parallel batch. Defaults to 2048.

    Returns:
        list: A list of parsed JSON objects that passed validation.
    """
    if max_workers is None:
        max_workers = os.cpu_count()

    validated_jsons = []
    total_lines = None
    
    try:
        # Pre-count lines for a determinate progress bar. This adds an initial pass.
        with open(filename, 'r', encoding='utf-8') as f_count:
            total_lines = sum(1 for _line in f_count) 
    except (IOError, OSError) as e:
        tqdm.write(f"Info: Could not pre-count lines ({e}). Progress bar will be indeterminate for total lines.")

    pbar_desc = "Reading & filtering JSON lines (parallel)"
    if total_lines is not None:
        pbar = tqdm(total=total_lines, desc=pbar_desc, unit="line")
    else:
        # Indeterminate total, tqdm will show rate and lines processed so far
        pbar = tqdm(desc=pbar_desc, unit="line") 

    with ProcessPoolExecutor(max_workers=max_workers) as executor, \
         open(filename, 'r', encoding='utf-8') as f:
        
        lines_batch_for_processing = []

        for line_content in f:
            lines_batch_for_processing.append((line_content, validator_function))
            
            if len(lines_batch_for_processing) >= batch_size:
                try:
                    for entry in executor.map(_mp_process_line_for_filtering, lines_batch_for_processing):
                        if entry: # If parsed, validated, and not None
                            validated_jsons.append(entry)
                except Exception as e: # Catch errors from the process pool more broadly
                    tqdm.write(f"Error processing a batch: {e}")
                
                pbar.update(len(lines_batch_for_processing))
                lines_batch_for_processing.clear()
        
        # Process any remaining lines in the last batch
        if lines_batch_for_processing:
            try:
                for entry in executor.map(_mp_process_line_for_filtering, lines_batch_for_processing):
                    if entry:
                        validated_jsons.append(entry)
            except Exception as e:
                 tqdm.write(f"Error processing the final batch: {e}")
            pbar.update(len(lines_batch_for_processing))
        
    pbar.close()
    return validated_jsons


fn = "arxiv-metadata-oai-snapshot.json"
jsons = read_and_filter_jsons_parallel(
    fn, 
    validator_function= is_valid(
        entry
        allowed_major_categories=['cs'], 
        start_date='2020-01-01', 
        end_date='2023-12-31'
    ),
    max_workers=8,
    batch_size=2048)

TypeError: is_valid() missing 1 required positional argument: 'entry'

In [5]:
import json
import random
from collections import deque
from pathlib import Path  # Assuming this might be used depending on file path needs
from datetime import datetime
from functools import partial # For creating picklable partial functions

from tqdm.auto import tqdm   # For progress bars
from joblib import Parallel, delayed # For parallel processing
import os                    # For os.cpu_count()

# --- Helper Function for Parallel Processing ---
def _mp_process_line_for_filtering(line_validator_tuple):
    """
    Parses a JSON line and applies a validator. Designed for use with multiprocessing/joblib.
    Args:
        line_validator_tuple (tuple): A tuple containing (line_content_string, validator_function).
                                      validator_function can be None.
    Returns:
        object: The parsed and validated JSON object (dict) if valid, otherwise None.
    """
    line_content, validator_func = line_validator_tuple
    if not line_content.strip(): # Skip empty or whitespace-only lines
        return None
    try:
        entry = json.loads(line_content)
        if validator_func is None or validator_func(entry):
            return entry  # Parsed and validated (or no validator)
        return None  # Failed validation
    except json.JSONDecodeError:
        # Optionally, log this error or count parsing failures
        # tqdm.write(f"Warning: Could not parse line: {line_content[:100]}")
        return None  # Failed parsing

# --- Validation Function ---
def is_valid(entry, allowed_categories=None, allowed_major_categories=None,
             allowed_minor_categories=None, start_date=None, end_date=None):
    """
    Check if an ArXiv entry is valid based on categories and date range.
    """
    # Category check
    if any(x is not None for x in [allowed_categories, allowed_major_categories, allowed_minor_categories]):
        entry_categories = entry.get('categories', '')
        if not entry_categories:
            return False

        entry_cats = [cat.strip() for cat in entry_categories.split()]

        if allowed_categories is not None:
            if not any(cat in allowed_categories for cat in entry_cats):
                return False

        if allowed_major_categories is not None:
            entry_majors = [cat.split('.')[0] if '.' in cat else cat for cat in entry_cats]
            if not any(major in allowed_major_categories for major in entry_majors):
                return False

        if allowed_minor_categories is not None:
            entry_minors = [cat.split('.')[1] if '.' in cat and len(cat.split('.')) > 1 else ''
                           for cat in entry_cats]
            entry_minors = [minor for minor in entry_minors if minor]
            if not any(minor in allowed_minor_categories for minor in entry_minors):
                return False

    # Date check
    if start_date is not None or end_date is not None:
        update_date = entry.get('update_date', '')
        if not update_date:
            return False

        try:
            entry_date = datetime.strptime(update_date, '%Y-%m-%d')

            if start_date is not None:
                start_dt = datetime.strptime(start_date, '%Y-%m-%d')
                if entry_date < start_dt:
                    return False

            if end_date is not None:
                end_dt = datetime.strptime(end_date, '%Y-%m-%d')
                if entry_date > end_dt:
                    return False
        except ValueError:
            # Invalid date format
            return False
    return True

# --- Sampling Function (Serial version as provided previously) ---
def sample_jsons(filename, n, method='first', seed=42, validator=None):
    """
    Sample N entries from a JSON lines file. (Non-parallel version)
    """
    if n == 0: return []
    samples = []
    if method == 'first':
        with open(filename, 'r', encoding='utf-8') as f:
            pbar = tqdm(desc=f"Finding first {n} valid entries", total=n, unit="entry")
            for line in f:
                if not line.strip(): continue
                try:
                    entry = json.loads(line)
                    if validator is None or validator(entry):
                        samples.append(entry)
                        pbar.update(1)
                        if len(samples) >= n: break
                except json.JSONDecodeError: tqdm.write(f"Skipping unparseable line in 'first': {line[:100]}")
            pbar.close()
            if len(samples) < n: tqdm.write(f"Warning: Found only {len(samples)} valid entries of {n} requested.")
        return samples
    elif method == 'last':
        samples_deque = deque(maxlen=n)
        total_lines = None
        try:
            with open(filename, 'r', encoding='utf-8') as f_count: total_lines = sum(1 for _ in f_count)
        except (IOError, OSError): pass
        with open(filename, 'r', encoding='utf-8') as f:
            with tqdm(desc=f"Scanning for last {n} valid entries", total=total_lines, unit="line") as pbar:
                for line in f:
                    pbar.update(1)
                    if not line.strip(): continue
                    try:
                        entry = json.loads(line)
                        if validator is None or validator(entry): samples_deque.append(entry)
                    except json.JSONDecodeError: tqdm.write(f"Skipping unparseable line in 'last': {line[:100]}")
        return list(samples_deque)
    elif method == 'random':
        random.seed(seed)
        samples = []
        valid_items_seen = 0
        total_lines = None
        try:
            with open(filename, 'r', encoding='utf-8') as f_count: total_lines = sum(1 for _ in f_count)
        except (IOError, OSError): pass
        with open(filename, 'r', encoding='utf-8') as f:
            with tqdm(desc=f"Random sampling for {n} valid entries", total=total_lines, unit="line") as pbar:
                for line in f:
                    pbar.update(1)
                    if not line.strip(): continue
                    try:
                        entry = json.loads(line)
                        if validator is None or validator(entry):
                            valid_items_seen += 1
                            if len(samples) < n: samples.append(entry)
                            else:
                                j = random.randint(0, valid_items_seen - 1)
                                if j < n: samples[j] = entry
                    except json.JSONDecodeError: tqdm.write(f"Skipping unparseable line in 'random': {line[:100]}")
            if len(samples) < n: tqdm.write(f"Warning: Found only {len(samples)} valid random entries of {n} requested.")
        return samples
    else: raise ValueError("method must be 'first', 'last', or 'random'")

# --- Filtering for In-Memory Lists (Joblib version) ---
def filter_jsons(json_list, validator_function, n_jobs=None):
    """
    Filter a list of JSON objects (already in memory) in parallel using a validator function with joblib.
    """
    if not json_list:
        return []
    if validator_function is None:
        return list(json_list)
    if n_jobs is None:
        n_jobs = -1 # Default to all CPUs

    # tqdm shows progress for creating delayed task objects
    tasks = (delayed(validator_function)(item) for item in tqdm(json_list, desc="Preparing tasks for joblib filtering", unit="entry"))
    
    # For actual computation progress from joblib, set verbose in Parallel, e.g., verbose=5
    is_valid_results = Parallel(n_jobs=n_jobs)(tasks)
    
    validated_items = []
    for item, is_valid_flag in zip(json_list, is_valid_results):
        if is_valid_flag:
            validated_items.append(item)
    return validated_items

def read_and_filter_jsons(filename, validator_function, n_jobs=None, batch_size=2048):
    """
    Reads a JSON lines file, filters its entries in parallel using joblib,
    and returns a list of valid JSON objects. This is memory-efficient.
    """
    if n_jobs is None:
        n_jobs = -1 # Default to all CPUs

    validated_jsons = []
    total_lines = None
    try:
        with open(filename, 'r', encoding='utf-8') as f_count:
            total_lines = sum(1 for _line in f_count)
    except (IOError, OSError) as e:
        tqdm.write(f"Info: Could not pre-count lines ({e}). Progress bar for total lines might be indeterminate.")

    pbar_desc = "Reading & filtering (joblib)"
    if total_lines is not None:
        pbar = tqdm(total=total_lines, desc=pbar_desc, unit="line")
    else:
        pbar = tqdm(desc=pbar_desc, unit="line")

    # For joblib execution progress, you can set verbose, e.g., verbose=5
    parallel_executor = Parallel(n_jobs=n_jobs)

    with open(filename, 'r', encoding='utf-8') as f:
        lines_batch_for_processing = []
        for line_content in f:
            lines_batch_for_processing.append((line_content, validator_function))
            if len(lines_batch_for_processing) >= batch_size:
                try:
                    tasks = (delayed(_mp_process_line_for_filtering)(item_tuple) for item_tuple in lines_batch_for_processing)
                    batch_results = parallel_executor(tasks)
                    for entry in batch_results:
                        if entry:
                            validated_jsons.append(entry)
                except Exception as e:
                    tqdm.write(f"Error processing a batch with joblib: {e}")
                pbar.update(len(lines_batch_for_processing))
                lines_batch_for_processing.clear()

        if lines_batch_for_processing: # Process any remaining lines
            try:
                tasks = (delayed(_mp_process_line_for_filtering)(item_tuple) for item_tuple in lines_batch_for_processing)
                batch_results = parallel_executor(tasks)
                for entry in batch_results:
                    if entry:
                        validated_jsons.append(entry)
            except Exception as e:
                 tqdm.write(f"Error processing the final batch with joblib: {e}")
            pbar.update(len(lines_batch_for_processing))
    pbar.close()
    return validated_jsons

# --- Main Execution Block ---
if __name__ == '__main__':
    # Ensure this file exists or provide the correct path
    # Download from: https://www.kaggle.com/datasets/Cornell-University/arxiv
    # (or use a smaller JSON Lines test file)
    json_lines_file = "arxiv-metadata-oai-snapshot.json" 

    if not Path(json_lines_file).is_file():
        print(f"Error: File '{json_lines_file}' not found.")
        print("Please download it or use a different JSON Lines file.")
        # Example of creating a dummy file for testing:
        if not Path("dummy_arxiv_data.jsonl").is_file():
            print("Creating a small dummy_arxiv_data.jsonl for testing...")
            dummy_data = [
                {"id": "2001.00001", "title": "Paper CS 1", "categories": "cs.AI hep-th", "update_date": "2020-01-01", "abstract": "Abstract 1"},
                {"id": "2001.00002", "title": "Paper Math 1", "categories": "math.CO cs.CG", "update_date": "2021-05-15", "abstract": "Abstract 2"},
                {"id": "2002.00003", "title": "Paper CS 2", "categories": "cs.LG", "update_date": "2020-02-20", "abstract": "Abstract 3"},
                {"id": "2023.00004", "title": "Paper CS 3 Recent", "categories": "cs.CV", "update_date": "2023-11-05", "abstract": "Abstract 4"},
                {"id": "cond-mat.00005", "title": "Paper Phys", "categories": "cond-mat.stat-mech", "update_date": "2019-12-31", "abstract": "Abstract 5"},
            ]
            with open("dummy_arxiv_data.jsonl", "w") as f_dummy:
                for item in dummy_data:
                    f_dummy.write(json.dumps(item) + "\n")
            json_lines_file = "dummy_arxiv_data.jsonl" # Use dummy file if main one not found
            print(f"Using '{json_lines_file}' for this run.")


    # Create a picklable validator function using functools.partial
    # This specific validator looks for 'cs' papers between 2020 and 2023.
    validation_criteria = partial(is_valid,
                                    allowed_major_categories=['cs'],
                                    start_date='2018-01-01',
                                    end_date='2025-05-20')

    print(f"\nStarting efficient parallel reading and filtering of '{json_lines_file}' with joblib...")
    
    # You can adjust n_jobs and batch_size for performance tuning
    # n_jobs=-1 uses all available CPUs
    # n_jobs=1 would be sequential (useful for debugging)
    # batch_size affects how many lines are grouped for each parallel task submission
    
    all_valid_cs_papers = read_and_filter_jsons(
        json_lines_file,
        validator_function=validation_criteria,
        n_jobs=-1, 
        batch_size=4096 # Increased batch_size can sometimes be better for very large files
    )

    print(f"\nFound {len(all_valid_cs_papers)} 'cs' papers from 2020-2023.")
    if all_valid_cs_papers:
        print(f"First valid entry example:")
        entry = all_valid_cs_papers[0]
        print(f"  ID: {entry.get('id', 'N/A')}")
        print(f"  Title: {entry.get('title','N/A')[:80]}...")
        print(f"  Categories: {entry.get('categories', 'N/A')}")
        print(f"  Update Date: {entry.get('update_date', 'N/A')}")



Starting efficient parallel reading and filtering of 'arxiv-metadata-oai-snapshot.json' with joblib...


Reading & filtering (joblib):   0%|          | 0/2735264 [00:00<?, ?line/s]


Found 619980 'cs' papers from 2020-2023.
First valid entry example:
  ID: 0704.3504
  Title: Smooth R\'enyi Entropy of Ergodic Quantum Information Sources...
  Categories: quant-ph cs.IT math.IT
  Update Date: 2018-02-13


In [7]:
import pandas as pd
df = pd.DataFrame(all_valid_cs_papers)

In [None]:
df['update_date'].max()

In [None]:
from src.json_utils import JsonUtils

In [4]:
!pip install pexpect

ModuleNotFoundError: No module named 'pexpect'

In [11]:
import pandas as pd
df = pd.read_json("../data/clean/arxiv.json")

ValueError: Expected object or value