In [None]:
import numpy as np

# Grid World environment
class GridWorld:
    def __init__(self, grid_size):
        self.grid_size = grid_size
        self.state = (0, 0)  # Start state (top-left)
        self.terminal_state = (grid_size - 1, grid_size - 1)
        self.actions = ['up', 'down', 'left', 'right']

    def reset(self):
        self.state = (0, 0)
        return self.state

    def step(self, action):
        x, y = self.state
        if action == 'up':
            x = max(x - 1, 0)
        elif action == 'down':
            x = min(x + 1, self.grid_size - 1)
        elif action == 'left':
            y = max(y - 1, 0)
        elif action == 'right':
            y = min(y + 1, self.grid_size - 1)

        self.state = (x, y)
        reward = -1 if self.state != self.terminal_state else 0
        done = self.state == self.terminal_state
        return self.state, reward, done

def td_learning(env, episodes, alpha, gamma):
    grid_size = env.grid_size
    state_values = np.zeros((grid_size, grid_size))  # Initialize state-value function

    for episode in range(episodes):
        state = env.reset()
        while True:
            action = np.random.choice(env.actions)
            next_state, reward, done = env.step(action)

            # Update state-value function using TD(0)
            x, y = state
            nx, ny = next_state
            td_target = reward + gamma * state_values[nx, ny]  # Bootstrapping with next state's value
            td_error = td_target - state_values[x, y]
            state_values[x, y] += alpha * td_error

            if done:
                break

            state = next_state

    return state_values

# Parameters
grid_size = 4
episodes = 1000
alpha = 0.1  # Learning rate
gamma = 0.9  # Discount factor

env = GridWorld(grid_size)
state_values = td_learning(env, episodes, alpha, gamma)

print("Learned State-Value Function:")
print(np.round(state_values, 2))


Learned State-Value Function:
[[-9.26 -9.1  -8.73 -8.42]
 [-9.09 -8.73 -8.44 -7.83]
 [-8.76 -8.16 -6.49 -4.78]
 [-8.23 -6.75 -3.79  0.  ]]
