-
Notifications
You must be signed in to change notification settings - Fork 1
/
HousingAnalyzer.scala
147 lines (130 loc) · 5.78 KB
/
HousingAnalyzer.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import java.util.Calendar
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.evaluation.RegressionEvaluator
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.regression._
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.{BinaryLogisticRegressionSummary, GBTClassifier, LogisticRegression}
import org.apache.spark.ml.evaluation.RegressionEvaluator
import org.apache.spark.ml.feature.{LabeledPoint, VectorAssembler, VectorIndexer}
import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor, LinearRegression, RandomForestRegressor}
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.max
import org.apache.log4j.{Level, Logger}
object HousingAnalyzer {
def main(args: Array[String]): Unit = {
val spark = SparkSession
.builder
.master("local[*]")
.appName("PipelineExample")
.getOrCreate()
import spark.implicits._
val srcDataDir = System.getProperty("user.dir") + "/source-data/"
val dataExtractor = new DataExtractor(srcDataDir, spark)
val featureColumns = Array("MSSubClass","LotArea","OverallQual","OverallCond","YearBuilt","YearRemodAdd","BsmtFinSF1","BsmtFinSF2","BsmtUnfSF","TotalBsmtSF","1stFlrSF","2ndFlrSF","LowQualFinSF","GrLivArea","BsmtFullBath","BsmtHalfBath","FullBath","HalfBath","BedroomAbvGr","KitchenAbvGr","TotRmsAbvGrd","Fireplaces","GarageCars","GarageArea","WoodDeckSF","OpenPorchSF","EnclosedPorch","3SsnPorch","ScreenPorch","PoolArea","MoSold","YrSold")
val labelColumn = "SalePrice"
val assembler = new VectorAssembler()
.setInputCols(featureColumns)
.setOutputCol("features")
val featurizedTrainingData = assembler.transform(dataExtractor.trainingData).select("Id","SalePrice", "features")
val featurizedTestData = assembler.transform(dataExtractor.testData).select("Id","features")
val trainingData = featurizedTrainingData
val testData = featurizedTestData
// val Array(trainingData, testData) = featurizedTrainingData.randomSplit(Array(0.8, 0.2))
// Train a DecisionTree model.
val dt = new DecisionTreeRegressor()
.setLabelCol("SalePrice")
.setFeaturesCol("features")
// Chain indexer and tree in a Pipeline.
val pipeline2 = new Pipeline()
.setStages(Array(dt))
// Train model. This also runs the indexer.
val model2 = pipeline2.fit(trainingData)
// Make predictions.
val predictions2 = model2.transform(testData)
// Select example rows to display.
val outputFile = System.getProperty("user.dir") + "/housing-predictions/" + Calendar.getInstance().getTime.toString
predictions2.withColumnRenamed("prediction", "SalePrice").select("Id","SalePrice").coalesce(1).write.option("header", "true").csv(outputFile)
predictions2.show(50)
println("Prediction output exported as " + outputFile + ".csv")
spark.stop()
// Select (prediction, true label) and compute test error.
// val evaluator2 = new RegressionEvaluator()
// .setLabelCol("label")
// .setPredictionCol("prediction")
// .setMetricName("rmse")
// val rmse2 = evaluator2.evaluate(predictions2)
//
// val treeModel = model2.stages(0).asInstanceOf[DecisionTreeRegressionModel]
// println("Learned regression tree model:\n" + treeModel.toDebugString)
// Train a RandomForest model.
// val rf = new RandomForestRegressor()
// .setLabelCol("SalePrice")
// .setFeaturesCol("features")
//
// // Chain indexer and forest in a Pipeline.
// val rfPipeline = new Pipeline()
// .setStages(Array(rf))
//
// // Train model. This also runs the indexer.
// val rfFittedModel = rfPipeline.fit(trainingData)
//
// // Make predictions.
// val rfPredictions = rfFittedModel.transform(testData)
//
// // Select example rows to display.
// rfPredictions.select("prediction", "SalePrice", "features").show(5)
//
// // Select (prediction, true label) and compute test error.
// val rfevaluator = new RegressionEvaluator()
// .setLabelCol("SalePrice")
// .setPredictionCol("prediction")
// .setMetricName("rmse")
// val RFrmse = rfevaluator.evaluate(rfPredictions)
//
// val rfModel = rfFittedModel.stages(0).asInstanceOf[RandomForestRegressionModel]
// println("Learned regression forest model:\n" + rfModel.toDebugString)
//
// // Train a GBT model.
// val gbt = new GBTRegressor()
// .setLabelCol("SalePrice")
// .setFeaturesCol("features")
// .setMaxIter(10)
//
// // Chain indexer and GBT in a Pipeline.
// val pipeline = new Pipeline()
// .setStages(Array(gbt))
//
// // Train model. This also runs the indexer.
// val model = pipeline.fit(trainingData)
//
// // Make predictions.
// val predictions = model.transform(testData)
//
// // Select example rows to display.
// predictions.select("prediction", "SalePrice", "features").show(5)
//
// // Select (prediction, true label) and compute test error.
// val evaluator = new RegressionEvaluator()
// .setLabelCol("SalePrice")
// .setPredictionCol("prediction")
// .setMetricName("rmse")
// val rmse = evaluator.evaluate(predictions)
// println("Root Mean Squared Error (RMSE) on GBT model test data = " + rmse)
//
// val gbtModel = model.stages(0).asInstanceOf[GBTRegressionModel]
//
//
//
// println("Learned regression GBT model:\n" + gbtModel.toDebugString)
//
// println("Decision Tree RMSE = " + rmse2)
// println("RandomForestRegressor RMSE = " + RFrmse)
// println("Gradient-boosted Tree RMSE = " + rmse)
}
}