<a href="https://colab.research.google.com/drive/1OL1bZ8fGpRpsKJcnyTaZ10YEe-75_0uG?usp=sharing" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a>

In [1]:
!pip install -qU google-generativeai

In [2]:
import google.generativeai as genai
import getpass
import random

Get free-tier Google's Gemini API Key here: https://aistudio.google.com/app/apikey

In [3]:
API_KEY = getpass.getpass("Enter your Google API key: ")

Enter your Google API key: ··········


In [5]:
genai.configure(api_key=API_KEY)

In [6]:
class Node:
    def __init__(self, state, action=None, parent=None):
        self.state = state
        self.action = action
        self.parent = parent
        self.children = []
        self.visits = 0
        self.value = 0.0

    def uct_score(self):
        """Upper Confidence Bound for Trees"""
        if self.visits == 0:
            return float('inf')
        exploitation = self.value / self.visits
        exploration = 1.4 * (2 * self.parent.visits) ** 0.5 / self.visits
        return exploitation + exploration

class LATSAgent:
    def __init__(self, environment):
        self.model = genai.GenerativeModel("gemini-2.0-flash")
        self.environment = environment
        self.root = Node(environment.get_state())

    def select(self, node):
        """Select best child using UCT"""
        while node.children:
            node = max(node.children, key=lambda n: n.uct_score())
        return node

    def expand(self, node):
        """Generate possible actions"""
        prompt = f"""Given this state, suggest 3 possible actions:

        State: {node.state}

        Actions (one per line):"""

        response = self.model.generate_content(prompt).text
        actions = [line.strip() for line in response.split("\n") if line.strip()][:3]

        for action in actions:
            child = Node(None, action, node)
            node.children.append(child)

        return node.children[0] if node.children else node

    def simulate(self, node):
        """Execute action and get new state"""
        if node.action:
            print(f"  🎲 Simulating: {node.action}")
            new_state = self.environment.execute(node.action)
            node.state = new_state
            return new_state
        return node.state

    def evaluate(self, state, goal):
        """LLM evaluates how good the state is"""
        prompt = f"""Rate this state toward the goal (0-10):

        Goal: {goal}
        State: {state}

        Score (0-10):"""

        response = self.model.generate_content(prompt).text
        try:
            score = float(response.strip().split()[0])
            return min(max(score / 10, 0), 1)  # Normalize to 0-1
        except:
            return 0.5

    def backpropagate(self, node, value):
        """Update values up the tree"""
        while node:
            node.visits += 1
            node.value += value
            node = node.parent

    def search(self, goal, iterations=5):
        """Run MCTS with real environment interaction"""
        print(f"\n{'='*60}")
        print(f"🌳 LATS Search: {goal}")
        print(f"{'='*60}\n")

        for i in range(iterations):
            print(f"--- Iteration {i+1} ---")

            # 1. Selection
            node = self.select(self.root)
            print(f"  ↓ Selected node at depth {self._depth(node)}")

            # 2. Expansion
            if node.visits > 0:
                node = self.expand(node)
                print(f"  ↗ Expanded with action: {node.action}")

            # 3. Simulation
            state = self.simulate(node)

            # 4. Evaluation
            value = self.evaluate(state, goal)
            print(f"  ⭐ Evaluated: {value:.2f}\n")

            # 5. Backpropagation
            self.backpropagate(node, value)

        # Return best action
        best = max(self.root.children, key=lambda n: n.value / n.visits if n.visits > 0 else 0)
        print(f"{'='*60}")
        print(f"🏆 Best action: {best.action}")
        print(f"   Value: {best.value:.2f}, Visits: {best.visits}")
        print(f"{'='*60}\n")

        return best.action

    def _depth(self, node):
        depth = 0
        while node.parent:
            depth += 1
            node = node.parent
        return depth

    def show_tree(self):
        """Display tree statistics"""
        print(f"📊 Tree Stats:")
        print(f"   Root visits: {self.root.visits}")
        print(f"   Children: {len(self.root.children)}")
        for i, child in enumerate(self.root.children, 1):
            avg = child.value / child.visits if child.visits > 0 else 0
            print(f"   {i}. {child.action[:40]}... | Score: {avg:.2f} | Visits: {child.visits}")


# Simple environment simulator
class Environment:
    def __init__(self, initial_state):
        self.state = initial_state

    def get_state(self):
        return self.state

    def execute(self, action):
        """Execute action and return new state (simulated)"""
        # Simple simulation: actions modify state
        new_state = self.state.copy()

        if "collect" in action.lower():
            new_state["resources"] = new_state.get("resources", 0) + 10
        elif "build" in action.lower():
            new_state["structures"] = new_state.get("structures", 0) + 1
            new_state["resources"] = max(0, new_state.get("resources", 0) - 5)
        elif "explore" in action.lower():
            new_state["explored"] = new_state.get("explored", 0) + 1
        elif "attack" in action.lower():
            new_state["enemies"] = max(0, new_state.get("enemies", 3) - 1)
        elif "defend" in action.lower():
            new_state["defense"] = new_state.get("defense", 0) + 1

        self.state = new_state
        return new_state

In [7]:
# Example 1: Resource Management
print("="*60)
print("EXAMPLE 1: Resource Management Game")
print("="*60)

env1 = Environment({"resources": 0, "structures": 0, "explored": 0})
agent1 = LATSAgent(env1)

best_action = agent1.search("Build 2 structures", iterations=6)
agent1.show_tree()


# Example 2: Combat Strategy
print("\n" + "="*60)
print("EXAMPLE 2: Combat Strategy")
print("="*60)

env2 = Environment({"enemies": 3, "defense": 0, "health": 100})
agent2 = LATSAgent(env2)

best_action = agent2.search("Defeat all enemies safely", iterations=6)
agent2.show_tree()


EXAMPLE 1: Resource Management Game

🌳 LATS Search: Build 2 structures

--- Iteration 1 ---
  ↓ Selected node at depth 0
  ⭐ Evaluated: 0.00

--- Iteration 2 ---
  ↓ Selected node at depth 0
  ↗ Expanded with action: Here are three possible actions based on the given state, aiming for early game development:
  🎲 Simulating: Here are three possible actions based on the given state, aiming for early game development:
  ⭐ Evaluated: 0.20

--- Iteration 3 ---
  ↓ Selected node at depth 1
  🎲 Simulating: 1.  **Explore:** (Prioritizes discovering more of the game world and potentially revealing resource locations)
  ⭐ Evaluated: 0.80

--- Iteration 4 ---
  ↓ Selected node at depth 1
  🎲 Simulating: 2.  **Gather Resources:** (Focuses on acquiring the fundamental resource needed to build structures or perform other actions)
  ⭐ Evaluated: 0.80

--- Iteration 5 ---
  ↓ Selected node at depth 1
  ↗ Expanded with action: Here are 3 possible actions based on the given state:
  🎲 Simulating: Here a



TooManyRequests: 429 POST https://generativelanguage.googleapis.com/v1beta/models/gemini-2.0-flash:generateContent?%24alt=json%3Benum-encoding%3Dint: You exceeded your current quota, please check your plan and billing details. For more information on this error, head to: https://ai.google.dev/gemini-api/docs/rate-limits.
* Quota exceeded for metric: generativelanguage.googleapis.com/generate_content_free_tier_requests, limit: 15
Please retry in 29.307326756s.