In [None]:
import numpy as np
from chmm_actions import CHMM, forwardE, datagen_structured_obs_room
import matplotlib.pyplot as plt
import igraph
from matplotlib import cm, colors
from collections import defaultdict
import os
import math

import sys
import os

# Get the path of the parent directory
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
sys.path.append(parent_dir)

from testing_environments import ContinuousTMaze, GridEnvRightDownNoCue, GridEnvRightDownNoSelf, GridEnvDivergingMultipleReward, GridEnvDivergingSingleReward
from util import *



custom_colors = (
    np.array(
        [
            [214, 214, 214],
            [85, 35, 157],
            [253, 252, 144],
            [114, 245, 144],
            [151, 38, 20],
            [239, 142, 192],
            [214, 134, 48],
            [140, 194, 250],
            [72, 160, 162],
        ]
    )
    / 256
)
if not os.path.exists("figures"):
    os.makedirs("figures")


In [None]:
def plot_graph(chmm, x, a, output_file, cmap=cm.Spectral, multiple_episodes=False, vertex_size=30):
    # Print diagnostic information
    print("Diagnostic information:")
    print("x.max() + 1 =", x.max() + 1)
    print("Number of states (n_clones) =", len(chmm.n_clones))
    print("n_clones values =", chmm.n_clones)

    # Decode the state sequence and get the visited clone indices.
    states = chmm.decode(x, a)[1]
    visited = np.unique(states)

    # If we are dealing with multiple episodes, adjust the transition matrix and visited indices.
    if multiple_episodes:
        T = chmm.C[:, visited][:, :, visited][:-1, 1:, 1:]
        visited = visited[1:]
    else:
        T = chmm.C[:, visited][:, :, visited]

    # Compute an overall transition matrix by summing over the first axis and normalizing each row.
    A = T.sum(axis=0)
    A /= A.sum(axis=1, keepdims=True)

    # Identify indices for nodes with observation labels 16 and 17
    obs_labels = np.arange(x.max() + 1).repeat(np.array(chmm.n_clones))
    target_nodes = [idx for idx, obs in enumerate(obs_labels[visited]) if obs in {16, 17}]

    # Set outbound edges for these nodes to 0 by zeroing out the corresponding rows
    for node_idx in target_nodes:
        A[node_idx, :] = 0  # Remove all outbound connections for this node

    # Create a graph where an edge exists if there is any positive transition probability.
    g = igraph.Graph.Adjacency((A > 0).tolist())

    # Create an array of observation labels for each clone.
    if multiple_episodes:
        obs_labels = obs_labels.copy() - 1

    print("\nNode information:")
    print("Number of unique states visited:", len(visited))
    print("Observation labels for visited nodes:", obs_labels[visited])

    # Group visited nodes by their observation label.
    from collections import defaultdict
    obs_to_indices = defaultdict(list)
    for idx, obs in enumerate(obs_labels[visited]):
        obs_to_indices[obs].append(idx)

    print("\nGrouping information:")
    for obs, indices in sorted(obs_to_indices.items()):
        print(f"Observation {obs}: {len(indices)} nodes at indices {indices}")

    # --- New Layout: 4x4 Grid of Clusters ---
    grid_cols = 4
    cluster_spacing = 10.0  # spacing between cluster centers in the grid
    unique_obs = sorted(obs_to_indices.keys())

    # Manually group 16 and 17 into a single observation cluster
    merged_obs_groups = {obs: obs for obs in unique_obs}
    # Force 16 and 17 to share the same cluster center
    if 16 in merged_obs_groups and 17 in merged_obs_groups:
        merged_obs_groups[17] = 16

    # Create cluster positions for the merged groups
    cluster_centers = {}
    assigned_obs = sorted(set(merged_obs_groups.values()))
    for i, obs in enumerate(assigned_obs):
        row = i // grid_cols
        col = i % grid_cols
        cluster_centers[obs] = (col * cluster_spacing, row * cluster_spacing)

    # Offset to ensure nodes aren't overlapping
    offset_mapping = {16: (-0.5, 0.5), 17: (0.5, -0.5)}  # Define offsets for 16 and 17

    # Arrange positions within each cluster, including merged groups
    layout_positions = [None] * len(visited)
    for obs, indices in obs_to_indices.items():
        cluster_obs = merged_obs_groups[obs]
        n_nodes = len(indices)
        mini_grid_size = int(np.ceil(np.sqrt(n_nodes)))
        mini_spacing = 1.0
        center_offset = ((mini_grid_size - 1) / 2.0, (mini_grid_size - 1) / 2.0)
        cx, cy = cluster_centers[cluster_obs]

        for j, node_idx in enumerate(indices):
            mini_row = j // mini_grid_size
            mini_col = j % mini_grid_size
            offset_x = (mini_col - center_offset[0]) * mini_spacing
            offset_y = (mini_row - center_offset[1]) * mini_spacing

            # Apply custom offsets for 16 and 17
            if obs in offset_mapping:
                dx, dy = offset_mapping[obs]
                offset_x += dx
                offset_y += dy

            layout_positions[node_idx] = (cx + offset_x, cy + offset_y)

    # --- End Layout Section ---

    # Set vertex colors: every node is white except for those with observation label 16 (green) and 17 (red).
    colors = []
    for obs in obs_labels[visited]:
        if obs == 16:
            colors.append("green")
        elif obs == 17:
            colors.append("red")
        else:
            colors.append("white")

    # Create vertex labels that indicate both the observation and the clone number.
    vertex_labels = []
    for obs in obs_labels[visited]:
        vertex_labels.append(f"{obs}")

    # Define visual style options.
    visual_style = {
        "vertex_size": vertex_size,
        "vertex_color": colors,
        "vertex_label": vertex_labels,
        "edge_curved": 0.2,
        "margin": 40,
        "bbox": (800, 600)
    }

    out = igraph.plot(g, output_file, layout=layout_positions, **visual_style)
    return out


def get_mess_fwd(chmm, x, pseudocount=0.0, pseudocount_E=0.0):
    n_clones = chmm.n_clones
    E = np.zeros((n_clones.sum(), len(n_clones)))
    last = 0
    for c in range(len(n_clones)):
        E[last : last + n_clones[c], c] = 1
        last += n_clones[c]
    E += pseudocount_E
    norm = E.sum(1, keepdims=True)
    norm[norm == 0] = 1
    E /= norm
    T = chmm.C + pseudocount
    norm = T.sum(2, keepdims=True)
    norm[norm == 0] = 1
    T /= norm
    T = T.mean(0, keepdims=True)
    log2_lik, mess_fwd = forwardE(
        T.transpose(0, 2, 1), E, chmm.Pi_x, chmm.n_clones, x, x * 0, store_messages=True
    )
    return mess_fwd


def place_field(mess_fwd, rc, clone):
    assert mess_fwd.shape[0] == rc.shape[0] and clone < mess_fwd.shape[1]
    field = np.zeros(rc.max(0) + 1)
    count = np.zeros(rc.max(0) + 1, int)
    for t in range(mess_fwd.shape[0]):
        r, c = rc[t]
        field[r, c] += mess_fwd[t, clone]
        count[r, c] += 1
    count[count == 0] = 1
    return field / count


# Training a CHMM

In [None]:
# Train CHMM on random data
TIMESTEPS = 1000
OBS = 2
x = np.random.randint(OBS, size=(1000,))  # Observations. Replace with your data.
a = np.zeros(
    1000, dtype=np.int64
)  # If there are actions in your domain replace this. If not, keep the vector of zeros.
n_clones = (
    np.ones(OBS, dtype=np.int64) * 5
)  # Number of clones specifies the capacity for each observation.

x_test = np.random.randint(
    OBS, size=(1000,)
)  # Test observations. Replace with your data.
a_test = np.zeros(1000, dtype=np.int64)

chmm = CHMM(n_clones=n_clones, pseudocount=1e-10, x=x, a=a)  # Initialize the model
progression = chmm.learn_em_T(x, a, n_iter=100, term_early=False)  # Training

nll_per_prediction = chmm.bps(
    x_test, a_test
)  # Evaluate negative log-likelihood (base 2 log)
avg_nll = np.mean(nll_per_prediction)
avg_prediction_probability = 2 ** (-avg_nll)
print(avg_prediction_probability)


# Rectangular room datagen

In [None]:
# Dataset
size = 5
env_size = (size,size)
rewarded_terminal = env_size[0]*env_size[1]
cue_states = [13]
env = GridEnvRightDownNoSelf(env_size=env_size, 
                             rewarded_terminal = [rewarded_terminal],
                             cue_states=cue_states)


In [None]:
pos_to_state = {
    (0,0): 1, (0,1): 2, (0,2): 3, (0,3): 4,
    (1,0): 5, (1,1): 6, (1,2): 7, (1,3): 8,
    (2,0): 9, (2,1): 10, (2,2): 11, (2,3): 12,
    (3,0): 13, (3,1): 14, (3,2): 15, (3,3): 16,
}

actions = {
    -1 : "reset",
    0 : "down",
    1 : "right"
}


def get_valid_actions(state:tuple):
    y, x = state
    valid_actions = []
    if y < 3:
        valid_actions.append(0)
    
    if x < 3:
        valid_actions.append(1)
    
    return valid_actions

def datagen_nonmarkov_room(n_episodes=10, max_steps=20):
    """
    Run 'n_episodes' episodes in the environment. Each episode ends
    either when the environment signals 'done' or when we hit 'max_steps'.

    Returns:
        A list of (state_sequence, action_sequence) pairs.
        - state_sequence: list of visited states
        - action_sequence: list of chosen actions
    """
    actions = []
    states = []
    rc = []

    for episode_idx in range(n_episodes):
        # Reset env to start a new episode
        state = (0, 0)
        episode = []

        for t in range(max_steps):
            episode.append(pos_to_state[state])
            rc.append(state)

            valid_actions = get_valid_actions(state)
            if not valid_actions:
                # No valid actions => we must be in a terminal or stuck
                print(state)
                break

            # Example: pick a random valid action
            action = np.random.choice(valid_actions)
            actions.append(action)

            # Step in the environment
            if action == 0:
                next_state = (state[0] + 1, state[1])
            
            elif action == 1:
                next_state = (state[0], state[1] + 1)

            done = next_state == (3, 3)

            state = next_state

            if done:
                # Also record the final state
                tmp = pos_to_state[state]

                if tmp == 16 and 6 not in episode:  # if the cue was not touched go to unrewarded terminal
                    tmp = 17
                
                episode.append(tmp)

                actions.append(-1)
                rc.append((-1, -1))

                break

        states.extend(episode)  # add the episode to the end of all the states
        
    return np.array(actions), np.array(states), np.array(rc)

def datagen_tmaze(n_data=10, n_episodes=10, max_steps=20):
    """
    Run 'n_episodes' episodes in the t-maze environment. Each episode ends
    either when the environment signals 'done' or when we hit 'max_steps'.

    Returns:
        A list of (state_sequence, action_sequence) pairs.
        - state_sequence: list of visited states
        - action_sequence: list of chosen actions
    """
    
    observation_1 = [1,0,3] # start 1, left turn, rewarded
    observation_2 = [2,0,5] # start 2, right turn, rewarded
    observation_3 = [1,0,6] # start 1, right turn, no reward
    observation_4 = [2,0,4] # start 2, left turn, no reward
    # 2: left turn, 4: right turn, 5: reward, 6: no reward
    actions_1 = [0,1,-1]
    actions_2 = [0,2,-1]
    actions_3 = [0,2,-1]
    actions_4 = [0,1,-1]
    # n_data = 100 #25
    super_observations = np.array(([observation_1] * n_data) +
                            ([observation_2] * n_data) +
                            ([observation_3] * n_data) +
                            ([observation_4] * n_data))
    # np.random.shuffle(super_observations)
    # Build super_actions with the same structure
    super_actions = np.array(([actions_1] * n_data) +
                            ([actions_2] * n_data) +
                            ([actions_3] * n_data) +
                            ([actions_4] * n_data))
    # We now have 100 rows in each (25 x 4 = 100).
    # Shuffle them in the *same* order using a random permutation of indices
    permutation = np.random.permutation(len(super_observations))

    # Apply the permutation to both arrays
    super_observations = super_observations[permutation]
    super_actions = super_actions[permutation]
    # dataset=[]
   
    actions = []
    states = []
    
    for l in range(len(super_observations)):
        states.extend(super_observations[l])
        actions.extend(super_actions[l])
    
    # rc = []

    # for episode_idx in range(n_episodes):
    #     # Reset env to start a new episode
    #     state = (0, 0)
    #     episode = []

    #     for t in range(max_steps):
    #         episode.append(pos_to_state[state])
            
    #         # rc.append(state)

    #         valid_actions = get_valid_actions(state)
    #         if not valid_actions:
    #             # No valid actions => we must be in a terminal or stuck
    #             print(state)
    #             break

    #         # Example: pick a random valid action
    #         action = np.random.choice(valid_actions)
    #         actions.append(action)

    #         # Step in the environment
    #         if action == 0:
    #             next_state = (state[0] + 1, state[1])
            
    #         elif action == 1:
    #             next_state = (state[0], state[1] + 1)

    #         done = next_state == (3, 3)

    #         state = next_state

    #         if done:
    #             # Also record the final state
    #             tmp = pos_to_state[state]

    #             if tmp == 16 and 6 not in episode:  # if the cue was not touched go to unrewarded terminal
    #                 tmp = 17
                
    #             episode.append(tmp)

    #             actions.append(-1)
    #             # rc.append((-1, -1))

    #             break

    #     states.extend(episode)  # add the episode to the end of all the states
        
    return np.array(actions), np.array(states)

In [None]:
room = np.array(
    [
        [1, 2, 3, 4],
        [5, 6, 7, 8],
        [9, 10, 11, 12],
        [13, 14, 15, 16]

    ]
)
n_emissions = room.max() + 2

a, x, rc = datagen_nonmarkov_room()

print(len(a))
print(len(x))
print(len(rc))

n_clones = np.ones(n_emissions, dtype=np.int64) * 70
chmm = CHMM(n_clones=n_clones, pseudocount=2e-3, x=x, a=a, seed=42)  # Initialize the model
progression = chmm.learn_em_T(x, a, n_iter=1000)  # Training


In [None]:
# refine learning
chmm.pseudocount = 0.0
chmm.learn_viterbi_T(x, a, n_iter=100)

In [None]:
cmap = colors.ListedColormap(custom_colors[-4:])
plt.matshow(room, cmap=cmap)
plt.savefig("figures/rectangular_room_layout.pdf")


In [None]:
# Print diagnostic information
print("x.max() + 1 =", x.max() + 1)
print("Length of n_clones =", len(chmm.n_clones))  # assuming n_clones is part of chmm

# Generate the graph with the improved clustering layout
graph = plot_graph(
    chmm, x, a, output_file="figures/rectangular_room_graph.pdf", cmap=cmap
)

# Display the returned graph object (if using an interactive environment)
graph


In [None]:
mess_fwd = get_mess_fwd(chmm, x, pseudocount_E=0.1)

In [None]:
clone = 114
plt.matshow(place_field(mess_fwd, rc, clone))
plt.savefig("figures/rectangular_room_place_field.pdf")

# T-maze implementation

In [None]:
a, x = datagen_tmaze()

In [None]:
# room = np.array(
#     [
#         [1, 2, 3, 4],
#         [5, 6, 7, 8],
#         [9, 10, 11, 12],
#         [13, 14, 15, 16]

#     ]
# )
a, x = datagen_tmaze()
n_emissions = x.max() + 1



print(len(a))
print(len(x))
# print(len(rc))

n_clones = np.ones(n_emissions, dtype=np.int64) * 70
chmm = CHMM(n_clones=n_clones, pseudocount=2e-3, x=x, a=a, seed=42)  # Initialize the model
progression = chmm.learn_em_T(x, a, n_iter=1000)  # Training

In [None]:
# refine learning
chmm.pseudocount = 0.0
chmm.learn_viterbi_T(x, a, n_iter=100)

In [None]:
# Print diagnostic information
print("x.max() + 1 =", x.max() + 1)
print("Length of n_clones =", len(chmm.n_clones))  # assuming n_clones is part of chmm

# Generate the graph with the improved clustering layout
graph = plot_graph(
    chmm, x, a, output_file="figures/tmaze_graph.pdf", cmap=cmap
)

# Display the returned graph object (if using an interactive environment)
graph


# Empty rectangular room datagen

In [None]:
H, W = 6, 8
room = np.zeros((H, W), dtype=np.int64)
room[:] = 0
room[0] = 5
room[-1] = 6
room[:, 0] = 7
room[:, -1] = 8
room[0, 0] = 1
room[0, -1] = 2
room[-1, 0] = 3
room[-1, -1] = 4
n_emissions = room.max() + 1

a, x, rc = datagen_structured_obs_room(room, length=50000)

n_clones = np.ones(n_emissions, dtype=np.int64) * 70
chmm = CHMM(n_clones=n_clones, pseudocount=2e-3, x=x, a=a, seed=4)  # Initialize the model
progression = chmm.learn_em_T(x, a, n_iter=1000)  # Training


In [None]:
# refine learning
chmm.pseudocount = 0.0
chmm.learn_viterbi_T(x, a, n_iter=100)

In [None]:
cmap = colors.ListedColormap(custom_colors)
plt.matshow(room, cmap=cmap)
plt.savefig("figures/empty_rectangular_room_layout.pdf")

In [None]:
graph = plot_graph(
    chmm, x, a, output_file="figures/empty_rectangular_room_graph.pdf", cmap=cmap, vertex_size=30
)
graph


In [None]:
mess_fwd = get_mess_fwd(chmm, x, pseudocount_E=0.1)

In [None]:
clone = 58
plt.matshow(place_field(mess_fwd, rc, clone))
plt.savefig("figures/empty_rectangular_room_place_field.pdf")

# 5x5 mazes

In [None]:
room = np.random.permutation(25).reshape(5, 5)

a, x, rc = datagen_structured_obs_room(room, length=10000)

n_clones = np.ones(25, dtype=np.int64) * 10
chmm = CHMM(n_clones=n_clones, pseudocount=1e-2, x=x, a=a, seed=4)  # Initialize the model
progression = chmm.learn_em_T(x, a, n_iter=1000)  # Training


In [None]:
# refine learning
chmm.pseudocount = 0.0
chmm.learn_viterbi_T(x, a, n_iter=100)

In [None]:
plt.matshow(room, cmap="Reds")
plt.savefig("figures/square_room_layout.pdf")

In [None]:
graph = plot_graph(
    chmm, x, a, output_file="figures/square_room_graph.pdf", cmap=cm.Reds
)
graph


In [None]:
mess_fwd = get_mess_fwd(chmm, x, pseudocount_E=0.1)

In [None]:
clone = 75
plt.matshow(place_field(mess_fwd, rc, clone))
plt.savefig("figures/square_room_place_field.pdf")

# Two Rooms Stitched Together

In [None]:
room1 = np.array(
    [
        [12, 4, 0, 1, 13, 2],
        [7, 3, 12, 11, 0, 10],
        [5, 12, 14, 12, 9, 4],
        [5, 0, 14, 7, 4, 8],
        [4, 10, 7, 2, 13, 1],
        [3, 14, 8, 3, 12, 11],
        [1, 1, 5, 12, 14, 12],
        [5, 9, 3, 0, 14, 7],
    ]
)

room2 = np.array(
    [
        [3, 12, 11, 4, 11, 11],
        [12, 14, 12, 11, 9, 1],
        [0, 14, 7, 2, 4, 9],
        [0, 0, 9, 8, 2, 11],
        [8, 13, 8, 6, 9, 2],
        [0, 5, 4, 13, 2, 14],
        [14, 4, 13, 7, 9, 14],
        [11, 1, 3, 13, 3, 0],
    ]
)

H, W = room1.shape

no_left = [(r, 0) for r in range(H)]
no_right = [(r, W-1) for r in range(H)]
no_up = [(0, c) for c in range(W)]
no_down = [(H-1, c) for c in range(W)]

a1, x1, rc1 = datagen_structured_obs_room(room1, None, None, no_left, no_right, no_up, no_down, length=50000)
a2, x2, rc2 = datagen_structured_obs_room(room2, None, None, no_left, no_right, no_up, no_down, length=50000)

x = np.hstack((0, x1 + 1, 0, x2 + 1))
a = np.hstack((4, a1[:-1], 4, 4, a2))

n_emissions = x.max() + 1

n_clones = 20 * np.ones(n_emissions, int)
n_clones[0] = 1
chmm = CHMM(n_clones=n_clones, pseudocount=2e-2, x=x, a=a, seed=19)  # Initialize the model
progression = chmm.learn_em_T(x, a, n_iter=1000)  # Training


In [None]:
# refine learning
chmm.pseudocount = 0.0
chmm.learn_viterbi_T(x, a, n_iter=100)
bps = chmm.bps(x, a)
states = chmm.decode(x, a)[1]
n_states = len(np.unique(states))
print("n_states: {} (88 for perfect recovery), bps: {}".format(n_states, bps))


In [None]:
graph = plot_graph(chmm, x, a, output_file="figures/stitched_rooms.pdf", multiple_episodes=True)
graph


In [None]:
mess_fwd = get_mess_fwd(chmm, x, pseudocount_E=0.1)

In [None]:
clone = 229
rc = np.vstack(((0, 8), rc1, (0, 8), rc2 + (5, 3)))
pf = place_field(mess_fwd, rc, clone)
pf[0, 8] = 0.0
plt.matshow(pf)
plt.savefig("figures/stitched_rooms_place_field.pdf")

## Three pentagonal cliques

In [None]:
T = np.zeros((15, 15))
# Connect cliques
for i in range(0, 4 + 1):
    for j in range(0, 4 + 1):
        if i != j:
            T[i, j] = 1.0
for i in range(5, 9 + 1):
    for j in range(5, 9 + 1):
        if i != j:
            T[i, j] = 1.0
for i in range(10, 14 + 1):
    for j in range(10, 14 + 1):
        if i != j:
            T[i, j] = 1.0
# Disconnect in clique connector nodes
T[0, 4] = 0.0
T[4, 0] = 0.0
T[5, 9] = 0.0
T[9, 5] = 0.0
T[10, 14] = 0.0
T[14, 10] = 0.0
# Connect cross clique connector nodes
T[4, 5] = 1.0
T[5, 4] = 1.0
T[9, 10] = 1.0
T[10, 9] = 1.0
T[14, 0] = 1.0
T[0, 14] = 1.0
plt.matshow(T)

# Draw data
states = [0]
for _ in range(10000):
    prev_state = states[-1]

    possible_next_states = np.where(T[prev_state, :])[0]
    next_state = np.random.choice(possible_next_states)
    states.append(next_state)
states = np.array(states)

state_to_obs = (
    np.array([1, 2, 3, 4, 5, 6, 1, 4, 5, 2, 8, 2, 3, 5, 7], dtype=int) - 1
)  # Aliasing version

# Create observation data
x = state_to_obs[states]
a = np.zeros(len(x), dtype=int)

n_clones = np.ones(8, dtype=np.int64) * 5
chmm = CHMM(n_clones=n_clones, pseudocount=1.0, x=x, a=a)  # Initialize the model
progression = chmm.learn_em_T(x, a, n_iter=1000)  # Training


In [None]:
# refine learning
chmm.pseudocount = 0.0
chmm.learn_viterbi_T(x, a, n_iter=100)
states = chmm.decode(x, a)[1]
n_states = len(np.unique(states))
n_states


In [None]:
cmap = colors.ListedColormap(custom_colors[[7, 3, 2, 1, 5, 0, 4, 6]])
graph = plot_graph(chmm, x, a, output_file="figures/pentagonal_cliques.pdf", cmap=cmap)
graph
