In [None]:
# ==========================================================
# Q-Learning Demonstration using Gymnasium
# Custom Two-State, Two-Action Environment
# ==========================================================

# Install dependencies (uncomment if running in Colab)
# !pip install gymnasium matplotlib seaborn numpy

import gymnasium as gym
from gymnasium import spaces
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# ----------------------------------------------------------
# Define the Custom Environment
# ----------------------------------------------------------

class TwoStateEnv(gym.Env):
    """Simple deterministic environment with 2 states and 2 actions."""
    metadata = {"render.modes": ["human"]}

    def __init__(self):
        super().__init__()
        self.observation_space = spaces.Discrete(2)  # s0, s1
        self.action_space = spaces.Discrete(2)       # a0, a1

        # Transition and reward table: (state, action) -> (next_state, reward)
        self.transitions = {
            (0, 0): (0, 1),  # (s0,a0)
            (0, 1): (1, 0),  # (s0,a1)
            (1, 0): (0, 4),  # (s1,a0)
            (1, 1): (1, 2)   # (s1,a1)
        }

        self.state = 0  # start in s0

    def reset(self, seed=None, options=None):
        self.state = 0
        return self.state, {}

    def step(self, action):
        next_state, reward = self.transitions[(self.state, action)]
        self.state = next_state
        terminated = False
        truncated = False
        return self.state, reward, terminated, truncated, {}

    def render(self):
        print(f"Current state: {self.state}")

# ----------------------------------------------------------
# Initialize Environment and Parameters
# ----------------------------------------------------------

env = TwoStateEnv()

n_states = env.observation_space.n
n_actions = env.action_space.n

Q = np.zeros((n_states, n_actions))

alpha = 1.0     # Learning rate
gamma = 0.9     # Discount factor
epsilon = 0.1   # Exploration rate

episodes = 10000
steps_per_episode = 4

# Track Q-values for visualization
q_history = []

# ----------------------------------------------------------
# ε-greedy Action Selection
# ----------------------------------------------------------

def epsilon_greedy(state):
    if np.random.rand() < epsilon:
        return env.action_space.sample()
    else:
        return np.argmax(Q[state])

# ----------------------------------------------------------
# Q-Learning Training Loop
# ----------------------------------------------------------

for episode in range(episodes):
    s, _ = env.reset()
    for step in range(steps_per_episode):
        a = epsilon_greedy(s)
        s_next, r, term, trunc, _ = env.step(a)
        Q[s, a] = Q[s, a] + alpha * (r + gamma * np.max(Q[s_next]) - Q[s, a])
        s = s_next

    # Record Q-table snapshot
    q_history.append(Q.copy())

# ----------------------------------------------------------
# Display Final Q-Table
# ----------------------------------------------------------

print("Final Q-table after training:")
print(Q)

# ----------------------------------------------------------
# Visualization: Q-Table Evolution
# ----------------------------------------------------------

# Convert Q-table history into array for plotting
q_history = np.array(q_history)

fig, axes = plt.subplots(1, 2, figsize=(10, 4))
sns.heatmap(Q, annot=True, cmap="YlGnBu", cbar=False, ax=axes[0])
axes[0].set_title("Final Q-Table Values")
axes[0].set_xlabel("Actions")
axes[0].set_ylabel("States")

# Plot evolution of each (s,a) pair across episodes
for s in range(n_states):
    for a in range(n_actions):
        plt.plot(range(episodes), q_history[:, s, a], label=f"Q(s{s},a{a})")
plt.title("Q-Value Evolution Over Episodes")
plt.xlabel("Episode")
plt.ylabel("Q(s,a)")
plt.legend()
plt.show()

# ----------------------------------------------------------
# Interpretation
# ----------------------------------------------------------

print("\nInterpretation:")
print("- The agent learns that (s1,a0) yields the highest long-term return.")
print("- From s0, the best action is a0 (loop with +1 reward).")
print("- The heatmap shows relative Q-values; the line plot shows convergence.")
