<a href="https://colab.research.google.com/drive/19yl6LcECtk657I1d7iLYjeM4R8k4cCUc?usp=sharing" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"></a>

### Graph of Thoughts (GoT)

In [1]:
!pip install -qU google-generativeai

In [2]:
import google.generativeai as genai
import getpass

Get free-tier Google's Gemini API Key here: https://aistudio.google.com/app/apikey

In [3]:
API_KEY = getpass.getpass("Enter your Google API key: ")

Enter your Google API key: ··········


In [5]:
genai.configure(api_key=API_KEY)

In [6]:
class ThoughtNode:
    def __init__(self, content, node_id):
        self.id = node_id
        self.content = content
        self.score = 0.0
        self.predecessors = []  # Can have multiple parents (graph, not tree)
        self.successors = []    # Can have multiple children

    def add_successor(self, node):
        if node not in self.successors:
            self.successors.append(node)
            node.predecessors.append(self)

class GoTAgent:
    def __init__(self):
        self.model = genai.GenerativeModel("gemini-2.0-flash-exp")
        self.nodes = {}
        self.node_counter = 0

    def create_node(self, content):
        """Create a new thought node"""
        node_id = f"N{self.node_counter}"
        self.node_counter += 1
        node = ThoughtNode(content, node_id)
        self.nodes[node_id] = node
        return node

    def generate(self, problem, context_nodes, num_thoughts=3):
        """Generate new candidate thoughts"""
        context = "\n".join([f"- {n.content}" for n in context_nodes])

        prompt = f"""Problem: {problem}

        Current thoughts:
        {context}

        Generate {num_thoughts} different next ideas or reasoning steps.
        List them numbered:"""

        response = self.model.generate_content(prompt).text

        # Parse thoughts
        thoughts = []
        for line in response.split("\n"):
            line = line.strip()
            if line and (line[0].isdigit() or line.startswith("-")):
                thought = line.lstrip("0123456789.-) ").strip()
                if thought and len(thought) > 10:
                    thoughts.append(thought)

        return thoughts[:num_thoughts]

    def score(self, problem, thought):
        """Score a thought's quality (0-10)"""
        prompt = f"""Problem: {problem}

        Thought: {thought}

        Rate this thought (0-10) based on:
        - Relevance to problem
        - Logical soundness
        - Potential to lead to solution

        Score (just number):"""

        response = self.model.generate_content(prompt).text

        try:
            score = float(response.strip().split()[0])
            return min(max(score / 10, 0), 1)
        except:
            return 0.5

    def aggregate(self, problem, nodes):
        """Combine multiple thoughts into one"""
        thoughts = "\n".join([f"{i+1}. {n.content}" for i, n in enumerate(nodes)])

        prompt = f"""Problem: {problem}

        Multiple thoughts to combine:
        {thoughts}

        Synthesize these into one coherent, stronger thought:"""

        response = self.model.generate_content(prompt).text
        return response.strip()

    def refine(self, problem, node):
        """Improve a single thought"""
        prompt = f"""Problem: {problem}

        Current thought: {node.content}

        Refine and improve this thought:"""

        response = self.model.generate_content(prompt).text
        return response.strip()

    def solve(self, problem, max_iterations=4, branch_factor=2):
        """Solve using Graph of Thoughts"""
        print(f"\n{'='*60}")
        print(f"🕸️  Graph of Thoughts")
        print(f"{'='*60}")
        print(f"Problem: {problem}\n")

        # Initialize with root node
        root = self.create_node("Starting to analyze the problem")
        root.score = 1.0

        active_nodes = [root]
        all_paths = []

        for iteration in range(max_iterations):
            print(f"{'─'*60}")
            print(f"ITERATION {iteration + 1}")
            print(f"{'─'*60}\n")

            new_active = []

            # GENERATE: Create new thoughts from active nodes
            print("🌱 Generating new thoughts...")
            for node in active_nodes:
                thoughts = self.generate(problem, [node], branch_factor)

                for thought in thoughts:
                    new_node = self.create_node(thought)
                    node.add_successor(new_node)

                    # SCORE: Evaluate thought
                    new_node.score = self.score(problem, thought)

                    print(f"  {new_node.id}: [Score: {new_node.score:.2f}] {thought[:60]}...")

                    if new_node.score > 0.4:
                        new_active.append(new_node)

            print()

            # AGGREGATE: Merge promising parallel thoughts
            if len(new_active) >= 2:
                print("🔗 Aggregating thoughts...")
                # Take top 2 nodes to merge
                sorted_nodes = sorted(new_active, key=lambda n: n.score, reverse=True)
                to_merge = sorted_nodes[:2]

                merged_content = self.aggregate(problem, to_merge)
                merged_node = self.create_node(merged_content)

                # Connect to both predecessors (graph structure!)
                for node in to_merge:
                    node.add_successor(merged_node)

                merged_node.score = self.score(problem, merged_content)
                print(f"  {merged_node.id}: [Score: {merged_node.score:.2f}] {merged_content[:60]}...")
                print()

                new_active.append(merged_node)

            # REFINE: Improve best thought
            if new_active:
                print("✨ Refining best thought...")
                best_node = max(new_active, key=lambda n: n.score)

                refined_content = self.refine(problem, best_node)
                refined_node = self.create_node(refined_content)
                best_node.add_successor(refined_node)

                refined_node.score = self.score(problem, refined_content)
                print(f"  {refined_node.id}: [Score: {refined_node.score:.2f}] {refined_content[:60]}...")
                print()

                new_active.append(refined_node)

            # Keep top nodes for next iteration
            active_nodes = sorted(new_active, key=lambda n: n.score, reverse=True)[:3]

            # Track paths
            for node in active_nodes:
                path = self._get_path_to_node(node)
                all_paths.append((node, path, node.score))

        # Find best path
        best_node, best_path, best_score = max(all_paths, key=lambda x: x[2])

        print(f"{'='*60}")
        print(f"🏆 BEST REASONING PATH")
        print(f"{'='*60}")
        for i, node in enumerate(best_path):
            print(f"{i}. [{node.id}, Score: {node.score:.2f}] {node.content}")
        print()

        # Generate final answer
        path_text = "\n".join([f"{i+1}. {n.content}" for i, n in enumerate(best_path)])

        final_prompt = f"""Problem: {problem}

        Reasoning path:
        {path_text}

        Provide final answer:"""

        final_answer = self.model.generate_content(final_prompt).text

        print(f"{'='*60}")
        print(f"💡 FINAL ANSWER")
        print(f"{'='*60}")
        print(final_answer)
        print()

        self._visualize_graph()

        return final_answer

    def _get_path_to_node(self, node):
        """Get one path from root to node (BFS)"""
        # Simple path - just track backwards through first predecessor
        path = []
        current = node
        while current:
            path.append(current)
            current = current.predecessors[0] if current.predecessors else None
        return list(reversed(path))

    def _visualize_graph(self):
        """Show graph structure"""
        print(f"{'='*60}")
        print(f"📊 GRAPH STRUCTURE")
        print(f"{'='*60}")
        print(f"Total nodes: {len(self.nodes)}")
        print(f"Connections:")
        for node_id, node in self.nodes.items():
            if node.successors:
                successors = ", ".join([n.id for n in node.successors])
                print(f"  {node_id} → {successors}")
        print()

In [None]:
# Example 1: Document Merging
print("="*60)
print("EXAMPLE 1: Document Merging")
print("="*60)

got1 = GoTAgent()
got1.solve(
    "Merge insights from three reports: Report A says 'sales up 20%', "
    "Report B says 'customer satisfaction at 4.2/5', Report C says 'costs increased 15%'. "
    "What's the overall business health?",
    max_iterations=3,
    branch_factor=2
)


# Example 2: Sorting with Rationale
print("\n" + "="*60)
print("EXAMPLE 2: Sorting with Rationale")
print("="*60)

got2 = GoTAgent()
got2.solve(
    "Sort these priorities for a startup: A) Customer acquisition, B) Product development, "
    "C) Fundraising, D) Team building. Consider dependencies and timing.",
    max_iterations=3,
    branch_factor=2
)


# Example 3: Complex Reasoning
print("\n" + "="*60)
print("EXAMPLE 3: Complex Multi-Path Reasoning")
print("="*60)

got3 = GoTAgent()
got3.solve(
    "A company can invest in: AI research (high risk, high reward), "
    "market expansion (medium risk/reward), or cost optimization (low risk/reward). "
    "Budget allows 2 choices. Which combination is best?",
    max_iterations=3,
    branch_factor=2
)


# Example 4: Knowledge Synthesis
print("\n" + "="*60)
print("EXAMPLE 4: Knowledge Synthesis")
print("="*60)

got4 = GoTAgent()
got4.solve(
    "Synthesize solution: Climate experts say 'reduce emissions 50% by 2030', "
    "economists say 'transition must be gradual to avoid disruption', "
    "technologists say 'renewable tech is now cost-effective'. What's the best approach?",
    max_iterations=3,
    branch_factor=2
)


# Example 5: Strategic Planning
print("\n" + "="*60)
print("EXAMPLE 5: Strategic Decision with Contingencies")
print("="*60)

got5 = GoTAgent()
got5.solve(
    "Plan software release strategy. Options: A) Big bang release (all features at once), "
    "B) Phased rollout (gradual), C) Beta program first. Consider risks, user feedback, and resources.",
    max_iterations=3,
    branch_factor=2
)

print("✅ Graph of Thoughts Complete!")