<a href="https://colab.research.google.com/github/katharguppe/BITS_Pilani_Final/blob/master/COAT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:

# Install required libraries
!pip install anytree  # For tree visualization

import random
from anytree import Node, RenderTree  # To visualize the reasoning tree
from collections import defaultdict

# === Step 1: Define the Associative Memory Mechanism ===
class AssociativeMemory:
    """
    A simple associative memory mechanism that dynamically stores and retrieves relevant information.
    """
    def __init__(self):
        self.memory = defaultdict(list)  # Stores associations as key-value pairs

    def add_association(self, key, value):
        """
        Adds a new association to the memory.
        :param key: The context or topic to associate with.
        :param value: The associated information.
        """
        self.memory[key].append(value)

    def retrieve_associations(self, key):
        """
        Retrieves all associations related to a given key.
        :param key: The context or topic to retrieve associations for.
        :return: List of associated information.
        """
        return self.memory.get(key, [])

# === Step 2: Define the Optimized MCTS Algorithm ===
class MCTSTreeNode:
    """
    Represents a node in the Monte Carlo Tree Search (MCTS).
    """
    def __init__(self, name, parent=None):
        self.name = name  # Name of the node (e.g., reasoning step)
        self.parent = parent  # Parent node
        self.children = []  # Child nodes
        self.visits = 0  # Number of visits during search
        self.value = 0  # Value of the node (quality of reasoning)

    def add_child(self, child_node):
        """
        Adds a child node to the current node.
        :param child_node: The child node to add.
        """
        self.children.append(child_node)

def uct_score(node, exploration_weight=1.41):
    """
    Calculates the UCT score for a node in MCTS.
    :param node: The node to calculate the score for.
    :param exploration_weight: Weight for exploration vs exploitation.
    :return: UCT score.
    """
    if node.visits == 0:
        return float('inf')  # Encourage exploration of unvisited nodes
    exploitation = node.value / node.visits
    exploration = exploration_weight * (node.parent.visits ** 0.5) / (1 + node.visits)
    return exploitation + exploration

def select_best_child(node):
    """
    Selects the best child node based on UCT score.
    :param node: The parent node.
    :return: The best child node.
    """
    return max(node.children, key=lambda child: uct_score(child))

def expand_node(node, possible_actions):
    """
    Expands a node by adding new child nodes based on possible actions.
    :param node: The node to expand.
    :param possible_actions: List of possible actions (reasoning steps).
    """
    for action in possible_actions:
        child_node = MCTSTreeNode(name=action, parent=node)
        node.add_child(child_node)

def simulate(node, associative_memory, query):
    """
    Simulates a reasoning process from the given node.
    :param node: The starting node for simulation.
    :param associative_memory: The associative memory to use for reasoning.
    :param query: The input query for reasoning.
    :return: Simulated result (a simple score for demonstration).
    """
    # Retrieve relevant associations from memory
    associations = associative_memory.retrieve_associations(query)
    if associations:
        # Use the first association to simulate a reasoning step
        simulated_result = len(associations)  # Higher score for more associations
    else:
        simulated_result = random.randint(1, 5)  # Random score if no associations
    return simulated_result

def backpropagate(node, result):
    """
    Backpropagates the result up the tree to update node values and visit counts.
    :param node: The node to start backpropagation from.
    :param result: The result to propagate.
    """
    while node is not None:
        node.visits += 1
        node.value += result
        node = node.parent

# === Step 3: Implement the CoAT Framework ===
class CoATFramework:
    """
    Simplified implementation of the Chain-of-Associated-Thoughts (CoAT) framework.
    """
    def __init__(self):
        self.associative_memory = AssociativeMemory()
        self.root = MCTSTreeNode(name="Root")  # Root node of the reasoning tree

    def reason(self, query, max_depth=5, iterations=10):
        """
        Performs reasoning using the CoAT framework.
        :param query: The input query for reasoning.
        :param max_depth: Maximum depth of the reasoning tree.
        :param iterations: Number of MCTS iterations.
        :return: The best reasoning path found.
        """
        for _ in range(iterations):
            # Selection: Traverse the tree to find the best leaf node
            current_node = self.root
            depth = 0
            while current_node.children and depth < max_depth:
                current_node = select_best_child(current_node)
                depth += 1

            # Expansion: Add new child nodes if not at max depth
            if depth < max_depth:
                possible_actions = [f"Step {i}" for i in range(1, 4)]  # Example actions
                expand_node(current_node, possible_actions)
                current_node = random.choice(current_node.children)  # Choose one child

            # Simulation: Simulate reasoning from the selected node
            result = simulate(current_node, self.associative_memory, query)

            # Backpropagation: Update the tree with the simulation result
            backpropagate(current_node, result)

        # Find the best reasoning path
        best_path = []
        current_node = self.root
        while current_node.children:
            current_node = select_best_child(current_node)
            best_path.append(current_node.name)
        return best_path

# === Step 4: Demonstrate the CoAT Framework ===
if __name__ == "__main__":
    # Initialize the CoAT framework
    coat = CoATFramework()

    # Add some associative memory entries
    coat.associative_memory.add_association("AI", "Artificial Intelligence")
    coat.associative_memory.add_association("AI", "Machine Learning")
    coat.associative_memory.add_association("AI", "Deep Learning")

    # Perform reasoning on a query
    query = "AI"
    reasoning_path = coat.reason(query, max_depth=3, iterations=20)

    # Display the reasoning tree and best path
    print("Reasoning Tree:")
    for pre, _, node in RenderTree(coat.root):
        print(f"{pre}{node.name} (Visits: {node.visits}, Value: {node.value})")

    print("\nBest Reasoning Path:")
    print(" -> ".join(reasoning_path))

Collecting anytree
  Downloading anytree-2.12.1-py3-none-any.whl.metadata (8.1 kB)
Downloading anytree-2.12.1-py3-none-any.whl (44 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/44.9 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.9/44.9 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: anytree
Successfully installed anytree-2.12.1
Reasoning Tree:
Root (Visits: 20, Value: 60)
├── Step 1 (Visits: 7, Value: 21)
│   ├── Step 1 (Visits: 3, Value: 9)
│   │   ├── Step 1 (Visits: 1, Value: 3)
│   │   ├── Step 2 (Visits: 1, Value: 3)
│   │   └── Step 3 (Visits: 1, Value: 3)
│   ├── Step 2 (Visits: 2, Value: 6)
│   │   ├── Step 1 (Visits: 1, Value: 3)
│   │   ├── Step 2 (Visits: 1, Value: 3)
│   │   └── Step 3 (Visits: 0, Value: 0)
│   └── Step 3 (Visits: 2, Value: 6)
│       ├── Step 1 (Visits: 0, Value: 0)
│       ├── Step 2 (Visits: 0, Value: 0)
│       └── Step 3 (Visits: 1, Val