In [None]:
# @title 2. DTW Reference Selection

# CELL 1 [TAG: parameters]
# ---------------------------------------------------------
# Default parameters (Airflow will OVERWRITE these)
# ---------------------------------------------------------
INPUT_GROUPED_DATA = "s3://models/grouped_segments.pkl"
OUTPUT_REFERENCE_DATA = "s3://models/reference_segments.pkl"

# MinIO Credentials (DEFAULTS ONLY - Airflow injects real ones)
MINIO_ENDPOINT = "http://localhost:9000"
MINIO_ACCESS_KEY = "admin"
MINIO_SECRET_KEY = "password123"


In [None]:
# CELL 2: Imports
import pickle
import numpy as np
import s3fs
from dtaidistance import dtw


In [None]:
# CELL 3: MinIO Configuration
# Initialize S3 Filesystem
fs = s3fs.S3FileSystem(
    key=MINIO_ACCESS_KEY,
    secret=MINIO_SECRET_KEY,
    client_kwargs={'endpoint_url': MINIO_ENDPOINT}
)


In [None]:
# CELL 4: Load Data
print(f"Loading grouped data from {INPUT_GROUPED_DATA}...")
try:
    with fs.open(INPUT_GROUPED_DATA, 'rb') as f:
        grouped_segments = pickle.load(f)
    print("✅ Data loaded successfully.")
except FileNotFoundError:
    print(f"❌ Error: Input file {INPUT_GROUPED_DATA} not found. Run Step 01 first.")
    raise


In [None]:
# CELL 5: Logic (DTW Selection)
# [cite_start]Finding the "Centroid" segment for each group [cite: 491-497]
reference_segments = [None, None]
for group_idx in range(2):
    segments = grouped_segments[group_idx]
    num_samples = len(segments)

    if num_samples < 1:
        print(f"Group {group_idx}: No samples found.")
        continue
    print(f"Processing Group {group_idx} (Size: {num_samples})...")

    # Limit sample size for performance optimization (O(N^2) complexity)
    # In production, consider parallelizing or sampling if N > 1000
    limit = min(num_samples, 50)
    dist_matrix = np.zeros((limit, limit))
    # Compute Distance Matrix
    for i in range(limit):
        for j in range(limit):
            if i == j:
                continue
            # DTW on Speed Profile (Column 0)
            # Ensure float type for C-library compatibility
            dist_matrix[i, j] = dtw.distance_fast(
                segments[i][:, 0].astype(float),
                segments[j][:, 0].astype(float)
            )

    # Find Centroid (Segment with minimum sum of distances to all others)
    total_distances = np.sum(dist_matrix, axis=0)
    best_idx = np.argmin(total_distances)

    reference_segments[group_idx] = segments[best_idx]
    print(f"  > Selected Reference Index: {best_idx}")
# [cite_end]


In [None]:
# CELL 6: Save
print(f"Saving reference segments to {OUTPUT_REFERENCE_DATA}...")
with fs.open(OUTPUT_REFERENCE_DATA, 'wb') as f:
    pickle.dump(reference_segments, f)
print("✅ Done.")
