# Task 0: AI Data Pipeline (Class 2 & Class 3)

Generates synthetic datasets using the Gemini API, mirroring the human dataset's distribution.

- **Class 2 (Standard AI):** Minimalist topic-only prompts across Gemini 3 Flash, Gemini 2.5 Flash, and Gemini 2.5 Flash Lite.
- **Class 3 (Imposter AI):** Roleplay prompts instructing the model to mimic a specific author's style and era.

AI output (800–1000 words) is cleaned and chunked with the same 100–200 word pipeline used for human data, preserving a comparable "choppy and incomplete" feel.

In [None]:
import pandas as pd
import random
import re
import uuid
import time
import os
import json
import logging
from dotenv import load_dotenv
import warnings
import nltk
import sys

from google import genai
from google.api_core import exceptions

try:
    sys.stdout.reconfigure(encoding='utf-8')
    sys.stderr.reconfigure(encoding='utf-8')
except Exception:
    pass

warnings.filterwarnings("ignore", message=r"You are using a Python version .* google.api_core", category=FutureWarning)
os.environ.setdefault("GRPC_VERBOSITY", "ERROR")

In [None]:
current_dir = os.getcwd()
BASE_DIR = os.path.abspath("..")
PROJECT_ROOT = os.path.dirname(os.path.dirname(BASE_DIR))
ENV_PATH = os.path.join(PROJECT_ROOT, ".env")

DATA_HUMAN_DIR = os.path.join(PROJECT_ROOT, "data", "data_human", "processed")
INPUT_FILE = os.path.join(DATA_HUMAN_DIR, "human_class1.parquet")

OUTPUT_DIR = os.path.join(BASE_DIR, "processed")
if not os.path.exists(OUTPUT_DIR): os.makedirs(OUTPUT_DIR)

load_dotenv(ENV_PATH)
API_KEY = os.getenv("GEMINI_API_KEY")

if not API_KEY:
    print("WARNING: GEMINI_API_KEY not found in .env files. AI calls will fail.")
    client = None
else:
    client = genai.Client(api_key=API_KEY)

MODEL_NAME = "gemini-2.5-flash"

## Cleaning & Chunking

Same strict 100–200 word chunking pipeline as the human data. Curly-bracket parsing extracts only the essay body from API responses.

In [None]:
def ensure_nltk_resources():
    try:
        nltk.data.find('tokenizers/punkt')
    except LookupError:
        nltk.download('punkt', quiet=True)
        nltk.download('punkt_tab', quiet=True)

def deep_clean_text(text):
    if text is None: return ""
    text = re.sub(r'\[.*?\]', '', text, flags=re.DOTALL)
    text = text.replace('_', '')
    text = " ".join(text.split())
    return text

def get_chunks(text, min_w=100, max_w=200):
    ensure_nltk_resources()
    sentences = nltk.sent_tokenize(text)
    current_chunk = []
    current_count = 0

    for sentence in sentences:
        w_count = len(sentence.split())

        if w_count > max_w:
            if current_count >= min_w:
                yield " ".join(current_chunk)
            current_chunk = []
            current_count = 0
            continue

        if current_count + w_count <= max_w:
            current_chunk.append(sentence)
            current_count += w_count
            if current_count >= 150:
                yield " ".join(current_chunk)
                current_chunk = []
                current_count = 0
        else:
            if current_count >= min_w:
                yield " ".join(current_chunk)
            current_chunk = [sentence]
            current_count = w_count

    if current_count >= min_w:
        yield " ".join(current_chunk)

def analyze_chunk(text):
    sentences = nltk.sent_tokenize(text)
    words = text.split()
    sent_count = len(sentences)
    return {"word_count": len(words), "avg_sent_len": len(words) / sent_count if sent_count else 0}

def parse_curly_content(text):
    start = text.find('{')
    end = text.rfind('}')
    if start != -1 and end != -1 and end > start:
        return text[start+1:end].strip()
    else:
        print("Warning: AI forgot brackets. Taking raw text.")
        return text

## Prompt Engineering

- **Class 2:** Topic-only prompt requesting a comprehensive essay.
- **Class 3:** Roleplay prompt mimicking a specific author's style, vocabulary, and era.

In [None]:
AUTHOR_PROFILES = {
    "Bacon": {
        "year": "1625",
        "style": "archaic, aphoristic, heavy use of semi-colons, logical, authoritative, Renaissance English",
        "context": "The Essays or Counsels, Civil and Moral"
    },
    "Emerson": {
        "year": "1841",
        "style": "transcendentalist, poetic, flowery, metaphor-heavy, focus on individualism and nature",
        "context": "Essays: First Series"
    },
    "James": {
        "year": "1907",
        "style": "pragmatic, academic but accessible, first-person ('I'), focus on practical consequences",
        "context": "Lectures on Pragmatism"
    },
    "Russell": {
        "year": "1912",
        "style": "analytic, dry, precise, logical, focus on definitions and epistemology",
        "context": "The Problems of Philosophy"
    }
}

def build_prompt_class_2(topic):
    return f"""
    Write a comprehensive essay on the topic of "{topic}".
    
    Constraints:
    1. Length: Approximately 800-1000 words (about 40-60 sentences).
    2. Format: Enclose the ENTIRE essay body inside curly brackets {{ }}. 
       Example: {{ The essay starts here... and ends here. }}
    3. Do not put introductions like "Here is your essay" inside the brackets.
    """

def build_prompt_class_3(topic, author):
    profile = AUTHOR_PROFILES.get(author, AUTHOR_PROFILES["Emerson"])
    return f"""
    Roleplay Task: You are {author}, writing in the year {profile['year']}.
    
    Task: Write a new chapter for your book "{profile['context']}" on the topic of "{topic}".
    
    Style Guidelines:
    - Imitate this style: {profile['style']}.
    - Use the vocabulary and sentence structure typical of {profile['year']}.
    - Do NOT copy existing text, but generate new thoughts in that exact voice.
    
    Constraints:
    1. Length: Approximately 800-1000 words (about 40-60 sentences).
    2. Format: Enclose the ENTIRE essay text inside curly brackets {{ }}.
    """

## API Interaction

Handles retries and rate-limiting (429 errors) with exponential backoff.

In [None]:
def call_ai_api(prompt, max_retries=5):
    if client is None:
        raise ValueError("API Key missing or client not initialized.")

    attempt = 0
    while True:
        try:
            response = client.models.generate_content(model=MODEL_NAME, contents=prompt)
            if hasattr(response, "text") and response.text:
                return response.text
            return str(response)

        except Exception as e:
            err_str = str(e)
            if "429" in err_str or "ResourceExhausted" in str(type(e)) or "Quota exceeded" in err_str:
                wait_time = 30.0
                match = re.search(r"retry in ([0-9\.]+)s", err_str)
                if match:
                    wait_time = float(match.group(1)) + 10.0
                print(f"Rate Limit Hit. Pausing for {wait_time:.1f}s...")
                time.sleep(wait_time)
                continue

            attempt += 1
            if attempt >= max_retries:
                print(f"Failed after {max_retries} attempts. Error: {e}")
                return "{}"

            backoff_time = (2 ** attempt) + random.uniform(0, 1)
            print(f"API Error: {e}. Retrying in {backoff_time:.1f}s...")
            time.sleep(backoff_time)

def get_existing_data(filepath):
    if os.path.exists(filepath):
        try:
            return pd.read_parquet(filepath)
        except Exception as e:
            print(f"Error reading existing parquet: {e}")
            return pd.DataFrame()
    return pd.DataFrame()

def save_incremental(df_new, filepath):
    if df_new.empty: return
    df_existing = get_existing_data(filepath)
    if not df_existing.empty:
        df_combined = pd.concat([df_existing, df_new], ignore_index=True)
    else:
        df_combined = df_new
    df_combined.to_parquet(filepath)
    print(f"Saved {len(df_new)} new rows to {filepath}. Total rows: {len(df_combined)}")

## Balanced Generation Loop

Queries the human dataset distribution and generates AI samples to match the exact `(Topic, Author)` pair weights.

In [None]:
def generate_balanced_dataset(human_df, class_label, target_total=500, output_filename="ai_data.parquet"):
    print(f"\n--- Generating Class {class_label} ({'Standard' if class_label==2 else 'Imposter'}) ---")

    output_path = os.path.join(OUTPUT_DIR, output_filename)
    existing_df = get_existing_data(output_path)
    current_count = len(existing_df) if not existing_df.empty else 0

    if not existing_df.empty:
        print(f"Found existing data with {current_count} rows.")

    if current_count >= target_total:
        print(f"Target of {target_total} reached. Skipping generation.")
        return existing_df

    distribution = human_df.groupby(['topic', 'feature_cache_author']).size().reset_index(name='count')
    pop_topics = distribution['topic'].tolist()
    pop_authors = distribution['feature_cache_author'].tolist()
    pop_weights = distribution['count'].tolist()

    generated_buffer = []

    while current_count < target_total:
        selection = random.choices(
            population=list(zip(pop_topics, pop_authors)),
            weights=pop_weights,
            k=1
        )[0]

        topic, author = selection

        if class_label == 2:
            prompt = build_prompt_class_2(topic)
            persona_used = "Generic_AI"
        else:
            prompt = build_prompt_class_3(topic, author)
            persona_used = f"Imposter_{author}"

        print(f"Generating ({current_count + 1}/{target_total}): {topic} [{persona_used}]...")

        gen_attempt = 0
        max_gen_attempts = 4
        raw_output = None

        while gen_attempt < max_gen_attempts:
            try:
                raw_output = call_ai_api(prompt)
                break
            except Exception as e:
                print(f"Generation failed: {e}")
                gen_attempt += 1
                time.sleep(2)
                continue

        if raw_output is None:
            continue

        essay_body = parse_curly_content(raw_output)
        clean_text = deep_clean_text(essay_body)

        if not clean_text or len(clean_text) < 50:
            continue

        chunks = get_chunks(clean_text)
        current_chunks = []

        for chunk in chunks:
            features = analyze_chunk(chunk)
            if features['word_count'] < 100 or features['word_count'] > 200:
                continue

            new_row = {
                "id": str(uuid.uuid4()),
                "class": class_label,
                "topic": topic,
                "text": chunk,
                "feature_cache": {
                    "author": MODEL_NAME,
                    "persona_mimicked": persona_used,
                    "word_count": features["word_count"],
                    "avg_sent_length": features["avg_sent_len"]
                }
            }
            current_chunks.append(new_row)

        generated_buffer.extend(current_chunks)

        if len(generated_buffer) >= 5:
            df_new = pd.DataFrame(generated_buffer)
            save_incremental(df_new, output_path)
            current_count += len(generated_buffer)
            generated_buffer = []

        if current_count >= target_total: break

    if generated_buffer:
        df_new = pd.DataFrame(generated_buffer)
        save_incremental(df_new, output_path)

    return get_existing_data(output_path)

## Execution

In [None]:
if __name__ == "__main__":
    if not os.path.exists(INPUT_FILE):
        print(f"Input file not found: {INPUT_FILE}\nRun the Human Data Notebook first!")
    else:
        df_human = pd.read_parquet(INPUT_FILE)

        def get_author(fc):
            if isinstance(fc, dict): return fc.get('author')
            return None

        df_human['feature_cache_author'] = df_human['feature_cache'].apply(get_author)
        df_human = df_human.dropna(subset=['feature_cache_author'])

        print(f"Loaded Human Data: {len(df_human)} rows. Author distribution:")
        print(df_human['feature_cache_author'].value_counts())

        generate_balanced_dataset(df_human, class_label=2, target_total=500, output_filename="ai_class2.parquet")
        # generate_balanced_dataset(df_human, class_label=3, target_total=500, output_filename="ai_class3.parquet")