# counterfactual generation with CLIP sampling

In [None]:
from huggingface_hub import login
from diffusers import StableDiffusion3Pipeline
import torch



In [None]:
!pip install -q diffusers==0.36.0 transformers accelerate lpips

In [None]:
# ============================================================
# 0. Imports + basic setup
# ============================================================
import os, glob, random
from dataclasses import dataclass

import torch
from PIL import Image
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from diffusers import StableDiffusion3Img2ImgPipeline
from transformers import CLIPModel, CLIPProcessor

from google.colab import drive
drive.mount("/content/drive")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

In [None]:
# ============================================================
# 1. Paths & global config
# ============================================================
BASE_DIR        = "/content/drive/MyDrive/thesis2"
NEUTRAL_DIR     = f"{BASE_DIR}/classifier_dataset/neutral"
CF_OUT_DIR      = f"{BASE_DIR}/phase4_clip_sampling_images"
CSV_CF_PATH     = f"{BASE_DIR}/phase4_clip_sampling_metadata.csv"

os.makedirs(CF_OUT_DIR, exist_ok=True)

In [None]:

# SD settings
HEIGHT  = 768
WIDTH   = 768
NUM_STEPS = 30
GUIDANCE = 7.0
STRENGTH = 0.55         # how strong the transformation is in img2img

N_SOURCE = 5            # how many neutral sources
N_RAW_CF = 12           # how many raw CF per source before CLIP selection
N_KEEP_CF = 6           # how many we keep after CLIP scoring

BASE_SEED = 12345

# ============================================================
# 2. Load Stable Diffusion 3.5 **img2img** pipeline
# ============================================================


model_id = "stabilityai/stable-diffusion-3.5-medium"

pipe = StableDiffusion3Img2ImgPipeline.from_pretrained(
    model_id,
    token=HF_TOKEN,
    torch_dtype=torch.float16,
)
pipe = pipe.to(device)

# light memory optimizations
pipe.enable_attention_slicing()
pipe.set_progress_bar_config(disable=True)

print("SD3.5 Img2Img pipeline loaded.")

# ============================================================
# 3. Load CLIP model (for semantic scoring)
# ============================================================
clip_model_id = "openai/clip-vit-base-patch32"
clip_model = CLIPModel.from_pretrained(clip_model_id).to(device)
clip_processor = CLIPProcessor.from_pretrained(clip_model_id)
clip_model.eval()

@torch.no_grad()
def clip_text_embeddings(texts):
    """Return L2-normalized CLIP text embeddings for a list of strings."""
    inputs = clip_processor(text=texts, images=None, return_tensors="pt", padding=True).to(device)
    text_feats = clip_model.get_text_features(**inputs)
    text_feats = text_feats / text_feats.norm(dim=-1, keepdim=True)
    return text_feats

neutral_text = "portrait photo of a person with a neutral facial expression"
smile_text   = "portrait photo of a person with a big visible smile, teeth showing, joyful expression"

txt_embs = clip_text_embeddings([neutral_text, smile_text])
txt_neutral_emb = txt_embs[0:1]
txt_smile_emb   = txt_embs[1:2]

@torch.no_grad()
def clip_smile_score(img: Image.Image) -> float:
    """
    CLIP-based 'smile score':
      score = sim(image, smile_text) - sim(image, neutral_text)
    Higher score => more smiley.
    """
    inputs = clip_processor(text=None, images=img, return_tensors="pt").to(device)
    img_feats = clip_model.get_image_features(**inputs)
    img_feats = img_feats / img_feats.norm(dim=-1, keepdim=True)

    sim_smile   = (img_feats @ txt_smile_emb.T).item()
    sim_neutral = (img_feats @ txt_neutral_emb.T).item()
    return float(sim_smile - sim_neutral)

print("CLIP model loaded.")

# ============================================================
# 4. Prompts
# ============================================================
BASE_FACE_PROMPT = (
    "a photorealistic portrait of a human face, studio lighting, high resolution, "
    "natural skin texture, realistic anatomy, professional photography, symmetric face, looking forward"
)

NEUTRAL_PROMPT = (
    BASE_FACE_PROMPT +
    ", neutral facial expression, relaxed mouth, closed lips, no smile"
)

SMILE_PROMPT = (
    BASE_FACE_PROMPT +
    ", big, bright smile, teeth clearly visible, joyful and expressive, cheeks raised, eyes slightly squinting"
)

print("Neutral prompt:\n ", NEUTRAL_PROMPT)
print("Smile prompt:\n ", SMILE_PROMPT)

# ============================================================
# 5. Build df_sources from neutral dataset (existing SD images)
# ============================================================
neutral_paths = sorted(
    glob.glob(os.path.join(NEUTRAL_DIR, "*.png")) +
    glob.glob(os.path.join(NEUTRAL_DIR, "*.jpg")) +
    glob.glob(os.path.join(NEUTRAL_DIR, "*.jpeg"))
)

if len(neutral_paths) == 0:
    raise RuntimeError(f"No images found in {NEUTRAL_DIR}")

neutral_paths = neutral_paths[:N_SOURCE]

import glob, os
import pandas as pd

CUSTOM_SOURCE_DIR = "/content/drive/MyDrive/thesis2/custom_sources"

source_paths = sorted(glob.glob(os.path.join(CUSTOM_SOURCE_DIR, "*")))
print("Found", len(source_paths), "source images")

df_sources = pd.DataFrame({
    "source_id": list(range(len(source_paths))),
    "source_path": source_paths
})

df_sources.head()


# ============================================================
# 6. Helper: generate CLIP-sampled smiling CFs for ONE source
# ============================================================
@dataclass
class CFEntry:
    source_id: int
    cf_rank: int
    raw_idx: int
    seed: int
    clip_score: float
    cf_path: str

@torch.no_grad()
def generate_cfs_for_source(source_id: int, source_path: str) -> list:
    """
    For a single neutral source image, generate N_RAW_CF smiling img2img candidates,
    score them with CLIP, keep the top N_KEEP_CF with highest 'smile score'.
    """
    # load + resize source
    src_img = Image.open(source_path).convert("RGB")
    src_img = src_img.resize((WIDTH, HEIGHT), Image.LANCZOS)

    candidates = []

    for raw_idx in range(N_RAW_CF):
        seed = BASE_SEED + source_id * 1000 + raw_idx
        gen = torch.Generator(device=device).manual_seed(seed)

        out = pipe(
            prompt=SMILE_PROMPT,
            image=src_img,
            strength=STRENGTH,
            num_inference_steps=NUM_STEPS,
            guidance_scale=GUIDANCE,
            generator=gen,
        )

        cf_img = out.images[0]
        score = clip_smile_score(cf_img)

        candidates.append((raw_idx, seed, score, cf_img))

    # sort by CLIP smile score descending
    candidates.sort(key=lambda x: x[2], reverse=True)
    selected = candidates[:N_KEEP_CF]

    entries = []
    for rank, (raw_idx, seed, score, cf_img) in enumerate(selected):
        fname = f"src{source_id:03d}_cf{rank:02d}_raw{raw_idx:02d}_seed{seed}_clip.png"
        cf_path = os.path.join(CF_OUT_DIR, fname)
        cf_img.save(cf_path)

        entries.append(CFEntry(
            source_id=source_id,
            cf_rank=rank,
            raw_idx=raw_idx,
            seed=seed,
            clip_score=score,
            cf_path=cf_path,
        ))

    return entries




In [None]:
# ============================================================
# 7. Run CLIP sampling for all sources
# ============================================================
all_entries = []

for _, row in tqdm(df_sources.iterrows(), total=len(df_sources), desc="Sources"):
    sid  = int(row["source_id"])
    spath = row["source_path"]
    print(f"\n=== Source {sid} ===")
    print(spath)

    entries = generate_cfs_for_source(sid, spath)
    all_entries.extend(entries)

rows = []
for e in all_entries:
    rows.append({
        "source_id": e.source_id,
        "cf_rank": e.cf_rank,
        "raw_idx": e.raw_idx,
        "seed": e.seed,
        "clip_smile_score": e.clip_score,
        "cf_path": e.cf_path,
    })

df_cf = pd.DataFrame(rows)
df_cf.to_csv(CSV_CF_PATH, index=False)
print("\nSaved CLIP-sampled CF metadata to:", CSV_CF_PATH)
display(df_cf.head())



In [None]:
# ============================================================
# 8. Quick visual sanity check
# ============================================================
from IPython.display import display as ipy_display

def show_source_and_cfs(source_id: int, max_show: int = 4):
    """Show one source image and a few of its top CLIP-sampled counterfactuals."""
    src_row = df_sources.loc[df_sources["source_id"] == source_id].iloc[0]
    src_img = Image.open(src_row["source_path"]).convert("RGB").resize((WIDTH, HEIGHT))

    print(f"\nSource {source_id}:", src_row["source_path"])
    ipy_display(src_img)

    sub = df_cf[df_cf["source_id"] == source_id].sort_values("cf_rank").head(max_show)
    for _, r in sub.iterrows():
        cf_img = Image.open(r["cf_path"]).convert("RGB")
        print(f"CF rank={r['cf_rank']}  score={r['clip_smile_score']:.3f}")
        ipy_display(cf_img)

# Example usage:
show_source_and_cfs(0, max_show=4)

In [None]:
# See a few sources + their top smiling CFs
show_source_and_cfs(0, max_show=6)
show_source_and_cfs(1, max_show=6)
show_source_and_cfs(2, max_show=6)


# Analysis

In [None]:
!pip install -q diffusers==0.36.0 transformers lpips scikit-image

In [None]:
import os, glob, itertools
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import numpy as np
import pandas as pd
from tqdm.auto import tqdm

from skimage.metrics import structural_similarity as ssim
import lpips

from transformers import CLIPModel, CLIPProcessor

from google.colab import drive
drive.mount("/content/drive")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)


In [None]:
BASE_DIR     = "/content/drive/MyDrive/thesis2"
CF_CSV_PATH  = f"{BASE_DIR}/phase4_clip_sampling_metadata.csv"


SOURCES_DIR  = f"{BASE_DIR}/classifier_dataset/neutral"


CSV_PER_PATH      = f"{BASE_DIR}/phase4_cf_eval_per_image.csv"
CSV_PAIRWISE_PATH = f"{BASE_DIR}/phase4_cf_eval_pairwise.csv"


In [None]:
source_paths = sorted(
    glob.glob(os.path.join(SOURCES_DIR, "*.png")) +
    glob.glob(os.path.join(SOURCES_DIR, "*.jpg")) +
    glob.glob(os.path.join(SOURCES_DIR, "*.jpeg"))
)

df_sources = pd.DataFrame({
    "source_id": list(range(len(source_paths))),
    "source_path": source_paths,
})
print("Sources:")
df_sources.head()


In [None]:
df_cf = pd.read_csv(CF_CSV_PATH)
df_cf.head()


In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from torchvision import models
from PIL import Image

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

ROOT = "/content/drive/MyDrive/thesis2"
os.makedirs(ROOT, exist_ok=True)
CLASSIFIER_CKPT = f"{ROOT}/models/smile_classifier_best.pt"

IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD  = [0.229, 0.224, 0.225]

clf_transform = transforms.Compose([
    transforms.Resize((224, 224)),  # NO central cropping
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
])


class SmileResNet18(nn.Module):
    def __init__(self, num_classes: int = 2):
        super().__init__()
        backbone = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
        in_features = backbone.fc.in_features
        backbone.fc = nn.Linear(in_features, num_classes)
        self.backbone = backbone

    def forward(self, x):
        return self.backbone(x)

# build model on device
classifier = SmileResNet18(num_classes=2).to(device)

# load checkpoint onto same device
state = torch.load(CLASSIFIER_CKPT, map_location=device)
missing, unexpected = classifier.load_state_dict(state, strict=False)
print("Loaded classifier checkpoint.")
print("Missing keys:", missing)
print("Unexpected keys:", unexpected)

classifier.eval()
print("Classifier param device:", next(classifier.parameters()).device)

@torch.no_grad()
def classifier_prob_smile(pil_img: Image.Image) -> float:
    """
    pil_img: PIL RGB image
    Returns: probability of class 'smiling' (index 1)
    """
    x = clf_transform(pil_img).unsqueeze(0).to(device)
    logits = classifier(x)  # <-- use `classifier`, not `clf`
    probs = F.softmax(logits, dim=1)
    return float(probs[0, 1].item())


In [None]:
clip_model_id = "openai/clip-vit-base-patch32"
clip_model = CLIPModel.from_pretrained(clip_model_id).to(device)
clip_processor = CLIPProcessor.from_pretrained(clip_model_id)
clip_model.eval()

@torch.no_grad()
def clip_image_embedding(pil_img: Image.Image) -> torch.Tensor:
    inputs = clip_processor(images=pil_img, return_tensors="pt").to(device)
    img_feats = clip_model.get_image_features(**inputs)
    img_feats = img_feats / img_feats.norm(dim=-1, keepdim=True)
    return img_feats.squeeze(0).float()


In [None]:
lpips_model = lpips.LPIPS(net="vgg").to(device)

def pil_to_lpips_tensor(img: Image.Image) -> torch.Tensor:
    t = transforms.ToTensor()(img).unsqueeze(0) * 2 - 1  # [0,1] -> [-1,1]
    return t.to(device)

def mse_numpy(a, b) -> float:
    return float(((a.astype(np.float32) - b.astype(np.float32)) ** 2).mean())

def ssim_numpy(a, b) -> float:
    # a,b: HxWxC
    if a.ndim == 3 and a.shape[2] == 3:
        ssim_val = ssim(
            a, b,
            channel_axis=2,
            data_range=255,
        )
    else:
        ssim_val = ssim(a, b, data_range=255)
    return float(ssim_val)


In [None]:
BASE_DIR = "/content/drive/MyDrive/thesis2"
CSV_PER_PATH = f"{BASE_DIR}/phase4_cf_eval_per_image.csv"


In [None]:
WIDTH, HEIGHT = 768, 768  # same as generation

@torch.no_grad()
def evaluate_counterfactuals_vs_source(df_sources, df_cf) -> pd.DataFrame:
    # Build quick lookup: source_id -> path
    src_map = {
        int(row["source_id"]): row["source_path"]
        for _, row in df_sources.iterrows()
    }

    rows = []

    for _, r in tqdm(df_cf.iterrows(), total=len(df_cf), desc="Evaluating CFs vs sources"):
        src_id = int(r["source_id"])
        cf_path = r["cf_path"]

        src_path = src_map[src_id]

        # Load images
        src_img = Image.open(src_path).convert("RGB")
        cf_img  = Image.open(cf_path).convert("RGB")

        # Resize
        src_res = src_img.resize((WIDTH, HEIGHT), Image.LANCZOS)
        cf_res  = cf_img.resize((WIDTH, HEIGHT), Image.LANCZOS)

        # 1) classifier on CF
        p_smile = classifier_prob_smile(cf_res)
        label   = int(p_smile >= 0.5)

        # 2) LPIPS
        t0 = pil_to_lpips_tensor(src_res)
        t1 = pil_to_lpips_tensor(cf_res)
        lpips_val = float(lpips_model(t0, t1).item())

        # 3) MSE + SSIM
        np0 = np.array(src_res)
        np1 = np.array(cf_res)
        mse_val  = mse_numpy(np0, np1)
        ssim_val = ssim_numpy(np0, np1)

        # 4) CLIP similarity
        emb0 = clip_image_embedding(src_res)
        emb1 = clip_image_embedding(cf_res)
        clip_sim = float((emb0 * emb1).sum().item())

        rows.append({
            "source_id": src_id,
            "cf_rank": r["cf_rank"],
            "raw_idx": r["raw_idx"],
            "seed": r["seed"],
            "cf_path": cf_path,
            "clip_smile_score_gen": r["clip_smile_score"],  # from generation stage
            "clf_prob_smile_cf": p_smile,
            "clf_label_cf": label,
            "lpips_to_source": lpips_val,
            "mse_to_source": mse_val,
            "ssim_to_source": ssim_val,
            "clip_sim_to_source": clip_sim,
        })

    df_eval = pd.DataFrame(rows)
    df_eval.to_csv(CSV_PER_PATH, index=False)
    print("Saved per-image CF eval to:", CSV_PER_PATH)
    return df_eval

df_cf_eval = evaluate_counterfactuals_vs_source(df_sources, df_cf)
df_cf_eval.head()


In [None]:
df_cf_eval = evaluate_counterfactuals_vs_source(df_sources, df_cf)


In [None]:
print("======== COUNTERFACTUAL VALIDITY ========")
total = len(df_cf_eval)
flipped = (df_cf_eval["clf_label_cf"] == 1).sum()
flip_rate = flipped / total

print(f"Total counterfactuals: {total}")
print(f"Smiling according to classifier: {flipped}")
print(f"Decision flip success rate: {flip_rate:.4f}")


In [None]:
print("\n======== PROXIMITY TO SOURCE ========")
prox_stats = df_cf_eval[[
    "lpips_to_source",
    "ssim_to_source",
    "mse_to_source",
    "clip_sim_to_source"
]].describe()

print(prox_stats)


In [None]:
import matplotlib.pyplot as plt

plt.hist(df_cf_eval["clf_prob_smile_cf"], bins=25)
plt.title("Classifier smile probabilities")
plt.xlabel("P(smile)")
plt.ylabel("count")
plt.show()


In [None]:
import matplotlib.pyplot as plt


plt.hist(df_cf_eval["lpips_to_source"], bins=25)
plt.title("LPIPS distance to source")
plt.xlabel("LPIPS")
plt.ylabel("count")
plt.show()


saving summaries

In [None]:
summary_validity = pd.DataFrame([{
    "total_cfs": total,
    "successful_flips": flipped,
    "flip_rate": flip_rate
}])

summary_validity_path = f"{BASE_DIR}/phase4_cf_summary_validity.csv"
summary_validity.to_csv(summary_validity_path, index=False)
print("Saved:", summary_validity_path)


In [None]:
prox_stats_path = f"{BASE_DIR}/phase4_cf_summary_proximity.csv"
prox_stats.to_csv(prox_stats_path)
print("Saved:", prox_stats_path)


In [None]:
# ============================================
# PHASE IMAGE TABLE BUILDER (Drive scanner)
# ============================================
import os, re, glob
from pathlib import Path
from datetime import datetime

import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt

# ----------------------------
# CONFIG: edit these paths
# ----------------------------
ROOT = "/content/drive/MyDrive/thesis2/phase1_smile/images"

PATHS = {
    # Synthetic dataset (you created)
    "dataset_neutral": f"{ROOT}/neutral",
    "dataset_smiling": f"{ROOT}/smiling",

    # Phase 4 sources (if you used them; else ignore)
    #"phase1_sources_neutral": f"{ROOT}/custom_sources",

    # Phase 4 counterfactuals (where you saved CF images)
    #"phase1_counterfactuals": f"{ROOT}/phase4_clip_sampling_images",


}

OUT_DIR = f"{ROOT}/tables"
os.makedirs(OUT_DIR, exist_ok=True)

OUT_IMAGES_TABLE_CSV = f"{OUT_DIR}/images_index_all.csv"
OUT_COUNTS_TABLE_CSV = f"{OUT_DIR}/images_counts_by_phase_label.csv"
OUT_COUNTS_SRC_CSV   = f"{OUT_DIR}/images_counts_by_source.csv"

# ----------------------------
# Helpers
# ----------------------------
def safe_exists(p):
    return p and os.path.exists(p)

def parse_ids_from_filename(fname: str):
    """
    Try to parse common patterns you used:
    - dataset: neutral_0007_p1.png, smiling_0043_p2.png
    - counterfactuals: src000_cf01_seed98766.png or source0_cf0_seed123.png etc
    """
    base = os.path.basename(fname)

    out = {
        "label_or_cond": None,
        "index": None,
        "prompt_variant": None,
        "source_id": None,
        "cf_id": None,
        "seed": None,
    }

    # dataset pattern: label_0007_p1.png
    m = re.match(r"^(neutral|smiling|soft_smile|big_smile)_(\d+)_p(\d+)\.(png|jpg|jpeg)$", base, re.IGNORECASE)
    if m:
        out["label_or_cond"] = m.group(1).lower()
        out["index"] = int(m.group(2))
        out["prompt_variant"] = int(m.group(3))
        return out

    # cf pattern: src000_cf01_seed98766.png
    m = re.search(r"src(\d+)", base, re.IGNORECASE)
    if m:
        out["source_id"] = int(m.group(1))

    m = re.search(r"cf(\d+)", base, re.IGNORECASE)
    if m:
        out["cf_id"] = int(m.group(1))

    m = re.search(r"seed(\d+)", base, re.IGNORECASE)
    if m:
        out["seed"] = int(m.group(1))

    # If filename includes "neutral"/"smiling"
    if re.search("neutral", base, re.IGNORECASE):
        out["label_or_cond"] = "neutral"
    if re.search("smil", base, re.IGNORECASE):
        out["label_or_cond"] = "smiling"

    return out


def scan_folder(phase_name: str, folder: str):
    exts = ("*.png", "*.jpg", "*.jpeg", "*.webp")
    files = []
    for ext in exts:
        files.extend(glob.glob(os.path.join(folder, ext)))
    files = sorted(files)

    rows = []
    for fp in files:
        st = os.stat(fp)
        meta = parse_ids_from_filename(fp)

        rows.append({
            "phase": phase_name,
            "folder": folder,
            "filename": os.path.basename(fp),
            "filepath": fp,
            "label_or_cond": meta["label_or_cond"],
            "index": meta["index"],
            "prompt_variant": meta["prompt_variant"],
            "source_id": meta["source_id"],
            "cf_id": meta["cf_id"],
            "seed": meta["seed"],
            "filesize_kb": round(st.st_size / 1024, 2),
            "modified_time": datetime.fromtimestamp(st.st_mtime).strftime("%Y-%m-%d %H:%M:%S"),
        })
    return rows


def preview_grid(df, title, n=12, seed=0):
    """
    Show a small grid of random images from a filtered dataframe.
    """
    if len(df) == 0:
        print(f"[preview_grid] No images for: {title}")
        return

    sample = df.sample(min(n, len(df)), random_state=seed).reset_index(drop=True)
    cols = 5
    rows = (len(sample) + cols - 1) // cols

    plt.figure(figsize=(cols * 3, rows * 3))
    plt.suptitle(title, y=1.02)


    for i, r in sample.iterrows():
        ax = plt.subplot(rows, cols, i+1)
        img = Image.open(r["filepath"]).convert("RGB")
        ax.imshow(img)
        #      ax.set_title(r["filename"][:22], fontsize=8)
        ax.axis("off")

    plt.tight_layout()
    plt.show()


# ----------------------------
# Build master table
# ----------------------------
all_rows = []
missing = []

for phase_name, folder in PATHS.items():
    if safe_exists(folder):
        all_rows.extend(scan_folder(phase_name, folder))
    else:
        missing.append((phase_name, folder))

df_images = pd.DataFrame(all_rows)

print("Total images indexed:", len(df_images))
if missing:
    print("\n Missing folders (ignored):")
    for k, p in missing:
        print(f" - {k}: {p}")

# Save master table
df_images.to_csv(OUT_IMAGES_TABLE_CSV, index=False)
print("\nSaved:", OUT_IMAGES_TABLE_CSV)

# ----------------------------
# Summary tables
# ----------------------------
# 1) counts per phase x label/cond
df_counts = (
    df_images
    .assign(label_or_cond=df_images["label_or_cond"].fillna("unknown"))
    .groupby(["phase", "label_or_cond"])
    .size()
    .reset_index(name="n_images")
    .sort_values(["phase", "label_or_cond"])
)
df_counts.to_csv(OUT_COUNTS_TABLE_CSV, index=False)
print("Saved:", OUT_COUNTS_TABLE_CSV)

# 2) counts per source_id (useful for CF runs)
df_by_source = (
    df_images[df_images["source_id"].notna()]
    .groupby(["phase", "source_id", "label_or_cond"])
    .size()
    .reset_index(name="n_images")
    .sort_values(["phase", "source_id"])
)
df_by_source.to_csv(OUT_COUNTS_SRC_CSV, index=False)
print("Saved:", OUT_COUNTS_SRC_CSV)

display(df_images.head(10))
display(df_counts)

# ----------------------------
# Optional: visual sanity previews
# ----------------------------
# Dataset previews
#preview_grid(df_images[df_images["phase"]=="dataset_neutral"], "Phase 1 neutral (sample)", n=5, seed=1)
preview_grid(df_images[df_images["phase"]=="dataset_smiling"], "Phase 2 smiling (sample)", n=5, seed=2)

# CF preview (if present)
#preview_grid(df_images[df_images["phase"].str.contains("phase4_counterfactuals")], "Phase4 counterfactuals (sample)", n=30, seed=3)
#preview_grid(df_images[df_images["phase"].str.contains("phase4_sources_neutral")], "Phase4 counterfactuals (sample)", n=5, seed=3)


In [None]:
def show_neutral_plus_cf_folder(
    neutral_path: str,
    cf_folder: str,
    cf_glob: str = "*.png",
    cols_cf: int = 6,
    resize_to=(768, 768),
    save_path: str = None,
):
    assert os.path.exists(neutral_path), f"Neutral not found: {neutral_path}"
    assert os.path.exists(cf_folder), f"CF folder not found: {cf_folder}"

    cf_paths = sorted(glob.glob(os.path.join(cf_folder, cf_glob)))
    assert len(cf_paths) > 0, f"No CFs found in {cf_folder}"

    n_cf = len(cf_paths)
    cols_cf = min(cols_cf, n_cf)
    rows = (n_cf + cols_cf - 1) // cols_cf

    plt.figure(figsize=((1 + cols_cf) * 3, rows * 3))

    ax0 = plt.subplot(rows, 1 + cols_cf, 1)
    ax0.imshow(Image.open(neutral_path).convert("RGB").resize(resize_to))
    ax0.axis("off")

    for i, p in enumerate(cf_paths):
        ax = plt.subplot(rows, 1 + cols_cf, 2 + i)
        ax.imshow(Image.open(p).convert("RGB").resize(resize_to))
        ax.axis("off")

    plt.tight_layout()

    if save_path:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path, dpi=200, bbox_inches="tight")
        print("Saved:", save_path)

    plt.show()

    return cf_paths


In [None]:
neutral_path = "/content/drive/MyDrive/thesis2/custom_sources/smaira.png"

cf_folder = "/content/drive/MyDrive/thesis2/fethi"

cf_paths = sorted(glob.glob(os.path.join(cf_folder, "*.png")))

print("Found CF images:", len(cf_paths))
print(cf_paths[:3])  # sanity check

show_neutral_and_cfs_panel(
    neutral_path=neutral_path,
    cf_paths=cf_paths,
    cols_cf=6,
    save_path="/content/drive/MyDrive/thesis2/plots/smairaa.png"
)

# Pairwise comparison

In [None]:
!pip -q install lpips scikit-image transformers


In [None]:
import os, glob, itertools
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm

ROOT = "/content/drive/MyDrive/thesis2"

# Use whichever you have (eval has more columns; metadata is enough too)
CSV_IN = f"{ROOT}/phase4_cf_eval_per_image.csv"   # or phase4_clip_sampling_metadata.csv

df = pd.read_csv(CSV_IN)
print("Loaded:", df.shape)
print(df.columns)



# Keep only what we need
df = df[["source_id", "cf_rank", "seed", "cf_path"]].copy()
df["cf_path"] = df["cf_path"].astype(str)

# sanity check: how many files exist?
exists = df["cf_path"].apply(os.path.isfile)
print("CF image files found:", int(exists.sum()), "/", len(df))
print("Missing examples:", df.loc[~exists, "cf_path"].head(5).tolist())


In [None]:
import torch
import torch.nn.functional as F
import lpips
from skimage.metrics import structural_similarity as ssim

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

# LPIPS model (VGG is standard)
lpips_model = lpips.LPIPS(net="vgg").to(device).eval()

def pil_to_lpips_tensor(pil_img: Image.Image) -> torch.Tensor:
    # LPIPS expects [-1,1], shape [1,3,H,W]
    x = np.array(pil_img).astype(np.float32) / 255.0
    x = torch.from_numpy(x).permute(2,0,1).unsqueeze(0)  # [1,3,H,W]
    x = x * 2.0 - 1.0
    return x.to(device)

def mse_numpy(a: np.ndarray, b: np.ndarray) -> float:
    return float(np.mean((a.astype(np.float32) - b.astype(np.float32)) ** 2))

def ssim_numpy(a: np.ndarray, b: np.ndarray) -> float:
    # SSIM expects HxWxC, use channel_axis=2
    return float(ssim(a, b, channel_axis=2, data_range=255))

# ---- Optional CLIP image embeddings ----
USE_CLIP = True

if USE_CLIP:
    from transformers import CLIPProcessor, CLIPModel
    clip_id = "openai/clip-vit-base-patch32"
    clip_model = CLIPModel.from_pretrained(clip_id).to(device).eval()
    clip_proc  = CLIPProcessor.from_pretrained(clip_id)

    @torch.no_grad()
    def clip_img_emb(pil_img: Image.Image) -> torch.Tensor:
        inp = clip_proc(images=pil_img, return_tensors="pt")
        inp = {k:v.to(device) for k,v in inp.items()}
        feat = clip_model.get_image_features(**inp)
        feat = F.normalize(feat, dim=-1)
        return feat.squeeze(0)  # [D]


In [None]:
OUT_PAIRS   = f"{ROOT}/phase4_cf_pairwise_pairs.csv"
OUT_SUMMARY = f"{ROOT}/phase4_cf_pairwise_summary.csv"

# for consistent resizing (match your generation)
WIDTH, HEIGHT = 768, 768

def load_img(path: str) -> Image.Image:
    return Image.open(path).convert("RGB").resize((WIDTH, HEIGHT), Image.LANCZOS)

rows = []

for src_id, g in df.groupby("source_id"):
    g = g.sort_values("cf_rank").reset_index(drop=True)

    paths = g["cf_path"].tolist()
    ranks = g["cf_rank"].tolist()

    # load once
    imgs = [load_img(p) for p in paths]

    # precompute clip embeddings once
    if USE_CLIP:
        embs = [clip_img_emb(im) for im in imgs]

    # all unordered pairs
    for i, j in itertools.combinations(range(len(imgs)), 2):
        im_i, im_j = imgs[i], imgs[j]

        # LPIPS
        t_i = pil_to_lpips_tensor(im_i)
        t_j = pil_to_lpips_tensor(im_j)
        lp = float(lpips_model(t_i, t_j).item())

        # SSIM / MSE
        a = np.array(im_i)
        b = np.array(im_j)
        mse_val = mse_numpy(a, b)
        ssim_val = ssim_numpy(a, b)

        # CLIP distance
        if USE_CLIP:
            cos = float((embs[i] * embs[j]).sum().item())
            clip_dist = float(1.0 - cos)
        else:
            cos, clip_dist = np.nan, np.nan

        rows.append({
            "source_id": int(src_id),
            "cf_rank_i": int(ranks[i]),
            "cf_rank_j": int(ranks[j]),
            "cf_path_i": paths[i],
            "cf_path_j": paths[j],
            "lpips": lp,
            "ssim": ssim_val,
            "mse": mse_val,
            "clip_cos": cos,
            "clip_dist": clip_dist,
        })

df_pairs = pd.DataFrame(rows)
df_pairs.to_csv(OUT_PAIRS, index=False)
print("Saved pairs:", OUT_PAIRS, " shape:", df_pairs.shape)

# summary per source
agg_cols = ["lpips", "ssim", "mse"] + (["clip_dist"] if USE_CLIP else [])
df_summary = (
    df_pairs
    .groupby("source_id")[agg_cols]
    .agg(["mean", "std", "count"])
    .reset_index()
)
df_summary.columns = ["_".join([c for c in col if c]).strip("_") for col in df_summary.columns.values]
df_summary.to_csv(OUT_SUMMARY, index=False)
print("Saved summary:", OUT_SUMMARY, " shape:", df_summary.shape)

display(df_summary)


# baseline

In [None]:
# ============================================================
# Standard DDPM-style counterfactual baseline (no CLIP ranking)
# SD3.5 Img2Img: generate K smile candidates per source by seed variation
# Evaluate: validity (classifier flip), proximity-to-source, within-set diversity
# ============================================================

import os
import math
import random
from pathlib import Path
from typing import List, Dict, Tuple

import numpy as np
import pandas as pd
from PIL import Image

import torch
import torch.nn.functional as F


from diffusers import AutoPipelineForImage2Image
from diffusers.utils import load_image

# ---------------------------
# 0) Set paths + device
# ---------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32



MODEL_ID = "stabilityai/stable-diffusion-3.5-medium"


SOURCES_DIR = "/content/drive/MyDrive/thesis2/custom_sources"
OUT_DIR     = "/content/drive/MyDrive/thesis2/baseline_ddpm_cf"
os.makedirs(OUT_DIR, exist_ok=True)

# ---------------------------
# 1) Load SD3 Img2Img pipeline
# ---------------------------
pipe = AutoPipelineForImage2Image.from_pretrained(
    MODEL_ID,
    torch_dtype=dtype,
    variant="fp16" if dtype == torch.float16 else None,
)
pipe = pipe.to(device)



# ---------------------------
# 2) Prompts + generation params
# ---------------------------
PROMPT_SMILE = (
    "a photorealistic human face, studio lighting, high resolution, realistic anatomy, "
    "smiling, natural smile, subtle teeth visibility, joyful expression"
)

NEG_PROMPT = (
    "cartoon, illustration, deformed face, disfigured, extra limbs, bad anatomy, "
    "weird teeth, exaggerated grin, low quality, blurry"
)

# Img2img knobs
CFG_SCALE   = 7.5
STRENGTH    = 0.60        # how strong the edit is
STEPS       = 30
K           = 6           # number of CFs per source (like your Phase 4)
SEED0       = 1000        # base seed (repeatability)

# ---------------------------
# 3) Utility: list source images
# ---------------------------
def list_images(folder: str) -> List[str]:
    exts = {".png", ".jpg", ".jpeg", ".webp"}
    paths = []
    for p in sorted(Path(folder).iterdir()):
        if p.suffix.lower() in exts:
            paths.append(str(p))
    return paths

source_paths = list_images(SOURCES_DIR)
print(f"Found {len(source_paths)} neutral sources in {SOURCES_DIR}")

# ---------------------------
# 4) Plug in smile classifier (must return prob_smile in [0,1])
#
# ---------------------------
@torch.no_grad()
def predict_smile_prob(pil_img: Image.Image) -> float:


    raise NotImplementedError("Replace predict_smile_prob() with your classifier inference.")


# ---------------------------
# 5) Metrics: MSE, SSIM (simple), LPIPS (optional)
# ---------------------------
def pil_to_torch(img: Image.Image, size: Tuple[int,int]=None) -> torch.Tensor:
    if size is not None:
        img = img.resize(size, Image.BICUBIC)
    arr = np.asarray(img).astype(np.float32) / 255.0
    if arr.ndim == 2:
        arr = np.stack([arr]*3, axis=-1)
    t = torch.from_numpy(arr).permute(2,0,1).unsqueeze(0)  # 1x3xHxW
    return t

@torch.no_grad()
def mse(a: Image.Image, b: Image.Image) -> float:
    size = (512, 512)
    ta = pil_to_torch(a, size=size)
    tb = pil_to_torch(b, size=size)
    return F.mse_loss(ta, tb).item()


@torch.no_grad()
def ssim_simple(a: Image.Image, b: Image.Image) -> float:
    size = (512, 512)
    x = pil_to_torch(a, size=size)
    y = pil_to_torch(b, size=size)
    # luminance-only quick SSIM-ish:
    xg = x.mean(dim=1, keepdim=True)
    yg = y.mean(dim=1, keepdim=True)
    mu_x = xg.mean()
    mu_y = yg.mean()
    sig_x = xg.var(unbiased=False)
    sig_y = yg.var(unbiased=False)
    sig_xy = ((xg - mu_x) * (yg - mu_y)).mean()
    C1, C2 = 0.01**2, 0.03**2
    ssim_val = ((2*mu_x*mu_y + C1) * (2*sig_xy + C2)) / ((mu_x**2 + mu_y**2 + C1) * (sig_x + sig_y + C2))
    return float(ssim_val.clamp(-1, 1).item())


@torch.no_grad()
def lpips_stub(a: Image.Image, b: Image.Image) -> float:
    return float("nan")

# ---------------------------
# 6) Generate baseline CFs for one source (seed diversity only)
# ---------------------------
@torch.no_grad()
def generate_baseline_cfs_for_source(
    source_img: Image.Image,
    source_id: str,
    k: int = 6,
    seed0: int = 0,
) -> List[Dict]:
    """
    Returns list of dicts with paths + metadata.
    """
    rows = []
    for i in range(k):
        seed = seed0 + i
        gen = torch.Generator(device=device).manual_seed(seed)

        out = pipe(
            prompt=PROMPT_SMILE,
            negative_prompt=NEG_PROMPT,
            image=source_img,
            strength=STRENGTH,
            guidance_scale=CFG_SCALE,
            num_inference_steps=STEPS,
            generator=gen,
        ).images[0]

        out_path = os.path.join(OUT_DIR, f"{source_id}_cf_{i:02d}_seed{seed}.png")
        out.save(out_path)

        rows.append({
            "source_id": source_id,
            "cf_idx": i,
            "seed": seed,
            "cf_path": out_path,
        })
    return rows

# ---------------------------
# 7) Evaluate proximity + validity for baseline CFs
# ---------------------------
@torch.no_grad()
def evaluate_source_to_cf(source_img: Image.Image, cf_img: Image.Image) -> Dict:
    return {
        "mse_to_source": mse(source_img, cf_img),
        "ssim_to_source": ssim_simple(source_img, cf_img),
        "lpips_to_source": lpips_stub(source_img, cf_img),
        # plus classifier validity:
        # "prob_smile": predict_smile_prob(cf_img),
    }

# ---------------------------
# 8) Evaluate within-set diversity among CFs of same source
# ---------------------------
@torch.no_grad()
def evaluate_within_set(cf_imgs: List[Image.Image]) -> Dict:
    """
    Mean over all unordered CF-CF pairs within a set.
    """
    K = len(cf_imgs)
    pairs = [(i, j) for i in range(K) for j in range(i+1, K)]
    if len(pairs) == 0:
        return {}

    lp_list, ssim_list, mse_list = [], [], []
    for i, j in pairs:
        a, b = cf_imgs[i], cf_imgs[j]
        mse_list.append(mse(a, b))
        ssim_list.append(ssim_simple(a, b))
        lp_list.append(lpips_stub(a, b))

    return {
        "K": K,
        "n_pairs": len(pairs),
        "within_mse_mean": float(np.mean(mse_list)),
        "within_ssim_mean": float(np.mean(ssim_list)),
        "within_lpips_mean": float(np.nanmean(lp_list)),
    }

# ---------------------------
# 9) Main loop: generate + evaluate baseline
# ---------------------------
all_cf_rows = []
prox_rows = []
within_rows = []

for s_idx, sp in enumerate(source_paths):
    source_id = Path(sp).stem
    src = Image.open(sp).convert("RGB")

    # Generate baseline K counterfactuals (seed diversity only)
    cf_rows = generate_baseline_cfs_for_source(
        source_img=src,
        source_id=source_id,
        k=K,
        seed0=SEED0 + 100*s_idx,  # keep seeds separated per source
    )
    all_cf_rows.extend(cf_rows)

    # Load CF images for evaluation
    cf_imgs = [Image.open(r["cf_path"]).convert("RGB") for r in cf_rows]

    # Proximity-to-source (source vs each CF)
    for r, cf_img in zip(cf_rows, cf_imgs):
        metrics = evaluate_source_to_cf(src, cf_img)
        prox_rows.append({**r, **metrics})

    # Within-set diversity (CF vs CF for same source)
    within = evaluate_within_set(cf_imgs)
    within_rows.append({"source_id": source_id, **within})

# Save outputs
df_index  = pd.DataFrame(all_cf_rows)
df_prox   = pd.DataFrame(prox_rows)
df_within = pd.DataFrame(within_rows)

index_csv  = os.path.join(OUT_DIR, "baseline_ddpm_cf_index.csv")
prox_csv   = os.path.join(OUT_DIR, "baseline_ddpm_proximity.csv")
within_csv = os.path.join(OUT_DIR, "baseline_ddpm_within_set_diversity.csv")

df_index.to_csv(index_csv, index=False)
df_prox.to_csv(prox_csv, index=False)
df_within.to_csv(within_csv, index=False)

print("Saved:")
print(" -", index_csv)
print(" -", prox_csv)
print(" -", within_csv)


