In [1]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql import types as T
from pyspark.sql.types import *
from pyspark.sql import Window

spark = SparkSession.builder\
      .config("spark.sql.shuffle.partitions", 4)\
      .master("local[4]")\
      .getOrCreate()

Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
22/05/04 21:32:36 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [2]:
races = spark.read.parquet("../../data/parquet/races.parquet")\
    .select("raceId", "year")\
    .where(F.col("year") == 2021)

In [3]:
drivers = spark.read.parquet("../../data/parquet/drivers.parquet")

In [4]:
driverWindow = Window.partitionBy("driverId")
seasonWindow = Window.partitionBy("year")
driverRaceWindow = Window.partitionBy("driverId", "raceId")
raceDriverLapWindow = Window.partitionBy("driverId", "raceId").orderBy("lap")

In [5]:
driverStats = spark.read.parquet("../../data/parquet/lap_times.parquet")\
    .withColumn("position", F.col("position").cast(IntegerType()))\
    .withColumn("lap", F.col("lap").cast(IntegerType()))\
    .join(races, "raceId")\
    .withColumn("positionNextLap", F.lead(F.col("position"), 1).over(raceDriverLapWindow))\
    .withColumn("positionsGainedLap", F.when(F.col("positionNextLap") < F.col("position") , F.abs(F.col("position") - F.col("positionNextLap"))).otherwise(0))\
    .withColumn("positionsLostLap", F.when(F.col("positionNextLap") > F.col("position"), F.abs(F.col("position") - F.col("positionNextLap"))).otherwise(0))\
    .withColumn("positionsGained", F.sum(F.col("positionsGainedLap")).over(driverRaceWindow))\
    .withColumn("positionsLost", F.sum(F.col("positionsLostLap")).over(driverRaceWindow))\
    .withColumn("lapLeader", F.when(F.col("position") == 1, 1).otherwise(0))\
    .withColumn("lapsLed", F.sum(F.col("lapLeader")).over(driverWindow))\
    .withColumn("totalLaps", F.sum(F.col("lapLeader")).over(seasonWindow))\
    .withColumn("percLapsLed", F.round(F.col("lapsLed") / F.col("totalLaps"), 2))\
    .select("raceId", "driverId", "positionsGained", "positionsLost", "lapsLed", "percLapsLed")\
    .dropDuplicates()



In [6]:
results = spark.read.parquet("../../data/parquet/results.parquet")\
    .withColumn("position", F.col("position").cast(T.IntegerType()))\
    .withColumn("grid", F.col("grid").cast(T.IntegerType()))\
    .withColumn("points", F.col("points").cast(T.IntegerType()))\
    .join(races, "raceId")\
    .join(driverStats, ["raceId", "driverId"], "left")\
    .join(drivers, "driverId")\
    .withColumn("podium", F.when((F.col("position") == 1) | (F.col("position") == 2) | (F.col("position") == 3), F.lit(1)).otherwise(F.lit(0)))\
    .withColumn("averagePoints", F.round(F.avg(F.col("points")).over(driverWindow), 2))\
    .withColumn("maxAvgPoints", F.max(F.col("averagePoints")).over(seasonWindow))\
    .select(
        F.col("code"),
        F.sum(F.col("points")).over(driverWindow).alias("champPoints"),
        F.col("averagePoints"),
        F.round(F.col("averagePoints") / F.col("maxAvgPoints"),2).alias("pointPercent"),
        F.sum(F.col("podium")).over(driverWindow).alias("totalPodiums"),
        F.round(F.sum(F.col("podium")).over(driverWindow) / F.count(F.col("podium")).over(driverWindow), 2).alias("podiumPercent"),
        F.round(F.avg(F.col("position") - F.col("grid")).over(driverWindow), 2).alias("positionDelta"),
        F.round(F.avg(F.col("positionsLost")).over(driverWindow), 2).alias("avgPositionsLost"),
        F.round(F.avg(F.col("positionsGained")).over(driverWindow), 2).alias("avgPositionsWon"),
        F.sum(F.col("positionsLost")).over(driverWindow).alias("totalPositionsLost"),
        F.sum(F.col("positionsGained")).over(driverWindow).alias("totalPositionsWon"),
        F.col("lapsLed"),
        F.col("percLapsLed")\
    )\
    .na.fill(0)\
    .dropDuplicates(["code"])\
    .sort(F.col("avgPositionsLost").desc())

                                                                                

In [7]:
import time
def current_milli_time():
    return round(time.time() * 1000)

def run():
    start = current_milli_time()
    results.collect()
    return current_milli_time() - start

def average(l):
    return sum(l)/len(l)
    
def time_test():
    l = list()
    for i in range(1):
        l.append(run())
    return average(l)

res = time_test()
