In [1]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql import types as T
from pyspark.sql.types import *
from pyspark.sql import Window

spark = SparkSession.builder\
      .config("spark.sql.shuffle.partitions", 4)\
      .master("local[4]")\
      .getOrCreate()

Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
22/05/04 21:29:53 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [2]:
lastRace = Window.partitionBy("year")

lastRaces = spark.read.parquet("../../data/parquet/races.parquet")\
    .where((F.col("year") >= 1990) & (F.col("year") <= 1999))\
    .withColumn("round", F.col("round").cast(T.IntegerType()))\
    .withColumn("max", F.max(F.col("round")).over(lastRace))\
    .where(F.col("round") == F.col("max"))\
    .select("raceId", "year")

constructors = Window.partitionBy("constructorId")

constructorWinners = spark.read.parquet("../../data/parquet/constructor_standings.parquet")\
    .join(lastRaces, ["raceId"], "right")\
    .where(F.col("position") == 1)\
    .select("constructorId", "wins", "year")\
    .withColumn("totalChampWins", F.count(F.col("constructorId")).over(constructors))\
    .withColumn("totalRaceWins", F.sum(F.col("wins")).over(constructors).cast(T.IntegerType()))\
    .drop("wins")  

constructors = spark.read.parquet("../../data/parquet/constructors.parquet")\
    .select("constructorId", "name")


results = constructorWinners\
    .drop("year")\
    .dropDuplicates(["constructorId"])\
    .join(constructors, "constructorId")\
    .sort(F.col("totalChampWins").desc(), F.col("totalRaceWins").desc())\
    .drop("constructorId")

                                                                                

In [3]:
# results.explain(extended=True)

In [4]:
import time
def current_milli_time():
    return round(time.time() * 1000)

def run():
    start = current_milli_time()
    results.collect()
    return current_milli_time() - start

def average(l):
    return sum(l)/len(l)
    
def time_test():
    l = list()
    for i in range(1):
        l.append(run())
    return average(l)

res = time_test()

print(res)

1445.0
