# Exemple d'apprentissage supervisé en  Java (SparkMLlib) - cegep stefoy

### Téléchargement des librairies avec Maven

In [1]:
%%loadFromPOM
<dependencies>
        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-mllib_2.11</artifactId>
            <version>2.2.0</version>
        </dependency>
</dependencies>

### Désactivation des logs verbeux

In [2]:
import org.apache.log4j.Level;
import org.apache.log4j.Logger;

Logger.getLogger("org").setLevel(Level.OFF);
Logger.getLogger("akka").setLevel(Level.OFF);

### Initialisation de la session spark

In [3]:
import org.apache.spark.sql.SparkSession;
SparkSession sparkSession = SparkSession.builder().appName("SparkIris").config("spark.master", "local").getOrCreate();

Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties


### Chargement du jeu de données (Iris)

In [4]:
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
Dataset<Row> rawData = sparkSession.read().option("header", "true").csv("Iris.csv");

### Quelques lignes du jeu de données

In [5]:
rawData.show(5)

+---+-------------+------------+-------------+------------+-----------+
| Id|SepalLengthCm|SepalWidthCm|PetalLengthCm|PetalWidthCm|    Species|
+---+-------------+------------+-------------+------------+-----------+
|  1|          5.1|         3.5|          1.4|         0.2|Iris-setosa|
|  2|          4.9|         3.0|          1.4|         0.2|Iris-setosa|
|  3|          4.7|         3.2|          1.3|         0.2|Iris-setosa|
|  4|          4.6|         3.1|          1.5|         0.2|Iris-setosa|
|  5|          5.0|         3.6|          1.4|         0.2|Iris-setosa|
+---+-------------+------------+-------------+------------+-----------+
only showing top 5 rows



### Type des données

In [6]:
rawData.printSchema();

root
 |-- Id: string (nullable = true)
 |-- SepalLengthCm: string (nullable = true)
 |-- SepalWidthCm: string (nullable = true)
 |-- PetalLengthCm: string (nullable = true)
 |-- PetalWidthCm: string (nullable = true)
 |-- Species: string (nullable = true)



### ajustement des types des colonnes

In [7]:
import static org.apache.spark.sql.functions.col;
import static org.apache.spark.sql.types.DataTypes.DoubleType;

// cast the values of the features to doubles for usage in the feature column vector
Dataset<Row> transformedDataSet = rawData.withColumn("SepalLengthCm", rawData.col("SepalLengthCm").cast("double"))
                .withColumn("SepalWidthCm", rawData.col("SepalWidthCm").cast("double"))
                .withColumn("PetalLengthCm", rawData.col("PetalLengthCm").cast("double"))
                .withColumn("PetalWidthCm", rawData.col("PetalWidthCm").cast("double"));

In [8]:
transformedDataSet.printSchema();

root
 |-- Id: string (nullable = true)
 |-- SepalLengthCm: double (nullable = true)
 |-- SepalWidthCm: double (nullable = true)
 |-- PetalLengthCm: double (nullable = true)
 |-- PetalWidthCm: double (nullable = true)
 |-- Species: string (nullable = true)



### Mapping de la variable de la variable de dépendante avec les valeurs numeriques

In [9]:
import static org.apache.spark.sql.functions.when;
// add a numerical label column for the Random Forest Classifier
transformedDataSet = transformedDataSet
                .withColumn("label", when(col("Species").equalTo("Iris-setosa"),1)
                .when(col("Species").equalTo("Iris-versicolor"),2)
                .otherwise(3));

In [11]:
transformedDataSet.sample(true,0.05).show()

+---+-------------+------------+-------------+------------+---------------+-----+
| Id|SepalLengthCm|SepalWidthCm|PetalLengthCm|PetalWidthCm|        Species|label|
+---+-------------+------------+-------------+------------+---------------+-----+
|  5|          5.0|         3.6|          1.4|         0.2|    Iris-setosa|    1|
|  7|          4.6|         3.4|          1.4|         0.3|    Iris-setosa|    1|
| 12|          4.8|         3.4|          1.6|         0.2|    Iris-setosa|    1|
| 14|          4.3|         3.0|          1.1|         0.1|    Iris-setosa|    1|
| 27|          5.0|         3.4|          1.6|         0.4|    Iris-setosa|    1|
| 28|          5.2|         3.5|          1.5|         0.2|    Iris-setosa|    1|
| 33|          5.2|         4.1|          1.5|         0.1|    Iris-setosa|    1|
| 60|          5.2|         2.7|          3.9|         1.4|Iris-versicolor|    2|
| 90|          5.5|         2.5|          4.0|         1.3|Iris-versicolor|    2|
|136|          7

### Choix des variables discriminantes

In [12]:
import org.apache.spark.ml.feature.VectorAssembler;
// identify the feature colunms
        String[] inputColumns = {"SepalLengthCm", "SepalWidthCm", "PetalLengthCm", "PetalWidthCm"};
        VectorAssembler assembler = new VectorAssembler().setInputCols(inputColumns).setOutputCol("features");
        Dataset<Row> featureSet = assembler.transform(transformedDataSet);

### Division du jeu de données (Training & Test)

In [13]:
// split data random in trainingset (70%) and testset (30%) using a seed so results can be reproduced
        long seed = 5043;
        Dataset<Row>[] trainingAndTestSet = featureSet.randomSplit(new double[]{0.7, 0.3}, seed);
        Dataset<Row> trainingSet = trainingAndTestSet[0];
        Dataset<Row> testSet = trainingAndTestSet[1];

In [14]:
trainingSet.show(5)

+---+-------------+------------+-------------+------------+---------------+-----+-----------------+
| Id|SepalLengthCm|SepalWidthCm|PetalLengthCm|PetalWidthCm|        Species|label|         features|
+---+-------------+------------+-------------+------------+---------------+-----+-----------------+
| 10|          4.9|         3.1|          1.5|         0.1|    Iris-setosa|    1|[4.9,3.1,1.5,0.1]|
|100|          5.7|         2.8|          4.1|         1.3|Iris-versicolor|    2|[5.7,2.8,4.1,1.3]|
|102|          5.8|         2.7|          5.1|         1.9| Iris-virginica|    3|[5.8,2.7,5.1,1.9]|
|103|          7.1|         3.0|          5.9|         2.1| Iris-virginica|    3|[7.1,3.0,5.9,2.1]|
|105|          6.5|         3.0|          5.8|         2.2| Iris-virginica|    3|[6.5,3.0,5.8,2.2]|
+---+-------------+------------+-------------+------------+---------------+-----+-----------------+
only showing top 5 rows



In [15]:
testSet.show(5)

+---+-------------+------------+-------------+------------+--------------+-----+-----------------+
| Id|SepalLengthCm|SepalWidthCm|PetalLengthCm|PetalWidthCm|       Species|label|         features|
+---+-------------+------------+-------------+------------+--------------+-----+-----------------+
|  1|          5.1|         3.5|          1.4|         0.2|   Iris-setosa|    1|[5.1,3.5,1.4,0.2]|
|101|          6.3|         3.3|          6.0|         2.5|Iris-virginica|    3|[6.3,3.3,6.0,2.5]|
|104|          6.3|         2.9|          5.6|         1.8|Iris-virginica|    3|[6.3,2.9,5.6,1.8]|
|107|          4.9|         2.5|          4.5|         1.7|Iris-virginica|    3|[4.9,2.5,4.5,1.7]|
|120|          6.0|         2.2|          5.0|         1.5|Iris-virginica|    3|[6.0,2.2,5.0,1.5]|
+---+-------------+------------+-------------+------------+--------------+-----+-----------------+
only showing top 5 rows



### Entrainement avec l'aglorithme  Random Forest

In [31]:
import org.apache.spark.ml.classification.RandomForestClassificationModel;
import org.apache.spark.ml.classification.RandomForestClassifier;
  // train the algorithm based on a Random Forest Classification Algorithm with default values
         RandomForestClassifier randomForestClassifier = new RandomForestClassifier()
                .setImpurity("gini")
                .setMaxDepth(3)
                .setNumTrees(20)
                .setFeatureSubsetStrategy("auto")
                .setSeed(seed);

        RandomForestClassificationModel model = randomForestClassifier.fit(trainingSet);


In [30]:
System.out.println(model.toDebugString())

RandomForestClassificationModel (uid=rfc_029fd82bb333) with 20 trees
  Tree 0 (weight 1.0):
    If (feature 3 <= 0.6)
     Predict: 1.0
    Else (feature 3 > 0.6)
     If (feature 2 <= 4.8)
      If (feature 3 <= 1.6)
       Predict: 2.0
      Else (feature 3 > 1.6)
       Predict: 2.0
     Else (feature 2 > 4.8)
      If (feature 3 <= 1.6)
       Predict: 3.0
      Else (feature 3 > 1.6)
       Predict: 3.0
  Tree 1 (weight 1.0):
    If (feature 3 <= 0.6)
     Predict: 1.0
    Else (feature 3 > 0.6)
     If (feature 3 <= 1.7)
      If (feature 2 <= 5.0)
       Predict: 2.0
      Else (feature 2 > 5.0)
       Predict: 3.0
     Else (feature 3 > 1.7)
      If (feature 0 <= 5.9)
       Predict: 3.0
      Else (feature 0 > 5.9)
       Predict: 3.0
  Tree 2 (weight 1.0):
    If (feature 3 <= 0.6)
     Predict: 1.0
    Else (feature 3 > 0.6)
     If (feature 2 <= 4.8)
      If (feature 2 <= 4.6)
       Predict: 2.0
      Else (feature 2 > 4.6)
       Predict: 2.0
     Else (feature 2 > 4.8)

### Test du modèle

In [23]:
// test the model against the testset and show results
        Dataset<Row> predictions = model.transform(testSet);
        predictions.select("id", "Species", "label", "prediction").show(5);

+---+--------------+-----+----------+
| id|       Species|label|prediction|
+---+--------------+-----+----------+
|  1|   Iris-setosa|    1|       1.0|
|101|Iris-virginica|    3|       3.0|
|104|Iris-virginica|    3|       3.0|
|107|Iris-virginica|    3|       2.0|
|120|Iris-virginica|    3|       2.0|
+---+--------------+-----+----------+
only showing top 5 rows



### Evaluation

In [24]:
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
        // evaluate the model
        MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()
                .setLabelCol("label")
                .setPredictionCol("prediction")
                .setMetricName("accuracy");

### Précision du modèle

In [25]:
System.out.println("accuracy: " + evaluator.evaluate(predictions));

accuracy: 0.9393939393939394


### Sauvegarde du modèle

In [20]:
model.save("model/iris-random-forest")

### Utilisation du modèle

In [55]:
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;

//chargement du modèle
RandomForestClassificationModel model_loaded = RandomForestClassificationModel.load("model/iris-random-forest");

//prédiction avec un exemple 
Double predicted_value = model_loaded.predict(Vectors.dense(new double[] {6.5, 3.0, 5.8, 2.2}).asML());;

/*Iris-setosa 1
Iris-versicolor 2
Iris-virginica 3 */
System.out.println(predicted_value);


3.0
