# Quantum Data Selection - Experiment 1

**兆トークン規模のストリーミング・スケッチベース量子データ選択**

## 概要

Experiment 0 では 100 サンプルの原理検証を行った。  
本実験では、**1兆トークン規模**のデータセットに適用可能なアーキテクチャを設計・検証する。

### 課題: なぜ Experiment 0 はスケールしないか

| 問題 | Experiment 0 | 必要なスケール |
|---|---|---|
| QUBO サイズ | N=100 → 5,050 entries | N=10B docs → O(N²) は不可能 |
| Surprise 計算 | 全データをメモリに保持 | 兆トークンはメモリに収まらない |
| 多様性 | 未考慮 | 重複排除・カバレッジが必須 |
| 処理時間 | 数分 | 数千GPU時間が現実的上限 |

### 解決策: 3 層パイプライン

```
Raw Corpus (1T tokens)
    │
    ▼ Pass 1: Streaming Surprise + Sketch
    │   - チャンク単位でストリーミング処理
    │   - MinHash LSH で近似重複検出
    │   - SimHash で多様性ベクトル生成
    │   - 各ドキュメントに (surprise, minhash, simhash) タプル付与
    │
    ▼ Pass 2: Shard-Local QUBO
    │   - データを S シャードに分割 (各 ~1M docs)
    │   - 各シャード内で小規模 QUBO を量子アニーリングで解く
    │   - シャード当たり K_local 個を選択
    │
    ▼ Pass 3: Global Merge QUBO
        - 各シャードの選択結果 (S × K_local) を集約
        - グローバル QUBO で最終 K_global 個を選択
        - SimHash 多様性項でカバレッジを最大化
```

## 実行時間: 15-30分

## 必要: D-Wave APIトークン

## セル1: インストール

In [None]:
!pip install transformers datasets dwave-ocean-sdk torch matplotlib seaborn xxhash mmh3 -q

## セル2: インポートと定数

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import time
import hashlib
import struct
from collections import defaultdict
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from datasets import load_dataset
from dwave.system import LeapHybridSampler
import dimod
import warnings
warnings.filterwarnings('ignore')

# --- Scaling constants (for theoretical analysis) ---
TRILLION = 1_000_000_000_000
AVG_TOKENS_PER_DOC = 500
TOTAL_DOCS_1T = TRILLION // AVG_TOKENS_PER_DOC  # 2B documents

# --- Demo constants (for actual execution) ---
N_SAMPLES = 2000          # Simulate with 2K docs
N_SHARDS = 4              # Number of shards
K_LOCAL = 25              # Selections per shard
K_GLOBAL = 20             # Final global selections
MINHASH_PERMS = 128       # MinHash signature length
SIMHASH_BITS = 64         # SimHash fingerprint bits
LSH_BANDS = 16            # LSH band count
LSH_ROWS = MINHASH_PERMS // LSH_BANDS  # Rows per band = 8

print("All imports successful")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"\nTheoretical scale: {TOTAL_DOCS_1T:,} documents ({TRILLION/1e12:.0f}T tokens)")
print(f"Demo scale: {N_SAMPLES:,} documents, {N_SHARDS} shards")

## セル3: D-Wave API 接続

In [None]:
import os

# os.environ['DWAVE_API_TOKEN'] = 'your-token-here'

try:
    sampler = LeapHybridSampler()
    print("D-Wave API connection successful")
    print(f"  Solver: {sampler.solver.name}")
    USE_QUANTUM = True
except Exception as e:
    print(f"D-Wave API connection failed: {e}")
    print("Falling back to simulated annealing (classical)")
    USE_QUANTUM = False

## セル4: データ準備

WikiText-103 から 2,000 サンプルをロード。  
本番環境では The Pile / RedPajama / FineWeb 等の兆トークンコーパスを想定。

In [None]:
print("Loading WikiText-103 dataset...")
dataset = load_dataset("wikitext", "wikitext-103-raw-v1", split="train")

# 50文字以上のテキストのみ
texts_raw = [x['text'] for x in dataset if len(x['text'].strip()) > 50]
texts = texts_raw[:N_SAMPLES]

print(f"Loaded {len(texts)} text samples")
print(f"Total characters: {sum(len(t) for t in texts):,}")
print(f"Average length: {np.mean([len(t) for t in texts]):.0f} chars")

---

## Part 1: Streaming Surprise 計算

### スケーリング戦略

兆トークン規模では全データをメモリに保持できない。  
ストリーミング処理で各ドキュメントを 1 パスで処理し、surprise を計算する。

```
ストリーム入力 → チャンク分割 → Proxy Model 推論 → surprise 値出力
                    │                    │
                    └── 固定メモリ ──────┘
```

**本番での最適化:**
- Proxy Model: DistilGPT-2 (82M params) — 大規模モデルの 1/10 コスト
- バッチ推論: GPU 上で 256-512 doc/batch
- 推定スループット: A100 1台で ~50K docs/sec → 2B docs ≈ 11 時間

In [None]:
print("Loading DistilGPT-2 proxy model...")
model = GPT2LMHeadModel.from_pretrained("distilgpt2")
tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
tokenizer.pad_token = tokenizer.eos_token

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()

print(f"Model loaded on {device}")
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
def compute_surprise_batch(texts_batch, model, tokenizer, device, max_length=128):
    """
    バッチ単位で surprise を計算。
    ストリーミング処理ではこの関数をチャンクごとに呼び出す。

    Returns: list of float (surprise per document)
    """
    inputs = tokenizer(
        texts_batch,
        return_tensors="pt",
        truncation=True,
        max_length=max_length,
        padding="max_length"
    )
    inputs = {k: v.to(device) for k, v in inputs.items()}

    with torch.no_grad():
        # Per-token loss, then average per document
        outputs = model(**inputs, labels=inputs["input_ids"])
        # outputs.loss is the mean over all tokens in the batch.
        # For per-document loss, we need to compute manually.
        logits = outputs.logits[:, :-1, :]  # (B, T-1, V)
        labels = inputs["input_ids"][:, 1:]  # (B, T-1)

        # Create attention mask for non-pad tokens
        attn = inputs["attention_mask"][:, 1:]  # (B, T-1)

        # Per-token cross-entropy
        loss_fn = torch.nn.CrossEntropyLoss(reduction='none')
        per_token_loss = loss_fn(
            logits.reshape(-1, logits.size(-1)),
            labels.reshape(-1)
        ).reshape(labels.shape)  # (B, T-1)

        # Mask padding and average per document
        masked_loss = per_token_loss * attn
        doc_lengths = attn.sum(dim=1).clamp(min=1)
        doc_surprises = (masked_loss.sum(dim=1) / doc_lengths).cpu().numpy()

    return doc_surprises.tolist()


def streaming_surprise(texts, model, tokenizer, device, batch_size=32):
    """
    ストリーミング surprise 計算。
    メモリ使用量は batch_size に比例し、データセットサイズに依存しない。
    """
    all_surprises = []
    n_batches = (len(texts) + batch_size - 1) // batch_size

    t0 = time.time()
    for i in range(0, len(texts), batch_size):
        batch = texts[i:i + batch_size]
        batch_surprises = compute_surprise_batch(batch, model, tokenizer, device)
        all_surprises.extend(batch_surprises)

        batch_idx = i // batch_size + 1
        if batch_idx % 10 == 0 or batch_idx == n_batches:
            elapsed = time.time() - t0
            docs_per_sec = len(all_surprises) / elapsed
            print(f"  Batch {batch_idx}/{n_batches} | "
                  f"{len(all_surprises)}/{len(texts)} docs | "
                  f"{docs_per_sec:.0f} docs/sec")

    elapsed = time.time() - t0
    return np.array(all_surprises), elapsed


print("Computing streaming surprises...")
surprises, surprise_time = streaming_surprise(texts, model, tokenizer, device)

print(f"\nSurprise computation complete in {surprise_time:.1f}s")
print(f"  Throughput: {len(texts) / surprise_time:.0f} docs/sec")
print(f"  Mean surprise: {surprises.mean():.4f}")
print(f"  Std surprise:  {surprises.std():.4f}")

# Scaling projection
docs_per_sec = len(texts) / surprise_time
gpu_speedup = 50  # Estimated A100 vs CPU speedup
projected_a100_dps = docs_per_sec * gpu_speedup
hours_for_2B = TOTAL_DOCS_1T / projected_a100_dps / 3600
print(f"\n--- Scaling Projection ---")
print(f"  Current throughput: {docs_per_sec:.0f} docs/sec")
print(f"  Projected A100 throughput: {projected_a100_dps:,.0f} docs/sec")
print(f"  Time for {TOTAL_DOCS_1T/1e9:.0f}B docs on 1x A100: {hours_for_2B:.0f} hours")
print(f"  Time for {TOTAL_DOCS_1T/1e9:.0f}B docs on 64x A100: {hours_for_2B/64:.1f} hours")

---

## Part 2: MinHash LSH (近似重複検出)

### なぜ必要か

兆トークンのウェブコーパスには大量の重複・近似重複が存在する。  
同じ情報を持つドキュメントを複数選択することは無駄であり、  
多様性を損なう。MinHash LSH で O(1) の近似重複検出を実現する。

### アルゴリズム

1. 各ドキュメントを n-gram (shingle) に分解
2. 各 shingle をハッシュし、128 個のハッシュ関数の最小値を取る → MinHash 署名
3. 署名を 16 バンドに分割し、各バンドのハッシュが一致するペアを近似重複候補とする
4. Jaccard 類似度の閾値 (0.5) で重複判定

**メモリ効率:** 署名は 128 × 4 bytes = 512 bytes/doc → 2B docs で ~1TB  
**本番最適化:** Spark / MapReduce で分散処理、LSH インデックスはシャード分割

In [None]:
def text_to_shingles(text, k=5):
    """テキストを k-gram (shingle) セットに変換"""
    text = text.lower().strip()
    if len(text) < k:
        return set()
    return set(text[i:i+k] for i in range(len(text) - k + 1))


def minhash_signature(shingles, n_perms=MINHASH_PERMS, seed=42):
    """
    MinHash 署名を計算。
    n_perms 個の独立ハッシュ関数を使い、各ハッシュの最小値を署名とする。

    本番では mmh3 や xxhash を使用して高速化。
    ここでは再現性のため hashlib ベースで実装。
    """
    if not shingles:
        return np.zeros(n_perms, dtype=np.uint32)

    signature = np.full(n_perms, np.iinfo(np.uint32).max, dtype=np.uint32)

    for shingle in shingles:
        shingle_bytes = shingle.encode('utf-8')
        for i in range(n_perms):
            # Hash(shingle || seed || perm_index)
            h = hashlib.md5(shingle_bytes + struct.pack('<II', seed, i)).digest()
            val = struct.unpack('<I', h[:4])[0]
            if val < signature[i]:
                signature[i] = val

    return signature


def lsh_buckets(signature, n_bands=LSH_BANDS):
    """
    LSH バンディング: 署名を n_bands 個のバンドに分割し、
    各バンドのハッシュをバケットキーとする。

    同じバケットに入るペアは近似重複候補。
    """
    rows_per_band = len(signature) // n_bands
    buckets = []
    for b in range(n_bands):
        band = signature[b * rows_per_band : (b + 1) * rows_per_band]
        bucket_key = hashlib.md5(band.tobytes()).hexdigest()
        buckets.append((b, bucket_key))
    return buckets


def estimated_jaccard(sig_a, sig_b):
    """MinHash 署名から Jaccard 類似度を推定"""
    return np.mean(sig_a == sig_b)


print("Computing MinHash signatures...")
t0 = time.time()

signatures = []
shingle_counts = []
for i, text in enumerate(texts):
    shingles = text_to_shingles(text, k=5)
    sig = minhash_signature(shingles)
    signatures.append(sig)
    shingle_counts.append(len(shingles))

    if (i + 1) % 500 == 0:
        print(f"  {i+1}/{N_SAMPLES} signatures computed")

minhash_time = time.time() - t0
print(f"\nMinHash signatures computed in {minhash_time:.1f}s")
print(f"  Throughput: {N_SAMPLES / minhash_time:.0f} docs/sec")
print(f"  Avg shingles per doc: {np.mean(shingle_counts):.0f}")
print(f"  Signature size: {MINHASH_PERMS * 4} bytes per doc")

In [None]:
print("Building LSH index and detecting near-duplicates...")
t0 = time.time()

# Build LSH index: band -> bucket -> list of doc indices
lsh_index = defaultdict(lambda: defaultdict(list))
for doc_idx, sig in enumerate(signatures):
    for band_id, bucket_key in lsh_buckets(sig):
        lsh_index[band_id][bucket_key].append(doc_idx)

# Find candidate pairs (docs sharing at least one bucket)
candidate_pairs = set()
for band_id in lsh_index:
    for bucket_key, doc_indices in lsh_index[band_id].items():
        if len(doc_indices) > 1:
            for i in range(len(doc_indices)):
                for j in range(i + 1, len(doc_indices)):
                    pair = (min(doc_indices[i], doc_indices[j]),
                            max(doc_indices[i], doc_indices[j]))
                    candidate_pairs.add(pair)

# Verify candidates with Jaccard threshold
JACCARD_THRESHOLD = 0.5
duplicate_pairs = []
for i, j in candidate_pairs:
    jaccard = estimated_jaccard(signatures[i], signatures[j])
    if jaccard >= JACCARD_THRESHOLD:
        duplicate_pairs.append((i, j, jaccard))

# Build duplicate clusters (union-find)
parent = list(range(N_SAMPLES))
def find(x):
    while parent[x] != x:
        parent[x] = parent[parent[x]]
        x = parent[x]
    return x
def union(a, b):
    ra, rb = find(a), find(b)
    if ra != rb:
        parent[ra] = rb

for i, j, _ in duplicate_pairs:
    union(i, j)

# Cluster analysis
clusters = defaultdict(list)
for i in range(N_SAMPLES):
    clusters[find(i)].append(i)

dup_clusters = {k: v for k, v in clusters.items() if len(v) > 1}

# Duplicate mask: keep only cluster representative (highest surprise)
is_duplicate = np.zeros(N_SAMPLES, dtype=bool)
for cluster_id, members in dup_clusters.items():
    # Keep the member with highest surprise
    best = max(members, key=lambda idx: surprises[idx])
    for m in members:
        if m != best:
            is_duplicate[m] = True

lsh_time = time.time() - t0

print(f"\nLSH deduplication complete in {lsh_time:.1f}s")
print(f"  Candidate pairs: {len(candidate_pairs):,}")
print(f"  Confirmed duplicates: {len(duplicate_pairs):,}")
print(f"  Duplicate clusters: {len(dup_clusters):,}")
print(f"  Documents removed: {is_duplicate.sum()} ({is_duplicate.mean()*100:.1f}%)")
print(f"  Remaining documents: {(~is_duplicate).sum()}")

if duplicate_pairs:
    print(f"\n  Example duplicate pair:")
    i, j, jac = duplicate_pairs[0]
    print(f"    Doc {i}: '{texts[i][:80]}...'")
    print(f"    Doc {j}: '{texts[j][:80]}...'")
    print(f"    Jaccard similarity: {jac:.3f}")

---

## Part 3: SimHash (多様性フィンガープリント)

### 目的

QUBO の多様性項に使用する。2 つのドキュメント間の「内容的距離」を  
SimHash の Hamming 距離で近似する。

- Hamming 距離が大きい → 内容が異なる → 多様性が高い
- Hamming 距離が小さい → 内容が類似 → 冗長

### SimHash アルゴリズム

1. テキストを特徴量（n-gram）に分解
2. 各特徴量をハッシュし、64ビットベクトルに展開（0→-1, 1→+1）
3. 全特徴量の加重和を取り、各ビットの符号で最終フィンガープリントを決定

**メモリ効率:** 8 bytes/doc → 2B docs で ~16GB

In [None]:
def simhash(text, n_bits=SIMHASH_BITS, k=3):
    """
    SimHash フィンガープリントを計算。

    各 k-gram をハッシュし、ビットごとの加重和の符号を取る。
    結果は n_bits ビットの整数。
    """
    text = text.lower().strip()
    if len(text) < k:
        return 0

    v = np.zeros(n_bits, dtype=np.float64)

    for i in range(len(text) - k + 1):
        gram = text[i:i+k]
        h = int(hashlib.md5(gram.encode('utf-8')).hexdigest(), 16)
        for bit in range(n_bits):
            if (h >> bit) & 1:
                v[bit] += 1.0
            else:
                v[bit] -= 1.0

    # Convert to fingerprint
    fingerprint = 0
    for bit in range(n_bits):
        if v[bit] > 0:
            fingerprint |= (1 << bit)

    return fingerprint


def hamming_distance(a, b, n_bits=SIMHASH_BITS):
    """2つの SimHash 間の Hamming 距離"""
    xor = a ^ b
    return bin(xor).count('1')


def hamming_to_diversity(dist, n_bits=SIMHASH_BITS):
    """Hamming 距離を [0, 1] の多様性スコアに変換"""
    return dist / n_bits


print("Computing SimHash fingerprints...")
t0 = time.time()

simhashes = []
for i, text in enumerate(texts):
    sh = simhash(text)
    simhashes.append(sh)
    if (i + 1) % 500 == 0:
        print(f"  {i+1}/{N_SAMPLES} fingerprints computed")

simhash_time = time.time() - t0

print(f"\nSimHash computation complete in {simhash_time:.1f}s")
print(f"  Throughput: {N_SAMPLES / simhash_time:.0f} docs/sec")
print(f"  Fingerprint size: {SIMHASH_BITS // 8} bytes per doc")

# Sample pairwise distances
sample_pairs = [(i, j) for i in range(min(100, N_SAMPLES))
                       for j in range(i+1, min(100, N_SAMPLES))]
sample_dists = [hamming_distance(simhashes[i], simhashes[j])
                for i, j in sample_pairs[:500]]

print(f"\n  Sample pairwise Hamming distances (first 100 docs):")
print(f"    Mean: {np.mean(sample_dists):.1f} / {SIMHASH_BITS}")
print(f"    Std:  {np.std(sample_dists):.1f}")
print(f"    Min:  {np.min(sample_dists)}")
print(f"    Max:  {np.max(sample_dists)}")

---

## Part 4: 階層的 QUBO (Shard-Local → Global Merge)

### スケーリング問題と解決策

N=2B ドキュメントで直接 QUBO を構築すると O(N²) = O(4 × 10¹⁸) エントリ → 不可能。

**階層的アプローチ:**

```
2B docs ÷ 2000 shards = 1M docs/shard
  │
  ├── Shard 1: QUBO(1M) → select 500 (top-K pre-filter + QUBO)
  ├── Shard 2: QUBO(1M) → select 500
  ├── ...
  └── Shard 2000: QUBO(1M) → select 500
         │
         ▼
  Global: QUBO(1M candidates) → select 100K final
```

**各シャードの QUBO (拡張版):**

$$H_{\text{shard}} = -\alpha \sum_i S_i x_i + \beta \sum_{i<j} \text{dup}(i,j) x_i x_j - \delta \sum_{i<j} d_H(i,j) x_i x_j + \gamma \left(\sum_i x_i - K\right)^2$$

| 項 | 意味 | 効果 |
|---|---|---|
| $-\alpha S_i$ | Surprise 最大化 | 高情報価値を選択 |
| $\beta \cdot \text{dup}$ | 重複ペナルティ | MinHash で検出した重複を排除 |
| $-\delta \cdot d_H$ | 多様性ボーナス | SimHash 距離が大きいペアを優遇 |
| $\gamma (\sum - K)^2$ | カーディナリティ制約 | ちょうど K 個選択 |

In [None]:
def build_enhanced_qubo(surprises, signatures, simhashes, is_duplicate,
                        doc_indices, K,
                        alpha=1.0, beta=5.0, delta=0.3, gamma=10.0):
    """
    拡張 QUBO 行列を構築。

    Parameters
    ----------
    surprises : array - 各ドキュメントの surprise 値
    signatures : list - MinHash 署名
    simhashes : list - SimHash フィンガープリント
    is_duplicate : array - 重複フラグ
    doc_indices : list - このシャード内のドキュメントインデックス
    K : int - 選択数
    alpha : float - Surprise の重み
    beta : float - 重複ペナルティの重み
    delta : float - 多様性ボーナスの重み
    gamma : float - カーディナリティ制約の重み

    Returns
    -------
    Q : dict - QUBO 行列
    var_to_doc : dict - QUBO 変数 → ドキュメントインデックスのマッピング
    """
    # Filter out already-removed duplicates
    valid_docs = [idx for idx in doc_indices if not is_duplicate[idx]]
    N = len(valid_docs)
    var_to_doc = {var: doc for var, doc in enumerate(valid_docs)}

    Q = {}

    # Normalize surprises for this shard
    shard_surprises = np.array([surprises[var_to_doc[v]] for v in range(N)])
    if shard_surprises.std() > 0:
        norm_surprises = (shard_surprises - shard_surprises.mean()) / shard_surprises.std()
    else:
        norm_surprises = np.zeros(N)

    # Diagonal terms: -alpha * S_i + gamma * (1 - 2K)
    for v in range(N):
        Q[(v, v)] = -alpha * norm_surprises[v] + gamma * (1 - 2 * K)

    # Off-diagonal terms
    for vi in range(N):
        for vj in range(vi + 1, N):
            doc_i = var_to_doc[vi]
            doc_j = var_to_doc[vj]

            # Cardinality constraint
            val = 2 * gamma

            # Duplicate penalty (MinHash Jaccard > threshold)
            jaccard = estimated_jaccard(signatures[doc_i], signatures[doc_j])
            if jaccard > 0.3:  # Soft penalty starts at 0.3
                val += beta * jaccard

            # Diversity bonus (SimHash Hamming distance)
            h_dist = hamming_to_diversity(
                hamming_distance(simhashes[doc_i], simhashes[doc_j]))
            val -= delta * h_dist

            Q[(vi, vj)] = val

    return Q, var_to_doc


def solve_qubo(Q, use_quantum=True, label='shard'):
    """
    QUBO を解く。量子アニーリングが利用可能なら使用、なければ SA。
    """
    if use_quantum and USE_QUANTUM:
        sampler = LeapHybridSampler()
        response = sampler.sample_qubo(Q, label=label)
    else:
        bqm = dimod.BinaryQuadraticModel.from_qubo(Q)
        sampler = dimod.SimulatedAnnealingSampler()
        response = sampler.sample(bqm, num_reads=100, num_sweeps=1000)

    solution = response.first.sample
    energy = response.first.energy
    selected = [v for v, val in solution.items() if val == 1]
    return selected, energy


print("Enhanced QUBO builder ready")
print(f"  QUBO terms: surprise (alpha), dedup (beta), diversity (delta), constraint (gamma)")

## セル12: シャード分割と Shard-Local QUBO 実行

In [None]:
print(f"Splitting {N_SAMPLES} documents into {N_SHARDS} shards...")

# Assign documents to shards (round-robin; in production, use hash-based)
shard_assignments = [[] for _ in range(N_SHARDS)]
for i in range(N_SAMPLES):
    shard_assignments[i % N_SHARDS].append(i)

for s in range(N_SHARDS):
    n_valid = sum(1 for idx in shard_assignments[s] if not is_duplicate[idx])
    print(f"  Shard {s}: {len(shard_assignments[s])} docs ({n_valid} after dedup)")

print(f"\nRunning shard-local QUBO selection (K_local={K_LOCAL} per shard)...")
print(f"  Use quantum: {USE_QUANTUM}")
print()

shard_results = []
total_qubo_time = 0

for s in range(N_SHARDS):
    t0 = time.time()

    Q, var_to_doc = build_enhanced_qubo(
        surprises, signatures, simhashes, is_duplicate,
        doc_indices=shard_assignments[s],
        K=K_LOCAL,
        alpha=1.0, beta=5.0, delta=0.3, gamma=10.0
    )

    selected_vars, energy = solve_qubo(
        Q,
        use_quantum=USE_QUANTUM,
        label=f'QDS-Exp1-Shard{s}'
    )

    # Map QUBO variables back to document indices
    selected_docs = [var_to_doc[v] for v in selected_vars if v in var_to_doc]
    shard_time = time.time() - t0
    total_qubo_time += shard_time

    avg_surprise = surprises[selected_docs].mean() if selected_docs else 0

    print(f"  Shard {s}: selected {len(selected_docs)}/{K_LOCAL} docs | "
          f"energy={energy:.1f} | avg_surprise={avg_surprise:.4f} | "
          f"time={shard_time:.1f}s")

    shard_results.append({
        'shard': s,
        'selected_docs': selected_docs,
        'energy': energy,
        'avg_surprise': avg_surprise
    })

# Merge all shard selections
all_shard_selected = []
for r in shard_results:
    all_shard_selected.extend(r['selected_docs'])

print(f"\nShard-local selection complete")
print(f"  Total selected: {len(all_shard_selected)} docs")
print(f"  Total QUBO time: {total_qubo_time:.1f}s")
print(f"  Avg surprise (all shards): {surprises[all_shard_selected].mean():.4f}")

## セル13: Global Merge QUBO

各シャードの選択結果を集約し、最終的な K_global 個を選択する。  
この段階ではシャード間の多様性が特に重要。

In [None]:
print(f"Running global merge QUBO...")
print(f"  Input: {len(all_shard_selected)} candidates from {N_SHARDS} shards")
print(f"  Target: K_global = {K_GLOBAL}")

# For global merge, we want stronger diversity emphasis
t0 = time.time()

# Build a dummy is_duplicate array (no additional dedup at global level)
global_no_dup = np.zeros(N_SAMPLES, dtype=bool)

Q_global, var_to_doc_global = build_enhanced_qubo(
    surprises, signatures, simhashes, global_no_dup,
    doc_indices=all_shard_selected,
    K=K_GLOBAL,
    alpha=1.0, beta=3.0, delta=0.5, gamma=12.0  # Higher diversity weight at global level
)

global_selected_vars, global_energy = solve_qubo(
    Q_global,
    use_quantum=USE_QUANTUM,
    label='QDS-Exp1-GlobalMerge'
)

# Map back to document indices
global_selected = [var_to_doc_global[v] for v in global_selected_vars
                   if v in var_to_doc_global]

global_time = time.time() - t0

print(f"\nGlobal merge complete in {global_time:.1f}s")
print(f"  Selected: {len(global_selected)} docs")
print(f"  Energy: {global_energy:.1f}")
print(f"  Avg surprise: {surprises[global_selected].mean():.4f}")

---

## Part 5: ベースライン比較

3 つの選択手法を比較する:

1. **Quantum Hierarchical** (本手法): Surprise + MinHash + SimHash + 階層 QUBO
2. **Top-K Surprise** (貪欲法): Surprise 上位 K 個を選択（多様性なし）
3. **Random**: ランダムに K 個を選択

In [None]:
print("=" * 70)
print("BASELINE COMPARISON")
print("=" * 70)

# --- Baseline 1: Top-K Surprise (greedy) ---
# Remove duplicates first, then take top K by surprise
non_dup_indices = [i for i in range(N_SAMPLES) if not is_duplicate[i]]
sorted_by_surprise = sorted(non_dup_indices, key=lambda i: surprises[i], reverse=True)
topk_selected = sorted_by_surprise[:K_GLOBAL]

# --- Baseline 2: Random ---
N_RANDOM_TRIALS = 200
random_results = []
for _ in range(N_RANDOM_TRIALS):
    random_sel = np.random.choice(non_dup_indices, K_GLOBAL, replace=False)
    random_results.append({
        'avg_surprise': surprises[random_sel].mean(),
        'indices': random_sel
    })
random_avg_surprise = np.mean([r['avg_surprise'] for r in random_results])
random_std_surprise = np.std([r['avg_surprise'] for r in random_results])

# --- Compute diversity for each method ---
def compute_diversity(selected_indices):
    """選択されたドキュメント間の平均 SimHash 多様性"""
    if len(selected_indices) < 2:
        return 0.0
    total_dist = 0
    n_pairs = 0
    for i in range(len(selected_indices)):
        for j in range(i + 1, len(selected_indices)):
            total_dist += hamming_to_diversity(
                hamming_distance(
                    simhashes[selected_indices[i]],
                    simhashes[selected_indices[j]]))
            n_pairs += 1
    return total_dist / n_pairs if n_pairs > 0 else 0.0

quantum_diversity = compute_diversity(global_selected)
topk_diversity = compute_diversity(topk_selected)

# Random diversity (average over trials)
random_diversities = []
for r in random_results[:20]:  # Sample 20 trials for speed
    random_diversities.append(compute_diversity(r['indices'].tolist()))
random_avg_diversity = np.mean(random_diversities)

# --- Results table ---
quantum_surprise = surprises[global_selected].mean()
topk_surprise = surprises[topk_selected].mean()

print(f"\n{'Method':<30} {'Avg Surprise':>15} {'Diversity':>12} {'Count':>8}")
print("-" * 70)
print(f"{'Quantum Hierarchical':<30} {quantum_surprise:>15.4f} {quantum_diversity:>12.4f} {len(global_selected):>8}")
print(f"{'Top-K Surprise (greedy)':<30} {topk_surprise:>15.4f} {topk_diversity:>12.4f} {len(topk_selected):>8}")
print(f"{'Random (n=200 trials)':<30} {random_avg_surprise:>15.4f} {random_avg_diversity:>12.4f} {K_GLOBAL:>8}")
print(f"{'  Random std':<30} {'+/-' + f'{random_std_surprise:.4f}':>15}")

# --- Composite score: balances surprise and diversity ---
# Normalize to [0,1] range based on observed values
all_surp = [quantum_surprise, topk_surprise, random_avg_surprise]
surp_min, surp_max = min(all_surp), max(all_surp)
surp_range = surp_max - surp_min if surp_max > surp_min else 1.0

all_div = [quantum_diversity, topk_diversity, random_avg_diversity]
div_min, div_max = min(all_div), max(all_div)
div_range = div_max - div_min if div_max > div_min else 1.0

def composite_score(surprise, diversity, w_surp=0.6, w_div=0.4):
    norm_s = (surprise - surp_min) / surp_range
    norm_d = (diversity - div_min) / div_range
    return w_surp * norm_s + w_div * norm_d

print(f"\n{'Composite Score (0.6×Surprise + 0.4×Diversity)':}")
print(f"  Quantum:  {composite_score(quantum_surprise, quantum_diversity):.4f}")
print(f"  Top-K:    {composite_score(topk_surprise, topk_diversity):.4f}")
print(f"  Random:   {composite_score(random_avg_surprise, random_avg_diversity):.4f}")

---

## Part 6: 可視化

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(20, 12))

# --- Plot 1: Surprise distribution with selections ---
ax = axes[0, 0]
ax.hist(surprises, bins=50, alpha=0.5, color='gray', label='All docs')
ax.hist(surprises[global_selected], bins=20, alpha=0.7, color='red',
        label=f'Quantum ({len(global_selected)})')
ax.hist(surprises[topk_selected], bins=20, alpha=0.5, color='green',
        label=f'Top-K ({len(topk_selected)})')
ax.set_xlabel('Surprise')
ax.set_ylabel('Frequency')
ax.set_title('Surprise Distribution')
ax.legend(fontsize=9)
ax.grid(True, alpha=0.3)

# --- Plot 2: Per-shard selection quality ---
ax = axes[0, 1]
shard_labels = [f'Shard {r["shard"]}' for r in shard_results]
shard_avg_surprises = [r['avg_surprise'] for r in shard_results]
shard_counts = [len(r['selected_docs']) for r in shard_results]
colors = plt.cm.viridis(np.linspace(0.2, 0.8, N_SHARDS))
bars = ax.bar(shard_labels, shard_avg_surprises, color=colors)
for bar, count in zip(bars, shard_counts):
    ax.text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.01,
            f'n={count}', ha='center', va='bottom', fontsize=9)
ax.axhline(quantum_surprise, color='red', linestyle='--', alpha=0.7,
           label=f'Global avg: {quantum_surprise:.3f}')
ax.set_ylabel('Avg Surprise')
ax.set_title('Shard-Local Selection Quality')
ax.legend(fontsize=9)
ax.grid(True, alpha=0.3)

# --- Plot 3: Diversity comparison ---
ax = axes[0, 2]
methods = ['Quantum\nHierarchical', 'Top-K\nSurprise', 'Random']
diversities = [quantum_diversity, topk_diversity, random_avg_diversity]
bar_colors = ['red', 'green', 'blue']
ax.bar(methods, diversities, color=bar_colors, alpha=0.7)
ax.set_ylabel('Diversity (avg pairwise SimHash distance)')
ax.set_title('Selection Diversity')
ax.grid(True, alpha=0.3)

# --- Plot 4: Surprise vs Diversity scatter ---
ax = axes[1, 0]
ax.scatter(quantum_surprise, quantum_diversity, c='red', s=200, marker='*',
           zorder=5, label='Quantum')
ax.scatter(topk_surprise, topk_diversity, c='green', s=200, marker='s',
           zorder=5, label='Top-K')
# Random trials as a cloud
for r, d in zip(random_results[:50], random_diversities[:50] if len(random_diversities) >= 50
                else random_diversities):
    ax.scatter(r['avg_surprise'], d, c='blue', s=20, alpha=0.3)
ax.scatter(random_avg_surprise, random_avg_diversity, c='blue', s=200, marker='o',
           zorder=5, label='Random (avg)')
ax.set_xlabel('Average Surprise')
ax.set_ylabel('Diversity')
ax.set_title('Surprise-Diversity Pareto')
ax.legend(fontsize=9)
ax.grid(True, alpha=0.3)

# --- Plot 5: Document selection map ---
ax = axes[1, 1]
quantum_set = set(global_selected)
topk_set = set(topk_selected)
colors_map = []
for i in range(N_SAMPLES):
    if i in quantum_set and i in topk_set:
        colors_map.append('purple')
    elif i in quantum_set:
        colors_map.append('red')
    elif i in topk_set:
        colors_map.append('green')
    elif is_duplicate[i]:
        colors_map.append('orange')
    else:
        colors_map.append('lightgray')
ax.scatter(range(N_SAMPLES), surprises, c=colors_map, s=15, alpha=0.6)
ax.set_xlabel('Document index')
ax.set_ylabel('Surprise')
ax.set_title('Selection Map (red=quantum, green=top-k, purple=both, orange=dup)')
ax.grid(True, alpha=0.3)

# --- Plot 6: Random baseline distribution ---
ax = axes[1, 2]
random_surp_list = [r['avg_surprise'] for r in random_results]
ax.hist(random_surp_list, bins=30, alpha=0.7, color='blue', label='Random trials')
ax.axvline(quantum_surprise, color='red', linestyle='--', linewidth=2,
           label=f'Quantum: {quantum_surprise:.4f}')
ax.axvline(topk_surprise, color='green', linestyle='--', linewidth=2,
           label=f'Top-K: {topk_surprise:.4f}')
ax.set_xlabel('Avg Surprise')
ax.set_ylabel('Frequency')
ax.set_title('Random Baseline Distribution')
ax.legend(fontsize=9)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('experiment1_results.png', dpi=150, bbox_inches='tight')
print("Visualization saved: experiment1_results.png")
plt.show()

---

## Part 7: 兆トークン規模のスケーリング分析

In [None]:
print("=" * 70)
print("SCALING ANALYSIS: Experiment 0 → 1 → Production")
print("=" * 70)

# Measured timings from this experiment
measured = {
    'n_docs': N_SAMPLES,
    'surprise_time': surprise_time,
    'minhash_time': minhash_time,
    'simhash_time': simhash_time,
    'lsh_time': lsh_time,
    'qubo_time': total_qubo_time + global_time,
}

print(f"\n--- Measured Timings (N={N_SAMPLES:,} docs) ---")
print(f"  Surprise computation:  {measured['surprise_time']:>8.1f}s")
print(f"  MinHash signatures:    {measured['minhash_time']:>8.1f}s")
print(f"  SimHash fingerprints:  {measured['simhash_time']:>8.1f}s")
print(f"  LSH deduplication:     {measured['lsh_time']:>8.1f}s")
print(f"  QUBO solving:          {measured['qubo_time']:>8.1f}s")
total_measured = sum(measured[k] for k in ['surprise_time', 'minhash_time',
                                            'simhash_time', 'lsh_time', 'qubo_time'])
print(f"  Total:                 {total_measured:>8.1f}s")

# Scaling projections
scales = [
    ('Experiment 0', 100, 1),
    ('Experiment 1 (demo)', N_SAMPLES, 1),
    ('Medium scale', 1_000_000, 1),
    ('Large scale', 100_000_000, 64),
    ('Production (1T tokens)', TOTAL_DOCS_1T, 256),
]

print(f"\n--- Scaling Projections ---")
print(f"{'Scale':<25} {'N docs':>15} {'GPUs':>6} {'Pass 1':>12} {'Pass 2':>12} {'Pass 3':>12} {'Total':>12}")
print("-" * 100)

for name, n_docs, n_gpus in scales:
    scale_factor = n_docs / N_SAMPLES

    # Pass 1: Streaming (linear, parallelizable)
    pass1_sec = (measured['surprise_time'] + measured['minhash_time'] +
                 measured['simhash_time']) * scale_factor / n_gpus

    # Pass 2: Shard-local QUBO
    # Each shard is ~1M docs max, number of shards grows linearly
    n_shards_projected = max(1, n_docs // 1_000_000)
    # QUBO per shard is O(K_local^2) which is constant; LSH is ~O(N/shard)
    pass2_sec = (measured['qubo_time'] / N_SHARDS) * n_shards_projected / max(1, n_gpus // 4)
    pass2_sec += measured['lsh_time'] * scale_factor / n_gpus

    # Pass 3: Global merge QUBO (fixed size: S * K_local candidates)
    # QUBO size = O((S * K_local)^2)
    global_candidates = min(n_shards_projected * K_LOCAL, 100_000)
    pass3_sec = global_time * (global_candidates / len(all_shard_selected)) ** 2
    pass3_sec = min(pass3_sec, 300)  # D-Wave hybrid solver caps at ~5min

    total_sec = pass1_sec + pass2_sec + pass3_sec

    def fmt_time(s):
        if s < 60:
            return f"{s:.0f}s"
        elif s < 3600:
            return f"{s/60:.0f}min"
        elif s < 86400:
            return f"{s/3600:.1f}hr"
        else:
            return f"{s/86400:.1f}day"

    print(f"{name:<25} {n_docs:>15,} {n_gpus:>6} {fmt_time(pass1_sec):>12} "
          f"{fmt_time(pass2_sec):>12} {fmt_time(pass3_sec):>12} {fmt_time(total_sec):>12}")

print(f"\n--- Memory Requirements ---")
print(f"{'Component':<25} {'Per Doc':>10} {'2B docs':>15}")
print("-" * 55)
mem_items = [
    ('Surprise (float32)', 4, TOTAL_DOCS_1T * 4),
    ('MinHash sig (128*u32)', 512, TOTAL_DOCS_1T * 512),
    ('SimHash (u64)', 8, TOTAL_DOCS_1T * 8),
    ('Duplicate flag (bool)', 1, TOTAL_DOCS_1T),
    ('Shard metadata', 16, TOTAL_DOCS_1T * 16),
]
total_mem = 0
for comp_name, per_doc, total_bytes in mem_items:
    total_mem += total_bytes
    if total_bytes > 1e12:
        print(f"{comp_name:<25} {per_doc:>10} B {total_bytes/1e12:>12.1f} TB")
    else:
        print(f"{comp_name:<25} {per_doc:>10} B {total_bytes/1e9:>12.1f} GB")
print("-" * 55)
print(f"{'Total':<25} {'':>10} {total_mem/1e12:>12.1f} TB")
print(f"\n  Note: MinHash signatures dominate. Options to reduce:")
print(f"    - Use 64 perms instead of 128 -> {TOTAL_DOCS_1T * 256 / 1e12:.1f} TB")
print(f"    - Use uint16 hashes -> {TOTAL_DOCS_1T * 256 / 1e12:.1f} TB")
print(f"    - Shard-local LSH (no global MinHash storage needed)")

---

## Part 8: D-Wave QPU 使用量分析

In [None]:
print("=" * 70)
print("D-WAVE QPU USAGE PROJECTION")
print("=" * 70)

# D-Wave Leap pricing (approximate)
# Hybrid solver: ~$0.10 per minute of solver time
# Free tier: 20 minutes/month

print(f"\n--- Current Experiment ---")
n_qubo_calls = N_SHARDS + 1  # shards + global
avg_call_time = (total_qubo_time + global_time) / n_qubo_calls
print(f"  QUBO calls: {n_qubo_calls}")
print(f"  Avg call time: {avg_call_time:.1f}s")
print(f"  Total solver time: {total_qubo_time + global_time:.1f}s")

print(f"\n--- Production Projection (2B docs) ---")
prod_shards = TOTAL_DOCS_1T // 1_000_000  # 2000 shards
prod_qubo_calls = prod_shards + 1
prod_solver_minutes = prod_qubo_calls * avg_call_time / 60
prod_cost = prod_solver_minutes * 0.10

print(f"  Shards: {prod_shards:,}")
print(f"  QUBO calls: {prod_qubo_calls:,}")
print(f"  Estimated solver time: {prod_solver_minutes:,.0f} minutes ({prod_solver_minutes/60:.0f} hours)")
print(f"  Estimated cost: ${prod_cost:,.0f}")

print(f"\n--- Optimization Strategies ---")
print(f"  1. Pre-filter top-50% by surprise before QUBO (halve problem size)")
print(f"  2. Use SA for low-value shards, quantum only for competitive shards")
print(f"  3. Hierarchical merge: local→regional→global (3 levels instead of 2)")
print(f"  4. Cache QUBO solutions for similar surprise distributions")

# Optimized projection
opt_shards = prod_shards
pct_quantum = 0.1  # Only 10% of shards use quantum
opt_quantum_calls = int(opt_shards * pct_quantum) + 10  # 10 regional + global
opt_solver_min = opt_quantum_calls * avg_call_time / 60
opt_cost = opt_solver_min * 0.10
print(f"\n  Optimized (10% quantum shards):")
print(f"    Quantum calls: {opt_quantum_calls:,}")
print(f"    Solver time: {opt_solver_min:,.0f} minutes ({opt_solver_min/60:.0f} hours)")
print(f"    Cost: ${opt_cost:,.0f}")

---

## Part 9: まとめと次のステップ

In [None]:
print("=" * 70)
print("EXPERIMENT 1 COMPLETE")
print("=" * 70)

print(f"""
Key Results:
  1. Streaming surprise computation: {len(texts)/surprise_time:.0f} docs/sec
     → Projects to {TOTAL_DOCS_1T/1e9:.0f}B docs in ~{TOTAL_DOCS_1T / (len(texts)/surprise_time * 50) / 3600 / 64:.0f} hours on 64x A100

  2. MinHash LSH deduplication: {is_duplicate.sum()} duplicates removed ({is_duplicate.mean()*100:.1f}%)
     → O(1) per-document lookup, shard-parallelizable

  3. SimHash diversity: avg pairwise distance {quantum_diversity:.3f}
     → 8 bytes/doc fingerprint, enables QUBO diversity term

  4. Hierarchical QUBO: {N_SHARDS} shards → global merge
     → Quantum surprise: {quantum_surprise:.4f}
     → Top-K surprise:   {topk_surprise:.4f}
     → Random surprise:  {random_avg_surprise:.4f}

Architecture Validated:
  - 3-pass pipeline (Stream → Shard QUBO → Global Merge)
  - Each pass is independently parallelizable
  - QUBO size bounded by shard size, not total corpus
  - D-Wave hybrid solver handles shard-local optimization

Next Steps (Experiment 2):
  1. Downstream validation: train small LM on quantum-selected vs random data
  2. Measure perplexity/benchmark improvement
  3. Test with larger corpus (WikiText-103 full or C4 subset)
  4. Optimize QUBO parameters (alpha, beta, delta, gamma) via grid search
  5. Compare with D4RL / DSIR / other data selection baselines
""")