## Visualisation tools

In [1]:
import logging
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
import matplotlib
matplotlib.use('TkAgg')

In [2]:
from map_dm_nav.visualisation_tools import plot_likelihood
import numpy as np
import networkx as nx
from pathlib import Path
from matplotlib import pyplot as plt
import seaborn as sns
from matplotlib import colors

In [3]:
def plot_transitions_per_actions(B, agent_state_mapping,possible_actions, selected_actions=[]):
    labels = [value['state'] for value in agent_state_mapping.values()]
    actions_plots = []
    l = len(labels) *1.5
    for action in range(len(possible_actions)):
        if len(selected_actions) > 0 and action not in selected_actions:
            continue
        fig = plt.figure(action, figsize=(l,l))
        a = B[2:4,:len(labels),action]
        fig = sns.heatmap(a, linewidth=0.5, vmin=0,vmax=1.0 ,cmap="YlOrBr", xticklabels=labels, yticklabels=['2','3'])
        fig.tick_params(axis='both', which='major', labelsize=14)  # Adjust label font size
        fig.set_title(possible_actions[action], fontsize=20)
        fig.set_xlabel('Prev State', fontsize=16)
        fig.set_ylabel('Next State', fontsize=16)
        
        actions_plots.append(fig)
    return actions_plots

In [4]:
def plot_transitions(B: np.ndarray, state_map: dict, actions: dict) -> np.ndarray:
    """Plot Transitions matrix showing the probability of a transition between two states given a certain action."""
    
    sorted_state_map = dict(sorted(state_map.items(), key=lambda item: item[1]['state']))
    labels = [f"{key} ({value['state']})" for key, value in sorted_state_map.items()]
    print('labels',labels)
    n_actions = len(actions)
    l = int(np.ceil(np.sqrt(n_actions)))
    L = int(np.ceil(n_actions / l))
    
    fig, axes = plt.subplots(L, l)
    
    axes = np.atleast_2d(axes)  # Ensure axes is always a 2D array
    count = 0
    print('L and l', L,l)
    for i in range(L):
        for j in range(l):
            if count >= n_actions:
                fig.delaxes(axes[i][j])
                continue
            
            if count not in actions:
                print(count, 'not in actions, n actions',n_actions)
                continue

            action_str = str(actions[count])  # Convert action name to string

            # Plot the heatmap
            g = sns.heatmap(B[:len(labels), :len(labels), count], cmap="OrRd", linewidth=3, 
                            cbar=False, ax=axes[i, j], xticklabels=labels, yticklabels=labels)

            g.tick_params(axis='both', which='major', labelsize=14)  # Adjust label font size
            g.set_title(action_str, fontsize=20)
            g.set_xlabel('Prev State', fontsize=16)
            g.set_ylabel('Next State', fontsize=16)

            # Rotate labels for better visibility
            g.set_xticklabels(labels, rotation=45, ha="right", fontsize=12)
            g.set_yticklabels(labels, rotation=0, fontsize=12)
            
            count += 1

    plt.subplots_adjust(left=0.2, bottom=0.2)  # Add margin space
    plt.tight_layout()
    return fig


In [5]:
def compare_B1_B2_plots(B1: np.ndarray, B2: np.ndarray, state_map: dict, actions: dict) -> np.ndarray:
    """Plot Transitions matrix showing the probability of a transition between two states given a certain action.
       Common values in B1 and B2 (within a margin of 0.1) are set to 0 in the resulting B.
    """
    
    # Create B by zeroing out common values within margin 0.1
    margin = 0.1
    B = np.where(np.abs(B1 - B2) <= margin,0, B2)
    
    sorted_state_map = dict(sorted(state_map.items(), key=lambda item: item[1]['state']))
    labels = [f"{key} ({value['state']})" for key, value in sorted_state_map.items()]

    n_actions = len(actions)
    l = int(np.ceil(np.sqrt(n_actions)))
    L = int(np.ceil(n_actions / l))
    
    fig, axes = plt.subplots(L, l, figsize=(L*3 + max(10, 2.5*len(state_map)), 
                                             l*2 + max(10, 1.5*len(state_map))))
    
    axes = np.atleast_2d(axes)  # Ensure axes is always a 2D array
    count = 0

    for i in range(L):
        for j in range(l):
            if count >= n_actions:
                fig.delaxes(axes[i][j])
                continue
            
            if count not in actions:
                continue

            action_str = str(actions[count])  # Convert action name to string

            # Plot the heatmap
            g = sns.heatmap(B[:len(labels), :len(labels), count], cmap="OrRd", linewidth=3, 
                            cbar=False, ax=axes[i, j], xticklabels=labels, yticklabels=labels)

            g.tick_params(axis='both', which='major', labelsize=14)  # Adjust label font size
            g.set_title(action_str, fontsize=20)
            g.set_xlabel('Prev State', fontsize=16)
            g.set_ylabel('Next State', fontsize=16)

            # Rotate labels for better visibility
            g.set_xticklabels(labels, rotation=45, ha="right", fontsize=12)
            g.set_yticklabels(labels, rotation=0, fontsize=12)
            
            count += 1

    plt.subplots_adjust(left=0.2, bottom=0.2)  # Add margin space
    plt.tight_layout()

    return fig, B

In [6]:
def plot_state_in_map(B: np.ndarray, state_mapping: dict,fig_ax=[None, None]) -> np.ndarray:
    """
    Plot states as dots positioned based on `state_mapping` keys.
    Draw transitions between states based on transition probabilities in `B`.

    Parameters:
    - B (np.ndarray): Transition matrix of shape (num_states, num_states, num_actions).
    - state_mapping (dict): Mapping of (x, y) positions to state properties.
    - possible_actions (dict): Dictionary of action indices to angle ranges.
    - pose_dist (float): Distance associated with each move action.

    Returns:
    - fig (matplotlib Figure): The generated figure.
    """
    if fig_ax[0] is None:
        fig, ax = plt.subplots(figsize=(25, 25))
    else:
        fig = fig_ax[0]
        ax = fig_ax[1]


    # Get unique observation values for color mapping
    unique_obs = np.sort(list({v['ob'] for v in state_mapping.values()}))
    color_map = get_cmap() #get_cmap('viridis', len(unique_obs))
    ob_to_color = {ob: color_map.colors[i] for i, ob in enumerate(unique_obs)}

    # Draw transitions between states
    num_states, _, num_actions = B.shape
    for prev_state in range(num_states):
        for next_state in range(num_states):
            for action in range(num_actions):
                prob = B[next_state, prev_state, action]
                if prob > 0.1:  # Only plot meaningful transitions
                    # Find corresponding positions in `state_mapping`
                    prev_pos = next((pos for pos, data in state_mapping.items() if data['state'] == prev_state), None)
                    next_pos = next((pos for pos, data in state_mapping.items() if data['state'] == next_state), None)
                    
                    if prev_pos and next_pos:
                        ax.plot([prev_pos[1], next_pos[1]], [prev_pos[0], next_pos[0]], 
                                'k-', linewidth=prob * 10)  # Scale linewidth with probability

    # Plot states as dots
    for (x, y), data in state_mapping.items():
        state = data['state']
        ob = data.get('ob', 0)
        color = ob_to_color[ob]

        ax.plot(y, x, 'o', color=color, markersize=20)  # Position state as (y, x)
        ax.text(y - 0.05, x + 0.05, str(state), fontsize=25, ha='right', c='r')  # Label state number

    # Formatting
    # ax.invert_yaxis()
    ax.invert_xaxis()
    ax.set_aspect('equal')
    ax.tick_params(axis='both', which='major', labelsize=26)
    plt.ylabel('X', fontsize=30)
    plt.xlabel('Y', fontsize=30)
    plt.title('State Transitions', fontsize=35)
    plt.grid(False)
    
    return fig
def create_custom_cmap(custom_colors) -> colors.ListedColormap:
    return colors.ListedColormap(custom_colors[:]) #,  alpha=None)

def get_cmap() -> colors.ListedColormap:
    custom_colors = (
            np.array(
                [
                    [255, 255, 255],#white 1
                    [255, 0, 0],#red 2
                    [0, 255, 0], #green 3
                    [50,50, 255], #bluish 4
                    [112, 39, 195], #purple5
                    [255, 255, 0], #yellow6
                    [100, 100, 100], #grey7
                    [115, 60, 60], #brown8
                    [255, 0, 255], #flash pink9
                    [80, 145,80], #kaki10
                    [201,132,226], #pink11
                    [75,182,220], #turquoise12
                    [255,153,51], #orange13
                    [255,204,229], #light pink14
                    [153,153,0], #ugly kaki 15
                    [229,255,204], #light green16
                    [204,204,255],#light purple17
                    [0, 153,153], #dark turquoise18
                    [138, 108, 106], #light brown19
                    [108, 115, 92],#ugly green20
                    [149, 199, 152],#pale green21
                    [89, 235, 210], #flashy light blue22
                    [37, 105, 122], #dark blue23
                    [22, 25, 92], #dark purple-blue24
                    [131, 24, 219], #flashy purple25
                    [109, 11, 120], #purple-pink26
                    [196, 145, 191], #pale pink27
                    [148, 89, 130], #dark pink28
                    [201, 75, 119], #pink-red29
                    [189, 89, 92], #light red30

                ]
            )
            / 256
        )

    n_colors = len(custom_colors)
    return create_custom_cmap(custom_colors[:n_colors])


In [7]:
def plot_mcts_tree(root_node):
    """Visualises the Monte Carlo Tree Search (MCTS) tree."""
    G = nx.DiGraph()  # Directed Graph
    dico = {}
    visited = set()  # To avoid infinite recursion

    # Recursively extract tree structure
    def add_nodes_edges(node, parent=None, action=None):
        if node.id in visited:
            if parent is not None:
                G.add_edge(parent.id, node.id, action=int(action))
            return  # Already added and traversed — skip further traversal

        visited.add(node.id)

        # Aggregate or update visit count
        if node.id not in dico:
            dico[node.id] = node.N
        else:
            dico[node.id] += node.N

        # Label for display
        node_label = f"ID: {node.id}\nN: {round(dico[node.id], 2)},\nR: {round(node.state_reward, 2)}"
        G.add_node(node.id, label=node_label, reward=dico[node.id])

        if parent is not None:
            G.add_edge(parent.id, node.id, action=int(action))

        if node.has_children_nodes():
            for action, child_node in node.childs.items():
                add_nodes_edges(child_node, node, action)

    add_nodes_edges(root_node)

    dico = sorted(dico.items(), key=lambda x: x[1])
    logging.info(f"max visits:{dico}, len dict:{len(dico)}")

    pos = nx.kamada_kawai_layout(G)
    # Scale positions to increase spacing
    pos = {k: (x * 1.5, y * 1.5) for k, (x, y) in pos.items()}

    # Node colors based on reward
    rewards = [G.nodes[n]['reward'] for n in G.nodes]
    min_reward = min(rewards) if rewards else 0
    max_reward = max(rewards) if rewards else 1
    node_colors = [(r - min_reward) / (max_reward - min_reward + 1e-6) for r in rewards]

    plt.figure(figsize=(12, 8))
    nx.draw(G, pos, with_labels=True, labels=nx.get_node_attributes(G, 'label'),
            node_color=node_colors, cmap=plt.cm.cool, node_size=1500,
            font_size=8, font_weight='bold', edgecolors="black", alpha=0.9)

    # Draw edge labels (actions)
    edge_labels = nx.get_edge_attributes(G, 'action')
    nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_size=7, label_pos=0.7)

    plt.title("Monte Carlo Tree Search (MCTS) Visualization")
    plt.show()

## Model 

In [8]:
from map_dm_nav.model.pymdp.agent import Agent
from map_dm_nav.model.odometry import PoseOdometry
from map_dm_nav.model.pymdp import utils
from map_dm_nav.model.pymdp.control import get_expected_obs, get_expected_states, calc_states_info_gain, calc_expected_utility, calc_pA_info_gain, calc_pB_info_gain
from map_dm_nav.model.pymdp.maths import softmax, spm_log_single
from map_dm_nav.model.modules import *
from map_dm_nav.model.pymdp.learning import update_obs_likelihood_dirichlet
# from .pymdp.maths import spm_dot
import math
import copy
import random

In [9]:
import pickle
def pickle_load_model(store_path: str = None):
    """Loads a pickled model from the specified path."""
    store_path = Path(store_path)
    if not store_path.exists():
        logging.error(f"Model file not found at: {store_path}")
        return None
    try:
        with open(store_path, 'rb') as f:
            loaded_model = pickle.load(f)
            logging.info(f"Model successfully loaded from: {store_path}")
            return loaded_model
    except Exception as e:
        logging.error(f"Failed to load model from {store_path}: {e}")
        return None

In [10]:
# --- Node Class ---
class Node:
    """
    Represents a node in the MCTS tree.
    Stores state information, MCTS statistics, and tree structure links.
    """
    def __init__(self, state_qs:np.ndarray, pose_id:int, parent:object=None, action_index:int=0, observation:np.ndarray=None, initial_reward:float=0.0, possible_actions=None):
        self.pose_id = pose_id
        self.id = pose_id  # Using pose_id as a unique identifier for the node

        # State Representation
        self.state_qs = state_qs  # Belief over states (e.g., particle filter weights)
        self.observation = observation # Observation associated with reaching this state (qo_pi)

        # MCTS Statistics
        self.total_reward = 0 #initial_reward # Sum of rewards accumulated through this node (Formerly T)
        self.N = 0 # Visit count

        # Tree Structure
        self.parent = parent # Reference to the parent node
        self.childs = {} # Dictionary mapping action -> child Node
        self.action_index = action_index # Action taken by the parent to reach this node

        # Action Space
        self.possible_actions = possible_actions # List of possible actions from this node's state (computed during expansion)
        self.untried_actions = None # Actions not yet explored from this node

        # Intrinsic Reward (EFE components calculated when node is evaluated)
        self.state_reward = initial_reward # The immediate EFE/G calculated for reaching this state

    def get_averaged_reward(self)->float:
        """Calculates the average reward accumulated through this node."""
        if self.N == 0:
            return self.state_reward # Avoid division by zero for unvisited nodes
        return self.total_reward / self.N

    def get_ucb1_score(self, c_param:float=1.41,use_utility:bool=True,use_states_info_gain:bool=True)->float:
        """
        Calculates the UCB1 score for this node.
        Balances exploitation (average reward) and exploration (visit count).

        UCB1 ensures that the search doesn't prematurely focus only on the initially best-looking option but also invests simulations in exploring other potentially promising, but less certain, branches. This leads to more robust and accurate value estimates for the actions at the root node over many simulations.
        """
        if self.N == 0:
            return float('inf') # Prioritize exploring unvisited nodes

        if self.parent is None:
             # Should not happen during selection if root is handled correctly, but added for safety
             parent_visits = self.N
        else:
             parent_visits = self.parent.N

        if parent_visits == 0: # Avoid log(0) or division by zero if parent somehow has 0 visits
            parent_visits = 1

        exploitation_term = self.get_averaged_reward()
        exploration_term = c_param * math.sqrt(math.log(parent_visits) / self.N)

        # logging.debug(f"Node {self.id}: AvgReward={exploitation_term:.3f}, ExploitTerm={exploration_term:.3f}, ParentN={parent_visits}, SelfN={self.N}")

        # Add intrinsic state reward to bias selection towards immediately rewarding states
        # Note: Depending on the scale of state_reward vs rollout_reward, this might need tuning.
        reward = 0
        if use_utility:
            reward += exploitation_term
        if use_states_info_gain:
            reward += exploration_term
        return reward #+ self.state_reward

    def is_fully_expanded(self)->bool:
        """Checks if we verified possible actions from this node leading to a child node. This suppose we have No isolated node"""
        return self.possible_actions is not None and len(self.childs) > 0

    def has_children_nodes(self)->bool:
        """Checks if the node has any child nodes."""
        return len(self.childs) > 0
    
    def select_best_child_UCB(self, c_param:float=1.41,use_utility:bool=True, use_info_gain:bool=True,parent_list=[])->object:
        """Selects the child with the highest UCB1 score."""
        best_score = -float('inf')
        best_child = None
        scores = []
        for action, child in self.childs.items():
            score = child.get_ucb1_score(c_param,use_utility, use_info_gain)
            # logging.info(f"  Child {child.id} (Action {action}) UCB1: {score:.2f}")
            scores.append(score)
            if score > best_score and child.id not in parent_list[1:]:
                best_score = score
                best_child = child

        if best_child is None:
            best_child_id = np.argmax(scores)
            best_child = list(self.childs.values())[best_child_id]
        #logging.debug(f"Node {self.id}: Selected child {best_child.id if best_child else 'None'} with score {best_score:.2f}")
        return best_child
    
    def all_children_AIF(self)->list:
        """Selects the child with AIF."""
        all_averaged_efe = [c.get_averaged_reward() for c in self.childs.values()]
        # q_pi, best_action_id = self.infer_policy_over_actions(all_averaged_efe, self.possible_actions)
        # logging.debug(f"  Child {child.id} (Action {action}) average_EFE: {score:.2f}")    
        #logging.debug(f"Node {self.id}: Selected child {best_child.id if best_child else 'None'} with score {best_score:.2f}")
        return all_averaged_efe

    def detach_parent(self)-> None:
        """Removes the reference to the parent node to allow garbage collection."""
        logging.debug(f"Detaching parent from node {self.id}")
        del self.parent
        self.parent = None

# --- Model Interface Class ---
class MCTS_Model_Interface:
    """
    Acts as a wrapper or interface to the underlying Active Inference model.
    Provides methods to query the model for transitions, observations, rewards, etc.
    """
    def __init__(self, underlying_model:object):
        self.model = underlying_model # The actual model object (e.g., Ours_V5_RW instance)
        # Caches can be added here if needed for expensive model calls
        # self.transition_cache = {}
        # self.observation_cache = {}
        # self.reward_cache = {}
        logging.info(f"MCTS_Model_Interface initialized with model type: {type(underlying_model)}")

    def get_possible_actions(self)->list:
        """Returns a list of all possible actions [action_id]"""
        return list(self.model.get_possible_actions().keys())

    def id_to_pose(self, pose_id:int)->list:
        return self.model.PoseMemory.id_to_pose(pose_id)
    
    def get_next_node_pose_id(self, current_pose_id:int, action:int)->int:
        """Calculates the next pose ID resulting from taking an action."""   
        odom = self.model.PoseMemory.id_to_pose(current_pose_id)
        next_pose = self.model.PoseMemory.pose_transition_from_action(action, odom=odom)
        next_pose_id = self.model.PoseMemory.pose_to_id(next_pose, save_in_memory=False)
        next_pose = self.model.PoseMemory.id_to_pose(next_pose_id)
        pose_in_action_range = self.model.PoseMemory.pose_in_action_range(action, next_pose, odom= odom) #if we don't reach that pose with that action, we pass
        
        if not pose_in_action_range:
            #logging.warning(f"Action {action} from pose {current_pose_id} leads to unreachable pose {next_pose_id}. Invalid transition.")
            return -1 # Indicate invalid transition
        #logging.info(f"Action {action} from pose {odom}, {current_pose_id} leads to pose {next_pose}, {next_pose_id}. Invalid transition.")
        return next_pose_id

    def get_next_state_belief(self, current_belief_qs:np.ndarray, action:int)->np.ndarray:
        """Predicts the next belief state (qs) given the current belief and action."""
        # This corresponds to the belief state transition model p(qs'|qs, a)
        return self.model.get_next_state_given_action(qs=current_belief_qs, action=action)

    def get_expected_observation(self, next_belief_qs:np.ndarray)->np.ndarray:
        """Calculates the expected observation (qo_pi) given a belief state."""
        # This corresponds to p(o|qs')
        return self.model.get_expected_observation(next_belief_qs)

    def calculate_expected_free_energy(self, next_belief_qs:np.ndarray, expected_observation_qo_pi:np.ndarray, current_qs:np.ndarray, action:int)->float:
        """
        Calculates the Expected Free Energy (G) for a potential next state.
        G = Utility + Information Gain
        """
        G = 0.0
        H = 0.0
        logging.debug(f"action:{action}, next_belief_qs: {str(next_belief_qs)}")
        if self.model.use_states_info_gain:
            #the highest (>0), the more interesting
            info_gain = self.model.infer_info_gain_term([next_belief_qs]) # Assuming takes a list
            G += info_gain

            logging.debug(f"  Info Gain Term: {info_gain:.4f}")
        if self.model.use_utility:
            #the lowest (<0), the more interesting
            logging.debug(f"  Utility Term exp ob: {str(expected_observation_qo_pi)}")
            utility = self.model.infer_utility_term(expected_observation_qo_pi)
            G += utility 
            logging.debug(f"  Utility Term: {utility:.4f}")

        if self.model.use_inductive_inference:
           H -= infer_inductive_preference(self.model, current_qs, next_belief_qs)
           logging.debug(f" Inductive Inference: {H:.4f}")
        if self.model.use_param_info_gain: #not good in asociation with the other terms
            #the highest (>0), the less interesting
            param_info_gain = self.model.infer_param_info_gain([next_belief_qs],expected_observation_qo_pi, current_qs, action)[0]/100
            G -= param_info_gain
            logging.debug(f"  Param info gain Term: {param_info_gain:.4f}")

        logging.debug(f"  Calculated G: {G:.4f}")
        return G, H

    def infer_policy_over_actions(self, action_values:list, available_actions:list, action_selection:str=None, alpha:float=None):
        """Infers a probability distribution (policy) over actions based on their values (e.g., EFE)."""
        # This likely involves a softmax function as in the original code's example
        q_pi, best_action_id = self.model.infer_best_action_given_actions(action_values, available_actions,action_selection, alpha)
        return q_pi, best_action_id

    def get_utility_term(self):
        return self.model.use_utility
    def get_use_states_info_gain_term(self):
        return self.model.use_states_info_gain
# --- MCTS Algorithm Class ---
class MCTS:
    """
    Implements the Monte Carlo Tree Search algorithm using an Active Inference model.
    """

    def __init__(self, AIF_model:object, c_param:float=1.41, num_simulation:int=25, max_rollout_depth:int=10):
        self.model_interface = MCTS_Model_Interface(AIF_model)
        self.c_param = c_param # Exploration parameter for UCB1
        self.num_simulation  = num_simulation # Number of MCTS simulations per planning step
        self.max_rollout_depth = max_rollout_depth # Maximum depth for the simulation (rollout) phase
        logging.info(f"MCTS initialized with exploration parameter c={c_param}, num_simus={num_simulation}, max_depth={max_rollout_depth}, policy_alpha={AIF_model.alpha},  action_selection={AIF_model.action_selection}")

    def start_mcts(self,state_qs:np.ndarray, pose_id:int, observation:np.ndarray, next_possible_actions:list= None, num_steps:int=1, logging=None, plot=False)-> list:
        current_node = Node(state_qs=state_qs,
                pose_id=pose_id,
                parent=None,
                action_index=None,
                observation=observation, 
                possible_actions=next_possible_actions)
        
        best_actions = []
        data = {"qs": state_qs[0],
            "qpi": [],
            "efe": [],
            "info_gain": [],
            "utility": [],
            #"bayesian_surprise": utils.bayesian_surprise(posterior[0].copy(), prior),
            }
        for i in range(num_steps):
            best_action, data = self.plan(current_node, self.num_simulation, self.max_rollout_depth, data, logging=logging)
            best_actions.append(best_action)
            if num_steps>1 and best_action in current_node.childs:
                next_node = current_node.childs[best_action]
                if logging:
                    logging.info(f"MCTS:Executing action {best_action} -> Transitioning to Node {next_node.id}")

                # IMPORTANT: Detach the chosen next state from its parent (the previous state).
                # This makes the chosen next state the new root for the *next* planning step
                # and allows the old parts of the tree to be garbage collected.
                next_node.detach_parent()
                current_node = next_node # Update the current state
        if plot:
            plot_node = copy.deepcopy(current_node)
            data['plot_MCTS_tree'] = plot_node
        if logging:
            logging.info(f"MCTS:Executing actions {best_actions} -> Transitioning up to Node {current_node.childs[best_action].id}")
           
        return best_actions, data

    def _select_node(self, root_node:object, logging=None)->object:
        """Phase 1: Selection - Traverse the tree using UCB1 until a leaf node is reached."""
        current = root_node
        self.parent_list = []
        # logging.debug(f"--- Selection Phase Start (Root: {root_node.id}) ---")
        while current.is_fully_expanded():
            if logging:
                logging.debug(f"  Selected Node {current.id}")
            # logging.debug(f"Selecting from Node {current.id} (N={current.N}, TR={current.total_reward:.3f})")
            #USING UCB
            #self.parent_list.extend([child.id for child in current.childs.values()])
            self.parent_list.append(current.id)
            next = current.select_best_child_UCB(self.c_param, self.model_interface.get_utility_term(), self.model_interface.get_use_states_info_gain_term, self.parent_list)
            
            if next is not None:
                current = next
            else:
                break
            #safety to avoid loopings
            counting_occurences = {x: self.parent_list.count(x) for x in set(self.parent_list)}
            if any(x > 2 for x in counting_occurences.values()):
                break
            #USING AIF
            # children_G = current.all_children_AIF()
            # q_pi, best_action = self.model_interface.infer_policy_over_actions(children_G, current.possible_actions, action_selection='stochastic', alpha=1.0)
            # current = current.childs[best_action]
        if logging:
            logging.info(f"--- Selection Phase End (Selected Node: {current.id}) ---")
        self.parent_list.append(current.id)
        return current

    def _expand_node_in_all_possible_direction(self, node:object)->object:
        """Phase 2: Expansion - Add a new child node for an untried action."""
        
        if node.possible_actions is None :
            node.possible_actions =[]
            all_possible_actions = self.model_interface.get_possible_actions()
        else:
            all_possible_actions = node.possible_actions
        node.childs = {}
            
        #we save as the current node child each new node created taking an action from current pose 
        for action in all_possible_actions:
            next_pose_id = self.model_interface.get_next_node_pose_id(node.pose_id, action)
            #=== check if new  (redundant)===#
            if action not in node.possible_actions:
                if next_pose_id < 0 or next_pose_id in self.parent_list[:-1] : #no known or next pose looping back in path
                    continue
                node.possible_actions.append(action)
        
            #=== action leading to a node, saving believed qs and qo ===#
            # print('node.state_qs', node.state_qs[0].round(4), action)
            next_state_qs = self.model_interface.get_next_state_belief(node.state_qs, action=action)[0]
            # print('next_state_qs', next_state_qs[0][0].round(3))
            qo_pi = self.model_interface.get_expected_observation(next_state_qs)
            #python should erase unreferenced classes. But let's systematise it
            if action in node.childs:
                del node.childs[action]
            # Calculate the immediate reward (Expected Free Energy) for this transition
            # Note: This G is associated with *reaching* the new state.
            child_reward_G, child_H = self.model_interface.calculate_expected_free_energy(next_state_qs, qo_pi, node.state_qs, action)
            child_reward = child_reward_G+ child_H 
            if next_pose_id in self.tree_table:
                child_node = self.tree_table[next_pose_id]
                if child_node.state_reward < child_reward:
                    child_node.state_reward = child_reward
            else:
                # Create the new child node
                child_node = Node(
                    state_qs=next_state_qs,
                    pose_id=next_pose_id,
                    parent=node,
                    action_index=action,
                    observation=qo_pi,
                    initial_reward= child_reward 
                    # Rollout will determine total_reward
                    # possible_actions will be determined when the child is expanded later
                )

                self.tree_table[child_node.id] = child_node

                #To get a headstart (not necessary)
                # child_node.N +=1
                # child_node.total_reward = child_node.state_reward 
                # parent = child_node
                # while parent.parent:
                #     parent = parent.parent
                #     parent.N += 1
                #     parent.total_reward = parent.total_reward + child_reward_G

            node.childs[action] = child_node
            logging.info(f"from node {node.id} -> Child Node {child_node.id}, expanding with action {action}(Initial full={child_node.state_reward:.3f}, G={child_reward_G:.3f} H={child_H:.3f})")
            # logging.debug(f"--- Expansion Phase End (Expanded Node: {child_node.id}) ---")
        return node # Return the newly expanded node


    # def _rollout(self, start_node:object, max_depth:int)->float:
        """
        Phase 3: Simulation (Rollout) - Simulate a trajectory from the start_node
        using a default policy (e.g., random actions) and return the cumulative reward (G).
        """
        logging.debug(f"--- Rollout Phase Start (Node: {start_node.id}, Max Depth: {max_depth}) ---")
        
        cumulative_G = 0.0
        depth = 0

        current_node = start_node
        while depth < max_depth:
            # 1. Check possible actions from the *current simulated pose*
            if not current_node.is_fully_expanded() or len(current_node.possible_actions) == 0:
                #current_node = self._expand_node_in_all_possible_direction(current_node)
                # logging.debug(f"  Rollout Depth {depth}: No actions possible from pose {current_sim_pose_id}. Stopping.")
                break # Dead end in simulation

            # 2. Choose an action using the default policy (random)
            action = random.choice(current_node.possible_actions)

            child = current_node.childs[action]
            # 3. Simulate the transition (expected state and observation)
            # next_sim_qs = child.state_qs
            # sim_qo_pi = child.observation

            # 4. Calculate reward (G) for this simulated step
            step_G = child.state_reward
           
            cumulative_G += step_G
            logging.info(f"  Rollout Depth {depth}: Action {action}, Next node{child.id}, StepG={step_G:.3f}, CumulG={cumulative_G:.3f}")

            # 5. Update simulated state
            current_node = child
            depth += 1

        # logging.debug(f"--- Rollout Phase End (Node: {start_node.id}, Total Rollout G: {cumulative_G:.3f}) ---")
        return cumulative_G

    def _rollout(self, start_node:object, max_depth:int)->float:
        """
        Phase 3: Simulation (Rollout) - Simulate a trajectory from the start_node
        using a default policy (e.g., random actions) and return the cumulative reward (G).
        """
        # logging.debug(f"--- Rollout Phase Start (Node: {start_node.id}, Max Depth: {max_depth}) ---")
        current_sim_qs = start_node.state_qs
        current_sim_pose_id = start_node.pose_id
        cumulative_G = 0.0
        depth = 1

        current_node = start_node

        while depth < max_depth:
            # 1. Get possible actions from the *current simulated pose*
            if current_node and current_node.is_fully_expanded():
                all_possible_actions = current_node.possible_actions
            else:
                all_possible_actions = self.model_interface.get_possible_actions()
            
            if len(all_possible_actions)==0:
                # logging.debug(f"  Rollout Depth {depth}: No actions possible from pose {current_sim_pose_id}. Stopping.")
                break # Dead end in simulation

            # 2. Choose an action using the default policy (random)
            action = random.choice(all_possible_actions)

            # 3. Simulate the transition
            next_sim_pose_id = self.model_interface.get_next_node_pose_id(current_sim_pose_id, action)
            if next_sim_pose_id < 0:
                 # logging.debug(f"  Rollout Depth {depth}: Action {action} from pose {current_sim_pose_id} leads to invalid state. Stopping.")
                 break # Invalid move in simulation

            next_sim_qs = self.model_interface.get_next_state_belief(current_sim_qs, action)[0]
            sim_qo_pi = self.model_interface.get_expected_observation(next_sim_qs)

            # 4. Calculate reward (G) for this simulated step
            step_G, step_H = self.model_interface.calculate_expected_free_energy(next_sim_qs, sim_qo_pi, current_sim_qs, action)
            cumulative_G += step_G + step_H
            logging.debug(f"  Rollout Depth {depth}: Action {action}, NextPose {next_sim_pose_id}, StepG={step_G:.3f}, StepH={step_H:.3f}, CumulG={cumulative_G:.3f}")

            # 5. Update simulated state
            current_sim_qs = next_sim_qs
            current_sim_pose_id = next_sim_pose_id
            depth += 1

            # 6. If a node exist for that action, retrieve it to get appropriate next actions
            if current_node and current_node.has_children_nodes():
                current_node = current_node.childs.get(action,None)
            else:
                current_node = None

        # logging.debug(f"--- Rollout Phase End (Node: {start_node.id}, Total Rollout G: {cumulative_G:.3f}) ---")
        return cumulative_G / depth
    

    def _minimal_rollout(self, start_node:object,max_depth:int)->float:
        """
        Phase 3: Simulation (Rollout) - Simulate a trajectory from the start_node
        using a default policy (e.g., random actions) and return the cumulative reward (G).
        """
        # logging.debug(f"--- Rollout Phase Start (Node: {start_node.id}, Max Depth: {max_depth}) ---")
        current_sim_qs = start_node.state_qs
        current_sim_pose_id = start_node.pose_id
        current_node = start_node
        best_state_reward = -1000

        #for depth in range(max_depth):
        # 1. Get possible actions from the *current simulated pose*
        if current_node and current_node.is_fully_expanded():
            all_possible_actions = current_node.possible_actions
        else:
            all_possible_actions = self.model_interface.get_possible_actions()
        
        if len(all_possible_actions)==0:
            # logging.debug(f"  Rollout Depth {depth}: No actions possible from pose {current_sim_pose_id}. Stopping.")
            return 0 # Dead end in simulation
        # 2. Review ALL the actions
        for action in all_possible_actions:
                # 3. Simulate the transition
                next_sim_pose_id = self.model_interface.get_next_node_pose_id(current_sim_pose_id, action)
                if next_sim_pose_id < 0:
                        # logging.debug(f"  Rollout Depth {depth}: Action {action} from pose {current_sim_pose_id} leads to invalid state. Stopping.")
                        continue # Invalid move in simulation

                next_sim_qs = self.model_interface.get_next_state_belief(current_sim_qs, action)[0]
                sim_qo_pi = self.model_interface.get_expected_observation(next_sim_qs)

                # 4. Calculate reward (G) for this simulated step
                step_G, step_H = self.model_interface.calculate_expected_free_energy(next_sim_qs, sim_qo_pi, current_sim_qs, action)
                step_reward = step_G + step_H
                if step_reward > best_state_reward:
                    best_state_reward = step_reward
                logging.debug(f"  Rollout node {start_node.id}: Action {action}, NextPose {next_sim_pose_id}, step_reward={step_reward:.3f}, StepG={step_G:.3f}, StepH={step_H:.3f}")


                # 6. If a node exist for that action, retrieve it to get appropriate next actions
                if current_node and current_node.has_children_nodes():
                    current_node = current_node.childs.get(action,None)
                else:
                    current_node = None
        #SECURITY (should be useless)
        if best_state_reward == -1000:
            best_state_reward = 0
        # logging.debug(f"--- Rollout Phase End (Node: {start_node.id}, Total Rollout G: {cumulative_G:.3f}) ---")
        return best_state_reward
    
    def _backpropagate(self, node:object, reward:float)-> None:
        """Phase 4: Backpropagation - Update visit counts and total rewards up the tree."""
        # logging.debug(f"--- Backpropagation Start (Node: {node.id}, Reward: {reward:.3f}) ---")
        current = node
        while current is not None:
            current.N += 1
            current.total_reward += reward
            # logging.debug(f"  Updating Node {current.id}: N={current.N}, TR={current.total_reward:.3f}")
            current = current.parent
        # logging.debug(f"--- Backpropagation End ---")

    def run_simulation(self, root_node, max_rollout_depth, logging=None):
        """Runs a single iteration of the MCTS algorithm (Select, Expand, Simulate, Backpropagate)."""
        # logging.debug(f"=== Starting MCTS Simulation ===")

        # Phase 1: Selection
        selected_node = self._select_node(root_node, logging=logging)

        # Phase 2: Expansion
        # If the selected node is not terminal and not fully expanded, expand it.
        # Check if the node is terminal (add domain-specific logic if needed, e.g., goal reached)
        # is_terminal = False # Placeholder - add condition if applicable
        # if not is_terminal:
        if not selected_node.is_fully_expanded():
            selected_node = self._expand_node_in_all_possible_direction(selected_node)
        else:
            # If fully expanded, the rollout starts from the selected node itself
            # This can happen if selection leads to an already expanded node
            logging.debug(f"Selected node {selected_node.id} is fully expanded, starting rollout from here.")
            #pass


        # Phase 3: Simulation (Rollout)
        # Start rollout from the newly expanded node (or the selected node if expansion wasn't possible/needed)
        reward = self._minimal_rollout(selected_node,max_rollout_depth)
        #reward = self._rollout(selected_node, max_rollout_depth)
        # Add the immediate state reward (G) of the node where the rollout started
        # This connects the immediate EFE gain with the future expected gains from the rollout
        reward += selected_node.state_reward

        # Phase 4: Backpropagation
        self._backpropagate(selected_node, reward)
    
        children_info = [('a', a, 'child id',c.id,'N',c.N,'T', round(c.total_reward,2),'efe_av', round(c.get_averaged_reward(),2)) for a,c in root_node.childs.items()]
        logging.info(f"Root node children stats: {children_info}")
        # logging.debug(f"=== Finished MCTS Simulation ===")

    def plan(self, root_node:object, num_simulations:int, max_rollout_depth:int, data:dict=None, logging=None)-> int: #dict
        """Runs the MCTS planning process for a given number of simulations."""
        if logging:
            logging.info(f"Starting MCTS planning from root node {root_node.id} for {num_simulations} simulations.")

        self.tree_table = {}
        for i in range(num_simulations):
            # print()
            if logging:
                logging.info(f"--- Simulation {i+1}/{num_simulations} ---")
            self.run_simulation(root_node, max_rollout_depth, logging)

        # After simulations, determine the best action from the root
        best_action, q_pi_actions_values = self.get_best_action(root_node)
        data['qpi'].append(q_pi_actions_values[0])
        data['efe'].append(q_pi_actions_values[1])

        return best_action, data

    def get_best_action(self, root_node:object)->int:
        """Selects the best action from the root node after simulations."""
        if not root_node.childs:
            logging.warning("Root node has no children after simulations. Cannot determine best action.")
            return None # Or a default action

        # Option 1: Choose the most visited child (robust)
        # best_action = max(root_node.childs.keys(), key=lambda action: root_node.childs[action].N)

        # Option 2: Choose the child with the highest average reward (can be greedy)
        # best_action = max(root_node.childs.keys(), key=lambda action: root_node.childs[action].get_averaged_reward())

        # Option 3: Use the model's policy inference based on average rewards (AIF scheme)
        action_values = []
        available_actions = []
        child_info = []
        for action, child in root_node.childs.items():
            avg_reward = child.get_averaged_reward()
            if child.id < 0 : #We don't care about uncharted state, thus we artificially decrease their attractiveness
                avg_reward = 0
            action_values.append(avg_reward)
            available_actions.append(action)
            child_info.append(f"Action {action}: AvgR={avg_reward:.3f}, N={child.N}")
        logging.info(f"Root node children stats: {'; '.join(child_info)}")

        if len(available_actions)==0:
             logging.warning("No valid actions available from root node children.")
             return None, []

        q_pi, best_action_id = self.model_interface.infer_policy_over_actions(action_values, available_actions)
        logging.info(f"action average G: {action_values}")
        logging.info(f"softmax policies: {q_pi.round(2)}")
        logging.info(f"Selected best action based on policy: {best_action_id}")
        
        # Ensure the selected action is actually one of the children
        if best_action_id not in root_node.childs:
             logging.error(f"Policy selected action {best_action_id} which is not a child of the root node. Available: {list(root_node.childs.keys())}. Falling back to most visited.")
             # Fallback to most visited
             if available_actions:
                best_action_id = max(root_node.childs.keys(), key=lambda action: root_node.childs.get(action).N if root_node.childs.get(action) else -1)
             else:
                 return None, [] # No valid children

        full_action_values = [action_values[available_actions.index(a)] if a in available_actions else 0 for a in self.model_interface.get_possible_actions()]
        full_q_pi = [q_pi[available_actions.index(a)] if a in available_actions else 0 for a in self.model_interface.get_possible_actions()]
        return best_action_id, (full_q_pi, full_action_values)

In [11]:
class Ours_V5_RW(Agent):
    #====== NECESSARY TO SETUP MODEL ======#
    def __init__(self, num_obs=2, num_states=2, dim=2, observations=[0,(0,0)], lookahead_policy=4, \
                 learning_rate_pB=3.0, n_actions= 6,\
                 influence_radius:float=0.5, robot_dim:float=0.25, \
                   lookahead_node_creation=3) -> None:
        self.agent_state_mapping = {} #testing purposes
        self.influence_radius = influence_radius
        self.robot_dim = robot_dim 
        self.possible_actions = self.generate_actions(n_actions) 
        self.PoseMemory = PoseOdometry(self.possible_actions, influence_radius, robot_dim)

        self.preferred_ob = [-1,-1]


        self.lookahead_node_creation = lookahead_node_creation
        observations, agent_params = self.create_agent_params(num_obs=num_obs, num_states=num_states, observations=observations, \
                            learning_rate_pB=learning_rate_pB, dim=dim, lookahead_policy=lookahead_policy)
        super().__init__(**agent_params)
        self.initialisation(observations=observations)
    
    def create_agent_params(self,num_obs:int=2, num_states:int=2, observations:list=[0,(0,0)], 
                    learning_rate_pB:float=3.0, dim:int=2, lookahead_policy:int=4):
        ob = observations[0]
        p_idx = -1
        if dim > 1:
            #start pose in map
            if len(observations) < 2:
                observations.append([0.0,0.0])
            self.PoseMemory.reset_odom(observations[1])
            p_idx = self.PoseMemory.pose_to_id()
            observations[1] = p_idx
            
        else:
            p_idx = self.PoseMemory.pose_to_id()
        
        self.current_pose = self.PoseMemory.get_odom(as_tuple=True)
        #INITIALISE AGENT PARAMS
        B_agent = create_B_matrix(num_states,len(self.possible_actions))
        if 'STAY' in self.possible_actions and self.set_stationary_B:
            B_agent = set_stationary(B_agent,self.possible_actions['STAY'])
        pB = utils.to_obj_array(B_agent)

        obs_dim = [np.max([num_obs, ob + 1])] + ([np.max([num_obs, p_idx + 1])] if dim > 1 else [])
        A_agent = create_A_matrix(obs_dim,[num_states]*dim,dim)
        pA = utils.dirichlet_like(A_agent, scale = 1)

        return observations, {
            'A': A_agent,
            'B': B_agent,
            'pA': pA,
            'pB': pB,
            'policy_len': lookahead_policy,
            'inference_horizon': lookahead_policy,
            'lr_pB': learning_rate_pB,
            'lr_pA': 5,
            'save_belief_hist': True,
            'action_selection': "stochastic", 
            'use_param_info_gain': False
        }

    def initialisation(self,observations:list=[0,[0,0]]):
        """
        Initialises the agent with the first observation and ensures all parameters 
        are suitable for continuous navigation.

        Parameters:
            observations (list, optional): Initial observation. Defaults to [0, [0, 0]].
            linear_policies (bool, optional): 
                - If **False**: Explores all combinations of actions (exponential complexity: `n_action^policy_len` with `policy_len == lookahead`).
                - If **True**: Generates a linear path reaching a **lookahead distance** or **num steps**.
                - The path remains linear if no "STAY" actions are included.
                - If "STAY" actions exist, the path follows a polynomial pattern.
                - "STAY" actions are irregular and appear only at the end of a policy.
            E (optional): Additional environment-specific parameters (default: None).

        Note:
            - `linear_policies=True` is optimized for cases where `num_factor == 1` 
            and `len(num_control) == 1`.

        Returns:
            None
        """
      
        self.reset(start_pose=self.PoseMemory.get_poses_from_memory()[0])
        if self.edge_handling_params["use_BMA"] and hasattr(self, "q_pi_hist"): #This is not compatible with our way of moving
            del self.q_pi_hist
            
        self.inference_params = {'num_iter': 3, 'dF': 1.0, 'dF_tol': 0.001}
        #Not necessary, but cleaner
        for i in range(len(self.A)):
            self.A[i][:,:] = 0.01 #reset A for cleaner plot and more fair state inference
        self.update_A_with_data(observations,0)
        self.update_agent_state_mapping(self.current_pose, observations, 0)
        self.infer_states(observation = observations, partial_ob=None)
        if 'STAY' in self.possible_actions.values():
            stay_action = [key for key, value in self.possible_actions.items() if value == 'STAY'][0]
            self.B[0] = set_stationary(self.B[0], stay_action)
        return 

    def reset(self, init_qs:np.ndarray=None, start_pose:tuple=None):
        """
        Resets the agent's posterior beliefs about hidden states to a uniform distribution 
        and resets the simulation time to the initial timestep.

        This function initializes or resets key agent parameters, including past actions, 
        observations, and beliefs, ensuring proper inference and navigation behavior.

        Parameters
        ----------
        init_qs : numpy.ndarray, optional
            A predefined posterior over hidden states. If provided, the agent's beliefs 
            will be initialized using `init_qs` instead of a uniform prior.
        
        start_pose : tuple, optional
            The initial position (pose) of the agent. If provided, it sets `self.current_pose`.

        Returns
        -------
        qs : numpy.ndarray
            The initialized posterior over hidden states. The structure of `qs` depends on 
            the inference algorithm selected:

            - If `self.inference_algo == 'VANILLA'`:  
            `qs` is a simple uniform distribution over hidden states.

        Notes
        -----
        - If `self.edge_handling_params['policy_sep_prior']` is enabled, 
        the latest beliefs are initialized separately for each policy.
        - If `init_qs` is provided, it is directly assigned to `self.qs`, 
        bypassing uniform initialization.
        """

        self.curr_timestep = 0
        self.action = None
        self.prev_actions = None
        self.prev_obs = []
        self.qs_step = 0
     
        self.current_pose = start_pose
        if init_qs is None:
            
            self.D = self._construct_D_prior()
           
            if hasattr(self, "q_pi_hist"):
                self.q_pi_hist = []

            if hasattr(self, "qs_hist"):
                self.qs_hist = []
            
            if self.inference_algo == 'VANILLA':
                self.qs = utils.obj_array_uniform(self.num_states)
            else:
                print('MMP INFERENCE NOT IMPLEMENTED')
        
        else:
            self.qs = init_qs

        return self.qs
    
    def generate_actions(self,n_actions:int)->dict:
        """
        Divides the 360-degree orientation into discrete action zones and 
        returns a dictionary mapping each action to its corresponding range.

        Parameters:
            n_actions (int): The number of discrete actions to divide the 
                            360-degree space into.

        Returns:
            dict: A dictionary where keys are action indices (int), and values 
                are lists containing the start and end zone (in degrees):
                `{action_index: [start_zone, end_zone]}`.

        Note:
            - The action zones are evenly spaced across 360 degrees.
            - The function include a "STAY" action.
            - The start and end values are rounded to the nearest integer.
        """
        stay = False
        if n_actions% 2 != 0:
            n_actions = n_actions-1
            stay = True
        zone_range_deg = round(360/n_actions,1)
        n_actions_keys = np.arange(0, n_actions, 1)
        zone_spacing_deg = np.arange(0, 361, zone_range_deg)
        possible_actions = {}
        for action_key in n_actions_keys:
            possible_actions[action_key] = [round(zone_spacing_deg[action_key]), round(zone_spacing_deg[action_key+1]),]
        if stay:
            possible_actions[len(possible_actions)] = "STAY"

        return possible_actions

    #==== TEST&VISU PURPOSES ONLY ====#
    def update_agent_state_mapping(self, pose:tuple, ob:list, state_belief:list=None)-> dict:
        """ Dictionnary to keep track of believes and associated obs, usefull for testing purposes"""
        if state_belief is None:
            state = -1
        else:
            state = np.argmax(state_belief)
        #If we already have an ob, let's not squish it with ghost nodes updates
        if pose in self.agent_state_mapping.keys() and self.agent_state_mapping[pose]['ob'] != -1:
            ob[0] = self.agent_state_mapping[pose]['ob']
        self.agent_state_mapping[pose] = {'state' : state , 'ob': ob[0]}
        if len(ob) > 1:
           self.agent_state_mapping[pose]['ob2'] =  ob[1] 
      
        return self.agent_state_mapping
    
    #==== SET METHODS ====#
    def explo_oriented_navigation(self):
        self.use_param_info_gain = False #if true, do not use with the other terms
        self.use_states_info_gain = True 
        self.use_utility = False

    def goal_oriented_navigation(self, obs=None, **kwargs):
        pref_weight = kwargs.get('pref_weight', 1.0)
        self.update_preference(obs, pref_weight)
        self.use_param_info_gain = False #if true, do not use with the other terms
        self.use_states_info_gain = False #This make it FULLY Goal oriented
        #NOTE: if we want it to prefere this C but still explore a bit once certain about state 
        #(keep exploration/exploitation balanced) keep info gain
        self.use_utility = True


    def set_action_step(self, action):
        ''' only to do if we don't nfer action'''
        self.action = np.array([action])
        self.step_time()
    
    #==== GET METHODS
    def get_current_pose_id(self):
        ''' we do not want a negative pose od'''
        if self.current_pose is None:
            current_pose = self.PoseMemory.get_odom()[:2]
        else:
            current_pose = self.current_pose
        return self.PoseMemory.pose_to_id(current_pose)
    
    def get_belief_over_states(self):
        """
        get the belief distribution over states

        Returns:
            np.ndarray: The extracted belief distribution over states.
        """
        return self.qs
    
    def get_current_timestep(self):
        return self.curr_timestep
    def get_possible_actions(self):
        print('IN MODEL get_possible_actions', self.possible_actions)
        return self.possible_actions
    def set_memory_views(self, views):
        self.ViewMemory = views
    def get_memory_views(self):
        return self.ViewMemory
    
    def get_n_states(self):
        return len(self.agent_state_mapping)
    
    def get_agent_state_mapping(self)->dict:
        return self.agent_state_mapping
    
    def get_B(self):
        return self.B[0]
    
    def get_A(self):
        return self.A
    def get_current_most_likely_pose(self, z_score:float, min_z_score:float=2, qs=None,  observations:list=[])->int:
        """
        Given a z_scores (usually around 2), is the agent certain about the state. If it is, to which pose does it correspond?
        Return pose -1 if < threhsold, else return pose id.
        If no state stands out at all, we don't know where we are and return -2
        """
        if qs is None:
            qs = self.get_belief_over_states()[0]
        p_idx = -1
        mean = np.mean(qs)
        std_dev = np.std(qs)
        #print('qs mean and std_dev', mean, std_dev)
        # Calculate Z-scores
        z_scores = (qs - mean) / std_dev
        # Get indices of values with Z-score above 2
        outlier_indices = np.where(np.abs(z_scores) > z_score)[0]
        min_outlier_indices = np.where(np.abs(z_scores) > min_z_score)[0]
        
        
        #print("Indices of outliers (Z-score >",z_score,"):" , outlier_indices)
        #If we are sure of a state (independent of number of states), we don't have pose as ob and A allows for pose
        if len(outlier_indices) >= 0 and len(observations) < 2 and len(self.A) > 1:
            #If 1 state stands out
            if len(outlier_indices) == 1:
                p_idx = outlier_indices[0]
        #If min_z_scores length is 0, it means no proba is standing out! We don't know where we are   
        elif len(min_outlier_indices) == 0 and len(observations) < 2 and len(self.A) > 1:
            p_idx = -2
        return p_idx

    def get_observation_most_likely_states(self, z_score:float, observations:list=[])->int:
        """
        Given a z_scores (usually around 2), is the agent certain about the state. If it is, to which pose does it correspond?
        Return pose -1 if < threhsold, else return pose id.
        If no state stands out at all, we don't know where we are and return -2
        """
        likelihood = self.get_A()[0][observations[0],:]
        p_idx = -1
        mean = np.mean(likelihood)
        std_dev = np.std(likelihood)
        logging.info(f'likelihood mean and std_dev {mean}, {std_dev}')
        # Calculate Z-scores
        z_scores = (likelihood - mean) / std_dev
        # Get indices of values with Z-score above 2
        outlier_indices = np.where(np.abs(z_scores) > z_score)[0]       
        print('z_scores',z_scores)
        #print("likelihood Indices of outliers (Z-score >",z_score,"):" , outlier_indices)
        #If we are sure of a state (independent of number of states), we don't have pose as ob and A allows for pose
    
        return outlier_indices
    
    def get_expected_observation(self, qs=np.ndarray, A:np.ndarray=None)-> np.ndarray:
        """ get observation belief for state qs"""
        if A is None:
            A = self.A
        qo_pi = get_expected_obs(qs, A)
        return qo_pi

    def get_next_state_given_action(self, qs= np.array, action=int, B=None)->np.ndarray:
        ''' expect only 1 qs, return only 1 qs with the same shape (1,) == np.ndarray([np.ndarray([])])'''
        if B is None:
            B = self.B

        # print('B check', B[0][:,np.argmax(qs), action])
        if isinstance(action, (int,np.int64)):
            action = np.array([[action]])
            
        qs_pi = get_expected_states(qs, B, action)
        return qs_pi
    
    
    #==== MCTS_CALL ====#
    def define_actions_from_MCTS_run(self,num_steps=1, logging=None,  **kwargs)->list: #,dict
        """ 
        MCTS RUN, UNDER TEST (SHOULD NOT BE RE-CREATED EACH RUN)
        TODO: adapt for when we want a full policy

        """
        c_param = 5
        num_simulations = 100  # Number of MCTS simulations per planning step
        max_rollout_depth = 1 #UNUSED NOW # Maximum depth for the simulation (rollout) phase
        mcts = MCTS(self, c_param, num_simulations, max_rollout_depth)

        observations = kwargs.get('observations', None)
        next_possible_actions = kwargs.get('next_possible_actions', None)
        plot_MCTS_tree = kwargs.get('plot_MCTS_tree', False)

        #If we are not inferring state at each state during the model update, we do it here
        if observations is not None and self.current_pose is None:
        #If self.current_pose is not None then we have step_update that infer state
            
            #NB: Only give obs if state not been inferred before 
            if len(observations) < len(self.A):
                partial_ob = 0
                            
            elif len(observations) == len(self.A):
                partial_ob = None
                if self.current_pose == None:
                    self.current_pose = observations[1]
                observations[1] = self.PoseMemory.pose_to_id(observations[1])
            
            self.infer_states(observation = observations, partial_ob=partial_ob, save_hist=True)

        initial_pose_id = self.get_current_pose_id()
        initial_belief_qs = self.get_belief_over_states() # Get initial belief
        initial_observation = self.get_expected_observation(initial_belief_qs)

        best_actions, data = mcts.start_mcts(state_qs=initial_belief_qs,
                     pose_id=initial_pose_id, observation=initial_observation, \
                     next_possible_actions=next_possible_actions, num_steps=num_steps, logging=logging, plot=plot_MCTS_tree)
        
        #NOTE: THIS CONSIDERONLY FIRST ACTION OF POLICY. MIGHT LEADS TO ISSUE DEPENDING ON HOW WE USE THAT
        self.q_pi = data['qpi'][0]
        self.G = data['efe'][0]

        #NOTE: THIS CONSIDER THAT WE APPLY FIRST ACTION OF POLICY. MIGHT LEADS TO ISSUE DEPENDING ON HOW WE USE THAT
        self.action = np.array([best_actions[0]])
        self.step_time()
        
        return best_actions[:num_steps], data
    
    #==== INFERENCE ====#

    def infer_states(self, observation:list, action:np.ndarray= None ,save_hist:bool=True, partial_ob:int=None, qs:list=None):
        """
        Performs variational inference to update posterior beliefs over hidden states given an observation.

        This method updates the agent's belief state (`qs`) by incorporating new observations 
        and optionally considering the previous action. The update process depends on the 
        selected inference algorithm (`VANILLA`).

        Parameters
        ----------
        observation : list or tuple 
            The observed state indices for each observation modality.
    
        action : np.ndarray, optional
            The most recent action taken by the agent. If provided, it helps refine posterior beliefs.

        save_hist : bool, default=True
            If True, stores the latest observation and updates historical data for future inference.


        partial_ob : int, optional
            Specifies a particular observation modality to update the belief state for, rather than all modalities.

        qs : list, optional
            A predefined posterior belief state. If provided, this will be used instead of computing from scratch.

        Returns
        -------
        qs : numpy.ndarray of dtype object
            Updated posterior beliefs over hidden states. The structure depends on the inference algorithm:
            - For `VANILLA`, `qs` represents a single posterior belief over hidden states.


        Notes
        -----
        - If `self.inference_algo == "VANILLA"`, posterior updates consider an empirical prior derived from 
        the transition model (`B`) or from a uniform prior (`D`).
        - The method also updates `self.qs_hist` and `self.qs_step` when `save_hist=True`, 
        enabling tracking of belief evolution over time.
        """

        if save_hist:
            self.prev_obs.append(observation)
            observations_hist = self.prev_obs
        else:
            observations_hist = self.prev_obs.copy()
            observations_hist.append(observation)

        if action == None:
            action = self.action

        if self.inference_algo == "VANILLA":
            if action is not None:
                if qs is None:
                    ref_qs = self.get_belief_over_states() #we don't yet want to consider current obs to selest qs
                else: #safety to avoid any risk
                    ref_qs = qs[:]
                empirical_prior = control.get_expected_states(
                    ref_qs, self.B, action.reshape(1, -1) #type: ignore
                )[0]
            #unused, but kept as a memor
            else:
                self.D = self._construct_D_prior() #self.D
                empirical_prior = self.D
            if self.current_pose is None:
                #TODO: increase A with observation even when self.current_pose is None
                for i in range(len(self.A)):
                    if partial_ob != None:
                        i = partial_ob
                    if observation[i] >= len(self.A[i]):
                        print('ERROR IN INFER STATE: given observation not in A')
                        qs = self.get_belief_over_states()
                        return qs
            qs = update_posterior_states(
            self.A,
            observation,
            empirical_prior,
            partial_ob,
            **self.inference_params
            )
            F = 0
            qs_step = 0

        if save_hist:
            self.F = F # variational free energy of each policy  
            self.qs_step = qs_step
            if hasattr(self, "qs_hist"):
                self.qs_hist.append(qs)
            self.qs = qs

        return qs
    
    def infer_pose(self, pose)->list:
        '''
        Parameters:
            pose (int or list): The index of the pose or the pose itself
        Here we consider that action consider the actual action sensed by agent (thus consider no motion)
        and we consider PoseMemory adapted to treat that perception
        '''
        if isinstance(pose,int):
            pose = self.PoseMemory.id_to_pose(pose)
        self.PoseMemory.update_odom_given_pose(pose)
        if self.current_pose !=None:
            self.current_pose = self.PoseMemory.get_odom(as_tuple=True)[:2]
        return self.current_pose

    def infer_action(self, logs=None, **kwargs):
        """
        return the best action to take
        possible params (as a dict):
        observations (List): (only if state not been inferred before)
        next_possible_actions (List): constraint the action to take to be among a list of given choices. 
        
        Returns
        ----------
        return action as int and info as dict
        """
        observations = kwargs.get('observations', None)
        next_possible_actions = kwargs.get('next_possible_actions', list(self.possible_actions.keys()))
        if logs is not None:
            logs.info('observations'+ str(observations)+ 'next_possible_actions'+ str(next_possible_actions))
        # prior = self.get_belief_over_states()
        
        #If we are not inferring state at each state during the model update, we do it here
        if observations is not None and self.current_pose is None:
        #If self.current_pose is not None then we have step_update that infer state
            
            #NB: Only give obs if state not been inferred before 
            if len(observations) < len(self.A):
                partial_ob = 0
                            
            elif len(observations) == len(self.A):
                partial_ob = None
                if self.current_pose == None:
                    self.current_pose = observations[1]
                observations[1] = self.PoseMemory.pose_to_id(observations[1])
            
            posterior = self.infer_states(observation = observations, partial_ob=partial_ob, save_hist=True)
            # print('infer action: self.current_pose', self.current_pose, posterior[0].round(3))
        if logs is not None:
            logs.info('still there')
        #In case we don't have observations.
        posterior = self.get_belief_over_states()
        #print('infer action: inferred prior state', posterior[0].round(3))
        q_pi, efe, info_gain, utility = self.infer_policies(posterior, logs=logs)
        if logs is not None:
            logs.info('catching up here')
        poses_efe = None
        action = self.sample_action(q_pi, next_possible_actions)

        if logs is not None:
            logs.info('no problem here')
        #TODO: switch back to sample_action once tests done
        # action, poses_efe = self.sample_action_test(next_possible_actions)
        
        #What we would expect given prev prior and B transition 
        # prior = spm_dot(self.B[0][:, :, int(action)], prior)
        
        self.q_pi = q_pi
        self.G = efe

        data = { "qs": posterior[0],
            "qpi": q_pi,
            "efe": efe,
            "info_gain": info_gain,
            "utility": utility,
            #"bayesian_surprise": utils.bayesian_surprise(posterior[0].copy(), prior),
            }
        if poses_efe is not None:
            data['poses_efe'] = poses_efe
        return int(action), data
    
    def infer_policies(self, qs=None, logs=None):
        """
        Perform policy inference by optimizing a posterior (categorical) distribution over policies.
        This distribution is computed as the softmax of ``G * gamma + lnE`` where ``G`` is the negative expected
        free energy of policies, ``gamma`` is a policy precision and ``lnE`` is the (log) prior probability of policies.
        This function returns the posterior over policies as well as the negative expected free energy of each policy.

        Returns
        ----------
        q_pi: 1D ``numpy.ndarray``
            Posterior beliefs over policies, i.e. a vector containing one posterior probability per policy.
        G: 1D ``numpy.ndarray``
            Negative expected free energies of each policy, i.e. a vector containing one negative expected free energy per policy.
        """
        if qs is None:
            qs = self.qs

        #If we want to increase the precision of the utility 
        # term on A, we can play with the section below.
        #Currently unused 
        if self.use_utility:
            A = copy.deepcopy(self.A)
            # for modality, array in enumerate(A[0]):
            #     # Compute normalization
            #     summed = array.sum(axis=0, keepdims=True)
            #     # print(summed)
            #     A[0][modality] = array * 10 / summed
        else:
            A = self.A

        q_pi, G, info_gain, utility = update_posterior_policies(
            qs,
            A,
            self.B,
            self.C,
            self.policies,
            self.use_utility,
            self.use_states_info_gain,
            self.use_param_info_gain,
            self.pA,
            self.pB,
            E = self.E,
            gamma = self.gamma,
            diff_policy_len = False, #TODO: erase in a refactoring
            logs= logs
        )
        if hasattr(self, "q_pi_hist"):
            self.q_pi_hist.append(q_pi)
            if len(self.q_pi_hist) > self.inference_horizon:
                self.q_pi_hist = self.q_pi_hist[-(self.inference_horizon-1):]
            
        return q_pi, G, info_gain, utility
    
    def sample_action(self, q_pi:np.ndarray, possible_first_actions:list=None)->int:
        """
        Sample or select a discrete action from the posterior over control states.
        This function both sets or cachés the action as an internal variable with the agent and returns it.
        This function also updates time variable (and thus manages consequences of updating the moving reference frame of beliefs)
        using ``self.step_time()``.

        Returns
        ----------
        action: 1D ``numpy.ndarray``
            Vector containing the indices of the actions for each control factor
        """
        if possible_first_actions != None:
            #Removing all policies leading us to uninteresting or forbiden action. //speed computation//
            policies, q_pi = zip(*[(policy, q_pi[p_id]) for p_id, policy \
                                   in enumerate(self.policies) if policy[0] in possible_first_actions])
        else:
            policies =  self.policies

        if self.sampling_mode == "marginal":
            action = control.sample_action(
                q_pi, policies, self.num_controls, action_selection = self.action_selection, alpha = self.alpha
            )
        elif self.sampling_mode == "full":
            action = control.sample_policy(q_pi, policies, self.num_controls,
                                           action_selection=self.action_selection, alpha=self.alpha)
        self.action = action
        self.step_time()

        return action
    
    def infer_current_most_likely_pose(self, observations:list, z_score:float=2, min_z_score:float=2):
        ''' define our position p '''
        
        # if self.current_pose is None:
        #     z_score = 2
        
        p_idx = self.get_current_most_likely_pose(z_score, min_z_score, observations = observations)
        #if we have a pose, replace current inferred pose by the most likely one.
        if p_idx >= 0:
            self.current_pose = self.PoseMemory.id_to_pose(p_idx)
            self.PoseMemory.reset_odom(self.current_pose)
            #print('updating believed pose given certitude on state:', self.current_pose)
        elif p_idx < -1:
            self.current_pose = None
        return p_idx
    
    def infer_best_action_given_actions(self, G:list, actions:list, action_selection:str=None, alpha:float=None):
        if isinstance(actions[0], (int, np.int64)):
            actions = np.array([[[a]] for a in actions])

        if action_selection is None:
            action_selection = self.action_selection
        if alpha is None:
            alpha = self.alpha
        G = np.array(G)
        lnE = spm_log_single(np.ones(G.shape) / len(G))

        q_pi = softmax(G * self.gamma + lnE) 
        if self.sampling_mode == "marginal":
            best_action = control.sample_action(q_pi, policies = actions, num_controls = self.num_controls, action_selection = action_selection, alpha= alpha)
        elif self.sampling_mode == "full":
            best_action = control.sample_policy(q_pi, actions, self.num_controls,
                                           action_selection=action_selection, alpha=alpha)
        return q_pi, int(best_action[0])
    
    def infer_utility_term(self, qo_pi:np.ndarray, C=None)->float:
        """ given the observation belief of a state, what is the utility term"""
        if C is None:
            C = self.C
        utility_term = calc_expected_utility(qo_pi, C)  #negative value, the highest the more interesting
        if self.inductive_inference:
            Bt = (self.get_B() > self.certitude_transition_threshold).astype(float)
            H = 0
            for n in range(self.num_steps):
                I_next = ((Bt.T @ I[-1]) > 0 ).astype(float) # New reachable states (as bool)
                print(I_next)
                I.append(I_next[0])
                print('setp', len(I)-1)

            utility_term += H
        return utility_term
    
    def infer_info_gain_term(self, qs_pi:np.ndarray, A=None)->float:
        """ given the belief of a state, what is the info gain term (Note the method expects several qs, thus qs must be in 3 layers of np.ndarray)"""
        if A is None:
            A = self.A
        return calc_states_info_gain(A, qs_pi)
    
    def infer_param_info_gain(self, qs_pi:np.ndarray, qo_pi:np.ndarray, qs:np.ndarray, action:int):
        """Infer param info gain for an action (but can also be a policy)"""
        G = 0
        if self.pA is not None:
            param_info_gain = calc_pA_info_gain(self.pA, qo_pi, qs_pi)
            G +=  param_info_gain
        if self.pB is not None:
            if isinstance(action, (int,np.int64)):
                action = np.array([[action]])
            param_info_gain = calc_pB_info_gain(self.pB, qs_pi, qs, action)
            G +=  param_info_gain

        return G
    
    #==== OTHER METHODS ====#

    def determine_next_pose(self, action_id:int, pose:list=None,  min_dist_to_next_node:float=None):
        next_pose = self.PoseMemory.pose_transition_from_action(action =action_id, odom= pose, ideal_dist=min_dist_to_next_node)
        next_pose = [round(elem, 2) for elem in next_pose]
        next_pose_id = self.PoseMemory.pose_to_id(next_pose, save_in_memory=False)
        #print('action, next pose and id', action_id, next_pose, next_pose_id)
        return next_pose, next_pose_id

    def determine_action_given_angle_deg(self, angle):
        """
        Find the key in possible_actions corresponding to the given angle.

        Args:
            angle (float): The angle to check in DEGREES

        Returns:
            int: The corresponding action key, or None if no match is found.
        """
        actions = self.possible_actions.copy()
        if "STAY" in actions.values():
            actions.popitem()
        action_key = [k for k,v in actions.items() if v[0] <= angle and v[1] > angle]
        return action_key[0]
        
        #same thing
        # for key, value in self.possible_actions.items():
        #     if value == "STAY":
        #         continue
        #     if value[0] <= angle < value[1]:  # Check if angle falls within range
        #         return key
        # return None

    def calculate_min_dist_to_next_node(self, state_step:int=1):
        return self.influence_radius * state_step + self.robot_dim/2#/3 to consider -a little- robot_dim when adding nodes.as_integer_ratio
    
    def define_next_possible_actions(self, obstacle_dist_per_actions:list, restrictive:bool=False, logs=None):
        min_dist = self.calculate_min_dist_to_next_node()
        
        n_actions = len(self.possible_actions) - ("STAY" in self.possible_actions.values())
        possible_actions = [i for i in range(n_actions) if obstacle_dist_per_actions[i] >= min_dist]
        if restrictive:
            possible_actions_2 = possible_actions[:]
            for action in possible_actions_2:
                next_pose, next_pose_id = self.determine_next_pose(action, min_dist_to_next_node=min_dist)
                registered_pose = self.PoseMemory.id_to_pose(next_pose_id)
                if logs:
                    logs.info(f'next pose{next_pose}{next_pose_id}, with action{action}, but registered_pose{registered_pose}')
                if registered_pose[0] != next_pose[0] or registered_pose[1] != next_pose[1] :
                    possible_actions.remove(action)
                    
        if "STAY" in self.possible_actions.values():
            possible_actions.append(n_actions)

        return possible_actions

    #==== Update A, B, C ====#
    def update_A_with_data(self,obs:list, state:int)->np.ndarray:
        """Given obs and state, update A entry """
        A = self.A
        
        for dim in range(self.num_modalities ):
            A[dim][:,state] = 0
            A[dim][obs[dim],state] = 1
        self.A = A
        return A
    
    def update_A(self, obs, qs=None):
        """
        Update approximate posterior beliefs about Dirichlet parameters that parameterise the observation likelihood or ``A`` array.

        Parameters
        ----------
        observation: ``list`` or ``tuple`` of ints
            The observation input. Each entry ``observation[m]`` stores the index of the discrete
            observation for modality ``m``.

        Returns
        -----------
        qA: ``numpy.ndarray`` of dtype object
            Posterior Dirichlet parameters over observation self (same shape as ``A``), after having updated it with observations.
        """
        if qs is None:
            qs = self.qs
        qA = update_obs_likelihood_dirichlet(
            self.pA, 
            self.A, 
            obs, 
            qs, 
            self.lr_pA, 
            self.modalities_to_learn
        )

        self.pA = qA # set new prior to posterior
        self.A = utils.norm_dist_obj_arr(qA) # take expected value of posterior Dirichlet parameters to calculate posterior over A array

        return qA
    
    def update_B(self,qs:np.ndarray, qs_prev:np.ndarray, action:int, lr_pB:int=None)-> np.ndarray:
        """
        Updates the posterior Dirichlet parameters (`pB`) that parameterize the transition likelihood (`B`).

        This function refines the transition model by incorporating new posterior beliefs about states (`qs`),
        previous state beliefs (`qs_prev`), and the most recent action taken. The update is performed using 
        a Dirichlet-multinomial approach, ensuring a smooth adaptation of transition probabilities.

        Parameters
        ----------
        qs : numpy.ndarray
            Marginal posterior beliefs over hidden states at the current time step.

        qs_prev : numpy.ndarray
            Marginal posterior beliefs over hidden states at the previous time step.

        action : int
            The most recent action taken by the agent, which affects transition updates.

        lr_pB : int, optional
            Learning rate for updating `pB`. If not specified, defaults to `self.lr_pB`.

        Returns
        -------
        qB : numpy.ndarray
            Updated posterior Dirichlet parameters over transition probabilities (`B`). 
            This has the same shape as `B` but now incorporates learned state-action transitions.

        Notes
        -----
        - The update is computed using the `update_state_likelihood_dirichlet` function, 
        which adjusts `pB` based on the observed transitions.
        - The function ensures that `qB` does not contain negative values by applying a failsafe correction.
        - Transition probabilities (`B`) are normalized after updating `pB` to maintain a valid probability distribution.
        - If `lr_pB` is negative, a failsafe mechanism prevents `qB` from dropping below a minimum threshold (0.005).
        - The updated `qB` is stored as `self.pB`, and `B` is re-normalized for future inference.

        """
        
        if lr_pB is None:
            lr_pB = self.lr_pB
        
        qB = update_state_likelihood_dirichlet(
            self.pB,
            self.B,
            [action],
            qs,
            qs_prev,
            lr_pB,
            self.factors_to_learn
        )
    
        qB[0] = np.maximum(qB[0], 0.005) #No negative value (failsafe because of lr possibly negative)
        #no 0 values because 0 values can't variate anymore
        self.pB = qB # set new prior to posterior
        self.B = utils.norm_dist_obj_arr(qB)  # take expected value of posterior Dirichlet parameters to calculate posterior over B array
        return qB

    def update_A_dim_given_obs(self, obs:list,null_proba:list=[True]) -> np.ndarray:
        ''' 
        Verify if the observations are new and fit into the current A shape.
        If not, increase A shape in observation (n row) only.
        '''
        A = self.A
        num_obs, num_states, num_modalities, num_factors = utils.get_model_dimensions(A=A)
        
        # Calculate the maximum dimension increase needed across all modalities
        dim_add = [int(max(0,obs[m] + 1 - num_obs[m])) for m in range(num_modalities)]
        # Update matrices size
        for m in range(num_modalities):
            A[m] = update_A_matrix_size(A[m], add_ob=dim_add[m], null_proba=null_proba[m])
            if self.pA is not None:
                self.pA[m] = utils.dirichlet_like(utils.to_obj_array(A[m]), scale=1)[0]
        num_obs, num_states, num_modalities, num_factors = utils.get_model_dimensions(A=A)
        self.num_obs = num_obs
        self.A = A
        return A
    
    def update_A_dim_given_pose(self, pose_idx:int,null_proba:bool=True) -> np.ndarray:
        ''' 
        Verify if the observations are new and fit into the current A shape.
        If not, increase A shape and associate those observations with the newest state generated.
        If yes, search for the first empty column available and fill it with new inferred position (pose_idx)
        '''
        A = self.A
        num_obs, num_states, num_modalities, num_factors = utils.get_model_dimensions(A=A)
        if pose_idx >= max(num_states):
            # Calculate the maximum dimension increase needed across all modalities
            dim_add = int(max(0,pose_idx + 1 - num_obs[num_modalities-1]))
            # Update matrices size
            #and associate new observations with the newest state generated
            if dim_add > 0:
                A[0] = update_A_matrix_size(A[0], add_state=dim_add, null_proba=null_proba)
                if num_modalities > 1:
                    A[1] = update_A_matrix_size(A[1], add_ob=dim_add, add_state=dim_add, null_proba=null_proba)
                    self.num_obs[1] = A[1].shape[0]
        if num_modalities > 1:
            #columns_wthout_data = np.sort(np.append(np.where(np.all(A[1] == 1/A[1].shape[0], axis=0))[0], np.where(np.all(A[1] == 0, axis=0))[0]))
            A[1][:, pose_idx] = 0
            A[1][pose_idx, pose_idx] = 1
            

        if self.pA is not None:
            self.pA = utils.dirichlet_like(utils.to_obj_array(A), scale=1)
                        
        self.num_states = [A[0].shape[1]]
        self.A = A
        return A
    
    def update_B_dim_given_A(self)-> np.ndarray:
        """ knowing A dimension, update B state dimension to match"""
        B = self.B
        add_dim = self.A[0].shape[1]-B[0].shape[1]
        if add_dim > 0: 
            #increase B dim
            B = update_B_matrix_size(B, add= add_dim)
            self.pB = update_B_matrix_size(self.pB, add= add_dim, alter_weights=True)
            if len(self.qs) > 1:
                for seq in self.qs:
                    for subseq in seq:
                        subseq[0] = np.append(subseq[0], [0] * add_dim)
            else:
            
                self.qs[0] = np.append(self.qs[0],[0]*add_dim)
        
        self.num_states = [B[0].shape[0]]
        self.B = B
        return B
    
    def update_believes_with_obs(self, Qs:list, action:int, obs:list)-> None:
        """
        Updates the model's transition (`B`) and observation (`A`) matrices using the given 
        posterior beliefs over states (`Qs`), action taken, and new observation.

        Parameters
        ----------
        Qs : list
            The updated posterior beliefs over hidden states.

        action : int
            The action taken at the current step.

        obs : list
            The observed sensory input at the current step.

        Returns
        -------
        None
            Updates the transition (`B`) and observation (`A`) matrices in-place.

        Notes
        -----
        - If `qs_hist` is available, it retrieves the previous belief state and ensures consistency 
        in dimensionality before updating `B` using `Qs`.
        - If the transition resulted in a change of state, the function also updates `B` for 
        the reverse transition using the inverse action.
        - After updating `B`, the function re-infers the new state (`Qs`) based on `obs` and 
        updates `A` accordingly.
        """
        if len(self.qs_hist) > 0:#secutity check
            qs_hist = self.qs_hist[-1]
            qs_hist[0] = np.append(qs_hist[0],[0]*\
                                   (len(Qs[0])-len(qs_hist[0])))
            self.update_B(Qs, qs_hist, action, lr_pB = 10) 
            #2 WAYS TRANSITION UPDATE (only if T to diff state)
            if np.argmax(qs_hist[0]) != np.argmax(Qs[0]):
                a_inv = reverse_action(self.possible_actions, action)
                self.update_B(qs_hist, Qs, a_inv, lr_pB = 7)
        Qs = self.infer_states(obs) 
        self.update_A(obs, Qs)

    def update_B_given_unreachable_pose(self,next_pose:list, action:int)-> None:
        """ We reduce transition probability between those 2 states that do not connect"""
        if self.current_pose is not None and next_pose in self.PoseMemory.get_poses_from_memory() :
            n_pose_id = self.PoseMemory.pose_to_id(next_pose)
            qs = self.get_belief_over_states()
            hypo_qs = self.infer_states([n_pose_id], np.array([action]), partial_ob=1, save_hist=False)

            # print(self.B[0][np.argmax(hypo_qs[0])][np.argmax(qs[0])][action], self.B[0][np.argmax(qs[0])][np.argmax(hypo_qs[0])][action])
            # print(self.B[0][np.argmax(qs[0])][np.argmax(hypo_qs[0])][action], self.B[0][np.argmax(hypo_qs[0])][np.argmax(qs[0])][action])
            self.update_B(hypo_qs, qs,action,lr_pB=-10)
            a_inv = reverse_action(self.possible_actions, action)
            self.update_B(qs,hypo_qs,a_inv,lr_pB=-7)

    def update_qs_dim(self, qs:np.array=None, update_qs_memory:bool=True)->np.ndarray:
        if qs is None:
            qs = self.qs[:]
        if len(qs[0]) < self.B[0].shape[0]:
            qs[0] = np.append(qs[0],[0]*(self.B[0].shape[0]-len(qs[0])))

        if update_qs_memory:       
            self.qs = qs
            for p_step, past_qs in enumerate(self.qs_hist):
                if len(past_qs[0]) < self.B[0].shape[0]:
                    past_qs[0] = np.append(past_qs[0],[0]*(self.B[0].shape[0]-len(past_qs[0])))
                    self.qs_hist[p_step]= past_qs
        return qs
    
    def update_C_dim(self):
        if self.C is not None:
            num_obs, num_states, num_modalities, num_factors = utils.get_model_dimensions(A=self.A) 
            for m in range(num_modalities):
                if self.C[m].shape[0] < num_obs[m]:
                    self.C[m] = np.append(self.C[m], [0]*(num_obs[m]- self.C[m].shape[0]))
    
    def update_preference(self, obs:list=[-1,-1], pref_weight:float=1.0):
        """given a list of observations (must fill all expected observation. If no preference enters -1) we fill C with thos as preference. 
        If we have a partial preference over several observations, 
        then the given observation should be an integer < 0, the preference will be a null array 
        """
        if isinstance(obs, list):
            self.update_A_dim_given_obs(obs, null_proba=[False]*len(obs))

            C = self._construct_C_prior()

            for modality, ob in enumerate(obs):
                if ob >= 0:
                    self.preferred_ob[modality] = ob
                    ob_processed = utils.process_observation(ob, 1, [self.num_obs[modality]])
                    ob = utils.to_obj_array(ob_processed)
                else:
                    ob = utils.obj_array_zeros([self.num_obs[modality]])
                C[modality] = np.array(ob[0]) * pref_weight

            if not isinstance(C, np.ndarray):
                raise TypeError(
                    'C vector must be a numpy array'
                )
            self.C = utils.to_obj_array(C)

            assert len(self.C) == self.num_modalities, f"Check C vector: number of sub-arrays must be equal to number of observation modalities: {self.num_modalities}"

            for modality, c_m in enumerate(self.C):
                assert c_m.shape[0] == self.num_obs[modality], f"Check C vector: number of rows of C vector for modality {modality} should be equal to {self.num_obs[modality]}"
        else:
            self.preferred_ob = [-1,-1]
            self.C = self._construct_C_prior()

    #====== UPDATE MODEL ======#
    def update_transitions_both_ways(self,qs:np.ndarray, next_qs:np.ndarray, action_id:int, \
                                     direct_lr_pB:int, reverse_lr_pB:int)-> None:
        
        a_inv = reverse_action(self.possible_actions, action_id)
        self.update_B(next_qs, qs, action_id, lr_pB = direct_lr_pB) 
        self.update_B(qs, next_qs, a_inv, lr_pB = reverse_lr_pB)
   
    def update_imagined_translation(self, qs:np.ndarray, action_jump:int, n_actions:int, action_id:int, cur_pose:list, \
                                    min_dist_to_next_node:int, obstacle_dist_per_actions:int):
        """
        Updates the model's transition probabilities (`B` matrix) from current pose to imagined poses and between imagined poses up to 'action_jump' actions away
        It reinforces direct and indirect motion transitions while considering obstacles. 

        Parameters
        ----------
        qs : np.ndarray
            The posterior belief over states before taking the action.

        action_jump : int
            The range of adjacent actions to consider when updating transitions.

        n_actions : int
            The total number of possible actions.

        action_id : int
            The current action being taken.

        cur_pose : list
            The current position in the environment.

        min_dist_to_next_node : int
            The minimum distance required to move to the next node.

        obstacle_dist_per_actions : int
            The distance of obstacles from the current state for each possible action.

        hypo_qs : list or None
            The hypothetical belief state over hidden states, used when reinforcing indirect motion. 
            If `None`, direct motion updates are applied.

        Returns
        -------
        None
            Updates the transition probabilities (`B` matrix) in-place.

        Notes
        -----

        """
        #1) We get current physical pose info + the next imagined pose info given action_id
        _, next_pose_id = self.determine_next_pose(action_id, cur_pose, min_dist_to_next_node) #from odom to next imagined pose
        next_pose = self.PoseMemory.id_to_pose(next_pose_id) #get memorised pose (not the approximated one)
        cur_pose_id = self.PoseMemory.pose_to_id(cur_pose) #get odom pose id
        next_state_ob_dist = obstacle_dist_per_actions[action_id] #to check if ob between physical pose to next imagined pose
        
        prev_action = -1
        #print('__')
        for offset in range(action_jump, -action_jump - 1, -1):
            action_adjacent = action_id + offset
            #print('action_adjacent, offset', action_adjacent, offset)
            #restraingning action between possible actions numbers
            if action_adjacent < 0 :
                action_adjacent = n_actions +action_adjacent
            else:
                action_adjacent %= (n_actions)
            #from physical pose, get next adjacent pose given offset action 
            next_adjacent_pose, next_adjacent_pose_id = self.determine_next_pose(action_adjacent, cur_pose, min_dist_to_next_node)
            #if no adjacent pose, nothing to do
            #print('next_adjacent_pose_id', next_adjacent_pose_id,'next_pose_id', next_pose_id, 'cur_pose_id',cur_pose_id)
            if next_adjacent_pose_id < 0 or next_adjacent_pose_id==cur_pose_id:
                continue
            #if adjacent pose exists, then we get it's state (obtained from Transitioning from physical state to this adjacent state)
            adjacent_qs = self.infer_states([next_adjacent_pose_id], np.array([action_adjacent]), save_hist=False, partial_ob=1, qs=qs)
            adjacent_state_dist_to_ob = obstacle_dist_per_actions[action_adjacent] #get if ob at this adjacent pose
            #print('offset', offset, 'action_adjacent', action_adjacent, 'adjacent_state_dist_to_ob', round(adjacent_state_dist_to_ob,2))
            #We correct to pose ID pose, to be sure it matches
            next_adjacent_pose = self.PoseMemory.id_to_pose(next_adjacent_pose_id) #get memorised pose (not the approximated one)
            
            #If known state and this is the direct transition from physical to an already existing state 
            if offset == 0 : #if direct motion from current state to another state  
                reference_qs = qs
                action = action_adjacent
                direct_lr_pB = 5
                reverse_lr_pB = 3
                pose_in_action_range = self.PoseMemory.pose_in_action_range(action, next_pose, odom= cur_pose)
                #print('direct transition from new current odom', cur_pose_id, 'to', next_pose_id)
                
            #If this transition is a lateral transition 
            elif offset != 0 and next_adjacent_pose_id != next_pose_id: #if indirect motion, we don't want to reinforce 'stay' motion with wrong action
                reference_qs = self.infer_states([next_pose_id], np.array([action_id]), save_hist=False, partial_ob=1, qs=qs) #next direct imagined pose qs
                angle = angle_turn_from_pose_to_p(pose = next_pose, goal_pose= next_adjacent_pose, in_deg=True)
                action = self.determine_action_given_angle_deg(angle)
                direct_lr_pB = 1
                reverse_lr_pB = 1
                pose_in_action_range = self.PoseMemory.pose_in_action_range(action, next_adjacent_pose, odom= next_pose)
                #Just to avoid reinforcing same link several times (can happens if we check pose to id only considering distance)
                if prev_action == action:
                    #print('already updated that transition')
                    continue
                prev_action = action
                #print('lateral transition from imagined pose',next_pose_id, 'to', next_adjacent_pose_id)
            else:
                continue
            #print('pose_in_action_range', pose_in_action_range, 'action', action,'next_pose', next_adjacent_pose, 'odom', next_pose)
            #If the pose is not in this action range, we don't enforce it + we can't have the poses being unreachable.
            if pose_in_action_range and adjacent_state_dist_to_ob > min_dist_to_next_node and next_state_ob_dist > min_dist_to_next_node :
                # Positive LR
                self.update_transitions_both_ways(reference_qs, adjacent_qs, action, direct_lr_pB=direct_lr_pB, reverse_lr_pB=reverse_lr_pB)
            elif pose_in_action_range:
                #print('negative reinforcement')
                # Negative LR
                self.update_transitions_both_ways(reference_qs, adjacent_qs, action, direct_lr_pB=-direct_lr_pB, reverse_lr_pB=-reverse_lr_pB)

    def update_transition_nodes(self, obstacle_dist_per_actions:list)-> None:
        ''' 
        For each new pose observation, add a ghost state and update the estimated transition and observation for that ghost state.
        '''
        #print('Ghost nodes process:')
        action_jump = int((len(self.possible_actions)-1) / 6)
        sure_about_data_until_this_state = 1
        ''' 
            TODO LIST: 
            1) Check if transition possible 
            IF impossible motion:
                2) increase transition from current state to current state for this action (both ways)
                3) if obstacle, check if there is a transition existing for this action and decrease state transition (both ways)
            go to next action
            Else: 
                4) infer new pose in that direction 
                5) check if a pose exist in that direction (margin of zone of influence)
                6) if yes, increase transition prob to existing pose with that action
                7) if no, 
                    7') increase all matrices IF NEED BE
                    7'')create new node 
                8) check if previous and next action have obstacle. If no, link previous/next pose node to current pose node
                9) from this node, check if action still possible further with increased zone of influence + margin (thus until we reach an obstacle) 
            skip next action (as we want 6 nodes around if no obstacle anywhere)
            '''
        
        n_actions = len(self.possible_actions) - ("STAY" in self.possible_actions.values())
        qs = self.get_belief_over_states()
        
        min_dist_to_next_node = self.calculate_min_dist_to_next_node()
        
        for action_id in range(n_actions):
            #print('______________________________')
            hypo_qs = None
            state_step = 1
            prev_step_qs = qs[:]
            no_obstacle = True
            cur_pose = self.PoseMemory.get_odom().copy()
            cur_ref_pose = cur_pose.copy()
            pose_in_action_range = True
            
            #9)
            #The second element is just to avoid any risk of infinity loop
            while no_obstacle and state_step<=self.lookahead_node_creation:
                next_state_min_dist_to_next_node = self.calculate_min_dist_to_next_node(state_step)
                #1)
                #Is obstacle too close?
                #print('for action', action_id, 'obstacle', obstacle_dist_per_actions[action_id], 'min_dist for new state', next_state_min_dist_to_next_node)
                if obstacle_dist_per_actions[action_id] <=  next_state_min_dist_to_next_node :
                    no_obstacle = False 
                    #Only enforce the loop back to current pose if it's a direct motion
                    if state_step <= sure_about_data_until_this_state:
                        #2)
                        #print('enforcing motion:',action_id,' leads to current state')
                        self.update_B(qs, qs, action_id, lr_pB = 10)   
                else:
                    #4)
                    next_pose, next_pose_id = self.determine_next_pose(action_id, cur_ref_pose, min_dist_to_next_node)
                    #print('next_pose', next_pose, self.PoseMemory.get_poses_from_memory().copy())
                    #5) ->7)  
                    if next_pose_id < 0 and pose_in_action_range:
                        next_pose_id = self.PoseMemory.pose_to_id(next_pose) 
                        #print('creating new node in position', next_pose_id)
                        #7')
                        self.update_A_dim_given_pose(next_pose_id,null_proba=True)
                        self.update_B_dim_given_A()
                        self.update_C_dim()
                        prev_step_qs = self.update_qs_dim(prev_step_qs,update_qs_memory=False)
                        #7'')
                        hypo_qs = self.infer_states([next_pose_id], np.array([action_id]), partial_ob=1, save_hist=False, qs=prev_step_qs)
                        self.update_agent_state_mapping(tuple(next_pose[:2]), [-1, next_pose_id], hypo_qs[0])
                        #We don't want lateral state transition updated when we are extrapolating further than "sure_about_data_until_this_state"
                        #plus we only want the action continuing in a straight line from physical current pose. 
                        #becquse the beam rqnge grows bigger as the vectors are longer.                            
                    else:
                        #print('pose existing nearby as', next_pose_id,'not creating new node')
                        hypo_qs = self.infer_states([next_pose_id], np.array([action_id]), partial_ob=1, save_hist=False, qs=prev_step_qs)

                    if state_step > sure_about_data_until_this_state and pose_in_action_range:
                            #print('DIRECT MOTION AT STATE STEP',state_step)
                            self.update_imagined_translation(prev_step_qs[:], 0, n_actions, action_id, cur_ref_pose, \
                                        min_dist_to_next_node, obstacle_dist_per_actions)
                    prev_step_qs = hypo_qs[:]
                    cur_ref_pose = self.PoseMemory.id_to_pose(next_pose_id)
                    pose_in_action_range = self.PoseMemory.pose_in_action_range(action_id, cur_ref_pose, odom= cur_pose, influence_radius=next_state_min_dist_to_next_node)#doesn't work
                    #print('cur_ref_pose', cur_ref_pose, 'can be reached from ', cur_pose, 'with action', action_id,'?:', pose_in_action_range)
                    
                state_step +=1
            #3), 5)->6)with offset 0 and 8)
            qs = self.update_qs_dim(qs)
            self.update_imagined_translation(qs[:], action_jump, n_actions, action_id, cur_pose, \
                                   min_dist_to_next_node, obstacle_dist_per_actions)
               
    def agent_step_update(self, action_id:int, obs:list, logs=None)->None:
        """
        Updates the agent's belief state, transition probabilities, and learned environment 
        model based on the given action and observations.

        This method performs the following steps:
        1. **Infer new pose**: Updates the agent's estimated position.
        2. **Update observation model (A) and control model (C)**: Adjusts probability distributions 
        to incorporate new observations.
        3. **Update transition model (B)**: Modifies transition probabilities based on inferred states.
        4. **Update belief states**: Updates internal beliefs given the latest observations.
        5. **Update state mapping for visualization**: Stores the agent's inferred position and 
        corresponding belief state.
        6. **Update transition nodes**: Modifies state transitions based on perceived obstacles.
        7. **Ensure stationary transitions**: If a 'STAY' action exists, enforces it in the transition model.

        Parameters:
            action_id (int): The index of the action taken by the agent.
            obs (list): The list of observations, expected to contain:
                - primary_ob (int): The primary observed feature (e.g., color).
                - pose_id (int, optional): The ID of the inferred pose. 
                - obstacles_dist_per_action_range (list): Distances to obstacles for each action.

        Returns:
            None
        """
        #we could get action_id from ours.action instead?

        #We only update A and B if we have inferred a current pose
        #Thus until doubt over current loc is not solved, we don't update internal self
        if self.current_pose != None:
            #Just for memory sake
            primary_ob = obs[0]
            if isinstance(obs[1],int):
                pose_id = obs[1]
            else:
                pose_id = self.PoseMemory.pose_to_id(self.current_pose)

            obstacles_dist_per_action_range = obs[-1]
            
            #1. INFER NEW POSE (should be after motion and before update)
            # pose = self.PoseMemory.id_to_pose(pose_id)
            # self.current_pose = self.infer_pose(pose) #Not sure it shouold be here. in case i want to give whatever...
            
            observations = [primary_ob,pose_id]
            if logs:
                logs.info('observations pose %f, action %f, ob_id %f, obstacles %s' % (pose_id, action_id, primary_ob, str(obstacles_dist_per_action_range)))
            #2. UPDATE A C DIM IF NEW OB
            self.update_A_dim_given_obs(observations, null_proba=[True,False])
            self.update_C_dim()
            #updating B in case pose_id new
            self.update_B_dim_given_A()
            #3. UPDATE BELIEVES GIVEN OBS
            prior = self.infer_states(observations, save_hist=False)

            #print('prior on believed state; action', self.action, action_id, \
            #    'colour_ob:', primary_ob , 'inf pose:',self.current_pose,'prior belief:', prior[0].round(3))
                
            self.update_believes_with_obs(prior,action=action_id, obs=observations)

            posterior = self.infer_states(observations, save_hist=True)
            ## agent_state_mapping for TEST PURPOSES and visualisation
            self.update_agent_state_mapping(tuple(self.current_pose[:2]), observations, posterior[0])
            #4. update all nodes
            self.update_transition_nodes(obstacle_dist_per_actions=obstacles_dist_per_action_range)
            #This is not mandatory, just a gain of time
            if 'STAY' in self.possible_actions.values():
                stay_action = [key for key, value in self.possible_actions.items() if value == 'STAY'][0]
                self.B[0] = set_stationary(self.B[0], stay_action)

In [12]:
from scipy.stats import median_abs_deviation

def get_observation_most_likely_states(self, observations: list, per_threshold: float = 0.5) -> list:
    """
    Robust standout detector using percentile, fallback for sparse/multi-peak distributions.
    """
    likely_states = {}

    for modality, ob in enumerate(observations):
        if ob < 0:
                continue
        standout_indices = []
        qo = np.array(self.get_A()[modality][ob])
        # print('qo', qo.round(3))
        # threshold for standout values
        standout_indices = np.where(qo >= per_threshold)[0]

        # Special case: only one clear maximum, much larger than rest
        max_val = np.max(qo)
        second_max = np.partition(qo.flatten(), -2)[-2]
        # print('max_val', max_val,'second_max',second_max)
        if max_val < 4 * second_max and second_max not in standout_indices:
            np.append(standout_indices,second_max)
        # print('standout_indices',standout_indices)
        for idx in standout_indices:
            likely_states[idx] = likely_states.get(idx, 0) + 1

    if not likely_states:
        return []
#     print('likely_states', likely_states)


    # Return most recurrent standout indices across modalities
    most_recurrent = max(likely_states.values())
    standout_final = [state for state, count in likely_states.items() if count == most_recurrent]

    return standout_final



def prev_get_observation_most_likely_states(self, observations:list,z_score:float=8)->list:
        """
        Given a z_scores (usually around 8) and observations likelihood, which state do we expect?
        Consider expected states through all modality and return most recurrent states
        """
        likely_states = {}
        standout_indices = []
        print('observations',observations)
        for modality, ob in enumerate(observations):
                qo = self.get_A()[modality][ob]
                print('qo', qo.round(3))

                max_val = np.max(qo)
                max_idx = np.argmax(qo)
                second_max = np.partition(qo, -2)[-2]
                print('max_val', max_val,'second_max',second_max)
                # If max is significantly greater than second max → clearly dominant
                if max_val > 4 * second_max:
                        likely_states[max_idx] = likely_states.get(max_idx, 0) + 1
                        continue
                mad = median_abs_deviation(qo, scale='normal')
                if mad == 0:
                        # Handle one-hot like case
                        if np.count_nonzero(qo >= 1.0 ) == 1:
                                print('HERE')
                                idx = np.argmax(qo)
                                likely_states[idx] = likely_states.get(idx, 0) + 1
                        continue  # No variability, nothing stands out
                
                
                median = np.median(qo)
                print('mad median', mad, median,)
                z_scores = (qo - median) / mad
                indices = np.where(z_scores > z_score)[0]
                np.set_printoptions(suppress=True)
                print('zscore and indices',z_scores.round(1), indices)
                for state in indices:
                        likely_states[state] = likely_states.get(state, 0) + 1
                  
        if len(likely_states) == 0:
                return []
        most_recurrent_state = max(likely_states.values())
        standout_indices = [key for key, value in likely_states.items() if value == most_recurrent_state]

        print(likely_states)
        return standout_indices

In [13]:
def update_C_dim(self):
    if self.C is not None:
        for m in range(self.num_modalities):
            if self.C[m].shape[0] < self.num_obs[m]:
                self.C[m] = np.append(self.C[m], [0]*(self.num_obs[m]- self.C[m].shape[0]))
                self.Cs = np.append(self.Cs, [0]*(self.num_states- len(self.Cs)))

In [14]:
def update_preference(self, obs:list=[-1,-1], pref_weight:float=1.0):
        """given a list of observations (must fill all expected observation. If no preference enters -1) we fill C with thos as preference. 
        If we have a partial preference over several observations, 
        then the given observation should be an integer < 0, the preference will be a null array 
        """
        if isinstance(obs, list):
            self.update_A_dim_given_obs(obs, null_proba=[False]*len(obs))

            C = self._construct_C_prior()
            Cs = np.zeros(self.num_states)

            for modality, ob in enumerate(obs):
                if ob >= 0:
                    self.preferred_ob[modality] = ob
                    ob_processed = utils.process_observation(ob, 1, [self.num_obs[modality]])
                    ob = utils.to_obj_array(ob_processed)
                else:
                    ob = utils.obj_array_zeros([self.num_obs[modality]])
                C[modality] = np.array(ob[0])

            if not isinstance(C, np.ndarray):
                raise TypeError(
                    'C vector must be a numpy array'
                )
            C = C * pref_weight
            self.C = utils.to_obj_array(C)


            assert len(self.C) == self.num_modalities, f"Check C vector: number of sub-arrays must be equal to number of observation modalities: {self.num_modalities}"

            for modality, c_m in enumerate(self.C):
                assert c_m.shape[0] == self.num_obs[modality], f"Check C vector: number of rows of C vector for modality {modality} should be equal to {self.num_obs[modality]}"
        else:
            self.preferred_ob = [-1,-1]
            self.C = self._construct_C_prior()

        desired_states = get_observation_most_likely_states(observations=obs, per_threshold=0.45)
        for state in desired_states:
            Cs[state] = 1.0
        self.Cs = Cs
        

In [30]:
def infer_inductive_preference(self, current_qs:np.ndarray, qs_pi:np.ndarray, C=None)->float:
        """ given the observation belief of a state, what is the utility term
        NOTE: THIS WORKS FOR VANILLA MODEL ONLY (NOT MMP) AS WE CONSIDER QS to have only 1 step 
        """
        if C is None:
            C = self.C
        current_qs = current_qs[0]
        qs_arg_max = np.argmax(current_qs)  #NOTE: not sure this is ideal...
        model_B = self.get_B()[:,qs_arg_max,:]
        median = np.median(model_B)
        B = model_B[model_B > median] 
        certitude_threshold = max(np.mean(B), 0.15)
        I = [copy.copy(self.Cs)]
        #Keep only certain Transitions
        B_certain_trans = (self.get_B() > certitude_threshold).astype(float)

        found_path = False
        # print('START with preferred state', np.argmax(I), I)
        # print('we are in starting state', qs_arg_max, 'prob',current_qs[qs_arg_max])
        # print('qs_pi',qs_pi)
        #from preferred states, which states lead to it then we repeat until we are in qs
        for n in range(self.num_steps):
            #TODO: ADD WHEN WE STOP FOR LOOP (WHEN WE ARE ON CURRENT STATE)
            I_next = ((B_certain_trans.T.dot(I[-1])) > 0).astype(float) # New reachable states (as bool -> float)
            I_next = np.max(I_next, axis=0) #We consider all states regardless of action
            # print('backward step', len(I)-1)
            I.append(I_next)
            #logging.info(f'States to inflate H {np.argwhere(I[n] > np.amin((I[n] >0).astype(float))).flatten()}')
            if I[-1][qs_arg_max] >= current_qs[qs_arg_max]:
                # print('we end process induction in ',n+1,' steps,current state', qs_arg_max)
                n-=1
                found_path = True
                break
        n+=1 #to consider that I[0] is goal
        if found_path:
            logging.info(f'Final States to inflate H {np.argwhere(I[n] > np.amin((I[n] >0).astype(float))).flatten()}')
            H = np.log(certitude_threshold)*I[n].dot(qs_pi[0])
        else:
            H = 0.0

        return H

In [16]:
ours = Ours_V5_RW(num_obs=2, num_states=2, n_actions=13, influence_radius=0.5,robot_dim=0.3, lookahead_node_creation=2)

In [17]:
stop

NameError: name 'stop' is not defined

# Update model

0 ,   1 ,       2 ,  4,   5 ,  7,  9,    15 ,    35, 39

_, [2,7,12-14], [1],[8], [0], [24], [11,40] , [44,45], [60], [33]

### TEst init MCTS

In [None]:
obstacles = [1.4612406492233276, 2.0710742473602295, 2.7915122509002686, 2.571498155593872, 2.9910919666290283, 2.616743803024292, 2.5963499546051025, 1.2535045146942139, 0.49500948190689087, 0.5026026964187622, 1.1839525699615479, 1.4570804834365845]

ob_id = 0

ours = Ours_V5_RW(num_obs=2, num_states=2, dim=2, \
                    observations=[ob_id], lookahead_policy=5,\
                    n_actions=13, influence_radius=1,\
                    robot_dim=0.3, lookahead_node_creation= 2)


In [None]:
ours.update_transition_nodes(obstacle_dist_per_actions=obstacles)

In [None]:
ours.update_preference([0,-1], pref_weight=1)
ours.C

In [None]:
ours.use_states_info_gain = True 
ours.use_utility = False
ours.use_param_info_gain = False #if true, do not use with the other terms

In [None]:
action, data = ours.define_actions_from_MCTS_run(num_step=1, observations=[0])

In [None]:
ours.action = np.array([12])

In [None]:
ours.get_belief_over_states()[0].round(3)

In [None]:
prior = ours.infer_states([0,0], save_hist=False)
prior

In [None]:
data

## TEST OUR MODEL ADJACENT STATE EFE

In [38]:
MODEL_PATH = '/home/idlab332/workspace/ros_ws/tests/big_ware/0/model.pkl' # Path to your pickled model
NUM_SIMULATIONS = 30  # Number of MCTS simulations per planning step
NUM_STEPS = 1      # Number of actions to take in the environment
MAX_ROLLOUT_DEPTH = 10 # Maximum depth for the simulation (rollout) phase
C_PARAM = 5
PLOT_TREE = True      # Whether to plot the MCTS tree after each planning step
pref_weight = 10
pref_obs = [11,-1]
vanilal_model = pickle_load_model(MODEL_PATH)
H_model = pickle_load_model(MODEL_PATH)

2025-06-17 18:44:57,536 - INFO - Model successfully loaded from: /home/idlab332/workspace/ros_ws/tests/big_ware/0/model.pkl
2025-06-17 18:44:57,685 - INFO - Model successfully loaded from: /home/idlab332/workspace/ros_ws/tests/big_ware/0/model.pkl


SET A GOAL

    #39 - 1step state33, 35 - 2 steps state 60, 30- 3steps state3, 7- 5 step - state24

In [39]:
for model in [vanilal_model, H_model]:
    #GOAL TESTS
    #print('np.mean(H_model.get_B())', np.mean(H_model.get_B()),1/model.get_B()[0].shape[0] )
    # model.certitude_transition_threshold  = np.mean(H_model.get_B())
    model.num_steps = MAX_ROLLOUT_DEPTH
    model.Cs = np.zeros(model.num_states)


    model.goal_oriented_navigation(pref_obs, pref_weight = pref_weight)
    model.use_utility = True
    obstacle_dist_per_actions = [4.507089614868164, 4.789198398590088, 4.365529537200928, 2.7395713329315186, 2.3621973991394043, 1.7037241458892822, 1.7129298448562622, 2.037290573120117, 1.3319873809814453, 6.884044647216797, 5.011831283569336, 4.510308742523193]
    possible_actions = model.define_next_possible_actions(obstacle_dist_per_actions, restrictive=True)
    desired_states = get_observation_most_likely_states(model,observations=pref_obs, per_threshold=0.45)
    for state in desired_states:
        model.Cs[state] = 1.0
    print('desired_states',desired_states)
    print('1 model setup')

# underlying_model.use_states_info_gain = True

desired_states [15]
1 model setup
desired_states [15]
1 model setup


### VANILLA

In [None]:
# Create the MCTS algorithm instance
vanilal_model.use_inductive_inference = False
mcts = MCTS(vanilal_model, c_param=C_PARAM, num_simulation=NUM_SIMULATIONS) # Adjust c_param if needed

# Get action names for logging
map_action_names = vanilal_model.get_possible_actions() # Assuming pose 0 exists

# Define the initial state
initial_pose_id = 53 # Or get from your model/environment
initial_belief_qs = vanilal_model.get_belief_over_states() # Get initial belief
initial_observation = vanilal_model.get_expected_observation(initial_belief_qs)
# Root node has no parent and no action leading to it
root_node = Node(state_qs=initial_belief_qs,
                pose_id=initial_pose_id,
                parent=None,
                action_index=None,
                observation=initial_observation, 
                possible_actions=possible_actions)

logging.info(f"===== Initial Root Node ID: {root_node.id} =====")

# --- Simulation Loop ---
current_node = root_node
data = {"qs": initial_belief_qs[0],
            "qpi": [],
            "efe": [],
            "info_gain": [],
            "utility": [],
            #"bayesian_surprise": utils.bayesian_surprise(posterior[0].copy(), prior),
            }

2025-06-17 18:28:23,966 - INFO - MCTS_Model_Interface initialized with model type: <class 'map_dm_nav.model.V5.Ours_V5_RW'>
2025-06-17 18:28:23,968 - INFO - MCTS initialized with exploration parameter c=5, num_simus=30, max_depth=10, policy_alpha=16.0,  action_selection=stochastic
2025-06-17 18:28:23,975 - INFO - ===== Initial Root Node ID: 53 =====


In [None]:
for i in range(NUM_STEPS):
    logging.info(f"\n===== Planning Step {i+1}/{NUM_STEPS} =====")
    logging.info(f"Current State (Node ID): {current_node.id}")

    # Plan the next action using MCTS
    # The root of the search is the current state node
    best_action, data = mcts.plan(current_node, NUM_SIMULATIONS, MAX_ROLLOUT_DEPTH, logging=logging, data=data)

    if best_action is None:
        logging.error("MCTS failed to find a best action. Stopping simulation.")
        break

    action_name = map_action_names.get(best_action, "Unknown")
    logging.info(f"Selected Action: {best_action} ({action_name})")

    # Display action values/visits from the root node
    if current_node.childs:
        child_info_list = []
        for action_id, child in current_node.childs.items():
                a_name = map_action_names.get(action_id, "?")
                child_info_list.append(f"  Action {action_id} ({a_name}): state={child.id} AvgR={child.get_averaged_reward():.3f}, N={child.N}")
        logging.info("Root Node Children Details:\n" + "\n".join(child_info_list))
        print('Action visit counts:', [current_node.childs[action_id].N for action_id in current_node.childs])
    else:
            logging.info("Root node has no children explored.")

    # Visualize the tree if enabled
    if PLOT_TREE:
        plot_mcts_tree(current_node)

    # --- Execute the selected action ---
    # In a real robot, this would involve sending the command and getting sensor feedback.
    # Here, we transition to the corresponding child node in the tree.
    if best_action in current_node.childs:
        next_node = current_node.childs[best_action]
        logging.info(f"Executing action {best_action} -> Transitioning to Node {next_node.id}")

        # IMPORTANT: Detach the chosen next state from its parent (the previous state).
        # This makes the chosen next state the new root for the *next* planning step
        # and allows the old parts of the tree to be garbage collected.
        
        #WITH MEMORY
        # next_node.detach_parent()
        #current_node = next_node # Update the current state

        #TMP TO TEST AS IN OUR MODEL

        current_node = Node(state_qs=next_node.state_qs,
                    pose_id=next_node.pose_id,
                    parent=None,
                    action_index=None,
                    observation=next_node.observation, 
                    possible_actions=next_node.possible_actions)

        # Log information about the new state (optional)
        # logging.info(f"New State Observation (Visual): {current_node.observation[0][0].round(2)}")
        # logging.info(f"New State Observation (Pose): {current_node.observation[0][1].round(2)}")

    else:
        logging.error(f"Consistency Error: Best action {best_action} not found in children of node {current_node.id}. Stopping.")
        break # Stop if the tree is inconsistent

logging.info(f"\n===== Simulation Finished =====")
logging.info(f"Completed {i+1 if 'i' in locals() else 0} steps.")
logging.info(f"Final State (Node ID): {current_node.id}")

2025-06-17 18:28:24,004 - INFO - 
===== Planning Step 1/1 =====
2025-06-17 18:28:24,007 - INFO - Current State (Node ID): 53
2025-06-17 18:28:24,008 - INFO - Starting MCTS planning from root node 53 for 30 simulations.
2025-06-17 18:28:24,010 - INFO - --- Simulation 1/30 ---
2025-06-17 18:28:24,011 - INFO - --- Selection Phase End (Selected Node: 53) ---
2025-06-17 18:28:24,023 - INFO - from node 53 -> Child Node 33, expanding with action 0(Initial full=-9.965, G=-9.965 H=0.000)
2025-06-17 18:28:24,037 - INFO - from node 53 -> Child Node 57, expanding with action 1(Initial full=-9.973, G=-9.973 H=0.000)
2025-06-17 18:28:24,049 - INFO - from node 53 -> Child Node 58, expanding with action 2(Initial full=-9.971, G=-9.971 H=0.000)
2025-06-17 18:28:24,063 - INFO - from node 53 -> Child Node 51, expanding with action 9(Initial full=-9.957, G=-9.957 H=0.000)
2025-06-17 18:28:24,072 - INFO - from node 53 -> Child Node 53, expanding with action 12(Initial full=-9.990, G=-9.990 H=0.000)
2025-06

Action visit counts: [21, 2, 1, 4, 1]


  plt.show()
2025-06-17 18:28:30,243 - INFO - Executing action 0 -> Transitioning to Node 33
2025-06-17 18:28:30,244 - INFO - 
===== Simulation Finished =====
2025-06-17 18:28:30,245 - INFO - Completed 1 steps.
2025-06-17 18:28:30,246 - INFO - Final State (Node ID): 33


### H model

In [40]:
# Create the MCTS algorithm instance
H_model.use_inductive_inference = True
mcts = MCTS(H_model, c_param=C_PARAM, num_simulation=NUM_SIMULATIONS) # Adjust c_param if needed

# Get action names for logging
map_action_names = H_model.get_possible_actions() # Assuming pose 0 exists

# Define the initial state
initial_pose_id = 53 # Or get from your model/environment
initial_belief_qs = H_model.get_belief_over_states() # Get initial belief
initial_observation = H_model.get_expected_observation(initial_belief_qs)
# Root node has no parent and no action leading to it
root_node = Node(state_qs=initial_belief_qs,
                pose_id=initial_pose_id,
                parent=None,
                action_index=None,
                observation=initial_observation, 
                possible_actions=possible_actions)

logging.info(f"===== Initial Root Node ID: {root_node.id} =====")

# --- Simulation Loop ---
current_node = root_node
data = {"qs": initial_belief_qs[0],
            "qpi": [],
            "efe": [],
            "info_gain": [],
            "utility": [],
            #"bayesian_surprise": utils.bayesian_surprise(posterior[0].copy(), prior),
            }

2025-06-17 18:45:37,285 - INFO - MCTS_Model_Interface initialized with model type: <class 'map_dm_nav.model.V5.Ours_V5_RW'>
2025-06-17 18:45:37,288 - INFO - MCTS initialized with exploration parameter c=5, num_simus=30, max_depth=10, policy_alpha=16.0,  action_selection=stochastic
2025-06-17 18:45:37,295 - INFO - ===== Initial Root Node ID: 53 =====


In [53]:
for i in range(NUM_STEPS):
    logging.info(f"\n===== Planning Step {i+1}/{NUM_STEPS} =====")
    logging.info(f"Current State (Node ID): {current_node.id}")

    # Plan the next action using MCTS
    # The root of the search is the current state node
    best_action, data = mcts.plan(current_node, NUM_SIMULATIONS, MAX_ROLLOUT_DEPTH, logging=logging, data=data)

    if best_action is None:
        logging.error("MCTS failed to find a best action. Stopping simulation.")
        break

    action_name = map_action_names.get(best_action, "Unknown")
    logging.info(f"Selected Action: {best_action} ({action_name})")

    # Display action values/visits from the root node
    if current_node.childs:
        child_info_list = []
        for action_id, child in current_node.childs.items():
                a_name = map_action_names.get(action_id, "?")
                child_info_list.append(f"  Action {action_id} ({a_name}): state={child.id} AvgR={child.get_averaged_reward():.3f}, N={child.N}")
        logging.info("Root Node Children Details:\n" + "\n".join(child_info_list))
        print('Action visit counts:', [current_node.childs[action_id].N for action_id in current_node.childs])
    else:
            logging.info("Root node has no children explored.")

    # Visualize the tree if enabled
    if PLOT_TREE:
        plot_mcts_tree(current_node)

    # --- Execute the selected action ---
    # In a real robot, this would involve sending the command and getting sensor feedback.
    # Here, we transition to the corresponding child node in the tree.
    if best_action in current_node.childs:
        next_node = current_node.childs[best_action]
        logging.info(f"Executing action {best_action} -> Transitioning to Node {next_node.id}")

        # IMPORTANT: Detach the chosen next state from its parent (the previous state).
        # This makes the chosen next state the new root for the *next* planning step
        # and allows the old parts of the tree to be garbage collected.
        
        #WITH MEMORY
        # next_node.detach_parent()
        #current_node = next_node # Update the current state

        #TMP TO TEST AS IN OUR MODEL

        current_node = Node(state_qs=next_node.state_qs,
                    pose_id=next_node.pose_id,
                    parent=None,
                    action_index=None,
                    observation=next_node.observation, 
                    possible_actions=next_node.possible_actions)

        # Log information about the new state (optional)
        # logging.info(f"New State Observation (Visual): {current_node.observation[0][0].round(2)}")
        # logging.info(f"New State Observation (Pose): {current_node.observation[0][1].round(2)}")

    else:
        logging.error(f"Consistency Error: Best action {best_action} not found in children of node {current_node.id}. Stopping.")
        break # Stop if the tree is inconsistent

logging.info(f"\n===== Simulation Finished =====")
logging.info(f"Completed {i+1 if 'i' in locals() else 0} steps.")
logging.info(f"Final State (Node ID): {current_node.id}")

2025-06-17 18:50:24,783 - INFO - 
===== Planning Step 1/1 =====
2025-06-17 18:50:24,784 - INFO - Current State (Node ID): 31
2025-06-17 18:50:24,785 - INFO - Starting MCTS planning from root node 31 for 30 simulations.
2025-06-17 18:50:24,786 - INFO - --- Simulation 1/30 ---
2025-06-17 18:50:24,787 - INFO - --- Selection Phase End (Selected Node: 31) ---
2025-06-17 18:50:24,801 - INFO - Final States to inflate H [15]
2025-06-17 18:50:24,802 - INFO - from node 31 -> Child Node 39, expanding with action 0(Initial full=-8.753, G=-8.985 H=0.231)
2025-06-17 18:50:24,816 - INFO - Final States to inflate H [15]
2025-06-17 18:50:24,818 - INFO - from node 31 -> Child Node 15, expanding with action 1(Initial full=-9.259, G=-9.395 H=0.136)
2025-06-17 18:50:24,834 - INFO - Final States to inflate H [15]
2025-06-17 18:50:24,836 - INFO - from node 31 -> Child Node 14, expanding with action 2(Initial full=-9.925, G=-9.934 H=0.009)
2025-06-17 18:50:24,851 - INFO - Final States to inflate H [15]
2025-0

Action visit counts: [6, 4, 6, 9, 4]


  plt.show()
2025-06-17 18:50:34,760 - INFO - Executing action 3 -> Transitioning to Node 30
2025-06-17 18:50:34,762 - INFO - 
===== Simulation Finished =====
2025-06-17 18:50:34,763 - INFO - Completed 1 steps.
2025-06-17 18:50:34,764 - INFO - Final State (Node ID): 30


In [None]:
H_model.get_B()[3,17,:].round(2)

array([0.02, 0.  , 0.02, 0.  , 0.  , 0.  , 0.85, 0.  , 0.  , 0.02, 0.  ,
       0.02, 0.  ])

In [None]:
stop

NameError: name 'stop' is not defined

## TEST 1 SQUARE MOTION NO OB - 

### init

In [None]:
import pandas as pd
csvfile = pd.read_csv('/home/idlab332/workspace/ros_ws/pose_obs_test1_v3.csv')
# csvfile

In [None]:
#START POSE
ob_id = 0 #csvfile['ob_id'].values[-1]
odom_theta = float(csvfile['theta'].values[0])
odom = [0.0,0.0, odom_theta]

p_idx = ours.PoseMemory.pose_to_id(odom)
obstacle_dists = eval(csvfile['ob_dists'].values[0])

In [None]:
qs = ours.get_belief_over_states() #self.qs
# agent_state_mapping for TEST PURPOSES
ours.update_agent_state_mapping(tuple(odom[:2]), [ob_id,p_idx], qs[0])
ours.update_transition_nodes(obstacle_dist_per_actions=obstacle_dists)

### step 1

In [None]:
#STEP1 
action_step1 = 10
ob_id_step1 = 1 
excel_step = 1
# p_id = int(csvfile['state'].values[excel_step])
pose_x = float(csvfile['Pose x for agent'].values[excel_step])
pose_y = float(csvfile['Pose Y for agent'].values[excel_step])
theta = float(csvfile['theta'].values[excel_step])
p_idx_step1 = ours.PoseMemory.pose_to_id([pose_x,pose_y], save_in_memory=False)
pose = ours.PoseMemory.id_to_pose(p_idx_step1)

odom_step1 = [pose[0],pose[1], theta]

obstacle_dists_step1 = eval(csvfile['ob_dists'].values[excel_step])

In [None]:
pose_x, pose_y, odom_step1, p_idx_step1

In [None]:
ours.agent_manual_action_step(action_step1)
ours.agent_step_update(action_step1, [ob_id_step1,p_idx_step1,obstacle_dists_step1])

# ours.PoseMemory.update_odom_given_pose(odom_step1[:2])
# ours.current_pose = ours.PoseMemory.get_odom(as_tuple=True)
# ours.update_A_dim_given_obs([ob_id_step1,p_idx_step1], null_proba=[False,False])
# Qs = ours.infer_states([ob_id_step1,p_idx_step1], save_hist=False)
# print('prior on believed state; action', ours.action, 'colour_ob:', ob_id_step1, 'inf pose:',odom_step1,'belief:', Qs[0].round(3))
            
# ours.update_believes_with_obs(Qs,action=action_step1, obs=[ob_id_step1,p_idx_step1])
# qs = ours.get_belief_over_states()

# print('qs',qs[0].round(3))
# # # agent_state_mapping for TEST PURPOSES
# ours.update_agent_state_mapping(tuple(odom_step1[:2]), [ob_id_step1,p_idx_step1], qs[0])
# ours.update_transition_nodes(obstacle_dist_per_actions=obstacle_dists_step1)

### step 2

In [None]:
#STEP2
action_step2 = 8
ob_id_step2 = 2 
excel_step = 2

pose_x = float(csvfile['Pose x for agent'].values[excel_step])
pose_y = float(csvfile['Pose Y for agent'].values[excel_step])
theta = float(csvfile['theta'].values[excel_step])
p_idx_step2 = ours.PoseMemory.pose_to_id([pose_x,pose_y], save_in_memory=False)
pose = ours.PoseMemory.id_to_pose(p_idx_step2)
odom_step2 = [pose[0],pose[1], theta]

obstacle_dists_step2 = eval(csvfile['ob_dists'].values[excel_step])
ob_id_step2 ,p_idx_step2, 

In [None]:
ours.agent_manual_action_step(action_step2)
ours.agent_step_update(action_step2, [ob_id_step2,p_idx_step2,obstacle_dists_step2])

### step 3

In [None]:
#STEP3
action_step3 = 4
ob_id_step3 = 3
excel_step = 3

pose_x = float(csvfile['Pose x for agent'].values[excel_step])
pose_y = float(csvfile['Pose Y for agent'].values[excel_step])
theta = float(csvfile['theta'].values[excel_step])
p_idx_step3 = ours.PoseMemory.pose_to_id([pose_x,pose_y], save_in_memory=False)
pose = ours.PoseMemory.id_to_pose(p_idx_step3)
odom_step3 = [pose[0],pose[1], theta]
obstacle_dists_step3 = eval(csvfile['ob_dists'].values[excel_step])

In [None]:
ours.agent_manual_action_step(action_step3)
ours.agent_step_update(action_step3, [ob_id_step3,p_idx_step3,obstacle_dists_step3])

### step 4

In [None]:
#STEP4 NO OB
action_step4 = 2
ob_id_step4 = 0 
excel_step = 0

#STEP4 WITH NEW OB
# action_step4 = 2
# ob_id_step4 = 4 
# excel_step = 4

pose_x = float(csvfile['Pose x for agent'].values[excel_step])
pose_y = float(csvfile['Pose Y for agent'].values[excel_step])
theta = float(csvfile['theta'].values[excel_step])
p_idx_step4 = ours.PoseMemory.pose_to_id([pose_x,pose_y], save_in_memory=False)
pose = ours.PoseMemory.id_to_pose(p_idx_step4)
odom_step4 = [pose[0],pose[1], theta]
obstacle_dists_step4= eval(csvfile['ob_dists'].values[excel_step])

In [None]:
ours.agent_manual_action_step(action_step4)
ours.agent_step_update(action_step4, [ob_id_step4,p_idx_step4,obstacle_dists_step4])

# PLOT

In [None]:
ours.agent_state_mapping

In [None]:
print(ours.B[0].shape)
ours.agent_state_mapping

In [None]:
ours.A[0][0]

In [None]:
A0 = plot_likelihood(ours.A[0], ours.agent_state_mapping, tittle_add='observation')
A1 = plot_likelihood(ours.A[1], ours.agent_state_mapping, tittle_add='pose')

In [None]:
plot = plot_state_in_map(ours.B[0], ours.agent_state_mapping)

In [None]:
from map_dm_nav.visualisation_tools import pickle_dump_model
from pathlib import Path

# pickle_dump_model(ours, store_path=Path('/home/idlab332/workspace/ros_ws/src/map_dm_nav/map_dm_nav'))

In [None]:

pose_id = 0
visit = 0
from pathlib import Path
import os
a = str(pose_id) + '_' + str(visit)
store_path = Path.cwd() / 'tests' /  a

store_path
store_path = str(store_path)
os.path.exists(store_path)

In [None]:
plot_transitions(H_model.B[0], H_model.agent_state_mapping, H_model.possible_actions)

In [None]:

actions_plots = plot_transitions_per_actions(H_model.B[0], H_model.agent_state_mapping, H_model.possible_actions)
# list = plt.get_fignums()
plt.tight_layout(pad=35)
plt.show()

In [None]:
plt.close()

In [None]:
plt.show()

In [None]:
B_plot = plot_transitions(ours.B[0], ours.agent_state_mapping, ours.possible_actions)

In [None]:
B_plot_compare = compare_B1_B2_plots(B_v1_test1_ob, V_v1_test1_ob, ours.agent_state_mapping, ours.possible_actions)

# SUB TESTS

In [None]:
def part1(self,qs):

    qs_arg_max = np.argmax(qs)  #NOTE: not sure this is ideal...
    model_B = self.get_B()[:,qs_arg_max,:]
    median = np.median(model_B)
    B = model_B[model_B > median] 
    certitude_threshold = max(np.mean(B), 0.15)
    return certitude_threshold, qs_arg_max

In [None]:
def part2(qs, B_certain_trans, I):
    qs_arg_max = np.argmax(qs)
    for n in range(4):
        
        I_next = ((B_certain_trans.T @ I[-1]) > 0).astype(float) # New reachable states (as bool -> float)
        I_next = np.sum(I_next, axis=0) #We consider all states regardless of action
        I.append(I_next)
        if I[-1][qs_arg_max] >= qs[qs_arg_max]:
            print('we end process induction in ',n+1,' steps,current state', qs_arg_max)
            n-=1
            break
    n+=1
    return I, n

In [None]:
qs = H_model.get_belief_over_states()[0]
next_sim_qs = [np.array([
        0.00297072, 0.0029671 , 0.00296772, 0.00213682, 0.00216489,
        0.00216486, 0.00203948, 0.00276972, 0.0029765 , 0.00203766,
        0.00203773, 0.00204936, 0.00298157, 0.00300263, 0.00296908,
        0.00203766, 0.00296783, 0.0029674 , 0.00188876, 0.00296883,
        0.00241505, 0.00203808, 0.00298006, 0.00296959, 0.00203822,
        0.0029696 , 0.0021555 , 0.00297541, 0.00297485, 0.00296827,
        0.00203807, 0.00203807, 0.00204936, 0.85620934, 0.00206093,
        0.00206093, 0.00206093, 0.00206093, 0.00206092, 0.00206093,
        0.00206093, 0.00206092, 0.00206092, 0.00206093, 0.00206093,
        0.00206091, 0.00206092, 0.00206092, 0.00206093, 0.00206092,
        0.00206092, 0.00206093, 0.00206092, 0.00206092, 0.00206092,
        0.00206092, 0.00206092, 0.00245272, 0.00205433, 0.00206053,
        0.00208947, 0.0019949 , 0.00206101])]
print('next state:', np.argmax(next_sim_qs))
# sim_qo_pi = H_model.get_expected_observation(next_sim_qs)


In [None]:
certitude_threshold, qs_arg_max = part1(H_model, qs)
I = [copy.copy(H_model.Cs)]
#Keep only certain Transitions
B_certain_trans = (H_model.get_B() > certitude_threshold).astype(float)

In [None]:
print('START with preferred state', np.argmax(I), I)
print('we are in starting state', qs_arg_max, 'prob',qs[qs_arg_max])
#from preferred states, which states lead to it then we repeat until we are in qs
I_end,m = part2(qs, B_certain_trans, copy.copy(I))
print('m', m)
for i, step_I in enumerate(I_end):
    print('step',i, step_I)
    print(np.argwhere(step_I > np.amin((step_I >0).astype(float))).flatten())
    # print('B*I', B_certain_trans.T.dot(step_I))


In [None]:
print('considered step',m)
H =  np.log(0.25)*I_end[m].dot(next_sim_qs[0])
print('H',H)

In [None]:
figures = plot_transitions_per_actions(B_certain_trans, H_model.agent_state_mapping, H_model.possible_actions, selected_actions = [9,10,11])
plt.tight_layout(pad=35)
plt.show()