### Calculate a basic cleaned stream

In [46]:
%AddJar -magic https://brunelvis.org/jar/spark-kernel-brunel-all-2.2.jar

Using cached version of spark-kernel-brunel-all-2.2.jar


In [47]:
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._
import org.apache.spark.sql.functions._

val spark = SparkSession.builder().getOrCreate()
import spark.implicits._ 

val raw = (
    spark.read.option("inferSchema", "true")
    .json("../data/winemag-data-130k-v2.json")
)

val rawdf = (
    raw
    .select(
        trim(lower($"variety")) as "variety",
        trim(lower($"description")) as "description")
    .dropDuplicates(Seq("description"))
    .filter($"variety".isNotNull)
    .filter(not($"variety".contains("blend")))
    .filter(not($"variety".contains("red")))
    .filter(not($"variety".contains("white")))
    .select($"variety", regexp_replace($"description", $"variety", lit("")) as "description")
    .select($"variety", regexp_replace($"description", "[^\\p{L}\\p{Nd}[0-9]+]+", " ") as "description")
    .select($"variety", trim(lower($"description")) as "description")
    .cache
)

In [48]:
raw.printSchema()

root
 |-- country: string (nullable = true)
 |-- description: string (nullable = true)
 |-- designation: string (nullable = true)
 |-- points: string (nullable = true)
 |-- price: long (nullable = true)
 |-- province: string (nullable = true)
 |-- region_1: string (nullable = true)
 |-- region_2: string (nullable = true)
 |-- taster_name: string (nullable = true)
 |-- taster_twitter_handle: string (nullable = true)
 |-- title: string (nullable = true)
 |-- variety: string (nullable = true)
 |-- winery: string (nullable = true)



In [49]:
val varietyCounts = (
    raw
    .groupBy($"variety")
    .agg(
        count("variety") as "count",
        stddev("price") as "price_stddev",
        mean("price") as "mean_price",
        min("price") as "min_price",
        max("price") as "max_price"
    )
    .where("count > 100")

)

In [50]:
%%brunel 
    data('varietyCounts') 
    bubble color(count) 
    size(count) 
    sort(count) 
    label(variety, count) 
    tooltip(#all) 
    style('* {font-size: 7pt}') :: width=1000, height=1000

In [51]:
val varietyCounts2 = (
    raw
    .groupBy($"variety", $"country")
    .count()
    .where("count > 1000")
)

In [52]:
%%brunel 
    data('varietyCounts2') 
    chord x(country) y(variety) 
    color(count) 
    size(count) 
    sort(count) 
    label(variety) 
    tooltip(#all) :: width=1000, height=1000

### Select the varieties with > 2000 reviews

In [53]:
val df = (rawdf
 .groupBy($"variety")
 .agg(count("variety") as "count")
 .where("count > 3000")
 .join(rawdf, Seq("variety"))
 .orderBy("variety")
 .select($"variety", $"description", $"count")).cache

In [54]:
import org.apache.spark.ml.feature.StopWordsRemover
import scala.collection.mutable.WrappedArray
import org.apache.spark.ml.feature.Tokenizer

val varietySplits = (
    new Tokenizer()             
    .setInputCol("variety")
    .setOutputCol("variety_splits")
    .transform(df.select("variety").distinct())
    .select("variety_splits")
    .collect()
    .map(_.toSeq.asInstanceOf[WrappedArray[WrappedArray[String]]])
    .flatMap(_.toSeq)
    .flatMap(_.toSeq)
    .toList
)

val tokenizer = (
    new Tokenizer()
    .setInputCol("description")
    .setOutputCol("words")
)

val stopWordsRemover = (
    new StopWordsRemover()
    .setInputCol(tokenizer.getOutputCol)
    .setOutputCol("filteredWords")
)

stopWordsRemover.setStopWords((varietySplits:::stopWordsRemover.getStopWords.toList).toSet.toArray)


stopWords_f6503d767e08

In [55]:
import org.apache.spark.ml.feature.CountVectorizer
import org.apache.spark.ml.feature.{HashingTF, IDF, Tokenizer}

val countVectorizer = (
    new CountVectorizer()
    .setInputCol(stopWordsRemover.getOutputCol)
    .setOutputCol("features")
)

In [56]:
import org.apache.spark.ml.feature.StringIndexer
import org.apache.spark.ml.feature.Normalizer

val indexer = (
    new StringIndexer()
    .setInputCol("variety")
    .setOutputCol("label")
)

In [57]:
import org.apache.spark.ml.classification.NaiveBayes
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator

val Array(trainingData, testData) = (
    indexer           
    .fit(df)
    .transform(df)
    .randomSplit(Array(0.7, 0.3), 42L)
)

In [58]:
import org.apache.spark.ml.{Pipeline, PipelineModel}

val naiveBayes = new NaiveBayes()

val pipeline = (
    new Pipeline()
    .setStages(
        Array(
            tokenizer, 
            stopWordsRemover, 
            countVectorizer,
            naiveBayes
        )
    )
)
            
val model = pipeline.fit(trainingData)

In [59]:
val predictions = model.transform(testData)

val evaluator = (new MulticlassClassificationEvaluator()
  .setLabelCol("label")
  .setPredictionCol("prediction")
  .setMetricName("accuracy"))

val accuracy = evaluator.evaluate(predictions)
println("Test set accuracy = " + accuracy)

Test set accuracy = 0.7747381477398015
