In [3]:
from pyspark.sql import SparkSession
from model import SparkModelTrainer

# Initialize SparkSession
spark = SparkSession.builder \
    .appName("ModelTraining") \
    .getOrCreate()

# Example DataFrame creation (Replace this with your actual data)
# For demonstration, we'll create a simple DataFrame with random data
import pandas as pd
import numpy as np

# Create a pandas DataFrame
pandas_df = pd.read_csv('data/HI-Medium_Trans.csv')

# Convert to Spark DataFrame
df = spark.createDataFrame(pandas_df)

# **Exclude Specific Columns:**
# Drop 'Timestamp', 'Account', 'Account.1' as they are not needed for training
df = df.drop('Timestamp', 'Account', 'Account.1')

# **Define label, numerical, and categorical columns:**
label_col = 'Is Laundering'
numerical_cols = []  # No numerical columns to use as-is after binning
categorical_cols = ['From Bank', 'To Bank', 'Receiving Currency', 'Payment Currency', 'Payment Format']
binning_cols = ['Amount Received', 'Amount Paid']

# **Initialize the trainer:**
trainer = SparkModelTrainer(
    label_col=label_col,
    numerical_cols=numerical_cols,  # Empty list since all numerical columns are binned
    categorical_cols=categorical_cols,  # Initial categorical columns
    binning_cols=binning_cols,  # Columns to discretize
    num_buckets=5,  # Number of bins for discretization
    spark=spark
)

# **Prepare data:**
train_df, test_df = trainer.prepare_data(df, test_ratio=0.2, seed=42)

# **Determine Feature Vector Size Dynamically for ANN:**
# Fetch the first row of train_df to get the size of the feature vector
feature_vector = train_df.select("features").first()["features"]
input_feature_size = len(feature_vector)
print(f"Feature vector size: {input_feature_size}")

# **Define ANN layers based on feature_dim:**
# Example: [input, hidden1, hidden2, output]
layers = [input_feature_size, 20, 10, 2]
trainer.train_ann(train_df, layers=layers, max_iter=100)

# **Train Random Forest:**
trainer.train_random_forest(train_df, num_trees=100, max_depth=5)

# **Make predictions and evaluate Random Forest:**
rf_predictions = trainer.predict('random_forest', test_df)
rf_accuracy = trainer.evaluate('random_forest', rf_predictions, metric="accuracy")
rf_f1 = trainer.evaluate('random_forest', rf_predictions, metric="f1")
print(f"Random Forest Accuracy: {rf_accuracy:.4f}")
print(f"Random Forest F1 Score: {rf_f1:.4f}")

# **Make predictions and evaluate ANN:**
ann_predictions = trainer.predict('ann', test_df)
ann_f1 = trainer.evaluate('ann', ann_predictions, metric="f1")
ann_accuracy = trainer.evaluate('ann', ann_predictions, metric="accuracy")
print(f"ANN F1 Score: {ann_f1:.4f}")
print(f"ANN Accuracy: {ann_accuracy:.4f}")

# **Plot metrics:**
trainer.plot_metrics(metric="accuracy")
trainer.plot_metrics(metric="f1")

# **Alternatively, plot all metrics at once:**
trainer.plot_all_metrics()

# **Stop Spark session:**
trainer.stop_spark()

  for column, series in pdf.iteritems():


Py4JJavaError: An error occurred while calling z:org.apache.spark.api.python.PythonRDD.readRDDFromFile.
: java.lang.OutOfMemoryError: Java heap space
	at org.apache.spark.api.java.JavaRDD$.readRDDFromInputStream(JavaRDD.scala:252)
	at org.apache.spark.api.java.JavaRDD$.readRDDFromFile(JavaRDD.scala:239)
	at org.apache.spark.api.python.PythonRDD$.readRDDFromFile(PythonRDD.scala:274)
	at org.apache.spark.api.python.PythonRDD.readRDDFromFile(PythonRDD.scala)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.GatewayConnection.run(GatewayConnection.java:238)
	at java.lang.Thread.run(Thread.java:750)


24/12/23 01:12:14 WARN HeartbeatReceiver: Removing executor driver with no recent heartbeats: 932912 ms exceeds timeout 120000 ms
24/12/23 01:12:14 WARN SparkContext: Killing executors is not supported by current scheduler.
