In [30]:
from sentence_transformers import SentenceTransformer
from config.common import cfg
import os
from os.path import join
import polars as pl
import pandas as pd
from tqdm.notebook import tqdm
import numpy as np

In [31]:
news = pl.read_parquet(join(cfg['data_path'], 'news.parquet'))

In [32]:
train=pl.read_parquet(join(cfg['train_data_path'], 'behaviors.parquet'))
dev=pl.read_parquet(join(cfg['dev_data_path'], 'behaviors.parquet'))

In [33]:
userID_mapping = {}
itemID_mapping = {}

for line in train.iter_rows(named=True):
    userID = line['user_id']
    history = line['history']
    for itemID in history:
        if userID not in userID_mapping:
            userID_mapping[userID] = len(userID_mapping) + 1

        if itemID not in itemID_mapping:
            itemID_mapping[itemID] = len(itemID_mapping) + 1
for line in dev.iter_rows(named=True):
    userID = line['user_id']
    history = line['history']
    for itemID in history:
        if userID not in userID_mapping:
            userID_mapping[userID] = len(userID_mapping) + 1

        if itemID not in itemID_mapping:
            itemID_mapping[itemID] = len(itemID_mapping) + 1
for itemID in news.iter_rows(named=True):
    itemID=itemID['news_id']
    if itemID not in itemID_mapping:
            itemID_mapping[itemID] = len(itemID_mapping) + 1

np.save(cfg['user_dict'], userID_mapping)
print("user_num:", len(userID_mapping))
print("the first five userID mapping:", list(userID_mapping.items())[:5])
np.save(cfg['item_dict'], itemID_mapping)
print("item_num:", len(itemID_mapping))
print("the first five itemID mapping:", list(itemID_mapping.items())[:5])


user_num: 94057
the first five userID mapping: [('U13740', 1), ('U91836', 2), ('U73700', 3), ('U34670', 4), ('U8125', 5)]
item_num: 65239
the first five itemID mapping: [('N55189', 1), ('N42782', 2), ('N34694', 3), ('N45794', 4), ('N18445', 5)]


In [34]:
data = pl.read_parquet(os.path.join(cfg['dev_data_path'], "behaviors.parquet"))
data = data.filter(pl.col("history").list.len() > 0)
data_sorted = data.sort(["user_id", "time"])

# 对每个用户取最后一条记录
data = data_sorted.group_by("user_id").tail(1)
data=data.select(['user_id','history','impressions'])

In [35]:
data.head()

user_id,history,impressions
str,list[str],list[str]
"""U72782""","[""N36920"", ""N25522"", … ""N12431""]","[""N28640-0"", ""N52492-0"", … ""N29862-0""]"
"""U72073""","[""N56698"", ""N56446"", … ""N31978""]","[""N1878-0"", ""N7556-0"", … ""N32237-0""]"
"""U82367""","[""N871"", ""N51706"", … ""N8148""]","[""N19990-1"", ""N36779-0"", … ""N6638-0""]"
"""U37201""","[""N11282"", ""N47558"", … ""N13636""]","[""N53572-0"", ""N50775-0"", … ""N58251-1""]"
"""U84200""","[""N38118"", ""N32893"", … ""N6074""]","[""N19685-0"", ""N20036-0"", … ""N48487-0""]"


In [36]:
def filter_and_map_impressions(imp_list):
    # imp 格式为 "N123-1"，只保留后缀为 -1 的，并提取 ID 部分进行映射
    clicked_ids = [
        imp.split('-')[0]
        for imp in imp_list
        if imp.endswith('-1')
    ]
    return clicked_ids

In [37]:
data = data.with_columns([
    pl.col("impressions").map_elements(filter_and_map_impressions, return_dtype=pl.List(pl.Utf8))
])

In [38]:
data.head()

user_id,history,impressions
str,list[str],list[str]
"""U72782""","[""N36920"", ""N25522"", … ""N12431""]","[""N12409""]"
"""U72073""","[""N56698"", ""N56446"", … ""N31978""]","[""N57359""]"
"""U82367""","[""N871"", ""N51706"", … ""N8148""]","[""N19990"", ""N20187"", ""N20036""]"
"""U37201""","[""N11282"", ""N47558"", … ""N13636""]","[""N5472"", ""N58251""]"
"""U84200""","[""N38118"", ""N32893"", … ""N6074""]","[""N36779""]"


In [39]:
def prepare_data(data_dict):
    "数据保存polar形式"
    rows = []
    for userID, item_sequenec in data_dict.items():
        history = item_sequenec[:-1]
        target = item_sequenec[-1]
        rows.append({'user_id': userID, 'history': history, 'target': target})
    return pl.DataFrame(rows)


"划分训练集，测试集，验证集"
res={}

for row in data.iter_rows(named=True):
    userID = row["user_id"]
    item_sequence = [itemID_mapping[t] for t in  row["history"]]
    target_sequence = [itemID_mapping[t] for  t in row["impressions"]]
    if len(target_sequence) >0:
        res[userID] = {'history': item_sequence, 'target': target_sequence}

In [40]:
res

{'U72782': {'history': [6887,
   14268,
   14595,
   4115,
   2724,
   903,
   11419,
   109,
   2304,
   567,
   2346,
   2921,
   864,
   23,
   243,
   1768,
   1863,
   1420,
   3342,
   1065,
   1182,
   2750,
   198,
   2089,
   46,
   1021,
   543,
   2929,
   3806,
   2736,
   1322,
   7,
   1184,
   720,
   509,
   9,
   2220,
   58,
   400,
   511,
   9496,
   1322,
   890,
   2665,
   112,
   1187,
   4435,
   2778,
   1083,
   8749,
   4374,
   617,
   286,
   4400,
   1390,
   967,
   448,
   3361,
   288,
   407,
   1496,
   4051,
   575,
   6315,
   1165,
   2070,
   458,
   821,
   113,
   974,
   2899,
   532,
   1679,
   83,
   2842,
   1461,
   900,
   4518,
   2767],
  'target': [65197]},
 'U72073': {'history': [1771,
   359,
   11629,
   4448,
   923,
   1182,
   26633,
   949,
   170,
   1073,
   6869,
   2673,
   150,
   6627,
   106,
   974,
   185,
   2644],
  'target': [46210]},
 'U82367': {'history': [109, 864, 1137, 170, 1787],
  'target': [58979, 49139, 472

In [28]:
rows = []
for user_id, v in res.items():
    rows.append({
        "user_id": user_id,
        "history": v["history"],   # List[Int]
        "target": v["target"],     # List[Int]，变长 OK
    })

df = pl.DataFrame(
    rows,
    schema={
        "user_id": pl.Utf8,
        "history": pl.List(pl.Int64),
        "target": pl.List(pl.Int64),
    }
)

df.write_parquet(os.path.join(cfg['eval_data']))

In [29]:
df.head()

user_id,history,target
str,list[i64],list[i64]
"""U55481""","[14400, 5636, … 3235]",[50259]
"""U50201""","[3995, 1571, … 6920]","[56685, 61389]"
"""U92745""","[3038, 205, … 1187]",[45902]
"""U70312""","[1803, 4402, … 1377]",[58871]
"""U15713""","[8447, 4266, … 597]",[56087]
