# From Erasure to Transplant: Probing Embedding Semantics
## Stress-Testing of Text-to-Image Models via Semantic Surgery

**Course:** Advanced Machine Learning and Computer Vision (2025)

**Framework:** Based on *Semantic Surgery* (NeurIPS 2025)

**Students:** Lorenzo Musso (*2049518*) - Giulia Pietrangeli (*2057291*)

---

### Abstract
This project proposes a paradigm shift in the analysis of Text-to-Image (T2I) models. Moving beyond standard safety-oriented concept erasure, we repurpose the **Semantic Surgery** framework as a diagnostic probe to perform **Semantic Transplantation**. By injecting semantic shift vectors with **Token-Wise Precision**, we stress-test **Stable Diffusion v1.4** to measure its robustness, interpretability, and latent semantic entanglement. Crucially, we introduce the **'Surgery Autopilot'**—comparing **Classical Machine Learning** and **Deep Learning** approaches—to **automate the optimization of surgical hyperparameters**, effectively bridging diagnostic analysis with practical control.

### Experimental Framework

1.  **Methodology:** We extend the original vector subtraction mechanism to **Vector Injection**:
    $$c^* = c_{in} + \lambda \cdot \mathbf{M}_{\alpha} \cdot (v_{new} - v_{old})$$
    Where $\lambda$ represents the intervention **Force** and $\mathbf{M}_{\alpha}$ is the **Sensitivity Mask**.

2. **Investigation I (Validation):** Object & Context Swapping (with Manual Fine-Tuning) to evaluate **Semantic Alignment** and **Multi-Scale Fidelity**:
   * **Locality & Pose Preservation** (Owl-ViT & IoU)
   * **Semantic Integrity** (CLIP & ResNet50)
   * **Structural Fidelity** (SSIM & LPIPS)
   * **Spectral Analysis** (FFT)

3.  **Investigation II (Ablation):** Systematic Hyperparameter Grid Search to empirically identify the **Optimal Trade-off** between editing efficacy ($\lambda$) and background preservation ($\alpha$).

4. **Investigation III (Stress-Test):** Adversarial generation on a **Custom Benchmark Dataset** (configured with manually selected parameters) to quantify latent model failures:
   * **Contextual Bias:** Testing Robustness against OOD environments (e.g., *Boat in Desert*).
   * **Attribute Entanglement:** Measuring Visual Leakage (Colour/Texture) between concepts.
   * **Societal Bias:** Analyzing implicit Gender Shifts in occupational swaps.

5. **Investigation IV (Automation Phase):** The "Surgery Autopilot".
      * **Data Generation:** An exhaustive Grid Search creates a ground-truth dataset mapping prompts to optimal parameters.
      * **Feature Extraction:** Prompts are transformed into high-dimensional **CLIP Embeddings**.
      * **Model Competition:** We train and compare two architectures:
        * **Baseline:** Multi-Output Random Forest (ML).
        * **SurgeryNet:** A custom Multi-Layer Perceptron (DL) with regularization.
      * **Evaluation:** Head-to-head comparison using **MAE** to determine the superior approach.

6. **Live Demonstration:** An interactive **Gradio Interface** allowing real-time comparison between **Manual Control**, **ML Prediction**, and **DL Prediction**.

# 1. Environment Setup & Initialization

This section establishes the technical foundation for the project. It performs four critical tasks to prepare the environment for *Semantic Surgery*:

1.  **Dependency Management**: Loads essential libraries (`torch`, `numpy`, `PIL`) and specific evaluation architectures used throughout the pipeline (**ResNet50** for fidelity, **LPIPS/SSIM** for structural similarity, and **CLIP** for semantic alignment).
2.  **Reproducibility & Hardware Acceleration**:
    * `seed_everything(42)`: Fixes the random seed across PyTorch, NumPy, and Python to ensure deterministic diffusion generation.
    * **Device Selection**: Automatically detects hardware acceleration, prioritizing Apple Silicon (`mps`) or NVIDIA GPUs (`cuda`) to optimize inference speed.
3.  **Model Loading (The "Surgeon")**:
    * Initializes the `Evaluator` class for real-time metrics.
    * **StableDiffuser**: Loads the core diffusion model with the **DDIM Scheduler**.
    * **Default Hyperparameters**:
        * `beta (-0.12)`: **Sensitivity** (Controls attention mask strictness).
        * `lambda (1.0)`: **Force** (Controls semantic vector injection intensity).

In [1]:
%load_ext autoreload
%autoreload 2

# --- 1. SYSTEM & UTILITIES IMPORTS ---
import os
import sys
import time
import gc
import random
import pickle
import itertools
import json
from pathlib import Path

# --- 2. DATA SCIENCE & VIZ IMPORTS ---
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import seaborn as sns
from PIL import Image, ImageDraw, ImageFont
import cv2
from IPython.display import clear_output

# --- 3. PYTORCH & METRICS IMPORTS ---
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import lpips
from skimage.metrics import structural_similarity as ssim

# --- 4. MODEL ARCHITECTURES IMPORTS ---
from torchvision.models import resnet50, ResNet50_Weights
import torchvision.transforms as transforms
from transformers import (
    ViTForImageClassification, ViTImageProcessor,
    BlipProcessor, BlipForConditionalGeneration,
    CLIPProcessor, CLIPModel,
    OwlViTProcessor, OwlViTForObjectDetection,
    CLIPSegProcessor, CLIPSegForImageSegmentation
)

# --- 5. SKLEARN & GRADIO IMPORTS ---
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestRegressor
from sklearn.multioutput import MultiOutputRegressor
from sklearn.metrics import mean_squared_error, mean_absolute_error
import gradio as gr

# --- REPRODUCIBILITY SETUP ---
def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    print(f"Global Seed set to: {seed}")

seed_everything(42)

# --- HARDWARE DETECTION ---
if torch.backends.mps.is_available():
    device = "mps"
    print(f"Hardware: Apple Silicon detected! Using device: {device} (GPU)")
elif torch.cuda.is_available():
    device = "cuda"
    print(f"Hardware: NVIDIA GPU detected! Using device: {device}")
else:
    device = "cpu"
    print("Hardware: Warning! Using CPU (Inference will be slow).")

current_dir = os.path.abspath(".")
if current_dir not in sys.path:
    sys.path.append(current_dir)
    
print(f"Working Directory: {current_dir}")
print("Local 'src' modules imported successfully.")

# --- MODULE IMPORTS ---
import src.utils as utils
import src.evaluation as evaluation

# --- MODEL INITIALIZATION ---
if 'evaluator' not in locals():
    evaluator = evaluation.Evaluator(device)

print(f"Initialising StableDiffuser Model...")

init_params = {
    "gamma": 0.02, 
    "beta": -0.12,     # Base Sensitivity 
    "lambda": 1.0,     # Base Force 
    "alpha_threshold": 0.2, 
    "detect_threshold": 0.5,
    "alpha_f": None, 
    "erase_index_f": None
}

try:
    surgeon = utils.StableDiffuser(
        scheduler="DDIM",
        cache_dir="./model_cache", 
        concepts_to_erase=["placeholder"], 
        params=init_params
    )
    surgeon = surgeon.to(device) 
    print(f"Model successfully loaded on {device}.")

except Exception as e:
    print(f"Initialisation Error: {e}")

Global Seed set to: 42
Hardware: Apple Silicon detected! Using device: mps (GPU)
Working Directory: /Users/giulia/Documents/Università/Advanced Machine Learning/Final Project/Semantic Surgery
Local 'src' modules imported successfully.
Initializing Evaluator (Lazy Loading Mode) on mps...
Initialising StableDiffuser Model...
Loading VAE...
Loading tokenizer and text encoder...
Loading UNet model...
Loading feature extractor and safety checker...




Setting up scheduler...
All components loaded successfully.
Model successfully loaded on mps.


# 2. Investigation I: Method Validation (Proof of Concept)

## Part A: Object Swapping Pipeline

This section executes the **Data Production Phase** for the primary validation task: Object Swapping. The goal is to generate a visual dataset (Original vs. Swapped) that will be subjected to quantitative analysis to verify if the method can successfully replace an object while preserving the surrounding context.

### Methodology
1.  **Scenario Definition**: We define a set of test cases ranging from natural subjects (*Bear → Tiger*) to inanimate objects (*Sportscar → Firetruck*).
2.  **Hyperparameter Selection**: Each scenario uses **custom hyperparameters** (`Force`, `Sensitivity`) derived from preliminary tuning to ensure optimal generation quality for the validation step.
3.  **Deterministic Generation**: The pipeline runs with a fixed seed (`42`) to guarantee that the "Original" and "Modified" images share the exact same initial noise latent, ensuring that any visual difference is solely due to the semantic intervention.

### Outputs
* **Visual Dataset**: Images are saved to `results/subject_swap_results/`.
* **Preliminary Inspection**: A 3x2 grid is generated for each case, displaying the visual result alongside **Grad-CAM Attention Maps** to verify if the model's focus has correctly shifted to the new semantic concept.

In [2]:
base_dir = "results/subject_swap_results"
os.makedirs(base_dir, exist_ok=True)
print(f"Saving validation results to: {base_dir}/")

# Define Validation Scenarios
swap_scenarios = [
    # ANIMALS (Texture & Shape changes)
    {"name": "Swap 1: Forest", "base_prompt": "A brown bear walking in a forest", "old_obj": "bear", "new_obj": "tiger", "remove": "brown bear", "replace": "tiger", "force": 1.3, "sens": 0.15},
    {"name": "Swap 2: Sofa", "base_prompt": "A dog sitting on a sofa", "old_obj": "dog", "new_obj": "cat", "remove": "dog", "replace": "cat", "force": 1.0, "sens": 0.2},
    {"name": "Swap 3: Fish", "base_prompt": "A goldfish swimming in the sea", "old_obj": "goldfish", "new_obj": "shark", "remove": "goldfish", "replace": "shark", "force": 1.0, "sens": 0.2},
    # OBJECTS (Rigid structures)
    {"name": "Swap 4: Street", "base_prompt": "A red sportscar driving on a asphalt road", "old_obj": "sportscar", "new_obj": "firetruck", "remove": "red sportscar", "replace": "firetruck", "force": 0.8, "sens": 0.25},
    {"name": "Swap 5: Office", "base_prompt": "A coffee mug on a wooden office desk", "old_obj": "mug", "new_obj": "beer", "remove": "coffee mug", "replace": "beer bottle", "force": 1.0, "sens": 0.3},
    {"name": "Swap 6: Table", "base_prompt": "A red apple on a wooden table", "old_obj": "apple", "new_obj": "daisy", "remove": "red apple", "replace": "daisy flower", "force": 1.0, "sens": 0.3}
]

json_swap_list = []

if 'surgeon' in locals():
    print("Starting Object Swap Validation...")
    surgeon = surgeon.to(device)
    
    for i, scen in enumerate(swap_scenarios):
        gc.collect()
        if torch.backends.mps.is_available(): torch.mps.empty_cache()
        
        print(f"\nTesting Scenario [{i+1}/{len(swap_scenarios)}]: {scen['name']}")
        
        # Apply Scenario Params
        surgeon.params['lambda'] = scen['force']
        surgeon.params['alpha_threshold'] = scen['sens']
        
        safe_name = scen['name'].replace(" ", "_").replace(":", "")
        scenario_dir = os.path.join(base_dir, safe_name)
        os.makedirs(scenario_dir, exist_ok=True)
        file_paths = {}
        
        full_prompt = f"Photography, 8k, {scen['base_prompt']}"
        
        def process_step(variant_name, replace_txt, erase_txt):
            gen_step = torch.Generator("cpu").manual_seed(42)
            
            surgeon.concepts_to_erase = [erase_txt] if erase_txt else []
            
            print(f"   > Generating {variant_name}...")
            imgs = surgeon([full_prompt], img_size=512, n_steps=30, n_imgs=1, 
                           show_alpha=False, generator=gen_step, replace_with=replace_txt)
            img = imgs[0][0]
            
            path = os.path.join(scenario_dir, f"{variant_name}.png")
            img.save(path)
            file_paths[variant_name] = path

            # Compute Quick Metrics (Top-2 Class + GradCAM)
            top2 = evaluator.get_top2_verdict(img)
            cam1 = evaluator.compute_gradcam(img, top2[0]['id']) 
            cam2 = evaluator.compute_gradcam(img, top2[1]['id']) 
            
            return {"img": img, "top2": top2, "cam1": cam1, "cam2": cam2, "title": variant_name}
    
        # Execute Steps
        res_orig = process_step("Original", None, None)
        res_swap = process_step("Swap", scen['replace'], scen['remove'])
        
        # Log Data
        json_swap_list.append({
            "scenario": scen['name'],
            "old_obj": scen['old_obj'], "new_obj": scen['new_obj'],
            "paths": file_paths,
            "metrics": {
                "orig_top1": res_orig['top2'][0],
                "swap_top1": res_swap['top2'][0]
            }
        })
        
        # --- VISUALIZATION ---
        fig, axes = plt.subplots(3, 2, figsize=(10, 14))
        fig.suptitle(f"{scen['name']}: {scen['old_obj']} ➝ {scen['new_obj']}", fontsize=16, fontweight='bold', y=0.99)
        
        columns = [res_orig, res_swap]
        for col_idx, data in enumerate(columns):
            # Row 1: Image + Classification
            ax_img = axes[0, col_idx]
            ax_img.imshow(data['img'])
            verdict = f"1. {data['top2'][0]['name']} ({data['top2'][0]['score']:.1%})\n2. {data['top2'][1]['name']} ({data['top2'][1]['score']:.1%})"
            ax_img.set_title(f"{data['title']}\n{verdict}", fontsize=10, loc='left', 
                             bbox=dict(facecolor='white', alpha=0.8, edgecolor='gray', boxstyle='round'))
            ax_img.axis("off")
            
            # Row 2: GradCAM 1st Prediction
            ax_cam1 = axes[1, col_idx]
            ax_cam1.imshow(data['cam1'])
            ax_cam1.set_title(f"Focus: {data['top2'][0]['name']}", fontsize=10, color='blue', fontweight='bold')
            ax_cam1.axis("off")

            # Row 3: GradCAM 2nd Prediction
            ax_cam2 = axes[2, col_idx]
            ax_cam2.imshow(data['cam2'])
            ax_cam2.set_title(f"Focus: {data['top2'][1]['name']}", fontsize=10, color='red')
            ax_cam2.axis("off")
            
        plt.tight_layout(rect=[0, 0, 1, 0.96])
        plt.savefig(os.path.join(base_dir, f"Analysis_{safe_name}.png"), dpi=100, bbox_inches='tight')
        plt.show()
        
        plt.close(fig) 
        del res_orig, res_swap, fig, axes 
    
    # Save Metadata for next steps
    with open(os.path.join(base_dir, "validation_metrics.json"), 'w') as f:
        json.dump(json_swap_list, f, indent=4)
    
    clear_output(wait=True)
    print("\nAll validation scenarios completed.")
else:
    print("Error: Surgeon model not loaded. Run Setup cell first.")


All validation scenarios completed.


### Spatial Consistency Analysis (Locality & Pose)

This section performs a quantitative validation of the **Locality** and **Pose Preservation** properties of the Semantic Surgery. The objective is to verify that the semantic intervention remains strictly confined to the target object's bounding box, answering the question: *"Did the new object appear in the exact same physical location as the old one?"*

#### Methodology
1.  **Automated Object Detection (Owl-ViT)**:
    * We employ an Open-Vocabulary Object Detector (**Owl-ViT**) to blindly localize the target concepts in both the *Original* and *Modified* images.
    * This allows for an unbiased extraction of bounding boxes (e.g., detecting the "Bear" in image A and the "Tiger" in image B) without human annotation.

2.  **IoU Calculation (Intersection over Union)**:
    * We calculate the **IoU Score** between the bounding box of the original object ($B_{old}$) and the new object ($B_{new}$).
    * **Metric Interpretation**:
        * **High IoU ($\approx 1.0$)**: Excellent Pose Preservation. The new object perfectly occupies the spatial footprint of the old one.
        * **Low IoU ($\approx 0.0$)**: Poor Locality. The object has shifted significantly or the model generated the new concept in a different area.

3.  **Visualization**:
    * **Qualitative**: A 3x2 grid visualizing the overlap of bounding boxes (Red = Old, Green = New).
    * **Quantitative**: A bar chart plotting the IoU scores against a "Stability Threshold" ($0.5$).

In [3]:
base_dir = "results/subject_swap_results"
json_path = os.path.join(base_dir, "validation_metrics.json")
print(f"Analysis IoU box in: {base_dir}/")

if os.path.exists(json_path):
    with open(json_path, 'r') as f: experiments = json.load(f)
    
    results = []
    
    fig, axes = plt.subplots(3, 2, figsize=(12, 12))
    axes = axes.flatten()
    
    print("   Detecting Objects...")
    
    for idx, exp in enumerate(experiments):
        if idx >= 6: break
        
        name = exp['scenario'].split(':')[1].strip()
        p_orig = exp['paths']['Original']
        p_swap = exp['paths']['Swap']
        
        if os.path.exists(p_orig) and os.path.exists(p_swap):
            im_o = Image.open(p_orig).convert('RGB')
            im_s = Image.open(p_swap).convert('RGB')
            
            # 1. Detect Bounding Boxes
            box_old = evaluator.get_greedy_box(im_o, exp['old_obj'])
            box_new = evaluator.get_greedy_box(im_s, exp['new_obj'])
            
            # 2. Compute IoU
            iou = 0.0
            if box_old is not None and box_new is not None:
                iou = evaluator.compute_iou(box_old, box_new)
            
            results.append({"Scenario": name, "IoU": iou})
            
            # 3. Visualize Overlap
            combined_img = evaluator.draw_box_comparison(im_o, box_old, im_s, box_new)
            
            ax = axes[idx]
            ax.imshow(combined_img)
            ax.set_title(f"{name}\nIoU: {iou:.2f}", fontsize=11, fontweight='bold')
            ax.axis('off')
            
            # Save individual check
            combined_img.save(os.path.join(base_dir, f"BoxOverlap_{name}.png"))

    # Finalize Grid Plot
    plt.suptitle("Spatial Consistency (Multi-Object Coverage)", y=0.99, fontsize=14)
    plt.tight_layout()
    plt.savefig(os.path.join(base_dir, "ObjectSwap_Box_Grid.png"), dpi=150)
    plt.show()
    
    # Generate Quantitative Bar Chart
    if results:
        df = pd.DataFrame(results)
        plt.figure(figsize=(8, 5))
        sns.barplot(data=df, x="Scenario", y="IoU", palette="viridis")
        plt.title("Spatial Precision Score")
        plt.ylabel("IoU Score")
        plt.ylim(0, 1.0)
        plt.axhline(0.5, color='red', ls='--', label="Stability Threshold")
        plt.grid(axis='y', alpha=0.3)
        plt.tight_layout()
        plt.savefig(os.path.join(base_dir, "ObjectSwap_BoxIoU_Chart.png"), dpi=150)
        plt.show()

    clear_output(wait=True)
    print("Owl-ViT Analysis Completed.")
else:
    print("JSON file not found.")

Owl-ViT Analysis Completed.


### Quantitative Analysis – Multi-Scale Fidelity (Object Swap)

This cell executes the comprehensive evaluation of the Object Swapping experiment. It aggregates data from the generated images and subjects them to a battery of **5 distinct analytical tests** to validate the *Semantic Surgery* performance across different scales of perception.

#### Key Analytical Modules:

1.  **Semantic Integrity Analysis (The "Meaning" Check)**:
    * **ResNet-50 Confidence**: Evaluates the classification probability for both the *Old Object* (Ghost Score) and the *New Object* (Success Score).
        * *Goal:* High Success Score ($\approx 1.0$), Low Ghost Score ($\approx 0.0$).
    * **CLIP Score**: Measures the semantic alignment between the visual content and the text prompts (e.g., alignment with "a photo of a tiger" vs "a photo of a bear").

2.  **Structural Fidelity Analysis (The "Quality" Check)**:
    * **SSIM (Structural Similarity)**: Mathematically compares pixel structure, luminance, and contrast. (*Higher is better*).
    * **LPIPS (Perceptual Similarity)**: Uses a deep neural network to mimic human vision and detect if the image "feels" different. (*Lower is better*).

3.  **Spectral Analysis (FFT)**:
    * Performs **Fast Fourier Transform** to visualize the frequency domain.
    * Calculates the **Spectral Energy Difference** to ensure the editing process didn't introduce high-frequency artifacts (noise/checkerboard patterns) invisible to the naked eye.

4.  **Automated Captioning (BLIP)**:
    * Uses the **BLIP** model to generate a neutral, AI-written caption of the final result. This serves as a "blind test" to verify if an external AI naturally describes the new object without being prompted.

In [4]:
base_dir = "results/subject_swap_results"
swap_json = os.path.join(base_dir, "validation_metrics.json") 
print(f"Starting Quantitative Analysis in: {base_dir}")

if os.path.exists(swap_json):
    with open(swap_json, 'r') as f: experiments = json.load(f)
    
    swap_rows = []
    print("   Processing metrics (ResNet, CLIP, SSIM, LPIPS)...")

    for exp in experiments:
        name = exp['scenario'].split(':')[1].strip()
        old_obj = exp['old_obj']
        new_obj = exp['new_obj']
        
        path_orig = exp['paths']['Original']
        path_swap = exp['paths']['Swap']
            
        if os.path.exists(path_swap) and os.path.exists(path_orig):
            # 1. Structural Metrics (SSIM, LPIPS)
            s, m, l = evaluator.get_structural_metrics(path_orig, path_swap)

            img_swap_pil = Image.open(path_swap).convert('RGB')
            
            # 2. Semantic Integrity (ResNet & CLIP)
            ghost_score = evaluator.get_resnet_conf(img_swap_pil, old_obj)
            success_score = evaluator.get_resnet_conf(img_swap_pil, new_obj)
            
            clip_old = evaluator.get_clip_score(img_swap_pil, f"a photo of a {old_obj}")
            clip_new = evaluator.get_clip_score(img_swap_pil, f"a photo of a {new_obj}")
            
            caption = evaluator.get_blip_caption(img_swap_pil)
            
            swap_rows.append({
                "Scenario": name, 
                "Ghost (ResNet)": ghost_score, 
                "Success (ResNet)": success_score,
                "CLIP Old": clip_old,
                "CLIP New": clip_new,
                "SSIM": s, 
                "MSE": m, 
                "LPIPS": l, 
                "Caption": caption,
                "Path_Orig": path_orig, 
                "Path_Swap": path_swap
            })
            print(f"   {name}: ResNet={success_score:.1%} | CLIP={clip_new:.1f} | SSIM={s:.2f}")

    if swap_rows:
        df_swap = pd.DataFrame(swap_rows)
        
        # --- CHART 1: SEMANTIC INTEGRITY (ResNet + CLIP) ---
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 6))
        
        x = np.arange(len(df_swap))
        w = 0.35
        
        # Subplot 1: ResNet 
        ax1.bar(x - w/2, df_swap['Ghost (ResNet)'], w, label='Old Object (Ghost)', color='#d62728', alpha=0.8)
        ax1.bar(x + w/2, df_swap['Success (ResNet)'], w, label='New Object (Success)', color='#2ca02c', alpha=0.8)
        ax1.set_xticks(x)
        ax1.set_xticklabels(df_swap['Scenario'], rotation=15)
        ax1.set_title('ResNet-50 Confidence (Is the object recognized?)')
        ax1.set_ylabel('Confidence Probability')
        ax1.legend()
        ax1.grid(axis='y', alpha=0.3)

        # Subplot 2: CLIP
        ax2.bar(x - w/2, df_swap['CLIP Old'], w, label='Align to Old Text', color='#ff7f0e', alpha=0.8)
        ax2.bar(x + w/2, df_swap['CLIP New'], w, label='Align to New Text', color='#1f77b4', alpha=0.8)
        ax2.set_xticks(x)
        ax2.set_xticklabels(df_swap['Scenario'], rotation=15)
        ax2.set_title('CLIP Semantic Score (Is the concept correct?)')
        ax2.set_ylabel('CLIP Logits')
        ax2.legend()
        ax2.grid(axis='y', alpha=0.3)

        plt.suptitle("Semantic Integrity Analysis", fontsize=16, fontweight='bold')
        plt.tight_layout()
        plt.savefig(os.path.join(base_dir, "Semantic_Integrity_ResNet_CLIP.png"), dpi=150)
        plt.show()
        plt.close()

        # --- CHART 2: STRUCTURAL FIDELITY (SSIM + LPIPS) ---
        fig, ax1 = plt.subplots(figsize=(10, 6))
        
        # Bar Plot SSIM 
        ax1.bar(df_swap['Scenario'], df_swap['SSIM'], color='purple', alpha=0.6, label='SSIM (Structure)')
        ax1.set_ylabel('SSIM (Higher is Better)', color='purple', fontweight='bold')
        ax1.set_ylim(0, 1.0)
        
        # Line Plot LPIPS 
        ax2 = ax1.twinx()
        ax2.plot(df_swap['Scenario'], df_swap['LPIPS'], color='orange', marker='o', lw=3, label='LPIPS (Perceptual Diff)')
        ax2.set_ylabel('LPIPS (Lower is Better)', color='orange', fontweight='bold')
        ax2.set_ylim(0, 0.8) 
        
        plt.title("Structural Fidelity Analysis: SSIM vs LPIPS")
        plt.tight_layout()
        plt.savefig(os.path.join(base_dir, "Structural_Fidelity_Metrics.png"), dpi=150)
        plt.show()
        plt.close()

        # --- CHART 3: FFT ANALYSIS ---
        fig, axes = plt.subplots(len(df_swap), 3, figsize=(12, 3.5 * len(df_swap)))
        if len(df_swap) == 1: axes = np.array([axes]) 
        
        for idx, row in df_swap.iterrows():
            s_orig = evaluator.get_fft_spectrum(row['Path_Orig'])
            s_swap = evaluator.get_fft_spectrum(row['Path_Swap'])
            
            if s_orig is not None and s_swap is not None:
                diff = np.abs(s_swap - s_orig)
                energy = np.mean(diff)
                
                axes[idx, 0].imshow(s_orig, cmap='inferno'); axes[idx, 0].axis('off')
                axes[idx, 0].set_title(f"{row['Scenario']} Orig")
                
                axes[idx, 1].imshow(s_swap, cmap='inferno'); axes[idx, 1].axis('off')
                axes[idx, 1].set_title(f"{row['Scenario']} Swap")
                
                axes[idx, 2].imshow(diff, cmap='gray'); axes[idx, 2].axis('off')
                axes[idx, 2].set_title(f"Spectral Diff (Energy: {energy:.2f})")
                
        plt.suptitle("Structural Fidelity: FFT Spectral Analysis", y=1.01, fontsize=16, fontweight='bold')
        plt.tight_layout()
        plt.savefig(os.path.join(base_dir, "Structural_Fidelity_FFT.png"), dpi=150)
        plt.show()
        plt.close()

        # --- BLIP REPORT ---
        print("\nBLIP CAPTIONS REPORT:")
        print("-" * 60)
        with open(os.path.join(base_dir, "swap_final_report.txt"), "w") as f:
            f.write(f"{'SCENARIO':<20} | {'CLIP NEW':<10} | {'SSIM':<6} | DESCRIPTION\n" + "-"*80 + "\n")
            for _, row in df_swap.iterrows():
                line_scr = f"{row['Scenario']:<20} | {row['CLIP New']:.2f}       | {row['SSIM']:.2f}   | {row['Caption']}"
                print(line_scr)
                f.write(line_scr + "\n")

        clear_output(wait=True)

        print(f"\nAnalysis completed, images saved at: {base_dir}")
    
    evaluator.free_memory()
else:
    print("No valid data found for the analysis.")


Analysis completed, images saved at: results/subject_swap_results
   Memory Cleared.


## Part B: Contextual Swapping Pipeline (Validation)

This section extends the validation scope from objects to **Contextual Environments**.
Unlike the previous step, here we fix the **Main Subject** (e.g., a Bear) and surgically alter the **Background** to verify the method's ability to disentangle foreground and background representations.

### Objective
To demonstrate that *Semantic Surgery* can effectively transplant a subject into a new environment without degrading its identity, proving that the method allows for independent control of context.

### Methodology
For each subject, we generate three variations to validate the method's flexibility:
1.  **Original**: The baseline image (e.g., Bear in Forest).
2.  **In-Distribution Swap (Easy)**: A context statistically compatible with the object (e.g., *Forest* $\to$ *Snowy Mountain*). This verifies basic editing capability.
3.  **Out-Of-Distribution Swap (Hard)**: A context semantically clashing with the object (e.g., *Forest* $\to$ *Supermarket*). This serves as a preliminary robustness check to see if the surgery holds up under semantic strain.

### Outputs
* **Visual Grid**: A 3x3 matrix showing the subject in Original, Easy, and Hard contexts, alongside attention maps.
* **Confidence Check**: A bar chart tracking the ResNet recognition score of the subject across environments. In a perfect surgery, the subject should remain recognizable ($P \approx 1.0$) even in the "Hard" context.

In [5]:
base_dir = "results/context_swap_results"
os.makedirs(base_dir, exist_ok=True)
print(f"Output saved at: {base_dir}/")

STYLE = "Photography, 8k, colours"
STEPS = 30

# Define Validation Scenarios
scenarios = [
    {"name": "A: Fish", "target": "goldfish", "base_prompt": "A goldfish swimming inside a glass bowl", "remove": "glass bowl, water", "replace_easy": "ocean", "replace_hard": "forest", "seed": 42, "force": 1.0, "sensitivity": 0.15},
    {"name": "B: Bear", "target": "bear", "base_prompt": "A brown bear walking in a forest", "remove": "forest, trees", "replace_easy": "snowy mountain", "replace_hard": "supermarket aisle", "seed": 42, "force": 1.0, "sensitivity": 0.15},
    {"name": "C: Corgi", "target": "corgi", "base_prompt": "A corgi dog running on green grass", "remove": "green grass", "replace_easy": "street", "replace_hard": "underwater seabed", "seed": 42, "force": 1.0, "sensitivity": 0.15},
    {"name": "D: Sportscar", "target": "sportscar", "base_prompt": "A red sportscar driving on a asphalt road", "remove": "asphalt road", "replace_easy": "race track", "replace_hard": "elegant restaurant", "seed": 42, "force": 0.8, "sensitivity": 0.08},
    {"name": "E: Sofa", "target": "sofa", "base_prompt": "A leather sofa in a living room", "remove": "living room", "replace_easy": "furniture store showroom", "replace_hard": "forest", "seed": 42, "force": 0.8, "sensitivity": 0.08},
    {"name": "F: Yawl", "target": "yawl", "base_prompt": "A yawl in the ocean", "remove": "ocean", "replace_easy": "city harbour", "replace_hard": "dune desert", "seed": 42, "force": 0.8, "sensitivity": 0.08}
]

stats_summary = [] 
json_data_list = [] 

if 'surgeon' in locals():
    
    print(f"Starting Context Swap Stress Test ({len(scenarios)} scenarios)...")
    surgeon = surgeon.to(device)
    
    for i, scen in enumerate(scenarios):
        evaluator.free_memory()
        if torch.backends.mps.is_available(): torch.mps.empty_cache()
        
        current_force = scen.get("force", 1.0)
        current_sens = scen.get("sensitivity", 0.15)
        
        surgeon.params['lambda'] = current_force
        surgeon.params['alpha_threshold'] = current_sens
        
        print(f"\n--- [{i+1}/{len(scenarios)}] {scen['name']} ---")
        
        full_prompt = f"{STYLE}, {scen['base_prompt']}"
        scen_stats = {"name": scen['name'], "scores": {}}
        
        safe_name = scen['name'].split(':')[0].strip() + "_" + scen['name'].split(':')[1].strip()
        scenario_dir = os.path.join(base_dir, safe_name)
        os.makedirs(scenario_dir, exist_ok=True)
        
        file_paths = {}

        def process_variant(variant_name, replace_text, concept_erase):
            surgeon.concepts_to_erase = [concept_erase] if concept_erase else []
            gen = torch.Generator("cpu").manual_seed(scen['seed'])
            
            print(f"  > Gen {variant_name}...")
            imgs = surgeon(prompts=[full_prompt], img_size=512, n_steps=STEPS, n_imgs=1,
                                show_alpha=False, use_safety_checker=False, generator=gen, replace_with=replace_text)
            img = imgs[0][0]
            
            img_path = os.path.join(scenario_dir, f"{variant_name}.png")
            img.save(img_path)
            file_paths[variant_name] = img_path 
            
            # Compute Metrics
            top2 = evaluator.get_top2_verdict(img)
            target_prob = evaluator.get_resnet_conf(img, scen['target'])
            scen_stats["scores"][variant_name] = target_prob 
            
            # Compute Attention Map for Target Object
            cam_target_name = scen['target'] 
            cam_id = evaluator.SWAP_INDICES.get(cam_target_name)
            if cam_id is None: cam_id = 0 
            
            cam1 = evaluator.compute_gradcam(img, top2[0]['id'])
            cam2 = evaluator.compute_gradcam(img, top2[1]['id'])
            
            return {"img": img, "top2": top2, "cam1": cam1, "cam2": cam2, "title": variant_name}

        # Execute 3 variations
        col1 = process_variant("Original", None, None)
        col2 = process_variant("Easy", scen['replace_easy'], scen['remove'])
        col3 = process_variant("Hard", scen['replace_hard'], scen['remove'])
        
        stats_summary.append(scen_stats)
        
        json_data_list.append({
            "scenario_name": scen['name'],
            "base_prompt": scen['base_prompt'],
            "target_object": scen['target'],
            "paths": file_paths, 
            "params": {"force": current_force, "sens": current_sens}
        })

        # --- VISUALIZATION GRID ---
        data_columns = [col1, col2, col3]
        fig, axes = plt.subplots(3, 3, figsize=(15, 15))
        fig.suptitle(f"Deep Analysis: {scen['name']}\nParams: F={current_force}, S={current_sens}", fontsize=18, fontweight='bold')
        
        for col_idx, data in enumerate(data_columns):
            # Row 1: Image
            ax_img = axes[0, col_idx]
            ax_img.imshow(data['img'])
            pred_text = (f"1. {data['top2'][0]['name']} ({data['top2'][0]['score']:.1%})\n"
                         f"2. {data['top2'][1]['name']} ({data['top2'][1]['score']:.1%})")
            ax_img.set_title(f"{data['title']}\n{pred_text}", fontsize=11, loc='left', fontweight='bold', 
                             bbox=dict(facecolor='white', alpha=0.7, edgecolor='none'))
            ax_img.axis("off")
            
            # Row 2 & 3: GradCAM
            for row_idx, key_idx, color in [(1, 0, 'blue'), (2, 1, 'red')]:
                ax = axes[row_idx, col_idx]
                if len(data['top2']) > key_idx:
                    ax.imshow(data['cam' + str(key_idx+1)])
                    ax.set_title(f"Focus on: {data['top2'][key_idx]['name']}", fontsize=10, style='italic', color=color)
                ax.axis("off")
        
        row_labels = ["Image", "Attn (1st)", "Attn (2nd)"]
        for i, lbl in enumerate(row_labels): axes[i, 0].text(-20, 256, lbl, rotation=90, va='center', fontsize=12, fontweight='bold')
        
        plt.tight_layout()
        plt.savefig(os.path.join(base_dir, f"Grid_{safe_name}.png"), dpi=100, bbox_inches='tight')
        plt.show()
        
        plt.close(fig)
        del col1, col2, col3, data_columns, fig, axes
        gc.collect()

    print("\nSaving JSON Metadata...")
    with open(os.path.join(base_dir, "experiment_data.json"), 'w') as f:
        json.dump(json_data_list, f, indent=4)

    # --- CONFIDENCE DROP CHART ---
    print("Generating Robustness Graph...")
    labels = [s['name'].split(':')[1].strip() for s in stats_summary]
    orig_scores = [s['scores']['Original'] for s in stats_summary]
    easy_scores = [s['scores']['Easy'] for s in stats_summary]
    hard_scores = [s['scores']['Hard'] for s in stats_summary]

    x = np.arange(len(labels))
    width = 0.25

    fig, ax = plt.subplots(figsize=(12, 6))
    rects1 = ax.bar(x - width, orig_scores, width, label='Original', color='#4e79a7')
    rects2 = ax.bar(x, easy_scores, width, label='Easy', color='#59a14f')
    rects3 = ax.bar(x + width, hard_scores, width, label='Hard', color='#e15759')

    ax.set_ylabel('Target Confidence Score')
    ax.set_title('Impact of Contextual Bias on ResNet50 Confidence')
    ax.set_xticks(x)
    ax.set_xticklabels(labels)
    ax.legend()
    ax.grid(axis='y', linestyle='--', alpha=0.7)

    def autolabel(rects):
        for rect in rects:
            height = rect.get_height()
            ax.annotate(f'{height:.2f}', xy=(rect.get_x() + rect.get_width() / 2, height),
                        xytext=(0, 3), textcoords="offset points", ha='center', va='bottom', fontsize=8)

    autolabel(rects1); autolabel(rects2); autolabel(rects3)

    plt.tight_layout()
    plt.savefig(f"{base_dir}/Quantitative_Confidence_Drop.png", dpi=150)
    plt.show()
    plt.close()

    clear_output(wait=True)

    print("Context-Swap image generation completed")
    
    surgeon = surgeon.to("cpu")
    evaluator.free_memory()
else:
    print("Error: Variable 'surgeon' not found.")

Context-Swap image generation completed
   Memory Cleared.


### Quantitative Analysis – Contextual Robustness & Fidelity

This cell executes the mathematical evaluation of the Contextual Swapping experiment. It compares the "Natural" (Original) images against the "Hostile" (Hard/OOD) images to measure the model's **Contextual Bias** and **Rigidity**.

#### Key Analytical Modules:

1.  **Semantic Integrity Analysis (The "Identity" Check)**:
    * **ResNet & CLIP Gap Analysis**: Compares the confidence scores of the *Subject* in its natural habitat vs. the hostile environment.
    * *The Metric:* **$\Delta$ Score (Natural - Hostile)**.
    * *Interpretation:* A large drop in confidence (e.g., Bear in Forest = 99% $\to$ Bear in Supermarket = 40%) proves that the model relies heavily on **Contextual Co-occurrence**. A small drop indicates high **OOD Robustness**.

2.  **Structural Fidelity Analysis (Background Shift)**:
    * **SSIM & LPIPS**: Measures the magnitude of the visual change.
    * *Note:* Unlike Object Swapping, here we expect lower SSIM scores because the entire background has changed. However, LPIPS helps ensure the *subject itself* hasn't been perceptually distorted during the transition.

3.  **Spectral Analysis (FFT)**:
    * Performs Fourier Analysis on the images.
    * Checks if the "Hostile" context (which might be unnatural for the model) introduces high-frequency noise or spectral artifacts compared to the original image.

4.  **Automated Captioning (BLIP)**:
    * Generates captions for the OOD images. If BLIP captions the image as *"A bear in a supermarket"*, it confirms the semantic edit was successful and recognizable, even if the diffusion model struggled to generate it.

In [6]:
base_dir = "results/context_swap_results"
json_path = os.path.join(base_dir, "experiment_data.json")
print(f"Starting Analysis (Context Swap) on: {base_dir}")

if os.path.exists(json_path):
    with open(json_path, 'r') as f: experiments = json.load(f)

    metrics_data = []
    
    print("   Processing metrics (ResNet, ViT, CLIP, SSIM, LPIPS)...")

    for exp in experiments:
        name = exp['scenario_name'].split(':')[1].strip()
        target_key = exp['target_object']
        
        paths = exp['paths']
        path_orig = paths.get('Original')
        path_hard = paths.get('Hard') 
        
        if path_orig and path_hard and os.path.exists(path_orig) and os.path.exists(path_hard):
            # 1. Structural Metrics (SSIM, LPIPS) - Original vs Hard
            s, m, l = evaluator.get_structural_metrics(path_orig, path_hard)
            
            # 2. Semantic Metrics (Load Images)
            img_orig = Image.open(path_orig).convert('RGB')
            img_hard = Image.open(path_hard).convert('RGB')
            
            rn_nat = evaluator.get_resnet_conf(img_orig, target_key)
            rn_ood = evaluator.get_resnet_conf(img_hard, target_key)
            
            # CLIP Semantic Alignment (Target Object)
            clip_nat = evaluator.get_clip_score(img_orig, f"a photo of a {target_key}")
            clip_ood = evaluator.get_clip_score(img_hard, f"a photo of a {target_key}")

            # BLIP Caption
            caption = evaluator.get_blip_caption(img_hard)

            metrics_data.append({
                "Scenario": name, 
                "SSIM": s, "MSE": m, "LPIPS": l,
                "ResNet Natural": rn_nat,
                "ResNet Hostile": rn_ood,
                "CLIP Natural": clip_nat,
                "CLIP Hostile": clip_ood,
                "Caption": caption,
                "Path_Orig": path_orig,
                "Path_Hard": path_hard
            })
            print(f"   {name}: Gap ResNet = {rn_nat - rn_ood:.2f} | Gap CLIP = {clip_nat - clip_ood:.1f}")

    if metrics_data:
        df_metrics = pd.DataFrame(metrics_data)
        
        # --- CHART 1: SEMANTIC INTEGRITY (ResNet + CLIP) ---
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 6))
        
        x = np.arange(len(df_metrics))
        w = 0.35
        
        # Subplot 1: ResNet (Robustness)
        ax1.bar(x - w/2, df_metrics['ResNet Natural'], w, label='Natural Context (Baseline)', color='#2ecc71', alpha=0.8)
        ax1.bar(x + w/2, df_metrics['ResNet Hostile'], w, label='Hostile Context (Stress)', color='#e74c3c', alpha=0.8)
        ax1.set_xticks(x)
        ax1.set_xticklabels(df_metrics['Scenario'], rotation=15)
        ax1.set_title('ResNet-50 Robustness (Does recognition drop?)')
        ax1.set_ylabel('Confidence Probability')
        ax1.legend()
        ax1.grid(axis='y', alpha=0.3)

        # Subplot 2: CLIP (Persistence)
        ax2.bar(x - w/2, df_metrics['CLIP Natural'], w, label='Natural Context', color='teal', alpha=0.8)
        ax2.bar(x + w/2, df_metrics['CLIP Hostile'], w, label='Hostile Context', color='orange', alpha=0.8)
        ax2.set_xticks(x)
        ax2.set_xticklabels(df_metrics['Scenario'], rotation=15)
        ax2.set_title('CLIP Semantic Persistence (Is the concept maintained?)')
        ax2.set_ylabel('CLIP Logits')
        ax2.legend()
        ax2.grid(axis='y', alpha=0.3)

        plt.suptitle("Semantic Integrity Analysis", fontsize=16, fontweight='bold')
        plt.tight_layout()
        plt.savefig(os.path.join(base_dir, "Semantic_Integrity_ResNet_CLIP.png"), dpi=150)
        plt.show()
        plt.close()

        # --- CHART 2: STRUCTURAL FIDELITY (SSIM + LPIPS) ---
        fig, ax1 = plt.subplots(figsize=(10, 6))
        
        ax1.bar(df_metrics['Scenario'], df_metrics['SSIM'], color='purple', alpha=0.6, label='SSIM (Structure)')
        ax1.set_ylabel('SSIM', color='purple', fontweight='bold')
        ax1.set_ylim(0, 1.0)
        
        ax2 = ax1.twinx()
        ax2.plot(df_metrics['Scenario'], df_metrics['LPIPS'], color='orange', marker='o', lw=3, label='LPIPS (Perceptual Diff)')
        ax2.set_ylabel('LPIPS', color='orange', fontweight='bold')
        ax2.set_ylim(0, 0.8) 
        
        plt.title("Structural Fidelity Analysis: SSIM vs LPIPS")
        plt.tight_layout()
        plt.savefig(os.path.join(base_dir, "Structural_Fidelity_Metrics.png"), dpi=150)
        plt.show()
        plt.close()

        # --- CHART 3: FFT ANALYSIS (Spectral Quality) ---
        fig, axes = plt.subplots(len(df_metrics), 3, figsize=(12, 3.5 * len(df_metrics)))
        if len(df_metrics) == 1: axes = np.array([axes]) 
        
        for idx, row in df_metrics.iterrows():
            s_orig = evaluator.get_fft_spectrum(row['Path_Orig'])
            s_hard = evaluator.get_fft_spectrum(row['Path_Hard'])
            
            if s_orig is not None and s_hard is not None:
                diff = np.abs(s_hard - s_orig)
                energy = np.mean(diff)
                
                axes[idx, 0].imshow(s_orig, cmap='inferno'); axes[idx, 0].axis('off')
                axes[idx, 0].set_title(f"{row['Scenario']} Natural")
                
                axes[idx, 1].imshow(s_hard, cmap='inferno'); axes[idx, 1].axis('off')
                axes[idx, 1].set_title(f"{row['Scenario']} Hostile")
                
                axes[idx, 2].imshow(diff, cmap='gray'); axes[idx, 2].axis('off')
                axes[idx, 2].set_title(f"Spectral Shift (Energy: {energy:.2f})")
                
        plt.suptitle("Contextual Fidelity: FFT Spectral Analysis", y=1.01, fontsize=16, fontweight='bold')
        plt.tight_layout()
        plt.savefig(os.path.join(base_dir, "Structural_Fidelity_FFT.png"), dpi=150)
        plt.show()
        plt.close()

        # --- BLIP REPORT ---
        print("\nBLIP CAPTIONS REPORT:")
        print("-" * 60)
        report_path = os.path.join(base_dir, "swap_final_report.txt")
        with open(report_path, "w") as f:
            header = f"{'SCENARIO':<20} | {'CLIP HARD':<10} | {'SSIM':<6} | DESCRIPTION\n" + "-"*80 + "\n"
            f.write(header)
            for _, row in df_metrics.iterrows():
                line_scr = f"{row['Scenario']:<20} | {row['CLIP Hostile']:.2f}       | {row['SSIM']:.2f}   | {row['Caption']}"
                print(line_scr)
                f.write(line_scr + "\n")

    clear_output(wait=True)
    
    print("\nQuantitative Analysis (Context Swap) Completed.")
    evaluator.free_memory()
else:
    print("JSON not found. Run Context Gen first.")


Quantitative Analysis (Context Swap) Completed.
   Memory Cleared.


# 3. Investigation II: Hyperparameter Analysis (Ablation Study)

This section executes a systematic **Grid Search** to map the "Hyperparameter Landscape" of the *Semantic Surgery* framework. Its goal is to empirically identify the optimal balance between the two core control levers: **Force ($\lambda$)** and **Sensitivity ($\alpha$)**.

### Experimental Setup
* **Target Scenario**: *Lightbulb* $\to$ *Firefly*. This case was chosen for its complexity: it requires high precision (replacing a small, glowing filament) while preserving a dark, delicate background.
* **Search Space**: We iterate through a $5 \times 5$ matrix of combinations:
    * **Force Values ($\lambda$)**: `[0.6 ... 1.4]` (From weak influence to aggressive editing).
    * **Sensitivity Values ($\alpha$)**: `[0.05 ... 0.25]` (From broad masking to extremely surgical precision).

### Methodology
1.  **Matrix Generation**: We generate 25 variations of the same prompt, one for each parameter pair.
2.  **Dual Metric Calculation**: For every variation, we calculate:
    * **Efficacy (CLIP Score)**: Measures if the semantic concept "Firefly" is present.
    * **Fidelity (SSIM)**: Measures if the original structure (background) is preserved.
3.  **Optimization**: We look for the "Sweet Spot" (Pareto efficiency) where CLIP is maximized without collapsing SSIM.

### Outputs
* **Visual Grid**: A $5 \times 5$ image matrix showing the visual degradation/improvement.
* **Trade-off Heatmaps**: Two side-by-side heatmaps visualizing the tension between Editing Power (CLIP) and Preservation (SSIM). The optimal configuration (Force=1.0, Sens=0.15) is highlighted.

In [2]:
evaluator.free_memory()
ablation_dir = "results/ablation_results"
os.makedirs(ablation_dir, exist_ok=True)
print(f"Starting Ablation Study in: {ablation_dir}/")

# Define the Ablation Scenario
# target_scenario = {
#     "prompt": "A lightbulb glowing in the dark",
#     "remove": "lightbulb",
#     "replace": "firefly",
#     "target_obj": "firefly"
# }

# target_scenario = {
#     "prompt": "A dog in the garden",
#     "remove": "dog",
#     "replace": "cat",
#     "target_obj": "cat"
# }

target_scenario = {
    "prompt": "A pizza in a plate",
    "remove": "pizza",
    "replace": "cake",
    "target_obj": "cake"
}

print(f"   Scenario selected: '{target_scenario['remove']}' -> '{target_scenario['replace']}'")

force_values = [0.6, 0.8, 1.0, 1.2, 1.4] 
sens_values = [0.05, 0.10, 0.15, 0.20, 0.25]

results_matrix_clip = np.zeros((len(force_values), len(sens_values)))
results_matrix_ssim = np.zeros((len(force_values), len(sens_values)))

if 'surgeon' in locals():
    print("   Generating Sensitivity Grid (This takes time)...")
    
    # Setup Visual Grid
    fig_grid, axes_grid = plt.subplots(len(force_values), len(sens_values), figsize=(14, 14))
    plt.subplots_adjust(wspace=0.1, hspace=0.3)
    
    # Generate Reference (Original)
    gen = torch.Generator("cpu").manual_seed(42)
    surgeon.concepts_to_erase = []
    img_ref = surgeon([target_scenario['prompt']], img_size=512, n_steps=30, n_imgs=1, 
                      show_alpha=False, generator=gen, replace_with=None)[0][0]
    img_ref.save(os.path.join(ablation_dir, "Reference.png"))
    
    # Grid Loop
    for i, force in enumerate(force_values):
        for j, sens in enumerate(sens_values):
            print(f"   Testing: Force={force}, Sens={sens}...")
            
            # Apply Parameters
            surgeon.params['lambda'] = force
            surgeon.params['alpha_threshold'] = sens
            surgeon.concepts_to_erase = [target_scenario['remove']]
            
            # Generate Variation
            gen.manual_seed(42) 
            img_res = surgeon([target_scenario['prompt']], img_size=512, n_steps=30, n_imgs=1, 
                              show_alpha=False, generator=gen, replace_with=target_scenario['replace'])[0][0]
            
            # Calculate Metrics
            clip_score = evaluator.get_clip_score_single(img_res, f"a photo of a {target_scenario['target_obj']}")
            ssim_score = evaluator.get_ssim_score(img_ref, img_res)
            
            # Store Results
            results_matrix_clip[i, j] = clip_score
            results_matrix_ssim[i, j] = ssim_score
            
            # Plot in Grid
            ax = axes_grid[i, j]
            ax.imshow(img_res)
            
            # Highlight Optimal Param (Hypothesis)
            is_optimal = (force==1.0 and sens==0.15)
            color = 'green' if is_optimal else 'black'
            fontweight = 'bold' if is_optimal else 'normal'
            
            ax.set_title(f"F={force}, S={sens}\nCLIP:{clip_score:.1f} | SSIM:{ssim_score:.2f}", 
                         fontsize=9, color=color, fontweight=fontweight)
            ax.axis('off')
            
            img_res.save(os.path.join(ablation_dir, f"F{force}_S{sens}.png"))
            evaluator.free_memory()

    # Save Visual Grid
    plt.suptitle(f"Sensitivity Analysis: {target_scenario['remove']} -> {target_scenario['replace']}", fontsize=16)
    plt.savefig(os.path.join(ablation_dir, "Ablation_Visual_Grid.png"), dpi=150)
    plt.show()
    
    # Generate Heatmaps
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
    
    # Heatmap 1: Efficacy (CLIP)
    sns.heatmap(results_matrix_clip, annot=True, fmt=".1f", cmap="viridis", ax=ax1,
                xticklabels=sens_values, yticklabels=force_values)
    ax1.set_title("Efficacy Landscape (CLIP Score)\nDoes it look like the new object?")
    ax1.set_xlabel("Sensitivity (Alpha)")
    ax1.set_ylabel("Force (Lambda)")
    
    # Heatmap 2: Fidelity (SSIM)
    sns.heatmap(results_matrix_ssim, annot=True, fmt=".2f", cmap="magma", ax=ax2,
                xticklabels=sens_values, yticklabels=force_values)
    ax2.set_title("Fidelity Landscape (SSIM)\nIs the background preserved?")
    ax2.set_xlabel("Sensitivity (Alpha)")
    ax2.set_ylabel("Force (Lambda)")
    
    # Highlight Optimal Zone
    try:
        idx_f = force_values.index(1.0)
        idx_s = sens_values.index(0.15)
        from matplotlib.patches import Rectangle
        for ax in [ax1, ax2]:
            ax.add_patch(Rectangle((idx_s, idx_f), 1, 1, fill=False, edgecolor='cyan', lw=3))
    except: pass

    plt.suptitle("Hyperparameter Stability & Trade-off", fontsize=16, fontweight='bold')
    plt.savefig(os.path.join(ablation_dir, "Ablation_Heatmaps.png"), dpi=150)
    plt.show()

    clear_output(wait=True)
    print("Ablation Study Completed.")
else:
    print("Error: Surgeon model not loaded.")

Ablation Study Completed.


# 4. Investigation III: Stress-Testing & Bias Quantification

## Part A: Adversarial Benchmark Generation

This section generates the **Custom Stress-Test Dataset**. Unlike the validation phase, here we act as adversaries: we intentionally feed the model "impossible," "conflicting," or "stereotype-prone" instructions to provoke and measure its internal failures.

### The Attack Strategy (3 Bias Vectors)
We probe the model across three distinct dimensions of potential semantic failure:

1.  **Contextual Bias (OOD)**: Forcing objects into incompatible environments (e.g., *Boat* $\to$ *Desert*, *Fish* $\to$ *Forest*).
    * *Goal:* To see if the model refuses the edit or "hallucinates" the old context (e.g., putting water in the desert) to make sense of the prompt.
2.  **Attribute Entanglement (Visual Leakage)**: Swapping objects with strong associations to specific colors or textures.
    * *Colour Leakage:* e.g., *Goldfish* (Orange) $\to$ *Shark* (White). Does the shark come out orange?
    * *Texture Leakage:* e.g., *Zebra* (Striped) $\to$ *Horse* (Smooth). Does the horse keep the stripes?
3.  **Societal Bias**: Swapping professions (e.g., *Doctor* $\to$ *Nurse*, *CEO* $\to$ *Teacher*).
    * *Goal:* To check if the model implicitly flips the gender (e.g., Male $\to$ Female) based on occupational stereotypes learned during training.

### Experimental Setup
* **Adversarial Tuning**: Unlike standard generation, we use specific $(\lambda, \alpha)$ pairs for each case. These parameters are chosen to be on the **"Stability Boundary"**—aggressive enough to force the edit, but sensitive enough to expose latent rigidity if the model resists the change.
* **Multi-Seed Protocol**: For every scenario, we generate 3 distinct variations (different seeds) to differentiate between random generation noise and systematic model bias.

In [3]:
benchmark_dir = "results/benchmark_dataset"
os.makedirs(benchmark_dir, exist_ok=True)
print(f"Generating Dataset at: {benchmark_dir}/")

# Define the Adversarial Benchmark Suite
# Each case has specific Force/Sens parameters to maximize stress testing
benchmark_suite = [
    # --- A. CONTEXTUAL BIAS (OOD Environments) ---
    # 1. CAT (Sofa -> Ocean)
    {
        "id": "context_cat_ocean",
        "type": "context_bias",
        "seed": 42, "force": 1.0, "sens": 0.15,
        "transplant_setup": { "base_prompt": "A cat on the sofa", "remove": "sofa", "inject": "ocean" },
        "leakage_concept": "sofa"
    },
    {
        "id": "context_cat_ocean",
        "type": "context_bias",
        "seed": 46, "force": 0.8, "sens": 0.15,
        "transplant_setup": { "base_prompt": "A cat on the sofa", "remove": "sofa", "inject": "ocean" },
        "leakage_concept": "sofa"
    },
    {
        "id": "context_cat_ocean",
        "type": "context_bias",
        "seed": 47, "force": 0.4, "sens": 0.15,
        "transplant_setup": { "base_prompt": "A cat on the sofa", "remove": "sofa", "inject": "ocean" },
        "leakage_concept": "sofa"
    },
    # 2. BOAT (Ocean -> Desert)
    {
        "id": "context_boat_desert",
        "type": "context_bias",
        "seed": 42, "force": 1.4, "sens": 0.25,
        "transplant_setup": { "base_prompt": "A boat sailing in the ocean", "remove": "ocean", "inject": "desert" },
        "leakage_concept": "water"
    },
    {
        "id": "context_boat_desert",
        "type": "context_bias",
        "seed": 43, "force": 1.4, "sens": 0.25,
        "transplant_setup": { "base_prompt": "A boat sailing in the ocean", "remove": "ocean", "inject": "desert" },
        "leakage_concept": "water"
    },
    {
        "id": "context_boat_desert",
        "type": "context_bias",
        "seed": 44, "force": 1.4, "sens": 0.25,
        "transplant_setup": { "base_prompt": "A boat sailing in the ocean", "remove": "ocean", "inject": "desert" },
        "leakage_concept": "water"
    },
    # 3. GOLDIFISH (Ocean -> Forest)
    {
        "id": "context_ocean_forest",
        "type": "context_bias",
        "seed": 42, "force": 1.0, "sens": 0.15,
        "transplant_setup": { "base_prompt": "A goldfish in the ocean", "remove": "ocean", "inject": "forest" },
        "leakage_concept": "water"
    },
    {
        "id": "context_ocean_forest",
        "type": "context_bias",
        "seed": 43, "force": 1.2, "sens": 0.2,
        "transplant_setup": { "base_prompt": "A goldfish in the ocean", "remove": "ocean", "inject": "forest" },
        "leakage_concept": "water"
    },
    {
        "id": "context_ocean_forest",
        "type": "context_bias",
        "seed": 46, "force": 0.8, "sens": 0.1,
        "transplant_setup": { "base_prompt": "A goldfish in the ocean", "remove": "ocean", "inject": "forest" },
        "leakage_concept": "water"
    },
    # --- B. ATTRIBUTE ENTANGLEMENT (Visual Leakage) ---
    # 4. GOLDFISH -> SHARK
    {
        "id": "entangle_fish",
        "type": "attribute_leakage",
        "seed": 42, "force": 0.8, "sens": 0.15,
        "transplant_setup": { "base_prompt": "A goldfish swimming in a fish tank", "remove": "goldfish", "inject": "white shark" },
        "leakage_concept": "orange"
    },
    {
        "id": "entangle_fish",
        "type": "attribute_leakage",
        "seed": 43, "force": 0.8, "sens": 0.15,
        "transplant_setup": { "base_prompt": "A goldfish swimming in a fish tank", "remove": "goldfish", "inject": "white shark" },
        "leakage_concept": "orange"
    },
    {
        "id": "entangle_fish",
        "type": "attribute_leakage",
        "seed": 44, "force": 0.8, "sens": 0.15,
        "transplant_setup": { "base_prompt": "A goldfish swimming in a fish tank", "remove": "goldfish", "inject": "white shark" },
        "leakage_concept": "orange"
    },
    # 5. FLAMINGO -> HERON
    {
        "id": "entangle_flamingo",
        "type": "attribute_leakage",
        "seed": 42, "force": 1.0, "sens": 0.15,
        "transplant_setup": { "base_prompt": "A pink flamingo standing in water", "remove": "pink flamingo", "inject": "grey heron" },
        "leakage_concept": "pink"
    },
    {
        "id": "entangle_flamingo",
        "type": "attribute_leakage",
        "seed": 43, "force": 1.0, "sens": 0.15,
        "transplant_setup": { "base_prompt": "A pink flamingo standing in water", "remove": "pink flamingo", "inject": "grey heron" },
        "leakage_concept": "pink"
    },
    {
        "id": "entangle_flamingo",
        "type": "attribute_leakage",
        "seed": 44, "force": 0.8, "sens": 0.15,
        "transplant_setup": { "base_prompt": "A pink flamingo standing in water", "remove": "pink flamingo", "inject": "grey heron" },
        "leakage_concept": "pink"
    },
    # 6. ZEBRA -> HORSE
    {
        "id": "entangle_zebra",
        "type": "attribute_leakage",
        "seed": 42, "force": 0.9, "sens": 0.15,
        "transplant_setup": { "base_prompt": "A zebra grazing", "remove": "zebra", "inject": "horse" },
        "leakage_concept": "stripes"
    },
    {
        "id": "entangle_zebra",
        "type": "attribute_leakage",
        "seed": 44, "force": 1.0, "sens": 0.15,
        "transplant_setup": { "base_prompt": "A zebra grazing", "remove": "zebra", "inject": "horse" },
        "leakage_concept": "stripes"
    },
    {
        "id": "entangle_zebra",
        "type": "attribute_leakage",
        "seed": 47, "force": 0.6, "sens": 0.15,
        "transplant_setup": { "base_prompt": "A zebra grazing", "remove": "zebra", "inject": "horse" },
        "leakage_concept": "stripes"
    },
    # --- C. SOCIETAL BIAS (Gender Stereotypes) ---
    # 7. DOCTOR -> NURSE
    {
        "id": "bias_doctor_nurse",
        "type": "gender_bias",
        "seed": 42, "force": 0.8, "sens": 0.15,
        "transplant_setup": { "base_prompt": "A doctor", "remove": "doctor", "inject": "nurse" },
        "leakage_concept": "female"
    },
    {
        "id": "bias_doctor_nurse",
        "type": "gender_bias",
        "seed": 44, "force": 1.0, "sens": 0.15,
        "transplant_setup": { "base_prompt": "A doctor", "remove": "doctor", "inject": "nurse" },
        "leakage_concept": "female"
    },
    {
        "id": "bias_doctor_nurse",
        "type": "gender_bias",
        "seed": 45, "force": 0.8, "sens": 0.15,
        "transplant_setup": { "base_prompt": "A doctor", "remove": "doctor", "inject": "nurse" },
        "leakage_concept": "female"
    },

    # 8. CEO -> TEACHER
    {
        "id": "bias_ceo_teacher",
        "type": "gender_bias",
        "seed": 42, "force": 1.2, "sens": 0.25,
        "transplant_setup": { "base_prompt": "A ceo in a suit", "remove": "ceo", "inject": "teacher" },
        "leakage_concept": "female"
    },
    {
        "id": "bias_ceo_teacher",
        "type": "gender_bias",
        "seed": 43, "force": 0.7, "sens": 0.17,
        "transplant_setup": { "base_prompt": "A ceo in a suit", "remove": "ceo", "inject": "teacher" },
        "leakage_concept": "female"
    },
    {
        "id": "bias_ceo_teacher",
        "type": "gender_bias",
        "seed": 45, "force": 0.8, "sens": 0.1,
        "transplant_setup": { "base_prompt": "A ceo in a suit", "remove": "ceo", "inject": "teacher" },
        "leakage_concept": "female"
    },


    # 9. MANAGER -> SECRETARY
    {
        "id": "bias_manager_secretary",
        "type": "gender_bias",
        "seed": 42, "force": 0.8, "sens": 0.25,
        "transplant_setup": { "base_prompt": "A manager in the office", "remove": "manager", "inject": "secretary" },
        "leakage_concept": "female"
    },
    {
        "id": "bias_manager_secretary",
        "type": "gender_bias",
        "seed": 43, "force": 0.4, "sens": 0.15,
        "transplant_setup": { "base_prompt": "A manager in the office", "remove": "manager", "inject": "secretary" },
        "leakage_concept": "female"
    },
    {
        "id": "bias_manager_secretary",
        "type": "gender_bias",
        "seed": 44, "force": 0.6, "sens": 0.20,
        "transplant_setup": { "base_prompt": "A manager in the office", "remove": "manager", "inject": "secretary" },
        "leakage_concept": "female"
    }
]

metadata_log = []

if 'surgeon' in locals():
    
    print(f"Starting Batch Generation: {len(benchmark_suite)} specific cases.\n")

    for idx, case in enumerate(benchmark_suite):
        
        seed = case['seed']
        force = case['force']
        sens = case['sens']
        setup = case['transplant_setup']
        test_id = case['id']
        
        # Apply Tuned Parameters
        surgeon.params['lambda'] = force
        surgeon.params['alpha_threshold'] = sens
        
        full_prompt = setup['base_prompt']

        filename_ref = f"{test_id}_seed{seed}_ref.png"
        filename_trans = f"{test_id}_seed{seed}_trans.png"
        path_ref = os.path.join(benchmark_dir, filename_ref)
        path_trans = os.path.join(benchmark_dir, filename_trans)
        
        print(f"[{idx+1}/{len(benchmark_suite)}] {test_id} | Seed: {seed} | F: {force} | S: {sens}")

        # Skip if already generated
        if os.path.exists(path_ref) and os.path.exists(path_trans):
            print(f"   -> Skipping (Already exists)")
            metadata_log.append({
                "test_id": test_id, "type": case['type'], "seed": seed,
                "path_ref": path_ref, "path_trans": path_trans,
                "base_prompt": full_prompt,
                "target_concept": setup['inject'], 
                "leakage_concept": case['leakage_concept'],
                "params": {"force": force, "sens": sens}
            })
            continue
        
        # 1. Reference Generation (Original)
        gen = torch.Generator("cpu").manual_seed(seed)
        surgeon.concepts_to_erase = [] 
        
        img_ref = surgeon([full_prompt], img_size=512, n_steps=30, n_imgs=1, 
                          show_alpha=False, generator=gen, replace_with=None)[0][0]

        # 2. Transplant Generation (Modified)
        gen.manual_seed(seed) 
        surgeon.concepts_to_erase = [setup['remove']]
        
        img_trans = surgeon([full_prompt], img_size=512, n_steps=30, n_imgs=1, 
                            show_alpha=False, generator=gen, replace_with=setup['inject'])[0][0]
        
        img_ref.save(path_ref)
        img_trans.save(path_trans)
        
        metadata_log.append({
            "test_id": test_id,
            "type": case['type'],
            "seed": seed,
            "path_ref": path_ref,
            "path_trans": path_trans,
            "base_prompt": full_prompt,
            "target_concept": setup['inject'], 
            "leakage_concept": case['leakage_concept'],
            "params": {"force": force, "sens": sens}
        })

    with open(os.path.join(benchmark_dir, "benchmark_log.json"), 'w') as f:
        json.dump(metadata_log, f, indent=4)

    clear_output(wait=True)
    
    print(f"\nDataset Generation Complete: {len(metadata_log)} items logged.")

else:
    print("Error: Surgeon class/object not loaded defined defined previous cells.")


Dataset Generation Complete: 27 items logged.


## Part B: Forensic Bias Quantification

This cell executes the forensic analysis of the adversarial dataset. It employs specialized neural probes to quantify exactly **how much** the model's biases influenced the generation process.

### The 3 Pillars of Bias Analysis:

1.  **Contextual Conflict Analysis**:
    * **Question**: *"In the 'Boat in Desert' image, did the model hallucinate water?"*
    * **Metrics**: We calculate **SSIM** (Structural Similarity) and **LPIPS** (Perceptual Distance) on the background.
    * **Interpretation**: High SSIM means the background didn't change (Bias: The model refused to put the boat in a desert). Low SSIM means the background changed successfully.

2.  **Attribute Entanglement (Visual Leakage)**:
    * **Question**: *"Did the visual attributes of the old object leak into the new one?"*
    * **Color Leakage**: We measure the RGB intensity in the target mask. If a shark (target) has high Red/Orange values, it confirms leakage from the Goldfish (source).
    * **Texture Leakage**: We use CLIP to classify the texture of the new object (e.g., "Striped" vs "Smooth") to see if the Zebra's stripes persisted on the Horse.

3.  **Societal Bias (Gender Shift)**:
    * **Question**: *"When we changed the profession, did the gender flip automatically?"*
    * **Metric**: We use CLIP to classify the subject as "Man" or "Woman".
    * **Flip Rate**: We measure the frequency with which a profession change (e.g., *Doctor* $\to$ *Nurse*) triggers a gender change, revealing deep-seated occupational stereotypes.

In [4]:
BENCHMARK_DIR = "results/benchmark_dataset"
ANALYSIS_DIR = "results/analysis_results"
LOG_FILE = os.path.join(BENCHMARK_DIR, "benchmark_log.json")
os.makedirs(ANALYSIS_DIR, exist_ok=True)

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Running Analysis on: {device}")

# --- LOAD MODELS ---
print("Loading Models...")
# Segmentation
seg_processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
seg_model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined").to(device)
# Classification
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
# Perceptual Distance
lpips_loss = lpips.LPIPS(net='alex').to(device)
print("Models Loaded.\n")

def get_mask(image, text_prompt, threshold=0.4):
    inputs = seg_processor(text=[text_prompt], images=[image], padding="max_length", return_tensors="pt").to(device)
    with torch.no_grad(): outputs = seg_model(**inputs)
    preds = torch.sigmoid(outputs.logits)
    mask = preds.squeeze().cpu().numpy()
    mask = cv2.resize(mask, image.size)
    return (mask > threshold).astype(np.uint8), mask

def get_average_color_in_mask(image, mask):
    img_arr = np.array(image)
    if mask.sum() == 0: return (0,0,0)
    return tuple(img_arr[mask == 1].mean(axis=0).astype(int))

def clip_classification(image, text_classes):
    inputs = clip_processor(text=text_classes, images=image, return_tensors="pt", padding=True).to(device)
    with torch.no_grad(): outputs = clip_model(**inputs)
    probs = outputs.logits_per_image.softmax(dim=1).cpu().numpy()[0]
    best_idx = np.argmax(probs)
    return text_classes[best_idx], probs[best_idx]

def calc_background_metrics(img_ref, img_trans, subject_mask):
    bg_mask = 1 - subject_mask 
    
    # SSIM
    gray_ref = cv2.cvtColor(np.array(img_ref), cv2.COLOR_RGB2GRAY)
    gray_trans = cv2.cvtColor(np.array(img_trans), cv2.COLOR_RGB2GRAY)
    score, diff_map = ssim(gray_ref, gray_trans, full=True)
    bg_ssim = (diff_map * bg_mask).sum() / (bg_mask.sum() + 1e-6) 

    # LPIPS
    tf = transforms.Compose([transforms.Resize((256,256)), transforms.ToTensor(), transforms.Normalize((0.5,),(0.5,))])
    
    ref_np = np.array(img_ref) * bg_mask[:,:,None]
    trans_np = np.array(img_trans) * bg_mask[:,:,None]
    
    ref_tensor = tf(Image.fromarray(ref_np)).unsqueeze(0).to(device)
    trans_tensor = tf(Image.fromarray(trans_np)).unsqueeze(0).to(device)
    
    with torch.no_grad():
        bg_lpips = lpips_loss(ref_tensor, trans_tensor).item()
        
    return bg_ssim, bg_lpips

with open(LOG_FILE, 'r') as f: dataset = json.load(f)
analysis_report = []

print(f"Analyzing {len(dataset)} items...")

for idx, entry in enumerate(dataset):
    try:
        test_id = entry['test_id']
        test_type = entry['type']
        target = entry['target_concept']
        
        img_ref = Image.open(entry['path_ref']).convert("RGB").resize((512,512))
        img_trans = Image.open(entry['path_trans']).convert("RGB").resize((512,512))
        
        print(f"[{idx+1}] {test_id}...", end="\r")
        metrics = {}
        visual_text = []
        mask_vis = None

        # 1. ATTRIBUTE LEAKAGE (Color/Texture)
        if test_type == "attribute_leakage":
            mask_bin, mask_vis = get_mask(img_trans, target)
            if "zebra" in test_id or "horse" in target:
                cls, conf = clip_classification(img_trans, ["striped texture", "smooth texture"])
                metrics["texture_conf"] = float(conf)
                metrics["texture_class"] = cls
            else:
                r, g, b = get_average_color_in_mask(img_trans, mask_bin)
                metrics["avg_color_rgb"] = [int(r), int(g), int(b)]

        # 2. CONTEXTUAL BIAS (Background Shift)
        elif test_type == "context_bias":
            subj_prompt = entry['base_prompt'] 
            mask_bin, mask_vis = get_mask(img_trans, subj_prompt) 
            
            bg_ssim, bg_lpips = calc_background_metrics(img_ref, img_trans, mask_bin)
            
            metrics["background_ssim"] = float(bg_ssim)
            metrics["background_lpips"] = float(bg_lpips) 
            
            visual_text.append(f"SSIM: {bg_ssim:.2f} (High=Sim)")
            visual_text.append(f"LPIPS: {bg_lpips:.2f} (Low=Sim)")

        # 2. CONTEXTUAL BIAS (Background Shift)
        elif test_type == "gender_bias":
            cls, conf = clip_classification(img_trans, ["man", "woman"])
            metrics["gender_class"] = cls
            metrics["gender_conf"] = float(conf)
            mask_bin, mask_vis = get_mask(img_trans, "person")

        # --- CREATE VISUAL REPORT IMAGE ---
        res_img = Image.new('RGB', (1536, 562), (20, 20, 20))
        res_img.paste(img_ref, (0, 0))
        res_img.paste(img_trans, (512, 0))
        if mask_vis is not None:
            mask_c = cv2.applyColorMap((mask_vis * 255).astype(np.uint8), cv2.COLORMAP_JET)
            res_img.paste(Image.fromarray(cv2.cvtColor(mask_c, cv2.COLOR_BGR2RGB)).resize((512,512)), (1024, 0))
        
        draw = ImageDraw.Draw(res_img)
        try: font = ImageFont.truetype("arial.ttf", 40)
        except: font = ImageFont.load_default()
        draw.text((20, 520), f"{test_id} | {entry['seed']}", fill="white", font=font)
        for i, line in enumerate(visual_text): draw.text((1040, 20 + i*50), line, fill="white", font=font)
        
        res_img.save(os.path.join(ANALYSIS_DIR, f"ANALYSIS_{test_id}_{entry['seed']}.jpg"))
        entry['analysis_metrics'] = metrics
        analysis_report.append(entry)

    except Exception as e: print(f"\nError: {e}")

# Save JSON Report
with open(os.path.join(ANALYSIS_DIR, "final_analysis_report.json"), 'w') as f:
    json.dump(analysis_report, f, indent=4)

clear_output(wait=True)
print("\nAnalysis + LPIPS Complete.")


Analysis + LPIPS Complete.


### Visualizing the Bias Landscape

This cell generates the final visual report for the Stress Test. It translates the raw JSON metrics into three high-level charts, designed to offer an immediate understanding of the model's vulnerabilities.

#### The Visual Report Structure:

1.  **Contextual Bias (Dual Chart)**:
    * **Left (SSIM)**: Shows structural preservation. *Ideally High*.
    * **Right (LPIPS)**: Shows perceptual distance. *Ideally Low*.
    * *Insight:* Allows us to see if the background was successfully swapped (Low SSIM) or if the model refused the edit (High SSIM).

2.  **Attribute Entanglement (Dual Chart)**:
    * **Left (Texture)**: CLIP Confidence for "Striped" vs "Smooth". High bars indicate texture leakage.
    * **Right (Color)**: Red Channel Intensity. High bars for "Goldfish $\to$ Shark" indicate color leakage.

3.  **Societal Bias (Stereotype Analysis)**:
    * **Top (Average Gender)**: Shows the model's default assumption for a profession (Pink Zone = Female, Blue Zone = Male).
    * **Bottom (Flip Rate)**: Shows the percentage of times the gender was swapped from the original. A 100% bar means the bias is **systematic** and unavoidable.

In [5]:
ANALYSIS_DIR = "results/analysis_results"
REPORT_FILE = os.path.join(ANALYSIS_DIR, "final_analysis_report.json")

# Set Visual Style
sns.set_theme(style="whitegrid", context="talk")

def generate_visual_report():
    print("Generating Final Visual Report (Compact Groups)...")
    
    if not os.path.exists(REPORT_FILE):
        print(f"Error: {REPORT_FILE} not found.")
        return

    with open(REPORT_FILE, 'r') as f: data = json.load(f)
    
    # --- DATA PREPARATION ---
    rows = []
    for entry in data:
        row = {
            "id": entry['test_id'], 
            "type": entry['type'], 
            "seed": entry['seed'],
            "target": entry['target_concept']
        }
        metrics = entry.get('analysis_metrics', {})
        
        # Flatten metrics
        for k, v in metrics.items():
            if isinstance(v, list): 
                row[f"{k}_R"], row[f"{k}_G"], row[f"{k}_B"] = v
            else: row[k] = v
            
        # Logic for Gender Flip Rate
        if entry['type'] == 'gender_bias':
            conf = metrics.get('gender_conf', 0.5)
            cls = metrics.get('gender_class', '')
            row['female_prob'] = conf if "woman" in cls else 1.0 - conf
            
            gen_ref = metrics.get('gender_original', 'unknown')
            gen_trans = metrics.get('gender_class', 'unknown')
            ref_label = "woman" if "woman" in gen_ref else "man"
            trans_label = "woman" if "woman" in gen_trans else "man"
            row['has_flipped'] = 1 if (ref_label != trans_label) else 0

        rows.append(row)
    
    df = pd.DataFrame(rows)

    df = df.sort_values(['id', 'seed'])
    df['Variation'] = df.groupby('id').cumcount() + 1
    df['Variation'] = df['Variation'].astype(str) 

    # --- CHART 1: CONTEXT BIAS ---
    df_ctx = df[df['type'] == 'context_bias']
    if not df_ctx.empty:
        fig, axes = plt.subplots(1, 2, figsize=(16, 7))
        
        # SSIM
        sns.barplot(data=df_ctx, x="id", y="background_ssim", hue="Variation", ax=axes[0], palette="viridis")
        axes[0].set_title("Structural Similarity (SSIM)", fontsize=16, fontweight='bold')
        axes[0].set_ylabel("Similarity (1.0 = Identical)")
        axes[0].set_xlabel("")
        axes[0].set_ylim(0, 1.0)
        axes[0].axhline(0.5, color='r', linestyle='--', alpha=0.3)
        axes[0].legend(title="Run #", loc='lower right', fontsize=10)

        # LPIPS
        sns.barplot(data=df_ctx, x="id", y="background_lpips", hue="Variation", ax=axes[1], palette="magma_r")
        axes[1].set_title("Perceptual Distance (LPIPS)", fontsize=16, fontweight='bold')
        axes[1].set_ylabel("Distance (Lower is Better)")
        axes[1].set_xlabel("")
        axes[1].set_ylim(0, 0.8)
        axes[1].legend(title="Run #", loc='lower right', fontsize=10)
        
        plt.suptitle("Context Bias Analysis", fontsize=18)
        plt.tight_layout()
        plt.savefig(os.path.join(ANALYSIS_DIR, "CHART_1_Context_Bias_Dual.png"), bbox_inches='tight')
        plt.close()

    # --- CHART 2: ATTRIBUTE LEAKAGE ---
    df_attr = df[df['type'] == 'attribute_leakage']
    if not df_attr.empty:
        fig, axes = plt.subplots(1, 2, figsize=(16, 7))
        
        # Texture (Zebra) - Using Seed directly on X for detail since it's one scenario
        df_tex = df_attr[df_attr['id'].str.contains("zebra")]
        if not df_tex.empty:
            sns.barplot(data=df_tex, x="seed", y="texture_conf", hue="texture_class", ax=axes[0], palette="magma")
            axes[0].set_title("Texture Entanglement (Zebra)", fontsize=16)
            axes[0].set_ylabel("Confidence (0.0 - 1.0)")
            axes[0].set_ylim(0, 1.0)
        
        # Color (Fish/Flamingo) - Grouped by Variation
        df_col = df_attr[~df_attr['id'].str.contains("zebra")]
        if not df_col.empty:
            sns.barplot(data=df_col, x="id", y="avg_color_rgb_R", hue="Variation", ax=axes[1], palette="Reds")
            axes[1].set_title("Color Leakage (Red Intensity)", fontsize=16)
            axes[1].set_ylabel("Pixel Intensity (0 - 255)")
            axes[1].set_ylim(0, 255)
            axes[1].legend(title="Run #")
            
        plt.tight_layout()
        plt.savefig(os.path.join(ANALYSIS_DIR, "CHART_2_Attribute_Leakage.png"), bbox_inches='tight')
        plt.close()

    # --- CHART 3: SOCIETAL BIAS ---
    df_soc = df[df['type'] == 'gender_bias']
    if not df_soc.empty:
        fig, axes = plt.subplots(1, 2, figsize=(18, 8))

        # Avg Probability 
        avg_probs = df_soc.groupby('id')['female_prob'].mean().sort_values()
        custom_palette = sns.color_palette("coolwarm", as_cmap=True)(avg_probs.values)

        sns.barplot(
            data=df_soc, x="id", y="female_prob", 
            estimator=np.mean, ci=None, 
            ax=axes[0], palette=custom_palette
        )
        
        axes[0].set_title("A. Average Predicted Gender", fontsize=16, fontweight='bold')
        axes[0].set_ylabel("Prob. of being Female (0-1)", fontsize=14)
        axes[0].set_xlabel("")
        axes[0].set_ylim(0, 1.0)
        axes[0].axhline(0.5, color='gray', linestyle='--', linewidth=2)
        
        axes[0].axhspan(0.5, 1.0, color='pink', alpha=0.1, label="Female Zone")
        axes[0].axhspan(0.0, 0.5, color='lightblue', alpha=0.1, label="Male Zone")
        axes[0].text(0.5, 0.9, "Female Prediction", ha='center', color='purple', transform=axes[0].transAxes)
        axes[0].text(0.5, 0.1, "Male Prediction", ha='center', color='navy', transform=axes[0].transAxes)

        # Flip Rate 
        sns.barplot(
            data=df_soc, x="id", y="has_flipped", 
            estimator=lambda x: np.mean(x)*100, 
            ci=None, 
            ax=axes[1], palette="Reds"
        )
        
        axes[1].set_title("B. Frequency of Gender Flip", fontsize=16, fontweight='bold')
        axes[1].set_ylabel("Flip Rate (%)", fontsize=14)
        axes[1].set_xlabel("")
        axes[1].set_ylim(0, 100)
        
        for container in axes[1].containers:
            axes[1].bar_label(container, fmt='%.0f%%', padding=5, fontsize=12, fontweight='bold')

        plt.suptitle("Societal Bias Analysis: Stereotype & Consistency", fontsize=20)
        plt.tight_layout()
        plt.subplots_adjust(top=0.88) 
        
        plt.savefig(os.path.join(ANALYSIS_DIR, "CHART_3_Societal_Bias_Combined.png"), bbox_inches='tight')
        plt.close()

    clear_output(wait=True)
    print("Final Report Generated.")

generate_visual_report()

Final Report Generated.


# 5. Investigation IV: The "Surgery Autopilot" (Automation)

## Part A: Automated Training Data Generation

This section executes an exhaustive **Grid Search** to automatically generate the Ground Truth data required to train the Autopilot Regression Model.

### Objective
To build a dataset mapping arbitrary semantic shifts (Text Prompts) to their ideal surgical parameters ($\lambda, \alpha$).

### Methodology: Multi-Objective Optimization
For each of the **52 diverse scenarios** (covering Animals, Vehicles, Objects, and Contexts), the system performs a grid search across a pre-defined hyperparameter space to find the configuration that maximizes a weighted quality score.

* **Search Space**:
    * **Force ($\lambda$):** $\{0.8, 1.0, 1.2, 1.4\}$
    * **Sensitivity ($\alpha$):** $\{0.10, 0.15, 0.20\}$
* **The "Golden Score"**:
    $$\text{Score} = (W_{CLIP} \cdot \frac{\text{CLIP}}{30}) + (W_{SSIM} \cdot \text{SSIM})$$
    * Where $W_{CLIP}=0.6$ (prioritizing semantic change) and $W_{SSIM}=0.4$ (preserving structure).

The resulting best parameters for each scenario are saved to `final_training_dataset/dataset_log.csv` and will serve as the **Target Labels ($Y$)** for the machine learning models.

In [6]:
dataset_dir = "results/final_training_dataset"
os.makedirs(dataset_dir, exist_ok=True)

# Grid Search Space
search_forces = [0.8, 1.0, 1.2, 1.4]
search_sensitivities = [0.10, 0.15, 0.20]

# Optimization Weights
W_CLIP = 0.6
W_SSIM = 0.4

# 52 Scenarios covering multiple domains
scenarios_to_test = [
    {"id": "Ani_Bear_Tiger", "prompt": "A brown bear walking in a forest", "remove": "brown bear", "inject": "tiger"}, 
    {"id": "Ani_Dog_Cat", "prompt": "A golden retriever sitting on a sofa", "remove": "golden retriever", "inject": "tabby cat"}, 
    {"id": "Ani_Fish_Shark", "prompt": "A goldfish swimming in a glass bowl", "remove": "goldfish", "inject": "shark"}, 
    {"id": "Ani_Horse_Zebra", "prompt": "A brown horse galloping in a field", "remove": "brown horse", "inject": "zebra"}, 
    {"id": "Ani_Bird_Parrot", "prompt": "A sparrow sitting on a tree branch", "remove": "sparrow", "inject": "colorful parrot"}, 
    {"id": "Veh_Car_Firetruck", "prompt": "A red sedan driving on a highway", "remove": "red sedan", "inject": "firetruck"}, 
    {"id": "Veh_Boat_Yacht", "prompt": "A wooden fishing boat in the ocean", "remove": "wooden fishing boat", "inject": "luxury yacht"}, 
    {"id": "Veh_Bike_Moto", "prompt": "A bicycle parked against a wall", "remove": "bicycle", "inject": "motorcycle"}, 
    {"id": "Veh_Bus_Train", "prompt": "A yellow school bus on the street", "remove": "yellow school bus", "inject": "green tram"}, 
    {"id": "Obj_Apple_Orange", "prompt": "A red apple on a wooden table", "remove": "red apple", "inject": "orange"}, 
    {"id": "Obj_Pizza_Cake", "prompt": "A pepperoni pizza on a plate", "remove": "pepperoni pizza", "inject": "chocolate cake"}, 
    {"id": "Obj_Coffee_Beer", "prompt": "A cup of hot coffee on a desk", "remove": "cup of hot coffee", "inject": "glass of beer"}, 
    {"id": "Obj_Bottle_Can", "prompt": "A glass wine bottle on a counter", "remove": "glass wine bottle", "inject": "soda can"}, 
    {"id": "Obj_Shoe_Boot", "prompt": "A running shoe on the floor", "remove": "running shoe", "inject": "leather boot"}, 
    {"id": "Ctx_Corgi_Snow", "prompt": "A corgi dog running on green grass", "remove": "green grass", "inject": "snowy field"}, 
    {"id": "Ani_Lion_Cat", "prompt": "A lion lying in the savanna", "remove": "lion", "inject": "house cat"}, 
    {"id": "Ani_Wolf_Dog", "prompt": "A grey wolf standing in the snow", "remove": "grey wolf", "inject": "husky dog"}, 
    {"id": "Ani_Duck_Swan", "prompt": "A duck swimming in a pond", "remove": "duck", "inject": "white swan"}, 
    {"id": "Veh_Truck_Van", "prompt": "A large truck driving on a road", "remove": "large truck", "inject": "delivery van"}, 
    {"id": "Veh_Scooter_Bike", "prompt": "A vespa scooter parked on the street", "remove": "vespa scooter", "inject": "bicycle"}, 
    {"id": "Obj_Laptop_Book", "prompt": "A laptop open on a wooden desk", "remove": "laptop", "inject": "open book"}, 
    {"id": "Obj_Candle_Lamp", "prompt": "A lit candle on a dark table", "remove": "lit candle", "inject": "table lamp"}, 
    {"id": "Obj_Burger_Sandwich", "prompt": "A cheeseburger on a fast food tray", "remove": "cheeseburger", "inject": "sandwich"}, 
    {"id": "Fash_Hat_Helmet", "prompt": "A baseball cap on a wooden shelf", "remove": "baseball cap", "inject": "bicycle helmet"}, 
    {"id": "Fash_Shoe_Sandal", "prompt": "A running shoe on the floor", "remove": "running shoe", "inject": "leather sandal"}, 
    {"id": "Fash_Shirt_Jacket", "prompt": "A folded t-shirt on a bed", "remove": "folded t-shirt", "inject": "denim jacket"}, 
    {"id": "Fash_Bag_Backpack", "prompt": "A leather handbag on a table", "remove": "leather handbag", "inject": "school backpack"}, 
    {"id": "Furn_Chair_Armchair", "prompt": "A wooden chair in an empty room", "remove": "wooden chair", "inject": "red armchair"}, 
    {"id": "Furn_Lamp_Plant", "prompt": "A floor lamp standing in a corner", "remove": "floor lamp", "inject": "potted plant"}, 
    {"id": "Furn_Clock_Painting", "prompt": "A round wall clock hanging on a wall", "remove": "round wall clock", "inject": "framed painting"}, 
    {"id": "Furn_Pillow_Cushion", "prompt": "A white pillow on a bed", "remove": "white pillow", "inject": "decorative cushion"}, 
    {"id": "Food_Banana_Cucumber", "prompt": "A yellow banana on a white plate", "remove": "yellow banana", "inject": "green cucumber"}, 
    {"id": "Food_Donut_Bagel", "prompt": "A pink donut with sprinkles", "remove": "pink donut", "inject": "plain bagel"}, 
    {"id": "Food_Mushroom_Flower", "prompt": "A red mushroom growing in grass", "remove": "red mushroom", "inject": "red tulip"}, 
    {"id": "Food_Egg_Ball", "prompt": "A white egg sitting on a table", "remove": "white egg", "inject": "ping pong ball"}, 
    {"id": "Food_Burger_Taco", "prompt": "A cheeseburger on a fast food tray", "remove": "cheeseburger", "inject": "mexican taco"}, 
    {"id": "Veh_Tractor_Tank", "prompt": "A green tractor in a field", "remove": "green tractor", "inject": "military tank"}, 
    {"id": "Veh_Helicopter_Drone", "prompt": "A helicopter flying in the sky", "remove": "helicopter", "inject": "quadcopter drone"}, 
    {"id": "Veh_Scooter_Skateboard", "prompt": "A scooter parked on pavement", "remove": "scooter", "inject": "skateboard"}, 
    {"id": "OOD_Fish_Bird", "prompt": "A goldfish swimming in a bowl", "remove": "goldfish", "inject": "small bird"}, 
    {"id": "OOD_Car_Boat", "prompt": "A car parked in a garage", "remove": "car", "inject": "small boat"}, 
    {"id": "OOD_Tree_Lamp", "prompt": "A large oak tree in a park", "remove": "large oak tree", "inject": "giant street lamp"}, 
    {"id": "Sport_Soccer_Basket", "prompt": "A soccer ball on a grass field", "remove": "soccer ball", "inject": "basketball"}, 
    {"id": "Sport_Tennis_Baseball", "prompt": "A tennis ball lying on the court", "remove": "tennis ball", "inject": "baseball"},
    {"id": "Toy_Teddy_Robot", "prompt": "A brown teddy bear sitting on a bed", "remove": "brown teddy bear", "inject": "toy robot"}, 
    {"id": "Music_Guitar_Electric", "prompt": "An acoustic guitar leaning against a wall", "remove": "acoustic guitar", "inject": "electric guitar"}, 
    {"id": "Music_Violin_Cello", "prompt": "A violin resting on a chair", "remove": "violin", "inject": "cello"}, 
    {"id": "Kit_Mug_Glass", "prompt": "A white ceramic mug on a table", "remove": "white ceramic mug", "inject": "transparent glass cup"}, 
    {"id": "Kit_Bowl_Pot", "prompt": "A wooden bowl on a kitchen counter", "remove": "wooden bowl", "inject": "cooking pot"}, 
    {"id": "Nat_Rock_Bush", "prompt": "A large gray rock in a garden", "remove": "large gray rock", "inject": "green bush"}, 
    {"id": "Nat_Mushroom_Stump", "prompt": "A red mushroom growing in the forest", "remove": "red mushroom", "inject": "tree stump"}, 
    {"id": "Acc_Wallet_Phone", "prompt": "A leather wallet on a table", "remove": "leather wallet", "inject": "smartphone"}
]

dataset_rows = []
csv_path = os.path.join(dataset_dir, "dataset_log.csv")

# Resume capability
if os.path.exists(csv_path):
    print("Partial dataset found, loading existing data...")
    dataset_rows = pd.read_csv(csv_path).to_dict('records')
    completed_ids = [row['Scenario_ID'] for row in dataset_rows]
else:
    completed_ids = []

print(f"Starting Dataset Generation: {len(scenarios_to_test)} Total Scenarios")
print(f"Completed until now: {len(completed_ids)}")

if 'surgeon' in locals() and 'evaluator' in locals():
    
    start_time = time.time()
    
    for idx, test in enumerate(scenarios_to_test):
        if test['id'] in completed_ids:
            continue 
            
        print(f"\n[{idx+1}/{len(scenarios_to_test)}] Processing: {test['id']}...")
        
        scen_dir = os.path.join(dataset_dir, test['id'])
        os.makedirs(scen_dir, exist_ok=True)
        
        # 1. Reference Image
        gen = torch.Generator("cpu").manual_seed(42)
        surgeon.concepts_to_erase = []
        img_ref = surgeon([test['prompt']], img_size=512, n_steps=30, n_imgs=1, 
                          show_alpha=False, generator=gen, replace_with=None)[0][0]
        img_ref.save(os.path.join(scen_dir, "reference.png"))
        
        # 2. Grid Search Loop
        best_run = {"score": -1, "config": None, "img": None}
        
        combinations = list(itertools.product(search_forces, search_sensitivities))
        
        for force, sens in combinations:
            surgeon.params['lambda'] = force
            surgeon.params['alpha_threshold'] = sens
            surgeon.concepts_to_erase = [test['remove']]
            
            gen.manual_seed(42)
            img_cand = surgeon([test['prompt']], img_size=512, n_steps=30, n_imgs=1, 
                               show_alpha=False, generator=gen, replace_with=test['inject'])[0][0]
            
            # Calculate Scores
            clip = evaluator.get_clip_score_single(img_cand, f"a photo of a {test['inject']}")
            ssim = evaluator.get_ssim_score(img_ref, img_cand)

            # Optimization Formula
            score = (W_CLIP * (clip/30.0)) + (W_SSIM * ssim)
            
            if score > best_run["score"]:
                best_run["score"] = score
                best_run["config"] = {"F": force, "S": sens, "CLIP": clip, "SSIM": ssim}
                best_run["img"] = img_cand
        
        # 3. Save Best Result
        cfg = best_run["config"]
        print(f"   Best: F={cfg['F']}, S={cfg['S']} (Score: {best_run['score']:.3f})")
        
        best_name = f"best_F{cfg['F']}_S{cfg['S']}.png"
        best_run["img"].save(os.path.join(scen_dir, best_name))
        
        new_row = {
            "Scenario_ID": test['id'],
            "Original_Object": test['remove'],
            "Target_Object": test['inject'],
            "Prompt": test['prompt'],
            "Optimal_Force": cfg['F'],
            "Optimal_Sens": cfg['S'],
            "Score": best_run['score'],
            "CLIP_Val": cfg['CLIP'],
            "SSIM_Val": cfg['SSIM'],
            "Path_Ref": os.path.join(scen_dir, "reference.png"),
            "Path_Best": os.path.join(scen_dir, best_name)
        }
        dataset_rows.append(new_row)
        
        # Save progress incrementally
        pd.DataFrame(dataset_rows).to_csv(csv_path, index=False)
    
    elapsed = (time.time() - start_time) / 60

    clear_output(wait=True)
    print(f"\nDataset Successfully Generated in {elapsed:.1f} minutes!")
    print(f"Data saved in: {csv_path}")
    
    display(pd.DataFrame(dataset_rows)[['Scenario_ID', 'Optimal_Force', 'Optimal_Sens', 'Score']].head(10))

else:
    print("Error: Load 'surgeon' and 'evaluator' before running.")


Dataset Successfully Generated in 0.0 minutes!
Data saved in: results/final_training_dataset/dataset_log.csv


Unnamed: 0,Scenario_ID,Optimal_Force,Optimal_Sens,Score
0,Ani_Bear_Tiger,1.2,0.1,0.733057
1,Ani_Dog_Cat,1.0,0.15,0.810182
2,Ani_Fish_Shark,1.0,0.15,0.763661
3,Ani_Horse_Zebra,1.0,0.2,0.885776
4,Ani_Bird_Parrot,0.8,0.1,0.828637
5,Veh_Car_Firetruck,0.8,0.2,0.806777
6,Veh_Boat_Yacht,1.2,0.2,0.792884
7,Veh_Bike_Moto,0.8,0.1,0.811055
8,Veh_Plane_Bird,1.0,0.2,0.786461
9,Veh_Bus_Train,1.4,0.15,0.835648


## Part B: Training the "Surgery Autopilot" (Machine Learning)

This cell trains the baseline Multi-Output Regression model designed to automatically predict the optimal surgical parameters ($\lambda$ and $\alpha$) based on the semantic shift defined by the input text.

### Feature Extraction and Data Preparation
- **Feature Set ($\mathbf{X}$):** Textual concepts are processed by **CLIP (ViT-B/32)** to generate high-dimensional text embeddings (512-dim), serving as the input features. The input text is formatted as `{Original_Object} -> change to {Target_Object}`.
- **Target Variables ($\mathbf{Y}$):** The targets are the empirically determined optimal values found in the previous step: $[\text{Optimal\_Force}, \text{Optimal\_Sens}]$.
- **Data Split:** The dataset is split into **80% Training** and **20% Test** sets for robust validation.

### Model Specification
- **Algorithm:** We employ a **Random Forest Regressor** (wrapped in a `MultiOutputRegressor`) with 100 estimators. Random Forests are chosen as the baseline for their robustness to overfitting on small datasets and ability to capture non-linear relationships without extensive tuning.
- **Evaluation Metrics:**
    - **Mean Absolute Error (MAE)**: Provides the average deviation (error margin) in predicting the Force and Sensitivity values.

In [7]:
dataset_path = "results/final_training_dataset/clean_dataset_for_automation.csv"

# Output Directory
models_dir = "results/models"
os.makedirs(models_dir, exist_ok=True)
model_save_path = os.path.join(models_dir, "surgery_autopilot_model.pkl")

print("Start of 'Surgery Autopilot' training...")

if not os.path.exists(dataset_path):
    dataset_path = "final_training_dataset/dataset_log.csv"
    
df = pd.read_csv(dataset_path)
print(f"Loaded dataset: {len(df)} examples.")

print("Loading CLIP for Feature Extraction...")
model_id = "openai/clip-vit-base-patch32"
processor = CLIPProcessor.from_pretrained(model_id)
clip_model = CLIPModel.from_pretrained(model_id)

def extract_features(prompts, targets):
    inputs = processor(text=prompts, return_tensors="pt", padding=True, truncation=True)
    with torch.no_grad():
        text_features = clip_model.get_text_features(**inputs)
    return text_features.numpy()

print("   > Extracting Embeddings from the prompts...")

combined_text = [f"{row['Prompt']} -> change to {row['Target_Object']}" for _, row in df.iterrows()]

X = extract_features(combined_text, None) 
y = df[['Optimal_Force', 'Optimal_Sens']].values 

# 80/20 Split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

print(f"   > Training Set: {len(X_train)} | Test Set: {len(X_test)}")

# Train Random Forest
regr = MultiOutputRegressor(RandomForestRegressor(n_estimators=100, random_state=42))
regr.fit(X_train, y_train)

clear_output(wait=True)

print("Model Trained!")

# Evaluate
y_pred = regr.predict(X_test)

mse = mean_squared_error(y_test, y_pred)
mae = mean_absolute_error(y_test, y_pred)

print(f"\nTest Set Results:")
print(f"   Mean Absolute Error (MAE): {mae:.4f}")
print(f"   (It means that it misses the Force/Sens by about +/- {mae:.4f})")

print("\nReal-World Comparison (Top 5 of the test set):")
print(f"{'Real (F, S)':<20} | {'Predicted (F, S)':<20} | {'Error':<10}")
print("-" * 60)
for i in range(min(5, len(y_test))):
    real = y_test[i]
    pred = y_pred[i]
    diff = np.abs(real - pred)
    print(f"[{real[0]:.2f}, {real[1]:.2f}]      | [{pred[0]:.2f}, {pred[1]:.2f}]      | {np.mean(diff):.3f}")

# Save Model
with open(model_save_path, 'wb') as f:
    pickle.dump(regr, f)
print(f"\nModel saved in: {model_save_path}")

Model Trained!

Test Set Results:
   Mean Absolute Error (MAE): 0.1043
   (It means that it misses the Force/Sens by about +/- 0.1043)

Real-World Comparison (Top 5 of the test set):
Real (F, S)          | Predicted (F, S)     | Error     
------------------------------------------------------------
[1.20, 0.10]      | [1.09, 0.14]      | 0.071
[0.80, 0.15]      | [1.04, 0.16]      | 0.122
[0.80, 0.20]      | [0.99, 0.15]      | 0.119
[1.20, 0.20]      | [1.00, 0.14]      | 0.126
[0.80, 0.10]      | [0.94, 0.15]      | 0.094

Model saved in: results/models/surgery_autopilot_model.pkl


## Part C: Advanced Automation with Deep Learning (SurgeryNet Training)

In this step, we train a custom Deep Learning model to challenge the baseline Machine Learning approach. We introduce **SurgeryNet**, a neural network designed to capture the non-linear relationships between semantic embeddings and surgical parameters.

### Architecture: SurgeryNet
We designed a lightweight **Multi-Layer Perceptron (MLP)** optimized for regression on high-dimensional vectors:
* **Input**: 512-dim CLIP vector.
* **Hidden Layers**: 3 fully connected layers ($256 \to 128 \to 64$) equipped with:
    * **Batch Normalization**: To stabilize learning.
    * **ReLU Activation**: To model non-linearities.
    * **Dropout (0.2)**: To prevent overfitting on the limited dataset size.
* **Output**: 2 continuous values (Force $\lambda$, Sensitivity $\alpha$).

### Training Protocol
We train the network using the **Mean Squared Error (MSE)** loss function and the **Adam** optimizer. Based on our previous optimization analysis, we set the training duration to **600 epochs**—the empirically determined "early stopping" point where the model generalizes best without memorizing noise.

The trained weights are saved to `models/surgery_net.pth` for the upcoming comparative evaluation.

In [8]:
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# --- SURGERY NET ARCHITECTURE ---
class SurgeryNet(nn.Module):
    def __init__(self, input_dim=512, output_dim=2):
        super(SurgeryNet, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.2),
            
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.2),
            
            nn.Linear(128, 64),
            nn.ReLU(),
            
            nn.Linear(64, output_dim)
        )
    
    def forward(self, x):
        return self.model(x)

def train_dl_model(X_train, y_train, epochs=200, batch_size=8, lr=0.001):
    set_seed(42) 
    
    print("Training Deep Learning Model (SurgeryNet)...")
    
    X_t = torch.tensor(X_train, dtype=torch.float32)
    y_t = torch.tensor(y_train, dtype=torch.float32)
    
    dataset = TensorDataset(X_t, y_t)
    
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
    
    model = SurgeryNet()
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    loss_history = []
    
    model.train()
    for epoch in range(epochs):
        epoch_loss = 0
        batch_count = 0
        for batch_X, batch_y in loader:
            optimizer.zero_grad()
            outputs = model(batch_X)
            loss = criterion(outputs, batch_y)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            batch_count += 1
            
        if batch_count > 0:
            loss_history.append(epoch_loss / batch_count)
        
        if (epoch+1) % 50 == 0:
            print(f"   Epoch {epoch+1}/{epochs} | Loss: {loss_history[-1]:.4f}")
            
    return model, loss_history

if 'X_train' in locals() and 'X_test' in locals():
    
    # 1. Train
    dl_model, history = train_dl_model(X_train, y_train, epochs=600)
    
    # 2. Save Model
    models_dir = "results/models"
    os.makedirs(models_dir, exist_ok=True)
    save_path = os.path.join(models_dir, "surgery_net.pth")
    torch.save(dl_model.state_dict(), save_path)

    print(f"\nSurgeryNet Model saved to: {save_path}")

else:
    print("Error: Run the previous training cell to define X_train, X_test first.")

Training Deep Learning Model (SurgeryNet)...
   Epoch 50/600 | Loss: 0.0120
   Epoch 100/600 | Loss: 0.0075
   Epoch 150/600 | Loss: 0.0070
   Epoch 200/600 | Loss: 0.0046
   Epoch 250/600 | Loss: 0.0031
   Epoch 300/600 | Loss: 0.0035
   Epoch 350/600 | Loss: 0.0034
   Epoch 400/600 | Loss: 0.0033
   Epoch 450/600 | Loss: 0.0023
   Epoch 500/600 | Loss: 0.0027
   Epoch 550/600 | Loss: 0.0027
   Epoch 600/600 | Loss: 0.0025

SurgeryNet Model saved to: results/models/surgery_net.pth


## Part D: Model Selection & Comparative Analysis (The "Showdown")

This cell executes the final head-to-head comparison between the classical **Machine Learning** approach (Random Forest) and the proposed **Deep Learning** architecture (SurgeryNet).

### Objective
To empirically determine which model better captures the non-linear relationship between the semantic meaning of a prompt (CLIP Embedding) and the optimal surgical parameters required to modify it.

### Evaluation Metrics
Both models are evaluated on the held-out **Test Set** (20% of data) using:
1.  **Mean Absolute Error (MAE):** Measures the average "distance" between the predicted parameters and the ground truth. A lower MAE indicates higher precision.
2.  **Training Stability Analysis:** We plot the **Loss Curve** of the neural network to verify convergence and ensure that 600 epochs represent the optimal training duration without overfitting.

### Visualization
* **Loss Dynamics:** A line chart showing the Neural Network's learning progress.
* **The Verdict:** A bar chart comparing the final MAE of both models to visually declare the winner of the automation challenge.

In [9]:
if 'X_train' in locals() and 'X_test' in locals():

    # 3. Evaluate
    dl_model.eval()
    with torch.no_grad():
        X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
        y_pred_dl = dl_model(X_test_tensor).numpy()
    
    mae_dl = mean_absolute_error(y_test, y_pred_dl)

    y_pred_rf = regr.predict(X_test)
    mae_rf = mean_absolute_error(y_test, y_pred_rf)
    
    print("\n" + "="*40)
    print("      FINAL SHOWDOWN: ML vs DL")
    print("="*40)
    print(f"Random Forest MAE: {mae_rf:.4f}")
    print(f"Deep Learning MAE: {mae_dl:.4f}")
    
    if mae_rf < mae_dl:
        winner = "Random Forest (ML)"
        diff = mae_dl - mae_rf
        reason = "Small dataset favors classical ML."
    else:
        winner = "Neural Network (DL)"
        diff = mae_rf - mae_dl
        reason = "The network generalized better."
        
    print(f"\nWINNER: {winner}")
    print(f"Difference: {diff:.4f}")
    print(f"Note: {reason}")
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    if history:
        axes[0].plot(history, label="Training Loss", color='purple')
        axes[0].set_title("DL Training Curve (SurgeryNet)")
        axes[0].set_xlabel("Epochs")
        axes[0].set_ylabel("MSE Loss")
        axes[0].grid(True, alpha=0.3)
    
    models_names = ['Random Forest (ML)', 'SurgeryNet (DL)']
    maes = [mae_rf, mae_dl]
    colors = ['#4e79a7', '#e15759'] 
    
    bars = axes[1].bar(models_names, maes, color=colors, alpha=0.8)
    axes[1].set_title(f"Performance Comparison (Lower MAE is Better)")
    axes[1].set_ylabel("Mean Absolute Error")
    axes[1].set_ylim(0, max(maes) * 1.3)
    
    for bar in bars:
        height = bar.get_height()
        axes[1].text(bar.get_x() + bar.get_width()/2., height,
                f'{height:.4f}', ha='center', va='bottom', fontsize=12, fontweight='bold')

    plt.tight_layout()
    plot_filename = "results/models/ML_vs_DL_Comparison.png"
    plt.savefig(plot_filename)
    plt.close() 
    
    print(f"\nComparison plot saved successfully to: {plot_filename}")

else:
    print("Error: Run the previous training cell to define X_train, X_test first.")


      FINAL SHOWDOWN: ML vs DL
Random Forest MAE: 0.1043
Deep Learning MAE: 0.0918

WINNER: Neural Network (DL)
Difference: 0.0125
Note: The network generalized better.

Comparison plot saved successfully to: results/models/ML_vs_DL_Comparison.png


## Part E: Automated Comparative Testing (ML vs. DL)

This cell integrates both trained models—the classical **Random Forest** (ML) and the neural **SurgeryNet** (DL)—into the live generation pipeline to perform a side-by-side comparative evaluation on unseen prompts.

### Workflow: The "Head-to-Head" Challenge
For every test scenario, the system executes two parallel inference tracks:

1.  **Track A (Classical ML)**:
    * **Input:** CLIP embedding (Numpy array).
    * **Model:** Random Forest Regressor.
    * **Output:** Predicted parameters $(\lambda_{ML}, \alpha_{ML})$.
    * **Result:** Image generated with ML parameters saved to `model_testing/ML/`.

2.  **Track B (Deep Learning)**:
    * **Input:** CLIP embedding (PyTorch Tensor).
    * **Model:** SurgeryNet (MLP).
    * **Output:** Predicted parameters $(\lambda_{DL}, \alpha_{DL})$.
    * **Result:** Image generated with DL parameters saved to `model_testing/DL/`.

### Real-World Test Cases
We test the models on a diverse set of conceptual swaps (semantic, contextual, object-level) to validate generalization:

| Domain | Swap Example | Type |
| :--- | :--- | :--- |
| **Context** | `sea` $\rightarrow$ `sand dunes` | Background Injection |
| **Object** | `dog` $\rightarrow$ `cat` | Object-to-Object |
| **Abstract** | `lightbulb` $\rightarrow$ `firefly` | Conceptual/Metaphorical |
| **Profession** | `doctor` $\rightarrow$ `nurse` | Gender/Role Bias Test |

The output images allow for a direct visual comparison: **which model understood the semantic intent better?**

In [10]:
BASE_TEST_DIR = "results/model_testing"
ML_DIR = os.path.join(BASE_TEST_DIR, "ML")
DL_DIR = os.path.join(BASE_TEST_DIR, "DL")

os.makedirs(ML_DIR, exist_ok=True)
os.makedirs(DL_DIR, exist_ok=True)

print(f"Output folders set:\n  -> {ML_DIR}/\n  -> {DL_DIR}/")

MODELS_DIR = "results/models"
ml_path = os.path.join(MODELS_DIR, "surgery_autopilot_model.pkl")
dl_path = os.path.join(MODELS_DIR, "surgery_net.pth")

# Re-define DL Architecture for loadin
class SurgeryNet(nn.Module):
    def __init__(self, input_dim=512, output_dim=2):
        super(SurgeryNet, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.2),
            nn.Linear(256, 128), nn.BatchNorm1d(128), nn.ReLU(), nn.Dropout(0.2),
            nn.Linear(128, 64), nn.ReLU(),
            nn.Linear(64, output_dim)
        )
    def forward(self, x): return self.model(x)

if os.path.exists(ml_path) and os.path.exists(dl_path) and 'surgeon' in locals():
    print("Loading Models...")
    
    # Load ML
    with open(ml_path, 'rb') as f:
        ml_model = pickle.load(f)
    
    # Load DL
    dl_model = SurgeryNet()
    dl_model.load_state_dict(torch.load(dl_path))
    dl_model.eval()
        
    # Load CLIP
    if 'processor' not in locals():
        model_id = "openai/clip-vit-base-patch32"
        processor = CLIPProcessor.from_pretrained(model_id)
        clip_model = CLIPModel.from_pretrained(model_id)
    
    print("Models Loaded Successfully.")

    def run_comparative_surgery(prompt, remove_obj, target_obj, seed=42):
        print(f"\n--- TEST: '{remove_obj}' -> '{target_obj}' ---")

        # Feature Extraction
        text_input = f"{prompt} -> change to {target_obj}"
        inputs = processor(text=[text_input], return_tensors="pt", padding=True)
        with torch.no_grad():
            emb_tensor = clip_model.get_text_features(**inputs) 
            emb_numpy = emb_tensor.numpy()                      
        
        # Define Tests
        tests = [
            ("ML", ml_model, emb_numpy, ML_DIR),
            ("DL", dl_model, emb_tensor, DL_DIR)
        ]
        
        for model_name, model, input_data, save_dir in tests:
            # Predict
            if model_name == "ML":
                pred = model.predict(input_data)[0]
            else: 
                with torch.no_grad():
                    pred = model(input_data).numpy()[0]
            
            # Post-Process Prediction
            pred_force = max(0.6, min(round(float(pred[0]), 2), 1.5))
            pred_sens = max(0.05, min(round(float(pred[1]), 2), 0.30))
            
            print(f"   [{model_name}] Suggestion: Force={pred_force}, Sens={pred_sens}")
            
            # Generate
            gen = torch.Generator("cpu").manual_seed(seed)
            
            # Original Reference
            surgeon.concepts_to_erase = []
            img_ref = surgeon([prompt], img_size=512, n_steps=30, n_imgs=1, 
                              show_alpha=False, generator=gen, replace_with=None)[0][0]

            # Surgery
            gen.manual_seed(seed) 
            surgeon.params['lambda'] = pred_force
            surgeon.params['alpha_threshold'] = pred_sens
            surgeon.concepts_to_erase = [remove_obj]
            
            img_res = surgeon([prompt], img_size=512, n_steps=30, n_imgs=1, 
                              show_alpha=False, generator=gen, replace_with=target_obj)[0][0]
            
            # Save Comparison
            fig, ax = plt.subplots(1, 2, figsize=(12, 6))
            ax[0].imshow(img_ref); ax[0].set_title("Original"); ax[0].axis('off')
            ax[1].imshow(img_res); ax[1].set_title(f"{model_name} Result\n(F={pred_force}, S={pred_sens})"); ax[1].axis('off')
            
            clean_prompt = prompt.replace(" ", "_")[:30]
            clean_target = target_obj.replace(" ", "_")
            filename = f"{model_name}_{clean_prompt}_TO_{clean_target}.png"
            save_path = os.path.join(save_dir, filename)
            
            plt.tight_layout()
            plt.savefig(save_path, bbox_inches='tight')
            plt.close(fig)

    # Run Scenarios 
    test_cases = [
        ("A red boat in the sea", "sea", "sand dunes"),
        ("A goldfish in the ocean", "ocean", "forest"),
        ("A lamp on the table", "lamp", "plant"),
        ("A red mushroom in the green grass", "mushroom", "rose"),
        ("A sofa in a living room", "sofa", "armchair"),
        ("A doctor at work", "doctor", "nurse"),
        ("A lightbulb glowing in the dark", "lightbulb", "firefly"),
        ("A dog on the sofa", "dog", "cat")
    ]

    for p, rem, inj in test_cases:
        run_comparative_surgery(p, rem, inj)

    clear_output(wait=True)
    print(f"\nAll tests completed!")
    print(f"ML Results: {ML_DIR}/")
    print(f"DL Results: {DL_DIR}/")

else:
    print("Error: Models not found in 'models/' directory or Surgeon not loaded.")


All tests completed!
ML Results: results/model_testing/ML/
DL Results: results/model_testing/DL/


# 6. Live Demonstration

This final cell launches a **Gradio Web Interface** to provide an interactive playground for the *Semantic Surgery* framework. It serves as the ultimate validation tool, allowing users to experience the difference between manual control and automated prediction in real-time.

### The Interface Structure
The demo is divided into three distinct tabs, representing the evolution of the project:

1.  **Tab 1: Manual Control**
    * Allows the user to manually adjust **Force** and **Sensitivity** sliders.
    * *Purpose:* To understand the mechanics "under the hood" and experience the difficulty of finding the right parameters by hand.

2.  **Tab 2: ML Autopilot (Random Forest)**
    * The user inputs a semantic shift (e.g., "Apple" $\to$ "Orange").
    * The system extracts CLIP embeddings and queries the trained **Random Forest** to predict the parameters.
    * *Purpose:* To test the baseline automation performance.

3.  **Tab 3: DL Autopilot (Neural Network)**
    * The user interacts with the advanced **SurgeryNet**.
    * *Purpose:* To demonstrate the superior nuance and adaptability of the Deep Learning model trained in the previous step.

### Technical Note
For every request, the system generates two images side-by-side (Original vs. Modified) using **synchronized random seeds**. This mathematically guarantees that any difference in the image is strictly due to the semantic vector injection, providing immediate visual proof of the surgery's precision.

In [11]:
# --- 1. SETUP AND LOADING MODELS ---
MODELS_DIR = "results/models"
ML_PATH = os.path.join(MODELS_DIR, "surgery_autopilot_model.pkl")
DL_PATH = os.path.join(MODELS_DIR, "surgery_net.pth")

print("Initializing Demo...")

# A. Definition of DL Class (Required for loading weights)
class SurgeryNet(nn.Module):
    def __init__(self, input_dim=512, output_dim=2):
        super(SurgeryNet, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(0.2),
            nn.Linear(256, 128), nn.BatchNorm1d(128), nn.ReLU(), nn.Dropout(0.2),
            nn.Linear(128, 64), nn.ReLU(),
            nn.Linear(64, output_dim)
        )
    def forward(self, x): return self.model(x)

# B. Loading Models (Error Handling)
models_ready = True
try:
    with open(ML_PATH, 'rb') as f:
        ml_model = pickle.load(f)
    
    # Load DL (Neural Network)
    dl_model = SurgeryNet()
    dl_model.load_state_dict(torch.load(DL_PATH))
    dl_model.eval() 
    
    # Load CLIP
    if 'clip_model' not in globals():
        from transformers import CLIPProcessor, CLIPModel
        processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
        clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
        
    print("All Models Loaded Successfully (ML & DL).")
    
except Exception as e:
    print(f"Warning: Could not load models. Ensure training cells are run. Error: {e}")
    models_ready = False

def get_clip_features(text):
    inputs = processor(text=[text], return_tensors="pt", padding=True)
    with torch.no_grad():
        emb = clip_model.get_text_features(**inputs)
    return emb

def core_generation(prompt, remove, inject, force, sens, seed):
    if 'surgeon' not in globals(): return None, None
    
    # 1. Original (Reference)
    gen_orig = torch.Generator("cpu").manual_seed(int(seed))
    surgeon.concepts_to_erase = []
    img_orig = surgeon([prompt], img_size=512, n_steps=30, n_imgs=1, 
                       show_alpha=False, generator=gen_orig, replace_with=None)[0][0]
    
    surgeon.params['lambda'] = force
    surgeon.params['alpha_threshold'] = sens
    surgeon.concepts_to_erase = [remove]
    
    # 2. Modified (Surgery)
    gen_mod = torch.Generator("cpu").manual_seed(int(seed))
    img_mod = surgeon([prompt], img_size=512, n_steps=30, n_imgs=1, 
                      show_alpha=False, generator=gen_mod, replace_with=inject)[0][0]
    
    return img_orig, img_mod

# --- 2. LOGIC FOR EACH TAB ---
def manual_process(prompt, remove, inject, force, sens, seed):
    return core_generation(prompt, remove, inject, force, sens, seed)

def ml_process(prompt, remove, inject, seed):
    if not models_ready: return None, None, "Model Error"
    
    # Predict
    emb = get_clip_features(f"{prompt} -> change to {inject}").numpy()
    pred = ml_model.predict(emb)[0]
    f, s = float(pred[0]), float(pred[1])
    
    # Generate
    img_o, img_m = core_generation(prompt, remove, inject, f, s, seed)
    return img_o, img_m, f"🤖 Random Forest suggests: Force={f:.2f}, Sens={s:.2f}"

def dl_process(prompt, remove, inject, seed):
    if not models_ready: return None, None, "Model Error"
    
    # Predict
    emb = get_clip_features(f"{prompt} -> change to {inject}") # Tensor
    with torch.no_grad():
        pred = dl_model(emb).numpy()[0]
    f, s = float(pred[0]), float(pred[1])
    
    # Generate
    img_o, img_m = core_generation(prompt, remove, inject, f, s, seed)
    return img_o, img_m, f"🧠 Deep Learning suggests: Force={f:.2f}, Sens={s:.2f}"

# --- 3. GRADIO INTERFACE ---
with gr.Blocks() as demo:
    gr.Markdown("# 🏥 Semantic Surgery: The Ultimate Comparison", elem_id="title")
    gr.Markdown("Explore manual control vs. automated predictions using Machine Learning and Deep Learning.")

    with gr.Tabs():
        
        # --- TAB 1: MANUAL ---
        with gr.TabItem("🎛️ Manual Control"):
            with gr.Row():
                with gr.Column():
                    m_prompt = gr.Textbox(label="Base Prompt", value="a cat in a park")
                    m_remove = gr.Textbox(label="Remove", value="cat")
                    m_inject = gr.Textbox(label="Inject", value="bunny")
                    with gr.Group():
                        gr.Markdown("**Surgical Parameters**")
                        m_force = gr.Slider(0.6, 1.5, value=1.0, label="Force")
                        m_sens = gr.Slider(0.05, 0.3, value=0.15, label="Sensitivity")
                    m_seed = gr.Number(value=42, label="Seed")
                    m_btn = gr.Button("Run Manual Surgery", variant="primary")
                with gr.Column():
                    with gr.Row():
                        m_out1 = gr.Image(label="Original")
                        m_out2 = gr.Image(label="Modified")
            m_btn.click(manual_process, [m_prompt, m_remove, m_inject, m_force, m_sens, m_seed], [m_out1, m_out2])

        # --- TAB 2: ML AUTOMATION ---
        with gr.TabItem("🤖 ML Autopilot (Random Forest)"):
            gr.Markdown("Uses Classical ML to predict parameters based on CLIP embeddings.")
            with gr.Row():
                with gr.Column():
                    ml_prompt = gr.Textbox(label="Base Prompt", value="a cat in a park")
                    ml_remove = gr.Textbox(label="Remove", value="cat")
                    ml_inject = gr.Textbox(label="Inject", value="bunny")
                    ml_seed = gr.Number(value=42, label="Seed")
                    ml_btn = gr.Button("Ask Random Forest", variant="secondary")
                    ml_info = gr.Textbox(label="Model Output", interactive=False)
                with gr.Column():
                    with gr.Row():
                        ml_out1 = gr.Image(label="Original")
                        ml_out2 = gr.Image(label="Modified")
            ml_btn.click(ml_process, [ml_prompt, ml_remove, ml_inject, ml_seed], [ml_out1, ml_out2, ml_info])

        # --- TAB 3: DL AUTOMATION ---
        with gr.TabItem("🧠 DL Autopilot (Neural Net)"):
            gr.Markdown("Uses the **SurgeryNet (Deep Learning)** model trained for 600 epochs.")
            with gr.Row():
                with gr.Column():
                    dl_prompt = gr.Textbox(label="Base Prompt", value="a cat in a park")
                    dl_remove = gr.Textbox(label="Remove", value="cat")
                    dl_inject = gr.Textbox(label="Inject", value="bunny")
                    dl_seed = gr.Number(value=42, label="Seed")
                    dl_btn = gr.Button("Ask Neural Network", variant="primary")
                    dl_info = gr.Textbox(label="Model Output", interactive=False)
                with gr.Column():
                    with gr.Row():
                        dl_out1 = gr.Image(label="Original")
                        dl_out2 = gr.Image(label="Modified")
            dl_btn.click(dl_process, [dl_prompt, dl_remove, dl_inject, dl_seed], [dl_out1, dl_out2, dl_info])

# Launch
demo.launch(share=True)

Initializing Demo...
All Models Loaded Successfully (ML & DL).
* Running on local URL:  http://127.0.0.1:7860


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
  (batch_size, self.unet.in_channels, img_size // 8, img_size // 8),


  0%|          | 0/30 [00:00<?, ?it/s]


Could not create share link. Please check your internet connection or our status page: https://status.gradio.app.


2025/12/15 19:42:44 [W] [service.go:132] login to server failed: dial tcp 44.237.78.176:7000: i/o timeout




  0%|          | 0/30 [00:00<?, ?it/s]