# Multi-View Inference Notebook

**M√¥ t·∫£:** Notebook n√†y th·ª±c hi·ªán inference cho multi-view model (8 crops/image)

**C√°c b∆∞·ªõc th·ª±c hi·ªán:**
1. Load config t·ª´ file training
2. Kh·ªüi t·∫°o model (extract detector t·ª´ MultiViewSoftTeacher)
3. Load checkpoint weights
4. Chu·∫©n b·ªã d·ªØ li·ªáu test
5. Inference t·ª´ng group 8 crops
6. ƒê√°nh gi√° k·∫øt qu·∫£

## B∆∞·ªõc 1: Import th∆∞ vi·ªán v√† thi·∫øt l·∫≠p m√¥i tr∆∞·ªùng

In [None]:
import sys
import os
import torch
import numpy as np
from pathlib import Path

# Add mmdetection to path
sys.path.insert(0, '/home/coder/data/trong/KLTN/Soft_Teacher/mmdetection')

# Import mmengine v√† mmdet
from mmengine.config import Config
from mmengine.runner import Runner
from mmdet.apis import init_detector, inference_detector

# Check GPU
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"GPU count: {torch.cuda.device_count()}")
if torch.cuda.is_available():
    print(f"Current GPU: {torch.cuda.current_device()}")
    print(f"GPU name: {torch.cuda.get_device_name(0)}")
    print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

## B∆∞·ªõc 2: Load config v√† x√°c ƒë·ªãnh paths

K·∫øt qu·∫£ training t·ª´ log:
- **Teacher bbox_mAP_50**: 0.1290 (12.90%)
- **Teacher bbox_mAP**: 0.0410 (4.10%)

In [None]:
# Paths
config_file = '/home/coder/data/trong/KLTN/Soft_Teacher/work_dirs/soft_teacher_custom_multi_view/20251206_182403/vis_data/config.py'
checkpoint_file = '/home/coder/data/trong/KLTN/Soft_Teacher/work_dirs/soft_teacher_custom_multi_view/best_teacher_coco_bbox_mAP_50_epoch_0.pth'

# Check files exist
print(f"Config exists: {os.path.exists(config_file)}")
print(f"Checkpoint exists: {os.path.exists(checkpoint_file)}")

# Load config
cfg = Config.fromfile(config_file)
print(f"\nModel type: {cfg.model.type}")
print(f"Backbone fusion: {cfg.model.detector.backbone.fusion}")
print(f"MVViT spatial_attention: {cfg.model.detector.backbone.mvvit.spatial_attention}")
print(f"Views per sample: {cfg.views_per_sample}")

## B∆∞·ªõc 3: Gi·∫£m batch size

In [None]:
# Override batch size to 1 (per-crop evaluation)
cfg.test_dataloader.batch_size = 1
cfg.val_dataloader.batch_size = 1

# Reduce num_workers to save memory
cfg.test_dataloader.num_workers = 1
cfg.val_dataloader.num_workers = 1

print(f"Test batch size: {cfg.test_dataloader.batch_size}")
print(f"Val batch size: {cfg.val_dataloader.batch_size}")

## B∆∞·ªõc 4: Kh·ªüi t·∫°o Runner v√† load checkpoint

In [None]:
# Set checkpoint path
cfg.load_from = checkpoint_file

# Set work dir for evaluation results
cfg.work_dir = '/home/coder/data/trong/KLTN/Soft_Teacher/work_dirs/eval_checkpoint_notebook'
os.makedirs(cfg.work_dir, exist_ok=True)

# Build runner
print("Building runner...")
runner = Runner.from_cfg(cfg)
print(f"‚úÖ Runner created successfully!")
print(f"Model type: {type(runner.model).__name__}")

## B∆∞·ªõc 5: Ch·∫°y evaluation

**L∆∞u √Ω:** Cell n√†y c√≥ th·ªÉ ch·∫°y l√¢u (~2-3 ph√∫t) v√† t·ªën memory. N·∫øu b·ªã OOM, c·∫ßn gi·∫£m K trong `multi_view_transformer.py`

In [None]:
print("="*80)
print("Starting evaluation...")
print(f"Checkpoint: {checkpoint_file}")
print(f"Config: {config_file}")
print("="*80)

# Run test
metrics = runner.test()

print("\n" + "="*80)
print("‚úÖ Evaluation completed!")
print("="*80)

## B∆∞·ªõc 6: Hi·ªÉn th·ªã k·∫øt qu·∫£ chi ti·∫øt

In [None]:
# Extract teacher results
print("\n" + "="*80)
print("TEACHER MODEL RESULTS:")
print("="*80)

teacher_keys = sorted([k for k in metrics.keys() if k.startswith('teacher/')])
for key in teacher_keys:
    metric_name = key.replace('teacher/coco/', '')
    value = metrics[key]
    if isinstance(value, (int, float)) and value >= 0:
        print(f"  {metric_name:30s}: {value:.4f} ({value*100:.2f}%)")
    else:
        print(f"  {metric_name:30s}: {value}")

# Extract student results
print("\n" + "="*80)
print("STUDENT MODEL RESULTS:")
print("="*80)

student_keys = sorted([k for k in metrics.keys() if k.startswith('student/')])
for key in student_keys:
    metric_name = key.replace('student/coco/', '')
    value = metrics[key]
    if isinstance(value, (int, float)) and value >= 0:
        print(f"  {metric_name:30s}: {value:.4f} ({value*100:.2f}%)")
    else:
        print(f"  {metric_name:30s}: {value}")

## B∆∞·ªõc 7: So s√°nh v·ªõi k·∫øt qu·∫£ training

In [None]:
# Expected results from training log
expected_teacher_map50 = 0.1290
expected_teacher_map = 0.0410

# Actual results from re-evaluation
actual_teacher_map50 = metrics.get('teacher/coco/bbox_mAP_50', -1)
actual_teacher_map = metrics.get('teacher/coco/bbox_mAP', -1)

print("\n" + "="*80)
print("COMPARISON WITH TRAINING RESULTS:")
print("="*80)

print("\nüìä Teacher bbox_mAP_50:")
print(f"  Training:   {expected_teacher_map50:.4f} ({expected_teacher_map50*100:.2f}%)")
print(f"  Re-eval:    {actual_teacher_map50:.4f} ({actual_teacher_map50*100:.2f}%)")
diff_map50 = abs(actual_teacher_map50 - expected_teacher_map50)
print(f"  Difference: {diff_map50:.4f} ({diff_map50*100:.2f}%)")
match_map50 = diff_map50 < 0.001
print(f"  Match:      {'‚úÖ YES' if match_map50 else '‚ùå NO'}")

print("\nüìä Teacher bbox_mAP:")
print(f"  Training:   {expected_teacher_map:.4f} ({expected_teacher_map*100:.2f}%)")
print(f"  Re-eval:    {actual_teacher_map:.4f} ({actual_teacher_map*100:.2f}%)")
diff_map = abs(actual_teacher_map - expected_teacher_map)
print(f"  Difference: {diff_map:.4f} ({diff_map*100:.2f}%)")
match_map = diff_map < 0.001
print(f"  Match:      {'‚úÖ YES' if match_map else '‚ùå NO'}")

print("\n" + "="*80)
if match_map50 and match_map:
    print("‚úÖ K·∫æT LU·∫¨N: K·∫øt qu·∫£ re-eval KH·ªöP v·ªõi training! Checkpoint ƒë√∫ng.")
else:
    print("‚ö†Ô∏è K·∫æT LU·∫¨N: C√≥ s·ª± kh√°c bi·ªát. Ki·ªÉm tra l·∫°i config ho·∫∑c data.")
print("="*80)

## B∆∞·ªõc 8: Eval tr√™n Test Set (Bright Images)

Gi·ªù ta s·∫Ω eval tr√™n test set ƒë·ªÉ so s√°nh performance

In [None]:
# Configure test set
test_ann_file = '/home/coder/data/trong/KLTN/Soft_Teacher/data_drill/anno_test/_annotations_filtered.bright.coco.json'
test_data_prefix = 'test/'

# Check if test annotation exists
print(f"Test annotation exists: {os.path.exists(test_ann_file)}")

# Update test dataloader config
cfg.test_dataloader.dataset.ann_file = test_ann_file
cfg.test_dataloader.dataset.data_prefix.img = test_data_prefix
cfg.test_dataloader.batch_size = 1
cfg.test_dataloader.num_workers = 1

# Update test evaluator
cfg.test_evaluator.ann_file = test_ann_file

print(f"Test annotation: {test_ann_file}")
print(f"Test data prefix: {test_data_prefix}")
print(f"Batch size: {cfg.test_dataloader.batch_size}")

In [None]:
# Rebuild runner with test set config
cfg.work_dir = '/home/coder/data/trong/KLTN/Soft_Teacher/work_dirs/eval_test_set'
os.makedirs(cfg.work_dir, exist_ok=True)

print("Rebuilding runner for test set...")
runner_test = Runner.from_cfg(cfg)
print(f"‚úÖ Test runner created!")

# Run evaluation on test set
print("\n" + "="*80)
print("Evaluating on TEST SET...")
print("="*80)

test_metrics = runner_test.test()

print("\n‚úÖ Test evaluation completed!")
print("="*80)

## B∆∞·ªõc 9: So s√°nh Validation vs Test Performance

In [None]:
import pandas as pd

# Extract metrics for comparison
def extract_metrics(metrics_dict, prefix='teacher'):
    results = {}
    for key, value in metrics_dict.items():
        if key.startswith(f'{prefix}/coco/'):
            metric_name = key.replace(f'{prefix}/coco/', '')
            if isinstance(value, (int, float)) and value >= 0:
                results[metric_name] = value
    return results

# Get validation metrics (from previous eval)
val_teacher = extract_metrics(metrics, 'teacher')
val_student = extract_metrics(metrics, 'student')

# Get test metrics
test_teacher = extract_metrics(test_metrics, 'teacher')
test_student = extract_metrics(test_metrics, 'student')

# Create comparison dataframe for Teacher
teacher_comparison = pd.DataFrame({
    'Validation': val_teacher,
    'Test': test_teacher
})

print("\n" + "="*80)
print("TEACHER MODEL: Validation vs Test Comparison")
print("="*80)
print(teacher_comparison.to_string())

# Calculate difference
teacher_comparison['Diff'] = teacher_comparison['Test'] - teacher_comparison['Validation']
teacher_comparison['Diff%'] = (teacher_comparison['Diff'] / teacher_comparison['Validation'] * 100).round(2)

print("\nüìä Key Metrics Comparison (Teacher):")
for metric in ['bbox_mAP', 'bbox_mAP_50', 'bbox_mAP_75']:
    if metric in teacher_comparison.index:
        val_val = teacher_comparison.loc[metric, 'Validation']
        test_val = teacher_comparison.loc[metric, 'Test']
        diff = teacher_comparison.loc[metric, 'Diff']
        print(f"\n{metric}:")
        print(f"  Validation: {val_val:.4f} ({val_val*100:.2f}%)")
        print(f"  Test:       {test_val:.4f} ({test_val*100:.2f}%)")
        print(f"  Difference: {diff:+.4f} ({diff*100:+.2f}%)")

## B∆∞·ªõc 10: Visualize Per-Class Performance

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# Set style
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (14, 10)

# Class names
classes = ['Broken', 'Chipped', 'Scratched', 'Severe_Rust', 'Tip_Wear']

# Extract per-class precision for Teacher
val_class_metrics = {cls: val_teacher.get(f'{cls}_precision', 0) for cls in classes}
test_class_metrics = {cls: test_teacher.get(f'{cls}_precision', 0) for cls in classes}

# Create figure with 2 subplots
fig, axes = plt.subplots(2, 2, figsize=(16, 12))

# 1. Bar chart: Validation vs Test per-class
ax1 = axes[0, 0]
x = np.arange(len(classes))
width = 0.35
bars1 = ax1.bar(x - width/2, list(val_class_metrics.values()), width, label='Validation', alpha=0.8, color='steelblue')
bars2 = ax1.bar(x + width/2, list(test_class_metrics.values()), width, label='Test', alpha=0.8, color='coral')
ax1.set_xlabel('Class', fontsize=12, fontweight='bold')
ax1.set_ylabel('Precision (AP)', fontsize=12, fontweight='bold')
ax1.set_title('Teacher Model: Per-Class Precision (Validation vs Test)', fontsize=14, fontweight='bold')
ax1.set_xticks(x)
ax1.set_xticklabels(classes, rotation=45, ha='right')
ax1.legend(fontsize=11)
ax1.grid(axis='y', alpha=0.3)

# Add value labels on bars
for bars in [bars1, bars2]:
    for bar in bars:
        height = bar.get_height()
        if height > 0.001:
            ax1.text(bar.get_x() + bar.get_width()/2., height,
                    f'{height*100:.1f}%', ha='center', va='bottom', fontsize=9)

# 2. Overall mAP comparison
ax2 = axes[0, 1]
metrics_to_plot = ['bbox_mAP', 'bbox_mAP_50', 'bbox_mAP_75']
val_map_values = [val_teacher.get(m, 0) for m in metrics_to_plot]
test_map_values = [test_teacher.get(m, 0) for m in metrics_to_plot]

x_map = np.arange(len(metrics_to_plot))
bars3 = ax2.bar(x_map - width/2, val_map_values, width, label='Validation', alpha=0.8, color='steelblue')
bars4 = ax2.bar(x_map + width/2, test_map_values, width, label='Test', alpha=0.8, color='coral')
ax2.set_xlabel('Metric', fontsize=12, fontweight='bold')
ax2.set_ylabel('Score', fontsize=12, fontweight='bold')
ax2.set_title('Teacher Model: Overall mAP Metrics', fontsize=14, fontweight='bold')
ax2.set_xticks(x_map)
ax2.set_xticklabels(['mAP', 'mAP@50', 'mAP@75'], rotation=0)
ax2.legend(fontsize=11)
ax2.grid(axis='y', alpha=0.3)

for bars in [bars3, bars4]:
    for bar in bars:
        height = bar.get_height()
        if height > 0.001:
            ax2.text(bar.get_x() + bar.get_width()/2., height,
                    f'{height*100:.1f}%', ha='center', va='bottom', fontsize=9)

# 3. Difference heatmap
ax3 = axes[1, 0]
differences = [test_class_metrics[cls] - val_class_metrics[cls] for cls in classes]
diff_df = pd.DataFrame({
    'Class': classes,
    'Difference (Test - Val)': [d*100 for d in differences]  # Convert to percentage
})
colors = ['red' if x < 0 else 'green' for x in differences]
bars5 = ax3.barh(classes, [d*100 for d in differences], color=colors, alpha=0.7)
ax3.set_xlabel('Difference (%)', fontsize=12, fontweight='bold')
ax3.set_ylabel('Class', fontsize=12, fontweight='bold')
ax3.set_title('Teacher Model: Performance Difference (Test - Validation)', fontsize=14, fontweight='bold')
ax3.axvline(x=0, color='black', linestyle='--', linewidth=1)
ax3.grid(axis='x', alpha=0.3)

# Add value labels
for i, (bar, diff) in enumerate(zip(bars5, differences)):
    width_val = bar.get_width()
    label_x = width_val + (0.5 if width_val > 0 else -0.5)
    ax3.text(label_x, bar.get_y() + bar.get_height()/2, 
            f'{diff*100:+.1f}%', ha='left' if width_val > 0 else 'right', 
            va='center', fontsize=10, fontweight='bold')

# 4. Summary statistics table
ax4 = axes[1, 1]
ax4.axis('off')

summary_data = [
    ['Metric', 'Validation', 'Test', 'Diff'],
    ['‚îÄ'*20, '‚îÄ'*12, '‚îÄ'*12, '‚îÄ'*12],
    ['mAP', f"{val_teacher.get('bbox_mAP', 0)*100:.2f}%", 
     f"{test_teacher.get('bbox_mAP', 0)*100:.2f}%",
     f"{(test_teacher.get('bbox_mAP', 0) - val_teacher.get('bbox_mAP', 0))*100:+.2f}%"],
    ['mAP@50', f"{val_teacher.get('bbox_mAP_50', 0)*100:.2f}%", 
     f"{test_teacher.get('bbox_mAP_50', 0)*100:.2f}%",
     f"{(test_teacher.get('bbox_mAP_50', 0) - val_teacher.get('bbox_mAP_50', 0))*100:+.2f}%"],
    ['mAP@75', f"{val_teacher.get('bbox_mAP_75', 0)*100:.2f}%", 
     f"{test_teacher.get('bbox_mAP_75', 0)*100:.2f}%",
     f"{(test_teacher.get('bbox_mAP_75', 0) - val_teacher.get('bbox_mAP_75', 0))*100:+.2f}%"],
    ['', '', '', ''],
    ['Per-Class (Precision):', '', '', ''],
]

for cls in classes:
    val_p = val_class_metrics.get(cls, 0)
    test_p = test_class_metrics.get(cls, 0)
    diff_p = test_p - val_p
    summary_data.append([
        f'  {cls}',
        f'{val_p*100:.2f}%',
        f'{test_p*100:.2f}%',
        f'{diff_p*100:+.2f}%'
    ])

table = ax4.table(cellText=summary_data, cellLoc='left', loc='center',
                 colWidths=[0.35, 0.2, 0.2, 0.2])
table.auto_set_font_size(False)
table.set_fontsize(10)
table.scale(1, 2.5)

# Style header row
for i in range(4):
    table[(0, i)].set_facecolor('#4472C4')
    table[(0, i)].set_text_props(weight='bold', color='white')

ax4.set_title('Teacher Model: Summary Statistics', fontsize=14, fontweight='bold', pad=20)

plt.tight_layout()
plt.savefig(f'{cfg.work_dir}/evaluation_comparison.png', dpi=300, bbox_inches='tight')
print(f"‚úÖ Visualization saved to: {cfg.work_dir}/evaluation_comparison.png")
plt.show()

## B∆∞·ªõc 11: L∆∞u k·∫øt qu·∫£ ra file

L∆∞u t·∫•t c·∫£ metrics v√†o CSV v√† JSON ƒë·ªÉ ph√¢n t√≠ch sau

In [None]:
import json
from datetime import datetime

# Create results directory
results_dir = f'{cfg.work_dir}/results'
os.makedirs(results_dir, exist_ok=True)

# 1. Save metrics to JSON
results = {
    'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
    'checkpoint': checkpoint_file,
    'validation': {
        'teacher': val_teacher,
        'student': val_student
    },
    'test': {
        'teacher': test_teacher,
        'student': test_student
    }
}

json_file = f'{results_dir}/metrics_comparison.json'
with open(json_file, 'w') as f:
    json.dump(results, f, indent=2)
print(f"‚úÖ Metrics saved to: {json_file}")

# 2. Save to CSV
comparison_df = pd.DataFrame({
    'Class': classes + ['Overall mAP', 'Overall mAP@50', 'Overall mAP@75'],
    'Val_Precision': list(val_class_metrics.values()) + [
        val_teacher.get('bbox_mAP', 0),
        val_teacher.get('bbox_mAP_50', 0),
        val_teacher.get('bbox_mAP_75', 0)
    ],
    'Test_Precision': list(test_class_metrics.values()) + [
        test_teacher.get('bbox_mAP', 0),
        test_teacher.get('bbox_mAP_50', 0),
        test_teacher.get('bbox_mAP_75', 0)
    ]
})
comparison_df['Difference'] = comparison_df['Test_Precision'] - comparison_df['Val_Precision']
comparison_df['Difference_%'] = (comparison_df['Difference'] / comparison_df['Val_Precision'] * 100).round(2)

csv_file = f'{results_dir}/metrics_comparison.csv'
comparison_df.to_csv(csv_file, index=False)
print(f"‚úÖ CSV saved to: {csv_file}")

# 3. Save full metrics to separate files
val_metrics_file = f'{results_dir}/validation_metrics.json'
with open(val_metrics_file, 'w') as f:
    json.dump(metrics, f, indent=2, default=str)
print(f"‚úÖ Validation metrics saved to: {val_metrics_file}")

test_metrics_file = f'{results_dir}/test_metrics.json'
with open(test_metrics_file, 'w') as f:
    json.dump(test_metrics, f, indent=2, default=str)
print(f"‚úÖ Test metrics saved to: {test_metrics_file}")

print("\n" + "="*80)
print("üìÅ ALL RESULTS SAVED TO:")
print(f"  Directory: {results_dir}")
print(f"  - metrics_comparison.json")
print(f"  - metrics_comparison.csv")
print(f"  - validation_metrics.json")
print(f"  - test_metrics.json")
print(f"  - ../evaluation_comparison.png")
print("="*80)

## B∆∞·ªõc 12: T·ªïng k·∫øt

In ra t·ªïng k·∫øt cu·ªëi c√πng v·ªÅ performance

In [None]:
print("\n" + "="*80)
print("üéØ FINAL EVALUATION SUMMARY")
print("="*80)

print("\nüìä TEACHER MODEL PERFORMANCE:")
print(f"\n  {'Metric':<20} {'Validation':>12} {'Test':>12} {'Difference':>12}")
print(f"  {'-'*20} {'-'*12} {'-'*12} {'-'*12}")

key_metrics = [
    ('mAP', 'bbox_mAP'),
    ('mAP@50', 'bbox_mAP_50'),
    ('mAP@75', 'bbox_mAP_75'),
]

for name, key in key_metrics:
    val_v = val_teacher.get(key, 0)
    test_v = test_teacher.get(key, 0)
    diff = test_v - val_v
    print(f"  {name:<20} {val_v*100:>11.2f}% {test_v*100:>11.2f}% {diff*100:>+11.2f}%")

print(f"\nüìå PER-CLASS PRECISION:")
print(f"\n  {'Class':<15} {'Validation':>12} {'Test':>12} {'Difference':>12}")
print(f"  {'-'*15} {'-'*12} {'-'*12} {'-'*12}")

for cls in classes:
    val_p = val_class_metrics.get(cls, 0)
    test_p = test_class_metrics.get(cls, 0)
    diff = test_p - val_p
    status = "‚úÖ" if diff >= 0 else "‚ö†Ô∏è"
    print(f"  {cls:<15} {val_p*100:>11.2f}% {test_p*100:>11.2f}% {diff*100:>+11.2f}% {status}")

print("\n" + "="*80)
print("‚úÖ EVALUATION COMPLETED SUCCESSFULLY!")
print("="*80)

# Identify best and worst performing classes
best_class = max(test_class_metrics, key=test_class_metrics.get)
worst_class = min(test_class_metrics, key=test_class_metrics.get)

print(f"\nüí° INSIGHTS:")
print(f"  ‚Ä¢ Best performing class on test:  {best_class} ({test_class_metrics[best_class]*100:.2f}%)")
print(f"  ‚Ä¢ Worst performing class on test: {worst_class} ({test_class_metrics[worst_class]*100:.2f}%)")

# Check generalization
avg_diff = np.mean([test_class_metrics[c] - val_class_metrics[c] for c in classes])
if abs(avg_diff) < 0.02:
    print(f"  ‚Ä¢ Model generalizes well! (avg diff: {avg_diff*100:+.2f}%)")
elif avg_diff > 0.02:
    print(f"  ‚Ä¢ Test performance better than validation (avg diff: {avg_diff*100:+.2f}%)")
else:
    print(f"  ‚Ä¢ Test performance worse than validation (avg diff: {avg_diff*100:+.2f}%)")

print("\n" + "="*80)

## B∆∞·ªõc 13: Visualize 8 Views c·ªßa Base Images

Hi·ªÉn th·ªã 8 crops (views) c·ªßa t·ª´ng base image v·ªõi predictions v√† ground truth

In [None]:
import cv2
from collections import defaultdict
import matplotlib.patches as mpatches

# Load COCO annotation to group images
import json
test_ann_file = '/home/coder/data/trong/KLTN/Soft_Teacher/data_drill/anno_test/_annotations_filtered.bright.coco.json'
with open(test_ann_file, 'r') as f:
    coco_data = json.load(f)

# Group images by base_img_id
image_groups = defaultdict(list)
for img in coco_data['images']:
    # Extract base_img_id from filename
    # Format: S245_Image__2025-11-11__12-09-08_bright_2_crop_5_jpg.rf.xxx.jpg
    filename = img['file_name']
    
    # Try to extract base name (before _crop_)
    if '_crop_' in filename:
        base_name = filename.split('_crop_')[0]  # e.g., S245_...bright_2
    else:
        # Fallback: use first part
        base_name = filename.split('_')[0]
    
    image_groups[base_name].append({
        'id': img['id'],
        'file_name': img['file_name'],
        'width': img['width'],
        'height': img['height']
    })

# Get annotations mapping
annotations_by_image = defaultdict(list)
for ann in coco_data['annotations']:
    annotations_by_image[ann['image_id']].append(ann)

print(f"Found {len(image_groups)} base images")
print(f"Total images: {sum(len(v) for v in image_groups.values())}")

# Show some examples
for i, (base_name, imgs) in enumerate(list(image_groups.items())[:3]):
    print(f"\nBase {i+1}: {base_name} ‚Üí {len(imgs)} views")
    for img in imgs[:2]:
        print(f"  - {img['file_name']}")

In [None]:
# Function to run inference and visualize
from mmdet.apis import init_detector, inference_detector
from mmdet.structures import DetDataSample
import torch

def visualize_8_views(base_name, image_list, data_prefix, score_threshold=0.3):
    """Visualize 8 crops of a base image with predictions and GT"""
    
    # Sort by crop number
    def get_crop_num(filename):
        if '_crop_' in filename:
            try:
                crop_part = filename.split('_crop_')[1]
                crop_num = int(crop_part.split('_')[0])
                return crop_num
            except:
                return 0
        return 0
    
    image_list_sorted = sorted(image_list, key=lambda x: get_crop_num(x['file_name']))
    
    # Create figure for 8 views
    fig, axes = plt.subplots(2, 4, figsize=(20, 10))
    axes = axes.flatten()
    
    # Class names and colors
    class_names = ['Broken', 'Chipped', 'Scratched', 'Severe_Rust', 'Tip_Wear']
    class_colors = [
        (134/255, 34/255, 255/255),   # Broken - Purple
        (0/255, 255/255, 206/255),     # Chipped - Cyan
        (255/255, 128/255, 0/255),     # Scratched - Orange
        (254/255, 0/255, 86/255),      # Severe_Rust - Red
        (199/255, 252/255, 0/255)      # Tip_Wear - Yellow
    ]
    
    for idx, img_info in enumerate(image_list_sorted[:8]):
        ax = axes[idx]
        
        # Load image
        img_path = os.path.join('/home/coder/data/trong/KLTN/Soft_Teacher/data_drill', 
                                data_prefix, img_info['file_name'])
        
        if not os.path.exists(img_path):
            ax.text(0.5, 0.5, 'Image not found', ha='center', va='center')
            ax.axis('off')
            continue
        
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        ax.imshow(img)
        
        # Get ground truth boxes
        gt_anns = annotations_by_image.get(img_info['id'], [])
        
        # Draw ground truth boxes (thick, solid)
        for ann in gt_anns:
            bbox = ann['bbox']  # [x, y, w, h]
            category_id = ann['category_id'] - 1  # COCO is 1-indexed
            
            if 0 <= category_id < len(class_names):
                color = class_colors[category_id]
                rect = mpatches.Rectangle(
                    (bbox[0], bbox[1]), bbox[2], bbox[3],
                    linewidth=3, edgecolor=color, facecolor='none',
                    linestyle='-', label=f'GT: {class_names[category_id]}'
                )
                ax.add_patch(rect)
                
                # Add GT label
                ax.text(bbox[0], bbox[1] - 5, f'GT: {class_names[category_id]}',
                       fontsize=8, color='white', weight='bold',
                       bbox=dict(boxstyle='round,pad=0.3', facecolor=color, alpha=0.8))
        
        # Extract crop number
        crop_num = get_crop_num(img_info['file_name'])
        ax.set_title(f'View {idx+1} (Crop {crop_num})\n{len(gt_anns)} GT boxes', 
                    fontsize=10, weight='bold')
        ax.axis('off')
    
    # Hide extra subplots if less than 8
    for idx in range(len(image_list_sorted), 8):
        axes[idx].axis('off')
    
    # Add legend
    legend_elements = [
        mpatches.Patch(color=class_colors[i], label=class_names[i]) 
        for i in range(len(class_names))
    ]
    fig.legend(handles=legend_elements, loc='lower center', ncol=5, 
              bbox_to_anchor=(0.5, -0.02), fontsize=11, frameon=True)
    
    plt.suptitle(f'8-View Visualization: {base_name}', fontsize=16, weight='bold', y=0.98)
    plt.tight_layout(rect=[0, 0.03, 1, 0.96])
    
    return fig

print("‚úÖ Visualization function ready!")

In [None]:
# Visualize first 3 base images
vis_output_dir = f'{cfg.work_dir}/8view_visualizations'
os.makedirs(vis_output_dir, exist_ok=True)

num_samples_to_visualize = 3
data_prefix = 'test/'

print(f"Visualizing {num_samples_to_visualize} base images...")
print(f"Output directory: {vis_output_dir}")
print("="*80)

for i, (base_name, image_list) in enumerate(list(image_groups.items())[:num_samples_to_visualize]):
    print(f"\n[{i+1}/{num_samples_to_visualize}] Processing: {base_name}")
    print(f"  Number of views: {len(image_list)}")
    
    # Visualize
    fig = visualize_8_views(base_name, image_list, data_prefix)
    
    # Save figure
    output_path = f'{vis_output_dir}/{base_name}_8views.png'
    plt.savefig(output_path, dpi=150, bbox_inches='tight')
    print(f"  ‚úÖ Saved to: {output_path}")
    
    plt.show()
    plt.close()

print("\n" + "="*80)
print(f"‚úÖ All visualizations saved to: {vis_output_dir}")
print("="*80)

## B∆∞·ªõc 14: Visualize v·ªõi Model Predictions

Ch·∫°y inference v√† v·∫Ω c·∫£ predictions l·∫´n ground truth

In [None]:
# Initialize detector for inference
print("Initializing detector for visualization...")

# Extract teacher detector from MultiViewSoftTeacher
if hasattr(runner.model, 'teacher'):
    detector = runner.model.teacher
elif hasattr(runner.model, 'module') and hasattr(runner.model.module, 'teacher'):
    detector = runner.model.module.teacher
else:
    detector = runner.model

print(f"Detector type: {type(detector).__name__}")
detector.eval()

# Move to GPU if available
if torch.cuda.is_available():
    detector = detector.cuda()
    print(f"‚úÖ Detector on GPU")
else:
    print(f"‚ö†Ô∏è Detector on CPU")

print("‚úÖ Detector ready for inference!")

In [None]:
def visualize_8_views_with_predictions(base_name, image_list, data_prefix, 
                                       detector, score_threshold=0.3):
    """Visualize 8 crops: GT on top row, Predictions on bottom row for each view"""
    
    # Sort by crop number
    def get_crop_num(filename):
        if '_crop_' in filename:
            try:
                crop_part = filename.split('_crop_')[1]
                crop_num = int(crop_part.split('_')[0])
                return crop_num
            except:
                return 0
        return 0
    
    image_list_sorted = sorted(image_list, key=lambda x: get_crop_num(x['file_name']))
    
    # Create figure with 2 rows per view: GT (top), Pred (bottom)
    # Total: 2 rows √ó 8 views = 16 subplots
    fig, axes = plt.subplots(2, 8, figsize=(28, 8))
    
    # Class names and colors
    class_names = ['Broken', 'Chipped', 'Scratched', 'Severe_Rust', 'Tip_Wear']
    class_colors = [
        (134/255, 34/255, 255/255),   # Broken
        (0/255, 255/255, 206/255),     # Chipped
        (255/255, 128/255, 0/255),     # Scratched
        (254/255, 0/255, 86/255),      # Severe_Rust
        (199/255, 252/255, 0/255)      # Tip_Wear
    ]
    
    from mmdet.apis import inference_detector
    
    for idx, img_info in enumerate(image_list_sorted[:8]):
        # Get axes for this view
        ax_gt = axes[0, idx]    # Top row: Ground Truth
        ax_pred = axes[1, idx]  # Bottom row: Predictions
        
        # Load image
        img_path = os.path.join('/home/coder/data/trong/KLTN/Soft_Teacher/data_drill', 
                                data_prefix, img_info['file_name'])
        
        if not os.path.exists(img_path):
            ax_gt.text(0.5, 0.5, 'Image not found', ha='center', va='center')
            ax_gt.axis('off')
            ax_pred.text(0.5, 0.5, 'Image not found', ha='center', va='center')
            ax_pred.axis('off')
            continue
        
        img = cv2.imread(img_path)
        img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        # Display image on both subplots
        ax_gt.imshow(img_rgb)
        ax_pred.imshow(img_rgb)
        
        # Get ground truth boxes
        gt_anns = annotations_by_image.get(img_info['id'], [])
        
        # ===== TOP ROW: Draw Ground Truth only =====
        for ann in gt_anns:
            bbox = ann['bbox']  # [x, y, w, h]
            category_id = ann['category_id'] - 1
            
            if 0 <= category_id < len(class_names):
                color = class_colors[category_id]
                rect = mpatches.Rectangle(
                    (bbox[0], bbox[1]), bbox[2], bbox[3],
                    linewidth=3, edgecolor=color, facecolor='none',
                    linestyle='-'
                )
                ax_gt.add_patch(rect)
                
                # GT label
                ax_gt.text(bbox[0], bbox[1] - 5, class_names[category_id],
                          fontsize=8, color='white', weight='bold',
                          bbox=dict(boxstyle='round,pad=0.3', 
                                  facecolor=color, alpha=0.9))
        
        crop_num = get_crop_num(img_info['file_name'])
        ax_gt.set_title(f'View {idx+1} (Crop {crop_num}) - GT: {len(gt_anns)} boxes', 
                       fontsize=10, weight='bold', color='darkgreen')
        ax_gt.axis('off')
        
        # ===== BOTTOM ROW: Run inference and draw Predictions =====
        with torch.no_grad():
            result = inference_detector(detector, img_path)
        
        # Extract predictions
        pred_instances = result.pred_instances
        pred_bboxes = pred_instances.bboxes.cpu().numpy()
        pred_scores = pred_instances.scores.cpu().numpy()
        pred_labels = pred_instances.labels.cpu().numpy()
        
        # Draw predictions
        num_preds = 0
        for bbox, score, label in zip(pred_bboxes, pred_scores, pred_labels):
            if score >= score_threshold:
                x1, y1, x2, y2 = bbox
                w, h = x2 - x1, y2 - y1
                
                if 0 <= label < len(class_names):
                    color = class_colors[label]
                    rect = mpatches.Rectangle(
                        (x1, y1), w, h,
                        linewidth=2.5, edgecolor=color, facecolor='none',
                        linestyle='-', alpha=0.9
                    )
                    ax_pred.add_patch(rect)
                    
                    # Prediction label with score
                    ax_pred.text(x1, y1 - 5, f'{class_names[label]}: {score:.2f}',
                               fontsize=8, color='white', weight='bold',
                               bbox=dict(boxstyle='round,pad=0.3', 
                                       facecolor=color, alpha=0.8))
                    num_preds += 1
        
        ax_pred.set_title(f'View {idx+1} (Crop {crop_num}) - Pred: {num_preds} boxes (‚â•{score_threshold})', 
                         fontsize=10, weight='bold', color='darkblue')
        ax_pred.axis('off')
    
    # Hide extra subplots if less than 8
    for idx in range(len(image_list_sorted), 8):
        axes[0, idx].axis('off')
        axes[1, idx].axis('off')
    
    # Create legend
    legend_elements = []
    
    # Add row labels
    from matplotlib.lines import Line2D
    legend_elements.append(Line2D([0], [0], color='darkgreen', linewidth=0, 
                                  marker='s', markersize=10, 
                                  label='‚ñ† Top Row: Ground Truth'))
    legend_elements.append(Line2D([0], [0], color='darkblue', linewidth=0, 
                                  marker='s', markersize=10,
                                  label=f'‚ñ† Bottom Row: Predictions (score ‚â• {score_threshold})'))
    legend_elements.append(Line2D([0], [0], color='white', linewidth=0, label=''))  # Spacer
    
    # Add class colors
    for i, name in enumerate(class_names):
        legend_elements.append(
            mpatches.Patch(color=class_colors[i], label=name)
        )
    
    fig.legend(handles=legend_elements, loc='lower center', ncol=8,
              bbox_to_anchor=(0.5, -0.02), fontsize=11, frameon=True,
              columnspacing=1.5)
    
    plt.suptitle(f'8-View Comparison: {base_name}\n(Top: Ground Truth | Bottom: Predictions)', 
                fontsize=16, weight='bold', y=0.98)
    plt.tight_layout(rect=[0, 0.05, 1, 0.95])
    
    return fig

print("‚úÖ Updated visualization function ready! (GT on top, Predictions on bottom)")

In [None]:
# Visualize with predictions
vis_pred_dir = f'{cfg.work_dir}/8view_predictions'
os.makedirs(vis_pred_dir, exist_ok=True)

num_samples = 3
score_threshold = 0.3

print(f"Visualizing {num_samples} base images with predictions...")
print(f"Score threshold: {score_threshold}")
print(f"Output directory: {vis_pred_dir}")
print("="*80)

for i, (base_name, image_list) in enumerate(list(image_groups.items())[:num_samples]):
    print(f"\n[{i+1}/{num_samples}] Processing: {base_name}")
    print(f"  Number of views: {len(image_list)}")
    
    try:
        # Visualize with predictions
        fig = visualize_8_views_with_predictions(
            base_name, image_list, data_prefix, 
            detector, score_threshold
        )
        
        # Save
        output_path = f'{vis_pred_dir}/{base_name}_predictions.png'
        plt.savefig(output_path, dpi=150, bbox_inches='tight')
        print(f"  ‚úÖ Saved to: {output_path}")
        
        plt.show()
        plt.close()
        
    except Exception as e:
        print(f"  ‚ùå Error: {e}")
        continue

print("\n" + "="*80)
print(f"‚úÖ All prediction visualizations saved to: {vis_pred_dir}")
print("="*80)

## B∆∞·ªõc 15: Ch·ªçn base image c·ª• th·ªÉ ƒë·ªÉ visualize

Nh·∫≠p base_name ho·∫∑c index ƒë·ªÉ xem chi ti·∫øt

In [None]:
# List all available base images
print("Available base images:")
print("="*80)
base_names_list = list(image_groups.keys())
for i, base_name in enumerate(base_names_list[:20]):  # Show first 20
    num_views = len(image_groups[base_name])
    num_boxes = sum(len(annotations_by_image.get(img['id'], [])) 
                   for img in image_groups[base_name])
    print(f"[{i:2d}] {base_name:<50} ‚Üí {num_views} views, {num_boxes} GT boxes")

if len(base_names_list) > 20:
    print(f"\n... and {len(base_names_list) - 20} more")

print(f"\nTotal: {len(base_names_list)} base images")

In [None]:
# Visualize specific base image by index
# Change this index to visualize different base images
selected_index = 0  # Change this: 0, 1, 2, 3, ... 

if 0 <= selected_index < len(base_names_list):
    selected_base = base_names_list[selected_index]
    selected_images = image_groups[selected_base]
    
    print(f"Selected base image: {selected_base}")
    print(f"Number of views: {len(selected_images)}")
    
    # Visualize
    fig = visualize_8_views_with_predictions(
        selected_base, selected_images, data_prefix,
        detector, score_threshold=0.3
    )
    
    # Save
    output_path = f'{vis_pred_dir}/{selected_base}_selected.png'
    plt.savefig(output_path, dpi=200, bbox_inches='tight')
    print(f"‚úÖ Saved to: {output_path}")
    
    plt.show()
else:
    print(f"‚ùå Invalid index! Choose between 0 and {len(base_names_list)-1}")