In [277]:
# __future__ import should always be first
from __future__ import annotations

# Standard library imports
from collections import defaultdict

# Third-party imports
import torch
import numpy as np
import torch.nn as nn
from torch.distributions.categorical import Categorical

# Gymnasium & Minigrid imports
import gymnasium as gym  # Correct way to import Gymnasium
from minigrid.core.constants import COLOR_NAMES
from minigrid.core.constants import DIR_TO_VEC
from minigrid.core.grid import Grid
from minigrid.core.actions import Actions
from minigrid.core.mission import MissionSpace
from minigrid.core.world_object import Door, Goal, Key, Wall
from minigrid.manual_control import ManualControl
from minigrid.minigrid_env import MiniGridEnv
from gymnasium.utils.play import play
import pandas as pd
# Visualization imports
import matplotlib.pyplot as plt


# Environment Set Up

In [625]:
# turning head direction turned off, but code is glitchy
class SimpleEnv(MiniGridEnv):
    def __init__(self, size=10, agent_start_pos=(1, 8), agent_start_dir=0, max_steps=256, **kwargs):
        self.agent_start_pos = agent_start_pos
        self.agent_start_dir = agent_start_dir
        self.reached_goal = False
        self.step_count = 0
        
        

        mission_space = MissionSpace(mission_func=self._gen_mission)

        super().__init__(
            mission_space=mission_space,
            grid_size=size,
            max_steps=max_steps,
            **kwargs,
        )

        # Only allow Forward, Left+Forward, Right+Forward
        self.action_space = gym.spaces.Discrete(3)  # 3 actions

    @staticmethod
    def _gen_mission():
        return "Find the shortest path"

    def _gen_grid(self, width, height):
        self.grid = Grid(width, height)
        self.grid.wall_rect(0, 0, width, height)

        self.put_obj(Goal(), 8, 1)

        for i in range(1, width // 2):
            self.grid.set(i, width - 4, Wall())
            self.grid.set(i + width // 2 - 1, width - 7, Wall())

        if self.agent_start_pos is not None:
            self.agent_pos = self.agent_start_pos #check this
            
        else:
            self.place_agent()


    
    def reset(self, **kwargs):
        obs = super().reset(**kwargs)
        self.reached_goal = False
        self.agent_dir = 0
        return obs
    

    def get_view_exts(self, agent_view_size=None):
        """Override default view extensions to bypass direction checks."""
        agent_view_size = agent_view_size or self.agent_view_size
        topX = self.agent_pos[0] - agent_view_size // 2
        topY = self.agent_pos[1] - agent_view_size // 2
        botX = topX + agent_view_size
        botY = topY + agent_view_size
        return topX, topY, botX, botY
    
    
    
    @property
    def dir_vec(self):
        """Override MiniGrid's default direction vector."""
        return np.array([0, 1])  # Always move upwards by default
    
    
    
    
    def count_states(self):
        free_cells = sum(
            1 for x in range(self.grid.width)
            for y in range(self.grid.height)
            if not self.grid.get(x, y)
        )
        return free_cells  

    def step(self, action):
        """Modify step function to ensure correct movement (no diagonal jumps, no wall clipping)."""
        self.agent_dir = 0  # ✅ Keep agent direction fixed

        # ✅ Define movement vectors: (dx, dy)
        move_vectors = {
            0: (0, -1),   # Move Forward (UP)
            1: (-1, 0),   # Move Left
            2: (1, -1)     # Move Right
        }

        # Get movement vector for action
        move_vector = move_vectors.get(action, (0, 0))  # Default: no movement if invalid action
        
        # Compute the new position
        new_x = self.agent_pos[0] + move_vector[0]
        new_y = self.agent_pos[1] + move_vector[1]

        # ✅ Ensure movement respects grid boundaries & walls
        if (0 <= new_x < self.grid.width) and (0 <= new_y < self.grid.height):
            cell_contents = self.grid.get(new_x, new_y)  # Check what's in the new position

            if cell_contents is None or isinstance(cell_contents, Goal):  
                # ✅ Only update if the move is valid
                self.agent_pos = (new_x, new_y)  
   

        # ✅ Call MiniGrid's original step function
        obs, reward, terminated, truncated, info = super().step(action)

        # ✅ Ensure agent never moves out of bounds
        x, y = self.agent_pos
        x = max(0, min(x, self.grid.width - 1))
        y = max(0, min(y, self.grid.height - 1))
        self.agent_pos = (x, y)

        # ✅ Check if goal is reached
        if np.array_equal(self.agent_pos, (8, 1)):  
            self.reached_goal = True

        reward = 0 if self.reached_goal else -1  # Assign reward
        return obs, reward, terminated, truncated, info



# 3 Action with Directions

In [288]:

class SimpleEnv(MiniGridEnv):
    def __init__(
            self, 
            size=10, 
            agent_start_pos=(1, 8), 
            agent_start_dir=0, 
            max_steps=256, 
            **kwargs,
    ):
        self.agent_start_pos = agent_start_pos
        self.agent_start_dir = agent_start_dir
        self.goal_pos = (8, 1)
        
        
        
        mission_space = MissionSpace(mission_func=self._gen_mission)

        super().__init__(
            mission_space=mission_space,
            grid_size=size,
            max_steps=max_steps,
            **kwargs,
        )

        self.action_space = gym.spaces.Discrete(3)
    @staticmethod
    def _gen_mission():
        return "Find the shortest path"

    def _gen_grid(self, width, height):
        #create gird
        self.grid = Grid(width, height)
        #place barrier
        self.grid.wall_rect(0, 0, width, height)
        #place goal
        self.put_obj(Goal(), 8, 1)
        #place walls
        for i in range(1, width // 2):
            self.grid.set(i, width - 4, Wall())
            self.grid.set(i + width // 2 - 1, width - 7, Wall())
        #place agent
        if self.agent_start_pos is not None:
            self.agent_pos = self.agent_start_pos #check this
            self.agent_dir = self.agent_start_dir
        else:
            self.place_agent()

        self.mission = "find the shortest path"
    
    def count_states(self):
        free_cells = sum(1 for x in range(self.grid.width)
                      for y in range(self.grid.height)
                      if not self.grid.get(x, y)) * 4
        return free_cells 


# Manual Environment Testing

In [3]:
env = SimpleEnv(render_mode="human") 


In [7]:
env = SimpleEnv(render_mode="human")  # Create environment
# free_energy_solver = FreeEnergyMin(env, beta=0.5)
obs = env.reset()[0]  # Reset environment and get initial state
env.render()  # Display initial state


In [None]:
env.count_states()

In [11]:
obs, reward, done, truncated, info = env.step(0)  #turn left
env.render()

In [111]:
obs, reward, done, truncated, info = env.step(1) #turn right
env.render()

In [123]:
obs, reward, done, truncated, info = env.step(2) #move forward
env.render()

In [None]:
free_energy_solver.position_to_state_index(obs)

In [None]:
free_energy_solver.state_index_to_position(328)

In [24]:
#test env
env = SimpleEnv(render_mode="human") 
# env.reset()
# free_energy_solver = FreeEnergyMin(env, beta=0.5)
# free_energy_solver.estimate_transitions() #for when you want to do a random walk, no learning


# Define a Policy Based On Information-to-Go Method

In [None]:
env.action_space.n * env.count_states() #number of actions * number of states

In [None]:
env.count_states()

In [None]:
free_energy_solver.num_states

In [None]:
free_energy_solver.num_actions

In [None]:
env.agent_pos

In [None]:
free_energy_solver.num_states

In [88]:
class FreeEnergyMin:
    """Free Energy Minimization to find optimal policy in MiniGrid Environment."""

    def __init__(
            self, 
            env, 
            beta=0.5
    ):
        
        self.env = env  # MiniGrid environment
        self.num_states = np.int64(env.count_states())  # Number of states
        print(f"Number of states: {self.num_states}")
        self.num_actions = env.action_space.n  # 3 actions (Forward, Left+Forward, Right+Forward)
        print(f"Number of actions: {self.num_actions}")
        self.beta = beta  # Temperature parameter

        # Initialize policy (uniform distribution)
        self.Pi_a_s = np.full((self.num_states, self.num_actions), 1 / self.num_actions) 
        print(f"policy shape: {self.Pi_a_s.shape}")
        self.Pi_a = np.full(self.num_actions, 1 / self.num_actions)
        print(f"marginal action distribution shape: {self.Pi_a.shape}")
        

        
    
    def position_to_state_index(self, state):
        """Converts (x, y, direction) into a unique state index."""
        grid_width = self.env.grid.width
        grid_height = self.env.grid.height
        x, y = self.env.agent_pos  
        direction = self.env.agent_dir  

        return np.int64((y * grid_width + x) * 4 + direction)

    def state_index_to_position(self, state_idx):
        """Converts a 1D state index back into (x, y, direction)."""
        grid_width = self.env.grid.width
        
        direction = state_idx % 4
        linear_idx = state_idx // 4

        y, x = divmod(linear_idx, grid_width)  # Convert to (x, y)
        
        return x, y, direction
    
    def estimate_transitions(self, episodes=10):
        """Simulate environment for debugging with random agent movements."""
        for _ in range(1, episodes + 1):
            state, _ = self.env.reset()  
            

            for step in range(self.env.max_steps):
                action = self.env.action_space.sample() 
                print(f"Step {step}: Action {action} taken.")
                next_state, _, done, _, _ = self.env.step(action)  
                
                #Convert states to indices for debugging
                s_idx = self.position_to_state_index(state)
                s_next_idx = self.position_to_state_index(next_state)
                
                state = next_state  #Update current state
                
                if done:
                    
                    break  # Stop if episode ends

        print("Finished testing environment.")

    def compute_free_energy(self, num_iterations=500):
        """Iteratively update Free Energy and optimize the policy."""
        states = self.num_states
        actions = self.num_actions

        print(f'Number of states: {states}')
        print(f'Number of actions: {actions}')

        # Initialize transition probabilities (deterministic)
        P_s_given_sa = np.zeros((self.num_states, self.num_actions, self.num_states))

        # Assume uniform state distribution initially
        P_s = np.full(self.num_states, 1 / self.num_states)

        # Initialize Free Energy
        self.F = np.zeros((self.num_states, self.num_actions))

        free_energy = []

        for iteration in range(num_iterations):
            prev_F = np.sum(self.F)

            for state_idx in range(states):  # Loop over all states
                x, y, direction = self.state_index_to_position(state_idx)

                # Ensure the agent is placed in a valid position
                if self.env.grid.get(x, y) is not None and not isinstance(self.env.grid.get(x, y), Goal):
                    print(f"Invalid state {x}, {y} (occupied). Trying another.")
                    continue
                    

                self.env.place_agent((x, y))
                self.env.agent_dir = direction
                

                for a in range(actions):  # Loop over actions
                    print(f"Processing state {state_idx}, action {a}")

                    # Execute action
                    action = np.argmax(self.Pi_a_s[state_idx])  # Select action based on policy
                    next_state, _, done, _, _ = self.env.step(action)
                    s_next = self.position_to_state_index(next_state) #convert position to next state

                    # If terminal, break
                    if done:
                        break

                    # Sample next action from policy
                    next_action = np.argmax(self.Pi_a_s[s_next])
                    next_state_1, _, _, _, _ = self.env.step(next_action)
                    s_next_next = self.position_to_state_index(next_state_1)

                    # **Deterministic Transition Probability**
                    P_s_given_sa[state_idx, action, s_next] = 1

                    # Compute expectation term (log ratio of policy)
                    J = np.sum(np.log(np.maximum(self.Pi_a_s[s_next, :] / np.maximum(self.Pi_a, 1e-10), 1e-10)))

                    # Define reward function
                    reward = 0 if self.env.reached_goal else -1

                    # Update Free Energy functional
                    self.F[state_idx, a] = (
                        np.sum(P_s_given_sa[state_idx, a, :] * np.log(np.maximum(P_s_given_sa[state_idx, a, :] / P_s, 1e-10)))
                        - self.beta * reward + J
                    )

            # Compute partition function Zπ(s, β)
            Z = np.sum(self.Pi_a[None, :] * np.exp(-self.F / (self.beta + 1e-10)), axis=1, keepdims=True) + 1e-5

            # Update policy π(a|s)
            self.Pi_a_s = (self.Pi_a[None, :] / Z) * np.exp(-self.F)

            # Update marginal π(a)
            self.Pi_a = np.sum(self.Pi_a_s * P_s[:, None], axis=0) + 1e-10
            self.Pi_a /= np.sum(self.Pi_a)

            # Compute max policy change
            max_policy_change = np.max(np.abs(self.Pi_a_s - self.Pi_a[None, :]))

            # Print status
            print(f"Iteration {iteration}: Free Energy Sum: {np.sum(self.F)}, Change: {prev_F - np.sum(self.F)}")
            print(f"Iteration {iteration}: Max Policy Change = {max_policy_change}")

            free_energy.append(np.sum(self.F))
            
            # Convergence check
            if max_policy_change < 1e-5:
                print(f"Converged at iteration {iteration}.")
                break
            
            
        
        policy = self.Pi_a_s

        return policy, free_energy 


    def run_policy(self, policy):
        """Run the environment using the learned policy. Convert state in policy to its position using """
        
        start_pos = self.env.reset()[0] #reset the environment and get initial state
        state_idx = self.position_to_state_index(start_pos) #convert starting position to index
        done = False
        while not done:
            x, y, direction = self.state_index_to_position(state_idx) 
            self.env.place_agent((x, y))
            self.env.agent_dir = direction
            action = np.argmax(policy[state_idx])  # Choose best action from policy
            position, _, done, _, _ = self.env.step(action) #take the action
            state_idx = self.position_to_state_index(position) #use this next position , decode the state
            self.env.render()
            


    def plot_free_energy(self, free_energy):
        """Plot the free energy over iterations."""
        

        plt.figure(figsize=(8, 5))
        plt.plot(free_energy, marker='o', linestyle='-')
        plt.xlabel("Iteration")
        plt.ylabel("Sum of Free Energy")
        plt.title("Free Energy Minimization Over Iterations")
        plt.grid()
        plt.show()

# Q learning

assume that -bR + F(a,s,b) is basically like Q learning

In [25]:
env.count_states()

220

In [24]:
(env.width-2) * (env.height-2) * 4 - (8 *4) 

224

In [339]:
class Qlearning:
    def __init__(
        self,
        env,
        learning_rate = 0.9,
        discount_factor = 0.9,
        epsilon = 0.5,
        epochs = 1000
         
    ):
        self.env = env # MiniGrid environment called from the class Minigrid
        self.num_states = np.int64(env.count_states())
        self.num_actions = env.action_space.n
        self.Q_table = np.zeros((((env.width -2) * (env.height -2) *4), self.num_actions)) # Q table has goal states in it as well
        # print(f"Q table shape: {self.Q_table.shape}")
        self.learning_rate = learning_rate
        self.discount_factor = discount_factor # discount factor
        self.epsilon =  epsilon # exploration probability
        self.epochs = epochs


        self.allowed_state_idx = self.find_state_indexes(env)
        # print(f"allowed state indexes: {self.allowed_state_idx}")


    
    def position_to_state_index(self, tuple_position = None): 
        grid_width = self.env.grid.width -2
        if tuple_position is None:
            direction = self.env.agent_dir
            x, y = self.env.agent_pos

        else:
            if not isinstance(tuple_position, tuple) or len(tuple_position) != 3:
                raise ValueError(f"Invalid position format: {tuple_position}")
            x, y, direction = tuple_position
        return np.int64(((y-1) * grid_width + (x-1)) * 4 + direction) 
    
    def state_index_to_position(self, state_idx):
        """Converts a scalar state index back into (x, y, direction)."""
        grid_width = self.env.grid.width-2
        
        direction = state_idx % 4
        linear_idx = state_idx // 4

        y, x = divmod(linear_idx, grid_width)  # Convert to (x, y)
        
        return x+1, y+1, direction
    
    def find_state_indexes(self, env):
        """Counts all states except walls and barriers"""
        state_indexes_list = []
        for x in range(1, env.grid.width-1):
            for y in range(1, env.grid.height-1):
                if env.grid.get(x, y) is None: #grabs all empty spaces
                    for direction in range(4):
                        state_index = self.position_to_state_index((x, y, direction))
                        state_indexes_list.append(state_index)
        return state_indexes_list        
    
    
    def train(self, epochs):
        goal_states = [self.position_to_state_index((8, 1, d)) for d in range(4)] #goal state index
        
        for epoch in range(epochs):
            current_state = np.random.choice(self.allowed_state_idx)
            x,y,dir = self.state_index_to_position(current_state)
            self.env.agent_pos = (x,y)
            self.env.agent_dir = dir
            # print(f"Epoch {epoch}: Starting state: {current_state}")
            # print(f"Epoch {epoch}: Starting position: {self.state_index_to_position(current_state)}")

            while current_state not in goal_states:
                # #replace current index with its table index
                # table_index = self.state_to_index[current_state]
                # Epsilon-greedy action selection
                if np.random.rand() < self.epsilon:
                    action = self.env.action_space.sample()
                    # print(f"random action: {action}, exploring")
                else:
                    action = np.argmax(self.Q_table[current_state])
                    # print(f"greedy action: {action}, exploiting")


                #transition to the next state
                next_obs, _, done, _, _ = self.env.step(action)

                # print(f"next_observation: {next_obs}")
                next_state = self.position_to_state_index()
                # print(f"next_state: {next_state}")

                #reward (-1 for each step thats not the gaol)
                reward = 0 if next_state in goal_states else -1

                self.Q_table[current_state, action] += self.learning_rate * (reward + self.discount_factor * np.max(self.Q_table[next_state]) - self.Q_table[current_state, action])
                #update state
                if next_state in goal_states:
                    break
                
                current_state = next_state
        
        
    

    def run_policy_2(self):
        """Run the environment using the learned policy from a Q-table."""

        self.env.reset()[0]  # Reset environment
        current_state = self.position_to_state_index()  # Convert starting position to index
        done = False
        step_count = 0  # Track steps to prevent infinite loops

        while not done and step_count < 100:  # Prevent infinite loops
            # Debugging: Print current state and Q-values
            print(f"Step {step_count}: State {current_state}")
            print(f"Q-values: {self.Q_table[current_state]}")
            
            action = np.argmax(self.Q_table[current_state])  # Choose best action
            print(f"Chosen action: {action}")

            # Before taking the step, print agent's current position
            print(f"Before step: Agent Pos: {self.env.agent_pos}, Dir: {self.env.agent_dir}")

            next_obs, _, done, _, _ = self.env.step(action)  # Take action

            # After step, print agent's new position
            print(f"After step: Agent Pos: {self.env.agent_pos}, Dir: {self.env.agent_dir}")

            next_state = self.position_to_state_index()  # Convert new state

            # Detect if the agent is looping in the same state
            if next_state == current_state:
                print(f"⚠️ Warning: Agent is stuck! Current state {current_state} is the same as next state {next_state}.")
                break  # Prevent infinite loop

            self.env.render()  # Visualize movement
            current_state = next_state  # Update current state
            step_count += 1





In [346]:
# Convert Q-table into a DataFrame for better readability
q_table_df = pd.DataFrame(q_training.Q_table, columns=[f"Action {i}" for i in range(3)])

# Add a column for the corresponding grid position
positions = [q_training.state_index_to_position(state_idx) for state_idx in range(q_training.Q_table.shape[0])]
q_table_df["Grid Position"] = positions  # Append positions to Q-table


In [347]:
q_table_df

Unnamed: 0,Action 0,Action 1,Action 2,Grid Position
0,-5.695328,-5.695328,-4.685590,"(1, 1, 0)"
1,-5.217031,-6.125789,-6.513209,"(1, 1, 1)"
2,-5.695328,-5.695328,-6.125770,"(1, 1, 2)"
3,-6.125795,-5.217031,-5.695328,"(1, 1, 3)"
4,-5.217031,-5.217031,-4.095100,"(2, 1, 0)"
...,...,...,...,...
251,-8.475596,-8.301889,-8.146980,"(7, 8, 3)"
252,-8.498971,-8.728534,-8.515567,"(8, 8, 0)"
253,-8.647125,-8.647685,-8.753436,"(8, 8, 1)"
254,-8.713976,-8.499051,-8.499038,"(8, 8, 2)"


In [349]:
Q_table_correct_800 = q_training.Q_table
%store Q_table_correct_800

Stored 'Q_table_correct_800' (ndarray)


In [None]:
q_training.state_index_to_position()

In [361]:
env = SimpleEnv(render_mode=None) 
env.reset()
q_training = Qlearning(env)
q_training.train(800)


In [112]:
Q_table_epochs_600000_e_95_discount_9_learnrate_9= q_training.Q_table
%store Q_table_epochs_600000_e_95_discount_9_learnrate_9

Stored 'Q_table_epochs_600000_e_95_discount_9_learnrate_9' (ndarray)


In [14]:
obs, reward, done, truncated, info = env.step(0)

In [13]:
q_training.position_to_state_index()

np.int64(225)

In [17]:
env.grid.get(1,6)

<minigrid.core.world_object.Wall at 0x1474e7c10>

In [None]:
env.agent_pos

In [None]:
env.agent_dir

In [362]:
env_human = SimpleEnv(render_mode="human") #make same env but in human mode 

q_training.env= env_human #switch out the env in q training with human env
q_training.run_policy_2() #run the policy in the human env

Step 0: State 224
Q-values: [-8.88808209 -8.75050581 -8.64914828]
Chosen action: 2
Before step: Agent Pos: (1, 8), Dir: 0
After step: Agent Pos: (np.int64(2), np.int64(8)), Dir: 0
Step 1: State 228
Q-values: [-8.7719472  -8.78423325 -8.49905365]
Chosen action: 2
Before step: Agent Pos: (np.int64(2), np.int64(8)), Dir: 0
After step: Agent Pos: (np.int64(3), np.int64(8)), Dir: 0
Step 2: State 232
Q-values: [-8.64914693 -8.64914828 -8.33228183]
Chosen action: 2
Before step: Agent Pos: (np.int64(3), np.int64(8)), Dir: 0
After step: Agent Pos: (np.int64(4), np.int64(8)), Dir: 0
Step 3: State 236
Q-values: [-8.49905364 -8.49905365 -8.14697981]
Chosen action: 2
Before step: Agent Pos: (np.int64(4), np.int64(8)), Dir: 0
After step: Agent Pos: (np.int64(5), np.int64(8)), Dir: 0
Step 4: State 240
Q-values: [-7.94108868 -8.33228183 -8.33228183]
Chosen action: 0
Before step: Agent Pos: (np.int64(5), np.int64(8)), Dir: 0
After step: Agent Pos: (np.int64(5), np.int64(8)), Dir: 3
Step 5: State 243
Q-

In [43]:
env_human.close()

In [None]:
q_training.state_indexes_list

In [None]:
for i, state in enumerate(q_training.state_indexes_list):
    x, y, _ = q_training.state_index_to_position(state) 
    print(f"State index {i}: Position ({x}, {y})")

In [7]:
Q_table = q_training.Q_table
%store Q_table

Stored 'Q_table' (ndarray)


In [10]:
Q_table = q_training.Q_table
def state_index_to_position(state_idx):
    """Converts a scalar state index back into (x, y, direction)."""
    grid_width = 8
        
    direction = state_idx % 4
    linear_idx = state_idx // 4

    y, x = divmod(linear_idx, grid_width)  # Convert to (x, y)
        
    return x+1, y+1, direction

# Q Learning with Policy Updates each Epoch

In [532]:
class Qlearning_frompolicy:
    def __init__(
        self,
        env,
        epochs = 1000,
        beta = 100
         
    ):
        #eventually need to replace num_states with the number of states in the environment formally
        self.env = env # MiniGrid environment called from the class Minigrid
        self.num_states = ((env.width -2) * (env.height -2) *4)
        self.num_actions = env.action_space.n
        self.Free_energy_table = np.zeros((self.num_states, self.num_actions)) # Q table has goal states in it as well
        print(f"Free energy table shape: {self.Free_energy_table.shape}")
        self.epochs = epochs
        self.Pi_a = np.zeros(self.num_actions)
        self.Pi_a_s = np.full((self.num_states, self.num_actions), 1 / self.num_actions)
        print(f"policy shape: {self.Pi_a_s.shape}")
        self.P_s_by_s = np.zeros((self.num_states, self.num_states)) # helps caluclate P(s'|s), to then calculate P(s)
        print(f"s by s shape: {self.P_s_by_s.shape}")
        self.allowed_state_idx = self.find_state_indexes(env)
        self.steps = 0 #within each epoch, this counts the number of state transitions
        self.Zeta = np.zeros(self.num_states)
        self.P_s = np.zeros(self.num_states)
        self.P_s_tplus1 = np.zeros(self.num_states)
        self.P_s_given_s_a = np.zeros((self.num_states, self.num_actions, self.num_states))
        self.beta = beta
        self.Pi_a_tplus1 = np.zeros(self.num_actions)

        # print(f"allowed state indexes: {self.allowed_state_idx}")


    
    def position_to_state_index(self, tuple_position = None): 
        grid_width = self.env.grid.width -2
        if tuple_position is None:
            direction = self.env.agent_dir
            x, y = self.env.agent_pos

        else:
            if not isinstance(tuple_position, tuple) or len(tuple_position) != 3:
                raise ValueError(f"Invalid position format: {tuple_position}")
            x, y, direction = tuple_position
        return np.int64(((y-1) * grid_width + (x-1)) * 4 + direction) 
    
    def state_index_to_position(self, state_idx):
        """Converts a scalar state index back into (x, y, direction)."""
        grid_width = self.env.grid.width-2
        
        direction = state_idx % 4
        linear_idx = state_idx // 4

        y, x = divmod(linear_idx, grid_width)  # Convert to (x, y)
        
        return x+1, y+1, direction
    
    def find_state_indexes(self, env):
        """Counts all states except walls and barriers"""
        state_indexes_list = []
        for x in range(1, env.grid.width-1):
            for y in range(1, env.grid.height-1):
                if env.grid.get(x, y) is None: #grabs all empty spaces
                    for direction in range(4):
                        state_index = self.position_to_state_index((x, y, direction))
                        state_indexes_list.append(state_index)
        return state_indexes_list        
    
    
    def connected_states(self):  #action that takes you there and its probability 
        """Check if two states are connected and return the action and probability."""
        
        for state1 in self.allowed_state_idx:
            
            for state2 in self.allowed_state_idx:
                
                if state1 != state2:
                    # Check if the two states are connected
                    for action in range(self.num_actions):
                        # Check if taking action i from state1 leads to state2
                        x, y, direction = self.state_index_to_position(state1)
                        self.env.place_agent((x, y))
                        self.env.agent_dir = direction
                        next_obs, _, done, _, _ = self.env.step(action)
                        next_state = self.position_to_state_index()
                        
                        if next_state == state2:
                            self.P_s_by_s[state1, state2] = self.Pi_a_s[state1, action]  # Store the action and its probability
        return self.P_s_by_s            
            
    def populate_P_s_given_s_a(self):
        for state in self.allowed_state_idx:
            for action in range(self.num_actions):
                x, y, direction = self.state_index_to_position(state)
                self.env.place_agent((x, y))
                self.env.agent_dir = direction
                next_obs, _, done, _, _ = self.env.step(action)
                next_state = self.position_to_state_index()
                self.P_s_given_s_a[state, action, next_state] = 1

    def train(self, epochs):
        
        
        goal_states = [self.position_to_state_index((8, 1, d)) for d in range(4)] #goal state index
        

        # print(f'P_s_by_s matrix: {self.P_s_by_s}')



        for epoch in range(epochs):
            current_state = np.random.choice(self.allowed_state_idx)
            x,y,dir = self.state_index_to_position(current_state)
            self.env.agent_pos = (x,y)
            self.env.agent_dir = dir
            # print(f"Epoch {epoch}: Starting state: {current_state}")
            # print(f"Epoch {epoch}: Starting position: {self.state_index_to_position(current_state)}")

                    #call this to initialize the P(s'|s) matrix
            self.connected_states() 
            self.steps = 1
            initial_s = np.zeros(self.P_s.shape[0])
            initial_s[current_state] = 1 # populate 1 only at the index that s is in ( one hot encoding)
            
            while current_state not in goal_states:
                print(f"Epoch {epoch}: Starting state: {current_state}")
                #action = np.argmax(self.Free_energy_table[current_state])
                #fix action choice
                action = np.random.choice(np.arange(self.num_actions), p=self.Pi_a_s[current_state])
                # print(f"greedy action: {action}, exploiting")


                #transition to the next state
                next_obs, _, done, _, _ = self.env.step(action)

                # print(f"next_observation: {next_obs}")
                next_state = self.position_to_state_index()

                reward = 0 if next_state in goal_states else -1




                #for self.Pi(a) marginal    
                rows_of_actions_per_state = []
                
               
                # print(f"initial_s: {initial_s}")
                Ps_s = np.linalg.matrix_power(self.P_s_by_s, self
                                              .steps - 1)
                # print(f'P_s_by_s matrix to the power of steps: {Ps_s}')
                self.P_s = np.dot(initial_s .T , Ps_s)   #define P(s)
                print(f'sum of P_s: {np.sum(self.P_s)}')
                
                # print(f'P_s: {self.P_s}, P_s_shape: {self.P_s.shape}')
                for s in range(self.Pi_a_s.shape[0]):
                    row = self.Pi_a_s[s] * self.P_s[s]
                    # print(f'row in Pi:{self.Pi_a_s[s]} ')
                    # print(f'row {row}')
                    rows_of_actions_per_state.append(row)
                    
                self.Pi_a  = np.sum(np.asarray(rows_of_actions_per_state), axis=0)
                # print(f"Pi_a: {self.Pi_a}")
                sum_of_Pi_a = np.sum(self.Pi_a)
                print(f"sum of Pi(a) is: {sum_of_Pi_a}")
                if sum_of_Pi_a != 1:
                    print(f"sum of Pi(a) is not 1: {sum_of_Pi_a}")
                
                #####deleted for now, Pi(a t+1) can just be pi(a) for now
                # #for self.Pi(a t+1) marginal 
                # rows_s_tplus1_list = []
                # Ps_s_t_plus_1 = np.linalg.matrix_power(self.P_s_by_s, self.steps)
                # self.P_s_tplus1 = np.dot(initial_s.T , (Ps_s_t_plus_1)) #define P(s t+1)
                # for s_t in range(self.Pi_a_s.shape[0]):    
                    
                #     rows_s_tplus1 = self.Pi_a_s[s_t] * self.P_s_tplus1[s_t]
                #     rows_s_tplus1_list.append(rows_s_tplus1)
                # self.Pi_a_tplus1  = np.sum(rows_s_tplus1_list, axis=0)
                # sum_of_Pi_a_tplus1 = np.sum(self.Pi_a_tplus1)
                # print(f"sum of Pi(a t+1) is: {sum_of_Pi_a_tplus1}")
                # if sum_of_Pi_a_tplus1 != 1:
                #     print(f"sum of Pi(a t+1) is not 1: {sum_of_Pi_a_tplus1}")
                    
                
                ##################################################################################################
                # P_s_given_s_a_sums = []

                # P_a_given_s_tplus1_sums = []
                
                
                # #sum up expectations over inner loop       
                # for j in self.Pi_a_s[next_state]:
                
                #         P_a_given_s_tplus1_expectation =  j * ((np.log(self.Pi_a_s[next_state]) - np.log(self.Pi_a) +  self.Free_energy_table[next_state]))
                #         P_a_given_s_tplus1_sums.append(P_a_given_s_tplus1_expectation)
                #         inner_expectation = np.sum(P_a_given_s_tplus1_sums, axis =0)
                #         print(f'inner expectation: {inner_expectation}')
                # #feed in inner loop expectation into outer loop and sum up expectations
                # for i in self.P_s_given_s_a[current_state, action]:
                    
                
                #         P_s_given_s_a_expectation = i * (np.log((self.P_s_given_s_a[current_state, action]/self.P_s_tplus1))) - self.beta * reward + inner_expectation
                #         # print(f'shape value inside exponential{P_s_given_s_a_expectation.shape}')
                
                # P_s_given_s_a_sums.append(P_s_given_s_a_expectation)
                # print(f'P_s_given_s_a_sums: {P_s_given_s_a_sums}')
                # print(f'P_s_given_s_a_sums length: {len(P_s_given_s_a_sums)}')
                # updated_energy = np.sum(P_s_given_s_a_sums, axis = 0)
                # self.Free_energy_table[state ] = updated_energy
                #################################################################################################

                P_a_given_s_tplus1_sums = []  # For the inner sum
                P_s_given_s_a_sums = []       # For the outer sum

                for j_idx, j_prob in enumerate(self.Pi_a_s[next_state]):
                    # j_prob = self.Pi_a_s[next_state, j_idx]  => pi(a'| s')
                    if j_prob > 0 and self.Pi_a[j_idx] > 0:
                        val = j_prob * (
                            np.log(j_prob) - np.log(self.Pi_a[j_idx]) 
                            + self.Free_energy_table[next_state, j_idx]  
                        )
                    else:
                        val = 0.0
                    P_a_given_s_tplus1_sums.append(val)
                
                # sum over actions a' => a single scalar
                inner_expectation = np.sum(P_a_given_s_tplus1_sums)
                # print("inner_expectation:", inner_expectation)

                # --------------------
                # 2) Outer sum
                # --------------------
                # Typically: sum_{s'} p(s'|s,a) [ log(p(s'|s,a)/p(s')) - beta*reward + inner_expectation ]
                # or your snippet might differ. We'll adapt your code's structure:
                # for i in self.P_s_given_s_a[current_state, action]:
                # But that is a vector over all s'?

                # Example: let's just do it as a single for-loop:
                outer_sums = []
                for s_prime in range(self.num_states):
                    i_val = self.P_s_given_s_a[current_state, action, s_prime]  # p(s'|s,a)
                    if i_val <= 0:
                        continue
                    # Suppose p(s') is self.P_s_tplus1[s_prime], or maybe self.P_s[s_prime].
                    # We'll guess self.P_s_tplus1 for demonstration:
                    p_sprime = self.P_s_tplus1[s_prime]
                    if p_sprime <= 0:
                        continue
                    
                    # The bracketed expression
                    bracket = (
                        np.log(i_val / p_sprime)   # log(p(s'|s,a)/p(s'))
                        - self.beta * reward
                        + inner_expectation
                    )
                    outer_sums.append(i_val * bracket)

                F_s_a = np.sum(outer_sums)
                
                # --------------------
                # 3) Store in the table
                # --------------------
                self.Free_energy_table[current_state, action] = F_s_a
                print(f"Free energy for state {current_state}, action {action}: {F_s_a}")
                print(f"Free energy table: {self.Free_energy_table}")
                #update state
                if next_state in goal_states:
                    break
                

                current_state = next_state
                self.steps+=1

                #system of equations
                
                

                
                #for self.Z(s, Beta)
                element_wise_a_by_q_table = self.Pi_a * np.exp(-self.Free_energy_table)
                
                # print(f'shape value inside exponential{self.Free_energy_table.shape}')
                # print(f'exponential of Free energy table: {np.exp(-self.Free_energy_table)}')
                # print(f'shape of element wise a by q table: {element_wise_a_by_q_table.shape}')
                self.Zeta = np.sum((element_wise_a_by_q_table), axis = 1)
                print (f"Zeta: {self.Zeta.shape}")
          
                    
                #for normalizing the policy pi(a|s)
                self.Pi_a_s = (self.Pi_a * np.exp(-self.Free_energy_table)) / self.Zeta.reshape(-1,1)
                print(f"Pi_a_s, shape: {self.Pi_a_s.shape}")
                print(f"Pi_a_s: {self.Pi_a_s}")
                if self.Pi_a_s.shape != self.Free_energy_table.shape:
                    raise ValueError (f"Pi_a_s is not equal to the Q table: {self.Pi_a_s.shape}")
                if np.sum(self.Pi_a_s[0]) != 1:
                    print(f"Sum of Pi_a_s is not 1: {np.sum(self.Pi_a_s, axis =1)}")
                    break
                assert(False)
        
    


    def run_policy_2(self):
        """Run the environment using the learned policy from the Free energy table."""

        self.env.reset()[0]  # Reset environment
        current_state = self.position_to_state_index()  # Convert starting position to index
        done = False
        step_count = 0  # Track steps to prevent infinite loops

        while not done and step_count < 100:  # Prevent infinite loops
            # Debugging: Print current state and Q-values
            print(f"Step {step_count}: State {current_state}")
            print(f"Free energy values: {self.Free_energy_table[current_state]}")
            
            action = np.argmax(self.Free_energy_table[current_state])  # Choose best action
            print(f"Chosen action: {action}")

            # Before taking the step, print agent's current position
            print(f"Before step: Agent Pos: {self.env.agent_pos}, Dir: {self.env.agent_dir}")

            next_obs, _, done, _, _ = self.env.step(action)  # Take action

            # After step, print agent's new position
            print(f"After step: Agent Pos: {self.env.agent_pos}, Dir: {self.env.agent_dir}")

            next_state = self.position_to_state_index()  # Convert new state

            # Detect if the agent is looping in the same state
            if next_state == current_state:
                print(f"⚠️ Warning: Agent is stuck! Current state {current_state} is the same as next state {next_state}.")
                break  # Prevent infinite loop

            self.env.render()  # Visualize movement
            current_state = next_state  # Update current state
            step_count += 1





In [533]:
free_energy = Qlearning_frompolicy(env)


Free energy table shape: (256, 3)
policy shape: (256, 3)
s by s shape: (256, 256)


In [534]:
free_energy.train(800)

Epoch 0: Starting state: 124
sum of P_s: 1.0
sum of Pi(a) is: 1.0
Free energy for state 124, action 2: 0.0
Free energy table: [[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.

Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x1214ef750>>
Traceback (most recent call last):
  File "/Users/iuliarusu/miniconda3/envs/IBP/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(

KeyboardInterrupt: 


Free energy table: [[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0

KeyboardInterrupt: 

# MLP Approach

In [None]:

class MLP_approach:
    def __init__(
        self,
        env,    
    ):

        self.env = env
        self.num_states = np.int64(env.count_states())
        self.num_actions = env.action_space.n
        

        pi_net = nn.Sequential(
                    nn.Linear(self.num_states, 64),
                    nn.Tanh(),
                    nn.Linear(64, 64),
                    nn.Tanh(),
                    nn.Linear(64, self.num_actions)
                    )
        
        

# Kernel Approach

In [None]:

env = SimpleEnv(render_mode="human")
free_energy_solver = FreeEnergyMin(env, beta=0.5)
# free_energy_solver.estimate_transitions() #for when you want to do a random walk, no learning
free_energy_solver.compute_free_energy()
free_energy_solver.run_policy()


In [None]:
print("Action Space:", env.action_space)

In [None]:
print("Action Space:", env.action_space)