In [2]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql import types as T
from pyspark.sql import Window

spark = SparkSession.builder\
      .config("spark.sql.shuffle.partitions", 4)\
      .master("local[4]")\
      .getOrCreate()\

lastRace = Window.partitionBy("year")\

constructorWindow = Window.partitionBy("constructorId")\

lastRaces = spark.read.format("csv")\
    .option("header", "True")\
    .option("sep", ",")\
    .load("../data/races.csv")\
    .where((F.col("year") >= 1990) & (F.col("year") <= 1999))\
    .withColumn("round", F.col("round").cast(T.IntegerType()))\
    .withColumn("max", F.max(F.col("round")).over(lastRace))\
    .where(F.col("round") == F.col("max"))\
    .select("raceId", "year")\

constructorMap = spark.read.format("csv")\
    .option("header", "True")\
    .option("sep", ",")\
    .load("../data/constructors.csv")\
    .select("constructorId", "name")\

results = spark.read.format("csv")\
    .option("header", "True")\
    .option("sep", ",")\
    .load("../data/constructor_standings.csv")\
    .join(lastRaces, ["raceId"], "right")\
    .where(F.col("position") == 1)\
    .select("constructorId", "wins", "year")\
    .withColumn("totalChampWins", 
                F.count(F.col("constructorId")).over(constructorWindow).cast(T.IntegerType()))\
    .withColumn("totalRaceWins", 
                F.sum(F.col("wins")).over(constructorWindow).cast(T.IntegerType()))\
    .dropDuplicates(["constructorId"])\
    .join(constructorMap, "constructorId")\
    .select("totalChampWins", "totalRaceWins", "name")\
    .orderBy(F.col("totalChampWins").desc(), F.col("totalRaceWins").desc())\



Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
22/06/02 22:14:36 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
22/06/02 22:14:37 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.
22/06/02 22:14:37 WARN Utils: Service 'SparkUI' could not bind on port 4041. Attempting port 4042.


In [3]:
results.show()

+--------------+-------------+--------+
|totalChampWins|totalRaceWins|    name|
+--------------+-------------+--------+
|             5|           47|Williams|
|             3|           23| McLaren|
|             1|           11|Benetton|
|             1|            6| Ferrari|
+--------------+-------------+--------+



In [None]:
# lastRace = Window.partitionBy("year")

# lastRaces = spark.read.format("csv")\
#     .option("header", "true")\
#     .option("sep", ",")\
#     .load("../data/races.csv")\
#     .where((F.col("year") >= 1990) & (F.col("year") <= 1999))\
#     .withColumn("round", F.col("round").cast(T.IntegerType()))\
#     .withColumn("max", F.max(F.col("round")).over(lastRace))\
#     .where(F.col("round") == F.col("max"))\
#     .select("raceId", "year")

# constructors = Window.partitionBy("constructorId")

# constructorWinners = spark.read.format("csv")\
#     .option("header", "true")\
#     .option("sep", ",")\
#     .load("../data/constructor_standings.csv")\
#     .join(lastRaces, ["raceId"], "right")\
#     .where(F.col("position") == 1)\
#     .select("constructorId", "wins", "year")\
#     .withColumn("totalChampWins", F.count(F.col("constructorId")).over(constructors))\
#     .withColumn("totalRaceWins", F.sum(F.col("wins")).over(constructors).cast(T.IntegerType()))\
#     .drop("wins")  

# constructors = spark.read.format("csv")\
#     .option("header", "true")\
#     .option("sep", ",")\
#     .load("../data/constructors.csv")\
#     .select("constructorId", "name")


# results = constructorWinners\
#     .drop("year")\
#     .dropDuplicates(["constructorId"])\
#     .join(constructors, "constructorId")\
#     .sort(F.col("totalChampWins").desc(), F.col("totalRaceWins").desc())\
#     .drop("constructorId")

In [None]:
# results.explain(extended=True)

In [None]:
# import time
# def current_milli_time():
#     return round(time.time() * 1000)

# def run():
#     start = current_milli_time()
#     results.collect()
#     return current_milli_time() - start

# def average(l):
#     return sum(l)/len(l)
    
# def time_test():
#     l = list()
#     for i in range(1):
#         l.append(run())
#     return average(l)

# res = time_test()

# print(res)