In [1]:
import os
spark_home = os.path.abspath(os.getcwd() + "/spark/spark-3.5.5-bin-hadoop3")
hadoop_home = os.path.abspath(os.getcwd() + "/spark/winutils")
print(f"I am using the following SPARK_HOME: {spark_home}")
if os.name == 'nt':
    os.environ["HADOOP_HOME"] = f"{hadoop_home}"
    print(f"Windows detected: set HADOOP_HOME to: {os.environ['HADOOP_HOME']}")
    hadoop_bin = os.path.join(hadoop_home, "bin")
    os.environ["PATH"] = f"{hadoop_bin};{os.environ['PATH']}"
    print(f"  Also added Hadoop bin directory to PATH: {hadoop_bin}")

import findspark
import pyspark
from pyspark.streaming import StreamingContext

findspark.init(spark_home)
sc = pyspark.SparkContext()
spark = pyspark.sql.SparkSession.builder.getOrCreate()


I am using the following SPARK_HOME: d:\OneDrive - CGIAR\Master\Advanced Analytics\assignments\assignment-03\spark\spark-3.5.5-bin-hadoop3
Windows detected: set HADOOP_HOME to: d:\OneDrive - CGIAR\Master\Advanced Analytics\assignments\assignment-03\spark\winutils
  Also added Hadoop bin directory to PATH: d:\OneDrive - CGIAR\Master\Advanced Analytics\assignments\assignment-03\spark\winutils\bin


In [2]:
import threading

# Helper thread to avoid the Spark StreamingContext from blocking Jupyter
        
class StreamingThread(threading.Thread):
    def __init__(self, ssc):
        super().__init__()
        self.ssc = ssc
    def run(self):
        self.ssc.start()
        self.ssc.awaitTermination()
    def stop(self):
        print('----- Stopping... this may take a few seconds -----')
        self.ssc.stop(stopSparkContext=False, stopGraceFully=True)
        

In [3]:
import requests
import time
from collections import Counter
from sklearn.metrics import f1_score, balanced_accuracy_score
from pyspark.sql.functions import udf, struct, col
from pyspark.sql.types import StringType
from src.utils import map_category

# Global monitoring variables
predictions_log = []
inference_times = []

def predict(row):
    """Predict using FastAPI service with monitoring"""
    start_time = time.time()
    
    try:
        response = requests.post("http://localhost:8000/predict", 
                               json={"title": str(row.title), "summary": str(row.summary)},
                               timeout=5)
        
        inference_time = (time.time() - start_time) * 1000  # ms
        inference_times.append(inference_time)
        
        if response.status_code == 200:
            result = response.json()
            prediction = result["prediction"]
            
            # Extract true label from main_category and map to parent category
            true_category = getattr(row, 'main_category', None)
            true_label = map_category(true_category) if true_category else None
            
            # Log prediction for metrics calculation
            predictions_log.append({
                'prediction': prediction,
                'true_label': true_label,
                'inference_time_ms': inference_time
            })
            
            return prediction
            
    except Exception as e:
        inference_times.append((time.time() - start_time) * 1000)
        print(f"API call failed: {e}")
        return "api_error"

predict_udf = udf(predict, StringType())

def process(time_batch, rdd):
    """Process streaming batch with predictions and monitoring"""
    if rdd.isEmpty(): 
        return
        
    print(f"========= {str(time_batch)} =========")
    
    df = spark.read.json(rdd)
    
    # Show original data structure
    if df.count() > 0:
        print("Sample data:")
        df.select("title", "main_category").show(2, truncate=True)
        
        # Apply predictions with monitoring
        df_withpreds = df.withColumn("pred", predict_udf(
            struct(col("title"), col("summary"), col("main_category"))
        ))
        
        df_withpreds.select("title", "main_category", "pred").show(5, truncate=True)
        print_performance_metrics()

def print_performance_metrics():
    """Print classification performance and inference speed metrics"""
    if len(predictions_log) < 2: 
        print("Need more predictions for metrics calculation")
        return
    
    # Filter predictions with valid true labels
    valid_preds = [p for p in predictions_log if p['true_label'] is not None]
    
    if len(valid_preds) < 2:
        print("Need more predictions with valid true labels")
        return
    
    # Extract predictions and true labels
    preds = [p['prediction'] for p in valid_preds]
    trues = [p['true_label'] for p in valid_preds]
    
    # Classification performance metrics
    try:
        macro_f1 = f1_score(trues, preds, average='macro', zero_division=0)
        balanced_acc = balanced_accuracy_score(trues, preds)
        
        print(f"\n--- PERFORMANCE METRICS ---")
        print(f"Valid predictions: {len(valid_preds)}")
        print(f"Macro F1: {macro_f1:.3f}")
        print(f"Balanced Accuracy: {balanced_acc:.3f}")
    except Exception as e:
        print(f"Error calculating metrics: {e}")
    
    # Inference speed metrics
    if len(inference_times) > 0:
        avg_time = sum(inference_times) / len(inference_times)
        min_time = min(inference_times)
        max_time = max(inference_times)
        
        print(f"\n--- INFERENCE SPEED ---")
        print(f"Avg: {avg_time:.1f}ms | Min: {min_time:.1f}ms | Max: {max_time:.1f}ms")
        print(f"Total predictions: {len(predictions_log)}")
    
    print("=" * 40)

def get_category_distribution():
    """Show distribution of true vs predicted categories"""
    if len(predictions_log) < 1: 
        return
    
    valid_preds = [p for p in predictions_log if p['true_label'] is not None]
    if len(valid_preds) < 1: 
        return
        
    true_dist = Counter([p['true_label'] for p in valid_preds])
    pred_dist = Counter([p['prediction'] for p in valid_preds])
    
    print("\n--- CATEGORY DISTRIBUTION ---")
    print("True labels:", dict(true_dist))
    print("Predictions:", dict(pred_dist))

def reset_metrics():
    """Reset all monitoring statistics"""
    global predictions_log, inference_times
    predictions_log = []
    inference_times = []
    print("Metrics reset!")
    

In [4]:
ssc = StreamingContext(sc, 10)




Sample data:
+--------------------+-------------+
|               title|main_category|
+--------------------+-------------+
|HCQA-1.5 @ Ego4D ...|        cs.CV|
|STEER-BENCH: A Be...|        cs.CL|
+--------------------+-------------+
only showing top 2 rows

+--------------------+-------------+----+
|               title|main_category|pred|
+--------------------+-------------+----+
|HCQA-1.5 @ Ego4D ...|        cs.CV|  cs|
|STEER-BENCH: A Be...|        cs.CL|  cs|
|Evaluating Traini...|        cs.LG|stat|
|Moment Expansions...|      stat.ML|stat|
|Voronoi-grid-base...|        cs.LG|  cs|
+--------------------+-------------+----+
only showing top 5 rows

Need more predictions for metrics calculation
Sample data:
+--------------------+-------------+
|               title|main_category|
+--------------------+-------------+
|TeroSeek: An AI-P...|        cs.IR|
|Self-Route: Autom...|        cs.CL|
+--------------------+-------------+
only showing top 2 rows

+--------------------+---------

In [5]:
lines = ssc.socketTextStream("seppe.net", 7778)
lines.foreachRDD(process)


In [6]:
ssc_t = StreamingThread(ssc)
ssc_t.start()


In [None]:
ssc_t.stop()
