In [1]:
import pyspark
from pyspark.context import SparkContext
from pyspark.sql import SparkSession, SQLContext, Window
from pyspark.sql import functions as f
from pyspark.sql.functions import col, lit, signum

spark = SparkSession.builder.master("local[1]").getOrCreate()


In [13]:
#Dollar bars
df=spark.read.csv("AAPL.csv", header='true')
df=df.withColumn('DolExch',col('Adj Close')*col('Volume'))
df=df.withColumn('CumDolExch',f.sum(df.DolExch).over(Window.partitionBy().orderBy().rowsBetween(-df.count(),0)))
df=df.withColumn('DolBars', col('CumDolExch')%(5E9))
df=df.filter((df.DolBars-df.DolExch)<0)


#Convert to type double
df=df.withColumn("Adj Close",col("Adj Close").cast("double"))
df=df.withColumn("Volume",col("Volume").cast("double"))
df=df.withColumn("Open",col("Open").cast("double"))
df=df.withColumn("High",col("High").cast("double"))
df=df.withColumn("Low",col("Low").cast("double"))


#Add target vector(2 for +change, 1 for no change 0 for -change)
df=df.withColumn("next_value",f.lag(col("Adj Close"),-1).over(Window.orderBy("Date")))
df=df.withColumn("Target", 1+signum(df['next_value']-df['Adj Close']))

#Drop extra columns, null rows
cols_to_drop=['Close','DolBars','DolExch','CumDolExch','next_value']
df=df.drop(*cols_to_drop)
df=df.na.drop()
df.show(10)
df.printSchema()

+----------+--------+--------+--------+---------+-----------+------+
|      Date|    Open|    High|     Low|Adj Close|     Volume|Target|
+----------+--------+--------+--------+---------+-----------+------+
|2000-07-17|4.160714|4.200893|4.080357| 3.605784|  6.50006E7|   0.0|
|2000-08-22|3.616071|3.772321|3.598214| 3.196123|  6.92006E7|   2.0|
|2000-09-19|4.267857|4.321429|4.183036| 3.706267|  6.78776E7|   0.0|
|2000-09-29|2.013393|2.071429|  1.8125| 1.592265|1.8554102E9|   0.0|
|2000-10-13|1.446429|1.580357|1.428571| 1.364247| 3.119382E8|   0.0|
|2000-11-20|1.328125|1.392857|1.303571|  1.17101| 1.020166E8|   2.0|
|2001-01-19|1.388393|1.397321|1.334821| 1.205793|  1.94166E8|   2.0|
|2001-03-13|1.348214|1.397321|1.299107| 1.209657| 1.108324E8|   2.0|
|2001-04-19|   1.825|1.839286|1.685714|  1.59041| 4.684176E8|   0.0|
|2001-06-12|1.412143|1.477857|1.411429|  1.25588|  7.59486E7|   0.0|
+----------+--------+--------+--------+---------+-----------+------+
only showing top 10 rows

root
 |-

In [14]:
from pyspark.ml.feature import VectorAssembler
#Create input features
df_cols=df.columns
df_cols.remove('Date')
df_cols.remove('Target')
#Move features to a single vector
assembler=VectorAssembler(inputCols=df_cols,outputCol="features")
df=assembler.transform(df)


In [15]:
from pyspark.ml.classification import DecisionTreeClassifier
import time




train,test=df.randomSplit([0.8,0.2],seed=1)

print('repartitioning')
train=train.repartition(10)
test=test.repartition(10)

start_time=time.time()

dt=DecisionTreeClassifier(featuresCol='features',
                         labelCol='Target',
                         maxDepth=30,
                         minInstancesPerNode=2)

dtModel=dt.fit(train)
predictions=dtModel.transform(test)
end_time=time.time()

delta_time = end_time - start_time

# 5. print total run time 
print(f'run-time: {round(delta_time/60.0, 2)}')

predictions.printSchema()

repartitioning
run-time: 0.06
root
 |-- Date: string (nullable = true)
 |-- Open: double (nullable = true)
 |-- High: double (nullable = true)
 |-- Low: double (nullable = true)
 |-- Adj Close: double (nullable = true)
 |-- Volume: double (nullable = true)
 |-- Target: double (nullable = true)
 |-- features: vector (nullable = true)
 |-- rawPrediction: vector (nullable = true)
 |-- probability: vector (nullable = true)
 |-- prediction: double (nullable = false)



In [16]:
predictions.select('Target','prediction').show()
correct_preds=predictions.filter(col('Target')==col('prediction'))
print(correct_preds.count()/predictions.count())

+------+----------+
|Target|prediction|
+------+----------+
|   0.0|       2.0|
|   0.0|       2.0|
|   2.0|       0.0|
|   2.0|       2.0|
|   2.0|       0.0|
|   2.0|       2.0|
|   2.0|       0.0|
|   2.0|       2.0|
|   2.0|       0.0|
|   0.0|       2.0|
|   0.0|       2.0|
|   0.0|       0.0|
|   0.0|       2.0|
|   2.0|       2.0|
|   2.0|       0.0|
|   0.0|       0.0|
|   0.0|       0.0|
|   2.0|       0.0|
|   2.0|       0.0|
|   0.0|       2.0|
+------+----------+
only showing top 20 rows

0.5076660988074957
