In [None]:
def compute_neighbor_embeddings(input_df: pd.DataFrame, cfg: CFG) -> pd.DataFrame:
    """
    Creates GNN-lite features by summarizing the state of nearby players
    at the last observed frame for each player.
    """
    print("Computing GNN-lite neighbor embeddings...")
    
    # We only need specific columns for this calculation to save memory
    cols_needed = [
        "game_id", "play_id", "nfl_id", "frame_id", "x", "y",
        "velocity_x", "velocity_y", "player_side"
    ]
    src_df = input_df[cols_needed].copy()

    # Get the state of each player at their last observed frame
    last_frame_df = (
        src_df.sort_values(["game_id", "play_id", "nfl_id", "frame_id"])
              .groupby(["game_id", "play_id", "nfl_id"], as_index=False)
              .tail(1)
              .rename(columns={"frame_id": "last_frame_id"})
              .reset_index(drop=True)
    )

    # Merge last_frame_df with all players in the same play at that specific frame
    # This creates pairs of (ego_player, neighbor_player)

    merged_df = last_frame_df.merge(
        src_df.rename(columns={
            "frame_id": "nb_frame_id", "nfl_id": "nfl_id_nb", "x": "x_nb", "y": "y_nb",
            "velocity_x": "vx_nb", "velocity_y": "vy_nb", "player_side": "player_side_nb"
        }),
        left_on=["game_id", "play_id", "last_frame_id"],
        right_on=["game_id", "play_id", "nb_frame_id"],
        how="left",
    )
    
    # Remove self-comparisons
    merged_df = merged_df[merged_df["nfl_id_nb"] != merged_df["nfl_id"]]

    # Calculate relative vectors and distance
    merged_df["dx"] = merged_df["x_nb"] - merged_df["x"]
    merged_df["dy"] = merged_df["y_nb"] - merged_df["y"]
    merged_df["dvx"] = merged_df["vx_nb"] - merged_df["velocity_x"]
    merged_df["dvy"] = merged_df["vy_nb"] - merged_df["velocity_y"]
    merged_df["dist"] = np.sqrt(merged_df["dx"]**2 + merged_df["dy"]**2)

    # Filter out distant neighbors
    merged_df = merged_df[merged_df["dist"] <= cfg.RADIUS_LIMIT].copy()

    # Identify allies vs opponents
    merged_df["is_ally"] = (merged_df["player_side_nb"] == merged_df["player_side"]).astype(float)

    # Rank neighbors by distance to find the closest ones
    keys = ["game_id", "play_id", "nfl_id"]
    merged_df["rank"] = merged_df.groupby(keys)["dist"].rank(method="first")
    
    # Keep only the top K neighbors
    merged_df = merged_df[merged_df["rank"] <= cfg.K_NEIGHBORS].copy()

    # --- Attention Weighting (Softmax) ---
    merged_df["attention"] = np.exp(-merged_df["dist"] / cfg.TAU)
    attention_sum = merged_df.groupby(keys)["attention"].transform("sum")
    merged_df["norm_attention"] = merged_df["attention"] / (attention_sum + 1e-9)
    
    merged_df["norm_attention_ally"] = merged_df["norm_attention"] * merged_df["is_ally"]
    merged_df["norm_attention_opp"] = merged_df["norm_attention"] * (1.0 - merged_df["is_ally"])

    # Pre-multiply features by attention weights for weighted aggregation
    for col in ["dx", "dy", "dvx", "dvy"]:
        merged_df[f"{col}_w_ally"] = merged_df[col] * merged_df["norm_attention_ally"]
        merged_df[f"{col}_w_opp"] = merged_df[col] * merged_df["norm_attention_opp"]
    
    # Separate distances for allies and opponents for min/mean stats
    merged_df["dist_ally"] = np.where(merged_df["is_ally"] > 0.5, merged_df["dist"], np.nan)
    merged_df["dist_opp"] = np.where(merged_df["is_ally"] < 0.5, merged_df["dist"], np.nan)

    # --- Aggregation ---
    agg_dict = {
        # Weighted means of relative vectors
        "gnn_ally_dx_mean": ("dx_w_ally", "sum"),
        "gnn_ally_dy_mean": ("dy_w_ally", "sum"),
        "gnn_ally_dvx_mean": ("dvx_w_ally", "sum"),
        "gnn_ally_dvy_mean": ("dvy_w_ally", "sum"),
        "gnn_opp_dx_mean": ("dx_w_opp", "sum"),
        "gnn_opp_dy_mean": ("dy_w_opp", "sum"),
        "gnn_opp_dvx_mean": ("dvx_w_opp", "sum"),
        "gnn_opp_dvy_mean": ("dvy_w_opp", "sum"),
        # Counts and distance stats
        "gnn_ally_count": ("is_ally", "sum"),
        "gnn_ally_dist_min": ("dist_ally", "min"),
        "gnn_ally_dist_mean": ("dist_ally", "mean"),
        "gnn_opp_dist_min": ("dist_opp", "min"),
        "gnn_opp_dist_mean": ("dist_opp", "mean"),
    }
    
    gnn_features = merged_df.groupby(keys).agg(**agg_dict).reset_index()
    gnn_features["gnn_opp_count"] = cfg.K_NEIGHBORS - gnn_features["gnn_ally_count"]

    # --- Add distance to N nearest players (regardless of side) ---
    nearest_dist = merged_df.loc[merged_df['rank'] <= 3].pivot_table(
        index=keys, columns='rank', values='dist'
    ).reset_index()
    nearest_dist.columns = [f"gnn_dist_rank{int(c)}" if isinstance(c, float) else c for c in nearest_dist.columns]
    
    gnn_features = gnn_features.merge(nearest_dist, on=keys, how="left")
    
    # Fill NaNs that occur when a player has no neighbors of a certain type
    gnn_features = gnn_features.fillna(0)
    
    print("GNN-lite embeddings computed.")
    return gnn_features

gnn_train_features = compute_neighbor_embeddings(train_input_df, CFG)
display(gnn_train_features.head())
