# KumoRFM SALT Evaluation

This notebook provides a step-by-step guide on how to evaluate the performance of `KumoRFM` on the [SALT](https://github.com/SAP-samples/salt) dataset.

SALT is designed to reflect real-world customer interactions within an Enterprise Resource Planning (ERP) system, where several fields are commonly missing from sales orders and must be predicted.
It comprises four (anonymized) core tables—sales documents, sales document items, customers, and addresses—amounting to approximately 5 million records in total.
The dataset defines eight missing fields as target variables for multi-class classification tasks.
These tasks are characterized by significant class imbalances, a wide range of class counts (up to 589), and challenges such as label diversity, noise, and distributional drift.

Let's start by installing the necessary packages:

In [None]:
!pip install kumoai --pre --upgrade
!pip install datasets

In [None]:
from kumoai.experimental import rfm

Ensure that we are authorized:

In [None]:
import os

if not os.environ.get("KUMO_API_KEY"):
    rfm.authenticate()

In [None]:
rfm.init()

## Data Loading and Data Cleaning

The SALT dataset is available via the [Hugging Face dataset platform](https://huggingface.co/datasets/sap-ai-research/SALT).
You will need a Hugging Face API key to access it, which can be obtained as described [here](https://huggingface.co/docs/hub/en/security-tokens).

In [None]:
import huggingface_hub

HUGGINGFACE_TOKEN="..."  # TODO Fill
huggingface_hub.login(HUGGINGFACE_TOKEN)

We are now ready to load the dataset into memory. The SALT dataset is divided into an artificial training and test dataset, in which customers and addresses are shared among the two sets, while salesdocuments and salesdocument items hold distinct information (splitted by time). In order to fit this dataset into `KumoRFM`, we need to reconstruct it into its original raw dataset. We do this by simply concatenating train and test records into a single table:

In [None]:
import pandas as pd
from datasets import load_dataset

name = 'sap-ai-research/SALT'

sales = pd.concat([
    load_dataset(name, 'salesdocuments', split='train').to_pandas(),
    load_dataset(name, 'salesdocuments', split='test').to_pandas(),
], axis=0, ignore_index=True)
items = pd.concat([
    load_dataset(name, 'salesdocument_items', split='train').to_pandas(),
    load_dataset(name, 'salesdocument_items', split='test').to_pandas(),
], axis=0, ignore_index=True)
customers = load_dataset(name, 'customers', split='train').to_pandas()
addresses = load_dataset(name, 'addresses', split='train').to_pandas()

Let's look at a subset of the data:

In [None]:
display(sales.head())

In [None]:
display(items.head())

We need to make a few adjustments and sanitize the data:

1. We merge the `CREATIONTIME` and `CREATIONDATE` column into a single `CREATIONDATETIME` column:

In [None]:
date = sales['CREATIONDATE'].astype(str)
time = sales['CREATIONTIME'].astype(str)
sales['CREATIONDATETIME'] = pd.to_datetime(date + ' ' + time)
del sales['CREATIONDATE']
del sales['CREATIONTIME']

2. We add this `CREATIONDATETIME` column to the salesdocument items table as well. This is necessary to ensure the tasks of SALT on the salesdocument items table are assigned a unique timestamp to avoid temporal data leakage:

In [None]:
items = pd.merge(
    left=items,
    right=sales[['SALESDOCUMENT', 'CREATIONDATETIME']],
    how='left',
    left_on='SALESDOCUMENT',
    right_on='SALESDOCUMENT',
)

3. We remove auto-generated columns:

In [None]:
# Remove auto-generated columns:
del sales['__index_level_0__']
del items['__index_level_0__']
del customers['__index_level_0__']
del addresses['__index_level_0__']

4. We add a primary key to the salesdocument items table in order to ensure that we can reference its training and test records:

In [None]:
items['ID'] = range(len(items))

5. There exists an `INCOTERMSCLASSIFICATION` column in both the salesdocuments and salesdocument items table, which are both used as a target downstream. Let's rename these columns to make the two tasks easily distinguishable:

In [None]:
sales = sales.rename(
    columns={'INCOTERMSCLASSIFICATION': 'HEADERINCOTERMSCLASSIFICATION'})
items = items.rename(
    columns={'INCOTERMSCLASSIFICATION': 'ITEMINCOTERMSCLASSIFICATION'})

According to the SALT evaluation protocol, none of the eight target columns should be used as input feature. As such, for a given task, we make sure to exclude all remaining ones.

Note that you can simply change to a different task here by adjusting the `task` variable:

In [None]:
sale_tasks = [
    'SALESOFFICE',
    'SALESGROUP',
    'CUSTOMERPAYMENTTERMS',
    'SHIPPINGCONDITION',
    'HEADERINCOTERMSCLASSIFICATION',
]
item_tasks = [
    'PLANT',
    'SHIPPINGPOINT',
    'ITEMINCOTERMSCLASSIFICATION',
]

task = 'ITEMINCOTERMSCLASSIFICATION'

sales = sales.drop([t for t in sale_tasks if t != task], axis=1)
items = items.drop([t for t in item_tasks if t != task], axis=1)

Finally, for the given task, we mask out all its test labels to prevent data leakage:

In [None]:
if task in sale_tasks:
    num_test = load_dataset(name, 'salesdocuments', split='test').num_rows
    y_test = sales[task].iloc[-num_test:].to_numpy().copy()
    pkey_test = sales['SALESDOCUMENT'].iloc[-num_test:].to_numpy()
    task_pos = sales.columns.get_loc(task)
    sales.iloc[-num_test:, task_pos] = None
elif task in item_tasks:
    num_test = load_dataset(name, 'salesdocument_items', split='test').num_rows
    y_test = items[task].iloc[-num_test:].to_numpy().copy()
    pkey_test = items['ID'].iloc[-num_test:].to_numpy()
    task_pos = items.columns.get_loc(task)
    items.iloc[-num_test:, task_pos] = None
else:
    raise ValueError(f"Unsupported task '{task}'")

## Graph Creation

We are now ready to convert the SALT dataset into a Kumo `LocalGraph`:

In [None]:
df_dict = {
    'sales': sales,
    'items': items,
    'customers': customers,
    'addresses': addresses,
}
graph = rfm.LocalGraph.from_data(df_dict, infer_metadata=False)

# "PRODUCT" is inferred as "text" column but should be marked as "categorical":
graph['items']['PRODUCT'].stype = 'categorical'

# Assign primary keys and time columns:
graph['sales'].primary_key = 'SALESDOCUMENT'
graph['sales'].time_column = 'CREATIONDATETIME'
graph['items'].primary_key = 'ID'
graph['items'].time_column = 'CREATIONDATETIME'
graph['customers'].primary_key = 'CUSTOMER'
graph['addresses'].primary_key = 'ADDRESSID'

# Assign edges:
graph.link(src_table='items', fkey='SALESDOCUMENT', dst_table='sales')
graph.link(src_table='items', fkey='SOLDTOPARTY', dst_table='customers')
graph.link(src_table='items', fkey='SHIPTOPARTY', dst_table='customers')
graph.link(src_table='items', fkey='PAYERPARTY', dst_table='customers')
graph.link(src_table='items', fkey='BILLTOPARTY', dst_table='customers')
graph.link(src_table='customers', fkey='ADDRESSID', dst_table='addresses');

Let's ensure that everything is set up correctly:

In [None]:
graph.print_metadata()
graph.print_links()

In [None]:
graph.visualize(show_columns=False);

## Model Execution

Once the `LocalGraph` is set up, we are ready making predictions and evaluate `KumoRFM` performance.

Let's load the model:

In [None]:
model = rfm.KumoRFM(graph)

The SALT tasks are multi-class classification tasks, in which we are asked to impute missing values. In `KumoRFM`, we can simply model this via the **Predictive Query Language** by predicting the target column on each of the tables:

In [None]:
if task in sale_tasks:
    query = f"PREDICT sales.{task} FOR sales.SALESDOCUMENT IN ({{indices}})"
else:
    query = f"PREDICT items.{task} FOR items.ID IN ({{indices}})"

We are now ready to iterate over all test entities (with a batch size up to `1000`) and obtain their predictions. In order to speed things up, we only run over a subset of test entities (defined by `max_test_steps`). For multi-class classification tasks, `KumoRFM` will return the probabilities of the top-10 most likely classes for each entity.

In [None]:
import tqdm
import numpy as np

batch_size = 1000
max_test_steps = 10

ys_pred = []
steps = list(range(0, len(pkey_test), batch_size))[:max_test_steps]
for i, step in enumerate(tqdm.tqdm(steps)):
    indices = pkey_test[step:step + batch_size].tolist()

    if task in sale_tasks:
        _query = query.format(indices=', '.join(f"'{i}'" for i in indices))
    else:
       _query = query.format(indices=', '.join(str(i) for i in indices))

    df = model.predict(
        _query,
        run_mode='best',  # Trades runtime in favor of better model performance.
        anchor_time='entity',  # Use entity table time as anchor time.
        num_hops=3,  # Ensure that we reach every table.
        verbose=i == 0,  # Prevent excessive logging.
    )

    # Save the predicted top-10 classes sorted by probability:
    ys_pred.append(df['CLASS'].to_numpy().reshape(len(indices), -1))

y_pred = np.concatenate(ys_pred, axis=0)

Finally, we are ready to evaluate our predictions by comparing them to the ground-truth test labels. The metric of choice here is Mean Reciprocal Rank (MRR), *i.e.* the reciprocal rank of the correct prediction averaged over all test labels.

Let's implement it quickly:

In [None]:
match = y_test[:len(y_pred)].reshape(-1, 1) == y_pred
rank = match.astype(float).argmax(axis=-1) + 1
reciprocal_rank = 1.0 / rank
reciprocal_rank[match.sum(axis=-1) == 0.0] = 0.0

print(f'MRR: {reciprocal_rank.mean():.4f}')

And that's it! If you run the model over all different tasks, you will observe that `KumoRFM` is on par to the best baseline reported in the [SALT](https://arxiv.org/html/2501.03413v1) paper, although we never trained anything.
Noteworthy, we are even underestimating the true MRR here, since for any correct prediction that is not within the top-10 most likely classes, we assign it a reciprocal rank of `0`.
The only dataset where `KumoRFM` underperforms a bit is the `SALESGROUP` task, which is explained by the fact that it is the task with the most classes (589).

We can further improve the performance by fine-tuning `KumoRFM` on SALT, but this is a story for another notebook. Happy hacking!