In [1]:
import os
spark_home = os.path.abspath(os.getcwd() + "/spark/spark-3.5.5-bin-hadoop3")
hadoop_home = os.path.abspath(os.getcwd() + "/spark/winutils")
print(f"I am using the following SPARK_HOME: {spark_home}")
if os.name == 'nt':
    os.environ["HADOOP_HOME"] = f"{hadoop_home}"
    print(f"Windows detected: set HADOOP_HOME to: {os.environ['HADOOP_HOME']}")
    hadoop_bin = os.path.join(hadoop_home, "bin")
    os.environ["PATH"] = f"{hadoop_bin};{os.environ['PATH']}"
    print(f"  Also added Hadoop bin directory to PATH: {hadoop_bin}")

import findspark
import pyspark
from pyspark.sql import SparkSession
from pyspark.sql.functions import from_json, col, struct
from pyspark.sql.types import StructType, StructField, StringType

findspark.init(spark_home)
spark = SparkSession.builder.appName("ArxivStreaming").getOrCreate()
sc = spark.sparkContext


I am using the following SPARK_HOME: d:\OneDrive - CGIAR\Master\Advanced Analytics\assignments\assignment-03\spark\spark-3.5.5-bin-hadoop3
Windows detected: set HADOOP_HOME to: d:\OneDrive - CGIAR\Master\Advanced Analytics\assignments\assignment-03\spark\winutils
  Also added Hadoop bin directory to PATH: d:\OneDrive - CGIAR\Master\Advanced Analytics\assignments\assignment-03\spark\winutils\bin


In [2]:
import requests
import time
from collections import Counter
from sklearn.metrics import f1_score, balanced_accuracy_score
from pyspark.sql.functions import udf, struct, col
from pyspark.sql.types import StringType
from src.utils import map_category

# Use Spark accumulators instead of global lists
predictions_accumulator = sc.accumulator(0)  # Count of predictions
valid_predictions_accumulator = sc.accumulator(0)  # Count of valid predictions
total_inference_time_accumulator = sc.accumulator(0.0)  # Sum of inference times

# Keep a driver-side log for detailed analysis
predictions_log = []

def predict(row):
    """Predict using FastAPI service with monitoring"""
    start_time = time.time()
    
    try:
        response = requests.post("http://localhost:8000/predict", 
                               json={"title": str(row.title), "summary": str(row.summary)},
                               timeout=20)
        
        inference_time = (time.time() - start_time) * 1000
        
        if response.status_code == 200:
            result = response.json()
            prediction = result["prediction"]
            
            # Update accumulators
            predictions_accumulator.add(1)
            total_inference_time_accumulator.add(inference_time)

            if prediction not in ["api_error", "timeout_error", "connection_error"]:
                valid_predictions_accumulator.add(1)

            return prediction
        else:
            print(f"API returned status: {response.status_code}")
            return "api_error"
            
    except requests.exceptions.Timeout:
        print(f"API call timed out after 10 seconds")
        return "timeout_error"
    except requests.exceptions.ConnectionError:
        print(f"Cannot connect to FastAPI service")
        return "connection_error"
    except Exception as e:
        print(f"API call failed: {e}")
        return "api_error"


predict_udf = udf(predict, StringType())
map_category_udf = udf(map_category, StringType())


def print_performance_metrics():
    """Print classification performance and inference speed metrics"""
    # print(f"\n--- ACCUMULATOR METRICS ---")
    # print(f"Total predictions (accumulator): {predictions_accumulator.value}")
    # print(f"Driver-side log: {len(predictions_log)} entries")
    
    if predictions_accumulator.value > 0:
        avg_inference_time = total_inference_time_accumulator.value / predictions_accumulator.value
       
    valid_preds = [p for p in predictions_log if p['true_label'] is not None]
    
    if len(valid_preds) < 2:
        print("Need more valid predictions for detailed metrics")
        return
    
    # Extract predictions and true labels
    preds = [p['prediction'] for p in valid_preds]
    trues = [p['true_label'] for p in valid_preds]
    
    # Classification performance metrics
    try:
        macro_f1 = f1_score(trues, preds, average='macro', zero_division=0)
        balanced_acc = balanced_accuracy_score(trues, preds)
        
        print(f"\n--- PERFORMANCE METRICS ---")
        print(f"Valid predictions: {len(valid_preds)}")
        print(f"Macro F1: {macro_f1:.3f}")
        print(f"Balanced Accuracy: {balanced_acc:.3f}")
        print(f"Average inference time: {avg_inference_time:.1f}ms")
    except Exception as e:
        print(f"Error calculating metrics: {e}")
    
    print("=" * 40)

def reset_metrics():
    """Reset all monitoring statistics"""
    global predictions_log
    predictions_log = []
    
    # Reset accumulators by recreating them
    global predictions_accumulator, valid_predictions_accumulator, total_inference_time_accumulator
    predictions_accumulator = sc.accumulator(0)
    valid_predictions_accumulator = sc.accumulator(0)
    total_inference_time_accumulator = sc.accumulator(0.0)
    
    print("Metrics reset!")

In [3]:
# Structured Streaming Setup
arxiv_schema = StructType([
    StructField("title", StringType(), True),
    StructField("summary", StringType(), True),
    StructField("main_category", StringType(), True),
    StructField("published", StringType(), True)
])

# Create streaming DataFrame
stream_df = spark \
    .readStream \
    .format("socket") \
    .option("host", "seppe.net") \
    .option("port", 7778) \
    .load()

# Parse and process
processed_df = stream_df \
    .select(from_json(col("value"), arxiv_schema).alias("data")) \
    .select("data.*") \
    .withColumn("label", map_category_udf(col("main_category"))) \
    .withColumn("pred", predict_udf(struct(col("title"), col("summary"), col("main_category"))))


In [None]:
def process_batch(batch_df, batch_id):
    paper_count = batch_df.count()
    print(f"========= Batch {batch_id} (Papers: {paper_count}) =========")
    if batch_df.count() > 0:
        batch_df.select("title", "main_category", "published", "label", "pred").show(10, truncate=True)
        
        # Collect for metrics - exclude all error types
        results = batch_df.select("label", "pred").collect()
        for row in results:
            if (row.label and row.pred and 
                row.pred not in ["api_error", "timeout_error", "connection_error"]):
                predictions_log.append({
                    'prediction': row.pred,
                    'true_label': row.label
                })
        print_performance_metrics()

query = processed_df.writeStream \
    .foreachBatch(process_batch) \
    .option("checkpointLocation", "/tmp/arxiv_checkpoint") \
    .trigger(processingTime='5 seconds') \
    .start()

print("Structured Streaming started!")


Structured Streaming started!


+--------------------+-------------+--------------------+-----+----+
|               title|main_category|           published|label|pred|
+--------------------+-------------+--------------------+-----+----+
|Extending Recent ...|      math.NT|2025-05-28T05:33:54Z| math|math|
|Leveraging LLM fo...|        cs.SD|2025-05-28T06:12:19Z|   cs|eess|
|GL-PGENet: A Para...|        cs.CV|2025-05-28T06:37:06Z|   cs|  cs|
|Jailbreak Distill...|        cs.CL|2025-05-28T06:59:46Z|   cs|  cs|
|AudioGenie: A Tra...|        cs.SD|2025-05-28T07:23:53Z|   cs|  cs|
|Delayed-KD: Delay...|        cs.SD|2025-05-28T07:51:21Z|   cs|eess|
|PADAM: Parallel a...|      math.OC|2025-05-28T08:07:34Z| math|  cs|
|MemOS: An Operati...|        cs.CL|2025-05-28T08:27:12Z|   cs|  cs|
|Polarforming Desi...|      eess.SP|2025-05-28T05:36:05Z| eess|eess|
|Efficiently Enhan...|        cs.AI|2025-05-28T06:12:51Z|   cs|  cs|
+--------------------+-------------+--------------------+-----+----+
only showing top 10 rows


--- PER

In [5]:
query.stop()


In [None]:
# Restart checkpoint to batch 0 for fresh monitoring

# import shutil
# import os

# # Clear checkpoint directory
# checkpoint_path = "/tmp/arxiv_checkpoint"
# if os.path.exists(checkpoint_path):
#     shutil.rmtree(checkpoint_path)
#     print(f"Cleared checkpoint: {checkpoint_path}")


Cleared checkpoint: /tmp/arxiv_checkpoint
