# read in files

In [5]:
import pandas as pd 
import json 
from src.helpers import text2list

from collections import Counter


arxiv_fn = "../data/clean/arxiv_2018-01-01_2025-05-20_cs__.jsonl"
nyt_fn = "../data/clean/nyt_2018-01-01_2025-05-20.jsonl"



ai_terms = "../data/raw/raw_ai_terms.json"
with open(ai_terms, 'r') as f:
    ai_terms = json.load(f)
all_terms = list(ai_terms.values())
all_terms = [item for sublist in all_terms for item in sublist]
all_terms = [x.lower() for x in all_terms if isinstance(x, str)]
all_terms = list(set(all_terms))

atus_roles = text2list("../data/clean/atus_roles.txt")
onet_roles = text2list("../data/clean/onet_roles.txt")
roles = atus_roles + onet_roles

arxiv_df = pd.read_json(arxiv_fn, lines=True)
arxiv_df['text'] = arxiv_df['title'] + " " + arxiv_df['abstract'] 

nyt_df = pd.read_json(nyt_fn, lines=True)
nyt_df['text'] = nyt_df['headline'] + " " + nyt_df['abstract'] + " " + nyt_df['snippet']


# Count words

In [None]:
import re
import pandas as pd
# from tqdm import tqdm # Optional: for progress bar with pandas apply

# Optional: Initialize tqdm for pandas apply (call once)
# tqdm.pandas(desc="Processing Series")

def make_combined_regex(terms_or_phrases_list):
    """
    Creates a compiled regex to find any of the terms/phrases in a list.
    The regex will perform a case-insensitive match for whole words/phrases.
    """
    if not terms_or_phrases_list:
        return None
    
    # Filter out any None or empty string terms before joining
    valid_terms = [str(term).lower() for term in terms_or_phrases_list if term and str(term).strip()]
    if not valid_terms:
        return None

    pattern_str = r'\b(?:' + '|'.join(re.escape(term) for term in valid_terms) + r')\b'
    
    if pattern_str == r'\b(?:)\b':
        return None
        
    try:
        return re.compile(pattern_str, re.IGNORECASE)
    except re.error as e:
        print(f"Error compiling regex with pattern '{pattern_str}': {e}")
        return None

def count_matches_with_compiled_regex(text, compiled_regex):
    """
    Counts non-overlapping occurrences of matches for a pre-compiled regex in a text.
    """
    if pd.isna(text) or compiled_regex is None:
        return 0
    text_str = str(text) # Ensure text is a string
    matches = compiled_regex.findall(text_str)
    return len(matches)

def add_ai_and_role_counts_optimized(df, all_ai_terms_list, roles_list, text_col='text'):
    """
    Adds 'ai_terms_count' and 'roles_count' columns to the DataFrame
    using optimized regex matching.

    Args:
        df (pd.DataFrame): The input DataFrame.
        all_ai_terms_list (list): A list of AI-related terms/phrases.
        roles_list (list): A list of role-related terms/phrases.
        text_col (str): The name of the column in `df` containing the text.

    Returns:
        pd.DataFrame: The DataFrame with added 'ai_terms_count' and 'roles_count'.
    """
    if text_col not in df.columns:
        print(f"Error: Text column '{text_col}' not found in DataFrame. Returning original DataFrame.")
        return df

    # Compile regex for AI terms
    ai_regex = make_combined_regex(all_ai_terms_list)
    if ai_regex:
        print(f"Calculating 'ai_terms_count' using column '{text_col}'...")
        df['ai_terms_count'] = df[text_col].apply(
            lambda text_to_search: count_matches_with_compiled_regex(text_to_search, ai_regex)
        )
    else:
        print("Warning: Could not compile regex for AI terms or list was empty. 'ai_terms_count' will be 0.")
        df['ai_terms_count'] = 0

    # Compile regex for roles
    roles_regex = make_combined_regex(roles_list)
    if roles_regex:
        print(f"Calculating 'roles_count' using column '{text_col}'...")
        df['roles_count'] = df[text_col].apply(
            lambda text_to_search: count_matches_with_compiled_regex(text_to_search, roles_regex)
        )
    else:
        print("Warning: Could not compile regex for roles or list was empty. 'roles_count' will be 0.")
        df['roles_count'] = 0
    
    return df

if __name__ == '__main__':
    # --- Example Usage ---


    print("--- Processing ArXiv DataFrame (Simplified) ---")
    arxiv_df_updated = add_ai_and_role_counts_optimized(arxiv_df.copy(), 
                                                      all_ai_terms_list=all_terms,
                                                      roles_list=roles,
                                                      text_col='text')
    print("\nArXiv DataFrame with counts:")
    print(arxiv_df_updated[['text', 'ai_terms_count', 'roles_count']])

    print("\n--- Processing NYT DataFrame (Simplified) ---")
    nyt_df_updated = add_ai_and_role_counts_optimized(nyt_df.copy(),
                                                    all_ai_terms_list=all_terms,
                                                    roles_list=roles,
                                                    text_col='text')
    print("\nNYT DataFrame with counts:")
    print(nyt_df_updated[['text', 'ai_terms_count', 'roles_count']])

    # Quick stats
    print("\n--- Quick Stats (Simplified) ---")
    print("ArXiv AI terms sum:", arxiv_df_updated['ai_terms_count'].sum())
    print("ArXiv roles sum:", arxiv_df_updated['roles_count'].sum()) 
    print("NYT AI terms sum:", nyt_df_updated['ai_terms_count'].sum())
    print("NYT roles sum:", nyt_df_updated['roles_count'].sum())


--- Processing ArXiv DataFrame (Simplified) ---
Calculating 'ai_terms_count' using column 'text'...
