# Solution: Exercise 1 - Building a Custom PySpark Data Source

**Time:** 15 minutes

This notebook contains the complete solutions for Exercise 1.

## Learning Objectives
1. Understand the Python Data Source API architecture
2. Implement schema definition for external data
3. Create partition-aware data reading for parallelism
4. Register and use your custom data source

## Reference
- [Python Data Source API Docs](https://docs.databricks.com/en/pyspark/datasources.html)
- [Example Implementations (GitHub)](https://github.com/allisonwang-db/pyspark-data-sources)
- Production implementation: `src/nemweb_datasource_arrow.py`


## Warm-up: Hello World Data Source (2 minutes)

Let's start with the **simplest possible** custom data source to understand the API.

A custom data source needs just **3 components**:
1. `name()` - The format string for `spark.read.format("name")`
2. `schema()` - What columns and types your data has
3. `reader()` - Creates a reader that fetches the data


### Key Insight

That's it! A custom data source is:
- **DataSource class**: Declares the name, schema, and creates a reader
- **DataSourceReader class**: Has a `read()` method that yields PyArrow RecordBatch

For Serverless/Spark Connect compatibility, the `read()` method must yield **PyArrow RecordBatch** objects (not Row objects or tuples).

In [None]:

from pyspark.sql.datasource import DataSource, DataSourceReader, InputPartition
from pyspark.sql.types import (
    StructType, StructField, StringType, DoubleType, TimestampType
)
from typing import Iterator, Tuple
from datetime import datetime, timedelta

def get_dispatchregionsum_schema() -> StructType:
    """
    Return the schema for DISPATCHREGIONSUM table.

    Reference: MMS Electricity Data Model Report - DISPATCH package
    """
    return StructType([
        # Time and identification fields
        StructField("SETTLEMENTDATE", TimestampType(), True),
        StructField("RUNNO", StringType(), True),
        StructField("REGIONID", StringType(), True),
        StructField("DISPATCHINTERVAL", StringType(), True),
        StructField("INTERVENTION", StringType(), True),

        # SOLUTION 1.1: Added measurement fields
        StructField("TOTALDEMAND", DoubleType(), True),
        StructField("AVAILABLEGENERATION", DoubleType(), True),
        StructField("AVAILABLELOAD", DoubleType(), True),
        StructField("DEMANDFORECAST", DoubleType(), True),
        StructField("DISPATCHABLEGENERATION", DoubleType(), True),
        StructField("DISPATCHABLELOAD", DoubleType(), True),
        StructField("NETINTERCHANGE", DoubleType(), True),
    ])

# Verify schema
schema = get_dispatchregionsum_schema()
print(f"Schema has {len(schema.fields)} fields (expected: 12)")
for field in schema.fields:
    print(f"  - {field.name}: {field.dataType}")


## Part 2: Partitioning and Data Reading

Spark achieves parallelism by dividing work into **partitions**. Each partition can be processed independently on different cores/nodes.

For NEMWEB, we'll create one partition per NEM region:
- NSW1 (New South Wales)
- QLD1 (Queensland)  
- SA1 (South Australia)
- VIC1 (Victoria)
- TAS1 (Tasmania)

### Helper Functions

We provide two helper functions:
- `fetch_nemweb_current()` - Fetches real data from AEMO NEMWEB
- `parse_nemweb_to_arrow()` - Converts data to PyArrow RecordBatch (required for Serverless)

In [None]:
# Import helper functions
import sys
import os

notebook_path = dbutils.notebook.entry_point.getDbutils().notebook().getContext().notebookPath().get()
repo_root = str(os.path.dirname(os.path.dirname(notebook_path)))
sys.path.insert(0, f"/Workspace{repo_root}/src")

from nemweb_utils import fetch_nemweb_current, parse_nemweb_to_arrow, get_version

# Output version for debugging
print(f"nemweb_utils version: {get_version()}")

# Quick test - these helpers handle HTTP fetching and Arrow conversion
# This fetches REAL data from AEMO NEMWEB CURRENT folder (last ~7 days)
test_data = fetch_nemweb_current(
    table="DISPATCHREGIONSUM",
    region="NSW1",
    max_files=2,
    use_sample=False,  # Fetch real current files from AEMO
    debug=True  # Print debug info
)
print(f"Helper function works! Got {len(test_data)} rows")
if test_data:
    # Note: Real AEMO data has MANY more columns (~130+) than our 12-field schema
    # parse_nemweb_to_arrow() will filter to only the fields in our schema
    all_keys = list(test_data[0].keys())
    print(f"Raw CSV has {len(all_keys)} columns (AEMO includes many fields)")
    print(f"Sample row keys (first 10): {all_keys[:10]}...")
    print(f"Sample SETTLEMENTDATE: {test_data[0].get('SETTLEMENTDATE')}")
    
    # Verify key fields exist in real data
    required_fields = ["SETTLEMENTDATE", "REGIONID", "RUNNO", "TOTALDEMAND", 
                      "AVAILABLEGENERATION", "NETINTERCHANGE"]
    missing = [f for f in required_fields if f not in all_keys]
    if missing:
        print(f"‚ö†Ô∏è  WARNING: Missing fields in real data: {missing}")
    else:
        print(f"‚úÖ All required fields present in real data")

In [None]:
class NemwebPartition(InputPartition):
    """
    Represents one partition of NEMWEB data.
    Each partition handles one region's data.
    """
    def __init__(self, region: str, start_date: str, end_date: str):
        self.region = region
        self.start_date = start_date
        self.end_date = end_date


class NemwebReader(DataSourceReader):
    """
    Reader for NEMWEB data source.

    The reader has two jobs:
    1. partitions() - Plan the work (called on driver)
    2. read() - Do the work (called on workers)
    """

    def __init__(self, schema: StructType, options: dict):
        self.schema = schema
        self.options = options
        self.regions = options.get("regions", "NSW1,QLD1,SA1,VIC1,TAS1").split(",")
        yesterday = (datetime.now() - timedelta(days=2)).strftime("%Y-%m-%d")
        self.start_date = options.get("start_date", yesterday)
        self.end_date = options.get("end_date", yesterday)

    def partitions(self) -> list[InputPartition]:
        """
        Plan partitions for parallel reading.

        SOLUTION 1.2a: Create one partition per region
        """
        partitions = []

        for region in self.regions:
            partition = NemwebPartition(
                region=region.strip(),
                start_date=self.start_date,
                end_date=self.end_date
            )
            partitions.append(partition)

        return partitions

    def read(self, partition: NemwebPartition):
        """
        Read data for a single partition (runs on workers).

        SOLUTION 1.2b: Fetch data and yield PyArrow RecordBatch
        """
        # Fetch data - fetches REAL current files from AEMO NEMWEB
        data = fetch_nemweb_current(
            table="DISPATCHREGIONSUM",
            region=partition.region,
            max_files=6,
            use_sample=False,  # Fetch real current files from AEMO
            debug=True  # Print debug info
        )
        
        print(f"[DEBUG] Fetched {len(data)} rows for region {partition.region}")

        # Convert to PyArrow RecordBatch (required for Serverless)
        if data:
            yield parse_nemweb_to_arrow(data, self.schema)


# Test partition planning
test_options = {"regions": "NSW1,VIC1,QLD1"}
reader = NemwebReader(schema, test_options)
partitions = reader.partitions()

print(f"Created {len(partitions)} partitions (expected: 3)")
for p in partitions:
    print(f"  - Region: {p.region}")

In [None]:
class NemwebDataSource(DataSource):
    """
    Custom PySpark Data Source for AEMO NEMWEB electricity market data.

    Usage:
        spark.dataSource.register(NemwebDataSource)
        df = spark.read.format("nemweb").option("regions", "NSW1,VIC1").load()

    Options:
        - regions: Comma-separated list of NEM regions (default: all 5)
        - start_date: Start date in YYYY-MM-DD format
        - end_date: End date in YYYY-MM-DD format
    """

    @classmethod
    def name(cls) -> str:
        """Return the format name used in spark.read.format("...")."""
        # SOLUTION 1.3a: Return the format name
        return "nemweb"

    def schema(self) -> StructType:
        """Return the schema for this data source."""
        # SOLUTION 1.3b: Return the schema
        return get_dispatchregionsum_schema()

    def reader(self, schema: StructType) -> DataSourceReader:
        """Create a reader for this data source."""
        # SOLUTION 1.3c: Create and return the reader
        return NemwebReader(schema, self.options)


# Test the complete implementation
print("NemwebDataSource defined successfully!")
print(f"  - name(): {NemwebDataSource.name()}")
ds = NemwebDataSource(options={})
print(f"  - schema(): {len(ds.schema().fields)} fields")
print(f"  - reader(): {type(ds.reader(ds.schema())).__name__}")

In [None]:
def validate_implementation():
    """Validate the custom data source implementation."""
    print("=" * 60)
    print("FINAL VALIDATION")
    print("=" * 60)

    checks = {
        "Part 1 - Schema (12 fields)": False,
        "Part 2 - Partitions": False,
        "Part 2 - Read (RecordBatch)": False,
        "Part 3 - DataSource.name()": False,
        "Part 3 - DataSource.schema()": False,
        "Part 3 - DataSource.reader()": False,
    }

    # Part 1: Schema
    schema = get_dispatchregionsum_schema()
    required = ["TOTALDEMAND", "AVAILABLEGENERATION", "NETINTERCHANGE"]
    schema_ok = (
        len(schema.fields) >= 12 and
        all(f in [field.name for field in schema.fields] for f in required)
    )
    checks["Part 1 - Schema (12 fields)"] = schema_ok

    # Part 2: Partitions
    reader = NemwebReader(schema, {"regions": "NSW1,VIC1,QLD1"})
    partitions = reader.partitions()
    checks["Part 2 - Partitions"] = partitions is not None and len(partitions) == 3

    # Part 2: Read (should yield RecordBatch)
    if partitions:
        try:
            import pyarrow as pa
            test_partition = NemwebPartition("NSW1", "2024-01-01", "2024-01-01")
            result = list(reader.read(test_partition))
            # Check if we got a RecordBatch
            checks["Part 2 - Read (RecordBatch)"] = (
                len(result) > 0 and 
                isinstance(result[0], pa.RecordBatch)
            )
        except Exception as e:
            print(f"Read check error: {e}")

    # Part 3: DataSource
    try:
        checks["Part 3 - DataSource.name()"] = NemwebDataSource.name() == "nemweb"
    except:
        pass

    try:
        ds = NemwebDataSource(options={})
        checks["Part 3 - DataSource.schema()"] = len(ds.schema().fields) >= 12
    except:
        pass

    try:
        ds = NemwebDataSource(options={})
        checks["Part 3 - DataSource.reader()"] = ds.reader(schema) is not None
    except:
        pass

    # Print results
    print()
    for check, passed in checks.items():
        status = "‚úÖ" if passed else "‚ùå"
        print(f"  {status} {check}")

    all_passed = all(checks.values())
    print()
    if all_passed:
        print("=" * 60)
        print("üéâ All checks passed!")
        print("=" * 60)

    return all_passed

validate_implementation()

In [None]:
# Register and use the custom data source we built
spark.dataSource.register(NemwebDataSource)

# Calculate date range (use 2 days ago since CURRENT folder has ~7 days of data)
from datetime import datetime, timedelta
target_date = (datetime.now() - timedelta(days=2)).strftime("%Y-%m-%d")

# Read data using our custom datasource
df = (spark.read
      .format("nemweb")
      .option("regions", "NSW1")
      .option("start_date", target_date)
      .option("end_date", target_date)
      .load())

# Display results
print(f"Row count: {df.count()}")
display(df.limit(5))

## Summary

You built a custom PySpark data source that fetches **live data** from NEMWEB!

| Component | Purpose |
|-----------|---------|
| `DataSource.name()` | Format string for `spark.read.format(...)` |
| `DataSource.schema()` | Define output columns and types |
| `DataSource.reader()` | Create reader with options |
| `DataSourceReader.partitions()` | Plan parallel work units |
| `DataSourceReader.read()` | Fetch data and yield PyArrow RecordBatch |

**Key Requirement:** Serverless/Spark Connect requires `read()` to yield **PyArrow RecordBatch** objects.
The `parse_nemweb_to_arrow()` helper function handles all the type conversion for you.

## Next Steps

Proceed to **Exercise 2** to integrate your data source with Lakeflow Pipelines.