# PairTally Dataset Demonstration

## Overview

PairTally is the first benchmark specifically designed to evaluate fine-grained visual counting capabilities in computer vision models. This notebook provides a comprehensive demonstration of the dataset structure, visualization capabilities, and evaluation pipeline.

### Paper

**Can Current AI Models Count What We Mean, Not What They See? A Benchmark and Systematic Evaluation**  
Gia Khanh Nguyen, Yifeng Huang, Minh Hoai  
Digital Image Computing: Techniques and Applications (DICTA) 2025

### Dataset Specifications

- **Total Images**: 681 high-resolution images
- **Categories**: 54 object categories across 98 subcategories
- **Task Types**: Inter-category (different objects) and Intra-category (same object, different attributes)
- **Attribute Differences**: Color (43.5%), Shape/Texture (42.5%), Size (14.0%)

### Key Findings

Current state-of-the-art models achieve Mean Absolute Error (MAE) of 53.07, revealing critical gaps in fine-grained visual understanding and discrimination capabilities.

## 1. Environment Setup and Dependencies

In [None]:
%pip install pandas matplotlib seaborn scikit-learn

In [None]:
# Standard library imports
import os
import json
import random
import warnings
from pathlib import Path
from typing import Dict, List, Tuple, Optional

# Data handling
import numpy as np
import pandas as pd

# Visualization
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.patches import Rectangle
import seaborn as sns

# Image processing
from PIL import Image
import cv2

# Progress bar
from tqdm.notebook import tqdm

# Configure visualization settings
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("husl")
warnings.filterwarnings('ignore')

# Set random seed for reproducibility
random.seed(42)
np.random.seed(42)

print("Environment configured successfully")

## 2. Dataset Configuration

In [None]:
# Configure dataset paths
BASE_DIR = Path(".").resolve()
DATASET_DIR = BASE_DIR / "dataset" / "pairtally_dataset"
ANNOTATIONS_DIR = DATASET_DIR / "annotations"
IMAGES_DIR = DATASET_DIR / "images"
MODELS_DIR = BASE_DIR / "models"

def verify_dataset():
    """Verify dataset structure and availability"""
    issues = []
    
    if not DATASET_DIR.exists():
        issues.append(f"Dataset directory not found: {DATASET_DIR}")
    
    if not ANNOTATIONS_DIR.exists():
        issues.append(f"Annotations directory not found: {ANNOTATIONS_DIR}")
    
    if not IMAGES_DIR.exists():
        issues.append(f"Images directory not found: {IMAGES_DIR}")
        issues.append("Download images from: https://drive.google.com/file/d/1TnenXS4yFicjo81NnmClfzgc8ltmmeBv/view")
    
    if issues:
        for issue in issues:
            print(f"ERROR: {issue}")
        return False
    
    # Count images
    num_images = len(list(IMAGES_DIR.glob("*.jpg")))
    print(f"Dataset verified successfully")
    print(f"Found {num_images} images at {DATASET_DIR}")
    return True

dataset_ready = verify_dataset()

## 3. Load Dataset Annotations

In [None]:
class PairTallyDataset:
    """PairTally dataset interface for loading and accessing annotations"""
    
    def __init__(self, dataset_dir: Path, version: str = "simple"):
        self.dataset_dir = dataset_dir
        self.annotations_dir = dataset_dir / "annotations"
        self.images_dir = dataset_dir / "images"
        self.version = version  # "simple" or "augmented"
        
        # Load annotations
        self.annotations = self._load_annotations()
        self.inter_annotations = self._load_inter_annotations()
        self.intra_annotations = self._load_intra_annotations()
        
        # Load metadata
        self.metadata = self._load_metadata()
        self.filename_mapping = self._load_filename_mapping()
        
        print(f"Loaded {len(self.annotations)} total annotations")
        print(f"  Inter-category: {len(self.inter_annotations)} images")
        print(f"  Intra-category: {len(self.intra_annotations)} images")
    
    def _load_json(self, filename: str) -> Dict:
        """Load a JSON file"""
        filepath = self.annotations_dir / filename
        if filepath.exists():
            with open(filepath, 'r') as f:
                return json.load(f)
        return {}
    
    def _load_annotations(self) -> Dict:
        """Load main annotations"""
        filename = f"pairtally_annotations_{self.version}.json"
        return self._load_json(filename)
    
    def _load_inter_annotations(self) -> Dict:
        """Load inter-category annotations"""
        filename = f"pairtally_annotations_inter_{self.version}.json"
        return self._load_json(filename)
    
    def _load_intra_annotations(self) -> Dict:
        """Load intra-category annotations"""
        filename = f"pairtally_annotations_intra_{self.version}.json"
        return self._load_json(filename)
    
    def _load_metadata(self) -> Dict:
        """Load image metadata"""
        return self._load_json("image_metadata.json")
    
    def _load_filename_mapping(self) -> Dict:
        """Load filename mapping"""
        return self._load_json("filename_mapping.json")
    
    def get_random_image(self, subset: str = "all") -> str:
        """Get a random image filename from the dataset"""
        if subset == "inter":
            images = list(self.inter_annotations.keys())
        elif subset == "intra":
            images = list(self.intra_annotations.keys())
        else:
            images = list(self.annotations.keys())
        
        return random.choice(images) if images else None
    
    def get_annotation(self, image_name: str) -> Dict:
        """Get annotation for a specific image"""
        return self.annotations.get(image_name, {})
    
    def get_counts(self, image_name: str) -> Tuple[int, int]:
        """Get positive and negative class counts for an image"""
        anno = self.get_annotation(image_name)
        if anno:
            positive_count = len(anno.get('points', []))
            negative_count = len(anno.get('negative_points', []))
            return positive_count, negative_count
        return 0, 0
    
    def get_image_path(self, image_name: str) -> Path:
        """Get full path to an image"""
        return self.images_dir / image_name

# Initialize dataset
if dataset_ready:
    dataset = PairTallyDataset(DATASET_DIR, version="simple")
    print("\nDataset Statistics:")
    print(f"Total images: {len(dataset.annotations)}")
    print(f"Inter-category pairs: {len(dataset.inter_annotations)}")
    print(f"Intra-category pairs: {len(dataset.intra_annotations)}")

## 4. Visualization Functions

In [None]:
def visualize_image_with_annotations(dataset: PairTallyDataset, 
                                    image_name: str,
                                    show_boxes: bool = True,
                                    show_points: bool = False,
                                    figsize: Tuple[int, int] = None):
    """
    Visualize an image with bounding box annotations and object counts.
    
    Args:
        dataset: PairTallyDataset instance
        image_name: Name of the image to visualize
        show_boxes: Whether to show bounding box examples
        show_points: Whether to show point annotations
        figsize: Figure size for visualization
    """
    # Get image path and annotation
    img_path = dataset.get_image_path(image_name)
    annotation = dataset.get_annotation(image_name)
    
    if not img_path.exists():
        print(f"ERROR: Image not found: {img_path}")
        return
    
    # Load image
    img = Image.open(img_path)
    img_array = np.array(img)
    
    # Get image dimensions
    height, width = img_array.shape[:2]
    
    # Auto-calculate figure size to maintain aspect ratio
    if figsize is None:
        aspect_ratio = width / height
        fig_width = 10
        fig_height = fig_width / aspect_ratio + 1
        figsize = (fig_width, fig_height)
    
    # Get annotation data
    positive_points = annotation.get('points', [])
    negative_points = annotation.get('negative_points', [])
    positive_boxes = annotation.get('box_examples_coordinates', [])
    negative_boxes = annotation.get('negative_box_exemples_coordinates', [])
    positive_prompt = annotation.get('positive_prompt', 'Class 1')
    negative_prompt = annotation.get('negative_prompt', 'Class 2')
    
    # Create figure
    fig = plt.figure(figsize=figsize, facecolor='white')
    ax = fig.add_subplot(111)
    
    # Display image
    ax.imshow(img_array)
    ax.axis('off')
    
    # Add title with counts
    blue_text = f"{positive_prompt}: {len(positive_points)}"
    red_text = f"{negative_prompt}: {len(negative_points)}"
    
    fig.text(0.45, 0.95, blue_text, 
             ha='right', va='top', fontsize=14, color='#0066FF', 
             fontweight='bold', transform=fig.transFigure)
    
    fig.text(0.5, 0.95, " | ", 
             ha='center', va='top', fontsize=14, color='#666666', 
             fontweight='normal', transform=fig.transFigure)
    
    fig.text(0.55, 0.95, red_text, 
             ha='left', va='top', fontsize=14, color='#FF0040', 
             fontweight='bold', transform=fig.transFigure)
    
    # Plot bounding boxes
    if show_boxes:
        # Positive class boxes in blue
        for box_coords in positive_boxes[:3]:
            if len(box_coords) == 4:
                x_coords = [pt[0] for pt in box_coords]
                y_coords = [pt[1] for pt in box_coords]
                x1, x2 = min(x_coords), max(x_coords)
                y1, y2 = min(y_coords), max(y_coords)
                
                rect = Rectangle((x1, y1), x2-x1, y2-y1, 
                               linewidth=2, edgecolor='#0066FF', 
                               facecolor='none')
                ax.add_patch(rect)
        
        # Negative class boxes in red
        for box_coords in negative_boxes[:3]:
            if len(box_coords) == 4:
                x_coords = [pt[0] for pt in box_coords]
                y_coords = [pt[1] for pt in box_coords]
                x1, x2 = min(x_coords), max(x_coords)
                y1, y2 = min(y_coords), max(y_coords)
                
                rect = Rectangle((x1, y1), x2-x1, y2-y1, 
                               linewidth=2, edgecolor='#FF0040', 
                               facecolor='none')
                ax.add_patch(rect)
    
    # Plot points if requested
    if show_points:
        if positive_points:
            pos_points = np.array(positive_points)
            ax.scatter(pos_points[:, 0], pos_points[:, 1], 
                      c='#0066FF', s=8, alpha=0.4, marker='.')
        
        if negative_points:
            neg_points = np.array(negative_points)
            ax.scatter(neg_points[:, 0], neg_points[:, 1], 
                      c='#FF0040', s=8, alpha=0.4, marker='.')
    
    plt.tight_layout(rect=[0, 0.02, 1, 0.94])
    plt.show()
    
    return positive_points, negative_points

def visualize_three_images(dataset: PairTallyDataset, 
                          image_names: List[str],
                          show_boxes: bool = True,
                          show_points: bool = False):
    """
    Visualize exactly 3 images in a horizontal layout.
    
    Args:
        dataset: PairTallyDataset instance
        image_names: List of exactly 3 image names to visualize
        show_boxes: Whether to show bounding box examples
        show_points: Whether to show point annotations
    """
    if len(image_names) != 3:
        print(f"ERROR: This function requires exactly 3 images, got {len(image_names)}")
        return
    
    # Create figure with 3 horizontal subplots
    fig, axes = plt.subplots(1, 3, figsize=(18, 6), facecolor='white')
    
    for i, image_name in enumerate(image_names):
        # Get image path and annotation
        img_path = dataset.get_image_path(image_name)
        annotation = dataset.get_annotation(image_name)
        
        if not img_path.exists():
            print(f"ERROR: Image not found: {img_path}")
            axes[i].text(0.5, 0.5, f"Image not found:\n{image_name}", 
                        ha='center', va='center', transform=axes[i].transAxes)
            axes[i].axis('off')
            continue
        
        # Load image
        img = Image.open(img_path)
        img_array = np.array(img)
        
        # Get annotation data
        positive_points = annotation.get('points', [])
        negative_points = annotation.get('negative_points', [])
        positive_boxes = annotation.get('box_examples_coordinates', [])
        negative_boxes = annotation.get('negative_box_exemples_coordinates', [])
        positive_prompt = annotation.get('positive_prompt', 'Class 1')
        negative_prompt = annotation.get('negative_prompt', 'Class 2')
        
        # Display image
        axes[i].imshow(img_array)
        axes[i].axis('off')
        
        # Add title with colored text
        blue_text = f"{positive_prompt}: {len(positive_points)}"
        red_text = f"{negative_prompt}: {len(negative_points)}"
        
        # Position title above each subplot
        fig.text(0.17 + i*0.33, 0.92, blue_text, 
                ha='center', va='top', fontsize=12, color='#0066FF', 
                fontweight='bold', transform=fig.transFigure)
        
        fig.text(0.17 + i*0.33, 0.88, red_text, 
                ha='center', va='top', fontsize=12, color='#FF0040', 
                fontweight='bold', transform=fig.transFigure)
        
        # Add image filename at bottom
        axes[i].text(0.5, -0.05, image_name, 
                    ha='center', va='top', fontsize=10, 
                    transform=axes[i].transAxes, weight='bold')
        
        # Plot bounding boxes
        if show_boxes:
            # Positive class boxes in blue
            for box_coords in positive_boxes[:3]:
                if len(box_coords) == 4:
                    x_coords = [pt[0] for pt in box_coords]
                    y_coords = [pt[1] for pt in box_coords]
                    x1, x2 = min(x_coords), max(x_coords)
                    y1, y2 = min(y_coords), max(y_coords)
                    
                    rect = Rectangle((x1, y1), x2-x1, y2-y1, 
                                   linewidth=2, edgecolor='#0066FF', 
                                   facecolor='none')
                    axes[i].add_patch(rect)
            
            # Negative class boxes in red
            for box_coords in negative_boxes[:3]:
                if len(box_coords) == 4:
                    x_coords = [pt[0] for pt in box_coords]
                    y_coords = [pt[1] for pt in box_coords]
                    x1, x2 = min(x_coords), max(x_coords)
                    y1, y2 = min(y_coords), max(y_coords)
                    
                    rect = Rectangle((x1, y1), x2-x1, y2-y1, 
                                   linewidth=2, edgecolor='#FF0040', 
                                   facecolor='none')
                    axes[i].add_patch(rect)
        
        # Plot points if requested
        if show_points:
            if positive_points:
                pos_points = np.array(positive_points)
                axes[i].scatter(pos_points[:, 0], pos_points[:, 1], 
                              c='#0066FF', s=8, alpha=0.4, marker='.')
            
            if negative_points:
                neg_points = np.array(negative_points)
                axes[i].scatter(neg_points[:, 0], neg_points[:, 1], 
                              c='#FF0040', s=8, alpha=0.4, marker='.')
    
    plt.tight_layout(rect=[0, 0.05, 1, 0.85])
    plt.show()

def display_dataset_statistics(dataset: PairTallyDataset):
    """Display comprehensive statistics about the PairTally dataset"""
    
    # Collect statistics
    all_positive_counts = []
    all_negative_counts = []
    all_total_counts = []
    
    for img_name in dataset.annotations.keys():
        pos_count, neg_count = dataset.get_counts(img_name)
        all_positive_counts.append(pos_count)
        all_negative_counts.append(neg_count)
        all_total_counts.append(pos_count + neg_count)
    
    # Create visualization
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    fig.patch.set_facecolor('white')
    
    # Distribution of total counts
    axes[0, 0].hist(all_total_counts, bins=30, color='#2E86AB', edgecolor='#1B4F72', alpha=0.8)
    axes[0, 0].set_title('Distribution of Total Object Counts', fontsize=12, fontweight='bold')
    axes[0, 0].set_xlabel('Total Count')
    axes[0, 0].set_ylabel('Frequency')
    axes[0, 0].axvline(np.mean(all_total_counts), color='#E74C3C', linestyle='--', linewidth=2,
                      label=f'Mean: {np.mean(all_total_counts):.1f}')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # Distribution of positive vs negative counts
    axes[0, 1].hist([all_positive_counts, all_negative_counts], 
                   bins=25, label=['Positive Class', 'Negative Class'],
                   color=['#0066FF', '#FF0040'], alpha=0.7, edgecolor='#2C3E50')
    axes[0, 1].set_title('Distribution of Class Counts', fontsize=12, fontweight='bold')
    axes[0, 1].set_xlabel('Count')
    axes[0, 1].set_ylabel('Frequency')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    
    # Scatter plot of positive vs negative counts
    axes[1, 0].scatter(all_positive_counts, all_negative_counts, 
                      alpha=0.6, s=25, c='#8E44AD', edgecolors='#5B2C6F', linewidth=0.5)
    axes[1, 0].set_title('Positive vs Negative Class Counts', fontsize=12, fontweight='bold')
    axes[1, 0].set_xlabel('Positive Class Count')
    axes[1, 0].set_ylabel('Negative Class Count')
    axes[1, 0].plot([0, max(all_positive_counts)], [0, max(all_positive_counts)], 
                   'k--', alpha=0.3, label='Equal counts', linewidth=1)
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)
    
    # Summary statistics
    stats_text = f"""Dataset Summary Statistics
    
Total Images: {len(dataset.annotations)}
Inter-category: {len(dataset.inter_annotations)}
Intra-category: {len(dataset.intra_annotations)}

Object Count Statistics:
  Mean total: {np.mean(all_total_counts):.1f} ± {np.std(all_total_counts):.1f}
  Min/Max: {min(all_total_counts)} / {max(all_total_counts)}
  Median: {np.median(all_total_counts):.1f}

Positive Class:
  Mean: {np.mean(all_positive_counts):.1f} ± {np.std(all_positive_counts):.1f}
  Range: [{min(all_positive_counts)}, {max(all_positive_counts)}]

Negative Class:
  Mean: {np.mean(all_negative_counts):.1f} ± {np.std(all_negative_counts):.1f}
  Range: [{min(all_negative_counts)}, {max(all_negative_counts)}]
"""
    
    axes[1, 1].text(0.1, 0.5, stats_text, transform=axes[1, 1].transAxes,
                   fontsize=11, verticalalignment='center', fontfamily='monospace',
                   bbox=dict(boxstyle='round,pad=0.8', facecolor='#F8F9FA', alpha=0.9, edgecolor='#BDC3C7'))
    axes[1, 1].axis('off')
    
    plt.suptitle('PairTally Dataset Statistics', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()

# Display statistics
if dataset_ready:
    print("Generating dataset statistics visualization...")
    display_dataset_statistics(dataset)

## 5. Demo: Select and Visualize Random Images

In [None]:
def demo_random_images(dataset: PairTallyDataset, 
                       num_images: int = 6,
                       subset: str = "all"):
    """
    Demonstrate dataset visualization with random image selection.
    Show images in groups of 3 to reduce scrolling.
    
    Args:
        dataset: PairTallyDataset instance
        num_images: Number of random images to display
        subset: Which subset to use ("all", "inter", "intra")
    """
    print(f"\nSelecting {num_images} random images from {subset} subset...\n")
    
    # Get random images
    selected_images = []
    for i in range(num_images):
        image_name = dataset.get_random_image(subset=subset)
        if image_name and image_name not in selected_images:
            selected_images.append(image_name)
        elif image_name in selected_images:
            # Try again with different random seed
            for _ in range(10):  # Max 10 retries
                alt_image = dataset.get_random_image(subset=subset)
                if alt_image and alt_image not in selected_images:
                    selected_images.append(alt_image)
                    break
    
    if not selected_images:
        print("No images found")
        return
    
    # Display info about selected images
    for i, image_name in enumerate(selected_images):
        annotation = dataset.get_annotation(image_name)
        pos_count, neg_count = dataset.get_counts(image_name)
        pos_prompt = annotation.get('positive_prompt', 'Class 1')
        neg_prompt = annotation.get('negative_prompt', 'Class 2')
        category_type = "INTER" if image_name in dataset.inter_annotations else "INTRA"
        
        print(f"Image {i+1}: {image_name[:40]}")
        print(f"  Type: {category_type} | Classes: {pos_prompt} vs {neg_prompt} | Counts: {pos_count}+{neg_count}={pos_count+neg_count}")
    
    print("\\n" + "="*80)
    
    # Display images in groups of 3
    for i in range(0, len(selected_images), 3):
        group = selected_images[i:i+3]
        
        if len(group) == 3:
            print(f"\\nGroup {i//3 + 1}: Images {i+1}-{i+3}")
            visualize_three_images(dataset, group, show_boxes=True, show_points=False)
        else:
            # For remaining images that don't make a group of 3
            print(f"\\nRemaining images:")
            for img_name in group:
                visualize_image_with_annotations(dataset, img_name, 
                                               show_boxes=True, show_points=False)

def demo_specific_attribute_images(dataset: PairTallyDataset, 
                                  attribute_type: str = "color",
                                  num_examples: int = 3):
    """
    Demonstrate specific attribute-based intra-category examples.
    
    Args:
        dataset: PairTallyDataset instance
        attribute_type: Type of attribute difference ("color", "size", "texture")
        num_examples: Number of examples to show (will show exactly 3)
    """
    # Find intra-category images with specific patterns
    intra_images = list(dataset.intra_annotations.keys())
    candidates = []
    
    if attribute_type == "color":
        keywords = ['COL', 'color', 'black', 'white', 'red', 'blue', 'green', 'yellow']
    elif attribute_type == "size":
        keywords = ['SIZ', 'size', 'big', 'small', 'large', 'tiny']
    else:  # texture/shape
        keywords = ['TEX', 'SHA', 'round', 'square', 'smooth', 'rough']
    
    # Filter images that might match
    for img in intra_images:
        if any(kw.lower() in img.lower() for kw in keywords):
            candidates.append(img)
    
    # If no specific matches, just use any intra images
    if not candidates:
        candidates = intra_images
    
    # Select exactly 3 examples
    selected = random.sample(candidates, min(3, len(candidates)))
    
    print(f"\\n{attribute_type.upper()} DIFFERENCE EXAMPLES:")
    print("="*60)
    
    # Display info
    for i, image_name in enumerate(selected):
        annotation = dataset.get_annotation(image_name)
        pos_count, neg_count = dataset.get_counts(image_name)
        pos_prompt = annotation.get('positive_prompt', 'Class 1')
        neg_prompt = annotation.get('negative_prompt', 'Class 2')
        
        print(f"Example {i+1}: {image_name[:40]}")
        print(f"  Classes: {pos_prompt} vs {neg_prompt} | Counts: {pos_count}+{neg_count}={pos_count+neg_count}")
    
    print()
    
    # Visualize as a group of 3
    if len(selected) == 3:
        visualize_three_images(dataset, selected, show_boxes=True, show_points=False)
    else:
        # Fallback for less than 3 images
        for image_name in selected:
            visualize_image_with_annotations(dataset, image_name, 
                                           show_boxes=True, show_points=False)

# Run comprehensive demonstration
if dataset_ready:
    print("="*80)
    print("COMPREHENSIVE DATASET VISUALIZATION")
    print("="*80)
    
    # 1. Show 6 INTER-category examples (2 groups of 3)
    print("\\n" + "="*60)
    print("1. INTER-CATEGORY EXAMPLES (Different Object Types)")
    print("="*60)
    demo_random_images(dataset, num_images=6, subset="inter")
    
    # 2. Show 6 INTRA-category examples (2 groups of 3)
    print("\\n" + "="*60)
    print("2. INTRA-CATEGORY EXAMPLES (Same Object, Different Attributes)")
    print("="*60)
    demo_random_images(dataset, num_images=6, subset="intra")
    
    # 3. Show specific attribute differences (3 examples each)
    print("\\n" + "="*60)
    print("3. SPECIFIC ATTRIBUTE DIFFERENCES")
    print("="*60)
    
    print("Demonstrating attribute-based discrimination challenges:")
    
    # Color difference examples
    demo_specific_attribute_images(dataset, attribute_type="color", 
                                 num_examples=3)
    
    # Size difference examples  
    demo_specific_attribute_images(dataset, attribute_type="size", 
                                 num_examples=3)
    
    # Texture/Shape difference examples
    demo_specific_attribute_images(dataset, attribute_type="texture", 
                                 num_examples=3)

## 6. CountGD Model Evaluation Demo

This section demonstrates how to use the CountGD model on PairTally dataset images. CountGD is a state-of-the-art object counting model that can use both exemplar boxes and text prompts for counting.

In [None]:
# Import CountGD model components
import sys
import torch
from pathlib import Path
import torchvision.transforms.functional as F

# Add CountGD to path
countgd_path = BASE_DIR / "models" / "countgd" / "CountGD"
sys.path.append(str(countgd_path))

try:
    from util.slconfig import SLConfig, DictAction
    from util.misc import nested_tensor_from_tensor_list
    import datasets_inference.transforms as T
    from models.registry import MODULE_BUILD_FUNCS
    
    print("CountGD imports successful")
    countgd_available = True
except ImportError as e:
    print(f"CountGD import failed: {e}")
    print("Make sure CountGD dependencies are installed")
    countgd_available = False

def load_countgd_model():
    """Load the CountGD model with default configuration"""
    if not countgd_available:
        print("CountGD not available - skipping model loading")
        return None, None
        
    # Set default paths
    config_path = countgd_path / "config" / "cfg_fsc147_vit_b.py"
    checkpoint_path = countgd_path / "checkpoints" / "checkpoint_fsc147_best.pth"
    
    if not config_path.exists():
        print(f"Config file not found: {config_path}")
        print("Please ensure CountGD model files are properly set up")
        return None, None
        
    if not checkpoint_path.exists():
        print(f"Checkpoint file not found: {checkpoint_path}")
        print("Please download the CountGD model checkpoint")
        return None, None
    
    try:
        # Load config
        cfg = SLConfig.fromfile(str(config_path))
        cfg.merge_from_dict({"text_encoder_type": str(countgd_path / "checkpoints" / "bert-base-uncased")})
        
        # Set device
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {device}")
        
        # Build model
        build_func = MODULE_BUILD_FUNCS.get("groundingdino")
        if build_func is None:
            print("Model builder not found")
            return None, None
            
        class Args:
            def __init__(self):
                self.modelname = "groundingdino"
                self.device = str(device)
        
        args = Args()
        for k, v in cfg._cfg_dict.items():
            setattr(args, k, v)
            
        model, _, _ = build_func(args)
        model.to(device)
        
        # Load checkpoint - FIX: Set weights_only=False for PyTorch 2.6+
        print("Loading checkpoint...")
        checkpoint = torch.load(str(checkpoint_path), map_location=device, weights_only=False)["model"]
        model.load_state_dict(checkpoint, strict=False)
        model.eval()
        
        # Create transform
        transforms = T.Compose([
            T.RandomResize([800], max_size=1333),
            T.ToTensor(),
            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        
        print("CountGD model loaded successfully!")
        return model, transforms
        
    except Exception as e:
        print(f"Failed to load CountGD model: {e}")
        import traceback
        traceback.print_exc()
        return None, None

# Load the model
if countgd_available:
    print("Loading CountGD model...")
    countgd_model, countgd_transform = load_countgd_model()
    model_ready = (countgd_model is not None)
else:
    model_ready = False
    print("CountGD model not available for this demo")

print(f"Model ready status: {model_ready}")

In [None]:
def run_countgd_inference(model, transform, image, exemplar_boxes, text_prompt, confidence_thresh=0.3):
    """
    Run CountGD inference on an image with exemplar boxes and text prompt
    
    Args:
        model: CountGD model
        transform: Image transform pipeline
        image: PIL Image
        exemplar_boxes: List of exemplar boxes in [x1, y1, x2, y2] format
        text_prompt: Text description of objects to count
        confidence_thresh: Confidence threshold for detections
        
    Returns:
        pred_boxes: Predicted bounding boxes (normalized coordinates)
        pred_logits: Prediction logits/scores
        pred_count: Number of predicted objects
    """
    device = next(model.parameters()).device
    
    # Convert exemplar boxes to tensor format if available
    if exemplar_boxes and len(exemplar_boxes) > 0:
        # Convert to normalized coordinates relative to image size
        img_width, img_height = image.size
        exemplar_tensor = []
        for box in exemplar_boxes[:3]:  # Limit to 3 exemplars
            x1, y1, x2, y2 = box
            # Normalize to [0, 1]
            norm_x1 = x1 / img_width
            norm_y1 = y1 / img_height
            norm_x2 = x2 / img_width
            norm_y2 = y2 / img_height
            exemplar_tensor.append([norm_x1, norm_y1, norm_x2, norm_y2])
        
        exemplar_tensor = torch.tensor(exemplar_tensor, dtype=torch.float32)
    else:
        # Use empty tensor if no exemplars
        exemplar_tensor = torch.tensor([], dtype=torch.float32).reshape(0, 4)
    
    # Prepare input
    input_image, target = transform(image, {"exemplars": exemplar_tensor})
    input_image = input_image.to(device)
    input_exemplar = target["exemplars"].to(device)
    
    # Format text prompt
    input_text = text_prompt + " ."
    
    # Run inference
    with torch.no_grad():
        model_output = model(
            input_image.unsqueeze(0),
            [input_exemplar],
            [torch.tensor([0]).to(device)],
            captions=[input_text],
        )
    
    # Extract predictions
    logits = model_output["pred_logits"][0].sigmoid()
    boxes = model_output["pred_boxes"][0]
    
    # Filter by confidence threshold
    box_mask = logits.max(dim=-1).values > confidence_thresh
    filtered_logits = logits[box_mask, :]
    filtered_boxes = boxes[box_mask, :]
    pred_count = filtered_boxes.shape[0]
    
    return filtered_boxes, filtered_logits, pred_count

def visualize_countgd_results(dataset, image_name, model, transform, confidence_thresh=0.3):
    """
    Run CountGD on a PairTally image and visualize results comparing ground truth vs predictions.
    """
    if not model_ready:
        print("CountGD model not available")
        return
        
    try:
        # Get image and annotation
        img_path = dataset.get_image_path(image_name)
        annotation = dataset.get_annotation(image_name)
        
        if not img_path.exists():
            print(f"ERROR: Image not found: {img_path}")
            return
        
        # Load image
        image = Image.open(img_path).convert('RGB')
        img_array = np.array(image)
        
        # Get annotation data
        positive_points = annotation.get('points', [])
        negative_points = annotation.get('negative_points', [])
        positive_boxes = annotation.get('box_examples_coordinates', [])
        negative_boxes = annotation.get('negative_box_exemples_coordinates', [])
        positive_prompt = annotation.get('positive_prompt', 'objects')
        negative_prompt = annotation.get('negative_prompt', 'objects')
        
        # Convert FSC147 format boxes to [x1, y1, x2, y2] format
        def convert_boxes(box_coords_list):
            boxes = []
            for box_coords in box_coords_list[:3]:  # Use first 3 exemplars
                if len(box_coords) == 4:
                    x_coords = [pt[0] for pt in box_coords]
                    y_coords = [pt[1] for pt in box_coords]
                    x1, x2 = min(x_coords), max(x_coords)
                    y1, y2 = min(y_coords), max(y_coords)
                    boxes.append([x1, y1, x2, y2])
            return boxes
        
        pos_exemplar_boxes = convert_boxes(positive_boxes)
        neg_exemplar_boxes = convert_boxes(negative_boxes)
        
        # Get image dimensions for box conversion
        img_width, img_height = image.size
        
        print(f"Running CountGD on: {image_name}")
        print(f"Ground Truth - {positive_prompt}: {len(positive_points)}, {negative_prompt}: {len(negative_points)}")
        
        # Run inference for positive class
        print(f"\nInferring positive class ({positive_prompt})...")
        pos_pred_boxes, pos_pred_logits, pos_pred_count = run_countgd_inference(
            model, transform, image, pos_exemplar_boxes, positive_prompt, confidence_thresh=confidence_thresh
        )
        
        # Run inference for negative class  
        print(f"\nInferring negative class ({negative_prompt})...")
        neg_pred_boxes, neg_pred_logits, neg_pred_count = run_countgd_inference(
            model, transform, image, neg_exemplar_boxes, negative_prompt, confidence_thresh=confidence_thresh
        )
        
        print(f"\nCountGD Predictions - {positive_prompt}: {pos_pred_count}, {negative_prompt}: {neg_pred_count}")
        
        # Create visualization
        fig, axes = plt.subplots(1, 3, figsize=(18, 6))
        fig.suptitle(f'CountGD Evaluation: {image_name}', fontsize=14, fontweight='bold')
        
        # 1. Ground Truth with exemplars (no dots, just exemplar boxes)
        axes[0].imshow(img_array)
        axes[0].set_title(f'Ground Truth + Exemplars\n{positive_prompt}: {len(positive_points)}, {negative_prompt}: {len(negative_points)}', fontsize=10)
        axes[0].axis('off')
        
        # Draw exemplar boxes only (no ground truth points)
        for box in pos_exemplar_boxes:
            x1, y1, x2, y2 = box
            rect = Rectangle((x1, y1), x2-x1, y2-y1, linewidth=2, edgecolor='#0066FF', facecolor='none', linestyle='--')
            axes[0].add_patch(rect)
        
        for box in neg_exemplar_boxes:
            x1, y1, x2, y2 = box
            rect = Rectangle((x1, y1), x2-x1, y2-y1, linewidth=2, edgecolor='#FF0040', facecolor='none', linestyle='--')
            axes[0].add_patch(rect)
        
        # 2. Positive class predictions
        axes[1].imshow(img_array)
        axes[1].set_title(f'{positive_prompt} Predictions\nGT: {len(positive_points)}, Pred: {pos_pred_count}', fontsize=10)
        axes[1].axis('off')
        
        # Draw positive exemplars (dashed)
        for box in pos_exemplar_boxes:
            x1, y1, x2, y2 = box
            rect = Rectangle((x1, y1), x2-x1, y2-y1, linewidth=2, edgecolor='#0066FF', facecolor='none', linestyle='--')
            axes[1].add_patch(rect)
        
        # Draw positive predictions (solid boxes)
        if len(pos_pred_boxes) > 0 and pos_pred_count > 0:
            # Convert from normalized coordinates [0,1] to pixel coordinates
            pos_pred_boxes_pixel = pos_pred_boxes.clone().cpu().numpy()
            
            # CountGD outputs are in format [cx, cy, w, h] normalized
            # Convert to [x1, y1, x2, y2] pixel coordinates
            for box in pos_pred_boxes_pixel:
                cx, cy, w, h = box
                # Convert from normalized [cx, cy, w, h] to pixel [x1, y1, x2, y2]
                x1 = (cx - w/2) * img_width
                y1 = (cy - h/2) * img_height
                x2 = (cx + w/2) * img_width  
                y2 = (cy + h/2) * img_height
                
                # Clamp to image bounds
                x1 = max(0, min(img_width, x1))
                y1 = max(0, min(img_height, y1))
                x2 = max(0, min(img_width, x2))
                y2 = max(0, min(img_height, y2))
                
                if x2 > x1 and y2 > y1:  # Valid box
                    rect = Rectangle((x1, y1), x2-x1, y2-y1, linewidth=2, edgecolor='#0066FF', facecolor='none')
                    axes[1].add_patch(rect)
        
        # 3. Negative class predictions
        axes[2].imshow(img_array)
        axes[2].set_title(f'{negative_prompt} Predictions\nGT: {len(negative_points)}, Pred: {neg_pred_count}', fontsize=10)
        axes[2].axis('off')
        
        # Draw negative exemplars (dashed)
        for box in neg_exemplar_boxes:
            x1, y1, x2, y2 = box
            rect = Rectangle((x1, y1), x2-x1, y2-y1, linewidth=2, edgecolor='#FF0040', facecolor='none', linestyle='--')
            axes[2].add_patch(rect)
        
        # Draw negative predictions (solid boxes)
        if len(neg_pred_boxes) > 0 and neg_pred_count > 0:
            # Convert from normalized coordinates [0,1] to pixel coordinates
            neg_pred_boxes_pixel = neg_pred_boxes.clone().cpu().numpy()
            
            # CountGD outputs are in format [cx, cy, w, h] normalized
            # Convert to [x1, y1, x2, y2] pixel coordinates
            for box in neg_pred_boxes_pixel:
                cx, cy, w, h = box
                # Convert from normalized [cx, cy, w, h] to pixel [x1, y1, x2, y2]
                x1 = (cx - w/2) * img_width
                y1 = (cy - h/2) * img_height
                x2 = (cx + w/2) * img_width
                y2 = (cy + h/2) * img_height
                
                # Clamp to image bounds
                x1 = max(0, min(img_width, x1))
                y1 = max(0, min(img_height, y1))
                x2 = max(0, min(img_width, x2))
                y2 = max(0, min(img_height, y2))
                
                if x2 > x1 and y2 > y1:  # Valid box
                    rect = Rectangle((x1, y1), x2-x1, y2-y1, linewidth=2, edgecolor='#FF0040', facecolor='none')
                    axes[2].add_patch(rect)
        
        plt.tight_layout()
        plt.show()
        
        # Calculate errors
        pos_error = abs(pos_pred_count - len(positive_points))
        neg_error = abs(neg_pred_count - len(negative_points))
        total_error = pos_error + neg_error
        
        print(f"\nEvaluation Results:")
        print(f"  {positive_prompt} - GT: {len(positive_points)}, Pred: {pos_pred_count}, Error: {pos_error}")
        print(f"  {negative_prompt} - GT: {len(negative_points)}, Pred: {neg_pred_count}, Error: {neg_error}")
        print(f"  Total Error: {total_error}")
        print(f"  Confidence Threshold: {confidence_thresh}")
        
        return {
            'positive_gt': len(positive_points),
            'positive_pred': pos_pred_count,
            'positive_error': pos_error,
            'negative_gt': len(negative_points),
            'negative_pred': neg_pred_count,
            'negative_error': neg_error,
            'total_error': total_error
        }
        
    except Exception as e:
        print(f"Error in CountGD evaluation: {e}")
        import traceback
        traceback.print_exc()
        return None

print("CountGD inference functions loaded successfully")

In [None]:
# Demo CountGD on sample images
def demo_countgd_inference(dataset, model, transform, num_examples=3):
    """
    Demonstrate CountGD inference on sample PairTally images
    """
    if not model_ready:
        print("CountGD model not available - please set up the model first")
        print("\nTo set up CountGD:")
        print("1. cd models/countgd")
        print("2. Follow setup instructions in README")
        print("3. Download model checkpoint")
        return
        
    print("COUNTGD INFERENCE DEMO")
    print("="*60)
    print("Demonstrating CountGD on PairTally images...")
    print("This shows how counting models perform on fine-grained discrimination tasks\n")
    
    # Get sample images from different subsets
    inter_samples = random.sample(list(dataset.inter_annotations.keys()), min(2, len(dataset.inter_annotations)))
    intra_samples = random.sample(list(dataset.intra_annotations.keys()), min(2, len(dataset.intra_annotations)))
    
    all_samples = inter_samples + intra_samples
    results = []
    
    for i, image_name in enumerate(all_samples[:num_examples]):
        print(f"\nExample {i+1}/{len(all_samples[:num_examples])}: {image_name}")
        print("-" * 50)
        
        result = visualize_countgd_results(dataset, image_name, model, transform)
        if result:
            results.append(result)
    
    # Summary statistics
    if results:
        print("\n" + "="*60)
        print("SUMMARY STATISTICS")
        print("="*60)
        
        total_pos_error = sum(r['positive_error'] for r in results)
        total_neg_error = sum(r['negative_error'] for r in results)
        total_error = sum(r['total_error'] for r in results)
        avg_error = total_error / len(results)
        
        print(f"Images evaluated: {len(results)}")
        print(f"Total positive class error: {total_pos_error}")
        print(f"Total negative class error: {total_neg_error}")
        print(f"Total combined error: {total_error}")
        print(f"Average error per image: {avg_error:.2f}")
        
        print("\nDetailed Results:")
        for i, result in enumerate(results):
            print(f"  Image {i+1}: Pos Error = {result['positive_error']}, Neg Error = {result['negative_error']}, Total = {result['total_error']}")

# Run CountGD demo if model is available
if model_ready:
    print("Running CountGD inference demonstration...")
    demo_countgd_inference(dataset, countgd_model, countgd_transform, num_examples=3)
else:
    print("CountGD model not available for inference demo")
    print("\nTo enable CountGD inference:")
    print("1. Set up CountGD model in models/countgd/")
    print("2. Download required checkpoint files")
    print("3. Install CountGD dependencies")
    print("4. Re-run this notebook")

## 7. Replicate Results from Paper

This section allows you to replicate the GeCo results from the paper.

### Instructions:
1. **Set up the environment and model:**
   ```
   models/geco/README.md
   ```
2. **Run single-class evaluation:**
   ```bash
   cd models/geco
   bash run_count_one_class.sh
   ```

3. **Run dual-class evaluation:**
   ```bash
   cd models/geco
   bash run_count_both_classes.sh
   ```

In [None]:
# ========= CONFIG (edit to your paths/names) =========
RESULTS_DIR = 'results'
DATASET     = 'pairtally_dataset'
MODEL_NAME  = 'CountGD'
OUTPUT_DIR  = 'countgd_analysis'
SAVE_FIGS   = True

QUAL_PATH   = f'{RESULTS_DIR}/{MODEL_NAME}-qualitative/{DATASET}/complete_qualitative_data.json'
# =====================================================

import json, os
from collections import defaultdict

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from IPython.display import display, Markdown

pd.options.display.max_rows = 200
pd.options.display.precision = 2

# Resolve a combined path now that 'os' is imported
_COMB_CANDIDATES = [
    f'{RESULTS_DIR}/{MODEL_NAME}-count-both-classes-qualitative/{DATASET}/{MODEL_NAME}-Combined_detailed_results.json',
    f'{RESULTS_DIR}/{MODEL_NAME}-qualitative-combined/{DATASET}/combined_inference_data.json',
    f'{RESULTS_DIR}/{MODEL_NAME}-qualitative-combined/{DATASET}/{MODEL_NAME}-Combined_detailed_results.json',
]
COMB_PATH = next((p for p in _COMB_CANDIDATES if os.path.exists(p)), None)

# ---------------- utils ----------------
def read_json(path):
    with open(path, 'r') as f:
        return json.load(f)

def basename(p): return os.path.basename(p)

def parse_filename_meta(filename):
    """
    Parse only metadata (test_type, super_category) from filename like:
      ..._{INTRA|INTER}_{super_category}_...
    We DO NOT trust counts from filename anymore; GT in JSON is authoritative.
    """
    stem = os.path.splitext(basename(filename))[0]
    parts = stem.split('_')
    test_type, super_cat = None, None
    for i, p in enumerate(parts):
        if p in ('INTRA','INTER'):
            test_type = p
            super_cat = parts[i+1] if i+1 < len(parts) else None
            break
    return {'test_type': test_type, 'super_category': super_cat}

def schema_report(obj, name):
    print(f'\n=== Schema Report: {name} ===')
    if isinstance(obj, dict):
        print('Top-level keys:', list(obj.keys())[:20])
        for k, v in obj.items():
            if isinstance(v, list) and v and isinstance(v[0], dict):
                print(f'Array-of-dicts key: {k} (len {len(v)})')
                print('Sample item keys:', list(v[0].keys())[:30])
                break
    elif isinstance(obj, list) and obj and isinstance(obj[0], dict):
        print('Top-level list len:', len(obj))
        print('Sample item keys:', list(obj[0].keys())[:30])

# ---------------- loaders ----------------
def extract_single_class_items(data):
    """
    Extract rows with: image_name, pred_count, class_type,
    gt_pos_count, gt_neg_count, gt_total_count (when present).
    Also maps per-class `gt_count` into pos/neg depending on class_type.
    """
    out = []

    def push(it, cls_hint=None):
        image_name = it.get('image_name') or it.get('filename') or it.get('image') or it.get('name')
        pred_count = it.get('pred_count') or it.get('predicted_count') or it.get('count') or it.get('prediction')
        cls        = (it.get('class_type') or it.get('class') or it.get('label') or cls_hint or '').lower()
        if cls not in ('positive','negative'):
            cls = 'positive'

        # Prefer explicit GT fields if present
        gt_pos = it.get('gt_pos_count')
        gt_neg = it.get('gt_neg_count')
        gt_tot = it.get('gt_total_count')

        # Some dumps only have gt_count for the queried class
        gt_one = it.get('gt_count')
        if gt_one is not None:
            if cls == 'positive' and gt_pos is None:
                gt_pos = int(gt_one)
            if cls == 'negative' and gt_neg is None:
                gt_neg = int(gt_one)

        if image_name is not None and pred_count is not None:
            out.append({
                'image_name': image_name,
                'pred_count': int(pred_count),
                'class_type': cls,
                'gt_pos_count': (int(gt_pos) if gt_pos is not None else None),
                'gt_neg_count': (int(gt_neg) if gt_neg is not None else None),
                'gt_total_count': (int(gt_tot) if gt_tot is not None else None),
            })

    # CountGD qualitative usually: {"class_results": {"positive": {"images":[...]}, "negative": {"images":[...]}}}
    if isinstance(data, dict) and 'class_results' in data:
        cr = data['class_results']
        for cls in ('positive','negative'):
            for it in cr.get(cls, {}).get('images', []):
                if isinstance(it, dict):
                    push(it, cls_hint=cls)
    elif isinstance(data, list):
        for it in data:
            if isinstance(it, dict):
                push(it)
    else:
        for v in (data.values() if isinstance(data, dict) else []):
            if isinstance(v, list):
                for it in v:
                    if isinstance(it, dict):
                        push(it)
    return out

def extract_combined_items(data):
    """
    Return dict keyed by BASENAME(filename) -> {'pred_total': int|None, 'gt_total': int|None}
    Supports common CountGD combined shapes and your `combined_gt_count`.
    """
    out = {}
    if not data:
        return out

    def put(fn, pred=None, gt=None):
        if not fn:
            return
        base = basename(fn)
        rec = out.setdefault(base, {'pred_total': None, 'gt_total': None})
        if pred is not None:
            rec['pred_total'] = int(pred)
        if gt is not None:
            rec['gt_total'] = int(gt)

    if isinstance(data, dict):
        # A) image_results mapping
        if isinstance(data.get('image_results'), dict):
            for fn, obj in data['image_results'].items():
                if isinstance(obj, dict):
                    pc = (
                        obj.get('combined_predicted_count') or
                        obj.get('pred_total_count') or
                        obj.get('predicted_count') or
                        obj.get('pred_count') or
                        obj.get('count') or
                        obj.get('prediction')
                    )
                    gt = (
                        obj.get('combined_gt_count') or     # <-- your key
                        obj.get('gt_total_count') or
                        obj.get('total_gt') or
                        obj.get('gt_count_total')
                    )
                    put(fn, pc, gt)

        # B) results -> images list
        res = data.get('results')
        if isinstance(res, dict) and isinstance(res.get('images'), list):
            for it in res['images']:
                if not isinstance(it, dict): continue
                fn = it.get('image_name') or it.get('filename') or it.get('image') or it.get('name')
                pc = (
                    it.get('pred_total_count') or
                    it.get('predicted_count') or
                    it.get('pred_count') or
                    it.get('count') or
                    it.get('prediction')
                )
                gt = (
                    it.get('combined_gt_count') or         # handle here too, just in case
                    it.get('gt_total_count') or
                    it.get('total_gt') or
                    it.get('gt_count_total')
                )
                put(fn, pc, gt)

        # C) any other top-level lists
        for v in data.values():
            if isinstance(v, list):
                for it in v:
                    if isinstance(it, dict):
                        fn = it.get('image_name') or it.get('filename') or it.get('image') or it.get('name')
                        pc = (
                            it.get('pred_total_count') or
                            it.get('predicted_count') or
                            it.get('pred_count') or
                            it.get('count') or
                            it.get('prediction')
                        )
                        gt = (
                            it.get('combined_gt_count') or
                            it.get('gt_total_count') or
                            it.get('total_gt') or
                            it.get('gt_count_total')
                        )
                        put(fn, pc, gt)

    elif isinstance(data, list):
        for it in data:
            if not isinstance(it, dict): continue
            fn = it.get('image_name') or it.get('filename') or it.get('image') or it.get('name')
            pc = (
                it.get('pred_total_count') or
                it.get('predicted_count') or
                it.get('pred_count') or
                it.get('count') or
                it.get('prediction')
            )
            gt = (
                it.get('combined_gt_count') or
                it.get('gt_total_count') or
                it.get('total_gt') or
                it.get('gt_count_total')
            )
            put(fn, pc, gt)

    return out

def load_predictions_countgd(qual_path, comb_path):
    qd = read_json(qual_path)
    schema_report(qd, 'QUALITATIVE JSON')
    cd = read_json(comb_path) if comb_path and os.path.exists(comb_path) else None
    if cd is not None:
        schema_report(cd, 'COMBINED JSON')

    single_items = extract_single_class_items(qd)
    combined_map = extract_combined_items(cd)

    print(f"\nFound {len(single_items)} single-class items; combined predictions: {len(combined_map)}")

    image_predictions = {}
    kept, skipped = 0, 0

    for it in single_items:
        fname = it['image_name']
        base  = basename(fname)
        cls   = it['class_type']
        key   = (base, cls)

        meta = parse_filename_meta(base)  # test_type / super_category only

        # Prefer JSON GT; compute total if missing but pos+neg exist
        gt_pos = it['gt_pos_count']
        gt_neg = it['gt_neg_count']
        gt_tot = it['gt_total_count']
        if gt_tot is None and (gt_pos is not None and gt_neg is not None):
            gt_tot = gt_pos + gt_neg

        # If still missing total, fall back to combined GT if available
        if gt_tot is None and base in combined_map and combined_map[base]['gt_total'] is not None:
            gt_tot = combined_map[base]['gt_total']

        # Choose a_true based on class
        a_true = gt_neg if cls == 'negative' else gt_pos

        # Strict: require a and total; otherwise skip
        if a_true is None or gt_tot is None:
            skipped += 1
            continue

        b_true = gt_tot - a_true

        image_predictions[key] = {
            'filename': base,
            'class_type': cls,
            'test_type': meta['test_type'],
            'super_category': meta['super_category'],
            'f_A': int(it['pred_count']),
            'a': int(a_true),
            'b': int(b_true),
            'a_plus_b': int(gt_tot),
            'f_A_plus_B': (
                combined_map[base]['pred_total']
                if base in combined_map and combined_map[base]['pred_total'] is not None
                else np.nan
            ),
        }
        kept += 1

    print(f"Kept {kept} rows; skipped {skipped} (missing GT).")
    return image_predictions

# ---------------- metrics & summaries ----------------
def per_image_metrics(image_predictions):
    rows = []
    for (_, _), d in image_predictions.items():
        f_A        = d['f_A']
        a          = d['a']
        a_plus_b   = d['a_plus_b']
        f_A_plus_B = d.get('f_A_plus_B', np.nan)

        rows.append({
            'filename': d['filename'],
            'class_type': d['class_type'],
            'test_type': d['test_type'],
            'super_category': d['super_category'],
            'f_A': f_A,
            'f_A_plus_B': f_A_plus_B,
            'a': a,
            'b': d['b'],
            'a_plus_b': a_plus_b,
            '|f(A)-a|': abs(f_A - a),
            '|f(A+B)-(a+b)|': abs((f_A_plus_B if not np.isnan(f_A_plus_B) else f_A) - a_plus_b),
            '|f(A)-(a+b)|': abs(f_A - a_plus_b),
            '|f(A)-f(A+B)|': abs(f_A - (f_A_plus_B if not np.isnan(f_A_plus_B) else f_A)),
            'f(A)>a': f_A > a,
            '|f(A)-a|>|f(A)-(a+b)|': abs(f_A - a) > abs(f_A - a_plus_b),
        })
    return pd.DataFrame(rows)

def summarize_metrics(df):
    if df.empty:
        return pd.DataFrame([{
            'total_predictions': 0,
            'mean |f(A)-a|': np.nan,
            'mean |f(A+B)-(a+b)|': np.nan,
            'mean |f(A)-(a+b)|': np.nan,
            'mean |f(A)-f(A+B)|': np.nan,
            'f(A)>a %': np.nan,
            '|f(A)-a|>|f(A)-(a+b)| %': np.nan
        }])
    total = len(df)
    return pd.DataFrame([{
        'total_predictions': total,
        'mean |f(A)-a|': df['|f(A)-a|'].mean(),
        'mean |f(A+B)-(a+b)|': df['|f(A+B)-(a+b)|'].mean(),
        'mean |f(A)-(a+b)|': df['|f(A)-(a+b)|'].mean(),
        'mean |f(A)-f(A+B)|': df['|f(A)-f(A+B)|'].mean(),
        'f(A)>a %': 100.0 * df['f(A)>a'].mean(),
        '|f(A)-a|>|f(A)-(a+b)| %': 100.0 * df['|f(A)-a|>|f(A)-(a+b)|'].mean(),
    }])

def summarize_by(df, group_cols):
    if df.empty: return pd.DataFrame()
    parts = []
    for keys, g in df.groupby(group_cols):
        s = summarize_metrics(g)
        row = {col: val for col, val in zip(group_cols, keys if isinstance(keys, tuple) else (keys,))}
        parts.append(pd.concat([pd.DataFrame([row]), s], axis=1))
    return pd.concat(parts, ignore_index=True)

def plot_overall_bars(summary_df, title_prefix=MODEL_NAME):
    if summary_df.empty or summary_df['total_predictions'].iloc[0] == 0:
        print("No data to plot."); return

    mets = ['mean |f(A)-a|', 'mean |f(A+B)-(a+b)|', 'mean |f(A)-(a+b)|', 'mean |f(A)-f(A+B)|']
    vals = summary_df.loc[0, mets].astype(float).values

    plt.figure(figsize=(10,5))
    plt.bar(mets, vals)
    plt.ylabel('Mean Absolute Error'); plt.title(f'{title_prefix}: Fine-Grained Counting Metrics (Means)')
    plt.xticks(rotation=20, ha='right'); plt.grid(True, axis='y', alpha=0.3); plt.tight_layout()
    if SAVE_FIGS:
        os.makedirs(OUTPUT_DIR, exist_ok=True)
        plt.savefig(os.path.join(OUTPUT_DIR, f'{MODEL_NAME.lower()}_fine_grained_metrics.png'), dpi=300, bbox_inches='tight')
        plt.savefig(os.path.join(OUTPUT_DIR, f'{MODEL_NAME.lower()}_fine_grained_metrics.pdf'), bbox_inches='tight')
    plt.show()

    mets2 = ['f(A)>a %', '|f(A)-a|>|f(A)-(a+b)| %']
    vals2 = summary_df.loc[0, mets2].astype(float).values

    plt.figure(figsize=(7,5))
    plt.bar(mets2, vals2)
    plt.ylabel('Percentage'); plt.title(f'{title_prefix}: Overcounting Indicators')
    plt.xticks(rotation=10, ha='right'); plt.grid(True, axis='y', alpha=0.3); plt.tight_layout()
    if SAVE_FIGS:
        plt.savefig(os.path.join(OUTPUT_DIR, f'{MODEL_NAME.lower()}_overcounting_percentages.png'), dpi=300, bbox_inches='tight')
        plt.savefig(os.path.join(OUTPUT_DIR, f'{MODEL_NAME.lower()}_overcounting_percentages.pdf'), bbox_inches='tight')
    plt.show()

# ---------------- run ----------------
display(Markdown(f"# {MODEL_NAME} Overcounting Experiment"))
print("Hypothesis: the single-class prediction f(A) is closer to total count (a+b) than to a.\n")
print(f"Single-class data: {QUAL_PATH}")
print("Combined data    :", (COMB_PATH if COMB_PATH else "(not found — proceeding without it)"))

if not os.path.exists(QUAL_PATH):
    raise FileNotFoundError(f"Qualitative data file not found: {QUAL_PATH}")

preds_dict = load_predictions_countgd(QUAL_PATH, COMB_PATH)
df_img = per_image_metrics(preds_dict)

print(f"\nParsed {len(df_img):,} per-(image,class) rows.\n")
if df_img.empty:
    display(Markdown("> **No rows parsed** — check that your qualitative JSON includes `gt_pos_count/gt_neg_count/gt_total_count` or at least per-row `gt_count` for the queried class."))
else:
    display(Markdown("### Sample rows"))
    display(df_img.head(10))

summary = summarize_metrics(df_img)
summary.insert(0, 'Model', MODEL_NAME)
display(Markdown("## Overall Summary"))
display(summary.style.format({
    'mean |f(A)-a|': '{:.2f}',
    'mean |f(A+B)-(a+b)|': '{:.2f}',
    'mean |f(A)-(a+b)|': '{:.2f}',
    'mean |f(A)-f(A+B)|': '{:.2f}',
    'f(A)>a %': '{:.1f}',
    '|f(A)-a|>|f(A)-(a+b)| %': '{:.1f}',
    'total_predictions': '{:,.0f}'
}))

by_test = summarize_by(df_img, ['test_type'])
if not by_test.empty:
    by_test.insert(0, 'Model', MODEL_NAME)
    display(Markdown("## By Test Type"))
    display(by_test.style.format({
        'mean |f(A)-a|': '{:.2f}',
        'mean |f(A+B)-(a+b)|': '{:.2f}',
        'mean |f(A)-(a+b)|': '{:.2f}',
        'mean |f(A)-f(A+B)|': '{:.2f}',
        'f(A)>a %': '{:.1f}',
        '|f(A)-a|>|f(A)-(a+b)| %': '{:.1f}',
        'total_predictions': '{:,.0f}'
    }))

by_super = summarize_by(df_img, ['super_category'])
if not by_super.empty:
    by_super.insert(0, 'Model', MODEL_NAME)
    display(Markdown("## By Super Category"))
    display(by_super.style.format({
        'mean |f(A)-a|': '{:.2f}',
        'mean |f(A+B)-(a+b)|': '{:.2f}',
        'mean |f(A)-(a+b)|': '{:.2f}',
        'mean |f(A)-f(A+B)|': '{:.2f}',
        'f(A)>a %': '{:.1f}',
        '|f(A)-a|>|f(A)-(a+b)| %': '{:.1f}',
        'total_predictions': '{:,.0f}'
    }))

plot_overall_bars(summary, title_prefix=MODEL_NAME)

if SAVE_FIGS:
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    df_img.to_csv(os.path.join(OUTPUT_DIR, f'{MODEL_NAME.lower()}_per_image_metrics_{DATASET}.csv'), index=False)
    summary.to_csv(os.path.join(OUTPUT_DIR, f'{MODEL_NAME.lower()}_overcounting_summary_{DATASET}.csv'), index=False)
    print(f"Saved tables:\n  {os.path.join(OUTPUT_DIR, f'{MODEL_NAME.lower()}_per_image_metrics_{DATASET}.csv')}\n  {os.path.join(OUTPUT_DIR, f'{MODEL_NAME.lower()}_overcounting_summary_{DATASET}.csv')}")
