In [None]:
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import json
from pathlib import Path
from scipy.ndimage import gaussian_filter

# ===== CONFIGURATION =====
SCREEN_SIZE = (1024, 768)

datasets = {
    "Training 1": {
        "json": "Training 1/training1.json",
        "images": "Training 1/training1_images",
        "dual_images": False
    },
    "Training 2": {
        "json": "Training 2/training2.json",
        "images": "Training 2/training2_images",
        "dual_images": True
    },
    "Testing": {
        "json": "Testing/testing.json",
        "images": "Testing/testing_images",
        "dual_images": True
    }
}

# ===== LOAD DATA =====
def load_datasets(datasets_config):
    all_data = {}
    for dataset_name, paths in datasets_config.items():
        json_path = Path(paths["json"])
        if json_path.exists():
            with open(json_path, "r", encoding="utf-8") as f:
                all_data[dataset_name] = {
                    "data": json.load(f),
                    "image_folder": Path(paths["images"]),
                    "dual_images": paths["dual_images"]
                }
            print(f"‚úÖ Loaded {dataset_name}: {len(all_data[dataset_name]['data'])} trials")
        else:
            print(f"‚ö†Ô∏è {dataset_name} not found at {json_path}")
    return all_data

# ===== PLOTTING FUNCTIONS =====
def plot_fixations_only(trials_list, image_folder, image_key, screen_size=(1024, 768), ax=None):
    """Plot image with coral fixation points only."""
    screen_w, screen_h = screen_size
    l, t, r, b = (screen_w//2 - 310, screen_h//2 - 310, 
                  screen_w//2 + 310, screen_h//2 + 310)
    
    img_name = trials_list[0].get(image_key)
    image_path = Path(image_folder) / img_name
    img = Image.open(image_path).convert("RGB")
    
    # Collect fixations
    all_xs, all_ys = [], []
    for trial in trials_list:
        xs = np.asarray(trial["fix_x"], dtype=float)
        ys = np.asarray(trial["fix_y"], dtype=float)
        order = np.asarray(trial.get("fix_index", np.arange(1, len(xs) + 1)), dtype=float)
        
        test_image_fix_idx = trial.get("test_image_fixation_idx")
        if test_image_fix_idx is not None:
            if image_key == "first_image":
                mask_image = order < test_image_fix_idx
            else:
                mask_image = order >= test_image_fix_idx
            xs, ys = xs[mask_image], ys[mask_image]
        
        mask = (xs >= l) & (xs <= r) & (ys >= t) & (ys <= b)
        all_xs.extend(xs[mask])
        all_ys.extend(ys[mask])
    
    all_xs = np.array(all_xs)
    all_ys = np.array(all_ys)
    
    ax.imshow(img, extent=(l, r, b, t))
    
    if len(all_xs) > 0:
        ax.scatter(all_xs, all_ys, s=30, c='#FF6B6B', 
                   alpha=0.5, edgecolor='white', linewidth=0.8)
    
    ax.set_xlim(0, screen_w)
    ax.set_ylim(screen_h, 0)
    ax.set_xlabel("x (screen px)")
    ax.set_ylabel("y (screen px)")


def plot_heatmap_only(trials_list, image_folder, image_key, screen_size=(1024, 768), ax=None):
    """Plot image with heatmap overlay only (no fixation points)."""
    screen_w, screen_h = screen_size
    l, t, r, b = (screen_w//2 - 310, screen_h//2 - 310, 
                  screen_w//2 + 310, screen_h//2 + 310)
    
    img_name = trials_list[0].get(image_key)
    image_path = Path(image_folder) / img_name
    img = Image.open(image_path).convert("RGB")
    
    # Collect fixations
    all_xs, all_ys = [], []
    for trial in trials_list:
        xs = np.asarray(trial["fix_x"], dtype=float)
        ys = np.asarray(trial["fix_y"], dtype=float)
        order = np.asarray(trial.get("fix_index", np.arange(1, len(xs) + 1)), dtype=float)
        
        test_image_fix_idx = trial.get("test_image_fixation_idx")
        if test_image_fix_idx is not None:
            if image_key == "first_image":
                mask_image = order < test_image_fix_idx
            else:
                mask_image = order >= test_image_fix_idx
            xs, ys = xs[mask_image], ys[mask_image]
        
        mask = (xs >= l) & (xs <= r) & (ys >= t) & (ys <= b)
        all_xs.extend(xs[mask])
        all_ys.extend(ys[mask])
    
    all_xs = np.array(all_xs)
    all_ys = np.array(all_ys)
    
    ax.imshow(img, extent=(l, r, b, t))
    
    im = None
    if len(all_xs) > 0:
        heatmap, xedges, yedges = np.histogram2d(
            all_xs, all_ys,
            bins=[np.linspace(l, r, 150), np.linspace(t, b, 150)]
        )
        heatmap = gaussian_filter(heatmap, sigma=15)
        
        extent = [xedges[0], xedges[-1], yedges[-1], yedges[0]]
        im = ax.imshow(heatmap.T, extent=extent, origin='upper', 
                       cmap='jet', alpha=0.6, interpolation='bilinear')
    
    ax.set_xlim(0, screen_w)
    ax.set_ylim(screen_h, 0)
    ax.set_xlabel("x (screen px)")
    ax.set_ylabel("y (screen px)")
    
    return im


# ===== GENERATE EXAMPLES =====
def generate_examples(all_data):
    output_dir = Path("fdm_examples")
    output_dir.mkdir(exist_ok=True)
    
    # Training 1
    if "Training 1" in all_data:
        data = all_data["Training 1"]["data"]
        image_folder = all_data["Training 1"]["image_folder"]
        
        first_trial = data[0]
        target_image = first_trial.get("first_image")
        image_trials = [t for t in data if t.get("first_image") == target_image]
        
        fig = plt.figure(figsize=(10, 14), dpi=150)
        gs = fig.add_gridspec(2, 2, height_ratios=[1, 1], width_ratios=[1, 0.05], 
                             hspace=0.35, wspace=0.05)
        
        ax1 = fig.add_subplot(gs[0, 0])
        plot_fixations_only(image_trials, image_folder, "first_image", ax=ax1)
        ax1.set_title("Fixation Points Only", fontsize=11, fontweight='bold')
        
        ax2 = fig.add_subplot(gs[1, 0])
        im = plot_heatmap_only(image_trials, image_folder, "first_image", ax=ax2)
        ax2.set_title("Density Heatmap Only", fontsize=11, fontweight='bold')
        
        if im is not None:
            cbar_ax = fig.add_subplot(gs[1, 1])
            cbar = plt.colorbar(im, cax=cbar_ax)
            cbar.set_label('Density', fontsize=9)
        
        fig.suptitle(f"Training 1 ‚Ä¢ All Subjects\n"
                    f"Paradigm: {image_trials[0].get('paradigm', 'N/A')} ‚Ä¢ "
                    f"N={len(image_trials)} observations",
                    fontsize=12, fontweight='bold', y=0.96)
        
        plt.savefig(output_dir / "training1_fdm_example.pdf", bbox_inches='tight', dpi=150)
        plt.close()
        print(f"‚úÖ Training 1: N={len(image_trials)} observations")
    
    # Training 2
    if "Training 2" in all_data:
        data = all_data["Training 2"]["data"]
        image_folder = all_data["Training 2"]["image_folder"]
        
        first_trial = data[0]
        target_first = first_trial.get("first_image")
        target_second = first_trial.get("second_image")
        image_trials = [t for t in data 
                       if t.get("first_image") == target_first 
                       and t.get("second_image") == target_second]
        
        fig = plt.figure(figsize=(16, 14), dpi=150)
        gs = fig.add_gridspec(2, 3, height_ratios=[1, 1], width_ratios=[1, 1, 0.05], 
                             hspace=0.35, wspace=0.3)
        
        ax1 = fig.add_subplot(gs[0, 0])
        plot_fixations_only(image_trials, image_folder, "first_image", ax=ax1)
        ax1.set_title("First Image - Fixation Points Only", fontsize=11, fontweight='bold')
        
        ax2 = fig.add_subplot(gs[0, 1])
        plot_fixations_only(image_trials, image_folder, "second_image", ax=ax2)
        ax2.set_title("Second Image - Fixation Points Only", fontsize=11, fontweight='bold')
        
        ax3 = fig.add_subplot(gs[1, 0])
        im1 = plot_heatmap_only(image_trials, image_folder, "first_image", ax=ax3)
        ax3.set_title("First Image - Density Heatmap", fontsize=11, fontweight='bold')
        
        ax4 = fig.add_subplot(gs[1, 1])
        im2 = plot_heatmap_only(image_trials, image_folder, "second_image", ax=ax4)
        ax4.set_title("Second Image - Density Heatmap", fontsize=11, fontweight='bold')
        
        if im1 is not None:
            cbar_ax = fig.add_subplot(gs[1, 2])
            cbar = plt.colorbar(im1, cax=cbar_ax)
            cbar.set_label('Fixation Density', fontsize=10)
        
        fig.suptitle(f"Training 2 ‚Ä¢ All Subjects\n"
                    f"Paradigm: {image_trials[0].get('paradigm', 'N/A')} ‚Ä¢ "
                    f"N={len(image_trials)} observations",
                    fontsize=12, fontweight='bold', y=0.96)
        
        plt.savefig(output_dir / "training2_fdm_example.pdf", bbox_inches='tight', dpi=150)
        plt.close()
        print(f"‚úÖ Training 2: N={len(image_trials)} observations")
    
    # Testing
    if "Testing" in all_data:
        data = all_data["Testing"]["data"]
        image_folder = all_data["Testing"]["image_folder"]
        
        first_central = next((t for t in data if t.get("viewing_condition") == "central"), None)
        target_first = first_central.get("first_image")
        target_second = first_central.get("second_image")
        image_trials = [t for t in data 
                       if t.get("viewing_condition") == "central"
                       and t.get("first_image") == target_first
                       and t.get("second_image") == target_second]
        
        fig = plt.figure(figsize=(16, 14), dpi=150)
        gs = fig.add_gridspec(2, 3, height_ratios=[1, 1], width_ratios=[1, 1, 0.05], 
                             hspace=0.35, wspace=0.3)
        
        ax1 = fig.add_subplot(gs[0, 0])
        plot_fixations_only(image_trials, image_folder, "first_image", ax=ax1)
        ax1.set_title("First Image - Fixation Points Only", fontsize=11, fontweight='bold')
        
        ax2 = fig.add_subplot(gs[0, 1])
        plot_fixations_only(image_trials, image_folder, "second_image", ax=ax2)
        ax2.set_title("Second Image - Fixation Points Only", fontsize=11, fontweight='bold')
        
        ax3 = fig.add_subplot(gs[1, 0])
        im1 = plot_heatmap_only(image_trials, image_folder, "first_image", ax=ax3)
        ax3.set_title("First Image - Density Heatmap", fontsize=11, fontweight='bold')
        
        ax4 = fig.add_subplot(gs[1, 1])
        im2 = plot_heatmap_only(image_trials, image_folder, "second_image", ax=ax4)
        ax4.set_title("Second Image - Density Heatmap", fontsize=11, fontweight='bold')
        
        if im1 is not None:
            cbar_ax = fig.add_subplot(gs[1, 2])
            cbar = plt.colorbar(im1, cax=cbar_ax)
            cbar.set_label('Fixation Density', fontsize=10)
        
        fig.suptitle(f"Testing ‚Ä¢ Central Viewing ‚Ä¢ All Subjects\n"
                    f"Paradigm: {image_trials[0].get('paradigm', 'N/A')} ‚Ä¢ "
                    f"N={len(image_trials)} observations",
                    fontsize=12, fontweight='bold', y=0.96)
        
        plt.savefig(output_dir / "testing_fdm_example.pdf", bbox_inches='tight', dpi=150)
        plt.close()
        print(f"‚úÖ Testing (central): N={len(image_trials)} observations")
    
    print(f"\nüìÅ All examples saved to: {output_dir}/")


# ===== FULL BATCH GENERATION =====
def generate_all_fdms(all_data):
    """Generate all FDM PDFs organized by dataset."""
    base_output_dir = Path("fdm_outputs")
    base_output_dir.mkdir(exist_ok=True)
    
    # Training 1 - 40 unique images
    if "Training 1" in all_data:
        print("\nüìä Generating Training 1 FDMs...")
        output_dir = base_output_dir / "training1_fdms"
        output_dir.mkdir(exist_ok=True)
        
        data = all_data["Training 1"]["data"]
        image_folder = all_data["Training 1"]["image_folder"]
        
        # Get all unique images
        unique_images = sorted(set([t.get("first_image") for t in data if t.get("first_image")]))
        print(f"   Found {len(unique_images)} unique images")
        
        for idx, target_image in enumerate(unique_images, 1):
            image_trials = [t for t in data if t.get("first_image") == target_image]
            
            fig = plt.figure(figsize=(10, 14), dpi=150)
            gs = fig.add_gridspec(2, 2, height_ratios=[1, 1], width_ratios=[1, 0.05], 
                                 hspace=0.35, wspace=0.05)
            
            ax1 = fig.add_subplot(gs[0, 0])
            plot_fixations_only(image_trials, image_folder, "first_image", ax=ax1)
            ax1.set_title("Fixation Points Only", fontsize=11, fontweight='bold')
            
            ax2 = fig.add_subplot(gs[1, 0])
            im = plot_heatmap_only(image_trials, image_folder, "first_image", ax=ax2)
            ax2.set_title("Density Heatmap Only", fontsize=11, fontweight='bold')
            
            if im is not None:
                cbar_ax = fig.add_subplot(gs[1, 1])
                cbar = plt.colorbar(im, cax=cbar_ax)
                cbar.set_label('Density', fontsize=9)
            
            fig.suptitle(f"Training 1 ‚Ä¢ All Subjects ‚Ä¢ Image {idx}/{len(unique_images)}\n"
                        f"Paradigm: {image_trials[0].get('paradigm', 'N/A')} ‚Ä¢ "
                        f"N={len(image_trials)} observations",
                        fontsize=12, fontweight='bold', y=0.96)
            
            output_path = output_dir / f"training1_fdm_{idx:02d}.pdf"
            plt.savefig(output_path, bbox_inches='tight', dpi=150)
            plt.close()
            
            if idx % 10 == 0:
                print(f"   Generated {idx}/{len(unique_images)}...")
        
        print(f"   ‚úÖ Complete! Saved to: {output_dir}/")
    
    # Training 2 - 72 unique image pairs
    if "Training 2" in all_data:
        print("\nüìä Generating Training 2 FDMs...")
        output_dir = base_output_dir / "training2_fdms"
        output_dir.mkdir(exist_ok=True)
        
        data = all_data["Training 2"]["data"]
        image_folder = all_data["Training 2"]["image_folder"]
        
        # Get all unique image pairs
        unique_pairs = sorted(set([
            (t.get("first_image"), t.get("second_image")) 
            for t in data 
            if t.get("first_image") and t.get("second_image")
        ]))
        print(f"   Found {len(unique_pairs)} unique image pairs")
        
        for idx, (target_first, target_second) in enumerate(unique_pairs, 1):
            image_trials = [t for t in data 
                           if t.get("first_image") == target_first 
                           and t.get("second_image") == target_second]
            
            fig = plt.figure(figsize=(16, 14), dpi=150)
            gs = fig.add_gridspec(2, 3, height_ratios=[1, 1], width_ratios=[1, 1, 0.05], 
                                 hspace=0.35, wspace=0.3)
            
            ax1 = fig.add_subplot(gs[0, 0])
            plot_fixations_only(image_trials, image_folder, "first_image", ax=ax1)
            ax1.set_title("First Image - Fixation Points Only", fontsize=11, fontweight='bold')
            
            ax2 = fig.add_subplot(gs[0, 1])
            plot_fixations_only(image_trials, image_folder, "second_image", ax=ax2)
            ax2.set_title("Second Image - Fixation Points Only", fontsize=11, fontweight='bold')
            
            ax3 = fig.add_subplot(gs[1, 0])
            im1 = plot_heatmap_only(image_trials, image_folder, "first_image", ax=ax3)
            ax3.set_title("First Image - Density Heatmap", fontsize=11, fontweight='bold')
            
            ax4 = fig.add_subplot(gs[1, 1])
            im2 = plot_heatmap_only(image_trials, image_folder, "second_image", ax=ax4)
            ax4.set_title("Second Image - Density Heatmap", fontsize=11, fontweight='bold')
            
            if im1 is not None:
                cbar_ax = fig.add_subplot(gs[1, 2])
                cbar = plt.colorbar(im1, cax=cbar_ax)
                cbar.set_label('Fixation Density', fontsize=10)
            
            fig.suptitle(f"Training 2 ‚Ä¢ All Subjects ‚Ä¢ Pair {idx}/{len(unique_pairs)}\n"
                        f"Paradigm: {image_trials[0].get('paradigm', 'N/A')} ‚Ä¢ "
                        f"N={len(image_trials)} observations",
                        fontsize=12, fontweight='bold', y=0.96)
            
            output_path = output_dir / f"training2_fdm_{idx:02d}.pdf"
            plt.savefig(output_path, bbox_inches='tight', dpi=150)
            plt.close()
            
            if idx % 10 == 0:
                print(f"   Generated {idx}/{len(unique_pairs)}...")
        
        print(f"   ‚úÖ Complete! Saved to: {output_dir}/")
    
    # Testing - 72 image pairs √ó 3 viewing conditions = 216 PDFs
    if "Testing" in all_data:
        print("\nüìä Generating Testing FDMs...")
        output_dir = base_output_dir / "testing_fdms"
        output_dir.mkdir(exist_ok=True)
        
        data = all_data["Testing"]["data"]
        image_folder = all_data["Testing"]["image_folder"]
        
        # Get all unique combinations of (image_pair, viewing_condition)
        unique_combos = sorted(set([
            (t.get("first_image"), t.get("second_image"), t.get("viewing_condition"))
            for t in data 
            if t.get("first_image") and t.get("second_image") and t.get("viewing_condition")
        ]))
        print(f"   Found {len(unique_combos)} unique image pairs √ó viewing conditions")
        
        for idx, (target_first, target_second, viewing) in enumerate(unique_combos, 1):
            image_trials = [t for t in data 
                           if t.get("first_image") == target_first 
                           and t.get("second_image") == target_second
                           and t.get("viewing_condition") == viewing]
            
            fig = plt.figure(figsize=(16, 14), dpi=150)
            gs = fig.add_gridspec(2, 3, height_ratios=[1, 1], width_ratios=[1, 1, 0.05], 
                                 hspace=0.35, wspace=0.3)
            
            ax1 = fig.add_subplot(gs[0, 0])
            plot_fixations_only(image_trials, image_folder, "first_image", ax=ax1)
            ax1.set_title("First Image - Fixation Points Only", fontsize=11, fontweight='bold')
            
            ax2 = fig.add_subplot(gs[0, 1])
            plot_fixations_only(image_trials, image_folder, "second_image", ax=ax2)
            ax2.set_title("Second Image - Fixation Points Only", fontsize=11, fontweight='bold')
            
            ax3 = fig.add_subplot(gs[1, 0])
            im1 = plot_heatmap_only(image_trials, image_folder, "first_image", ax=ax3)
            ax3.set_title("First Image - Density Heatmap", fontsize=11, fontweight='bold')
            
            ax4 = fig.add_subplot(gs[1, 1])
            im2 = plot_heatmap_only(image_trials, image_folder, "second_image", ax=ax4)
            ax4.set_title("Second Image - Density Heatmap", fontsize=11, fontweight='bold')
            
            if im1 is not None:
                cbar_ax = fig.add_subplot(gs[1, 2])
                cbar = plt.colorbar(im1, cax=cbar_ax)
                cbar.set_label('Fixation Density', fontsize=10)
            
            fig.suptitle(f"Testing ‚Ä¢ {viewing.capitalize()} Viewing ‚Ä¢ All Subjects ‚Ä¢ {idx}/{len(unique_combos)}\n"
                        f"Paradigm: {image_trials[0].get('paradigm', 'N/A')} ‚Ä¢ "
                        f"N={len(image_trials)} observations",
                        fontsize=12, fontweight='bold', y=0.96)
            
            output_path = output_dir / f"testing_fdm_{idx:03d}_{viewing}.pdf"
            plt.savefig(output_path, bbox_inches='tight', dpi=150)
            plt.close()
            
            if idx % 20 == 0:
                print(f"   Generated {idx}/{len(unique_combos)}...")
        
        print(f"   ‚úÖ Complete! Saved to: {output_dir}/")
    
    print(f"\nüéâ All FDMs generated!")
    print(f"üìÅ Output directory: {base_output_dir}/")


# ===== RUN =====
print("Loading datasets...")
all_data = load_datasets(datasets)

print("\n" + "="*50)
print("GENERATING ALL FDM PDFs")
print("="*50)
generate_all_fdms(all_data)

Loading datasets...
‚úÖ Loaded Training 1: 1364 trials
‚úÖ Loaded Training 2: 2320 trials
‚úÖ Loaded Testing: 2304 trials

GENERATING ALL FDM PDFs

üìä Generating Training 1 FDMs...
   Found 40 unique images
   Generated 10/40...
   Generated 20/40...
   Generated 30/40...
   Generated 40/40...
   ‚úÖ Complete! Saved to: fdm_outputs/training1_fdms/

üìä Generating Training 2 FDMs...
   Found 72 unique image pairs
   Generated 10/72...
   Generated 20/72...
   Generated 30/72...
   Generated 40/72...
   Generated 50/72...
   Generated 60/72...
   Generated 70/72...
   ‚úÖ Complete! Saved to: fdm_outputs/training2_fdms/

üìä Generating Testing FDMs...
   Found 72 unique image pairs √ó viewing conditions
   Generated 20/72...
   Generated 40/72...
   Generated 60/72...
   ‚úÖ Complete! Saved to: fdm_outputs/testing_fdms/

üéâ All FDMs generated!
üìÅ Output directory: fdm_outputs/
