In [None]:
from pyspark.sql import functions as sf, DataFrame, SparkSession
from abc import ABCMeta, abstractmethod
from typing import List, Set, Dict, Optional
from dataclasses import dataclass

from pyspark.sql import Column, DataFrame
from pyspark.sql import functions as sf

In [None]:
spark = SparkSession.builder.getOrCreate()

In [None]:
class FeatureGroup(metaclass=ABCMeta):
    alias: str = ...
    source: str = ...
    keys: str = ...
    supported_levels: Set[str] = set()
    available_features: Dict[str, Column] = {}
    # depend_on: List["FeatureGroup"]

    def __init__(self, features: List[str]):
        self.features = features

    def apply(self, data: DataFrame, level: str):
        print(self.keys, data.columns)
        if any(key not in data.columns for key in self.keys):
            raise KeyError()
        print(self, "apply", data.columns)
        if level not in self.supported_levels:
            raise ValueError
        source = self._transform(self._read(), level=level)
        return data.join(source.alias(self.alias), on=self.keys, how="left")

    @abstractmethod
    def _read(self) -> DataFrame:
        pass

    @staticmethod
    def _transform(data, level: str) -> DataFrame:
        return data
    
    @property
    def selections(self):
        return [self.available_features[feature].select for feature in self.features]

In [None]:
@dataclass
class Feature:
    name: str
    query: Column
    default: Optional[Column] = None
        
    @property
    def select(self) -> Column:
        query = self.query
        if self.default is not None:
            query = sf.coalesce(query, sf.lit(self.default))
        return query.alias(self.name)

In [None]:
class LocationFeatures(FeatureGroup):
    
    
    alias = "location"
    source = "some/path/to/weather"
    keys = ["party_id"]
    supported_levels = {"party"}
    available_features = {
        "postcode": Feature("postcode", sf.col("location.postcode")),
        "is_foreign": Feature("is_foreign", sf.col("location.postcode") == sf.lit(""), False)
    }
    
    def _read(self) -> DataFrame:
        return spark.createDataFrame([
            dict(party_id="a", postcode="1234AB"),
            dict(party_id="b", postcode="1234BC"),
            dict(party_id="c", postcode=""),
        ])
    
    
class DemographicFeatures(FeatureGroup):
    alias = "demo"
    source = ".."
    keys = ["postcode"]
    supported_levels = {"party"}
    available_features = {
        "n_people": Feature("n_people", sf.col("demo.n_people"), default=1),
        }
    
    def _read(self) -> DataFrame:
        return spark.createDataFrame([
            dict(postcode="1234AB", n_people=1),
            dict(postcode="1234BC", n_people=4),
            dict(postcode="3456GF", n_people=8),
        ])


class TransactionFeatures(FeatureGroup):
    alias = "tx"
    keys = ["party_id"]
    supported_levels = {"party"}
    available_features = {
        "n_transactions": Feature("n_transactions", sf.col("tx.n_transactions"))
    }
    
    def _read(self) -> DataFrame:
        return spark.createDataFrame([
            dict(party_id="a", tx_id=1, counterparty="b"),
            dict(party_id="a", tx_id=2, counterparty="e"),
            dict(party_id="b", tx_id=3, counterparty="e"),
            dict(party_id="c", tx_id=4, counterparty="d"),
            dict(party_id="c", tx_id=5, counterparty="a"),
        ])

    @staticmethod
    def _transform(data, level: str) -> DataFrame:
        return data.groupby("party_id").agg(
            sf.count("*").alias("n_transactions")
        )
    
    
class SecondaryPartyFeatures(FeatureGroup):
    alias = "stx"
    keys = ["party_id"]
    supported_levels = {"party"}
    available_features = {
        "traded_with_foreign": Feature("traded_with_foreign", sf.col("stx.traded_with_foreign"))
    }


    @staticmethod
    def _transform(data: DataFrame, level: str) -> DataFrame:
        dependent_features = LocationFeatures(features=["is_foreign"])
        return (
            data
            .withColumnRenamed("party_id", "original_party_id")
            .withColumnRenamed("counterparty", "party_id")
            .transform(dependent_features.apply, level=level)
            .select(*dependent_features.selections, "original_party_id")
            .groupby(sf.col("original_party_id").alias("party_id")).agg(
                sf.max("is_foreign").alias("traded_with_foreign")
            )
        )
    
    def _read(self) -> DataFrame:
        return spark.createDataFrame([
            dict(party_id="a", tx_id=1, counterparty="b"),
            dict(party_id="a", tx_id=2, counterparty="c"),
            dict(party_id="b", tx_id=3, counterparty="e"),
            dict(party_id="c", tx_id=4, counterparty="d"),
            dict(party_id="c", tx_id=5, counterparty="a"),
        ])


In [None]:
def create_features(level: str, feature_groups: List[FeatureGroup]):
    data: DataFrame = spark.createDataFrame([dict(party_id="a"), dict(party_id="b"), dict(party_id="c")]) 
    columns: List[Column] = []
    base_key = f"{level}_id"
    for feature_group in feature_groups:
        # for dependent_feature_group in feature_group.depends_on:
            # data = dependent_feature_group.apply(data, level=level)
        data = feature_group.apply(data, level=level)
        columns += feature_group.selections
    print(columns)
    return data.select(base_key, *columns)


features = create_features("party", feature_groups=[
    LocationFeatures(features=["postcode"]),
    TransactionFeatures(features=["n_transactions"]),
    SecondaryPartyFeatures(features=["traded_with_foreign"]),
    DemographicFeatures(features=["n_people"])
])


In [None]:
features.toPandas()

In [None]:

wow = spark.read.parquet("wow.parquet")
wow.limit(10).toPandas()

In [None]:
wow.select("zone").distinct().collect()

In [None]:
(
    wow
    .withColumn("durotar_flag", sf.when(sf.col("zone") == " Durotar", 1).otherwise(0))
    .groupby("avatarId")
    .agg(
        sf.count("*").alias("total_count"),
        sf.sum("durotar_flag").alias("durotar_count")
    )
    .withColumn("durotar_prob", sf.col("durotar_count") / sf.col("total_count"))
    .show()
)