# Solution: Exercise 1 - Building a Custom PySpark Data Source

**Time:** 15 minutes

This notebook contains the complete solutions for Exercise 1.

## Learning Objectives
1. Understand the Python Data Source API architecture
2. Implement schema definition for external data
3. Create partition-aware data reading for parallelism
4. Register and use your custom data source

## Reference
- [Python Data Source API Docs](https://docs.databricks.com/en/pyspark/datasources.html)
- [Example Implementations (GitHub)](https://github.com/allisonwang-db/pyspark-data-sources)
- Production implementation: `src/nemweb_datasource_arrow.py`


## Warm-up: Hello World Data Source (2 minutes)

Let's start with the **simplest possible** custom data source to understand the API.

A custom data source needs just **3 components**:
1. `name()` - The format string for `spark.read.format("name")`
2. `schema()` - What columns and types your data has
3. `reader()` - Creates a reader that fetches the data


### Key Insight

That's it! A custom data source is:
- **DataSource class**: Declares the name, schema, and creates a reader
- **DataSourceReader class**: Has a `read()` method that yields tuples

The tuples you yield **must match** the schema field order.

In [None]:

from pyspark.sql.datasource import DataSource, DataSourceReader
from pyspark.sql.types import StructType, StructField, StringType, IntegerType

from databricks.sdk.runtime import spark, display

class HelloWorldDataSource(DataSource):
    """Minimal data source that generates greeting messages."""

    @classmethod
    def name(cls) -> str:
        return "hello"

    def schema(self) -> StructType:
        return StructType([
            StructField("id", IntegerType()),
            StructField("message", StringType()),
        ])

    def reader(self, schema: StructType) -> DataSourceReader:
        return HelloWorldReader(self.options)


class HelloWorldReader(DataSourceReader):
    def __init__(self, options: dict):
        self.count = int(options.get("count", 5))

    def read(self, partition):
        for i in range(self.count):
            yield (i, f"Hello, World #{i}!")

spark.dataSource.register(HelloWorldDataSource)
df = spark.read.format("hello").option("count", 3).load()
display(df)


---
## Now Let's Build a Real One: NEMWEB Data Source

AEMO NEMWEB publishes Australia's National Electricity Market data.
We'll create a data source that fetches **live data** from their HTTP API.

Key differences from Hello World:
- **Real schema** based on AEMO's MMS data model
- **Partitions** for parallel reading (one per NEM region)
- **HTTP fetching** from https://www.nemweb.com.au/

> **Reference:** See how the production code does it in `src/nemweb_datasource_arrow.py`

---
## Part 1: Define the Schema (3 minutes)

The DISPATCHREGIONSUM table contains regional dispatch summary data.
Let's define its schema using Spark types.

> **Reference:** [MMS Data Model - DISPATCH package](https://nemweb.com.au/Reports/Current/MMSDataModelReport/)


In [None]:

from pyspark.sql.datasource import DataSource, DataSourceReader, InputPartition
from pyspark.sql.types import (
    StructType, StructField, StringType, DoubleType, TimestampType
)
from typing import Iterator, Tuple
from datetime import datetime, timedelta

def get_dispatchregionsum_schema() -> StructType:
    """
    Return the schema for DISPATCHREGIONSUM table.

    Reference: MMS Electricity Data Model Report - DISPATCH package
    """
    return StructType([
        # Time and identification fields
        StructField("SETTLEMENTDATE", TimestampType(), True),
        StructField("RUNNO", StringType(), True),
        StructField("REGIONID", StringType(), True),
        StructField("DISPATCHINTERVAL", StringType(), True),
        StructField("INTERVENTION", StringType(), True),

        # SOLUTION 1.1: Added measurement fields
        StructField("TOTALDEMAND", DoubleType(), True),
        StructField("AVAILABLEGENERATION", DoubleType(), True),
        StructField("AVAILABLELOAD", DoubleType(), True),
        StructField("DEMANDFORECAST", DoubleType(), True),
        StructField("DISPATCHABLEGENERATION", DoubleType(), True),
        StructField("DISPATCHABLELOAD", DoubleType(), True),
        StructField("NETINTERCHANGE", DoubleType(), True),
    ])

# Verify schema
schema = get_dispatchregionsum_schema()
print(f"Schema has {len(schema.fields)} fields (expected: 12)")
for field in schema.fields:
    print(f"  - {field.name}: {field.dataType}")


## Part 2: Add Partitioning and Data Reading (5 minutes)

Spark achieves parallelism by dividing work into **partitions**.
Each partition can be processed independently on different cores/nodes.

For NEMWEB, we'll create one partition per NEM region:
- NSW1 (New South Wales)
- QLD1 (Queensland)
- SA1 (South Australia)
- VIC1 (Victoria)
- TAS1 (Tasmania)

You need to implement two methods:
1. **`partitions()`** - Plan the work (runs on driver)
2. **`read()`** - Do the work (runs on workers)

### Helper Functions (Provided)

NEMWEB data comes as CSV files inside ZIP archives with a complex multi-record format.
We've provided helper functions so you can focus on the **Data Source API**.

**Note:** The helper functions fetch **REAL data** from AEMO NEMWEB CURRENT folder
(last ~7 days of 5-minute interval files). This ensures you're working with actual
electricity market data.

**Important:** Real AEMO DISPATCHREGIONSUM data contains ~130+ columns, but our schema
only defines 12 key fields. The `parse_nemweb_csv()` function automatically filters
to only the fields in your schema - extra columns are ignored, and missing columns
are set to `None` (which is fine since all fields are nullable).


In [None]:

# Import helper functions
import sys
import os

notebook_path = dbutils.notebook.entry_point.getDbutils().notebook().getContext().notebookPath().get()
repo_root = str(os.path.dirname(os.path.dirname(notebook_path)))
sys.path.insert(0, f"/Workspace{repo_root}/src")

from nemweb_utils import fetch_nemweb_current, parse_nemweb_csv, get_version

# Output version for debugging
print(f"nemweb_utils version: {get_version()}")

# Quick test - these helpers handle HTTP fetching and CSV parsing
# This fetches REAL data from AEMO NEMWEB CURRENT folder (last ~7 days)
test_data = fetch_nemweb_current(
    table="DISPATCHREGIONSUM",
    region="NSW1",
    max_files=2,
    use_sample=False,  # Fetch real current files from AEMO
    debug=True  # Print debug info
)
print(f"Helper function works! Got {len(test_data)} rows")
if test_data:
    # Note: Real AEMO data has MANY more columns (~130+) than our 12-field schema
    # parse_nemweb_csv() will filter to only the fields in our schema
    all_keys = list(test_data[0].keys())
    print(f"Raw CSV has {len(all_keys)} columns (AEMO includes many fields)")
    print(f"Sample row keys (first 10): {all_keys[:10]}...")
    print(f"Sample SETTLEMENTDATE: {test_data[0].get('SETTLEMENTDATE')}")
    
    # Verify key fields exist in real data
    required_fields = ["SETTLEMENTDATE", "REGIONID", "RUNNO", "TOTALDEMAND", 
                      "AVAILABLEGENERATION", "NETINTERCHANGE"]
    missing = [f for f in required_fields if f not in all_keys]
    if missing:
        print(f"⚠️  WARNING: Missing fields in real data: {missing}")
    else:
        print(f"✅ All required fields present in real data")


In [None]:

class NemwebPartition(InputPartition):
    """
    Represents one partition of NEMWEB data.
    Each partition handles one region's data.
    """
    def __init__(self, region: str, start_date: str, end_date: str):
        self.region = region
        self.start_date = start_date
        self.end_date = end_date


class NemwebReader(DataSourceReader):
    """
    Reader for NEMWEB data source.

    The reader has two jobs:
    1. partitions() - Plan the work (called on driver)
    2. read() - Do the work (called on workers)
    """

    def __init__(self, schema: StructType, options: dict):
        self.schema = schema
        self.options = options
        self.regions = options.get("regions", "NSW1,QLD1,SA1,VIC1,TAS1").split(",")
        yesterday = (datetime.now() - timedelta(days=2)).strftime("%Y-%m-%d")
        self.start_date = options.get("start_date", yesterday)
        self.end_date = options.get("end_date", yesterday)

    def partitions(self) -> list[InputPartition]:
        """
        Plan partitions for parallel reading.

        SOLUTION 1.2a: Create one partition per region
        """
        partitions = []

        for region in self.regions:
            partition = NemwebPartition(
                region=region.strip(),
                start_date=self.start_date,
                end_date=self.end_date
            )
            partitions.append(partition)

        return partitions

    def read(self, partition: NemwebPartition) -> Iterator[Tuple]:
        """
        Read data for a single partition (runs on workers).

        SOLUTION 1.2b: Fetch and parse NEMWEB data
        """
        # Fetch data - fetches REAL current files from AEMO NEMWEB
        data = fetch_nemweb_current(
            table="DISPATCHREGIONSUM",
            region=partition.region,
            max_files=6,
            use_sample=False,  # Fetch real current files from AEMO
            debug=True  # Print debug info
        )
        
        print(f"[DEBUG] Fetched {len(data)} rows for region {partition.region}")
        if data:
            print(f"[DEBUG] Sample row keys: {list(data[0].keys())}")
            print(f"[DEBUG] Sample SETTLEMENTDATE value: {data[0].get('SETTLEMENTDATE')}")

        # Convert to tuples matching schema
        tuple_count = 0
        for row_tuple in parse_nemweb_csv(data, self.schema):
            tuple_count += 1
            yield row_tuple
        
        print(f"[DEBUG] Yielded {tuple_count} tuples for region {partition.region}")


# Test partition planning
test_options = {"regions": "NSW1,VIC1,QLD1"}
reader = NemwebReader(schema, test_options)
partitions = reader.partitions()

print(f"Created {len(partitions)} partitions (expected: 3)")
for p in partitions:
    print(f"  - Region: {p.region}")


## Part 3: Complete the Data Source (5 minutes)

Now bring it all together! The DataSource class is the entry point that Spark calls.


In [None]:

def validate_implementation():
    """Validate the custom data source implementation."""
    print("=" * 60)
    print("FINAL VALIDATION")
    print("=" * 60)

    checks = {
        "Part 1 - Schema (12 fields)": False,
        "Part 2 - Partitions": False,
        "Part 2 - Read": False,
        "Part 3 - DataSource.name()": False,
        "Part 3 - DataSource.schema()": False,
        "Part 3 - DataSource.reader()": False,
    }

    # Part 1: Schema
    schema = get_dispatchregionsum_schema()
    required = ["TOTALDEMAND", "AVAILABLEGENERATION", "NETINTERCHANGE"]
    schema_ok = (
        len(schema.fields) >= 12 and
        all(f in [field.name for field in schema.fields] for f in required)
    )
    checks["Part 1 - Schema (12 fields)"] = schema_ok

    # Part 2: Partitions
    reader = NemwebReader(schema, {"regions": "NSW1,VIC1,QLD1"})
    partitions = reader.partitions()
    checks["Part 2 - Partitions"] = partitions is not None and len(partitions) == 3

    # Part 2: Read
    if partitions:
        try:
            test_partition = NemwebPartition("NSW1", "2024-01-01", "2024-01-01")
            result = list(reader.read(test_partition))
            checks["Part 2 - Read"] = len(result) > 0
        except:
            pass

    # Part 3: DataSource
    try:
        checks["Part 3 - DataSource.name()"] = NemwebDataSource.name() == "nemweb"
    except:
        pass

    try:
        ds = NemwebDataSource(options={})
        checks["Part 3 - DataSource.schema()"] = len(ds.schema().fields) >= 12
    except:
        pass

    try:
        ds = NemwebDataSource(options={})
        checks["Part 3 - DataSource.reader()"] = ds.reader(schema) is not None
    except:
        pass

    # Print results
    print()
    for check, passed in checks.items():
        status = "✅" if passed else "❌"
        print(f"  {status} {check}")

    all_passed = all(checks.values())
    print()
    if all_passed:
        print("=" * 60)
        print("🎉 All checks passed!")
        print("=" * 60)

    return all_passed

validate_implementation()


## Part 4: Register and Test


In [None]:

# Note: The tuple-based datasource works for learning, but Spark Connect
# (Serverless) has strict datetime serialization requirements. For production,
# use the Arrow datasource which uses PyArrow RecordBatches to avoid serialization issues.

# Option 1: Test tuple-based datasource (commented out - may have Spark Connect issues)
# spark.dataSource.register(NemwebDataSource)
# df = (spark.read
#       .format("nemweb")
#       .option("regions", "NSW1")
#       .load())

# Option 2: Use production Arrow datasource (works perfectly with Spark Connect)
from nemweb_datasource_arrow import NemwebArrowDataSource
from datetime import datetime, timedelta

spark.dataSource.register(NemwebArrowDataSource)

# Use 2 days ago for CURRENT files (CURRENT has ~7 days of data, but timezone differences mean yesterday may not be available yet)
yesterday = (datetime.now() - timedelta(days=2)).strftime("%Y-%m-%d")

# Read data using the Arrow datasource - fetches real current files
df = (spark.read
      .format("nemweb_arrow")
      .option("table", "DISPATCHREGIONSUM")
      .option("regions", "NSW1")  # Single region for speed
      .option("start_date", yesterday)  # Use recent date for CURRENT files
      .option("end_date", yesterday)
      .load())

# Display results
print(f"Row count: {df.count()}")
display(df.limit(5))


## Summary

You built a custom PySpark data source that fetches **live data** from NEMWEB!

| Component | Purpose |
|-----------|---------|
| `DataSource.name()` | Format string for `spark.read.format(...)` |
| `DataSource.schema()` | Define output columns and types |
| `DataSource.reader()` | Create reader with options |
| `DataSourceReader.partitions()` | Plan parallel work units |
| `DataSourceReader.read()` | Fetch and yield data (runs on workers) |

## Compare to Production

Your implementation is a simplified version. The production code in
`src/nemweb_datasource_arrow.py` adds:
- **PyArrow RecordBatch** for zero-copy transfer (Serverless compatible)
- **Volume mode** with parallel downloads to UC Volume
- **Multiple tables** (DISPATCHREGIONSUM, DISPATCHPRICE, TRADINGPRICE)
- **Retry logic** with exponential backoff

## Next Steps

Proceed to **Exercise 2** to integrate your data source with Lakeflow Pipelines.