In [1]:
from sentence_transformers import SentenceTransformer
from config.common import cfg
import os
from os.path import join
import polars as pl
import pandas as pd
from tqdm.notebook import tqdm
import numpy as np

  return torch._C._cuda_getDeviceCount() > 0


In [2]:
news = pl.read_parquet(join(cfg['data_path'], 'news.parquet'))

In [3]:
train=pl.read_parquet(join(cfg['train_data_path'], 'behaviors.parquet'))
dev=pl.read_parquet(join(cfg['dev_data_path'], 'behaviors.parquet'))

In [4]:
userID_mapping = {}
itemID_mapping = {}

for line in train.iter_rows(named=True):
    userID = line['user_id']
    history = line['history']
    for itemID in history:
        if userID not in userID_mapping:
            userID_mapping[userID] = len(userID_mapping) + 1

        if itemID not in itemID_mapping:
            itemID_mapping[itemID] = len(itemID_mapping) + 1
for line in dev.iter_rows(named=True):
    userID = line['user_id']
    history = line['history']
    for itemID in history:
        if userID not in userID_mapping:
            userID_mapping[userID] = len(userID_mapping) + 1

        if itemID not in itemID_mapping:
            itemID_mapping[itemID] = len(itemID_mapping) + 1
for itemID in news.iter_rows(named=True):
    itemID=itemID['news_id']
    if itemID not in itemID_mapping:
            itemID_mapping[itemID] = len(itemID_mapping) + 1

np.save(cfg['user_dict'], userID_mapping)
print("user_num:", len(userID_mapping))
print("the first five userID mapping:", list(userID_mapping.items())[:5])
np.save(cfg['item_dict'], itemID_mapping)
print("item_num:", len(itemID_mapping))
print("the first five itemID mapping:", list(itemID_mapping.items())[:5])


user_num: 94057
the first five userID mapping: [('U13740', 1), ('U91836', 2), ('U73700', 3), ('U34670', 4), ('U8125', 5)]
item_num: 65239
the first five itemID mapping: [('N55189', 1), ('N42782', 2), ('N34694', 3), ('N45794', 4), ('N18445', 5)]


In [5]:
data = pl.read_parquet(os.path.join(cfg['dev_data_path'], "behaviors.parquet"))
data = data.filter(pl.col("history").list.len() > 0)
data_sorted = data.sort(["user_id", "time"])

# 对每个用户取最后一条记录
data = data_sorted.group_by("user_id").tail(1)
data=data.select(['user_id','history','impressions'])

In [6]:
data.head()

user_id,history,impressions
str,list[str],list[str]
"""U9325""","[""N21746"", ""N45729"", … ""N40509""]","[""N17807-0"", ""N27936-0"", … ""N31958-0""]"
"""U31113""","[""N57571"", ""N3479"", … ""N34983""]","[""N5472-0"", ""N53572-0"", … ""N31958-1""]"
"""U31370""","[""N11405""]","[""N34130-1"", ""N50775-0"", … ""N29862-0""]"
"""U44712""","[""N55846"", ""N1150"", … ""N22816""]","[""N26485-0"", ""N48426-0"", … ""N14273-0""]"
"""U57210""","[""N2794"", ""N6233"", … ""N54625""]","[""N35216-0"", ""N36779-1"", … ""N23513-0""]"


In [7]:
def filter_and_map_impressions(imp_list):
    # imp 格式为 "N123-1"，只保留后缀为 -1 的，并提取 ID 部分进行映射
    clicked_ids = [
        imp.split('-')[0]
        for imp in imp_list
        if imp.endswith('-1')
    ]
    return clicked_ids

In [8]:
data = data.with_columns([
    pl.col("impressions").map_elements(filter_and_map_impressions, return_dtype=pl.List(pl.Utf8))
])

In [9]:
data.head()

user_id,history,impressions
str,list[str],list[str]
"""U9325""","[""N21746"", ""N45729"", … ""N40509""]","[""N56391""]"
"""U31113""","[""N57571"", ""N3479"", … ""N34983""]","[""N31958""]"
"""U31370""","[""N11405""]","[""N34130""]"
"""U44712""","[""N55846"", ""N1150"", … ""N22816""]","[""N33507"", ""N23675""]"
"""U57210""","[""N2794"", ""N6233"", … ""N54625""]","[""N36779""]"


In [10]:
data = data.with_columns(
    pl.concat_list(["history", "impressions"]).alias("history")
).drop("impressions")


In [11]:
data.head()

user_id,history
str,list[str]
"""U9325""","[""N21746"", ""N45729"", … ""N56391""]"
"""U31113""","[""N57571"", ""N3479"", … ""N31958""]"
"""U31370""","[""N11405"", ""N34130""]"
"""U44712""","[""N55846"", ""N1150"", … ""N23675""]"
"""U57210""","[""N2794"", ""N6233"", … ""N36779""]"


In [12]:
def prepare_data(data_dict):
    "数据保存polar形式"
    rows = []
    for userID, item_sequenec in data_dict.items():
        history = item_sequenec[:-1]
        target = item_sequenec[-1]
        rows.append({'user_id': userID, 'history': history, 'target': target})
    return pl.DataFrame(rows)


"划分训练集，测试集，验证集"
train_data = {}
val_data = {}
test_data = {}

for row in data.iter_rows(named=True):
    userID = row["user_id"]
    item_sequence = [itemID_mapping[t] for t in  row["history"]]
    if len(item_sequence) > 2:
        train_data[userID] = item_sequence[:-2]
        val_data[userID] = item_sequence[:-1]
        test_data[userID] = item_sequence

print("training data:", list(train_data.items())[:5])
print("validation data:", list(val_data.items())[:5])
print("testing data:", list(test_data.items())[:5])

train_df = prepare_data(train_data)
print("\nTraining data shape:", train_df.shape)
print("the first 3 rows of training data:\n", train_df.head(3))
test_df = prepare_data(test_data)
print("\nTesting data shape:", test_df.shape)
print("the first 3 rows of testing data:\n", test_df.head(3))
val_df = prepare_data(val_data)
print("\nValidation data shape:", val_df.shape)
print("the first 3 rows of validation data:\n", val_df.head(3))

test_df.shape

training data: [('U9325', [3377, 107, 368, 931, 784, 1848, 996, 12695, 3274, 2590, 8645, 815, 1383, 824]), ('U31113', [1057, 21727, 46, 541, 258, 260, 499, 546, 1884, 2752, 709, 550, 1599, 111, 14197, 2969, 574, 443, 2508, 2067, 458, 1518, 453, 5819, 2835, 897, 79, 3229, 974, 930, 1461, 1983, 4517, 1787]), ('U44712', [759, 1005, 785, 1020, 3961, 4949, 3140, 3502, 573, 6913, 1521, 564]), ('U57210', [1179, 21, 2329, 1203, 236, 5090, 3062, 7656, 917, 368, 2737, 1617]), ('U854', [5546, 5961, 18825, 1005, 3991, 1402, 4593, 2322, 15212, 4067, 911, 3049, 117, 6950, 1660, 3786, 567, 1197, 21, 16435, 8258, 27, 1016, 4010, 487, 4487, 5712, 8269, 6248, 132, 14849, 1090, 6041, 1257, 40, 8481, 42, 4082, 16847, 6136, 6755, 6040, 1182, 4083, 10598, 1669, 10003, 2662, 11193, 10783, 1303, 10461, 7357, 1389, 14854, 5333, 4090, 550, 1187, 4152, 1344, 4052, 1335, 83, 824])]
validation data: [('U9325', [3377, 107, 368, 931, 784, 1848, 996, 12695, 3274, 2590, 8645, 815, 1383, 824, 185]), ('U31113', [1057, 2

(48403, 3)

In [13]:
print("Data saved to parquet files.")
train_df.write_parquet(os.path.join(cfg['dev_data_path'], "train_df_1.parquet"))
test_df.write_parquet(os.path.join(cfg['dev_data_path'], "test_df_1.parquet"))
val_df.write_parquet(os.path.join(cfg['dev_data_path'], "valid_df_1.parquet"))

Data saved to parquet files.
