In [2]:
import numpy as np
from collections import defaultdict
import random
from tqdm import trange
import copy
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
import sys
import igraph
from matplotlib import cm, colors
random.seed(42)

In [4]:
class Agent:
    def __init__(self, num_obs=3, num_actions=1, pseudocount=1e-3):
        self.num_obs = num_obs
        self.groups_of_tables = {}
        self.table_totals = {}  # Keep track of totals for each table separately
        self.total_observations = 0  # Keep track of total observations across all tables
        self.count = np.ones((num_obs,num_obs,num_actions))
        self.TM = self.count.copy()
        self.C = np.zeros((num_obs,num_obs,num_actions))
        self.normalize_TM()
        self.initialize_clones(num_obs)
        self.pseudocount = pseudocount


    def initialize_clones(self,num_obs):
        for restaurant_id in range(num_obs):
            self.groups_of_tables[restaurant_id] = {}
            self.groups_of_tables[restaurant_id][0] = 1
    
    def normalize_TM(self):
        num_obs = np.shape(self.count)[0]
        num_actions = np.shape(self.count)[2]
        count = self.count.copy()
        for s in range(num_obs):
            for a in range(num_actions):
                self.TM[s, :, a] = count[s,:,a] / count[s,:,a].sum()   
                
    def add_clone(self, restaurant_id, table_id):
        """Add exactly one clone to a specified table, creating the table or group if necessary."""
        # Automatically create the group and table if they don't exist
        if restaurant_id not in self.groups_of_tables: # if this is the first observation
            self.groups_of_tables[restaurant_id] = {}
        if table_id not in self.groups_of_tables[restaurant_id]: # if this is the first clone of the observation
            self.groups_of_tables[restaurant_id][table_id] = 0  # Initialize clones count for the table

        # Add one clone to the table count and update total observations
        self.groups_of_tables[restaurant_id][table_id] += 1
        self.table_totals[(restaurant_id, table_id)] = self.groups_of_tables[restaurant_id][table_id]  # Update table total
        self.total_observations += 1

    def get_total_observations(self):
        """Return the total number of observations."""
        return self.total_observations

    def get_restaurant_total_customers(self, restaurant_id):
        """Return the total number of clones in all tables within a specific restaurant."""
        return sum(self.groups_of_tables.get(restaurant_id, {}).values())

    def get_table_total_customers(self, restaurant_id, table_id):
        """Return the total number of clones for a specific table."""
        return self.groups_of_tables.get(restaurant_id, {}).get(table_id, 0)


    def count_tables_in_restaurant(self, restaurant_id):
        """Returns the number of tables within the specified restaurant."""
        if restaurant_id in self.groups_of_tables:
            return len(self.groups_of_tables[restaurant_id])
        else:
            # print(f"Group {group_id} does not exist.")
            return 0

    def expand_split(self):
        # expand dimension
        orig_count = self.count.copy()
        orig_TM = self.TM.copy()
        orig_C = self.C.copy()
        n_prev_clones = np.shape(orig_count)[0]
        n_actions = np.shape(orig_count)[2]
        expanded_count = np.zeros((n_prev_clones+1, n_prev_clones+1, n_actions)) + self.pseudocount
        expanded_TM = np.zeros((n_prev_clones+1, n_prev_clones+1, n_actions)) + self.pseudocount
        expanded_C = np.zeros((n_prev_clones+1, n_prev_clones+1, n_actions)) + self.pseudocount
                
        # Copy the original matrix values into the top-left submatrix of the expanded matrix
        expanded_count[:n_prev_clones, :n_prev_clones, :] = orig_count
        expanded_TM[:n_prev_clones, :n_prev_clones, :] = orig_TM
        expanded_C[:n_prev_clones, :n_prev_clones, :] = orig_C
        
        self.count = expanded_count
        self.TM = expanded_TM
        self.C = expanded_C

            
    def update_count(self, state, state2, action): # updating counts when not splitted
        self.count[state, state2, action] += 1

    def update_count_fix(self, state, state2, action): # updating counts when not splitted
        self.count[state, state2, action] -= 1       
     
    def merged_likelihood(self, clone_map):
        
        merged_TM = self.TM.copy()
        
        curr_obs, clone_num, unique_idx = random.choice(clone_map)
        shared_obs_clones = [t for t in clone_map if t[0]==curr_obs and t[1] != clone_num] # all the clones that share the obs
        
        compare_clones = random.choice(shared_obs_clones) # pick a random clone that shares observation 

        state_a = unique_idx
        state_b = compare_clones[2]
        
        for s in range(len(clone_map)):
            if s != state_a and s != state_b:                
                merged_TM[state_a][s] += merged_TM[state_b][s]
                merged_TM[s][state_a] += merged_TM[s][state_b] 
                # Remove state_b by zeroing out probabilities (optional)
                merged_TM[state_b][s] = 0
                merged_TM[s][state_b] = 0
                
    def contingency(self, t, current_node, next_node):
        return prc


In [None]:
def contingency(prev_node, prev_action, curr_node, curr_action, next_node, t, sequence):
    return prc

In [5]:
example_sequences = [[1,2,3],
                     [4,2,5],
                     [1,2,3],
                     [4,2,5],
                     [1,2,3]]
# we want to see the exclusivity of 5 co-occurring with previous states
prev_node = 4
curr_node = 2
next_node = 5
curr_action = 0
prev_action = 0

# identify "curr node"


In [13]:
# Instead of storing only nodes like [1,2,3],
# store each element as (node, action).
# For illustration:
# example_sequences = [
#     [(1, 0), (2, 0), (3, 1)],
#     [(4, 0), (2, 0), (5, 0)],
#     [(1, 2), (2, 0), (3, 1)],
#     [(4, 0), (2, 1), (5, 0)],
#     [(1, 0), (2, 0), (3, 0)]
# ]


# Suppose these are the conditions we want to find in consecutive pairs:
prev_node = 4
prev_action = 0
curr_node = 2
curr_action = 0


In [28]:
# def find_contingency(
#     sequences,
#     prev_node,   # e.g., 4
#     prev_action, # e.g., 0
#     curr_node,   # e.g., 2
#     curr_action,  # e.g., 0, 
#     next_node
# ):
#     """
#     Return a list of (index, sequence) pairs where the given consecutive
#     pattern [(prev_node, prev_action), (curr_node, curr_action)] is found.
#     """
#     found_sequences = []
#     pr = 0
#     other_pr = 0
#     n_pr =0
    
#     # forward
#     for idx, seq in enumerate(sequences):
#         # seq is a list of (node, action) tuples

#         for i in range(len(seq) - 2):
#             # Check if the i-th and (i+1)-th elements match
#             node_i, action_i = seq[i]
#             node_i1, action_i1 = seq[i+1]
#             node_i2,_ = seq[i+2]
            
#             if (node_i1 == curr_node and action_i1 == curr_action and node_i2 == next_node):
#                 # print('here')
#                 if node_i == prev_node and action_i == prev_action: 
#                     pr += 1
#                 else: 
#                     other_pr += 1
#                 n_pr += 1        
#             # if (node_i == prev_node and action_i == prev_action and
#             #     node_i1 == curr_node and action_i1 == curr_action and node_i2 == next_node):
#             #         found_sequences.append((idx, seq))
#             #         break  # Once found in this sequence, move on to the next sequence
    
#     # backward
#     sr = 0
#     other_sr = 0
#     n_sr = 0
    
#     for idx, seq in enumerate(sequences):
#         # seq is a list of (node, action) tuples

#         for i in range(len(seq) - 2):
#             # Check if the i-th and (i+1)-th elements match
#             node_i, action_i = seq[i]
#             node_i1, action_i1 = seq[i+1]
#             node_i2,_ = seq[i+2]
   
#             if (node_i == prev_node and action_i == prev_action and
#                 node_i1 == curr_node and action_i1 == curr_action):
#                 if node_i2 == next_node:
#                     sr += 1
#                 else: 
#                     other_sr += 1
#                 n_sr += 1
#                     # found_sequences.append((idx, seq))
#                     # break  # Once found in this sequence, move on to the next sequence
         
    
    
#     # return found_sequences
#     # return prc/n if n!=0 else 0
#     return pr/n_pr, sr/n_sr

def find_contingency(
    sequences,
    prev_node,   # e.g., 4
    prev_action, # e.g., 0
    curr_node,   # e.g., 2
    curr_action,  # e.g., 0, 
    next_node,
    t=1
):
    """
    Return a list of (index, sequence) pairs where the given consecutive
    pattern [(prev_node, prev_action), (curr_node, curr_action)] is found.
    """
    found_sequences = []
    pr = 0
    other_pr = 0
    n_pr =0
    
    # forward
    for idx, seq in enumerate(sequences):
        # seq is a list of (node, action) tuples

        for i in range(t,len(seq)):
            # Check if the i-th and (i+1)-th elements match
            node_i, action_i = seq[i]
            node_i1, action_i1 = seq[i+1]
            node_i2,_ = seq[i+2]
            
            # prev_node_idx = i-t
            curr_node = i
            prev_to_next = seq[i-t:i+1]
            for n,nodes in enumerate(prev_to_next):
                 
                
            
            if (node_i1 == curr_node and action_i1 == curr_action and node_i2 == next_node):
                # print('here')
                if node_i == prev_node and action_i == prev_action: 
                    pr += 1
                else: 
                    other_pr += 1
                n_pr += 1        
            # if (node_i == prev_node and action_i == prev_action and
            #     node_i1 == curr_node and action_i1 == curr_action and node_i2 == next_node):
            #         found_sequences.append((idx, seq))
            #         break  # Once found in this sequence, move on to the next sequence
    
    # backward
    sr = 0
    other_sr = 0
    n_sr = 0
    
    for idx, seq in enumerate(sequences):
        # seq is a list of (node, action) tuples

        for i in range(len(seq) - 2):
            # Check if the i-th and (i+1)-th elements match
            node_i, action_i = seq[i]
            node_i1, action_i1 = seq[i+1]
            node_i2,_ = seq[i+2]
   
            if (node_i == prev_node and action_i == prev_action and
                node_i1 == curr_node and action_i1 == curr_action):
                if node_i2 == next_node:
                    sr += 1
                else: 
                    other_sr += 1
                n_sr += 1
                    # found_sequences.append((idx, seq))
                    # break  # Once found in this sequence, move on to the next sequence
         
    
    
    # return found_sequences
    # return prc/n if n!=0 else 0
    return pr/n_pr, sr/n_sr


In [31]:
example_sequences = [
    [(1, 0), (2, 0), (3, 1)],
    [(4, 0), (2, 0), (5, 0)],
    [(1, 2), (2, 0), (3, 1)],
    [(4, 0), (2, 0), (5, 0)],
    [(1, 0), (2, 0), (3, 0)], 
    # [(4, 0), (2, 0), (6, 0)],
    [(7, 0), (2, 0), (5, 0)]
]

pr, sr = find_contingency(
    example_sequences,
    prev_node=4,
    prev_action=0,
    curr_node=2,
    curr_action=0,
    next_node = 5
)
print(pr, sr)
# print("Sequences containing [(4, 0), (2, 0), (5)]:")
# for idx, seq in result:
#     print(f"  Index: {idx}, Sequence: {seq}")

# print ("Contingency between 4 and 5: {}".format(result))


0.6666666666666666 1.0


In [None]:
# state, action, next_state
example_sequences = [[1,2,0,4],
                     [1,3,0,5],
                     [3,2,0,4]]

In [None]:
def backward_contingency(sequences, prev_node, curr_node, condition):
    # curr_node: the state we're considering to clone (e.g. 2)
    # condition: the consequence to condition the clone on. e.g. reward/noreward: coded as 4,5
    overall_contingency = 0
    overall_contingency_n = 0
    conditioned_contingency = 0
    conditioned_contingency_n = 0
    for s, sequence in enumerate(sequences):
        for n, node in sequence:
            if node==curr_node: 
                # calculate the overall contingency (should end with r though) without cloning
                if (sequence[-1] == condition and sequence[n-1] == prev_node):
                    overall_contingency += 1
                elif (sequence[-1] == condition and sequence[n-1] != prev_node):
                    overall_contingency_n += 1
                # calculate the clone conditioned on outcome
                # if (sequence[-1] == condition and )
                
            # sequence[n-1]  
    

In [42]:
import numpy as np

class GridEnv:
    """
    A 3x3 grid with states numbered 1..9:
      1   2   3
      4   5   6
      7   8   9

    - State 5 is the 'cue' state.
    - State 9 is the 'goal' or 'end' state.
    - If the agent moves 8 -> 9 having visited 5 at least once, reward = +1.
    - If the agent moves 8 -> 9 without visiting 5 first, reward = -1.
    - Episode ends upon reaching state 9 (coming from state 8).
    """

    def __init__(self):
        # Map grid positions (row, col) -> state number
        self.pos_to_state = {
            (0,0): 1, (0,1): 2, (0,2): 3,
            (1,0): 4, (1,1): 5, (1,2): 6,
            (2,0): 7, (2,1): 8, (2,2): 9
        }
        # Inverse mapping: state -> (row, col)
        self.state_to_pos = {v: k for k, v in self.pos_to_state.items()}

        # Actions as integers: 0=up, 1=right, 2=down, 3=left
        self.actions = {
            0: (-1, 0),   # up
            1: (0, +1),   # right
            2: (+1, 0),   # down
            3: (0, -1)    # left
        }
        self.action_space = list(self.actions.keys())  # [0, 1, 2, 3]

        # Start/cue/goal states
        self.start_state = 1
        self.cue_state = 5
        self.goal_state = 9

        self.reset()

    def reset(self):
        """
        Reset environment to the start:
          - current_state = 1
          - visited_cue = False
        Returns the initial state (1).
        """
        self.current_state = self.start_state
        self.visited_cue = False
        return self.current_state

    def step(self, action):
        """
        Apply the action (one of 0,1,2,3).
        Returns (next_state, reward, done).
        """
        # Current position
        row, col = self.state_to_pos[self.current_state]

        # Action deltas
        dr, dc = self.actions[action]

        # Compute new row, col
        new_row = row + dr
        new_col = col + dc

        # If out of bounds, stay put
        if not (0 <= new_row <= 2):
            new_row = row
        if not (0 <= new_col <= 2):
            new_col = col

        next_state = self.pos_to_state[(new_row, new_col)]

        # Check if we visited the cue
        if next_state == self.cue_state:
            self.visited_cue = True

        # Calculate reward and done
        reward = 0
        done = False
        # Special case: transitioning 8 -> 9
        if self.current_state == 8 and next_state == 9:
            done = True
            reward = 1 if self.visited_cue else -1

        # Update environment
        self.current_state = next_state
        return next_state, reward, done


if __name__ == "__main__":
    env = GridEnv()

    # We'll store up to 50 steps (states, actions, rewards)
    max_steps = 50

    states = []
    actions = []
    rewards = []

    state = env.reset()
    for t in range(max_steps):
        # Pick a random action (0..3) for this demonstration
        action = np.random.choice(env.action_space)

        # Record current state, then step
        states.append(state)
        actions.append(action)
        
        next_state, reward, done = env.step(action)
        rewards.append(reward)
        
        state = next_state

        if done:
            states.append(state)
            print(f"Episode ended at step {t+1} (transitioned {states[-1]}->9). Reward={reward}")
            break

    print("\nRecorded trajectory:")
    print("States:", states)
    print("Actions:", actions)
    print("Rewards:", rewards)


Episode ended at step 10 (transitioned 9->9). Reward=1

Recorded trajectory:
States: [1, 4, 5, 4, 1, 1, 1, 4, 7, 8, 9]
Actions: [2, 1, 3, 0, 3, 3, 2, 2, 1, 1]
Rewards: [0, 0, 0, 0, 0, 0, 0, 0, 0, 1]


In [39]:
import numpy as np

class GridEnv:
    """
    A 3x3 grid with states numbered 1..9:
      1   2   3
      4   5   6
      7   8   9

    - State 5 is the 'cue' state.
    - State 9 is the 'goal' or 'end' state.
    - If the agent moves 8 -> 9 having visited 5 at least once, reward = +1.
    - If the agent moves 8 -> 9 without visiting 5 first, reward = -1.
    - Episode ends upon reaching state 9 (coming from state 8).
    """

    def __init__(self):
        # Map grid positions (row, col) -> state number
        self.pos_to_state = {
            (0,0): 1, (0,1): 2, (0,2): 3,
            (1,0): 4, (1,1): 5, (1,2): 6,
            (2,0): 7, (2,1): 8, (2,2): 9
        }
        # Inverse mapping: state -> (row, col)
        self.state_to_pos = {v: k for k, v in self.pos_to_state.items()}

        # Define possible actions: up, down, left, right.
        # We'll store them in a dict { "up": (-1,0), ... }
        self.actions = {
            "up":    (-1, 0),
            "down":  (+1, 0),
            "left":  (0, -1),
            "right": (0, +1),
        }
        self.action_space = list(self.actions.keys())

        # We’ll fix the start state as 1 for clarity (top-left).
        # The 'cue' state is 5, the 'goal' state is 9.
        self.start_state = 1
        self.cue_state = 5
        self.goal_state = 9

        # Initialize environment
        self.reset()

    def reset(self):
        """
        Reset environment to the start:
          - current_state = 1
          - visited_cue = False
        Returns the initial state (1).
        """
        self.current_state = self.start_state
        self.visited_cue = False
        return self.current_state

    def step(self, action):
        """
        Apply the action (one of: "up", "down", "left", "right").
        Returns (next_state, reward, done).
        """
        # Current row, col
        row, col = self.state_to_pos[self.current_state]
        # Action deltas
        dr, dc = self.actions[action]

        # Compute new row, col
        new_row = row + dr
        new_col = col + dc
        # If it goes out of bounds, stay put
        if not (0 <= new_row <= 2):
            new_row = row
        if not (0 <= new_col <= 2):
            new_col = col

        next_state = self.pos_to_state[(new_row, new_col)]

        # Check if we visited the cue
        if next_state == self.cue_state:
            self.visited_cue = True

        # Calculate reward
        reward = 0
        done = False
        # Special case: transitioning 8 -> 9
        if self.current_state == 8 and next_state == 9:
            done = True
            # Reward depends on whether we've visited state 5
            if self.visited_cue:
                reward = +1
            else:
                reward = -1

        # Update environment
        self.current_state = next_state
        return next_state, reward, done


if __name__ == "__main__":
    env = GridEnv()

    # We will store a sequence of states and actions, up to length 50
    max_steps = 50

    states = []
    actions = []
    rewards = []

    state = env.reset()
    for t in range(max_steps):
        # Choose an action. Here, we pick randomly for illustration.
        action = np.random.choice(env.action_space)

        # Record the current state (before we step).
        states.append(state)
        actions.append(action)

        next_state, reward, done = env.step(action)
        rewards.append(reward)

        state = next_state

        if done:
            # Episode ended by reaching state 9 from state 8
            print(f"Episode ended at step {t+1}.")
            break

    # If we didn't break, we either ended or reached 50 steps
    print("States:", states)
    print("Actions:", actions)
    print("Rewards:", rewards)


Episode ended at step 6.
States: [1, 1, 1, 4, 5, 8]
Actions: ['up', 'up', 'down', 'right', 'down', 'right']
Rewards: [0, 0, 0, 0, 0, 1]


In [154]:
import numpy as np

class GridEnvRightDownNoSelf:
    """
    A 3x3 grid with states numbered 1..9:

        1   2   3
        4   5   6
        7   8   9

    - The agent can only move RIGHT (0) or DOWN (1).
    - No self-transitions at borders:
        If an action would go out of bounds, that action is not allowed
        from that state.
    - Same cue logic: must visit 5 for a +1 reward at 9, else goes to 10 with -1.
    """

    def __init__(self, cue_state=2):
        # Grid layout: (row, col) -> state
        self.pos_to_state = {
            (0,0): 1, (0,1): 2, (0,2): 3, (0,3): 4,
            (1,0): 5, (1,1): 6, (1,2): 7, (1,3): 8,
            (2,0): 9, (2,1): 10, (2,2): 11, (2,3): 12,
            (3,0): 13, (3,1): 14, (3,2): 15, (3,3): 16,
        }
        self.state_to_pos = {s: rc for rc, s in self.pos_to_state.items()}
        
        # Actions as integers: 0=right, 1=down
        # right => (0, +1)
        # down  => (+1, 0)
        self.base_actions = {
            0: (0, +1),   # right
            1: (+1, 0)    # down
        }
        
        # Instead of having a static [0,1] action space, we have a
        # per-state action set (no invalid moves).
        self.valid_actions = self._build_valid_actions()

        # Special states
        self.start_state = 1
        self.cue_state = cue_state
        self.rewarded_terminal = 16
        self.unrewarded_terminal = 17  # not in the grid, just a label

        self.reset()

    def _build_valid_actions(self):
        """
        Precompute valid actions for each non-terminal grid state.
        A 'valid' action is one that leads to a NEW in-bounds state.
        """
        valid_dict = {}
        for row in range(4):
            for col in range(4):
                s = self.pos_to_state[(row, col)]
                # We'll store all actions that yield a different state (no self-transitions)
                valid_dict[s] = []
                for a, (dr, dc) in self.base_actions.items():
                    new_r = row + dr
                    new_c = col + dc
                    if 0 <= new_r < 4 and 0 <= new_c < 4:
                        next_s = self.pos_to_state[(new_r, new_c)]
                        # Only count it if next_s != s (which can't happen in 3x3, but let's be explicit)
                        if next_s != s:
                            valid_dict[s].append(a)

        # For terminal states 9 and 10, no actions are valid
        # valid_dict[9] = []
        # valid_dict[10] = []
        valid_dict[16] = []
        valid_dict[17] = []

        return valid_dict

    def reset(self):
        """
        Reset environment to the start:
          - current_state=1
          - visited_cue=False
        Returns current_state (1).
        """
        self.current_state = self.start_state
        self.visited_cue = False
        return self.current_state

    def get_valid_actions(self, state=None):
        """
        Return the list of valid actions for the current state
        (or a given state).
        """
        if state is None:
            state = self.current_state
        return self.valid_actions[state]

    def step(self, action):
        """
        Step with a guaranteed valid action. If an invalid action is given,
        we can either ignore or raise an Exception. We'll raise an error.
        """
        if action not in self.get_valid_actions():
            raise ValueError(f"Action {action} is not valid from state {self.current_state}.")

        # If we're already in a terminal (9 or 10), episode is over.
        if self.current_state in [self.rewarded_terminal, self.unrewarded_terminal]:
            return self.current_state, 0, True

        # Move
        row, col = self.state_to_pos[self.current_state]
        dr, dc = self.base_actions[action]
        next_row = row + dr
        next_col = col + dc

        next_state = self.pos_to_state[(next_row, next_col)]

        # Check cue
        if next_state == self.cue_state:
            self.visited_cue = True

        reward = 0
        done = False

        # Terminal condition: 8 -> 9
        if next_state == 16:
            done = True
            if self.visited_cue:
                reward = +1
                next_state = self.rewarded_terminal
            else:
                reward = -1
                next_state = self.unrewarded_terminal

        # Update
        self.current_state = next_state
        return next_state, reward, done


if __name__ == "__main__":
    env = GridEnvRightDownNoSelf()

    max_steps = 20
    states = []
    actions = []
    rewards = []

    state = env.reset()
    for t in range(max_steps):
        # Record current state
        states.append(state)

        # Query valid actions
        vacts = env.get_valid_actions(state)
        if not vacts:
            # No valid actions => terminal or stuck. Break out.
            print("No valid actions, must be terminal.")
            break

        # Choose random from valid set
        action = np.random.choice(vacts)
        actions.append(action)

        next_state, reward, done = env.step(action)
        rewards.append(reward)
        state = next_state

        if done:
            # Also record the final state (9 or 10)
            states.append(state)
            print(f"Episode ended at step {t+1}, final state={state}, reward={reward}")
            break

    print("\nTrajectory:")
    print("States:", states)
    print("Actions:", actions)
    print("Rewards:", rewards)


Episode ended at step 6, final state=17, reward=-1

Trajectory:
States: [1, 5, 9, 10, 14, 15, 17]
Actions: [1, 1, 0, 1, 0, 0]
Rewards: [0, 0, 0, 0, 0, -1]


In [149]:
def generate_dataset(env, n_episodes=10, max_steps=20):
    """
    Run 'n_episodes' episodes in the environment. Each episode ends
    either when the environment signals 'done' or when we hit 'max_steps'.

    Returns:
        A list of (state_sequence, action_sequence) pairs.
        - state_sequence: list of visited states
        - action_sequence: list of chosen actions
    """
    dataset = []

    for episode_idx in range(n_episodes):
        # Prepare lists to store states & actions for this episode
        states = []
        actions = []

        # Reset env to start a new episode
        state = env.reset()

        for t in range(max_steps):
            states.append(state)

            valid_actions = env.get_valid_actions(state)
            if not valid_actions:
                # No valid actions => we must be in a terminal or stuck
                break

            # Example: pick a random valid action
            action = np.random.choice(valid_actions)
            actions.append(action)

            # Step in the environment
            next_state, reward, done = env.step(action)
            state = next_state

            if done:
                # Also record the final state
                states.append(state)
                break

        # Store (states, actions) for this episode
        dataset.append((states, actions))

    return dataset

# if __name__ == "__main__":
#     # Suppose env is your environment, e.g.:
#     env = GridEnvRightDownNoSelf()

#     n_episodes = 5
#     max_steps_per_episode = 10

#     my_dataset = generate_dataset(env, n_episodes, max_steps_per_episode)

#     for i, (states_seq, actions_seq) in enumerate(my_dataset):
#         print(f"Episode {i+1}:")
#         print("  States:", states_seq)
#         print("  Actions:", actions_seq)
#         print("  Length of episode:", len(actions_seq), "steps")
#         print()


def TM(dataset):
    """
    Given a dataset of episodes, each episode being (states_seq, actions_seq),
    build a 3D count matrix of shape [max_state+1, max_action+1, max_state+1].
    
    Returns:
        transition_counts (np.ndarray): counts[s, a, s_next]
            The number of times we observed (state=s) --(action=a)--> (next_state=s_next).
    """
    # 1) Collect all observed states and actions to determine indexing bounds
    all_states = set()
    all_actions = set()
    
    for (states_seq, actions_seq) in dataset:
        for s in states_seq:
            all_states.add(s)
        for a in actions_seq:
            all_actions.add(a)
    
    max_state = max(all_states) if all_states else 0
    max_action = max(all_actions) if all_actions else 0
    
    # 2) Initialize a 3D count array
    #    We'll assume states range from 0..max_state
    #    and actions range from 0..max_action
    transition_counts = np.zeros((max_state+1, max_action+1, max_state+1), dtype=int)
    
    # 3) Fill in the counts by iterating over each episode's transitions
    for (states_seq, actions_seq) in dataset:
        # for each step t in the episode
        for t in range(len(actions_seq)):
            s = states_seq[t]
            a = actions_seq[t]
            s_next = states_seq[t+1]
            transition_counts[s, a, s_next] += 1
    
    return transition_counts


In [157]:
# if __name__ == "__main__":
    # Suppose env is your environment, e.g.:
env = GridEnvRightDownNoSelf(cue_state=6)

n_episodes = 50
max_steps_per_episode = 10

dataset = generate_dataset(env, n_episodes, max_steps_per_episode)

for i, (states_seq, actions_seq) in enumerate(dataset):
    print(f"Episode {i+1}:")
    print("  States:", states_seq)
    print("  Actions:", actions_seq)
    print("  Length of episode:", len(actions_seq), "steps")
    print()


Episode 1:
  States: [1, 5, 9, 10, 11, 12, 17]
  Actions: [1, 1, 0, 0, 0, 1]
  Length of episode: 6 steps

Episode 2:
  States: [1, 2, 6, 10, 14, 15, 16]
  Actions: [0, 1, 1, 1, 0, 0]
  Length of episode: 6 steps

Episode 3:
  States: [1, 5, 6, 10, 11, 12, 16]
  Actions: [1, 0, 1, 0, 0, 1]
  Length of episode: 6 steps

Episode 4:
  States: [1, 2, 6, 10, 11, 15, 16]
  Actions: [0, 1, 1, 0, 1, 0]
  Length of episode: 6 steps

Episode 5:
  States: [1, 5, 6, 7, 8, 12, 16]
  Actions: [1, 0, 0, 0, 1, 1]
  Length of episode: 6 steps

Episode 6:
  States: [1, 5, 9, 13, 14, 15, 17]
  Actions: [1, 1, 1, 0, 0, 0]
  Length of episode: 6 steps

Episode 7:
  States: [1, 5, 6, 7, 8, 12, 16]
  Actions: [1, 0, 0, 0, 1, 1]
  Length of episode: 6 steps

Episode 8:
  States: [1, 2, 3, 4, 8, 12, 17]
  Actions: [0, 0, 0, 1, 1, 1]
  Length of episode: 6 steps

Episode 9:
  States: [1, 5, 6, 10, 11, 12, 16]
  Actions: [1, 0, 1, 0, 0, 1]
  Length of episode: 6 steps

Episode 10:
  States: [1, 5, 6, 7, 11, 15, 

In [214]:
# Let's assume you have a list of episodes, each episode is:
# (states_seq, actions_seq)
# episodes = [
#     # Episode 1
#     ([1, 5, 9, 10, 11, 12, 17], [1, 1, 0, 0, 0, 1]),
#     # Episode 2
#     ([1, 2, 6, 10, 14, 15, 16], [0, 1, 1, 1, 0, 0]),
#     # Episode 3
#     ([1, 5, 6, 10, 11, 12, 16], [1, 0, 1, 0, 0, 1]),
#     # Episode 4
#     ([1, 2, 6, 10, 11, 15, 16], [0, 1, 1, 0, 1, 0]),
#     # Episode 5
#     ([1, 5, 6, 7, 8, 12, 16], [1, 0, 0, 0, 1, 1]),
#     # Episode 6
#     ([1, 5, 9, 13, 14, 15, 17], [1, 1, 1, 0, 0, 0]),
#     # Episode 7
#     ([1, 5, 6, 7, 8, 12, 16], [1, 0, 0, 0, 1, 1]),
#     # Episode 8
#     ([1, 2, 3, 4, 8, 12, 17], [0, 0, 0, 1, 1, 1]),
#     # Episode 9
#     ([1, 5, 6, 10, 11, 12, 16], [1, 0, 1, 0, 0, 1]),
#     # Episode 10
#     ([1, 5, 6, 7, 11, 15, 16], [1, 0, 0, 1, 1, 0]),
# ]
all_states =[]
# episodes = dataset
for curr_dataset in dataset:
    s,a = curr_dataset
    all_states.append(s)
unique_states = np.unique(all_states)
    
    
def has_state(sequence, state):
    """Return True if the episode's state sequence contains state=5."""
    return state in sequence

def has_transition(s,sprime,sequence):
    """Return True if the episode's state sequence contains a transition 15->16."""
    for i in range(len(sequence) - 1):
        if sequence[i] == s and sequence[i + 1] == sprime:
            return True
    return False

# Counters
# for curr_state in 
# unique_states = [6]
s=12
sprime=16
sprime2 = 17
for curr_state in unique_states:
    if curr_state < s:
        episodes_with_state = 0
        # episodes_with_state_and_transition = 0
        other =0
        # curr_state = 6

        total = 0
        a=0
        b=0
        c=0
        d=0
        conditioned_contingency=0
        print("Current state: {}".format(curr_state))
        for (states_seq, actions_seq) in dataset:
            if has_state(states_seq,s):
                total += 1
                if has_state(states_seq, curr_state):
                
                    
                    # episodes_with_state += 1
                    if has_transition(s,sprime,states_seq): 
                        # episodes_with_state_and_transition += 1   
                        a += 1
                        # print('transition: {}'.format(states_seq))
                    elif has_transition(s,sprime2, states_seq): 
                        # print(states_seq)
                        b+=1
                else: 
                    # print('here')
                    if has_transition(s,sprime,states_seq): 
                        # episodes_with_state_and_transition += 1   
                        c += 1
                        # print('transition: {}'.format(states_seq))
                    elif has_transition(s,sprime2, states_seq): 
                        # print(states_seq)
                        d+=1
                assert total == a+b+c+d
        if a+b != 0: 
            print("forward contingency: {}".format(a/(a+b)))
        else: 
            print("no forward contingency")
        if c+d != 0: 
            print("backward contingency: {}".format(d/(c+d)))
        else: 
            print("no backward contingency")
        # print("Number of episodes that have state {}: {}".format(curr_state, total))
        # print("Number of those that also have transition {}->{}:{}".format(s,sprime,conditioned_contingency))

        # if episodes_with_state > 0:
        #     fraction = conditioned_contingency / total
        #     print("Fraction of episodes with {} that also contain {}->{}: {}".format(curr_state, s, sprime, fraction))
        # else:
        #     print("No episodes contain state={} at all.".format(curr_state))


Current state: 1
forward contingency: 0.6451612903225806
no backward contingency
Current state: 2
forward contingency: 0.4444444444444444
backward contingency: 0.07692307692307693
Current state: 3
forward contingency: 0.0
backward contingency: 0.047619047619047616
Current state: 4
forward contingency: 0.0
backward contingency: 0.2
Current state: 5
forward contingency: 0.9230769230769231
backward contingency: 0.5555555555555556
Current state: 6
forward contingency: 1.0
backward contingency: 1.0
Current state: 7
forward contingency: 0.7777777777777778
backward contingency: 0.5384615384615384
Current state: 8
forward contingency: 0.5
backward contingency: 0.15384615384615385
Current state: 9
forward contingency: 0.0
backward contingency: 0.3333333333333333
Current state: 10
forward contingency: 0.8571428571428571
backward contingency: 0.4166666666666667
Current state: 11
forward contingency: 0.8461538461538461
backward contingency: 0.5


In [212]:
print(total)
print(a,b,c,d)

31
20 0 0 11


In [196]:
unique_states

array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17])

In [168]:
for i, (states, actions) in enumerate(dataset):
    transition_counts = TM(dataset)

In [169]:
transition_counts[8,0,9]

0

In [170]:
transition_counts[8,0,10]

0

In [171]:
denominators = transition_counts.sum(axis=2, keepdims=True)
denominators[denominators == 0] = 1
transition_probs = transition_counts / denominators
denominators.shape

(18, 2, 1)

In [166]:
transition_probs[8,0,9]

0.4

In [167]:
transition_probs[8,0,10]

0.6

In [158]:
def find_stochastic_state_action_pairs(transition_probs, tol=1e-9):
    """
    Given transition_probs[s, a, s_next], returns a list of (s, a) pairs
    for which the transition distribution is *not* purely deterministic.
    
    A purely deterministic distribution means exactly one s' has probability=1.0,
    and the rest are 0. Everything else we call 'stochastic'.
    
    Args:
        transition_probs (np.ndarray): shape [S+1, A+1, S+1].
        tol (float): numerical tolerance for checking 1.0 or 0.0.
    
    Returns:
        A list of (s, a) pairs that have a stochastic (non-deterministic) distribution.
    """
    stochastic_pairs = []
    S, A, _ = transition_probs.shape

    for s in range(S):
        for a in range(A):
            dist = transition_probs[s, a]  # distribution over s'
            
            # Check if there's exactly one next state with prob ~ 1.0
            # We'll say "exactly one" if the max prob is close to 1, 
            # AND the sum of probabilities is close to 1 (should be if the data is well-formed),
            # AND exactly one entry is near 1 while all others are near 0.
            max_idx = np.argmax(dist)
            max_val = dist[max_idx]
            
            # If max_val is effectively 1.0 and the rest are effectively 0.0 => deterministic
            # We'll check how many entries are "significantly non-zero"
            nonzero_count = np.count_nonzero(dist > tol)
            
            is_deterministic = (
                np.isclose(max_val, 1.0, atol=tol) 
                and nonzero_count == 1
            )
            
            if not is_deterministic:
                # Then it's a stochastic distribution
                # (could also be all zero if it never happened in the dataset, but that is
                # trivially "stochastic" or maybe "undefined". Often you'd skip those.)
                
                # If you want to skip "never occurred" distributions (all zeros), you can check:
                # if dist.sum() > 0: 
                #     stochastic_pairs.append((s, a))
                # else:
                #     # no data at all, skip or treat as "unobserved"
                
                # For now, let's include any distribution that is not purely deterministic
                stochastic_pairs.append((s, a))

    return stochastic_pairs


In [162]:
stochastic_pairs = find_stochastic_state_action_pairs(transition_probs)

print("Found the following (s,a) pairs with non-deterministic transitions:")
for (s, a) in stochastic_pairs:
    print(f"  (s={s}, a={a}) => distribution = {transition_probs[s,a]}")

Found the following (s,a) pairs with non-deterministic transitions:
  (s=0, a=0) => distribution = [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
  (s=0, a=1) => distribution = [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
  (s=3, a=0) => distribution = [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
  (s=6, a=0) => distribution = [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
  (s=6, a=1) => distribution = [0.   0.   0.   0.   0.   0.   0.   0.   0.   0.56 0.44]
  (s=7, a=1) => distribution = [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
  (s=8, a=0) => distribution = [0.  0.  0.  0.  0.  0.  0.  0.  0.  0.4 0.6]
  (s=8, a=1) => distribution = [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
  (s=9, a=0) => distribution = [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
  (s=9, a=1) => distribution = [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
  (s=10, a=0) => distribution = [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
  (s=10, a=1) => distribution = [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]


In [160]:
import numpy as np

def compute_transition_entropies(transition_probs, tol=1e-9):
    """
    Given transition_probs[s, a, s_next], compute the Shannon entropy
    (in bits, i.e. log base 2) of each (s, a) distribution.
    
    Returns:
        entropies: A 2D array of shape [S, A], where entropies[s, a]
                   is the entropy of transition_probs[s, a, :].
    
    Notes:
      - If the total probability mass for (s, a) is ~0 (i.e. no data),
        we set entropy to 0 by default (or you could mark it as NaN).
      - We ignore states that are purely out-of-bounds or never visited.
    """
    S, A, _ = transition_probs.shape
    entropies = np.zeros((S, A), dtype=float)
    
    for s in range(S):
        for a in range(A):
            dist = transition_probs[s, a]  # shape = [S]
            
            # Sum of probabilities (should be ~1 if we have data)
            total_prob = dist.sum()
            if total_prob < tol:
                # Means no data or zero-prob distribution
                entropies[s, a] = 0.0
                continue
            
            # Identify the non-zero probabilities (to avoid log(0))
            p_nonzero = dist[dist > tol]
            
            # Normalize them so they sum to 1
            p_nonzero /= p_nonzero.sum()
            
            # Shannon entropy in bits
            #  E = - sum(p * log2(p))
            ent = -np.sum(p_nonzero * np.log2(p_nonzero))
            entropies[s, a] = ent
            
    return entropies

def find_stochastic_state_actions_by_entropy(entropies, eps=1e-9):
    """
    Given a 2D array of entropies[s,a], return a list of (s,a) pairs
    that are strictly > eps in entropy (i.e. non-deterministic).
    """
    stochastic_pairs = []
    S, A = entropies.shape
    for s in range(S):
        for a in range(A):
            # If entropy is basically 0 => deterministic
            if entropies[s, a] > eps:
                stochastic_pairs.append((s,a))
    return stochastic_pairs


In [172]:
entropies = compute_transition_entropies(transition_probs)

# Identify (s,a) pairs that are definitely stochastic (entropy > small threshold)
stochastic_pairs = find_stochastic_state_actions_by_entropy(entropies, eps=1e-9)

print("Entropy array shape:", entropies.shape)
print("Stochastic (s,a) pairs where H>0:")
for (s, a) in stochastic_pairs:
    print(f"  (s={s}, a={a}): H={entropies[s,a]:.4f}")

Entropy array shape: (18, 2)
Stochastic (s,a) pairs where H>0:
  (s=12, a=1): H=0.9383
  (s=15, a=0): H=0.9819


In [219]:
# if __name__ == "__main__":
    # Suppose env is your environment, e.g.:
env = GridEnvRightDownNoSelf(cue_state=6)

n_episodes = 50
max_steps_per_episode = 10

dataset = generate_dataset(env, n_episodes, max_steps_per_episode)
cumulative_dataset = []
for i, curr_dataset in enumerate(dataset):
    cumulative_dataset.append(curr_dataset)
    print(f"Episode {i+1}:")
    transition_counts = TM(cumulative_dataset)
    denominators = transition_counts.sum(axis=2, keepdims=True)
    denominators[denominators == 0] = 1
    transition_probs = transition_counts / denominators
    # denominators.shape
    # print("  States:", states_seq)
    # print("  Actions:", actions_seq)
    # print("  Length of episode:", len(actions_seq), "steps")
    # print()
    entropies = compute_transition_entropies(transition_probs)
    print(curr_dataset[0])

    # Identify (s,a) pairs that are definitely stochastic (entropy > small threshold)
    stochastic_pairs = find_stochastic_state_actions_by_entropy(entropies, eps=1e-9)
    # curr_entropy = False
    if stochastic_pairs:
        print("Entropy array shape:", entropies.shape)
        print("Stochastic (s,a) pairs where H>0:")
        for (s, a) in stochastic_pairs:
            print(f"  (s={s}, a={a}): H={entropies[s,a]:.4f}")
            

Episode 1:
[1, 5, 6, 10, 11, 12, 16]
Episode 2:
[1, 5, 6, 7, 11, 15, 16]
Episode 3:
[1, 5, 9, 13, 14, 15, 17]
Entropy array shape: (18, 2)
Stochastic (s,a) pairs where H>0:
  (s=15, a=0): H=1.0000
Episode 4:
[1, 2, 6, 10, 14, 15, 16]
Entropy array shape: (18, 2)
Stochastic (s,a) pairs where H>0:
  (s=15, a=0): H=0.9183
Episode 5:
[1, 2, 6, 10, 11, 15, 16]
Entropy array shape: (18, 2)
Stochastic (s,a) pairs where H>0:
  (s=15, a=0): H=0.8113
Episode 6:
[1, 5, 6, 7, 8, 12, 16]
Entropy array shape: (18, 2)
Stochastic (s,a) pairs where H>0:
  (s=15, a=0): H=0.8113
Episode 7:
[1, 5, 9, 10, 11, 12, 17]
Entropy array shape: (18, 2)
Stochastic (s,a) pairs where H>0:
  (s=12, a=1): H=0.9183
  (s=15, a=0): H=0.8113
Episode 8:
[1, 2, 3, 4, 8, 12, 17]
Entropy array shape: (18, 2)
Stochastic (s,a) pairs where H>0:
  (s=12, a=1): H=1.0000
  (s=15, a=0): H=0.8113
Episode 9:
[1, 2, 6, 10, 11, 12, 16]
Entropy array shape: (18, 2)
Stochastic (s,a) pairs where H>0:
  (s=12, a=1): H=0.9710
  (s=15, a=0): 

In [216]:
stochastic_pairs

[(12, 1), (15, 0)]

In [173]:
split_states = [6]
for split_state in split_states:
    curr_dataset = []
    for i, (states, actions) in enumerate(dataset):
        # transition_counts = TM(dataset)
        
        if split_state in states:
            # print('here')
            curr_dataset.append((states,actions))
            
    # identifying unique states in this dataset
    all_states = [s for (s,a) in curr_dataset]
    unique_states = np.unique(all_states)
    curr_clone = unique_states[4] # just fix this for 5 now / this means we're looking at contingency between 2 and 10 (state after 6)    
    for i, (states,actions) in enumerate(curr_dataset):
        print(states, actions)
        contingency_checker = False
        for j in range(len(actions)):
            if states[j] == curr_clone: # if 5
                contingency_checker = True
            if states[j] == split_states # 6
                
    


SyntaxError: expected ':' (829458916.py, line 21)

In [146]:
np.unique(curr_dataset[0])

ValueError: setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (2,) + inhomogeneous part.

In [147]:
all_states = [s for (s,a) in curr_dataset]
unique_states = np.unique(all_states)

# unique_states = np.unique(

# unique_states =
print("Unique states:", unique_states)

Unique states: [ 1  2  3  4  5  6  9 10]
