In [1]:
from pathlib import Path
import sys
import pickle
import os
import polars as pl
from collections import defaultdict
from math import sqrt
from heapq import heappush, heapreplace
from tqdm import tqdm
from math import exp
import numpy as np

In [2]:
def processData(actions: pl.DataFrame):
    click = (
        actions
        .select(["user_id", "time", "impressions"])
        .explode("impressions")
        .with_columns([
            pl.col("impressions").str.split("-").list.get(0).alias("item_id"),
            pl.col("impressions").str.split("-").list.get(1).cast(pl.UInt8).alias("label"),
        ])
        .filter(pl.col("label") == 1)
        .select(["user_id", "item_id", "time"])
    )

    return click


In [3]:
def read_dev_behaviors():
    pass
def read_train_behaviors():
    pass

In [4]:
def build_itemcf_from_history(
    click,
    max_hist_len=30,
    cooc_window=100,
    topk=50,
    tau=86400
):
    """
    基于 MIND history 构建 ItemCF 相似度矩阵

    :param behaviors: DataFrame，至少包含 user_id, history(list[str])
    :param max_hist_len: 每个用户最多使用的历史长度（取最近）
    :param cooc_window: 共现窗口大小（左右各 window）
    :param topk: 每个 item 保留 topk 个相似 item
    :return: dict[item_id] -> list[(sim, item_j)]
    """
        
    item_cnt = defaultdict(int) # 点击计数
    cooc = defaultdict(float) #　共现次数

    # 按用户分组
    user_hist = (
        click
        .sort(["user_id", "time"])
        .group_by("user_id")
        .agg([
            pl.col("item_id").alias("items"),
            pl.col("time").alias("times"),
        ])
    )
    
    for row in tqdm(user_hist.iter_rows(named=True), total=user_hist.height):
        items = row["items"][-max_hist_len:]
        times = row["times"][-max_hist_len:]
        L = len(items)

        if L < 2:
            continue

        for i in range(L):
            item_i = items[i]
            ti = times[i]
            item_cnt[item_i] += 1

            left = max(0, i - cooc_window)
            right = min(L, i + cooc_window + 1)

            for j in range(left, right):
                if i == j:
                    continue
                item_j = items[j]
                tj = times[j]

                # 时间衰减权重
                dt = abs(ti - tj).total_seconds()
                # time_w = exp(-dt / tau)
                # time_w = 1 / (1 + dt / (0.6* tau))
                time_w = np.exp(-dt / tau)
                cooc[(item_i, item_j)] += time_w


    # TopK 相似度
    item_sim = defaultdict(list)

    for (i, j), cij in cooc.items():
        sim = cij / sqrt(item_cnt[i] * item_cnt[j]+5)
        heap = item_sim[i]
        if len(heap) < topk:
            heappush(heap, (sim, j))
        else:
            if sim > heap[0][0]:
                heapreplace(heap, (sim, j))

    return item_sim

In [5]:
def recall_itemcf_valid_impression(
    train_path=None,
    pred_path=None,
    sim_path="itemcf_sim.pkl",
    topk=50,
    tau=86400,
):
    # 1. 读取训练数据
    if train_path:
        train = pl.read_parquet(train_path)
    else:
        train = read_train_behaviors()

    train = processData(train)

    # 2. 构建 / 加载 ItemCF
    if sim_path and os.path.exists(sim_path):
        with open(sim_path, "rb") as f:
            item_sim = pickle.load(f)
    else:
        item_sim = build_itemcf_from_history(train)
        with open(sim_path, "wb") as f:
            pickle.dump(item_sim, f)

    del train

    # 3. 读取 dev behaviors（每一行 = 一个 impression）
    if pred_path:
        pred = pl.read_parquet(pred_path)
    else:
        pred = read_dev_behaviors()

    pred = (
        pred
        .with_row_count("impr_id")
        .select(["impr_id", "user_id", "history", "time"])
    )

    current = pred.select(pl.col("time").max()).item().timestamp()

    results = []

    for row in pred.iter_rows(named=True):
        impr_id = row["impr_id"]
        user = row["user_id"]
        hist = row["history"]
        ti = row["time"].timestamp()

        dt = (current - ti) / tau
        # time_w = 1 / (1 + dt)
        time_w = np.exp(-(current - ti) / tau)

        scores = defaultdict(float)

        for item in hist:
            if item not in item_sim:
                continue
            for sim, j in item_sim[item]:
                scores[j] += sim * time_w

        # 去掉历史点击
        for item in hist:
            scores.pop(item, None)

        if not scores:
            results.append((impr_id, user, [], []))
            continue

        rec = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:topk]

        results.append(
            (
                impr_id,
                user,
                [i for i, _ in rec],
                [s for _, s in rec],
            )
        )

    return pl.DataFrame(
        results,
        schema=["impr_id", "user_id", "rec_list", "rec_score"],
    )


In [6]:
res=recall_itemcf_valid_impression('behaviors.parquet','behaviors_1.parquet')

100%|██████████| 711222/711222 [01:56<00:00, 6127.46it/s]
  .with_row_count("impr_id")
  return pl.DataFrame(


In [7]:
def get_impression_gt(data):
    """
    impression 级 GT
    return: impr_id, user_id, gt(list[str])
    """
    gt = (
        data
        .with_row_count("impr_id")
        .select(["impr_id", "user_id", "impressions"])
        .explode("impressions")
        .with_columns([
            pl.col("impressions").str.split("-").list.get(0).alias("news_id"),
            pl.col("impressions").str.split("-").list.get(1).cast(pl.Int8).alias("label"),
        ])
        .filter(pl.col("label") == 1)
        .group_by(["impr_id", "user_id"])
        .agg(pl.col("news_id").alias("gt"))
    )
    return gt


def valid_recall_impression(pred, topk=5):
    gt = get_impression_gt(pl.read_parquet('behaviors_1.parquet'))

    data = pred.join(
        gt,
        on=["impr_id", "user_id"],
        how="inner"
    )

    for i in range(1, topk + 1):
        k = i * 10

        recall_k = (
            data
            # 1. 截断 rec_list
            .with_columns(
                pl.col("rec_list").list.slice(0, k).alias("rec_k"),
                pl.col("gt").list.len().alias("gt_len")
            )
            # 2. 展开 rec_k
            .explode("rec_k")
            # 3. 判断是否命中
            .with_columns(
                pl.col("rec_k").is_in(pl.col("gt")).cast(pl.Int8).alias("hit")
            )
            # 4. impression 级聚合
            .group_by(["impr_id", "user_id", "gt_len"])
            .agg(pl.sum("hit").alias("hit_cnt"))
            # 5. impression recall
            .with_columns(
                (pl.col("hit_cnt") / pl.col("gt_len")).alias("recall")
            )
            # 6. 全体 impression 平均
            .select(pl.col("recall").mean())
            .item()
        )

        print(f"Recall@{k}", recall_k)


In [8]:
valid_recall_impression(res)

  .with_row_count("impr_id")


Recall@10 0.00041214529236182956
Recall@20 0.0007465628730334807
Recall@30 0.001134424261476587
Recall@40 0.0014065993786479677
Recall@50 0.0016875200005241778


In [9]:
res.head()

impr_id,user_id,rec_list,rec_score
i64,str,list[str],list[f64]
0,"""U134050""","[""N25215"", ""N101760"", … ""N94157""]","[0.21786, 0.21786, … 0.001495]"
1,"""U254959""","[""N5997"", ""N62199"", … ""N35663""]","[0.173019, 0.173019, … 0.017575]"
2,"""U499841""","[""N81885"", ""N47196"", … ""N122628""]","[0.09102, 0.054119, … 0.010571]"
3,"""U107107""","[""N15166"", ""N44094"", … ""N55261""]","[0.175347, 0.139743, … 0.024534]"
4,"""U492344""","[""N15166"", ""N44094"", … ""N86745""]","[0.169587, 0.135152, … 0.010729]"
