# 7. 双塔召回（Two-Tower + In-batch Negatives + FAISS）

这一节把基础版项目缺失的**双塔召回训练**补齐，并给出一条更贴近工业界的端到端流程：

1) 构造训练样本（序列 next-item）

2) 训练双塔（用户塔 / 物品塔）

3) 导出 item embedding，建立 FAISS 索引

4) 离线评估 Recall（HitRate@K / NDCG@K）

5) 产出召回候选集（给后续排序模型训练用）

## 面试/实习加分点（建议你在 README / 简历里写）

- **双塔训练**：in-batch negatives（对比学习 / InfoNCE）
- **负样本策略**：in-batch 负样本 +（可扩展：hard negative）
- **向量检索**：FAISS IndexFlatIP / IVF（可扩展）
- **一致性**：离线 item 向量 + 在线 user 向量 + ANN 检索
- **评估**：Recall 指标 + 过滤已看历史（避免数据泄漏）

## 产物位置

默认写入：`tmp/projects/news_recommendation_system/artifacts/two_tower/`


In [1]:
import os
import pickle
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd
import tensorflow as tf
import faiss
from dotenv import find_dotenv, load_dotenv
from tqdm import tqdm

tf.get_logger().setLevel('ERROR')


def find_repo_root(start: Path) -> Path:
    cur = start
    for _ in range(10):
        if (cur / 'pyproject.toml').exists() or (cur / '.git').exists():
            return cur
        if cur.parent == cur:
            break
        cur = cur.parent
    return start


REPO_ROOT = find_repo_root(Path.cwd())

# 读取 .env（如果存在）
dotenv_path = find_dotenv(usecwd=True)
if dotenv_path:
    load_dotenv(dotenv_path)

# 回退默认路径（保持与基础版一致）
os.environ.setdefault('FUNREC_RAW_DATA_PATH', str(REPO_ROOT / 'data'))
os.environ.setdefault('FUNREC_PROCESSED_DATA_PATH', str(REPO_ROOT / 'tmp'))

RAW_DATA_PATH = Path(os.getenv('FUNREC_RAW_DATA_PATH'))
PROCESSED_DATA_PATH = Path(os.getenv('FUNREC_PROCESSED_DATA_PATH'))

DATA_PATH = RAW_DATA_PATH / 'dataset' / 'news_recommendation'
if not DATA_PATH.exists():
    DATA_PATH = RAW_DATA_PATH / 'news_recommendation'

PROJECT_PATH = PROCESSED_DATA_PATH / 'projects' / 'news_recommendation_system'
ARTIFACTS_DIR = PROJECT_PATH / 'artifacts' / 'two_tower' / 'dssm_inbatch'
ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)

DATA_PATH, PROJECT_PATH, ARTIFACTS_DIR


(PosixPath('/Users/wangjunfei/Desktop/fun-rec/data/dataset/news_recommendation'),
 PosixPath('/Users/wangjunfei/Desktop/fun-rec/tmp/projects/news_recommendation_system'),
 PosixPath('/Users/wangjunfei/Desktop/fun-rec/tmp/projects/news_recommendation_system/artifacts/two_tower/dssm_inbatch'))

## 1) 数据准备：离线切分（每个用户最后一次点击做验证）

如果你已经跑过基础版 2.baseline.ipynb，会存在：

- `tmp/projects/news_recommendation_system/train_hist.pkl`
- `tmp/projects/news_recommendation_system/valid_last.pkl`

这里会优先复用；否则自动从原始点击日志构建。


In [2]:
def build_offline_split_last_click(click_df: pd.DataFrame) -> Tuple[pd.DataFrame, pd.DataFrame]:
    click_df = click_df.sort_values(['user_id', 'click_timestamp'])
    last_click = click_df.groupby('user_id').tail(1)
    hist = click_df.drop(last_click.index)

    valid_users = hist['user_id'].unique()
    hist = hist[hist['user_id'].isin(valid_users)]
    last_click = last_click[last_click['user_id'].isin(valid_users)]
    return hist.reset_index(drop=True), last_click.reset_index(drop=True)


train_hist_path = PROJECT_PATH / 'train_hist.pkl'
valid_last_path = PROJECT_PATH / 'valid_last.pkl'

if train_hist_path.exists() and valid_last_path.exists():
    train_hist = pd.read_pickle(train_hist_path)
    valid_last = pd.read_pickle(valid_last_path)
else:
    train_click = pd.read_csv(DATA_PATH / 'train_click_log.csv')
    train_hist, valid_last = build_offline_split_last_click(train_click)
    train_hist.to_pickle(train_hist_path)
    valid_last.to_pickle(valid_last_path)

articles = pd.read_csv(DATA_PATH / 'articles.csv')

train_hist.head(), valid_last.head(), articles.head()


(   user_id  click_article_id  click_timestamp  click_environment  \
 0        0             30760    1508211672520                  4   
 1        1            289197    1508211316889                  4   
 2        2             36162    1508211438695                  4   
 3        3             50644    1508211359672                  4   
 4        4             42567    1508211625466                  4   
 
    click_deviceGroup  click_os  click_country  click_region  \
 0                  1        17              1            25   
 1                  1        17              1            25   
 2                  3        20              1            25   
 3                  3         2              1            25   
 4                  1        12              1            16   
 
    click_referrer_type  
 0                    2  
 1                    6  
 2                    2  
 3                    2  
 4                    1  ,
    user_id  click_article_id  click_time

In [3]:
# 可选：DEBUG 采样（更快）
DEBUG = True
MAX_USERS = 20000
SEED = 42

if DEBUG:
    rng = np.random.default_rng(SEED)
    users = train_hist['user_id'].unique()
    if len(users) > MAX_USERS:
        sample_users = rng.choice(users, size=MAX_USERS, replace=False)
        train_hist = train_hist[train_hist['user_id'].isin(sample_users)]
        valid_last = valid_last[valid_last['user_id'].isin(sample_users)]

train_hist['user_id'].nunique(), len(train_hist), len(valid_last)


(20000, 92646, 20000)

## 2) ID 编码与样本构造

我们需要把原始 `user_id/article_id/category_id` 编码成从 1 开始的连续整数：

- 0 预留给 padding/unknown
- embedding table 的 vocab_size = num_classes + 1

训练样本采用 next-item：

- 输入：用户 id + 历史点击序列（定长 padding）
- label：下一个点击 item（物品塔输入）

为避免样本爆炸，这里提供 `MAX_SAMPLES_PER_USER` 控制每个用户最多生成多少条样本（面试时可解释：工程上可用采样/重放/窗口裁剪）。


In [4]:
@dataclass(frozen=True)
class IdMap:
    name: str
    classes_: np.ndarray
    offset: int = 1
    unknown_value: int = 0

    @property
    def vocab_size(self) -> int:
        return int(len(self.classes_) + self.offset)

    def transform(self, values) -> np.ndarray:
        index = pd.Index(self.classes_)
        arr = np.asarray(values)
        flat = arr.reshape(-1)
        idx = index.get_indexer(flat)
        out = idx.astype(np.int64) + self.offset
        out[idx < 0] = self.unknown_value
        return out.reshape(arr.shape).astype(np.int32)

    @classmethod
    def fit(cls, name: str, values, offset: int = 1) -> 'IdMap':
        uniq = pd.unique(pd.Series(list(values)))
        try:
            uniq = np.array(sorted(uniq))
        except Exception:
            uniq = np.array(list(uniq))
        return cls(name=name, classes_=uniq, offset=offset)


def pad_left(seqs: List[List[int]], max_len: int, pad_value: int = 0) -> np.ndarray:
    out = np.full((len(seqs), max_len), pad_value, dtype=np.int32)
    for i, seq in enumerate(seqs):
        if not seq:
            continue
        seq = seq[-max_len:]
        out[i, -len(seq):] = np.asarray(seq, dtype=np.int32)
    return out


# 物品到类别映射（不存在的置 0）
item_to_cat = dict(zip(articles['article_id'].astype(int), articles['category_id'].astype(int)))

# 编码器：item 编码用 articles 保证 vocab 稳定
user_id_map = IdMap.fit('user_id', train_hist['user_id'].unique(), offset=1)
item_id_map = IdMap.fit('article_id', articles['article_id'].unique(), offset=1)
cat_id_map = IdMap.fit('category_id', articles['category_id'].unique(), offset=1)

user_id_map.vocab_size, item_id_map.vocab_size, cat_id_map.vocab_size


(20001, 364048, 462)

In [5]:
# ==================== 样本构造（滑窗 next-item） ====================
MAX_SEQ_LEN = 30
MIN_HIST = 1
MAX_SAMPLES_PER_USER = 20 if DEBUG else 200

train_hist_sorted = train_hist.sort_values(['user_id', 'click_timestamp'])
user_hist = train_hist_sorted.groupby('user_id')['click_article_id'].apply(list).to_dict()

samples_user: List[int] = []
samples_hist: List[List[int]] = []
samples_item: List[int] = []
samples_cat: List[int] = []

rng = np.random.default_rng(SEED)

for u, seq in tqdm(user_hist.items(), desc='build_samples'):
    if len(seq) < (MIN_HIST + 1):
        continue
    # 每个用户可选：只取最后 MAX_SAMPLES_PER_USER 次点击作为训练目标（更贴近兴趣演化）
    positions = list(range(1, len(seq)))
    if len(positions) > MAX_SAMPLES_PER_USER:
        # 采样一部分位置（偏向最近）
        tail = positions[-MAX_SAMPLES_PER_USER * 3 :]
        positions = rng.choice(tail, size=MAX_SAMPLES_PER_USER, replace=False).tolist()
        positions.sort()

    for t in positions:
        hist = seq[max(0, t - MAX_SEQ_LEN) : t]
        if len(hist) < MIN_HIST:
            continue
        target_item = int(seq[t])
        target_cat = int(item_to_cat.get(target_item, 0))
        samples_user.append(int(u))
        samples_hist.append([int(x) for x in hist])
        samples_item.append(target_item)
        samples_cat.append(target_cat)

print('num_samples:', len(samples_user))
print('num_users:', len(set(samples_user)))


build_samples: 100%|██████████| 20000/20000 [00:00<00:00, 150475.32it/s]

num_samples: 65809
num_users: 11924





## 3) 双塔模型（In-batch Negatives / InfoNCE）

核心做法：

- 一个 batch 内，(user_i, item_i) 作为正样本
- 其它 (user_i, item_j) 作为负样本（in-batch negatives）
- 用 softmax 交叉熵做对比学习（可选双向：user→item + item→user）

注意：如果 batch 内存在重复 item，会引入“伪负样本”；工业里会用去重 batch / 多正样本 loss / sampled softmax 等手段缓解。


In [6]:
def masked_mean(emb: tf.Tensor, ids: tf.Tensor) -> tf.Tensor:
    # emb: [B, L, D], ids: [B, L]
    mask = tf.cast(tf.not_equal(ids, 0), tf.float32)  # [B, L]
    mask = tf.expand_dims(mask, axis=-1)  # [B, L, 1]
    summed = tf.reduce_sum(emb * mask, axis=1)  # [B, D]
    denom = tf.reduce_sum(mask, axis=1)  # [B, 1]
    return summed / tf.maximum(denom, 1.0)


def build_two_tower_model(
    user_vocab_size: int,
    item_vocab_size: int,
    cat_vocab_size: int,
    max_seq_len: int = 30,
    emb_dim: int = 32,
    dnn_units: List[int] = [128, 64, 32],
    temperature: float = 0.05,
):
    # Inputs
    user_id_inp = tf.keras.layers.Input(shape=(), dtype=tf.int32, name='user_id')
    hist_item_inp = tf.keras.layers.Input(shape=(max_seq_len,), dtype=tf.int32, name='hist_article_id')
    target_item_inp = tf.keras.layers.Input(shape=(), dtype=tf.int32, name='article_id')
    target_cat_inp = tf.keras.layers.Input(shape=(), dtype=tf.int32, name='category_id')

    # Embeddings
    user_emb_layer = tf.keras.layers.Embedding(user_vocab_size, emb_dim, name='emb_user')
    item_emb_layer = tf.keras.layers.Embedding(item_vocab_size, emb_dim, name='emb_item')
    cat_emb_layer = tf.keras.layers.Embedding(cat_vocab_size, max(4, emb_dim // 4), name='emb_cat')

    user_id_emb = tf.keras.layers.Flatten()(user_emb_layer(user_id_inp))  # [B, D]
    hist_item_emb = item_emb_layer(hist_item_inp)  # [B, L, D]
    hist_mean = tf.keras.layers.Lambda(lambda x: masked_mean(x[0], x[1]))([hist_item_emb, hist_item_inp])

    user_feat = tf.keras.layers.Concatenate()([user_id_emb, hist_mean])
    x = user_feat
    for units in dnn_units:
        x = tf.keras.layers.Dense(units, activation='relu')(x)
        x = tf.keras.layers.Dropout(0.1)(x)
    user_vec = tf.keras.layers.Lambda(lambda t: tf.nn.l2_normalize(t, axis=1), name='user_vec')(x)

    item_id_emb = tf.keras.layers.Flatten()(item_emb_layer(target_item_inp))
    cat_emb = tf.keras.layers.Flatten()(cat_emb_layer(target_cat_inp))
    item_feat = tf.keras.layers.Concatenate()([item_id_emb, cat_emb])
    y = item_feat
    for units in dnn_units:
        y = tf.keras.layers.Dense(units, activation='relu')(y)
        y = tf.keras.layers.Dropout(0.1)(y)
    item_vec = tf.keras.layers.Lambda(lambda t: tf.nn.l2_normalize(t, axis=1), name='item_vec')(y)

    # In-batch logits: [B, D] x [B, D]^T => [B, B]
    logits = tf.keras.layers.Lambda(lambda z: tf.matmul(z[0], z[1], transpose_b=True) / temperature, name='logits')(
        [user_vec, item_vec]
    )

    model = tf.keras.Model(
        inputs={'user_id': user_id_inp, 'hist_article_id': hist_item_inp, 'article_id': target_item_inp, 'category_id': target_cat_inp},
        outputs=logits,
        name='two_tower_inbatch',
    )
    user_tower = tf.keras.Model(inputs={'user_id': user_id_inp, 'hist_article_id': hist_item_inp}, outputs=user_vec, name='user_tower')
    item_tower = tf.keras.Model(inputs={'article_id': target_item_inp, 'category_id': target_cat_inp}, outputs=item_vec, name='item_tower')
    return model, user_tower, item_tower


TEMPERATURE = 0.05
EMB_DIM = 32
DNN_UNITS = [128, 64, 32]

model, user_tower, item_tower = build_two_tower_model(
    user_vocab_size=user_id_map.vocab_size,
    item_vocab_size=item_id_map.vocab_size,
    cat_vocab_size=cat_id_map.vocab_size,
    max_seq_len=MAX_SEQ_LEN,
    emb_dim=EMB_DIM,
    dnn_units=DNN_UNITS,
    temperature=TEMPERATURE,
)

model.summary()


Model: "two_tower_inbatch"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 user_id (InputLayer)        [(None,)]                    0         []                            
                                                                                                  
 hist_article_id (InputLaye  [(None, 30)]                 0         []                            
 r)                                                                                               
                                                                                                  
 article_id (InputLayer)     [(None,)]                    0         []                            
                                                                                                  
 category_id (InputLayer)    [(None,)]                    0         []            

In [None]:
def inbatch_symmetric_loss(y_true, logits):
    # y_true ignored
    b = tf.shape(logits)[0]
    labels = tf.range(b)
    loss_u2i = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits)
    loss_i2u = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=tf.transpose(logits))
    return tf.reduce_mean(loss_u2i + loss_i2u) / 2.0


try:
    # Mac M1/M2 上 Keras v2.11+ 的新 optimizer 可能会明显变慢；这里直接用 legacy 版本更稳。
    optimizer = tf.keras.optimizers.legacy.Adam(learning_rate=2e-4)
except Exception:
    optimizer = tf.keras.optimizers.Adam(learning_rate=2e-4)

model.compile(
    optimizer=optimizer,
    loss=inbatch_symmetric_loss,
)

# 组装训练数据
X_user = user_id_map.transform(np.asarray(samples_user, dtype=np.int64))
X_hist = pad_left([item_id_map.transform(np.asarray(s, dtype=np.int64)).tolist() for s in samples_hist], max_len=MAX_SEQ_LEN)
X_item = item_id_map.transform(np.asarray(samples_item, dtype=np.int64))
X_cat = cat_id_map.transform(np.asarray(samples_cat, dtype=np.int64))

train_X = {
    'user_id': X_user,
    'hist_article_id': X_hist,
    'article_id': X_item,
    'category_id': X_cat,
}

dummy_y = np.zeros(len(X_user), dtype=np.float32)

BATCH_SIZE = 1024
EPOCHS = 3

history = model.fit(train_X, dummy_y, batch_size=BATCH_SIZE, epochs=EPOCHS, verbose=1)




Epoch 1/3
Epoch 2/3
Epoch 3/3


: 

## 4) 导出 item embedding + FAISS 建索引

这里用内积检索（IndexFlatIP），并做 L2 normalize，让内积等价于 cosine。

你可以在面试里扩展：

- Flat → IVF / HNSW 加速
- offline build + online serving（增量更新）


In [None]:
# 全量 item 输入（使用 raw article_id 作为 FAISS index id）
all_items_raw = articles['article_id'].astype(int).unique()
all_items_enc = item_id_map.transform(all_items_raw)
all_cats_raw = np.asarray([item_to_cat.get(int(i), 0) for i in all_items_raw], dtype=np.int64)
all_cats_enc = cat_id_map.transform(all_cats_raw)

item_embs = item_tower.predict(
    {'article_id': all_items_enc, 'category_id': all_cats_enc},
    batch_size=4096,
    verbose=0,
).astype('float32')

faiss.normalize_L2(item_embs)

index = faiss.IndexIDMap2(faiss.IndexFlatIP(item_embs.shape[1]))
index.add_with_ids(item_embs, all_items_raw.astype('int64'))

faiss.write_index(index, str(ARTIFACTS_DIR / 'faiss_index.bin'))
np.save(ARTIFACTS_DIR / 'item_embeddings.npy', item_embs)

print('indexed items:', index.ntotal)


## 5) 离线召回评估（过滤已看历史）


In [None]:
def evaluate_recall_at_k(
    user_hist: Dict[int, List[int]],
    valid_last: pd.DataFrame,
    topk: int = 20,
):
    # 构建评估样本：每个用户用 train_hist 的最后 MAX_SEQ_LEN 个点击作为输入
    users = valid_last['user_id'].astype(int).tolist()
    targets = valid_last['click_article_id'].astype(int).tolist()

    X_u_raw = []
    X_hist_raw = []
    y_raw = []
    for u, t in zip(users, targets):
        seq = user_hist.get(u)
        if not seq:
            continue
        X_u_raw.append(u)
        X_hist_raw.append(seq[-MAX_SEQ_LEN:])
        y_raw.append(t)

    X_u = user_id_map.transform(np.asarray(X_u_raw, dtype=np.int64))
    X_hist = pad_left([item_id_map.transform(np.asarray(s, dtype=np.int64)).tolist() for s in X_hist_raw], max_len=MAX_SEQ_LEN)

    user_embs = user_tower.predict({'user_id': X_u, 'hist_article_id': X_hist}, batch_size=4096, verbose=0).astype('float32')
    faiss.normalize_L2(user_embs)

    # 为了过滤历史，检索时多拿一些
    search_k = topk + MAX_SEQ_LEN + 10
    D, I = index.search(user_embs, search_k)

    hit = 0
    ndcg = 0.0
    for i in range(len(X_u_raw)):
        u = X_u_raw[i]
        hist_set = set(user_hist.get(u, []))
        recs = []
        for item_id in I[i].tolist():
            item_id = int(item_id)
            if item_id <= 0:
                continue
            if item_id in hist_set:
                continue
            recs.append(item_id)
            if len(recs) >= topk:
                break
        target = int(y_raw[i])
        if target in recs:
            hit += 1
            rank = recs.index(target)
            ndcg += 1.0 / np.log2(rank + 2)

    return {
        f'hit_rate@{topk}': hit / max(1, len(X_u_raw)),
        f'ndcg@{topk}': ndcg / max(1, len(X_u_raw)),
        'num_users_eval': len(X_u_raw),
    }


metrics = evaluate_recall_at_k(user_hist, valid_last, topk=20)
metrics


## 6) 生成召回候选集（给排序训练用）

输出列对齐基础版的 `recall_candidates.pkl`：

- `user_id`
- `article_id`
- `recall_score`
- `recall_rank`

后续在 `8.deep_ranking.ipynb` 里可以把该候选集与 ItemCF/热门等进行融合，或做 rerank 训练样本。


In [None]:
TOPK_CANDIDATES = 100
SEARCH_K = TOPK_CANDIDATES + MAX_SEQ_LEN + 10

rows = []
users_all = list(user_hist.keys())

batch_size = 4096
for start in tqdm(range(0, len(users_all), batch_size), desc='recall_all_users'):
    end = min(start + batch_size, len(users_all))
    u_raw_batch = users_all[start:end]
    hist_raw_batch = [user_hist[u][-MAX_SEQ_LEN:] for u in u_raw_batch]

    X_u = user_id_map.transform(np.asarray(u_raw_batch, dtype=np.int64))
    X_hist = pad_left([item_id_map.transform(np.asarray(s, dtype=np.int64)).tolist() for s in hist_raw_batch], max_len=MAX_SEQ_LEN)

    user_embs = user_tower.predict({'user_id': X_u, 'hist_article_id': X_hist}, batch_size=4096, verbose=0).astype('float32')
    faiss.normalize_L2(user_embs)
    D, I = index.search(user_embs, SEARCH_K)

    for local_i, u in enumerate(u_raw_batch):
        hist_set = set(user_hist.get(u, []))
        rank = 0
        for item_id, score in zip(I[local_i].tolist(), D[local_i].tolist()):
            item_id = int(item_id)
            if item_id <= 0:
                continue
            if item_id in hist_set:
                continue
            rank += 1
            rows.append((int(u), int(item_id), float(score), int(rank)))
            if rank >= TOPK_CANDIDATES:
                break

two_tower_recall_df = pd.DataFrame(rows, columns=['user_id', 'article_id', 'recall_score', 'recall_rank'])
out_path = PROJECT_PATH / 'recall_candidates_two_tower.pkl'
two_tower_recall_df.to_pickle(out_path)

out_path, two_tower_recall_df.head()


## 7) 保存模型与编码器


In [None]:
model.save(ARTIFACTS_DIR / 'two_tower_model.keras')
user_tower.save(ARTIFACTS_DIR / 'user_tower.keras')
item_tower.save(ARTIFACTS_DIR / 'item_tower.keras')

with open(ARTIFACTS_DIR / 'id_maps.pkl', 'wb') as f:
    pickle.dump({'user': user_id_map, 'item': item_id_map, 'cat': cat_id_map}, f)

with open(ARTIFACTS_DIR / 'metrics.pkl', 'wb') as f:
    pickle.dump(metrics, f)

print('saved to:', ARTIFACTS_DIR)
