In [1]:
import numpy as np
import math
import random
import time
from typing import List, Dict, Any, Optional, Tuple
from dataclasses import dataclass, field
import json
import os
import re

In [2]:
try:
    import requests
    HUGGINGFACE_AVAILABLE = True
except ImportError:
    HUGGINGFACE_AVAILABLE = False

In [3]:
@dataclass
class StoryState:
    """Represents a state in story planning"""
    characters: Dict[str, Dict[str, Any]]
    location: str
    plot_points: List[str]
    tension_level: float

    def to_string(self) -> str:
        """Convert state to string for LLM"""
        desc = f"Location: {self.location}\n"
        desc += "Characters:\n"
        for char, attrs in self.characters.items():
            desc += f"  {char}: "
            desc += ", ".join([f"{k}={v}" for k, v in attrs.items()])
            desc += "\n"
        desc += f"Plot points completed: {', '.join(self.plot_points) if self.plot_points else 'None'}\n"
        desc += f"Tension level: {self.tension_level:.1f}"
        return desc

    def copy(self):
        """Deep copy of state"""
        return StoryState(
            characters={k: v.copy() for k, v in self.characters.items()},
            location=self.location,
            plot_points=self.plot_points.copy(),
            tension_level=self.tension_level
        )

@dataclass
class StoryAction:
    """Represents an action in story"""
    action_type: str  # "dialogue", "move", "conflict", "reveal"
    actor: str
    target: Optional[str] = None
    content: str = ""

    def to_string(self) -> str:
        if self.target:
            return f"{self.actor} {self.action_type} with {self.target}: {self.content}"
        return f"{self.actor} {self.action_type}: {self.content}"

@dataclass
class StoryGoal:
    """Defines the target story state"""
    required_plot_points: List[str]
    target_relationships: Dict[str, str]
    target_location: Optional[str] = None
    min_tension_arc: bool = True

In [4]:
class ActualLLMWorldModel:
    """World model using real LLM for story planning"""

    def __init__(self, model_type="openai", api_key=None, hf_token=None):
        """
        Initialize LLM model
        model_type: "openai", "huggingface", or "hybrid"
        """
        self.model_type = model_type
        self.locations = ["village", "castle", "forest", "home", "mountain", "river"]

        if model_type == "openai" and OPENAI_AVAILABLE:
            # Use environment variable or provided key
            api_key = api_key or os.getenv("OPENAI_API_KEY")
            if api_key:
                self.client = OpenAI(api_key=api_key)
                self.model_name = "gpt-3.5-turbo"
            else:
                print("Warning: No OpenAI API key provided. Falling back to hybrid mode.")
                self.model_type = "hybrid"

        elif model_type == "huggingface":
            self.hf_token = hf_token or os.getenv("HF_TOKEN")
            self.hf_model = "mistralai/Mixtral-8x7B-Instruct-v0.1"
            self.api_url = f"https://api-inference.huggingface.co/models/{self.hf_model}"

        else:
            # Hybrid mode - use heuristics with simulated LLM behavior
            self.model_type = "hybrid"
            print("Using hybrid mode (heuristics + simulated LLM)")

    def get_valid_actions(self, state: StoryState, goal: StoryGoal) -> List[StoryAction]:
        """Generate valid actions using LLM"""

        if self.model_type == "openai" and OPENAI_AVAILABLE:
            return self._get_actions_openai(state, goal)
        elif self.model_type == "huggingface":
            return self._get_actions_huggingface(state, goal)
        else:
            return self._get_actions_hybrid(state, goal)

    def _get_actions_openai(self, state: StoryState, goal: StoryGoal) -> List[StoryAction]:
        """Get actions from OpenAI API"""
        prompt = f"""You are a story planner. Given the current state and goal, suggest 5-10 possible story actions.

Current State:
{state.to_string()}

Goal:
- Required plot points: {goal.required_plot_points}
- Target relationships: {goal.target_relationships}
- Target location: {goal.target_location}

Generate actions in this exact format (one per line):
[actor]|[action_type]|[target]|[content]

Where action_type is one of: dialogue, move, conflict, reveal
Example: Hero|dialogue|Mentor|asks for guidance

Focus on actions that progress toward the goal."""

        try:
            response = self.client.chat.completions.create(
                model=self.model_name,
                messages=[{"role": "user", "content": prompt}],
                temperature=0.7,
                max_tokens=300
            )

            return self._parse_actions(response.choices[0].message.content)

        except Exception as e:
            print(f"OpenAI API error: {e}")
            return self._get_actions_hybrid(state, goal)

    def _get_actions_huggingface(self, state: StoryState, goal: StoryGoal) -> List[StoryAction]:
        """Get actions from HuggingFace Inference API"""
        prompt = f"""Generate 5 story actions for:
State: {state.to_string()}
Goal: {goal.required_plot_points}
Format: actor|action_type|target|content"""

        headers = {"Authorization": f"Bearer {self.hf_token}"} if self.hf_token else {}

        try:
            response = requests.post(
                self.api_url,
                headers=headers,
                json={"inputs": prompt, "parameters": {"max_new_tokens": 200}}
            )

            if response.status_code == 200:
                text = response.json()[0]["generated_text"]
                return self._parse_actions(text)
            else:
                return self._get_actions_hybrid(state, goal)

        except Exception as e:
            print(f"HuggingFace API error: {e}")
            return self._get_actions_hybrid(state, goal)

    def _get_actions_hybrid(self, state: StoryState, goal: StoryGoal) -> List[StoryAction]:
        """Hybrid approach: Generate actions using heuristics"""
        actions = []
        chars = list(state.characters.keys())

        # Priority 1: Goal-achieving actions
        if goal.target_location and state.location != goal.target_location:
            for actor in chars[:2]:  # Limit for efficiency
                actions.append(StoryAction(
                    action_type="move",
                    actor=actor,
                    content=f"travels to {goal.target_location}"
                ))

        # Priority 2: Plot point actions
        for required in goal.required_plot_points:
            if required not in state.plot_points:
                if "revelation" in required:
                    for actor in chars[:2]:
                        actions.append(StoryAction(
                            action_type="reveal",
                            actor=actor,
                            content="makes important discovery"
                        ))
                elif "conflict" in required:
                    parts = required.split("_")
                    if len(parts) >= 3:
                        actor, target = parts[1], parts[2]
                        if actor in chars and target in chars:
                            actions.append(StoryAction(
                                action_type="conflict",
                                actor=actor,
                                target=target,
                                content="confronts"
                            ))

        # Priority 3: Relationship actions
        for char, target_rel in goal.target_relationships.items():
            if char in state.characters:
                for other in chars:
                    if other != char:
                        if target_rel == "friendly":
                            actions.append(StoryAction(
                                action_type="dialogue",
                                actor=char,
                                target=other,
                                content="asks for help"
                            ))
                        elif target_rel == "tense":
                            actions.append(StoryAction(
                                action_type="dialogue",
                                actor=char,
                                target=other,
                                content="confronts"
                            ))

        # Add some variety
        if len(actions) < 5:
            for actor in chars[:2]:
                for target in chars[:2]:
                    if target != actor and len(actions) < 10:
                        actions.append(StoryAction(
                            action_type="dialogue",
                            actor=actor,
                            target=target,
                            content="reveals secret"
                        ))

        return actions[:10]

    def _parse_actions(self, text: str) -> List[StoryAction]:
        """Parse LLM response into StoryAction objects"""
        actions = []
        lines = text.strip().split('\n')

        for line in lines:
            line = line.strip()
            if '|' in line:
                parts = line.split('|')
                if len(parts) >= 3:
                    actor = parts[0].strip()
                    action_type = parts[1].strip()
                    target = parts[2].strip() if len(parts) > 2 and parts[2].strip() else None
                    content = parts[3].strip() if len(parts) > 3 else ""

                    # Validate action_type
                    if action_type in ["dialogue", "move", "conflict", "reveal"]:
                        actions.append(StoryAction(
                            actor=actor,
                            action_type=action_type,
                            target=target if target and target != "None" else None,
                            content=content
                        ))

        # If parsing fails, return hybrid fallback
        if not actions:
            print("Failed to parse LLM output, using fallback")
            return self._get_actions_hybrid(StoryState({}, "village", [], 0.5),
                                           StoryGoal([], {}))

        return actions

    def calculate_action_likelihood(self, state: StoryState, action: StoryAction) -> float:
        """Calculate r1: likelihood of action using LLM"""

        if self.model_type == "openai" and OPENAI_AVAILABLE:
            return self._calculate_likelihood_openai(state, action)
        elif self.model_type == "huggingface":
            return self._calculate_likelihood_huggingface(state, action)
        else:
            return self._calculate_likelihood_hybrid(state, action)

    def _calculate_likelihood_openai(self, state: StoryState, action: StoryAction) -> float:
        """Get action likelihood from OpenAI"""
        prompt = f"""Rate the narrative coherence of this story action on a scale of 0.0 to 1.0.

Current State:
{state.to_string()}

Proposed Action:
{action.to_string()}

Consider:
1. Character consistency (would this character do this?)
2. Plot progression (does this advance the story?)
3. Narrative logic (does this make sense?)

Respond with ONLY a number between 0.0 and 1.0."""

        try:
            response = self.client.chat.completions.create(
                model=self.model_name,
                messages=[{"role": "user", "content": prompt}],
                temperature=0.3,
                max_tokens=10
            )

            text = response.choices[0].message.content.strip()
            # Extract number from response
            import re
            match = re.search(r'[0-9]+\.?[0-9]*', text)
            if match:
                score = float(match.group())
                return min(max(score, 0.0), 1.0)  # Clamp to [0,1]
            return 0.5  # Default

        except Exception as e:
            print(f"Likelihood calculation error: {e}")
            return self._calculate_likelihood_hybrid(state, action)

    def _calculate_likelihood_huggingface(self, state: StoryState, action: StoryAction) -> float:
        """Calculate likelihood using HuggingFace"""
        prompt = f"Rate 0-1: {action.to_string()} given {state.to_string()[:100]}"

        # Simplified scoring via HF
        headers = {"Authorization": f"Bearer {self.hf_token}"} if self.hf_token else {}

        try:
            response = requests.post(
                self.api_url,
                headers=headers,
                json={"inputs": prompt, "parameters": {"max_new_tokens": 10}}
            )

            if response.status_code == 200:
                text = response.json()[0]["generated_text"]
                match = re.search(r'[0-9]+\.?[0-9]*', text)
                if match:
                    return min(max(float(match.group()), 0.0), 1.0)

            return self._calculate_likelihood_hybrid(state, action)

        except:
            return self._calculate_likelihood_hybrid(state, action)

    def _calculate_likelihood_hybrid(self, state: StoryState, action: StoryAction) -> float:
        """Heuristic-based likelihood calculation"""
        score = 0.5  # Base score

        # Character role consistency
        if action.actor in state.characters:
            char_attrs = state.characters[action.actor]
            if "role" in char_attrs:
                role = char_attrs["role"]

                if role == "hero":
                    if "help" in action.content or action.action_type == "reveal":
                        score += 0.2
                elif role == "villain" and action.action_type == "conflict":
                    score += 0.3
                elif role == "investigator" and action.action_type == "reveal":
                    score += 0.3

        # Story tension appropriateness
        if state.tension_level < 0.5 and action.action_type in ["conflict", "reveal"]:
            score += 0.1

        # Add some randomness for exploration
        score += random.uniform(-0.1, 0.1)

        return min(max(score, 0.0), 1.0)

    def predict_next_state(self, state: StoryState, action: StoryAction) -> StoryState:
        """Predict next state - can use LLM or heuristics"""
        # For efficiency, we'll use deterministic transitions
        # You could enhance this with LLM calls for more creative outcomes

        new_state = state.copy()

        if action.action_type == "dialogue":
            if action.actor in new_state.characters:
                if "relationships" not in new_state.characters[action.actor]:
                    new_state.characters[action.actor]["relationships"] = {}

                if "help" in action.content:
                    new_state.characters[action.actor]["relationships"][action.target] = "friendly"
                    new_state.tension_level = max(0, new_state.tension_level - 0.1)
                elif "confront" in action.content or "secret" in action.content:
                    new_state.characters[action.actor]["relationships"][action.target] = "tense"
                    new_state.tension_level = min(1.0, new_state.tension_level + 0.2)

        elif action.action_type == "move":
            for loc in self.locations:
                if loc in action.content:
                    new_state.location = loc
                    break

        elif action.action_type == "conflict":
            new_state.tension_level = min(1.0, new_state.tension_level + 0.3)
            conflict_point = f"conflict_{action.actor}_{action.target}"
            if conflict_point not in new_state.plot_points:
                new_state.plot_points.append(conflict_point)

        elif action.action_type == "reveal":
            revelation_num = len([p for p in new_state.plot_points if "revelation" in p])
            new_state.plot_points.append(f"revelation_{revelation_num}")
            new_state.tension_level = min(1.0, new_state.tension_level + 0.1)

        return new_state


In [5]:
class RAPNode:
    """Node in RAP-MCTS tree"""

    def __init__(self, state: StoryState, parent: Optional['RAPNode'] = None,
                 action: Optional[StoryAction] = None, depth: int = 0):
        self.state = state
        self.parent = parent
        self.action = action
        self.depth = depth
        self.children: Dict[str, 'RAPNode'] = {}
        self.visits = 0
        self.q_value = 0.0
        self.untried_actions: List[StoryAction] = []
        self.is_terminal = False

    def ucb1(self, c: float = 1.4) -> float:
        if self.visits == 0:
            return float('inf')
        exploitation = self.q_value / self.visits
        exploration = c * math.sqrt(math.log(self.parent.visits) / self.visits)
        return exploitation + exploration

    def best_child(self, c: float = 1.4) -> 'RAPNode':
        return max(self.children.values(), key=lambda n: n.ucb1(c))

    def add_child(self, action: StoryAction, state: StoryState) -> 'RAPNode':
        child = RAPNode(state, parent=self, action=action, depth=self.depth + 1)
        self.children[action.to_string()] = child
        return child

class RAP_MCTS:
    """RAP framework with MCTS for story planning"""

    def __init__(self,
                 world_model: ActualLLMWorldModel,
                 max_depth: int = 15,
                 iterations: int = 50,
                 exploration_constant: float = 1.4,
                 alpha: float = 0.5):
        self.world_model = world_model
        self.max_depth = max_depth
        self.iterations = iterations
        self.c = exploration_constant
        self.alpha = alpha
        self.root = None
        self.goal = None

    def search(self, initial_state: StoryState, goal: StoryGoal) -> List[StoryAction]:
        """Main RAP-MCTS search"""
        self.root = RAPNode(initial_state)
        self.goal = goal

        self.root.untried_actions = self.world_model.get_valid_actions(
            initial_state, goal
        )

        for iteration in range(self.iterations):
            node = self._selection()

            if not node.is_terminal and node.untried_actions:
                node = self._expansion(node)

            reward = self._evaluation(node)

            self._backpropagation(node, reward)

        return self._extract_best_path()

    def _selection(self) -> RAPNode:
        current = self.root
        while current.children and not current.is_terminal:
            if current.untried_actions:
                return current
            current = current.best_child(self.c)
        return current

    def _expansion(self, node: RAPNode) -> RAPNode:
        if not node.untried_actions:
            return node

        action = node.untried_actions.pop(0)
        next_state = self.world_model.predict_next_state(node.state, action)
        child = node.add_child(action, next_state)

        if self._is_goal_reached(next_state) or child.depth >= self.max_depth:
            child.is_terminal = True
        else:
            child.untried_actions = self.world_model.get_valid_actions(
                next_state, self.goal
            )

        return child

    def _evaluation(self, node: RAPNode) -> float:
        """Evaluate using dual reward system (r1 from LLM, r2 from heuristics)"""
        if node.action:
            r1 = self.world_model.calculate_action_likelihood(
                node.parent.state if node.parent else node.state,
                node.action
            )
        else:
            r1 = 1.0

        r2 = self._calculate_heuristic_reward(node.state)

        # RAP's dual reward combination
        reward = (r1 ** self.alpha) * (r2 ** (1 - self.alpha))

        return reward

    def _calculate_heuristic_reward(self, state: StoryState) -> float:
        """Task-specific heuristic (r2 in RAP)"""
        score = 0.0

        # Reward for completing plot points
        if self.goal.required_plot_points:
            required = set(self.goal.required_plot_points)
            completed = set(state.plot_points)
            matching = required.intersection(completed)
            score += len(matching) * 2.0

            if required.issubset(completed):
                score += 5.0

        # Reward for correct location
        if self.goal.target_location:
            if state.location == self.goal.target_location:
                score += 2.0

        # Reward for relationships
        if self.goal.target_relationships:
            for char, target_rel in self.goal.target_relationships.items():
                if char in state.characters:
                    char_rels = state.characters[char].get("relationships", {})
                    for other_char, rel in char_rels.items():
                        if rel == target_rel:
                            score += 1.0
                            break

        # Goal completion bonus
        if self._is_goal_reached(state):
            score = 100.0

        return score

    def _is_goal_reached(self, state: StoryState) -> bool:
        if self.goal.required_plot_points:
            required = set(self.goal.required_plot_points)
            completed = set(state.plot_points)
            if not required.issubset(completed):
                return False

        if self.goal.target_location:
            if state.location != self.goal.target_location:
                return False

        if self.goal.target_relationships:
            for char, target_rel in self.goal.target_relationships.items():
                if char in state.characters:
                    char_rels = state.characters[char].get("relationships", {})
                    has_relationship = any(
                        rel == target_rel for rel in char_rels.values()
                    )
                    if not has_relationship:
                        return False

        return True

    def _backpropagation(self, node: RAPNode, reward: float):
        current = node
        while current is not None:
            current.visits += 1
            current.q_value += reward
            current = current.parent

    def _extract_best_path(self) -> List[StoryAction]:
        path = []
        current = self.root
        visited = set()

        while current.children and len(path) < self.max_depth:
            state_key = str(current.state.plot_points) + current.state.location
            if state_key in visited:
                break
            visited.add(state_key)

            best_child = None
            best_score = -float('inf')

            for child in current.children.values():
                if child.visits > 0:
                    avg_value = child.q_value / child.visits
                    score = avg_value + (child.visits / 10.0)
                    if score > best_score:
                        best_score = score
                        best_child = child

            if best_child is None:
                break

            if best_child.action:
                path.append(best_child.action)
            current = best_child

            if current.is_terminal or self._is_goal_reached(current.state):
                break

        return path

In [6]:
def create_demo_scenarios():
    """Create demonstration story planning scenarios"""
    scenarios = [
        {
            "name": "Hero's Journey",
            "initial_state": StoryState(
                characters={
                    "Hero": {"role": "hero", "status": "naive"},
                    "Mentor": {"role": "guide", "status": "wise"},
                    "Villain": {"role": "villain", "status": "hidden"}
                },
                location="village",
                plot_points=[],
                tension_level=0.1
            ),
            "goal": StoryGoal(
                required_plot_points=["revelation_0", "conflict_Hero_Villain"],
                target_relationships={"Hero": "friendly", "Mentor": "friendly"},
                target_location="castle",
                min_tension_arc=True
            )
        },
        {
            "name": "Mystery Resolution",
            "initial_state": StoryState(
                characters={
                    "Detective": {"role": "investigator", "status": "searching"},
                    "Suspect1": {"role": "suspect", "status": "nervous"},
                    "Suspect2": {"role": "suspect", "status": "calm"}
                },
                location="home",
                plot_points=[],
                tension_level=0.3
            ),
            "goal": StoryGoal(
                required_plot_points=["revelation_0", "revelation_1"],
                target_relationships={"Detective": "tense"},
                target_location=None,
                min_tension_arc=True
            )
        }
    ]
    return scenarios

def run_experiment_with_llm():
    """Run RAP-MCTS with actual LLM"""

    print("=" * 60)
    print("RAP-MCTS Story Planning with LLM")
    print("=" * 60)

    # Initialize with your choice:
    # Option 1: OpenAI (need API key)
    # world_model = ActualLLMWorldModel(model_type="openai", api_key="your-key-here")

    # Option 2: HuggingFace (free, may be slower)
    # world_model = ActualLLMWorldModel(model_type="huggingface", hf_token="your-token")

    # Option 3: Hybrid (no API needed, uses heuristics + simulated LLM)
    world_model = ActualLLMWorldModel(model_type="hybrid")

    scenarios = create_demo_scenarios()

    for scenario in scenarios:
        print(f"\nScenario: {scenario['name']}")
        print("-" * 40)

        rap_mcts = RAP_MCTS(
            world_model=world_model,
            iterations=30,  # Reduced for API cost/speed
            alpha=0.6  # Balance between LLM and heuristics
        )

        print("Initial State:")
        print(scenario['initial_state'].to_string())

        start_time = time.time()
        best_path = rap_mcts.search(
            scenario['initial_state'],
            scenario['goal']
        )
        search_time = time.time() - start_time

        print(f"\nGenerated Plan ({len(best_path)} steps):")
        for i, action in enumerate(best_path, 1):
            print(f"  {i}. {action.to_string()}")

        # Simulate final state
        final_state = scenario['initial_state']
        for action in best_path:
            final_state = world_model.predict_next_state(final_state, action)

        goal_reached = rap_mcts._is_goal_reached(final_state)

        print(f"\nFinal State:")
        print(final_state.to_string())

        print(f"\nResults:")
        print(f"  Goal Reached: {'✓' if goal_reached else '✗'}")
        print(f"  Search Time: {search_time:.2f}s")
        print(f"  Model Type: {world_model.model_type}")

if __name__ == "__main__":
    # To use OpenAI, set your API key:
    # os.environ["OPENAI_API_KEY"] = "your-key-here"

    # To use HuggingFace, set your token:
    # os.environ["HF_TOKEN"] = "your-token-here"

    run_experiment_with_llm()

RAP-MCTS Story Planning with LLM
Using hybrid mode (heuristics + simulated LLM)

Scenario: Hero's Journey
----------------------------------------
Initial State:
Location: village
Characters:
  Hero: role=hero, status=naive
  Mentor: role=guide, status=wise
  Villain: role=villain, status=hidden
Plot points completed: None
Tension level: 0.1

Generated Plan (2 steps):
  1. Hero reveal: makes important discovery
  2. Hero conflict with Villain: confronts

Final State:
Location: village
Characters:
  Hero: role=hero, status=naive
  Mentor: role=guide, status=wise
  Villain: role=villain, status=hidden
Plot points completed: revelation_0, conflict_Hero_Villain
Tension level: 0.5

Results:
  Goal Reached: ✗
  Search Time: 0.00s
  Model Type: hybrid

Scenario: Mystery Resolution
----------------------------------------
Initial State:
Location: home
Characters:
  Detective: role=investigator, status=searching
  Suspect1: role=suspect, status=nervous
  Suspect2: role=suspect, status=calm
Plot