# Prototype MVP 测试

逐步验证 `build_base_samples`、`PrototypeDataset`、`PrototypeHead` 与 `PrototypeTrainer`。


In [1]:
import sys
from pathlib import Path

repo_root = Path("..").resolve()
if str(repo_root) not in sys.path:
    sys.path.append(str(repo_root))

In [29]:
from pathlib import Path

from concept_graph.prototypes.prototype_dataset import build_base_samples, PrototypeDataset
from concept_graph.prototypes.prototype_head import PrototypeHead
from concept_graph.prototypes.prototype_trainer import PrototypeTrainer, PrototypeTrainerConfig

In [30]:
DATASET_JSON = Path("../data/dataset/wikiart_5artists_dataset.json")
DIMENSIONS = ["artist"]  # 缩小范围，加快验证
DIMENSION = "artist"
SAVE_DIR = Path("../outputs/prototypes")
SAVE_DIR.mkdir(parents=True, exist_ok=True)


In [31]:
base_samples = build_base_samples(DATASET_JSON, dimensions=DIMENSIONS)
# base_samples = base_samples[:]  # 仅取前 20 张，加快原型构建
print(f"Loaded {len(base_samples)} samples (subset). Example keys: {base_samples[0].keys()}")

Loaded 175 samples (subset). Example keys: dict_keys(['image_path', 'image', 'labels_per_dim'])


In [32]:
concept_head = PrototypeHead(device="cpu", precision="fp32", batch_size=4)
dataset = PrototypeDataset(base_samples, concept_head.preprocess, dimension=DIMENSION)
print(f"Dataset size: {len(dataset)} | Sample label: {dataset[0]['label']}")


Dataset size: 175 | Sample label: vincent-van-gogh


In [33]:
from collections import defaultdict

concept_to_paths = {DIMENSION: defaultdict(list)}
for sample in base_samples:
    label = sample["labels_per_dim"].get(DIMENSION)
    if label is None:
        continue
    concept_to_paths[DIMENSION][label].append(sample["image_path"])


In [36]:
prototypes = concept_head.build_prototypes(concept_to_paths,
                                            save_path=SAVE_DIR / f"{DIMENSION}_prototypes.pt")
print(f"Built {len(prototypes[DIMENSION])} prototypes for dimension '{DIMENSION}'.")


Built 5 prototypes for dimension 'artist'.


In [37]:
test_paths = [sample["image_path"] for sample in base_samples[:4]]
signals = concept_head.extract_signal(test_paths, dimension=DIMENSION)
for path, score_dict in signals.items():
    top_idx = max(score_dict, key=lambda k: score_dict[k][1].item())
    concept = concept_head.idx_to_concept[DIMENSION][top_idx]
    print(path.name, concept, score_dict[top_idx][1].item())


van_gogh_0.jpg vincent-van-gogh 0.7435899376869202
van_gogh_1.jpg vincent-van-gogh 0.8398466110229492
van_gogh_2.jpg vincent-van-gogh 0.7753981947898865
van_gogh_3.jpg claude-monet 0.7478918433189392


In [38]:
from torch.utils.data import DataLoader

def prototype_collate(batch):
    return {
        "image_anchor": torch.stack([b["image_anchor"] for b in batch]),
        "label": [b["label"] for b in batch],
        "labels_per_dim": [b["labels_per_dim"] for b in batch],
        "image_path": [str(b["image_path"]) for b in batch],
    }

dataloader = DataLoader(dataset, batch_size=2, collate_fn=prototype_collate, shuffle=True, drop_last=True)

In [40]:
import torch

trainer_cfg = PrototypeTrainerConfig(
    dimension=DIMENSION,
    batch_size=2,
    epochs=1,
    lr=5e-3,
    temperature=0.07,
    save_path=str(SAVE_DIR / f"{DIMENSION}_prototypes_finetuned.pt")
)
trainer = PrototypeTrainer(trainer_cfg, dataset, concept_head)

trainer.dataloader = DataLoader(
    dataset,
    batch_size=trainer_cfg.batch_size,
    collate_fn=prototype_collate,
    shuffle=True,
    drop_last=True
)
training_log = trainer.train()
training_log


{'epoch_0_loss': 0.28563610400105344}

In [41]:
finetuned_signals = concept_head.extract_signal(test_paths, dimension=DIMENSION)
for path, score_dict in finetuned_signals.items():
    top_idx = max(score_dict, key=lambda k: score_dict[k][1].item())
    concept = concept_head.idx_to_concept[DIMENSION][top_idx]
    print("[Finetuned]", path.name, concept, score_dict[top_idx][1].item())


  with amp.autocast(enabled=self.precision == "fp16"):


[Finetuned] van_gogh_0.jpg vincent-van-gogh 0.3491225242614746
[Finetuned] van_gogh_1.jpg vincent-van-gogh 0.408365398645401
[Finetuned] van_gogh_2.jpg vincent-van-gogh 0.29018086194992065
[Finetuned] van_gogh_3.jpg vincent-van-gogh 0.3063909709453583
