In [0]:
from datetime import datetime, timedelta
import os
import pandas as pd

local_path = "data/"
local_click_file = "search_click.csv"
local_item_file = "item_desc.csv"
local_click_ts_file = "search_click_ts.csv"
local_item_clientid_file = "item_desc_clientid.csv"

In [0]:
def query_search_with_click(output_path=local_click_file, output_file=local_click_file, date_range=7):
    """
    Export search and click data for the last `date_range` days to a CSV file.

    Args:
        output_path (str): The file path to save the CSV.
        date_range (int): The number of days to look back for data.
    """
    end_date = datetime.now() - timedelta(days=1)
    if date_range > 7:
        date_range = 7
    start_date = max(datetime(2025, 3, 18), end_date - timedelta(days=date_range))
    date_range = (end_date - start_date).days

    # Query the data
    query = f"""
        SELECT _token_associate_id AS user_id, 
               click_object_id AS item_id, 
               SUM(click) AS rating
        FROM onedata_us_east_1_shared_dit.nas_raw_lyric_search_dit.ml_search_with_click
        GROUP BY 1, 2
    """
    df = spark.sql(query)
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    df.toPandas().to_csv('/'.join([output_path, output_file]), index=False)

In [0]:
query_search_with_click(output_path=local_path, output_file=local_click_file, date_range=7)

In [0]:
def query_click_description(output_path=local_path, output_file=local_item_file):
    """
    Export click item data to a CSV file.

    Args:
        output_path (str): The file path to save the CSV.
    """
    # SQL query to fetch click item descriptions
    query = """
        SELECT click_object_id AS item_id, 
               click_details_caption AS title  
        FROM onedata_us_east_1_shared_dit.nas_raw_lyric_search_dit.ml_search_with_click
        GROUP BY click_object_id, click_details_caption
    """
    # Execute the query and save the result to a CSV file
    df = spark.sql(query)
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    df.toPandas().to_csv('/'.join([output_path, output_file]), index=False)

In [0]:
query_click_description(output_path=local_path, output_file=local_item_file)

In [0]:
def query_search_with_click_ts(output_path=local_path, output_file=local_click_ts_file):
    """
    Export click item data to a CSV file.

    Args:
        output_path (str): The file path to save the CSV.
    """
    # SQL query to fetch click item descriptions
    query = """
        SELECT _token_associate_id AS user_id,
                click_object_id AS item_id,
                to_unix_timestamp(time_stamp, "yyyy-MM-dd\'T\'HH:mm:ss.SSS\'Z\'") AS unix_time_stamp,
                SUM(click) AS rating
            FROM onedata_us_east_1_shared_dit.nas_raw_lyric_search_dit.ml_search_with_click
            WHERE click_object_id IS NOT NULL
            GROUP BY 1, 2, 3;
    """
    # Execute the query and save the result to a CSV file
    df = spark.sql(query)
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    df.toPandas().to_csv('/'.join([output_path, output_file]), index=False)

In [0]:
query_search_with_click_ts(output_path=local_path, output_file=local_click_ts_file)

In [0]:
def query_click_description_clientid(output_path=local_path, output_file=local_item_clientid_file):
    """
    Export click item data to a CSV file.

    Args:
        output_path (str): The file path to save the CSV.
    """
    # SQL query to fetch click item descriptions
    query = """
        SELECT click_object_id AS item_id, 
               click_details_caption AS title,
               click_client_id
        FROM onedata_us_east_1_shared_dit.nas_raw_lyric_search_dit.ml_search_with_click
        WHERE click_object_id IS NOT NULL
        GROUP BY click_object_id, click_details_caption, click_client_id
    """
    # Execute the query and save the result to a CSV file
    df = spark.sql(query)
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    df.toPandas().to_csv('/'.join([output_path, output_file]), index=False)

In [0]:
query_click_description_clientid(output_path=local_path, output_file=local_item_clientid_file)

In [0]:
def load_click_data(input_path=local_path, input_file=local_click_file):
    """
    Load click data from a CSV file.

    Args:
        input_path (str): The file path to load the CSV from.

    Returns:
        pandas.DataFrame: The loaded click data.
    """
    df = pd.read_csv('/'.join([input_path, input_file]),
                       skiprows=0,
                       names=["user_id", "item_id", "rating"],
                       dtype={"user_id": str, "item_id": str, "rating": str}
                       )
    df["rating"] = pd.to_numeric(df['rating'], errors='coerce').fillna(0).astype(int)
    print(len(df))
    return df

In [0]:
df = load_click_data(input_path=local_path, input_file=local_click_file)

513


In [0]:
def load_item_data(input_path=local_item_file, input_file=local_item_file):
    """
    Load item data from a CSV file.

    Args:
        input_path (str): The file path to load the CSV from.

    Returns:
        pandas.DataFrame: The loaded item data.
    """
    df = pd.read_csv('/'.join([input_path, input_file]),
                       skiprows=0,
                       names=["item_id", "title"],
                       dtype={"item_id": str, "title": str}
                       )
    print(len(df))
    return df

In [0]:
data = load_item_data(input_path=local_path, input_file=local_item_file)

211


In [0]:
def query_click_description_clientid(output_path=local_path, output_file=local_item_clientid_file):
    """
    Export click item data to a CSV file.

    Args:
        output_path (str): The file path to save the CSV.
    """
    # SQL query to fetch click item descriptions
    query = """
        SELECT click_object_id AS item_id, 
               click_details_caption AS title,
               click_client_id
        FROM onedata_us_east_1_shared_dit.nas_raw_lyric_search_dit.ml_search_with_click
        WHERE click_object_id IS NOT NULL
        GROUP BY click_object_id, click_details_caption, click_client_id
    """
    # Execute the query and save the result to a CSV file
    df = spark.sql(query)
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    df.toPandas().to_csv('/'.join([output_path, output_file]), index=False)

In [0]:
df = spark.sql(f"""
SELECT
    view._token_associate_id AS user_id,
    view._id AS item_id,
    view.click AS rating
FROM
    onedata_us_east_1_shared_dit.nas_raw_lyric_search_dit.ml_search_with_click AS view
JOIN
    (
        SELECT
            traceId,
            MAX(resPos) AS max_resPos
        FROM
            onedata_us_east_1_shared_dit.nas_raw_lyric_search_dit.ml_search_with_click
        WHERE
            click_object_id IS NOT NULL
        GROUP BY
            traceId
    ) AS click
ON
    view.traceId = click.traceId
    AND view.resPos <= click.max_resPos
 """)
df.toPandas().to_csv("data/view_click.csv", index=False)

In [0]:
df = spark.sql(f"""
        SELECT DISTINCT
            _token_associate_id as user_id,
            LAST_VALUE(user_agent) OVER (PARTITION BY _token_associate_id ORDER BY time_stamp ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS last_user_agent
        FROM onedata_us_east_1_shared_dit.nas_raw_lyric_search_dit.ml_search_with_click 
 """)
df.toPandas().to_csv("data/user_desc.csv", index=False)

In [0]:
!pip install torch torchvision

Looking in indexes: https://artifactory.us.caas.oneadp.com/artifactory/api/pypi/pypi/simple/
Collecting torchvision
  Downloading https://artifactory.us.caas.oneadp.com/artifactory/api/pypi/pypi/packages/packages/36/63/0722e153fd27d64d5b0af45b5c8cb0e80b35a68cf0130303bc9a8bb095c7/torchvision-0.21.0-cp312-cp312-manylinux1_x86_64.whl (7.2 MB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/7.2 MB[0m [31m?[0m eta [36m-:--:--[0m
[2K     [91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.1/7.2 MB[0m [31m3.4 MB/s[0m eta [36m0:00:03[0m
[2K     [91m━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.3/7.2 MB[0m [31m4.1 MB/s[0m eta [36m0:00:02[0m
[2K     [91m━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.8/7.2 MB[0m [31m8.6 MB/s[0m eta [36m0:00:01[0m
[2K     [91m━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/7.2 MB[0m [31m24.9 MB/s[0m eta [36m0:00:01[0m
[2K     

### Preprocessing and Cleaning the Data

In [0]:
query = f"""
    SELECT _token_associate_id AS user_id, 
            click_object_id AS item_id, 
            SUM(click) AS rating
    FROM onedata_us_east_1_shared_dit.nas_raw_lyric_search_dit.ml_search_with_click
    WHERE click_object_id is not null
    GROUP BY 1, 2
"""
df = spark.sql(query)

In [0]:
from pyspark.sql import functions as F

# Calculate the mean of the ratings column
ratings_mean = df.groupBy().avg('rating').collect()[0][0]
print(f"Mean rating: {ratings_mean}")

# Modify all ratings less than the mean to 0 and greater than the mean to 1 and using a UDF to apply the transformation
modify_rating_udf = F.udf(lambda x: 0 if x < ratings_mean else 1, 'int')
relevant_df = df.withColumn('rating', modify_rating_udf('rating'))

# Rename rating to label
relevant_df = relevant_df.withColumnRenamed('rating', 'label')

# Displaying the dataframe
display(relevant_df)

Mean rating: 1.1316062176165802


user_id,item_id,label
3829c698-e618-47a4-a560-264450c55e9b,,0
570ade82-fdef-4be5-9193-7c8869834bef,cf885bac-24ce-45fe-85ec-a8e578890b42,1
319ae8a7-9ab6-4c19-9cdd-8d80d1c9d12b,f6676d24a90242689b35c45333ef8d40,0
7864e425-bbea-4537-9425-f807f60bf131,73c5155e042b4f228ce1822c51e52fed,1
818c47d0-4ffd-4e41-a9c0-4430771bfe28,,0
e9252582-aeb7-4993-8c17-a975e520599b,,0
7c287dcc-9efb-40e6-9fa9-94f92c0dd2d5,,0
3d365bbc-267e-42c2-ad46-33635a96c841,b29ed30996ac49c99bef20bd295ef677,1
5cddd9c3-d785-4a8f-b78e-1a569c76596e,b3cc3ceac4d24c2e843aa13078bd2f8e,0
8b215b29-f375-4561-a792-915a4b23315d,417ff59ee1c442b0bffb43ec4847a908,0


In [0]:
# Split the dataframe into train, test, and validation sets
train_df, validation_df, test_df = relevant_df.randomSplit([0.7, 0.2, 0.1], seed=42)

# Show the count of each split to verify the distribution
print(f"Training Dataset Count: {train_df.count()}")
print(f"Validation Dataset Count: {validation_df.count()}")
print(f"Test Dataset Count: {test_df.count()}")

Training Dataset Count: 719
Validation Dataset Count: 172
Test Dataset Count: 74


In [0]:
!conda install pytorch cudatoolkit=11.3 -c pytorch-nightly -y

/bin/bash: line 1: conda: command not found


In [0]:
!pip3 install torchrec-nightly

Looking in indexes: https://artifactory.us.caas.oneadp.com/artifactory/api/pypi/pypi/simple/
Collecting torchrec-nightly
  Downloading https://artifactory.us.caas.oneadp.com/artifactory/api/pypi/pypi/packages/packages/b1/ce/da1e04aa387fdfc28e0541bd9fda735b9f31e346fd18b85c1a1cb8243244/torchrec_nightly-2023.8.20-py310-none-any.whl (390 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/390.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m71.7/390.3 kB[0m [31m4.0 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━[0m [32m378.9/390.3 kB[0m [31m6.6 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m390.3/390.3 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting iopath (from torchrec-nightly)
  Downloading https://artifactory.us.caas.oneadp.com/artifactory/api/pypi/pypi/packages/packa

In [0]:
%pip install scikit-learn

Looking in indexes: https://artifactory.us.caas.oneadp.com/artifactory/api/pypi/pypi/simple/
[43mNote: you may need to restart the kernel using %restart_python or dbutils.library.restartPython() to use updated packages.[0m


In [0]:
query = f"""
    SELECT _token_associate_id AS user_id, 
            click_object_id AS item_id, 
            SUM(click) AS rating
    FROM onedata_us_east_1_shared_dit.nas_raw_lyric_search_dit.ml_search_with_click
    GROUP BY 1, 2
"""
df_click = spark.sql(query).toPandas()

In [0]:
import pandas as pd

In [0]:
MAX_USERS = 500

In [0]:
from sklearn import preprocessing
item_id_encoder = preprocessing.LabelEncoder()
user_id_encoder = preprocessing.LabelEncoder()

In [0]:
user_id_encoder .fit(df_click['user_id'])
len(user_id_encoder.classes_)

286

In [0]:
item_id_encoder .fit(df_click['item_id'])
len(item_id_encoder.classes_)

278

In [0]:
data = dict()
for idx, row in df_click.iterrows():
  if user_id_encoder.transform([row['user_id']])[0] not in data:
    data[user_id_encoder.transform([row['user_id']])[0]] = list(item_id_encoder.transform([row['item_id']]))
  else:
    data[user_id_encoder.transform([row['user_id']])[0]].extend(list(item_id_encoder.transform([row['item_id']])))

In [0]:
data

{54: [277, 163, 191],
 85: [225, 277, 112, 231, 42, 0, 83, 143, 191, 164],
 44: [265, 76, 277, 240, 22, 80],
 129: [114,
  76,
  166,
  205,
  151,
  119,
  277,
  48,
  190,
  144,
  259,
  210,
  33,
  153,
  126,
  87,
  276,
  227,
  253,
  275,
  239,
  256,
  88,
  204,
  164,
  191,
  255,
  258],
 142: [277, 166, 48, 151, 191, 183, 23, 133, 67],
 262: [277, 189],
 138: [277],
 60: [190, 277, 45],
 94: [191, 277, 275],
 152: [54,
  151,
  15,
  141,
  6,
  277,
  247,
  78,
  191,
  87,
  93,
  19,
  77,
  193,
  31,
  169,
  96,
  165],
 45: [259,
  138,
  22,
  166,
  114,
  87,
  190,
  256,
  277,
  191,
  51,
  79,
  84,
  275,
  142,
  220,
  127,
  30,
  169,
  58,
  8,
  48,
  16,
  40,
  56,
  204,
  151,
  93,
  202,
  196,
  47,
  71,
  233,
  119,
  239,
  213,
  108,
  276,
  184,
  63,
  109,
  181,
  106,
  137,
  228,
  163,
  122,
  107,
  133,
  160,
  64,
  136,
  172,
  244,
  139,
  156,
  130,
  192,
  96,
  273,
  208,
  89,
  24,
  175,
  223,
  174,
  11

In [0]:
%pip install --trusted-host artifactory.us.caas.oneadp.com -i https://artifactory.us.caas.oneadp.com/artifactory/api/pypi/pypi/simple/ torch torchvision torchrec fbgemm-gpu

Looking in indexes: https://artifactory.us.caas.oneadp.com/artifactory/api/pypi/pypi/simple/
[43mNote: you may need to restart the kernel using %restart_python or dbutils.library.restartPython() to use updated packages.[0m


In [0]:
%pip install --upgrade pip fbgemm-gpu-nightly-cpu

Looking in indexes: https://artifactory.us.caas.oneadp.com/artifactory/api/pypi/pypi/simple/
[43mNote: you may need to restart the kernel using %restart_python or dbutils.library.restartPython() to use updated packages.[0m


In [0]:
!pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 -U
!pip3 install fbgemm_gpu --index-url https://download.pytorch.org/whl/nightly/cu121
!pip3 install torchmetrics==1.0.3
!pip3 install torchrec --index-url https://download.pytorch.org/whl/nightly/cu121

Looking in indexes: https://download.pytorch.org/whl/nightly/cu121
[0mCould not fetch URL https://download.pytorch.org/whl/nightly/cu121/torch/: There was a problem confirming the ssl certificate: HTTPSConnectionPool(host='download.pytorch.org', port=443): Max retries exceeded with url: /whl/nightly/cu121/torch/ (Caused by SSLError(SSLCertVerificationError(1, '[SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: self-signed certificate in certificate chain (_ssl.c:1000)'))) - skipping
Looking in indexes: https://download.pytorch.org/whl/nightly/cu121
Looking in indexes: https://artifactory.us.caas.oneadp.com/artifactory/api/pypi/pypi/simple/
Looking in indexes: https://download.pytorch.org/whl/nightly/cu121


In [0]:
dbutils.library.restartPython()

In [0]:
%pip install --trusted-host artifactory.us.caas.oneadp.com -i https://artifactory.us.caas.oneadp.com/artifactory/api/pypi/pypi/simple/ torch torchvision

Looking in indexes: https://artifactory.us.caas.oneadp.com/artifactory/api/pypi/pypi/simple/
[43mNote: you may need to restart the kernel using %restart_python or dbutils.library.restartPython() to use updated packages.[0m


In [0]:
%pip install --trusted-host artifactory.us.caas.oneadp.com -i https://artifactory.us.caas.oneadp.com/artifactory/api/pypi/pypi/simple/ fbgemm-gpu

Looking in indexes: https://artifactory.us.caas.oneadp.com/artifactory/api/pypi/pypi/simple/
[43mNote: you may need to restart the kernel using %restart_python or dbutils.library.restartPython() to use updated packages.[0m


In [0]:
dbutils.library.restartPython()

In [0]:
%pip install --upgrade --no-deps --force-reinstall torch torchvision fbgemm-gpu torchrec --trusted-host artifactory.us.caas.oneadp.com --extra-index-url https://artifactory.us.caas.oneadp.com/artifactory/api/pypi/pypi/simple/


# Successfully installed fbgemm-gpu-1.1.0 torch-2.6.0 torchrec-1.1.0 torchvision-0.21.0

dbutils.library.restartPython()

Looking in indexes: https://artifactory.us.caas.oneadp.com/artifactory/api/pypi/pypi/simple/, https://artifactory.us.caas.oneadp.com/artifactory/api/pypi/pypi/simple/
Collecting torch
  Downloading https://artifactory.us.caas.oneadp.com/artifactory/api/pypi/pypi/packages/packages/e5/35/0c52d708144c2deb595cd22819a609f78fdd699b95ff6f0ebcd456e3c7c1/torch-2.6.0-cp312-cp312-manylinux1_x86_64.whl (766.6 MB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/766.6 MB[0m [31m?[0m eta [36m-:--:--[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.1/766.6 MB[0m [31m3.2 MB/s[0m eta [36m0:04:04[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.3/766.6 MB[0m [31m4.4 MB/s[0m eta [36m0:02:54[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/766.6 MB[0m [31m10.4 MB/s[0m eta [36m0:01:14[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.2/766.6 MB[0m [31m30.8 MB/s[0m eta [36m0:00:25[0

In [0]:
import torch
import torchrec


[0;31m---------------------------------------------------------------------------[0m
[0;31mOSError[0m                                   Traceback (most recent call last)
File [0;32m/local_disk0/.ephemeral_nfs/envs/pythonEnv-e2204beb-e1db-47d7-9048-9791d2b0172d/lib/python3.12/site-packages/torch/__init__.py:318[0m, in [0;36m_load_global_deps[0;34m()[0m
[1;32m    317[0m [38;5;28;01mtry[39;00m:
[0;32m--> 318[0m     ctypes[38;5;241m.[39mCDLL(global_deps_lib_path, mode[38;5;241m=[39mctypes[38;5;241m.[39mRTLD_GLOBAL)
[1;32m    319[0m     [38;5;66;03m# Workaround slim-wheel CUDA dependency bugs in cusparse and cudnn by preloading nvjitlink[39;00m
[1;32m    320[0m     [38;5;66;03m# and nvrtc. In CUDA-12.4+ cusparse depends on nvjitlink, but does not have rpath when[39;00m
[1;32m    321[0m     [38;5;66;03m# shipped as wheel, which results in OS picking wrong/older version of nvjitlink library[39;00m
[1;32m    322[0m     [38;5;66;03m# if `LD_LIBRARY_PATH` is d

### Sample Data Set

In [0]:
### collect all clicks

query = f"""
    SELECT _token_associate_id AS user_id, 
            click_object_id AS item_id, 
            SUM(click) AS rating
    FROM onedata_us_east_1_shared_dit.nas_raw_lyric_search_dit.ml_search_with_click
    WHERE click_object_id is not null
    GROUP BY 1, 2
"""
df = spark.sql(query)

In [0]:
df.select('user_id').distinct().count()

229

In [0]:
df.select('item_id').distinct().count()

302

In [0]:
df.groupby('rating').count().show()

+------+-----+
|rating|count|
+------+-----+
|     7|    4|
|     6|    5|
|    17|    1|
|    28|    1|
|     5|   12|
|     1|  552|
|     3|   51|
|     8|    2|
|     2|  101|
|     4|   24|
|    13|    1|
|    14|    1|
|    16|    1|
+------+-----+



In [0]:
def shape(data):
    rows, cols = data.count(), len(data.columns)
    shape = (rows, cols)
    return shape

In [0]:
shape(df)

(756, 3)

In [0]:
from pyspark.sql import functions as F

# Calculate the mean of the ratings column
ratings_mean = df.groupBy().avg('rating').collect()[0][0]
print(f"Mean rating: {ratings_mean}")

# Modify all ratings less than the mean to 0 and greater than the mean to 1 and using a UDF to apply the transformation
modify_rating_udf = F.udf(lambda x: 0 if x < ratings_mean else 1, 'int')
relevant_df = df.withColumn('rating', modify_rating_udf('rating'))

# Rename rating to label
relevant_df = relevant_df.withColumnRenamed('rating', 'label')

# Displaying the dataframe
display(relevant_df)

Mean rating: 1.6203703703703705


user_id,item_id,label
321d47b1-78de-438f-b6c3-2479e06d3bd2,7bb2c89a01a746e1b46272901a87605b,1
f30165e7-e7d7-4212-b961-47997b7e53af,9ca43b27a0ed48d699f0e126a9dcbe14,0
35d2b23b-4d4d-4b58-9686-0c57103af869,1c8362ee6cf44ff49990d29f3f783825,1
321d47b1-78de-438f-b6c3-2479e06d3bd2,b4eb1e9b-ce0c-4a6e-ac8c-b78745c6f696,0
cf056cdb-c14b-4fe7-abb4-ea899db0e992,2dfcb5da-8de6-43ad-a597-9caaef87884c,1
cf4d9d6d-9902-47dc-ab48-ec153f6f05e9,3a59d8daae5147c389cdaaffbcad7d64,0
c1e5e4b0-5515-4b95-bdda-2dd12b7f7f6a,dcfb1c00f6414037935ad05e2c69d10c,1
7864e425-bbea-4537-9425-f807f60bf131,fe6c2ca0f8744858836da9117b4cebcc,1
734cacb1-3f7a-4023-a75a-542996432829,b157943c-c33e-41eb-ab4a-e1337ddb70bb,0
b01bb32d-e5ee-4aa2-82f2-9155160d77e1,60c2dd26439a47a09c01ef205baec43d,1


In [0]:
train_df, test_df = relevant_df.randomSplit([0.99, 0.01], seed=42)
display(train_df)

user_id,item_id,label
0016b96c-3042-467c-a00e-14d69b0fd172,0429ea12a5974851ab47620c9d7205c9,0
0016b96c-3042-467c-a00e-14d69b0fd172,2c09c9db5dfb4314af1b9a4ffa28dbb6,0
0016b96c-3042-467c-a00e-14d69b0fd172,3a59d8daae5147c389cdaaffbcad7d64,0
0016b96c-3042-467c-a00e-14d69b0fd172,4766e323575849acb389a0ba42353f89,1
0016b96c-3042-467c-a00e-14d69b0fd172,5d34ef61f54f484689722483803984a8,0
0016b96c-3042-467c-a00e-14d69b0fd172,5d775cc778f84a37a137062294f2a451,0
0016b96c-3042-467c-a00e-14d69b0fd172,ac88a5027bca47408d31fbba4a5146a5,1
0016b96c-3042-467c-a00e-14d69b0fd172,bebe615efd3442469f4b3a3c33cd6d3e,1
0016b96c-3042-467c-a00e-14d69b0fd172,f2359cbbb0b14ed095b3f3f510e56c05,0
001ba9e1-576c-4d2a-84ba-4970ac5e5168,1821f26111ce487e9089d7d85b3098e3,0


In [0]:
import databricks.automl
 
from databricks import automl
summary = automl.classify(train_df, target_col="label", timeout_minutes=30)

2025/03/27 20:16:32 INFO databricks.automl.client.manager: AutoML will optimize for F1 score metric, which is tracked as val_f1_score in the MLflow experiment.
2025/03/27 20:16:33 INFO databricks.automl.client.manager: MLflow Experiment ID: 3674116869511500
2025/03/27 20:16:33 INFO databricks.automl.client.manager: MLflow Experiment: https://adpdc-share1-dev.cloud.databricks.com/?o=233647784655798#mlflow/experiments/3674116869511500
2025/03/27 20:18:18 INFO databricks.automl.client.manager: Data exploration notebook: https://adpdc-share1-dev.cloud.databricks.com/?o=233647784655798#notebook/3674116869511534
2025/03/27 20:47:25 INFO databricks.automl.client.manager: AutoML experiment completed successfully.


Unnamed: 0,Train,Validation,Test
f1_score,0.881,0.509,0.308
recall_score,0.855,0.424,0.275
roc_auc,0.987,0.691,0.628
false_negatives,17.0,19.0,37.0
false_positives,10.0,8.0,26.0
example_count,430.0,118.0,199.0
precision_score,0.909,0.636,0.35
true_positives,100.0,14.0,14.0
precision_recall_auc,0.968,0.457,0.346
true_negatives,303.0,77.0,122.0
