In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from data import load_parquets_from_zip
from tqdm.auto import tqdm
from torch.utils.data import Dataset
from functools import partial
from itertools import chain
import polars as pl
import numpy as np
import os
from collections import defaultdict, Counter
from typing import Tuple, Dict, List, Any, Callable, Optional, Union, Iterable
from datetime import datetime
from bisect import bisect
import math
import sentence_transformers as st


base_dir = 'preprocess'
if not os.path.exists(base_dir):
    os.makedirs(base_dir)

In [3]:
from collections import defaultdict
import numpy as np
def binary_encoding(x: bool) -> int:
    return 1 if x else 0


def dict_encoding(x: Any, map: Dict[Any, int]) -> int:
    return map[x]


def build_dict_encoding(labels: List[Any], unknown: bool=False) -> Callable[[Any], int]:
    map_elems = {label: i for i, label in enumerate(labels)}
    if unknown:
        unk_val = len(labels)
        map_elems = defaultdict(lambda: unk_val, map_elems)
    return partial(dict_encoding, map=map_elems)


#Given a list it returns a list with the x quantile limits minus the last limit
def quantile_limits(x: Iterable[int], quantiles: int=100) -> List[float]:
    data = list(x)
    data.sort()
    step = len(data) / 100
    idxs = [math.ceil(i * step) for i in range(1, 100)]
    res = [data[i] for i in idxs]
    return res


def time_encoding(x: int, limits: List[int]) -> Tuple[int]:
    return bisect(limits, x)


def count_in_list(df: pl.DataFrame, col: str) -> Dict[Any, int]:
    count = Counter()
    for x in df[col]:
        if x is None:
            continue
        for v in x:
            count[v] += 1
    return count


def count(df: pl.DataFrame, col: str) -> Dict[Any, int]:
    count = Counter()
    for x in df[col]:
        count[x] += 1
    return count


def get_map_for_feature(df: pl.DataFrame, col: str, min_reps: int, is_list_column: bool=False, unknown: bool=False) -> Tuple[Callable[[Any], int], List[Any]]:
    if is_list_column:
        values_count = count_in_list(df, col)
    else:
        values_count = count(df, col)
    filtered_values = [x for x, v in values_count.items() if v > min_reps]
    filtered_values.sort()
    return build_dict_encoding(filtered_values, unknown=unknown), filtered_values


def preprocess_article(article_df: pl.DataFrame, min_reps: int=100, 
                       image_ids: Optional[Dict[int, int]]=None, 
                       model_name: str='paraphrase-multilingual-mpnet-base-v2') -> Tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame, pl.DataFrame, Dict[int, int], Dict[int, datetime]]:
    ops = []
    model = st.SentenceTransformer(model_name)
    categories = list(set(article_df['category_str']))
    categories.sort()
    categories_map = {c: i for i, c in enumerate(categories)}
    embeddings = model.encode(categories)
    categories_df = pl.DataFrame(data=[categories, embeddings], schema=['category_str', 'embeddings'])
    ops.append(pl.col('category_str').replace(categories_map).cast(pl.Int64).alias('category_link'))

    ners = set()
    for n in article_df['ner_clusters']:
        if n is not None:
            for w in n:
                ners.add(w)
    ners = list(ners)
    ners.sort()
    ner_map = {n: i for i, n in enumerate(ners)}
    embeddings = model.encode(ners)
    ner_df = pl.DataFrame(data=[ners, embeddings], schema=['ners', 'embeddings'])
    ops.append(pl.col('ner_clusters').map_elements(lambda x: [ner_map[n] for n in x] if x is not None else [], 
                                               return_dtype=pl.List(pl.Int64), skip_nulls=False).alias('ner_clusters_link'))

    topics = set()
    for n in article_df['topics']:
        if n is not None:
            for w in n:
                topics.add(w)
    topics = list(topics)
    topics.sort()
    topics_map = {n: i for i, n in enumerate(topics)}
    embeddings = model.encode(topics)
    topics_df = pl.DataFrame(data=[topics, embeddings], schema=['topics', 'embeddings'])
    ops.append(pl.col('topics').map_elements(lambda x: [topics_map[n] for n in x] if x is not None else [], 
                                               return_dtype=pl.List(pl.Int64), skip_nulls=False).alias('topics_link'))

    article_df = article_df.with_columns(ops)

    #Preprocessing
    article_df = article_df.with_columns(pl.col('ner_clusters').map_elements(lambda x: [n.lower()[:3] for n in x] if x is not None else [], 
                                               return_dtype=pl.List(pl.String), skip_nulls=False))

    articleid_date = article_df.select('article_id', 'published_time').rows_by_key(key='article_id')
    articleid_date = {k: v[0] for k, v in articleid_date.items()}
    articleid_idx = {a: i for i, a in enumerate(article_df['article_id'])}
    #Ops to execute on the article_df
    ops = [pl.col('premium').map_elements(binary_encoding, return_dtype=pl.Int64),
           (pl.col('published_time').dt.weekday() - 1).alias('pub_weekday'),
           pl.col('published_time').dt.hour().alias('pub_hour')]
    #Is list & unknown
    features = [('article_type', False, True),
                ('subcategory', True, False),
                ('sentiment_label', False, False),
                ('category', False, True),
                ('ner_clusters', True, False),
                ('topics', True, False)]
    for feature, is_list, unknown in features:
        map_feature, valid_value = get_map_for_feature(article_df, feature, min_reps, is_list_column=is_list, unknown=unknown)
        if is_list:
            valid_value = set(valid_value)

            def process_feature(x, map_feature=map_feature, valid_value=valid_value):
                if x is None:
                    return []
                return [map_feature(v) for v in x if v in valid_value]
            
            ops.append(pl.col(feature).map_elements(process_feature, return_dtype=pl.List(pl.Int64), skip_nulls=False))
        else:
            ops.append(pl.col(feature).map_elements(map_feature, return_dtype=pl.Int64))
        print(f'Feature {feature} encoded. Number of values: {(len(valid_value) + 1) if unknown else len(valid_value)}')
    if image_ids is not None:
        ops.append(pl.col('image_ids').map_elements(lambda x: [image_ids[v] for v in x if v in image_ids] if x is not None else [], 
                                                    return_dtype=pl.List(pl.Int64), skip_nulls=False))
    article_df = article_df.with_columns(ops)

    return article_df, categories_df, ner_df, topics_df, articleid_idx, articleid_date


def compute_behaviors_dates(behavior_df, article_publish):
    behavior_df = behavior_df.with_columns(
        article_delta_time=pl.struct(('impression_time','article_ids_inview' )). \
                            map_elements(
                                lambda x: [int((x['impression_time'] - article_publish[a]).total_seconds()) for a in x['article_ids_inview']],
                                return_dtype=pl.List(pl.Int64)),
        number_articles=pl.col('article_ids_inview').map_elements(len, return_dtype=pl.Int64)
    )
    return behavior_df


def ids_sort(data: List[Any]) -> List[int]:
    idx = list(range(len(data)))
    idx.sort(key=lambda x: data[x], reverse=False)
    return idx


def sort_ids(info: List[Any], idxs: List[int]) -> List[Any]:
    return [info[idx] for idx in idxs]


def compute_history_dates(history_df: pl.DataFrame, article_publish: Dict[int, datetime], sorted_history: bool=False) -> pl.DataFrame:
    if sorted_history:
        history_df = history_df.\
                    with_columns(order=pl.col('impression_time_fixed'). \
                                map_elements(ids_sort, return_dtype=pl.List(pl.Int64)))
        history_df = history_df.with_columns(
                        impression_time_fixed=pl.struct('impression_time_fixed', 'order').
                        map_elements(lambda x: sort_ids(x['impression_time_fixed'], x['order']), 
                                    return_dtype=pl.List(pl.Datetime)),
                        scroll_percentage_fixed=pl.struct('scroll_percentage_fixed', 'order').
                        map_elements(lambda x: sort_ids(x['scroll_percentage_fixed'], x['order']), 
                                    return_dtype=pl.List(pl.Float64)),
                        article_id_fixed=pl.struct('article_id_fixed', 'order').
                        map_elements(lambda x: sort_ids(x['article_id_fixed'], x['order']), 
                                    return_dtype=pl.List(pl.Int64)),
                        read_time_fixed=pl.struct('read_time_fixed', 'order').
                        map_elements(lambda x: sort_ids(x['read_time_fixed'], x['order']), 
                                    return_dtype=pl.List(pl.Float64))). \
                        drop('order')
    history_df = history_df.with_columns(
        impression_weekday=pl.col('impression_time_fixed'). 
            map_elements(lambda x: [d.weekday() for d in x], return_dtype=pl.List(pl.Int64)),
        impression_hour=pl.col('impression_time_fixed'). 
            map_elements(lambda x: [d.hour for d in x], return_dtype=pl.List(pl.Int64)),
        article_delta_time=pl.struct(('impression_time_fixed', 'article_id_fixed')). 
            map_elements(lambda x:[int((d - article_publish[a]).total_seconds()) for d, a in zip(x['impression_time_fixed'], x['article_id_fixed'])], 
                         return_dtype=pl.List(pl.Int64)),
        number_articles=pl.col('article_id_fixed').map_elements(len, return_dtype=pl.Int64)
    )
    return history_df


def preprocess(behavior: pl.DataFrame, history: pl.DataFrame, article_publish: Dict[int, datetime], 
               article_idx: Dict[int, int], sorted_history: bool=False, time_quartiles: Optional[List[int]]=None) -> Tuple[pl.DataFrame, pl.DataFrame, List[int]]:
    print('Computing history & behavior dates')
    history = compute_history_dates(history, article_publish, sorted_history)
    behavior = compute_behaviors_dates(behavior, article_publish)
    
    print('Computing history user_id mapping and time percentiles')
    historyid_idx = {a: i for i, a in enumerate(history['user_id'])}
    if time_quartiles is None:
        time_quartiles = quantile_limits(chain(*history['article_delta_time'].to_list()))

    print('Preprocessing behavior')
    if 'next_scroll_percentage' in behavior.columns:
        behavior = behavior.with_columns(pl.col('next_scroll_percentage').
                                         map_elements(lambda x: x / 100 if x is not None else 0.0, 
                                                      return_dtype=pl.Float64, skip_nulls=False),
                                        article_ids_clicked=pl.struct('article_ids_inview', 'article_ids_clicked').
                                        map_elements(lambda x: [x['article_ids_inview'].index(v) for v in x['article_ids_clicked']], 
                                                     return_dtype=pl.List(pl.Int64)))
    behavior = behavior.with_columns(
        pl.col('user_id').replace(historyid_idx),
        pl.col('is_sso_user').map_elements(binary_encoding, return_dtype=pl.Int8),
        pl.col('gender').map_elements(build_dict_encoding([0, 1, 2, None]), return_dtype=pl.Int64, skip_nulls=False),
        pl.col('postcode').map_elements(build_dict_encoding([0, 1, 2, 3, 4, None]), return_dtype=pl.Int64, skip_nulls=False),
        pl.col('age').map_elements(build_dict_encoding([0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, None]), return_dtype=pl.Int64, skip_nulls=False),
        pl.col('is_subscriber').map_elements(binary_encoding, return_dtype=pl.Int8),
        pl.col('device_type').map_elements(build_dict_encoding([0, 1, 2, 3]), return_dtype=pl.Int8),
        pl.col('article_delta_time').map_elements(lambda x: [time_encoding(v, time_quartiles) for v in x], return_dtype=pl.List(pl.Int64)),
        pl.col('article_ids_inview').map_elements(lambda x: [article_idx[v] for v in x], return_dtype=pl.List(pl.Int64)),
        impression_hour=pl.col('impression_time').dt.hour(),
        impression_weekday=pl.col('impression_time').dt.weekday() - 1
    )
    print('Preprocessing history')
    history = history.with_columns(pl.col('scroll_percentage_fixed').
                                   map_elements(lambda l: [x / 100 if x != None else 0.0 for x in l], 
                                                return_dtype=pl.List(pl.Float64)),
                                    pl.col('article_delta_time').map_elements(lambda x: [time_encoding(v, time_quartiles) for v in x], return_dtype=pl.List(pl.Int64)),
                                    pl.col('article_id_fixed').map_elements(lambda x: [article_idx[v] for v in x], return_dtype=pl.List(pl.Int64))
                                    )
    return behavior, history, time_quartiles

In [4]:
ds = load_parquets_from_zip('dataset/ebnerd_large.zip')
behavior = ds['train/behaviors']
history = ds['train/history']
article = ds['articles']
del ds

In [5]:
# images = load_parquets_from_zip('dataset/Ekstra_Bladet_image_embeddings.zip')['Ekstra_Bladet_image_embeddings/image_embeddings']
# image_ids = {a: i for i, a in enumerate(images['article_id'])}
# del images
# preprocess_article(article, image_ids=image_ids)[0]

In [6]:
article, categories_embs, ner_embs, topics_embs, articleid_idx, article_publish_date = preprocess_article(article, 1000)

Feature article_type encoded. Number of values: 4
Feature subcategory encoded. Number of values: 28
Feature sentiment_label encoded. Number of values: 3
Feature category encoded. Number of values: 14
Feature ner_clusters encoded. Number of values: 96
Feature topics encoded. Number of values: 54


In [7]:
%%time
behavior, history, time_quartiles = preprocess(behavior, history, article_publish_date, articleid_idx)

Computing history & behavior dates
Computing history user_id mapping and time percentiles
Preprocessing behavior
Preprocessing history
CPU times: total: 8min 43s
Wall time: 22min 45s


In [8]:
behavior.head()

impression_id,article_id,impression_time,read_time,scroll_percentage,device_type,article_ids_inview,article_ids_clicked,user_id,is_sso_user,gender,postcode,age,is_subscriber,session_id,next_read_time,next_scroll_percentage,article_delta_time,number_articles,impression_hour,impression_weekday
u32,i32,datetime[μs],f32,f32,i8,list[i64],list[i64],i64,i8,i64,i64,i64,i8,u32,f32,f64,list[i64],i64,i8,i8
47727,,2023-05-21 21:35:07,20.0,,1,"[95773, 122630, … 99525]",[1],341731,0,3,5,11,0,265,34.0,1.0,"[99, 75, … 99]",6,21,6
47731,,2023-05-21 21:32:33,13.0,,1,"[122562, 122550, … 120591]",[4],341731,0,3,5,11,0,265,45.0,1.0,"[49, 62, … 61]",5,21,6
47736,,2023-05-21 21:33:32,17.0,,1,"[120591, 122562, … 122638]",[9],341731,0,3,5,11,0,265,78.0,1.0,"[62, 49, … 63]",13,21,6
47737,,2023-05-21 21:38:17,27.0,,1,"[122567, 122626, … 122602]",[9],341731,0,3,5,11,0,265,6.0,0.52,"[81, 76, … 82]",11,21,6
47740,,2023-05-21 21:36:02,48.0,,1,"[122595, 122629, … 122578]",[8],341731,0,3,5,11,0,265,32.0,1.0,"[73, 72, … 76]",9,21,6


In [9]:
history.head()

user_id,impression_time_fixed,scroll_percentage_fixed,article_id_fixed,read_time_fixed,impression_weekday,impression_hour,article_delta_time,number_articles
u32,list[datetime[μs]],list[f64],list[i64],list[f32],list[i64],list[i64],list[i64],i64
10029,"[2023-04-28 06:16:57, 2023-04-28 06:17:31, … 2023-05-18 06:59:50]","[0.23, 0.69, … 0.0]","[117123, 117715, … 122002]","[28.0, 24.0, … 0.0]","[4, 4, … 3]","[6, 6, … 6]","[0, 18, … 67]",678
10033,"[2023-04-27 11:11:32, 2023-04-27 11:12:56, … 2023-05-17 20:22:42]","[0.33, 0.41, … 0.29]","[117462, 117475, … 121840]","[2.0, 2.0, … 1.0]","[3, 3, … 2]","[11, 11, … 20]","[81, 76, … 29]",587
10034,"[2023-04-30 09:46:57, 2023-04-30 09:47:33, … 2023-05-16 08:40:52]","[0.0, 0.88, … 1.0]","[118112, 118110, … 121595]","[21.0, 103.0, … 9.0]","[6, 6, … 1]","[9, 9, … 8]","[69, 59, … 33]",140
10041,"[2023-04-27 15:15:28, 2023-04-27 15:16:30, … 2023-05-17 14:54:05]","[0.78, 0.41, … 0.57]","[117552, 117478, … 120316]","[12.0, 11.0, … 22.0]","[3, 3, … 2]","[15, 15, … 14]","[39, 62, … 77]",139
10103,"[2023-04-27 15:37:35, 2023-04-27 15:38:37, … 2023-05-18 04:52:09]","[1.0, 0.0, … 0.63]","[117552, 117571, … 121846]","[45.0, 8.0, … 24.0]","[3, 3, … 3]","[15, 15, … 4]","[49, 3, … 92]",64


In [10]:
behavior.write_parquet('preprocess/train_behaviors.parquet')
history.write_parquet('preprocess/train_history.parquet')
article.write_parquet('preprocess/article.parquet')
categories_embs.write_parquet('preprocess/categories_embs.parquet')
ner_embs.write_parquet('preprocess/ner_embs.parquet')
topics_embs.write_parquet('preprocess/topics_embs.parquet')

In [11]:
ds = load_parquets_from_zip('dataset/ebnerd_large.zip')
behavior = ds['validation/behaviors']
history = ds['validation/history']
del ds

In [12]:
%%time
behavior, history, _ = preprocess(behavior, history, article_publish_date, articleid_idx, time_quartiles=time_quartiles)

Computing history & behavior dates
Computing history user_id mapping and time percentiles
Preprocessing behavior
Preprocessing history
CPU times: total: 6min 57s
Wall time: 21min 50s


In [13]:
behavior.write_parquet('preprocess/validation_behaviors.parquet')
history.write_parquet('preprocess/validation_history.parquet')

In [14]:
del behavior
del history

In [15]:
ds = load_parquets_from_zip('dataset/ebnerd_testset.zip')
behavior = ds['ebnerd_testset/test/behaviors']
history = ds['ebnerd_testset/test/history']
del ds

In [16]:
%%time
behavior, history, _ = preprocess(behavior, history, article_publish_date, articleid_idx, time_quartiles=time_quartiles)

Computing history & behavior dates
Computing history user_id mapping and time percentiles
Preprocessing behavior
Preprocessing history
CPU times: total: 8min 6s
Wall time: 15min 12s


In [17]:
behavior.write_parquet('preprocess/test_behaviors.parquet')
history.write_parquet('preprocess/test_history.parquet')

In [18]:
del behavior
del history