In [0]:
from pyspark.sql import functions as F, types as T
from pyspark.sql.window import Window
from delta.tables import DeltaTable
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql.types import *
from functools import reduce

#Functions

In [0]:
def write_to_table(
    df: DataFrame,
    table_name: str,
    mode: str = "overwrite",
    merge_schema: bool = False,
    partition_by: list[str] = None,
    path: str = None,
    save_as_table: bool = True
) -> None:
    """
    Generalised Delta write helper for bronze layer.

    Parameters:
    - df (DataFrame): Spark DataFrame to write.
    - table_name (str): Name of the Delta table (used if save_as_table=True).
    - mode (str): Write mode ('overwrite', 'append', 'ignore', 'error', etc.).
    - merge_schema (bool): Whether to merge schema on write.
    - partition_by (list[str], optional): List of columns to partition by.
    - path (str, optional): Path to save the Delta table (used if save_as_table=False).
    - save_as_table (bool): If True, saves as managed table; else saves to path.

    Raises:
    - ValueError: If neither save_as_table nor path is properly specified.
    """

    df_with_ts = df.withColumn("last_updated", F.current_timestamp())

    writer = df_with_ts.write.format("delta").mode(mode)

    if merge_schema:
        writer = writer.option("mergeSchema", "true")
    elif mode == "overwrite":
        writer = writer.option("overwriteSchema", "true")

    if partition_by:
        writer = writer.partitionBy(*partition_by)

    if save_as_table:
        writer.saveAsTable(table_name)
    elif path:
        writer.save(path)
    else:
        raise ValueError("Either save_as_table must be True or a path must be provided.")

In [0]:
def merge_to_table(
    df: DataFrame,
    table_name: str,
    merge_condition: str,
    spark: SparkSession,
    partition_by: list[str] = None
) -> None:
    """
    Performs an upsert (merge) into a Delta table.

    Parameters:
    - df (DataFrame): Incoming DataFrame to merge.
    - table_name (str): Target Delta table name.
    - merge_condition (str): SQL condition for matching rows.
    - spark (SparkSession): Active Spark session.
    - partition_by (list[str], optional): Columns to partition by on initial write.

    If the table does not exist, it will be created using write_to_table.
    """
    df_with_ts = df.withColumn("last_updated", F.current_timestamp())

    if not spark.catalog.tableExists(table_name):
        write_to_table(
            df=df_with_ts,
            table_name=table_name,
            partition_by=partition_by
        )
    else:
        delta_table = DeltaTable.forName(spark, table_name)
        (
            delta_table.alias("target")
            .merge(
                source=df_with_ts.alias("source"),
                condition=merge_condition
            )
            .whenMatchedUpdateAll()
            .whenNotMatchedInsertAll()
            .execute()
        )

#Variables

In [0]:
try:
    ENV = dbutils.widgets.get("ENV")
except Exception:
    ENV = "dev"

try:
    PROTOCOL = dbutils.widgets.get("PROTOCOL")
except Exception:
    PROTOCOL = "HIST"

#ensure valid ENV and PROTOCOL
valid_envs = {"dev", "test", "prod"}
valid_protocols = {"HIST", "INCR"}

# Validate ENV
if ENV not in valid_envs:
    print(f"Invalid ENV: {ENV}. Must be one of {valid_envs}. Exiting notebook.")
    dbutils.notebook.exit("Invalid ENV")

# Validate PROTOCOL
if PROTOCOL not in valid_protocols:
    print(f"Invalid PROTOCOL: {PROTOCOL}. Must be one of {valid_protocols}. Exiting notebook.")
    dbutils.notebook.exit("Invalid PROTOCOL")
    
silver_schema = f"fpl_silver_{ENV}"
feature_schema = f"fpl_feature_{ENV}"
rolling_window_size = 5 

In [0]:
df = spark.table(f"{silver_schema}.teams")
df.printSchema()

root
 |-- season_key: integer (nullable = true)
 |-- team_key: integer (nullable = true)
 |-- team_id: integer (nullable = true)
 |-- team_name: string (nullable = true)
 |-- team_name_short: string (nullable = true)
 |-- is_promoted: boolean (nullable = true)
 |-- is_relegated: boolean (nullable = true)
 |-- last_updated: timestamp (nullable = true)



#Load Source Tables

In [0]:
fixtures_df = spark.table(f"{silver_schema}.fixtures").filter(F.col("home_team_score").isNotNull())
teams_df = spark.table(f"{silver_schema}.teams")
stats_df = spark.table(f"{silver_schema}.gameweek_stats")

#Team Feature Engineering

In [0]:
#Step 1: Aggregate xG, xA, and exp_stats_available per team per fixture
team_xg_xa_df = stats_df.groupBy("fixture_key", "team_key").agg(
    F.sum("expected_goals").alias("team_expected_goals"),
    F.sum("expected_assists").alias("team_expected_assists"),
    F.max("exp_stats_available").alias("team_exp_stats_available")
).withColumn(
    "team_expected_goal_involvements", F.col("team_expected_goals") + F.col("team_expected_assists")
)

#Create opponent xG/xA aggregates
opponent_xg_xa_df = team_xg_xa_df.select(
    F.col("fixture_key"),
    F.col("team_key").alias("opponent_team_key"),
    F.col("team_expected_goals").alias("expected_goals_against"),
    F.col("team_expected_assists").alias("expected_assists_against"),
    F.col("team_expected_goal_involvements").alias("expected_goal_involvements_against")
)

#Step 3: Transform fixtures into team-level records
home_df = fixtures_df.select(
    F.col("fixture_key"),
    F.col("season_key"),
    F.col("gameweek_key"),
    F.col("home_team_key").alias("team_key"),
    F.col("away_team_key").alias("opponent_team_key"),
    F.lit(True).alias("is_home"),
    F.col("home_team_score").alias("goals_for"),
    F.col("away_team_score").alias("goals_against")
)

away_df = fixtures_df.select(
    F.col("fixture_key"),
    F.col("season_key"),
    F.col("gameweek_key"),
    F.col("away_team_key").alias("team_key"),
    F.col("home_team_key").alias("opponent_team_key"),
    F.lit(False).alias("is_home"),
    F.col("away_team_score").alias("goals_for"),
    F.col("home_team_score").alias("goals_against")
)

team_fixtures_df = home_df.unionByName(away_df)

#Add match-level metrics
team_fixtures_df = team_fixtures_df.withColumn(
    "goal_diff", F.col("goals_for") - F.col("goals_against")
).withColumn(
    "match_points",
    F.when(F.col("goals_for") > F.col("goals_against"), F.lit(3))
     .when(F.col("goals_for") == F.col("goals_against"), F.lit(1))
     .otherwise(F.lit(0))
)

#Join team xG/xA and opponent xG/xA
team_fixtures_df = team_fixtures_df.join(
    team_xg_xa_df,
    on=["fixture_key", "team_key"],
    how="left"
).join(
    opponent_xg_xa_df,
    on=["fixture_key", "opponent_team_key"],
    how="left"
)

#Compute rolling metrics ---
rolling_window = Window.partitionBy("team_key", "season_key").orderBy("gameweek_key").rowsBetween(-rolling_window_size + 1, 0)

team_fixtures_df = (
    team_fixtures_df.withColumn("rolling_points", F.sum("match_points").over(rolling_window)) 
    .withColumn("rolling_goal_diff", F.sum("goal_diff").over(rolling_window)) 
    .withColumn("home_rolling_points", F.sum(F.when(F.col("is_home"), F.col("match_points")).otherwise(0)).over(rolling_window)) 
    .withColumn("away_rolling_points", F.sum(F.when(~F.col("is_home"), F.col("match_points")).otherwise(0)).over(rolling_window)) 
    .withColumn("rolling_team_expected_goals", F.sum("team_expected_goals").over(rolling_window)) 
    .withColumn("rolling_team_expected_assists", F.sum("team_expected_assists").over(rolling_window)) 
    .withColumn("rolling_team_expected_goal_involvements", F.sum("team_expected_goal_involvements").over(rolling_window)) 
    .withColumn("rolling_expected_goals_against", F.sum("expected_goals_against").over(rolling_window)) 
    .withColumn("rolling_expected_assists_against", F.sum("expected_assists_against").over(rolling_window)) 
    .withColumn("rolling_expected_goal_involvements_against", F.sum("expected_goal_involvements_against").over(rolling_window)) 
    .withColumn("rolling_goal_difference", F.sum("goal_diff").over(rolling_window)) 
    .withColumn("rolling_games_played", F.count("fixture_key").over(rolling_window)) 
    .withColumn("avg_team_expected_goals", F.round(F.col("rolling_team_expected_goals") / F.col("rolling_games_played"), 3)) 
    .withColumn("avg_team_expected_assists", F.round(F.col("rolling_team_expected_assists") / F.col("rolling_games_played"), 3)) 
    .withColumn("avg_team_expected_goal_involvements", F.round(F.col("rolling_team_expected_goal_involvements") / F.col("rolling_games_played"), 3)) 
    .withColumn("avg_expected_goals_against", F.round(F.col("rolling_expected_goals_against") / F.col("rolling_games_played"), 3)) 
    .withColumn("avg_expected_assists_against", F.round(F.col("rolling_expected_assists_against") / F.col("rolling_games_played"), 3)) 
    .withColumn("avg_expected_goal_involvements_against", F.round(F.col("rolling_expected_goal_involvements_against") / F.col("rolling_games_played"), 3)) 
    .withColumn("avg_goal_difference", F.round(F.col("rolling_goal_difference") / F.col("rolling_games_played"), 3))
)

#Join team metadata ---
team_features_df = team_fixtures_df.join(
    teams_df.select("team_key", "team_name", "team_name_short", "is_promoted", "is_relegated", "season_key"),
    on=["team_key", "season_key"],
    how="left"
)

#Select final columns ---
team_features_df = team_features_df.select(
    "team_key", "team_name", "team_name_short", "season_key", "gameweek_key",
    "is_home", "goals_for", "goals_against", "goal_diff", "match_points",
    "team_expected_goals", "team_expected_assists", "team_expected_goal_involvements",
    "expected_goals_against", "expected_assists_against", "expected_goal_involvements_against",
    "team_exp_stats_available",
    "rolling_points", "rolling_goal_diff", "home_rolling_points", "away_rolling_points",
    "rolling_team_expected_goals", "rolling_team_expected_assists", "rolling_team_expected_goal_involvements",
    "rolling_expected_goals_against", "rolling_expected_assists_against", "rolling_expected_goal_involvements_against",
    "rolling_goal_difference", "rolling_games_played",
    "avg_team_expected_goals", "avg_team_expected_assists", "avg_team_expected_goal_involvements",
    "avg_expected_goals_against", "avg_expected_assists_against", "avg_expected_goal_involvements_against",
    "avg_goal_difference",
    "is_promoted", "is_relegated"
)

write_to_table(
    df = team_features_df,
    table_name = f"{feature_schema}.team_features",
    mode = "overwrite"
)