In [2]:
import os
spark_home = os.path.abspath(os.getcwd() + "/spark/spark-3.5.5-bin-hadoop3")
hadoop_home = os.path.abspath(os.getcwd() + "/spark/winutils")
print(f"I am using the following SPARK_HOME: {spark_home}")
if os.name == 'nt':
    os.environ["HADOOP_HOME"] = f"{hadoop_home}"
    print(f"Windows detected: set HADOOP_HOME to: {os.environ['HADOOP_HOME']}")
    hadoop_bin = os.path.join(hadoop_home, "bin")
    os.environ["PATH"] = f"{hadoop_bin};{os.environ['PATH']}"
    print(f"  Also added Hadoop bin directory to PATH: {hadoop_bin}")

import findspark
import pyspark
from pyspark.sql import SparkSession
from pyspark.sql.functions import from_json, col, struct
from pyspark.sql.types import StructType, StructField, StringType

findspark.init(spark_home)
spark = SparkSession.builder.appName("ArxivStreaming").getOrCreate()
sc = spark.sparkContext


I am using the following SPARK_HOME: d:\OneDrive - CGIAR\Master\Advanced Analytics\assignments\assignment-03\spark\spark-3.5.5-bin-hadoop3
Windows detected: set HADOOP_HOME to: d:\OneDrive - CGIAR\Master\Advanced Analytics\assignments\assignment-03\spark\winutils
  Also added Hadoop bin directory to PATH: d:\OneDrive - CGIAR\Master\Advanced Analytics\assignments\assignment-03\spark\winutils\bin


In [3]:
import requests
import time
from collections import Counter
from sklearn.metrics import f1_score, balanced_accuracy_score
from pyspark.sql.functions import udf, struct, col, collect_list
from pyspark.sql.types import StringType, ArrayType
from src.utils import map_category

# Use Spark accumulators for tracking across batches
batch_count_accumulator = sc.accumulator(0)        # Number of batches processed
total_papers_accumulator = sc.accumulator(0)       # Total papers processed
total_batch_time_accumulator = sc.accumulator(0.0) # Total time for all batches
valid_predictions_accumulator = sc.accumulator(0)  # Valid predictions
error_predictions_accumulator = sc.accumulator(0)  # Failed predictions

predictions_log = []

def predict_batch_api(papers_list):
    """Call FastAPI once for entire batch"""
    start_time = time.time()
    
    try:
        # Prepare batch payload
        batch_data = [
            {"title": str(paper.title), "summary": str(paper.summary)}
            for paper in papers_list
        ]
        
        response = requests.post("http://localhost:8000/predict_batch", 
                               json={"papers": batch_data},
                               timeout=30)
        
        batch_time = (time.time() - start_time) * 1000  # Total batch time in ms
        
        if response.status_code == 200:
            result = response.json()
            predictions = result["predictions"]
            api_inference_time = result.get("inference_time_ms", batch_time)
            
            # Update accumulators correctly
            batch_count_accumulator.add(1)
            total_papers_accumulator.add(len(predictions))
            total_batch_time_accumulator.add(batch_time)  # Total time including network
            
            # Count valid vs error predictions
            valid_count = sum(1 for pred in predictions 
                            if pred not in ["api_error", "timeout_error", "connection_error"])
            error_count = len(predictions) - valid_count
            
            valid_predictions_accumulator.add(valid_count)
            error_predictions_accumulator.add(error_count)
            
            print(f"Batch API: {len(predictions)} papers in {batch_time:.1f}ms (API: {api_inference_time:.1f}ms)")
            
            return predictions
        else:
            print(f"Batch API returned status: {response.status_code}")
            # Update accumulators for failed batch
            batch_count_accumulator.add(1)
            total_papers_accumulator.add(len(papers_list))
            total_batch_time_accumulator.add(batch_time)
            error_predictions_accumulator.add(len(papers_list))
            
            return ["api_error"] * len(papers_list)
            
    except requests.exceptions.Timeout:
        batch_time = (time.time() - start_time) * 1000
        print(f"Batch API call timed out after {batch_time:.1f}ms")
        
        # Update accumulators for timeout
        batch_count_accumulator.add(1)
        total_papers_accumulator.add(len(papers_list))
        total_batch_time_accumulator.add(batch_time)
        error_predictions_accumulator.add(len(papers_list))
        
        return ["timeout_error"] * len(papers_list)
        
    except requests.exceptions.ConnectionError:
        batch_time = (time.time() - start_time) * 1000
        print(f"Cannot connect to batch API")
        
        # Update accumulators for connection error
        batch_count_accumulator.add(1)
        total_papers_accumulator.add(len(papers_list))
        total_batch_time_accumulator.add(batch_time)
        error_predictions_accumulator.add(len(papers_list))
        
        return ["connection_error"] * len(papers_list)
        
    except Exception as e:
        batch_time = (time.time() - start_time) * 1000
        print(f"Batch API call failed: {e}")
        
        # Update accumulators for other errors
        batch_count_accumulator.add(1)
        total_papers_accumulator.add(len(papers_list))
        total_batch_time_accumulator.add(batch_time)
        error_predictions_accumulator.add(len(papers_list))
        
        return ["api_error"] * len(papers_list)

map_category_udf = udf(map_category, StringType())

def print_performance_metrics():
    """Print classification performance and inference speed metrics across ALL batches"""
    
    print(f"\n--- BATCH PROCESSING METRICS ---")
    print(f"Total batches processed: {batch_count_accumulator.value}")
    print(f"Total papers processed: {total_papers_accumulator.value}")
    print(f"Valid predictions: {valid_predictions_accumulator.value}")
    print(f"Failed predictions: {error_predictions_accumulator.value}")
    
    if batch_count_accumulator.value > 0:
        avg_batch_time = total_batch_time_accumulator.value / batch_count_accumulator.value
        print(f"Average batch processing time: {avg_batch_time:.1f}ms")
        
        if total_papers_accumulator.value > 0:
            avg_time_per_paper = total_batch_time_accumulator.value / total_papers_accumulator.value
            print(f"Average time per paper: {avg_time_per_paper:.1f}ms")
            
            success_rate = (valid_predictions_accumulator.value / total_papers_accumulator.value) * 100
            print(f"Success rate: {success_rate:.1f}%")
    
    print(f"Driver-side log entries: {len(predictions_log)}")
    
    # Classification performance from driver-side log
    valid_preds = [p for p in predictions_log if p['true_label'] is not None]
    
    if len(valid_preds) < 2:
        print("Need more valid predictions for detailed metrics")
        return
    
    # Extract predictions and true labels
    preds = [p['prediction'] for p in valid_preds]
    trues = [p['true_label'] for p in valid_preds]
    
    # Classification performance metrics
    try:
        macro_f1 = f1_score(trues, preds, average='macro', zero_division=0)
        balanced_acc = balanced_accuracy_score(trues, preds)
        
        print(f"\n--- CLASSIFICATION PERFORMANCE ---")
        print(f"Valid predictions for metrics: {len(valid_preds)}")
        print(f"Macro F1: {macro_f1:.3f}")
        print(f"Balanced Accuracy: {balanced_acc:.3f}")
        
        # Show category distribution
        pred_counter = Counter(preds)
        print(f"Prediction distribution: {dict(pred_counter)}")
        
    except Exception as e:
        print(f"Error calculating metrics: {e}")
    
    print("=" * 50)

def reset_metrics():
    """Reset all monitoring statistics"""
    global predictions_log
    predictions_log = []
    
    # Reset accumulators by recreating them
    global batch_count_accumulator, total_papers_accumulator, total_batch_time_accumulator
    global valid_predictions_accumulator, error_predictions_accumulator
    
    batch_count_accumulator = sc.accumulator(0)
    total_papers_accumulator = sc.accumulator(0)
    total_batch_time_accumulator = sc.accumulator(0.0)
    valid_predictions_accumulator = sc.accumulator(0)
    error_predictions_accumulator = sc.accumulator(0)
    
    print("Metrics reset!")
    

In [4]:
# Structured Streaming Setup (keep the same schema)
arxiv_schema = StructType([
    StructField("title", StringType(), True),
    StructField("summary", StringType(), True),
    StructField("main_category", StringType(), True),
    StructField("published", StringType(), True)
])

# Create streaming DataFrame
stream_df = spark \
    .readStream \
    .format("socket") \
    .option("host", "seppe.net") \
    .option("port", 7778) \
    .load()

# Parse and add labels
parsed_df = stream_df \
    .select(from_json(col("value"), arxiv_schema).alias("data")) \
    .select("data.*") \
    .withColumn("label", map_category_udf(col("main_category")))

In [None]:
def process_batch(batch_df, batch_id):
    paper_count = batch_df.count()
    print(f"========= Batch {batch_id} (Papers: {paper_count}) =========")
    
    if paper_count > 0:
        # Collect all papers for batch prediction
        papers = batch_df.collect()
        
        if papers:
            # Call batch API once
            predictions = predict_batch_api(papers)
            
            # Create DataFrame with results
            results_data = []
            for i, paper in enumerate(papers):
                prediction = predictions[i] if i < len(predictions) else "api_error"
                results_data.append({
                    'title': paper.title,
                    'main_category': paper.main_category,
                    'published': paper.published,
                    'label': paper.label,
                    'pred': prediction
                })
            
            # Create results DataFrame and show
            results_df = spark.createDataFrame(results_data)
            results_df.select("title", "main_category", "published", "label", "pred").show(10, truncate=True)
            
            # Log for metrics
            for result in results_data:
                if (result['label'] and result['pred'] and 
                    result['pred'] not in ["api_error", "timeout_error", "connection_error"]):
                    predictions_log.append({
                        'prediction': result['pred'],
                        'true_label': result['label']
                    })
            
            print_performance_metrics()

# Start streaming with the new batch processing
query = parsed_df.writeStream \
    .foreachBatch(process_batch) \
    .option("checkpointLocation", "/tmp/arxiv_checkpoint") \
    .trigger(processingTime='2 seconds') \
    .start()

print("Structured Streaming started with batch API!")

Structured Streaming started with batch API!


Batch API: 60 papers in 8835.7ms (API: 8817.6ms)
+--------------------+-------------+--------------------+--------+--------+
|               title|main_category|           published|   label|    pred|
+--------------------+-------------+--------------------+--------+--------+
|GL-PGENet: A Para...|        cs.CV|2025-05-28T06:37:06Z|      cs|      cs|
|Jailbreak Distill...|        cs.CL|2025-05-28T06:59:46Z|      cs|      cs|
|AudioGenie: A Tra...|        cs.SD|2025-05-28T07:23:53Z|      cs|      cs|
|Delayed-KD: Delay...|        cs.SD|2025-05-28T07:51:21Z|      cs|    eess|
|A High Accuracy S...|      math.NA|2025-05-28T06:38:20Z|    math|    math|
|Balanced Token Pr...|        cs.CV|2025-05-28T07:00:50Z|      cs|      cs|
|Voice Adaptation ...|        cs.CL|2025-05-28T07:24:40Z|      cs|      cs|
|Physical Reduced ...|     quant-ph|2025-05-28T07:52:37Z|quant-ph|quant-ph|
|Securing the Soft...|        cs.SE|2025-05-28T06:42:37Z|      cs|      cs|
|OmniAD: Detect an...|        cs.CV|202

In [None]:
query.stop()


In [None]:
# Restart checkpoint to batch 0 for fresh monitoring

# import shutil
# import os

# # Clear checkpoint directory
# checkpoint_path = "/tmp/arxiv_checkpoint"
# if os.path.exists(checkpoint_path):
#     shutil.rmtree(checkpoint_path)
#     print(f"Cleared checkpoint: {checkpoint_path}")


Cleared checkpoint: /tmp/arxiv_checkpoint
