In [7]:
import numpy as np
import math
import random
import time
from typing import List, Dict, Any, Optional, Tuple
from dataclasses import dataclass, field
from abc import ABC, abstractmethod
import json

In [8]:
@dataclass
class StoryState:
    """Represents a state in story planning"""
    characters: Dict[str, Dict[str, Any]]  # Character states and attributes
    location: str
    plot_points: List[str]  # Completed plot points
    tension_level: float  # Story tension (0-1)

    def to_string(self) -> str:
        """Convert state to string for LLM"""
        desc = f"Location: {self.location}\n"
        desc += "Characters:\n"
        for char, attrs in self.characters.items():
            desc += f"  {char}: "
            desc += ", ".join([f"{k}={v}" for k, v in attrs.items()])
            desc += "\n"
        desc += f"Plot points completed: {', '.join(self.plot_points) if self.plot_points else 'None'}\n"
        desc += f"Tension level: {self.tension_level:.1f}"
        return desc

    def copy(self):
        """Deep copy of state"""
        return StoryState(
            characters={k: v.copy() for k, v in self.characters.items()},
            location=self.location,
            plot_points=self.plot_points.copy(),
            tension_level=self.tension_level
        )

@dataclass
class StoryAction:
    """Represents an action in story"""
    action_type: str  # "dialogue", "move", "conflict", "reveal"
    actor: str  # Character performing action
    target: Optional[str] = None  # Target character/object
    content: str = ""  # Action description

    def to_string(self) -> str:
        if self.target:
            return f"{self.actor} {self.action_type} with {self.target}: {self.content}"
        return f"{self.actor} {self.action_type}: {self.content}"

@dataclass
class StoryGoal:
    """Defines the target story state"""
    required_plot_points: List[str]
    target_relationships: Dict[str, str]  # Character relationship goals
    target_location: Optional[str] = None
    min_tension_arc: bool = True  # Should have rising/falling tension

In [9]:
class LLMWorldModel:
    """Simulates LLM for world modeling and evaluation - FIXED VERSION"""

    def __init__(self):
        self.locations = ["village", "castle", "forest", "home", "mountain", "river"]

    def get_valid_actions(self, state: StoryState, goal: StoryGoal) -> List[StoryAction]:
        """Generate valid actions prioritizing goal-relevant ones"""
        actions = []
        chars = list(state.characters.keys())

        # Priority 1: Actions that directly achieve goals

        # If need to reach target location
        if goal.target_location and state.location != goal.target_location:
            for actor in chars:
                actions.append(StoryAction(
                    action_type="move",
                    actor=actor,
                    content=f"travels to {goal.target_location}"
                ))

        # If need specific plot points
        for required in goal.required_plot_points:
            if required not in state.plot_points:
                if "revelation" in required:
                    for actor in chars:
                        actions.append(StoryAction(
                            action_type="reveal",
                            actor=actor,
                            content="makes important discovery"
                        ))
                elif "conflict" in required:
                    parts = required.split("_")
                    if len(parts) == 3:
                        actor, target = parts[1], parts[2]
                        if actor in chars and target in chars:
                            actions.append(StoryAction(
                                action_type="conflict",
                                actor=actor,
                                target=target,
                                content="confronts"
                            ))

        # Priority 2: Actions for relationships
        for char, target_rel in goal.target_relationships.items():
            if char in state.characters:
                current_rels = state.characters[char].get("relationships", {})
                needs_relationship = True
                for other, rel in current_rels.items():
                    if rel == target_rel:
                        needs_relationship = False
                        break

                if needs_relationship:
                    for other_char in chars:
                        if other_char != char:
                            if target_rel == "friendly":
                                actions.append(StoryAction(
                                    action_type="dialogue",
                                    actor=char,
                                    target=other_char,
                                    content="asks for help from"
                                ))
                            elif target_rel == "tense":
                                actions.append(StoryAction(
                                    action_type="dialogue",
                                    actor=char,
                                    target=other_char,
                                    content="confronts"
                                ))

        # Priority 3: General actions
        if len(actions) < 5:
            for actor in chars[:2]:  # Limit to avoid explosion
                for target in chars[:2]:
                    if target != actor:
                        actions.append(StoryAction(
                            action_type="dialogue",
                            actor=actor,
                            target=target,
                            content="reveals secret to"
                        ))

        # Return deterministically ordered actions
        actions.sort(key=lambda a: a.to_string())
        return actions[:15]

    def predict_next_state(self, state: StoryState, action: StoryAction) -> StoryState:
        """Deterministic state transition"""
        new_state = state.copy()

        if action.action_type == "dialogue":
            if action.actor in new_state.characters:
                if "relationships" not in new_state.characters[action.actor]:
                    new_state.characters[action.actor]["relationships"] = {}

                if "asks for help" in action.content:
                    new_state.characters[action.actor]["relationships"][action.target] = "friendly"
                    new_state.tension_level = max(0, new_state.tension_level - 0.1)
                elif "confronts" in action.content or "reveals secret" in action.content:
                    new_state.characters[action.actor]["relationships"][action.target] = "tense"
                    new_state.tension_level = min(1.0, new_state.tension_level + 0.2)

        elif action.action_type == "move":
            for loc in self.locations:
                if loc in action.content:
                    new_state.location = loc
                    break

        elif action.action_type == "conflict":
            new_state.tension_level = min(1.0, new_state.tension_level + 0.3)
            # Always add conflict plot point
            conflict_point = f"conflict_{action.actor}_{action.target}"
            if conflict_point not in new_state.plot_points:
                new_state.plot_points.append(conflict_point)

        elif action.action_type == "reveal":
            # Add revelation plot point
            revelation_num = len([p for p in new_state.plot_points if "revelation" in p])
            new_state.plot_points.append(f"revelation_{revelation_num}")
            new_state.tension_level = min(1.0, new_state.tension_level + 0.1)

        return new_state

    def calculate_action_likelihood(self, state: StoryState, action: StoryAction) -> float:
        """Calculate r1: likelihood of action being narratively coherent"""
        score = 0.5

        if action.actor in state.characters:
            char_attrs = state.characters[action.actor]
            if "role" in char_attrs:
                role = char_attrs["role"]

                if role == "hero":
                    if "asks for help" in action.content:
                        score += 0.3
                    elif action.action_type == "reveal":
                        score += 0.2
                elif role == "villain" and action.action_type == "conflict":
                    score += 0.3
                elif role == "investigator" and action.action_type == "reveal":
                    score += 0.3

        if state.tension_level < 0.5 and action.action_type in ["conflict", "reveal"]:
            score += 0.2

        return min(score, 1.0)

In [10]:
class RAPNode:
    """Node in RAP-MCTS tree"""

    def __init__(self, state: StoryState, parent: Optional['RAPNode'] = None,
                 action: Optional[StoryAction] = None, depth: int = 0):
        self.state = state
        self.parent = parent
        self.action = action
        self.depth = depth
        self.children: Dict[str, 'RAPNode'] = {}

        self.visits = 0
        self.q_value = 0.0
        self.untried_actions: List[StoryAction] = []
        self.is_terminal = False

    def ucb1(self, c: float = 1.4) -> float:
        """Calculate UCB1 value"""
        if self.visits == 0:
            return float('inf')

        exploitation = self.q_value / self.visits
        exploration = c * math.sqrt(math.log(self.parent.visits) / self.visits)
        return exploitation + exploration

    def best_child(self, c: float = 1.4) -> 'RAPNode':
        """Select best child using UCB1"""
        return max(self.children.values(), key=lambda n: n.ucb1(c))

    def add_child(self, action: StoryAction, state: StoryState) -> 'RAPNode':
        """Add a child node"""
        child = RAPNode(state, parent=self, action=action, depth=self.depth + 1)
        self.children[action.to_string()] = child
        return child

class RAP_MCTS:
    """RAP framework with MCTS for story planning - FIXED VERSION"""

    def __init__(self,
                 world_model: LLMWorldModel,
                 max_depth: int = 15,  # Increased from 10
                 iterations: int = 50,  # Increased from 20
                 exploration_constant: float = 1.4,
                 alpha: float = 0.5):
        self.world_model = world_model
        self.max_depth = max_depth
        self.iterations = iterations
        self.c = exploration_constant
        self.alpha = alpha
        self.root = None
        self.goal = None

    def search(self, initial_state: StoryState, goal: StoryGoal) -> List[StoryAction]:
        """Main RAP-MCTS search"""
        self.root = RAPNode(initial_state)
        self.goal = goal

        self.root.untried_actions = self.world_model.get_valid_actions(
            initial_state, goal
        )

        for iteration in range(self.iterations):
            node = self._selection()

            if not node.is_terminal and node.untried_actions:
                node = self._expansion(node)

            reward = self._evaluation(node)

            self._backpropagation(node, reward)

        return self._extract_best_path()

    def _selection(self) -> RAPNode:
        """Select leaf node using UCB1"""
        current = self.root

        while current.children and not current.is_terminal:
            if current.untried_actions:
                return current
            current = current.best_child(self.c)

        return current

    def _expansion(self, node: RAPNode) -> RAPNode:
        """Expand node with world model"""
        if not node.untried_actions:
            return node

        action = node.untried_actions.pop(0)

        next_state = self.world_model.predict_next_state(node.state, action)

        child = node.add_child(action, next_state)

        if (self._is_goal_reached(next_state) or
            child.depth >= self.max_depth):
            child.is_terminal = True
        else:
            child.untried_actions = self.world_model.get_valid_actions(
                next_state, self.goal
            )

        return child

    def _evaluation(self, node: RAPNode) -> float:
        """Evaluate state using dual reward system"""
        if node.action:
            r1 = self.world_model.calculate_action_likelihood(
                node.parent.state if node.parent else node.state,
                node.action
            )
        else:
            r1 = 1.0

        r2 = self._calculate_heuristic_reward(node.state)

        reward = (r1 ** self.alpha) * (r2 ** (1 - self.alpha))

        return reward

    def _calculate_heuristic_reward(self, state: StoryState) -> float:
        """Enhanced heuristic reward - FIXED VERSION"""
        score = 0.0

        # Strong reward for completing plot points
        if self.goal.required_plot_points:
            required = set(self.goal.required_plot_points)
            completed = set(state.plot_points)
            matching = required.intersection(completed)

            # 2 points per completed plot point
            score += len(matching) * 2.0

            # Bonus if all plot points done
            if required.issubset(completed):
                score += 5.0

        # Reward for correct location
        if self.goal.target_location:
            if state.location == self.goal.target_location:
                score += 2.0

        # Reward for relationships
        if self.goal.target_relationships:
            for char, target_rel in self.goal.target_relationships.items():
                if char in state.characters:
                    char_rels = state.characters[char].get("relationships", {})
                    for other_char, rel in char_rels.items():
                        if rel == target_rel:
                            score += 1.0
                            break

        # Tension bonus
        if self.goal.min_tension_arc and state.tension_level > 0.3:
            score += 0.5

        # Huge bonus if goal reached
        if self._is_goal_reached(state):
            score = 100.0

        return score

    def _is_goal_reached(self, state: StoryState) -> bool:
        """Check if goal conditions are met"""
        if self.goal.required_plot_points:
            required = set(self.goal.required_plot_points)
            completed = set(state.plot_points)
            if not required.issubset(completed):
                return False

        if self.goal.target_location:
            if state.location != self.goal.target_location:
                return False

        if self.goal.target_relationships:
            for char, target_rel in self.goal.target_relationships.items():
                if char in state.characters:
                    char_rels = state.characters[char].get("relationships", {})
                    has_relationship = any(
                        rel == target_rel for rel in char_rels.values()
                    )
                    if not has_relationship:
                        return False

        return True

    def _backpropagation(self, node: RAPNode, reward: float):
        """Propagate reward up the tree"""
        current = node
        while current is not None:
            current.visits += 1
            current.q_value += reward
            current = current.parent

    def _extract_best_path(self) -> List[StoryAction]:
        """Extract best path - FIXED VERSION"""
        path = []
        current = self.root
        visited = set()

        while current.children and len(path) < self.max_depth:
            state_key = str(current.state.plot_points) + current.state.location
            if state_key in visited:
                break
            visited.add(state_key)

            # Choose child with best average value
            best_child = None
            best_score = -float('inf')

            for child in current.children.values():
                if child.visits > 0:
                    avg_value = child.q_value / child.visits
                    # Combine value with visit count
                    score = avg_value + (child.visits / 10.0)
                    if score > best_score:
                        best_score = score
                        best_child = child

            if best_child is None:
                break

            if best_child.action:
                path.append(best_child.action)
            current = best_child

            if current.is_terminal or self._is_goal_reached(current.state):
                break

        return path

In [11]:
def create_demo_scenarios():
    """Create demonstration story planning scenarios"""

    scenarios = [
        {
            "name": "Hero's Journey",
            "initial_state": StoryState(
                characters={
                    "Hero": {"role": "hero", "status": "naive"},
                    "Mentor": {"role": "guide", "status": "wise"},
                    "Villain": {"role": "villain", "status": "hidden"}
                },
                location="village",
                plot_points=[],
                tension_level=0.1
            ),
            "goal": StoryGoal(
                required_plot_points=["revelation_0", "conflict_Hero_Villain"],
                target_relationships={"Hero": "friendly", "Mentor": "friendly"},
                target_location="castle",
                min_tension_arc=True
            )
        },
        {
            "name": "Mystery Resolution",
            "initial_state": StoryState(
                characters={
                    "Detective": {"role": "investigator", "status": "searching"},
                    "Suspect1": {"role": "suspect", "status": "nervous"},
                    "Suspect2": {"role": "suspect", "status": "calm"},
                    "Witness": {"role": "witness", "status": "afraid"}
                },
                location="home",
                plot_points=[],
                tension_level=0.3
            ),
            "goal": StoryGoal(
                required_plot_points=["revelation_0", "revelation_1"],
                target_relationships={"Detective": "tense"},
                target_location=None,
                min_tension_arc=True
            )
        },
        {
            "name": "Romance Arc",
            "initial_state": StoryState(
                characters={
                    "Protagonist": {"role": "lead", "status": "lonely"},
                    "Love_Interest": {"role": "lead", "status": "independent"},
                    "Rival": {"role": "antagonist", "status": "confident"}
                },
                location="village",
                plot_points=[],
                tension_level=0.2
            ),
            "goal": StoryGoal(
                required_plot_points=["conflict_Protagonist_Rival"],
                target_relationships={"Protagonist": "friendly", "Love_Interest": "friendly"},
                target_location="home",
                min_tension_arc=True
            )
        }
    ]

    return scenarios

In [12]:
def run_rap_mcts_experiments():
    """Run RAP-MCTS on story planning tasks - FIXED VERSION"""

    print("=" * 60)
    print("RAP-MCTS Story Planning Experiments (FIXED)")
    print("Based on: 'Reasoning via Planning' (EMNLP 2023)")
    print("=" * 60)

    world_model = LLMWorldModel()
    scenarios = create_demo_scenarios()

    # Updated configurations with more iterations
    configs = [
        {"iterations": 50, "alpha": 0.5, "name": "RAP-50"},
        {"iterations": 100, "alpha": 0.5, "name": "RAP-100"},
        {"iterations": 100, "alpha": 0.7, "name": "RAP-100-α0.7"},
    ]

    results = {}

    for config in configs:
        print(f"\n{'='*50}")
        print(f"Configuration: {config['name']}")
        print(f"  Iterations: {config['iterations']}")
        print(f"  Alpha (r1 weight): {config['alpha']}")
        print(f"{'='*50}")

        config_results = []

        for scenario in scenarios:
            print(f"\nScenario: {scenario['name']}")
            print("-" * 30)

            rap_mcts = RAP_MCTS(
                world_model=world_model,
                iterations=config['iterations'],
                alpha=config['alpha'],
                max_depth=15  # Increased depth
            )

            start_time = time.time()
            best_path = rap_mcts.search(
                scenario['initial_state'],
                scenario['goal']
            )
            search_time = time.time() - start_time

            final_state = scenario['initial_state']
            for action in best_path:
                final_state = world_model.predict_next_state(final_state, action)

            goal_reached = rap_mcts._is_goal_reached(final_state)
            final_reward = rap_mcts._calculate_heuristic_reward(final_state)

            print(f"Initial State:")
            print(scenario['initial_state'].to_string())
            print(f"\nGoal:")
            print(f"  Required plot points: {scenario['goal'].required_plot_points}")
            print(f"  Target relationships: {scenario['goal'].target_relationships}")
            print(f"  Target location: {scenario['goal'].target_location}")

            print(f"\nGenerated Plan ({len(best_path)} steps):")
            for i, action in enumerate(best_path, 1):
                print(f"  {i}. {action.to_string()}")

            print(f"\nFinal State:")
            print(final_state.to_string())

            print(f"\nResults:")
            print(f"  Goal Reached: {'✓' if goal_reached else '✗'}")
            print(f"  Final Reward: {final_reward:.3f}")
            print(f"  Search Time: {search_time:.2f}s")
            print(f"  Nodes Explored: {count_nodes(rap_mcts.root)}")

            config_results.append({
                "scenario": scenario['name'],
                "goal_reached": goal_reached,
                "final_reward": final_reward,
                "path_length": len(best_path),
                "search_time": search_time,
                "nodes_explored": count_nodes(rap_mcts.root)
            })

        results[config['name']] = config_results

    print("\n" + "=" * 60)
    print("SUMMARY COMPARISON")
    print("=" * 60)

    for config_name, config_results in results.items():
        success_rate = sum(r['goal_reached'] for r in config_results) / len(config_results)
        avg_reward = np.mean([r['final_reward'] for r in config_results])
        avg_path_length = np.mean([r['path_length'] for r in config_results])
        avg_time = np.mean([r['search_time'] for r in config_results])

        print(f"\n{config_name}:")
        print(f"  Success Rate: {success_rate:.1%}")
        print(f"  Avg Final Reward: {avg_reward:.3f}")
        print(f"  Avg Path Length: {avg_path_length:.1f}")
        print(f"  Avg Search Time: {avg_time:.3f}s")

def count_nodes(node):
    """Count total nodes in tree"""
    if not node:
        return 0
    count = 1
    for child in node.children.values():
        count += count_nodes(child)
    return count

class GreedyBaseline:
    """Greedy baseline for comparison (similar to CoT)"""

    def __init__(self, world_model: LLMWorldModel):
        self.world_model = world_model

    def search(self, initial_state: StoryState, goal: StoryGoal, max_depth: int = 10) -> List[StoryAction]:
        """Greedy search without tree exploration"""
        path = []
        current_state = initial_state

        for _ in range(max_depth):
            actions = self.world_model.get_valid_actions(current_state, goal)
            if not actions:
                break

            best_action = None
            best_reward = -1

            for action in actions[:5]:
                next_state = self.world_model.predict_next_state(current_state, action)
                reward = self._evaluate_state(next_state, goal)
                if reward > best_reward:
                    best_reward = reward
                    best_action = action

            if best_action:
                path.append(best_action)
                current_state = self.world_model.predict_next_state(current_state, best_action)

                if self._is_goal_reached(current_state, goal):
                    break

        return path

    def _evaluate_state(self, state: StoryState, goal: StoryGoal) -> float:
        """Simple state evaluation"""
        score = 0.0

        if goal.required_plot_points:
            required = set(goal.required_plot_points)
            completed = set(state.plot_points)
            score += len(required.intersection(completed)) / len(required)

        if goal.target_location and state.location == goal.target_location:
            score += 0.3

        return score

    def _is_goal_reached(self, state: StoryState, goal: StoryGoal) -> bool:
        """Check if goal is reached"""
        if goal.required_plot_points:
            required = set(goal.required_plot_points)
            completed = set(state.plot_points)
            if not required.issubset(completed):
                return False

        if goal.target_location and state.location != goal.target_location:
            return False

        return True

def compare_with_baseline():
    """Compare RAP-MCTS with greedy baseline"""

    print("\n" + "=" * 60)
    print("BASELINE COMPARISON")
    print("=" * 60)

    world_model = LLMWorldModel()
    scenarios = create_demo_scenarios()

    rap_results = []
    baseline_results = []

    for scenario in scenarios:
        # Use fixed RAP-MCTS with more iterations
        rap_mcts = RAP_MCTS(
            world_model=world_model,
            iterations=100,  # Increased
            max_depth=15
        )
        rap_path = rap_mcts.search(scenario['initial_state'], scenario['goal'])

        baseline = GreedyBaseline(world_model)
        baseline_path = baseline.search(scenario['initial_state'], scenario['goal'])

        rap_final = scenario['initial_state']
        for action in rap_path:
            rap_final = world_model.predict_next_state(rap_final, action)

        baseline_final = scenario['initial_state']
        for action in baseline_path:
            baseline_final = world_model.predict_next_state(baseline_final, action)

        rap_success = rap_mcts._is_goal_reached(rap_final)
        baseline_success = baseline._is_goal_reached(baseline_final, scenario['goal'])

        rap_results.append(rap_success)
        baseline_results.append(baseline_success)

        print(f"\n{scenario['name']}:")
        print(f"  RAP-MCTS: {'✓' if rap_success else '✗'} ({len(rap_path)} steps)")
        print(f"  Baseline: {'✓' if baseline_success else '✗'} ({len(baseline_path)} steps)")

    print(f"\nOverall Success Rates:")
    print(f"  RAP-MCTS: {sum(rap_results)}/{len(rap_results)} ({sum(rap_results)/len(rap_results):.1%})")
    print(f"  Baseline: {sum(baseline_results)}/{len(baseline_results)} ({sum(baseline_results)/len(baseline_results):.1%})")

In [13]:
if __name__ == "__main__":
    # Run main experiments
    run_rap_mcts_experiments()

    # Compare with baseline
    compare_with_baseline()

    print("\n" + "=" * 60)
    print("RAP-MCTS Implementation Complete!")
    print("=" * 60)

RAP-MCTS Story Planning Experiments (FIXED)
Based on: 'Reasoning via Planning' (EMNLP 2023)

Configuration: RAP-50
  Iterations: 50
  Alpha (r1 weight): 0.5

Scenario: Hero's Journey
------------------------------
Initial State:
Location: village
Characters:
  Hero: role=hero, status=naive
  Mentor: role=guide, status=wise
  Villain: role=villain, status=hidden
Plot points completed: None
Tension level: 0.1

Goal:
  Required plot points: ['revelation_0', 'conflict_Hero_Villain']
  Target relationships: {'Hero': 'friendly', 'Mentor': 'friendly'}
  Target location: castle

Generated Plan (2 steps):
  1. Hero conflict with Villain: confronts
  2. Hero reveal: makes important discovery

Final State:
Location: village
Characters:
  Hero: role=hero, status=naive
  Mentor: role=guide, status=wise
  Villain: role=villain, status=hidden
Plot points completed: conflict_Hero_Villain, revelation_0
Tension level: 0.5

Results:
  Goal Reached: ✗
  Final Reward: 9.500
  Search Time: 0.00s
  Nodes Exp