# Debug Validation Data

This notebook inspects the validation data across all formats to identify label issues.

In [None]:
import os
import sys
from pathlib import Path
import pandas as pd
import webdataset as wds
import tensorflow as tf
import lmdb
import pickle
from collections import Counter

# Detect environment
IS_KAGGLE = "KAGGLE_KERNEL_RUN_TYPE" in os.environ
BASE_DIR = Path('/kaggle/working/format-matters') if IS_KAGGLE else Path('..').resolve()
BUILT_DIR = BASE_DIR / 'data' / 'built'
DATASET = 'cifar10'

print("=" * 80)
print("VALIDATION DATA INSPECTION")
print("=" * 80)

## 1. CSV Format

In [None]:
print("\n1. CSV FORMAT (val.csv)")
csv_path = BUILT_DIR / DATASET / 'csv' / 'default' / 'val.csv'
if csv_path.exists():
    df = pd.read_csv(csv_path)
    print(f"   Total samples: {len(df)}")
    print(f"   Unique labels: {df['label'].nunique()}")
    print(f"   Label distribution:")
    label_counts = df['label'].value_counts().sort_index()
    for label, count in label_counts.items():
        print(f"      Label {label}: {count} samples")
    print(f"\n   First 20 samples:")
    print(df[['path', 'label']].head(20))
else:
    print("   NOT FOUND")

## 2. WebDataset Format

In [None]:
print("\n2. WEBDATASET FORMAT (val-*.tar)")
wds_dir = BUILT_DIR / DATASET / 'webdataset' / 'shard256_none'
val_shards = sorted(wds_dir.glob('val-*.tar'))
if val_shards:
    print(f"   Found {len(val_shards)} shard(s): {[s.name for s in val_shards]}")
    labels = []
    keys = []
    
    for shard in val_shards:
        shard_path = "file://" + shard.as_posix()
        dataset = wds.WebDataset(shard_path)
        
        for i, sample in enumerate(dataset):
            if '__key__' in sample:
                keys.append(sample['__key__'])
            if 'cls' in sample:
                label_bytes = sample['cls']
                label = int(label_bytes.decode('utf-8') if isinstance(label_bytes, bytes) else label_bytes)
                labels.append(label)
            if len(labels) >= 20000:  # Get all val samples
                break
        if len(labels) >= 20000:
            break
    
    print(f"   Total samples inspected: {len(labels)}")
    print(f"   Unique labels: {len(set(labels))}")
    print(f"   Label distribution:")
    label_counts = Counter(labels)
    for label in sorted(label_counts.keys()):
        print(f"      Label {label}: {label_counts[label]} samples")
    print(f"\n   First 20 samples:")
    for i in range(min(20, len(labels))):
        print(f"      {keys[i]}: label={labels[i]}")
else:
    print("   NOT FOUND")

## 3. TFRecord Format

In [None]:
print("\n3. TFRECORD FORMAT (val-*.tfrecord)")
tfr_dir = BUILT_DIR / DATASET / 'tfrecord' / 'shard256_none'
val_shards = sorted(tfr_dir.glob('val-*.tfrecord'))
if val_shards:
    print(f"   Found {len(val_shards)} shard(s): {[s.name for s in val_shards]}")
    labels = []
    
    for shard in val_shards:
        dataset = tf.data.TFRecordDataset(str(shard))
        
        for i, raw_record in enumerate(dataset):
            example = tf.train.Example()
            example.ParseFromString(raw_record.numpy())
            label = example.features.feature['label'].int64_list.value[0]
            labels.append(int(label))
            if len(labels) >= 20000:
                break
        if len(labels) >= 20000:
            break
    
    print(f"   Total samples inspected: {len(labels)}")
    print(f"   Unique labels: {len(set(labels))}")
    print(f"   Label distribution:")
    label_counts = Counter(labels)
    for label in sorted(label_counts.keys()):
        print(f"      Label {label}: {label_counts[label]} samples")
    print(f"\n   First 20 labels: {labels[:20]}")
else:
    print("   NOT FOUND")

## 4. LMDB Format

In [None]:
print("\n4. LMDB FORMAT (val.lmdb)")
lmdb_path = BUILT_DIR / DATASET / 'lmdb' / 'compress_none' / 'val.lmdb'
if lmdb_path.exists():
    env = lmdb.open(str(lmdb_path), readonly=True, lock=False)
    labels = []
    
    with env.begin() as txn:
        metadata_bytes = txn.get(b'__metadata__')
        if metadata_bytes:
            metadata = pickle.loads(metadata_bytes)
            num_samples = metadata['num_samples']
            print(f"   Total samples: {num_samples}")
            
            # Read all samples
            for idx in range(min(20000, num_samples)):
                key = f"{idx:08d}".encode('utf-8')
                entry_bytes = txn.get(key)
                if entry_bytes:
                    entry = pickle.loads(entry_bytes)
                    labels.append(entry['label'])
    
    env.close()
    
    print(f"   Total samples inspected: {len(labels)}")
    print(f"   Unique labels: {len(set(labels))}")
    print(f"   Label distribution:")
    label_counts = Counter(labels)
    for label in sorted(label_counts.keys()):
        print(f"      Label {label}: {label_counts[label]} samples")
    print(f"\n   First 20 labels: {labels[:20]}")
else:
    print("   NOT FOUND")

## Summary

Compare the label distributions across all formats to identify discrepancies.

In [None]:
print("\n" + "=" * 80)
print("INSPECTION COMPLETE")
print("=" * 80)