In [10]:
from pathlib import Path
import sys


ROOT = Path().resolve().parents[1] / "code"
sys.path.append(str(ROOT))
from functions.valid_recall import valid_recall
from functions.read_behaviors import read_train_behaviors,read_dev_behaviors

import pickle
from collections import defaultdict
import polars as pl
from tqdm import tqdm
import numpy as np
import numpy as np
import pickle
from cuml.neighbors import NearestNeighbors
import cupy as cp
import warnings
warnings.filterwarnings("ignore")

In [11]:
embeddings = np.load("/home/ming/GraduateProject/Data/content_emb/news_emb.npy")
with open("/home/ming/GraduateProject/Data/content_emb/news_ids.pkl", "rb") as f:
    article_ids = pickle.load(f)
news_id2idx = {nid: i for i, nid in enumerate(article_ids)}

news_emb = cp.asarray(embeddings)
knn = NearestNeighbors(n_neighbors=51, metric='cosine')  # 最近邻查找，每个样本返回21个（包括自己）
knn.fit(news_emb)

In [12]:
def processData(actions: pl.DataFrame):
    """
    拆开impression数据，转换为user-item-hist_time
    :param actions: user_id,time,history,impressions
    :return: user_id, item_id, time
    """
    click = (
        actions
        .select(["user_id", "time", "impressions"])
        .explode("impressions")
        .with_columns([
            pl.col("impressions").str.split("-").list.get(0).alias("item_id"),
            pl.col("impressions").str.split("-").list.get(1).cast(pl.UInt8).alias("label"),
        ])
        .filter(pl.col("label") == 1)
        .select(["user_id", "item_id", "time"])
    )
    click = click.group_by("user_id").agg(
        pl.col("item_id").alias("item_ids"),
        pl.col("time").alias("times")
    )
    return click


In [13]:
def build_user_embedding(history,times,max_time, news_emb, news_id2idx, K=50,tau=86400*0.55):
    # 取最近 K 条（history 本身是按时间顺序的）
    recent = history[-K:]
    times=times[-K:]

    vecs = []
    weights=[]

    for nid,t in zip(recent,times):
        if nid not in news_id2idx:
            continue
        vec=news_emb[news_id2idx[nid]]
        dt = abs(max_time- t).total_seconds()
        weight= np.exp(-dt / tau)

        vecs.append(vec*weight)
        weights.append(weight)

    if not vecs:
        return None

    vecs = cp.stack(vecs)              # [k, d]
    user_emb = vecs.sum(axis=0)/sum(weights)       # mean pooling
    user_emb = user_emb / cp.linalg.norm(user_emb)

    return user_emb


In [14]:
def recall_user_embedding_valid(train_path=None,pred_path=None, topk=50):
    if train_path:
        train = pl.read_parquet(train_path)
    else:
        train = read_train_behaviors()
    train=processData(train)
    if pred_path:
        pred = pl.read_parquet(pred_path)
    else:
        pred = read_dev_behaviors()

    pred=(pred.sort('time').group_by('user_id').agg(pl.all().first())).select(['user_id','time'])
    user_time={uid:time for uid,time in pred.iter_rows()}
    user_set=set(pred['user_id'].unique().to_list())
    del pred

    rows = []

    for uid, history, times in tqdm(train.iter_rows(),total=train.shape[0]):
        if uid not in user_set or uid not in user_time:
            continue

        user_emb = build_user_embedding(history, times,user_time[uid],news_emb, news_id2idx)
        if user_emb is None:
            rows.append((uid, []))
            continue

        _, idx = knn.kneighbors(user_emb[None, :], n_neighbors=topk+20)
        idx = idx[0].get()

        rec = []
        for i in idx:
            nid = article_ids[int(i)]
            if nid not in history:
                rec.append(nid)
            if len(rec) >= topk:
                break

        rows.append((uid, rec))

    return pl.DataFrame(rows, schema=["user_id", "rec_list"])

In [15]:
from sklearn.cluster import KMeans

def build_multi_interest_embeddings(history,news_emb,news_id2idx,K_clusters=3):
    recent = history

    vecs=[]
    for nid in recent:
        if nid in news_id2idx:
            vecs.append(news_emb[news_id2idx[nid]])
    if not vecs:
        return None
    X=cp.stack(vecs).get()

    actual_clusters=min(K_clusters,len(X))
    if actual_clusters<=1:
        user_emb=cp.mean(cp.stack(vecs),axis=0)
        return [user_emb/cp.linalg.norm(user_emb)]
    kmeans=KMeans(n_clusters=actual_clusters,random_state=42,n_init='auto')
    kmeans.fit(X)

    interest_vectors=[]
    for center in kmeans.cluster_centers_:
        v=cp.asarray(center)
        interest_vectors.append(v/cp.linalg.norm(v))
    return interest_vectors


In [21]:
def recall_multi_interest_embedding_valid(train_path=None,pred_path=None, K_clusters=5,topk=50):
    if train_path:
        train = pl.read_parquet(train_path)
    else:
        train = read_train_behaviors()
    train=processData(train)
    #train=train.sort('time').group_by('user_id').agg(pl.all().last()).select(['user_id','history'])

    if pred_path:
        pred = pl.read_parquet(pred_path)
    else:
        pred = read_dev_behaviors()

    pred=(pred.sort('time').group_by('user_id').agg(pl.all().first())).select(['user_id','time'])
    user_time={uid:time for uid,time in pred.iter_rows()}
    user_set=set(pred['user_id'].unique().to_list())
    del pred

    rows = []

    for uid, history in tqdm(train.iter_rows(),total=train.shape[0]):
        if uid not in user_set or uid not in user_time:
            continue
        combined_rec=[]
        user_embs = build_multi_interest_embeddings(history, news_emb, news_id2idx,K_clusters)
        if user_embs:
            for emb in user_embs:
                # 针对每个兴趣点找 topk 个候选，后面再统一截断
                _, idx = knn.kneighbors(emb[None, :], n_neighbors=topk//K_clusters+5)
                idx = idx[0].get()

                for i in idx:
                    nid = article_ids[int(i)]
                    # 过滤历史和已加入的重复新闻
                    if nid not in history and nid not in combined_rec:
                        combined_rec.append(nid)

            # 在处理完该用户所有兴趣点后，保存 combined_rec
            rows.append((uid, combined_rec[:topk]))
        else:
            # 如果没有兴趣向量，也存个空，保持 DataFrame 长度对齐（可选）
            rows.append((uid, []))

    return pl.DataFrame(rows, schema=["user_id", "rec_list"])

In [22]:
res=recall_multi_interest_embedding_valid()

100%|██████████| 711222/711222 [21:19<00:00, 555.94it/s]


In [24]:
res=recall_user_embedding_valid()

100%|██████████| 711222/711222 [04:46<00:00, 2483.21it/s]


In [25]:
valid_recall(res)

User-Recall@10: 0.004159803034559255
User-Recall@20: 0.00703525341556781
User-Recall@30: 0.009200001790552002
User-Recall@40: 0.011129831015739096
User-Recall@50: 0.012691627797011563


In [None]:
res