# PairTally Dataset Demonstration

## Overview

PairTally is the first benchmark specifically designed to evaluate fine-grained visual counting capabilities in computer vision models. This notebook provides a comprehensive demonstration of the dataset structure, visualization capabilities, and evaluation pipeline.

### Paper

**Can Current AI Models Count What We Mean, Not What They See? A Benchmark and Systematic Evaluation**  
Gia Khanh Nguyen, Yifeng Huang, Minh Hoai  
Digital Image Computing: Techniques and Applications (DICTA) 2025

### Dataset Specifications

- **Total Images**: 681 high-resolution images
- **Categories**: 54 object categories across 98 subcategories
- **Task Types**: Inter-category (different objects) and Intra-category (same object, different attributes)
- **Attribute Differences**: Color (43.5%), Shape/Texture (42.5%), Size (14.0%)

### Key Findings

Current state-of-the-art models achieve Mean Absolute Error (MAE) of 53.07, revealing critical gaps in fine-grained visual understanding and discrimination capabilities.

## 1. Environment Setup and Dependencies

In [None]:
%pip install pandas matplotlib seaborn scikit-learn

In [None]:
# Standard library imports
import os
import json
import random
import warnings
from pathlib import Path
from typing import Dict, List, Tuple, Optional

# Data handling
import numpy as np
import pandas as pd

# Visualization
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.patches import Rectangle
import seaborn as sns

# Image processing
from PIL import Image
import cv2

# Progress bar
from tqdm.notebook import tqdm

# Configure visualization settings
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("husl")
warnings.filterwarnings('ignore')

# Set random seed for reproducibility
random.seed(42)
np.random.seed(42)

print("Environment configured successfully")

## 2. Dataset Configuration

In [None]:
# Configure dataset paths
BASE_DIR = Path(".").resolve()
DATASET_DIR = BASE_DIR / "dataset" / "pairtally_dataset"
ANNOTATIONS_DIR = DATASET_DIR / "annotations"
IMAGES_DIR = DATASET_DIR / "images"
MODELS_DIR = BASE_DIR / "models"

def verify_dataset():
    """Verify dataset structure and availability"""
    issues = []
    
    if not DATASET_DIR.exists():
        issues.append(f"Dataset directory not found: {DATASET_DIR}")
    
    if not ANNOTATIONS_DIR.exists():
        issues.append(f"Annotations directory not found: {ANNOTATIONS_DIR}")
    
    if not IMAGES_DIR.exists():
        issues.append(f"Images directory not found: {IMAGES_DIR}")
        issues.append("Download images from: https://drive.google.com/file/d/1TnenXS4yFicjo81NnmClfzgc8ltmmeBv/view")
    
    if issues:
        for issue in issues:
            print(f"ERROR: {issue}")
        return False
    
    # Count images
    num_images = len(list(IMAGES_DIR.glob("*.jpg")))
    print(f"Dataset verified successfully")
    print(f"Found {num_images} images at {DATASET_DIR}")
    return True

dataset_ready = verify_dataset()

## 3. Load Dataset Annotations

In [None]:
class PairTallyDataset:
    """PairTally dataset interface for loading and accessing annotations"""
    
    def __init__(self, dataset_dir: Path, version: str = "simple"):
        self.dataset_dir = dataset_dir
        self.annotations_dir = dataset_dir / "annotations"
        self.images_dir = dataset_dir / "images"
        self.version = version  # "simple" or "augmented"
        
        # Load annotations
        self.annotations = self._load_annotations()
        self.inter_annotations = self._load_inter_annotations()
        self.intra_annotations = self._load_intra_annotations()
        
        # Load metadata
        self.metadata = self._load_metadata()
        self.filename_mapping = self._load_filename_mapping()
        
        print(f"Loaded {len(self.annotations)} total annotations")
        print(f"  Inter-category: {len(self.inter_annotations)} images")
        print(f"  Intra-category: {len(self.intra_annotations)} images")
    
    def _load_json(self, filename: str) -> Dict:
        """Load a JSON file"""
        filepath = self.annotations_dir / filename
        if filepath.exists():
            with open(filepath, 'r') as f:
                return json.load(f)
        return {}
    
    def _load_annotations(self) -> Dict:
        """Load main annotations"""
        filename = f"pairtally_annotations_{self.version}.json"
        return self._load_json(filename)
    
    def _load_inter_annotations(self) -> Dict:
        """Load inter-category annotations"""
        filename = f"pairtally_annotations_inter_{self.version}.json"
        return self._load_json(filename)
    
    def _load_intra_annotations(self) -> Dict:
        """Load intra-category annotations"""
        filename = f"pairtally_annotations_intra_{self.version}.json"
        return self._load_json(filename)
    
    def _load_metadata(self) -> Dict:
        """Load image metadata"""
        return self._load_json("image_metadata.json")
    
    def _load_filename_mapping(self) -> Dict:
        """Load filename mapping"""
        return self._load_json("filename_mapping.json")
    
    def get_random_image(self, subset: str = "all") -> str:
        """Get a random image filename from the dataset"""
        if subset == "inter":
            images = list(self.inter_annotations.keys())
        elif subset == "intra":
            images = list(self.intra_annotations.keys())
        else:
            images = list(self.annotations.keys())
        
        return random.choice(images) if images else None
    
    def get_annotation(self, image_name: str) -> Dict:
        """Get annotation for a specific image"""
        return self.annotations.get(image_name, {})
    
    def get_counts(self, image_name: str) -> Tuple[int, int]:
        """Get positive and negative class counts for an image"""
        anno = self.get_annotation(image_name)
        if anno:
            positive_count = len(anno.get('points', []))
            negative_count = len(anno.get('negative_points', []))
            return positive_count, negative_count
        return 0, 0
    
    def get_image_path(self, image_name: str) -> Path:
        """Get full path to an image"""
        return self.images_dir / image_name

# Initialize dataset
if dataset_ready:
    dataset = PairTallyDataset(DATASET_DIR, version="simple")
    print("\nDataset Statistics:")
    print(f"Total images: {len(dataset.annotations)}")
    print(f"Inter-category pairs: {len(dataset.inter_annotations)}")
    print(f"Intra-category pairs: {len(dataset.intra_annotations)}")

## 4. Visualization Functions

In [None]:
def visualize_image_with_annotations(dataset: PairTallyDataset, 
                                    image_name: str,
                                    show_boxes: bool = True,
                                    show_points: bool = False,
                                    figsize: Tuple[int, int] = None):
    """
    Visualize an image with bounding box annotations and object counts.
    
    Args:
        dataset: PairTallyDataset instance
        image_name: Name of the image to visualize
        show_boxes: Whether to show bounding box examples
        show_points: Whether to show point annotations
        figsize: Figure size for visualization
    """
    # Get image path and annotation
    img_path = dataset.get_image_path(image_name)
    annotation = dataset.get_annotation(image_name)
    
    if not img_path.exists():
        print(f"ERROR: Image not found: {img_path}")
        return
    
    # Load image
    img = Image.open(img_path)
    img_array = np.array(img)
    
    # Get image dimensions
    height, width = img_array.shape[:2]
    
    # Auto-calculate figure size to maintain aspect ratio
    if figsize is None:
        aspect_ratio = width / height
        fig_width = 10
        fig_height = fig_width / aspect_ratio + 1
        figsize = (fig_width, fig_height)
    
    # Get annotation data
    positive_points = annotation.get('points', [])
    negative_points = annotation.get('negative_points', [])
    positive_boxes = annotation.get('box_examples_coordinates', [])
    negative_boxes = annotation.get('negative_box_exemples_coordinates', [])
    positive_prompt = annotation.get('positive_prompt', 'Class 1')
    negative_prompt = annotation.get('negative_prompt', 'Class 2')
    
    # Create figure
    fig = plt.figure(figsize=figsize, facecolor='white')
    ax = fig.add_subplot(111)
    
    # Display image
    ax.imshow(img_array)
    ax.axis('off')
    
    # Add title with counts
    blue_text = f"{positive_prompt}: {len(positive_points)}"
    red_text = f"{negative_prompt}: {len(negative_points)}"
    
    fig.text(0.45, 0.95, blue_text, 
             ha='right', va='top', fontsize=14, color='#0066FF', 
             fontweight='bold', transform=fig.transFigure)
    
    fig.text(0.5, 0.95, " | ", 
             ha='center', va='top', fontsize=14, color='#666666', 
             fontweight='normal', transform=fig.transFigure)
    
    fig.text(0.55, 0.95, red_text, 
             ha='left', va='top', fontsize=14, color='#FF0040', 
             fontweight='bold', transform=fig.transFigure)
    
    # Plot bounding boxes
    if show_boxes:
        # Positive class boxes in blue
        for box_coords in positive_boxes[:3]:
            if len(box_coords) == 4:
                x_coords = [pt[0] for pt in box_coords]
                y_coords = [pt[1] for pt in box_coords]
                x1, x2 = min(x_coords), max(x_coords)
                y1, y2 = min(y_coords), max(y_coords)
                
                rect = Rectangle((x1, y1), x2-x1, y2-y1, 
                               linewidth=2, edgecolor='#0066FF', 
                               facecolor='none')
                ax.add_patch(rect)
        
        # Negative class boxes in red
        for box_coords in negative_boxes[:3]:
            if len(box_coords) == 4:
                x_coords = [pt[0] for pt in box_coords]
                y_coords = [pt[1] for pt in box_coords]
                x1, x2 = min(x_coords), max(x_coords)
                y1, y2 = min(y_coords), max(y_coords)
                
                rect = Rectangle((x1, y1), x2-x1, y2-y1, 
                               linewidth=2, edgecolor='#FF0040', 
                               facecolor='none')
                ax.add_patch(rect)
    
    # Plot points if requested
    if show_points:
        if positive_points:
            pos_points = np.array(positive_points)
            ax.scatter(pos_points[:, 0], pos_points[:, 1], 
                      c='#0066FF', s=8, alpha=0.4, marker='.')
        
        if negative_points:
            neg_points = np.array(negative_points)
            ax.scatter(neg_points[:, 0], neg_points[:, 1], 
                      c='#FF0040', s=8, alpha=0.4, marker='.')
    
    plt.tight_layout(rect=[0, 0.02, 1, 0.94])
    plt.show()
    
    return positive_points, negative_points

def display_dataset_statistics(dataset: PairTallyDataset):
    """Display comprehensive statistics about the PairTally dataset"""
    
    # Collect statistics
    all_positive_counts = []
    all_negative_counts = []
    all_total_counts = []
    
    for img_name in dataset.annotations.keys():
        pos_count, neg_count = dataset.get_counts(img_name)
        all_positive_counts.append(pos_count)
        all_negative_counts.append(neg_count)
        all_total_counts.append(pos_count + neg_count)
    
    # Create visualization
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    fig.patch.set_facecolor('white')
    
    # Distribution of total counts
    axes[0, 0].hist(all_total_counts, bins=30, color='#2E86AB', edgecolor='#1B4F72', alpha=0.8)
    axes[0, 0].set_title('Distribution of Total Object Counts', fontsize=12, fontweight='bold')
    axes[0, 0].set_xlabel('Total Count')
    axes[0, 0].set_ylabel('Frequency')
    axes[0, 0].axvline(np.mean(all_total_counts), color='#E74C3C', linestyle='--', linewidth=2,
                      label=f'Mean: {np.mean(all_total_counts):.1f}')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # Distribution of positive vs negative counts
    axes[0, 1].hist([all_positive_counts, all_negative_counts], 
                   bins=25, label=['Positive Class', 'Negative Class'],
                   color=['#0066FF', '#FF0040'], alpha=0.7, edgecolor='#2C3E50')
    axes[0, 1].set_title('Distribution of Class Counts', fontsize=12, fontweight='bold')
    axes[0, 1].set_xlabel('Count')
    axes[0, 1].set_ylabel('Frequency')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    
    # Scatter plot of positive vs negative counts
    axes[1, 0].scatter(all_positive_counts, all_negative_counts, 
                      alpha=0.6, s=25, c='#8E44AD', edgecolors='#5B2C6F', linewidth=0.5)
    axes[1, 0].set_title('Positive vs Negative Class Counts', fontsize=12, fontweight='bold')
    axes[1, 0].set_xlabel('Positive Class Count')
    axes[1, 0].set_ylabel('Negative Class Count')
    axes[1, 0].plot([0, max(all_positive_counts)], [0, max(all_positive_counts)], 
                   'k--', alpha=0.3, label='Equal counts', linewidth=1)
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)
    
    # Summary statistics
    stats_text = f"""Dataset Summary Statistics
    
Total Images: {len(dataset.annotations)}
Inter-category: {len(dataset.inter_annotations)}
Intra-category: {len(dataset.intra_annotations)}

Object Count Statistics:
  Mean total: {np.mean(all_total_counts):.1f} ± {np.std(all_total_counts):.1f}
  Min/Max: {min(all_total_counts)} / {max(all_total_counts)}
  Median: {np.median(all_total_counts):.1f}

Positive Class:
  Mean: {np.mean(all_positive_counts):.1f} ± {np.std(all_positive_counts):.1f}
  Range: [{min(all_positive_counts)}, {max(all_positive_counts)}]

Negative Class:
  Mean: {np.mean(all_negative_counts):.1f} ± {np.std(all_negative_counts):.1f}
  Range: [{min(all_negative_counts)}, {max(all_negative_counts)}]
"""
    
    axes[1, 1].text(0.1, 0.5, stats_text, transform=axes[1, 1].transAxes,
                   fontsize=11, verticalalignment='center', fontfamily='monospace',
                   bbox=dict(boxstyle='round,pad=0.8', facecolor='#F8F9FA', alpha=0.9, edgecolor='#BDC3C7'))
    axes[1, 1].axis('off')
    
    plt.suptitle('PairTally Dataset Statistics', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()

# Display statistics
if dataset_ready:
    print("Generating dataset statistics visualization...")
    display_dataset_statistics(dataset)

## 5. Demo: Select and Visualize Random Images

In [None]:
def demo_random_images(dataset: PairTallyDataset, 
                       num_images: int = 3,
                       subset: str = "all"):
    """
    Demonstrate dataset visualization with random image selection.
    
    Args:
        dataset: PairTallyDataset instance
        num_images: Number of random images to display
        subset: Which subset to use ("all", "inter", "intra")
    """
    print(f"\nSelecting {num_images} random images from {subset} subset...\n")
    
    for i in range(num_images):
        # Get random image
        image_name = dataset.get_random_image(subset=subset)
        
        if image_name:
            # Get annotation info
            annotation = dataset.get_annotation(image_name)
            pos_count, neg_count = dataset.get_counts(image_name)
            
            # Get category names
            pos_prompt = annotation.get('positive_prompt', 'Class 1')
            neg_prompt = annotation.get('negative_prompt', 'Class 2')
            
            # Determine if inter or intra category
            category_type = "INTER" if image_name in dataset.inter_annotations else "INTRA"
            
            print(f"Image {i+1}/{num_images}: {image_name}")
            print(f"  Category Type: {category_type}-category")
            print(f"  Classes: {pos_prompt} vs {neg_prompt}")
            print(f"  Counts: {pos_prompt}={pos_count}, {neg_prompt}={neg_count}, Total={pos_count+neg_count}")
            print()
            
            # Visualize with updated function
            visualize_image_with_annotations(dataset, image_name, 
                                           show_boxes=True, show_points=False)
            print("-" * 80)

def demo_specific_attribute_images(dataset: PairTallyDataset, 
                                  attribute_type: str = "color",
                                  specific_image: str = None):
    """
    Demonstrate specific attribute-based intra-category examples.
    
    Args:
        dataset: PairTallyDataset instance
        attribute_type: Type of attribute difference ("color", "size", "texture")
        specific_image: Specific image name to display (optional)
    """
    if specific_image:
        if specific_image in dataset.annotations:
            annotation = dataset.get_annotation(specific_image)
            pos_count, neg_count = dataset.get_counts(specific_image)
            pos_prompt = annotation.get('positive_prompt', 'Class 1')
            neg_prompt = annotation.get('negative_prompt', 'Class 2')
            
            print(f"\nSpecific Image: {specific_image}")
            print(f"  Classes: {pos_prompt} vs {neg_prompt}")
            print(f"  Counts: {pos_prompt}={pos_count}, {neg_prompt}={neg_count}")
            print()
            
            visualize_image_with_annotations(dataset, specific_image, 
                                           show_boxes=True, show_points=False)
        else:
            print(f"ERROR: Image not found: {specific_image}")
    else:
        # Find an intra-category image
        intra_images = list(dataset.intra_annotations.keys())
        if intra_images:
            # Try to find images with specific patterns
            candidates = []
            
            if attribute_type == "color":
                keywords = ['COL', 'color', 'black', 'white', 'red', 'blue', 'green', 'yellow']
            elif attribute_type == "size":
                keywords = ['SIZ', 'size', 'big', 'small', 'large', 'tiny']
            else:  # texture/shape
                keywords = ['TEX', 'SHA', 'round', 'square', 'smooth', 'rough']
            
            # Filter images that might match
            for img in intra_images:
                if any(kw.lower() in img.lower() for kw in keywords):
                    candidates.append(img)
            
            # If no specific matches, just use any intra image
            if not candidates:
                candidates = intra_images
            
            # Select and display
            if candidates:
                selected = random.choice(candidates)
                annotation = dataset.get_annotation(selected)
                pos_count, neg_count = dataset.get_counts(selected)
                pos_prompt = annotation.get('positive_prompt', 'Class 1')
                neg_prompt = annotation.get('negative_prompt', 'Class 2')
                
                print(f"\n{attribute_type.capitalize()} Difference Example: {selected}")
                print(f"  Classes: {pos_prompt} vs {neg_prompt}")
                print(f"  Counts: {pos_prompt}={pos_count}, {neg_prompt}={neg_count}")
                print()
                
                visualize_image_with_annotations(dataset, selected, 
                                               show_boxes=True, show_points=False)

# Run comprehensive demonstration
if dataset_ready:
    print("="*80)
    print("COMPREHENSIVE DATASET VISUALIZATION")
    print("="*80)
    
    # 1. Show 3 INTER-category examples
    print("\n" + "="*60)
    print("1. INTER-CATEGORY EXAMPLES (Different Object Types)")
    print("="*60)
    demo_random_images(dataset, num_images=3, subset="inter")
    
    # 2. Show 3 INTRA-category examples
    print("\n" + "="*60)
    print("2. INTRA-CATEGORY EXAMPLES (Same Object, Different Attributes)")
    print("="*60)
    demo_random_images(dataset, num_images=3, subset="intra")
    
    # 3. Show specific attribute differences
    print("\n" + "="*60)
    print("3. SPECIFIC ATTRIBUTE DIFFERENCES")
    print("="*60)
    
    print("\nDemonstrating attribute-based discrimination challenges:")
    
    # Color difference example
    color_image = "FOO_INTRA_CFC1_CFC2_077_077_33f3fb.jpg"
    print(f"\nCOLOR DIFFERENCE EXAMPLE:")
    demo_specific_attribute_images(dataset, specific_image=color_image)
    
    # Size difference example
    size_image = "HOU_INTRA_BAT1_BAT2_014_012_578d7d.jpg"
    print(f"\nSIZE DIFFERENCE EXAMPLE:")
    demo_specific_attribute_images(dataset, specific_image=size_image)
    
    # Texture/Shape difference example
    texture_image = "OTR_INTRA_NUT1_NUT2_030_055_f00b70.jpg"
    print(f"\nTEXTURE/SHAPE DIFFERENCE EXAMPLE:")
    demo_specific_attribute_images(dataset, specific_image=texture_image)

## 6. CountGD Model Evaluation Demo

This section demonstrates how to use the CountGD model on PairTally dataset images. CountGD is a state-of-the-art object counting model that can use both exemplar boxes and text prompts for counting.

In [None]:
# Import CountGD model components
import sys
import torch
from pathlib import Path
import torchvision.transforms.functional as F

# Add CountGD to path
countgd_path = BASE_DIR / "models" / "countgd" / "CountGD"
sys.path.append(str(countgd_path))

try:
    from util.slconfig import SLConfig, DictAction
    from util.misc import nested_tensor_from_tensor_list
    import datasets_inference.transforms as T
    from models.registry import MODULE_BUILD_FUNCS
    
    print("CountGD imports successful")
    countgd_available = True
except ImportError as e:
    print(f"CountGD import failed: {e}")
    print("Make sure CountGD dependencies are installed")
    countgd_available = False

def load_countgd_model():
    """Load the CountGD model with default configuration"""
    if not countgd_available:
        print("CountGD not available - skipping model loading")
        return None, None
        
    # Set default paths
    config_path = countgd_path / "config" / "cfg_fsc147_vit_b.py"
    checkpoint_path = countgd_path / "checkpoints" / "checkpoint_fsc147_best.pth"
    
    if not config_path.exists():
        print(f"Config file not found: {config_path}")
        print("Please ensure CountGD model files are properly set up")
        return None, None
        
    if not checkpoint_path.exists():
        print(f"Checkpoint file not found: {checkpoint_path}")
        print("Please download the CountGD model checkpoint")
        return None, None
    
    try:
        # Load config
        cfg = SLConfig.fromfile(str(config_path))
        cfg.merge_from_dict({"text_encoder_type": str(countgd_path / "checkpoints" / "bert-base-uncased")})
        
        # Set device
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {device}")
        
        # Build model
        build_func = MODULE_BUILD_FUNCS.get("groundingdino")
        if build_func is None:
            print("Model builder not found")
            return None, None
            
        class Args:
            def __init__(self):
                self.modelname = "groundingdino"
                self.device = str(device)
        
        args = Args()
        for k, v in cfg._cfg_dict.items():
            setattr(args, k, v)
            
        model, _, _ = build_func(args)
        model.to(device)
        
        # Load checkpoint - FIX: Set weights_only=False for PyTorch 2.6+
        print("Loading checkpoint...")
        checkpoint = torch.load(str(checkpoint_path), map_location=device, weights_only=False)["model"]
        model.load_state_dict(checkpoint, strict=False)
        model.eval()
        
        # Create transform
        transforms = T.Compose([
            T.RandomResize([800], max_size=1333),
            T.ToTensor(),
            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        
        print("CountGD model loaded successfully!")
        return model, transforms
        
    except Exception as e:
        print(f"Failed to load CountGD model: {e}")
        import traceback
        traceback.print_exc()
        return None, None

# Load the model
if countgd_available:
    print("Loading CountGD model...")
    countgd_model, countgd_transform = load_countgd_model()
    model_ready = (countgd_model is not None)
else:
    model_ready = False
    print("CountGD model not available for this demo")

print(f"Model ready status: {model_ready}")

In [None]:
def run_countgd_inference(model, transform, image, exemplar_boxes, text_prompt, confidence_thresh=0.3):
    """
    Run CountGD inference on an image with exemplar boxes and text prompt
    
    Args:
        model: CountGD model
        transform: Image transform pipeline
        image: PIL Image
        exemplar_boxes: List of exemplar boxes in [x1, y1, x2, y2] format
        text_prompt: Text description of objects to count
        confidence_thresh: Confidence threshold for detections
        
    Returns:
        pred_boxes: Predicted bounding boxes (normalized coordinates)
        pred_logits: Prediction logits/scores
        pred_count: Number of predicted objects
    """
    device = next(model.parameters()).device
    
    # Convert exemplar boxes to tensor format if available
    if exemplar_boxes and len(exemplar_boxes) > 0:
        # Convert to normalized coordinates relative to image size
        img_width, img_height = image.size
        exemplar_tensor = []
        for box in exemplar_boxes[:3]:  # Limit to 3 exemplars
            x1, y1, x2, y2 = box
            # Normalize to [0, 1]
            norm_x1 = x1 / img_width
            norm_y1 = y1 / img_height
            norm_x2 = x2 / img_width
            norm_y2 = y2 / img_height
            exemplar_tensor.append([norm_x1, norm_y1, norm_x2, norm_y2])
        
        exemplar_tensor = torch.tensor(exemplar_tensor, dtype=torch.float32)
    else:
        # Use empty tensor if no exemplars
        exemplar_tensor = torch.tensor([], dtype=torch.float32).reshape(0, 4)
    
    # Prepare input
    input_image, target = transform(image, {"exemplars": exemplar_tensor})
    input_image = input_image.to(device)
    input_exemplar = target["exemplars"].to(device)
    
    # Format text prompt
    input_text = text_prompt + " ."
    
    # Run inference
    with torch.no_grad():
        model_output = model(
            input_image.unsqueeze(0),
            [input_exemplar],
            [torch.tensor([0]).to(device)],
            captions=[input_text],
        )
    
    # Extract predictions
    logits = model_output["pred_logits"][0].sigmoid()
    boxes = model_output["pred_boxes"][0]
    
    # Filter by confidence threshold
    box_mask = logits.max(dim=-1).values > confidence_thresh
    filtered_logits = logits[box_mask, :]
    filtered_boxes = boxes[box_mask, :]
    pred_count = filtered_boxes.shape[0]
    
    return filtered_boxes, filtered_logits, pred_count

def visualize_countgd_results(dataset, image_name, model, transform, confidence_thresh=0.3):
    """
    Run CountGD on a PairTally image and visualize results comparing ground truth vs predictions.
    """
    if not model_ready:
        print("CountGD model not available")
        return
        
    try:
        # Get image and annotation
        img_path = dataset.get_image_path(image_name)
        annotation = dataset.get_annotation(image_name)
        
        if not img_path.exists():
            print(f"ERROR: Image not found: {img_path}")
            return
        
        # Load image
        image = Image.open(img_path).convert('RGB')
        img_array = np.array(image)
        
        # Get annotation data
        positive_points = annotation.get('points', [])
        negative_points = annotation.get('negative_points', [])
        positive_boxes = annotation.get('box_examples_coordinates', [])
        negative_boxes = annotation.get('negative_box_exemples_coordinates', [])
        positive_prompt = annotation.get('positive_prompt', 'objects')
        negative_prompt = annotation.get('negative_prompt', 'objects')
        
        # Convert FSC147 format boxes to [x1, y1, x2, y2] format
        def convert_boxes(box_coords_list):
            boxes = []
            for box_coords in box_coords_list[:3]:  # Use first 3 exemplars
                if len(box_coords) == 4:
                    x_coords = [pt[0] for pt in box_coords]
                    y_coords = [pt[1] for pt in box_coords]
                    x1, x2 = min(x_coords), max(x_coords)
                    y1, y2 = min(y_coords), max(y_coords)
                    boxes.append([x1, y1, x2, y2])
            return boxes
        
        pos_exemplar_boxes = convert_boxes(positive_boxes)
        neg_exemplar_boxes = convert_boxes(negative_boxes)
        
        # Get image dimensions for box conversion
        img_width, img_height = image.size
        
        print(f"Running CountGD on: {image_name}")
        print(f"Ground Truth - {positive_prompt}: {len(positive_points)}, {negative_prompt}: {len(negative_points)}")
        
        # Run inference for positive class
        print(f"\nInferring positive class ({positive_prompt})...")
        pos_pred_boxes, pos_pred_logits, pos_pred_count = run_countgd_inference(
            model, transform, image, pos_exemplar_boxes, positive_prompt, confidence_thresh=confidence_thresh
        )
        
        # Run inference for negative class  
        print(f"\nInferring negative class ({negative_prompt})...")
        neg_pred_boxes, neg_pred_logits, neg_pred_count = run_countgd_inference(
            model, transform, image, neg_exemplar_boxes, negative_prompt, confidence_thresh=confidence_thresh
        )
        
        print(f"\nCountGD Predictions - {positive_prompt}: {pos_pred_count}, {negative_prompt}: {neg_pred_count}")
        
        # Create visualization
        fig, axes = plt.subplots(1, 3, figsize=(18, 6))
        fig.suptitle(f'CountGD Evaluation: {image_name}', fontsize=14, fontweight='bold')
        
        # 1. Ground Truth with exemplars (no dots, just exemplar boxes)
        axes[0].imshow(img_array)
        axes[0].set_title(f'Ground Truth + Exemplars\n{positive_prompt}: {len(positive_points)}, {negative_prompt}: {len(negative_points)}', fontsize=10)
        axes[0].axis('off')
        
        # Draw exemplar boxes only (no ground truth points)
        for box in pos_exemplar_boxes:
            x1, y1, x2, y2 = box
            rect = Rectangle((x1, y1), x2-x1, y2-y1, linewidth=2, edgecolor='#0066FF', facecolor='none', linestyle='--')
            axes[0].add_patch(rect)
        
        for box in neg_exemplar_boxes:
            x1, y1, x2, y2 = box
            rect = Rectangle((x1, y1), x2-x1, y2-y1, linewidth=2, edgecolor='#FF0040', facecolor='none', linestyle='--')
            axes[0].add_patch(rect)
        
        # 2. Positive class predictions
        axes[1].imshow(img_array)
        axes[1].set_title(f'{positive_prompt} Predictions\nGT: {len(positive_points)}, Pred: {pos_pred_count}', fontsize=10)
        axes[1].axis('off')
        
        # Draw positive exemplars (dashed)
        for box in pos_exemplar_boxes:
            x1, y1, x2, y2 = box
            rect = Rectangle((x1, y1), x2-x1, y2-y1, linewidth=2, edgecolor='#0066FF', facecolor='none', linestyle='--')
            axes[1].add_patch(rect)
        
        # Draw positive predictions (solid boxes)
        if len(pos_pred_boxes) > 0 and pos_pred_count > 0:
            # Convert from normalized coordinates [0,1] to pixel coordinates
            pos_pred_boxes_pixel = pos_pred_boxes.clone().cpu().numpy()
            
            # CountGD outputs are in format [cx, cy, w, h] normalized
            # Convert to [x1, y1, x2, y2] pixel coordinates
            for box in pos_pred_boxes_pixel:
                cx, cy, w, h = box
                # Convert from normalized [cx, cy, w, h] to pixel [x1, y1, x2, y2]
                x1 = (cx - w/2) * img_width
                y1 = (cy - h/2) * img_height
                x2 = (cx + w/2) * img_width  
                y2 = (cy + h/2) * img_height
                
                # Clamp to image bounds
                x1 = max(0, min(img_width, x1))
                y1 = max(0, min(img_height, y1))
                x2 = max(0, min(img_width, x2))
                y2 = max(0, min(img_height, y2))
                
                if x2 > x1 and y2 > y1:  # Valid box
                    rect = Rectangle((x1, y1), x2-x1, y2-y1, linewidth=2, edgecolor='#0066FF', facecolor='none')
                    axes[1].add_patch(rect)
        
        # 3. Negative class predictions
        axes[2].imshow(img_array)
        axes[2].set_title(f'{negative_prompt} Predictions\nGT: {len(negative_points)}, Pred: {neg_pred_count}', fontsize=10)
        axes[2].axis('off')
        
        # Draw negative exemplars (dashed)
        for box in neg_exemplar_boxes:
            x1, y1, x2, y2 = box
            rect = Rectangle((x1, y1), x2-x1, y2-y1, linewidth=2, edgecolor='#FF0040', facecolor='none', linestyle='--')
            axes[2].add_patch(rect)
        
        # Draw negative predictions (solid boxes)
        if len(neg_pred_boxes) > 0 and neg_pred_count > 0:
            # Convert from normalized coordinates [0,1] to pixel coordinates
            neg_pred_boxes_pixel = neg_pred_boxes.clone().cpu().numpy()
            
            # CountGD outputs are in format [cx, cy, w, h] normalized
            # Convert to [x1, y1, x2, y2] pixel coordinates
            for box in neg_pred_boxes_pixel:
                cx, cy, w, h = box
                # Convert from normalized [cx, cy, w, h] to pixel [x1, y1, x2, y2]
                x1 = (cx - w/2) * img_width
                y1 = (cy - h/2) * img_height
                x2 = (cx + w/2) * img_width
                y2 = (cy + h/2) * img_height
                
                # Clamp to image bounds
                x1 = max(0, min(img_width, x1))
                y1 = max(0, min(img_height, y1))
                x2 = max(0, min(img_width, x2))
                y2 = max(0, min(img_height, y2))
                
                if x2 > x1 and y2 > y1:  # Valid box
                    rect = Rectangle((x1, y1), x2-x1, y2-y1, linewidth=2, edgecolor='#FF0040', facecolor='none')
                    axes[2].add_patch(rect)
        
        plt.tight_layout()
        plt.show()
        
        # Calculate errors
        pos_error = abs(pos_pred_count - len(positive_points))
        neg_error = abs(neg_pred_count - len(negative_points))
        total_error = pos_error + neg_error
        
        print(f"\nEvaluation Results:")
        print(f"  {positive_prompt} - GT: {len(positive_points)}, Pred: {pos_pred_count}, Error: {pos_error}")
        print(f"  {negative_prompt} - GT: {len(negative_points)}, Pred: {neg_pred_count}, Error: {neg_error}")
        print(f"  Total Error: {total_error}")
        print(f"  Confidence Threshold: {confidence_thresh}")
        
        return {
            'positive_gt': len(positive_points),
            'positive_pred': pos_pred_count,
            'positive_error': pos_error,
            'negative_gt': len(negative_points),
            'negative_pred': neg_pred_count,
            'negative_error': neg_error,
            'total_error': total_error
        }
        
    except Exception as e:
        print(f"Error in CountGD evaluation: {e}")
        import traceback
        traceback.print_exc()
        return None

print("CountGD inference functions loaded successfully")

In [None]:
# CountGD Model Evaluation Demo
if model_ready and dataset_ready:
    print("="*80)
    print("COUNTGD MODEL EVALUATION DEMONSTRATION")
    print("="*80)
    
    # Select a few example images to demonstrate CountGD
    example_images = [
        "FOO_INTRA_CFC1_CFC2_077_077_33f3fb.jpg",  # Color difference example
        "HOU_INTRA_BAT1_BAT2_014_012_578d7d.jpg",  # Size difference example
        "OTR_INTRA_NUT1_NUT2_030_055_f00b70.jpg",  # Texture/shape difference example
    ]
    
    # Also include a random inter-category example
    random_inter = dataset.get_random_image(subset="inter")
    if random_inter:
        example_images.append(random_inter)
    
    print(f"Running CountGD evaluation on {len(example_images)} example images...")
    print("This demonstrates how CountGD performs on PairTally's challenging fine-grained counting tasks.\\n")
    
    evaluation_results = []
    
    for i, image_name in enumerate(example_images):
        if image_name in dataset.annotations:
            print(f"\\n{'='*60}")
            print(f"EXAMPLE {i+1}/{len(example_images)}: {image_name}")
            print('='*60)
            
            # Get image info
            annotation = dataset.get_annotation(image_name)
            pos_prompt = annotation.get('positive_prompt', 'Class 1')
            neg_prompt = annotation.get('negative_prompt', 'Class 2')
            category_type = "INTER" if image_name in dataset.inter_annotations else "INTRA"
            
            print(f"Category Type: {category_type}-category")
            print(f"Task: Count '{pos_prompt}' vs '{neg_prompt}'")
            print()
            
            # Run CountGD evaluation with visualization
            try:
                result = visualize_countgd_results(
                    dataset, image_name, countgd_model, countgd_transform, 
                    confidence_thresh=0.3
                )
                
                if result:
                    evaluation_results.append({
                        'image_name': image_name,
                        'category_type': category_type,
                        'positive_class': pos_prompt,
                        'negative_class': neg_prompt,
                        **result
                    })
                    
            except Exception as e:
                print(f"Error evaluating {image_name}: {e}")
                continue
                
        else:
            print(f"Skipping {image_name} - not found in annotations")
    
    # Summary of results
    if evaluation_results:
        print(f"\\n{'='*80}")
        print("COUNTGD EVALUATION SUMMARY")
        print('='*80)
        
        total_pos_error = sum(r['positive_error'] for r in evaluation_results)
        total_neg_error = sum(r['negative_error'] for r in evaluation_results)
        total_images = len(evaluation_results)
        
        avg_pos_error = total_pos_error / total_images
        avg_neg_error = total_neg_error / total_images
        avg_total_error = (total_pos_error + total_neg_error) / total_images
        
        print(f"Results across {total_images} example images:")
        print(f"  Average Positive Class MAE: {avg_pos_error:.2f}")
        print(f"  Average Negative Class MAE: {avg_neg_error:.2f}")
        print(f"  Average Total MAE: {avg_total_error:.2f}")
        print()
        
        print("Per-Image Results:")
        for result in evaluation_results:
            print(f"  {result['image_name'][:30]:<30} | {result['category_type']:<5} | "
                  f"Total Error: {result['total_error']:2d} | "
                  f"{result['positive_class']}: {result['positive_error']:2d} | "
                  f"{result['negative_class']}: {result['negative_error']:2d}")
        
        print(f"\\n{'='*80}")
        print("KEY OBSERVATIONS:")
        print('='*80)
        print("• CountGD uses exemplar boxes and text prompts for object counting")
        print("• The model processes each class (positive/negative) separately")
        print("• Performance varies significantly across different attribute types")
        print("• Fine-grained discrimination (e.g., color, size differences) remains challenging")
        print("• PairTally reveals limitations in current state-of-the-art counting models")
        
else:
    print("CountGD model evaluation demo not available")
    print("Reasons:")
    if not model_ready:
        print("  - CountGD model not loaded (missing dependencies or checkpoints)")
    if not dataset_ready:
        print("  - Dataset not ready (missing images or annotations)")
    print()
    print("To enable this demo:")
    print("  1. Install CountGD dependencies")
    print("  2. Download CountGD model checkpoints")
    print("  3. Ensure PairTally dataset images are available")