In [1]:
from pyspark.sql import SparkSession
import pyspark.sql.functions as sf

if not 'spark' in locals():
    spark = SparkSession.builder \
        .master("local[*]") \
        .config("spark.driver.memory","4G") \
        .getOrCreate()

spark.conf.set("spark.sql.adaptive.enabled", False)
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1)
spark

/opt/anaconda3/lib/python3.10/site-packages/pyspark/bin/load-spark-env.sh: line 68: ps: command not found
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
23/11/25 17:35:59 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


### Set Checkpoint directory

First we need to specify a checkpoint directory on a reliable shared file system.

In [2]:
spark.sparkContext.setCheckpointDir("/tmp/checkpoints")

# 1 Fictional Sales Data

In this example we use a fictional data set of company revenues. The special property of this data set is that a company can have a different company as its parent company. Eventually a business expert wants to see the whole revenue of a company including all child companies. This requires that we build up an additional table containing all children (direct and indirect) for every company, such that we can join the revenues against this table and then aggregate over all direct and indirect children for each parent.

Let's start by loading and inspecting the data.

In [3]:
basedir = "s3://dimajix-training/data"
#basedir = "file:///home/kaya/Jupyter/Dimajix/pyspark-advanced/data"

In [4]:
data = spark.read \
    .option("header", True) \
    .option("inferSchema", True) \
    .csv(basedir + "/global-sales.csv")

data.printSchema()

                                                                                

root
 |-- company: integer (nullable = true)
 |-- parent_company: integer (nullable = true)
 |-- company_name: string (nullable = true)
 |-- revenue: integer (nullable = true)



In [26]:
data.limit(10).toPandas()

Unnamed: 0,company,parent_company,company_name,revenue
0,1,,Global Earth Inc,10000
1,2,1.0,European Markets,0
2,3,2.0,Germany Sales GmbH,2000
3,4,2.0,Spain Products,123000
4,5,2.0,Swiss Made,213000
5,6,2.0,France Superstars,241000
6,7,3.0,Berlin Store,287000
7,8,3.0,Hamburg Store,312000
8,9,3.0,Hessian Store Group,10000
9,10,9.0,Frankfurt Shop Central,287000


## Enrich Data

We implement a small helper function, which performs a lookup of the company name and the parent companies name.

In [24]:
def enrich_data(result):
    result = result.alias("result") \
        .join(data.alias("data"), sf.col("result.parent_company") == sf.col("data.company")) \
        .select(
            sf.col("result.company"),
            sf.col("result.parent_company"),
            sf.col("data.company_name").alias("parent_company_name")
        )
    result = result.alias("result") \
        .join(data.alias("data"), sf.col("result.company") == sf.col("data.company")) \
        .select(
            sf.col("result.company"),
            sf.col("data.company_name").alias("company_name"),
            sf.col("result.parent_company"),
            sf.col("result.parent_company_name")
        )
    return result.orderBy("parent_company","company")

In [25]:
result = enrich_data(data)
result.toPandas()

Unnamed: 0,company,company_name,parent_company,parent_company_name
0,2,European Markets,1,Global Earth Inc
1,13,American Markets,1,Global Earth Inc
2,14,Asian Markets,1,Global Earth Inc
3,15,African Markets,1,Global Earth Inc
4,3,Germany Sales GmbH,2,European Markets
5,4,Spain Products,2,European Markets
6,5,Swiss Made,2,European Markets
7,6,France Superstars,2,European Markets
8,7,Berlin Store,3,Germany Sales GmbH
9,8,Hamburg Store,3,Germany Sales GmbH


# 2 Single Step of transitive parent-child relations

In the next step we want to build the helper table containing all children for every company. We will calculate this table using an iterative algorithm which adds the next level of children in every iteration. We first implement a single iteration, which will add the next level of children to each parent company.

In [27]:
# Remove all records without a parent company for the algorithm
cleaned_df = data \
    .filter(data["parent_company"].isNotNull()) \
    .select(data["company"], data["parent_company"])

In [28]:
def iterate_parent_child(df):
    # Denote the incoming table "parent" and "child", since we will do a self-join and the join condition would be ambigious without aliases otherwise
    parent_df = df.alias("parent")
    child_df = df.alias("child")
    
    # Calculate next levels of indirect children by joining the table to itself and by retrieving the child of each child of each parent
    next_level = parent_df.join(child_df, sf.col("parent.company") == sf.col("child.parent_company"), "inner") \
        .select(sf.col("parent.parent_company"), sf.col("child.company"))

    # Add current relations, otherwise they will be lost
    cur_level = parent_df.select(parent_df["parent_company"], parent_df["company"])
    
    # Return union of next indirection and current relations
    return next_level.union(cur_level).distinct()

### Perform one iteration

Now let us perform a single iteration and inspect the result.

In [29]:
result = iterate_parent_child(cleaned_df)
result = enrich_data(result)
result.toPandas()

Unnamed: 0,company,company_name,parent_company,parent_company_name
0,2,European Markets,1,Global Earth Inc
1,3,Germany Sales GmbH,1,Global Earth Inc
2,4,Spain Products,1,Global Earth Inc
3,5,Swiss Made,1,Global Earth Inc
4,6,France Superstars,1,Global Earth Inc
5,13,American Markets,1,Global Earth Inc
6,14,Asian Markets,1,Global Earth Inc
7,15,African Markets,1,Global Earth Inc
8,16,USA Group,1,Global Earth Inc
9,17,Egypt Products,1,Global Earth Inc


# 3 Iterative Algorithm

Now that we can add one level of indirection to our table of parent-child relations, we simply need to apply this algorithm as often as new records are created. We also add a reflective relation of each company to itself at the end, such that when using the table for aggregating all children, the revenue of each company itself  will also be added up in addition to its children.

In [34]:
def calc_transitive_children(df):
    # Remove records without a parent
    cleaned_df = data \
        .filter(data["parent_company"].isNotNull()) \
        .select(data["company"], data["parent_company"])
    
    # Iterate as long as new records are created
    cur_df = cleaned_df
    # This would be a good point to cache/checkpoint
    cur_df = cur_df.checkpoint()
    #cur_df.cache()
    
    cur_count = cur_df.count()

    while (True):
        # Perform next iteration using iterate_parent_child
        next_df = iterate_parent_child(cur_df)
        # This would be a good place to perform a checkpoint
        next_df = next_df.checkpoint()
        #next_df.cache()

        # Count number of records
        next_count = next_df.count()
        # If no new records are created, we are finished
        if next_count == cur_count:
            break

        # If you use caching, this would be a good point to relase any previous caches
        #cur_df.unpersist()
        cur_df = next_df 
        cur_count = next_count
        
    # Create additional reflective relation of each company to itself
    self_df = data.select(sf.col("company").alias("parent_company"), sf.col("company"))
    
    return self_df.union(cur_df).distinct()

### Run Algorithm

Now let us run the whole algorithm on the original data set and inspect the result.

In [35]:
relations = calc_transitive_children(data)

result = enrich_data(relations)
result.toPandas()

Unnamed: 0,company,company_name,parent_company,parent_company_name
0,1,Global Earth Inc,1,Global Earth Inc
1,2,European Markets,1,Global Earth Inc
2,3,Germany Sales GmbH,1,Global Earth Inc
3,4,Spain Products,1,Global Earth Inc
4,5,Swiss Made,1,Global Earth Inc
...,...,...,...,...
58,16,USA Group,16,USA Group
59,17,Egypt Products,17,Egypt Products
60,18,China Inc,18,China Inc
61,19,Japanese Stores,19,Japanese Stores


### Inspect execution plan

In [36]:
relations.explain()

== Physical Plan ==
*(4) HashAggregate(keys=[parent_company#2702, company#17], functions=[])
+- Exchange hashpartitioning(parent_company#2702, company#17, 200), ENSURE_REQUIREMENTS, [plan_id=2918]
   +- *(3) HashAggregate(keys=[parent_company#2702, company#17], functions=[])
      +- Union
         :- *(1) Project [company#17 AS parent_company#2702, company#17]
         :  +- FileScan csv [company#17] Batched: false, DataFilters: [], Format: CSV, Location: InMemoryFileIndex(1 paths)[file:/home/kaya/Jupyter/Dimajix/pyspark-advanced/data/global-sales.csv], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<company:int>
         +- *(2) Scan ExistingRDD[parent_company#18,company#2649]




# 4 Perform Aggregation

Now let us perform the final aggregation, such that we can calculate the revenue of each company including each direct and indirect child. This can be performed by joining the `relations` data frame to the original `data` data frame and then grouping on the `parent_company` column of the `relations` data frame and adding up the revenue.

In [39]:
hierarchical_revenue = relations \
    .join(data, ["company"]) \
    .groupby(relations["parent_company"]) \
    .agg(sf.sum(sf.col("revenue")).alias("total_revenue")).alias("result") \
    .join(data.alias("data"), sf.col("result.parent_company") == sf.col("data.company")) \
    .select(
        sf.col("result.parent_company").alias("company"),
        sf.col("data.company_name"),
        sf.col("result.total_revenue"),
    )

In [40]:
hierarchical_revenue.toPandas()

Unnamed: 0,company,company_name,total_revenue
0,12,Darmstadt Shop,90000
1,1,Global Earth Inc,9492820
2,13,American Markets,2231000
3,6,France Superstars,241000
4,16,USA Group,2131000
5,3,Germany Sales GmbH,1109000
6,20,Korean United,198000
7,5,Swiss Made,213000
8,19,Japanese Stores,2179820
9,15,African Markets,197000


### Check Totals

Just to verify the result, let us compare the result of company 1 ("Earth") with a simple sum over all revenues.

In [35]:
totals = data.select(sf.sum(data["revenue"]))
totals.toPandas()

Unnamed: 0,sum(revenue)
0,9492820
