In [None]:
import numpy as np
from chmm_actions import CHMM, forwardE, datagen_structured_obs_room
import matplotlib.pyplot as plt
import igraph 
from matplotlib import cm, colors
import os

custom_colors = (
    np.array(
        [
            [214, 214, 214],
            [85, 35, 157],
            [253, 252, 144],
            [114, 245, 144],
            [151, 38, 20],
            [239, 142, 192],
            [214, 134, 48],
            [140, 194, 250],
            [72, 160, 162],
        ]
    )
    / 256
)
if not os.path.exists("figures"):
    os.makedirs("figures")

In [None]:
def plot_graph(
    chmm, x, a, output_file, cmap=cm.Spectral, multiple_episodes=False, vertex_size=30
):
    states = chmm.decode(x, a)[1]

    v = np.unique(states)
    if multiple_episodes:
        T = chmm.C[:, v][:, :, v][:-1, 1:, 1:]
        v = v[1:]
    else:
        T = chmm.C[:, v][:, :, v]
    A = T.sum(0)
    A /= A.sum(1, keepdims=True)

    g = igraph.Graph.Adjacency((A > 0).tolist())
    node_labels = np.arange(x.max() + 1).repeat(n_clones)[v]
    if multiple_episodes:
        node_labels -= 1
    colors = [cmap(nl)[:3] for nl in node_labels / node_labels.max()]
    out = igraph.plot(
        g,
        output_file,
        layout=g.layout("kamada_kawai"),
        vertex_color=colors,
        vertex_label=v,
        vertex_size=vertex_size,
        margin=50,
    )

    return out


def get_mess_fwd(chmm, x, pseudocount=0.0, pseudocount_E=0.0):
    n_clones = chmm.n_clones
    E = np.zeros((n_clones.sum(), len(n_clones)))
    last = 0
    for c in range(len(n_clones)):
        E[last : last + n_clones[c], c] = 1
        last += n_clones[c]
    E += pseudocount_E
    norm = E.sum(1, keepdims=True)
    norm[norm == 0] = 1
    E /= norm
    T = chmm.C + pseudocount
    norm = T.sum(2, keepdims=True)
    norm[norm == 0] = 1
    T /= norm
    T = T.mean(0, keepdims=True)
    log2_lik, mess_fwd = forwardE(
        T.transpose(0, 2, 1), E, chmm.Pi_x, chmm.n_clones, x, x * 0, store_messages=True
    )
    return mess_fwd


def place_field(mess_fwd, rc, clone):
    assert mess_fwd.shape[0] == rc.shape[0] and clone < mess_fwd.shape[1]
    field = np.zeros(rc.max(0) + 1)
    count = np.zeros(rc.max(0) + 1, int)
    for t in range(mess_fwd.shape[0]):
        r, c = rc[t]
        field[r, c] += mess_fwd[t, clone]
        count[r, c] += 1
    count[count == 0] = 1
    return field / count

In [None]:
# Functions for setting up the environment
class GridEnvRightDownNoSelf:
    """
    A 3x3 grid with states numbered 1..9:

        1   2   3
        4   5   6
        7   8   9

    - The agent can only move RIGHT (0) or DOWN (1).
    - No self-transitions at borders:
        If an action would go out of bounds, that action is not allowed
        from that state.
    - Same cue logic: must visit 5 for a +1 reward at 9, else goes to 10 with -1.
    """

    def __init__(self, cue_state=2):
        # Grid layout: (row, col) -> state
        self.pos_to_state = {
            (0,0): 0, (0,1): 1, (0,2): 2, (0,3): 3,
            (1,0): 4, (1,1): 5, (1,2): 6, (1,3): 7,
            (2,0): 8, (2,1): 9, (2,2): 10, (2,3): 11,
            (3,0): 12, (3,1): 13, (3,2): 14, (3,3): 15,
        }
        self.state_to_pos = {s: rc for rc, s in self.pos_to_state.items()}
        
        # Actions as integers: 0=right, 1=down
        # right => (0, +1)
        # down  => (+1, 0)
        self.base_actions = {
            0: (0, +1),   # right
            1: (+1, 0)    # down
        }
        
        # Instead of having a static [0,1] action space, we have a
        # per-state action set (no invalid moves).
        self.valid_actions = self._build_valid_actions()

        # Special states
        self.start_state = 0
        self.cue_state = cue_state
        self.rewarded_terminal = 15
        self.unrewarded_terminal = 16  # not in the grid, just a label

        self.reset()

    def _build_valid_actions(self):
        """
        Precompute valid actions for each non-terminal grid state.
        A 'valid' action is one that leads to a NEW in-bounds state.
        """
        valid_dict = {}
        for row in range(4):
            for col in range(4):
                s = self.pos_to_state[(row, col)]
                # We'll store all actions that yield a different state (no self-transitions)
                valid_dict[s] = []
                for a, (dr, dc) in self.base_actions.items():
                    new_r = row + dr
                    new_c = col + dc
                    if 0 <= new_r < 4 and 0 <= new_c < 4:
                        next_s = self.pos_to_state[(new_r, new_c)]
                        # Only count it if next_s != s (which can't happen in 3x3, but let's be explicit)
                        if next_s != s:
                            valid_dict[s].append(a)

        # For terminal states 9 and 10, no actions are valid
        # valid_dict[9] = []
        # valid_dict[10] = []
        valid_dict[15] = []
        valid_dict[16] = []

        return valid_dict

    def reset(self):
        """
        Reset environment to the start:
          - current_state=1
          - visited_cue=False
        Returns current_state (1).
        """
        self.current_state = self.start_state
        self.visited_cue = False
        return self.current_state

    def get_valid_actions(self, state=None):
        """
        Return the list of valid actions for the current state
        (or a given state).
        """
        if state is None:
            state = self.current_state
        return self.valid_actions[state]

    def step(self, action):
        """
        Step with a guaranteed valid action. If an invalid action is given,
        we can either ignore or raise an Exception. We'll raise an error.
        """
        if action not in self.get_valid_actions():
            raise ValueError(f"Action {action} is not valid from state {self.current_state}.")

        # If we're already in a terminal (9 or 10), episode is over.
        if self.current_state in [self.rewarded_terminal, self.unrewarded_terminal]:
            return self.current_state, 0, True

        # Move
        row, col = self.state_to_pos[self.current_state]
        dr, dc = self.base_actions[action]
        next_row = row + dr
        next_col = col + dc

        next_state = self.pos_to_state[(next_row, next_col)]

        # Check cue
        if next_state == self.cue_state:
            self.visited_cue = True

        reward = 0
        done = False

        # Terminal condition: 8 -> 9
        if next_state == 15:
            done = True
            if self.visited_cue:
                reward = +1
                next_state = self.rewarded_terminal
            else:
                reward = -1
                next_state = self.unrewarded_terminal

        # Update
        self.current_state = next_state
        return next_state, reward, done
    
def generate_dataset(env, n_episodes=10, max_steps=20):
    """
    Run 'n_episodes' episodes in the environment. Each episode ends
    either when the environment signals 'done' or when we hit 'max_steps'.

    Returns:
        A list of (state_sequence, action_sequence) pairs.
        - state_sequence: list of visited states
        - action_sequence: list of chosen actions
    """
    dataset = []

    for episode_idx in range(n_episodes):
        # Prepare lists to store states & actions for this episode
        states = []
        actions = []

        # Reset env to start a new episode
        state = env.reset()

        for t in range(max_steps):
            states.append(state)

            valid_actions = env.get_valid_actions(state)
            if not valid_actions:
                # No valid actions => we must be in a terminal or stuck
                break

            # Example: pick a random valid action
            action = np.random.choice(valid_actions)
            actions.append(action)

            # Step in the environment
            next_state, reward, done = env.step(action)
            state = next_state

            if done:
                # Also record the final state
                states.append(state)
                actions.append(action)
                break

        # Store (states, actions) for this episode
        dataset.append([states, actions])

    return dataset

In [None]:
env = GridEnvRightDownNoSelf(cue_state=5)

n_episodes = 500
max_steps_per_episode = 10

dataset = generate_dataset(env, n_episodes, max_steps_per_episode)
# x = dataset[0][0]
# a = dataset[1]


In [None]:
# room = np.array(
#     [
#         [1, 2, 3, 0, 3, 1, 1, 1],
#         [1, 1, 3, 2, 3, 2, 3, 1],
#         [1, 1, 2, 0, 1, 2, 1, 0],
#         [0, 2, 1, 1, 3, 0, 0, 2],
#         [3, 3, 1, 0, 1, 0, 3, 0],
#         [2, 1, 2, 3, 3, 3, 2, 0],
#     ]
# )

n_emissions = 17 #room.max() + 1

# a, x, rc = datagen_structured_obs_room(room, length=50000)


n_clones = np.ones(n_emissions, dtype=np.int64) * 3
x = dataset[0][0]
a = dataset[0][1]
a = np.array(a)
x = np.array(x)
# np.append(a,0)
# dataset = dataset[:1]
chmm = CHMM(n_clones=n_clones, pseudocount=2e-3, x=x, a=a, seed=42)  # Initialize the model
for d, curr_dataset in enumerate(dataset):
    x = curr_dataset[0]
    a = curr_dataset[1]
    a = np.array(a)
    # np.append(a,0)
    x = np.array(x)

    progression = chmm.learn_em_T(x, a, n_iter=1000)  # Training
    
    # refine learning
    chmm.pseudocount = 0.0
    chmm.learn_viterbi_T(x, a, n_iter=100)

In [None]:
len(x)

In [None]:
x = np.array(x)
len(x.shape)

In [None]:
graph = plot_graph(
    chmm, x, a, output_file="figures/rectangular_room_graph.pdf", 
    
)
graph

In [None]:
a

In [None]:
x