In [1]:
import traceback
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, concat_ws, length, regexp_replace, size, split
from pyspark.sql.types import StructType, StructField, StringType, IntegerType
from pyspark.ml import Pipeline
from pyspark.ml.feature import Tokenizer, StopWordsRemover, HashingTF, IDF, StringIndexer, VectorAssembler
from pyspark.ml.classification import LogisticRegression, RandomForestClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

def main():
    spark = (SparkSession.builder
    .appName("GPU_IMP")
        .master("local[*]")
        .config("spark.plugins", "com.nvidia.spark.SQLPlugin")
        .config("spark.driver.host", "localhost")
        .config("spark.driver.memory","8g")
        .config("spark.rapids.sql.explain", "NONE")
        .config("spark.rapids.sql.allowMultipleJars", "ALWAYS")
        .getOrCreate()
    )
        
    data_path = "/mnt/c/Users/BerenÜnveren/Desktop/BIL401/data/train.csv"

    schema = StructType([
        StructField("Id", IntegerType(), True),
        StructField("Title", StringType(), True),
        StructField("Body", StringType(), True),
        StructField("Y", StringType(), True)
    ])

    df = spark.read.format("csv") \
        .schema(schema) \
        .option("header", "true") \
        .option("quote", "\"") \
        .option("multiLine", "true") \
        .load(data_path)
    
    df.printSchema()
    df.groupBy("Y").count().show()

    df_clean = df.na.drop(subset=["Title", "Body", "Y"]) \
        .withColumn("CleanBody", regexp_replace(col("Body"), "<.*?>", "")) \
        .withColumn("text", concat_ws(" ", col("Title"), col("CleanBody")))

    df_featured = df_clean.withColumn("title_len", length(col("Title"))) \
        .withColumn("body_len", length(col("CleanBody"))) \
        .withColumn("punct_count", length(col("text")) - length(regexp_replace(col("text"), "[?!]", ""))) \
        .withColumn("avg_word_len", length(regexp_replace(col("text"), " ", "")) / (size(split(col("text"), " ")) + 1e-6))
    
    label_indexer = StringIndexer(inputCol="Y", outputCol="label", handleInvalid="skip")
    tokenizer = Tokenizer(inputCol="text", outputCol="words")
    stopwords_remover = StopWordsRemover(inputCol="words", outputCol="filtered_words")
    hashing_tf = HashingTF(inputCol="filtered_words", outputCol="raw_features", numFeatures=20000)
    idf = IDF(inputCol="raw_features", outputCol="text_features")
    
    feature_assembler = VectorAssembler(
        inputCols=["text_features", "title_len", "body_len", "punct_count", "avg_word_len"],
        outputCol="features"
    )

    (train_data, test_data) = df_featured.randomSplit([0.8, 0.2], seed=42)
    train_data.cache()
    test_data.cache()
    """
    lr = LogisticRegression(featuresCol="features", labelCol="label", maxIter=10)
    lr_pipeline = Pipeline(stages=[label_indexer, tokenizer, stopwords_remover, hashing_tf, idf, feature_assembler, lr])
    lr_model = lr_pipeline.fit(train_data)
    lr_predictions = lr_model.transform(test_data)
    
    evaluator = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction")
    lr_accuracy = evaluator.setMetricName("accuracy").evaluate(lr_predictions)
    lr_f1 = evaluator.setMetricName("f1").evaluate(lr_predictions)
    
    print("\nLogistic Regression Evaluation")
    print(f"Accuracy: {lr_accuracy:.4f}")
    print(f"F1 Score: {lr_f1:.4f}")
    print("Confusion Matrix:")
    lr_predictions.groupBy("label", "prediction").count().orderBy("label", "prediction").show()
"""
    evaluator = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction")
    rf = RandomForestClassifier(featuresCol="features", labelCol="label", numTrees=100)
    rf_pipeline = Pipeline(stages=[label_indexer, tokenizer, stopwords_remover, hashing_tf, idf, feature_assembler, rf])
    
    rf_model = rf_pipeline.fit(train_data)
    rf_predictions = rf_model.transform(test_data)
    rf_accuracy = evaluator.setMetricName("accuracy").evaluate(rf_predictions)
    rf_f1 = evaluator.setMetricName("f1").evaluate(rf_predictions)

    print("\nRandom Forest Evaluation")
    print(f"Accuracy: {rf_accuracy:.4f}")
    print(f"F1 Score: {rf_f1:.4f}")
    print("Confusion Matrix:")
    rf_predictions.groupBy("label", "prediction").count().orderBy("label", "prediction").show()

if __name__ == '__main__':
    try:
        main()
    except Exception as e:
        print(f"An error occurred: {e}")
        traceback.print_exc()
    finally:
        from pyspark.sql import SparkSession
        spark = SparkSession.getActiveSession()
        if spark:
            spark.stop()


25/07/21 03:07:07 WARN Utils: Your hostname, DESKTOP-15VE119 resolves to a loopback address: 127.0.1.1; using 10.255.255.254 instead (on interface lo)
25/07/21 03:07:07 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/07/21 03:07:08 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
25/07/21 03:07:09 WARN RapidsPluginUtils: RAPIDS Accelerator 24.02.0 using cudf 24.02.1.
25/07/21 03:07:09 WARN RapidsPluginUtils: Multiple cudf jars found in the classpath:
revison: dd34fdbe35e68ba56a2183f11ed822ddaa6c927b
	jar URL: jar:file:/home/bunveren/miniconda3/envs/rapids-24.02/lib/python3.10/site-packages/pyspark/jars/rapids-4-spark_2.12-24.02.0.jar
	version=24.02.1
	user=
	revision=dd34fdbe35e68ba56a2183f11ed822ddaa6c927b
	branch=HEAD
	date=2024-02-28T05:34:16Z
	url=https:/

root
 |-- Id: integer (nullable = true)
 |-- Title: string (nullable = true)
 |-- Body: string (nullable = true)
 |-- Y: string (nullable = true)



                                                                                

+--------------------+-----+
|                   Y|count|
+--------------------+-----+
|<intellij-idea><a...|    1|
|<matlab><neural-n...|    1|
| <regex><python-3.x>|    1|
|<java><java-strea...|    1|
|<node.js><windows...|    1|
|<build-process><t...|    1|
|<python><pandas><...|    3|
|<docker><apk><alp...|    1|
|<php><mysql><sql>...|    1|
|<android><android...|    1|
|<architecture><en...|    2|
|<asp.net-core><en...|    3|
|<sql><oracle><sna...|    1|
|<delphi><delphi-xe7>|    1|
|<c#><asp.net><dat...|    2|
|<symfony><doctrin...|    1|
|<java><websocket>...|    1|
|<python><web-scra...|    1|
|<java><arraylist>...|    1|
|<java><mysql><sql...|    1|
+--------------------+-----+
only showing top 20 rows



25/07/21 03:07:36 WARN DAGScheduler: Broadcasting large task binary with size 1096.6 KiB
25/07/21 03:07:42 WARN DAGScheduler: Broadcasting large task binary with size 2.7 MiB
25/07/21 03:07:42 WARN DAGScheduler: Broadcasting large task binary with size 2.7 MiB
25/07/21 03:07:47 WARN DAGScheduler: Broadcasting large task binary with size 2.9 MiB
25/07/21 03:07:58 WARN DAGScheduler: Broadcasting large task binary with size 3.4 MiB
25/07/21 03:08:19 WARN DAGScheduler: Broadcasting large task binary with size 4.4 MiB
25/07/21 03:08:21 WARN DAGScheduler: Broadcasting large task binary with size 4.9 MiB
25/07/21 03:08:23 WARN DAGScheduler: Broadcasting large task binary with size 6.3 MiB
25/07/21 03:08:25 WARN DAGScheduler: Broadcasting large task binary with size 7.5 MiB
25/07/21 03:08:27 WARN DAGScheduler: Broadcasting large task binary with size 8.9 MiB
25/07/21 03:08:29 WARN DAGScheduler: Broadcasting large task binary with size 10.4 MiB
25/07/21 03:08:30 WARN DAGScheduler: Broadcasting 


Random Forest Evaluation
Accuracy: 0.1212
F1 Score: 0.0646
Confusion Matrix:


25/07/21 03:28:48 WARN DAGScheduler: Broadcasting large task binary with size 536.7 MiB
                                                                                

+-----+----------+-----+
|label|prediction|count|
+-----+----------+-----+
|  0.0|       0.0|  143|
|  0.0|       1.0|    2|
|  1.0|       0.0|   33|
|  1.0|       1.0|  114|
|  1.0|       8.0|    1|
|  2.0|       0.0|   56|
|  2.0|       1.0|   15|
|  2.0|       2.0|   62|
|  2.0|       3.0|    1|
|  2.0|       8.0|    1|
|  3.0|       0.0|   16|
|  3.0|       1.0|    6|
|  3.0|       3.0|   70|
|  3.0|       5.0|    2|
|  4.0|       0.0|   35|
|  4.0|       1.0|   21|
|  4.0|       2.0|    1|
|  4.0|       3.0|    3|
|  4.0|       4.0|   45|
|  5.0|       0.0|   27|
+-----+----------+-----+
only showing top 20 rows

