# Product Recommendation with Feathr

This notebook illustrates the use of Feathr Feature Store to create a model that predict users' rating for different products for a e-commerce website.

### Model Problem Statement
The e-commerce website has collected past user ratings for various products. The website also collected data about user and product, like user age, product category etc. Now we want to predict users' product rating for new product so that we can recommend the new product to users that give a high rating for those products.

### Feature Creation Illustration
In this example, our observation data has compound entity key where a record is uniquely identified by `user_id` and `product_id`. With that, we can think about three types of features:
1. **User features** that are different for different users but are the same for different products. For example, user age is different for different users but it's product-agnostic.
2. **Product features** that are different for different products but are the same for all the users.
3. **User-to-product** features that are different for different users AND different products. For example, a feature to represent if the user has bought this product before or not.

In this example, we will focus on the first two types of features. After we train a model based on those features, we predict the product ratings that users will give for the products.

The feature creation flow is as below:
![Feature Flow](https://github.com/feathr-ai/feathr/blob/main/docs/images/product_recommendation_advanced.jpg?raw=true)

## 2. Config Feathr Client

In [1]:
import glob
import os
import tempfile
from datetime import datetime, timedelta
from math import sqrt

import pandas as pd
from pyspark.sql import DataFrame
# from interpret.provider import InlineProvider
# from interpret import set_visualize_provider

# set_visualize_provider(InlineProvider())

import feathr
from feathr import (
    FeathrClient,
    BOOLEAN, FLOAT, INT32, ValueType,
    Feature, DerivedFeature, FeatureAnchor,
    BackfillTime, MaterializationSettings,
    FeatureQuery, ObservationSettings,
    RedisSink,
    INPUT_CONTEXT, HdfsSource,
    WindowAggTransformation,
    TypedKey,
)
from feathr.datasets.constants import (
    PRODUCT_RECOMMENDATION_USER_OBSERVATION_URL,
    PRODUCT_RECOMMENDATION_USER_PROFILE_URL,
    PRODUCT_RECOMMENDATION_USER_PURCHASE_HISTORY_URL,
    PRODUCT_RECOMMENDATION_PRODUCT_DETAIL_URL,
)
from feathr.datasets.utils import maybe_download
from feathr.utils.config import generate_config
from feathr.utils.job_utils import get_result_df

print(f"Feathr version: {feathr.__version__}")

Feathr version: 1.0.0


In [2]:
os.environ['SPARK_LOCAL_IP'] = "127.0.0.1"
os.environ['REDIS_PASSWORD'] = "foobared"
PROJECT_NAME = "product_recommendation"

### Initialize Feathr Client

In [3]:
from pathlib import Path
feathr_workspace_folder = Path(f"./{PROJECT_NAME}_feathr_config.yaml")
client = FeathrClient(str(feathr_workspace_folder))

2024-09-05 15:51:25.514 | INFO     | feathr.utils._env_config_reader:get:62 - Config secrets__azure_key_vault__name is not found in the environment variable, configuration file, or the remote key value store. Returning the default value: None.
2024-09-05 15:51:25.515 | INFO     | feathr.utils._env_config_reader:get:62 - Config offline_store__s3__s3_enabled is not found in the environment variable, configuration file, or the remote key value store. Returning the default value: None.
2024-09-05 15:51:25.516 | INFO     | feathr.utils._env_config_reader:get:62 - Config offline_store__adls__adls_enabled is not found in the environment variable, configuration file, or the remote key value store. Returning the default value: None.
2024-09-05 15:51:25.516 | INFO     | feathr.utils._env_config_reader:get:62 - Config offline_store__wasb__wasb_enabled is not found in the environment variable, configuration file, or the remote key value store. Returning the default value: None.
2024-09-05 15:51:25

## 3. Prepare Datasets

In [4]:
# Download datasets
WORKING_DIR = PROJECT_NAME
user_observation_file_path = f"{WORKING_DIR}/user_observation.csv"
user_profile_file_path = f"{WORKING_DIR}/user_profile.csv"
user_purchase_history_file_path = f"{WORKING_DIR}/user_purchase_history.csv"
product_detail_file_path = f"{WORKING_DIR}/product_detail.csv"
# maybe_download(
#     src_url=PRODUCT_RECOMMENDATION_USER_OBSERVATION_URL,
#     dst_filepath=user_observation_file_path,
# )
# maybe_download(
#     src_url=PRODUCT_RECOMMENDATION_USER_PROFILE_URL,
#     dst_filepath=user_profile_file_path,
# )
# maybe_download(
#     src_url=PRODUCT_RECOMMENDATION_USER_PURCHASE_HISTORY_URL,
#     dst_filepath=user_purchase_history_file_path,
# )
# maybe_download(
#     src_url=PRODUCT_RECOMMENDATION_PRODUCT_DETAIL_URL,
#     dst_filepath=product_detail_file_path,
# )

# In local mode, we can use the same data path as the source.
user_observation_source_path = user_observation_file_path
user_profile_source_path = user_profile_file_path
user_purchase_history_source_path = user_purchase_history_file_path
product_detail_source_path = product_detail_file_path

In [5]:
!mkdir {PROJECT_NAME}

/bin/bash: /mnt/e/setup/miniconda3/lib/libtinfo.so.6: no version information available (required by /bin/bash)
mkdir: cannot create directory ‘product_recommendation’: File exists


In [6]:
import pandas as pd
import numpy as np
import random

## 4. Define Sharable Features using Feathr API

### Understand raw datasets
We have three datasets to work with:
* Observation dataset (a.k.a. labeled dataset)
* User profile
* User purchase history
* Product details

In [7]:
# User profile dataset
# Used to generate user features
pd.read_csv(user_profile_file_path).head()

Unnamed: 0,user_id,age,gift_card_balance,tax_rate,number_of_credit_cards
0,237,40,673,5,1
1,428,19,1121,5,1
2,94,25,105,5,2
3,382,49,1005,4,2
4,367,18,904,4,1


In [8]:
# User purchase history dataset.
# Used to generate user features. This is activity type data, so we need to use aggregation to generate features.
pd.read_csv(user_purchase_history_file_path).head()

Unnamed: 0,user_id,purchase_time,purchase_amount,purchase_date
0,333,2024-09-01,79.51,2024-09-01
1,169,2023-09-06,12.15,2023-09-06
2,241,2022-01-02,113.65,2022-01-02
3,301,2023-12-04,120.3,2023-12-04
4,133,2023-05-11,133.7,2023-05-11


In [9]:
# Product detail dataset.
# Used to generate product features.
pd.read_csv(product_detail_file_path).head()

Unnamed: 0,product_id,price,quantity
0,92,133.249225,3
1,8,183.123882,61
2,49,130.539794,42
3,48,195.402699,70
4,58,73.002042,85


### What's a feature in Feathr
A feature is an individual measurable property or characteristic of a phenomenon which is sometimes time-sensitive.

In Feathr, a feature is defined by the following characteristics:
* The typed key (a.k.a. entity id): identifies the subject of feature, e.g. a user id of 123, a product id of SKU234456.
* The feature name: the unique identifier of the feature, e.g. user_age, total_spending_in_30_days.
* The feature value: the actual value of that aspect at a particular time, e.g. the feature value of the person's age is 30 at year 2022.
* The timestamp: this indicates when the event happened. For example, the user purchased certain product on a certain timestamp. This is usually used for point-in-time join.

You can feel that this is defined from a feature consumer (a person who wants to use a feature) perspective. It only tells us what a feature is like. In later sections, you can see how a feature consumer can access the features in a very simple way.

To define how to produce the feature, we need to specify:
* Feature source: what source data that this feature is based on
* Transformation: what transformation is used to transform the source data into feature. Transformation can be optional when you just want to take a column out from the source data.

(For more details on feature definition, please refer to the [Feathr Feature Definition Guide](https://feathr-ai.github.io/feathr/concepts/feature-definition.html).)

Note: in some cases, such as features defined on top of request data, may have no entity key or timestamp.
It is merely a function/transformation executing against request data at runtime.
For example, the day of week of the request, which is calculated by converting the request UNIX timestamp.
(We won't cover this in the tutorial.)

### Define Sources Section with UDFs

A feature is called an anchored feature when the feature is directly extracted from the source data, rather than computed on top of other features. The latter case is called derived feature.

A [feature source](https://feathr.readthedocs.io/en/latest/#feathr.Source) is needed for anchored features that describes the raw data in which the feature values are computed from. See the python documentation to get the details on each input column.

See [the python API documentation](https://feathr.readthedocs.io/en/latest/#feathr.HdfsSource) to get the details of each input fields.

### Define window aggregation features

[Window aggregation](https://en.wikipedia.org/wiki/Window_function_%28SQL%29) helps us to create more powerful features by compressing large amount of information. For example, we can compute *average purchase amount over the last 90 days* from the purchase history to capture user's recent consumption trend.

To create window aggregation features, we define `WindowAggTransformation` with following arguments:
1. `agg_expr`: the field/column you want to aggregate. It can be an ANSI SQL expression, e.g. `cast_float(purchase_amount)` to cast `str` type values to `float`.
2. `agg_func`: the aggregation function, e.g. `AVG`. See below table for the full list of supported functions.
3. `window`: the aggregation window size, e.g. `90d` to aggregate over the 90 days.

| Aggregation Type | Input Type | Description |
| --- | --- | --- |
| `SUM`, `COUNT`, `MAX`, `MIN`, `AVG` | Numeric | Applies the the numerical operation on the numeric inputs. |
| `MAX_POOLING`, `MIN_POOLING`, `AVG_POOLING`	| Numeric Vector | Applies the max/min/avg operation on a per entry basis for a given a collection of numbers. |
| `LATEST` | Any | Returns the latest not-null values from within the defined time window. |

After you have defined features and sources, bring them together to build an anchor:

> Note that if the features comes directly from the observation data, the `source` argument should be `INPUT_CONTEXT` to indicate the source of the anchor is the observation data.

## Get features from Register Server

In [10]:
client.list_registered_features(project_name=PROJECT_NAME)

[{'name': 'feature_user_avg_purchase_for_90days',
  'id': '34ed8536-6f7d-44ff-a1e1-e13c11486591',
  'qualifiedName': 'product_recommendation__aggregationFeatures__feature_user_avg_purchase_for_90days'},
 {'name': 'feature_user_age',
  'id': 'd2537e45-a0fa-429f-9d7a-ff7d84ee0625',
  'qualifiedName': 'product_recommendation__anchored_features__feature_user_age'},
 {'name': 'feature_user_tax_rate',
  'id': '20b61ff4-364c-4c3e-9d9c-8bd28d642362',
  'qualifiedName': 'product_recommendation__anchored_features__feature_user_tax_rate'},
 {'name': 'feature_user_gift_card_balance',
  'id': '6400f882-e0fe-4abf-a84e-450b9530f40d',
  'qualifiedName': 'product_recommendation__anchored_features__feature_user_gift_card_balance'},
 {'name': 'feature_user_has_valid_credit_card',
  'id': '8e5bea9f-f363-4dd8-9854-90d396333db9',
  'qualifiedName': 'product_recommendation__anchored_features__feature_user_has_valid_credit_card'},
 {'name': 'feature_product_quantity',
  'id': '4c33fdca-47b5-4d7a-b76e-32768de3

In [11]:
feature_dict = client.get_features_from_registry(project_name=PROJECT_NAME, return_keys=True, verbose=True)

2024-09-05 15:51:26.336 | INFO     | feathr.client:get_features_from_registry:1147 - Get anchor features from registry: 
2024-09-05 15:51:26.339 | INFO     | feathr.client:get_features_from_registry:1153 - {
  "name": "feature_user_avg_purchase_for_90days",
  "featureType": {
    "type": "TENSOR",
    "tensorCategory": "DENSE",
    "dimensionType": [],
    "valType": "FLOAT"
  },
  "key": [
    {
      "keyColumn": "user_id",
      "keyColumnType": "INT",
      "fullName": "product_recommendation.user_id",
      "description": "product_recommendation.user_id",
      "keyColumnAlias": "user_id"
    }
  ],
  "transformation": {
    "defExpr": "cast_float(purchase_amount)",
    "aggFunc": "AVG",
    "window": "90d"
  }
}
2024-09-05 15:51:26.340 | INFO     | feathr.client:get_features_from_registry:1153 - {
  "name": "feature_user_age",
  "featureType": {
    "type": "TENSOR",
    "tensorCategory": "DENSE",
    "dimensionType": [],
    "valType": "INT"
  },
  "key": [
    {
      "keyColum

## we can list all features

In [12]:
[feat.name for feat in list(feature_dict[0].values())]

['feature_user_avg_purchase_for_90days',
 'feature_user_age',
 'feature_user_tax_rate',
 'feature_user_gift_card_balance',
 'feature_user_has_valid_credit_card',
 'feature_product_quantity',
 'feature_product_price',
 'feature_user_purchasing_power']

## we can list all type_key

In [13]:
[type_key.key_column for type_keys in list(feature_dict[1].values()) for type_key in type_keys]

['user_id',
 'user_id',
 'user_id',
 'user_id',
 'user_id',
 'product_id',
 'product_id',
 'user_id']

## 5. Create Training Data using Point-in-Time Correct Feature join

To create a training dataset using Feathr, we need to provide a **feature join settings** to specify what features and how these features should be joined to the observation data.

Also note that since a `FeatureQuery` accepts features of the same join key, we define two query objects, one for `user_id` key and the other one for `product_id` and pass them together to compute offline features.

To learn more on this topic, please refer to [Point-in-time Correctness document](https://feathr-ai.github.io/feathr/concepts/point-in-time-join.html).

In [14]:
key_user_id = feature_dict[1]["feature_user_avg_purchase_for_90days"][0]
key_product_id = feature_dict[1]["feature_product_quantity"][0]

## This is a demo in the product senario: 
### We can control extract features for training on some days, or every day we want. Trick here is that we'll control observation_path in setting

 - Frirtly, we want to run extract all features before "2024-12". 
 - Secondly, we only need extract features on days from "2024-12". If we want to run daily we can specific day we want 

In [15]:
df = pd.read_csv(user_observation_source_path)

In [16]:
df[~df.event_timestamp.str.startswith("2024-12")].to_csv(
    user_observation_source_path.rpartition("/")[0] + "/user_observation_except_2024-12.csv", index=None
)

In [17]:
df[df.event_timestamp.str.startswith("2024-12")].to_csv(
    user_observation_source_path.rpartition("/")[0] + "/user_observation_only_2024-12.csv", index=None
)

In [18]:
user_feature_query = FeatureQuery(
    feature_list=[feat.name for feat in list(feature_dict[0].values()) if "user" in feat.name], #[feat.name for feat in features + agg_features + derived_features],
    key=key_user_id, #user_id
)

product_feature_query = FeatureQuery(
    feature_list= [feat.name for feat in list(feature_dict[0].values()) if "product" in feat.name] ,#[feat.name for feat in product_features],
    key=key_product_id, #product_id
)


In [19]:
settings_for_except_2024_12 = ObservationSettings(
    observation_path=user_observation_source_path.rpartition("/")[0] + "/user_observation_except_2024-12.csv",
    event_timestamp_column="event_timestamp",
    timestamp_format="yyyy-MM-dd",
)
client.get_offline_features(
    observation_settings=settings_for_except_2024_12,
    feature_query=[user_feature_query, product_feature_query],
    output_path=user_observation_source_path.rpartition("/")[0] + f"/except_2024_12/product_recommendation_features.avro",
)
client.wait_job_to_finish(timeout_sec=100000)

2024-09-05 15:51:27.005 | INFO     | feathr.spark_provider._localspark_submission:_get_debug_file_name:292 - Spark log path is debug/product_recommendation_feathr_feature_join_job20240905155127
2024-09-05 15:51:27.090 | INFO     | feathr.spark_provider._localspark_submission:_init_args:267 - Spark job: product_recommendation_feathr_feature_join_job is running on local spark with master: local[*].
2024-09-05 15:51:27.150 | INFO     | feathr.spark_provider._localspark_submission:submit_feathr_job:147 - Detail job stdout and stderr are in debug/product_recommendation_feathr_feature_join_job20240905155127/log.
bash: /mnt/e/setup/miniconda3/lib/libtinfo.so.6: no version information available (required by bash)
2024-09-05 15:51:27.152 | INFO     | feathr.spark_provider._localspark_submission:submit_feathr_job:157 - Local Spark job submit with pid: 38124.
2024-09-05 15:51:27.153 | INFO     | feathr.spark_provider._localspark_submission:wait_for_completion:167 - 1 local spark job(s) in this La

x

https://repository.mulesoft.org/nexus/content/repositories/public/ added as a remote repository with the name: repo-1
https://linkedin.jfrog.io/artifactory/open-source/ added as a remote repository with the name: repo-2
Ivy Default Cache set to: /home/cuong/.ivy2/cache
The jars for the packages stored in: /home/cuong/.ivy2/jars
org.apache.spark#spark-avro_2.12 added as a dependency
com.microsoft.sqlserver#mssql-jdbc added as a dependency
com.microsoft.azure#spark-mssql-connector_2.12 added as a dependency
org.apache.logging.log4j#log4j-core added as a dependency
com.typesafe#config added as a dependency
com.fasterxml.jackson.core#jackson-databind added as a dependency
org.apache.hadoop#hadoop-mapreduce-client-core added as a dependency
org.apache.hadoop#hadoop-common added as a dependency
org.apache.hadoop#hadoop-azure added as a dependency
org.apache.avro#avro added as a dependency
org.apache.xbean#xbean-asm6-shaded added as a dependency
org.apache.spark#spark-sql-kafka-0-10_2.12 adde

>>>>>>>>>x

2024-09-05 15:52:39.235 | INFO     | feathr.spark_provider._localspark_submission:wait_for_completion:194 - Pyspark job Completed


>

2024-09-05 15:52:40.238 | INFO     | feathr.spark_provider._localspark_submission:wait_for_completion:230 - Spark job with pid 38124 finished in: 73 seconds.


>

In [20]:
settings_only_2024_12 = ObservationSettings(
    observation_path=user_observation_source_path.rpartition("/")[0] + "/user_observation_only_2024-12.csv",
    event_timestamp_column="event_timestamp",
    timestamp_format="yyyy-MM-dd",
)
client.get_offline_features(
    observation_settings=settings_only_2024_12,
    feature_query=[user_feature_query, product_feature_query],
    output_path=user_observation_source_path.rpartition("/")[0] + f"/only_2024_12/product_recommendation_features.avro",
)
client.wait_job_to_finish(timeout_sec=100000)

2024-09-05 15:52:40.279 | INFO     | feathr.spark_provider._localspark_submission:_get_debug_file_name:292 - Spark log path is debug/product_recommendation_feathr_feature_join_job20240905155240
2024-09-05 15:52:40.279 | INFO     | feathr.spark_provider._localspark_submission:_init_args:267 - Spark job: product_recommendation_feathr_feature_join_job is running on local spark with master: local[*].
2024-09-05 15:52:40.285 | INFO     | feathr.spark_provider._localspark_submission:submit_feathr_job:147 - Detail job stdout and stderr are in debug/product_recommendation_feathr_feature_join_job20240905155240/log.
bash: /mnt/e/setup/miniconda3/lib/libtinfo.so.6: no version information available (required by bash)
2024-09-05 15:52:40.286 | INFO     | feathr.spark_provider._localspark_submission:submit_feathr_job:157 - Local Spark job submit with pid: 41941.
2024-09-05 15:52:40.287 | INFO     | feathr.spark_provider._localspark_submission:wait_for_completion:167 - 2 local spark job(s) in this La

x

https://repository.mulesoft.org/nexus/content/repositories/public/ added as a remote repository with the name: repo-1
https://linkedin.jfrog.io/artifactory/open-source/ added as a remote repository with the name: repo-2
Ivy Default Cache set to: /home/cuong/.ivy2/cache
The jars for the packages stored in: /home/cuong/.ivy2/jars
org.apache.spark#spark-avro_2.12 added as a dependency
com.microsoft.sqlserver#mssql-jdbc added as a dependency
com.microsoft.azure#spark-mssql-connector_2.12 added as a dependency
org.apache.logging.log4j#log4j-core added as a dependency
com.typesafe#config added as a dependency
com.fasterxml.jackson.core#jackson-databind added as a dependency
org.apache.hadoop#hadoop-mapreduce-client-core added as a dependency
org.apache.hadoop#hadoop-common added as a dependency
org.apache.hadoop#hadoop-azure added as a dependency
org.apache.avro#avro added as a dependency
org.apache.xbean#xbean-asm6-shaded added as a dependency
org.apache.spark#spark-sql-kafka-0-10_2.12 adde

>>>>>>>x

2024-09-05 15:53:50.356 | INFO     | feathr.spark_provider._localspark_submission:wait_for_completion:194 - Pyspark job Completed


>>x

2024-09-05 15:54:22.393 | INFO     | feathr.spark_provider._localspark_submission:wait_for_completion:230 - Spark job with pid 41941 finished in: 102 seconds.


Let's use the helper function `get_result_df` to download the result and view it:

## get df & check except_2024_12 and only_2024_12

In [21]:
res_df_except_2024_12 = get_result_df(client, res_url=user_observation_source_path.rpartition("/")[0] + f"/except_2024_12/product_recommendation_features.avro")
res_df_except_2024_12.head()

  return pd.concat([pdx.read_avro(f) for f in Path(dir_path).glob("*.avro")]).reset_index(drop=True)


Unnamed: 0,user_id,product_id,event_timestamp,Product_rating,feature_user_avg_purchase_for_90days,feature_product_price,feature_product_quantity,feature_user_gift_card_balance,feature_user_has_valid_credit_card,feature_user_tax_rate,feature_user_age,feature_user_purchasing_power
0,139,11,2022-06-21,3,,50.927082,13.0,1111.0,True,0.02,28,1211.0
1,82,44,2024-02-22,2,,186.664688,89.0,1174.0,True,0.01,39,1274.0
2,159,95,2023-06-23,3,,10.84915,91.0,464.0,True,0.01,22,564.0
3,49,6,2023-02-23,4,,67.417374,49.0,1358.0,True,0.03,28,1458.0
4,155,69,2024-04-16,5,,160.445007,17.0,767.0,True,0.03,64,867.0


In [22]:
# it should be passed
assert len(res_df_except_2024_12[res_df_except_2024_12.event_timestamp.str.startswith("2024-12")]) == 0

In [23]:
res_df_only_2024_12 = get_result_df(
    client, 
    res_url=user_observation_source_path.rpartition("/")[0] + f"/only_2024_12/product_recommendation_features.avro"
)
res_df_only_2024_12.head()

  return pd.concat([pdx.read_avro(f) for f in Path(dir_path).glob("*.avro")]).reset_index(drop=True)


Unnamed: 0,user_id,product_id,event_timestamp,Product_rating,feature_user_avg_purchase_for_90days,feature_product_price,feature_product_quantity,feature_user_gift_card_balance,feature_user_has_valid_credit_card,feature_user_tax_rate,feature_user_age,feature_user_purchasing_power
0,359,4,2024-12-15,2,,152.512939,10.0,1997.0,True,0.02,64,2097.0
1,143,93,2024-12-17,3,,66.237854,37.0,798.0,True,0.01,59,898.0
2,370,82,2024-12-08,4,52.130001,76.572174,65.0,275.0,True,0.04,59,375.0
3,293,22,2024-12-05,5,,22.263668,57.0,1188.0,True,0.03,25,1288.0
4,235,26,2024-12-01,3,100.610001,85.598175,59.0,1344.0,True,0.05,58,1444.0


In [24]:
# it should be passed
assert len(res_df_only_2024_12[~res_df_only_2024_12.event_timestamp.str.startswith("2024-12")]) == 0

In [25]:
res_df = pd.concat([res_df_except_2024_12, res_df_only_2024_12])
res_df.head()

Unnamed: 0,user_id,product_id,event_timestamp,Product_rating,feature_user_avg_purchase_for_90days,feature_product_price,feature_product_quantity,feature_user_gift_card_balance,feature_user_has_valid_credit_card,feature_user_tax_rate,feature_user_age,feature_user_purchasing_power
0,139,11,2022-06-21,3,,50.927082,13.0,1111.0,True,0.02,28,1211.0
1,82,44,2024-02-22,2,,186.664688,89.0,1174.0,True,0.01,39,1274.0
2,159,95,2023-06-23,3,,10.84915,91.0,464.0,True,0.01,22,564.0
3,49,6,2023-02-23,4,,67.417374,49.0,1358.0,True,0.03,28,1458.0
4,155,69,2024-04-16,5,,160.445007,17.0,767.0,True,0.03,64,867.0


### Train a machine learning model
After getting all the features, let's train a machine learning model with the converted feature by Feathr. Here, we use **EBM (Explainable Boosting Machine)** regressor from [InterpretML](https://github.com/interpretml/interpret) package to visualize the modeling results.

In [26]:
from interpret import show
from interpret.glassbox import ExplainableBoostingRegressor
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split

# Fill None values with 0
final_df = (
    res_df
    .drop(["event_timestamp"], axis=1, errors="ignore")
    .fillna(0)
)

# Split data into train and test
X_train, X_test, y_train, y_test = train_test_split(
    final_df.drop(["Product_rating"], axis=1),
    final_df["Product_rating"].astype("float64"),
    test_size=0.2,
    random_state=42,
)

ebm = ExplainableBoostingRegressor()
ebm.fit(X_train, y_train)

# show(ebm_global) # Will run on 127.0.0.1/localhost at port 7080
# Note, currently InterpretML's visualization dashboard doesn't work w/ VSCODE notebook viewer
# https://github.com/interpretml/interpret/issues/317
ebm_global = ebm.explain_global()
show(ebm_global)

In [27]:
# Predict and evaluate
y_pred = ebm.predict(X_test)
rmse = sqrt(mean_squared_error(y_test.values.flatten(), y_pred))

print(f"Root mean squared error: {rmse}")

Root mean squared error: 1.6156213274182007


In [28]:
X_test.columns

Index(['user_id', 'product_id', 'feature_user_avg_purchase_for_90days',
       'feature_product_price', 'feature_product_quantity',
       'feature_user_gift_card_balance', 'feature_user_has_valid_credit_card',
       'feature_user_tax_rate', 'feature_user_age',
       'feature_user_purchasing_power'],
      dtype='object')