# Apache Spark Group Assignment - Multi Feature Bitcoin Price and Direction Forecasting 
This notebook is organized in the following sections:
* [Description](#0)
* [Step 1 - Historical Data Collection](#2)
* [Step 2 - Initializing Spark Session and defining Features and Label](#3)
* [Step 3 - Raw Multi Model Traininig/Evaluation](#4)
* [Step 4 - Multi Model Traininig/Evaluation With Feature Engineering/Selection and Imbalance with SMOTE](#5)
* [Step 5 - Simulated Streaming Data Forecasting with MicroBatching](#6)

## **Description**

The objective of this project is to forecast the Bitcoin price and its direction, using a Linear Regression and a Logistic Regression Model.
The features used are the prices and funding rates of the following top 5 Altcoins by marketcap.

The data pipeline starts with an API request to 2 different binance endpoints as we're looking both for prices and funding rates, this are stored as csv's in HDFS and subsequently fetched from there to convert them to spark DataFrames.

The next step consists on splitting the DF for training and testing, performing feature scaling and selection to feed and train the Linear and Logistic regression models.

Finally we emulate Streaming data via MicroBatching to predict in real time the price and direction of Bitcoin.

## **Step 1 - Historical Data collection**

In this section, we'll focus on fetching historical data both for token prices and their funding rate, as the combined DataFrame will be instrumental to fit and train our model

### Imports and Constants
Imports libraries for API requests (`requests`), data manipulation (`pandas`), time handling (`time`, `datetime`), and defines constants for Binance API endpoints and cryptocurrency symbols to fetch data for.

In [3]:
import requests
import pandas as pd
import time
from datetime import datetime, timedelta

BASE_URL = "https://api.binance.com"  # Spot API for prices
FUTURES_URL = "https://fapi.binance.com"  # Futures API for funding rates

SYMBOLS = ["BTCUSDT", "ETHUSDT", "BNBUSDT", "XRPUSDT", "ADAUSDT", "SOLUSDT"]

### Fetch Binance Klines Function
Defines a function to fetch historical price data (klines) from Binance Spot API for a given symbol. Retrieves 8-hour interval data over the past year, paginating through results, and returns a DataFrame with timestamps and closing prices.

In [1]:
def fetch_binance_klines(symbol, interval="8h", days=365):
    url = f"{BASE_URL}/api/v3/klines"
    end_time = int(time.time() * 1000)  # Current time in ms
    start_time = end_time - (days * 24 * 60 * 60 * 1000)  # 1 year ago in ms
    
    params = {
        "symbol": symbol,
        "interval": interval,
        "startTime": start_time,
        "endTime": end_time,
        "limit": 1000  # Max limit per request
    }
    
    all_data = []
    while start_time < end_time:
        response = requests.get(url, params=params)
        if response.status_code == 200:
            data = response.json()
            if not data:
                break
            all_data.extend(data)
            start_time = int(data[-1][0]) + 1  # Next batch
            params["startTime"] = start_time
            time.sleep(0.5)  # Avoid rate limits
        else:
            print(f"Error fetching {symbol} klines: {response.status_code}, {response.text}")
            break
    
    df = pd.DataFrame(all_data, columns=["timestamp", "open", "high", "low", "close", "volume", 
                                         "close_time", "quote_volume", "trades", "taker_base", 
                                         "taker_quote", "ignore"])
    df["timestamp"] = pd.to_datetime(df["timestamp"], unit="ms")
    df["close"] = df["close"].astype(float)
    return df[["timestamp", "close"]].rename(columns={"close": f"{symbol}_price"})

### Fetch Binance Funding Rate Function
Defines a function to fetch historical funding rates from Binance Futures API for a given symbol. Collects data over the past year, paginating through results, and returns a DataFrame with timestamps and funding rates.

In [None]:
def fetch_binance_funding_rate(symbol, days=365):
    """Fetch historical funding rates from Binance Futures."""
    url = f"{FUTURES_URL}/fapi/v1/fundingRate"
    end_time = int(time.time() * 1000)
    start_time = end_time - (days * 24 * 60 * 60 * 1000)
    
    params = {
        "symbol": symbol,
        "startTime": start_time,
        "endTime": end_time,
        "limit": 1000
    }
    
    all_data = []
    while start_time < end_time:
        response = requests.get(url, params=params)
        if response.status_code == 200:
            data = response.json()
            if not data:
                break
            all_data.extend(data)
            start_time = int(data[-1]["fundingTime"]) + 1
            params["startTime"] = start_time
            time.sleep(0.5)
        else:
            print(f"Error fetching {symbol} funding: {response.status_code}, {response.text}")
            break
    
    # Convert to DataFrame (timestamp, funding rate)
    df = pd.DataFrame(all_data)
    df["fundingTime"] = pd.to_datetime(df["fundingTime"], unit="ms")
    df["fundingRate"] = df["fundingRate"].astype(float)
    return df[["fundingTime", "fundingRate"]].rename(columns={"fundingTime": "timestamp", "fundingRate": f"{symbol}_funding_rate"})

### Fetch Data for All Symbols
Loops through the list of symbols, fetching price data and funding rates for each using the defined functions. Stores the resulting DataFrames in separate lists for prices and funding rates.

In [4]:
# Fetch and combine data
price_dfs = []
funding_dfs = []
for symbol in SYMBOLS:
    print(f"Fetching data for {symbol}...")
    price_df = fetch_binance_klines(symbol)
    funding_df = fetch_binance_funding_rate(symbol)
    price_dfs.append(price_df)
    funding_dfs.append(funding_df)

Fetching data for BTCUSDT...
Fetching data for ETHUSDT...
Fetching data for BNBUSDT...
Fetching data for XRPUSDT...
Fetching data for ADAUSDT...
Fetching data for SOLUSDT...


### Merge Price Data
Merges all price DataFrames into a single DataFrame, aligning them by `timestamp` using an outer join to preserve all data points across symbols.

In [5]:
# Merge price data
combined_price = price_dfs[0]
for df in price_dfs[1:]:
    combined_price = combined_price.merge(df, on="timestamp", how="outer")

### Merge Funding Rate Data
Merges all funding rate DataFrames into a single DataFrame, aligning them by `timestamp` using an outer join to preserve all funding rate data across symbols.

In [6]:
# Merge funding rate data
combined_funding = funding_dfs[0]
for df in funding_dfs[1:]:
    combined_funding = combined_funding.merge(df, on="timestamp", how="outer")

### Combine Prices and Funding Rates
Merges the combined price and funding rate DataFrames on `timestamp` using an inner join to ensure only rows with matching timestamps are kept. Drops any remaining rows with missing values.

In [7]:
# Merge prices and funding rates (align timestamps)
data = combined_price.merge(combined_funding, on="timestamp", how="inner")
data = data.dropna()  # Drop rows with missing values

In [8]:
print(data)

              timestamp  BTCUSDT_price  ETHUSDT_price  BNBUSDT_price  \
0   2024-05-27 16:00:00       69436.43        3894.22         603.90   
1   2024-05-28 00:00:00       67694.00        3847.29         597.80   
2   2024-05-28 08:00:00       68374.08        3861.10         601.40   
3   2024-05-28 16:00:00       68398.39        3844.69         601.70   
4   2024-05-29 08:00:00       67379.13        3759.85         595.40   
..                  ...            ...            ...            ...   
866 2025-05-24 08:00:00      108929.70        2555.64         674.09   
867 2025-05-25 08:00:00      106967.97        2511.70         663.20   
868 2025-05-25 16:00:00      109004.19        2551.22         669.44   
869 2025-05-26 00:00:00      110094.62        2580.80         674.74   
870 2025-05-26 08:00:00      110009.82        2565.99         675.41   

     XRPUSDT_price  ADAUSDT_price  SOLUSDT_price  BTCUSDT_funding_rate  \
0           0.5340         0.4683         170.15             

### Save Data to CSV
Saves the fully merged dataset to a CSV file named `binance_crypto_data.csv` without the index column and prints a confirmation message.

In [8]:
from pyspark.sql import SparkSession

# Create a Spark session
spark = SparkSession.builder.appName("SaveToHDFS").getOrCreate()

# Convert pandas DataFrame to Spark DataFrame
spark_df = spark.createDataFrame(data)

# Save to HDFS as CSV (replace with your HDFS path)
hdfs_path = 'hdfs://localhost:9000/datalake/raw/binance_crypto_data.csv'
spark_df.write.csv(hdfs_path, mode='overwrite', header=True)

print(f"Data saved to HDFS at {hdfs_path}")

# Stop the Spark session
spark.stop()

25/03/13 19:45:53 WARN Utils: Your hostname, osbdet resolves to a loopback address: 127.0.0.1; using 10.0.2.15 instead (on interface enp0s1)
25/03/13 19:45:53 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/03/13 19:45:53 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
                                                                                

Data saved to HDFS at hdfs://localhost:9000/datalake/raw/binance_crypto_data.csv


## **Step 2 - Initializing Spark Session and defining Featureas and Label**

### Spark Imports
Imports necessary PySpark modules for creating a Spark session, manipulating data with SQL functions (`col`, `lag`), and defining window operations (`Window`).

In [9]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, lag
from pyspark.sql.window import Window

### Spark Session Initialization
Initializes a Spark session with the application name "BTCPricePrediction" to enable distributed data processing.

In [10]:
# Initialize Spark session
spark = SparkSession.builder.appName("BTCPricePrediction").getOrCreate()
spark.sparkContext.setLogLevel("ERROR")

### Load Data
Loads the preprocessed cryptocurrency data from `binance_crypto_data.csv` into a Spark DataFrame, inferring the schema and using the first row as headers.

In [11]:
hdfs_path = 'hdfs://localhost:9000/datalake/raw/binance_crypto_data.csv'

df = spark.read.csv(hdfs_path, header=True, inferSchema=True)

df.show()

+-------------------+-------------+-------------+-------------+-------------+-------------+-------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+
|          timestamp|BTCUSDT_price|ETHUSDT_price|BNBUSDT_price|XRPUSDT_price|ADAUSDT_price|SOLUSDT_price|BTCUSDT_funding_rate|ETHUSDT_funding_rate|BNBUSDT_funding_rate|XRPUSDT_funding_rate|ADAUSDT_funding_rate|SOLUSDT_funding_rate|
+-------------------+-------------+-------------+-------------+-------------+-------------+-------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+
|2024-09-08 16:00:00|     54869.95|       2297.3|        503.1|       0.5295|        0.339|       130.15|           -5.245E-5|            1.827E-5|                 0.0|            3.107E-5|            8.211E-5|            6.412E-5|
|2024-09-09 00:00:00|     54974.01|      2314.33|        507.0|        0

### Define Features and Label
Defines the feature columns (prices for non-BTC symbols and funding rates for all symbols) and the target label (`BTCUSDT_price`) for regression and classification tasks.

In [12]:
# Define features and label
feature_cols = [f"{symbol}_price" for symbol in SYMBOLS[1:]] + \
               [f"{symbol}_funding_rate" for symbol in SYMBOLS]
label_col = "BTCUSDT_price"

### Create Price Direction Label
Uses a window function to compute the previous `BTCUSDT_price` (`prev_price`) and creates a binary `price_direction` column (1 if price increases, 0 if not). Drops rows with null values from the lag operation.

In [13]:
w = Window.orderBy("timestamp")
df = df.withColumn("prev_price", lag(col(label_col)).over(w))
df = df.withColumn("price_direction", (col(label_col) > col("prev_price")).cast("int"))
df = df.dropna()

### Assemble Features
Imports `VectorAssembler` and combines the feature columns into a single vector column named `features` for use in machine learning models.

In [14]:
# Assemble features into a vector
from pyspark.ml.feature import VectorAssembler
assembler = VectorAssembler(inputCols=feature_cols, outputCol="features")
data = assembler.transform(df)

### Select Relevant Columns and Display
Selects the key columns (`timestamp`, `features`, `label_col`, `price_direction`) for modeling and displays the first 5 rows of the resulting DataFrame.

In [15]:
# Select relevant columns
data = data.select("timestamp", "features", label_col, "price_direction")
data.show(5)

+-------------------+--------------------+-------------+---------------+
|          timestamp|            features|BTCUSDT_price|price_direction|
+-------------------+--------------------+-------------+---------------+
|2024-03-14 16:00:00|[3881.7,603.2,0.6...|     71388.94|              1|
|2024-03-15 00:00:00|[3754.12,582.6,0....|     68448.29|              0|
|2024-03-15 08:00:00|[3688.9,596.3,0.6...|     68157.07|              0|
|2024-03-15 16:00:00|[3742.19,632.7,0....|     69499.85|              1|
|2024-03-16 00:00:00|[3740.19,614.7,0....|     69432.45|              0|
+-------------------+--------------------+-------------+---------------+
only showing top 5 rows



## **Step 3 - Raw Multi Model Traininig/Evaluation**

### Data Splitting
Splits the input PySpark DataFrame into training and testing sets using an 80-20 split (80% for training, 20% for testing), with a fixed seed for reproducibility.

In [16]:
# Split data (80% train, 20% test)
train_data, test_data = data.randomSplit([0.8, 0.2], seed=42)

### Linear Regression Model Training
Imports the Linear Regression class from PySpark ML, initializes it to predict a continuous price variable (e.g., `BTCUSDT_price`), and trains the model on the training data.

In [None]:
# Linear Regression (for price prediction)
from pyspark.ml.regression import LinearRegression
lr = LinearRegression(featuresCol="features", labelCol=label_col)
lr_model = lr.fit(train_data)

### Linear Regression Evaluation
Generates predictions on the test data using the trained Linear Regression model, evaluates them using Root Mean Squared Error (RMSE), and prints the result to assess prediction accuracy.

In [None]:
from pyspark.ml.evaluation import RegressionEvaluator
lr_predictions = lr_model.transform(test_data)
evaluator = RegressionEvaluator(labelCol=label_col, predictionCol="prediction", metricName="rmse")
rmse = evaluator.evaluate(lr_predictions)
print(f"Linear Regression RMSE: {rmse}")

### Logistic Regression Model Training
Imports the Logistic Regression class from PySpark ML, initializes it to predict a binary price direction (e.g., up or down), and trains the model on the training data.

In [None]:
# Logistic Regression (for price direction)
from pyspark.ml.classification import LogisticRegression
log_reg = LogisticRegression(featuresCol="features", labelCol="price_direction")
log_reg_model = log_reg.fit(train_data)

### Logistic Regression Evaluation
Generates predictions on the test data using the trained Logistic Regression model, evaluates them using Area Under the ROC Curve (ROC AUC), and prints the result to assess classification performance.

In [None]:
from pyspark.ml.evaluation import BinaryClassificationEvaluator
log_reg_predictions = log_reg_model.transform(test_data)
evaluator = BinaryClassificationEvaluator(labelCol="price_direction", rawPredictionCol="rawPrediction", metricName="areaUnderROC")
roc_auc = evaluator.evaluate(log_reg_predictions)
print(f"Logistic Regression ROC AUC: {roc_auc}")

### Display Sample Predictions
Selects and displays five sample rows from both Linear Regression and Logistic Regression predictions, showing timestamps, actual values, and predicted values for visual inspection.

In [None]:
lr_predictions.select("timestamp", label_col, "prediction").show(5)
log_reg_predictions.select("timestamp", "price_direction", "prediction").show(5)

## **Step 4 - Multi Model Traininig/Evaluation With Feature Engineering/Selection and Imbalance with SMOTE**

### Imports and Setup
Imports libraries for API requests (`requests`), data manipulation (`pandas`), visualization (`seaborn`, `matplotlib`), and PySpark ML functionalities (`SparkSession`, `functions`, `feature`, `regression`, `classification`, `evaluation`). Initializes a Spark session for data processing.

In [None]:
import requests
import pandas as pd
import time
import seaborn as sns
import matplotlib.pyplot as plt
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, lag
from pyspark.sql.window import Window
from pyspark.ml.feature import VectorAssembler, StandardScaler
from pyspark.ml.regression import RandomForestRegressor, LinearRegression
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.evaluation import RegressionEvaluator, BinaryClassificationEvaluator, MulticlassClassificationEvaluator

# Spark setup
spark = SparkSession.builder.appName("BTCPricePrediction").getOrCreate()
df = spark.read.csv("binance_crypto_data.csv", header=True, inferSchema=True)

### Feature Engineering and Label Creation
Defines feature and label columns, creates a previous price column using a window function, and computes a binary price direction label (1 for up, 0 for down) based on price changes, dropping rows with missing values.

In [None]:
# Features and labels
feature_cols = [f"{symbol}_price" for symbol in SYMBOLS[1:]] + [f"{symbol}_funding_rate" for symbol in SYMBOLS]
label_col = "BTCUSDT_price"
w = Window.orderBy("timestamp")
df = df.withColumn("prev_price", lag(col(label_col)).over(w))
df = df.withColumn("price_direction", (col(label_col) > col("prev_price")).cast("int")).dropna()

### Correlation Heatmap Visualization
Converts the PySpark DataFrame to Pandas, computes a correlation matrix for features and the label, and visualizes it as a heatmap using Seaborn to identify relationships between variables, saving the plot as a PNG file.

In [None]:
# Correlation Heatmap
pandas_df = df.select(feature_cols + [label_col]).toPandas()
corr_matrix = pandas_df.corr()
plt.figure(figsize=(12, 8))
sns.heatmap(corr_matrix, annot=True, cmap="coolwarm", fmt=".2f")
plt.title("Correlation Matrix")
plt.savefig("correlation_heatmap.png")
plt.show()

### Feature Selection with Random Forest
Assembles all features into a vector, trains a Random Forest Regressor to determine feature importances, and selects the top 8 most important features for modeling, printing the selected features.

In [None]:
# Feature Selection with Random Forest
assembler = VectorAssembler(inputCols=feature_cols, outputCol="features")
data = assembler.transform(df)
rf = RandomForestRegressor(featuresCol="features", labelCol=label_col, numTrees=20, seed=42)
rf_model = rf.fit(data)
importances = rf_model.featureImportances
feature_importance = sorted(zip(feature_cols, importances), key=lambda x: x[1], reverse=True)
selected_cols = [f[0] for f in feature_importance[:8]]
print(f"Selected Features: {selected_cols}")

### Feature Assembly and Scaling
Assembles the selected features into a vector and scales them using StandardScaler to standardize the data (mean=0, std=1) for better model performance.

In [None]:
# Assemble selected features
assembler_selected = VectorAssembler(inputCols=selected_cols, outputCol="raw_features")
data_selected = assembler_selected.transform(df)

# Scale features
scaler = StandardScaler(inputCol="raw_features", outputCol="features", withStd=True, withMean=True)
scaler_model = scaler.fit(data_selected)
data_scaled = scaler_model.transform(data_selected)

### Class Imbalance Check and SMOTE Oversampling
Checks the distribution of `price_direction` classes (0 and 1), and if imbalanced (ratio > 1.5), applies a SMOTE-like oversampling by replicating the minority class to balance the dataset.

In [None]:
# Check imbalance and apply SMOTE if needed
counts = data_scaled.groupBy("price_direction").count().collect()
majority_count = next(row["count"] for row in counts if row["price_direction"] == 0)
minority_count = next(row["count"] for row in counts if row["price_direction"] == 1)
print(f"Class counts: 0={majority_count}, 1={minority_count}")
if majority_count / minority_count > 1.5:
    majority = data_scaled.filter(col("price_direction") == 0)
    minority = data_scaled.filter(col("price_direction") == 1)
    ratio = majority_count / minority_count
    oversampled_minority = minority.sample(withReplacement=True, fraction=ratio, seed=42)
    balanced_data = majority.union(oversampled_minority)
    print("Applied SMOTE-like oversampling.")
else:
    balanced_data = data_scaled

### Data Splitting
Splits the balanced dataset into training (80%) and testing (20%) sets with a fixed seed for reproducibility, preparing it for model training and evaluation.

In [None]:
# Split data
train_data, test_data = balanced_data.randomSplit([0.8, 0.2], seed=42)

### Linear Regression Training and Evaluation
Trains a Linear Regression model to predict `BTCUSDT_price`, generates predictions on the test set, and evaluates them using RMSE, printing the result to assess performance.

In [None]:
# Linear Regression
lr = LinearRegression(featuresCol="features", labelCol=label_col)
lr_model = lr.fit(train_data)
lr_predictions = lr_model.transform(test_data)
evaluator = RegressionEvaluator(labelCol=label_col, predictionCol="prediction", metricName="rmse")
rmse = evaluator.evaluate(lr_predictions)
print(f"Linear Regression RMSE: {rmse}")

### Logistic Regression Training and Evaluation
Trains a Logistic Regression model to predict `price_direction`, generates predictions on the test set, and evaluates them using ROC AUC, printing the result to assess classification performance.

In [None]:
# Logistic Regression
log_reg = LogisticRegression(featuresCol="features", labelCol="price_direction")
log_reg_model = log_reg.fit(train_data)
log_reg_predictions = log_reg_model.transform(test_data)
roc_evaluator = BinaryClassificationEvaluator(labelCol="price_direction", metricName="areaUnderROC")
roc_auc = roc_evaluator.evaluate(log_reg_predictions)
print(f"Logistic Regression ROC AUC: {roc_auc}")

In [None]:
lr_model.write().overwrite().save("./models/lr_model") 
log_reg_model.write().overwrite().save("./models/log_reg_model")

### Confusion Matrix Visualization
Computes a confusion matrix from Logistic Regression predictions, converts it to a Pandas DataFrame, and visualizes it as a heatmap using Seaborn to show true vs. predicted classifications, saving the plot as a PNG file.

In [None]:
# Confusion Matrix
conf_matrix = log_reg_predictions.groupBy("price_direction", "prediction").count().toPandas()
conf_matrix_pivot = conf_matrix.pivot(index="price_direction", columns="prediction", values="count").fillna(0)
plt.figure(figsize=(6, 5))
sns.heatmap(conf_matrix_pivot, annot=True, fmt="g", cmap="Blues")
plt.title("Confusion Matrix")
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.show()

### Visualize Linear Regression Predictions
Checks the available columns in the Linear Regression predictions DataFrame, converts it to a Pandas DataFrame, and creates a line plot comparing actual and predicted BTC prices over time using Matplotlib

In [None]:
import matplotlib.pyplot as plt

# Convert the prediction DataFrame to Pandas
lr_predictions_df = lr_predictions.select("timestamp", "BTCUSDT_price", "prediction").toPandas()

# Rename for clarity
lr_predictions_df.rename(columns={"prediction": "predicted_price"}, inplace=True)

# Sort by timestamp
lr_predictions_df = lr_predictions_df.sort_values(by="timestamp")

# Plotting actual vs predicted BTC prices
plt.figure(figsize=(14, 7))
plt.plot(lr_predictions_df["timestamp"], lr_predictions_df["BTCUSDT_price"], label="Actual BTC Price", marker='o')
plt.plot(lr_predictions_df["timestamp"], lr_predictions_df["predicted_price"], label="Predicted BTC Price", marker='x')
plt.xlabel("Timestamp")
plt.ylabel("BTC Price (USDT)")
plt.title("Actual vs Predicted BTC Prices Over Time")
plt.legend()
plt.grid(True)
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()


## **Step 5 - Simulated Streaming Data Forecasting with MicroBatching**


### Bitcoin Price Prediction Streaming Pipeline
This script implements a real-time Bitcoin price prediction pipeline using PySpark streaming, fetching cryptocurrency data from Binance APIs, including prices and funding rates for symbols like BTCUSDT and ETHUSDT. 

Initializing a Spark session with optimized configurations, defining a schema for structured streaming data, simulating a streaming source by writing JSON files to a directory.

Following, the processing the stream with watermarking and windowed aggregations, preparing the data for machine learning by assembling features, applying pre-trained Linear and Logistic Regression models to predict price and direction, and managing stream execution with proper resource cleanup upon completion.

In [None]:
import requests
import json
import time
import os
import threading
import shutil
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, window, max as max_, expr, avg
from pyspark.sql.types import StructType, StructField, FloatType, StringType
from pyspark.ml.feature import VectorAssembler, StandardScaler
from pyspark.ml.regression import LinearRegressionModel
from pyspark.ml.classification import LogisticRegressionModel

# Constants and setup
BASE_URL = "https://api.binance.com"
FUTURES_URL = "https://fapi.binance.com"
SYMBOLS = ["BTCUSDT", "ETHUSDT", "BNBUSDT", "XRPUSDT", "ADAUSDT", "SOLUSDT"]

spark = SparkSession.builder.appName("BTCPricePredictionStreaming") \
    .config("spark.sql.shuffle.partitions", "10") \
    .config("spark.streaming.stopGracefullyOnShutdown", "true") \
    .config("spark.sql.streaming.statefulOperator.allowMultiple", "true") \
    .getOrCreate()
spark.sparkContext.setLogLevel("ERROR")

# Define schema with proper types
schema = StructType([
    StructField("timestamp", StringType(), True)
] + [
    StructField(f"{s}_price", FloatType(), True) for s in SYMBOLS
] + [
    StructField(f"{s}_funding_rate", FloatType(), True) for s in SYMBOLS
])

def fetch_realtime_data():
    data = {"timestamp": time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime())}
    for s in SYMBOLS:
        try:
            price_url = f"{BASE_URL}/api/v3/ticker/price?symbol={s}"
            funding_url = f"{FUTURES_URL}/fapi/v1/premiumIndex?symbol={s}"
            price_resp = requests.get(price_url, timeout=5)
            funding_resp = requests.get(funding_url, timeout=5)
            price_resp.raise_for_status()
            funding_resp.raise_for_status()
            data[f"{s}_price"] = float(price_resp.json()["price"])
            data[f"{s}_funding_rate"] = float(funding_resp.json()["lastFundingRate"])
        except requests.RequestException as e:
            print(f"Error fetching {s}: {e}")
            data[f"{s}_price"] = 0.0  # default
            data[f"{s}_funding_rate"] = 0.0
    print(f"Generated data: {data}")
    return json.dumps(data)

# Simulate streaming source
stream_dir = "streaming_input"
if os.path.exists(stream_dir):
    shutil.rmtree(stream_dir)  # clean the directory before starting
os.makedirs(stream_dir, exist_ok=True)
stop_event = threading.Event()

# Write initial data
base_time = int(time.time())
for i in range(3):
    filename = f"{stream_dir}/init_data_{base_time - 20 + i*10}.json"
    with open(filename, "w") as f:
        f.write(fetch_realtime_data())
    print(f"Wrote initial {filename}")
    time.sleep(1)

def write_stream_data():
    while not stop_event.is_set():
        filename = f"{stream_dir}/data_{int(time.time())}.json"
        with open(filename, "w") as f:
            f.write(fetch_realtime_data())
        time.sleep(5)

stream_thread = threading.Thread(target=write_stream_data, daemon=True)
stream_thread.start()

try:
    # Read stream and convert timestamp string to timestamp
    streaming_df = spark.readStream.schema(schema).json(stream_dir)
    streaming_df = streaming_df.withColumn("timestamp", expr("cast(timestamp as timestamp)"))
    
    # Debug raw streaming data
    raw_query = streaming_df.writeStream \
        .outputMode("append") \
        .format("console") \
        .option("truncate", "false") \
        .trigger(processingTime="5 seconds") \
        .start()
    
    # Apply watermark
    watermarked_df = streaming_df.withWatermark("timestamp", "2 minutes")
    
    # Use a sliding window: window duration = 30 seconds, slide duration = 10 seconds
    windowed_df = watermarked_df.groupBy(window(col("timestamp"), "30 seconds", "10 seconds").alias("time_window")) \
                              .agg(avg(col("BTCUSDT_price")).alias("BTCUSDT_price"),
                                   *[avg(col(f"{s}_price")).alias(f"{s}_price") for s in SYMBOLS[1:]],
                                   *[avg(col(f"{s}_funding_rate")).alias(f"{s}_funding_rate") for s in SYMBOLS])
    
    # Debug windowed data
    windowed_query = windowed_df.writeStream \
        .outputMode("append") \
        .format("console") \
        .option("truncate", "false") \
        .trigger(processingTime="5 seconds") \
        .start()
    
    # Prepare data for model
    windowed_df_with_ts = windowed_df.select(
        col("time_window.start").alias("window_start"),
        col("time_window.end").alias("window_end"),
        col("BTCUSDT_price"),
        *[col(f"{s}_price") for s in SYMBOLS[1:]],
        *[col(f"{s}_funding_rate") for s in SYMBOLS]
    )
    
    model_dir = "/home/osbdet/notebooks/mda2/group_project/models"
    
    def process_batch(batch_df, batch_id):
        if batch_df.count() > 0:
            try:
                print(f"Processing batch {batch_id} with {batch_df.count()} records")
                
                # Hard-coded feature list matching training
                selected_cols = [
                    "ETHUSDT_price",
                    "BNBUSDT_price",
                    "XRPUSDT_price",
                    "ADAUSDT_price",
                    "SOLUSDT_price",
                    "BTCUSDT_funding_rate",
                    "ETHUSDT_funding_rate",
                    "BNBUSDT_funding_rate"
                ]
                print(f"Using features: {selected_cols}")
                
                # Assemble features
                assembler = VectorAssembler(inputCols=selected_cols, outputCol="raw_features")
                assembled_df = assembler.transform(batch_df)
                print("Assembled raw features:")
                assembled_df.select("raw_features").show(truncate=False)
                
                # Check record count – if only one record, bypass scaling
                record_count = assembled_df.count()
                if record_count > 1:
                    scaler = StandardScaler(inputCol="raw_features", outputCol="features",
                                            withStd=True, withMean=True)
                    scaler_model = scaler.fit(assembled_df)
                    scaled_df = scaler_model.transform(assembled_df)
                else:
                    print("Only one record in batch; using raw_features directly as features.")
                    scaled_df = assembled_df.withColumn("features", col("raw_features"))
                
                print("Scaled features:")
                scaled_df.select("features").show(truncate=False)
                # Load and apply models

                # Define window for price direction calculation
                w = Window.orderBy("window_start")
                
                # Add prev_price and compute price_direction
                scaled_df = scaled_df.withColumn("prev_price", lag(col("BTCUSDT_price")).over(w)) \
                                    .withColumn("price_direction", 
                                                expr("CASE WHEN BTCUSDT_price > prev_price THEN 1 ELSE 0 END")) \
                                    .dropna()  # Drop rows where prev_price is null
                
                if os.path.exists(f"{model_dir}/lr_model") and os.path.exists(f"{model_dir}/log_reg_model"):
                    # Load Linear Regression model
                    lr_model = LinearRegressionModel.load(f"{model_dir}/lr_model")
                    lr_predicted_df = lr_model.transform(scaled_df) \
                                            .withColumnRenamed("prediction", "predicted_price")  # Rename LR prediction
                    
                    # Load Logistic Regression model
                    log_reg_model = LogisticRegressionModel.load(f"{model_dir}/log_reg_model")
                    final_predicted_df = log_reg_model.transform(lr_predicted_df) \
                        .withColumnRenamed("BTCUSDT_price", "actual_price") \
                        .withColumnRenamed("prev_price", "previous_price")
                    
                    # Show results
                    print(f"Batch {batch_id} Results:")
                    final_predicted_df.select(
                        "window_start",
                        "actual_price",
                        "predicted_price",  # LR prediction
                        "price_direction",
                        "prediction"  # LogReg prediction as predicted_direction
                    ).show(truncate=False)
                else:
                    missing_models = []
                    if not os.path.exists(f"{model_dir}/lr_model"):
                        missing_models.append("lr_model")
                    if not os.path.exists(f"{model_dir}/log_reg_model"):
                        missing_models.append("log_reg_model")
                    print(f"Models not found: {missing_models}")
                    scaled_df.select("window_start", "BTCUSDT_price").show(truncate=False)
                
            except Exception as e:
                print(f"Error processing batch {batch_id}: {str(e)}")
                import traceback
                traceback.print_exc()
        else:
            print(f"Empty batch {batch_id}")
    
    # Process streaming data with foreachBatch
    prediction_query = windowed_df_with_ts.writeStream \
        .foreachBatch(process_batch) \
        .outputMode("update") \
        .trigger(processingTime="10 seconds") \
        .start()
    
    spark.streams.awaitAnyTermination()
    
except Exception as e:
    print(f"Error in main processing: {e}")
    import traceback
    traceback.print_exc()
    
finally:
    print("Stopping streams and cleaning up...")
    stop_event.set()
    
    for query_name in ['raw_query', 'windowed_query', 'prediction_query']:
        if query_name in locals() and locals()[query_name] is not None:
            try:
                locals()[query_name].stop()
                print(f"Stopped {query_name}")
            except Exception as e:
                print(f"Error stopping {query_name}: {e}")
    
    time.sleep(2)
    if os.path.exists(stream_dir):
        shutil.rmtree(stream_dir)
    
    spark.stop()
    print("Streaming stopped and resources cleaned up.")
