In [1]:
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.gridspec import GridSpec
from pathlib import Path
from PIL import Image
from scipy.ndimage import gaussian_filter
from typing import Dict, List, Tuple, Optional
import warnings
warnings.filterwarnings('ignore')

class QualitativeTrialAnalyzer:
    """Analyze and visualize top vs bottom performing trials for each viewing condition"""
    
    def __init__(self, data_path: Path = Path(".")):
        self.data_path = data_path
        self.screen_size = (1024, 768)
        self.image_size = 620
        
        # Image positioning on screen
        self.box = (
            self.screen_size[0] // 2 - 310,  # left
            self.screen_size[1] // 2 - 310,  # top
            self.screen_size[0] // 2 + 310,  # right
            self.screen_size[1] // 2 + 310   # bottom
        )
        
        # Load all datasets
        self.load_all_data()
        
    def load_all_data(self):
        """Load data from all phases"""
        self.datasets = {}
        
        # Define dataset paths - matching your existing structure
        dataset_configs = {
            "Training 1": {
                "json": "Training 1/training1.json",
                "images": "Training 1/training1_images",
                "dual_images": False
            },
            "Training 2": {
                "json": "Training 2/training2.json",
                "images": "Training 2/training2_images",
                "dual_images": True
            },
            "Testing": {
                "json": "Testing/testing.json",
                "images": "Testing/testing_images",
                "dual_images": True
            }
        }
        
        for name, config in dataset_configs.items():
            json_path = self.data_path / config["json"]
            if json_path.exists():
                with open(json_path, 'r') as f:
                    data = json.load(f)
                self.datasets[name] = {
                    "data": data,
                    "image_folder": self.data_path / config["images"],
                    "dual_images": config["dual_images"]
                }
                print(f"‚úì Loaded {name}: {len(data)} trials")
            else:
                print(f"‚ö† {name} not found at {json_path}")
    
    def load_image(self, image_folder: Path, image_name: str) -> Optional[np.ndarray]:
        """Load an image file and return as numpy array"""
        image_path = image_folder / image_name
        if image_path.exists():
            img = Image.open(image_path).convert("RGB")
            return np.array(img)
        else:
            print(f"‚ö† Image not found: {image_path}")
            return None
    
    def get_trial_performance(self, dataset_name: str) -> pd.DataFrame:
        """Calculate performance metrics for each unique trial configuration"""
        if dataset_name not in self.datasets:
            return pd.DataFrame()
        
        data = self.datasets[dataset_name]["data"]
        
        # Group trials by unique configurations
        if dataset_name == "Training 1":
            # Group by single image
            trial_groups = {}
            for trial in data:
                img = trial.get("first_image")
                if img:
                    if img not in trial_groups:
                        trial_groups[img] = []
                    trial_groups[img].append(trial)
            
            performance = []
            for img, trials in trial_groups.items():
                correct = sum(1 for t in trials if self.is_correct(t))
                total = len(trials)
                performance.append({
                    "trial_id": img,
                    "first_image": img,
                    "accuracy": correct / total if total > 0 else 0,
                    "n_responses": total,
                    "correct_count": correct,
                    "trials": trials
                })
                
        elif dataset_name in ["Training 2", "Testing"]:
            # Group by image pair (and viewing condition for Testing)
            trial_groups = {}
            for trial in data:
                first = trial.get("first_image")
                second = trial.get("second_image")
                
                if dataset_name == "Testing":
                    viewing = trial.get("viewing_condition", "unknown")
                    key = (first, second, viewing)
                else:
                    key = (first, second)
                
                if first and second:
                    if key not in trial_groups:
                        trial_groups[key] = []
                    trial_groups[key].append(trial)
            
            performance = []
            for key, trials in trial_groups.items():
                correct = sum(1 for t in trials if self.is_correct(t))
                total = len(trials)
                
                entry = {
                    "trial_id": str(key),
                    "first_image": key[0],
                    "second_image": key[1],
                    "accuracy": correct / total if total > 0 else 0,
                    "n_responses": total,
                    "correct_count": correct,
                    "trials": trials
                }
                
                if dataset_name == "Testing" and len(key) > 2:
                    entry["viewing_condition"] = key[2]
                
                performance.append(entry)
        
        return pd.DataFrame(performance)
    
    def is_correct(self, trial: Dict) -> bool:
        """Determine if a trial response was correct"""
        # Check for explicit accuracy flag
        if "acc" in trial:
            v = trial["acc"]
            if isinstance(v, bool):
                return v
            if isinstance(v, (int, float)):
                return bool(v)
            if isinstance(v, str):
                sv = str(v).strip().lower()
                if sv in {"1", "true", "correct", "right"}:
                    return True
                if sv in {"0", "false", "incorrect", "wrong"}:
                    return False
        
        # Fallback: compare answer to correct response
        answer = str(trial.get("subj_answer", "")).strip().lower()
        correct = str(trial.get("correct_response", "")).strip().lower()
        if answer and correct:
            return answer == correct
        
        return False
    
    def get_top_bottom_trials(self, dataset_name: str, viewing_condition: Optional[str] = None, n: int = 3):
        """Get top and bottom n trial configurations by success rate"""
        df = self.get_trial_performance(dataset_name)
        
        if df.empty:
            return pd.DataFrame(), pd.DataFrame()
        
        # Filter by viewing condition if specified
        if viewing_condition and "viewing_condition" in df.columns:
            df = df[df["viewing_condition"] == viewing_condition]
        
        # Sort by accuracy
        df = df.sort_values("accuracy", ascending=False)
        
        top = df.head(n)
        bottom = df.tail(n)
        
        return top, bottom
    
    def collect_fixations(self, trials: List[Dict], image_key: str = "first_image") -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        """Collect all fixations for given trials"""
        l, t, r, b = self.box
        xs_all, ys_all, durations_all = [], [], []
        
        for trial in trials:
            xs = np.asarray(trial.get("fix_x", []), dtype=float)
            ys = np.asarray(trial.get("fix_y", []), dtype=float)
            durations = np.asarray(trial.get("fix_dur", np.ones_like(xs) * 100), dtype=float)
            
            # Handle image order for dual-image trials
            order = np.asarray(trial.get("fix_index", np.arange(1, len(xs) + 1)), dtype=float)
            idx = trial.get("test_image_fixation_idx")
            
            if idx is not None:
                if image_key == "first_image":
                    mask = order < idx
                else:
                    mask = order >= idx
                xs = xs[mask]
                ys = ys[mask]
                durations = durations[mask]
            
            # Filter to image boundaries
            mask = (xs >= l) & (xs <= r) & (ys >= t) & (ys <= b)
            xs_all.extend(xs[mask])
            ys_all.extend(ys[mask])
            durations_all.extend(durations[mask])
        
        return np.array(xs_all), np.array(ys_all), np.array(durations_all)
    
    def create_heatmap(self, xs, ys, durations=None):
        """Create a fixation density heatmap"""
        if len(xs) == 0:
            return np.zeros((150, 150))
        
        l, t, r, b = self.box
        
        # Create 2D histogram
        H, xedges, yedges = np.histogram2d(
            xs, ys, 
            bins=[np.linspace(l, r, 150), np.linspace(t, b, 150)],
            weights=durations if durations is not None else None
        )
        
        # Apply gaussian smoothing
        H = gaussian_filter(H, sigma=15)
        
        return H
    
    def calculate_pattern_metrics(self, xs, ys, durations):
        """Calculate metrics that explain fixation patterns"""
        if len(xs) == 0:
            return {
                'n_fixations': 0,
                'total_duration': 0,
                'avg_duration': 0,
                'spatial_spread': 0,
                'center_bias': 0,
                'coverage': 0
            }
        
        # Convert to image coordinates
        l, t, r, b = self.box
        xs_img = xs - l
        ys_img = ys - t
        
        # Calculate metrics
        metrics = {
            'n_fixations': len(xs),
            'total_duration': np.sum(durations) if durations is not None else len(xs) * 100,
            'avg_duration': np.mean(durations) if durations is not None else 100,
            'spatial_spread': np.std(xs_img) + np.std(ys_img),
            'center_bias': np.mean(np.sqrt((xs_img - 310)**2 + (ys_img - 310)**2)),
            'coverage': len(set([(int(x/20), int(y/20)) for x, y in zip(xs_img, ys_img)])) / 961  # 31x31 grid
        }
        
        return metrics
    
    def plot_trial_with_image(self, ax, trial_config, position: str, dataset_name: str, 
                              image_folder: Path, image_key: str = "first_image", 
                              show_image_num: int = 1):
        """Plot heatmap and fixations with actual image background"""
        trials = trial_config["trials"]
        
        # Determine which image to show
        if image_key == "first_image":
            image_name = trial_config["first_image"]
        else:
            image_name = trial_config.get("second_image", trial_config["first_image"])
        
        # Load and display the image
        l, t, r, b = self.box
        img = self.load_image(image_folder, image_name)
        if img is not None:
            ax.imshow(img, extent=(l, r, b, t), aspect='auto')
        
        # Collect fixations
        xs, ys, durations = self.collect_fixations(trials, image_key)
        
        # Create and overlay heatmap
        if len(xs) > 0:
            H = self.create_heatmap(xs, ys, durations)
            extent = [l, r, b, t]
            im = ax.imshow(H.T, extent=extent, origin="upper", cmap="jet", 
                          alpha=0.6, interpolation="bilinear")
            
            # Add fixation points (coral color as in your original code)
            ax.scatter(xs, ys, s=30, c='#FF6B6B', alpha=0.5, 
                      edgecolors='white', linewidth=0.8)
        
        # Add image boundary
        rect = patches.Rectangle((l, t), r-l, b-t, linewidth=2, 
                                edgecolor='black', facecolor='none')
        ax.add_patch(rect)
        
        # Calculate metrics
        metrics = self.calculate_pattern_metrics(xs, ys, durations)
        
        # Set title
        color = 'green' if position == 'top' else 'red'
        
        if dataset_name == "Training 1":
            title = f"{image_name[:30]}"
        else:
            title = f"Image {show_image_num}: {image_name[:20]}"
        
        title += f"\nAcc: {trial_config['accuracy']:.1%} (N={trial_config['n_responses']})"
        title += f"\nFix: {metrics['n_fixations']} | Spread: {metrics['spatial_spread']:.0f}px"
        
        ax.set_title(title, fontsize=9, color=color)
        
        ax.set_xlim(0, self.screen_size[0])
        ax.set_ylim(self.screen_size[1], 0)
        ax.set_xlabel("x (screen px)", fontsize=8)
        ax.set_ylabel("y (screen px)", fontsize=8)
        ax.tick_params(labelsize=7)
        
        return metrics
    
    def create_qualitative_figures(self, output_dir: Path = Path("qualitative_analysis")):
        """Create comprehensive qualitative comparison figures"""
        output_dir.mkdir(exist_ok=True)
        
        # Focus on Testing dataset with viewing conditions
        if "Testing" not in self.datasets:
            print("‚ö† Testing dataset not found!")
            return
        
        dataset_info = self.datasets["Testing"]
        image_folder = dataset_info["image_folder"]
        
        viewing_conditions = ['full', 'central', 'peripheral']
        
        for condition in viewing_conditions:
            print(f"\nüìä Analyzing {condition} viewing condition...")
            
            # Get top and bottom trials
            top_trials, bottom_trials = self.get_top_bottom_trials("Testing", condition, n=3)
            
            if top_trials.empty or bottom_trials.empty:
                print(f"  ‚ö† Not enough data for {condition}")
                continue
            
            # Create figure for this viewing condition
            fig = plt.figure(figsize=(20, 14))
            fig.suptitle(f'Testing Phase - {condition.upper()} Viewing Condition\n'
                        f'Top 3 Trials (High Accuracy) vs Bottom 3 Trials (Low Accuracy)',
                        fontsize=16, fontweight='bold')
            
            # Create grid: 2 rows (top/bottom), 6 columns (3 trials √ó 2 images each)
            gs = GridSpec(2, 6, figure=fig, hspace=0.3, wspace=0.15)
            
            all_metrics = {'top': [], 'bottom': []}
            
            # Plot top trials
            for i, (_, trial) in enumerate(top_trials.iterrows()):
                # First image
                ax1 = fig.add_subplot(gs[0, i*2])
                m1 = self.plot_trial_with_image(ax1, trial, 'top', "Testing", 
                                                image_folder, "first_image", 
                                                show_image_num=1)
                
                # Second image
                ax2 = fig.add_subplot(gs[0, i*2 + 1])
                m2 = self.plot_trial_with_image(ax2, trial, 'top', "Testing", 
                                                image_folder, "second_image", 
                                                show_image_num=2)
                
                all_metrics['top'].append({'first': m1, 'second': m2})
            
            # Plot bottom trials
            for i, (_, trial) in enumerate(bottom_trials.iterrows()):
                # First image
                ax1 = fig.add_subplot(gs[1, i*2])
                m1 = self.plot_trial_with_image(ax1, trial, 'bottom', "Testing", 
                                                image_folder, "first_image", 
                                                show_image_num=1)
                
                # Second image
                ax2 = fig.add_subplot(gs[1, i*2 + 1])
                m2 = self.plot_trial_with_image(ax2, trial, 'bottom', "Testing", 
                                                image_folder, "second_image", 
                                                show_image_num=2)
                
                all_metrics['bottom'].append({'first': m1, 'second': m2})
            
            # Add colorbar for heatmaps
            cbar_ax = fig.add_axes([0.92, 0.15, 0.015, 0.7])
            sm = plt.cm.ScalarMappable(cmap='jet', norm=plt.Normalize(vmin=0, vmax=1))
            sm.set_array([])
            cbar = fig.colorbar(sm, cax=cbar_ax)
            cbar.set_label('Fixation Density', fontsize=10)
            
            # Save figure
            output_file = output_dir / f"qualitative_{condition}_viewing.pdf"
            plt.savefig(output_file, bbox_inches='tight', dpi=150)
            plt.close()
            print(f"  ‚úì Saved: {output_file}")
            
            # Create metrics summary
            self.create_metrics_summary(condition, all_metrics, top_trials, bottom_trials, output_dir)
    
    def create_metrics_summary(self, condition: str, metrics_data: Dict, 
                              top_trials: pd.DataFrame, bottom_trials: pd.DataFrame, 
                              output_dir: Path):
        """Create a summary figure showing pattern metrics"""
        fig, axes = plt.subplots(2, 3, figsize=(15, 10))
        fig.suptitle(f'{condition.upper()} Viewing - Pattern Analysis\n'
                    f'Values that Explain the Overall Pattern',
                    fontsize=14, fontweight='bold')
        
        # Extract metrics for comparison
        metric_names = ['n_fixations', 'avg_duration', 'spatial_spread', 
                       'center_bias', 'coverage', 'total_duration']
        metric_labels = ['Number of\nFixations', 'Average\nDuration (ms)', 
                        'Spatial\nSpread (px)', 'Center\nBias (px)', 
                        'Coverage\n(%)', 'Total\nDuration (ms)']
        
        for idx, (metric, label) in enumerate(zip(metric_names, metric_labels)):
            ax = axes[idx // 3, idx % 3]
            
            # Collect values
            top_vals = []
            bottom_vals = []
            
            for trial_metrics in metrics_data['top']:
                val1 = trial_metrics['first'].get(metric, 0)
                val2 = trial_metrics['second'].get(metric, 0)
                top_vals.extend([val1, val2])
            
            for trial_metrics in metrics_data['bottom']:
                val1 = trial_metrics['first'].get(metric, 0)
                val2 = trial_metrics['second'].get(metric, 0)
                bottom_vals.extend([val1, val2])
            
            if metric == 'coverage':
                top_vals = [v * 100 for v in top_vals]
                bottom_vals = [v * 100 for v in bottom_vals]
            
            # Create bar plot
            x = ['Top\nTrials', 'Bottom\nTrials']
            y = [np.mean(top_vals) if top_vals else 0, 
                 np.mean(bottom_vals) if bottom_vals else 0]
            err = [np.std(top_vals) if top_vals else 0, 
                   np.std(bottom_vals) if bottom_vals else 0]
            
            bars = ax.bar(x, y, yerr=err, color=['green', 'red'], 
                         alpha=0.7, capsize=10)
            
            ax.set_ylabel(label, fontsize=10)
            ax.grid(True, alpha=0.3)
            
            # Add value labels
            for bar, val in zip(bars, y):
                height = bar.get_height()
                ax.text(bar.get_x() + bar.get_width()/2., height,
                       f'{val:.0f}', ha='center', va='bottom', fontsize=9)
            
            # Add difference annotation
            if y[1] != 0:
                diff = y[0] - y[1]
                diff_pct = (diff / y[1]) * 100
                ax.set_title(f'Œî = {diff:.0f} ({diff_pct:+.0f}%)', fontsize=9)
        
        # Add summary text
        avg_acc_top = top_trials['accuracy'].mean()
        avg_acc_bottom = bottom_trials['accuracy'].mean()
        
        summary = f"""
Key Patterns for {condition.upper()} viewing:
‚Ä¢ Top trials (avg accuracy: {avg_acc_top:.1%}) show most difference in fixation patterns
‚Ä¢ Bottom trials (avg accuracy: {avg_acc_bottom:.1%}) demonstrate inefficient scanning
‚Ä¢ Success rate difference: {(avg_acc_top - avg_acc_bottom):.1%}
‚Ä¢ These metrics explain the overall pattern between successful and unsuccessful trials
        """
        
        fig.text(0.5, 0.02, summary, ha='center', fontsize=10,
                bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
        
        plt.tight_layout(rect=[0, 0.08, 1, 0.96])
        
        output_file = output_dir / f"metrics_{condition}_viewing.pdf"
        plt.savefig(output_file, bbox_inches='tight', dpi=150)
        plt.close()
        print(f"  ‚úì Metrics saved: {output_file}")
    
    def create_overall_summary(self, output_dir: Path = Path("qualitative_analysis")):
        """Create an overall summary across all conditions"""
        if "Testing" not in self.datasets:
            return
        
        fig, axes = plt.subplots(3, 4, figsize=(16, 12))
        fig.suptitle('Pattern Discovery: Key Differences Between Top and Bottom Trials',
                    fontsize=14, fontweight='bold')
        
        viewing_conditions = ['full', 'central', 'peripheral']
        metrics_to_show = ['n_fixations', 'spatial_spread', 'center_bias', 'coverage']
        metric_labels = ['Fixation Count', 'Spatial Spread (px)', 'Center Bias (px)', 'Coverage (%)']
        
        for i, condition in enumerate(viewing_conditions):
            # Get trials
            top_trials, bottom_trials = self.get_top_bottom_trials("Testing", condition, n=3)
            
            if top_trials.empty or bottom_trials.empty:
                continue
            
            for j, (metric, label) in enumerate(zip(metrics_to_show, metric_labels)):
                ax = axes[i, j]
                
                # Collect metrics for all trials
                top_vals = []
                bottom_vals = []
                
                for _, trial in top_trials.iterrows():
                    xs1, ys1, dur1 = self.collect_fixations(trial["trials"], "first_image")
                    xs2, ys2, dur2 = self.collect_fixations(trial["trials"], "second_image")
                    
                    m1 = self.calculate_pattern_metrics(xs1, ys1, dur1)
                    m2 = self.calculate_pattern_metrics(xs2, ys2, dur2)
                    
                    val = (m1[metric] + m2[metric]) / 2
                    if metric == 'coverage':
                        val *= 100
                    top_vals.append(val)
                
                for _, trial in bottom_trials.iterrows():
                    xs1, ys1, dur1 = self.collect_fixations(trial["trials"], "first_image")
                    xs2, ys2, dur2 = self.collect_fixations(trial["trials"], "second_image")
                    
                    m1 = self.calculate_pattern_metrics(xs1, ys1, dur1)
                    m2 = self.calculate_pattern_metrics(xs2, ys2, dur2)
                    
                    val = (m1[metric] + m2[metric]) / 2
                    if metric == 'coverage':
                        val *= 100
                    bottom_vals.append(val)
                
                # Create box plot
                bp = ax.boxplot([top_vals, bottom_vals], 
                               tick_labels=['Top', 'Bottom'],
                               patch_artist=True)
                
                # Color the boxes
                bp['boxes'][0].set_facecolor('lightgreen')
                bp['boxes'][1].set_facecolor('lightcoral')
                
                if i == 0:
                    ax.set_title(label, fontsize=10, fontweight='bold')
                if j == 0:
                    ax.set_ylabel(f'{condition.capitalize()}', fontsize=10, fontweight='bold')
                
                ax.grid(True, alpha=0.3)
                ax.tick_params(labelsize=8)
        
        plt.tight_layout()
        
        output_file = output_dir / "overall_pattern_summary.png"
        fig.savefig(output_file, dpi=300, bbox_inches='tight')
        plt.close()
        print(f"\n‚úì Overall summary saved: {output_file}")

# Main execution
def main():
    print("=" * 60)
    print("QUALITATIVE TRIAL ANALYSIS")
    print("Comparing Top vs Bottom Performing Trials")
    print("=" * 60)
    
    # Initialize analyzer
    analyzer = QualitativeTrialAnalyzer()
    
    # Create output directory
    output_dir = Path("qualitative_analysis")
    output_dir.mkdir(exist_ok=True)
    
    print("\nüìä Generating qualitative comparison figures...")
    print("Note: Actual bird images will be displayed as backgrounds")
    analyzer.create_qualitative_figures(output_dir)
    
    print("\nüìà Creating overall pattern summary...")
    analyzer.create_overall_summary(output_dir)
    
    print("\n" + "=" * 60)
    print("‚úÖ Analysis complete!")
    print(f"üìÅ All outputs saved to: {output_dir}/")
    print("\nThe figures will show:")
    print("‚Ä¢ Actual bird images as backgrounds")
    print("‚Ä¢ Coral-colored fixation points overlaid")
    print("‚Ä¢ Jet colormap heatmaps (semi-transparent)")
    print("‚Ä¢ No viewing condition masks")
    print("=" * 60)

if __name__ == "__main__":
    main()

QUALITATIVE TRIAL ANALYSIS
Comparing Top vs Bottom Performing Trials
‚úì Loaded Training 1: 1364 trials
‚úì Loaded Training 2: 2320 trials
‚úì Loaded Testing: 2304 trials

üìä Generating qualitative comparison figures...
Note: Actual bird images will be displayed as backgrounds

üìä Analyzing full viewing condition...
  ‚úì Saved: qualitative_analysis/qualitative_full_viewing.pdf
  ‚úì Metrics saved: qualitative_analysis/metrics_full_viewing.pdf

üìä Analyzing central viewing condition...
  ‚úì Saved: qualitative_analysis/qualitative_central_viewing.pdf
  ‚úì Metrics saved: qualitative_analysis/metrics_central_viewing.pdf

üìä Analyzing peripheral viewing condition...
  ‚úì Saved: qualitative_analysis/qualitative_peripheral_viewing.pdf
  ‚úì Metrics saved: qualitative_analysis/metrics_peripheral_viewing.pdf

üìà Creating overall pattern summary...

‚úì Overall summary saved: qualitative_analysis/overall_pattern_summary.png

‚úÖ Analysis complete!
üìÅ All outputs saved to: qualita

In [None]:
# file: qualitative_trial_analyzer.py
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.gridspec import GridSpec
from pathlib import Path
from PIL import Image
from scipy.ndimage import gaussian_filter
from typing import Dict, List, Tuple, Optional
import warnings
warnings.filterwarnings('ignore')

class QualitativeTrialAnalyzer:
    """Analyze and visualize top vs bottom performing trials for each viewing condition"""
    
    def __init__(self, data_path: Path = Path(".")):
        self.data_path = data_path
        self.screen_size = (1024, 768)
        self.image_size = 620
        
        # Image positioning on screen
        self.box = (
            self.screen_size[0] // 2 - 310,  # left
            self.screen_size[1] // 2 - 310,  # top
            self.screen_size[0] // 2 + 310,  # right
            self.screen_size[1] // 2 + 310   # bottom
        )
        
        # Load all datasets
        self.load_all_data()
    
    # ---------- helpers for trial labeling ----------
    @staticmethod
    def extract_trial_number(trial: Dict) -> Optional[int]:
        """Best-effort extraction of a numeric trial index."""
        # Why: datasets often vary key names; we normalize without requiring schema changes.
        candidate_keys = [
            "trial_index", "trial_num", "trial_number", "trial",
            "trialId", "trial_id", "idx", "index"
        ]
        for k in candidate_keys:
            if k in trial:
                v = trial[k]
                # Normalize strings like "12", "trial_12"
                if isinstance(v, str):
                    digits = ''.join(ch for ch in v if ch.isdigit())
                    if digits.isdigit():
                        return int(digits)
                elif isinstance(v, (int, float)) and not np.isnan(v):
                    return int(v)
        # As a fallback, try position if present (non-standard)
        if "__position__" in trial and isinstance(trial["__position__"], int):
            return int(trial["__position__"])
        return None

    def format_trial_label(self, trials: List[Dict]) -> str:
        """Generate a compact human label like 'Trial 12' or 'Trials 5, 9, 11‚Ä¶'."""
        nums = []
        for t in trials:
            n = self.extract_trial_number(t)
            if n is not None:
                nums.append(n)
        nums = sorted(set(nums))
        if not nums:
            return "Trial (id unavailable)"
        if len(nums) == 1:
            return f"Trial {nums[0]}"
        # cap to avoid long titles
        head = nums[:5]
        suffix = "‚Ä¶" if len(nums) > 5 else ""
        return "Trials " + ", ".join(str(x) for x in head) + suffix

    # ---------- data loading ----------
    def load_all_data(self):
        """Load data from all phases"""
        self.datasets = {}
        
        dataset_configs = {
            "Training 1": {
                "json": "Training 1/training1.json",
                "images": "Training 1/training1_images",
                "dual_images": False
            },
            "Training 2": {
                "json": "Training 2/training2.json",
                "images": "Training 2/training2_images",
                "dual_images": True
            },
            "Testing": {
                "json": "Testing/testing.json",
                "images": "Testing/testing_images",
                "dual_images": True
            }
        }
        
        for name, config in dataset_configs.items():
            json_path = self.data_path / config["json"]
            if json_path.exists():
                with open(json_path, 'r') as f:
                    data = json.load(f)
                # Optionally attach position for fallback labeling
                for i, tr in enumerate(data):
                    tr.setdefault("__position__", i + 1)
                self.datasets[name] = {
                    "data": data,
                    "image_folder": self.data_path / config["images"],
                    "dual_images": config["dual_images"]
                }
                print(f"‚úì Loaded {name}: {len(data)} trials")
            else:
                print(f"‚ö† {name} not found at {json_path}")
    
    def load_image(self, image_folder: Path, image_name: str) -> Optional[np.ndarray]:
        """Load an image file and return as numpy array"""
        image_path = image_folder / image_name
        if image_path.exists():
            img = Image.open(image_path).convert("RGB")
            return np.array(img)
        else:
            print(f"‚ö† Image not found: {image_path}")
            return None
    
    def get_trial_performance(self, dataset_name: str) -> pd.DataFrame:
        """Calculate performance metrics for each unique trial configuration"""
        if dataset_name not in self.datasets:
            return pd.DataFrame()
        
        data = self.datasets[dataset_name]["data"]
        
        if dataset_name == "Training 1":
            trial_groups = {}
            for trial in data:
                img = trial.get("first_image")
                if img:
                    trial_groups.setdefault(img, []).append(trial)
            
            performance = []
            for img, trials in trial_groups.items():
                correct = sum(1 for t in trials if self.is_correct(t))
                total = len(trials)
                performance.append({
                    "trial_id": img,
                    "first_image": img,
                    "accuracy": correct / total if total > 0 else 0,
                    "n_responses": total,
                    "correct_count": correct,
                    "trials": trials
                })
                
        elif dataset_name in ["Training 2", "Testing"]:
            trial_groups = {}
            for trial in data:
                first = trial.get("first_image")
                second = trial.get("second_image")
                if dataset_name == "Testing":
                    viewing = trial.get("viewing_condition", "unknown")
                    key = (first, second, viewing)
                else:
                    key = (first, second)
                if first and second:
                    trial_groups.setdefault(key, []).append(trial)
            
            performance = []
            for key, trials in trial_groups.items():
                correct = sum(1 for t in trials if self.is_correct(t))
                total = len(trials)
                entry = {
                    "trial_id": str(key),
                    "first_image": key[0],
                    "second_image": key[1],
                    "accuracy": correct / total if total > 0 else 0,
                    "n_responses": total,
                    "correct_count": correct,
                    "trials": trials
                }
                if dataset_name == "Testing" and len(key) > 2:
                    entry["viewing_condition"] = key[2]
                performance.append(entry)
        
        return pd.DataFrame(performance)
    
    def is_correct(self, trial: Dict) -> bool:
        """Determine if a trial response was correct"""
        if "acc" in trial:
            v = trial["acc"]
            if isinstance(v, bool):
                return v
            if isinstance(v, (int, float)):
                return bool(v)
            if isinstance(v, str):
                sv = str(v).strip().lower()
                if sv in {"1", "true", "correct", "right"}:
                    return True
                if sv in {"0", "false", "incorrect", "wrong"}:
                    return False
        
        answer = str(trial.get("subj_answer", "")).strip().lower()
        correct = str(trial.get("correct_response", "")).strip().lower()
        if answer and correct:
            return answer == correct
        return False
    
    def get_top_bottom_trials(self, dataset_name: str, viewing_condition: Optional[str] = None, n: int = 3):
        """Get top and bottom n trial configurations by success rate"""
        df = self.get_trial_performance(dataset_name)
        if df.empty:
            return pd.DataFrame(), pd.DataFrame()
        if viewing_condition and "viewing_condition" in df.columns:
            df = df[df["viewing_condition"] == viewing_condition]
        df = df.sort_values("accuracy", ascending=False)
        top = df.head(n)
        bottom = df.tail(n)
        return top, bottom
    
    def collect_fixations(self, trials: List[Dict], image_key: str = "first_image") -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        """Collect all fixations for given trials"""
        l, t, r, b = self.box
        xs_all, ys_all, durations_all = [], [], []
        
        for trial in trials:
            xs = np.asarray(trial.get("fix_x", []), dtype=float)
            ys = np.asarray(trial.get("fix_y", []), dtype=float)
            durations = np.asarray(trial.get("fix_dur", np.ones_like(xs) * 100), dtype=float)
            
            order = np.asarray(trial.get("fix_index", np.arange(1, len(xs) + 1)), dtype=float)
            idx = trial.get("test_image_fixation_idx")
            
            if idx is not None:
                if image_key == "first_image":
                    mask = order < idx
                else:
                    mask = order >= idx
                xs = xs[mask]
                ys = ys[mask]
                durations = durations[mask]
            
            mask = (xs >= l) & (xs <= r) & (ys >= t) & (ys <= b)
            xs_all.extend(xs[mask])
            ys_all.extend(ys[mask])
            durations_all.extend(durations[mask])
        
        return np.array(xs_all), np.array(ys_all), np.array(durations_all)
    
    def create_heatmap(self, xs, ys, durations=None):
        """Create a fixation density heatmap"""
        if len(xs) == 0:
            return np.zeros((150, 150))
        
        l, t, r, b = self.box
        H, xedges, yedges = np.histogram2d(
            xs, ys, 
            bins=[np.linspace(l, r, 150), np.linspace(t, b, 150)],
            weights=durations if durations is not None else None
        )
        H = gaussian_filter(H, sigma=15)
        return H
    
    def calculate_pattern_metrics(self, xs, ys, durations):
        """Calculate metrics that explain fixation patterns"""
        if len(xs) == 0:
            return {
                'n_fixations': 0,
                'total_duration': 0,
                'avg_duration': 0,
                'spatial_spread': 0,
                'center_bias': 0,
                'coverage': 0
            }
        
        l, t, r, b = self.box
        xs_img = xs - l
        ys_img = ys - t
        
        metrics = {
            'n_fixations': len(xs),
            'total_duration': np.sum(durations) if durations is not None else len(xs) * 100,
            'avg_duration': np.mean(durations) if durations is not None else 100,
            'spatial_spread': np.std(xs_img) + np.std(ys_img),
            'center_bias': np.mean(np.sqrt((xs_img - 310)**2 + (ys_img - 310)**2)),
            'coverage': len(set([(int(x/20), int(y/20)) for x, y in zip(xs_img, ys_img)])) / 961
        }
        return metrics
    
    def plot_trial_with_image(self, ax, trial_config, position: str, dataset_name: str, 
                              image_folder: Path, image_key: str = "first_image", 
                              show_image_num: int = 1):
        """Plot heatmap and fixations with actual image background"""
        trials = trial_config["trials"]
        
        # Determine which image to show (background only)
        if image_key == "first_image":
            image_name = trial_config["first_image"]
        else:
            image_name = trial_config.get("second_image", trial_config["first_image"])
        
        l, t, r, b = self.box
        img = self.load_image(image_folder, image_name)
        if img is not None:
            ax.imshow(img, extent=(l, r, b, t), aspect='auto')
        
        xs, ys, durations = self.collect_fixations(trials, image_key)
        
        if len(xs) > 0:
            H = self.create_heatmap(xs, ys, durations)
            extent = [l, r, b, t]
            im = ax.imshow(H.T, extent=extent, origin="upper", cmap="jet", 
                           alpha=0.6, interpolation="bilinear")
            ax.scatter(xs, ys, s=30, c='#FF6B6B', alpha=0.5, 
                       edgecolors='white', linewidth=0.8)
        
        rect = patches.Rectangle((l, t), r-l, b-t, linewidth=2, 
                                 edgecolor='black', facecolor='none')
        ax.add_patch(rect)
        
        metrics = self.calculate_pattern_metrics(xs, ys, durations)
        
        color = 'green' if position == 'top' else 'red'
        trial_label = self.format_trial_label(trials)

        # Title now highlights the trial(s), not the image name; keep which image slot this is.
        title = f"{trial_label} ¬∑ Image {show_image_num}\n"
        title += f"Acc: {trial_config['accuracy']:.1%} (N={trial_config['n_responses']})\n"
        title += f"Fix: {metrics['n_fixations']} | Gaze Dispersion: {metrics['spatial_spread']:.0f}px"
        
        ax.set_title(title, fontsize=9, color=color)
        
        ax.set_xlim(0, self.screen_size[0])
        ax.set_ylim(self.screen_size[1], 0)
        ax.set_xlabel("x (screen px)", fontsize=8)
        ax.set_ylabel("y (screen px)", fontsize=8)
        ax.tick_params(labelsize=7)
        
        return metrics
    
    def create_qualitative_figures(self, output_dir: Path = Path("qualitative_analysis")):
        """Create comprehensive qualitative comparison figures"""
        output_dir.mkdir(exist_ok=True)
        
        if "Testing" not in self.datasets:
            print("‚ö† Testing dataset not found!")
            return
        
        dataset_info = self.datasets["Testing"]
        image_folder = dataset_info["image_folder"]
        
        viewing_conditions = ['full', 'central', 'peripheral']
        
        for condition in viewing_conditions:
            print(f"\nüìä Analyzing {condition} viewing condition...")
            
            top_trials, bottom_trials = self.get_top_bottom_trials("Testing", condition, n=3)
            
            if top_trials.empty or bottom_trials.empty:
                print(f"  ‚ö† Not enough data for {condition}")
                continue
            
            fig = plt.figure(figsize=(20, 14))
            fig.suptitle(
                f'Testing Phase - {condition.upper()} Viewing Condition\n'
                f'Top 3 Trials (High Accuracy) vs Bottom 3 Trials (Low Accuracy)',
                fontsize=16, fontweight='bold'
            )
            
            gs = GridSpec(2, 6, figure=fig, hspace=0.3, wspace=0.15)
            all_metrics = {'top': [], 'bottom': []}
            
            for i, (_, trial) in enumerate(top_trials.iterrows()):
                ax1 = fig.add_subplot(gs[0, i*2])
                m1 = self.plot_trial_with_image(ax1, trial, 'top', "Testing", 
                                                image_folder, "first_image", 
                                                show_image_num=1)
                ax2 = fig.add_subplot(gs[0, i*2 + 1])
                m2 = self.plot_trial_with_image(ax2, trial, 'top', "Testing", 
                                                image_folder, "second_image", 
                                                show_image_num=2)
                all_metrics['top'].append({'first': m1, 'second': m2})
            
            for i, (_, trial) in enumerate(bottom_trials.iterrows()):
                ax1 = fig.add_subplot(gs[1, i*2])
                m1 = self.plot_trial_with_image(ax1, trial, 'bottom', "Testing", 
                                                image_folder, "first_image", 
                                                show_image_num=1)
                ax2 = fig.add_subplot(gs[1, i*2 + 1])
                m2 = self.plot_trial_with_image(ax2, trial, 'bottom', "Testing", 
                                                image_folder, "second_image", 
                                                show_image_num=2)
                all_metrics['bottom'].append({'first': m1, 'second': m2})
            
            cbar_ax = fig.add_axes([0.92, 0.15, 0.015, 0.7])
            sm = plt.cm.ScalarMappable(cmap='jet', norm=plt.Normalize(vmin=0, vmax=1))
            sm.set_array([])
            cbar = fig.colorbar(sm, cax=cbar_ax)
            cbar.set_label('Fixation Density', fontsize=10)
            
            output_file = output_dir / f"qualitative_{condition}_viewing.pdf"
            plt.savefig(output_file, bbox_inches='tight', dpi=150)
            plt.close()
            print(f"  ‚úì Saved: {output_file}")
            
            self.create_metrics_summary(condition, all_metrics, top_trials, bottom_trials, output_dir)
    
    def create_metrics_summary(self, condition: str, metrics_data: Dict, 
                               top_trials: pd.DataFrame, bottom_trials: pd.DataFrame, 
                               output_dir: Path):
        """Create a summary figure showing pattern metrics"""
        fig, axes = plt.subplots(2, 3, figsize=(15, 10))
        fig.suptitle(
            f'{condition.upper()} Viewing - Pattern Analysis\n'
            f'Values that Explain the Overall Pattern',
            fontsize=14, fontweight='bold'
        )
        
        metric_names = ['n_fixations', 'avg_duration', 'spatial_spread', 
                        'center_bias', 'coverage', 'total_duration']
        metric_labels = ['Number of\nFixations', 'Average\nDuration (ms)', 
                         'Gaze\nDispersion (px)', 'Center\nBias (px)', 
                         'Coverage\n(%)', 'Total\nDuration (ms)']
        
        for idx, (metric, label) in enumerate(zip(metric_names, metric_labels)):
            ax = axes[idx // 3, idx % 3]
            top_vals = []
            bottom_vals = []
            
            for trial_metrics in metrics_data['top']:
                val1 = trial_metrics['first'].get(metric, 0)
                val2 = trial_metrics['second'].get(metric, 0)
                top_vals.extend([val1, val2])
            for trial_metrics in metrics_data['bottom']:
                val1 = trial_metrics['first'].get(metric, 0)
                val2 = trial_metrics['second'].get(metric, 0)
                bottom_vals.extend([val1, val2])
            
            if metric == 'coverage':
                top_vals = [v * 100 for v in top_vals]
                bottom_vals = [v * 100 for v in bottom_vals]
            
            x = ['Top\nTrials', 'Bottom\nTrials']
            y = [np.mean(top_vals) if top_vals else 0, 
                 np.mean(bottom_vals) if bottom_vals else 0]
            err = [np.std(top_vals) if top_vals else 0, 
                   np.std(bottom_vals) if bottom_vals else 0]
            
            bars = ax.bar(x, y, yerr=err, color=['green', 'red'], 
                          alpha=0.7, capsize=10)
            
            ax.set_ylabel(label, fontsize=10)
            ax.grid(True, alpha=0.3)
            
            for bar, val in zip(bars, y):
                height = bar.get_height()
                ax.text(bar.get_x() + bar.get_width()/2., height,
                        f'{val:.0f}', ha='center', va='bottom', fontsize=9)
            
            if y[1] != 0:
                diff = y[0] - y[1]
                diff_pct = (diff / y[1]) * 100
                ax.set_title(f'Œî = {diff:.0f} ({diff_pct:+.0f}%)', fontsize=9)
        
        avg_acc_top = top_trials['accuracy'].mean()
        avg_acc_bottom = bottom_trials['accuracy'].mean()
        
        summary = f"""
Key Patterns for {condition.upper()} viewing:
‚Ä¢ Top trials (avg accuracy: {avg_acc_top:.1%}) show most difference in fixation patterns
‚Ä¢ Bottom trials (avg accuracy: {avg_acc_bottom:.1%}) demonstrate inefficient scanning
‚Ä¢ Success rate difference: {(avg_acc_top - avg_acc_bottom):.1%}
‚Ä¢ These metrics explain the overall pattern between successful and unsuccessful trials
        """
        
        fig.text(0.5, 0.02, summary, ha='center', fontsize=10,
                 bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
        
        plt.tight_layout(rect=[0, 0.08, 1, 0.96])
        
        output_file = output_dir / f"metrics_{condition}_viewing.pdf"
        plt.savefig(output_file, bbox_inches='tight', dpi=150)
        plt.close()
        print(f"  ‚úì Metrics saved: {output_file}")
    
    def create_overall_summary(self, output_dir: Path = Path("qualitative_analysis")):
        """Create an overall summary across all conditions"""
        if "Testing" not in self.datasets:
            return
        
        fig, axes = plt.subplots(3, 4, figsize=(16, 12))
        fig.suptitle('Pattern Discovery: Key Differences Between Top and Bottom Trials',
                     fontsize=14, fontweight='bold')
        
        viewing_conditions = ['full', 'central', 'peripheral']
        metrics_to_show = ['n_fixations', 'spatial_spread', 'center_bias', 'coverage']
        metric_labels = ['Fixation Count', 'Gaze Dispersion (px)', 'Center Bias (px)', 'Coverage (%)']
        
        for i, condition in enumerate(viewing_conditions):
            top_trials, bottom_trials = self.get_top_bottom_trials("Testing", condition, n=3)
            if top_trials.empty or bottom_trials.empty:
                continue
            
            for j, (metric, label) in enumerate(zip(metrics_to_show, metric_labels)):
                ax = axes[i, j]
                top_vals = []
                bottom_vals = []
                
                for _, trial in top_trials.iterrows():
                    xs1, ys1, dur1 = self.collect_fixations(trial["trials"], "first_image")
                    xs2, ys2, dur2 = self.collect_fixations(trial["trials"], "second_image")
                    m1 = self.calculate_pattern_metrics(xs1, ys1, dur1)
                    m2 = self.calculate_pattern_metrics(xs2, ys2, dur2)
                    val = (m1[metric] + m2[metric]) / 2
                    if metric == 'coverage': val *= 100
                    top_vals.append(val)
                
                for _, trial in bottom_trials.iterrows():
                    xs1, ys1, dur1 = self.collect_fixations(trial["trials"], "first_image")
                    xs2, ys2, dur2 = self.collect_fixations(trial["trials"], "second_image")
                    m1 = self.calculate_pattern_metrics(xs1, ys1, dur1)
                    m2 = self.calculate_pattern_metrics(xs2, ys2, dur2)
                    val = (m1[metric] + m2[metric]) / 2
                    if metric == 'coverage': val *= 100
                    bottom_vals.append(val)
                
                bp = ax.boxplot([top_vals, bottom_vals], 
                                tick_labels=['Top', 'Bottom'],
                                patch_artist=True)
                bp['boxes'][0].set_facecolor('lightgreen')
                bp['boxes'][1].set_facecolor('lightcoral')
                
                if i == 0:
                    ax.set_title(label, fontsize=10, fontweight='bold')
                if j == 0:
                    ax.set_ylabel(f'{condition.capitalize()}', fontsize=10, fontweight='bold')
                
                ax.grid(True, alpha=0.3)
                ax.tick_params(labelsize=8)
        
        plt.tight_layout()
        output_file = output_dir / "overall_pattern_summary.png"
        fig.savefig(output_file, dpi=300, bbox_inches='tight')
        plt.close()
        print(f"\n‚úì Overall summary saved: {output_file}")

# Main execution
def main():
    print("=" * 60)
    print("QUALITATIVE TRIAL ANALYSIS")
    print("Comparing Top vs Bottom Performing Trials")
    print("=" * 60)
    
    analyzer = QualitativeTrialAnalyzer()
    
    output_dir = Path("qualitative_analysis")
    output_dir.mkdir(exist_ok=True)
    
    print("\nüìä Generating qualitative comparison figures...")
    print("Note: Actual bird images will be displayed as backgrounds")
    analyzer.create_qualitative_figures(output_dir)
    
    print("\nüìà Creating overall pattern summary...")
    analyzer.create_overall_summary(output_dir)
    
    print("\n" + "=" * 60)
    print("‚úÖ Analysis complete!")
    print(f"üìÅ All outputs saved to: {output_dir}/")
    print("\nThe figures will show:")
    print("‚Ä¢ Trial numbers in titles (not image names)")
    print("‚Ä¢ Coral-colored fixation points overlaid")
    print("‚Ä¢ Jet colormap heatmaps (semi-transparent)")
    print("‚Ä¢ Gaze Dispersion in place of Spread")
    print("=" * 60)

if __name__ == "__main__":
    main()


QUALITATIVE TRIAL ANALYSIS
Comparing Top vs Bottom Performing Trials
‚úì Loaded Training 1: 1364 trials
‚úì Loaded Training 2: 2320 trials
‚úì Loaded Testing: 2304 trials

üìä Generating qualitative comparison figures...
Note: Actual bird images will be displayed as backgrounds

üìä Analyzing full viewing condition...
  ‚úì Saved: qualitative_analysis/qualitative_full_viewing.pdf
  ‚úì Metrics saved: qualitative_analysis/metrics_full_viewing.pdf

üìä Analyzing central viewing condition...
  ‚úì Saved: qualitative_analysis/qualitative_central_viewing.pdf
  ‚úì Metrics saved: qualitative_analysis/metrics_central_viewing.pdf

üìä Analyzing peripheral viewing condition...
  ‚úì Saved: qualitative_analysis/qualitative_peripheral_viewing.pdf
  ‚úì Metrics saved: qualitative_analysis/metrics_peripheral_viewing.pdf

üìà Creating overall pattern summary...

‚úì Overall summary saved: qualitative_analysis/overall_pattern_summary.png

‚úÖ Analysis complete!
üìÅ All outputs saved to: qualita

In [1]:
# file: fdm_top_bottom_condition_sheet.py
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import argparse
import json
import re
from collections import defaultdict

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from PIL import Image
from scipy.ndimage import gaussian_filter

# ===== CONFIG =====
SCREEN_SIZE = (1024, 768)
VIEWINGS = ("full", "central", "peripheral")

# ===== UTIL =====
def ensure_dir(p: Path) -> Path:
    p.mkdir(parents=True, exist_ok=True); return p

def sanitize(name: str, maxlen: int = 80) -> str:
    name = re.sub(r"[^\w\-]+", "-", str(name))
    name = re.sub(r"-+", "-", name).strip("-")
    return name[:maxlen]

def _norm(v) -> Optional[str]:
    if v is None: return None
    s = str(v).strip()
    return s if s else None

def _norm_lower(v) -> Optional[str]:
    s = _norm(v)
    return s.lower() if s is not None else None

def norm_viewing(v: object) -> Optional[str]:
    s = _norm_lower(v)
    if s is None: return None
    aliases = {
        "f": "full", "fullfield": "full",
        "c": "central", "center": "central",
        "p": "peripheral", "peri": "peripheral", "periph": "peripheral",
    }
    s = aliases.get(s, s)
    return s if s in VIEWINGS else None

# ===== TRIAL ID EXTRACTION =====
TRIAL_KEYS = (
    "trial_index","trial_num","trial_number","trial",
    "trialId","trial_id","idx","index"
)

def extract_trial_number(tr: Dict) -> Optional[int]:
    for k in TRIAL_KEYS:
        if k in tr:
            v = tr[k]
            if isinstance(v, (int, float)) and not np.isnan(v):
                return int(v)
            if isinstance(v, str):
                digits = "".join(ch for ch in v if ch.isdigit())
                if digits.isdigit():
                    return int(digits)
    if "__position__" in tr and isinstance(tr["__position__"], int):
        return int(tr["__position__"])
    return None

def format_trials_label(trials: List[Dict]) -> str:
    nums = []
    for t in trials:
        n = extract_trial_number(t)
        if n is not None:
            nums.append(n)
    nums = sorted(set(nums))
    if not nums:
        return "Trial (?)"
    if len(nums) == 1:
        return f"Trial {nums[0]}"
    head = nums[:5]
    suffix = "‚Ä¶" if len(nums) > 5 else ""
    return "Trials " + ", ".join(str(x) for x in head) + suffix

# ===== CORRECTNESS =====
def infer_correctness(t: Dict) -> Optional[bool]:
    if "acc" in t:
        v = t["acc"]
        if isinstance(v, bool): return v
        if isinstance(v, (int, float)): return bool(v)
        if isinstance(v, str):
            sv = _norm_lower(v)
            if sv in {"1","true","correct","right"}:  return True
            if sv in {"0","false","incorrect","wrong"}: return False
    ans = _norm_lower(t.get("subj_answer"))
    gt  = _norm_lower(t.get("correct_response"))
    if ans is not None and gt is not None:
        return ans == gt
    return None

# ===== GROUPING =====
def group_trials(trials: List[Dict]) -> Dict[Tuple[str,str,str], List[Dict]]:
    groups: Dict[Tuple[str,str,str], List[Dict]] = defaultdict(list)
    for t in trials:
        f = _norm(t.get("first_image"))
        s = _norm(t.get("second_image"))
        v = norm_viewing(t.get("viewing_condition"))
        if not (f and s and v): continue
        groups[(f, s, v)].append(t)
    return groups

def summarize_group(trs: List[Dict]) -> Dict[str, float | int]:
    right = wrong = total = 0
    for t in trs:
        flag = infer_correctness(t)
        if flag is True:  right += 1; total += 1
        elif flag is False: wrong += 1; total += 1
    acc = (right / total) if total else 0.0
    effect = abs(right - wrong)
    return {"right": right, "wrong": wrong, "total": total, "acc": acc, "effect": effect}

# ===== FIXATIONS + HEATMAP (combined panel) =====
def _box(screen_size: Tuple[int,int] = SCREEN_SIZE):
    w, h = screen_size
    return (w//2 - 310, h//2 - 310, w//2 + 310, h//2 + 310)

def _imshow_bg(ax, image_folder: Path, img_name: str):
    l, t, r, b = _box()
    img = Image.open(image_folder / img_name).convert("RGB")
    ax.imshow(img, extent=(l, r, b, t), aspect="auto")

def _collect_fixations(trials: List[Dict], image_key: str):
    l, t, r, b = _box()
    xs_all, ys_all = [], []
    for tr in trials:
        xs = np.asarray(tr.get("fix_x", []), dtype=float)
        ys = np.asarray(tr.get("fix_y", []), dtype=float)
        order = np.asarray(tr.get("fix_index", np.arange(1, len(xs) + 1)), dtype=float)
        idx = tr.get("test_image_fixation_idx")
        if idx is not None:
            mask = order < idx if image_key == "first_image" else order >= idx
            xs, ys = xs[mask], ys[mask]
        m = (xs>=l) & (xs<=r) & (ys>=t) & (ys<=b)
        xs_all.extend(xs[m]); ys_all.extend(ys[m])
    return np.array(xs_all), np.array(ys_all)

def plot_panel(ax, trials_subset: List[Dict], image_folder: Path, img: str, image_key: str):
    _imshow_bg(ax, image_folder, img)
    xs, ys = _collect_fixations(trials_subset, image_key)
    if xs.size:
        l, t, r, b = _box()
        H, xe, ye = np.histogram2d(xs, ys, bins=[np.linspace(l, r, 150), np.linspace(t, b, 150)])
        H = gaussian_filter(H, sigma=15)
        extent = [xe[0], xe[-1], ye[-1], ye[0]]
        ax.imshow(H.T, extent=extent, origin="upper", cmap="jet", alpha=0.6, interpolation="bilinear")
        ax.scatter(xs, ys, s=28, c="#FF6B6B", alpha=0.55, edgecolors="white", linewidth=0.8)
    ax.set_xlim(0, SCREEN_SIZE[0]); ax.set_ylim(SCREEN_SIZE[1], 0)
    ax.set_xlabel("x (screen px)"); ax.set_ylabel("y (screen px)")

# ===== SHEET BUILDER =====
def build_condition_sheet(viewing: str,
                          picks_top: List[Tuple[Tuple[str,str,str], Dict]],
                          picks_bottom: List[Tuple[Tuple[str,str,str], Dict]],
                          groups: Dict[Tuple[str,str,str], List[Dict]],
                          image_folder: Path,
                          out_path: Path):
    fig = plt.figure(figsize=(20, 12), dpi=150)
    gs = GridSpec(2, 6, figure=fig, hspace=0.35, wspace=0.15)  # equal boxes

    rows = [("Top 3 (High Accuracy)", picks_top, "green"), ("Bottom 3 (Low Accuracy)", picks_bottom, "red")]
    for row_idx, (row_title, picks, color) in enumerate(rows):
        for i in range(3):
            if i >= len(picks):
                ax_blank1 = fig.add_subplot(gs[row_idx, i*2]); ax_blank1.axis("off")
                ax_blank2 = fig.add_subplot(gs[row_idx, i*2+1]); ax_blank2.axis("off")
                continue
            (first, second, _), stats = picks[i]
            trs = groups[(first, second, viewing)]
            trial_label = format_trials_label(trs)
            subset_all = trs  # overlay both correctness types together

            ax1 = fig.add_subplot(gs[row_idx, i*2])
            plot_panel(ax1, subset_all, image_folder, first, "first_image")
            ax1.set_title(
                f"{trial_label} ¬∑ Image 1\nAcc {stats['acc']:.1%} ¬∑ right={stats['right']} ¬∑ wrong={stats['wrong']}",
                fontsize=9, color=color
            )

            ax2 = fig.add_subplot(gs[row_idx, i*2 + 1])
            plot_panel(ax2, subset_all, image_folder, second, "second_image")
            ax2.set_title(
                f"{trial_label} ¬∑ Image 2\nAcc {stats['acc']:.1%} ¬∑ right={stats['right']} ¬∑ wrong={stats['wrong']}",
                fontsize=9, color=color
            )

    cax = fig.add_axes([0.92, 0.15, 0.015, 0.7])
    sm = plt.cm.ScalarMappable(cmap="jet", norm=plt.Normalize(vmin=0, vmax=1))
    sm.set_array([])
    cbar = fig.colorbar(sm, cax=cax)
    cbar.set_label("Fixation Density", fontsize=10)

    fig.suptitle(
        f"Testing Phase ‚Äì {viewing.upper()} Viewing\nTop 3 Trials vs Bottom 3 Trials (heatmap + fixations overlaid)",
        fontsize=16, fontweight="bold"
    )
    plt.savefig(out_path, bbox_inches="tight")
    plt.close(fig)
    print(f"  ‚úì Saved: {out_path}")

# ===== SELECTION =====
def pick_top_bottom_per_view(groups: Dict[Tuple[str,str,str], List[Dict]], k: int, min_n: int):
    by_view: Dict[str, List[Tuple[Tuple[str,str,str], Dict]]] = {v: [] for v in VIEWINGS}
    for key, trs in groups.items():
        _, _, v = key
        if v not in VIEWINGS: continue
        stats = summarize_group(trs)
        if stats["total"] < min_n:  # filter low N
            continue
        by_view[v].append((key, stats))

    picks: Dict[str, Dict[str, List[Tuple[Tuple[str,str,str], Dict]]]] = {}
    for v in VIEWINGS:
        lst = by_view[v]
        if not lst:
            picks[v] = {"top": [], "bottom": []}
            continue
        top_sorted = sorted(lst, key=lambda x: (x[1]["acc"], x[1]["effect"], x[1]["total"]), reverse=True)
        bot_sorted = sorted(lst, key=lambda x: (x[1]["acc"], x[1]["effect"], x[1]["total"]), reverse=False)
        picks[v] = {"top": top_sorted[:k], "bottom": bot_sorted[:k]}
    return picks

# ===== RUN/CLI =====
def run(json_path: str = "Testing/testing.json",
        images: str = "Testing/testing_images",
        k: int = 3,
        min_n: int = 5,
        outdir: str = "fdm_outputs_correctness/top_bottom_by_condition_sheets") -> None:
    json_path = Path(json_path); img_folder = Path(images); out_root = ensure_dir(Path(outdir))
    with json_path.open("r", encoding="utf-8") as f:
        trials = json.load(f)
    # provide stable fallback indexing for trial labels
    for i, tr in enumerate(trials, start=1):
        tr.setdefault("__position__", i)

    groups = group_trials(trials)
    picks = pick_top_bottom_per_view(groups, k=k, min_n=min_n)

    for viewing in VIEWINGS:
        view_out = ensure_dir(out_root / viewing)
        top_picks = picks[viewing]["top"]
        bot_picks = picks[viewing]["bottom"]
        out = view_out / f"qualitative_top_bottom_{sanitize(viewing)}.pdf"
        build_condition_sheet(viewing, top_picks, bot_picks, groups, img_folder, out)

    print(f"\nDone. Sheets in: {out_root.resolve()}")

def parse_args():
    p = argparse.ArgumentParser(description="One-page sheets: Top/Bottom 3 per viewing (heatmap+fixations; shows trial IDs)")
    p.add_argument("--json", default="Testing/testing.json")
    p.add_argument("--images", default="Testing/testing_images")
    p.add_argument("--k", type=int, default=3)
    p.add_argument("--min-n", type=int, default=5)
    p.add_argument("--outdir", default="fdm_outputs_correctness/top_bottom_by_condition_sheets")
    p.add_argument("-f", "--f", help=argparse.SUPPRESS)  # swallow Jupyter -f
    args, _ = p.parse_known_args()
    return args

if __name__ == "__main__":
    args = parse_args()
    run(json_path=args.json, images=args.images, k=args.k, min_n=args.min_n, outdir=args.outdir)


  ‚úì Saved: fdm_outputs_correctness/top_bottom_by_condition_sheets/full/qualitative_top_bottom_full.pdf
  ‚úì Saved: fdm_outputs_correctness/top_bottom_by_condition_sheets/central/qualitative_top_bottom_central.pdf
  ‚úì Saved: fdm_outputs_correctness/top_bottom_by_condition_sheets/peripheral/qualitative_top_bottom_peripheral.pdf

Done. Sheets in: /Users/daisybuathatseephol/Documents/three_json_output/fdm_outputs_correctness/top_bottom_by_condition_sheets


In [5]:
# scripts/detect_same_pairs_across_conditions.py
#!/usr/bin/env python3
"""
Detect identical (unordered) image pairs appearing across viewing conditions (full/central/peripheral).

Inputs searched:
- JSON: --json (default: Testing/testing.json)
- TSV/CSV: recursively scan folders whose name you pass via --scan-tsv-dir (repeatable).
  If none exist or none match, auto-fallback to scan --data-root recursively.

Understands TSV headers you shared: filename1, filename2, viewing, TRIAL_INDEX/INDEX.
Notebook-safe (ignores unknown CLI args) and includes a helper run_in_notebook(...).
"""

from __future__ import annotations

import argparse
import csv
import json
import re
from collections import Counter, defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Sequence, Tuple

import pandas as pd

# ---------------- constants ----------------

Pair = Tuple[str, str]
Triple = Tuple[str, str, str]
VALID_CONDITIONS = {"full", "central", "peripheral"}

TRIAL_KEY_CANDIDATES = [
    "trial_index", "trial_num", "trial_number", "trial",
    "trialid", "trial_id", "idx", "index"
]
FIRST_COL_CANDS = [
    "first_image", "image_first", "first", "img1", "image1",
    "left_image", "left", "filename1"
]
SECOND_COL_CANDS = [
    "second_image", "image_second", "second", "img2", "image2",
    "right_image", "right", "filename2"
]
COND_COL_CANDS = ["viewing_condition", "condition", "view", "viewing"]


# ---------------- utils ----------------

def normalize_token(s: str, strip_ext: bool) -> str:
    t = " ".join(str(s).strip().lower().split())
    if strip_ext:
        p = Path(t)
        if p.suffix:
            t = p.with_suffix("").as_posix()
    return t


def unordered_pair(a: str, b: str) -> Pair:
    return tuple(sorted((a, b)))  # type: ignore[return-value]


def pick_first_present(cols: List[str], cands: List[str]) -> Optional[str]:
    for c in cands:
        if c in cols:
            return c
    return None


def normalize_columns(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    df.columns = [re.sub(r"\s+", "_", c.strip().lower()) for c in df.columns]
    return df


def read_any_tsv(path: Path) -> pd.DataFrame:
    try:
        return pd.read_csv(path, sep="\t", dtype=str, encoding="utf-8")
    except Exception:
        return pd.read_csv(path, dtype=str, encoding="utf-8")


def infer_condition_from_name(name: str) -> Optional[str]:
    s = name.lower()
    for c in VALID_CONDITIONS:
        if c in s:
            return c
    return None


# ---------------- loaders ----------------

def load_trials_json(json_path: Path, verbose: bool = False) -> List[Triple]:
    if not json_path.exists():
        if verbose:
            print(f"[verbose] JSON not found: {json_path}")
        return []
    with json_path.open("r", encoding="utf-8") as f:
        data = json.load(f)
    out: List[Triple] = []
    for t in data if isinstance(data, list) else []:
        a, b = t.get("first_image"), t.get("second_image")
        cond = str(t.get("viewing_condition", "")).strip().lower()
        if a and b and cond in VALID_CONDITIONS:
            out.append((str(a), str(b), cond))
    if verbose:
        print(f"[verbose] JSON pairs: {len(out)} from {json_path}")
        if out[:3]:
            print(f"[verbose] JSON sample: {out[:3]}")
    return out


def extract_pairs_direct(df: pd.DataFrame, source_name: str, verbose: bool) -> List[Triple]:
    """When both image columns exist in a single TSV/CSV."""
    df = normalize_columns(df)
    cols = df.columns.tolist()
    c_first = pick_first_present(cols, FIRST_COL_CANDS)
    c_second = pick_first_present(cols, SECOND_COL_CANDS)
    c_cond = pick_first_present(cols, COND_COL_CANDS)
    if verbose:
        print(f"[verbose] {source_name}: first={c_first} second={c_second} cond={c_cond or '(filename hint)'}")
    if not (c_first and c_second):
        return []
    out: List[Triple] = []
    take = df[[c_first, c_second] + ([c_cond] if c_cond else [])]
    fname_cond = infer_condition_from_name(source_name)
    for _, row in take.iterrows():
        a, b = row.get(c_first), row.get(c_second)
        if not (isinstance(a, str) and isinstance(b, str)):
            continue
        cond = (str(row.get(c_cond, "")).strip().lower() if c_cond else "") or (fname_cond or "")
        if cond in VALID_CONDITIONS:
            out.append((a, b, cond))
    if verbose:
        print(f"[verbose] {source_name}: extracted {len(out)} triples")
    return out


def scan_tsvs(scan_roots: List[Path], glob_pattern: str, data_root: Path, verbose: bool) -> List[Path]:
    """Return files matching *glob_pattern*.(tsv|csv); auto-fallback to data_root if no hits."""
    candidates: List[Path] = []
    for root in scan_roots:
        if root.exists():
            hits = [p for p in root.rglob("*") if p.is_file()
                    and glob_pattern.lower() in p.name.lower()
                    and p.suffix.lower() in (".tsv", ".csv")]
            if verbose:
                print(f"[verbose] Scanning {root} -> {len(hits)} file(s)")
            candidates.extend(hits)
        else:
            if verbose:
                print(f"[verbose] Scan root not found: {root}")
    if not candidates:
        # fallback: scan entire data_root
        hits = [p for p in data_root.rglob("*") if p.is_file()
                and glob_pattern.lower() in p.name.lower()
                and p.suffix.lower() in (".tsv", ".csv")]
        if verbose:
            print(f"[verbose] Fallback scan {data_root} -> {len(hits)} file(s)")
        candidates.extend(hits)
    # de-duplicate while preserving order
    seen = set()
    uniq: List[Path] = []
    for p in candidates:
        rp = p.resolve()
        if rp not in seen:
            seen.add(rp)
            uniq.append(rp)
    if verbose and uniq[:10]:
        print("[verbose] First few TSV/CSV files:")
        for p in uniq[:10]:
            print(f"         - {p}")
    return uniq


def load_pairs_from_tsvs(paths: List[Path], verbose: bool) -> List[Triple]:
    triples: List[Triple] = []
    for p in paths:
        try:
            df = read_any_tsv(p)
        except Exception as e:
            if verbose:
                print(f"[verbose] Failed reading {p}: {e}")
            continue
        triples.extend(extract_pairs_direct(df, p.name, verbose))
    if verbose:
        print(f"[verbose] Total TSV-derived triples: {len(triples)}")
    return triples


# ---------------- index & reports ----------------

def build_index(triples: List[Triple], strip_ext: bool):
    pair_to_conditions: Dict[Pair, set] = defaultdict(set)
    pair_to_counts: Dict[Pair, Counter] = defaultdict(Counter)
    for a_raw, b_raw, cond in triples:
        if cond not in VALID_CONDITIONS:
            continue
        a = normalize_token(a_raw, strip_ext)
        b = normalize_token(b_raw, strip_ext)
        if not a or not b:
            continue
        p = unordered_pair(a, b)
        pair_to_conditions[p].add(cond)
        pair_to_counts[p][cond] += 1
    return pair_to_conditions, pair_to_counts


def write_pairs_csv(path: Path, rows: Iterable[Pair]) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(["first_image", "second_image"])
        for a, b in sorted(rows):
            w.writerow([a, b])


def write_master_csv(path: Path, pair_to_conditions, pair_to_counts) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow([
            "first_image","second_image",
            "in_full","in_central","in_peripheral","n_conditions",
            "count_full","count_central","count_peripheral",
        ])
        for a, b in sorted(pair_to_conditions.keys()):
            conds = pair_to_conditions[(a, b)]
            cnts = pair_to_counts[(a, b)]
            w.writerow([
                a, b,
                int("full" in conds), int("central" in conds), int("peripheral" in conds),
                len(conds),
                cnts.get("full", 0), cnts.get("central", 0), cnts.get("peripheral", 0),
            ])


def summarize_and_write(outdir: Path, pair_to_conditions, pair_to_counts) -> pd.DataFrame:
    rows = []
    for (a, b), conds in pair_to_conditions.items():
        cnts = pair_to_counts[(a, b)]
        rows.append({
            "first_image": a, "second_image": b,
            "in_full": int("full" in conds),
            "in_central": int("central" in conds),
            "in_peripheral": int("peripheral" in conds),
            "n_conditions": len(conds),
            "count_full": cnts.get("full", 0),
            "count_central": cnts.get("central", 0),
            "count_peripheral": cnts.get("peripheral", 0),
        })
    master = pd.DataFrame(rows).sort_values(["n_conditions","first_image","second_image"],
                                            ascending=[False,True,True])

    outdir.mkdir(parents=True, exist_ok=True)
    write_master_csv(outdir / "pairs_master.csv", pair_to_conditions, pair_to_counts)

    in2 = [(r["first_image"], r["second_image"]) for _, r in master[master["n_conditions"] == 2].iterrows()]
    in3 = [(r["first_image"], r["second_image"]) for _, r in master[master["n_conditions"] == 3].iterrows()]
    u_full = [(r["first_image"], r["second_image"]) for _, r in master[(master["n_conditions"] == 1) & (master["in_full"] == 1)].iterrows()]
    u_central = [(r["first_image"], r["second_image"]) for _, r in master[(master["n_conditions"] == 1) & (master["in_central"] == 1)].iterrows()]
    u_periph = [(r["first_image"], r["second_image"]) for _, r in master[(master["n_conditions"] == 1) & (master["in_peripheral"] == 1)].iterrows()]

    write_pairs_csv(outdir / "pairs_in_2_conditions.csv", in2)
    write_pairs_csv(outdir / "pairs_in_3_conditions.csv", in3)
    write_pairs_csv(outdir / "unique_full.csv", u_full)
    write_pairs_csv(outdir / "unique_central.csv", u_central)
    write_pairs_csv(outdir / "unique_peripheral.csv", u_periph)

    print("=== Pair Overlap Summary ===")
    print(f"Total unique unordered pairs: {len(master)}")
    print(f"Pairs in exactly 2 conditions: {len(in2)}")
    print(f"Pairs in all 3 conditions:     {len(in3)}")
    print(f"Unique to FULL / CENTRAL / PERIPHERAL: {len(u_full)} / {len(u_central)} / {len(u_periph)}")
    print(f"Reports saved to: {outdir.resolve()}")
    return master


# ---------------- CLI ----------------

def build_argparser() -> argparse.ArgumentParser:
    p = argparse.ArgumentParser(
        description="Detect same image pairs across viewing conditions from JSON and *testing*.tsv/.csv files.",
        allow_abbrev=False,
    )
    p.add_argument("--data-root", type=Path, default=Path("."), help="Project root.")
    p.add_argument("--json", type=str, default="Testing/testing.json", help="Relative to --data-root.")
    p.add_argument(
        "--scan-tsv-dir", type=Path, nargs="*", default=[Path("THREE_JSON_OUTPUT/data/three_json_output")],
        help="One or more folders to recursively scan. If none found or folder missing, falls back to scanning --data-root."
    )
    p.add_argument("--tsv-glob", type=str, default="testing", help="Substring to match in filenames (default: 'testing').")
    p.add_argument("--out", type=Path, default=Path("pair_overlap_reports"), help="Output directory.")
    p.add_argument("--strip-ext", action="store_true", help="Match regardless of file extension.")
    p.add_argument("--verbose", action="store_true", help="Print diagnostic info.")
    return p


def main(argv: Optional[Sequence[str]] = None) -> int:
    ap = build_argparser()
    args, unknown = ap.parse_known_args(argv)  # ignore notebook-injected args like --f=...
    if unknown:
        print(f"[info] Ignoring unknown CLI args: {unknown}")

    data_root = args.data_root.resolve()
    json_path = (data_root / args.json).resolve()

    # TSV scan
    scan_roots = [data_root / d for d in args.scan_tsv_dir] if args.scan_tsv_dir else [data_root]
    tsv_paths = scan_tsvs(scan_roots, args.tsv_glob, data_root, verbose=args.verbose)

    # Load
    triples: List[Triple] = []
    triples += load_trials_json(json_path, verbose=args.verbose)
    triples += load_pairs_from_tsvs(tsv_paths, verbose=args.verbose)

    if args.verbose:
        print(f"[verbose] Total triples gathered: {len(triples)}")

    if not triples:
        print("No pairs found in JSON/TSV inputs that match expected conditions.")
        return 0

    pair_to_conditions, pair_to_counts = build_index(triples, args.strip_ext)
    summarize_and_write(args.out, pair_to_conditions, pair_to_counts)
    return 0


# ---------------- Notebook helper ----------------

def run_in_notebook(
    data_root: str | Path = ".",
    json_rel: str = "Testing/testing.json",
    scan_dirs: list[str | Path] | None = None,
    tsv_glob: str = "testing",
    outdir: str | Path = "pair_overlap_reports",
    strip_ext: bool = True,
    verbose: bool = True,
) -> pd.DataFrame:
    """
    Example:
        from scripts.detect_same_pairs_across_conditions import run_in_notebook
        df = run_in_notebook(data_root=".", scan_dirs=["THREE_JSON_OUTPUT"], verbose=True)
    """
    root = Path(data_root).resolve()
    dirs = [root / Path(d) for d in (scan_dirs or ["THREE_JSON_OUTPUT/data/three_json_output"])]
    tsv_paths = scan_tsvs(dirs, tsv_glob, root, verbose=verbose)
    triples: List[Triple] = []
    triples += load_trials_json((root / json_rel).resolve(), verbose=verbose)
    triples += load_pairs_from_tsvs(tsv_paths, verbose=verbose)
    if not triples:
        print("No pairs found in JSON/TSV inputs that match expected conditions.")
        return pd.DataFrame()
    pair_to_conditions, pair_to_counts = build_index(triples, strip_ext)
    return summarize_and_write(Path(outdir), pair_to_conditions, pair_to_counts)


if __name__ == "__main__":
    raise SystemExit(main())


[info] Ignoring unknown CLI args: ['--f=/Users/daisybuathatseephol/Library/Jupyter/runtime/kernel-v3f3702a772b0e7063e8c6aabbccc98f0eb003b113.json']
No pairs found in JSON/TSV inputs that match expected conditions.


SystemExit: 0

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [1]:
import pandas as pd
from pathlib import Path
import json
from collections import defaultdict, Counter

def analyze_image_pairs(root_path=".", json_folder="Testing", data_folder="data files", verbose=True):
    """
    Find and analyze duplicate image pairs across viewing conditions.
    
    Parameters:
    - root_path: Base directory to search from
    - json_folder: Folder name containing JSON files (default: "Testing")
    - data_folder: Folder name containing TSV/CSV files (default: "data files")
    """
    root = Path(root_path).resolve()
    print(f"üîç Searching from root: {root}\n")
    
    # Look for the specific folders
    json_dirs = list(root.rglob(json_folder))
    data_dirs = list(root.rglob(data_folder))
    
    print(f"üìÅ Found {len(json_dirs)} '{json_folder}' folder(s)")
    print(f"üìÅ Found {len(data_dirs)} '{data_folder}' folder(s)")
    
    # Collect all JSON files from Testing folders
    json_files = []
    for jdir in json_dirs:
        json_files.extend(list(jdir.glob("*.json")))
    
    # Collect all TSV/CSV files from data files folders
    data_files = []
    for ddir in data_dirs:
        data_files.extend(list(ddir.glob("*.tsv")))
        data_files.extend(list(ddir.glob("*.csv")))
    
    # Also search in root and immediate subdirectories if folders not found
    if not json_files:
        print(f"‚ö†Ô∏è No JSON files found in '{json_folder}' folders, searching everywhere...")
        json_files = list(root.rglob("*.json"))
    
    if not data_files:
        print(f"‚ö†Ô∏è No TSV/CSV files found in '{data_folder}' folders, searching everywhere...")
        data_files = list(root.rglob("*.tsv")) + list(root.rglob("*.csv"))
    
    print(f"\nüìÑ Found {len(json_files)} JSON file(s)")
    if json_files:
        for f in json_files[:5]:
            print(f"  - {f.name} (in {f.parent.name}/)")
        if len(json_files) > 5:
            print(f"  ... and {len(json_files)-5} more")
    
    print(f"\nüìä Found {len(data_files)} TSV/CSV file(s)")
    if data_files:
        for f in data_files[:5]:
            print(f"  - {f.name} (in {f.parent.name}/)")
        if len(data_files) > 5:
            print(f"  ... and {len(data_files)-5} more")
    
    if not json_files and not data_files:
        print("\n‚ùå No data files found!")
        print("Make sure you're running this from the right directory.")
        print("Current directory structure:")
        # Show directory structure
        for p in sorted(root.iterdir())[:10]:
            if p.is_dir():
                print(f"  üìÅ {p.name}/")
                for sp in sorted(p.iterdir())[:5]:
                    print(f"     - {sp.name}")
        return pd.DataFrame()
    
    all_pairs = []
    
    # Process JSON files
    print("\n" + "="*60)
    print("PROCESSING JSON FILES")
    print("="*60)
    
    for json_file in json_files:
        print(f"\nüìÑ Processing: {json_file.name}")
        
        try:
            with open(json_file) as f:
                data = json.load(f)
            
            # Handle different JSON structures
            if isinstance(data, list):
                items = data
                print(f"  Format: List with {len(items)} items")
            elif isinstance(data, dict):
                # Try various possible keys
                items = data.get("trials") or data.get("data") or data.get("items") or []
                print(f"  Format: Dict with {len(items)} items in '{list(data.keys())[:5]}'")
            else:
                items = []
            
            # If still no items but dict has direct data
            if not items and isinstance(data, dict):
                items = [data]  # Treat the dict itself as a single item
            
            found_count = 0
            sample_item = None
            
            for item in items:
                if not isinstance(item, dict):
                    continue
                
                if not sample_item:
                    sample_item = item
                
                # Try ALL possible key names for images
                img1 = (item.get("first_image") or item.get("image1") or 
                       item.get("left_image") or item.get("filename1") or
                       item.get("image_left") or item.get("leftimage") or
                       item.get("img1") or item.get("image_1") or
                       item.get("imageFirst") or item.get("first") or
                       item.get("image_a") or item.get("stim1"))
                
                img2 = (item.get("second_image") or item.get("image2") or 
                       item.get("right_image") or item.get("filename2") or
                       item.get("image_right") or item.get("rightimage") or
                       item.get("img2") or item.get("image_2") or
                       item.get("imageSecond") or item.get("second") or
                       item.get("image_b") or item.get("stim2"))
                
                # Try ALL possible key names for condition
                condition = (item.get("viewing_condition") or item.get("viewing") or 
                           item.get("condition") or item.get("view") or 
                           item.get("viewing_cond") or item.get("viewingCondition") or
                           item.get("view_condition") or item.get("trial_type") or "")
                
                if img1 and img2:
                    condition = str(condition).strip().lower()
                    # Also check for variations of condition names
                    if condition in ["full", "central", "peripheral", "f", "c", "p"]:
                        # Normalize single letters
                        if condition == "f": condition = "full"
                        elif condition == "c": condition = "central"
                        elif condition == "p": condition = "peripheral"
                        
                        # Store as sorted pair for unordered comparison
                        pair = tuple(sorted([str(img1).strip().lower(), str(img2).strip().lower()]))
                        all_pairs.append({
                            "image1": pair[0],
                            "image2": pair[1],
                            "condition": condition,
                            "source": json_file.name,
                            "source_type": "json"
                        })
                        found_count += 1
            
            print(f"  ‚úì Found {found_count} valid pairs")
            
            # Show sample if no pairs found
            if found_count == 0 and sample_item:
                print(f"  ‚ö†Ô∏è No valid pairs extracted. Sample item keys:")
                print(f"     {list(sample_item.keys())[:15]}")
                if verbose:
                    print(f"  Sample values:")
                    for k, v in list(sample_item.items())[:5]:
                        print(f"     {k}: {str(v)[:50]}")
                
        except Exception as e:
            print(f"  ‚ùå Error: {e}")
    
    # Process TSV/CSV files
    print("\n" + "="*60)
    print("PROCESSING TSV/CSV FILES")
    print("="*60)
    
    for data_file in data_files:
        print(f"\nüìä Processing: {data_file.name}")
        
        try:
            # Try reading as TSV first, then CSV
            try:
                df = pd.read_csv(data_file, sep='\t')
                print(f"  Format: TSV with {len(df)} rows")
            except:
                df = pd.read_csv(data_file)
                print(f"  Format: CSV with {len(df)} rows")
            
            if df.empty:
                print(f"  ‚ö†Ô∏è File is empty")
                continue
            
            # Show original columns
            print(f"  Columns: {list(df.columns)[:10]}{'...' if len(df.columns) > 10 else ''}")
            
            # Normalize column names
            df.columns = [col.strip().lower().replace(' ', '_').replace('-', '_') for col in df.columns]
            
            # Find image columns - be very flexible
            img1_col = None
            img2_col = None
            cond_col = None
            
            # Look for first image column
            for col in df.columns:
                if any(x in col for x in ["filename1", "first", "image1", "img1", "left", "image_1", "stim1", "image_a"]):
                    img1_col = col
                    break
            
            # Look for second image column  
            for col in df.columns:
                if any(x in col for x in ["filename2", "second", "image2", "img2", "right", "image_2", "stim2", "image_b"]):
                    img2_col = col
                    break
                    
            # Look for condition column
            for col in df.columns:
                if any(x in col for x in ["viewing", "condition", "view", "trial_type"]):
                    cond_col = col
                    break
            
            print(f"  Column mapping: img1={img1_col}, img2={img2_col}, cond={cond_col}")
            
            found_count = 0
            if img1_col and img2_col:
                for idx, row in df.iterrows():
                    img1 = str(row[img1_col]).strip() if pd.notna(row[img1_col]) else None
                    img2 = str(row[img2_col]).strip() if pd.notna(row[img2_col]) else None
                    
                    # Skip if images are empty or 'nan'
                    if not img1 or not img2 or img1.lower() == 'nan' or img2.lower() == 'nan':
                        continue
                    
                    # Get condition from column or infer from filename
                    condition = None
                    if cond_col and pd.notna(row[cond_col]):
                        condition = str(row[cond_col]).strip().lower()
                    
                    # If no condition column, try to infer from filename
                    if not condition or condition not in ["full", "central", "peripheral", "f", "c", "p"]:
                        fname_lower = data_file.name.lower()
                        if "full" in fname_lower or "_f_" in fname_lower or "_f." in fname_lower:
                            condition = "full"
                        elif "central" in fname_lower or "_c_" in fname_lower or "_c." in fname_lower:
                            condition = "central"
                        elif "peripheral" in fname_lower or "periph" in fname_lower or "_p_" in fname_lower or "_p." in fname_lower:
                            condition = "peripheral"
                    
                    # Normalize single letter conditions
                    if condition == "f": condition = "full"
                    elif condition == "c": condition = "central"
                    elif condition == "p": condition = "peripheral"
                    
                    if condition in ["full", "central", "peripheral"]:
                        pair = tuple(sorted([img1.lower(), img2.lower()]))
                        all_pairs.append({
                            "image1": pair[0],
                            "image2": pair[1],
                            "condition": condition,
                            "source": data_file.name,
                            "source_type": "csv"
                        })
                        found_count += 1
                
                print(f"  ‚úì Found {found_count} valid pairs")
                
                if found_count == 0 and not df.empty:
                    print(f"  Sample row data:")
                    sample_row = df.iloc[0]
                    if img1_col and img2_col:
                        print(f"    {img1_col}: {sample_row[img1_col]}")
                        print(f"    {img2_col}: {sample_row[img2_col]}")
                        if cond_col:
                            print(f"    {cond_col}: {sample_row[cond_col]}")
            else:
                print(f"  ‚ö†Ô∏è Could not find required image columns")
                print(f"     Available columns: {df.columns.tolist()[:15]}")
                
        except Exception as e:
            print(f"  ‚ùå Error: {e}")
    
    if not all_pairs:
        print("\n‚ùå No valid pairs found across any files!")
        print("\nTroubleshooting:")
        print("1. Check that viewing conditions are: 'full', 'central', 'peripheral' (or 'f', 'c', 'p')")
        print("2. Check that image columns contain actual filenames")
        print("3. Try setting verbose=True to see more details")
        return pd.DataFrame()
    
    # Create analysis
    print(f"\nüîÑ Analyzing {len(all_pairs)} total pairs...")
    df_all = pd.DataFrame(all_pairs)
    
    # Show source breakdown
    print(f"\nPairs by source type:")
    print(f"  JSON files: {len(df_all[df_all['source_type'] == 'json'])}")
    print(f"  CSV files: {len(df_all[df_all['source_type'] == 'csv'])}")
    
    # Group by pair and count conditions
    pair_analysis = df_all.groupby(['image1', 'image2']).agg({
        'condition': lambda x: set(x),
        'source': lambda x: list(x),
        'source_type': 'count'
    }).reset_index()
    
    pair_analysis.columns = ['image1', 'image2', 'conditions', 'sources', 'total_occurrences']
    pair_analysis['n_conditions'] = pair_analysis['conditions'].apply(len)
    
    # Add binary flags for each condition
    pair_analysis['in_full'] = pair_analysis['conditions'].apply(lambda x: 'full' in x).astype(int)
    pair_analysis['in_central'] = pair_analysis['conditions'].apply(lambda x: 'central' in x).astype(int)
    pair_analysis['in_peripheral'] = pair_analysis['conditions'].apply(lambda x: 'peripheral' in x).astype(int)
    
    # Convert conditions set to string
    pair_analysis['conditions_str'] = pair_analysis['conditions'].apply(lambda x: ', '.join(sorted(x)))
    
    # Sort by number of conditions (descending)
    pair_analysis = pair_analysis.sort_values(['n_conditions', 'image1', 'image2'], 
                                              ascending=[False, True, True])
    
    # Print summary
    print("\n" + "="*60)
    print("üìä ANALYSIS RESULTS")
    print("="*60)
    print(f"Total unique pairs: {len(pair_analysis)}")
    print(f"Pairs in 1 condition only: {sum(pair_analysis['n_conditions'] == 1)}")
    print(f"Pairs in 2 conditions: {sum(pair_analysis['n_conditions'] == 2)}")
    print(f"Pairs in all 3 conditions: {sum(pair_analysis['n_conditions'] == 3)}")
    
    # Show duplicates
    duplicates = pair_analysis[pair_analysis['n_conditions'] > 1]
    if not duplicates.empty:
        print(f"\nüîç DUPLICATE PAIRS (in multiple conditions): {len(duplicates)} pairs")
        
        in_all_3 = duplicates[duplicates['n_conditions'] == 3]
        if not in_all_3.empty:
            print(f"\n‚ú® Pairs in ALL 3 conditions ({len(in_all_3)} total):")
            for i, (_, row) in enumerate(in_all_3.head(5).iterrows(), 1):
                print(f"  {i}. {row['image1'][:40]}")
                print(f"     {row['image2'][:40]}")
            if len(in_all_3) > 5:
                print(f"  ... and {len(in_all_3)-5} more")
    
    # Save results
    try:
        output_dir = Path("pair_analysis_results")
        output_dir.mkdir(exist_ok=True)
        
        pair_analysis.to_csv(output_dir / "all_pairs.csv", index=False)
        if not duplicates.empty:
            duplicates.to_csv(output_dir / "duplicate_pairs.csv", index=False)
        
        print(f"\nüíæ Results saved to: {output_dir.resolve()}/")
        
    except Exception as e:
        print(f"\n‚ö†Ô∏è Could not save CSV files: {e}")
    
    return pair_analysis

# RUN THE ANALYSIS - adjust the folder names if needed
df = analyze_image_pairs(
    root_path=".",  # or use "../.." to go up directories
    json_folder="Testing",  # folder containing JSON files
    data_folder="data files",  # folder containing TSV/CSV files
    verbose=True
)

# Show the duplicate pairs
if not df.empty:
    duplicates = df[df['n_conditions'] > 1]
    if not duplicates.empty:
        print("\nüìã All duplicate pairs:")
        display(duplicates[['image1', 'image2', 'conditions_str', 'n_conditions']].head(20))

üîç Searching from root: /Users/daisybuathatseephol/Documents/three_json_output

üìÅ Found 2 'Testing' folder(s)
üìÅ Found 1 'data files' folder(s)

üìÑ Found 1 JSON file(s)
  - testing.json (in Testing/)

üìä Found 11 TSV/CSV file(s)
  - trialdata_training2_firstdisplayonly.tsv (in data files/)
  - trialdata_testing_firstdisplayonly.tsv (in data files/)
  - fixdata_training1.tsv (in data files/)
  - fix_testing.tsv (in data files/)
  - training2_testing_dimensions.tsv (in data files/)
  ... and 6 more

PROCESSING JSON FILES

üìÑ Processing: testing.json
  Format: List with 2304 items
  ‚úì Found 2304 valid pairs

PROCESSING TSV/CSV FILES

üìä Processing: trialdata_training2_firstdisplayonly.tsv
  ‚ùå Error: 'utf-8' codec can't decode byte 0xff in position 0: invalid start byte

üìä Processing: trialdata_testing_firstdisplayonly.tsv
  ‚ùå Error: 'utf-8' codec can't decode byte 0xff in position 0: invalid start byte

üìä Processing: fixdata_training1.tsv
  ‚ùå Error: 'utf-8' co

In [2]:
import pandas as pd
from pathlib import Path
import json
from collections import defaultdict, Counter

def diagnose_pair_matching(root_path=".", json_folder="Testing", data_folder="data files"):
    """
    Diagnose why pairs aren't matching across files
    """
    root = Path(root_path).resolve()
    print(f"üîç Diagnostic Analysis for: {root}\n")
    
    # Find folders
    json_dirs = list(root.rglob(json_folder))
    data_dirs = list(root.rglob(data_folder))
    
    # Collect files
    json_files = []
    for jdir in json_dirs:
        json_files.extend(list(jdir.glob("*.json")))
    
    data_files = []
    for ddir in data_dirs:
        data_files.extend(list(ddir.glob("*.tsv")))
        data_files.extend(list(ddir.glob("*.csv")))
    
    print(f"Files found:")
    print(f"  JSON: {[f.name for f in json_files]}")
    print(f"  TSV/CSV: {[f.name for f in data_files[:10]]}")
    
    # Store all pairs with their sources
    json_pairs = {}  # condition -> set of pairs
    tsv_pairs = {}   # condition -> set of pairs
    
    # Process JSON
    print("\n" + "="*60)
    print("CHECKING JSON FILE")
    print("="*60)
    
    for json_file in json_files:
        with open(json_file) as f:
            data = json.load(f)
        
        items = data if isinstance(data, list) else []
        print(f"\nüìÑ {json_file.name}: {len(items)} items")
        
        # Check first few items
        if items:
            print(f"  Sample item keys: {list(items[0].keys())}")
            
            # Count by condition
            conditions_found = defaultdict(int)
            sample_pairs = defaultdict(list)
            
            for item in items:
                # Get images - try multiple keys
                img1 = (item.get("first_image") or item.get("image1") or 
                       item.get("filename1") or item.get("left_image"))
                img2 = (item.get("second_image") or item.get("image2") or 
                       item.get("filename2") or item.get("right_image"))
                cond = (item.get("viewing_condition") or item.get("viewing") or 
                       item.get("condition") or "")
                
                if img1 and img2:
                    cond = str(cond).strip().lower()
                    if cond in ["full", "central", "peripheral"]:
                        conditions_found[cond] += 1
                        
                        # Store normalized pair
                        pair = tuple(sorted([str(img1).strip().lower(), str(img2).strip().lower()]))
                        
                        if cond not in json_pairs:
                            json_pairs[cond] = set()
                        json_pairs[cond].add(pair)
                        
                        # Keep sample
                        if len(sample_pairs[cond]) < 3:
                            sample_pairs[cond].append((img1, img2))
            
            print(f"  Conditions found:")
            for cond, count in conditions_found.items():
                print(f"    {cond}: {count} pairs")
                if sample_pairs[cond]:
                    print(f"      Sample pairs:")
                    for img1, img2 in sample_pairs[cond][:2]:
                        print(f"        - {img1[:40]}")
                        print(f"          {img2[:40]}")
    
    # Process TSV/CSV files
    print("\n" + "="*60)
    print("CHECKING TSV/CSV FILES")
    print("="*60)
    
    for data_file in data_files[:5]:  # Check first 5 files
        print(f"\nüìä {data_file.name}")
        
        try:
            # Read file
            try:
                df = pd.read_csv(data_file, sep='\t')
            except:
                df = pd.read_csv(data_file)
            
            print(f"  Shape: {df.shape}")
            print(f"  Columns: {list(df.columns)[:10]}")
            
            # Normalize columns
            df.columns = [col.strip().lower().replace(' ', '_') for col in df.columns]
            
            # Find image columns
            img1_col = None
            img2_col = None
            cond_col = None
            
            for col in df.columns:
                if not img1_col and any(x in col for x in ["filename1", "first", "image1", "left"]):
                    img1_col = col
                if not img2_col and any(x in col for x in ["filename2", "second", "image2", "right"]):
                    img2_col = col
                if not cond_col and any(x in col for x in ["viewing", "condition"]):
                    cond_col = col
            
            print(f"  Mapped: img1={img1_col}, img2={img2_col}, cond={cond_col}")
            
            if img1_col and img2_col:
                # Show sample data
                sample = df.head(3)
                print(f"  Sample rows:")
                for idx, row in sample.iterrows():
                    img1 = str(row[img1_col]) if pd.notna(row[img1_col]) else "NA"
                    img2 = str(row[img2_col]) if pd.notna(row[img2_col]) else "NA"
                    cond_val = str(row[cond_col]) if cond_col and pd.notna(row[cond_col]) else "NA"
                    
                    print(f"    Row {idx}:")
                    print(f"      img1: {img1[:50]}")
                    print(f"      img2: {img2[:50]}")
                    print(f"      condition: {cond_val}")
                
                # Count conditions
                conditions_in_file = defaultdict(int)
                for _, row in df.iterrows():
                    img1 = str(row[img1_col]).strip() if pd.notna(row[img1_col]) else None
                    img2 = str(row[img2_col]).strip() if pd.notna(row[img2_col]) else None
                    
                    if not img1 or not img2 or img1.lower() == 'nan' or img2.lower() == 'nan':
                        continue
                    
                    # Get condition
                    if cond_col and pd.notna(row[cond_col]):
                        cond = str(row[cond_col]).strip().lower()
                    else:
                        # Infer from filename
                        fname = data_file.name.lower()
                        if "full" in fname:
                            cond = "full"
                        elif "central" in fname:
                            cond = "central"
                        elif "peripheral" in fname or "periph" in fname:
                            cond = "peripheral"
                        else:
                            cond = None
                    
                    if cond in ["full", "central", "peripheral"]:
                        conditions_in_file[cond] += 1
                        
                        # Store normalized pair
                        pair = tuple(sorted([img1.lower(), img2.lower()]))
                        if cond not in tsv_pairs:
                            tsv_pairs[cond] = set()
                        tsv_pairs[cond].add(pair)
                
                print(f"  Conditions in this file: {dict(conditions_in_file)}")
                
        except Exception as e:
            print(f"  Error: {e}")
    
    # Compare JSON vs TSV pairs
    print("\n" + "="*60)
    print("COMPARISON: JSON vs TSV/CSV")
    print("="*60)
    
    print("\nPairs per condition:")
    all_conditions = set(json_pairs.keys()) | set(tsv_pairs.keys())
    
    for cond in sorted(all_conditions):
        json_count = len(json_pairs.get(cond, set()))
        tsv_count = len(tsv_pairs.get(cond, set()))
        print(f"\n{cond.upper()}:")
        print(f"  JSON pairs: {json_count}")
        print(f"  TSV pairs: {tsv_count}")
        
        if json_count > 0 and tsv_count > 0:
            # Check for overlaps
            json_set = json_pairs[cond]
            tsv_set = tsv_pairs[cond]
            
            overlap = json_set & tsv_set
            json_only = json_set - tsv_set
            tsv_only = tsv_set - json_set
            
            print(f"  Overlapping pairs: {len(overlap)}")
            print(f"  JSON-only pairs: {len(json_only)}")
            print(f"  TSV-only pairs: {len(tsv_only)}")
            
            if len(overlap) == 0 and json_count > 0 and tsv_count > 0:
                print(f"\n  ‚ö†Ô∏è NO OVERLAP! Checking why...")
                
                # Show samples from each
                print(f"  Sample JSON pairs (first 3):")
                for pair in list(json_set)[:3]:
                    print(f"    - {pair[0][:40]}")
                    print(f"      {pair[1][:40]}")
                
                print(f"  Sample TSV pairs (first 3):")
                for pair in list(tsv_set)[:3]:
                    print(f"    - {pair[0][:40]}")
                    print(f"      {pair[1][:40]}")
    
    # Check for duplicates across conditions within each source
    print("\n" + "="*60)
    print("CHECKING FOR DUPLICATES WITHIN EACH SOURCE")
    print("="*60)
    
    # Check JSON duplicates
    all_json_pairs = []
    for cond, pairs in json_pairs.items():
        for pair in pairs:
            all_json_pairs.append((pair, cond))
    
    pair_conditions = defaultdict(set)
    for pair, cond in all_json_pairs:
        pair_conditions[pair].add(cond)
    
    duplicates = {pair: conds for pair, conds in pair_conditions.items() if len(conds) > 1}
    
    print(f"\nJSON duplicates (pairs in multiple conditions): {len(duplicates)}")
    if duplicates:
        for i, (pair, conds) in enumerate(list(duplicates.items())[:5], 1):
            print(f"  {i}. Pair in conditions: {', '.join(sorted(conds))}")
            print(f"     {pair[0][:40]}")
            print(f"     {pair[1][:40]}")
    
    return json_pairs, tsv_pairs

# Run the diagnostic
json_pairs, tsv_pairs = diagnose_pair_matching(".")

print("\n" + "="*60)
print("üí° DIAGNOSIS COMPLETE")
print("="*60)
print("\nPossible issues to check:")
print("1. Image filenames might have different formats between JSON and TSV")
print("2. File extensions might be included in one source but not the other")
print("3. Path information might be included in one source but not the other")
print("4. Case sensitivity issues (uppercase vs lowercase)")
print("5. The TSV files might not have been processed correctly")

üîç Diagnostic Analysis for: /Users/daisybuathatseephol/Documents/three_json_output

Files found:
  JSON: ['testing.json']
  TSV/CSV: ['trialdata_training2_firstdisplayonly.tsv', 'trialdata_testing_firstdisplayonly.tsv', 'fixdata_training1.tsv', 'fix_testing.tsv', 'training2_testing_dimensions.tsv', 'fixdata_training2.tsv', 'trialdata_testing_seconddisplayonly.tsv', 'msgdata_training2.tsv', 'trialdata_training2_seconddisplayonly.tsv', 'training1_dimensions.tsv']

CHECKING JSON FILE

üìÑ testing.json: 2304 items
  Sample item keys: ['subject_id', 'trial_index', 'fix_index', 'fix_start_ms', 'fix_x', 'fix_y', 'fix_dur_ms', 'fix_pupil', 'saccade_dx', 'saccade_dy', 'saccade_amp', 'saccade_dir_deg', 'acc', 'subj_answer', 'paradigm', 'viewing_condition', 'eye_tracked', 'test_image_on_ms', 'test_image_fixation_idx', 'age', 'gender', 'max_calibration_error', 'avg_calibration_error', 'correct_response', 'first_image', 'second_image']
  Conditions found:
    central: 768 pairs
      Sample pair

In [3]:
import pandas as pd
from pathlib import Path
import json
from collections import defaultdict, Counter

def find_duplicate_pairs(root_path="."):
    """
    Find duplicate pairs across viewing conditions from the JSON file.
    Since TSV files have issues, we'll focus on the JSON data.
    """
    root = Path(root_path).resolve()
    print(f"üîç Analyzing pairs from: {root}\n")
    
    # Find the JSON file
    json_files = list(root.rglob("testing.json"))
    if not json_files:
        print("‚ùå testing.json not found!")
        return pd.DataFrame()
    
    json_file = json_files[0]
    print(f"üìÑ Reading: {json_file}\n")
    
    # Load JSON data
    with open(json_file) as f:
        data = json.load(f)
    
    print(f"Total items in JSON: {len(data)}\n")
    
    # Process all pairs
    all_pairs = []
    
    for item in data:
        # Get images
        img1 = item.get("first_image", "")
        img2 = item.get("second_image", "")
        condition = item.get("viewing_condition", "")
        
        if img1 and img2 and condition:
            # Remove .jpg extension and normalize
            img1_clean = str(img1).replace('.jpg', '').strip().lower()
            img2_clean = str(img2).replace('.jpg', '').strip().lower()
            condition_clean = str(condition).strip().lower()
            
            if condition_clean in ["full", "central", "peripheral"]:
                # Create sorted pair (for unordered comparison)
                pair = tuple(sorted([img1_clean, img2_clean]))
                
                all_pairs.append({
                    "img1": pair[0],
                    "img2": pair[1],
                    "condition": condition_clean,
                    "img1_original": img1,
                    "img2_original": img2
                })
    
    if not all_pairs:
        print("‚ùå No valid pairs found!")
        return pd.DataFrame()
    
    # Create DataFrame
    df = pd.DataFrame(all_pairs)
    
    print(f"Valid pairs extracted: {len(df)}")
    print(f"Pairs by condition:")
    for cond in ["full", "central", "peripheral"]:
        count = len(df[df['condition'] == cond])
        print(f"  {cond}: {count}")
    
    # Find duplicates - group by pair and see which appear in multiple conditions
    pair_analysis = df.groupby(['img1', 'img2']).agg({
        'condition': lambda x: sorted(list(x)),
        'img1_original': 'first',
        'img2_original': 'first'
    }).reset_index()
    
    pair_analysis['n_conditions'] = pair_analysis['condition'].apply(len)
    pair_analysis['conditions_str'] = pair_analysis['condition'].apply(lambda x: ', '.join(x))
    
    # Add flags for each condition
    pair_analysis['in_full'] = pair_analysis['condition'].apply(lambda x: 'full' in x).astype(int)
    pair_analysis['in_central'] = pair_analysis['condition'].apply(lambda x: 'central' in x).astype(int)
    pair_analysis['in_peripheral'] = pair_analysis['condition'].apply(lambda x: 'peripheral' in x).astype(int)
    
    # Sort by number of conditions
    pair_analysis = pair_analysis.sort_values(['n_conditions', 'img1'], ascending=[False, True])
    
    # Get unique pairs per condition
    unique_per_condition = {}
    for cond in ["full", "central", "peripheral"]:
        cond_pairs = set(df[df['condition'] == cond][['img1', 'img2']].apply(tuple, axis=1))
        unique_per_condition[cond] = cond_pairs
    
    # Find overlaps
    full_central = unique_per_condition['full'] & unique_per_condition['central']
    full_peripheral = unique_per_condition['full'] & unique_per_condition['peripheral']
    central_peripheral = unique_per_condition['central'] & unique_per_condition['peripheral']
    all_three = unique_per_condition['full'] & unique_per_condition['central'] & unique_per_condition['peripheral']
    
    print("\n" + "="*60)
    print("üìä DUPLICATE ANALYSIS RESULTS")
    print("="*60)
    
    print(f"\nTotal unique pairs: {len(pair_analysis)}")
    print(f"Pairs in 1 condition only: {sum(pair_analysis['n_conditions'] == 1)}")
    print(f"Pairs in 2 conditions: {sum(pair_analysis['n_conditions'] == 2)}")
    print(f"Pairs in all 3 conditions: {sum(pair_analysis['n_conditions'] == 3)}")
    
    print(f"\nOverlaps between condition pairs:")
    print(f"  Full ‚à© Central: {len(full_central)}")
    print(f"  Full ‚à© Peripheral: {len(full_peripheral)}")
    print(f"  Central ‚à© Peripheral: {len(central_peripheral)}")
    print(f"  All three: {len(all_three)}")
    
    # Show duplicates
    duplicates = pair_analysis[pair_analysis['n_conditions'] > 1]
    
    if len(duplicates) > 0:
        print(f"\nüîç DUPLICATE PAIRS (in multiple conditions): {len(duplicates)} pairs")
        
        # Pairs in all 3
        in_all_3 = duplicates[duplicates['n_conditions'] == 3]
        if len(in_all_3) > 0:
            print(f"\n‚ú® Pairs in ALL 3 conditions: {len(in_all_3)}")
            for i, (_, row) in enumerate(in_all_3.head(5).iterrows(), 1):
                print(f"\n  {i}. Pair:")
                print(f"     Image 1: {row['img1_original']}")
                print(f"     Image 2: {row['img2_original']}")
        
        # Pairs in exactly 2
        in_2 = duplicates[duplicates['n_conditions'] == 2]
        if len(in_2) > 0:
            print(f"\nüìç Pairs in exactly 2 conditions: {len(in_2)}")
            
            # Break down by which 2 conditions
            fc = in_2[(in_2['in_full'] == 1) & (in_2['in_central'] == 1) & (in_2['in_peripheral'] == 0)]
            fp = in_2[(in_2['in_full'] == 1) & (in_2['in_peripheral'] == 1) & (in_2['in_central'] == 0)]
            cp = in_2[(in_2['in_central'] == 1) & (in_2['in_peripheral'] == 1) & (in_2['in_full'] == 0)]
            
            print(f"  Full + Central: {len(fc)}")
            print(f"  Full + Peripheral: {len(fp)}")
            print(f"  Central + Peripheral: {len(cp)}")
            
            if len(in_2) > 0:
                print(f"\n  Examples:")
                for i, (_, row) in enumerate(in_2.head(3).iterrows(), 1):
                    print(f"  {i}. {row['conditions_str']}:")
                    print(f"     {row['img1_original']}")
                    print(f"     {row['img2_original']}")
    else:
        print("\n‚úÖ No duplicate pairs found across conditions!")
        print("Each image pair appears in only one viewing condition.")
    
    # Save results
    try:
        output_dir = Path("pair_analysis_results")
        output_dir.mkdir(exist_ok=True)
        
        # Save all pairs
        pair_analysis.to_csv(output_dir / "all_pairs_analysis.csv", index=False)
        
        # Save duplicates if any
        if len(duplicates) > 0:
            duplicates.to_csv(output_dir / "duplicate_pairs.csv", index=False)
            print(f"\nüíæ Results saved to {output_dir}/")
        
        # Save summary
        with open(output_dir / "summary.txt", "w") as f:
            f.write(f"Total unique pairs: {len(pair_analysis)}\n")
            f.write(f"Pairs in 1 condition: {sum(pair_analysis['n_conditions'] == 1)}\n")
            f.write(f"Pairs in 2 conditions: {sum(pair_analysis['n_conditions'] == 2)}\n")
            f.write(f"Pairs in 3 conditions: {sum(pair_analysis['n_conditions'] == 3)}\n")
            f.write(f"\nOverlaps:\n")
            f.write(f"Full ‚à© Central: {len(full_central)}\n")
            f.write(f"Full ‚à© Peripheral: {len(full_peripheral)}\n")
            f.write(f"Central ‚à© Peripheral: {len(central_peripheral)}\n")
            f.write(f"All three: {len(all_three)}\n")
            
    except Exception as e:
        print(f"Could not save files: {e}")
    
    return pair_analysis

# Run the analysis
df = find_duplicate_pairs(".")

# Show the dataframe
if not df.empty:
    print("\nüìã First 10 pairs:")
    display(df[['img1_original', 'img2_original', 'conditions_str', 'n_conditions']].head(10))

üîç Analyzing pairs from: /Users/daisybuathatseephol/Documents/three_json_output

üìÑ Reading: /Users/daisybuathatseephol/Documents/three_json_output/Testing/testing.json

Total items in JSON: 2304

Valid pairs extracted: 2304
Pairs by condition:
  full: 768
  central: 768
  peripheral: 768

üìä DUPLICATE ANALYSIS RESULTS

Total unique pairs: 72
Pairs in 1 condition only: 0
Pairs in 2 conditions: 0
Pairs in all 3 conditions: 0

Overlaps between condition pairs:
  Full ‚à© Central: 0
  Full ‚à© Peripheral: 0
  Central ‚à© Peripheral: 0
  All three: 0

üîç DUPLICATE PAIRS (in multiple conditions): 72 pairs

üíæ Results saved to pair_analysis_results/

üìã First 10 pairs:


Unnamed: 0,img1_original,img2_original,conditions_str,n_conditions
0,002fac9e85104f41a7843820e9e87ee3.jpg,30e447f9321a4c74880915cc77f54950.jpg,"full, full, full, full, full, full, full, full...",32
1,6b94d252b54242748bfb2c611bd3368f.jpg,01f6d2022368434aacc21db2a6ba4fc9.jpg,"peripheral, peripheral, peripheral, peripheral...",32
2,374cad12d807421b9e8295785d6bc44f.jpg,0260cc5cc34344afb50cfaeda417d65b.jpg,"full, full, full, full, full, full, full, full...",32
3,1098187526484c62b20eb0f59faec6ef.jpg,0693816789364e83bc559e435816ec8f.jpg,"peripheral, peripheral, peripheral, peripheral...",32
4,78765bd55e9a4ec7a9c5f5590a36b386.jpg,07150a02be164be4bd0e325cb2428430.jpg,"full, full, full, full, full, full, full, full...",32
5,6f50c2b038564cd19a0873bf396f1549.jpg,0c5a770e3f57463f84a9897407744692.jpg,"peripheral, peripheral, peripheral, peripheral...",32
6,15ef3e209a84417282278db5a1e3b412.jpg,0e68f17fae9949888cd2017b9b5a6754.jpg,"central, central, central, central, central, c...",32
7,0e68f17fae9949888cd2017b9b5a6754.jpg,1e014f55dac948ccb2e9053666e123b4.jpg,"full, full, full, full, full, full, full, full...",32
8,8b20a875626e420fa0abf64466663c6e.jpg,0f4c34f941f243c5b5f19c25924a9202.jpg,"peripheral, peripheral, peripheral, peripheral...",32
9,11b33ba71fc34d88b6a2fd4183ec3cea.jpg,1a8b5b4beb2844bfa7d6d7dd163a829f.jpg,"central, central, central, central, central, c...",32
