# Data Exploration - WAD Dataset
## Khám phá dữ liệu WAD_Images cho Vision-Language Model

In [None]:
# Setup
import sys
sys.path.append('../')

from datasets import load_dataset
from collections import defaultdict, Counter
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
from PIL import Image
import io
import tarfile

sns.set_style('whitegrid')
%matplotlib inline

## 1. Load Metadata

In [None]:
# Load metadata
print("Loading metadata...")
metadata = load_dataset(
    "minhdang0901/WAD_Images",
    data_files={
        "train": "train.json",
        "test": "test_alter.json"
    }
)

print(f"Train samples: {len(metadata['train'])}")
print(f"Test samples: {len(metadata['test'])}")

# Show sample
sample = metadata['train'][0]
print("\nSample structure:")
for key in sample.keys():
    print(f"  - {key}: {type(sample[key]).__name__}")

## 2. Analyze Metadata Distribution

In [None]:
# Extract metadata fields
train_data = metadata['train']

area_types = [s['area_type'] for s in train_data]
weather_conditions = [s['weather_condition'] for s in train_data]
traffic_flow = [s['traffic_flow_rating'] for s in train_data]

# Create subplots
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Area type distribution
area_counts = Counter(area_types)
axes[0].bar(area_counts.keys(), area_counts.values())
axes[0].set_title('Area Type Distribution', fontsize=14, fontweight='bold')
axes[0].set_xlabel('Area Type')
axes[0].set_ylabel('Count')
axes[0].tick_params(axis='x', rotation=45)

# Weather distribution
weather_counts = Counter(weather_conditions)
axes[1].bar(weather_counts.keys(), weather_counts.values(), color='orange')
axes[1].set_title('Weather Condition Distribution', fontsize=14, fontweight='bold')
axes[1].set_xlabel('Weather')
axes[1].set_ylabel('Count')
axes[1].tick_params(axis='x', rotation=45)

# Traffic distribution
traffic_counts = Counter(traffic_flow)
axes[2].bar(traffic_counts.keys(), traffic_counts.values(), color='green')
axes[2].set_title('Traffic Flow Distribution', fontsize=14, fontweight='bold')
axes[2].set_xlabel('Traffic Level')
axes[2].set_ylabel('Count')

plt.tight_layout()
plt.savefig('../experiments/results/metadata_distribution.png', dpi=300, bbox_inches='tight')
plt.show()

## 3. Analyze BBox Annotations

In [None]:
# Load bbox data
print("Loading bbox annotations...")
bbox_dataset = load_dataset(
    "minhdang0901/WAD_Images",
    data_files="all_bboxes.jsonl",
    split="train"
)

print(f"Total bbox entries: {len(bbox_dataset)}")

# Analyze object labels
labels = [entry['label'] for entry in bbox_dataset]
label_counts = Counter(labels)

# Plot top 20 objects
top_objects = dict(label_counts.most_common(20))

plt.figure(figsize=(12, 6))
plt.barh(list(top_objects.keys()), list(top_objects.values()))
plt.xlabel('Count', fontsize=12)
plt.ylabel('Object Label', fontsize=12)
plt.title('Top 20 Detected Objects', fontsize=14, fontweight='bold')
plt.gca().invert_yaxis()
plt.tight_layout()
plt.savefig('../experiments/results/top_objects.png', dpi=300, bbox_inches='tight')
plt.show()

print(f"\nTotal unique objects: {len(label_counts)}")

## 4. Visualize Sample Images with BBoxes

In [None]:
# This cell requires frame_index.pkl to be built first
# Run: python scripts/build_frame_index.py

import pickle
import os

index_file = "../wad_dataset/frame_index.pkl"

if os.path.exists(index_file):
    with open(index_file, 'rb') as f:
        frame_index = pickle.load(f)
    
    # Load and display a sample image
    sample = metadata['train'][0]
    frame_path = sample['frame_path']
    
    if frame_path in frame_index:
        frame_ids = sorted(frame_index[frame_path].keys())
        first_frame_id = frame_ids[0]
        
        frame_info = frame_index[frame_path][first_frame_id]
        
        # Load image
        with tarfile.open(frame_info['shard'], 'r') as tar:
            member = tar.getmember(frame_info['tar_path'])
            file_obj = tar.extractfile(member)
            img = Image.open(io.BytesIO(file_obj.read()))
        
        plt.figure(figsize=(10, 8))
        plt.imshow(img)
        plt.title(f'Sample Image: {frame_path}/{first_frame_id}', fontsize=14)
        plt.axis('off')
        plt.tight_layout()
        plt.show()
else:
    print("Frame index not found. Run: python scripts/build_frame_index.py")

## 5. Instruction Analysis

In [None]:
# Analyze instruction lengths
instruction_lengths = []

for sample in train_data:
    if 'alter' in sample and sample['alter']:
        instruction_lengths.append(len(sample['alter'].split()))
    elif 'QA' in sample and isinstance(sample['QA'], dict):
        if 'A' in sample['QA']:
            instruction_lengths.append(len(sample['QA']['A'].split()))

# Plot distribution
plt.figure(figsize=(10, 5))
plt.hist(instruction_lengths, bins=30, edgecolor='black')
plt.xlabel('Number of Words', fontsize=12)
plt.ylabel('Frequency', fontsize=12)
plt.title('Instruction Length Distribution', fontsize=14, fontweight='bold')
plt.axvline(np.mean(instruction_lengths), color='red', linestyle='--', label=f'Mean: {np.mean(instruction_lengths):.1f}')
plt.legend()
plt.tight_layout()
plt.savefig('../experiments/results/instruction_length_dist.png', dpi=300, bbox_inches='tight')
plt.show()

print(f"Instruction statistics:")
print(f"  Mean: {np.mean(instruction_lengths):.2f} words")
print(f"  Median: {np.median(instruction_lengths):.2f} words")
print(f"  Min: {np.min(instruction_lengths)} words")
print(f"  Max: {np.max(instruction_lengths)} words")

## 6. Summary Statistics

In [None]:
summary = {
    'Total Train Samples': len(metadata['train']),
    'Total Test Samples': len(metadata['test']),
    'Total BBox Annotations': len(bbox_dataset),
    'Unique Object Types': len(label_counts),
    'Unique Area Types': len(area_counts),
    'Unique Weather Conditions': len(weather_counts),
    'Avg Instruction Length': f"{np.mean(instruction_lengths):.2f} words"
}

print("\n" + "="*60)
print("DATASET SUMMARY")
print("="*60)

for key, value in summary.items():
    print(f"  {key}: {value}")

print("="*60)

# Save summary
import json
with open('../experiments/results/data_summary.json', 'w') as f:
    json.dump(summary, f, indent=2)

print("\n✓ Summary saved to experiments/results/data_summary.json")