In [None]:
from pyspark.sql import SparkSession
import os
import sys

# Add parent directory to path to import pums_loader
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
sys.path.insert(0, project_root)

from pums_loader import (
    process_pums_pipeline,
    load_all_years,
    write_parquet,
    validate_folder_structure,
    print_validation_report,
    REQUIRED_COLUMNS
)

In [None]:
# ============================================================================
# CONFIGURATION
# ============================================================================
# Define absolute paths (notebooks run inside /notebooks directory)
# Project root is one level up from notebooks/

PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), '..'))
RAW_DATA_DIR = os.path.join(PROJECT_ROOT, 'data', 'raw', 'pums')
OUTPUT_DIR = os.path.join(PROJECT_ROOT, 'data', 'processed', 'parquet_pums')

print(f"Project root: {PROJECT_ROOT}")
print(f"Raw data directory: {RAW_DATA_DIR}")
print(f"Output directory: {OUTPUT_DIR}")

# Create output directory if it doesn't exist
os.makedirs(OUTPUT_DIR, exist_ok=True)

In [None]:
# ============================================================================
# CREATE SPARK SESSION
# ============================================================================

spark = SparkSession.builder \
    .appName("PUMS Data Pipeline") \
    .config("spark.sql.adaptive.enabled", "true") \
    .config("spark.sql.adaptive.coalescePartitions.enabled", "true") \
    .getOrCreate()

print("✓ Spark session created")
print(f"Spark version: {spark.version}")

In [None]:
# ============================================================================
# OPTIONAL: VALIDATE FOLDER STRUCTURE
# ============================================================================
# This step is optional but helpful to see what data is available

validation = validate_folder_structure(RAW_DATA_DIR)
print_validation_report(validation)

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/12/07 10:51:01 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [None]:
# ============================================================================
# METHOD 1: USE THE COMPLETE PIPELINE (RECOMMENDED)
# ============================================================================
# This single function call does everything:
# - Validates folder structure
# - Loads all years
# - Combines them
# - Writes to parquet

df_all = process_pums_pipeline(
    spark=spark,
    raw_data_dir=RAW_DATA_DIR,
    output_dir=OUTPUT_DIR,
    columns=REQUIRED_COLUMNS,  # Use default columns, or pass custom list
    skip_year=2020,  # Skip 2020 (experimental data)
    write_individual_years=True,  # Write both combined and individual files
    validate_first=True  # Validate folder structure first
)

In [None]:
# ============================================================================
# METHOD 2: STEP-BY-STEP APPROACH (FOR MORE CONTROL)
# ============================================================================
# Uncomment the code below if you want more control over each step

# # Step 1: Load all years
# df_all = load_all_years(
#     spark=spark,
#     raw_data_dir=RAW_DATA_DIR,
#     columns=REQUIRED_COLUMNS,
#     skip_year=2020
# )

# # Step 2: Write to parquet
# if df_all is not None:
#     combined_path, individual_paths = write_parquet(
#         df_all,
#         output_dir=OUTPUT_DIR,
#         mode="overwrite",
#         write_individual_years=True
#     )

No year data found. Exiting.


In [None]:
# ============================================================================
# EXPLORE THE LOADED DATA
# ============================================================================

if df_all is not None:
    print("\n" + "="*60)
    print("DATA SUMMARY")
    print("="*60)
    
    # Show schema
    print("\nSchema:")
    df_all.printSchema()
    
    # Show row count
    total_rows = df_all.count()
    print(f"\nTotal rows: {total_rows:,}")
    
    # Show column count
    print(f"Total columns: {len(df_all.columns)}")
    
    # Show years available
    if "YEAR" in df_all.columns:
        years = df_all.select("YEAR").distinct().orderBy("YEAR").collect()
        year_list = [row["YEAR"] for row in years]
        print(f"Years: {year_list}")
        
        # Show row counts by year
        print("\nRows by year:")
        year_counts = df_all.groupBy("YEAR").count().orderBy("YEAR").collect()
        for row in year_counts:
            print(f"  {row['YEAR']}: {row['count']:,} rows")
    
    # Show sample data
    print("\nSample data (first 5 rows):")
    df_all.show(5, truncate=False)
    
    # Show some basic statistics
    print("\nBasic statistics for numeric columns:")
    numeric_cols = ["AGEP", "WAGP", "PINCP"]
    available_numeric = [c for c in numeric_cols if c in df_all.columns]
    if available_numeric:
        df_all.select(available_numeric).describe().show()
else:
    print("❌ No data loaded. Check the error messages above.")


In [None]:
# ============================================================================
# READ BACK THE PARQUET FILES (EXAMPLE)
# ============================================================================
# You can read the parquet files back like this:

# Read combined file
# df_combined = spark.read.parquet(os.path.join(OUTPUT_DIR, "pums_all.parquet"))

# Read a specific year
# df_2019 = spark.read.parquet(os.path.join(OUTPUT_DIR, "pums_2019.parquet"))

# Verify the data
# print(f"Combined rows: {df_combined.count():,}")
# print(f"2019 rows: {df_2019.count():,}")


In [None]:
# ============================================================================
# CLEANUP (OPTIONAL)
# ============================================================================
# Uncomment to stop Spark session when done

# spark.stop()
