In [1]:
from pyspark.sql import SparkSession

spark = (
    SparkSession
        .builder
        .master('local')
        .config('spark.sql.files.ignoreCorruptFiles', 'true')
        .config('sparl.sql.sources.partitionOverwriteMode', 'dynamic')
        .config("spark.jars.packages", "ai.catboost:catboost-spark_3.0_2.12:1.0.4")
        .config('spark.dynamicAllocation.enabled', 'false')
        .getOrCreate()
        )

22/10/04 05:05:25 WARN Utils: Your hostname, notebook resolves to a loopback address: 127.0.1.1; using 192.168.0.18 instead (on interface wlp9s0)
22/10/04 05:05:25 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Ivy Default Cache set to: /home/walter/.ivy2/cache
The jars for the packages stored in: /home/walter/.ivy2/jars
:: loading settings :: url = jar:file:/opt/spark-3.0.3/jars/ivy-2.4.0.jar!/org/apache/ivy/core/settings/ivysettings.xml
ai.catboost#catboost-spark_3.0_2.12 added as a dependency
:: resolving dependencies :: org.apache.spark#spark-submit-parent-6f97cd44-468c-4e5d-ae74-4f5f84e91b51;1.0
	confs: [default]
	found ai.catboost#catboost-spark_3.0_2.12;1.0.4 in central
	found com.google.guava#guava;29.0-jre in central
	found com.google.guava#failureaccess;1.0.1 in central
	found com.google.guava#listenablefuture;9999.0-empty-to-avoid-conflict-with-guava in central
	found com.google.code.findbugs#jsr305;3.0.2 in central
	found org.checkerframework#checker-

In [2]:
from pyspark.sql.types import StructField, StringType, StructType
from pyspark.ml.linalg import VectorUDT, Vectors
from pyspark.sql import Row
import catboost_spark as cs
import pandas as pd

# # StructField
# PySpark StructType & StructField classes are used 
# to programmatically specify the schema to the DataFrame
# and creating complex columns like nested struct, array and map columns. 
# StructType is a collection of StructFields that defines column name, 
# column data type, boolean to specify if the field can be nullable or not and metadata.

# # pyspark.ml.linalg.vectorUDT
# User-defined type for Vector which allows easy interaction with SQL via Dataset.

# # pyspark.sql.types.StringType
# String data type.

# # Vectors
# Factory methods for working with vectors. Note that dense vectors are simply represented as NumPy
# array objects, so there is no need to covert them for use in MLlib.
# For sparse vectors, the factory methods in this class create an MLlib-compatible type, 
# or users can pass in SciPy's scipy.sparse column vectors.

# #Vectors.dense()
# Create a dense vector of 64-bit floats from a Python list. Always returns a NumPy array.

# # pyspark.sql.Row
# A row in DataFrame. The fields in it can be accessed:

# # cs.Pool(mi pyspark.Dataframe)
# Mete los datos en una estructura para catboost, nada más.

In [3]:
# armo una lista de structField(nombre,tipo), o sea, un esquema.
srcDataSchema = [
    StructField("features", VectorUDT()),
    StructField("label", StringType())
]

In [4]:
# armo una lista de rows, cada una con dos elementos, un vector.dense y un string
# o sea, armo la data

trainData = [
    Row(Vectors.dense(0.1, 0.2, 0.11), "0"),
    Row(Vectors.dense(0.97, 0.82, 0.33), "1"),
    Row(Vectors.dense(0.13, 0.22, 0.23), "1"),
    Row(Vectors.dense(0.8, 0.62, 0.0), "0")
]

In [5]:
# me armo mi pyspark.sql.dataframe.DataFrame, 

# spark.sparkContext.parallelize(mi_data<lista>)
# Distribute a local Python collection to form an RDD. Using range is 
# recommended if the input represents a range for performance.

trainDf = spark.createDataFrame(spark.sparkContext.parallelize(trainData), StructType(srcDataSchema))

In [6]:
# realizo un Pool, o sea, nada, lo meto en una estructura para que catboost entienda
trainPool = cs.Pool(trainDf)

In [7]:
# lo mismo anterior, pero con data de evaluación
evalData = [
    Row(Vectors.dense(0.22, 0.33, 0.9), "1"),
    Row(Vectors.dense(0.11, 0.1, 0.21), "0"),
    Row(Vectors.dense(0.77, 0.0, 0.0), "1"),
    Row(Vectors.dense(0.77, 0.0, 0.0), "1"),
    Row(Vectors.dense(0.77, 0.0, 0.0), "1"),
    Row(Vectors.dense(0.77, 0.1, 0.0), "1"),
    Row(Vectors.dense(0.77, 0.4, 0.0), "1"),
    Row(Vectors.dense(0.77, 0.0, 5.0), "0"),
    Row(Vectors.dense(0.77, 0.0, 0.0), "1"),
    Row(Vectors.dense(0.77, 0.0, 0.0), "1")
]
    
evalDf = spark.createDataFrame(spark.sparkContext.parallelize(evalData), StructType(srcDataSchema))
evalPool = cs.Pool(evalDf)

In [8]:
# instancio. Ver lista de parametros en https://catboost.ai/docs/catboost-spark/3.0_2.12/latest/api/python/api/catboost_spark.CatBoostClassifier.html#catboost_spark.CatBoostClassifier
# allowWritingFiles: bool (para ver si escribe cosas durante training)

cl = cs.CatBoostClassifier(allowWritingFiles=False, classWeightsList=[0.01,0.99], earlyStoppingRounds = 200, iterations=10, learningRate=0.01, randomSeed=100, lossFunction='Logloss') #customMetric=''
# genero modelo
model = cl.fit(trainPool, [evalPool])
# predigo con el modelo
predictions = model.transform(evalPool.data)
predictions.show()

[Stage 430:>                                                        (0 + 1) / 1]

0:	learn: 0.6917459	test: 0.6929852	best: 0.6929852 (0)	total: 10.9ms	remaining: 98.4ms
1:	learn: 0.6903636	test: 0.6928148	best: 0.6928148 (1)	total: 18.6ms	remaining: 74.3ms
2:	learn: 0.6889702	test: 0.6926512	best: 0.6926512 (2)	total: 27.5ms	remaining: 64.2ms
3:	learn: 0.6875958	test: 0.6926693	best: 0.6926512 (2)	total: 34.1ms	remaining: 51.1ms
4:	learn: 0.6862402	test: 0.6912981	best: 0.6912981 (4)	total: 39.7ms	remaining: 39.7ms
5:	learn: 0.6848736	test: 0.6911419	best: 0.6911419 (5)	total: 46.2ms	remaining: 30.8ms
6:	learn: 0.6834958	test: 0.6909824	best: 0.6909824 (6)	total: 54.7ms	remaining: 23.4ms
7:	learn: 0.6821219	test: 0.6908251	best: 0.6908251 (7)	total: 63.4ms	remaining: 15.8ms
8:	learn: 0.6807518	test: 0.6906667	best: 0.6906667 (8)	total: 71.7ms	remaining: 7.96ms
9:	learn: 0.6783598	test: 0.6903791	best: 0.6903791 (9)	total: 77.9ms	remaining: 0us

bestTest = 0.6903791263
bestIteration = 9

QueryFullTime: 0.124107
QueryExecutionTime: 0.067353


                                                                                

Skipping test eval output
0.003230997535 min passed
+---------------+-----+--------------------+--------------------+----------+
|       features|label|       rawPrediction|         probability|prediction|
+---------------+-----+--------------------+--------------------+----------+
|[0.22,0.33,0.9]|    1|[-0.0122700165425...|[0.49386529959210...|       1.0|
|[0.11,0.1,0.21]|    0|[-0.0054910660731...|[0.49725449455736...|       1.0|
| [0.77,0.0,0.0]|    1|[-0.0012423872599...|[0.49937880668964...|       1.0|
| [0.77,0.0,0.0]|    1|[-0.0012423872599...|[0.49937880668964...|       1.0|
| [0.77,0.0,0.0]|    1|[-0.0012423872599...|[0.49937880668964...|       1.0|
| [0.77,0.1,0.0]|    1|[-0.0012423872599...|[0.49937880668964...|       1.0|
| [0.77,0.4,0.0]|    1|[-0.0026715148175...|[0.49866424576900...|       1.0|
| [0.77,0.0,5.0]|    0|[-0.0055689975852...|[0.49721552999291...|       1.0|
| [0.77,0.0,0.0]|    1|[-0.0012423872599...|[0.49937880668964...|       1.0|
| [0.77,0.0,0.0]|    1|[

                                                                                

In [9]:
predictions.toPandas()

Unnamed: 0,features,label,rawPrediction,probability,prediction
0,"[0.22, 0.33, 0.9]",1,"[-0.012270016542563773, 0.012270016542563773]","[0.49386529959210396, 0.5061347004078961]",1.0
1,"[0.11, 0.1, 0.21]",0,"[-0.005491066073124653, 0.005491066073124653]","[0.4972544945573653, 0.5027455054426347]",1.0
2,"[0.77, 0.0, 0.0]",1,"[-0.0012423872599294427, 0.0012423872599294427]","[0.4993788066896446, 0.5006211933103554]",1.0
3,"[0.77, 0.0, 0.0]",1,"[-0.0012423872599294427, 0.0012423872599294427]","[0.4993788066896446, 0.5006211933103554]",1.0
4,"[0.77, 0.0, 0.0]",1,"[-0.0012423872599294427, 0.0012423872599294427]","[0.4993788066896446, 0.5006211933103554]",1.0
5,"[0.77, 0.1, 0.0]",1,"[-0.0012423872599294427, 0.0012423872599294427]","[0.4993788066896446, 0.5006211933103554]",1.0
6,"[0.77, 0.4, 0.0]",1,"[-0.0026715148175077133, 0.0026715148175077133]","[0.49866424576900015, 0.5013357542309999]",1.0
7,"[0.77, 0.0, 5.0]",0,"[-0.005568997585267384, 0.005568997585267384]","[0.4972155299929109, 0.502784470007089]",1.0
8,"[0.77, 0.0, 0.0]",1,"[-0.0012423872599294427, 0.0012423872599294427]","[0.4993788066896446, 0.5006211933103554]",1.0
9,"[0.77, 0.0, 0.0]",1,"[-0.0012423872599294427, 0.0012423872599294427]","[0.4993788066896446, 0.5006211933103554]",1.0


In [10]:
#guardo modelo
savedModelPath = "/home/walter/Documents/serie-notas/z_aux/modelo_catboost/example_save_model.cbm"

model.saveNativeModel(savedModelPath)

# cargo modelo nuevamente
#loadedNativeModel = catboost_spark.CatBoostClassificationModel.loadNativeModel(savedNativeModelPath)
