# Getting started with the EB-NeRD

In [76]:
from pathlib import Path
import polars as pl

from ebrec.utils._descriptive_analysis import (
    min_max_impression_time_behaviors, 
    min_max_impression_time_history
)
from ebrec.utils._polars import slice_join_dataframes
from ebrec.utils._behaviors import (
    create_binary_labels_column,
    sampling_strategy_wu2019,
    truncate_history,
)
from ebrec.utils._constants import (
    DEFAULT_HISTORY_ARTICLE_ID_COL,
    DEFAULT_CLICKED_ARTICLES_COL,
    DEFAULT_INVIEW_ARTICLES_COL,
    DEFAULT_USER_COL,
    DEFAULT_HISTORY_IMPRESSION_TIMESTAMP_COL
)

## Load dataset:

In [53]:
PATH = Path("../dataset/data/ebnerd_demo")
data_split = "train"

In [90]:
df_behaviors = pl.scan_parquet(PATH.joinpath(data_split, "behaviors.parquet"))
df_history = pl.scan_parquet(PATH.joinpath(data_split, "history.parquet"))
display(df_history.collect().head())
display(df_behaviors.collect().head())

user_id,impression_time_fixed,scroll_percentage_fixed,article_id_fixed,read_time_fixed
u32,list[datetime[μs]],list[f32],list[i32],list[f32]
13538,"[2023-04-27 10:17:43, 2023-04-27 10:18:01, … 2023-05-17 20:36:34]","[100.0, 35.0, … 100.0]","[9738663, 9738569, … 9769366]","[17.0, 12.0, … 16.0]"
58608,"[2023-04-27 18:48:09, 2023-04-27 18:48:45, … 2023-05-17 19:46:40]","[37.0, 61.0, … null]","[9739362, 9739179, … 9770333]","[2.0, 24.0, … 0.0]"
95507,"[2023-04-27 15:20:28, 2023-04-27 15:20:47, … 2023-05-17 14:57:46]","[60.0, 100.0, … null]","[9739035, 9738646, … 9769450]","[18.0, 29.0, … 0.0]"
106588,"[2023-04-27 08:29:09, 2023-04-27 08:29:26, … 2023-05-16 05:50:52]","[24.0, 57.0, … 100.0]","[9738292, 9738216, … 9747803]","[9.0, 15.0, … 33.0]"
617963,"[2023-04-27 14:42:25, 2023-04-27 14:43:10, … 2023-05-18 02:28:09]","[100.0, 100.0, … 90.0]","[9739035, 9739088, … 9770798]","[45.0, 29.0, … 22.0]"


impression_id,article_id,impression_time,read_time,scroll_percentage,device_type,article_ids_inview,article_ids_clicked,user_id,is_sso_user,gender,postcode,age,is_subscriber,session_id,next_read_time,next_scroll_percentage
u32,i32,datetime[μs],f32,f32,i8,list[i32],list[i32],u32,bool,i8,i8,i8,bool,u32,f32,f32
48401,,2023-05-21 21:06:50,21.0,,2,"[9774516, 9771051, … 9759966]",[9759966],22779,False,,,,False,21,16.0,27.0
152513,9778745.0,2023-05-24 07:31:26,30.0,100.0,1,"[9778669, 9778736, … 9777397]",[9778661],150224,False,,,,False,298,2.0,48.0
155390,,2023-05-24 07:30:33,45.0,,1,"[9778369, 9777856, … 9778448]",[9777856],160892,False,,,,False,401,215.0,100.0
214679,,2023-05-23 05:25:40,33.0,,2,"[9776715, 9776406, … 9776855]",[9776566],1001055,False,,,,False,1357,40.0,47.0
214681,,2023-05-23 05:31:54,21.0,,2,"[9775202, 9776855, … 9776570]",[9776553],1001055,False,,,,False,1358,5.0,49.0


### Check min/max time-stamps in the data-split period

In [80]:
print(f"History: {min_max_impression_time_history(df_history).collect()}")
print(f"Behaviors: {min_max_impression_time_behaviors(df_behaviors).collect()}")

History: shape: (1, 2)
┌─────────────────────┬─────────────────────┐
│ min                 ┆ max                 │
│ ---                 ┆ ---                 │
│ datetime[μs]        ┆ datetime[μs]        │
╞═════════════════════╪═════════════════════╡
│ 2023-04-27 07:00:05 ┆ 2023-05-18 06:59:51 │
└─────────────────────┴─────────────────────┘
Behaviors: shape: (1, 2)
┌─────────────────────┬─────────────────────┐
│ min                 ┆ max                 │
│ ---                 ┆ ---                 │
│ datetime[μs]        ┆ datetime[μs]        │
╞═════════════════════╪═════════════════════╡
│ 2023-05-18 07:00:03 ┆ 2023-05-25 06:59:52 │
└─────────────────────┴─────────────────────┘


## Add History to Behaviors

In [84]:
df_history = df_history.select(DEFAULT_USER_COL, DEFAULT_HISTORY_ARTICLE_ID_COL).pipe(
    truncate_history,
    column=DEFAULT_HISTORY_ARTICLE_ID_COL,
    history_size=30,
    padding_value=0,
    enable_warning=False,
)
df_history.head(5).collect()

user_id,article_id_fixed
u32,list[i32]
13538,"[9767342, 9767751, … 9769366]"
58608,"[9763090, 9765545, … 9770333]"
95507,"[9768802, 9768583, … 9769450]"
106588,"[9751531, 9751633, … 9747803]"
617963,"[9765410, 9759300, … 9770798]"


In [87]:
df_history.collect()

user_id,article_id_fixed
u32,list[i32]
13538,"[9767342, 9767751, … 9769366]"
58608,"[9763090, 9765545, … 9770333]"
95507,"[9768802, 9768583, … 9769450]"
106588,"[9751531, 9751633, … 9747803]"
617963,"[9765410, 9759300, … 9770798]"
750497,"[9746360, 9767746, … 9769244]"
854388,"[9767233, 9766242, … 9768260]"
119480,"[0, 0, … 9747684]"
160892,"[9759345, 9766042, … 9770178]"
168638,"[9769909, 9769743, … 9768321]"


In [83]:
df = slice_join_dataframes(
    df1=df_behaviors.collect(),
    df2=df_history.collect(),
    on=DEFAULT_USER_COL,
    how="left",
)
df

impression_id,article_id,impression_time,read_time,scroll_percentage,device_type,article_ids_inview,article_ids_clicked,user_id,is_sso_user,gender,postcode,age,is_subscriber,session_id,next_read_time,next_scroll_percentage,impression_time_fixed,scroll_percentage_fixed,article_id_fixed,read_time_fixed
u32,i32,datetime[μs],f32,f32,i8,list[i32],list[i32],u32,bool,i8,i8,i8,bool,u32,f32,f32,list[datetime[μs]],list[f32],list[i32],list[f32]
48401,,2023-05-21 21:06:50,21.0,,2,"[9774516, 9771051, … 9759966]",[9759966],22779,false,,,,false,21,16.0,27.0,"[2023-04-27 09:05:54, 2023-04-27 09:06:09, … 2023-05-18 06:26:39]","[28.0, 17.0, … 15.0]","[9738452, 9737521, … 9770541]","[5.0, 4.0, … 7.0]"
152513,9778745,2023-05-24 07:31:26,30.0,100.0,1,"[9778669, 9778736, … 9777397]",[9778661],150224,false,,,,false,298,2.0,48.0,"[2023-04-29 11:34:06, 2023-04-29 11:34:25, … 2023-05-18 06:13:47]","[100.0, 49.0, … 24.0]","[9740087, 9741986, … 9735909]","[18.0, 244.0, … 7.0]"
155390,,2023-05-24 07:30:33,45.0,,1,"[9778369, 9777856, … 9778448]",[9777856],160892,false,,,,false,401,215.0,100.0,"[2023-04-27 09:10:33, 2023-04-27 09:20:25, … 2023-05-17 15:51:19]","[100.0, 20.0, … 100.0]","[9738557, 9738211, … 9770178]","[583.0, 257.0, … 158.0]"
214679,,2023-05-23 05:25:40,33.0,,2,"[9776715, 9776406, … 9776855]",[9776566],1001055,false,,,,false,1357,40.0,47.0,"[2023-04-27 12:12:45, 2023-04-27 12:13:30, … 2023-05-18 05:31:44]","[100.0, 100.0, … 28.0]","[9738777, 9738663, … 9769981]","[35.0, 62.0, … 16.0]"
214681,,2023-05-23 05:31:54,21.0,,2,"[9775202, 9776855, … 9776570]",[9776553],1001055,false,,,,false,1358,5.0,49.0,"[2023-04-27 12:12:45, 2023-04-27 12:13:30, … 2023-05-18 05:31:44]","[100.0, 100.0, … 28.0]","[9738777, 9738663, … 9769981]","[35.0, 62.0, … 16.0]"
214684,,2023-05-23 05:32:21,10.0,,2,"[9776508, 9767490, … 9774840]",[9776508],1001055,false,,,,false,1358,52.0,100.0,"[2023-04-27 12:12:45, 2023-04-27 12:13:30, … 2023-05-18 05:31:44]","[100.0, 100.0, … 28.0]","[9738777, 9738663, … 9769981]","[35.0, 62.0, … 16.0]"
214691,,2023-05-23 05:30:46,18.0,,2,"[9759955, 9776449, … 9775985]",[9776691],1001055,false,,,,false,1358,4.0,37.0,"[2023-04-27 12:12:45, 2023-04-27 12:13:30, … 2023-05-18 05:31:44]","[100.0, 100.0, … 28.0]","[9738777, 9738663, … 9769981]","[35.0, 62.0, … 16.0]"
369958,,2023-05-24 14:25:56,16.0,,2,"[9776023, 9778158, … 7594265]",[9778158],1469458,false,,,,false,1623,0.0,,"[2023-04-27 12:37:50, 2023-04-27 18:59:49, … 2023-05-14 19:05:18]","[null, null, … 67.0]","[9738452, 9739344, … 9764765]","[3.0, 0.0, … 8.0]"
369959,,2023-05-24 14:23:14,161.0,,2,"[9779186, 9779289, … 9779071]",[9779071],1469458,false,,,,false,1623,16.0,,"[2023-04-27 12:37:50, 2023-04-27 18:59:49, … 2023-05-14 19:05:18]","[null, null, … 67.0]","[9738452, 9739344, … 9764765]","[3.0, 0.0, … 8.0]"
370414,,2023-05-24 14:48:54,9.0,,2,"[9779408, 9779377, … 9779007]",[9777182],1470585,false,,,,false,1678,13.0,41.0,"[2023-04-27 10:23:31, 2023-04-27 10:43:15, … 2023-05-17 20:48:39]","[null, 44.0, … 33.0]","[9695098, 9738355, … 9769404]","[0.0, 19.0, … 4.0]"


## Generate labels

Here's an example how to generate binary labels based on ``article_ids_clicked`` and ``article_ids_inview``

In [28]:
df.select(DEFAULT_CLICKED_ARTICLES_COL, DEFAULT_INVIEW_ARTICLES_COL).pipe(
    create_binary_labels_column, shuffle=True, seed=123
).with_columns(pl.col("labels").list.len().name.suffix("_len")).head(5)

article_ids_clicked,article_ids_inview,labels,labels_len
list[i32],list[i32],list[i8],u32
[9759966],"[9142581, 9774461, … 9770028]","[0, 0, … 0]",11
[9778661],"[9778728, 9777397, … 9778657]","[0, 0, … 0]",17
[9777856],"[9778155, 9777856, … 9778226]","[0, 1, … 0]",11
[9776566],"[9776497, 9776071, … 9776855]","[0, 0, … 0]",9
[9776553],"[9771995, 9776570, … 9776246]","[0, 0, … 0]",18


An example using the downsample strategy employed by Wu et al.

In [29]:
NPRATIO = 2
df.select(DEFAULT_CLICKED_ARTICLES_COL, DEFAULT_INVIEW_ARTICLES_COL).pipe(
    sampling_strategy_wu2019, npratio=NPRATIO, shuffle=False, with_replacement=True, seed=123
).pipe(create_binary_labels_column, shuffle=True, seed=123).with_columns(pl.col("labels").list.len().name.suffix("_len")).head(5)

article_ids_clicked,article_ids_inview,labels,labels_len
list[i64],list[i64],list[i8],u32
[9759966],"[9774461, 9775371, 9759966]","[0, 0, 1]",3
[9778661],"[9778661, 9777397, 9778682]","[1, 0, 0]",3
[9777856],"[9777856, 9778351, 9778448]","[1, 0, 0]",3
[9776566],"[9776566, 9776855, 9776808]","[1, 0, 0]",3
[9776553],"[9776449, 9776553, 9776570]","[0, 1, 0]",3
