# Serverless EDA Pipeline — Amazon Athena Simulation

This notebook simulates **Amazon Athena** ad-hoc queries against the
**Curated Zone** Parquet files produced by the Glue ETL pipeline.

Locally we use **PySpark + SparkSQL** as a drop-in replacement for Athena.
In production, you would point Athena at the same S3 Curated Zone paths
via the AWS Glue Data Catalog and run identical SQL.

In [None]:
import os, sys
sys.path.insert(0, os.path.join(os.getcwd(), '..'))

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pyspark.sql import SparkSession

from aws_pipeline.config import LOCAL_CURATED_ZONE, SPARK_APP_NAME

sns.set_theme(style='whitegrid', palette='muted', font_scale=1.1)
plt.rcParams.update({'figure.dpi': 100})

CURATED = LOCAL_CURATED_ZONE
print(f'Curated Zone: {CURATED}')

In [None]:
# ── Spin up a local SparkSession (simulates Athena engine) ──
spark = (
    SparkSession.builder
    .appName(f'{SPARK_APP_NAME}-EDA')
    .master('local[*]')
    .config('spark.driver.memory', '4g')
    .getOrCreate()
)
spark.sparkContext.setLogLevel('WARN')
print('SparkSession ready.')

## 1. Register Curated Parquet as SQL Tables

In production these would be **Glue Data Catalog tables** queryable by Athena.
Here we register them as Spark temporary views.

In [None]:
# Metadata
meta_train = spark.read.parquet(os.path.join(CURATED, 'metadata_train'))
meta_train.createOrReplaceTempView('metadata_train')
print(f'metadata_train: {meta_train.count()} rows, {len(meta_train.columns)} cols')

meta_test = spark.read.parquet(os.path.join(CURATED, 'metadata_test'))
meta_test.createOrReplaceTempView('metadata_test')
print(f'metadata_test: {meta_test.count()} rows')

# Enriched annotations
anno = spark.read.parquet(os.path.join(CURATED, 'annotations_enriched'))
anno.createOrReplaceTempView('annotations')
print(f'annotations: {anno.count()} rows')

# Class distribution stats
stats = spark.read.parquet(os.path.join(CURATED, 'class_distribution_stats'))
stats.createOrReplaceTempView('class_stats')
print(f'class_stats: {stats.count()} rows')

# Curated tracking (partitioned by lab_id)
tracking = spark.read.parquet(os.path.join(CURATED, 'tracking_curated'))
tracking.createOrReplaceTempView('tracking')
print(f'tracking: {tracking.count():,} rows')

## 2. Lab Distribution (SparkSQL)

In [None]:
lab_dist = spark.sql("""
    SELECT lab_id, COUNT(*) AS video_count
    FROM metadata_train
    GROUP BY lab_id
    ORDER BY video_count DESC
""").toPandas()

fig, ax = plt.subplots(figsize=(12, 6))
sns.barplot(data=lab_dist, y='lab_id', x='video_count', ax=ax, color='steelblue')
ax.set_title('Train: Number of Videos per Lab')
ax.set_xlabel('Number of Videos')
ax.set_ylabel('')
plt.tight_layout()
plt.show()

## 3. Video Metadata Summary (SparkSQL)

In [None]:
meta_summary = spark.sql("""
    SELECT
        split,
        COUNT(*)                              AS n_videos,
        COUNT(DISTINCT lab_id)                AS n_labs,
        ROUND(AVG(video_duration_sec), 1)     AS avg_duration_sec,
        ROUND(AVG(frames_per_second), 1)      AS avg_fps,
        ROUND(AVG(n_mice), 2)                 AS avg_mice
    FROM (
        SELECT * FROM metadata_train
        UNION ALL
        SELECT * FROM metadata_test
    )
    GROUP BY split
    ORDER BY split
""").toPandas()

display(meta_summary)

In [None]:
duration_df = spark.sql("""
    SELECT video_duration_sec FROM metadata_train
""").toPandas()

fps_df = spark.sql("""
    SELECT frames_per_second, COUNT(*) AS cnt
    FROM metadata_train
    GROUP BY frames_per_second
    ORDER BY frames_per_second
""").toPandas()

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

sns.histplot(duration_df['video_duration_sec'], bins=50, ax=axes[0], color='teal', edgecolor='white')
axes[0].set_title('Video Duration Distribution')
axes[0].set_xlabel('Seconds')

sns.barplot(data=fps_df, x='frames_per_second', y='cnt', ax=axes[1], color='coral')
axes[1].set_title('Frame Rate (FPS) Distribution')
axes[1].set_ylabel('Number of Videos')

plt.tight_layout()
plt.show()

## 4. Behavior Class Distribution — Extreme Imbalance (SparkSQL)

In [None]:
class_dist = spark.sql("""
    SELECT
        action,
        SUM(event_count)   AS total_events,
        SUM(total_frames)  AS total_frames
    FROM class_stats
    GROUP BY action
    ORDER BY total_events DESC
""").toPandas()

fig, axes = plt.subplots(1, 2, figsize=(16, 8))

sns.barplot(data=class_dist.sort_values('total_events'),
            y='action', x='total_events', ax=axes[0], color='darkorange')
axes[0].set_title('Annotation Events per Action')
axes[0].set_xlabel('Event Count')
axes[0].set_ylabel('')

sns.barplot(data=class_dist.sort_values('total_frames'),
            y='action', x='total_frames', ax=axes[1], color='seagreen')
axes[1].set_title('Total Annotated Frames per Action')
axes[1].set_xlabel('Frame Count')
axes[1].set_ylabel('')

plt.tight_layout()
plt.show()

print(f'Imbalance ratio (max/min events): '
      f'{class_dist["total_events"].max() / class_dist["total_events"].min():.0f}x')

## 5. Event Duration Analysis (SparkSQL)

In [None]:
duration_stats = spark.sql("""
    SELECT
        action,
        COUNT(*)                                        AS n_events,
        ROUND(AVG(duration_frames), 1)                  AS avg_dur,
        PERCENTILE_APPROX(duration_frames, 0.5)         AS median_dur,
        MIN(duration_frames)                             AS min_dur,
        MAX(duration_frames)                             AS max_dur
    FROM annotations
    GROUP BY action
    ORDER BY median_dur DESC
""").toPandas()

fig, ax = plt.subplots(figsize=(10, 8))
sns.barplot(data=duration_stats.sort_values('median_dur'),
            y='action', x='median_dur', ax=ax, color='indianred')
ax.set_title('Median Event Duration per Action (frames)')
ax.set_xlabel('Frames')
ax.set_ylabel('')
plt.tight_layout()
plt.show()

## 6. Lab × Action Heatmap (SparkSQL + Seaborn)

In [None]:
lab_action = spark.sql("""
    SELECT lab_id, action, SUM(event_count) AS events
    FROM class_stats
    GROUP BY lab_id, action
""").toPandas()

pivot = lab_action.pivot(index='lab_id', columns='action', values='events').fillna(0)

# Keep top 20 actions by total count
top_actions = lab_action.groupby('action')['events'].sum().nlargest(20).index
pivot_top = pivot[[c for c in top_actions if c in pivot.columns]]

fig, ax = plt.subplots(figsize=(16, 8))
sns.heatmap(np.log1p(pivot_top), ax=ax, cmap='YlOrRd', linewidths=0.5,
            cbar_kws={'label': 'log(count+1)'}, xticklabels=True, yticklabels=True)
ax.set_title('Lab × Action Event Count Heatmap (log scale)')
ax.set_xlabel('Action')
ax.set_ylabel('Lab')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.show()

## 7. Self vs Interaction Behaviors (SparkSQL)

In [None]:
type_stats = spark.sql("""
    SELECT
        action_type,
        SUM(event_count)  AS total_events,
        SUM(total_frames) AS total_frames
    FROM class_stats
    GROUP BY action_type
    ORDER BY action_type
""").toPandas()

pal = sns.color_palette(['#66c2a5', '#fc8d62'])

fig, axes = plt.subplots(1, 2, figsize=(12, 4))

sns.barplot(data=type_stats, x='action_type', y='total_events', ax=axes[0], palette=pal)
axes[0].set_title('Event Count: Self vs Interaction')
axes[0].set_xlabel('')
axes[0].set_ylabel('Number of Events')
for p in axes[0].patches:
    axes[0].annotate(f'{int(p.get_height()):,}',
                     (p.get_x() + p.get_width()/2., p.get_height()),
                     ha='center', va='bottom', fontsize=10)

sns.barplot(data=type_stats, x='action_type', y='total_frames', ax=axes[1], palette=pal)
axes[1].set_title('Total Frames: Self vs Interaction')
axes[1].set_xlabel('')
axes[1].set_ylabel('Total Frames')
for p in axes[1].patches:
    axes[1].annotate(f'{int(p.get_height()):,}',
                     (p.get_x() + p.get_width()/2., p.get_height()),
                     ha='center', va='bottom', fontsize=10)

plt.tight_layout()
plt.show()

## 8. Tracking Data Quality Audit (SparkSQL)

In [None]:
quality = spark.sql("""
    SELECT
        lab_id,
        COUNT(*)                                                    AS total_rows,
        ROUND(100.0 * SUM(CASE WHEN x_valid = 0 THEN 1 ELSE 0 END) / COUNT(*), 2) AS nan_x_pct,
        ROUND(100.0 * SUM(CASE WHEN y_valid = 0 THEN 1 ELSE 0 END) / COUNT(*), 2) AS nan_y_pct,
        ROUND(AVG(x_cm), 3)                                        AS avg_x_cm,
        ROUND(AVG(y_cm), 3)                                        AS avg_y_cm
    FROM tracking
    GROUP BY lab_id
    ORDER BY nan_x_pct DESC
""").toPandas()

fig, ax = plt.subplots(figsize=(12, 6))
sns.barplot(data=quality.sort_values('nan_x_pct', ascending=False),
            y='lab_id', x='nan_x_pct', ax=ax, color='salmon')
ax.set_title('Missing Coordinate Rate (%) per Lab')
ax.set_xlabel('NaN x-coordinate %')
ax.set_ylabel('')
plt.tight_layout()
plt.show()

display(quality)

## 9. Train vs Test Comparison (SparkSQL)

In [None]:
comparison = spark.sql("""
    SELECT
        split,
        COUNT(*)                              AS n_videos,
        COUNT(DISTINCT lab_id)                AS n_labs,
        ROUND(AVG(video_duration_sec), 1)     AS avg_dur_sec,
        ROUND(PERCENTILE_APPROX(video_duration_sec, 0.5), 1) AS median_dur_sec,
        ROUND(AVG(n_mice), 2)                 AS avg_mice,
        COUNT(DISTINCT body_parts_tracked)    AS n_tracking_schemes
    FROM (
        SELECT * FROM metadata_train
        UNION ALL
        SELECT * FROM metadata_test
    )
    GROUP BY split
""").toPandas()

display(comparison)

## 10. Summary

| Finding | Detail |
|---|---|
| **Severe class imbalance** | `sniff` dominates; `ejaculate` has only 3 events (~12,600x ratio) |
| **High lab heterogeneity** | 19+ labs with different FPS, keypoint schemes, arena shapes |
| **Wide event duration range** | From single frames to thousands of frames |
| **Missing data varies by lab** | Some labs have >5% NaN in tracking coordinates |
| **Interaction behaviors dominate** | Both in variety and total annotated frames |

In [None]:
spark.stop()
print('SparkSession stopped. EDA complete.')