In [None]:
# Path setup / Imports

import sys
repo_root = "../"
if repo_root not in sys.path:
    sys.path.insert(0, repo_root)

from typing import Callable, List, Self
import numpy as np

In [None]:
# Define a Node class

from node import Node

class NumberNode(Node):
    """Node class that stores its number as a string."""
    def __init__(self, number: str, parent=None, last_action: int=None, last_prob=1.0):
        self.number = number
        super().__init__(parent=parent, last_prob=last_prob, last_action=last_action)

    # Define the reward condition
    def has_reward(self) -> bool:
        # Empty string ("") and strings like "0X...X" are incomplete as numbers, and don't have reward.
        if self.number == "":
            return False
        elif self.number.startswith("0") and not self.number == "0":
            return False
        else:
            return True
    
    # Keys of the generated nodes will be recorded
    def key(self) -> str:
        return self.number
    
    # Create a Node instance from a key
    @classmethod
    def node_from_key(cls, key: str, parent=None, last_action: int=None, last_prob=1.0) -> Self:
        return NumberNode(key, parent=parent, last_prob=last_prob, last_action=last_action)

In [None]:
# Define a Transition class

from transition import Transition

class NumberTransition(Transition):
    """Transition class for NumberNode that randomly prepends a number to the beginning of the string."""
    
    # It is recommended to handle settings in the __init__() method for YAML compatibility.
    def __init__(self, numbers: list[int]=[0,1,2,3,4,5,6,7,8,9]):
        self.numbers = numbers
        
    def next_nodes(self, node: NumberNode) -> list[Node]:
        chilren = [] # Note: next_nodes() should return an empty list ([]) if no transitions are available from the node
        for n in self.numbers:
            next_number = str(n) + node.number
            action = n
            prob = 1 / len(self.numbers)
            # The transition probability (and the action label, if any) should be stored in the node instance.
            child = NumberNode(number=next_number, parent=node, last_action=action, last_prob=prob)
            chilren.append(child)
        return chilren

In [None]:
# Define a Reward class

from reward import Reward

class NumberReward(Reward):
    """Reward class for NumberNode that returns a higher value for numbers with higher multiplicity of the given factors, and with fewer digits."""
    def __init__(self, factor: int=5, scale: float=1.0):
        self.factor = factor
        self.scale = scale    

    # Return objective functions of the node; each function returns an objective value.
    def objective_functions(self) -> List[Callable[[NumberNode], float]]:
        def factor_count(node: Node) -> float:
            n = int(node.number)
            count = 0
            while n % self.factor == 0 and n > 0:
                n //= self.factor
                count += 1
            return count
        
        def length(node: Node) -> float:
            return len(node.number)
        
        return [factor_count, length]
    
    # Compute the final reward based on the objective values calculated by objective_functions().
    def reward_from_objective_values(self, objective_values: List[float]) -> float:
        two_factor_count = objective_values[0]
        length = objective_values[1]
        return np.tanh((two_factor_count - length) / self.scale)

In [None]:
# Run the MCTS generation using the components defined above
# (To enable YAML-based loading, the class definition should be placed in the corresponding module directory and registered in the __init__.py file.)

from datetime import datetime
import os
from generator import MCTS
from policy import UCT
from utils import make_logger

# Set output directory and logger
output_dir=os.path.join(repo_root, "sandbox/generation_result", datetime.now().strftime("%m-%d_%H-%M")) + os.sep
logger = make_logger(output_dir)

# Define components
root = NumberNode.node_from_key("") # root node
transition = NumberTransition()
reward = NumberReward(factor=5, scale=1.0)
policy = UCT(c=0.1, best_rate=0.5)

# Define generator
generator = MCTS(root=root, transition=transition, reward=reward, policy=policy, filters=None, info_interval=100, output_dir=output_dir, logger=logger)

# Run generator (generation results will be saved in the output directory)
generator.generate(max_generations=2000)

In [None]:
# Analyze and visualize the result

generator.analyze()
generator.plot(moving_average_window=0.1, reward_top_ps=[0.1, 0.5]) # This also save figs to the output directory