# Prototype Step 2 验证

本 notebook 旨在验证以下能力：
1. 基于已有 prototype checkpoint 预计算概念信号并缓存
2. `ConceptGraphDataset` 能读取缓存并返回 `concept_signals`
3. 缓存与在线计算的结果一致


## 0. 环境准备
- 建议在项目根目录的虚拟环境中运行（与训练环境一致），确保 `torch`、`open_clip` 等依赖可用。
- Notebook 位于 `notebooks/`，需要把项目根加入 `sys.path`。


In [1]:
import sys
from pathlib import Path

REPO_ROOT = Path('..').resolve()
if str(REPO_ROOT) not in sys.path:
    sys.path.append(str(REPO_ROOT))
print(f"Repo root: {REPO_ROOT}")


Repo root: /Users/xieyantong/Documents/F25_GenAI/Final_project/MyVLM_art


## 1. 配置关键路径
根据实际情况修改以下路径：
- `DATASET_JSON`: WikiArt JSON
- `IMAGES_ROOT`: 图片根目录
- `PROTOTYPE_CKPT`: 已训练的 prototype checkpoint
- `SIGNALS_CACHE`: 缓存输出路径
- `DIMENSIONS`: 想要处理的概念维度


In [2]:
DATASET_JSON = REPO_ROOT / 'data/dataset/wikiart_5artists_dataset.json'
IMAGES_ROOT = REPO_ROOT / 'data/dataset'
PROTOTYPE_CKPT = REPO_ROOT / 'outputs/prototypes/artist_prototypes.pt'
SIGNALS_CACHE = REPO_ROOT / 'artifacts/concept_signals_artist.json'
DIMENSIONS = ['artist']

print(DATASET_JSON)
print(IMAGES_ROOT)
print(PROTOTYPE_CKPT)
print(SIGNALS_CACHE)
print('Dimensions:', DIMENSIONS)


/Users/xieyantong/Documents/F25_GenAI/Final_project/MyVLM_art/data/dataset/wikiart_5artists_dataset.json
/Users/xieyantong/Documents/F25_GenAI/Final_project/MyVLM_art/data/dataset
/Users/xieyantong/Documents/F25_GenAI/Final_project/MyVLM_art/outputs/prototypes/artist_prototypes.pt
/Users/xieyantong/Documents/F25_GenAI/Final_project/MyVLM_art/artifacts/concept_signals_artist.json
Dimensions: ['artist']


## 2. 调用脚本生成概念信号缓存
通过 `subprocess.run` 包装 `scripts/precompute_prototype_signals.py`，并打印输出结果。


In [4]:
import subprocess
import os

env = os.environ.copy()
if 'PYTHONPATH' in env:
    env['PYTHONPATH'] = f"{REPO_ROOT}:{env['PYTHONPATH']}"
else:
    env['PYTHONPATH'] = str(REPO_ROOT)

cmd = [
    sys.executable,
    str(REPO_ROOT / 'scripts/precompute_prototype_signals.py'),
    str(DATASET_JSON),
    str(IMAGES_ROOT),
    str(PROTOTYPE_CKPT),
    str(SIGNALS_CACHE),
]
cmd += ['--dimensions'] + DIMENSIONS
cmd += ['--device', 'cpu', '--batch-size', '4', '--chunk-size', '16', '--precision', 'fp32']
print('Running command:')
print(' '.join(cmd))
proc = subprocess.run(cmd, capture_output=True, text=True, env=env)
print(proc.stdout)
print(proc.stderr)
assert proc.returncode == 0, 'Precompute script failed'
print(f'Cache saved to {SIGNALS_CACHE}')


Running command:
/opt/miniconda3/envs/myvlm/bin/python /Users/xieyantong/Documents/F25_GenAI/Final_project/MyVLM_art/scripts/precompute_prototype_signals.py /Users/xieyantong/Documents/F25_GenAI/Final_project/MyVLM_art/data/dataset/wikiart_5artists_dataset.json /Users/xieyantong/Documents/F25_GenAI/Final_project/MyVLM_art/data/dataset /Users/xieyantong/Documents/F25_GenAI/Final_project/MyVLM_art/outputs/prototypes/artist_prototypes.pt /Users/xieyantong/Documents/F25_GenAI/Final_project/MyVLM_art/artifacts/concept_signals_artist.json --dimensions artist --device cpu --batch-size 4 --chunk-size 16 --precision fp32
Processing dimension 'artist' with 5 concepts...
Saved signals for 175 images to /Users/xieyantong/Documents/F25_GenAI/Final_project/MyVLM_art/artifacts/concept_signals_artist.json


  with amp.autocast(enabled=self.precision == "fp16"):

artist chunks:   9%|▉         | 1/11 [00:18<03:02, 18.20s/it]
artist chunks:  18%|█▊        | 2/11 [00:36<02:45, 18.34s/it]
artist chunks:  2

## 3. 检查缓存结构
加载 JSON，查看 meta 信息和部分样本，确认结构正确。


In [5]:
import json

with open(SIGNALS_CACHE, 'r') as f:
    cache_payload = json.load(f)
print('Meta:', cache_payload.get('meta', {}))
signals = cache_payload.get('signals', cache_payload)
print('Total cached samples:', len(signals))
first_items = list(signals.items())[:3]
for idx, val in first_items:
    print(idx, 'dimensions:', list(val.keys()))


Meta: {'dataset_json': '/Users/xieyantong/Documents/F25_GenAI/Final_project/MyVLM_art/data/dataset/wikiart_5artists_dataset.json', 'images_root': '/Users/xieyantong/Documents/F25_GenAI/Final_project/MyVLM_art/data/dataset', 'prototype_ckpt': '/Users/xieyantong/Documents/F25_GenAI/Final_project/MyVLM_art/outputs/prototypes/artist_prototypes.pt', 'dimensions': ['artist']}
Total cached samples: 175
0 dimensions: ['artist']
1 dimensions: ['artist']
2 dimensions: ['artist']


## 4. 验证 ConceptGraphDataset 读取缓存
使用 `precomputed_signals_path` 初始化 dataset，检查 `concept_signals`。


In [6]:
from concept_graph.datasets.concept_graph_dataset import ConceptGraphDataset

dataset = ConceptGraphDataset(
    dataset_path=DATASET_JSON,
    images_root=IMAGES_ROOT,
    precomputed_signals_path=SIGNALS_CACHE,
    prototype_head=None,
    transforms=None,
)
print('Dataset length:', len(dataset))
sample = dataset[0]
print('Sample image path:', sample['image_path'])
signals = sample['concept_signals']
print('Available dimensions:', list(signals.keys()))
artist_signals = signals.get('artist')
print('Number of artist concepts:', len(artist_signals))
first_idx = sorted(artist_signals.keys())[0]
print('Example concept idx:', first_idx)
print('Scores tensor:', artist_signals[first_idx])


Dataset length: 175
Sample image path: /Users/xieyantong/Documents/F25_GenAI/Final_project/MyVLM_art/data/dataset/van_gogh/van_gogh_0.jpg
Available dimensions: ['artist']
Number of artist concepts: 5
Example concept idx: 0
Scores tensor: tensor([0.2564, 0.7436])


## 5. 对比在线计算与缓存
可选：加载 `PrototypeHead` 在线计算一次，与缓存结果对比。


In [7]:
from concept_graph.prototypes.prototype_head import PrototypeHead

head = PrototypeHead(device='cpu', precision='fp32', batch_size=4)
head.load_prototypes(PROTOTYPE_CKPT)
inline = head.extract_signal([sample['image_path']], dimension='artist')
inline_scores = inline[Path(sample['image_path'])]
print('Cache tensor:', artist_signals[first_idx])
print('Inline tensor:', inline_scores[first_idx])


  from .autonotebook import tqdm as notebook_tqdm
  with amp.autocast(enabled=self.precision == "fp16"):


Cache tensor: tensor([0.2564, 0.7436])
Inline tensor: tensor([0.2564, 0.7436])
