# 3. Train Markov Models (Quality Eval - Train Set Only)


In [None]:
# @title 3. Train Markov Models (Quality Eval)

# CELL 1 [TAG: parameters]
# ---------------------------------------------------------
# Default parameters (Airflow will OVERWRITE these)
# ---------------------------------------------------------
RUN_TIMESTAMP = "2025-01-01_00-00-00"  # Injected by Airflow
INPUT_GROUPED_DATA = "s3://models-quality-eval/2025-01-01_00-00-00/train/grouped_segments.pkl"
OUTPUT_MODEL_DIR = "s3://models-quality-eval/2025-01-01_00-00-00/models/"
V_RES = 2.5
A_RES = 0.25

# MinIO Credentials (DEFAULTS ONLY - Airflow injects real ones)
MINIO_ENDPOINT = "http://localhost:9000"
MINIO_ACCESS_KEY = "admin"
MINIO_SECRET_KEY = "password123"


In [None]:
# CELL 2: Imports
import pickle
import numpy as np
import s3fs


In [None]:
# CELL 3: MinIO Configuration
# Initialize S3 Filesystem
fs = s3fs.S3FileSystem(
    key=MINIO_ACCESS_KEY,
    secret=MINIO_SECRET_KEY,
    client_kwargs={'endpoint_url': MINIO_ENDPOINT}
)


In [None]:
# CELL 4: Logic
print(f"Loading TRAIN grouped segments from {INPUT_GROUPED_DATA}...")
print(f"Run Timestamp: {RUN_TIMESTAMP}")
try:
    with fs.open(INPUT_GROUPED_DATA, 'rb') as f:
        grouped_segments = pickle.load(f)
    print("✅ Training data loaded successfully.")
except FileNotFoundError:
    print(f"❌ Error: Input file {INPUT_GROUPED_DATA} not found. Run Step 01 first.")
    raise

transition_matrices = []
state_definitions = []

for group_idx in range(2):
    # Check if group has data
    if len(grouped_segments[group_idx]) == 0:
        print(f"Warning: Group {group_idx} is empty. Skipping model training for this group.")
        transition_matrices.append(None)
        state_definitions.append(None)
        continue

    all_data = np.concatenate(grouped_segments[group_idx])

    # [cite_start]Grid Definitions [cite: 256]
    v_bins = np.arange(0, np.max(all_data[:, 0]) + V_RES, V_RES)
    a_bins = np.arange(np.min(all_data[:, 1]), np.max(all_data[:, 1]) + A_RES, A_RES)
    num_a_bins = len(a_bins)

    # [cite_start]State Mapping [cite: 258]
    v_indices = np.digitize(all_data[:, 0], v_bins) - 1
    a_indices = np.digitize(all_data[:, 1], a_bins) - 1
    states = v_indices * num_a_bins + a_indices
    unique_states = np.unique(states)
    state_map = {id: i for i, id in enumerate(unique_states)}
    n_states = len(unique_states)

    # [cite_start]Build Matrix [cite: 248]
    trans_matrix = np.zeros((n_states, n_states))
    for i in range(len(states) - 1):
        if states[i] in state_map and states[i + 1] in state_map:
            trans_matrix[state_map[states[i]], state_map[states[i + 1]]] += 1

    # Normalize to Probabilities
    row_sums = trans_matrix.sum(axis=1, keepdims=True)
    # Avoid division by zero for dead-end states
    trans_matrix = np.divide(trans_matrix, row_sums, out=np.zeros_like(trans_matrix), where=row_sums != 0)

    # Build Lookup Table (State ID -> Physical Values)
    state_lookup = np.zeros((n_states, 2))
    for real_id, matrix_idx in state_map.items():
        v_idx, a_idx = divmod(real_id, num_a_bins)
        if v_idx < len(v_bins):
            state_lookup[matrix_idx, 0] = v_bins[v_idx] + V_RES / 2
        # We store accel too, though currently only speed is used for reconstruction
        if a_idx < len(a_bins):
            state_lookup[matrix_idx, 1] = a_bins[a_idx] + A_RES / 2

    transition_matrices.append(trans_matrix)
    state_definitions.append(state_lookup)
    print(f"Group {group_idx} Model Trained: {n_states} unique states found.")


In [None]:
# CELL 5: Visualization (Speed-Acceleration Grid Map)
import matplotlib.pyplot as plt

# Mapping index to names as requested
group_labels = {0: "Heavy Traffic", 1: "Light Traffic"}

for group_idx in range(2):
    # Skip if no model was trained for this group
    if group_idx >= len(state_definitions) or state_definitions[group_idx] is None:
        print(f"Skipping visualization for Group {group_idx} (No model found).")
        continue

    print(f"Generating S-A Grid Map for {group_labels[group_idx]}...")

    # 1. Retrieve Data
    # Get the raw data points used for this group
    raw_points = np.concatenate(grouped_segments[group_idx])
    # Get the lookup table (State ID -> [Speed Center, Accel Center])
    lookup_table = state_definitions[group_idx]

    # 2. Setup Plot
    fig, ax = plt.subplots(figsize=(14, 9))

    # 3. Plot Raw S-A Grid Values (Hollow Blue Circles)
    # edgecolors='b', facecolors='none' creates the hollow circle effect
    ax.scatter(raw_points[:, 0], raw_points[:, 1],
               s=30, alpha=0.6, edgecolors='dodgerblue', facecolors='none', 
               label='S-A Grid Values')

    # 4. Plot State IDs (Red Text)
    # We iterate through the lookup table to place the ID at the bin center
    for state_id, coords in enumerate(lookup_table):
        v_center = coords[0]
        a_center = coords[1]
        
        # Only plot text if it's within the visible graph limits (optional check)
        ax.text(v_center, a_center, str(state_id),
                color='darkred', fontsize=9, fontweight='bold', 
                ha='center', va='center', label='States' if state_id == 0 else "")

    # 5. Configure Grid Lines (To match V_RES and A_RES)
    # Determine bounds for the grid
    v_max_plot = np.max(raw_points[:, 0]) + V_RES
    a_min_plot = np.min(raw_points[:, 1]) - A_RES
    a_max_plot = np.max(raw_points[:, 1]) + A_RES

    # Set ticks to align exactly with the resolution (bin edges)
    # We shift by half resolution because the lookup table stores centers, 
    # but the grid lines should be on the edges.
    x_ticks = np.arange(0, v_max_plot, V_RES)
    y_ticks = np.arange(np.floor(a_min_plot), np.ceil(a_max_plot), A_RES)

    ax.set_xticks(x_ticks)
    ax.set_yticks(y_ticks)
    ax.grid(True, which='both', linestyle='-', color='gray', alpha=0.5)

    # 6. Labels and Titles
    ax.set_xlabel("Speed (km/h)", fontsize=14)
    ax.set_ylabel("Acceleration ($m/s^2$)", fontsize=14)
    ax.set_title(f"Speed-Acceleration Grid Map: {group_labels[group_idx]} (SEG{group_idx})", fontsize=16)
    
    # Custom Legend
    # Since we plotted text multiple times, we need to handle the legend manually 
    # or rely on the label set in the first iteration.
    handles, labels = ax.get_legend_handles_labels()
    # Filter duplicates in legend caused by the text loop
    by_label = dict(zip(labels, handles))
    ax.legend(by_label.values(), by_label.keys(), loc='upper right', framealpha=1, fontsize=12)

    plt.tight_layout()
    plt.show()

In [None]:
# CELL 5: Save
# Ensure output path ends with slash for cleanliness, though S3 is object store
model_dir = OUTPUT_MODEL_DIR.rstrip("/")
print(f"Saving models to {model_dir}...")

# Save Transition Matrices
with fs.open(f"{model_dir}/transition_matrices.pkl", 'wb') as f:
    pickle.dump(transition_matrices, f)

# Save State Definitions
with fs.open(f"{model_dir}/state_definitions.pkl", 'wb') as f:
    pickle.dump(state_definitions, f)

print("✅ Markov Models saved to MinIO (Quality Eval).")
