# Subject-Driven Generation Metrics Evaluation

##### We have a dataset of 30 distinct “subjects” (objects vs. live subjects/pets), split into:
##### - **Real images**: stored under `data/<subject_name>/…`
##### - **Generated images**: stored under `results/{non-ppl,ppl}/<subject_name>/…`

##### We also have:
##### - `data/subjects.csv` with columns:
#####     - `subject_name` (matches each folder name)
#####    - `class`        (e.g. “dog”, “backpack”, etc.)
#####     - `live`         (boolean: True for pets, False for objects)
##### - `data/prompts.csv` with columns:
#####     - `prompt` (templates containing `{0}` → “sks”, `{1}` → the `class` value)
#####     - `live`   (boolean: whether this prompt applies to live subjects or objects)


##### **Evaluation protocol**:
##### - We generated up to **4** samples per prompt in `ppl`, but only **2** for `non-ppl`.
##### - Metrics:
#####     1. **PRES** (avg pairwise DINO similarity real↔gen)
#####     2. **DIV**  (avg pairwise LPIPS distance among gen images)
#####     3. **CLIP-I** (avg cosine between CLIP image embeddings real↔gen)
#####     4. **CLIP-T** (avg cosine between CLIP text embeddings vs. gen images)

##### This notebook will loop over every `(condition, subject)`, compute all four metrics, and tabulate the results.

In [1]:
import os
import pandas as pd
from metrics import clip_embeddings, div, pres

In [2]:
subjects_df = pd.read_csv('../data/subjects.csv')
prompts_df  = pd.read_csv('../data/prompts.csv')

REAL_ROOT  = '../data'
GEN_ROOT  = '../results'
CONDITIONS = ['no_ppl', 'ppl']

In [6]:
results = []
for cond in CONDITIONS:
    data_root = os.path.join(REAL_ROOT, cond)
    res_root = os.path.join(GEN_ROOT, cond)
    for _, row in subjects_df.iterrows():
        subject = row['subject_name']

        real_dir = os.path.join(data_root, subject)
        gen_dir  = os.path.join(res_root, subject)
        if not os.path.isdir(real_dir) or not os.path.isdir(gen_dir):
            continue 

        preservation   = pres.collect_pres(real_dir, gen_dir)
        div_score = None
        clip_i, clip_t = None, None #clip_embeddings.collect_clip_metrics(real_dir, gen_dir, prompts_df['prompt_text'].tolist())

        results.append({
            'condition': cond,
            'subject':   subject,
            'PRES':      preservation,
            'DIV':       div_score,
            'CLIP-I':    clip_i,
            'CLIP-T':    clip_t
        })

        print(f'{cond} {subject} {preservation:.4f}')

results_df = pd.DataFrame(results)

Using cache found in /Users/ignazioperez/.cache/torch/hub/facebookresearch_dino_main


no_ppl backpack 0.6328


Using cache found in /Users/ignazioperez/.cache/torch/hub/facebookresearch_dino_main


no_ppl backpack_dog 0.6952


Using cache found in /Users/ignazioperez/.cache/torch/hub/facebookresearch_dino_main


no_ppl bear_plushie 0.8157


Using cache found in /Users/ignazioperez/.cache/torch/hub/facebookresearch_dino_main


no_ppl berry_bowl 0.6374


Using cache found in /Users/ignazioperez/.cache/torch/hub/facebookresearch_dino_main


no_ppl can 0.6469


Using cache found in /Users/ignazioperez/.cache/torch/hub/facebookresearch_dino_main


no_ppl candle 0.6265


Using cache found in /Users/ignazioperez/.cache/torch/hub/facebookresearch_dino_main


no_ppl cat 0.8191


Using cache found in /Users/ignazioperez/.cache/torch/hub/facebookresearch_dino_main


no_ppl cat2 0.7594


Using cache found in /Users/ignazioperez/.cache/torch/hub/facebookresearch_dino_main


no_ppl clock 0.5880


Using cache found in /Users/ignazioperez/.cache/torch/hub/facebookresearch_dino_main


no_ppl colorful_sneaker 0.8492


Using cache found in /Users/ignazioperez/.cache/torch/hub/facebookresearch_dino_main


no_ppl dog 0.7842


Using cache found in /Users/ignazioperez/.cache/torch/hub/facebookresearch_dino_main


no_ppl dog2 0.8136


Using cache found in /Users/ignazioperez/.cache/torch/hub/facebookresearch_dino_main


no_ppl dog3 0.6044


Using cache found in /Users/ignazioperez/.cache/torch/hub/facebookresearch_dino_main


no_ppl dog5 0.7488


Using cache found in /Users/ignazioperez/.cache/torch/hub/facebookresearch_dino_main


no_ppl dog6 0.7938


Using cache found in /Users/ignazioperez/.cache/torch/hub/facebookresearch_dino_main


no_ppl dog7 0.7045


Using cache found in /Users/ignazioperez/.cache/torch/hub/facebookresearch_dino_main


no_ppl dog8 0.7596


Using cache found in /Users/ignazioperez/.cache/torch/hub/facebookresearch_dino_main


no_ppl duck_toy 0.7295


Using cache found in /Users/ignazioperez/.cache/torch/hub/facebookresearch_dino_main


no_ppl fancy_boot 0.7509


Using cache found in /Users/ignazioperez/.cache/torch/hub/facebookresearch_dino_main


no_ppl grey_sloth_plushie 0.7829


Using cache found in /Users/ignazioperez/.cache/torch/hub/facebookresearch_dino_main


no_ppl monster_toy 0.4212


Using cache found in /Users/ignazioperez/.cache/torch/hub/facebookresearch_dino_main


no_ppl pink_sunglasses 0.6201


Using cache found in /Users/ignazioperez/.cache/torch/hub/facebookresearch_dino_main


no_ppl poop_emoji 0.6943


Using cache found in /Users/ignazioperez/.cache/torch/hub/facebookresearch_dino_main


no_ppl rc_car 0.6491


Using cache found in /Users/ignazioperez/.cache/torch/hub/facebookresearch_dino_main


no_ppl red_cartoon 0.5061


Using cache found in /Users/ignazioperez/.cache/torch/hub/facebookresearch_dino_main


no_ppl robot_toy 0.6106


Using cache found in /Users/ignazioperez/.cache/torch/hub/facebookresearch_dino_main


no_ppl shiny_sneaker 0.8460


Using cache found in /Users/ignazioperez/.cache/torch/hub/facebookresearch_dino_main


no_ppl teapot 0.5380


Using cache found in /Users/ignazioperez/.cache/torch/hub/facebookresearch_dino_main


no_ppl vase 0.6630


Using cache found in /Users/ignazioperez/.cache/torch/hub/facebookresearch_dino_main


KeyboardInterrupt: 

In [None]:
results_df.to_csv("metric_results.csv", index=False)

KeyError: 'subject'

In [None]:
merged = (
    results_df
    .merge(subjects_df, left_on='subject', right_on='subject_name')
    .drop(columns=['subject_name']) 
)

pivot = (
    merged
    .pivot_table(
       index='class',
       columns='condition',
       values=['PRES','DIV','CLIP-I','CLIP-T'],
       aggfunc='mean'
    )
)
print(pivot)

pivot.to_csv("average_metric_results.csv", index=False)