In [1]:
import os

import warnings
import glob
import pickle

warnings.filterwarnings('ignore')

In [50]:
import torch 
from transformers4rec import torch as tr
from transformers4rec.torch.ranking_metric import NDCGAt, AvgPrecisionAt, RecallAt
from transformers4rec.torch.utils.examples_utils import wipe_memory
import pandas as pd
from nvtabular.workflow import Workflow

from merlin.schema import Schema
from merlin.io import Dataset
import numpy as np

from transformers4rec import torch as tr

In [1]:
import pandas as pd

In [3]:
INPUT_DATA_DIR = "../../../data/train_10/"
OUTPUT_DIR = "../../../data/train_10/sessions_by_ts/"
TEST_INPUT_DATA_DIR = "../../../data/test_processed/"

In [51]:
train = Dataset(os.path.join(INPUT_DATA_DIR, "processed_nvt/part_0.parquet"))
schema = train.schema
workflow = Workflow.load(os.path.join(INPUT_DATA_DIR, "workflow_etl"))

In [68]:
ARTICLE_ID_MAPPING = pd.read_parquet("../../../data/train_10/workflow_etl/categories/unique.article_id.parquet")["article_id"].to_dict()

In [5]:
schema = schema.excluding_by_name(["day_index", "session_id"])

In [6]:
schema

Unnamed: 0,name,tags,dtype,is_list,is_ragged,properties.freq_threshold,properties.num_buckets,properties.cat_path,properties.max_size,properties.embedding_sizes.dimension,properties.embedding_sizes.cardinality,properties.domain.min,properties.domain.max,properties.domain.name,properties.value_count.min,properties.value_count.max
0,article_id-list,"(Tags.ID, Tags.LIST, Tags.CATEGORICAL, Tags.ITEM)","DType(name='int64', element_type=<ElementType....",True,True,0.0,,.//categories/unique.article_id.parquet,0.0,462.0,24748.0,0.0,24747.0,article_id,2,20
1,is_premium-list,"(Tags.LIST, Tags.CATEGORICAL)","DType(name='int64', element_type=<ElementType....",True,True,0.0,,.//categories/unique.is_premium.parquet,0.0,16.0,5.0,0.0,4.0,is_premium,2,20
2,article_type-list,"(Tags.LIST, Tags.CATEGORICAL)","DType(name='int64', element_type=<ElementType....",True,True,0.0,,.//categories/unique.article_type.parquet,0.0,16.0,16.0,0.0,15.0,article_type,2,20
3,category-list,"(Tags.LIST, Tags.CATEGORICAL)","DType(name='int64', element_type=<ElementType....",True,True,0.0,,.//categories/unique.category.parquet,0.0,16.0,30.0,0.0,29.0,category,2,20
4,topic-list,"(Tags.LIST, Tags.CATEGORICAL)","DType(name='int64', element_type=<ElementType....",True,True,0.0,,.//categories/unique.topic.parquet,0.0,16.0,39.0,0.0,38.0,topic,2,20
5,read_time-list,"(Tags.CONTINUOUS, Tags.LIST)","DType(name='float64', element_type=<ElementTyp...",True,True,,,,,,,,,,2,20
6,topics_count-list,"(Tags.CONTINUOUS, Tags.LIST)","DType(name='int64', element_type=<ElementType....",True,True,,,,,,,,,,2,20
7,sentiment_score-list,"(Tags.CONTINUOUS, Tags.LIST)","DType(name='float64', element_type=<ElementTyp...",True,True,,,,,,,,,,2,20


In [7]:
TOP_K = int(schema['article_id-list'].properties["embedding_sizes"]["cardinality"])

In [8]:
inputs = tr.TabularSequenceFeatures.from_schema(
        schema,
        max_sequence_length=20,
        continuous_projection=64,
        masking="mlm",
        d_output=100,
)

article_id-list article_id-list {<Tags.ID: 'id'>, <Tags.LIST: 'list'>, <Tags.CATEGORICAL: 'categorical'>, <Tags.ITEM: 'item'>}
is_premium-list is_premium-list {<Tags.LIST: 'list'>, <Tags.CATEGORICAL: 'categorical'>}
article_type-list article_type-list {<Tags.LIST: 'list'>, <Tags.CATEGORICAL: 'categorical'>}
category-list category-list {<Tags.LIST: 'list'>, <Tags.CATEGORICAL: 'categorical'>}
topic-list topic-list {<Tags.LIST: 'list'>, <Tags.CATEGORICAL: 'categorical'>}
article_id-list article_id-list {<Tags.ID: 'id'>, <Tags.LIST: 'list'>, <Tags.CATEGORICAL: 'categorical'>, <Tags.ITEM: 'item'>}
is_premium-list is_premium-list {<Tags.LIST: 'list'>, <Tags.CATEGORICAL: 'categorical'>}
article_type-list article_type-list {<Tags.LIST: 'list'>, <Tags.CATEGORICAL: 'categorical'>}
category-list category-list {<Tags.LIST: 'list'>, <Tags.CATEGORICAL: 'categorical'>}
topic-list topic-list {<Tags.LIST: 'list'>, <Tags.CATEGORICAL: 'categorical'>}
read_time-list read_time-list {<Tags.CONTINUOUS: 'cont

In [9]:
# Define XLNetConfig class and set default parameters for HF XLNet config  
transformer_config = tr.XLNetConfig.build(
    d_model=64, n_head=4, n_layer=2, total_seq_length=20
)
# Define the model block including: inputs, masking, projection and transformer block.
body = tr.SequentialBlock(
    inputs, tr.MLPBlock([64]), tr.TransformerBlock(transformer_config, masking=inputs.masking)
)

# Define the evaluation top-N metrics and the cut-offs
metrics = [NDCGAt(top_ks=[20, 40], labels_onehot=True),  
           RecallAt(top_ks=[20, 40], labels_onehot=True)]

# Define a head related to next item prediction task 
head = tr.Head(
    body,
    tr.NextItemPredictionTask(
        weight_tying=True, 
        metrics=metrics
    ),
    inputs=inputs,
)

# Get the end-to-end Model class 
model = tr.Model(head)

In [10]:
per_device_train_batch_size = int(os.environ.get(
    "per_device_train_batch_size", 
    '64'
))

per_device_eval_batch_size = int(os.environ.get(
    "per_device_eval_batch_size", 
    "2"
))

In [11]:
# Set hyperparameters for training 
train_args = tr.T4RecTrainingArguments(data_loader_engine='merlin', 
                                    dataloader_drop_last = True,
                                    gradient_accumulation_steps = 1,
                                    per_device_train_batch_size = per_device_train_batch_size, 
                                    per_device_eval_batch_size = per_device_eval_batch_size,
                                    output_dir = "./tmp", 
                                    learning_rate=0.0005,
                                    lr_scheduler_type='cosine', 
                                    learning_rate_num_cosine_cycles_by_epoch=1.5,
                                    num_train_epochs=5,
                                    predict_top_k=TOP_K,
                                    max_sequence_length=20, 
                                    report_to = [],
                                    logging_steps=500,
                                    no_cuda=False)

In [12]:
# Instantiate the T4Rec Trainer, which manages training and evaluation for the PyTorch API
trainer = tr.Trainer(
    model=model,
    args=train_args,
    schema=schema,
    compute_metrics=True,
)

In [13]:
timestamps = os.listdir(f"{OUTPUT_DIR}/")
timestamps = sorted(timestamps)[1:]
int_list = [int(i) for i in timestamps]
int_list.sort()
sorted_string_list = [str(i) for i in int_list]

In [14]:
#Iterating over days of one week
for idx, time_val in enumerate(sorted_string_list):
    # Set data
    time_index_train = time_val
    time_index_eval = sorted_string_list[idx+1]
    train_paths = glob.glob(os.path.join(OUTPUT_DIR, f"{time_index_train}/train.parquet"))
    eval_paths = glob.glob(os.path.join(OUTPUT_DIR, f"{time_index_eval}/valid.parquet"))

    # Train on day related to time_index
    print('*'*20)
    print("Launch training for day %s are:" %time_val)
    print('*'*20 + '\n')
    trainer.train_dataset_or_path = train_paths
    trainer.reset_lr_scheduler()
    trainer.train()
    trainer.state.global_step +=1
    print('finished')

    # Evaluate on the following day
    trainer.eval_dataset_or_path = eval_paths
    train_metrics = trainer.evaluate(metric_key_prefix='eval')
    print('*'*20)
    print("Eval results for day %s are:\t" %time_index_eval)
    print('\n' + '*'*20 + '\n')
    for key in sorted(train_metrics.keys()):
        print(" %s = %s" % (key, str(train_metrics[key])))
    wipe_memory()
    break

********************
Launch training for day 2 are:
********************



Step,Training Loss
500,7.1724
1000,6.3524
1500,6.205


finished


********************
Eval results for day 3 are:	

********************

 eval_/loss = 6.460936069488525
 eval_/next-item/ndcg_at_20 = 0.034348174929618835
 eval_/next-item/ndcg_at_40 = 0.04883221164345741
 eval_/next-item/recall_at_20 = 0.09379906207323074
 eval_/next-item/recall_at_40 = 0.16405023634433746
 eval_runtime = 46.3453
 eval_samples_per_second = 54.979
 eval_steps_per_second = 27.489


In [49]:
schema['article_id-list']

ColumnSchema(name='article_id-list', tags={<Tags.ID: 'id'>, <Tags.LIST: 'list'>, <Tags.CATEGORICAL: 'categorical'>, <Tags.ITEM: 'item'>}, properties={'freq_threshold': 0.0, 'num_buckets': None, 'cat_path': './/categories/unique.article_id.parquet', 'max_size': 0.0, 'embedding_sizes': {'dimension': 462.0, 'cardinality': 24748.0}, 'domain': {'min': 0, 'max': 24747, 'name': 'article_id'}, 'value_count': {'min': 2, 'max': 20}}, dtype=DType(name='int64', element_type=<ElementType.Int: 'int'>, element_size=64, element_unit=None, signed=True, shape=Shape(dims=(Dimension(min=0, max=None), Dimension(min=2, max=20)))), is_list=True, is_ragged=True)

In [20]:
with open('../../../data/ebnerd_testset/inviews.pkl', 'rb') as file:
    INVIEWS = pickle.load(file)

In [None]:
pd.read_pa

In [21]:
trainer.predict(f"{TEST_INPUT_DATA_DIR}processed_nvt/{1}/part_0.parquet")

KeyboardInterrupt: 

In [41]:
len(INVIEWS)

13536710

In [4]:
pd.read_parquet(f"{TEST_INPUT_DATA_DIR}processed_nvt/{1}/part_0.parquet")

Unnamed: 0,article_id-list,is_premium-list,article_type-list,category-list,topic-list,read_time-list,topics_count-list,sentiment_score-list,session_id,day_index
0,"[3196, 2645, 3220, 3238, 2742, 2628, 3776, 320...","[3, 3, 3, 3, 3, 3, 4, 3, 3, 3]","[3, 3, 3, 3, 3, 3, 3, 3, 3, 3]","[3, 5, 5, 4, 5, 8, 6, 3, 8, 6]","[5, 4, 4, 3, 4, 11, 3, 5, 8, 13]","[3.0, 116.0, 9.0, 16.0, 36.0, 16.0, 11.0, 1.0,...","[5, 2, 3, 7, 3, 2, 6, 3, 2, 2]","[0.5519000291824341, 0.9657999873161316, 0.965...",0,21
1,"[3275, 3169, 2774, 3224, 2935, 1898, 2953, 355...","[3, 3, 3, 3, 3, 3, 3, 3, 3, 3]","[3, 3, 3, 3, 3, 3, 3, 3, 3, 3]","[3, 3, 5, 3, 6, 3, 6, 6, 5, 5]","[5, 8, 9, 5, 13, 5, 3, 3, 4, 4]","[8.0, 6.0, 35.0, 1.0, 93.0, 27.0, 25.0, 45.0, ...","[4, 4, 3, 4, 2, 4, 3, 3, 1, 1]","[0.7477999925613403, 0.9742000102996826, 0.962...",1,20
2,"[2935, 3220, 3202, 3281, 2667, 3183, 3238, 319...","[3, 3, 3, 3, 3, 3, 3, 3, 3, 3]","[3, 3, 3, 3, 3, 3, 3, 3, 3, 3]","[6, 5, 3, 8, 6, 5, 4, 3, 3, 5]","[13, 4, 5, 5, 3, 4, 3, 5, 5, 4]","[66.0, 1.0, 9.0, 1.0, 4.0, 9.0, 2.0, 2.0, 8.0,...","[2, 3, 3, 6, 3, 1, 7, 5, 4, 3]","[0.722599983215332, 0.965399980545044, 0.69900...",2,21
3,"[3196, 3238, 2935, 3220, 3413, 3183, 2667, 328...","[3, 3, 3, 3, 3, 3, 3, 3, 3, 3]","[3, 3, 3, 3, 3, 3, 3, 3, 3, 3]","[3, 4, 6, 5, 3, 5, 6, 8, 3, 5]","[5, 3, 13, 4, 5, 4, 3, 5, 5, 4]","[2.0, 2.0, 66.0, 2.0, 8.0, 9.0, 4.0, 1.0, 9.0,...","[5, 7, 2, 3, 4, 1, 3, 6, 3, 3]","[0.5519000291824341, 0.9957000017166138, 0.722...",3,21
4,"[2935, 3413, 3281, 3238, 3196, 3183, 3220, 322...","[3, 3, 3, 3, 3, 3, 3, 3, 3, 3]","[3, 3, 3, 3, 3, 3, 3, 3, 3, 3]","[6, 3, 8, 4, 3, 5, 5, 5, 3, 6]","[13, 5, 5, 3, 5, 4, 4, 4, 5, 3]","[66.0, 8.0, 1.0, 2.0, 2.0, 9.0, 1.0, 2.0, 9.0,...","[2, 4, 6, 7, 5, 1, 3, 3, 3, 3]","[0.722599983215332, 0.8378000259399414, 0.9197...",4,21
...,...,...,...,...,...,...,...,...,...,...
1353665,"[3478, 3089, 3478, 3478, 2517, 3631, 3478, 323...","[3, 3, 3, 3, 3, 3, 3, 3, 3, 3]","[3, 3, 3, 3, 3, 3, 3, 3, 3, 3]","[4, 3, 4, 4, 4, 4, 4, 4, 7, 4]","[3, 3, 3, 3, 3, 4, 3, 3, 6, 5]","[16.0, 46.0, 9.0, 3.0, 115.0, 11.0, 14.0, 8.0,...","[6, 3, 6, 6, 4, 4, 6, 7, 5, 5]","[0.9926000237464905, 0.6837999820709229, 0.992...",1353665,21
1353666,"[3478, 2517, 3089, 2963, 3507, 3478, 3631, 347...","[3, 3, 3, 3, 3, 3, 3, 3, 3, 3]","[3, 3, 3, 3, 3, 3, 3, 3, 3, 3]","[4, 4, 3, 7, 4, 4, 4, 4, 4, 4]","[3, 3, 3, 6, 5, 3, 4, 3, 3, 3]","[16.0, 115.0, 46.0, 54.0, 695.0, 9.0, 11.0, 3....","[6, 4, 3, 5, 5, 6, 4, 6, 7, 6]","[0.9926000237464905, 0.963699996471405, 0.6837...",1353666,21
1353667,"[2562, 3001, 2861, 4346, 4489, 3262, 3652, 318...","[3, 3, 3, 3, 3, 3, 3, 3, 3, 3]","[3, 3, 3, 3, 3, 3, 3, 3, 3, 3]","[5, 5, 5, 5, 8, 8, 8, 5, 6, 5]","[4, 16, 4, 6, 5, 5, 5, 4, 5, 10]","[23.0, 7.0, 11.0, 7.0, 3.0, 35.0, 13.0, 10.0, ...","[2, 4, 3, 4, 3, 3, 3, 1, 5, 4]","[0.9846000075340271, 0.9941999912261963, 0.993...",1353667,20
1353668,"[3169, 3501, 3888, 2952, 2732, 2579, 1852, 348...","[3, 3, 3, 3, 3, 3, 3, 3, 3, 3]","[3, 3, 3, 3, 3, 4, 3, 3, 3, 4]","[3, 5, 4, 4, 3, 4, 6, 4, 4, 4]","[8, 9, 3, 3, 3, 3, 3, 3, 3, 3]","[8.0, 1.0, 119.0, 35.0, 0.0, 0.0, 87.0, 0.0, 1...","[4, 3, 4, 5, 3, 4, 7, 5, 5, 4]","[0.9742000102996826, 0.9681000113487244, 0.932...",1353668,20


In [25]:
1353670 / 54.979 / 60 / 60

6.839328551709643

In [26]:
a = pd.read_parquet(f"{TEST_INPUT_DATA_DIR}processed_nvt/{1}/part_0.parquet")

In [28]:
a.head(1000).to_parquet(f"{TEST_INPUT_DATA_DIR}processed_nvt/1/part_min.parquet")

In [75]:
def mask_logits(article_ids, logits, inviews):
    batch_size, num_classes = logits.shape

    # Create a copy of logits to apply masking
    masked_logits = np.copy(logits)

    for i in range(batch_size):
        # Create a mask initialized to all True (we will mask these positions to -inf)
        mask = np.ones(num_classes, dtype=bool)

        # Get the indices of article_ids that are in inviews
        if inviews:
            # This finds the indices in article_ids[i] that should not be masked
            valid_indices = [np.where(article_ids[i] == vid)[0][0] for vid in inviews[i] if vid in article_ids[i]]
            mask[valid_indices] = False  # Set False where positions should not be masked

        # Apply the mask to logits
        masked_logits[i, mask] = -np.inf
    
    return masked_logits

def get_article_rankings(trainer, data_path, inviews=None):
    trainer.args.dataloader_drop_last = False
    trainer.model.eval()
    trainer.per_device_eval_batch_size = 512
    print("started predicting")
    article_ids, logits = trainer.predict(data_path).predictions
    fn = np.vectorize(lambda x: ARTICLE_ID_MAPPING.get(x, x))
    article_ids = fn(article_ids)
    print("got predictions")
    trainer.args.dataloader_drop_last = True

    # Apply masking, if inviews is provided
    if inviews:
        logits = mask_logits(article_ids, logits, inviews)

    # Apply softmax to the logits (after masking, if applicable)
    exp_logits = np.exp(logits - np.max(logits, axis=1, keepdims=True))
    softmax_probs = exp_logits / exp_logits.sum(axis=1, keepdims=True)

    ranks = []
    if inviews:  # Only adjust rankings if inviews are provided
        for i in range(softmax_probs.shape[0]):
            if inviews[i]:  # If inviews for this batch element is non-empty
                # Get the probabilities for only the valid indices
                valid_indices_probs = softmax_probs[i, [np.where(article_ids[i] == vid)[0][0] for vid in inviews[i] if vid in article_ids[i]]]
                # Get the ranks for these valid indices
                valid_indices_rankings = np.argsort(-valid_indices_probs).tolist()
                ranks.append(valid_indices_rankings)
            else:
                ranks.append([]) # No inviews for this element
    else:
        # If no inviews are provided, just get the rankings for the entire softmax_probs
        ranks = np.argsort(-softmax_probs, axis=1).tolist()

    return ranks, article_ids

In [85]:
for idx in range(1, 11):
    test_ds_part_path =  f"{TEST_INPUT_DATA_DIR}processed_nvt/1/part_min.parquet"
    indexes = pd.read_parquet(test_ds_part_path, columns=['session_id'])['session_id'].values.tolist()
    filtered_inviews = [INVIEWS[i] for i in indexes]
    ranks, ids = get_article_rankings(trainer, test_ds_part_path, filtered_inviews)
    break

started predicting


got predictions


In [88]:
for i in list(set(sum(filtered_inviews, []))):
    if i not in ARTICLE_ID_MAPPING.keys():
        print(i)

9793538
9791493
9486351
9799697
9791515
9648160
9789473
9795620
9799726
9652271
9777200
9789493
9789494
9787455
9619528
9476171
9771091
8067162
9443420
9089120
9791587
9797733
9797735
9754730
9797738
9734255
9791602
9756788
9494650
9787524
9654405
6111369
9779345
9793684
9472149
9793687
9482394
9797792
9787553
9789605
9779365
9795764
9791670
9793726
9685186
9797827
9797828
9797830
9779411
9799896
9775325
9795807
9797857
9052390
9466087
9793776
9793777
9793785
9797882
9785593
9793786
9566461
9777406
9791743
9797890
9791748
9787659
9378062
9525523
9797912
9789721
9795870
9795871
9799969
9795876
9793829
9142564
9492777
9720107
9789745
9793842
9797938
9789747
9474355
9793846
9777457
9787701
9789757
9793856
9718081
9793860
9771333
9800008
9789773
9800022
9793888
9666919
9793900
9787767
9720184
9582969
9795964
9791878
6842758
9787784
9791881
9793930
9800078
9800084
9800095
9800109
9798065
6691251
9380293
9564613
9787848
9798093
9798094
9796047
9783757
9787863
9798109
9529823
9785826
9796077


In [79]:
for val in ids[0]

array([9743795, 9747762, 9745912, ..., 8860105, 9748841, 6533312])

In [77]:
ranks

[[1, 2, 0],
 [],
 [],
 [1, 0],
 [0],
 [0],
 [],
 [1, 3, 2, 5, 4, 0],
 [],
 [0],
 [],
 [],
 [3, 4, 0, 2, 1],
 [0, 1, 2],
 [0, 2, 1],
 [2, 0, 1],
 [],
 [0],
 [],
 [1, 0, 2],
 [],
 [1, 2, 0, 3],
 [],
 [0],
 [],
 [0, 1, 3, 2],
 [],
 [],
 [],
 [1, 3, 0, 2],
 [1, 0, 3, 2, 4],
 [8, 0, 3, 4, 7, 1, 6, 5, 2],
 [0],
 [0],
 [1, 2, 3, 0],
 [0, 3, 1, 2],
 [3, 0, 2, 1, 4, 5],
 [],
 [0],
 [3, 2, 1, 0],
 [],
 [4, 1, 0, 3, 2],
 [],
 [1, 0],
 [1, 0, 2],
 [4, 0, 2, 5, 6, 1, 3],
 [3, 2, 0, 4, 1],
 [2, 4, 3, 0, 1],
 [],
 [2, 1, 0],
 [],
 [8, 1, 6, 4, 9, 10, 2, 11, 0, 5, 3, 7],
 [1, 0, 8, 9, 4, 10, 2, 5, 6, 7, 3],
 [13, 0, 12, 9, 4, 11, 8, 5, 6, 3, 1, 14, 7, 2, 10],
 [1, 0],
 [],
 [],
 [3, 2, 1, 0],
 [],
 [],
 [0],
 [2, 1, 0],
 [0],
 [0],
 [],
 [2, 0, 1],
 [],
 [],
 [2, 0, 3, 1],
 [0, 5, 2, 7, 1, 6, 4, 3],
 [1, 0],
 [],
 [],
 [1, 0],
 [0, 1, 2, 3],
 [1, 0, 2],
 [0],
 [0],
 [4, 1, 3, 0, 2],
 [1, 2, 3, 0],
 [],
 [],
 [],
 [3, 0, 1, 2],
 [4, 1, 3, 0, 2],
 [1, 2, 3, 0],
 [3, 1, 4, 0, 2],
 [],
 [],
 [0, 1],
 [2, 

In [71]:
for val in ids[0]:
    try:
        ARTICLE_ID_MAPPING[val]
    except:
        pass

In [39]:
len(ids[0])

24748