# Getting started

In [22]:
from pathlib import Path
import polars as pl
from ebrec.utils._constants import (
    DEFAULT_HISTORY_ARTICLE_ID_COL, 
    DEFAULT_IMPRESSION_ID_COL,
    DEFAULT_SUBTITLE_COL,
    DEFAULT_TITLE_COL, 
    DEFAULT_USER_COL, 
)
from ebrec.utils._behaviors import (
    create_binary_labels_column, 
    sampling_strategy_wu2019,
    add_known_user_column,
    add_prediction_scores,
    truncate_history, 
)
from ebrec.utils._articles import convert_text2encoding_with_transformers
from ebrec.utils._articles import create_article_id_to_value_mapping
from ebrec.utils._nlp import get_transformers_word_embeddings
from ebrec.utils._polars import concat_str_columns

from ebrec.models.newsrec.dataloader import NRMSDataLoader
from ebrec.models.newsrec.model_config import hparams_nrms
from ebrec.models.newsrec import NRMSModel

# 
from transformers import AutoTokenizer, AutoModel
import tensorflow as tf

## Load dataset

In [3]:
path = Path("../downloads/ebnerd")
HISTORY_SIZE = 30

df_history = (
    pl.scan_parquet(path.joinpath("history.parquet"))
    .select(DEFAULT_USER_COL, DEFAULT_HISTORY_ARTICLE_ID_COL)
    .pipe(
        truncate_history,
        column=DEFAULT_HISTORY_ARTICLE_ID_COL,
        history_size=HISTORY_SIZE,
        padding_value=0,
    )
)
df_behaviors = (
    pl.scan_parquet(path.joinpath("behaviors.parquet"))
    .join(df_history, on=DEFAULT_USER_COL, how="inner")
    .collect()
    .pipe(create_binary_labels_column, shuffle= True, seed=123)
)
df_articles = pl.read_parquet(path.joinpath("articles.parquet"))



In [4]:
df_behaviors.head(3)

impression_id,article_id,impression_time,read_time,scroll_percentage,device_type,article_ids_inview,article_ids_clicked,user_id,is_sso_user,gender,postcode,age,is_subscriber,session_id,next_read_time,next_scroll_percentage,article_id_fixed,labels
u32,i32,datetime[μs],f32,f32,i8,list[i32],list[i32],u32,bool,i8,i8,i8,bool,u32,f32,f32,list[i32],list[i8]
9,,2023-02-27 03:23:27,17.0,,2,"[9652027, 9652025, … 9651983]",[9651983],12258,True,0.0,,40.0,True,70,8.0,30.0,"[9645747, 9645715, … 9647111]","[0, 0, … 1]"
71,,2023-02-24 22:12:39,82.0,,3,"[9642621, 9644724, … 9650042]",[9649392],33638,True,1.0,4.0,60.0,True,938,14.0,17.0,"[9642685, 9642770, … 9646146]","[0, 0, … 0]"
151,9650148.0,2023-02-25 06:37:19,22.0,100.0,1,"[9594794, 9650097, … 9649734]",[9642713],44038,True,,,,True,1153,66.0,100.0,"[9644413, 9646081, … 9646154]","[0, 0, … 0]"


In [5]:
df_articles.head(3)

article_id,title,subtitle,last_modified_time,premium,body,published_time,image_ids,article_type,url,ner_clusters,entity_groups,topics,category,subcategory,category_str
i32,str,str,datetime[μs],bool,str,datetime[μs],list[i64],str,str,list[str],list[str],list[str],i16,list[i16],str
4108820,"""Se frække Tria…","""Den norske Par…",2023-06-29 06:36:23,False,"""- Gud. Jeg håb…",2011-03-26 09:12:29,"[3628735, 3628733, 3628734]","""article_defaul…","""https://ekstra…",[],[],[],414,[432],"""underholdning"""
4201730,"""Seks års fængs…","""Rockerlærlinge…",2023-06-29 06:42:51,False,"""Den 20-årige A…",2009-10-09 10:17:15,[3405274],"""article_defaul…","""https://ekstra…",[],[],[],140,[],"""krimi"""
4739365,"""Økonomisk luss…","""Stjernestylist…",2023-06-29 07:07:15,False,"""Det blev en dy…",2013-02-18 08:55:29,[3820962],"""article_defaul…","""https://ekstra…",[],[],[],414,[425],"""underholdning"""


In [6]:
TRANSFORMER_MODEL_NAME = "bert-base-multilingual-cased"
TEXT_COLUMNS_TO_USE = [DEFAULT_SUBTITLE_COL, DEFAULT_TITLE_COL]
MAX_TITLE_LENGTH = 30

transformer_model = AutoModel.from_pretrained(TRANSFORMER_MODEL_NAME)
transformer_tokenizer = AutoTokenizer.from_pretrained(TRANSFORMER_MODEL_NAME)
word2vec_embedding = get_transformers_word_embeddings(transformer_model)

df_articles, cat_cal = concat_str_columns(df_articles, columns=TEXT_COLUMNS_TO_USE)
df_articles, token_col_title = convert_text2encoding_with_transformers(
    df_articles, transformer_tokenizer, cat_cal, max_length=MAX_TITLE_LENGTH
)
# =>
article_mapping = create_article_id_to_value_mapping(df=df_articles, value_col=token_col_title)

In [7]:
df_articles.head(5)

article_id,title,subtitle,last_modified_time,premium,body,published_time,image_ids,article_type,url,ner_clusters,entity_groups,topics,category,subcategory,category_str,subtitle-title,subtitle-title_encode_bert-base-multilingual-cased
i32,str,str,datetime[μs],bool,str,datetime[μs],list[i64],str,str,list[str],list[str],list[str],i16,list[i16],str,str,list[i64]
4108820,"""Se frække Tria…","""Den norske Par…",2023-06-29 06:36:23,False,"""- Gud. Jeg håb…",2011-03-26 09:12:29,"[3628735, 3628733, 3628734]","""article_defaul…","""https://ekstra…",[],[],[],414,[432],"""underholdning""","""Den norske Par…","[10235, 18470, … 10453]"
4201730,"""Seks års fængs…","""Rockerlærlinge…",2023-06-29 06:42:51,False,"""Den 20-årige A…",2009-10-09 10:17:15,[3405274],"""article_defaul…","""https://ekstra…",[],[],[],140,[],"""krimi""","""Rockerlærlinge…","[12158, 10165, … 89494]"
4739365,"""Økonomisk luss…","""Stjernestylist…",2023-06-29 07:07:15,False,"""Det blev en dy…",2013-02-18 08:55:29,[3820962],"""article_defaul…","""https://ekstra…",[],[],[],414,[425],"""underholdning""","""Stjernestylist…","[10838, 59679, … 63400]"
4918926,"""Stjernekokkene…","""Nogle af Danma…",2023-06-29 07:14:22,True,"""Du behøver ikk…",2014-11-25 07:53:06,"[4668258, 4476036, … 4476661]","""article_defaul…","""https://ekstra…",[],[],[],457,[],"""forbrug""","""Nogle af Danma…","[10657, 23239, … 60735]"
5016920,"""Betinget fængs…","""Johnny Hansens…",2023-06-29 07:15:25,False,"""Kandis-forsang…",2014-09-12 09:19:21,"[4640922, 3539276]","""article_defaul…","""https://ekstra…",[],[],[],140,[],"""krimi""","""Johnny Hansens…","[15551, 22126, … 40681]"


In [8]:
df_test = df_behaviors.sample(fraction=0.1)
df_train = df_behaviors.filter(~pl.col(DEFAULT_IMPRESSION_ID_COL).is_in(df_test.select(DEFAULT_IMPRESSION_ID_COL).to_series()))
df_validation = df_train.sample(fraction=0.2)
df_train = df_train.filter(~pl.col(DEFAULT_IMPRESSION_ID_COL).is_in(df_validation.select(DEFAULT_IMPRESSION_ID_COL).to_series()))

In [9]:
print(df_train.shape)
print(df_validation.shape)
print(df_test.shape)

(754, 19)
(188, 19)
(104, 19)


In [10]:
df_train = sampling_strategy_wu2019(df_train, npratio=4, shuffle=False, with_replacement=True, seed=123).drop("labels").pipe(create_binary_labels_column, shuffle= True, seed=123)

In [11]:
df_train

impression_id,article_id,impression_time,read_time,scroll_percentage,device_type,article_ids_inview,article_ids_clicked,user_id,is_sso_user,gender,postcode,age,is_subscriber,session_id,next_read_time,next_scroll_percentage,article_id_fixed,labels
u32,i32,datetime[μs],f32,f32,i8,list[i64],list[i64],u32,bool,i8,i8,i8,bool,u32,f32,f32,list[i32],list[i8]
9,,2023-02-27 03:23:27,17.0,,2,"[9651983, 9652025, … 9652010]",[9651983],12258,true,0,,40,true,70,8.0,30.0,"[9645747, 9645715, … 9647111]","[1, 0, … 0]"
151,9650148,2023-02-25 06:37:19,22.0,100.0,1,"[9650073, 9649654, … 9649734]",[9642713],44038,true,,,,true,1153,66.0,100.0,"[9644413, 9646081, … 9646154]","[0, 0, … 0]"
153,9650148,2023-02-25 06:41:40,14.0,100.0,1,"[9649689, 9649538, … 9649569]",[9649689],44038,true,,,,true,1153,8.0,41.0,"[9644413, 9646081, … 9646154]","[1, 0, … 0]"
154,9650272,2023-02-25 06:30:56,177.0,100.0,1,"[9648409, 9650202, … 9647466]",[9650148],44038,true,,,,true,1153,139.0,100.0,"[9644413, 9646081, … 9646154]","[0, 0, … 0]"
155,9650148,2023-02-25 06:38:57,63.0,100.0,1,"[9649538, 9219607, … 9650040]",[9649538],44038,true,,,,true,1153,98.0,100.0,"[9644413, 9646081, … 9646154]","[1, 0, … 0]"
160,9650148,2023-02-25 06:33:53,139.0,100.0,1,"[9650272, 9647439, … 9650155]",[9647712],44038,true,,,,true,1153,21.0,41.0,"[9644413, 9646081, … 9646154]","[0, 0, … 0]"
200,,2023-02-23 09:16:29,16.0,,2,"[9301117, 9647097, … 9647097]",[9301117],31670,true,0,,40,true,449,127.0,75.0,"[9646081, 9646080, … 9646374]","[1, 0, … 0]"
204,,2023-02-23 09:16:15,10.0,,2,"[9647305, 9646826, … 9647359]",[9646826],31670,true,0,,40,true,449,3.0,35.0,"[9646081, 9646080, … 9646374]","[0, 1, … 0]"
444,9647482,2023-02-27 08:48:54,51.0,100.0,1,"[9647482, 9636613, … 9647482]",[9636613],13992,true,,0,,false,1757,44.0,100.0,"[9639691, 9614001, … 9645510]","[0, 1, … 0]"
766,9652127,2023-02-27 08:52:48,45.0,100.0,2,"[9652169, 9647482, … 9652127]",[9652148],28902,true,0,,,true,2032,33.0,66.0,"[9631127, 9644396, … 9646287]","[0, 0, … 0]"


In [12]:
train_dataloader = NRMSDataLoader(
    behaviors=df_train,
    article_dict=article_mapping,
    unknown_representation="zeros",
    history_column=DEFAULT_HISTORY_ARTICLE_ID_COL,
    eval_mode=False,
    batch_size=64,
)
val_dataloader = NRMSDataLoader(
    behaviors=df_validation,
    article_dict=article_mapping,
    unknown_representation="zeros",
    history_column=DEFAULT_HISTORY_ARTICLE_ID_COL,
    eval_mode=True,
    batch_size=64,
)
test_dataloader = NRMSDataLoader(
    behaviors=df_test,
    article_dict=article_mapping,
    unknown_representation="zeros",
    history_column=DEFAULT_HISTORY_ARTICLE_ID_COL,
    eval_mode=True,
    batch_size=16,
)

In [13]:
MODEL_NAME = "NRMS"
LOG_DIR = f"downloads/runs/{MODEL_NAME}"
MODEL_WEIGHTS = f"downloads/data/state_dict/{MODEL_NAME}/weights"

# CALLBACKS
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=LOG_DIR, histogram_freq=1)
early_stopping = tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=2)
modelcheckpoint = tf.keras.callbacks.ModelCheckpoint(
    filepath=MODEL_WEIGHTS, save_best_only=True, save_weights_only=True, verbose=1
)

hparams_nrms.history_size = HISTORY_SIZE
model = NRMSModel(
    hparams=hparams_nrms,
    word2vec_embedding=word2vec_embedding,
    seed=42,
)
hist = model.model.fit(
    train_dataloader,
    validation_data=val_dataloader,
    epochs=1,
    callbacks=[tensorboard_callback, early_stopping, modelcheckpoint],
)
model.model.load_weights(filepath=MODEL_WEIGHTS)



Epoch 1: val_loss improved from inf to 0.00000, saving model to data/state_dict/NRMS/weights


<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus at 0x329ee32d0>

In [14]:
pred = model.scorer.predict(test_dataloader)




In [24]:
df_test = add_prediction_scores(df_test, pred.tolist()).pipe(
    add_known_user_column, known_users=df_train[DEFAULT_USER_COL]
)

In [27]:
from ebrec.evaluation import AucScore, MrrScore, MetricEvaluator, NdcgScore
def compute_evaluation_scores(
    df: pl.DataFrame,
    metric_functions: list[MetricEvaluator] = [
        AucScore(),
        MrrScore(),
        NdcgScore(k=5),
        NdcgScore(k=10),
    ],
    pred_score = "scores",
    labels:str="labels"
) -> dict[str, float]:
    # =>
    y_pred = df[pred_score].to_list()
    y_true = df[labels].to_list()
    # =>
    metr = MetricEvaluator(
        labels=y_true,
        predictions=y_pred,
        metric_functions=metric_functions,
    )
    return metr.evaluate().evaluations

compute_evaluation_scores(df_test)


{'auc': 0.5114077996130365,
 'mrr': 0.33578225373086923,
 'ndcg@5': 0.3944523664383991,
 'ndcg@10': 0.44676071382900684}