# Databricks STRUCT Aggregation Demo

This notebook demonstrates modern approaches to aggregate `chargeAmount` fields within STRUCT arrays in Databricks, replacing deprecated `EXPLODE` and `LATERAL VIEW` patterns.


In [1]:
# Create sample data that matches your structure
from pyspark.sql import SparkSession
from pyspark.sql.types import *
from pyspark.sql.functions import *

# Initialize Spark session (already available in Databricks as 'spark')
# spark = SparkSession.builder.appName("StructAggregationDemo").getOrCreate()


In [4]:
# Define the schema for our nested structure
schema = StructType([
    StructField("claimHeader", StructType([
        StructField("claimId", StringType(), True),
        StructField("lineOfBusiness", StringType(), True),
        StructField("totalCharges", DoubleType(), True)
    ]), True),
    StructField("claimDetail", ArrayType(StructType([
        StructField("chargeAmount", DoubleType(), True),
        StructField("units", DoubleType(), True)
    ])), True)
])

print("Schema defined successfully")


Schema defined successfully


In [5]:
# Create sample data matching your example
sample_data = [
    {
        "claimHeader": {
            "claimId": "ABC123456789",
            "lineOfBusiness": "Medicaid",
            "totalCharges": 3.25  # This should match the sum of claimDetail charges
        },
        "claimDetail": [
            {"chargeAmount": 1.25, "units": 1.00},
            {"chargeAmount": 2.00, "units": 1.00}
        ]
    },
    {
        "claimHeader": {
            "claimId": "XYZ987654321",
            "lineOfBusiness": "Medicare",
            "totalCharges": 0.0  # We'll calculate this
        },
        "claimDetail": [
            {"chargeAmount": 5.50, "units": 2.00},
            {"chargeAmount": 3.75, "units": 1.00},
            {"chargeAmount": 1.25, "units": 1.00}
        ]
    },
    {
        "claimHeader": {
            "claimId": "DEF555666777",
            "lineOfBusiness": "Commercial",
            "totalCharges": 0.0
        },
        "claimDetail": [
            {"chargeAmount": 12.50, "units": 1.00}
        ]
    }
]

# Create DataFrame
df = spark.createDataFrame(sample_data, schema)
df.createOrReplaceTempView("claims_table")

print("Sample data created and registered as 'claims_table'")
df.show(truncate=False)


Sample data created and registered as 'claims_table'
+-------------------------------+--------------------------------------+
|claimHeader                    |claimDetail                           |
+-------------------------------+--------------------------------------+
|{ABC123456789, Medicaid, 3.25} |[{1.25, 1.0}, {2.0, 1.0}]             |
|{XYZ987654321, Medicare, 0.0}  |[{5.5, 2.0}, {3.75, 1.0}, {1.25, 1.0}]|
|{DEF555666777, Commercial, 0.0}|[{12.5, 1.0}]                         |
+-------------------------------+--------------------------------------+



## Method 1: Using `aggregate()` Function (Recommended)

The `aggregate()` function is the most flexible and performant approach for summing array elements.


In [6]:
%sql
-- Method 1: Using aggregate() function
SELECT 
    claimHeader.claimId,
    claimHeader.lineOfBusiness,
    claimHeader.totalCharges as original_totalCharges,
    -- Sum all chargeAmount values in the claimDetail array
    aggregate(
        claimDetail, 
        CAST(0.0 AS DOUBLE), -- initial value with explicit cast
        (acc, detail) -> acc + detail.chargeAmount
    ) as calculated_totalCharges,
    claimDetail
FROM claims_table


Unnamed: 0,claimId,lineOfBusiness,original_totalCharges,calculated_totalCharges,claimDetail
0,ABC123456789,Medicaid,3.25,3.25,"[{'chargeAmount': 1.25, 'units': 1.0}, {'chargeAmount': 2.0, 'units': 1.0}]"
1,XYZ987654321,Medicare,0.0,10.5,"[{'chargeAmount': 5.5, 'units': 2.0}, {'chargeAmount': 3.75, 'units': 1.0}, {'chargeAmount': 1.25, 'units': 1.0}]"
2,DEF555666777,Commercial,0.0,12.5,"[{'chargeAmount': 12.5, 'units': 1.0}]"


## Method 2: Using `transform()` and `array_sum()`

This approach first extracts all charge amounts into a simple array, then sums them.


In [7]:
%sql
-- Method 2: Using transform() and reduce()
SELECT 
    claimHeader.claimId,
    claimHeader.lineOfBusiness,
    -- Extract all chargeAmount values and sum them using reduce
    reduce(
        transform(claimDetail, detail -> detail.chargeAmount),
        CAST(0.0 AS DOUBLE),
        (acc, x) -> acc + x
    ) as calculated_totalCharges,
    -- Show the intermediate array for understanding
    transform(claimDetail, detail -> detail.chargeAmount) as chargeAmounts_array
FROM claims_table


Unnamed: 0,claimId,lineOfBusiness,calculated_totalCharges,chargeAmounts_array
0,ABC123456789,Medicaid,3.25,"[1.25, 2.0]"
1,XYZ987654321,Medicare,10.5,"[5.5, 3.75, 1.25]"
2,DEF555666777,Commercial,12.5,[12.5]


## Method 3: Rebuilding the STRUCT with Calculated Total

This shows how to rebuild your STRUCT with the corrected `totalCharges` field


In [8]:
%sql
-- Method 3: Rebuilding the STRUCT with calculated totalCharges
SELECT 
    struct(
        claimHeader.claimId as claimId,
        claimHeader.lineOfBusiness as lineOfBusiness,
        -- Calculate totalCharges from claimDetail array
        aggregate(
            claimDetail, 
            CAST(0.0 AS DOUBLE), 
            (acc, detail) -> acc + detail.chargeAmount
        ) as totalCharges
    ) as claimHeader,
    claimDetail
FROM claims_table


Unnamed: 0,claimHeader,claimDetail
0,"{'claimId': 'ABC123456789', 'lineOfBusiness': 'Medicaid', 'totalCharges': 3.25}","[{'chargeAmount': 1.25, 'units': 1.0}, {'chargeAmount': 2.0, 'units': 1.0}]"
1,"{'claimId': 'XYZ987654321', 'lineOfBusiness': 'Medicare', 'totalCharges': 10.5}","[{'chargeAmount': 5.5, 'units': 2.0}, {'chargeAmount': 3.75, 'units': 1.0}, {'chargeAmount': 1.25, 'units': 1.0}]"
2,"{'claimId': 'DEF555666777', 'lineOfBusiness': 'Commercial', 'totalCharges': 12.5}","[{'chargeAmount': 12.5, 'units': 1.0}]"


## Method 4: Complex Aggregations

You can also perform more complex aggregations, like calculating both sum and count, or conditional sums.


In [9]:
%sql
-- Method 4: Complex aggregations - sum, count, averages, and conditional logic
SELECT 
    claimHeader.claimId,
    claimHeader.lineOfBusiness,
    -- Sum of all charges
    aggregate(
        claimDetail, 
        CAST(0.0 AS DOUBLE), 
        (acc, detail) -> acc + detail.chargeAmount
    ) as total_charges,
    -- Count of line items
    size(claimDetail) as line_item_count,
    -- Average charge per line item
    aggregate(
        claimDetail, 
        CAST(0.0 AS DOUBLE), 
        (acc, detail) -> acc + detail.chargeAmount
    ) / size(claimDetail) as avg_charge_per_line,
    -- Sum of units
    aggregate(
        claimDetail, 
        CAST(0.0 AS DOUBLE), 
        (acc, detail) -> acc + detail.units
    ) as total_units,
    -- Conditional sum: only charges > 2.00
    aggregate(
        claimDetail, 
        CAST(0.0 AS DOUBLE), 
        (acc, detail) -> CASE WHEN detail.chargeAmount > 2.0 THEN acc + detail.chargeAmount ELSE acc END
    ) as high_charges_total
FROM claims_table


Unnamed: 0,claimId,lineOfBusiness,total_charges,line_item_count,avg_charge_per_line,total_units,high_charges_total
0,ABC123456789,Medicaid,3.25,2,1.625,2.0,0.0
1,XYZ987654321,Medicare,10.5,3,3.5,4.0,9.25
2,DEF555666777,Commercial,12.5,1,12.5,1.0,12.5


## Method 5: LATERAL VIEW EXPLODE Syntax

Using the traditional `LATERAL VIEW EXPLODE` syntax which is widely supported across Spark/Databricks versions.


In [11]:
%sql
-- Method 5a: Using LATERAL VIEW EXPLODE syntax to sum charges
SELECT 
    claimHeader.claimId,
    claimHeader.lineOfBusiness,
    claimHeader.totalCharges as original_totalCharges,
    SUM(detail.chargeAmount) as calculated_totalCharges
FROM claims_table
LATERAL VIEW EXPLODE(claimDetail) t AS detail
GROUP BY 
    claimHeader.claimId,
    claimHeader.lineOfBusiness,
    claimHeader.totalCharges


HBox(children=(IntProgress(value=0, bar_style='success'), Label(value='')))

Unnamed: 0,claimId,lineOfBusiness,original_totalCharges,calculated_totalCharges
0,ABC123456789,Medicaid,3.25,3.25
1,XYZ987654321,Medicare,0.0,10.5
2,DEF555666777,Commercial,0.0,12.5


In [12]:
%sql
-- Method 5b: Rebuilding the complete STRUCT with LATERAL VIEW EXPLODE syntax
WITH exploded_data AS (
    SELECT 
        claimHeader,
        detail.chargeAmount,
        detail.units
    FROM claims_table
    LATERAL VIEW EXPLODE(claimDetail) t AS detail
),
aggregated_data AS (
    SELECT 
        claimHeader.claimId,
        claimHeader.lineOfBusiness,
        SUM(chargeAmount) as total_charges,
        -- Reconstruct the original claimDetail array using collect_list
        collect_list(struct(chargeAmount, units)) as claimDetail
    FROM exploded_data
    GROUP BY 
        claimHeader.claimId,
        claimHeader.lineOfBusiness
)
SELECT 
    struct(
        claimId,
        lineOfBusiness,
        total_charges as totalCharges
    ) as claimHeader,
    claimDetail
FROM aggregated_data


HBox(children=(IntProgress(value=0, bar_style='success'), Label(value='')))

Unnamed: 0,claimHeader,claimDetail
0,"{'claimId': 'ABC123456789', 'lineOfBusiness': 'Medicaid', 'totalCharges': 3.25}","[{'chargeAmount': 1.25, 'units': 1.0}, {'chargeAmount': 2.0, 'units': 1.0}]"
1,"{'claimId': 'XYZ987654321', 'lineOfBusiness': 'Medicare', 'totalCharges': 10.5}","[{'chargeAmount': 5.5, 'units': 2.0}, {'chargeAmount': 3.75, 'units': 1.0}, {'chargeAmount': 1.25, 'units': 1.0}]"
2,"{'claimId': 'DEF555666777', 'lineOfBusiness': 'Commercial', 'totalCharges': 12.5}","[{'chargeAmount': 12.5, 'units': 1.0}]"


## Method 6: More Complex Aggregations with Modern EXPLODE

This shows how to perform multiple aggregations using the new EXPLODE syntax.


In [13]:
%sql
-- Method 6: Complex aggregations with LATERAL VIEW EXPLODE syntax
SELECT 
    claimHeader.claimId,
    claimHeader.lineOfBusiness,
    SUM(detail.chargeAmount) as total_charges,
    COUNT(*) as line_item_count,
    AVG(detail.chargeAmount) as avg_charge_per_line,
    SUM(detail.units) as total_units,
    SUM(CASE WHEN detail.chargeAmount > 2.0 THEN detail.chargeAmount ELSE 0 END) as high_charges_total,
    MIN(detail.chargeAmount) as min_charge,
    MAX(detail.chargeAmount) as max_charge
FROM claims_table
LATERAL VIEW EXPLODE(claimDetail) t AS detail
GROUP BY 
    claimHeader.claimId,
    claimHeader.lineOfBusiness


HBox(children=(IntProgress(value=0, bar_style='success'), Label(value='')))

Unnamed: 0,claimId,lineOfBusiness,total_charges,line_item_count,avg_charge_per_line,total_units,high_charges_total,min_charge,max_charge
0,ABC123456789,Medicaid,3.25,2,1.625,2.0,0.0,1.25,2.0
1,XYZ987654321,Medicare,10.5,3,3.5,4.0,9.25,1.25,5.5
2,DEF555666777,Commercial,12.5,1,12.5,1.0,12.5,12.5,12.5


## Comparison: Modern EXPLODE vs Higher-Order Functions

### Modern EXPLODE Syntax (Methods 5 & 6):
**Pros:**
- ✅ Familiar SQL aggregation syntax (`SUM`, `COUNT`, `AVG`)
- ✅ Easy to understand for SQL developers
- ✅ Great for complex multi-column aggregations
- ✅ No deprecated `LATERAL VIEW` syntax

**Cons:**
- ❌ Creates temporary rows (less memory efficient for large arrays)
- ❌ Requires `GROUP BY` and reconstruction for maintaining structure

### Higher-Order Functions (Methods 1-4):
**Pros:**
- ✅ No row expansion - processes arrays in-place
- ✅ Better performance for large nested arrays
- ✅ Maintains data structure without reconstruction

**Cons:**
- ❌ More complex syntax
- ❌ Requires understanding of lambda functions

### Recommendation:
- **Use Modern EXPLODE** for: Complex aggregations, multiple metrics, when SQL familiarity is important
- **Use Higher-Order Functions** for: Simple aggregations, performance-critical applications, very large nested arrays
