# Functions Shared Across Notebooks

In [None]:
from pyspark.sql import functions as F
from pyspark.sql.types import StructType, StructField, IntegerType, StringType, FloatType

In [None]:
schema = StructType([
    StructField('Rk', IntegerType()),
    StructField('Player', StringType()),
    StructField('Pos', StringType()),
    StructField('Age', IntegerType()),
    StructField('Tm', StringType()),
    StructField('G', IntegerType()),
    StructField('GS', IntegerType()),
    StructField('MP', IntegerType()),
    StructField('FG', IntegerType()),
    StructField('FGA', IntegerType()),
    StructField('FG%', FloatType()),
    StructField('3P', IntegerType()),
    StructField('3PA', IntegerType()),
    StructField('3P%', FloatType()),
    StructField('2P', IntegerType()),
    StructField('2PA', IntegerType()),
    StructField('2P%', FloatType()),
    StructField('eFG%', FloatType()),
    StructField('FT', IntegerType()),
    StructField('FTA', IntegerType()),
    StructField('FT%', FloatType()),
    StructField('ORB', IntegerType()),
    StructField('DRB', IntegerType()),
    StructField('TRB', IntegerType()),
    StructField('AST', IntegerType()),
    StructField('STL', IntegerType()),
    StructField('BLK', IntegerType()),
    StructField('TOV', IntegerType()),
    StructField('PF', IntegerType()),
    StructField('PTS', IntegerType()),
    StructField('Player-additional', StringType())
])


In [None]:
def load_data(data_root):
    df = spark.read.format("csv").option("header", "true").schema(schema).load(data_root + "*.csv")
    df = df.withColumn("file_name", F.input_file_name())
    df = df.withColumn("Year", F.regexp_extract(df["file_name"], r"\d{4}", 0))
    df = df.drop("file_name")
    df = df.withColumn("Player", F.regexp_replace(df["Player"], "\*", ""))
    df = df.withColumnRenamed("Player-additional", "PlayerId")

    return df

def player_summary(df, name):
    return df.filter(df["Player"] == name).orderBy(F.desc("Year"))

def format_avg(numerator_col, denominator_col, decimal_places=1):
    if isinstance(denominator_col, str):
        denominator_col = F.col(denominator_col)
    return F.format_number(F.col(numerator_col) / denominator_col, decimal_places)

per_game_avg = lambda numerator_col, decimal_places=1: \
    format_avg(numerator_col, "G", decimal_places)

per_36_avg = lambda numerator_col, decimal_places=1: \
    format_avg(numerator_col, (F.col("MP") / 36), decimal_places)


def career_totals(df):
    ret_df = df.groupBy("PlayerId") \
        .agg(F.sum("PTS").alias("PTS"),
            F.sum("TRB").alias("TRB"),
            F.sum("AST").alias("AST"),
            F.sum("BLK").alias("BLK"),
            F.sum("G").alias("G"),
            F.sum("MP").alias("MP"))

    ret_df = ret_df.withColumn("PPG", per_game_avg("PTS")) \
                    .withColumn("RPG", per_game_avg("TRB")) \
                    .withColumn("APG", per_game_avg("AST")) \
                    .withColumn("BPG", per_game_avg("BLK")) \
                    .withColumn("PP36", per_36_avg("PTS")) \
                    .withColumn("RP36", per_36_avg("TRB")) \
                    .withColumn("AP36", per_36_avg("AST")) \
                    .withColumn("BP36", per_36_avg("BLK"))

    float_cols = ["PPG", "RPG", "APG", "BPG", "PP36", "RP36", "AP36", "BP36"]

    for float_col in float_cols:
        ret_df = ret_df.withColumn(float_col, F.col(float_col).cast("float"))

    return ret_df

def get_player_id(df, name):
    filtered_df = df.filter(df["Player"] == name)
    rows = filtered_df.select("PlayerId").distinct().collect()
    return [row.PlayerId for row in rows]

def get_player_name(df, id):
    filtered_df = df.filter(df["PlayerId"] == id)
    rows = filtered_df.select("Player").distinct().collect()
    return [row.Player for row in rows]

def years_played(df):
    return df.select(["PlayerId", "Year"]).distinct() \
    .groupBy("PlayerId") \
    .agg(F.count("Year").alias("Years_played"))

def players_in_cluster(df, cluster):
    return df.filter(df["prediction"] == cluster)

def players_cluster(df, id):
    filtered_df = df.filter(df["PlayerId"] == id)
    rows = filtered_df.select("prediction").distinct().collect()
    return [row.prediction for row in rows]