### Calculate a basic cleaned stream

In [None]:
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._
import org.apache.spark.sql.functions._

val spark = SparkSession.builder().getOrCreate()
import spark.implicits._ 

var rawdf = (
    spark.read.option("inferSchema", "true")
    .json("../data/winemag-data-130k-v2.json")
    .select(
        trim(lower($"variety")) as "variety",
        trim(lower($"description")) as "description")
    .dropDuplicates(Seq("description"))
    .filter($"variety".isNotNull)
    .filter(not($"variety".contains("blend")))
    .filter(not($"variety".contains("red")))
    .filter(not($"variety".contains("white")))
    .select($"variety", regexp_replace($"description", $"variety", lit("")) as "description")
    .select($"variety", regexp_replace($"description", "[^\\p{L}\\p{Nd}[0-9]+]+", " ") as "description")
    .select($"variety", trim(lower($"description")) as "description")
    .cache
)

rawdf.show(3)

### Select the varieties with > 3000 reviews

In [None]:
val df = (rawdf
 .groupBy($"variety")
 .agg(count("variety") as "count")
 .where("count > 3000")
 .join(rawdf, Seq("variety"))
 .orderBy("variety")
 .select($"variety", $"description")).cache

df.show(3)

In [None]:
import org.apache.spark.ml.feature.StopWordsRemover
import scala.collection.mutable.WrappedArray
import org.apache.spark.ml.feature.Tokenizer

val varietySplits = (
    new Tokenizer()             
    .setInputCol("variety")
    .setOutputCol("variety_splits")
    .transform(df.select("variety").distinct())
    .select("variety_splits")
    .collect()
    .map(_.toSeq.asInstanceOf[WrappedArray[WrappedArray[String]]])
    .flatMap(_.toSeq)
    .flatMap(_.toSeq)
    .toList
)

val wordsDf = (
    new Tokenizer()
    .setInputCol("description")
    .setOutputCol("words")
    .transform(df)
    .select($"variety", $"words")
)

val remover = new StopWordsRemover().setInputCol("words").setOutputCol("filteredWords")
remover.setStopWords((varietySplits:::remover.getStopWords.toList).toSet.toArray)

val noStopWordsDf = remover.transform(wordsDf).select($"variety", $"filteredWords"as "words")
noStopWordsDf.show(3)

In [None]:
import org.apache.spark.ml.feature.CountVectorizer

val countVectorizer = new CountVectorizer().setInputCol("words").setOutputCol("features")
val countVectorizerModel = countVectorizer.fit(noStopWordsDf)
val countVectorizerDF = countVectorizerModel.transform(noStopWordsDf)

countVectorizerDF.show(3,true)

In [None]:
import org.apache.spark.ml.feature.StringIndexer
import org.apache.spark.ml.feature.Normalizer

val indexer = (
    new StringIndexer()
    .setInputCol("variety")
    .setOutputCol("varietyIndex"))

val indexed = (
    indexer           
    .fit(countVectorizerDF).transform(countVectorizerDF)
    .select($"variety", $"varietyIndex".cast("double") as "label", $"features"))

indexed.show(3)

In [None]:
import org.apache.spark.ml.classification.NaiveBayes
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator

val Array(trainingData, testData) = indexed.randomSplit(Array(0.7, 0.3), 42L)

val model = new NaiveBayes().fit(trainingData)

In [None]:
val predictions = model.transform(testData)

val evaluator = (new MulticlassClassificationEvaluator()
  .setLabelCol("label")
  .setPredictionCol("prediction")
  .setMetricName("accuracy"))

val accuracy = evaluator.evaluate(predictions)
println("Test set accuracy = " + accuracy)