Import statements

In [1]:
import os
import re
import threading

import findspark
import pyspark
import torch

from pyspark.sql.functions import udf, concat_ws, col
from pyspark.sql.types import StringType, StructType, StructField
from pyspark.streaming import StreamingContext

Set up environment

In [2]:
spark_home = os.environ.get("SPARK_HOME") or os.path.abspath(os.path.join(os.getcwd(), "..", "spark-3.5.5-bin-hadoop3"))
hadoop_home = os.environ.get("HADOOP_HOME") or os.path.abspath(os.path.join(os.getcwd(), "..", "winutils"))

if not os.path.exists(spark_home):
    print(f"ERROR: SPARK_HOME does not exist: {spark_home}")
    exit(1)

if os.name == "nt" and os.path.exists(hadoop_home):
    os.environ["HADOOP_HOME"] = hadoop_home
    os.environ["PATH"] = f"{os.path.join(hadoop_home, 'bin')};{os.environ['PATH']}"

print(f"Using SPARK_HOME: {spark_home}")
print(f"Using HADOOP_HOME: {hadoop_home}")

findspark.init(spark_home)

sc = pyspark.SparkContext(appName="StreamingPaperClassifier")
spark = pyspark.sql.SparkSession.builder.appName("StreamingPaperClassifierSession").config("spark.executor.memory", "8g") \
    .config("spark.driver.memory", "8g") \
    .config("spark.executor.extraJavaOptions", "-verbose:gc") \
    .getOrCreate()
spark.sparkContext.setLogLevel("WARN")

Using SPARK_HOME: C:\Users\topsj\Desktop\spark\spark-3.5.5-bin-hadoop3
Using HADOOP_HOME: C:\Users\topsj\Desktop\spark\winutils


Global variables and model load function

In [3]:
# -------------------- Global Variables --------------------

globals().update({
    'models_loaded': False,
    'tokenizer': None,
    'my_model': None,
    'category_labels': []
})

# -------------------- Model Load Function --------------------
def load_model_and_labels():
    from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification

    if globals()['models_loaded']:
        return

    print("Loading model and tokenizer...")
    try:
        model_path = "results/checkpoint-2631"
        if not os.path.exists(model_path):
            print(f"Model path not found: {model_path}")
            return

        tokenizer = DistilBertTokenizerFast.from_pretrained(model_path)
        model = DistilBertForSequenceClassification.from_pretrained(model_path)
        model.eval()

        id2label = getattr(model.config, 'id2label', {})
        if id2label:
            try:
                sorted_labels = [id2label[str(i)] for i in range(len(id2label))]
            except Exception:
                sorted_labels = list(id2label.values())

            globals().update({
                'tokenizer': tokenizer,
                'my_model': model,
                'category_labels': sorted_labels,
                'models_loaded': True
            })
            print(f"Model loaded with labels: {sorted_labels}")
        else:
            print("WARNING: No id2label found in config. Predictions will be indices.")
            globals().update({
                'tokenizer': tokenizer,
                'my_model': model,
                'models_loaded': True
            })

    except Exception as e:
        print(f"Error loading model: {e}")
        import traceback
        traceback.print_exc()

Inference function

In [4]:
def run_inference(text_list):
    if not globals()['models_loaded']:
        print("Model not loaded. Skipping inference.")
        return []

    try:
        tokens = globals()['tokenizer'](text_list, padding=True, truncation=True, return_tensors="pt", max_length=512)
        with torch.no_grad():
            outputs = globals()['my_model'](**tokens)
            return torch.argmax(outputs.logits, axis=1).tolist()
    except Exception as e:
        print(f"Inference error: {e}")
        import traceback
        traceback.print_exc()
        return []

Normalize category function

In [5]:
def normalize_category_func(category_str):
    if category_str is None:
        return None
    if "q-fin" in category_str:
        return "q-fin"
    elif "q-bio" in category_str:
        return "q-bio"
    return re.split(r'[-\.]', category_str)[0]

normalize_category_udf = udf(normalize_category_func, StringType())

Json schema we follow

In [6]:
# -------------------- Schema --------------------
json_schema = StructType([
    StructField("aid", StringType(), True),
    StructField("categories", StringType(), True),
    StructField("main_category", StringType(), True),
    StructField("published", StringType(), True),
    StructField("summary", StringType(), True),
    StructField("title", StringType(), True)
])

Process function

In [7]:
def process(time, rdd):
    print(f"\n=== Processing batch at {time} ===")

    if not globals()['models_loaded']:
        load_model_and_labels()
        if not globals()['models_loaded']:
            print("Model not loaded. Skipping batch.")
            return

    try:
        if rdd.isEmpty():
            print("Empty RDD. Skipping.")
            return


        df = spark.read.json(rdd, schema=json_schema)
        if df.rdd.isEmpty():
            print("Parsed DataFrame is empty.")
            return

        df = df.withColumn("normalized_main_category", normalize_category_udf(col("main_category")))
        df = df.withColumn("text", concat_ws(" ", col("title"), col("summary")))
        df_processed = df.select("text", "normalized_main_category").dropna()

        if df_processed.rdd.isEmpty():
            print("No valid rows after preprocessing.")
            return

        pandas_df = df_processed.toPandas()
        if pandas_df.empty:
            print("Empty pandas DataFrame.")
            return

        predictions = run_inference(pandas_df["text"].tolist())
        if not predictions or len(predictions) != len(pandas_df):
            print("Mismatch or no predictions.")
            return

        pandas_df["predicted_category_idx"] = predictions
        labels = globals()['category_labels']

        if labels:
            pandas_df["predicted_category_name"] = pandas_df["predicted_category_idx"].apply(
                lambda idx: labels[idx] if 0 <= idx < len(labels) else f"unknown_idx_{idx}"
            )
        else:
            pandas_df["predicted_category_name"] = pandas_df["predicted_category_idx"].astype(str)

        correct = (pandas_df["normalized_main_category"] == pandas_df["predicted_category_name"]).sum()
        total = len(pandas_df)
        accuracy = correct / total if total else 0

        print(f"Accuracy: {accuracy:.4f} ({correct}/{total})")
        print("Sample Predictions:")
        print(pandas_df[["normalized_main_category", "predicted_category_name"]].head())

    except Exception as e:
        print(f"Batch processing error: {e}")
        import traceback
        traceback.print_exc()

Streaming

In [8]:
ssc = StreamingContext(sc, 30)




=== Processing batch at 2025-05-16 20:47:30 ===


In [9]:
lines = ssc.socketTextStream("seppe.net", 7778)
lines.foreachRDD(process)

Threading for streaming

In [10]:
class StreamingThread(threading.Thread):
    def __init__(self, ssc_instance):
        super().__init__()
        self.ssc_instance = ssc_instance
        self._stop_event = threading.Event()

    def run(self):
        try:
            print("Starting streaming context...")
            self.ssc_instance.start()
            while not self._stop_event.is_set():
                self._stop_event.wait(1)
        except Exception as e:
            print(f"Streaming error: {e}")
            import traceback
            traceback.print_exc()
        finally:
            self.ssc_instance.stop(stopSparkContext=False, stopGraceFully=True)

    def stop_stream(self):
        print("Stopping streaming context...")
        self._stop_event.set()
        if self.ssc_instance.getState() == StreamingContext.STATE_ACTIVE:
            self.ssc_instance.stop(stopSparkContext=False, stopGraceFully=True)

In [None]:
if __name__ == "__main__":
    ssc_t = StreamingThread(ssc)
    try:
        ssc_t.start()
        print("Streaming started. Press Ctrl+C to stop.")
        while ssc_t.is_alive():
            ssc_t.join(timeout=1.0)
    except KeyboardInterrupt:
        print("\nKeyboardInterrupt received. Stopping...")
    finally:
        if ssc_t.is_alive():
            ssc_t.stop_stream()
            ssc_t.join(timeout=BATCH_INTERVAL_SECONDS * 2 + 5)
        print("Shutdown complete.")

Starting streaming context...Streaming started. Press Ctrl+C to stop.



In [None]:
ssc_t.stop()