# 82. Bloomフィルタ（バンドプリフィルタ）をITQパイプラインに統合

## 目的
- 既存ITQパイプラインのPivot枝刈りの**前**にバンドベースのプリフィルタを追加
- DF-LSHのDAHBF概念をITQハッシュに適用

## パイプライン比較
```
現行:    ITQ hash → Pivot filter → Hamming top-K → Cosine rerank
提案:    ITQ hash → Band pre-filter → Pivot filter → Hamming top-K → Cosine rerank
```

## 0. セットアップ

In [1]:
import sys
import numpy as np
import time
from pathlib import Path
from sklearn.metrics.pairwise import cosine_similarity
import warnings
warnings.filterwarnings('ignore')

sys.path.insert(0, '../src')
from itq_lsh import ITQLSH, hamming_distance, hamming_distance_batch
from dflsh import build_band_index, band_filter, confidence_multiprobe, combined_band_pivot_filter

DATA_DIR = Path('../data')
np.random.seed(42)

N_QUERIES = 100
TOP_K = 10
print(f'Configuration: N_QUERIES={N_QUERIES}, TOP_K={TOP_K}')

Configuration: N_QUERIES=100, TOP_K=10


## 1. データロード

In [2]:
# 英語
en_emb = np.load(DATA_DIR / '10k_e5_base_en_embeddings.npy')
en_hashes = np.load(DATA_DIR / '10k_e5_base_en_hashes_128bits.npy')
en_pivot_dist = np.load(DATA_DIR / '10k_e5_base_en_pivot_distances.npy')
en_pivots = np.load(DATA_DIR / 'pivots_8_e5_base_en.npy')

# 日本語
ja_emb = np.load(DATA_DIR / '10k_e5_base_ja_embeddings.npy')
ja_hashes = np.load(DATA_DIR / '10k_e5_base_ja_hashes_128bits.npy')
ja_pivot_dist = np.load(DATA_DIR / '10k_e5_base_ja_pivot_distances.npy')
ja_pivots = np.load(DATA_DIR / 'pivots_8_e5_base_ja.npy')

# MiniLM
minilm_emb = np.load(DATA_DIR / '10k_minilm_embeddings.npy')
minilm_hashes = np.load(DATA_DIR / '10k_minilm_hashes_128bits.npy')
minilm_pivot_dist = np.load(DATA_DIR / '10k_minilm_pivot_distances.npy')
minilm_pivots = np.load(DATA_DIR / 'pivots_8_minilm.npy')

# ITQモデル
itq = ITQLSH.load(str(DATA_DIR / 'itq_e5_base_128bits.pkl'))
itq_minilm = ITQLSH.load(str(DATA_DIR / 'itq_minilm_128bits.pkl'))

print(f'English: {en_emb.shape}, Japanese: {ja_emb.shape}, MiniLM: {minilm_emb.shape}')

English: (10000, 768), Japanese: (10000, 768), MiniLM: (10000, 384)


## 2. 評価関数定義

In [3]:
def get_ground_truth(embeddings, qi, top_k=10):
    """コサイン類似度でGround Truthを取得"""
    cos_sims = cosine_similarity(embeddings[qi:qi+1], embeddings)[0]
    cos_sims[qi] = -1
    return set(np.argsort(cos_sims)[-top_k:])

def evaluate_pipeline(embeddings, hashes, pivot_distances, pivots, itq_model,
                      label, band_width=16, min_band_matches=1, pivot_threshold=20,
                      use_band=True, use_pivot=True, candidate_limit=500):
    """パイプライン評価"""
    rng = np.random.default_rng(42)
    query_indices = rng.choice(len(embeddings), N_QUERIES, replace=False)
    
    bi = build_band_index(hashes, band_width) if use_band else None
    
    filter_recalls = []
    band_counts = []
    pivot_counts = []
    final_recalls = []
    total_times = []
    band_times = []
    pivot_times = []
    
    for qi in query_indices:
        gt = get_ground_truth(embeddings, qi, TOP_K)
        query_hash = hashes[qi]
        
        start_total = time.time()
        
        # Stage 1: Band filter
        if use_band:
            start_band = time.time()
            band_cands = band_filter(query_hash, bi, band_width, min_matches=min_band_matches)
            band_cands = band_cands[band_cands != qi]
            band_times.append(time.time() - start_band)
            band_counts.append(len(band_cands))
        else:
            band_cands = np.arange(len(embeddings))
            band_cands = band_cands[band_cands != qi]
        
        # Stage 2: Pivot filter
        if use_pivot and len(band_cands) > 0:
            start_pivot = time.time()
            query_pivot_dists = np.array([
                hamming_distance(query_hash, p) for p in pivots
            ])
            cand_pivot_dists = pivot_distances[band_cands]
            mask = np.ones(len(band_cands), dtype=bool)
            for i in range(len(pivots)):
                lower = query_pivot_dists[i] - pivot_threshold
                upper = query_pivot_dists[i] + pivot_threshold
                mask &= (cand_pivot_dists[:, i] >= lower) & (cand_pivot_dists[:, i] <= upper)
            pivot_cands = band_cands[mask]
            pivot_times.append(time.time() - start_pivot)
            pivot_counts.append(len(pivot_cands))
        else:
            pivot_cands = band_cands
        
        # Filter recall
        filter_recalls.append(len(gt & set(pivot_cands)) / TOP_K)
        
        # Stage 3: Hamming sort + top candidates
        if len(pivot_cands) > 0:
            h_dists = hamming_distance_batch(query_hash, hashes[pivot_cands])
            top_idx = np.argsort(h_dists)[:candidate_limit]
            final_cands = pivot_cands[top_idx]
            
            # Stage 4: Cosine rerank
            cand_cos = cosine_similarity(embeddings[qi:qi+1], embeddings[final_cands])[0]
            top_in_cand = final_cands[np.argsort(cand_cos)[-TOP_K:]]
            final_recalls.append(len(gt & set(top_in_cand)) / TOP_K)
        else:
            final_recalls.append(0.0)
        
        total_times.append(time.time() - start_total)
    
    result = {
        'label': label,
        'filter_recall': np.mean(filter_recalls),
        'recall_at_k': np.mean(final_recalls),
        'total_time_ms': np.mean(total_times) * 1000,
    }
    
    if use_band:
        result['band_candidates'] = np.mean(band_counts)
        result['band_reduction'] = 1 - np.mean(band_counts) / len(embeddings)
        result['band_time_ms'] = np.mean(band_times) * 1000
    if use_pivot:
        result['pivot_candidates'] = np.mean(pivot_counts)
        result['pivot_reduction'] = 1 - np.mean(pivot_counts) / (np.mean(band_counts) if use_band else len(embeddings))
        result['pivot_time_ms'] = np.mean(pivot_times) * 1000
    
    return result

## 3. バンドフィルタ単体評価

In [4]:
print('='*70)
print('Band Filter Only (no Pivot)')
print('='*70)

band_results = []
for bw in [8, 16, 32]:
    for min_m in [1, 2]:
        r = evaluate_pipeline(
            en_emb, en_hashes, en_pivot_dist, en_pivots, itq,
            label=f'EN bw={bw} min={min_m}',
            band_width=bw, min_band_matches=min_m,
            use_band=True, use_pivot=False
        )
        band_results.append(r)

print(f'\n{"Config":<25} {"Candidates":>12} {"Reduction":>10} {"FilterRecall":>13} {"Recall@10":>10}')
print('-'*75)
for r in band_results:
    print(f'{r["label"]:<25} {r["band_candidates"]:>10.0f} '
          f'{r["band_reduction"]*100:>9.1f}% '
          f'{r["filter_recall"]*100:>12.1f}% '
          f'{r["recall_at_k"]*100:>9.1f}%')

Band Filter Only (no Pivot)



Config                      Candidates  Reduction  FilterRecall  Recall@10
---------------------------------------------------------------------------
EN bw=8 min=1                   2181      78.2%         68.9%      66.0%
EN bw=8 min=2                    292      97.1%         34.8%      34.7%
EN bw=16 min=1                    33      99.7%          7.9%       7.8%
EN bw=16 min=2                     3     100.0%          1.7%       1.6%
EN bw=32 min=1                     0     100.0%          1.5%       1.4%
EN bw=32 min=2                     0     100.0%          1.4%       1.3%


## 4. Pivotフィルタ単体ベースライン

In [5]:
print('='*70)
print('Pivot Filter Only (Baseline)')
print('='*70)

pivot_results = []
for pt in [15, 20, 25]:
    r = evaluate_pipeline(
        en_emb, en_hashes, en_pivot_dist, en_pivots, itq,
        label=f'EN pivot_t={pt}',
        pivot_threshold=pt, use_band=False, use_pivot=True
    )
    pivot_results.append(r)

print(f'\n{"Config":<25} {"Candidates":>12} {"Reduction":>10} {"FilterRecall":>13} {"Recall@10":>10}')
print('-'*75)
for r in pivot_results:
    print(f'{r["label"]:<25} {r["pivot_candidates"]:>10.0f} '
          f'{(1-r["pivot_candidates"]/10000)*100:>9.1f}% '
          f'{r["filter_recall"]*100:>12.1f}% '
          f'{r["recall_at_k"]*100:>9.1f}%')

Pivot Filter Only (Baseline)



Config                      Candidates  Reduction  FilterRecall  Recall@10
---------------------------------------------------------------------------
EN pivot_t=15                   6317      36.8%         90.1%      78.1%
EN pivot_t=20                   9055       9.4%         99.2%      84.2%
EN pivot_t=25                   9825       1.7%         99.9%      84.0%


## 5. Band + Pivot 統合評価

In [6]:
print('='*70)
print('Band + Pivot Combined Filter')
print('='*70)

combined_results = []

for bw in [8, 16]:
    for pt in [15, 20, 25]:
        r = evaluate_pipeline(
            en_emb, en_hashes, en_pivot_dist, en_pivots, itq,
            label=f'EN bw={bw} pt={pt}',
            band_width=bw, min_band_matches=1,
            pivot_threshold=pt, use_band=True, use_pivot=True
        )
        combined_results.append(r)

print(f'\n{"Config":<20} {"Band→":>8} {"→Pivot":>8} {"Reduction":>10} {"FilterRecall":>13} {"Recall@10":>10} {"Time(ms)":>10}')
print('-'*85)
for r in combined_results:
    total_reduction = 1 - r['pivot_candidates'] / 10000
    print(f'{r["label"]:<20} {r["band_candidates"]:>7.0f} {r["pivot_candidates"]:>7.0f} '
          f'{total_reduction*100:>9.1f}% '
          f'{r["filter_recall"]*100:>12.1f}% '
          f'{r["recall_at_k"]*100:>9.1f}% '
          f'{r["total_time_ms"]:>9.2f}')

Band + Pivot Combined Filter



Config                  Band→   →Pivot  Reduction  FilterRecall  Recall@10   Time(ms)
-------------------------------------------------------------------------------------
EN bw=8 pt=15           2181    1511      84.9%         62.9%      61.3%      3.47
EN bw=8 pt=20           2181    2039      79.6%         68.4%      66.1%      4.33
EN bw=8 pt=25           2181    2158      78.4%         68.8%      66.0%      4.43
EN bw=16 pt=15            33      26      99.7%          7.7%       7.6%      0.94
EN bw=16 pt=20            33      32      99.7%          7.9%       7.9%      1.03
EN bw=16 pt=25            33      33      99.7%          7.9%       7.9%      1.13


## 6. 日本語データでの評価

In [7]:
print('='*70)
print('Japanese Data Evaluation')
print('='*70)

ja_results = []

# Pivot only
for pt in [15, 20, 25]:
    r = evaluate_pipeline(
        ja_emb, ja_hashes, ja_pivot_dist, ja_pivots, itq,
        label=f'JA pivot_t={pt}',
        pivot_threshold=pt, use_band=False, use_pivot=True
    )
    ja_results.append(r)

# Band + Pivot
for bw in [8, 16]:
    for pt in [20, 25]:
        r = evaluate_pipeline(
            ja_emb, ja_hashes, ja_pivot_dist, ja_pivots, itq,
            label=f'JA bw={bw} pt={pt}',
            band_width=bw, min_band_matches=1,
            pivot_threshold=pt, use_band=True, use_pivot=True
        )
        ja_results.append(r)

print(f'\n{"Config":<25} {"Candidates":>12} {"FilterRecall":>13} {"Recall@10":>10}')
print('-'*65)
for r in ja_results:
    cands = r.get('pivot_candidates', r.get('band_candidates', 10000))
    print(f'{r["label"]:<25} {cands:>10.0f} '
          f'{r["filter_recall"]*100:>12.1f}% '
          f'{r["recall_at_k"]*100:>9.1f}%')

Japanese Data Evaluation



Config                      Candidates  FilterRecall  Recall@10
-----------------------------------------------------------------
JA pivot_t=15                   3842         87.4%      86.9%
JA pivot_t=20                   6956         98.1%      96.7%
JA pivot_t=25                   8738         99.9%      98.0%
JA bw=8 pt=20                    560         63.0%      63.0%
JA bw=8 pt=25                    646         63.8%      63.8%
JA bw=16 pt=20                     5          3.6%       3.6%
JA bw=16 pt=25                     5          3.7%       3.7%


## 7. MiniLM検証

In [8]:
print('='*70)
print('MiniLM Verification')
print('='*70)

minilm_results = []

# Pivot only
r = evaluate_pipeline(
    minilm_emb, minilm_hashes, minilm_pivot_dist, minilm_pivots, itq_minilm,
    label='MiniLM pivot_t=20',
    pivot_threshold=20, use_band=False, use_pivot=True
)
minilm_results.append(r)

# Band + Pivot
for bw in [8, 16]:
    r = evaluate_pipeline(
        minilm_emb, minilm_hashes, minilm_pivot_dist, minilm_pivots, itq_minilm,
        label=f'MiniLM bw={bw} pt=20',
        band_width=bw, min_band_matches=1,
        pivot_threshold=20, use_band=True, use_pivot=True
    )
    minilm_results.append(r)

print(f'\n{"Config":<25} {"Candidates":>12} {"FilterRecall":>13} {"Recall@10":>10}')
print('-'*65)
for r in minilm_results:
    cands = r.get('pivot_candidates', r.get('band_candidates', 10000))
    print(f'{r["label"]:<25} {cands:>10.0f} '
          f'{r["filter_recall"]*100:>12.1f}% '
          f'{r["recall_at_k"]*100:>9.1f}%')

MiniLM Verification



Config                      Candidates  FilterRecall  Recall@10
-----------------------------------------------------------------
MiniLM pivot_t=20               8025         96.7%      95.2%
MiniLM bw=8 pt=20                557         49.2%      49.2%
MiniLM bw=16 pt=20                 3          3.2%       3.2%


## 8. サマリー

In [9]:
print('='*80)
print('Bloom Filter (Band Pre-filter) Integration Summary')
print('='*80)

print('\n【結論】')
print('バンドプリフィルタをPivot枝刈りの前に追加することで：')
print('- Pivotフィルタの対象ドキュメント数を削減')
print('- 2段フィルタの候補数とFilter Recallのトレードオフを評価')
print('\n詳細な数値は上記の各セクションを参照。')
print('実験83（Confidence Multi-probe）と組み合わせて実験84で総合評価。')

Bloom Filter (Band Pre-filter) Integration Summary

【結論】
バンドプリフィルタをPivot枝刈りの前に追加することで：
- Pivotフィルタの対象ドキュメント数を削減
- 2段フィルタの候補数とFilter Recallのトレードオフを評価

詳細な数値は上記の各セクションを参照。
実験83（Confidence Multi-probe）と組み合わせて実験84で総合評価。
