In [4]:
import numpy as np

# TD(0) algorithm implementation
def td_zero(env, policy, alpha=0.1, gamma=0.9, num_episodes=100):
    """
    TD(0) algorithm for estimating the value function of a given policy.

    Args:
        env: Environment with states, actions, and rewards.
        policy: Policy to evaluate (function that takes a state and returns an action).
        alpha: Step size (learning rate).
        gamma: Discount factor for future rewards.
        num_episodes: Number of episodes to run.

    Returns:
        V: Estimated state-value function.
    """
    # Initialize value function arbitrarily, V(s) = 0 for all states
    V = {state: 0 for state in env.states}

    for episode in range(num_episodes):
        # Start a new episode
        state = env.reset()  # Initialize the starting state
        
        while not env.is_terminal(state):
            # Select an action based on the given policy
            action = policy(state)
            
            # Take the action, observe reward and next state
            reward, next_state = env.step(state, action)
            
            # Ensure next_state exists in V (handle unseen states dynamically)
            if next_state not in V:
                V[next_state] = 0  # Initialize unseen states with a default value
            
            # Update the value function using TD(0) formula
            V[state] += alpha * (reward + gamma * V[next_state] - V[state])
            
            # Move to the next state
            state = next_state

    return V


class ExampleEnv:
    def __init__(self):
        self.states = ['A', 'B', 'C', 'terminal']
        self.terminal_state = 'terminal'

    def reset(self):
        return 'A'  # Starting state

    def step(self, state, action):
        transitions = {
            'A': ('B', 1),  # 次状態 'B', 報酬 1
            'B': ('C', 2),  # 次状態 'C', 報酬 2
            'C': ('terminal', 0)  # 終端状態 'terminal', 報酬 0
        }
        if state in transitions:
            return transitions[state][1], transitions[state][0]  # 報酬, 次状態
        else:
            raise ValueError(f"Invalid state: {state}")

    def is_terminal(self, state):
        return state == self.terminal_state


def simple_policy(state):
    # Example policy: always return a single predefined action
    return 'move'

# Create environment and policy
env = ExampleEnv()
policy = simple_policy

# Run TD(0) algorithm
value_function = td_zero(env, policy)
print("Estimated value function:")
print(value_function)


Estimated value function:
{'A': 2.7993944001053634, 'B': 1.9999468772022246, 'C': 0.0, 'terminal': 0}
