In [None]:
from pathlib import Path
import sys

ROOT = Path().resolve().parents[1] / "code"
sys.path.append(str(ROOT))
from functions.valid_recall import valid_recall
from gensim.models import Word2Vec
import numpy as np
import cupy as cp

In [None]:
model = Word2Vec.load("/home/ming/GraduateProject/Data/item2vec.model")

item_ids = list(model.wv.index_to_key)
item_vecs = np.vstack([model.wv[i] for i in item_ids])

# 归一化（非常重要）
item_vecs = item_vecs / np.linalg.norm(item_vecs, axis=1, keepdims=True)

item_id2idx = {nid: i for i, nid in enumerate(item_ids)}

item_emb = cp.asarray(item_vecs)

In [None]:
def build_user_embedding(history, item_emb, item_id2idx, max_hist=10):
    vecs = []
    for nid in history[-max_hist:]:
        if nid in item_id2idx:
            vecs.append(item_emb[item_id2idx[nid]])
    if not vecs:
        return None
    user_emb = cp.stack(vecs).mean(axis=0)
    user_emb = user_emb / cp.linalg.norm(user_emb)
    return user_emb


In [None]:
from cuml.neighbors import NearestNeighbors

knn = NearestNeighbors(n_neighbors=100, metric="cosine")
knn.fit(item_emb)


In [None]:
def recall_user_embedding_valid(train_path=None,pred_path=None, topk=500):
    if train_path:
        train = pl.read_parquet(train_path)
    else:
        train = read_train_behaviors()


    if pred_path:
        pred = pl.read_parquet(pred_path)
    else:
        pred = read_dev_behaviors()

    train = (
        train
        .sort("time")
        .group_by("user_id")
        .agg(pl.all().last())
    )

    # 构建 user -> history dict
    user_hist = {
        row["user_id"]: row["history"]
        for row in train.iter_rows(named=True)
    }
    user_list=pred['user_id'].unique().to_list()

    rows = []

    for uid in tqdm(user_list,total=len(user_list)):
        history = user_hist.get(uid, [])
        if not history:
            rows.append((uid, []))
            continue

        user_emb = build_user_embedding(history, item_vecs, item_id2idx)
        if user_emb is None:
            rows.append((uid, []))
            continue

        _, idx = knn.kneighbors(user_emb[None, :], n_neighbors=topk+20)
        idx = idx[0].get()

        rec = []
        for i in idx:
            nid = item_ids[int(i)]
            if nid not in history:
                rec.append(nid)
            if len(rec) >= topk:
                break

        rows.append((uid, rec))

    return pl.DataFrame(rows, schema=["user_id", "rec_list"])

In [None]:
def get_users_earliest_hist(data):
    """
    只保留每个 user_id 最旧的一次 impression 的真实点击
    :param data: user_id, time, history, impressions
    :return: user_id, hist(list[str])
    """
    # 1. 每个 user 只保留 time 最小的一条
    data_oldest = (
        data
        .sort("time")
        .group_by("user_id")
        .agg(pl.all().first())
    )

    # 2. 从 impressions 中抽取正样本
    click = (
        data_oldest
        .select(["user_id", "impressions"])
        .explode("impressions")
        .with_columns([
            pl.col("impressions")
            .str.split("-")
            .list.get(0)
            .alias("article_id"),
            pl.col("impressions")
            .str.split("-")
            .list.get(1)
            .cast(pl.Int8)
            .alias("label"),
        ])
        .filter(pl.col("label") == 1)
    )

    # 3. 聚合为 user → hist
    user_hist = (
        click
        .group_by("user_id")
        .agg(pl.col("article_id").alias("hist"))
    )

    return user_hist


def valid_recall(pred, topk=5):
    gt = get_users_earliest_hist(
        read_dev_behaviors()
    )

    data = pred.join(
        gt,
        on="user_id",
        how="inner"
    )

    for i in range(1, topk + 1):
        k = i * 10

        recall_k = (
            data
            # 1. 截断召回列表
            .with_columns(
                pl.col("rec_list").list.slice(0, k).alias("rec_k"),
                pl.col("hist").list.len().alias("gt_len")
            )
            # 2. 展开 rec_k
            .explode("rec_k")
            # 3. 是否命中用户 earliest hist
            .with_columns(
                pl.col("rec_k").is_in(pl.col("hist")).cast(pl.Int8).alias("hit")
            )
            # 4. user 级聚合
            .group_by(["user_id", "gt_len"])
            .agg(pl.sum("hit").alias("hit_cnt"))
            # 5. user recall
            .with_columns(
                (pl.col("hit_cnt") / pl.col("gt_len")).alias("recall")
            )
            # 6. 所有 user 平均
            .select(pl.col("recall").mean())
            .item()
        )

        print(f"User-Recall@{k}: {recall_k}")




In [None]:
res = recall_user_embedding_valid()

In [None]:
valid_recall(res)