In [3]:
import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.feature.{VectorAssembler, StringIndexer, VectorIndexer, OneHotEncoder}
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.Pipeline
import org.apache.spark.mllib.evaluation.MulticlassMetrics

import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.feature.{VectorAssembler, StringIndexer, VectorIndexer, OneHotEncoder}
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.Pipeline
import org.apache.spark.mllib.evaluation.MulticlassMetrics


In [4]:
val spark = SparkSession.builder().getOrCreate()
val data = spark.read.option("header", "true")
            .option("inferSchema", "true")
            .format("csv")
            .load("../datasets/advertising.csv")
data.show()

+------------------------+---+-----------+--------------------+--------------------+-----------------+----+--------------------+-------------------+-------------+
|Daily Time Spent on Site|Age|Area Income|Daily Internet Usage|       Ad Topic Line|             City|Male|             Country|          Timestamp|Clicked on Ad|
+------------------------+---+-----------+--------------------+--------------------+-----------------+----+--------------------+-------------------+-------------+
|                   68.95| 35|    61833.9|              256.09|Cloned 5thgenerat...|      Wrightburgh|   0|             Tunisia|2016-03-27 00:53:11|            0|
|                   80.23| 31|   68441.85|              193.77|Monitored nationa...|        West Jodi|   1|               Nauru|2016-04-04 01:39:02|            0|
|                   69.47| 26|   59785.94|               236.5|Organic bottom-li...|         Davidton|   0|          San Marino|2016-03-13 20:35:42|            0|
|                   74

spark: org.apache.spark.sql.SparkSession = org.apache.spark.sql.SparkSession@a4b53c4
data: org.apache.spark.sql.DataFrame = [Daily Time Spent on Site: double, Age: int ... 8 more fields]


In [5]:
data.printSchema()

root
 |-- Daily Time Spent on Site: double (nullable = true)
 |-- Age: integer (nullable = true)
 |-- Area Income: double (nullable = true)
 |-- Daily Internet Usage: double (nullable = true)
 |-- Ad Topic Line: string (nullable = true)
 |-- City: string (nullable = true)
 |-- Male: integer (nullable = true)
 |-- Country: string (nullable = true)
 |-- Timestamp: string (nullable = true)
 |-- Clicked on Ad: integer (nullable = true)



In [6]:
data.head(1)

res3: Array[org.apache.spark.sql.Row] = Array([68.95,35,61833.9,256.09,Cloned 5thgeneration orchestration,Wrightburgh,0,Tunisia,2016-03-27 00:53:11,0])


In [9]:
val colnames = data.columns
val firstrow = data.head(1)(0)
println("Example Data Row")
for(ind <- Range(1,colnames.length)){
  println(colnames(ind))
  println(firstrow(ind))
  println("\n")
}

Example Data Row
Age
35


Area Income
61833.9


Daily Internet Usage
256.09


Ad Topic Line
Cloned 5thgeneration orchestration


City
Wrightburgh


Male
0


Country
Tunisia


Timestamp
2016-03-27 00:53:11


Clicked on Ad
0




colnames: Array[String] = Array(Daily Time Spent on Site, Age, Area Income, Daily Internet Usage, Ad Topic Line, City, Male, Country, Timestamp, Clicked on Ad)
firstrow: org.apache.spark.sql.Row = [68.95,35,61833.9,256.09,Cloned 5thgeneration orchestration,Wrightburgh,0,Tunisia,2016-03-27 00:53:11,0]


In [20]:
val timedata = data.withColumn("Hour", hour(data("Timestamp")))

timedata: org.apache.spark.sql.DataFrame = [Daily Time Spent on Site: double, Age: int ... 9 more fields]


In [22]:
val df = timedata.select(data("Clicked on Ad").as("label"),
                         $"Daily Time Spent on Site", $"Age", $"Area Income", $"Daily Internet Usage", $"Hour", $"Male")

df: org.apache.spark.sql.DataFrame = [label: int, Daily Time Spent on Site: double ... 5 more fields]


In [23]:
val assembler = new VectorAssembler()
                .setInputCols(Array("Daily Time Spent on Site", "Age", "Area Income", "Daily Internet Usage", "Hour", "Male"))
                .setOutputCol("features")

assembler: org.apache.spark.ml.feature.VectorAssembler = VectorAssembler: uid=vecAssembler_c9fe68448861, handleInvalid=error, numInputCols=6


In [25]:
val Array(train, test) = df.randomSplit(Array(0.7, 0.3), seed=42)

train: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [label: int, Daily Time Spent on Site: double ... 5 more fields]
test: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [label: int, Daily Time Spent on Site: double ... 5 more fields]


Build the model

In [26]:
val lr = new LogisticRegression()

lr: org.apache.spark.ml.classification.LogisticRegression = logreg_ad31a6aa16d2


In [27]:
val pipeline = new Pipeline()
                .setStages(Array(assembler, lr))

pipeline: org.apache.spark.ml.Pipeline = pipeline_dc4a8fe3fcad


In [28]:
val model = pipeline.fit(train)

model: org.apache.spark.ml.PipelineModel = pipeline_dc4a8fe3fcad


In [29]:
val results = model.transform(test)

results: org.apache.spark.sql.DataFrame = [label: int, Daily Time Spent on Site: double ... 9 more fields]


Evaluation

In [30]:
val predictionAndLabels = results.select($"prediction", $"label")
                            .as[(Double, Double)].rdd

predictionAndLabels: org.apache.spark.rdd.RDD[(Double, Double)] = MapPartitionsRDD[103] at rdd at <console>:33


In [31]:
val metrics = new MulticlassMetrics(predictionAndLabels)

metrics: org.apache.spark.mllib.evaluation.MulticlassMetrics = org.apache.spark.mllib.evaluation.MulticlassMetrics@35e298a8


In [32]:
metrics.confusionMatrix

res17: org.apache.spark.mllib.linalg.Matrix =
125.0  1.0
6.0    124.0
