## Overview

This notebook will show you how to create and query a table or DataFrame that you uploaded to DBFS. [DBFS](https://docs.databricks.com/user-guide/dbfs-databricks-file-system.html) is a Databricks File System that allows you to store data for querying inside of Databricks. This notebook assumes that you have a file already inside of DBFS that you would like to read from.

This notebook is written in **Python** so the default cell type is Python. However, you can use different languages by using the `%LANGUAGE` syntax. Python, Scala, SQL, and R are all supported.

In [0]:
# File location and type
file_location = "/FileStore/shared_uploads/desaizeeshan22@gmail.com/preprocessed_ecom.csv"
file_type = "csv"

# CSV options
infer_schema = "true"
first_row_is_header = "true"
delimiter = ","

# The applied options are for CSV files. For other file types, these will be ignored.
df = spark.read.format(file_type) \
  .option("inferSchema", infer_schema) \
  .option("header", first_row_is_header) \
  .option("sep", delimiter) \
  .load(file_location)

display(df)

_c0,price,retail_price,units_sold,uses_ad_boosts,rating,rating_five_count,rating_four_count,rating_three_count,rating_two_count,rating_one_count,badges_count,badge_local_product,badge_product_quality,badge_fast_shipping,product_variation_inventory,shipping_option_price,shipping_is_express,countries_shipped_to,inventory_total,has_urgency_banner,merchant_rating_count,merchant_rating,merchant_has_profile_picture,COLOR__Pink24,COLOR__army,COLOR__beige,COLOR__black,COLOR__blackwhite,COLOR__blue,COLOR__brown,COLOR__claret,COLOR__coolblack,COLOR__dual,COLOR__dustypink,COLOR__gold,COLOR__green,COLOR__grey,COLOR__greysnakeskinprint,COLOR__ivory,COLOR__jasper,COLOR__leopardprint,COLOR__lightgray,COLOR__lightgrey,COLOR__lightpurple,COLOR__nude,COLOR__offblack,COLOR__offwhite,COLOR__orange,COLOR__other,COLOR__pink50,COLOR__purple,COLOR__rainbow,COLOR__red,COLOR__star,COLOR__tan,COLOR__violet,COLOR__white,COLOR__whitestripe,COLOR__yellow,origin__OTHER,origin__US,prodsize__M,prodsize__S,prodsize__XL,prodsize__XS,prodsize__XXL,prodsize__XXS,prodsize__XXXS,prodsize__XXXXL,prodsize__XXXXXL,prodsize__other
0,16.0,14,100,0,3.76,26.0,8.0,10.0,1.0,9.0,0,0,0,0,50,4,0,34,50,1.0,568,4.128521127,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,1,0,0,0,0,0,0,0,0,0
1,8.0,22,20000,1,3.45,2269.0,1027.0,1118.0,644.0,1077.0,0,0,0,0,50,2,0,41,50,1.0,17752,3.899673276,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0
2,8.0,43,100,0,3.57,5.0,4.0,2.0,0.0,3.0,0,0,0,0,1,3,0,36,50,1.0,295,3.989830508,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0
3,8.0,8,5000,1,4.03,295.0,119.0,87.0,42.0,36.0,0,0,0,0,50,2,0,41,50,0.0,23832,4.02043471,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0
4,2.72,3,100,1,3.1,6.0,4.0,2.0,2.0,6.0,0,0,0,0,1,1,0,35,50,1.0,14482,4.001588178,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,0
5,3.92,9,10,0,5.0,1.0,0.0,0.0,0.0,0.0,0,0,0,0,1,1,0,40,50,0.0,65,3.507692308,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1
6,7.0,6,50000,0,3.84,3172.0,1352.0,971.0,490.0,757.0,0,0,0,0,50,2,0,31,50,0.0,10194,4.076515597,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1,0,0,0,0,0,0
7,12.0,11,1000,0,3.76,120.0,56.0,61.0,18.0,31.0,0,0,0,0,50,3,0,139,50,0.0,342,3.68128655,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1
8,11.0,84,100,1,3.47,6.0,2.0,3.0,1.0,3.0,0,0,0,0,50,2,0,36,50,1.0,330,3.803030303,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0
9,5.78,22,5000,0,3.6,287.0,128.0,92.0,68.0,112.0,0,0,0,0,50,2,0,33,50,0.0,5534,3.999819299,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0


In [0]:
# Create a view or table

temp_table_name = "preprocessed_ecom_csv"

df.createOrReplaceTempView(temp_table_name)

In [0]:
%sql

/* Query the created temp table in a SQL cell */

select * from `preprocessed_ecom_csv`

_c0,price,retail_price,units_sold,uses_ad_boosts,rating,rating_five_count,rating_four_count,rating_three_count,rating_two_count,rating_one_count,badges_count,badge_local_product,badge_product_quality,badge_fast_shipping,product_variation_inventory,shipping_option_price,shipping_is_express,countries_shipped_to,inventory_total,has_urgency_banner,merchant_rating_count,merchant_rating,merchant_has_profile_picture,COLOR__Pink24,COLOR__army,COLOR__beige,COLOR__black,COLOR__blackwhite,COLOR__blue,COLOR__brown,COLOR__claret,COLOR__coolblack,COLOR__dual,COLOR__dustypink,COLOR__gold,COLOR__green,COLOR__grey,COLOR__greysnakeskinprint,COLOR__ivory,COLOR__jasper,COLOR__leopardprint,COLOR__lightgray,COLOR__lightgrey,COLOR__lightpurple,COLOR__nude,COLOR__offblack,COLOR__offwhite,COLOR__orange,COLOR__other,COLOR__pink50,COLOR__purple,COLOR__rainbow,COLOR__red,COLOR__star,COLOR__tan,COLOR__violet,COLOR__white,COLOR__whitestripe,COLOR__yellow,origin__OTHER,origin__US,prodsize__M,prodsize__S,prodsize__XL,prodsize__XS,prodsize__XXL,prodsize__XXS,prodsize__XXXS,prodsize__XXXXL,prodsize__XXXXXL,prodsize__other
0,16.0,14,100,0,3.76,26.0,8.0,10.0,1.0,9.0,0,0,0,0,50,4,0,34,50,1.0,568,4.128521127,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,1,0,0,0,0,0,0,0,0,0
1,8.0,22,20000,1,3.45,2269.0,1027.0,1118.0,644.0,1077.0,0,0,0,0,50,2,0,41,50,1.0,17752,3.899673276,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0
2,8.0,43,100,0,3.57,5.0,4.0,2.0,0.0,3.0,0,0,0,0,1,3,0,36,50,1.0,295,3.989830508,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0
3,8.0,8,5000,1,4.03,295.0,119.0,87.0,42.0,36.0,0,0,0,0,50,2,0,41,50,0.0,23832,4.02043471,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0
4,2.72,3,100,1,3.1,6.0,4.0,2.0,2.0,6.0,0,0,0,0,1,1,0,35,50,1.0,14482,4.001588178,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,0
5,3.92,9,10,0,5.0,1.0,0.0,0.0,0.0,0.0,0,0,0,0,1,1,0,40,50,0.0,65,3.507692308,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1
6,7.0,6,50000,0,3.84,3172.0,1352.0,971.0,490.0,757.0,0,0,0,0,50,2,0,31,50,0.0,10194,4.076515597,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1,0,0,0,0,0,0
7,12.0,11,1000,0,3.76,120.0,56.0,61.0,18.0,31.0,0,0,0,0,50,3,0,139,50,0.0,342,3.68128655,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1
8,11.0,84,100,1,3.47,6.0,2.0,3.0,1.0,3.0,0,0,0,0,50,2,0,36,50,1.0,330,3.803030303,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0
9,5.78,22,5000,0,3.6,287.0,128.0,92.0,68.0,112.0,0,0,0,0,50,2,0,33,50,0.0,5534,3.999819299,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0


In [0]:
# With this registered as a temp view, it will only be available to this particular notebook. If you'd like other users to be able to query this table, you can also create a table from the DataFrame.
# Once saved, this table will persist across cluster restarts as well as allow various users across different notebooks to query this data.
# To do so, choose your table name and uncomment the bottom line.

permanent_table_name = "preprocessed_ecom_csv"

# df.write.format("parquet").saveAsTable(permanent_table_name)

In [0]:
from collections import defaultdict
data_types = defaultdict(list)
for entry in df.schema.fields:
    data_types[str(entry.dataType)].append(entry.name)

In [0]:
data_types

In [0]:
len([x for x in df.columns if x!='units_sold'])

In [0]:
from pyspark.ml.feature import VectorAssembler
features = [x for x in df.columns if x!='units_sold']
vector_assembler = VectorAssembler(inputCols = features, outputCol= "features")
data_training_and_test = vector_assembler.transform(df)

In [0]:
from pyspark.ml.feature import StringIndexer
l_indexer = StringIndexer(inputCol="units_sold", outputCol="labelIndex")
data_training_and_test  = l_indexer.fit(data_training_and_test ).transform(data_training_and_test )

In [0]:
(training_data, test_data) = data_training_and_test.randomSplit([0.7, 0.3], seed=10)

In [0]:
display(training_data)

_c0,price,retail_price,units_sold,uses_ad_boosts,rating,rating_five_count,rating_four_count,rating_three_count,rating_two_count,rating_one_count,badges_count,badge_local_product,badge_product_quality,badge_fast_shipping,product_variation_inventory,shipping_option_price,shipping_is_express,countries_shipped_to,inventory_total,has_urgency_banner,merchant_rating_count,merchant_rating,merchant_has_profile_picture,COLOR__Pink24,COLOR__army,COLOR__beige,COLOR__black,COLOR__blackwhite,COLOR__blue,COLOR__brown,COLOR__claret,COLOR__coolblack,COLOR__dual,COLOR__dustypink,COLOR__gold,COLOR__green,COLOR__grey,COLOR__greysnakeskinprint,COLOR__ivory,COLOR__jasper,COLOR__leopardprint,COLOR__lightgray,COLOR__lightgrey,COLOR__lightpurple,COLOR__nude,COLOR__offblack,COLOR__offwhite,COLOR__orange,COLOR__other,COLOR__pink50,COLOR__purple,COLOR__rainbow,COLOR__red,COLOR__star,COLOR__tan,COLOR__violet,COLOR__white,COLOR__whitestripe,COLOR__yellow,origin__OTHER,origin__US,prodsize__M,prodsize__S,prodsize__XL,prodsize__XS,prodsize__XXL,prodsize__XXS,prodsize__XXXS,prodsize__XXXXL,prodsize__XXXXXL,prodsize__other,features,labelIndex
0,16.0,14,100,0,3.76,26.0,8.0,10.0,1.0,9.0,0,0,0,0,50,4,0,34,50,1.0,568,4.128521127,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,1,0,0,0,0,0,0,0,0,0,"List(0, 71, List(1, 2, 4, 5, 6, 7, 8, 9, 14, 15, 17, 18, 19, 20, 21, 56, 61), List(16.0, 14.0, 3.76, 26.0, 8.0, 10.0, 1.0, 9.0, 50.0, 4.0, 34.0, 50.0, 1.0, 568.0, 4.128521127, 1.0, 1.0))",0.0
2,8.0,43,100,0,3.57,5.0,4.0,2.0,0.0,3.0,0,0,0,0,1,3,0,36,50,1.0,295,3.989830508,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,"List(0, 71, List(0, 1, 2, 4, 5, 6, 7, 9, 14, 15, 17, 18, 19, 20, 21, 40, 64), List(2.0, 8.0, 43.0, 3.57, 5.0, 4.0, 2.0, 3.0, 1.0, 3.0, 36.0, 50.0, 1.0, 295.0, 3.9898305080000003, 1.0, 1.0))",0.0
4,2.72,3,100,1,3.1,6.0,4.0,2.0,2.0,6.0,0,0,0,0,1,1,0,35,50,1.0,14482,4.001588178,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,0,"List(0, 71, List(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 14, 15, 17, 18, 19, 20, 21, 58, 62), List(4.0, 2.72, 3.0, 1.0, 3.1, 6.0, 4.0, 2.0, 2.0, 6.0, 1.0, 1.0, 35.0, 50.0, 1.0, 14482.0, 4.001588178, 1.0, 1.0))",0.0
5,3.92,9,10,0,5.0,1.0,0.0,0.0,0.0,0.0,0,0,0,0,1,1,0,40,50,0.0,65,3.507692308,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,"List(0, 71, List(0, 1, 2, 4, 5, 14, 15, 17, 18, 20, 21, 28, 70), List(5.0, 3.92, 9.0, 5.0, 1.0, 1.0, 1.0, 40.0, 50.0, 65.0, 3.507692308, 1.0, 1.0))",6.0
9,5.78,22,5000,0,3.6,287.0,128.0,92.0,68.0,112.0,0,0,0,0,50,2,0,33,50,0.0,5534,3.999819299,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,"List(0, 71, List(0, 1, 2, 4, 5, 6, 7, 8, 9, 14, 15, 17, 18, 20, 21, 25, 62), List(9.0, 5.78, 22.0, 3.6, 287.0, 128.0, 92.0, 68.0, 112.0, 50.0, 2.0, 33.0, 50.0, 5534.0, 3.999819299, 1.0, 1.0))",2.0
11,6.0,8,100,1,3.31,3.0,4.0,3.0,0.0,3.0,0,0,0,0,2,2,0,40,50,1.0,3515,3.983783784,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,"List(0, 71, List(0, 1, 2, 3, 4, 5, 6, 7, 9, 14, 15, 17, 18, 19, 20, 21, 28, 64), List(11.0, 6.0, 8.0, 1.0, 3.31, 3.0, 4.0, 3.0, 3.0, 2.0, 2.0, 40.0, 50.0, 1.0, 3515.0, 3.983783784, 1.0, 1.0))",0.0
12,1.91,6,1000,1,3.45,49.0,29.0,24.0,14.0,25.0,0,0,0,0,1,1,0,38,50,0.0,557,4.123877917,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,"List(0, 71, List(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 14, 15, 17, 18, 20, 21, 47, 64), List(12.0, 1.91, 6.0, 1.0, 3.45, 49.0, 29.0, 24.0, 14.0, 25.0, 1.0, 1.0, 38.0, 50.0, 557.0, 4.123877917, 1.0, 1.0))",1.0
14,2.0,2,20000,1,3.65,984.0,481.0,459.0,206.0,327.0,0,0,0,0,1,1,0,36,50,0.0,55499,4.138885385,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,0,"List(0, 71, List(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 14, 15, 17, 18, 20, 21, 58, 62), List(14.0, 2.0, 2.0, 1.0, 3.65, 984.0, 481.0, 459.0, 206.0, 327.0, 1.0, 1.0, 36.0, 50.0, 55499.0, 4.138885385, 1.0, 1.0))",4.0
15,11.0,81,1000,0,3.92,204.0,94.0,62.0,21.0,45.0,0,0,0,0,50,3,0,41,50,0.0,39381,4.066326401,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,"List(0, 71, List(0, 1, 2, 4, 5, 6, 7, 8, 9, 14, 15, 17, 18, 20, 21, 26, 62), List(15.0, 11.0, 81.0, 3.92, 204.0, 94.0, 62.0, 21.0, 45.0, 50.0, 3.0, 41.0, 50.0, 39381.0, 4.066326401, 1.0, 1.0))",1.0
17,5.0,25,100000,1,3.83,8290.0,3483.0,2951.0,1410.0,1846.0,0,0,0,0,50,1,0,41,50,1.0,139223,3.933581377,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0,0,0,0,"List(0, 71, List(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 14, 15, 17, 18, 19, 20, 21, 56, 62), List(17.0, 5.0, 25.0, 1.0, 3.83, 8290.0, 3483.0, 2951.0, 1410.0, 1846.0, 50.0, 1.0, 41.0, 50.0, 1.0, 139223.0, 3.933581377, 1.0, 1.0))",8.0


In [0]:
from pyspark.ml.classification import LogisticRegression

lr = LogisticRegression(featuresCol = 'features', labelCol = "labelIndex")

# Fit the model
lrModel = lr.fit(training_data)




In [0]:
# Print the coefficients and intercept for multinomial logistic regression
print("Coefficients: \n" + str(lrModel.coefficientMatrix ))



In [0]:
summary = lrModel.summary
print(f"Training accuracy : {summary.accuracy}")
print(summary.falsePositiveRateByLabel)

In [0]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
predictions = lrModel.transform(test_data)


In [0]:
predictions.select("prediction","features").show(5)

In [0]:
evaluator = MulticlassClassificationEvaluator(labelCol="labelIndex",predictionCol="prediction", metricName="accuracy")
accuracy = evaluator.evaluate(predictions)
print(f"Test accuracy : {accuracy}")


In [0]:

from pyspark.ml.classification import DecisionTreeClassifier
DT = DecisionTreeClassifier(featuresCol = 'features', labelCol = "labelIndex")
dtModel = DT.fit(training_data)


In [0]:
predictions_DT = dtModel.transform(test_data)

In [0]:
predictions_DT.select("prediction","features").show(5)

In [0]:
dtModel.explainParams


In [0]:
predictions_DT_train = dtModel.transform(training_data)

In [0]:
evaluator = MulticlassClassificationEvaluator(labelCol="labelIndex", predictionCol="prediction", metricName="accuracy")
accuracy_DT_train = evaluator.evaluate(predictions_DT_train)
print(f"Train accuracy : {accuracy_DT_train}")


In [0]:
evaluator = MulticlassClassificationEvaluator(labelCol="labelIndex", predictionCol="prediction", metricName="accuracy")
accuracy_DT = evaluator.evaluate(predictions_DT)
print(f"Test accuracy : {accuracy_DT}")


In [0]:
from pyspark.ml.classification import RandomForestClassifier

In [0]:
rf = RandomForestClassifier(featuresCol = 'features', labelCol = "labelIndex")
rfModel=rf.fit(training_data)

In [0]:
predictions_RF_train = rfModel.transform(training_data)

In [0]:
evaluator = MulticlassClassificationEvaluator(labelCol="labelIndex", predictionCol="prediction", metricName="accuracy")
accuracy_RF_train = evaluator.evaluate(predictions_RF_train)
print(f"Train accuracy : {accuracy_RF_train}")


In [0]:
predictions_RF = rfModel.transform(test_data)

In [0]:
evaluator = MulticlassClassificationEvaluator(labelCol="labelIndex", predictionCol="prediction", metricName="accuracy")
accuracy_RF = evaluator.evaluate(predictions_RF)
print(f"Test accuracy : {accuracy_RF}")