In [None]:
# analyze_overlaps.ipynb
# (시각화 제거 버전 - 빠르고 간단!)

# ============================================
# 1. 경로 확인 및 설정
# ============================================

import os
import json
import numpy as np
import pandas as pd
import cv2
from tqdm import tqdm
from itertools import combinations
from collections import defaultdict

# 경로 설정
BASE_ROOT = "/home/jkim0094/au31_scratch2/jkim0094/project5/data"

possible_paths = [
    {
        'image': os.path.join(BASE_ROOT, "train/DCM"),
        'label': os.path.join(BASE_ROOT, "train/outputs_json")
    },
    {
        'image': os.path.join(BASE_ROOT, "ephemeral/home/data/train/DCM"),
        'label': os.path.join(BASE_ROOT, "ephemeral/home/data/train/outputs_json")
    },
]

IMAGE_ROOT = None
LABEL_ROOT = None

print("🔍 Searching for correct paths...\n")
for i, paths in enumerate(possible_paths, 1):
    print(f"Option {i}: {paths['image']}")
    if os.path.exists(paths['image']) and os.path.exists(paths['label']):
        IMAGE_ROOT = paths['image']
        LABEL_ROOT = paths['label']
        print(f"  ✅ Found!\n")
        break
    else:
        print(f"  ❌ Not found\n")

if IMAGE_ROOT is None:
    print("❌ Could not find data!")
    print(f"\n🔧 Please check manually:")
    print(f"   ls {BASE_ROOT}")
    raise FileNotFoundError("Data not found")

print("="*60)
print(f"✅ IMAGE_ROOT: {IMAGE_ROOT}")
print(f"✅ LABEL_ROOT: {LABEL_ROOT}")
print("="*60)

# ============================================
# 2. 파일 수집
# ============================================

print("\n🔍 Collecting files...")

jsons = []
for root, dirs, files in os.walk(LABEL_ROOT):
    for fname in files:
        if fname.endswith('.json'):
            rel_path = os.path.relpath(os.path.join(root, fname), LABEL_ROOT)
            jsons.append(rel_path)

jsons = sorted(jsons)

print(f'✅ Found {len(jsons)} JSON files')
if len(jsons) > 0:
    print(f'   Example: {jsons[0]}')
else:
    raise FileNotFoundError("No JSON files found!")

# ============================================
# 3. 클래스 정의
# ============================================

CLASSES = [
    'finger-1', 'finger-2', 'finger-3', 'finger-4', 'finger-5',
    'finger-6', 'finger-7', 'finger-8', 'finger-9', 'finger-10',
    'finger-11', 'finger-12', 'finger-13', 'finger-14', 'finger-15',
    'finger-16', 'finger-17', 'finger-18', 'finger-19', 'Trapezium',
    'Trapezoid', 'Capitate', 'Hamate', 'Scaphoid', 'Lunate',
    'Triquetrum', 'Pisiform', 'Radius', 'Ulna',
]
CLASS2IND = {v: i for i, v in enumerate(CLASSES)}

print(f'\n📚 Total classes: {len(CLASSES)}')

# ============================================
# 4. 마스크 생성 함수
# ============================================

def create_mask_from_json(json_path, image_size=(2048, 2048)):
    """JSON 파일에서 각 클래스별 마스크 생성"""
    with open(json_path, 'r') as f:
        data = json.load(f)
    
    masks = np.zeros((len(CLASSES), image_size[0], image_size[1]), dtype=np.uint8)
    
    annotations = data.get('annotations', [])
    for ann in annotations:
        label = ann.get('label')
        points = ann.get('points', [])
        
        if label in CLASS2IND and len(points) > 2:
            class_idx = CLASS2IND[label]
            pts_array = np.array(points, dtype=np.int32)
            cv2.fillPoly(masks[class_idx], [pts_array], 1)
    
    return masks

# ============================================
# 5. Overlap 분석 (핵심!)
# ============================================

print("\n" + "="*60)
print("🔍 Starting Overlap Analysis...")
print("="*60)

overlap_stats = defaultdict(lambda: {
    'total_pixels': 0,
    'num_images': 0,
    'per_image_pixels': []
})

print(f"\nProcessing {len(jsons)} images...")
for json_file in tqdm(jsons, desc="Analyzing"):
    json_path = os.path.join(LABEL_ROOT, json_file)
    
    # 마스크 생성
    masks = create_mask_from_json(json_path)
    
    # 존재하는 클래스 확인
    present_classes = [i for i in range(29) if masks[i].sum() > 0]
    
    # 2개 조합 확인
    for cls_a, cls_b in combinations(present_classes, 2):
        overlap = (masks[cls_a] == 1) & (masks[cls_b] == 1)
        overlap_pixels = overlap.sum()
        
        if overlap_pixels > 0:
            pair = tuple(sorted([cls_a, cls_b]))
            overlap_stats[pair]['total_pixels'] += overlap_pixels
            overlap_stats[pair]['num_images'] += 1
            overlap_stats[pair]['per_image_pixels'].append(overlap_pixels)

# ============================================
# 6. 결과 정리
# ============================================

print("\n" + "="*60)
print("📊 Calculating Statistics...")
print("="*60)

results = []
for pair, stats in overlap_stats.items():
    cls_a, cls_b = pair
    avg_pixels = stats['total_pixels'] / stats['num_images']
    
    results.append({
        'class_a': CLASSES[cls_a],
        'class_b': CLASSES[cls_b],
        'class_a_idx': cls_a,
        'class_b_idx': cls_b,
        'total_overlap_pixels': stats['total_pixels'],
        'num_images_overlap': stats['num_images'],
        'avg_overlap_pixels': avg_pixels,
        'max_overlap_pixels': max(stats['per_image_pixels']),
        'min_overlap_pixels': min(stats['per_image_pixels']),
    })

df_overlap = pd.DataFrame(results)
df_overlap = df_overlap.sort_values('avg_overlap_pixels', ascending=False)

print(f"\n✅ Total pairs with overlap: {len(df_overlap)}")
print(f"   (Out of {len(list(combinations(range(29), 2)))} possible)")

# ============================================
# 7. 결과 출력
# ============================================

print("\n" + "="*60)
print("🏆 Top 20 Most Overlapping Pairs")
print("="*60)
for idx, row in df_overlap.head(20).iterrows():
    print(f"{row['class_a']:15} ↔ {row['class_b']:15} : "
          f"{row['avg_overlap_pixels']:7.0f} px  "
          f"({row['num_images_overlap']:3} images)")

# ============================================
# 8. 임계값별 필터링
# ============================================

print("\n" + "="*60)
print("📏 Filtering by Threshold")
print("="*60)

thresholds = [50, 100, 200, 500, 1000]
for threshold in thresholds:
    filtered = df_overlap[df_overlap['avg_overlap_pixels'] >= threshold]
    print(f"\n≥ {threshold:4} pixels: {len(filtered):2} pairs")
    if len(filtered) > 0 and len(filtered) <= 10:
        for _, row in filtered.iterrows():
            print(f"  - {row['class_a']:15} ↔ {row['class_b']:15} : {row['avg_overlap_pixels']:7.0f} px")

# ============================================
# 9. JSON 저장
# ============================================

print("\n" + "="*60)
print("💾 Saving Results...")
print("="*60)

SAVE_THRESHOLD = 100

filtered_pairs = df_overlap[df_overlap['avg_overlap_pixels'] >= SAVE_THRESHOLD]

output_data = {
    'metadata': {
        'total_images_analyzed': len(jsons),
        'threshold_pixels': SAVE_THRESHOLD,
        'total_pairs_found': len(filtered_pairs),
        'analysis_date': pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S'),
    },
    'pairs': []
}

for _, row in filtered_pairs.iterrows():
    output_data['pairs'].append({
        'classes': [int(row['class_a_idx']), int(row['class_b_idx'])],
        'class_names': [row['class_a'], row['class_b']],
        'avg_overlap_pixels': float(row['avg_overlap_pixels']),
        'total_overlap_pixels': int(row['total_overlap_pixels']),
        'num_images': int(row['num_images_overlap']),
    })

# 저장
output_file = 'overlap_pairs.json'
with open(output_file, 'w') as f:
    json.dump(output_data, f, indent=2)

print(f"\n✅ Saved to: {output_file}")
print(f"   Threshold: {SAVE_THRESHOLD} pixels")
print(f"   Pairs saved: {len(output_data['pairs'])}")

print("\n📋 Saved pairs:")
for pair in output_data['pairs']:
    print(f"  {pair['class_names'][0]:15} ↔ {pair['class_names'][1]:15} : "
          f"{pair['avg_overlap_pixels']:7.0f} px  "
          f"({pair['num_images']} imgs)")

# ============================================
# 10. 완료
# ============================================

print("\n" + "="*60)
print("✅ Analysis Complete!")
print("="*60)
print(f"\nSummary:")
print(f"  📊 Images analyzed: {len(jsons)}")
print(f"  🔗 Pairs with overlap: {len(df_overlap)}")
print(f"  ✨ Pairs ≥ {SAVE_THRESHOLD}px: {len(filtered_pairs)}")
print(f"  💾 Output: {output_file}")

print("\n📝 Next steps:")
print("  1. Check 'overlap_pairs.json' file")
print("  2. Update config.py:")
print("     OVERLAP_ANALYSIS_FILE = 'overlap_pairs.json'")
print("  3. Modify util.py to load this file")
print("  4. Run training with Overlap Loss!")

print("\n" + "="*60)