In [16]:
!pip install pyspark

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [17]:
from pyspark.sql import SparkSession

In [18]:
spark_application_name = "Spark_Application_Name"

In [19]:
spark = (SparkSession.builder.appName(spark_application_name).getOrCreate())

In [20]:
from pyspark.sql.functions import percent_rank
from pyspark.sql import Window

filePath = "stocks-final.parquet"
stocksDF = spark.read.parquet(filePath)

stocksDF = stocksDF.withColumn("rank", percent_rank().over(Window.partitionBy().orderBy("Date")))#.drop("company_name")
trainDF = stocksDF.where("rank <= .8").drop("rank")
testDF = stocksDF.where("rank > .8").drop("rank")

## Vector Assembler

In [21]:
from pyspark.ml.feature import VectorAssembler

numericCols = []
for (field, dataType) in trainDF.dtypes:
    if (dataType == "double") & (field != "Next"):
        numericCols.append(field)

vecAssembler = VectorAssembler(inputCols=numericCols, outputCol="features")

## Random Forest

In [22]:
from pyspark.ml.regression import RandomForestRegressor
from pyspark.ml import Pipeline

rf = RandomForestRegressor(labelCol="Next", seed=42)
pipeline = Pipeline(stages = [vecAssembler, rf])

## Grid Search

In [23]:
from pyspark.ml.tuning import ParamGridBuilder

paramGrid = (ParamGridBuilder()
            .addGrid(rf.maxDepth, [2, 4, 6])
            .addGrid(rf.numTrees, [10, 100])
            .build())

## Cross Validation

In [24]:
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.tuning import CrossValidator

regEvaluator = RegressionEvaluator(labelCol="Next",
                                predictionCol="prediction", 
                                metricName="rmse")

cv = CrossValidator(estimator=pipeline, 
                    evaluator=regEvaluator,
                    estimatorParamMaps=paramGrid, 
                    numFolds=3, 
                    seed=42)

In [25]:
cvModel = cv.setParallelism(4).fit(trainDF)

In [26]:
cv = CrossValidator(estimator=rf, 
                    evaluator=regEvaluator,
                    estimatorParamMaps=paramGrid, 
                    numFolds=3, 
                    parallelism=4, 
                    seed=42)

pipeline = Pipeline(stages=[vecAssembler, cv])

pipelineModel = pipeline.fit(trainDF)

In [27]:
list(zip(cvModel.getEstimatorParamMaps(), cvModel.avgMetrics))

[({Param(parent='RandomForestRegressor_d978ea002b8c', name='maxDepth', doc='Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. Must be in range [0, 30].'): 2,
   Param(parent='RandomForestRegressor_d978ea002b8c', name='numTrees', doc='Number of trees to train (>= 1).'): 10},
  47.33141238232539),
 ({Param(parent='RandomForestRegressor_d978ea002b8c', name='maxDepth', doc='Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. Must be in range [0, 30].'): 2,
   Param(parent='RandomForestRegressor_d978ea002b8c', name='numTrees', doc='Number of trees to train (>= 1).'): 100},
  43.28900249925332),
 ({Param(parent='RandomForestRegressor_d978ea002b8c', name='maxDepth', doc='Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. Must be in range [0, 30].'): 4,
   Param(parent='RandomForestRegressor_d978ea002b8c', name

In [28]:
predictionDF = pipelineModel.transform(testDF)

regEvaluator = RegressionEvaluator(predictionCol="prediction", labelCol="Next", metricName="rmse")

rmse = regEvaluator.evaluate(predictionDF)
r2 = regEvaluator.setMetricName("r2").evaluate(predictionDF)
print(f"RMSE is {rmse}")
print(f"R2 is {r2}")

RMSE is 112.47297881106735
R2 is 0.5723869215376514


In [29]:
predictionDF = pipelineModel.transform(testDF)

predictionDF.select("features", "Next", "prediction").show(10)

+--------------------+------------------+------------------+
|            features|              Next|        prediction|
+--------------------+------------------+------------------+
|[1436.96997070312...|1438.1400146484375|1381.1177120022178|
|[1438.14001464843...| 1415.699951171875|1381.1177120022178|
|[1415.69995117187...|1371.7039794921875|1381.1177120022178|
|[1371.70397949218...|1341.1400146484375| 1343.178724487436|
|[1341.14001464843...|1390.8699951171875|1297.3409612065316|
|[1390.86999511718...|1410.1500244140625|1373.8490518417661|
|[1410.15002441406...|1388.0899658203125| 1343.178724487436|
|[1388.08996582031...|1358.9100341796875|1381.1177120022178|
|[1358.91003417968...| 1306.219970703125| 1343.178724487436|
|[1306.21997070312...| 1254.760009765625|1297.3409612065316|
+--------------------+------------------+------------------+
only showing top 10 rows

