In [None]:
# ==============================================================================
# ライブラリのインポート
# ==============================================================================
import torch
import torch.nn.functional as F
from diffusers import StableDiffusionPipeline, StableDiffusionInpaintPipeline 
import numpy as np
from PIL import Image, ImageDraw
from torchvision import datasets, transforms # datasets を追加
from tqdm.auto import tqdm
import os
import gc
import random
import pickle
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.patheffects as pe
from collections import Counter

# ==============================================================================
# 1. パラメータ設定
# ==============================================================================
STANDARD_MODEL_ID = "runwayml/stable-diffusion-v1-5"
INPAINT_MODEL_ID = "runwayml/stable-diffusion-inpainting"

DATA_ROOT = "./dataset"#stl10_binaryがあるディレクトリパス

NUM_IMAGES_TO_VISUALIZE = 5
NUM_TIMESTEP_SAMPLES = 100 
NOISE_LEVEL_MIN = 100
NOISE_LEVEL_MAX = 900
PATCH_GRID_SIZE = 5  #N*Nに分割
SEED = 45

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# --- 乱数シードの固定 ---
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
np.random.seed(SEED)
random.seed(SEED)

# ==============================================================================
# 2. モデルとデータのロード
# ==============================================================================
print(f"Loading STL-10 dataset (test split) from {DATA_ROOT}...")


try:
    stl10_dataset_pil = datasets.STL10(root=DATA_ROOT, split='test', download=True)
    class_names = stl10_dataset_pil.classes
    NUM_CLASSES = len(class_names)
    print(f"  STL-10 dataset loaded. Found {NUM_CLASSES} classes: {class_names}")
except Exception as e:
    print(f"エラー: STL-10データセットのロードに失敗しました。パス '{DATA_ROOT}' を確認してください。: {e}")
    raise e


# ==============================================================================
# 3. 関数定義
# ==============================================================================

# --- 関数: モデルのロード ---
def load_model(model_id, pipeline_class):
    print(f"Loading {pipeline_class.__name__}: {model_id}")
    try:
        pipe = pipeline_class.from_pretrained(model_id, torch_dtype=torch.float16).to(DEVICE)
    except Exception as e:
        print(f"  Could not load fp16 model: {e}. Trying with full precision.")
        pipe = pipeline_class.from_pretrained(model_id).to(DEVICE)
    pipe.set_progress_bar_config(disable=True)
    print(f"  Model loaded.")
    return pipe

# --- 関数: テキスト埋め込みの事前計算 ---
def precompute_embeddings(pipe, class_names):
    print("Pre-computing text embeddings...")
    tokenizer = pipe.tokenizer
    text_encoder = pipe.text_encoder
    with torch.no_grad():
        uncond_input = tokenizer("", padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt")
        uncond_embeddings = text_encoder(uncond_input.input_ids.to(DEVICE))[0]
        all_class_prompts = [f"a photo of a {name}" for name in class_names]
        text_inputs = tokenizer(all_class_prompts, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt")
        all_text_embeddings = text_encoder(text_inputs.input_ids.to(DEVICE))[0]
    return torch.cat([uncond_embeddings, all_text_embeddings]).to(DEVICE)

# --- 関数: 標準モデルのパッチ分析 (analyze_standard_patch) ---
@torch.no_grad()
def analyze_standard_patch(pipe, init_latents, noise_timestep_list, text_embeds_batch, true_label):
    unet, scheduler = pipe.unet, pipe.scheduler 
    patch_correct_counts = np.zeros((PATCH_GRID_SIZE, PATCH_GRID_SIZE))
    patch_winner_counts = np.zeros((NUM_CLASSES, PATCH_GRID_SIZE, PATCH_GRID_SIZE))
    sum_of_pred_variances = np.zeros((PATCH_GRID_SIZE, PATCH_GRID_SIZE))
    overall_correct_timesteps = 0
    batch_size = text_embeds_batch.shape[0]

    for noise, start_timestep in noise_timestep_list:
        noisy_latents = scheduler.add_noise(init_latents, noise, torch.tensor([start_timestep], device=DEVICE))
        
        latent_model_input = noisy_latents.repeat(batch_size, 1, 1, 1)
        timestep_input = torch.tensor([start_timestep], device=DEVICE).repeat(batch_size)
        pred_noise_batch = unet(latent_model_input, timestep_input, encoder_hidden_states=text_embeds_batch).sample
        
        error_maps = (pred_noise_batch - noise)**2
        spatial_error_maps = error_maps.mean(dim=1) 
        patch_errors = F.adaptive_avg_pool2d(spatial_error_maps, (PATCH_GRID_SIZE, PATCH_GRID_SIZE)) 
        
        class_pred_noise_batch = pred_noise_batch[1:] 
        variance_of_preds = torch.var(class_pred_noise_batch, dim=0, unbiased=False) 
        spatial_variance_map = variance_of_preds.mean(dim=0)
        patch_variance = F.adaptive_avg_pool2d(spatial_variance_map.unsqueeze(0), (PATCH_GRID_SIZE, PATCH_GRID_SIZE)).squeeze(0)
        sum_of_pred_variances += patch_variance.cpu().numpy()
        
        class_patch_errors = patch_errors[1:].cpu().numpy()
        avg_errors_per_class = class_patch_errors.mean(axis=(1, 2))
        if np.argmin(avg_errors_per_class) == true_label:
            overall_correct_timesteps += 1
        patch_predictions = np.argmin(class_patch_errors, axis=0)
        patch_correct_counts += (patch_predictions == true_label)
        for i in range(PATCH_GRID_SIZE):
            for j in range(PATCH_GRID_SIZE):
                patch_winner_counts[patch_predictions[i, j], i, j] += 1

    accuracy_map = (patch_correct_counts / NUM_TIMESTEP_SAMPLES) * 100
    avg_pred_variance_map = sum_of_pred_variances / NUM_TIMESTEP_SAMPLES
    most_frequent_winner_map = np.argmax(patch_winner_counts, axis=0)
    overall_accuracy = (overall_correct_timesteps / NUM_TIMESTEP_SAMPLES) * 100
    
    return accuracy_map, avg_pred_variance_map, most_frequent_winner_map, overall_accuracy

# --- 関数: Inpaintingモデルのパッチ分析 (analyze_inpainting_patch) ---
@torch.no_grad()
def analyze_inpainting_patch(pipe, init_latents_common, noise_timestep_list, text_embeds_batch, true_label, image_pil_resized_for_analysis):
    vae, unet, scheduler = pipe.vae, pipe.unet, pipe.scheduler
    batch_size = text_embeds_batch.shape[0]
    dtype = text_embeds_batch.dtype

    accuracy_map = np.zeros((PATCH_GRID_SIZE, PATCH_GRID_SIZE))
    avg_pred_variance_map = np.zeros((PATCH_GRID_SIZE, PATCH_GRID_SIZE))
    most_frequent_winner_map = np.zeros((PATCH_GRID_SIZE, PATCH_GRID_SIZE), dtype=int)
    patch_size_img = 512 // PATCH_GRID_SIZE

    image_transform_inner = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
    mask_transform_inner = transforms.ToTensor()
    
    image_tensor_common = image_transform_inner(image_pil_resized_for_analysis).unsqueeze(0).to(DEVICE, dtype=dtype)
    
    desc_inner = f"Analyzing {PATCH_GRID_SIZE**2} patches (Inpainting)"
    for i in tqdm(range(PATCH_GRID_SIZE), desc=desc_inner, leave=False):
        for j in range(PATCH_GRID_SIZE):
            mask_pil = Image.new("L", (512, 512), 0)
            draw = ImageDraw.Draw(mask_pil)
            top_left = (j * patch_size_img, i * patch_size_img)
            bottom_right = ((j + 1) * patch_size_img, (i + 1) * patch_size_img)
            draw.rectangle([top_left, bottom_right], fill=255)
            
            current_mask_tensor = mask_transform_inner(mask_pil).unsqueeze(0).to(DEVICE, dtype=dtype)
            masked_image_tensor = image_tensor_common * (1 - current_mask_tensor) 

            mask_latents = F.interpolate(current_mask_tensor, size=init_latents_common.shape[-2:])
            masked_image_latents = vae.encode(masked_image_tensor).latent_dist.sample() * 0.18215

            patch_winners_list = []
            patch_pred_variances_list = []

            for noise, start_timestep in noise_timestep_list:
                
                noisy_latents = scheduler.add_noise(init_latents_common, noise, torch.tensor([start_timestep], device=DEVICE))
                
                latent_model_input_ = torch.cat([noisy_latents] * batch_size)
                mask_input_ = torch.cat([mask_latents] * batch_size)
                masked_latents_input_ = torch.cat([masked_image_latents] * batch_size)
                latent_model_input = torch.cat([latent_model_input_, mask_input_, masked_latents_input_], dim=1)
                
                timestep_input = torch.tensor([start_timestep], device=DEVICE).repeat(batch_size)
                pred_noise_batch = unet(latent_model_input, timestep_input, encoder_hidden_states=text_embeds_batch).sample
                class_pred_noise_batch = pred_noise_batch[1:] 
                
                errors = [F.mse_loss(pred * mask_latents, noise * mask_latents).item() for pred in class_pred_noise_batch]
                winner_idx = np.argmin(errors)
                patch_winners_list.append(winner_idx)
                
                masked_preds = class_pred_noise_batch * mask_latents
                variance_of_preds = torch.var(masked_preds, dim=0, unbiased=False)
                spatial_variance_map = variance_of_preds.mean(dim=0)
                mask_pixels = mask_latents.squeeze() > 0
                if mask_pixels.any():
                        patch_variance_value = spatial_variance_map[mask_pixels].mean().item()
                else:
                       patch_variance_value = 0.0 
                patch_pred_variances_list.append(patch_variance_value)

            accuracy_map[i, j] = (patch_winners_list.count(true_label) / NUM_TIMESTEP_SAMPLES) * 100
            avg_pred_variance_map[i, j] = np.mean(patch_pred_variances_list) if patch_pred_variances_list else 0.0
            most_frequent_winner_map[i, j] = Counter(patch_winners_list).most_common(1)[0][0] if patch_winners_list else -1

    return accuracy_map, avg_pred_variance_map, most_frequent_winner_map, -1.0 

# ==============================================================================
# 5. メイン実行ループ
# ==============================================================================
print(f"\nSelecting {NUM_IMAGES_TO_VISUALIZE} random images from STL-10...")
all_indices = list(range(len(stl10_dataset_pil)))
selected_indices = random.sample(all_indices, NUM_IMAGES_TO_VISUALIZE)

results_standard = {}
results_inpaint = {}

# --- 5.1 標準モデルでの計算 ---
pipe_std = load_model(STANDARD_MODEL_ID, StableDiffusionPipeline)
text_embeds_std = precompute_embeddings(pipe_std, class_names)

vae_std = pipe_std.vae
scheduler_std = pipe_std.scheduler
sd_transform_for_latents = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

print("\n--- Running Standard Model Analysis ---")
for image_index in tqdm(selected_indices, desc="Standard Model"):
    image_pil_original, true_label = stl10_dataset_pil[image_index]
    if image_pil_original.mode != 'RGB':
        image_pil_original = image_pil_original.convert('RGB')
        
    image_pil_resized_std = image_pil_original.resize((512, 512)) # リサイズ
    
    # 1. 共通のinit_latentsをここで1回だけ生成
    image_tensor = sd_transform_for_latents(image_pil_resized_std).unsqueeze(0).to(DEVICE, dtype=text_embeds_std.dtype)
    with torch.no_grad():
        init_latents = vae_std.encode(image_tensor).latent_dist.sample() * 0.18215
        
    # 2. 共通のノイズとタイムステップのリストをここで1回だけ生成
    noise_timestep_list = []
    for _ in range(NUM_TIMESTEP_SAMPLES):
        noise = torch.randn_like(init_latents)
        start_timestep = random.randint(NOISE_LEVEL_MIN, NOISE_LEVEL_MAX)
        noise_timestep_list.append((noise, start_timestep))
    
    acc_map, var_map, win_map, overall_acc = analyze_standard_patch(
        pipe_std, init_latents, noise_timestep_list, text_embeds_std, true_label 
    )
    
    results_standard[image_index] = {
        'acc': acc_map, 'var': var_map, 'win': win_map, 'overall': overall_acc,
        'image_pil': image_pil_resized_std, 
        'init_latents': init_latents, 
        'noise_list': noise_timestep_list,
        'true_label': true_label # ★true_labelも保存
    }

del vae_std, scheduler_std
del pipe_std, text_embeds_std 
gc.collect(); torch.cuda.empty_cache()

# --- 5.2 Inpaintingモデルでの計算 ---
pipe_inpaint = load_model(INPAINT_MODEL_ID, StableDiffusionInpaintPipeline)
text_embeds_inpaint = precompute_embeddings(pipe_inpaint, class_names)

print("\n--- Running Inpainting Model Analysis ---")
for image_index in tqdm(selected_indices, desc="Inpainting Model"):
    
    if image_index not in results_standard:
        print(f"Skipping inpaint for index {image_index} as standard results are missing.")
        continue
        
    # 標準モデルで保存したアセットを取得
    res_std_item = results_standard[image_index]
    true_label = res_std_item['true_label'] # ★保存したtrue_labelを使用
    image_pil_resized_inp = res_std_item['image_pil'] 
    init_latents_common = res_std_item['init_latents'] 
    noise_list_common = res_std_item['noise_list'] 

    acc_map, var_map, win_map, _ = analyze_inpainting_patch(
        pipe_inpaint, init_latents_common, noise_list_common, text_embeds_inpaint, true_label, image_pil_resized_inp
    )
    
    results_inpaint[image_index] = {
        'acc': acc_map, 'var': var_map, 'win': win_map
    }
del pipe_inpaint, text_embeds_inpaint 
gc.collect(); torch.cuda.empty_cache()

# ==============================================================================
# 6. 比較プロットの生成
# ==============================================================================
print("\nGenerating comparison plots...")

plot_text_settings = {
    'ha': "center", 'va': "center", 'fontsize': 12, 'fontweight': 'bold',
    'path_effects': [pe.Stroke(linewidth=1.5, foreground='black'), pe.Normal()]
}
plot_text_settings_small = {
    'ha': "center", 'va': "center", 'fontsize': 10, 'fontweight': 'bold',
    'path_effects': [pe.Stroke(linewidth=2, foreground='black'), pe.Normal()]
}

for idx in selected_indices:
    
    if idx not in results_standard or idx not in results_inpaint:
        print(f"Skipping plot for index {idx} due to missing results.")
        continue
        
    res_std = results_standard[idx]
    res_inp = results_inpaint[idx]

    true_label = res_std['true_label'] 
    true_label_name = class_names[true_label]
    image_pil = res_std['image_pil'] 

    fig, axes = plt.subplots(2, 4, figsize=(32, 16)) 
    fig.suptitle(f"Comparison Analysis for '{true_label_name}' (Index: {idx}, Standard Overall: {res_std['overall']:.0f}%)", fontsize=20, y=0.97) 
    
    # --- 共通の分散スケールを「この画像」の標準モデルから決定 ---
    std_var_map = res_std['var']
    vmin_var_local, vmax_var_local = std_var_map.min(), std_var_map.max()
    if vmin_var_local == vmax_var_local:
        norm_var_local = mcolors.Normalize(vmin=vmin_var_local - 1e-6, vmax=vmax_var_local + 1e-6)
    else:
        norm_var_local = mcolors.Normalize(vmin=vmin_var_local, vmax=vmax_var_local)
    cmap_var = plt.get_cmap('magma')

    # --- 1行目: 標準モデル ---
    axes[0, 0].imshow(image_pil); axes[0, 0].set_title("Original Image (Standard Model Row)"); axes[0, 0].axis('off')
    
    ax = axes[0, 1]; ax.imshow(image_pil)
    cmap_acc = plt.get_cmap('viridis'); norm_acc = mcolors.Normalize(vmin=0, vmax=100)
    map_data = res_std['acc']
    resized_heatmap = Image.fromarray(map_data.astype(np.float32)).resize((512, 512), Image.NEAREST)
    im = ax.imshow(np.array(resized_heatmap), cmap=cmap_acc, norm=norm_acc, alpha=0.6)
    fig.colorbar(im, ax=ax, label='Patch Accuracy (%)', shrink=0.7)
    ax.set_title("Std: Patch Accuracy (vs True)")
    for r in range(PATCH_GRID_SIZE):
        for c in range(PATCH_GRID_SIZE):
            val = map_data[r, c]
            text_color = "k" if val > 60 else "w"
            ax.text((c+0.5)*(512/PATCH_GRID_SIZE), (r+0.5)*(512/PATCH_GRID_SIZE), f"{val:.0f}%", color=text_color, **plot_text_settings)
    ax.axis('off')

    ax = axes[0, 2]; ax.imshow(image_pil)
    map_data = res_std['var']
    resized_heatmap = Image.fromarray(map_data.astype(np.float32)).resize((512, 512), Image.NEAREST)
    im = ax.imshow(np.array(resized_heatmap), cmap=cmap_var, norm=norm_var_local, alpha=0.6) 
    fig.colorbar(im, ax=ax, label='Avg. Pred. Noise Variance', shrink=0.7)
    ax.set_title("Std: Prediction Variance (Importance)"); ax.axis('off')

    ax = axes[0, 3]; ax.imshow(image_pil)
    cmap_winner = plt.get_cmap('tab10', len(class_names)); norm_winner = mcolors.Normalize(vmin=-0.5, vmax=len(class_names)-0.5)
    map_data = res_std['win']
    resized_heatmap = Image.fromarray(map_data.astype(np.float32)).resize((512, 512), Image.NEAREST)
    im = ax.imshow(np.array(resized_heatmap), cmap=cmap_winner, norm=norm_winner, alpha=0.6)
    cbar = fig.colorbar(im, ax=ax, label='Most Frequent Winner Class', shrink=0.7)
    cbar.set_ticks(np.arange(len(class_names)))
    cbar.set_ticklabels(class_names)
    ax.set_title("Std: Most Frequent Winner")
    for r in range(PATCH_GRID_SIZE):
        for c in range(PATCH_GRID_SIZE):
            winner_idx = map_data[r, c]; 
            winner_name = class_names[winner_idx]
            text_color = "lime" if winner_idx == true_label else "red"
            ax.text((c+0.5)*(512/PATCH_GRID_SIZE), (r+0.5)*(512/PATCH_GRID_SIZE), winner_name, color=text_color, **plot_text_settings_small)
    ax.axis('off')

    # --- 2行目: Inpaintingモデル ---
    axes[1, 0].imshow(image_pil); axes[1, 0].set_title("Original Image (Inpainting Model Row)"); axes[1, 0].axis('off')
    
    ax = axes[1, 1]; ax.imshow(image_pil)
    map_data = res_inp['acc'] 
    resized_heatmap = Image.fromarray(map_data.astype(np.float32)).resize((512, 512), Image.NEAREST)
    im = ax.imshow(np.array(resized_heatmap), cmap=cmap_acc, norm=norm_acc, alpha=0.6)
    fig.colorbar(im, ax=ax, label='Patch Accuracy (%)', shrink=0.7)
    ax.set_title("Inpaint: Patch Accuracy (In-Mask)")
    for r in range(PATCH_GRID_SIZE):
        for c in range(PATCH_GRID_SIZE):
            val = map_data[r, c]
            text_color = "k" if val > 60 else "w"
            ax.text((c+0.5)*(512/PATCH_GRID_SIZE), (r+0.5)*(512/PATCH_GRID_SIZE), f"{val:.0f}%", color=text_color, **plot_text_settings)
    ax.axis('off')

    ax = axes[1, 2]; ax.imshow(image_pil)
    map_data = res_inp['var'] 
    resized_heatmap = Image.fromarray(map_data.astype(np.float32)).resize((512, 512), Image.NEAREST)
    im = ax.imshow(np.array(resized_heatmap), cmap=cmap_var, norm=norm_var_local, alpha=0.6) 
    fig.colorbar(im, ax=ax, label='Avg. Pred. Noise Variance', shrink=0.7)
    ax.set_title("Inpaint: Prediction Variance (In-Mask)"); ax.axis('off')

    ax = axes[1, 3]; ax.imshow(image_pil)
    map_data = res_inp['win'] 
    resized_heatmap = Image.fromarray(map_data.astype(np.float32)).resize((512, 512), Image.NEAREST)
    im = ax.imshow(np.array(resized_heatmap), cmap=cmap_winner, norm=norm_winner, alpha=0.6)
    cbar = fig.colorbar(im, ax=ax, label='Most Frequent Winner Class', shrink=0.7)
    cbar.set_ticks(np.arange(len(class_names)))
    cbar.set_ticklabels(class_names)
    ax.set_title("Inpaint: Most Frequent Winner (In-Mask)")
    for r in range(PATCH_GRID_SIZE):
        for c in range(PATCH_GRID_SIZE):
            winner_idx = map_data[r, c]
            if winner_idx == -1:
                winner_name = "ERR"; text_color = "gray"
            else:
                winner_name = class_names[winner_idx]
                text_color = "lime" if winner_idx == true_label else "red"
            ax.text((c+0.5)*(512/PATCH_GRID_SIZE), (r+0.5)*(512/PATCH_GRID_SIZE), winner_name, color=text_color, **plot_text_settings_small)
    ax.axis('off')

    plt.tight_layout(rect=[0, 0.03, 1, 0.95]) 
    plt.show()

# --- 7. クリーンアップ ---
print("\nCleaning up final resources...")
variables_to_delete = ['results_standard', 'results_inpaint', 'stl10_dataset_pil',
                       'selected_indices', 'sd_transform_for_latents']
for var_name in variables_to_delete:
    if var_name in locals() or var_name in globals():
        try: exec(f"del {var_name}")
        except NameError: pass 
gc.collect()
if torch.cuda.is_available(): torch.cuda.empty_cache()
print("--- Comparison analysis finished. ---")