In [96]:
import numpy as np
import IPython

In [97]:
# Map action to direction
action_to_direction = {
    0: 'UP',
    1: 'RIGHT',
    2: 'DOWN',
    3: 'LEFT'
}

In [98]:
# Step 1: Define the environment
class GridWorld:
    def __init__(self, grid_size=4, hole_count=4):
        """Initialize the environment.

        This is a static grid world environment.
        
        Args:
            grid_size (int): size of the grid (default is 4)
            hole_count (int): number of holes in the grid (default is 4)
        """
        self.grid_size = grid_size
        self.state_space = np.arange(grid_size * grid_size)
        self.action_space = np.arange(4) # 0: Up, 1: Right, 2: Down, 3: Left

        # Initialize the grid
        self.hole_count = hole_count
        self.state = 0
        self.grid = np.full((self.grid_size, self.grid_size), 'F') # Fill the grid with frozen blocks
        self.grid[0, 0] = 'S' # Place the start block
        self.grid[-1, -1] = 'G' # Place the goal block

        # Place the hole blocks randomly
        np.random.seed(0)  # For reproducibility
        for _ in range(self.hole_count):
            row, col = np.random.randint(self.grid_size, size=2)
            # Ensure we don't place a hole at the start or goal
            while (row, col) in [(0, 0), (self.grid_size - 1, self.grid_size - 1)]:
                row, col = np.random.randint(self.grid_size, size=2)
            self.grid[row, col] = 'H'

    def _get_grid_position(self):
        """Converts a state number to a position in a 4x4 grid.

        Args:
            state (int): The state number, between 0 and 15 inclusive.

        Returns:
            tuple: The grid position as (row, column).
        """
        row = self.state // self.grid_size
        col = self.state % self.grid_size
        return (row, col)

    def reset(self):
        """Reset the environment to the initial state

        The grid stays the same, only the agent's state is reset to the start block.

        Returns:
            state (int): initial state
        """
        self.state = 0
        return self.state


    # Step 3: Take an action and return the next state and reward
    def step(self, action):
        """Take an action and return the next state and reward

        Args:
            action (int): action to take

        Returns:
            state (int): next state
            reward (int): reward
        """
        row, col = self._get_grid_position()
        if action == 0:  # Up
            row = max(row - 1, 0)
        elif action == 1:  # Right
            col = min(col + 1, self.grid_size - 1)
        elif action == 2:  # Down
            row = min(row + 1, self.grid_size - 1)
        elif action == 3:  # Left
            col = max(col - 1, 0)
            
        
        # Update the state
        next_state = row * self.grid_size + col
        self.state = next_state
        
        # Get reward based on current state
        reward = self._get_reward()

        return next_state, reward

    def _get_reward(self):
        """Return the reward based on the current state.
        
        Returns:
            reward (int): -1 for falling into a hole, 1 for reaching the goal, and 0 otherwise
        """
        row, col = self._get_grid_position()
        if self.grid[row, col] == 'H':  # If the agent falls into a hole
            return -1
        elif self.grid[row, col] == 'G':  # If the agent reaches the goal
            return 1
        else:  # If the agent is on a frozen block
            return 0

env = GridWorld()
env.reset()
print(env.grid)

[['S' 'F' 'F' 'H']
 ['H' 'F' 'H' 'H']
 ['F' 'F' 'F' 'F']
 ['F' 'F' 'F' 'G']]


In [99]:
# Step 4: Define a simple agent
class SimpleAgent:
    def __init__(self, num_states=16, num_actions=4, alpha=0.9, gamma=0.95, epsilon=0.5):
        """Init the agent.
        
        Args:
            num_states (int): number of states. Defaults to 16 (4 x 4 grid).
            num_actions (int): number of actions. Defaults to 4 (left, down, right, up).
            alpha (float, optional): learning rate. Defaults to 0.5.
            gamma (float, optional): discount factor. Defaults to 0.95.
            epsilon (float, optional): exploration rate. Defaults to 0.1.
        """
        self.num_states = num_states
        self.num_actions = num_actions
        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = epsilon

        # Q table: Expected return (not reward) for each state-action pair 
        self.Q = np.zeros((num_states, num_actions))
    
    def get_epsilon_greedy_action(self, state):
        """Pick a random action with probability epsilon, otherwise pick the best action
        
        Args:
            state (int): current state
        """
        if np.random.rand() < self.epsilon:
            return np.random.choice(self.num_actions) # explore
        else:
            # exploit
            # find the indices of the maximum values
            max_indices = np.where(self.Q[state] == np.amax(self.Q[state]))[0]
            # choose randomly from those indices
            return np.random.choice(max_indices)

    
    def update_Q(self, state, action, reward, next_state):
        """Update the Q table.

        The Q-table is the expected return for each state-action pair.

        This is the core mechanics of Q-learning and implements the Bellman optimality
        equation for Q-values. This is a subset of the general Bellman equation.
        
        Reference principle of optimality on why Q learning works. Optimal policy of
        subsequences is also optimal for the original sequence.

        Args:
            state (int): current state
            action (int): action taken
            reward (int): reward received
            next_state (int): next state
        """
        # Calculate temporal difference target (TD target). Expected future returns based
        # off a combination of received reward and the expected future returns from
        # best action in the next state.
        best_next_action = np.argmax(self.Q[next_state])
        td_target = reward + self.gamma * self.Q[next_state][best_next_action]

        # Calculate temporal difference error (TD error). Difference between the
        # implied expected return from received reward and the current Q value prior to
        # update.
        td_error = td_target - self.Q[state][action]

        # Update the Q table value, multiplying the learning rate (alpha) by the error.
        self.Q[state][action] += self.alpha * td_error

In [101]:
# Step 5: Training loop

# Initialize agent and environment
env = GridWorld(grid_size=4, hole_count=4)
print(env.grid)
agent = SimpleAgent(env.state_space.shape[0], env.action_space.shape[0])

# Training parameters
num_episodes = 10000
max_steps_per_episode = 100

for episode in range(num_episodes):
    # Reset the state, env doesn't change
    state = env.reset()
    
    for step in range(max_steps_per_episode):
        action = agent.get_epsilon_greedy_action(state)
        next_state, reward = env.step(action)
        agent.update_Q(state, action, reward, next_state)
        
        state = next_state
        
        if reward == -1 or reward == 1:  # agent fell in a hole or reached the goal
            break
    
    # Print out progress
    if (episode + 1) % 1000 == 0:
        print(f"Episode {episode + 1}/{num_episodes} completed")

print("Training finished.")
print("Q table:")
print(agent.Q)


[['S' 'F' 'F' 'H']
 ['H' 'F' 'H' 'H']
 ['F' 'F' 'F' 'F']
 ['F' 'F' 'F' 'G']]
Episode 1000/10000 completed
Episode 2000/10000 completed
Episode 3000/10000 completed
Episode 4000/10000 completed
Episode 5000/10000 completed
Episode 6000/10000 completed
Episode 7000/10000 completed
Episode 8000/10000 completed
Episode 9000/10000 completed
Episode 10000/10000 completed
Training finished.
Q table:
[[ 0.73509189  0.77378094 -1.          0.73509189]
 [ 0.77378094  0.73509189  0.81450625  0.73509189]
 [ 0.73509189 -1.         -1.          0.77378094]
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]
 [ 0.77378094 -1.          0.857375   -1.        ]
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]
 [-1.          0.857375    0.857375    0.81450625]
 [ 0.81450625  0.9025      0.9025      0.81450625]
 [-1.          0.95        0.95        0.857375  ]
 [-1.          0.95        1.          0.90