# 多路召回（增强版）

包含以下召回策略：
1. **热门召回** - 全局热门物品
2. **ItemCF（增强版）** - 添加时间衰减和位置权重
3. **UserCF** - 基于用户相似度的协同过滤
4. **Embedding召回** - 基于文章向量的相似度召回
5. **冷启动过滤** - 基于规则的候选过滤

最后进行多路召回融合并评估召回率。

In [1]:
import os
import math
import pickle
from collections import defaultdict
from pathlib import Path

import faiss
import numpy as np
import pandas as pd
from tqdm import tqdm

from funrec.utils import load_env_with_fallback

load_env_with_fallback()
RAW_DATA_PATH = Path(os.getenv('FUNREC_RAW_DATA_PATH'))
PROCESSED_DATA_PATH = Path(os.getenv('FUNREC_PROCESSED_DATA_PATH'))

DATA_PATH = RAW_DATA_PATH / 'dataset' / 'news_recommendation'
if not DATA_PATH.exists():
    DATA_PATH = RAW_DATA_PATH / 'news_recommendation'

PROJECT_PATH = PROCESSED_DATA_PATH / 'projects' / 'news_recommendation_system'


In [2]:
train_hist = pd.read_pickle(PROJECT_PATH / 'train_hist.pkl')
valid_last = pd.read_pickle(PROJECT_PATH / 'valid_last.pkl')

train_hist = train_hist.sort_values(['user_id', 'click_timestamp'])
user_hist = (
    train_hist.groupby('user_id')['click_article_id'].apply(list).to_dict()
)
valid_last_map = dict(zip(valid_last['user_id'], valid_last['click_article_id']))

len(user_hist)


200000

In [3]:
DEBUG = True
MAX_USERS = 20000
RANDOM_SEED = 42

if DEBUG and len(user_hist) > MAX_USERS:
    rng = np.random.default_rng(RANDOM_SEED)
    sample_users = rng.choice(list(user_hist.keys()), size=MAX_USERS, replace=False)
    user_hist = {u: user_hist[u] for u in sample_users}
    train_hist_small = train_hist[train_hist['user_id'].isin(sample_users)]
    valid_last_map = {u: valid_last_map[u] for u in sample_users if u in valid_last_map}
else:
    train_hist_small = train_hist

len(user_hist), len(train_hist_small)


(20000, 92646)

In [4]:
popular_items = train_hist_small['click_article_id'].value_counts().index.tolist()

def popular_recall(hist_items, k=50):
    recs = []
    hist_set = set(hist_items)
    for item in popular_items:
        if item in hist_set:
            continue
        recs.append((item, 1.0))
        if len(recs) >= k:
            break
    return recs


In [5]:
def build_itemcf_sim(click_df):
    user_items = (
        click_df.groupby('user_id')['click_article_id'].apply(list).to_dict()
    )
    i2i = defaultdict(dict)
    item_cnt = defaultdict(int)
    for items in tqdm(user_items.values(), desc='itemcf', mininterval=1):
        for i in items:
            item_cnt[i] += 1
            i2i.setdefault(i, {})
            for j in items:
                if i == j:
                    continue
                i2i[i].setdefault(j, 0.0)
                i2i[i][j] += 1 / math.log(len(items) + 1)

    for i, related in i2i.items():
        for j in related:
            i2i[i][j] /= math.sqrt(item_cnt[i] * item_cnt[j])

    return i2i

i2i_sim = build_itemcf_sim(train_hist_small)
with open(PROJECT_PATH / 'itemcf_i2i.pkl', 'wb') as f:
    pickle.dump(i2i_sim, f)


itemcf: 100%|██████████| 20000/20000 [00:00<00:00, 56668.79it/s]


In [6]:
def itemcf_recall(hist_items, i2i_sim, sim_topk=20, recall_num=50):
    rank = defaultdict(float)
    hist_set = set(hist_items)
    for item in hist_items:
        for j, score in sorted(
            i2i_sim.get(item, {}).items(), key=lambda x: x[1], reverse=True
        )[:sim_topk]:
            if j in hist_set:
                continue
            rank[j] += score
    if len(rank) < recall_num:
        for item in popular_items:
            if item in rank or item in hist_set:
                continue
            rank[item] = -1.0
            if len(rank) >= recall_num:
                break
    return sorted(rank.items(), key=lambda x: x[1], reverse=True)[:recall_num]

recall_itemcf = {}
for user, items in tqdm(user_hist.items(), desc='recall_itemcf', mininterval=1):
    recall_itemcf[user] = itemcf_recall(items, i2i_sim, sim_topk=20, recall_num=50)


recall_itemcf: 100%|██████████| 20000/20000 [00:07<00:00, 2514.99it/s]


In [7]:
article_emb = pd.read_csv(DATA_PATH / 'articles_emb.csv')
emb_cols = [c for c in article_emb.columns if c.startswith('emb_')]
emb_matrix = article_emb[emb_cols].values.astype('float32')
emb_matrix /= np.linalg.norm(emb_matrix, axis=1, keepdims=True) + 1e-12

article_ids = article_emb['article_id'].values
id2idx = {aid: idx for idx, aid in enumerate(article_ids)}

index = faiss.IndexFlatIP(emb_matrix.shape[1])
index.add(emb_matrix)

def emb_recall(last_item, hist_items, topk=50):
    idx = id2idx.get(last_item)
    if idx is None:
        return []
    query = emb_matrix[idx].reshape(1, -1)
    scores, indices = index.search(query, topk + len(hist_items))
    recs = []
    hist_set = set(hist_items)
    for score, j in zip(scores[0], indices[0]):
        item_id = int(article_ids[j])
        if item_id in hist_set:
            continue
        recs.append((item_id, float(score)))
        if len(recs) >= topk:
            break
    return recs

recall_emb = {}
for user, items in tqdm(user_hist.items(), desc='recall_emb', mininterval=1):
    recall_emb[user] = emb_recall(items[-1], items, topk=50)


recall_emb: 100%|██████████| 20000/20000 [02:26<00:00, 136.45it/s]


In [8]:
recall_pop = {u: popular_recall(items, k=50) for u, items in user_hist.items()}


In [9]:
def merge_recall(recall_dicts, weights, topk=50):
    merged = defaultdict(dict)
    for name, recall in recall_dicts.items():
        weight = weights.get(name, 1.0)
        for user, items in recall.items():
            for rank, (item, score) in enumerate(items):
                merged[user][item] = merged[user].get(item, 0.0) + weight / (rank + 1)
    merged_sorted = {
        user: sorted(items.items(), key=lambda x: x[1], reverse=True)[:topk]
        for user, items in merged.items()
    }
    return merged_sorted

def hit_rate(recall_dict, target_map, k=50):
    hit = 0
    total = 0
    for user, target in target_map.items():
        items = [i for i, _ in recall_dict.get(user, [])][:k]
        if target in items:
            hit += 1
        total += 1
    return hit / total if total else 0.0

recall_merged = merge_recall(
    {'pop': recall_pop, 'itemcf': recall_itemcf, 'emb': recall_emb},
    weights={'pop': 0.2, 'itemcf': 1.0, 'emb': 0.8},
    topk=50,
)

{
    'pop@50': hit_rate(recall_pop, valid_last_map, k=50),
    'itemcf@50': hit_rate(recall_itemcf, valid_last_map, k=50),
    'emb@50': hit_rate(recall_emb, valid_last_map, k=50),
    'merged@50': hit_rate(recall_merged, valid_last_map, k=50),
}


{'pop@50': 0.2539, 'itemcf@50': 0.428, 'emb@50': 0.0272, 'merged@50': 0.39325}

In [10]:
def recall_to_df(recall_dict):
    rows = []
    for user, items in recall_dict.items():
        for item, score in items:
            rows.append((user, item, score))
    return pd.DataFrame(rows, columns=['user_id', 'article_id', 'recall_score'])

recall_df = recall_to_df(recall_merged)
recall_df.to_pickle(PROJECT_PATH / 'recall_candidates.pkl')
recall_df.head()


Unnamed: 0,user_id,article_id,recall_score
0,199593,180196,1.0
1,199593,207876,0.8
2,199593,106947,0.5
3,199593,206109,0.4
4,199593,107057,0.333333
