In [None]:
 from pyspark.sql import SparkSession
from pyspark.sql.functions import col, avg, struct, collect_list, udf
from pyspark.ml.linalg import SparseVector, VectorUDT
from pyspark.sql.types import ArrayType, StructType, StructField, IntegerType, DoubleType

# Step 1: Initialize SparkSession with updated configurations
spark = SparkSession.builder \
    .appName("AmazonReview") \
    .config("spark.driver.memory", "16g") \
    .config("spark.executor.memory", "24g") \
    .config("spark.executor.cores", "4") \
    .config("spark.driver.maxResultSize", "2g") \
    .config("spark.executor.memoryOverhead", "1g") \
    .config("spark.sql.shuffle.partitions", "400") \
    .config("spark.shuffle.spill.compress", "true") \
    .config("spark.memory.storageFraction", "0.2") \
    .config("spark.memory.fraction", "0.8") \
    .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") \
    .config("spark.executor.extraJavaOptions", "-Xss4m") \
    .config("spark.driver.extraJavaOptions", "-Xss4m") \
    .config("spark.local.dir", "/scratch/szele/") \
    .getOrCreate()

spark.sparkContext.setCheckpointDir("/scratch/szele/tmp/spark_checkpoint")

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/12/12 10:31:27 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
24/12/12 10:31:27 WARN SparkConf: Note that spark.local.dir will be overridden by the value set by the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone/kubernetes and LOCAL_DIRS in YARN).
24/12/12 10:31:28 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


In [None]:
from pyspark.sql.functions import col
from pyspark.sql.functions import concat_ws, regexp_replace, lower, when, col
from pyspark.ml.feature import Tokenizer, StopWordsRemover
from pyspark.sql.functions import udf
from pyspark.sql.types import ArrayType, StringType

# Path to your CSV file
csv_file_path = "train.csv"

# Read the CSV file
df = spark.read.csv(csv_file_path, header=True, inferSchema=True)

# Check for null values

null_counts = df.select([col(c).isNull().alias(c) for c in df.columns]).toPandas().sum()
print("Null value counts:\n", null_counts)

# Drop rows with null values in review_title or review_text
df = df.dropna(subset=["review_title", "review_text"])


duplicate_count = df.count() - df.dropDuplicates().count()
print("Number of duplicate rows:", duplicate_count)
df = df.dropDuplicates()

# Replace class_index 2 with 0 for binary classification
df = df.withColumn("class_index", when(col("class_index") == 2, 0).otherwise(col("class_index")))

# Combine review_title and review_text into a single column
df = df.withColumn("combined_text", concat_ws(" ", col("review_title"), col("review_text")))

# Convert text to lowercase
df = df.withColumn("cleaned_text", lower(col("combined_text")))

# Remove special characters and numbers, keeping only letters and spaces
df = df.withColumn("cleaned_text", regexp_replace(col("cleaned_text"), "[^a-zA-Z\\s]", ""))

# Drop unnecessary columns
df = df.drop("review_title", "review_text", "combined_text")

# Tokenize the cleaned text into words
tokenizer = Tokenizer(inputCol="cleaned_text", outputCol="tokens")
df = tokenizer.transform(df)

# Remove stop words
remover = StopWordsRemover(inputCol="tokens", outputCol="text_without_stopwords")
df = remover.transform(df)

# Define a UDF to filter out empty strings from the list
def remove_empty_strings(token_list):
    return [token for token in token_list if token.strip() != "" and len(token) >= 3]

remove_empty_udf = udf(remove_empty_strings, ArrayType(StringType()))

# Apply the UDF to clean the tokens
df = df.withColumn("text_without_stopwords", remove_empty_udf(col("text_without_stopwords")))


                                                                                

Null value counts:
 class_index      0
review_title    77
review_text      7
dtype: int64


                                                                                

Number of duplicate rows: 317


In [None]:
from pyspark.sql.functions import col
from pyspark.sql.functions import concat_ws, regexp_replace, lower, when, col
from pyspark.ml.feature import Tokenizer, StopWordsRemover
from pyspark.sql.functions import udf
from pyspark.sql.types import ArrayType, StringType

# Path to your CSV file
test_csv_file_path = "test.csv"

# Read the CSV file
testdf = spark.read.csv(test_csv_file_path, header=True, inferSchema=True)

# Check for null values

null_counts = testdf.select([col(c).isNull().alias(c) for c in testdf.columns]).toPandas().sum()
print("Null value counts:\n", null_counts)

# Drop rows with null values in review_title or review_text
testdf = testdf.dropna(subset=["review_title", "review_text"])


duplicate_count = testdf.count() - testdf.dropDuplicates().count()
print("Number of duplicate rows:", duplicate_count)
testdf = testdf.dropDuplicates()

# Replace class_index 2 with 0 for binary classification
testdf = testdf.withColumn("class_index", when(col("class_index") == 2, 0).otherwise(col("class_index")))

# Combine review_title and review_text into a single column
testdf = testdf.withColumn("combined_text", concat_ws(" ", col("review_title"), col("review_text")))

# Convert text to lowercase
testdf = testdf.withColumn("cleaned_text", lower(col("combined_text")))

# Remove special characters and numbers, keeping only letters and spaces
testdf = testdf.withColumn("cleaned_text", regexp_replace(col("cleaned_text"), "[^a-zA-Z\\s]", ""))

# Drop unnecessary columns
testdf = testdf.drop("review_title", "review_text", "combined_text")

# Tokenize the cleaned text into words
test_tokenizer = Tokenizer(inputCol="cleaned_text", outputCol="tokens")
testdf = test_tokenizer.transform(testdf)

# Remove stop words
test_remover = StopWordsRemover(inputCol="tokens", outputCol="text_without_stopwords")
testdf = test_remover.transform(testdf)

# Define a Utestdf to filter out empty strings from the list
def remove_empty_strings(token_list):
    return [token for token in token_list if token.strip() != "" and len(token) >= 3]

remove_empty_testdf = udf(remove_empty_strings, ArrayType(StringType()))

# Apply the Utestdf to clean the tokens
testdf = testdf.withColumn("text_without_stopwords", remove_empty_testdf(col("text_without_stopwords")))

from pyspark.ml.feature import HashingTF, IDF
from pyspark.sql.functions import lit



# Apply HashingTF to convert tokens into term frequency vectors
testhashingTF = HashingTF(inputCol="text_without_stopwords", outputCol="term_frequency", numFeatures=10000)
testdf = testhashingTF.transform(testdf)

# Apply IDF to compute the TF-IDF scores
test_idf = IDF(inputCol="term_frequency", outputCol="tfidf_features")
test_idf_model = test_idf.fit(testdf)  # Fit the IDF model
testdf = test_idf_model.transform(testdf)  # Transform the data using the fitted model


# Display the resulting DataFrame with TF-Itestdf features
testdf.select("text_without_stopwords", "tfitestdf_features").show(10, truncate=False)

from pyspark.ml.feature import HashingTF, IDF, PCA, ChiSqSelector
from pyspark.sql.functions import lit

# Option 1: Use PCA for dimensionality reduction
testpca = PCA(k=100, inputCol="tfitestdf_features", outputCol="reduced_features")  # Reduce to 50 dimensions
test_pca_model = testpca.fit(testdf)
testdf = test_pca_model.transform(testdf)

from pyspark.ml.linalg import Vectors,VectorUDT
from pyspark.sql.functions import udf


# Define the number of PCA components (k)
k = 100  # Ensure this matches your PCA configuration

# UDF to pad or truncate the vector to length k
def pad_or_truncate(vector, length):
    if vector.size < length:
        # Pad with zeros if the vector is smaller than k
        return Vectors.dense(list(vector.toArray()) + [0.0] * (length - vector.size))
    elif vector.size > length:
        # Truncate if the vector is larger than k (shouldn't happen in PCA)
        return Vectors.dense(vector.toArray()[:length])
    return vector

test_pad_or_truncate_udf = udf(lambda x: pad_or_truncate(x, k), VectorUDT())

# Convert PCA output to fixed-length dense vector
testdf = testdf.withColumn("reduced_features", test_pad_or_truncate_udf(col("reduced_features")))

# Drop unnecessary columns explicitly
columns_to_drop = ["cleaned_text", "tokens", "text_without_stopwords", "term_frequency","tfidf_features"]
for col in columns_to_drop:
    testdf = testdf.drop(col)

# Verify the schema
testdf.printSchema()

Null value counts:
 class_index      0
review_title    10
review_text      0
dtype: int64


                                                                                

Number of duplicate rows: 4


                                                                                

AnalysisException: [UNRESOLVED_COLUMN.WITH_SUGGESTION] A column or function parameter with name `tfitestdf_features` cannot be resolved. Did you mean one of the following? [`tfidf_features`, `cleaned_text`, `term_frequency`, `class_index`, `tokens`].;
'Project [text_without_stopwords#386, 'tfitestdf_features]
+- Project [class_index#341, cleaned_text#356, tokens#367, text_without_stopwords#386, term_frequency#393, UDF(term_frequency#393) AS tfidf_features#408]
   +- Project [class_index#341, cleaned_text#356, tokens#367, text_without_stopwords#386, UDF(text_without_stopwords#386) AS term_frequency#393]
      +- Project [class_index#341, cleaned_text#356, tokens#367, remove_empty_strings(text_without_stopwords#377)#385 AS text_without_stopwords#386]
         +- Project [class_index#341, cleaned_text#356, tokens#367, UDF(tokens#367) AS text_without_stopwords#377]
            +- Project [class_index#341, cleaned_text#356, UDF(cleaned_text#356) AS tokens#367]
               +- Project [class_index#341, cleaned_text#356]
                  +- Project [class_index#341, review_title#302, review_text#303, combined_text#345, regexp_replace(cleaned_text#350, [^a-zA-Z\s], , 1) AS cleaned_text#356]
                     +- Project [class_index#341, review_title#302, review_text#303, combined_text#345, lower(combined_text#345) AS cleaned_text#350]
                        +- Project [class_index#341, review_title#302, review_text#303, concat_ws( , review_title#302, review_text#303) AS combined_text#345]
                           +- Project [CASE WHEN (class_index#301 = 2) THEN 0 ELSE class_index#301 END AS class_index#341, review_title#302, review_text#303]
                              +- Deduplicate [class_index#301, review_title#302, review_text#303]
                                 +- Filter atleastnnonnulls(2, review_title#302, review_text#303)
                                    +- Relation [class_index#301,review_title#302,review_text#303] csv


In [None]:
# # Total row count
total_count = df.count()
print(f"Total number of rows: {total_count}")

# # Count of unique values in class_index
unique_class_count = df.select("class_index").distinct().count()
print(f"Number of unique values in 'class_index': {unique_class_count}")

# # Count occurrences of each unique value in class_index
class_counts = df.groupBy("class_index").count()
class_counts.show()

                                                                                

Total number of rows: 3599599


                                                                                

Number of unique values in 'class_index': 2




+-----------+-------+
|class_index|  count|
+-----------+-------+
|          0|1799904|
|          1|1799695|
+-----------+-------+



                                                                                

In [None]:
from pyspark.ml.feature import HashingTF, IDF
from pyspark.sql.functions import lit

# Apply HashingTF to convert tokens into term frequency vectors
hashingTF = HashingTF(inputCol="text_without_stopwords", outputCol="term_frequency", numFeatures=10000)
df = hashingTF.transform(df)

# Apply IDF to compute the TF-IDF scores
idf = IDF(inputCol="term_frequency", outputCol="tfidf_features")
idfModel = idf.fit(df)
df = idfModel.transform(df)

# Display the resulting DataFrame with TF-IDF features
df.select("text_without_stopwords", "tfidf_features").show(10, truncate=False)



+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

                                                                                

In [None]:
from pyspark.ml.feature import HashingTF, IDF, PCA, ChiSqSelector
from pyspark.sql.functions import lit

# Option 1: Use PCA for dimensionality reduction
pca = PCA(k=100, inputCol="tfidf_features", outputCol="reduced_features")  # Reduce to 50 dimensions
pca_model = pca.fit(df)
df = pca_model.transform(df)

24/12/12 10:42:20 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.lapack.JNILAPACK


In [None]:
from pyspark.ml.linalg import Vectors,VectorUDT
from pyspark.sql.functions import udf


# Define the number of PCA components (k)
k = 100  # Ensure this matches your PCA configuration

# UDF to pad or truncate the vector to length k
def pad_or_truncate(vector, length):
    if vector.size < length:
        # Pad with zeros if the vector is smaller than k
        return Vectors.dense(list(vector.toArray()) + [0.0] * (length - vector.size))
    elif vector.size > length:
        # Truncate if the vector is larger than k (shouldn't happen in PCA)
        return Vectors.dense(vector.toArray()[:length])
    return vector

pad_or_truncate_udf = udf(lambda x: pad_or_truncate(x, k), VectorUDT())

# Convert PCA output to fixed-length dense vector
df = df.withColumn("reduced_features", pad_or_truncate_udf(col("reduced_features")))

In [None]:
# Drop unnecessary columns explicitly
columns_to_drop = ["cleaned_text", "tokens", "text_without_stopwords", "term_frequency","tfidf_features"]
for col in columns_to_drop:
    df = df.drop(col)

# Verify the schema
df.printSchema()

root
 |-- class_index: integer (nullable = true)
 |-- reduced_features: vector (nullable = true)



In [None]:
# Verify the schema
df.printSchema()

In [None]:
labeled_fraction = 0.5
# Split data into labeled and unlabeled datasets
labeled_data = df.sample(withReplacement=False, fraction=labeled_fraction, seed=42).cache()
unlabeled_data = df.subtract(labeled_data).cache()

In [None]:
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

# Parameters
initial_temperature = 1.0
cooling_rate = 0.9


# Function to calculate metrics
def calculate_metrics(predictions):
    evaluator_accuracy = MulticlassClassificationEvaluator(labelCol="class_index", predictionCol="prediction", metricName="accuracy")
    evaluator_precision = MulticlassClassificationEvaluator(labelCol="class_index", predictionCol="prediction", metricName="weightedPrecision")
    evaluator_recall = MulticlassClassificationEvaluator(labelCol="class_index", predictionCol="prediction", metricName="weightedRecall")

    accuracy = evaluator_accuracy.evaluate(predictions)
    precision = evaluator_precision.evaluate(predictions)
    recall = evaluator_recall.evaluate(predictions)

    return accuracy, precision, recall


In [None]:
# Train initial Random Forest on labeled data
rf = RandomForestClassifier(featuresCol="reduced_features", labelCol="class_index", numTrees=100)
rf_model = rf.fit(labeled_data)


single iteration implementation

In [None]:
'''
# Import necessary libraries
from pyspark.sql.functions import col

# First iteration of iterative training
print("\n--- Iteration 1 ---")

# Step 4.1: Assign pseudo-labels to unlabeled data
pseudo_labeled_data = rf_model.transform(unlabeled_data).select("reduced_features", "prediction")

# Rename and cast columns to match the labeled data schema
pseudo_labeled_data = (
    pseudo_labeled_data
    .withColumnRenamed("prediction", "class_index")            # Rename 'prediction' to 'class_index'
    .withColumn("class_index", col("class_index").cast("int"))  # Ensure class_index is INT
)

# Reorder columns in pseudo_labeled_data to match labeled_data
pseudo_labeled_data = pseudo_labeled_data.select("class_index", "reduced_features")

# Cache pseudo-labeled data
pseudo_labeled_data = pseudo_labeled_data.cache()
pseudo_labeled_data.count()  # Trigger caching

# Step 4.2: Combine labeled and pseudo-labeled data
# Ensure schemas of both datasets are identical
combined_data = labeled_data.union(pseudo_labeled_data).checkpoint()

# Step 4.3: Train new Random Forest with combined data
rf_model = rf.fit(combined_data)

# Step 4.4: Evaluate Out-Of-Bag Error (OOBE) on labeled data
predictions = rf_model.transform(labeled_data)
accuracy, precision, recall = calculate_metrics(predictions)

print(f"Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}")

# Step 4.5: Reduce temperature (if using simulated annealing or similar strategy)
temperature = initial_temperature * cooling_rate

# Unpersist pseudo-labeled data
pseudo_labeled_data.unpersist()
'''

experiment 1: 50 percent split

In [None]:
# Import necessary libraries
from pyspark.sql.functions import col

# Parameters for iterative training
num_iterations = 10
initial_temperature = 1.0
cooling_rate = 0.9
temperature = initial_temperature
previous_oobe = None

for iteration in range(1, num_iterations + 1):
    print(f"\n--- Iteration {iteration} ---")

    # Step 4.1: Assign pseudo-labels to unlabeled data
    pseudo_labeled_data = rf_model.transform(unlabeled_data).select("reduced_features", "prediction")

    # Rename and cast columns to match the labeled data schema
    pseudo_labeled_data = (
        pseudo_labeled_data
        .withColumnRenamed("prediction", "class_index")            # Rename 'prediction' to 'class_index'
        .withColumn("class_index", col("class_index").cast("int"))  # Ensure class_index is INT
        .select("class_index", "reduced_features")  # Reorder columns to match labeled_data
    )

    # Cache pseudo-labeled data
    pseudo_labeled_data = pseudo_labeled_data.cache()
    pseudo_labeled_data.count()  # Trigger caching

    # Step 4.2: Combine labeled and pseudo-labeled data
    combined_data = labeled_data.union(pseudo_labeled_data).checkpoint()

    # Step 4.3: Train new Random Forest with combined data
    rf_model = rf.fit(combined_data)

    # Step 4.4: Evaluate Out-Of-Bag Error (OOBE) on labeled data
    predictions = rf_model.transform(labeled_data)
    accuracy, precision, recall = calculate_metrics(predictions)

    print(f"Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}")

    # Monitor OOBE to decide whether to continue
    current_oobe = 1 - accuracy  # OOBE is 1 - accuracy
    if previous_oobe is not None and current_oobe > previous_oobe:
        print("OOBE increased. Stopping training and reverting to the previous model.")
        break
    previous_oobe = current_oobe

    # Step 4.5: Reduce temperature
    temperature *= cooling_rate

    # Unpersist pseudo-labeled data from the previous iteration
    pseudo_labeled_data.unpersist()


Experiment 2: multiple splits

In [None]:
# Import necessary libraries
from pyspark.sql.functions import col
from pyspark.ml.classification import RandomForestClassifier

# Parameters for iterative training
num_iterations = 10
initial_temperature = 1.0
cooling_rate = 0.9
temperature = initial_temperature
previous_oobe = None

# Step 1: Initialize Random Forest model
rf = RandomForestClassifier(featuresCol="reduced_features", labelCol="class_index", numTrees=100, maxDepth=10, seed=42)

# Step 2: Split data into labeled and unlabeled subsets
labeled_data = df.sample(withReplacement=False, fraction=0.25, seed=42)
unlabeled_data = df.subtract(labeled_data).cache()
unlabeled_initial = unlabeled_data.sample(withReplacement=False, fraction=0.33, seed=42)  # 25% of the total
gradual_unlabeled_remaining = unlabeled_data.subtract(unlabeled_initial).cache()  # Remaining 50%

for iteration in range(1, num_iterations + 1):
    print(f"\n--- Iteration {iteration} ---")

    # Step 3: Train the Random Forest model on labeled data
    rf_model = rf.fit(labeled_data)

    # Step 4: Predict pseudo-labels for the current chunk of unlabeled data
    pseudo_labeled_initial = rf_model.transform(unlabeled_initial).select("reduced_features", "prediction")

    # Rename and cast columns to match the labeled data schema
    pseudo_labeled_initial = (
        pseudo_labeled_initial
        .withColumnRenamed("prediction", "class_index")            # Rename 'prediction' to 'class_index'
        .withColumn("class_index", col("class_index").cast("int"))  # Ensure class_index is INT
        .select("class_index", "reduced_features")  # Reorder columns to match labeled_data
    )

    # Step 5: Combine labeled data with pseudo-labeled data
    combined_data = labeled_data.union(pseudo_labeled_initial).checkpoint()

    # Step 6: Train the Random Forest model on combined data
    rf_model = rf.fit(combined_data)

    # Step 7: Evaluate Out-Of-Bag Error (OOBE) on labeled data
    predictions = rf_model.transform(labeled_data)
    accuracy, precision, recall = calculate_metrics(predictions)

    print(f"Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}")

    # Monitor OOBE to decide whether to continue
    current_oobe = 1 - accuracy  # OOBE is 1 - accuracy
    print(f"OOBE: {current_oobe:.4f}")

    if previous_oobe is not None and current_oobe > previous_oobe:
        print("OOBE increased. Stopping training and reverting to the previous model.")
        break
    previous_oobe = current_oobe

    # Step 8: Gradually add more unlabeled data for the next iteration
    if iteration < num_iterations:
        additional_unlabeled = gradual_unlabeled_remaining.sample(withReplacement=False, fraction=0.25, seed=42)
        unlabeled_initial = unlabeled_initial.union(additional_unlabeled).cache()

    # Step 9: Reduce temperature (if using simulated annealing or similar strategy)
    temperature *= cooling_rate

    # Unpersist pseudo-labeled data from the previous iteration
    pseudo_labeled_initial.unpersist()


testing final model

In [None]:
# Import necessary libraries
from pyspark.sql import functions as F
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.mllib.evaluation import MulticlassMetrics
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Function to calculate confusion matrix and display metrics
def calculate_metrics_with_confusion_matrix(predictions):
    # Multiclass evaluator for overall accuracy
    evaluator = MulticlassClassificationEvaluator(
        labelCol="class_index", predictionCol="prediction", metricName="accuracy"
    )
    accuracy = evaluator.evaluate(predictions)

    # Collect predictions and labels for confusion matrix
    prediction_and_labels = predictions.select("prediction", "class_index").rdd.map(tuple)
    metrics = MulticlassMetrics(prediction_and_labels)

    # Extract precision and recall per label
    labels = predictions.select("class_index").distinct().rdd.flatMap(lambda x: x).collect()
    label_metrics = {label: {"Precision": metrics.precision(label), "Recall": metrics.recall(label)} for label in labels}

    # Display confusion matrix
    confusion_matrix = metrics.confusionMatrix().toArray()
    print("\nConfusion Matrix:")
    print(confusion_matrix)

    # Plot confusion matrix using matplotlib
    plt.figure(figsize=(10, 7))
    sns.heatmap(pd.DataFrame(confusion_matrix, index=labels, columns=labels),
                annot=True, fmt=".0f", cmap="Blues")
    plt.title("Confusion Matrix")
    plt.xlabel("Predicted Label")
    plt.ylabel("True Label")
    plt.show()

    # Return overall metrics
    return accuracy, label_metrics

# Step 10: Testing on the test dataset
print("\n--- Testing on Test Dataset ---")
test_predictions = rf_model.transform(test_data)
test_accuracy, test_label_metrics = calculate_metrics_with_confusion_matrix(test_predictions)

print(f"\nTest Accuracy: {test_accuracy:.4f}")
print("\nPer Label Metrics:")
for label, metrics in test_label_metrics.items():
    print(f"Label {label}: Precision = {metrics['Precision']:.4f}, Recall = {metrics['Recall']:.4f}")


In [None]:
# Import necessary libraries
from pyspark.sql.functions import col
from pyspark.ml.classification import RandomForestClassifier
from pyspark.mllib.evaluation import MulticlassMetrics
from pyspark.sql import DataFrame

# Function to calculate confusion matrix
def calculate_confusion_matrix(predictions: DataFrame):
    rdd = predictions.select("prediction", "class_index").rdd
    metrics = MulticlassMetrics(rdd)
    return metrics.confusionMatrix()

# Parameters for iterative training
num_iterations = 10
initial_temperature = 1.0
cooling_rate = 0.9
temperature = initial_temperature
previous_oobe = None

# Step 1: Split the dataset into training and testing sets
data = df.sample(withReplacement=False, fraction=1.0, seed=42)  # Shuffle the data
testing_data = data.sample(withReplacement=False, fraction=0.20, seed=42)  # 20% for testing
training_data = data.subtract(testing_data).cache()

# Step 2: Initialize Random Forest model
rf = RandomForestClassifier(featuresCol="reduced_features", labelCol="class_index", numTrees=100, maxDepth=10, seed=42)

# Step 3: Split training data into labeled and unlabeled subsets
labeled_data = training_data.sample(withReplacement=False, fraction=0.25, seed=42)
unlabeled_data = training_data.subtract(labeled_data).cache()
unlabeled_initial = unlabeled_data.sample(withReplacement=False, fraction=0.33, seed=42)  # 25% of the total
gradual_unlabeled_remaining = unlabeled_data.subtract(unlabeled_initial).cache()  # Remaining 50%

for iteration in range(1, num_iterations + 1):
    print(f"\n--- Iteration {iteration} ---")

    # Step 4: Train the Random Forest model on labeled data
    rf_model = rf.fit(labeled_data)

    # Step 5: Predict pseudo-labels for the current chunk of unlabeled data
    pseudo_labeled_initial = rf_model.transform(unlabeled_initial).select("reduced_features", "prediction")

    # Rename and cast columns to match the labeled data schema
    pseudo_labeled_initial = (
        pseudo_labeled_initial
        .withColumnRenamed("prediction", "class_index")            # Rename 'prediction' to 'class_index'
        .withColumn("class_index", col("class_index").cast("int"))  # Ensure class_index is INT
        .select("class_index", "reduced_features")  # Reorder columns to match labeled_data
    )

    # Step 6: Combine labeled data with pseudo-labeled data
    combined_data = labeled_data.union(pseudo_labeled_initial).checkpoint()

    # Step 7: Train the Random Forest model on combined data
    rf_model = rf.fit(combined_data)

    # Step 8: Evaluate Out-Of-Bag Error (OOBE) on labeled data
    predictions = rf_model.transform(labeled_data)
    accuracy, precision, recall = calculate_metrics(predictions)

    print(f"Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}")

    current_oobe = 1 - accuracy  # OOBE is 1 - accuracy
    print(f"OOBE: {current_oobe:.4f}")

    if previous_oobe is not None and current_oobe > previous_oobe:
        print("OOBE increased. Stopping training and reverting to the previous model.")
        break
    previous_oobe = current_oobe

    # Gradually add more unlabeled data for the next iteration
    if iteration < num_iterations:
        additional_unlabeled = gradual_unlabeled_remaining.sample(withReplacement=False, fraction=0.25, seed=42)
        unlabeled_initial = unlabeled_initial.union(additional_unlabeled).cache()

    # Reduce temperature (if using simulated annealing or similar strategy)
    temperature *= cooling_rate

    # Unpersist pseudo-labeled data from the previous iteration
    pseudo_labeled_initial.unpersist()

# Step 9: Evaluate the model on the test set
test_predictions = rf_model.transform(testing_data)
confusion_matrix = calculate_confusion_matrix(test_predictions)
print("Confusion Matrix for Testing Dataset:")
print(confusion_matrix.toArray())



--- Iteration 1 ---


24/12/12 11:51:30 WARN DAGScheduler: Broadcasting large task binary with size 7.9 MiB
24/12/12 11:51:30 WARN DAGScheduler: Broadcasting large task binary with size 7.9 MiB
24/12/12 11:55:44 WARN DAGScheduler: Broadcasting large task binary with size 7.9 MiB
24/12/12 11:56:09 WARN DAGScheduler: Broadcasting large task binary with size 7.9 MiB
24/12/12 11:56:19 WARN DAGScheduler: Broadcasting large task binary with size 7.9 MiB
24/12/12 11:56:19 WARN DAGScheduler: Broadcasting large task binary with size 7.9 MiB
24/12/12 11:56:20 WARN DAGScheduler: Broadcasting large task binary with size 7.9 MiB
24/12/12 11:56:29 WARN DAGScheduler: Broadcasting large task binary with size 8.0 MiB
24/12/12 11:56:37 WARN DAGScheduler: Broadcasting large task binary with size 8.0 MiB
24/12/12 11:56:47 WARN DAGScheduler: Broadcasting large task binary with size 8.1 MiB
24/12/12 11:57:00 WARN DAGScheduler: Broadcasting large task binary with size 8.2 MiB
24/12/12 11:57:14 WARN DAGScheduler: Broadcasting larg

Accuracy: 0.7699, Precision: 0.7702, Recall: 0.7699
OOBE: 0.2301

--- Iteration 2 ---


24/12/12 12:14:57 WARN DAGScheduler: Broadcasting large task binary with size 7.9 MiB
24/12/12 12:15:03 WARN DAGScheduler: Broadcasting large task binary with size 7.9 MiB
24/12/12 12:15:03 WARN DAGScheduler: Broadcasting large task binary with size 7.9 MiB
24/12/12 12:15:04 WARN DAGScheduler: Broadcasting large task binary with size 7.9 MiB
24/12/12 12:15:13 WARN DAGScheduler: Broadcasting large task binary with size 8.0 MiB
24/12/12 12:15:22 WARN DAGScheduler: Broadcasting large task binary with size 8.0 MiB
24/12/12 12:15:32 WARN DAGScheduler: Broadcasting large task binary with size 8.1 MiB
24/12/12 12:15:45 WARN DAGScheduler: Broadcasting large task binary with size 8.2 MiB
24/12/12 12:16:01 WARN DAGScheduler: Broadcasting large task binary with size 8.4 MiB
24/12/12 12:16:15 WARN DAGScheduler: Broadcasting large task binary with size 8.9 MiB
24/12/12 12:16:32 WARN DAGScheduler: Broadcasting large task binary with size 9.9 MiB
24/12/12 12:16:54 WARN DAGScheduler: Broadcasting larg

Accuracy: 0.7672, Precision: 0.7676, Recall: 0.7672
OOBE: 0.2328
OOBE increased. Stopping training and reverting to the previous model.


24/12/12 12:38:31 WARN DAGScheduler: Broadcasting large task binary with size 16.3 MiB
24/12/12 12:38:34 WARN DAGScheduler: Broadcasting large task binary with size 16.3 MiB
24/12/12 12:38:45 ERROR Executor: Exception in task 3.0 in stage 351.0 (TID 75660)
org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/home/szele/anaconda3/envs/MMD/lib/python3.12/site-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 830, in main
    process()
  File "/home/szele/anaconda3/envs/MMD/lib/python3.12/site-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 822, in process
    serializer.dump_stream(out_iter, outfile)
  File "/home/szele/anaconda3/envs/MMD/lib/python3.12/site-packages/pyspark/python/lib/pyspark.zip/pyspark/serializers.py", line 274, in dump_stream
    vs = list(itertools.islice(iterator, batch))
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/szele/anaconda3/envs/MMD/lib/python3.12/site-packages/pyspark/pyt

Py4JJavaError: An error occurred while calling o1451.confusionMatrix.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 3 in stage 351.0 failed 1 times, most recent failure: Lost task 3.0 in stage 351.0 (TID 75660) (hop043.orc.gmu.edu executor driver): org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/home/szele/anaconda3/envs/MMD/lib/python3.12/site-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 830, in main
    process()
  File "/home/szele/anaconda3/envs/MMD/lib/python3.12/site-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 822, in process
    serializer.dump_stream(out_iter, outfile)
  File "/home/szele/anaconda3/envs/MMD/lib/python3.12/site-packages/pyspark/python/lib/pyspark.zip/pyspark/serializers.py", line 274, in dump_stream
    vs = list(itertools.islice(iterator, batch))
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/szele/anaconda3/envs/MMD/lib/python3.12/site-packages/pyspark/python/lib/pyspark.zip/pyspark/util.py", line 81, in wrapper
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "/home/szele/anaconda3/envs/MMD/lib/python3.12/site-packages/pyspark/sql/session.py", line 1292, in prepare
    verify_func(obj)
  File "/home/szele/anaconda3/envs/MMD/lib/python3.12/site-packages/pyspark/sql/types.py", line 2001, in verify
    verify_value(obj)
  File "/home/szele/anaconda3/envs/MMD/lib/python3.12/site-packages/pyspark/sql/types.py", line 1979, in verify_struct
    verifier(v)
  File "/home/szele/anaconda3/envs/MMD/lib/python3.12/site-packages/pyspark/sql/types.py", line 2001, in verify
    verify_value(obj)
  File "/home/szele/anaconda3/envs/MMD/lib/python3.12/site-packages/pyspark/sql/types.py", line 1995, in verify_default
    verify_acceptable_types(obj)
  File "/home/szele/anaconda3/envs/MMD/lib/python3.12/site-packages/pyspark/sql/types.py", line 1871, in verify_acceptable_types
    raise TypeError(
TypeError: field label: DoubleType() can not accept object 1 in type <class 'int'>

	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:561)
	at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:767)
	at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:749)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:514)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:491)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:760)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:197)
	at org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)
	at org.apache.spark.shuffle.ShuffleWriteProcessor.write(ShuffleWriteProcessor.scala:59)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:101)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:161)
	at org.apache.spark.scheduler.Task.run(Task.scala:139)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:554)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1529)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:557)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1128)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:628)
	at java.base/java.lang.Thread.run(Thread.java:829)

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2785)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2721)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2720)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2720)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1206)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1206)
	at scala.Option.foreach(Option.scala:407)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1206)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2984)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2923)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2912)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:971)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2263)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2284)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2303)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2328)
	at org.apache.spark.rdd.RDD.$anonfun$collect$1(RDD.scala:1019)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
	at org.apache.spark.rdd.RDD.withScope(RDD.scala:405)
	at org.apache.spark.rdd.RDD.collect(RDD.scala:1018)
	at org.apache.spark.rdd.PairRDDFunctions.$anonfun$collectAsMap$1(PairRDDFunctions.scala:738)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
	at org.apache.spark.rdd.RDD.withScope(RDD.scala:405)
	at org.apache.spark.rdd.PairRDDFunctions.collectAsMap(PairRDDFunctions.scala:737)
	at org.apache.spark.mllib.evaluation.MulticlassMetrics.confusions$lzycompute(MulticlassMetrics.scala:61)
	at org.apache.spark.mllib.evaluation.MulticlassMetrics.confusions(MulticlassMetrics.scala:52)
	at org.apache.spark.mllib.evaluation.MulticlassMetrics.tpByClass$lzycompute(MulticlassMetrics.scala:78)
	at org.apache.spark.mllib.evaluation.MulticlassMetrics.tpByClass(MulticlassMetrics.scala:76)
	at org.apache.spark.mllib.evaluation.MulticlassMetrics.labels$lzycompute(MulticlassMetrics.scala:241)
	at org.apache.spark.mllib.evaluation.MulticlassMetrics.labels(MulticlassMetrics.scala:241)
	at org.apache.spark.mllib.evaluation.MulticlassMetrics.confusionMatrix(MulticlassMetrics.scala:113)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.base/java.lang.reflect.Method.invoke(Method.java:566)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:374)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
	at py4j.ClientServerConnection.run(ClientServerConnection.java:106)
	at java.base/java.lang.Thread.run(Thread.java:829)
Caused by: org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/home/szele/anaconda3/envs/MMD/lib/python3.12/site-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 830, in main
    process()
  File "/home/szele/anaconda3/envs/MMD/lib/python3.12/site-packages/pyspark/python/lib/pyspark.zip/pyspark/worker.py", line 822, in process
    serializer.dump_stream(out_iter, outfile)
  File "/home/szele/anaconda3/envs/MMD/lib/python3.12/site-packages/pyspark/python/lib/pyspark.zip/pyspark/serializers.py", line 274, in dump_stream
    vs = list(itertools.islice(iterator, batch))
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/szele/anaconda3/envs/MMD/lib/python3.12/site-packages/pyspark/python/lib/pyspark.zip/pyspark/util.py", line 81, in wrapper
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "/home/szele/anaconda3/envs/MMD/lib/python3.12/site-packages/pyspark/sql/session.py", line 1292, in prepare
    verify_func(obj)
  File "/home/szele/anaconda3/envs/MMD/lib/python3.12/site-packages/pyspark/sql/types.py", line 2001, in verify
    verify_value(obj)
  File "/home/szele/anaconda3/envs/MMD/lib/python3.12/site-packages/pyspark/sql/types.py", line 1979, in verify_struct
    verifier(v)
  File "/home/szele/anaconda3/envs/MMD/lib/python3.12/site-packages/pyspark/sql/types.py", line 2001, in verify
    verify_value(obj)
  File "/home/szele/anaconda3/envs/MMD/lib/python3.12/site-packages/pyspark/sql/types.py", line 1995, in verify_default
    verify_acceptable_types(obj)
  File "/home/szele/anaconda3/envs/MMD/lib/python3.12/site-packages/pyspark/sql/types.py", line 1871, in verify_acceptable_types
    raise TypeError(
TypeError: field label: DoubleType() can not accept object 1 in type <class 'int'>

	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:561)
	at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:767)
	at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:749)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:514)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:491)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:760)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:197)
	at org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)
	at org.apache.spark.shuffle.ShuffleWriteProcessor.write(ShuffleWriteProcessor.scala:59)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:101)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:53)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:161)
	at org.apache.spark.scheduler.Task.run(Task.scala:139)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:554)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1529)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:557)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1128)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:628)
	... 1 more


In [None]:
from pyspark.mllib.evaluation import MulticlassMetrics
from pyspark.sql import DataFrame
from pyspark.sql.functions import col, countDistinct

# Function to calculate confusion matrix
def calculate_confusion_matrix(predictions: DataFrame):
    rdd = predictions.select("prediction", "class_index").rdd
    metrics = MulticlassMetrics(rdd)
    return metrics.confusionMatrix()

# Function to retrieve confusion matrix from existing predictions with type casting
def retrieve_confusion_matrix_with_cast(predictions: DataFrame):
    try:
        # Print the schema of predictions
        print("Schema of predictions:")
        predictions.printSchema()

        # Print total count and unique class count
        total_count = predictions.count()
        unique_classes = predictions.select("prediction").agg(countDistinct("prediction")).collect()[0][0]
        print(f"Total count of predictions: {total_count}")
        print(f"Unique class count in predictions: {unique_classes}")

        # Cast both prediction and class_index columns to DoubleType
        predictions = predictions.withColumn("prediction", col("prediction").cast("double"))
        predictions = predictions.withColumn("class_index", col("class_index").cast("double"))

        # Create RDD and calculate confusion matrix
        rdd = predictions.select("prediction", "class_index").rdd
        metrics = MulticlassMetrics(rdd)
        return metrics.confusionMatrix()
    except Exception as e:
        print(f"Error retrieving confusion matrix: {e}")
        return None

confusion_matrix = retrieve_confusion_matrix_with_cast(test_predictions)
print("Confusion Matrix for Testing Dataset:")
if confusion_matrix:
    print(confusion_matrix.toArray())

Schema of predictions:
root
 |-- class_index: integer (nullable = true)
 |-- reduced_features: vector (nullable = true)
 |-- rawPrediction: vector (nullable = true)
 |-- probability: vector (nullable = true)
 |-- prediction: double (nullable = false)



24/12/12 12:54:34 WARN DAGScheduler: Broadcasting large task binary with size 16.3 MiB
24/12/12 12:56:46 WARN DAGScheduler: Broadcasting large task binary with size 16.3 MiB
                                                                                

Total count of predictions: 720034
Unique class count in predictions: 2


24/12/12 12:56:50 WARN DAGScheduler: Broadcasting large task binary with size 16.3 MiB
24/12/12 12:56:51 WARN DAGScheduler: Broadcasting large task binary with size 16.3 MiB

Confusion Matrix for Testing Dataset:
[[282384.  77713.]
 [ 91118. 268819.]]


                                                                                