In [1]:
import org.apache.spark.sql.functions._
import org.apache.spark.storage.StorageLevel

spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "-1")
spark.conf.set("spark.sql.analyzer.failAmbiguousSelfJoin", "false")
spark.conf.set("spark.sql.shuffle.partitions", "4")

val events = spark.read.option("header", "true")
                        .option("inferSchema", "true")
                        .csv("/home/iceberg/data/events.csv")
                        .where($"user_id".isNotNull)

events.createOrReplaceTempView("events")

import org.apache.spark.sql.functions._
import org.apache.spark.storage.StorageLevel
events: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [user_id: int, device_id: int ... 4 more fields]


In [2]:
val devices = spark.read.option("header", "true")
                            .option("inferSchema", "true")
                            .csv("/home/iceberg/data/devices.csv")

devices.createOrReplaceTempView("devices")

devices: org.apache.spark.sql.DataFrame = [device_id: int, browser_type: string ... 2 more fields]


In [19]:
// Caching here should be < 5 GBs or used for broadcast join
// You need to tune executor memory otherwise it will spill to disk and be slow
// Do not really try using any of the other StorageLevel besides MEMORY_ONLY

val eventsAgg = spark.sql(s"""
    select 
        user_id,
        device_id,
        count(1) as event_counts,
        collect_list(distinct host) as host_array
    from events
    group by
        user_id,
        device_id
""").cache()

// eventsAgg.write.mode("overwrite").saveAsTable("bootcamp.events_agg_staging")

eventsAgg: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [user_id: int, device_id: int ... 2 more fields]


In [11]:
spark.sql(f"""
    create table if not exists bootcamp.events_agg_staging (
        user_id bigint,
        device bigint,
        event_counts bigint,
        host_array array<string>
    )
    partitioned by (ds string)
""")

res5: org.apache.spark.sql.DataFrame = []


In [21]:
val eventsAndEventsAgg = events
    .join(eventsAgg, events("user_id") === eventsAgg("user_id"))
    .groupBy(events("user_id"))
    .agg(
        events("user_id"),
        max(eventsAgg("event_counts")).as("total_hits"),
        collect_list(eventsAgg("device_id")).as("devices")
    )

eventsAndEventsAgg: org.apache.spark.sql.DataFrame = [user_id: int, user_id: int ... 2 more fields]


In [22]:
val devicesAndEventsAgg = devices
    .join(eventsAgg, devices("device_id") === eventsAgg("device_id"))
    .groupBy(devices("device_id"), devices("device_type"))
    .agg(
        devices("device_id"),
        devices("device_type"),
        collect_list(eventsAgg("user_id")).as("users")
    )

devicesAndEventsAgg: org.apache.spark.sql.DataFrame = [device_id: int, device_type: string ... 3 more fields]


In [23]:
eventsAndEventsAgg.explain()
devicesAndEventsAgg.explain()

== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- ObjectHashAggregate(keys=[user_id#17], functions=[max(event_counts#246L), collect_list(device_id#18, 0, 0)])
   +- ObjectHashAggregate(keys=[user_id#17], functions=[partial_max(event_counts#246L), partial_collect_list(device_id#18, 0, 0)])
      +- Project [user_id#17, device_id#18, event_counts#246L]
         +- SortMergeJoin [user_id#17], [user_id#280], Inner
            :- Sort [user_id#17 ASC NULLS FIRST], false, 0
            :  +- Exchange hashpartitioning(user_id#17, 4), ENSURE_REQUIREMENTS, [plan_id=227]
            :     +- Filter isnotnull(user_id#17)
            :        +- FileScan csv [user_id#17,device_id#18] Batched: false, DataFilters: [isnotnull(user_id#17)], Format: CSV, Location: InMemoryFileIndex(1 paths)[file:/home/iceberg/data/events.csv], PartitionFilters: [], PushedFilters: [IsNotNull(user_id)], ReadSchema: struct<user_id:int,device_id:int>
            +- Sort [user_id#280 ASC NULLS FIRST], false, 0
    

In [25]:
eventsAndEventsAgg.take(1)
devicesAndEventsAgg.take(1)

res11: Array[org.apache.spark.sql.Row] = Array([-2147042689,Other,-2147042689,Other,WrappedArray(1633522354)])


In [26]:
eventsAgg.unpersist()

res12: eventsAgg.type = [user_id: int, device_id: int ... 2 more fields]
