In [9]:
import pandas as pd
import numpy as np
import json
from data_utils import load_json_files

# Load all JSON files from the data directory
data_root = "data"
json_files = load_json_files(data_root=data_root)

# Create summary table
summary_data = []

for json_file in json_files:
    dataset_path = f"{data_root}/{json_file}"
    
    with open(dataset_path, 'r') as f:
        dataset = json.load(f)
    
    # Get basic info
    relation_name = json_file.replace('.json', '')
    
    # Extract from the standard format with "samples" key
    if isinstance(dataset, dict) and 'samples' in dataset:
        samples = dataset['samples']
        num_samples = len(samples)
        
        # Extract unique objects from samples
        all_objects = []
        for sample in samples:
            if isinstance(sample, dict) and 'object' in sample:
                obj = sample['object']
                if isinstance(obj, str):
                    all_objects.append(obj)
        
        unique_objects = sorted(set(all_objects))
        
    # Fallback for other formats
    elif isinstance(dataset, list):
        num_samples = len(dataset)
        all_objects = []
        for item in dataset:
            if isinstance(item, dict) and 'object' in item:
                obj = item['object']
                if isinstance(obj, str):
                    all_objects.append(obj)
        unique_objects = sorted(set(all_objects))
    else:
        num_samples = 0
        unique_objects = []
    
    # Truncate object list if too long for display
    objects_str = ', '.join(unique_objects[:10])
    if len(unique_objects) > 10:
        objects_str += f'... (+{len(unique_objects) - 10} more)'
    
    summary_data.append({
        'Relation': relation_name,
        'File Path': json_file,
        'Num Samples': num_samples,
        'Num Unique Objects': len(unique_objects),
        'Sample Objects': objects_str
    })

# Create DataFrame
summary_df = pd.DataFrame(summary_data)

# Sort by Num Samples in descending order
summary_df = summary_df.sort_values('Num Samples', ascending=False)

# Display summary table
print(f"\n{'='*100}")
print("DATASET SUMMARY TABLE")
print(f"{'='*100}\n")
print(f"Total datasets: {len(json_files)}")
print(f"Total samples across all datasets: {summary_df['Num Samples'].sum()}")
print(f"\n{summary_df.to_string(index=False)}")
print(f"\n{'='*100}")

# Also save to CSV for easy viewing
summary_df.to_csv('dataset_summary.csv', index=False)
print("\nSummary saved to: dataset_summary.csv")

# Filter to datasets with exactly 2 unique objects
binary_df = summary_df[summary_df['Num Unique Objects'] == 2]

print(f"\n{'='*100}")
print("BINARY CLASSIFICATION DATASETS (2 unique objects)")
print(f"{'='*100}\n")
print(f"Total binary datasets: {len(binary_df)}")
print(f"Total samples in binary datasets: {binary_df['Num Samples'].sum()}")
print(f"\n{binary_df.to_string(index=False)}")
print(f"\n{'='*100}")

Found 47 relation files:

DATASET SUMMARY TABLE

Total datasets: 47
Total samples across all datasets: 11089

                              Relation                                   File Path  Num Samples  Num Unique Objects                                                                                                                                                                                                                                                                                         Sample Objects
                 factual/person_mother                  factual/person_mother.json          994                 962                                                                                     Abby May, Abiah Folger, Abigail Adams, Abigail Erick, Adelaide Antici Leopardi, Adelaide Suzanne de Sellon, Adeline Daudet, Adeline Maria de l'Étang, Adriaentje van Geertenryck, Adèle Schillinger... (+952 more)
                 factual/person_father                  factual/pers