In [1]:
import os
# Set JAVA_HOME to Java 17 which is already installed.
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-17-openjdk-amd64"
os.environ["PATH"] += os.pathsep + os.path.join(os.environ["JAVA_HOME"], "bin")

#Install the required libraries
!pip install pyspark
# Initialize a Spark session
from pyspark.sql import SparkSession

# Stop any existing Spark session to ensure new configurations take effect
if 'spark' in locals() and spark is not None:
    spark.stop()

spark = (
    SparkSession.builder
    .appName("05_association_rules_pyspark")
    .master("local[*]")
    .config("spark.driver.memory", "6g")
    .config("spark.executor.memory", "6g")
    .config("spark.sql.shuffle.partitions", "8")
    .getOrCreate()
)
spark.sparkContext.setLogLevel("WARN")



In [2]:
import os
import argparse
from pyspark.sql import SparkSession
from pyspark.ml.fpm import FPGrowth
from pyspark.ml.feature import QuantileDiscretizer
import pyspark.sql.functions as F

parser = argparse.ArgumentParser()
parser.add_argument("--featured_parquet", default="/content/data/featured.paraquet")
parser.add_argument("--out_dir", default="/content/results")
parser.add_argument("--num_bins", type=int, default=3)
parser.add_argument("--min_support", type=float, default=0.01)

args = parser.parse_args(args=[])

In [3]:
df = spark.read.parquet(args.featured_parquet)
print("Loaded df count:", df.count())

Loaded df count: 2499784


In [4]:
numeric_cols = [c for c,t in df.dtypes if t in ('double','int','long','float','bigint','tinyint','smallint','decimal')]

In [5]:
numeric_cols = [c for c in numeric_cols if c not in ('y_binary',)]

In [6]:
sample = df.sample(False, 0.05, seed=42)
stats = sample.select([F.variance(c).alias(c) for c in numeric_cols]).collect()[0].asDict()
sorted_vars = sorted(stats.items(), key=lambda x: (x[1] is None, -float(x[1]) if x[1] is not None else 0))
chosen = [c for c,_ in sorted_vars[:10]]
print("Chosen numeric cols for discretization:", chosen)

Chosen numeric cols for discretization: ['Flow Duration', 'Fwd IAT Total', 'Bwd IAT Total', 'Fwd IAT Max', 'Flow IAT Max', 'Bwd IAT Max', 'Fwd IAT Std', 'Fwd IAT Mean', 'Bwd IAT Mean', 'Fwd IAT Min']


In [7]:
binned_cols = []
for c in chosen:
    out_col = c + "_bin"
    discretizer = QuantileDiscretizer(numBuckets=args.num_bins, inputCol=c, outputCol=out_col, relativeError=0.01)
    df = discretizer.fit(df).transform(df)
    # map numeric bin indices to string items like col_bin_0
    df = df.withColumn(out_col + "_item", F.concat(F.lit(c + "_bin_"), F.col(out_col).cast("int").cast("string")))
    binned_cols.append(out_col + "_item")

In [8]:
dummy_cols = [c for c,t in df.dtypes if c.endswith("_ohe") or (t=='int' and df.select(c).distinct().count()==2)]

In [9]:
dummy_cols = []
print("binned columns items count:", len(binned_cols))

binned columns items count: 10


In [10]:
df_items = df.select(*binned_cols).withColumn("items", F.array(*[F.col(c) for c in binned_cols])).select("items")
df_items = df_items.na.drop()
print("Prepared transactions count:", df_items.count())

Prepared transactions count: 2499784


In [11]:
fp = FPGrowth(itemsCol="items", minSupport=args.min_support, minConfidence=0.6)
model = fp.fit(df_items)
freq_itemsets = model.freqItemsets
rules = model.associationRules

In [12]:
os.makedirs(args.out_dir, exist_ok=True)
freq_itemsets.coalesce(1).write.mode("overwrite").parquet(os.path.join(args.out_dir, "fp_freq_itemsets.parquet"))
rules.coalesce(1).write.mode("overwrite").parquet(os.path.join(args.out_dir, "fp_rules.parquet"))
print("Saved frequent itemsets and rules to results/")

Saved frequent itemsets and rules to results/


In [13]:
rules.orderBy(F.desc("confidence")).limit(100).toPandas().to_csv(os.path.join(args.out_dir, "association_rules_top100.csv"), index=False)
print("Saved association_rules_top100.csv")

Saved association_rules_top100.csv


In [None]:
spark.stop()