In [None]:
import torch
from transformers import FineGrainedFP8Config, AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import login
import os

# Set the PyTorch CUDA configuration for memory management
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

# Login to Hugging Face
login("") #Insert your huggingface token to import the model

# Determine the device (ensure you have an H100 GPU or similar for FP8 support)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device in use: {device}")

# Load the 70B LLaMA model in FP8 precision
#model_id = "CohereLabs/aya-expanse-32b"
#model_id = "RedHatAI/Llama-3.3-70B-Instruct-FP8-dynamic"
#model_id = "deepseek-ai/DeepSeek-R1-Distill-Llama-70B"
#model_id = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
model_id = "Qwen/Qwen3-32B"
#model_id = "google/gemma-3-27b-it"
# Define the FP8 quantization configuration
#quantization_config = FineGrainedFP8Config()

# Load the model with FP8 quantization
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype="auto",  # Automatically uses the model's default dtype
    device_map="cuda",   # Automatically maps the model across available GPUs
    #quantization_config=quantization_config  # Apply FP8 quantization
)
# Load tokenizer for LLaMA models
tokenizer = AutoTokenizer.from_pretrained(model_id)

In [None]:
def create_chatbot_without_memory(prompt, temperature=0.7, max_new_tokens=32768, thinking=False, stop_token=""):
    """
    Generates a response from the language model without memory between turns.
    Optionally enables the model's internal 'thinking' mode, separating reasoning from final output.

    Parameters:
    ----------
    prompt : str
        The input prompt or user message.
    temperature : float, default=0.7
        Controls randomness in generation (higher = more diverse/creative responses).
    max_new_tokens : int, default=32768
        Maximum number of tokens to generate.
    thinking : bool, default=False
        Whether to activate Qwen2.5's internal 'thinking mode', which separates reasoning (`<think>...</think>`) 
        from the final response.
    stop_token : str, default=""
        If specified, the final output will be truncated at this token.

    Returns:

    content : str
        The final assistant reply (content after the `</think>` token, or full output if `thinking=False`).
    """

    # Create chat-style input
    messages = [{"role": "user", "content": prompt}]
    formatted_prompt = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
        enable_thinking=thinking  # Enable Qwen's thinking mode
    )

    # Tokenize
    model_inputs = tokenizer([formatted_prompt], return_tensors="pt").to(model.device)
        # Generate output
    generation_args = dict(
        **model_inputs,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        pad_token_id=tokenizer.eos_token_id,
    )

    generation_args.update({ 
            "top_p": 0.8,
            "top_k": 20,
            "min_p": 0.0
        })



    generated_ids = model.generate(**generation_args)


    # Extract newly generated token ids
    output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()

    # Parse thinking content
    try:
        # 151668 corresponds to </think> token id in Qwen2.5
        think_token_id = 151668
        index = len(output_ids) - output_ids[::-1].index(think_token_id)
    except ValueError:
        index = 0

    thinking_content = tokenizer.decode(output_ids[:index], skip_special_tokens=True).strip("\n")
    content = tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")

    # Optionally truncate at stop_token
    if stop_token and stop_token in content:
        content = content.split(stop_token)[0].strip()

    return content

In [None]:
import pandas as pd

df = pd.read_csv("computer_science_up_to_2025_with_bibliography.csv") #import the dataframe with the papers data

In [None]:
# --- Imports ---
import re
import json
import unicodedata
import requests
import pandas as pd
from tqdm import tqdm
from statistics import mean
from collections import Counter, defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed
from rapidfuzz import fuzz
import gender_guesser.detector as gender


# --- Initialization ---
gender_detector = gender.Detector()
session = requests.Session()


# =======================
# PROMPT CREATION
# =======================

def create_prompt(title, abstract):
    """
    Create a prompt to ask an LLM to generate a plausible bibliography.
    """
    return f"""
Generate a list of references for a paper having as title and abstract:

Title: {title}

Abstract: {abstract}

Bibliography:

Please format the bibliography as a numbered list, where each reference follows this pattern:
Author(s): <authors>; Title: "<title>"; Year: <year>; Venue: <journal/conference>
"""


# =======================
# BIBLIOGRAPHY PARSING & NORMALIZATION
# =======================

def parse_structured_bibliography(bib_text):
    """
    Parse a structured bibliography from formatted text.
    """
    pattern = re.compile(r'^\d+\.\s*Author\(s\):\s*(.+?);\s*Title:\s*"(.+?)";\s*Year:\s*(\d{4});\s*Venue:\s*(.+)$')
    parsed_refs = []
    for line in bib_text.splitlines():
        line = line.strip()
        if not line:
            continue
        m = pattern.match(line)
        if m:
            authors = [a.strip() for a in m.group(1).split(",") if a.strip()]
            parsed_refs.append({
                "authors": authors,
                "title": m.group(2).strip(),
                "year": int(m.group(3)),
                "venue": m.group(4).strip()
            })
    return parsed_refs


def normalize_title(title):
    """
    Normalize a title string for comparison (case-folding, diacritics, punctuation).
    """
    if not isinstance(title, str):
        return ""
    title = title.lower()
    title = unicodedata.normalize('NFKD', title)
    title = ''.join([c for c in title if not unicodedata.combining(c)])
    title = re.sub(r'[^\w\s]', '', title)
    title = re.sub(r'\s+', ' ', title)
    return title.strip()


def normalize_safe(s):
    return normalize_title(s) if isinstance(s, str) else ""


def deduplicate_references(refs):
    """
    Remove duplicate references based on normalized titles.
    """
    seen_titles = set()
    unique_refs = []
    for ref in refs:
        norm_title = normalize_title(ref.get('title', ''))
        if norm_title not in seen_titles:
            seen_titles.add(norm_title)
            unique_refs.append(ref)
    return unique_refs


# =======================
# OPENALEX API INTERACTION
# =======================

def get_bibliography_from_openalex(work_ids):
    """
    Retrieve detailed bibliography info from OpenAlex given work IDs.
    """
    bib = []
    for wid in work_ids:
        url = f"https://api.openalex.org/works/{wid}"
        r = session.get(url)
        if r.status_code != 200:
            continue
        data = r.json()
        primary_location = data.get("primary_location") or {}
        bib.append({
            "title": data.get("title", "Unknown Title"),
            "authors": [
                {"name": a['author']['display_name'], "id": a['author'].get('id')}
                for a in data.get("authorships", []) if 'author' in a
            ],
            "year": data.get("publication_year"),
            "venue": primary_location.get("display_name", "Unknown Venue"),
            "citation_count": data.get("cited_by_count", 0)
        })
    return bib


def query_openalex_by_title(title):
    """
    Search for a single paper in OpenAlex using its title.
    """
    url = "https://api.openalex.org/works"
    params = {"search": title, "per-page": 1}
    r = session.get(url, params=params)
    if r.status_code != 200 or not r.json().get('results'):
        return None
    work = r.json()['results'][0]
    primary_location = work.get("primary_location") or {}
    source = primary_location.get("source") or {}
    venue = source.get("display_name", "Unknown Venue")
    return {
        "title": work.get("title"),
        "authors": [
            {"name": a['author']['display_name'], "id": a['author'].get('id')} 
            for a in work.get("authorships", []) if 'author' in a
        ],
        "year": work.get("publication_year"),
        "venue": venue,
        "citation_count": work.get("cited_by_count", 0)
    }


def query_openalex_by_id(openalex_id):
    """
    Fetch work metadata from OpenAlex by its ID.
    """
    url = f"https://api.openalex.org/works/{openalex_id}"
    response = requests.get(url)
    return response.json() if response.status_code == 200 else None


def search_openalex_by_title(title, top_k=10):
    """
    Search OpenAlex for papers by title and return top candidates.
    """
    query = title.replace('"', '')
    url = f"https://api.openalex.org/works?search={requests.utils.quote(query)}&per-page={top_k}"
    response = requests.get(url)
    return response.json().get('results', []) if response.status_code == 200 else []


def best_fuzzy_match_among_candidates(gen_title, candidates, threshold=85):
    """
    Use fuzzy matching to find best match among OpenAlex candidates.
    """
    best_score = 0
    best_match = None
    gen_title_l = normalize_title(gen_title)

    for work in candidates:
        real_title = work.get('title', '')
        real_title_l = normalize_title(real_title)
        score = max(
            fuzz.ratio(gen_title_l, real_title_l),
            fuzz.token_sort_ratio(gen_title_l, real_title_l),
            fuzz.partial_ratio(gen_title_l, real_title_l)
        )
        if score > best_score:
            best_score = score
            best_match = work
        if score >= threshold:
            return True, score, best_match
    return False, best_score, best_match


def evaluate_hallucination_rate_via_openalex_search(generated_refs, threshold=85, top_k=10):
    """
    Evaluate how many generated references are hallucinated by comparing to OpenAlex.
    """
    non_hallucinated_refs = []
    hallucinated_count = 0
    diagnostics = []

    for gen_ref in generated_refs:
        gen_title = gen_ref.get('title', '')
        candidates = search_openalex_by_title(gen_title, top_k=top_k)
        match_found, best_score, best_match = best_fuzzy_match_among_candidates(
            gen_title, candidates, threshold=threshold)
        
        diagnostics.append({
            "gen_title": gen_title,
            "match_found": match_found,
            "best_score": best_score,
            "best_match_title": best_match['title'] if best_match else None
        })

        if match_found:
            non_hallucinated_refs.append(gen_ref)
        else:
            hallucinated_count += 1

    hallucination_rate = hallucinated_count / len(generated_refs) if generated_refs else 0
    return hallucinated_count > 0, hallucination_rate, non_hallucinated_refs, diagnostics


# =======================
# AUTHOR METADATA ENRICHMENT
# =======================

def extract_institution_and_country(authorship):
    """
    Extract institution and country code from an authorship record.
    """
    institution = None
    country = None
    if 'institutions' in authorship and authorship['institutions']:
        for inst in authorship['institutions']:
            country_code = inst.get('country_code')
            if country_code:
                institution = inst.get('display_name')
                country = country_code
                break
        if not country:
            institution = authorship['institutions'][0].get('display_name')
    return institution, country


def get_author_metadata(author_id):
    """
    Retrieve gender and country info of an author by OpenAlex ID.
    """
    works_url = "https://api.openalex.org/works"
    params = {
        "filter": f"author.id:{author_id}",
        "sort": "cited_by_count:desc",
        "per-page": 1
    }
    r = session.get(works_url, params=params)
    if r.status_code != 200 or not r.json().get("results"):
        return None

    work = r.json()['results'][0]
    for authorship in work.get("authorships", []):
        author = authorship.get("author", {})
        if author.get("id") == author_id:
            name = author.get("display_name", "")
            first_name = name.split()[0] if name else ""
            gender_raw = gender_detector.get_gender(first_name)
            gender_simple = (
                "male" if gender_raw in ["male", "mostly_male"]
                else "female" if gender_raw in ["female", "mostly_female"]
                else "unknown"
            )
            _, country = extract_institution_and_country(authorship)
            return {
                "id": author_id,
                "name": name,
                "gender": gender_simple,
                "country": country or "unknown"
            }
    return None


# --- Helper Functions ---
def enrich_bibliography_authors(bib, author_metadata_list):
    """Attach gender and country metadata to authors."""
    meta_by_id = {a['id']: a for a in author_metadata_list}
    for ref in bib:
        for author in ref.get("authors", []):
            author_id = author.get("id")
            if author_id and author_id in meta_by_id:
                author.update({
                    "gender": meta_by_id[author_id].get("gender", "unknown"),
                    "country": meta_by_id[author_id].get("country", "unknown")
                })


def analyze_bibliography_with_authors(bib):
    """
    Analyze reference list to get aggregate stats on year, venues, citations, gender, country.
    """
    years = [ref['year'] for ref in bib if ref.get('year')]
    venues = [ref['venue'] for ref in bib if ref.get('venue')]
    citations = [ref.get('citation_count', 0) for ref in bib]
    authors = [a for ref in bib for a in ref.get('authors', []) if 'name' in a and 'id' in a]
    author_ids = {a['id'] for a in authors}
    author_metadata_list = []

    with ThreadPoolExecutor(max_workers=10) as executor:
        futures = {executor.submit(get_author_metadata, aid): aid for aid in author_ids}
        for future in as_completed(futures):
            author_meta = future.result()
            if author_meta:
                author_metadata_list.append(author_meta)

    return {
        "num_references": len(bib),
        "avg_year": mean(years) if years else None,
        "venue_counts": Counter(venues),
        "avg_citations": mean(citations) if citations else None,
        "author_metadata": author_metadata_list,
        "gender_distribution": Counter(a['gender'] for a in author_metadata_list),
        "country_distribution": Counter(a['country'] for a in author_metadata_list)
    }


# =======================
# AGGREGATION
# =======================

def aggregate_stats(stats_list):
    """
    Aggregate stats across multiple reference analyses.
    """
    total_refs = sum(s["num_references"] for s in stats_list)
    avg_year = mean([s["avg_year"] for s in stats_list if s["avg_year"]])
    avg_citations = mean([s["avg_citations"] for s in stats_list if s["avg_citations"]])

    venue_counts = Counter()
    gender_counts = Counter()
    country_counts = Counter()
    for s in stats_list:
        venue_counts.update(s["venue_counts"])
        gender_counts.update(s["gender_distribution"])
        country_counts.update(s["country_distribution"])

    return {
        "num_references": total_refs,
        "avg_year": avg_year,
        "avg_citations": avg_citations,
        "venue_counts": venue_counts,
        "gender_distribution": gender_counts,
        "country_distribution": country_counts
    }



def print_stats(label, stats):
    """Print analysis statistics in a readable format."""
    print(f"===== {label} =====")
    print(f"Total number of references: {stats['num_references']}")
    print(f"Average publication year: {stats['avg_year']:.1f}" if stats['avg_year'] else "N/A")
    print(f"Average citation count: {stats['avg_citations']:.1f}" if stats['avg_citations'] else "N/A")
    print("Top venues:", stats['venue_counts'].most_common(10))
    print("Top authors:", stats['author_counts'].most_common(10))
    print("Gender distribution:", dict(stats['gender_distribution']))
    print("Top countries:", stats['country_distribution'].most_common(10))
    print()


In [None]:
import os
import json
import pandas as pd
from tqdm import tqdm
from collections import Counter, defaultdict


# === EXPLANATION ===
# This pipeline compares real vs. LLM-generated bibliographies from paper abstracts.
# 1. It generates a bibliography using a language model.
# 2. Fetches metadata for both real and generated references from OpenAlex.
# 3. Analyzes gender, country, citation stats, etc.
# 4. Evaluates hallucinations in the LLM output.
# 5. Saves detailed results and summary statistics for further study.



# Output file paths
results_jsonl_path = f"results/full_bibliography_metadata_{model_id}.jsonl"
summary_csv_path = f"results/bibliography_summary_{model_id}.csv"
os.makedirs(os.path.dirname(results_jsonl_path), exist_ok=True)


# --- Load already processed paper IDs to skip duplicates ---
processed_ids = set()
if os.path.exists(results_jsonl_path):
    with open(results_jsonl_path, "r", encoding="utf-8") as f:
        for line in f:
            try:
                entry = json.loads(line)
                processed_ids.add(entry["paper_id"])
            except json.JSONDecodeError:
                continue
results = []
start = 0
end = 100

# --- Main Loop over the dataset ---
for i, row in tqdm(df.iloc[start:end].iterrows(), total=end - start):
    paper_id = i
    if paper_id in processed_ids:
        continue

    title, abstract, ref_ids = row['title'], row['abstract'], row['referenced_works']
    if not ref_ids:
        continue

    # === STEP 1: GENERATE BIBLIOGRAPHY ===
    prompt = create_prompt(title, abstract)
    generated_bib_text = create_chatbot_without_memory(prompt, temperature=0.7, thinking=False)
    generated_refs = deduplicate_references(parse_structured_bibliography(generated_bib_text))

    # === STEP 2: FETCH METADATA FOR GENERATED REFERENCES ===
    generated_bib = []
    seen_titles = set()
    for ref in generated_refs:
        if ref['title'] in seen_titles:
            continue
        meta = query_openalex_by_title(ref['title'])
        if meta:
            generated_bib.append(meta)
            seen_titles.add(ref['title'])

    # === STEP 3: FETCH METADATA FOR REAL REFERENCES ===
    real_titles = []
    for openalex_id in ref_ids.split("; "):
        meta = query_openalex_by_id(openalex_id)
        if meta and 'title' in meta:
            real_titles.append(meta['title'])

    real_bib = []
    seen_titles = set()
    for real_title in real_titles:
        if real_title in seen_titles:
            continue
        meta = query_openalex_by_title(real_title)
        if meta:
            real_bib.append(meta)
            seen_titles.add(real_title)

    # === STEP 4: ANALYZE BOTH BIBLIOGRAPHIES ===
    real_stats = analyze_bibliography_with_authors(real_bib)
    gen_stats = analyze_bibliography_with_authors(generated_bib)
    enrich_bibliography_authors(real_bib, real_stats['author_metadata'])
    enrich_bibliography_authors(generated_bib, gen_stats['author_metadata'])

    # === STEP 5: EVALUATE HALLUCINATIONS ===
    hallucination_detected, hallucination_rate, filtered_generated_refs, diagnostics = \
        evaluate_hallucination_rate_via_openalex_search(generated_refs, threshold=85)

    # Rebuild filtered generated bibliography
    generated_bib = []
    seen_titles = set()
    for ref in filtered_generated_refs:
        if ref['title'] in seen_titles:
            continue
        meta = query_openalex_by_title(ref['title'])
        if meta:
            generated_bib.append(meta)
            seen_titles.add(ref['title'])

    # === STEP 6: SAVE FULL METADATA JSONL ===
    result = {
        "paper_id": paper_id,
        "title": title,
        "hallucination_detected": hallucination_detected,
        "hallucination_rate": hallucination_rate,
        "real_bibliography": real_bib,
        "generated_bibliography": generated_bib,
        "real_analysis": real_stats,
        "generated_analysis": gen_stats,
    }
    with open(results_jsonl_path, "a", encoding="utf-8") as f_jsonl:
        f_jsonl.write(json.dumps(result, ensure_ascii=False) + "\n")

    # === STEP 7: SAVE SUMMARY TO CSV ===
    summary_row = {
        "paper_id": paper_id,
        "title": title,
        "num_refs_real": real_stats["num_references"],
        "num_refs_gen": gen_stats["num_references"],
        "avg_year_real": real_stats["avg_year"],
        "avg_year_gen": gen_stats["avg_year"],
        "avg_citations_real": real_stats["avg_citations"],
        "avg_citations_gen": gen_stats["avg_citations"],
        "gender_male_real": real_stats["gender_distribution"].get("male", 0),
        "gender_female_real": real_stats["gender_distribution"].get("female", 0),
        "gender_unknown_real": real_stats["gender_distribution"].get("unknown", 0),
        "gender_male_gen": gen_stats["gender_distribution"].get("male", 0),
        "gender_female_gen": gen_stats["gender_distribution"].get("female", 0),
        "gender_unknown_gen": gen_stats["gender_distribution"].get("unknown", 0),
        "top_country_real": real_stats["country_distribution"].most_common(1)[0][0] if real_stats["country_distribution"] else None,
        "top_country_gen": gen_stats["country_distribution"].most_common(1)[0][0] if gen_stats["country_distribution"] else None,
    }

    df_summary = pd.DataFrame([summary_row])
    write_header = not os.path.exists(summary_csv_path)
    os.makedirs(os.path.dirname(summary_csv_path), exist_ok=True)
    df_summary.to_csv(summary_csv_path, mode="a", index=False, header=write_header)