In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import polars as pl
import numpy as np
from tqdm.auto import tqdm
import torch
from data_process import *
from model import *
import random

In [3]:
# Set random seeds for reproducibility
random.seed(42)
torch.manual_seed(42)
np.random.seed(42)

**Training model**

In [33]:
print('Loading training data...')

behaviors_train = pl.read_parquet('preprocess/behaviors_train.parquet')
history_train = pl.read_parquet('preprocess/history_train.parquet')
article = pl.read_parquet('preprocess/article.parquet')
images_embeddings = pl.read_parquet('preprocess/image_embeddings.parquet')
categories = pl.read_parquet('preprocess/categories_embeddings.parquet')

article_embeddings = load_parquets('datasets/FacebookAI_xlm_roberta_base.zip')['FacebookAI_xlm_roberta_base/xlm_roberta_base'] 
article_image_embeddings = merge_article_with_imgs(article_embeddings, images_embeddings, col='embeddings')


Loading training data...


In [39]:
article = article.fill_null(0)

In [40]:
print(behaviors_train.shape)
print(history_train.shape)
print(article.shape)
print(images_embeddings.shape)
print(categories.shape)
# print(article_embeddings)
print(article_image_embeddings.shape)

(24724, 21)
(1590, 9)
(11777, 26)
(102603, 2)
(25, 2)
(125541, 3)


In [41]:
article

article_id,title,subtitle,last_modified_time,premium,body,published_time,image_ids,article_type,url,ner_clusters,entity_groups,topics,category,subcategory,category_str,total_inviews,total_pageviews,total_read_time,sentiment_score,sentiment_label,category_link,ner_clusters_link,topics_link,pub_weekday,pub_hour
i32,str,str,datetime[μs],i64,str,datetime[μs],list[i64],i64,str,list[i64],list[str],list[i64],i64,list[i64],str,i32,i32,f32,f32,i64,i64,list[i64],list[i64],i8,i8
3037230,"""Ishockey-spiller: Jeg troede j…","""ISHOCKEY: Ishockey-spilleren S…",2023-06-29 06:20:57,0,"""Ambitionerne om at komme til U…",2003-08-28 08:55:00,,0,"""https://ekstrabladet.dk/sport/…",[],[],"[4, 3, 9]",2,[],"""sport""",0,0,0.0,0.9752,0,22,[],"[25, 21, … 37]",3,8
3044020,"""Prins Harry tvunget til dna-te…","""Hoffet tvang Prins Harry til a…",2023-06-29 06:21:16,0,"""Den britiske tabloidavis The S…",2005-06-29 08:47:00,"[3097307, 3097197, 3104927]",0,"""https://ekstrabladet.dk/underh…",[],"[""PER"", ""PER""]","[4, 3, … 6]",3,[],"""underholdning""",0,0,0.0,0.7084,0,23,"[11216, 13248]","[25, 21, … 47]",2,8
3057622,"""Rådden kørsel på blå plader""","""Kan ikke straffes: Udenlandske…",2023-06-29 06:21:24,0,"""Slingrende spritkørsel. Grove …",2005-10-10 07:20:00,[3047102],0,"""https://ekstrabladet.dk/nyhede…",[],[],[4],0,[0],"""nyheder""",0,0,0.0,0.9236,0,12,[],"[25, 66, 4]",0,7
3073151,"""Mærsk-arvinger i livsfare""","""FANGET I FLODBØLGEN: Skibsrede…",2023-06-29 06:21:38,0,"""To oldebørn af skibsreder Mærs…",2005-01-04 06:59:00,"[3067474, 3067478, 3153705]",0,"""https://ekstrabladet.dk/nyhede…",[],[],"[1, 8, 5]",0,[0],"""nyheder""",0,0,0.0,0.9945,0,12,[],"[10, 50, … 52]",1,6
3193383,"""Skød svigersøn gennem babydyne""","""44-årig kvinde tiltalt for dra…",2023-06-29 06:22:57,0,"""En 44-årig mormor blev i dag f…",2003-09-15 15:30:00,,0,"""https://ekstrabladet.dk/krimi/…",[],[],"[4, 6]",1,[],"""krimi""",0,0,0.0,0.9966,0,9,[],"[25, 47]",0,15
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
9803492,"""Vilde billeder: Vulkan i udbru…","""Der er gang i vulkanen på Hawa…",2023-06-29 06:49:26,0,"""Det spyer med lava fra vulkane…",2023-06-08 05:49:20,"[9803493, 9803494, … 9803494]",0,"""https://ekstrabladet.dk/nyhede…",[],"[""LOC"", ""LOC"", … ""ORG""]",[],0,[0],"""nyheder""",535989,100120,4.112624e6,0.6095,1,12,"[11285, 15434, … 30071]","[20, 72, 59]",3,5
9803505,"""Flyvende Antonsen knuser topsp…","""Verdens nummer syv, Chou Tien-…",2023-06-29 06:49:26,0,"""Anders Antonsen har holdt paus…",2023-06-08 05:54:06,[9803516],0,"""https://ekstrabladet.dk/sport/…","[0, 0]","[""PER"", ""PROD"", … ""LOC""]","[3, 0, … 10]",2,[],"""sport""",13320,959,55691.0,0.8884,2,22,"[892, 995, … 28251]","[21, 3, … 58]",3,5
9803525,"""Dansk skuespiller: - Jeg nægte…","""Julie R. Ølgaard fik akut kejs…",2023-06-29 06:49:26,0,"""Mens hun lå søvnløs, lød kakof…",2023-06-08 06:45:46,"[9803518, 9803519, … 9803524]",0,"""https://ekstrabladet.dk/underh…",[],"[""PER"", ""PROD"", … ""MISC""]","[3, 5, 11]",3,[],"""underholdning""",315391,50361,2.550671e6,0.7737,0,23,"[5626, 8166, … 27754]","[21, 32, … 62]",3,6
9803560,"""Så slemt er det: 14.000 huse e…","""Tusindvis af huse står under v…",2023-06-29 06:49:26,0,"""Et område på omkring 600 kvadr…",2023-06-08 06:25:42,,0,"""https://ekstrabladet.dk/nyhede…",[],"[""LOC"", ""LOC"", … ""LOC""]",[7],0,[],"""nyheder""",21318,1237,67514.0,0.9927,0,12,"[7192, 15388, … 30140]","[19, 20, … 49]",3,6


In [48]:
behaviors_train = behaviors_train.with_columns(
    [
        pl.col("gender").fill_null(-1),
        pl.col("postcode").fill_null(-1),
        pl.col("age").fill_null(-1)
    ]
)
behaviors_train

impression_id,article_id,impression_time,read_time,scroll_percentage,device_type,article_ids_inview,article_ids_clicked,user_id,is_sso_user,gender,postcode,age,is_subscriber,session_id,next_read_time,next_scroll_percentage,article_delta_time,number_articles,impression_hour,impression_weekday
u32,i32,datetime[μs],f32,f32,i8,list[i64],list[i64],u32,i8,i8,i8,i8,i8,u32,f32,f64,list[i64],i64,i8,i8
48401,,2023-05-21 21:06:50,21.0,,2,"[9657, 9360, … 8514]",[10],27,0,-1,-1,-1,0,21,16.0,0.27,"[56, 14, … 55]",,21,6
152513,9778745,2023-05-24 07:31:26,30.0,100.0,1,"[10004, 10014, … 9918]",[4],741,0,-1,-1,-1,0,298,2.0,0.48,"[49, 3, … 64]",,7,2
155390,,2023-05-24 07:30:33,45.0,,1,"[9987, 9945, … 9994]",[1],8,0,-1,-1,-1,0,401,215.0,1.0,"[72, 76, … 76]",,7,2
214679,,2023-05-23 05:25:40,33.0,,2,"[9861, 9837, … 9864]",[2],574,0,-1,-1,-1,0,1357,40.0,0.47,"[58, 13, … 19]",,5,1
214681,,2023-05-23 05:31:54,21.0,,2,"[9706, 9864, … 9853]",[5],574,0,-1,-1,-1,0,1358,5.0,0.49,"[95, 24, … 85]",,5,1
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
579983230,,2023-05-22 08:30:52,35.0,,1,"[9736, 9751, … 9746]",[0],69,0,-1,-1,-1,0,170832,7.0,0.28,"[67, 15, … 5]",,8,0
579983231,,2023-05-22 08:31:34,89.0,,1,"[9715, 4894, … 9681]",[19],69,0,-1,-1,-1,0,170832,0.0,0.0,"[80, 98, … 73]",,8,0
579984721,9774541,2023-05-22 08:51:33,123.0,100.0,2,"[9750, 9757, … 9713]",[0],101,0,-1,-1,-1,0,107303,73.0,1.0,"[35, 22, … 20]",,8,0
579984723,9775699,2023-05-22 08:53:36,73.0,100.0,2,"[9467, 8368, … 9379]",[1],101,0,-1,-1,-1,0,107303,101.0,1.0,"[96, 75, … 62]",,8,0


In [49]:
ds_train = EBDataset(behaviors_train, history_train, article, article_image_embeddings, categories)
dl_train = torch.utils.data.DataLoader(ds_train, batch_size=64, collate_fn=pad_train, shuffle=True)

Converting format
Loading embeddings...


In [50]:
model = EBRank()
device = 'cpu'
model = model.to(device)

In [52]:
optimizer = torch.optim.Adam(model.parameters())
epochs = 10  # Number of epochs to train from scratch
acc_loss = 0
acc_hit_rate = 0

for epoch in range(1, epochs + 1):
    if epoch == 2:
        print('Setting LR to 1e-4')
        set_lr(optimizer, 1e-4)
    if epoch == 5:
        print('Setting LR to 1e-5')
        set_lr(optimizer, 1e-5)

    for (_, (in_view_len, behavior), history), (clicked, scroll) in tqdm(dl_train):
        optimizer.zero_grad()
        torch.cuda.empty_cache()
        behavior = to_device(behavior, device)
        history = to_device(history, device)
        pred = model(behavior, history)
        c_loss, hit_rate = balance_bce_loss(pred, in_view_len, clicked)
        c_loss.backward()
        optimizer.step()
        c_loss = c_loss.item()
        acc_loss += c_loss
        acc_hit_rate += hit_rate
    
    print(f'Epoch {epoch} - Loss: {acc_loss / len(dl_train)} - Hit Rate: {acc_hit_rate / len(dl_train)}')

  0%|          | 0/387 [00:00<?, ?it/s]

IndexError: index out of range in self

**Validating model**

In [None]:
print('Loading validation data...')

behaviors_validation = pl.read_parquet('preprocess/behaviors_validation.parquet')
history_validation = pl.read_parquet('preprocess/history_validation.parquet')

ds_validation = EBDataset(behaviors_validation, history_validation, article, article_image_embeddings, categories)
dl_validation = torch.utils.data.DataLoader(ds_validation, batch_size=64, collate_fn=pad_train, shuffle=True)

In [None]:
optimizer = torch.optim.Adam(model.parameters())

acc_loss = 0
acc_hit_rate = 0

for (_, (in_view_len, behavior), history), (clicked, scroll) in tqdm(dl_validation):
    torch.cuda.empty_cache()
    behavior = to_device(behavior, device)
    history = to_device(history, device)
    pred = model(behavior, history)
    c_loss, hit_rate = balance_bce_loss(pred, in_view_len, clicked)
    acc_loss += c_loss.item()
    acc_hit_rate += hit_rate

print(f'Validation Loss: {acc_loss / len(dl_validation)} - Hit Rate: {acc_hit_rate / len(dl_validation)}')

**Model Inference**

In [None]:
behaviors_test = pl.read_parquet('preprocess/behaviors_test.parquet')
history_test = pl.read_parquet('preprocess/history_test.parquet')

ds_test = EBDataset(behaviors_test, history_test, article, article_image_embeddings, categories, labels=False)
dl_test = torch.utils.data.DataLoader(ds_test, batch_size=10, collate_fn=pad_inference, shuffle=False)

In [None]:
import pickle

model.eval()

res = []

with torch.no_grad():
    for idx, (in_view_len, behavior), history in tqdm(dl_test):
        behavior = to_device(behavior, device)
        history = to_device(history, device)
        pred = model(behavior, history)
        res.extend(interpret_inference(idx, pred.cpu().numpy(), in_view_len))

res