In [None]:
from src.json_utils import JsonUtils
from src.helpers import path2correct_loc
from datetime import datetime
from functools import partial
import kagglehub
import os
import shutil
from pathlib import Path
import logging 
import os 


start_date = '2025-05-01'
end_date = '2025-05-20'
allowed_categories = ["cs.AI", "cs.MA"]
allowed_categories_str = "".join(f"{cat} " for cat in allowed_categories).strip()

logging.info(f"Filtering ArXiv entries from {start_date} to {end_date} for categories: {allowed_categories_str}")


def is_valid(entry, allowed_categories=None, allowed_major_categories=None,
             allowed_minor_categories=None, start_date=None, end_date=None):
    """
    Check if an ArXiv entry is valid based on categories and date range.
    """
    # Category check
    if any(x is not None for x in [allowed_categories, allowed_major_categories, allowed_minor_categories]):
        entry_categories = entry.get('categories', '')
        if not entry_categories:
            return False

        entry_cats = [cat.strip() for cat in entry_categories.split()]

        if allowed_categories is not None:
            if not any(cat in allowed_categories for cat in entry_cats):
                return False

        if allowed_major_categories is not None:
            entry_majors = [cat.split('.')[0] if '.' in cat else cat for cat in entry_cats]
            if not any(major in allowed_major_categories for major in entry_majors):
                return False

        if allowed_minor_categories is not None:
            entry_minors = [cat.split('.')[1] if '.' in cat and len(cat.split('.')) > 1 else ''
                           for cat in entry_cats]
            entry_minors = [minor for minor in entry_minors if minor]
            if not any(minor in allowed_minor_categories for minor in entry_minors):
                return False

    # Date check
    if start_date is not None or end_date is not None:
        update_date = entry.get('update_date', '')
        if not update_date:
            return False

        try:
            entry_date = datetime.strptime(update_date, '%Y-%m-%d')

            if start_date is not None:
                start_dt = datetime.strptime(start_date, '%Y-%m-%d')
                if entry_date < start_dt:
                    return False

            if end_date is not None:
                end_dt = datetime.strptime(end_date, '%Y-%m-%d')
                if entry_date > end_dt:
                    return False
        except ValueError:
            # Invalid date format
            return False
    return True






validation_criteria = partial(is_valid,
                                    allowed_categories=allowed_categories,
                                    start_date=start_date,
                                    end_date=end_date)

# dls to cache
path = kagglehub.dataset_download("Cornell-University/arxiv/versions/234")
new_location = path2correct_loc(path, "")
print(f"New location: {new_location}")

fn = "arxiv-metadata-oai-snapshot.json"

jutils = JsonUtils()


jsons = jutils.read_and_filter_json_file(filename=fn, 
                                           validator_function=validation_criteria)

output_fn = f"../data/clean/arxiv_{start_date}_{end_date}_{allowed_categories_str}.jsonl"

with open(output_fn, 'w') as f:
    for entry in jsons:
        f.write(f"{entry}\n")
        
print(f"Filtered entries saved to {output_fn}")
logging.info(f"Filtered entries saved to {output_fn}")




In [ ]:
path

In [ ]:
import pandas as pd
df = pd.read_json("../data/clean/arxiv.json")
