In [5]:
import polars as pl
from config.common import cfg
import os

In [6]:
data=pl.read_parquet(os.path.join(cfg['train_data_path'],"behaviors.parquet"))

data_sorted = data.sort(["user_id", "time"])

# 对每个用户取最后一条记录
data=data_sorted.group_by("user_id").tail(1)

In [7]:
def prepare_data(data_dict):
    "数据保存polar形式"
    rows=[]
    for userID,item_sequenec in data_dict.items():
        history=item_sequenec[:-1]
        target=item_sequenec[-1]
        rows.append({'user':userID,'history':history,'target':target})
    return pl.DataFrame(rows)

In [8]:
"划分训练集，测试集，验证集"
train_data = {}
val_data = {}
test_data = {}

for row in data.iter_rows(named=True):
    userID = row["user_id"]  # 注意列名可能是 'userID' 或 'user_id'
    item_sequence = row["history"]
    if len(item_sequence)>2:
        train_data[userID]=item_sequence[:-2]
        val_data[userID]=item_sequence[:-1]
        test_data[userID]=item_sequence

print("training data:", list(train_data.items())[:5])
print("validation data:", list(val_data.items())[:5])
print("testing data:", list(test_data.items())[:5])

train_df=prepare_data(train_data)
print("\nTraining data shape:", train_df.shape)
print("the first 3 rows of training data:\n", train_df.head(3))
test_df=prepare_data(test_data)
print("\nTesting data shape:", test_df.shape)
print("the first 3 rows of testing data:\n", test_df.head(3))
val_df=prepare_data(val_data)
print("\nValidation data shape:", val_df.shape)
print("the first 3 rows of validation data:\n", val_df.head(3))

print("Data saved to parquet files.")
train_df.write_parquet(os.path.join(cfg['train_data_path'],"train_df.parquet"))
test_df.write_parquet(os.path.join(cfg['train_data_path'],"test_df.parquet"))
val_df.write_parquet(os.path.join(cfg['train_data_path'],"valid_df.parquet"))

training data: [('U79044', ['N45544', 'N477', 'N42470', 'N22697']), ('U18373', ['N22130', 'N24406', 'N36576', 'N19741', 'N33276', 'N46244', 'N56514', 'N35488', 'N53531', 'N12034', 'N53033', 'N31611', 'N15677', 'N52500', 'N14385', 'N13935', 'N42620', 'N48680', 'N46267', 'N42677', 'N52551', 'N29149', 'N42128', 'N46795', 'N1969', 'N18870', 'N60702', 'N63435', 'N55911', 'N2420', 'N40509', 'N40785']), ('U15844', ['N46868', 'N12194', 'N28992']), ('U19498', ['N47765', 'N56742', 'N39276', 'N44442', 'N42281', 'N38585', 'N30353', 'N18285']), ('U91516', ['N53538'])]
validation data: [('U79044', ['N45544', 'N477', 'N42470', 'N22697', 'N57955']), ('U18373', ['N22130', 'N24406', 'N36576', 'N19741', 'N33276', 'N46244', 'N56514', 'N35488', 'N53531', 'N12034', 'N53033', 'N31611', 'N15677', 'N52500', 'N14385', 'N13935', 'N42620', 'N48680', 'N46267', 'N42677', 'N52551', 'N29149', 'N42128', 'N46795', 'N1969', 'N18870', 'N60702', 'N63435', 'N55911', 'N2420', 'N40509', 'N40785', 'N37377']), ('U15844', ['N46