# Markov Models Playground

Welcome to this interactive notebook on **Discrete-Time Markov Models (DTMCs)**. You’ll explore how these models work by:
- Visualizing state transitions
- Simulating random walks
- Interacting with example systems

Let’s get started!
---

## Imports

In [None]:
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
from random import choices

plt.rcParams['figure.facecolor'] = 'white'

## Markov Chain Class

In [None]:
class MarkovChainEnv:
    def __init__(self, transition_matrix, state_names=None):
        self.P = np.array(transition_matrix)
        self.num_states = self.P.shape[0]
        self.states = state_names if state_names else list(range(self.num_states))

    def next_state(self, current_state):
        # TODO: return the next state given the current state by sampling from the transition matrix
        raise NotImplementedError

    def plot_transition_graph(self, highlight_state=None, previous_state=None):
        G = nx.DiGraph()
        for i in range(self.num_states):
            for j in range(self.num_states):
                if self.P[i, j] > 0:
                    G.add_edge(self.states[i], self.states[j], weight=self.P[i, j])

        pos = nx.spring_layout(G, seed=42)
        edge_labels = {(u, v): f"{d['weight']:.2f}" for u, v, d in G.edges(data=True)}

        node_colors = []
        for state in self.states:
            if highlight_state is not None and state == self.states[highlight_state]:
                node_colors.append('red')
            else:
                node_colors.append('lightblue')

        edge_colors = []
        for u, v in G.edges():
            if (previous_state is not None and highlight_state is not None and
                u == self.states[previous_state] and v == self.states[highlight_state]):
                edge_colors.append('red')
            else:
                edge_colors.append('gray')

        nx.draw(G, pos, with_labels=True, node_size=1500, node_color=node_colors, font_size=12, edge_color=edge_colors, width=2)
        nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels)
        plt.title("Markov Chain Transition Graph")
        plt.show()


    def simulate_walk(self, start_state, steps=10):
        current = start_state
        walk = [current]
        for _ in range(steps):
            current = self.next_state(current)
            walk.append(current)
        return walk


    def compute_stationary_distribution(self, tol=1e-8, max_iter=10000):
        dist = np.ones(self.num_states) / self.num_states
        for _ in range(max_iter):
            new_dist = dist @ self.P
            if np.allclose(new_dist, dist, atol=tol):
                break
            dist = new_dist
        return dist

## Define A Markov Chain

Here, we define a basic 3-state Markov chain to demonstrate how a system transitions between discrete states over time. We'll use this model to simulate a walk and visualize its behavior.

In [None]:
classic_chain = MarkovChainEnv(
    transition_matrix=[
        [0.1, 0.1, 0.8],
        [0.3, 0.4, 0.3],
        [0.5, 0.4, 0.1]
    ],
    state_names=['A', 'B', 'C']
)

classic_chain.plot_transition_graph()

## Simulate a Random Walk
We now simulate a random walk beginning in state A (`index 0`) and running for 15 steps. This helps us understand how the chain evolves over time using the defined transition probabilities.

In [None]:
walk = classic_chain.simulate_walk(start_state=0, steps=15)
print("Simulated Walk:", [classic_chain.states[i] for i in walk])

## Another Example: Absorbing Markov Chain

In this example, we introduce an *absorbing state* — a state that, once entered, cannot be left (e.g., state D). Absorbing Markov chains are useful for modeling systems with terminal outcomes, such as failure, success, or exit conditions.

In [None]:
absorbing_chain = MarkovChainEnv(
    transition_matrix=[
        [0.5, 0.5, 0.0, 0.0],
        [0.2, 0.6, 0.2, 0.0],
        [0.0, 0.3, 0.4, 0.3],
        [0.0, 0.0, 0.0, 1.0]  # absorbing state D
    ],
    state_names=['A', 'B', 'C', 'D']
)

absorbing_chain.plot_transition_graph()

In [None]:
walk = absorbing_chain.simulate_walk(start_state=0, steps=15)
print("Simulated Walk:", [absorbing_chain.states[i] for i in walk])

## Interactive Random Walk Simulator

Use the interactive widgets below to:
- Choose a starting state
- Step through the simulation
- Visualize transitions and changes in real time

This tool helps you build intuition about how a Markov chain behaves step by step.

In [None]:
import ipywidgets as widgets
from IPython.display import display, clear_output

# Store simulation state
class MarkovStepper:
    def __init__(self, chain):
        self.chain = chain
        self.setup_widgets()  # Setup widgets first
        self.reset()

    def plot_state_over_time(self, walk, states, title=None):
        time_steps = list(range(len(walk)))
        unique_states = sorted(set(walk))
        plt.figure(figsize=(8, 4))
        plt.plot(time_steps, walk, drawstyle='steps-post', marker='o')
        plt.title(title or "State Over Time")
        plt.xlabel("Time step")
        plt.ylabel("State")
        plt.xticks(time_steps)
        plt.yticks(unique_states, [states[s] for s in unique_states])
        plt.grid(True)
        plt.show()


    def reset(self, start_state=0):
        self.current_state = start_state
        self.walk = [self.current_state]
        self.output.clear_output()
        with self.output:
            print(f"Start at: {self.chain.states[self.current_state]}")
            self.chain.plot_transition_graph(highlight_state=self.current_state)
            self.plot_state_over_time(self.walk, self.chain.states, title="State Over Time (t = 0)")


    def step(self, _=None):
        prev_state = self.current_state
        next_state = self.chain.next_state(self.current_state)
        self.current_state = next_state
        self.walk.append(self.current_state)

        with self.output:
            clear_output(wait=True)
            print(f"Step {len(self.walk)-1}: {self.chain.states[self.current_state]}")
            self.chain.plot_transition_graph(highlight_state=self.current_state, previous_state=prev_state)
            self.plot_state_over_time(self.walk, self.chain.states, title="State Over Time (t = 0)")



    def display(self):
        display(widgets.HBox([self.start_widget, self.reset_button, self.step_button]))
        display(self.output)

    def setup_widgets(self):
        self.start_widget = widgets.Dropdown(
            options=[(name, i) for i, name in enumerate(self.chain.states)],
            value=0,
            description='Start:'
        )
        self.start_widget.observe(self.on_start_change, names='value')

        self.step_button = widgets.Button(description="Next Step")
        self.step_button.on_click(self.step)

        self.reset_button = widgets.Button(description="Reset")
        self.reset_button.on_click(self.on_reset)

        self.output = widgets.Output()

    def on_reset(self, _=None):
        self.reset(self.start_widget.value)

    def on_start_change(self, change):
        self.reset(change['new'])


In [None]:
classic_stepper = MarkovStepper(classic_chain)
classic_stepper.setup_widgets()
classic_stepper.display()
classic_stepper.reset()


## Two Spies

Let's create the Markov chain that describes the Deep Cover Spy's transitions.

In [None]:
deep_cover_chain = None

deep_cover_stepper = MarkovStepper(deep_cover_chain)
deep_cover_stepper.setup_widgets()
deep_cover_stepper.display()
deep_cover_stepper.reset()


## Hidden Markov Models

Here, we introduce Hidden Markov Models. Unlike standard Markov chains, HMMs include:
- A transition matrix for hidden states
- An sensor matrix that links hidden states to observable outputs

This lets us model systems where we can only indirectly observe the true state.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx

class HiddenMarkovModelEnv:
    def __init__(self, transition_matrix, sensor_matrix, initial_probs, state_names=None, observation_names=None):
        self.transition_matrix = np.array(transition_matrix)
        self.sensor_matrix = np.array(sensor_matrix)
        self.pi = np.array(initial_probs)

        self.num_states = self.transition_matrix.shape[0]
        self.num_obs = self.sensor_matrix.shape[1]

        self.states = state_names if state_names else list(range(self.num_states))
        self.observations = observation_names if observation_names else list(range(self.num_obs))

    def next_state(self, current_state):
        return np.random.choice(self.num_states, p=self.transition_matrix[current_state])

    def emit(self, state):
        return np.random.choice(self.num_obs, p=self.sensor_matrix[state])

    def simulate_walk(self, steps=10, start_state=None):
        if start_state is None:
            current = np.random.choice(self.num_states, p=self.pi)
        else:
            current = start_state

        hidden = [current]
        observed = [self.emit(current)]

        for _ in range(steps):
            current = self.next_state(current)
            hidden.append(current)
            observed.append(self.emit(current))

        return [self.states[s] for s in hidden], [self.observations[o] for o in observed]

    def plot_observations(self, observed_seq, title="Observations Over Time"):
        time_steps = list(range(len(observed_seq)))
        obs_indices = [self.observations.index(o) if o in self.observations else o for o in observed_seq]

        plt.figure(figsize=(10, 4))
        plt.plot(time_steps, obs_indices, drawstyle='steps-post', marker='o')
        plt.title(title)
        plt.xlabel("Time Step")
        plt.ylabel("Observation")
        plt.yticks(range(self.num_obs), self.observations)
        plt.grid(True)
        plt.show()

    def plot_transition_graph(self):
        G = nx.DiGraph()
        for i in range(self.num_states):
            for j in range(self.num_states):
                if self.transition_matrix[i, j] > 0:
                    G.add_edge(self.states[i], self.states[j], weight=self.transition_matrix[i, j])

        pos = nx.spring_layout(G, seed=42)
        edge_labels = {(u, v): f"{d['weight']:.2f}" for u, v, d in G.edges(data=True)}

        nx.draw(G, pos, with_labels=True, node_size=1500, node_color="lightblue", font_size=12, edge_color="gray", width=2)
        nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels)
        plt.title("HMM Transition Graph")
        plt.show()


In [None]:
# Create a MarkovChain instance for hidden state transitions

weather_hmm = HiddenMarkovModelEnv(
    transition_matrix=[
        [0.7, 0.3],  # Rainy → Rainy, Sunny
        [0.4, 0.6]   # Sunny → Rainy, Sunny
    ],
    sensor_matrix=[
        [0.1, 0.4, 0.5],  # Rainy → Walk, Shop, Clean
        [0.6, 0.3, 0.1]   # Sunny → Walk, Shop, Clean
    ],
    initial_probs=[0.6, 0.4],
    state_names=["Rainy", "Sunny"],
    observation_names=["Walk", "Shop", "Clean"]
)

weather_hmm.plot_transition_graph()


In [None]:
hidden, observed = weather_hmm.simulate_walk(steps=10)

print("Hidden states: ", hidden)
print("Observed output:", observed)
weather_hmm.plot_observations(observed)


## Two Spies

Now we define the complete Hidden Markov Model for our Two Spies scenario, incorporating both:
- Hidden transitions (spy movement)
- Noisy observations (sensor readings)

In [None]:
#TODO: implement the Two Spies Deep Cover HMM