In [None]:
import os
spark_home = os.path.abspath(os.getcwd() + "/spark/spark-3.5.5-bin-hadoop3")
hadoop_home = os.path.abspath(os.getcwd() + "/spark/winutils")
print(f"I am using the following SPARK_HOME: {spark_home}")
if os.name == 'nt':
    os.environ["HADOOP_HOME"] = f"{hadoop_home}"
    print(f"Windows detected: set HADOOP_HOME to: {os.environ['HADOOP_HOME']}")
    hadoop_bin = os.path.join(hadoop_home, "bin")
    os.environ["PATH"] = f"{hadoop_bin};{os.environ['PATH']}"
    print(f"  Also added Hadoop bin directory to PATH: {hadoop_bin}")

import findspark
import pyspark
from pyspark.streaming import StreamingContext

findspark.init(spark_home)
sc = pyspark.SparkContext()
spark = pyspark.sql.SparkSession.builder.getOrCreate()


In [None]:
import threading

# Helper thread to avoid the Spark StreamingContext from blocking Jupyter
        
class StreamingThread(threading.Thread):
    def __init__(self, ssc):
        super().__init__()
        self.ssc = ssc
    def run(self):
        self.ssc.start()
        self.ssc.awaitTermination()
    def stop(self):
        print('----- Stopping... this may take a few seconds -----')
        self.ssc.stop(stopSparkContext=False, stopGraceFully=True)
        

In [None]:
import requests
import time
from collections import Counter
from sklearn.metrics import f1_score, balanced_accuracy_score
from pyspark.sql.functions import udf, struct, col
from pyspark.sql.types import StringType
from src.utils import map_category

# Use Spark accumulators instead of global lists
predictions_accumulator = sc.accumulator(0)  # Count of predictions
valid_predictions_accumulator = sc.accumulator(0)  # Count of valid predictions
total_inference_time_accumulator = sc.accumulator(0.0)  # Sum of inference times

# Keep a driver-side log for detailed analysis
predictions_log = []

def predict(row):
    """Predict using FastAPI service with monitoring"""
    start_time = time.time()
    
    try:
        response = requests.post("http://localhost:8000/predict", 
                               json={"title": str(row.title), "summary": str(row.summary)},
                               timeout=10)  # Increase timeout
        
        inference_time = (time.time() - start_time) * 1000
        
        if response.status_code == 200:
            result = response.json()
            prediction = result["prediction"]
            
            # Update accumulators
            predictions_accumulator.add(1)
            total_inference_time_accumulator.add(inference_time)
            
            return prediction
        else:
            print(f"API returned status: {response.status_code}")
            return "api_error"
            
    except requests.exceptions.Timeout:
        print(f"API call timed out after 10 seconds")
        return "timeout_error"
    except requests.exceptions.ConnectionError:
        print(f"Cannot connect to FastAPI service")
        return "connection_error"
    except Exception as e:
        print(f"API call failed: {e}")
        return "api_error"


predict_udf = udf(predict, StringType())
map_category_udf = udf(map_category, StringType())

def process(time_batch, rdd):
    """Process streaming batch with predictions and monitoring"""
    if rdd.isEmpty(): 
        return
        
    print(f"========= {str(time_batch)} =========")
    
    df = spark.read.json(rdd)
    
    if df.count() > 0:
        print("Sample data:")
        # Add label column to verify mapping works correctly
        df = df.withColumn("label", map_category_udf(col("main_category")))
        df.select("title", "published", "main_category", "label").show()

        # Apply predictions with monitoring
        df_withpreds = df.withColumn("pred", predict_udf(
            struct(col("title"), col("summary"), col("main_category"))
        ))
        
        df_withpreds.select("title", "main_category", "label", "pred").show()
        
        # Collect results to driver for detailed analysis
        results = df_withpreds.select("main_category", "label", "pred").collect()
        for row in results:
            true_label = row.label
            prediction = row.pred
            if true_label and prediction and prediction != "api_error":
                predictions_log.append({
                    'prediction': prediction,
                    'true_label': true_label
                })
        
        print_performance_metrics()

def print_performance_metrics():
    """Print classification performance and inference speed metrics"""
    print(f"\n--- ACCUMULATOR METRICS ---")
    print(f"Total predictions (accumulator): {predictions_accumulator.value}")
    print(f"Valid predictions (accumulator): {valid_predictions_accumulator.value}")
    print(f"Driver-side log: {len(predictions_log)} entries")
    
    if predictions_accumulator.value > 0:
        avg_inference_time = total_inference_time_accumulator.value / predictions_accumulator.value
        print(f"Average inference time: {avg_inference_time:.1f}ms")
    
    # Use driver-side log for detailed metrics
    if len(predictions_log) < 2:
        print("Need more predictions in driver log for detailed metrics")
        return
    
    valid_preds = [p for p in predictions_log if p['true_label'] is not None]
    
    if len(valid_preds) < 2:
        print("Need more valid predictions for detailed metrics")
        return
    
    # Extract predictions and true labels
    preds = [p['prediction'] for p in valid_preds]
    trues = [p['true_label'] for p in valid_preds]
    
    # Classification performance metrics
    try:
        macro_f1 = f1_score(trues, preds, average='macro', zero_division=0)
        balanced_acc = balanced_accuracy_score(trues, preds)
        
        print(f"\n--- PERFORMANCE METRICS ---")
        print(f"Valid predictions: {len(valid_preds)}")
        print(f"Macro F1: {macro_f1:.3f}")
        print(f"Balanced Accuracy: {balanced_acc:.3f}")
    except Exception as e:
        print(f"Error calculating metrics: {e}")
    
    print("=" * 40)

def reset_metrics():
    """Reset all monitoring statistics"""
    global predictions_log
    predictions_log = []
    
    # Reset accumulators by recreating them
    global predictions_accumulator, valid_predictions_accumulator, total_inference_time_accumulator
    predictions_accumulator = sc.accumulator(0)
    valid_predictions_accumulator = sc.accumulator(0)
    total_inference_time_accumulator = sc.accumulator(0.0)
    
    print("Metrics reset!")

In [None]:
ssc = StreamingContext(sc, 10)


In [None]:
lines = ssc.socketTextStream("seppe.net", 7778)
lines.foreachRDD(process)


In [None]:
ssc_t = StreamingThread(ssc)
ssc_t.start()


In [None]:
ssc_t.stop()
