# Solution: Exercise 1 - Building a Custom PySpark Data Source
This notebook contains the complete solutions for Exercise 1.


## Warm-up: Hello World Data Source


In [None]:

from pyspark.sql.datasource import DataSource, DataSourceReader
from pyspark.sql.types import StructType, StructField, StringType, IntegerType

from databricks.sdk.runtime import spark, display

class HelloWorldDataSource(DataSource):
    """Minimal data source that generates greeting messages."""

    @classmethod
    def name(cls) -> str:
        return "hello"

    def schema(self) -> StructType:
        return StructType([
            StructField("id", IntegerType()),
            StructField("message", StringType()),
        ])

    def reader(self, schema: StructType) -> DataSourceReader:
        return HelloWorldReader(self.options)


class HelloWorldReader(DataSourceReader):
    def __init__(self, options: dict):
        self.count = int(options.get("count", 5))

    def read(self, partition):
        for i in range(self.count):
            yield (i, f"Hello, World #{i}!")

spark.dataSource.register(HelloWorldDataSource)
df = spark.read.format("hello").option("count", 3).load()
display(df)


---
## Solution 1.1: Complete Schema Definition


In [None]:

from pyspark.sql.datasource import DataSource, DataSourceReader, InputPartition
from pyspark.sql.types import (
    StructType, StructField, StringType, DoubleType, TimestampType
)
from typing import Iterator, Tuple
from datetime import datetime, timedelta

def get_dispatchregionsum_schema() -> StructType:
    """
    Return the schema for DISPATCHREGIONSUM table.

    Reference: MMS Electricity Data Model Report - DISPATCH package
    """
    return StructType([
        # Time and identification fields
        StructField("SETTLEMENTDATE", TimestampType(), True),
        StructField("RUNNO", StringType(), True),
        StructField("REGIONID", StringType(), True),
        StructField("DISPATCHINTERVAL", StringType(), True),
        StructField("INTERVENTION", StringType(), True),

        # SOLUTION 1.1: Added measurement fields
        StructField("TOTALDEMAND", DoubleType(), True),
        StructField("AVAILABLEGENERATION", DoubleType(), True),
        StructField("AVAILABLELOAD", DoubleType(), True),
        StructField("DEMANDFORECAST", DoubleType(), True),
        StructField("DISPATCHABLEGENERATION", DoubleType(), True),
        StructField("DISPATCHABLELOAD", DoubleType(), True),
        StructField("NETINTERCHANGE", DoubleType(), True),
    ])

# Verify schema
schema = get_dispatchregionsum_schema()
print(f"Schema has {len(schema.fields)} fields (expected: 12)")
for field in schema.fields:
    print(f"  - {field.name}: {field.dataType}")


### Option: Inline Datasource Implementation (For Debugging)

The following cell contains the complete datasource implementation inline.
This allows Databricks Assistant to help debug timestamp conversion issues.

**To use:** Uncomment the code in the next cell and comment out the package import in the cell after it.

In [None]:
# Complete inline datasource implementation
# Uncomment this block and comment out the package import below to use inline code
# This gives Databricks Assistant full context to help debug timestamp issues

# import csv
# import io
# import re
# import time
# import zipfile
# from datetime import datetime, timedelta
# from typing import Iterator, Tuple, Optional
# from urllib.request import urlopen, Request
# from urllib.error import HTTPError, URLError
# from pyspark.sql.types import StructType, TimestampType, StringType, DoubleType, IntegerType
# 
# # Constants
# NEMWEB_CURRENT_URL = "https://www.nemweb.com.au/REPORTS/CURRENT"
# REQUEST_TIMEOUT = 30
# USER_AGENT = "DatabricksNemwebLab/1.0"
# MAX_RETRIES = 3
# DEBUG_LOG_PATH = "/tmp/nemweb_debug.log"
# 
# TABLE_CONFIG = {
#     "DISPATCHREGIONSUM": {
#         "folder": "DispatchIS_Reports",
#         "file_prefix": "DISPATCHIS",
#         "record_type": "DISPATCH,REGIONSUM"
#     },
# }
# 
# def _debug_log(msg: str) -> None:
#     try:
#         with open(DEBUG_LOG_PATH, "a") as f:
#             f.write(f"{msg}\n")
#     except Exception:
#         pass
# 
# def _get_sample_data(table: str, region: Optional[str] = None) -> list[dict]:
#     sample = [
#         {"SETTLEMENTDATE": "2024-01-01 00:05:00", "RUNNO": "1", "REGIONID": "NSW1",
#          "DISPATCHINTERVAL": "1", "INTERVENTION": "0", "TOTALDEMAND": "7500.5",
#          "AVAILABLEGENERATION": "8000.0", "NETINTERCHANGE": "-200.5"},
#         {"SETTLEMENTDATE": "2024-01-01 00:05:00", "RUNNO": "1", "REGIONID": "VIC1",
#          "DISPATCHINTERVAL": "1", "INTERVENTION": "0", "TOTALDEMAND": "5200.3",
#          "AVAILABLEGENERATION": "5500.0", "NETINTERCHANGE": "150.2"},
#     ]
#     if region:
#         sample = [row for row in sample if row["REGIONID"] == region]
#     return sample
# 
# def fetch_with_retry(url: str, max_retries: int = MAX_RETRIES) -> bytes:
#     last_error = None
#     for attempt in range(max_retries):
#         try:
#             request = Request(url, headers={"User-Agent": USER_AGENT})
#             with urlopen(request, timeout=REQUEST_TIMEOUT) as response:
#                 return response.read()
#         except (HTTPError, URLError) as e:
#             last_error = e
#             if isinstance(e, HTTPError) and e.code == 404:
#                 raise
#             if attempt < max_retries - 1:
#                 time.sleep(1.0 * (2 ** attempt))
#     raise last_error
# 
# def _parse_nemweb_csv_file(csv_file, record_type: str = None) -> list[dict]:
#     text = csv_file.read().decode("utf-8")
#     if not record_type:
#         return list(csv.DictReader(io.StringIO(text)))
#     rows, headers = [], None
#     for parts in csv.reader(io.StringIO(text)):
#         if not parts:
#             continue
#         row_type = parts[0].strip().upper()
#         if row_type == "I" and len(parts) > 2:
#             if f"{parts[1]},{parts[2]}" == record_type:
#                 headers = parts[4:]
#         elif row_type == "D" and headers and len(parts) > 2:
#             if f"{parts[1]},{parts[2]}" == record_type:
#                 values = parts[4:]
#                 row_dict = dict(zip(headers, values))
#                 for header in headers[len(values):]:
#                     row_dict[header] = None
#                 rows.append(row_dict)
#     return rows
# 
# def _fetch_and_extract_zip(url: str, record_type: str = None) -> list[dict]:
#     raw_data = fetch_with_retry(url)
#     zip_data = io.BytesIO(raw_data)
#     rows = []
#     with zipfile.ZipFile(zip_data) as zf:
#         for name in zf.namelist():
#             if name.endswith((".zip", ".ZIP")):
#                 with zf.open(name) as nested_zip_file:
#                     nested_zip_data = io.BytesIO(nested_zip_file.read())
#                     with zipfile.ZipFile(nested_zip_data) as nested_zf:
#                         for nested_name in nested_zf.namelist():
#                             if nested_name.endswith((".CSV", ".csv")):
#                                 with nested_zf.open(nested_name) as csv_file:
#                                     rows.extend(_parse_nemweb_csv_file(csv_file, record_type))
#             elif name.endswith((".CSV", ".csv")):
#                 with zf.open(name) as csv_file:
#                     rows.extend(_parse_nemweb_csv_file(csv_file, record_type))
#     return rows
# 
# def fetch_nemweb_current(table: str, region: Optional[str] = None, max_files: int = 6,
#                          use_sample: bool = False, debug: bool = False) -> list[dict]:
#     if use_sample:
#         return _get_sample_data(table, region)
#     if table not in TABLE_CONFIG:
#         raise ValueError(f"Unsupported table: {table}")
#     config = TABLE_CONFIG[table]
#     current_url = f"{NEMWEB_CURRENT_URL}/{config['folder']}/"
#     try:
#         request = Request(current_url, headers={"User-Agent": USER_AGENT})
#         with urlopen(request, timeout=REQUEST_TIMEOUT) as response:
#             html = response.read().decode('utf-8')
#     except (HTTPError, URLError) as e:
#         if debug:
#             print(f"[NEMWEB] ERROR: {e}")
#         raise
#     pattern = rf'(PUBLIC_{config["file_prefix"]}_\d{{12}}_\d+\.zip)'
#     matches = sorted(set(re.findall(pattern, html, re.IGNORECASE)), reverse=True)[:max_files]
#     rows = []
#     for filename in matches:
#         url = f"{NEMWEB_CURRENT_URL}/{config['folder']}/{filename}"
#         try:
#             data = _fetch_and_extract_zip(url, config.get("record_type"))
#             if region:
#                 data = [row for row in data if row.get("REGIONID") == region]
#             rows.extend(data)
#         except Exception as e:
#             if debug:
#                 print(f"[NEMWEB] ERROR fetching {filename}: {e}")
#     return rows
# 
# # Timestamp parsing and conversion functions
# def _parse_timestamp_value(ts_str: str) -> Optional[datetime]:
#     if not ts_str:
#         return None
#     ts_str = str(ts_str).strip().strip('"').strip("'").strip()
#     if not ts_str:
#         return None
#     for fmt in ["%Y/%m/%d %H:%M:%S", "%Y-%m-%d %H:%M:%S", "%Y/%m/%d %H:%M", "%Y-%m-%d %H:%M"]:
#         try:
#             return datetime.strptime(ts_str, fmt)
#         except ValueError:
#             continue
#     return None
# 
# def _to_python_scalar(v: any, spark_type: any = None) -> any:
#     if v is None:
#         return None
#     if isinstance(v, datetime):
#         if v.tzinfo is not None:
#             import datetime as dt
#             return v.astimezone(dt.timezone.utc).replace(tzinfo=None)
#         return v
#     try:
#         import pandas as pd
#         if isinstance(v, pd.Timestamp):
#             return None if pd.isna(v) else v.to_pydatetime()
#     except Exception:
#         pass
#     try:
#         import numpy as np
#         if isinstance(v, (np.datetime64,)):
#             if np.isnat(v):
#                 return None
#             try:
#                 import pandas as pd
#                 ts = pd.to_datetime(v, utc=False)
#                 return None if pd.isna(ts) else ts.to_pydatetime()
#             except Exception:
#                 try:
#                     ts_str = np.datetime_as_string(v, unit='us', timezone='naive')
#                     return datetime.fromisoformat(ts_str.replace('Z', ''))
#                 except Exception:
#                     return None
#         if isinstance(v, (np.bool_, np.integer, np.floating)):
#             return bool(v) if isinstance(v, np.bool_) else (int(v) if isinstance(v, np.integer) else float(v))
#     except Exception:
#         pass
#     if spark_type and isinstance(spark_type, TimestampType):
#         if isinstance(v, datetime):
#             return v
#         if v is None:
#             return None
#         if isinstance(v, str):
#             parsed = _parse_timestamp_value(v)
#             return parsed if parsed is not None else None
#         return None
#     return v if isinstance(v, str) else v
# 
# def _convert_value(value, spark_type):
#     if value is None or value == "":
#         return None
#     try:
#         str_value = str(value).strip()
#     except Exception:
#         return None
#     if str_value == "":
#         return None
#     if isinstance(spark_type, TimestampType):
#         parsed_ts = _parse_timestamp_value(str_value)
#         result = _to_python_scalar(parsed_ts, spark_type)
#         if result is not None and not isinstance(result, datetime):
#             if isinstance(result, str):
#                 result = _parse_timestamp_value(result)
#             else:
#                 return None
#         return result
#     elif isinstance(spark_type, StringType):
#         return str_value.replace("/", "-") if "/" in str_value else str_value
#     elif isinstance(spark_type, DoubleType):
#         try:
#             return _to_python_scalar(float(str_value), spark_type)
#         except (ValueError, TypeError):
#             return None
#     elif isinstance(spark_type, IntegerType):
#         try:
#             return _to_python_scalar(int(float(str_value)), spark_type)
#         except (ValueError, TypeError):
#             return None
#     return str_value
# 
# def _validate_tuple_types(tuple_values: list, schema) -> Optional[tuple]:
#     validated = []
#     for field, value in zip(schema.fields, tuple_values):
#         if isinstance(field.dataType, TimestampType):
#             if value is None:
#                 validated.append(None)
#             elif isinstance(value, datetime):
#                 if value.tzinfo is not None:
#                     import datetime as dt
#                     validated.append(value.astimezone(dt.timezone.utc).replace(tzinfo=None))
#                 else:
#                     validated.append(value)
#             else:
#                 converted_value = None
#                 try:
#                     converted_value = _parse_timestamp_value(value) if isinstance(value, str) else _to_python_scalar(value, TimestampType())
#                     if converted_value is not None and isinstance(converted_value, datetime):
#                         if converted_value.tzinfo is not None:
#                             import datetime as dt
#                             converted_value = converted_value.astimezone(dt.timezone.utc).replace(tzinfo=None)
#                         validated.append(converted_value)
#                     else:
#                         validated.append(None)
#                 except Exception:
#                     validated.append(None)
#         else:
#             validated.append(value)
#     result = tuple(validated)
#     for field, val in zip(schema.fields, result):
#         if isinstance(field.dataType, TimestampType):
#             if val is not None and not isinstance(val, datetime):
#                 return None
#     return result
# 
# def parse_nemweb_csv(data: list[dict], schema) -> Iterator[Tuple]:
#     _debug_log(f"=== parse_nemweb_csv started ===")
#     _debug_log(f"Processing {len(data)} rows")
#     field_names = [field.name for field in schema.fields]
#     field_types = {field.name: field.dataType for field in schema.fields}
#     row_num = 0
#     for row in data:
#         row_num += 1
#         try:
#             values = []
#             for name in field_names:
#                 raw_value = row.get(name)
#                 if raw_value is None or raw_value == "":
#                     values.append(None)
#                 else:
#                     converted = _convert_value(raw_value, field_types[name])
#                     coerced = _to_python_scalar(converted, field_types[name])
#                     if isinstance(field_types[name], TimestampType):
#                         if coerced is None:
#                             values.append(None)
#                         elif isinstance(coerced, datetime):
#                             if coerced.tzinfo is not None:
#                                 import datetime as dt
#                                 values.append(coerced.astimezone(dt.timezone.utc).replace(tzinfo=None))
#                             else:
#                                 values.append(coerced)
#                         else:
#                             try:
#                                 final_val = _parse_timestamp_value(coerced) if isinstance(coerced, str) else _to_python_scalar(coerced, TimestampType())
#                                 if final_val is None:
#                                     values.append(None)
#                                 elif isinstance(final_val, datetime):
#                                     if final_val.tzinfo is not None:
#                                         import datetime as dt
#                                         values.append(final_val.astimezone(dt.timezone.utc).replace(tzinfo=None))
#                                     else:
#                                         values.append(final_val)
#                                 else:
#                                     values.append(None)
#                             except Exception:
#                                 values.append(None)
#                     else:
#                         values.append(coerced)
#             validated_tuple = _validate_tuple_types(values, schema)
#             if validated_tuple is None:
#                 continue
#             final_values = []
#             for idx, (field, val) in enumerate(zip(schema.fields, validated_tuple)):
#                 if isinstance(field.dataType, TimestampType):
#                     if val is None:
#                         final_values.append(None)
#                     elif isinstance(val, datetime):
#                         if val.tzinfo is not None:
#                             import datetime as dt
#                             final_values.append(val.astimezone(dt.timezone.utc).replace(tzinfo=None))
#                         else:
#                             final_values.append(val)
#                     else:
#                         final_values.append(None)
#                 else:
#                     final_values.append(val)
#             yield tuple(final_values)
#         except Exception as e:
#             _debug_log(f"ROW {row_num} ERROR: {e}")
#             continue
#     _debug_log(f"=== parse_nemweb_csv complete: {row_num} rows ===")
# 
# def get_version() -> str:
#     return "2.10.8-inline"
# 
# print("‚úÖ Inline datasource functions loaded")
# print(f"üìù Debug log: {DEBUG_LOG_PATH}")

## Solution 1.2: Partition Planning and Data Reading


In [None]:

# Import helper functions
import sys
import os

notebook_path = dbutils.notebook.entry_point.getDbutils().notebook().getContext().notebookPath().get()
repo_root = str(os.path.dirname(os.path.dirname(notebook_path)))
sys.path.insert(0, f"/Workspace{repo_root}/src")

from nemweb_utils import fetch_nemweb_current, parse_nemweb_csv, get_version

# Output version for debugging
print(f"nemweb_utils version: {get_version()}")

# Quick test
test_data = fetch_nemweb_current(
    table="DISPATCHREGIONSUM",
    region="NSW1",
    max_files=2,
    use_sample=True
)
print(f"Helper function works! Got {len(test_data)} rows")


In [None]:

class NemwebPartition(InputPartition):
    """
    Represents one partition of NEMWEB data.
    Each partition handles one region's data.
    """
    def __init__(self, region: str, start_date: str, end_date: str):
        self.region = region
        self.start_date = start_date
        self.end_date = end_date


class NemwebReader(DataSourceReader):
    """
    Reader for NEMWEB data source.

    The reader has two jobs:
    1. partitions() - Plan the work (called on driver)
    2. read() - Do the work (called on workers)
    """

    def __init__(self, schema: StructType, options: dict):
        self.schema = schema
        self.options = options
        self.regions = options.get("regions", "NSW1,QLD1,SA1,VIC1,TAS1").split(",")
        yesterday = (datetime.now() - timedelta(days=1)).strftime("%Y-%m-%d")
        self.start_date = options.get("start_date", yesterday)
        self.end_date = options.get("end_date", yesterday)

    def partitions(self) -> list[InputPartition]:
        """
        Plan partitions for parallel reading.

        SOLUTION 1.2a: Create one partition per region
        """
        partitions = []

        for region in self.regions:
            partition = NemwebPartition(
                region=region.strip(),
                start_date=self.start_date,
                end_date=self.end_date
            )
            partitions.append(partition)

        return partitions

    def read(self, partition: NemwebPartition) -> Iterator[Tuple]:
        """
        Read data for a single partition (runs on workers).

        SOLUTION 1.2b: Fetch and parse NEMWEB data
        """
        # Fetch live data from NEMWEB CURRENT folder
        data = fetch_nemweb_current(
            table="DISPATCHREGIONSUM",
            region=partition.region,
            max_files=6
        )

        # Convert to tuples matching schema
        for row_tuple in parse_nemweb_csv(data, self.schema):
            yield row_tuple


# Test partition planning
test_options = {"regions": "NSW1,VIC1,QLD1"}
reader = NemwebReader(schema, test_options)
partitions = reader.partitions()

print(f"Created {len(partitions)} partitions (expected: 3)")
for p in partitions:
    print(f"  - Region: {p.region}")


## Solution 1.3: Complete DataSource Class


In [None]:

class NemwebDataSource(DataSource):
    """
    Custom PySpark Data Source for AEMO NEMWEB electricity market data.

    Usage:
        spark.dataSource.register(NemwebDataSource)
        df = spark.read.format("nemweb").option("regions", "NSW1,VIC1").load()

    Options:
        - regions: Comma-separated list of NEM regions (default: all 5)
        - start_date: Start date in YYYY-MM-DD format
        - end_date: End date in YYYY-MM-DD format
    """

    @classmethod
    def name(cls) -> str:
        """Return the format name used in spark.read.format("...")."""
        # SOLUTION 1.3a: Return format name
        return "nemweb"

    def schema(self) -> StructType:
        """Return the schema for this data source."""
        # SOLUTION 1.3b: Return schema
        return get_dispatchregionsum_schema()

    def reader(self, schema: StructType) -> DataSourceReader:
        """Create a reader for this data source."""
        # SOLUTION 1.3c: Create reader with schema and options
        return NemwebReader(schema, self.options)


## Register and Test


In [None]:

# Register the data source with Spark
spark.dataSource.register(NemwebDataSource)

# Read data using your custom data source!
df = (spark.read
      .format("nemweb")
      .option("regions", "NSW1")
      .load())

print(f"Row count: {df.count()}")
display(df.limit(5))


## Validation


In [None]:

def validate_implementation():
    """Validate the custom data source implementation."""
    print("=" * 60)
    print("FINAL VALIDATION")
    print("=" * 60)

    checks = {
        "Part 1 - Schema (12 fields)": False,
        "Part 2 - Partitions": False,
        "Part 2 - Read": False,
        "Part 3 - DataSource.name()": False,
        "Part 3 - DataSource.schema()": False,
        "Part 3 - DataSource.reader()": False,
    }

    # Part 1: Schema
    schema = get_dispatchregionsum_schema()
    required = ["TOTALDEMAND", "AVAILABLEGENERATION", "NETINTERCHANGE"]
    schema_ok = (
        len(schema.fields) >= 12 and
        all(f in [field.name for field in schema.fields] for f in required)
    )
    checks["Part 1 - Schema (12 fields)"] = schema_ok

    # Part 2: Partitions
    reader = NemwebReader(schema, {"regions": "NSW1,VIC1,QLD1"})
    partitions = reader.partitions()
    checks["Part 2 - Partitions"] = partitions is not None and len(partitions) == 3

    # Part 2: Read
    if partitions:
        try:
            test_partition = NemwebPartition("NSW1", "2024-01-01", "2024-01-01")
            result = list(reader.read(test_partition))
            checks["Part 2 - Read"] = len(result) > 0
        except:
            pass

    # Part 3: DataSource
    try:
        checks["Part 3 - DataSource.name()"] = NemwebDataSource.name() == "nemweb"
    except:
        pass

    try:
        ds = NemwebDataSource(options={})
        checks["Part 3 - DataSource.schema()"] = len(ds.schema().fields) >= 12
    except:
        pass

    try:
        ds = NemwebDataSource(options={})
        checks["Part 3 - DataSource.reader()"] = ds.reader(schema) is not None
    except:
        pass

    # Print results
    print()
    for check, passed in checks.items():
        status = "‚úÖ" if passed else "‚ùå"
        print(f"  {status} {check}")

    all_passed = all(checks.values())
    print()
    if all_passed:
        print("=" * 60)
        print("üéâ All checks passed!")
        print("=" * 60)

    return all_passed

validate_implementation()


## Summary
| Component | Purpose |
|-----------|---------|
| `DataSource.name()` | Format string for `spark.read.format(...)` |
| `DataSource.schema()` | Define output columns and types |
| `DataSource.reader()` | Create reader with options |
| `DataSourceReader.partitions()` | Plan parallel work units |
| `DataSourceReader.read()` | Fetch and yield data (runs on workers) |
## Compare to Production
Your implementation is a simplified version. The production code in
`src/nemweb_datasource_arrow.py` adds:
- **PyArrow RecordBatch** for zero-copy transfer (Serverless compatible)
- **Volume mode** with parallel downloads to UC Volume
- **Multiple tables** (DISPATCHREGIONSUM, DISPATCHPRICE, TRADINGPRICE)
- **Retry logic** with exponential backoff
