In [14]:
from sentence_transformers import SentenceTransformer
from config.common import cfg
import os
from os.path import join
import polars as pl
import pandas as pd
from tqdm.notebook import tqdm
import numpy as np

In [15]:
news = pl.read_parquet(join(cfg['data_path'], 'news.parquet'))

In [16]:
train=pl.read_parquet(join(cfg['train_data_path'], 'behaviors.parquet'))
dev=pl.read_parquet(join(cfg['dev_data_path'], 'behaviors.parquet'))

In [17]:
userID_mapping = {}
itemID_mapping = {}

for line in train.iter_rows(named=True):
    userID = line['user_id']
    history = line['history']
    for itemID in history:
        if userID not in userID_mapping:
            userID_mapping[userID] = len(userID_mapping) + 1

        if itemID not in itemID_mapping:
            itemID_mapping[itemID] = len(itemID_mapping) + 1
for line in dev.iter_rows(named=True):
    userID = line['user_id']
    history = line['history']
    for itemID in history:
        if userID not in userID_mapping:
            userID_mapping[userID] = len(userID_mapping) + 1

        if itemID not in itemID_mapping:
            itemID_mapping[itemID] = len(itemID_mapping) + 1
for itemID in news.iter_rows(named=True):
    itemID=itemID['news_id']
    if itemID not in itemID_mapping:
            itemID_mapping[itemID] = len(itemID_mapping) + 1

np.save(cfg['user_dict'], userID_mapping)
print("user_num:", len(userID_mapping))
print("the first five userID mapping:", list(userID_mapping.items())[:5])
np.save(cfg['item_dict'], itemID_mapping)
print("item_num:", len(itemID_mapping))
print("the first five itemID mapping:", list(itemID_mapping.items())[:5])


user_num: 94057
the first five userID mapping: [('U13740', 1), ('U91836', 2), ('U73700', 3), ('U34670', 4), ('U8125', 5)]
item_num: 65239
the first five itemID mapping: [('N55189', 1), ('N42782', 2), ('N34694', 3), ('N45794', 4), ('N18445', 5)]


In [18]:
data = pl.read_parquet(os.path.join(cfg['dev_data_path'], "behaviors.parquet"))
data = data.filter(pl.col("history").list.len() > 0)
data_sorted = data.sort(["user_id", "time"])

# 对每个用户取最后一条记录
data = data_sorted.group_by("user_id").first()
data=data.select(['user_id','history','impressions'])

In [19]:
data.head()

user_id,history,impressions
str,list[str],list[str]
"""U80902""","[""N10617"", ""N27120"", … ""N24261""]","[""N24802-1"", ""N29862-0""]"
"""U73558""","[""N1132"", ""N7127"", … ""N61664""]","[""N58656-0"", ""N23490-1"", … ""N39907-0""]"
"""U80440""","[""N48699"", ""N28257"", ""N15676""]","[""N1952-1"", ""N19990-0"", … ""N11390-0""]"
"""U86230""","[""N57737"", ""N32089"", … ""N13539""]","[""N4241-0"", ""N23513-0"", … ""N11930-0""]"
"""U86981""","[""N39885"", ""N619"", … ""N41698""]","[""N28345-0"", ""N24176-0"", … ""N7895-0""]"


In [20]:
def filter_and_map_impressions(imp_list):
    # imp 格式为 "N123-1"，只保留后缀为 -1 的，并提取 ID 部分进行映射
    clicked_ids = [
        imp.split('-')[0]
        for imp in imp_list
        if imp.endswith('-1')
    ]
    return clicked_ids

In [21]:
data = data.with_columns([
    pl.col("impressions").map_elements(filter_and_map_impressions, return_dtype=pl.List(pl.Utf8))
])

In [22]:
data.head()

user_id,history,impressions
str,list[str],list[str]
"""U80902""","[""N10617"", ""N27120"", … ""N24261""]","[""N24802""]"
"""U73558""","[""N1132"", ""N7127"", … ""N61664""]","[""N23490""]"
"""U80440""","[""N48699"", ""N28257"", ""N15676""]","[""N1952""]"
"""U86230""","[""N57737"", ""N32089"", … ""N13539""]","[""N496"", ""N30290""]"
"""U86981""","[""N39885"", ""N619"", … ""N41698""]","[""N32237"", ""N30290"", ""N56113""]"


In [23]:
data = data.with_columns(
    pl.concat_list(["history", "impressions"]).alias("history")
).drop("impressions")


In [24]:
data.head()

user_id,history
str,list[str]
"""U80902""","[""N10617"", ""N27120"", … ""N24802""]"
"""U73558""","[""N1132"", ""N7127"", … ""N23490""]"
"""U80440""","[""N48699"", ""N28257"", … ""N1952""]"
"""U86230""","[""N57737"", ""N32089"", … ""N30290""]"
"""U86981""","[""N39885"", ""N619"", … ""N56113""]"


In [25]:
def prepare_data(data_dict):
    "数据保存polar形式"
    rows = []
    for userID, item_sequenec in data_dict.items():
        history = item_sequenec[:-1]
        target = item_sequenec[-1]
        rows.append({'user_id': userID, 'history': history, 'target': target})
    return pl.DataFrame(rows)


"划分训练集，测试集，验证集"
train_data = {}
val_data = {}
test_data = {}

for row in data.iter_rows(named=True):
    userID = row["user_id"]
    item_sequence = [itemID_mapping[t] for t in  row["history"]]
    if len(item_sequence) > 2:
        train_data[userID] = item_sequence[:-2]
        val_data[userID] = item_sequence[:-1]
        test_data[userID] = item_sequence

print("training data:", list(train_data.items())[:5])
print("validation data:", list(val_data.items())[:5])
print("testing data:", list(test_data.items())[:5])

train_df = prepare_data(train_data)
print("\nTraining data shape:", train_df.shape)
print("the first 3 rows of training data:\n", train_df.head(3))
test_df = prepare_data(test_data)
print("\nTesting data shape:", test_df.shape)
print("the first 3 rows of testing data:\n", test_df.head(3))
val_df = prepare_data(val_data)
print("\nValidation data shape:", val_df.shape)
print("the first 3 rows of validation data:\n", val_df.head(3))

test_df.shape

training data: [('U80902', [6869, 7036, 176, 8194, 2113, 7845]), ('U73558', [13945, 3391, 4649, 40, 1067, 956, 499, 966, 2825]), ('U80440', [1758, 468]), ('U86230', [368, 363, 1486, 1729, 597, 3598]), ('U86981', [10022, 550, 9328, 4300, 443, 119, 4951, 688, 8190, 13659, 46178])]
validation data: [('U80902', [6869, 7036, 176, 8194, 2113, 7845, 2638]), ('U73558', [13945, 3391, 4649, 40, 1067, 956, 499, 966, 2825, 3560]), ('U80440', [1758, 468, 1508]), ('U86230', [368, 363, 1486, 1729, 597, 3598, 52629]), ('U86981', [10022, 550, 9328, 4300, 443, 119, 4951, 688, 8190, 13659, 46178, 55165])]
testing data: [('U80902', [6869, 7036, 176, 8194, 2113, 7845, 2638, 61881]), ('U73558', [13945, 3391, 4649, 40, 1067, 956, 499, 966, 2825, 3560, 49416]), ('U80440', [1758, 468, 1508, 62363]), ('U86230', [368, 363, 1486, 1729, 597, 3598, 52629, 55165]), ('U86981', [10022, 550, 9328, 4300, 443, 119, 4951, 688, 8190, 13659, 46178, 55165, 54077])]

Training data shape: (48413, 3)
the first 3 rows of trainin

(48413, 3)

In [26]:
print("Data saved to parquet files.")
train_df.write_parquet(os.path.join(cfg['dev_data_path'], "train_df_1.parquet"))
test_df.write_parquet(os.path.join(cfg['dev_data_path'], "test_df_1.parquet"))
val_df.write_parquet(os.path.join(cfg['dev_data_path'], "valid_df_1.parquet"))

Data saved to parquet files.
