In [1]:
import findspark
findspark.init()

In [2]:
import pyspark.sql.functions as F
import pyspark.sql.types as T
import sparknlp

from pyspark.sql import SparkSession

from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

import pandas as pd

from models.bloom_filter import BloomFilter
from models.emb_logreg import EmbeddingsLogReg

In [3]:
spark = SparkSession.builder \
    .appName("WikimediaStreamProcessor") \
    .config("spark.driver.memory", "12G") \
    .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") \
    .config("spark.kryoserializer.buffer.max", "2000M") \
    .config("spark.jars.packages", "com.johnsnowlabs.nlp:spark-nlp_2.12:5.5.1") \
    .getOrCreate()


24/11/17 19:23:13 WARN Utils: Your hostname, andrii-VirtualBox resolves to a loopback address: 127.0.1.1; using 10.0.2.15 instead (on interface enp0s3)
24/11/17 19:23:13 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address


:: loading settings :: url = jar:file:/home/andrii/spark-3.5.3-bin-hadoop3/jars/ivy-2.5.1.jar!/org/apache/ivy/core/settings/ivysettings.xml


Ivy Default Cache set to: /home/andrii/.ivy2/cache
The jars for the packages stored in: /home/andrii/.ivy2/jars
com.johnsnowlabs.nlp#spark-nlp_2.12 added as a dependency
:: resolving dependencies :: org.apache.spark#spark-submit-parent-c3f4f2f3-f758-4e13-a6bc-4a3b34ee5f85;1.0
	confs: [default]
	found com.johnsnowlabs.nlp#spark-nlp_2.12;5.5.1 in central
	found com.typesafe#config;1.4.2 in central
	found org.rocksdb#rocksdbjni;6.29.5 in central
	found com.amazonaws#aws-java-sdk-s3;1.12.500 in central
	found com.amazonaws#aws-java-sdk-kms;1.12.500 in central
	found com.amazonaws#aws-java-sdk-core;1.12.500 in central
	found commons-logging#commons-logging;1.1.3 in central
	found commons-codec#commons-codec;1.15 in central
	found org.apache.httpcomponents#httpclient;4.5.13 in central
	found org.apache.httpcomponents#httpcore;4.4.13 in central
	found software.amazon.ion#ion-java;1.0.2 in central
	found joda-time#joda-time;2.8.1 in central
	found com.amazonaws#jmespath-java;1.12.500 in centra

In [4]:
changeSchema = T.StructType([
    T.StructField("$schema", T.StringType(), True),
    T.StructField("meta", T.StringType(), True),
    T.StructField("id", T.LongType(), True),
    T.StructField("type", T.StringType(), True),
    T.StructField("namespace", T.IntegerType(), True),
    T.StructField("title", T.StringType(), True),
    T.StructField("title_url", T.StringType(), True),
    T.StructField("comment", T.StringType(), True),
    T.StructField("timestamp", T.LongType(), True),
    T.StructField("user", T.StringType(), True),
    T.StructField("bot", T.BooleanType(), True),
    T.StructField("notify_url", T.StringType(), True),
    T.StructField("minor", T.BooleanType(), True),
    T.StructField("length", T.StringType(), True),
    T.StructField("revision", T.StringType(), True),
    T.StructField("server_url", T.StringType(), True),
    T.StructField("server_name", T.StringType(), True),
    T.StructField("server_script_path", T.StringType(), True),
    T.StructField("wiki", T.StringType(), True),
    T.StructField("parsedcomment", T.StringType(), True)
])

In [5]:
# filter = BloomFilter.load('./data/filter_train_small')
emb_logreg = EmbeddingsLogReg.load('./data/logreg_small') 

word2vec_gigaword_300 download started this may take some time.
Approximate size to download 312.3 MB
[ | ]

24/11/17 19:23:56 WARN S3AbortableInputStream: Not all bytes were read from the S3ObjectInputStream, aborting HTTP connection. This is likely an error and may result in sub-optimal behavior. Request only the bytes you need via a ranged GET or drain the input stream after use.
24/11/17 19:23:57 WARN S3AbortableInputStream: Not all bytes were read from the S3ObjectInputStream, aborting HTTP connection. This is likely an error and may result in sub-optimal behavior. Request only the bytes you need via a ranged GET or drain the input stream after use.


word2vec_gigaword_300 download started this may take some time.
Approximate size to download 312.3 MB
Download done! Loading the resource.
[ / ]

                                                                                

[ — ]

[Stage 1:>                                                         (0 + 9) / 12]

[ \ ]

                                                                                

[OK!]


                                                                                

In [6]:
# Function to compute and print metrics
def compute_metrics(batch_df, batch_id):
    labels_and_preds = batch_df.select("bot", "prediction")
    
    TP = labels_and_preds.filter((F.col("bot") == 1) & (F.col("prediction") == 1)).count()
    TN = labels_and_preds.filter((F.col("bot") == 0) & (F.col("prediction") == 0)).count()
    FP = labels_and_preds.filter((F.col("bot") == 0) & (F.col("prediction") == 1)).count()
    FN = labels_and_preds.filter((F.col("bot") == 1) & (F.col("prediction") == 0)).count()
    
    acc = (TP+TN) / (TP + FN + TN + FP)
    precision = TP / (TP + FP) if (TP + FP) > 0 else 0
    recall = TP / (TP + FN) if (TP + FN) > 0 else 0
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    
    print(f"Batch {batch_id} - Accuracy: {acc:.3f} Precision: {precision:.3f}, Recall: {recall:.3f}, F1-Score: {f1:.3f}")

In [7]:
test_df_stream = spark.readStream \
    .schema(changeSchema) \
    .format("csv") \
    .option("header", "true") \
    .option("multiLine", "true") \
    .option("escape", "\"") \
    .option("quote", "\"") \
    .option("sep", ",") \
    .load("./data/test/")

test_df_stream = test_df_stream.withColumn(
    "text",
    F.concat(
        F.col("title"),
        F.when(F.col("comment").isNotNull(), F.col("comment")).otherwise(F.lit("NULL"))
    )
)

predictions = emb_logreg.predict(test_df_stream)

predictions_filtered = predictions.filter(F.col("prediction") == 0).select("id", "user")

predictions_filtered.writeStream \
    .format("csv") \
    .option("path", "./data/outputs") \
    .option("checkpointLocation", "./data/checkpoints/") \
    .start()

predictions.writeStream \
    .foreachBatch(compute_metrics) \
    .start()

24/11/17 19:24:25 WARN ResolveWriteToStream: spark.sql.adaptive.enabled is not supported in streaming DataFrames/Datasets and will be disabled.
24/11/17 19:24:25 WARN ResolveWriteToStream: Temporary checkpoint location created which is deleted normally when the query didn't fail: /tmp/temporary-97222a8a-7d94-49cd-b3a4-62338d7b0b72. If it's required to delete it under any circumstances, please set spark.sql.streaming.forceDeleteTempCheckpointLocation to true. Important to know deleting temp checkpoint folder is best effort.
24/11/17 19:24:25 WARN ResolveWriteToStream: spark.sql.adaptive.enabled is not supported in streaming DataFrames/Datasets and will be disabled.


<pyspark.sql.streaming.query.StreamingQuery at 0x76b59ba469f0>

24/11/17 19:24:26 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
24/11/17 19:24:31 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.blas.JNIBLAS
                                                                                

Batch 0 - Accuracy: 0.935 Precision: 0.950, Recall: 0.827, F1-Score: 0.884


In [8]:
spark.stop()