In [1]:
# __future__ import should always be first
from __future__ import annotations

# Standard library imports
from collections import defaultdict

# Third-party imports
import torch
import numpy as np
import torch.nn as nn
from torch.distributions.categorical import Categorical

# Gymnasium & Minigrid imports
import gymnasium as gym  # Correct way to import Gymnasium
from minigrid.core.constants import COLOR_NAMES
from minigrid.core.constants import DIR_TO_VEC
from minigrid.core.grid import Grid
from minigrid.core.actions import Actions
from minigrid.core.mission import MissionSpace
from minigrid.core.world_object import Door, Goal, Key, Wall
from minigrid.manual_control import ManualControl
from minigrid.minigrid_env import MiniGridEnv
from gymnasium.utils.play import play

# Visualization imports
import matplotlib.pyplot as plt


# Environment Set Up

In [625]:
# turning head direction turned off, but code is glitchy
class SimpleEnv(MiniGridEnv):
    def __init__(self, size=10, agent_start_pos=(1, 8), agent_start_dir=0, max_steps=256, **kwargs):
        self.agent_start_pos = agent_start_pos
        self.agent_start_dir = agent_start_dir
        self.reached_goal = False
        self.step_count = 0
        
        

        mission_space = MissionSpace(mission_func=self._gen_mission)

        super().__init__(
            mission_space=mission_space,
            grid_size=size,
            max_steps=max_steps,
            **kwargs,
        )

        # Only allow Forward, Left+Forward, Right+Forward
        self.action_space = gym.spaces.Discrete(3)  # 3 actions

    @staticmethod
    def _gen_mission():
        return "Find the shortest path"

    def _gen_grid(self, width, height):
        self.grid = Grid(width, height)
        self.grid.wall_rect(0, 0, width, height)

        self.put_obj(Goal(), 8, 1)

        for i in range(1, width // 2):
            self.grid.set(i, width - 4, Wall())
            self.grid.set(i + width // 2 - 1, width - 7, Wall())

        if self.agent_start_pos is not None:
            self.agent_pos = self.agent_start_pos #check this
            
        else:
            self.place_agent()


    
    def reset(self, **kwargs):
        obs = super().reset(**kwargs)
        self.reached_goal = False
        self.agent_dir = 0
        return obs
    

    def get_view_exts(self, agent_view_size=None):
        """Override default view extensions to bypass direction checks."""
        agent_view_size = agent_view_size or self.agent_view_size
        topX = self.agent_pos[0] - agent_view_size // 2
        topY = self.agent_pos[1] - agent_view_size // 2
        botX = topX + agent_view_size
        botY = topY + agent_view_size
        return topX, topY, botX, botY
    
    
    
    @property
    def dir_vec(self):
        """Override MiniGrid's default direction vector."""
        return np.array([0, 1])  # Always move upwards by default
    
    
    
    
    def count_states(self):
        free_cells = sum(
            1 for x in range(self.grid.width)
            for y in range(self.grid.height)
            if not self.grid.get(x, y)
        )
        return free_cells  

    def step(self, action):
        """Modify step function to ensure correct movement (no diagonal jumps, no wall clipping)."""
        self.agent_dir = 0  # ✅ Keep agent direction fixed

        # ✅ Define movement vectors: (dx, dy)
        move_vectors = {
            0: (0, -1),   # Move Forward (UP)
            1: (-1, 0),   # Move Left
            2: (1, -1)     # Move Right
        }

        # Get movement vector for action
        move_vector = move_vectors.get(action, (0, 0))  # Default: no movement if invalid action
        
        # Compute the new position
        new_x = self.agent_pos[0] + move_vector[0]
        new_y = self.agent_pos[1] + move_vector[1]

        # ✅ Ensure movement respects grid boundaries & walls
        if (0 <= new_x < self.grid.width) and (0 <= new_y < self.grid.height):
            cell_contents = self.grid.get(new_x, new_y)  # Check what's in the new position

            if cell_contents is None or isinstance(cell_contents, Goal):  
                # ✅ Only update if the move is valid
                self.agent_pos = (new_x, new_y)  
   

        # ✅ Call MiniGrid's original step function
        obs, reward, terminated, truncated, info = super().step(action)

        # ✅ Ensure agent never moves out of bounds
        x, y = self.agent_pos
        x = max(0, min(x, self.grid.width - 1))
        y = max(0, min(y, self.grid.height - 1))
        self.agent_pos = (x, y)

        # ✅ Check if goal is reached
        if np.array_equal(self.agent_pos, (8, 1)):  
            self.reached_goal = True

        reward = 0 if self.reached_goal else -1  # Assign reward
        return obs, reward, terminated, truncated, info



# 3 Action with Directions

In [170]:

class SimpleEnv(MiniGridEnv):
    def __init__(
            self, 
            size=10, 
            agent_start_pos=(1, 8), 
            agent_start_dir=0, 
            max_steps=256, 
            **kwargs,
    ):
        self.agent_start_pos = agent_start_pos
        self.agent_start_dir = agent_start_dir
        self.goal_pos = (8, 1)
        
        
        
        mission_space = MissionSpace(mission_func=self._gen_mission)

        super().__init__(
            mission_space=mission_space,
            grid_size=size,
            max_steps=max_steps,
            **kwargs,
        )

        self.action_space = gym.spaces.Discrete(3)
    @staticmethod
    def _gen_mission():
        return "Find the shortest path"

    def _gen_grid(self, width, height):
        #create gird
        self.grid = Grid(width, height)
        #place barrier
        self.grid.wall_rect(0, 0, width, height)
        #place goal
        self.put_obj(Goal(), 8, 1)
        #place walls
        for i in range(1, width // 2):
            self.grid.set(i, width - 4, Wall())
            self.grid.set(i + width // 2 - 1, width - 7, Wall())
        #place agent
        if self.agent_start_pos is not None:
            self.agent_pos = self.agent_start_pos #check this
            self.agent_dir = self.agent_start_dir
        else:
            self.place_agent()

        self.mission = "find the shortest path"
    
    def count_states(self):
        free_cells = sum(1 for x in range(self.grid.width)
                      for y in range(self.grid.height)
                      if not self.grid.get(x, y)) * 4
        return free_cells 


# Manual Environment Testing

In [173]:
env = SimpleEnv(render_mode="human") 

In [175]:
env = SimpleEnv(render_mode="human")  # Create environment
free_energy_solver = FreeEnergyMin(env, beta=0.5)
obs = env.reset()[0]  # Reset environment and get initial state
env.render()  # Display initial state


Number of states: 400
Number of actions: 3
policy shape: (400, 3)
marginal action distribution shape: (3,)


In [15]:
env.count_states()

220

In [177]:
obs, reward, done, truncated, info = env.step(0)  #turn left
env.render()

In [111]:
obs, reward, done, truncated, info = env.step(1) #turn right
env.render()

In [123]:
obs, reward, done, truncated, info = env.step(2) #move forward
env.render()

In [94]:
free_energy_solver.position_to_state_index(obs)

np.int64(328)

In [96]:
free_energy_solver.state_index_to_position(328)

(2, 8, 0)

In [125]:
obs

{'image': array([[[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0],
         [2, 5, 0],
         [1, 0, 0]],
 
        [[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0],
         [2, 5, 0],
         [1, 0, 0]],
 
        [[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0],
         [2, 5, 0],
         [1, 0, 0]],
 
        [[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0],
         [2, 5, 0],
         [1, 0, 0]],
 
        [[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0],
         [2, 5, 0],
         [2, 5, 0]],
 
        [[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0]],
 
        [[0, 0, 0],
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0],
         [0, 0, 0],
         [0, 

In [181]:
#test env
env = SimpleEnv(render_mode=None) 
env.reset()
# free_energy_solver = FreeEnergyMin(env, beta=0.5)
# free_energy_solver.estimate_transitions() #for when you want to do a random walk, no learning


({'image': array([[[1, 0, 0],
          [1, 0, 0],
          [1, 0, 0],
          [0, 0, 0],
          [0, 0, 0],
          [0, 0, 0],
          [0, 0, 0]],
  
         [[1, 0, 0],
          [1, 0, 0],
          [1, 0, 0],
          [2, 5, 0],
          [2, 5, 0],
          [2, 5, 0],
          [2, 5, 0]],
  
         [[1, 0, 0],
          [1, 0, 0],
          [1, 0, 0],
          [1, 0, 0],
          [1, 0, 0],
          [1, 0, 0],
          [1, 0, 0]],
  
         [[1, 0, 0],
          [1, 0, 0],
          [1, 0, 0],
          [1, 0, 0],
          [1, 0, 0],
          [1, 0, 0],
          [1, 0, 0]],
  
         [[2, 5, 0],
          [2, 5, 0],
          [2, 5, 0],
          [2, 5, 0],
          [2, 5, 0],
          [2, 5, 0],
          [2, 5, 0]],
  
         [[0, 0, 0],
          [0, 0, 0],
          [0, 0, 0],
          [0, 0, 0],
          [0, 0, 0],
          [0, 0, 0],
          [0, 0, 0]],
  
         [[0, 0, 0],
          [0, 0, 0],
          [0, 0, 0],
          [0, 0, 0],
 

# Define a Policy Based On Information-to-Go Method

In [76]:
env.action_space.n * env.count_states() #number of actions * number of states

np.int64(660)

In [77]:
env.count_states()

220

In [78]:
free_energy_solver.num_states

np.int64(220)

In [34]:
free_energy_solver.num_actions

np.int64(3)

In [85]:
env.agent_pos

(1, 8)

In [86]:
free_energy_solver.num_states

np.int64(220)

In [88]:
class FreeEnergyMin:
    """Free Energy Minimization to find optimal policy in MiniGrid Environment."""

    def __init__(
            self, 
            env, 
            beta=0.5
    ):
        
        self.env = env  # MiniGrid environment
        self.num_states = np.int64(env.count_states())  # Number of states
        print(f"Number of states: {self.num_states}")
        self.num_actions = env.action_space.n  # 3 actions (Forward, Left+Forward, Right+Forward)
        print(f"Number of actions: {self.num_actions}")
        self.beta = beta  # Temperature parameter

        # Initialize policy (uniform distribution)
        self.Pi_a_s = np.full((self.num_states, self.num_actions), 1 / self.num_actions) 
        print(f"policy shape: {self.Pi_a_s.shape}")
        self.Pi_a = np.full(self.num_actions, 1 / self.num_actions)
        print(f"marginal action distribution shape: {self.Pi_a.shape}")
        

        
    
    def position_to_state_index(self, state):
        """Converts (x, y, direction) into a unique state index."""
        grid_width = self.env.grid.width
        grid_height = self.env.grid.height
        x, y = self.env.agent_pos  
        direction = self.env.agent_dir  

        return np.int64((y * grid_width + x) * 4 + direction)

    def state_index_to_position(self, state_idx):
        """Converts a 1D state index back into (x, y, direction)."""
        grid_width = self.env.grid.width
        
        direction = state_idx % 4
        linear_idx = state_idx // 4

        y, x = divmod(linear_idx, grid_width)  # Convert to (x, y)
        
        return x, y, direction
    
    def estimate_transitions(self, episodes=10):
        """Simulate environment for debugging with random agent movements."""
        for _ in range(1, episodes + 1):
            state, _ = self.env.reset()  
            

            for step in range(self.env.max_steps):
                action = self.env.action_space.sample() 
                print(f"Step {step}: Action {action} taken.")
                next_state, _, done, _, _ = self.env.step(action)  
                
                #Convert states to indices for debugging
                s_idx = self.position_to_state_index(state)
                s_next_idx = self.position_to_state_index(next_state)
                
                state = next_state  #Update current state
                
                if done:
                    
                    break  # Stop if episode ends

        print("Finished testing environment.")

    def compute_free_energy(self, num_iterations=500):
        """Iteratively update Free Energy and optimize the policy."""
        states = self.num_states
        actions = self.num_actions

        print(f'Number of states: {states}')
        print(f'Number of actions: {actions}')

        # Initialize transition probabilities (deterministic)
        P_s_given_sa = np.zeros((self.num_states, self.num_actions, self.num_states))

        # Assume uniform state distribution initially
        P_s = np.full(self.num_states, 1 / self.num_states)

        # Initialize Free Energy
        self.F = np.zeros((self.num_states, self.num_actions))

        free_energy = []

        for iteration in range(num_iterations):
            prev_F = np.sum(self.F)

            for state_idx in range(states):  # Loop over all states
                x, y, direction = self.state_index_to_position(state_idx)

                # Ensure the agent is placed in a valid position
                if self.env.grid.get(x, y) is not None and not isinstance(self.env.grid.get(x, y), Goal):
                    print(f"Invalid state {x}, {y} (occupied). Trying another.")
                    continue
                    

                self.env.place_agent((x, y))
                self.env.agent_dir = direction
                

                for a in range(actions):  # Loop over actions
                    print(f"Processing state {state_idx}, action {a}")

                    # Execute action
                    action = np.argmax(self.Pi_a_s[state_idx])  # Select action based on policy
                    next_state, _, done, _, _ = self.env.step(action)
                    s_next = self.position_to_state_index(next_state) #convert position to next state

                    # If terminal, break
                    if done:
                        break

                    # Sample next action from policy
                    next_action = np.argmax(self.Pi_a_s[s_next])
                    next_state_1, _, _, _, _ = self.env.step(next_action)
                    s_next_next = self.position_to_state_index(next_state_1)

                    # **Deterministic Transition Probability**
                    P_s_given_sa[state_idx, action, s_next] = 1

                    # Compute expectation term (log ratio of policy)
                    J = np.sum(np.log(np.maximum(self.Pi_a_s[s_next, :] / np.maximum(self.Pi_a, 1e-10), 1e-10)))

                    # Define reward function
                    reward = 0 if self.env.reached_goal else -1

                    # Update Free Energy functional
                    self.F[state_idx, a] = (
                        np.sum(P_s_given_sa[state_idx, a, :] * np.log(np.maximum(P_s_given_sa[state_idx, a, :] / P_s, 1e-10)))
                        - self.beta * reward + J
                    )

            # Compute partition function Zπ(s, β)
            Z = np.sum(self.Pi_a[None, :] * np.exp(-self.F / (self.beta + 1e-10)), axis=1, keepdims=True) + 1e-5

            # Update policy π(a|s)
            self.Pi_a_s = (self.Pi_a[None, :] / Z) * np.exp(-self.F)

            # Update marginal π(a)
            self.Pi_a = np.sum(self.Pi_a_s * P_s[:, None], axis=0) + 1e-10
            self.Pi_a /= np.sum(self.Pi_a)

            # Compute max policy change
            max_policy_change = np.max(np.abs(self.Pi_a_s - self.Pi_a[None, :]))

            # Print status
            print(f"Iteration {iteration}: Free Energy Sum: {np.sum(self.F)}, Change: {prev_F - np.sum(self.F)}")
            print(f"Iteration {iteration}: Max Policy Change = {max_policy_change}")

            free_energy.append(np.sum(self.F))
            
            # Convergence check
            if max_policy_change < 1e-5:
                print(f"Converged at iteration {iteration}.")
                break
            
            
        
        policy = self.Pi_a_s

        return policy, free_energy 


    def run_policy(self, policy):
        """Run the environment using the learned policy. Convert state in policy to its position using """
        
        start_pos = self.env.reset()[0] #reset the environment and get initial state
        state_idx = self.position_to_state_index(start_pos) #convert starting position to index
        done = False
        while not done:
            x, y, direction = self.state_index_to_position(state_idx) 
            self.env.place_agent((x, y))
            self.env.agent_dir = direction
            action = np.argmax(policy[state_idx])  # Choose best action from policy
            position, _, done, _, _ = self.env.step(action) #take the action
            state_idx = self.position_to_state_index(position) #use this next position , decode the state
            self.env.render()
            


    def plot_free_energy(self, free_energy):
        """Plot the free energy over iterations."""
        

        plt.figure(figsize=(8, 5))
        plt.plot(free_energy, marker='o', linestyle='-')
        plt.xlabel("Iteration")
        plt.ylabel("Sum of Free Energy")
        plt.title("Free Energy Minimization Over Iterations")
        plt.grid()
        plt.show()

# Q learning

assume that -bR + F(a,s,b) is basically like Q learning

In [187]:
class Qlearning:
    def __init__(
        self,
        env,
        learning_rate = 0.8,
        discount_factor = 0.9,
        epsilon = 0.2,
        epochs = 1000
         
    ):
        self.env = env # MiniGrid environment called from the class Minigrid
        self.num_states = np.int64(env.count_states())
        self.num_actions = env.action_space.n
        self.Q_table = np.zeros((self.num_states, self.num_actions)) # Q table
        self.learning_rate = learning_rate
        self.discount_factor = discount_factor # discount factor
        self.epsilon =  epsilon  # exploration probability
        self.epochs = epochs

     # Store state indexes at initialization
        self.state_indexes_list = self.find_state_indexes(env)
        print(f'state indexes list: {self.state_indexes_list}')
     # making a dictionary to convert state index to index
        self.state_to_index = {state: i for i, state in enumerate(self.state_indexes_list)}
        print(f'state to index: {self.state_to_index}')
        
    # def position_to_state_index(self, state):
    #     """Converts (x, y, direction) into a unique state index."""
    #     grid_width = self.env.grid.width
        
    #     x, y, direction  = state  
        

    #     return np.int64((y * grid_width + x) * 4 + direction)
    
    def position_to_state_index(self, obs):
        """Converts an observation (either (x, y, direction) or an image-based dict) into a state index."""
        grid_width = self.env.grid.width
        
        # Case 1: Observation is a dictionary (image-based)
        if isinstance(obs, dict) and 'image' in obs:
            x, y = self.env.agent_pos  # Extract (x, y) directly from the environment
            direction = self.env.agent_dir  # Extract direction
        # Case 2: Observation is already in (x, y, direction) format
        elif isinstance(obs, tuple) and len(obs) == 3:
            x, y, direction = obs
        else:
            raise ValueError(f"Invalid observation format: {obs}")

        # Convert (x, y, direction) to a unique state index
        return np.int64((y * grid_width + x) * 4 + direction)
    
    def state_index_to_position(self, state_idx):
        """Converts a scalar state index back into (x, y, direction)."""
        grid_width = self.env.grid.width
        
        direction = state_idx % 4
        linear_idx = state_idx // 4

        y, x = divmod(linear_idx, grid_width)  # Convert to (x, y)
        
        return x, y, direction
    
    def find_state_indexes(self, env):
        state_indexes_list = []
        for x in range(env.grid.width):
            for y in range(env.grid.height):
                if env.grid.get(x, y) is None: #grabs all empty spaces
                    for direction in range(4):
                        state_index = self.position_to_state_index((x, y, direction))
                        state_indexes_list.append(state_index)
        return state_indexes_list        
    
    
    def train(self, epochs):
        goal_states = [self.position_to_state_index((8, 1, d)) for d in range(4)] #goal state index
        
        
        for epoch in range(epochs):
            current_state = np.random.choice(self.state_indexes_list)

            while current_state not in goal_states:
                #replace current index with its table index
                table_index = self.state_to_index[current_state]
                # Epsilon-greedy action selection
                if np.random.rand() < self.epsilon:
                    action = self.env.action_space.sample()
                    print(f"random action: {action}, exploring")
                else:
                    action = np.argmax(self.Q_table[table_index])
                    print(f"greedy action: {action}, exploiting")

                #transition to the next state
                next_obs, _, done, _, _ = self.env.step(action)
                # print(f"next_observation: {next_obs}")
                next_state = self.position_to_state_index(next_obs)
                print(f"next_state: {next_state}")

                if next_state in goal_states:
                    print(f'reached goal state: {next_state}, not in q table, skipping')
                    break

                if next_state not in self.state_to_index:
                    continue
                # Convert sparse index to dense index
                next_table_index = self.state_to_index[next_state]

                #reward (-1 for each step thats not the gaol)
                reward = 0 if next_state in goal_states else -1

                #Q valye update rule
                self.Q_table[table_index, action] += self.learning_rate * (reward + self.discount_factor * np.max(self.Q_table[next_table_index]) - self.Q_table[table_index, action])

                #update state
                current_state = next_state
        
        
    def run_policy(self):
        """Run the environment using the learned policy from a Q table. Convert state in policy to its position using """
        
        start_pos = self.env.reset()[0] #reset the environment and get initial state
        state_idx = self.position_to_state_index(start_pos) #convert starting position to index
        done = False
        while not done:
            x, y, direction = self.state_index_to_position(state_idx) 
            self.env.place_agent((x, y))
            self.env.agent_dir = direction
            table_index = self.state_to_index[state_idx]
            action = np.argmax(self.Q_table[table_index])  # Choose best action from policy
            position, _, done, _, _ = self.env.step(action) #take the action
            state_idx = self.position_to_state_index(position) #use this next position , decode the state
            self.env.render()

        





In [190]:
q_training = Qlearning(env)
q_training.train(1000)


state indexes list: [np.int64(44), np.int64(45), np.int64(46), np.int64(47), np.int64(84), np.int64(85), np.int64(86), np.int64(87), np.int64(124), np.int64(125), np.int64(126), np.int64(127), np.int64(164), np.int64(165), np.int64(166), np.int64(167), np.int64(204), np.int64(205), np.int64(206), np.int64(207), np.int64(284), np.int64(285), np.int64(286), np.int64(287), np.int64(324), np.int64(325), np.int64(326), np.int64(327), np.int64(48), np.int64(49), np.int64(50), np.int64(51), np.int64(88), np.int64(89), np.int64(90), np.int64(91), np.int64(128), np.int64(129), np.int64(130), np.int64(131), np.int64(168), np.int64(169), np.int64(170), np.int64(171), np.int64(208), np.int64(209), np.int64(210), np.int64(211), np.int64(288), np.int64(289), np.int64(290), np.int64(291), np.int64(328), np.int64(329), np.int64(330), np.int64(331), np.int64(52), np.int64(53), np.int64(54), np.int64(55), np.int64(92), np.int64(93), np.int64(94), np.int64(95), np.int64(132), np.int64(133), np.int64(134)

In [191]:
q_training.run_policy()

KeyboardInterrupt: 

Q Learning with Policy Updates each Epoch

In [None]:
class Qlearning:
    def __init__(
        self,
        env,
        learning_rate = 0.8,
        discount_factor = 0.9,
        epsilon = 0.2,
        epochs = 1000
         
    ):
        self.env = env # MiniGrid environment called from the class Minigrid
        self.num_states = np.int64(env.count_states())
        self.num_actions = env.action_space.n
        self.Q_table = np.zeros((self.num_states, self.num_actions)) # Q table
        self.learning_rate = learning_rate
        self.discount_factor = discount_factor # discount factor
        self.epsilon =  epsilon  # exploration probability
        self.epochs = epochs
        self.policy = np.full((self.num_states, self.num_actions) , 1/self.num_actions)

     # Store state indexes at initialization
        self.state_indexes_list = self.find_state_indexes(env)
        print(f'state indexes list: {self.state_indexes_list}')
     # making a dictionary to convert state index to index
        self.state_to_index = {state: i for i, state in enumerate(self.state_indexes_list)}
        print(f'state to index: {self.state_to_index}')
    
    def position_to_state_index(self, obs):
        """Converts an observation (either (x, y, direction) or an image-based dict) into a state index."""
        grid_width = self.env.grid.width
        
        # Case 1: Observation is a dictionary (image-based)
        if isinstance(obs, dict) and 'image' in obs:
            x, y = self.env.agent_pos  # Extract (x, y) directly from the environment
            direction = self.env.agent_dir  # Extract direction
        # Case 2: Observation is already in (x, y, direction) format
        elif isinstance(obs, tuple) and len(obs) == 3:
            x, y, direction = obs
        else:
            raise ValueError(f"Invalid observation format: {obs}")

        # Convert (x, y, direction) to a unique state index
        return np.int64((y * grid_width + x) * 4 + direction)
    
    def state_index_to_position(self, state_idx):
        """Converts a scalar state index back into (x, y, direction)."""
        grid_width = self.env.grid.width
        
        direction = state_idx % 4
        linear_idx = state_idx // 4

        y, x = divmod(linear_idx, grid_width)  # Convert to (x, y)
        
        return x, y, direction
    
    def find_state_indexes(self, env):
        state_indexes_list = []
        for x in range(env.grid.width):
            for y in range(env.grid.height):
                if env.grid.get(x, y) is None: #grabs all empty spaces
                    for direction in range(4):
                        state_index = self.position_to_state_index((x, y, direction))
                        state_indexes_list.append(state_index)
        return state_indexes_list        
    
    
    def train(self, epochs):

        """"""
        goal_states = [self.position_to_state_index((8, 1, d)) for d in range(4)] #goal state index
        
        
        for epoch in range(epochs):
            current_state = np.random.choice(self.state_indexes_list)

            while current_state not in goal_states:
                #replace current index with its table index
                table_index = self.state_to_index[current_state]
                # select an action at random from the policy!
                action = self.env.action_space.sample()
                print(f"random action: {action}")
                else:
                    action = np.argmax(self.Q_table[table_index])
                    print(f"greedy action: {action}, exploiting")

                #transition to the next state
                next_obs, _, done, _, _ = self.env.step(action)
                # print(f"next_observation: {next_obs}")
                next_state = self.position_to_state_index(next_obs)
                print(f"next_state: {next_state}")

                if next_state in goal_states:
                    print(f'reached goal state: {next_state}, not in q table, skipping')
                    continue

                if next_state not in self.state_to_index:
                    continue
                # Convert sparse index to dense index
                next_table_index = self.state_to_index[next_state]

                #reward (-1 for each step thats not the gaol)
                reward = 0 if next_state in goal_states else -1

                #Q valye update rule
                self.Q_table[table_index, action] += self.learning_rate * (reward + self.discount_factor * np.max(self.Q_table[next_table_index]) - self.Q_table[table_index, action])

                #update state
                current_state = next_state
    

        
        
    def run_policy(self):
        """Run the environment using the learned policy from a Q table. Convert state in policy to its position using """
        
        start_pos = self.env.reset()[0] #reset the environment and get initial state
        state_idx = self.position_to_state_index(start_pos) #convert starting position to index
        done = False
        while not done:
            x, y, direction = self.state_index_to_position(state_idx) 
            self.env.place_agent((x, y))
            self.env.agent_dir = direction
            action = np.argmax(self.Q_table[state_idx])  # Choose best action from policy
            position, _, done, _, _ = self.env.step(action) #take the action
            state_idx = self.position_to_state_index(position) #use this next position , decode the state
            self.env.render()

        





# MLP Approach

# Kernel Approach

In [None]:

env = SimpleEnv(render_mode="human")
free_energy_solver = FreeEnergyMin(env, beta=0.5)
# free_energy_solver.estimate_transitions() #for when you want to do a random walk, no learning
free_energy_solver.compute_free_energy()
free_energy_solver.run_policy()


In [None]:
print("Action Space:", env.action_space)

In [None]:
print("Action Space:", env.action_space)