<img width="200" src="https://mmlspark.blob.core.windows.net/graphics/emails/vw-blue-dark-orange.svg" />

# Contextual-Bandits using Vowpal Wabbit

[Azure Personalizer](https://azure.microsoft.com/en-us/products/cognitive-services/personalizer) emits logs in DSJSON-format. This example demonstrates how to perform off-policy evaluation.

#### Read dataset

In [1]:
from pyspark.sql import SparkSession

# Bootstrap Spark Session
spark = (
    SparkSession.builder.config(
        "spark.jars.packages",
        "org.apache.hadoop:hadoop-azure:3.3.1,com.microsoft.azure:azure-storage:8.6.6,com.microsoft.azure:synapseml_2.12:0.10.2-114-409e395c-SNAPSHOT",
    )
    .config("spark.jars.repositories", "https://mmlspark.azureedge.net/maven")
    .config(
        "spark.jars.excludes",
        "org.scala-lang:scala-reflect,org.apache.spark:spark-tags_2.12,org.scalatest:scalatest_2.12,com.fasterxml.jackson.core:jackson-databind",
    )
    .config("spark.yarn.user.classpath.first", "true")
    .getOrCreate()
)

from synapse.ml.core.platform import *

from synapse.ml.core.platform import materializing_display as display

23/01/16 13:32:38 WARN Utils: Your hostname, marcozo-eu resolves to a loopback address: 127.0.1.1; using 172.22.131.99 instead (on interface eth0)
23/01/16 13:32:38 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
23/01/16 13:32:39 WARN SparkConf: The configuration key 'spark.yarn.user.classpath.first' has been deprecated as of Spark 1.3 and may be removed in the future. Please use spark.{driver,executor}.userClassPathFirst instead.
https://mmlspark.azureedge.net/maven added as a remote repository with the name: repo-1
Ivy Default Cache set to: /home/marcozo/.ivy2/cache
The jars for the packages stored in: /home/marcozo/.ivy2/jars
org.apache.hadoop#hadoop-azure added as a dependency
com.microsoft.azure#azure-storage added as a dependency
com.microsoft.azure#synapseml_2.12 added as a dependency


:: loading settings :: url = jar:file:/home/marcozo/miniconda3/envs/synapseml/lib/python3.8/site-packages/pyspark/jars/ivy-2.5.0.jar!/org/apache/ivy/core/settings/ivysettings.xml


:: resolving dependencies :: org.apache.spark#spark-submit-parent-1317ef6b-13a4-4cff-af2b-208f5df7692f;1.0
	confs: [default]
	found org.apache.hadoop#hadoop-azure;3.3.1 in central
	found org.apache.httpcomponents#httpclient;4.5.13 in local-m2-cache
	found org.apache.httpcomponents#httpcore;4.4.13 in local-m2-cache
	found commons-logging#commons-logging;1.1.3 in central
	found commons-codec#commons-codec;1.11 in spark-list
	found org.apache.hadoop.thirdparty#hadoop-shaded-guava;1.1.1 in central
	found org.eclipse.jetty#jetty-util-ajax;9.4.40.v20210413 in central
	found org.eclipse.jetty#jetty-util;9.4.40.v20210413 in central
	found org.codehaus.jackson#jackson-mapper-asl;1.9.13 in local-m2-cache
	found org.codehaus.jackson#jackson-core-asl;1.9.13 in local-m2-cache
	found org.wildfly.openssl#wildfly-openssl;1.0.7.Final in central
	found com.microsoft.azure#azure-storage;8.6.6 in central
	found com.fasterxml.jackson.core#jackson-core;2.9.4 in central
	found org.slf4j#slf4j-api;1.7.12 in c

23/01/16 13:32:56 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
23/01/16 13:32:56 WARN SparkConf: The configuration key 'spark.yarn.user.classpath.first' has been deprecated as of Spark 1.3 and may be removed in the future. Please use spark.{driver,executor}.userClassPathFirst instead.
23/01/16 13:32:56 WARN SparkConf: The configuration key 'spark.yarn.user.classpath.first' has been deprecated as of Spark 1.3 and may be removed in the future. Please use spark.{driver,executor}.userClassPathFirst instead.
23/01/16 13:32:56 WARN SparkConf: The configuration key 'spark.yarn.user.classpath.first' has been deprecated as of Spark 1.3 and may be removed in the future. Please use spark.{driver,executor}.userCl

In [3]:
import pyspark.sql.types as T
from pyspark.sql import functions as F

schema = T.StructType(
    [
        T.StructField("input", T.StringType(), False),
    ]
)

df = (
    spark.read.format("text")
    .schema(schema)
    .load("wasbs://publicwasb@mmlspark.blob.core.windows.net/decisionservice.json")
)
# print dataset basic info
print("records read: " + str(df.count()))
print("Schema: ")
df.printSchema()

23/01/16 13:34:49 WARN MetricsConfig: Cannot locate configuration: tried hadoop-metrics2-azure-file-system.properties,hadoop-metrics2.properties
[Stage 0:>                                                          (0 + 1) / 1]

records read: 3
Schema: 
root
 |-- input: string (nullable = true)



                                                                                

In [4]:
display(df)

DataFrame[input: string]
+--------------------+
|               input|
+--------------------+
|{"_label_cost":-1...|
|{"_label_cost":0,...|
|{"_label_cost":-1...|
+--------------------+



[Stage 3:>                                                          (0 + 1) / 1]                                                                                

#### Use VowalWabbitFeaturizer to convert data features into vector

In [24]:
from synapse.ml.vw import VowpalWabbitDSJsonTransformer

df_ready = (
    VowpalWabbitDSJsonTransformer()
    .setDsJsonColumn("input")
    .transform(df)
    .withColumn("splitId", F.lit(0))
    .repartition(2)
)
df_ready.printSchema()

# exclude JSON as it's too messy
display(df_ready.drop("input"))

df_ready.drop("input").show(5, False)  # 1, False, True)

root
 |-- input: string (nullable = true)
 |-- json: struct (nullable = true)
 |    |-- EventId: string (nullable = true)
 |    |-- _label_probability: float (nullable = true)
 |    |-- _labelIndex: integer (nullable = true)
 |    |-- _label_cost: float (nullable = true)
 |-- EventId: string (nullable = true)
 |-- rewards: struct (nullable = false)
 |    |-- reward: float (nullable = true)
 |-- probLog: float (nullable = true)
 |-- chosenActionIndex: integer (nullable = true)
 |-- splitId: integer (nullable = false)

DataFrame[json: struct<EventId:string,_label_probability:float,_labelIndex:int,_label_cost:float>, EventId: string, rewards: struct<reward:float>, probLog: float, chosenActionIndex: int, splitId: int]


[Stage 17:>                                                         (0 + 1) / 1]

+------------------------------------------------------+--------------------------------+-------+---------+-----------------+-------+
|json                                                  |EventId                         |rewards|probLog  |chosenActionIndex|splitId|
+------------------------------------------------------+--------------------------------+-------+---------+-----------------+-------+
|{fbe7a11d120b4df4bf23b836de8a29d1, 0.8166667, 9, 0.0} |fbe7a11d120b4df4bf23b836de8a29d1|{0.0}  |0.8166667|9                |0      |
|{0074434d3a3a46529f65de8a59631939, 0.8166667, 9, -1.0}|0074434d3a3a46529f65de8a59631939|{-1.0} |0.8166667|9                |0      |
|{9077f996581148978a0ebe2484260dab, 0.8166667, 9, -1.0}|9077f996581148978a0ebe2484260dab|{-1.0} |0.8166667|9                |0      |
+------------------------------------------------------+--------------------------------+-------+---------+-----------------+-------+





#### Model Training

VowpalWabbits 
* trains a model for each split (=group)
* synchronizes accross partitions after every split
* store the 1-step ahead predictions in the model

In [29]:
from synapse.ml.vw import VowpalWabbitGeneric

model = (
    VowpalWabbitGeneric(
        passThroughArgs="--cb_adf --cb_type mtr --clip_p 0.1 -q GT -q MS -q GR -q OT -q MT -q OS --dsjson --preserve_performance_counters"
    )
    .setInputCol("input")
    .setSplitCol("splitId")
    .setPredictionIdCol("EventId")
    .fit(df_ready)
)

23/01/16 13:53:10 WARN VowpalWabbitGeneric: VowpalWabbit args: --cb_adf --cb_type mtr --clip_p 0.1 -q GT -q MS -q GR -q OT -q MT -q OS --dsjson --preserve_performance_counters --no_stdin)
creating quadratic features for pairs: GT MS GR OT MT OS
using no cache
Reading datafile = none
num sources = 0
Num weight bits = 18
learning rate = 0.5
initial_t = 0
power_t = 0.5
cb_type = mtr
Enabled reductions: gd, scorer-identity, csoaa_ldf-rank, cb_adf, shared_feature_merger
Input label = cb
Output pred = action_scores
average  since         example        example        current        current  current
loss     last          counter         weight          label        predict features

finished run
number of examples = 0
weighted example sum = 0.000000
weighted label sum = 0.000000
average loss = n.a.
total feature number = 0
23/01/16 13:53:10 WARN VowpalWabbitGeneric: VowpalWabbit args: --cb_adf --cb_type mtr --clip_p 0.1 -q GT -q MS -q GR -q OT -q MT -q OS --dsjson --preserve_performance_coun

#### Model Prediction

In [39]:
df_headers_predictions.printSchema()

root
 |-- EventId: string (nullable = true)
 |-- input: string (nullable = true)
 |-- json: struct (nullable = true)
 |    |-- EventId: string (nullable = true)
 |    |-- _label_probability: float (nullable = true)
 |    |-- _labelIndex: integer (nullable = true)
 |    |-- _label_cost: float (nullable = true)
 |-- rewards: struct (nullable = false)
 |    |-- reward: float (nullable = true)
 |-- probLog: float (nullable = true)
 |-- chosenActionIndex: integer (nullable = true)
 |-- splitId: integer (nullable = false)
 |-- predictions: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- action: integer (nullable = false)
 |    |    |-- score: float (nullable = false)



In [38]:
df_headers_predictions.drop("input").show(5, False, True)

23/01/16 14:01:16 WARN SparkConf: The configuration key 'spark.yarn.user.classpath.first' has been deprecated as of Spark 1.3 and may be removed in the future. Please use spark.{driver,executor}.userClassPathFirst instead.
[Stage 110:>                                                        (0 + 1) / 1]

-RECORD 0---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
 EventId           | fbe7a11d120b4df4bf23b836de8a29d1                                                                                                                                                                                           
 json              | {fbe7a11d120b4df4bf23b836de8a29d1, 0.8166667, 9, 0.0}                                                                                                                                                                      
 rewards           | {0.0}                                                                                                                                                                                                                      
 probLog           | 0.8166667      

                                                                                

In [44]:
from synapse.ml.vw import VowpalWabbitCSETransformer

df_predictions = model.getOneStepAheadPredictions()  # .show(5, False)
df_headers = df_ready.drop("input")

df_headers_predictions = df_headers.join(df_predictions, "EventId")
# df_headers_predictions.show()

metrics = VowpalWabbitCSETransformer().transform(df_headers_predictions)

metrics.show(100, False, True)

23/01/16 14:08:10 WARN SimpleFunctionRegistry: The function snips replaced a previously registered function.
23/01/16 14:08:10 WARN SimpleFunctionRegistry: The function ips replaced a previously registered function.
23/01/16 14:08:10 WARN SimpleFunctionRegistry: The function cressieread replaced a previously registered function.
23/01/16 14:08:10 WARN SimpleFunctionRegistry: The function cressiereadinterval replaced a previously registered function.
23/01/16 14:08:10 WARN SimpleFunctionRegistry: The function cressiereadintervalempirical replaced a previously registered function.
23/01/16 14:08:11 WARN SparkConf: The configuration key 'spark.yarn.user.classpath.first' has been deprecated as of Spark 1.3 and may be removed in the future. Please use spark.{driver,executor}.userClassPathFirst instead.
23/01/16 14:08:11 WARN SparkConf: The configuration key 'spark.yarn.user.classpath.first' has been deprecated as of Spark 1.3 and may be removed in the future. Please use spark.{driver,execut

-RECORD 0--------------------------------------------------------------------------------------------------------------------
 exampleCount                                        | 3                                                                     
 probPredNonZeroCount                                | 0                                                                     
 minimumImportanceWeight                             | 0.0                                                                   
 maximumImportanceWeight                             | 0.0                                                                   
 averageImportanceWeight                             | 0.0                                                                   
 averageSquaredImportanceWeight                      | 0.0                                                                   
 proportionOfMaximumImportanceWeight                 | 0.0                                                            

In [None]:
metrics.select("reward.*").show(100, False, True)