In [None]:
# ====================================================================
# FINAL SCRIPT: 5 DIVERSE STORIES, 100 NODES EACH (N=500 TOTAL)
# Goal: Generate 5 independent narratives with high semantic and emotional diversity.
# Constraint: Each story is treated as a separate dataset (enforced by 'story_id').
# ====================================================================

# --- 1. IMPORTS & SETUP ---
import time
import random
import os
import json
import networkx as nx
from typing import List, Tuple, Union, Dict, Any
from datetime import datetime, timezone
from google import genai
from google.genai import types
from pydantic import BaseModel, Field, ValidationError

try:
    from google.colab import drive, userdata
    print("Colab environment detected.")
except ImportError:
    def userdata(): return None
    def drive(): return None
    print("Running in non-Colab environment. Checkpoint saving may be affected.")


# --- 2. CONFIGURATION PARAMETERS ---
NUM_STORIES = 5
NODES_PER_STORY = 100
TOTAL_NODES = NUM_STORIES * NODES_PER_STORY # 500 total nodes
MODEL_NAME = "gemini-2.5-flash"
BRANCHING_FACTOR = 4
CONTEXT_BUFFER_SIZE = 5
TIMING_INTERVAL = 50 # Report every 50 nodes
CHECKPOINT_INTERVAL = 100 # Save checkpoint every 100 nodes
RECENT_LINK_WINDOW_MULTIPLIER = 2
RECENT_LINK_CANDIDATES = CONTEXT_BUFFER_SIZE * RECENT_LINK_WINDOW_MULTIPLIER

# --- FILENAMES & DIRECTORY ---
DRIVE_OUTPUT_DIR = "/content/gdrive/MyDrive/MemoryGraph_Data"
RAW_GRAPH_FILENAME_BASE = "01_DIVERSE_memory_graph_5x100_FINAL.json"
QUERY_FILENAME = "02_DIVERSE_training_queries_5x100_FINAL.json"
# -------------------------


# --- 3. STORY & PERSONA DEFINITIONS ---
STORY_PROMPTS = [
    {
        "story_id": "A",
        "persona_name": "Alex",
        "core_theme": "Navigating a career promotion that requires relocating away from a best friend.",
        "start_time": "2025-01-10T15:00:00Z",
        "seed_event": "I just received a promotion offer, but it requires relocating away from my best friend. I feel a mix of elation about the career step and deep sadness about the personal loss.",
        "seed_tags": ["career achievement", "personal sacrifice", "major decision", "relocation"],
        "seed_emotion": "Conflict (Joy & Sadness)"
    },
    {
        "story_id": "B",
        "persona_name": "Maya",
        "core_theme": "Starting a new, highly-competitive PhD program while struggling with imposter syndrome and a demanding research supervisor.",
        "start_time": "2024-09-01T09:30:00Z",
        "seed_event": "The official welcome to the PhD program felt overwhelming. I'm excited by the research, but I constantly doubt if I'm smart enough to be here. My supervisor's first email was very curt.",
        "seed_tags": ["academic pressure", "imposter syndrome", "new environment", "self-doubt"],
        "seed_emotion": "Anxiety"
    },
    {
        "story_id": "C",
        "persona_name": "Liam",
        "core_theme": "Maintaining a long-distance relationship with a partner in a different time zone and planning a major move to close the distance.",
        "start_time": "2026-03-20T11:00:00Z",
        "seed_event": "Had a late-night video call with my partner, celebrating a minor visa approval. It reminds me how hard the distance is, but how worth it the future planning feels. We're setting a date for me to visit.",
        "seed_tags": ["long-distance romance", "future planning", "visa milestone", "communication"],
        "seed_emotion": "Hopeful"
    },
    {
        "story_id": "D",
        "persona_name": "Sarah",
        "core_theme": "Caring for an aging parent while balancing a full-time, emotionally draining job in nursing, leading to burnout and strain on her romantic relationship.",
        "start_time": "2025-05-15T18:00:00Z",
        "seed_event": "I had to rush my parent to the ER again, which made me late for work. My partner was frustrated that our dinner plans were ruined. I feel like I'm failing everyone.",
        "seed_tags": ["caregiving", "burnout", "relationship strain", "guilt"],
        "seed_emotion": "Exhaustion"
    },
    {
        "story_id": "E",
        "persona_name": "Ben",
        "core_theme": "Training for a marathon while managing a major home renovation that keeps hitting unexpected structural problems and delaying their move-in date.",
        "start_time": "2024-11-25T07:00:00Z",
        "seed_event": "Completed my longest run yet—20 miles! Feeling strong physically, but the contractor just sent a photo of a termite problem in the attic, setting the renovation back by a month.",
        "seed_tags": ["fitness goal", "home renovation", "setback", "physical challenge"],
        "seed_emotion": "Frustration (physical peak, logistical low)"
    }
]


# --- 4. SCHEMA DEFINITION ---
class MemoryNode(BaseModel):
    """Schema for a single event/memory node in the graph, now with story_id."""
    event_id: int = Field(description="The unique, sequential integer ID within its story (0-99).")
    global_id: int = Field(description="The unique, sequential integer ID across the entire master graph (0-499).")
    story_id: str = Field(description="The unique identifier for the story this event belongs to (A, B, C, D, or E).")
    timestamp: str = Field(description="The ISO 8601 UTC timestamp of the event, strictly later than the previous event's timestamp *in this story*.")
    event_text: str = Field(description="A short, descriptive natural language text detailing the event.")
    semantic_tags: List[str] = Field(description="A list of 2-4 semantic keywords describing the event.")
    emotional_state: str = Field(description="The dominant emotional state.")
    semantic_vec: List[float] = Field(description="Placeholder: []")
    emotional_vec: List[float] = Field(description="Placeholder: []")

    # ID alias for NetworkX compatibility
    id: int = Field(alias='id', default=None, description="Duplicate of global_id for graph visualization tools.")

    def __init__(self, **data):
        super().__init__(**data)
        if 'id' not in data or data['id'] is None:
            self.id = self.global_id


# --- 5. UTILITY FUNCTIONS ---
def adaptive_wait(base_wait: float = 0.5, max_rand: float = 1.0):
    """Introduces a small, random delay to prevent hitting burst rate limits."""
    delay = base_wait + random.random() * max_rand
    time.sleep(delay)

def setup_environment():
    """Mounts drive and loads API key securely."""
    print("Mounting Google Drive...")
    drive.mount('/content/gdrive', force_remount=True)
    os.makedirs(DRIVE_OUTPUT_DIR, exist_ok=True)
    print(f"Drive mounted. Output directory: {DRIVE_OUTPUT_DIR}")
    try:
        api_key = userdata.get('GEMINI_API_KEY')
        if not api_key:
            raise ValueError("API Key not found or is empty.")
        print("Gemini API Key loaded from Colab Secrets.")
        return api_key
    except Exception as e:
        print(f"FATAL ERROR: Failed to load GEMINI_API_KEY. Details: {e}")
        raise

def load_partial_graph(filename: str) -> Tuple[nx.DiGraph, int, Dict[str, int]]:
    """Loads existing graph data to resume generation."""
    graph_path = os.path.join(DRIVE_OUTPUT_DIR, filename)
    story_completion = {s["story_id"]: 0 for s in STORY_PROMPTS}

    if os.path.exists(graph_path):
        try:
            with open(graph_path, 'r') as f:
                data = json.load(f)

            G = nx.node_link_graph(data)

            # Recalculate story_completion from loaded nodes
            if G.nodes:
                for node_data in G.nodes.values():
                    story_id = node_data.get('story_id')
                    event_id = node_data.get('event_id')
                    if story_id and event_id is not None:
                        # Find the highest event_id for each story
                        story_completion[story_id] = max(story_completion.get(story_id, 0), event_id)

            # Find the next global ID to use
            global_event_id = max(G.nodes) + 1 if G.nodes else 0

            print(f"\n--- RESUMING GENERATION (FILE: {filename}) ---")
            print(f"Loaded {G.number_of_nodes()} existing nodes. Starting Global ID at {global_event_id}.")
            print(f"Story Progress (highest local event_id): {story_completion}")
            return G, global_event_id, story_completion

        except json.JSONDecodeError:
            print(f"WARNING: File corrupted. Restarting from Global ID 0.")
            return nx.DiGraph(), 0, story_completion

    return nx.DiGraph(), 0, story_completion

def serialize_data(graph: nx.DiGraph, query_pool: List[str]):
    """Saves the RAW graph and query pool to the mounted Drive."""
    graph_data = nx.node_link_data(graph)
    graph_path = os.path.join(DRIVE_OUTPUT_DIR, RAW_GRAPH_FILENAME_BASE)
    with open(graph_path, 'w') as f:
        json.dump(graph_data, f, indent=4)

    if query_pool:
        print(f"\nRAW Graph data (NO Embeddings) saved to: {graph_path}")
        query_path = os.path.join(DRIVE_OUTPUT_DIR, QUERY_FILENAME)
        with open(query_path, 'w') as f:
            json.dump(query_pool, f, indent=4)
        print(f"Query pool saved to: {query_path}")

def generate_complex_queries(count: int) -> List[str]:
    """Placeholder for complex query generation."""
    return [f"Query {i} about semantic conflicts and emotional states across multiple stories." for i in range(count)]


# --- 6. LLM GENERATION FUNCTION ---
def generate_contextual_node(
    global_id: int,
    local_event_id: int,
    story_config: Dict[str, Any],
    context_buffer: List[MemoryNode]
) -> Tuple[Union[MemoryNode, None], float]:
    """Generates a new MemoryNode for a specific story."""
    global client

    story_id = story_config["story_id"]
    persona_name = story_config["persona_name"]
    last_timestamp_str = context_buffer[-1].timestamp if context_buffer else story_config["start_time"]

    # Format the context history for the model
    context_history = "\n".join([f"ID {node.event_id} ({node.timestamp}): {node.event_text} [Emotion: {node.emotional_state}]" for node in context_buffer])

    # 1. Construct the detailed system prompt
    system_prompt = f"""
    You are an event generator for a complex graph-based memory system.
    Your task is to generate the next chronological event (Local Event ID {local_event_id}) in the life of the user, '{persona_name}'.

    # Persona and Narrative (Story {story_id}):
    The core theme is: '{story_config["core_theme"]}'. The events must stay focused on this core theme, evolving realistically, and spanning a wide time range over 100 events.

    # Strict Output Requirements:
    1. **Format:** Output must be a single, VALID JSON object adhering to the Pydantic schema.
    2. **Chronology:** The 'timestamp' MUST be a valid ISO 8601 UTC string and MUST be chronologically *later* than the last event's timestamp: {last_timestamp_str}. **Advance the time realistically, ensuring temporal diversity (both minutes and months).**
    3. **Identifiers:** 'event_id' must be {local_event_id}. 'global_id' must be {global_id}. 'story_id' must be '{story_id}'.
    4. **Placeholders:** 'semantic_vec' and 'emotional_vec' must each be an empty list: `[]`.

    # Context History (Last {len(context_buffer)} Events for {persona_name}):
    {context_history}
    """

    # 2. Define the user prompt for the call
    user_prompt = f"Given the history for {persona_name} (Story {story_id}), generate Event ID {local_event_id}. The previous timestamp was {last_timestamp_str}. Advance the story related to the main theme."

    # 3. API Call Configuration
    config = types.GenerateContentConfig(
        system_instruction=system_prompt,
        response_mime_type="application/json",
        response_schema=MemoryNode.model_json_schema(),
    )

    max_retries = 3
    for attempt in range(max_retries):
        try:
            start_api_call = time.time()
            response = client.models.generate_content(
                model=MODEL_NAME,
                contents=user_prompt,
                config=config,
            )
            elapsed_time = time.time() - start_api_call

            # 4. Validation & Chronology Check
            raw_json = response.text
            node_dict = json.loads(raw_json)
            new_node = MemoryNode(**node_dict)

            # Additional Chronology Check
            last_dt = datetime.fromisoformat(last_timestamp_str.replace('Z', '+00:00'))
            new_dt = datetime.fromisoformat(new_node.timestamp.replace('Z', '+00:00'))

            if new_dt <= last_dt:
                raise ValueError(f"Chronology Error: New timestamp ({new_dt}) is not later than previous ({last_dt}).")

            if new_node.event_id != local_event_id or new_node.global_id != global_id or new_node.story_id != story_id:
                raise ValueError(f"ID/Story Mismatch Error. Expected (E:{local_event_id}, G:{global_id}, S:{story_id}) but got (E:{new_node.event_id}, G:{new_node.global_id}, S:{new_node.story_id}).")

            # Perform adaptive wait before returning
            adaptive_wait()
            return new_node, elapsed_time

        except (ValidationError, ValueError, json.JSONDecodeError) as e:
            # Handle validation/data errors
            print(f"Attempt {attempt + 1}/{max_retries} failed for Story {story_id}, ID {local_event_id}. Error: {type(e).__name__} - {e}")
            if attempt == max_retries - 1:
                return None, 0.0 # Final failure

        except Exception as e:
            # Handle API/Network errors with Exponential Backoff
            sleep_time = (2 ** attempt) + random.uniform(0, 1) # Exponential backoff: 1s, 2s, 4s, 8s, etc.
            if attempt < max_retries - 1:
                 print(f"API Error (likely 503/429) for Story {story_id}, ID {local_event_id}. Retrying in {sleep_time:.2f}s...")
                 time.sleep(sleep_time)
            else:
                print(f"FINAL API FAILURE for Story {story_id}, ID {local_event_id}. Giving up after {max_retries} attempts.")
                return None, 0.0


# --- 7. MAIN GRAPH GENERATION LOGIC ---
def generate_and_build_full_graph():
    """Generates all nodes and links across all stories with checkpointing."""

    G, current_global_id, story_progress = load_partial_graph(RAW_GRAPH_FILENAME_BASE)
    llm_call_times = []

    print(f"\n--- Starting LLM Generation of {TOTAL_NODES} Nodes (5 DIVERSE STORIES) ---")
    start_llm_time = time.time()

    # Iterate through each defined story
    for story_config in STORY_PROMPTS:
        story_id = story_config["story_id"]
        persona_name = story_config["persona_name"]

        # Determine starting point for this story (0 if fresh, or resume point)
        start_local_id = story_progress.get(story_id, 0)

        # Skip if the story is already completed
        if start_local_id >= NODES_PER_STORY:
            print(f"\n--- SKIP: Story {story_id} ({persona_name}) is already complete ({NODES_PER_STORY} nodes).")
            continue

        print(f"\n### BEGIN STORY {story_id}: {persona_name} (Starting at Local ID {start_local_id}, Global ID {current_global_id}) ###")

        # Context buffer will be nodes from the current story only
        context_buffer: List[MemoryNode] = []

        # 1. Insert Initial Narrative Event (Node 0) if necessary
        if start_local_id == 0:
            print(f" -> Inserting initial narrative event (Node 0) for {persona_name}...")
            initial_event = MemoryNode(
                event_id=0, global_id=current_global_id, story_id=story_id,
                timestamp=story_config["start_time"],
                event_text=story_config["seed_event"],
                semantic_tags=story_config["seed_tags"],
                emotional_state=story_config["seed_emotion"],
                semantic_vec=[], emotional_vec=[],
            )
            G.add_node(initial_event.global_id, **initial_event.model_dump())
            context_buffer.append(initial_event)
            current_global_id += 1 # Increment global ID after adding Node 0
            start_local_id = 1

        else: # If resuming, rebuild context buffer from last few saved nodes
            # Find the last N nodes from THIS story only to build the buffer

            # --- CONTEXT BUFFER RECONSTRUCTION (If resuming) ---
            current_story_nodes = [data for _, data in G.nodes(data=True) if data.get('story_id') == story_id]
            current_story_nodes.sort(key=lambda x: x.get('global_id', 0))

            # Use the last CONTEXT_BUFFER_SIZE nodes of this story
            context_nodes = current_story_nodes[-CONTEXT_BUFFER_SIZE:]
            context_buffer.extend([MemoryNode(**data) for data in context_nodes])

            # Find the current global ID based on the highest global_id added so far
            current_global_id = max(G.nodes) + 1 if G.nodes else 0


        # 2. Start node generation for the current story
        for local_event_id in range(start_local_id, NODES_PER_STORY): # Generates nodes 1 through 99

            node_data, node_time = generate_contextual_node(
                global_id=current_global_id,
                local_event_id=local_event_id,
                story_config=story_config,
                context_buffer=context_buffer
            )

            if node_data:
                llm_call_times.append(node_time)

                # Add node to the master graph
                G.add_node(node_data.global_id, **node_data.model_dump())

                # Link Creation (Target links backward to previous nodes *IN THIS STORY ONLY*)
                num_links = random.randint(1, BRANCHING_FACTOR)

                # Target candidates are only the nodes that share the same story_id and have a local_event_id less than the current node
                # Note: We must use the global_id for NetworkX edge creation
                target_candidates = [
                    data['global_id'] for _, data in G.nodes(data=True)
                    if data.get('story_id') == story_id and data.get('event_id') < local_event_id
                ]

                # Focus on the most recent nodes in this story's context window (local causality)
                if len(target_candidates) > RECENT_LINK_CANDIDATES:
                    source_candidates = target_candidates[-RECENT_LINK_CANDIDATES:]
                else:
                    source_candidates = target_candidates

                if source_candidates:
                    source_nodes = random.sample(source_candidates, min(num_links, len(source_candidates)))
                    for source_global_id in source_nodes:
                        # Add a directed edge (source -> target) with a placeholder weight
                        G.add_edge(source_global_id, node_data.global_id, weight=random.random())

                # Update context buffer
                context_buffer.append(node_data)
                context_buffer = context_buffer[-CONTEXT_BUFFER_SIZE:]

                # CRITICAL: Increment Global ID for the next node
                current_global_id += 1

            else:
                print(f"Skipping Node ID {local_event_id} due to LLM failure. Stopping generation for Story {story_id}.")
                break # Stop this story's generation on API/validation failure

            # --- PROGRESS & CHECKPOINT REPORTING ---

            if len(llm_call_times) > 0 and (current_global_id) % TIMING_INTERVAL == 0:
                # Calculate average time using the last 50 calls
                if len(llm_call_times) >= TIMING_INTERVAL:
                    avg_time_per_node = sum(llm_call_times[-TIMING_INTERVAL:]) / TIMING_INTERVAL
                else:
                    avg_time_per_node = sum(llm_call_times) / len(llm_call_times)

                nodes_remaining = TOTAL_NODES - current_global_id
                time_remaining_minutes = (nodes_remaining * avg_time_per_node) / 60

                print(f"    Generated {G.number_of_nodes()}/{TOTAL_NODES} total nodes. (Current Story: {story_id})")
                print(f"    Average time per node (last {TIMING_INTERVAL}): {avg_time_per_node:.2f}s")
                print(f"    Estimated Time Remaining: {time_remaining_minutes:.1f} minutes")

            # --- CRITICAL CHECKPOINT SAVE ---
            # Save if we hit a checkpoint, or if we have finished all nodes
            if current_global_id > 0 and (current_global_id % CHECKPOINT_INTERVAL == 0 or current_global_id == TOTAL_NODES):
                serialize_data(G, [])
                print(f"\n    CHECKPOINT: Graph saved successfully at Global Node {current_global_id-1}.")

    end_llm_time = time.time()
    print(f"\nLLM Generation Complete. Total Nodes Created: {G.number_of_nodes()}/{TOTAL_NODES}")
    print(f"Time for LLM Generation and Linking: {end_llm_time - start_llm_time:.2f} seconds")

    return G


# ====================================================================
# 8. MAIN EXECUTION
# ====================================================================

if __name__ == "__main__":

    try:
        # 1. Setup Environment
        GEMINI_API_KEY = setup_environment()
        client = genai.Client(api_key=GEMINI_API_KEY)
        print("Client successfully configured.")

        # 2. Run Full Generation Pipeline
        start_time = time.time()
        final_memory_graph = generate_and_build_full_graph()

        if final_memory_graph.number_of_nodes() < TOTAL_NODES:
            print(f"Script completed but only generated {final_memory_graph.number_of_nodes()}/{TOTAL_NODES} total nodes.")

        # 3. Generate and Save Final Data
        query_pool = generate_complex_queries(TOTAL_NODES)
        serialize_data(final_memory_graph, query_pool)

        end_time = time.time()
        print("PIPELINE COMPLETE: DIVERSE 5x100 Data Saved.")
        print(f"Total time for Pipeline: {end_time - start_time:.2f} seconds")

    except Exception as e:
        print(f"\nSCRIPT ERROR: {type(e).__name__} - {e}")

Colab environment detected.
Mounting Google Drive...
Mounted at /content/gdrive
Drive mounted. Output directory: /content/gdrive/MyDrive/MemoryGraph_Data
Gemini API Key loaded from Colab Secrets.
Gemini client successfully configured.

--- Starting LLM Generation of 500 Nodes (5 DIVERSE STORIES) ---

### BEGIN STORY A: Alex (Starting at Local ID 0, Global ID 0) ###
 -> Inserting initial narrative event (Node 0) for Alex...
 -> Generated 50/500 total nodes. (Current Story: A)
    Average time per node (last 50): 4.52s
    **Estimated Time Remaining: 33.9 minutes**
 -> Generated 100/500 total nodes. (Current Story: A)
    Average time per node (last 50): 4.73s
    **Estimated Time Remaining: 31.5 minutes**

*** CHECKPOINT: Graph saved successfully at Global Node 99. ***

### BEGIN STORY B: Maya (Starting at Local ID 0, Global ID 100) ###
 -> Inserting initial narrative event (Node 0) for Maya...
 -> Generated 150/500 total nodes. (Current Story: B)
    Average time per node (last 50): 4.

In [None]:
# ====================================================================
# PHASE 2: FEATURE ENGINEERING AND DATASET CREATION
# FEATURES: Time_Closeness, Semantic_Similarity, Emotional_Alignment
# ====================================================================

import os
import json
import time
import random
import networkx as nx
import numpy as np
import pandas as pd
from datetime import datetime
from typing import List, Dict, Tuple, Any, Union

from google import genai
from google.genai import types
from google.colab import drive, userdata
# --------------------------

# --- CONFIGURATION PARAMETERS ---
EMBEDDING_MODEL = "text-embedding-004"
VECTOR_DIMENSION = 768
NEGATIVE_SAMPLE_RATIO = 2.0
NEGATIVE_SAMPLE_ATTEMPTS = 5
EMBEDDING_CHUNK_SIZE = 100

# Directory and filenames
DRIVE_OUTPUT_DIR = "/content/gdrive/MyDrive/MemoryGraph_Data"
RAW_GRAPH_FILENAME_BASE = "01_DIVERSE_memory_graph_5x100_FINAL.json"
EMBEDDED_GRAPH_FILENAME = "02_EMBEDDED_memory_graph_5x100_FINAL.json"
FINAL_DATASET_FILENAME = "03_CORE_training_dataset_FINAL.csv"

# --- UTILITY SETUP ---
def setup_environment():
    """Mounts drive and loads API key securely."""
    print("Mounting Google Drive...")
    drive.mount('/content/gdrive', force_remount=True)
    os.makedirs(DRIVE_OUTPUT_DIR, exist_ok=True)
    try:
        api_key = userdata.get('GEMINI_API_KEY')
        if not api_key:
            raise ValueError("API Key not found or is empty.")
        print("Gemini API Key loaded from Colab Secrets.")
        return api_key
    except Exception as e:
        print(f"FATAL ERROR: Failed to load GEMINI_API_KEY. Details: {e}")
        raise

def load_graph(filename: str) -> nx.DiGraph:
    """Loads the graph from the specified file path (FIXED KeyError)."""
    graph_path = os.path.join(DRIVE_OUTPUT_DIR, filename)
    if not os.path.exists(graph_path):
        raise FileNotFoundError(f"Input graph file not found at: {graph_path}")

    with open(graph_path, 'r') as f:
        data = json.load(f)

    try:
        node_count = len(data['nodes'])
    except KeyError:
        node_count = len(data.get('node-data', []))

    print(f"Loaded graph from {filename}. Total nodes: {node_count}.")

    return nx.node_link_graph(data)

def save_graph(graph: nx.DiGraph, filename: str):
    """Saves the graph to the specified file path."""
    graph_data = nx.node_link_data(graph)
    graph_path = os.path.join(DRIVE_OUTPUT_DIR, filename)
    with open(graph_path, 'w') as f:
        json.dump(graph_data, f, indent=4)
    print(f"\n✅ Graph saved successfully to: {graph_path}")

def save_dataframe(df: pd.DataFrame, filename: str):
    """Saves the final DataFrame to the specified file path."""
    df_path = os.path.join(DRIVE_OUTPUT_DIR, filename)
    df.to_csv(df_path, index=False)
    print(f"\n✅ FINAL DATASET saved successfully to: {df_path}")

# --- FEATURE CALCULATION ---

def calculate_temporal_closeness(timestamp_source: str, timestamp_target: str) -> float:
    """Calculates Time_Closeness."""
    dt_source = datetime.fromisoformat(timestamp_source.replace('Z', '+00:00'))
    dt_target = datetime.fromisoformat(timestamp_target.replace('Z', '+00:00'))
    time_diff_seconds = abs((dt_target - dt_source).total_seconds())
    return 1.0 / (time_diff_seconds + 1.0)

def cosine_similarity(vec_a: List[float], vec_b: List[float]) -> float:
    """Calculates cosine similarity."""
    a = np.array(vec_a)
    b = np.array(vec_b)
    dot_product = np.dot(a, b)
    norm_a = np.linalg.norm(a)
    norm_b = np.linalg.norm(b)
    if norm_a == 0 or norm_b == 0:
        return 0.0
    return dot_product / (norm_a * norm_b)

def calculate_feature_vector(node_a_data: Dict[str, Any], node_b_data: Dict[str, Any]) -> Dict[str, float]:
    """Calculates all three CORE features using descriptive names."""

    time_closeness = calculate_temporal_closeness(node_a_data['timestamp'], node_b_data['timestamp'])
    semantic_similarity = cosine_similarity(node_a_data['semantic_vec'], node_b_data['semantic_vec'])
    emotional_alignment = cosine_similarity(node_a_data['emotional_vec'], node_b_data['emotional_vec'])

    return {
        'Time_Closeness': time_closeness,
        'Semantic_Similarity': semantic_similarity,
        'Emotional_Alignment': emotional_alignment,
    }

# --- BATCHING HELPER ---
def chunk_list(lst: List[Any], n: int):
    """Helper to yield successive n-sized chunks from a list."""
    for i in range(0, len(lst), n):
        yield lst[i:i + n]

# ====================================================================
# STEP 1: EMBEDDING GENERATION
# ====================================================================

def generate_embeddings(G: nx.DiGraph) -> nx.DiGraph:
    """
    Generates semantic and emotional embeddings using micro-batching.
    """
    global client

    batch_input_texts = []
    vector_map = []

    for node_id, data in G.nodes(data=True):
        if not data.get('semantic_vec') or len(data['semantic_vec']) != VECTOR_DIMENSION:

            # 1. Semantic Embedding Prompt
            semantic_prompt = f"Event: {data['event_text']}. Tags: {', '.join(data['semantic_tags'])}"
            batch_input_texts.append(semantic_prompt)
            vector_map.append({'node_id': node_id, 'type': 'semantic'})

            # 2. Emotional Embedding Prompt
            emotional_prompt = data['emotional_state']
            batch_input_texts.append(emotional_prompt)
            vector_map.append({'node_id': node_id, 'type': 'emotional'})


    if not batch_input_texts:
        print("Node enrichment skipped: All nodes already contain embeddings.")
        return G

    total_inputs = len(batch_input_texts)
    num_batches = int(np.ceil(total_inputs / EMBEDDING_CHUNK_SIZE))
    print(f"\n--- Starting MICRO-BATCH Embedding Generation for {total_inputs} total vectors ({total_inputs/2:.0f} nodes) ---")
    start_time = time.time()

    all_embeddings = []

    # Loop through chunks of the input texts (Batching)
    for i, chunk in enumerate(chunk_list(batch_input_texts, EMBEDDING_CHUNK_SIZE)):
        print(f"Processing micro-batch {i+1} of {num_batches} (Size: {len(chunk)})...")
        try:
            # API call for the current chunk
            batch_result = client.models.embed_content(
                model=EMBEDDING_MODEL,
                contents=chunk,
                # Task type for similarity comparison
                config=types.EmbedContentConfig(task_type="SEMANTIC_SIMILARITY")
            )

            # Extend the master list of all embeddings
            all_embeddings.extend([np.array(e.values) for e in batch_result.embeddings])

        except Exception as e:
            print(f"FATAL MICRO-BATCH ERROR in chunk {i+1}: {e}. Stopping script.")
            return G

    # Verification check after all batches are processed
    if len(all_embeddings) != total_inputs:
        raise RuntimeError(f"Total embeddings returned ({len(all_embeddings)}) does not match expected ({total_inputs}).")

    # Map the results back to the graph nodes
    for i, vector in enumerate(all_embeddings):
        map_entry = vector_map[i]
        node_id = map_entry['node_id']
        vec_type = map_entry['type']

        # Convert NumPy array back to list for JSON serialization
        vector_list = vector.tolist()

        # Update the correct field in the graph
        if vec_type == 'semantic':
            G.nodes[node_id]['semantic_vec'] = vector_list
        elif vec_type == 'emotional':
            G.nodes[node_id]['emotional_vec'] = vector_list

    elapsed = time.time() - start_time
    print(f"\n✅ All {total_inputs/2:.0f} nodes enriched across {num_batches} batches in {elapsed:.2f} seconds.")
    print(f"   (Average time per node: {elapsed / (total_inputs/2):.2f}s)")

    return G


# ====================================================================
# STEP 2: SAMPLE GENERATION (POSITIVE & NEGATIVE)
# ====================================================================

def create_training_dataset(G: nx.DiGraph, ratio: float) -> pd.DataFrame:
    """
    Creates the final training dataset (Linked=1 and Linked=0 samples).
    """

    all_samples = []

    story_nodes: Dict[str, List[int]] = {}
    for node_id, data in G.nodes(data=True):
        story_id = data.get('story_id')
        if story_id:
            if story_id not in story_nodes:
                story_nodes[story_id] = []
            story_nodes[story_id].append(node_id)

    print(f"\n--- Starting Sample Generation (Ratio Linked=0/Linked=1: {ratio:.1f}) ---")

    # ----------------------------------------
    # POSITIVE SAMPLES (Linked=1)
    # ----------------------------------------
    positive_count = 0

    for source_id, target_id in G.edges():
        source_data = G.nodes[source_id]
        target_data = G.nodes[target_id]

        if source_data['story_id'] != target_data['story_id']:
            continue

        features = calculate_feature_vector(source_data, target_data)
        # RENAMED TARGET VARIABLE: Y -> Linked
        features.update({'Linked': 1, 'story_id': source_data['story_id'], 'Source_ID': source_id, 'Target_ID': target_id})
        all_samples.append(features)
        positive_count += 1

    print(f" Generated {positive_count} Positive Samples (Linked=1).")

    # ----------------------------------------
    # PART 2: NEGATIVE SAMPLES (Linked=0)
    # ----------------------------------------

    target_negative_count = int(positive_count * ratio)
    negative_count = 0

    print(f" Target Negative Samples (Linked=0): {target_negative_count}")

    while negative_count < target_negative_count:

        story_id = random.choice(list(story_nodes.keys()))
        current_story_nodes = story_nodes[story_id]

        if len(current_story_nodes) < 2: continue

        for _ in range(NEGATIVE_SAMPLE_ATTEMPTS):
            source_id, target_id = random.sample(current_story_nodes, 2)

            if G.nodes[source_id]['global_id'] >= G.nodes[target_id]['global_id']:
                 source_id, target_id = target_id, source_id

            if not G.has_edge(source_id, target_id):

                source_data = G.nodes[source_id]
                target_data = G.nodes[target_id]

                features = calculate_feature_vector(source_data, target_data)
                features.update({'Linked': 0, 'story_id': story_id, 'Source_ID': source_id, 'Target_ID': target_id})
                all_samples.append(features)
                negative_count += 1

                if negative_count % 500 == 0:
                    print(f"   -> Generated {negative_count}/{target_negative_count} negative samples...")

                break

    print(f" Final Negative Samples Generated: {negative_count}")
    print(f"TOTAL Samples Generated: {positive_count + negative_count}")

    # ----------------------------------------
    # PART 3: CREATE DATAFRAME
    # ----------------------------------------

    df = pd.DataFrame(all_samples)

    # Shuffle the dataset before saving
    df = df.sample(frac=1).reset_index(drop=True)

    return df


if __name__ == "__main__":

    try:
        # 1. Setup Environment
        GEMINI_API_KEY = setup_environment()
        client = genai.Client(api_key=GEMINI_API_KEY)

        # 2. Load Raw Graph
        G_raw = load_graph(RAW_GRAPH_FILENAME_BASE)

        # 3. STEP 1: Node Enrichment (Batch Embedding) - Uses Micro-Batching
        G_embedded = generate_embeddings(G_raw)
        save_graph(G_embedded, EMBEDDED_GRAPH_FILENAME)

        # 4. STEP 2: Feature Calculation and Sample Generation
        final_dataframe = create_training_dataset(G_embedded, NEGATIVE_SAMPLE_RATIO)

        # 5. Save Final Dataset
        save_dataframe(final_dataframe, FINAL_DATASET_FILENAME)

        print("\nDATASET CREATION PIPELINE COMPLETE. Proceed to CORE Model Training.")

    except Exception as e:
        print(f"\SCRIPT ERROR in Feature Engineering: {type(e).__name__} - {e}")

Mounting Google Drive...
Mounted at /content/gdrive
Gemini API Key loaded from Colab Secrets.
Loaded graph from 01_DIVERSE_memory_graph_5x100_FINAL.json. Total nodes: 500.

--- Starting MICRO-BATCH Embedding Generation for 1000 total vectors (500 nodes) ---
   -> Processing micro-batch 1 of 10 (Size: 100)...
   -> Processing micro-batch 2 of 10 (Size: 100)...
   -> Processing micro-batch 3 of 10 (Size: 100)...
   -> Processing micro-batch 4 of 10 (Size: 100)...
   -> Processing micro-batch 5 of 10 (Size: 100)...
   -> Processing micro-batch 6 of 10 (Size: 100)...
   -> Processing micro-batch 7 of 10 (Size: 100)...
   -> Processing micro-batch 8 of 10 (Size: 100)...
   -> Processing micro-batch 9 of 10 (Size: 100)...
   -> Processing micro-batch 10 of 10 (Size: 100)...

✅ All 500 nodes enriched across 10 batches in 8.61 seconds.
   (Average time per node: 0.02s)

✅ Graph saved successfully to: /content/gdrive/MyDrive/MemoryGraph_Data/02_EMBEDDED_memory_graph_5x100_FINAL.json

--- Starti