In [1]:
from sentence_transformers import SentenceTransformer
from config.common import cfg
import os
from os.path import join
import polars as pl
import pandas as pd
from tqdm.notebook import tqdm
import numpy as np

In [2]:
news = pl.read_parquet(join(cfg['data_path'], 'news.parquet'))

In [3]:
train=pl.read_parquet(join(cfg['train_data_path'], 'behaviors.parquet'))
dev=pl.read_parquet(join(cfg['dev_data_path'], 'behaviors.parquet'))

In [4]:
userID_mapping = {}
itemID_mapping = {}

for line in train.iter_rows(named=True):
    userID = line['user_id']
    history = line['history']
    for itemID in history:
        if userID not in userID_mapping:
            userID_mapping[userID] = len(userID_mapping) + 1

        if itemID not in itemID_mapping:
            itemID_mapping[itemID] = len(itemID_mapping) + 1
for line in dev.iter_rows(named=True):
    userID = line['user_id']
    history = line['history']
    for itemID in history:
        if userID not in userID_mapping:
            userID_mapping[userID] = len(userID_mapping) + 1

        if itemID not in itemID_mapping:
            itemID_mapping[itemID] = len(itemID_mapping) + 1
for itemID in news.iter_rows(named=True):
    itemID=itemID['news_id']
    if itemID not in itemID_mapping:
            itemID_mapping[itemID] = len(itemID_mapping) + 1

np.save(cfg['user_dict'], userID_mapping)
print("user_num:", len(userID_mapping))
print("the first five userID mapping:", list(userID_mapping.items())[:5])
np.save(cfg['item_dict'], itemID_mapping)
print("item_num:", len(itemID_mapping))
print("the first five itemID mapping:", list(itemID_mapping.items())[:5])


user_num: 94057
the first five userID mapping: [('U13740', 1), ('U91836', 2), ('U73700', 3), ('U34670', 4), ('U8125', 5)]
item_num: 65239
the first five itemID mapping: [('N55189', 1), ('N42782', 2), ('N34694', 3), ('N45794', 4), ('N18445', 5)]


In [5]:
data = pl.read_parquet(os.path.join(cfg['dev_data_path'], "behaviors.parquet"))
data = data.filter(pl.col("history").list.len() > 0)
data_sorted = data.sort(["user_id", "time"])

# 对每个用户取最后一条记录
data = data_sorted.group_by("user_id").first()
data=data.select(['user_id','history','impressions'])

In [6]:
data.head()

user_id,history,impressions
str,list[str],list[str]
"""U74897""","[""N16965"", ""N25488"", … ""N52214""]","[""N50055-0"", ""N17807-0"", … ""N6916-0""]"
"""U45165""","[""N37942"", ""N1066"", … ""N26009""]","[""N61838-0"", ""N30049-0"", … ""N4667-0""]"
"""U76958""","[""N44402"", ""N13137"", … ""N4080""]","[""N23355-0"", ""N19990-0"", … ""N5940-0""]"
"""U77523""","[""N22279"", ""N29177"", … ""N55497""]","[""N31958-1"", ""N36779-0"", … ""N19990-0""]"
"""U82709""","[""N6233"", ""N51706"", ""N28888""]","[""N30598-0"", ""N54562-1"", … ""N4734-0""]"


In [7]:
def filter_and_map_impressions(imp_list):
    # imp 格式为 "N123-1"，只保留后缀为 -1 的，并提取 ID 部分进行映射
    clicked_ids = [
        imp.split('-')[0]
        for imp in imp_list
        if imp.endswith('-1')
    ]
    return clicked_ids

In [8]:
data = data.with_columns([
    pl.col("impressions").map_elements(filter_and_map_impressions, return_dtype=pl.List(pl.Utf8))
])

In [9]:
data.head()

user_id,history,impressions
str,list[str],list[str]
"""U74897""","[""N16965"", ""N25488"", … ""N52214""]","[""N11930""]"
"""U45165""","[""N37942"", ""N1066"", … ""N26009""]","[""N56080"", ""N28640""]"
"""U76958""","[""N44402"", ""N13137"", … ""N4080""]","[""N30290""]"
"""U77523""","[""N22279"", ""N29177"", … ""N55497""]","[""N31958""]"
"""U82709""","[""N6233"", ""N51706"", ""N28888""]","[""N54562""]"


In [10]:
data = data.with_columns(
    pl.concat_list(["history", "impressions"]).alias("history")
).drop("impressions")


In [11]:
data.head()

user_id,history
str,list[str]
"""U74897""","[""N16965"", ""N25488"", … ""N11930""]"
"""U45165""","[""N37942"", ""N1066"", … ""N28640""]"
"""U76958""","[""N44402"", ""N13137"", … ""N30290""]"
"""U77523""","[""N22279"", ""N29177"", … ""N31958""]"
"""U82709""","[""N6233"", ""N51706"", … ""N54562""]"


In [12]:
def prepare_data(data_dict):
    "数据保存polar形式"
    rows = []
    for userID, item_sequenec in data_dict.items():
        history = item_sequenec[:-1]
        target = item_sequenec[-1]
        rows.append({'user_id': userID, 'history': history, 'target': target})
    return pl.DataFrame(rows)


"划分训练集，测试集，验证集"
train_data = {}
val_data = {}
test_data = {}

for row in data.iter_rows(named=True):
    userID = row["user_id"]
    item_sequence = [itemID_mapping[t] for t in  row["history"]]
    if len(item_sequence) > 2:
        train_data[userID] = item_sequence[:-2]
        val_data[userID] = item_sequence[:-1]
        test_data[userID] = item_sequence

print("training data:", list(train_data.items())[:5])
print("validation data:", list(val_data.items())[:5])
print("testing data:", list(test_data.items())[:5])

train_df = prepare_data(train_data)
print("\nTraining data shape:", train_df.shape)
print("the first 3 rows of training data:\n", train_df.head(3))
test_df = prepare_data(test_data)
print("\nTesting data shape:", test_df.shape)
print("the first 3 rows of testing data:\n", test_df.head(3))
val_df = prepare_data(val_data)
print("\nValidation data shape:", val_df.shape)
print("the first 3 rows of validation data:\n", val_df.head(3))

test_df.shape

training data: [('U74897', [3648, 6892, 2583, 7250, 6909, 5473, 3549, 4983, 3342, 4084, 609, 4292, 4291, 3082, 11218, 2618, 544, 3325, 2436, 2374, 6869, 2981, 267, 4300, 6402, 727, 2664, 12556, 619, 1330, 8166, 5047, 1093, 1280, 6057, 4577, 2006, 3370, 421, 9923, 5791, 1461, 2654, 3234, 3679, 8690, 5628, 2644, 6551]), ('U45165', [1972, 2808, 5665, 3168, 176, 595, 400, 3100]), ('U76958', [2417, 197, 4, 219, 176, 288, 3151, 3119, 7900, 1165, 6053, 1679, 470]), ('U77523', [2011, 904, 11488, 109, 2796, 17593, 1197, 861, 669, 2696, 544, 1616, 12127, 435, 3507]), ('U82709', [21, 864])]
validation data: [('U74897', [3648, 6892, 2583, 7250, 6909, 5473, 3549, 4983, 3342, 4084, 609, 4292, 4291, 3082, 11218, 2618, 544, 3325, 2436, 2374, 6869, 2981, 267, 4300, 6402, 727, 2664, 12556, 619, 1330, 8166, 5047, 1093, 1280, 6057, 4577, 2006, 3370, 421, 9923, 5791, 1461, 2654, 3234, 3679, 8690, 5628, 2644, 6551, 5337]), ('U45165', [1972, 2808, 5665, 3168, 176, 595, 400, 3100, 50272]), ('U76958', [2417, 1

(48413, 3)

In [13]:
print("Data saved to parquet files.")
train_df.write_parquet(os.path.join(cfg['dev_data_path'], "train_df_1.parquet"))
test_df.write_parquet(os.path.join(cfg['dev_data_path'], "test_df_1.parquet"))
val_df.write_parquet(os.path.join(cfg['dev_data_path'], "valid_df_1.parquet"))

Data saved to parquet files.
