# üìä TBX11K Dataset Exploration

**Goal**: Understand the TBX11K chest X-ray dataset structure, distribution, and characteristics.

---

## Dataset Overview

- **Total Images**: 11,200 chest X-rays
- **Classes**: Healthy, Sick (non-TB), Active TB, Latent TB, Uncertain
- **Resolution**: 512x512 pixels
- **Format**: PNG/JPEG
- **Annotations**: Bounding boxes for TB regions

---

In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from pathlib import Path
from collections import Counter

# Set style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)

print("‚úÖ Libraries imported successfully")

## 1. Load Dataset Metadata

In [None]:
# Define paths
data_dir = Path('../data/raw')

# Check if dataset exists
if not data_dir.exists():
    print("‚ùå Dataset not found. Please download TBX11K dataset first.")
else:
    print(f"‚úÖ Dataset directory found: {data_dir.absolute()}")
    
# List contents
print("\nDataset contents:")
for item in data_dir.iterdir():
    print(f"  - {item.name}")

## 2. Analyze Class Distribution

In [None]:
# Load metadata (adjust path based on actual dataset structure)
# This will vary depending on how TBX11K is organized

# Example: If metadata is in CSV
# metadata_file = data_dir / 'labels.csv'
# df = pd.read_csv(metadata_file)

# For now, let's explore the directory structure
all_images = list(data_dir.rglob('*.png')) + list(data_dir.rglob('*.jpg')) + list(data_dir.rglob('*.jpeg'))

print(f"Total images found: {len(all_images)}")
print(f"\nSample image paths:")
for img_path in all_images[:5]:
    print(f"  {img_path}")

In [None]:
# Extract classes from folder structure (if organized by class)
class_counts = Counter([img.parent.name for img in all_images])

print("Class distribution:")
for class_name, count in class_counts.most_common():
    print(f"  {class_name}: {count} images ({count/len(all_images)*100:.1f}%)")

In [None]:
# Visualize class distribution
plt.figure(figsize=(10, 6))
classes = list(class_counts.keys())
counts = list(class_counts.values())

plt.bar(classes, counts, color='steelblue', alpha=0.8)
plt.xlabel('Class', fontsize=12)
plt.ylabel('Number of Images', fontsize=12)
plt.title('TBX11K Dataset - Class Distribution', fontsize=14, fontweight='bold')
plt.xticks(rotation=45, ha='right')
plt.grid(axis='y', alpha=0.3)

for i, (cls, cnt) in enumerate(zip(classes, counts)):
    plt.text(i, cnt + 50, f'{cnt}\n({cnt/len(all_images)*100:.1f}%)', 
             ha='center', va='bottom', fontsize=10)

plt.tight_layout()
plt.show()

## 3. Analyze Image Properties

In [None]:
# Sample images for analysis (to avoid loading all 11K images)
sample_size = min(500, len(all_images))
sample_images = np.random.choice(all_images, sample_size, replace=False)

image_stats = []

print(f"Analyzing {sample_size} sample images...")
for img_path in sample_images:
    try:
        img = Image.open(img_path)
        image_stats.append({
            'width': img.width,
            'height': img.height,
            'mode': img.mode,
            'size_mb': os.path.getsize(img_path) / (1024*1024)
        })
    except Exception as e:
        print(f"Error loading {img_path}: {e}")

stats_df = pd.DataFrame(image_stats)
print("\n‚úÖ Image analysis complete")

In [None]:
# Display statistics
print("Image Statistics:")
print(f"\nResolution:")
print(stats_df[['width', 'height']].describe())

print(f"\nImage Modes:")
print(stats_df['mode'].value_counts())

print(f"\nFile Sizes:")
print(stats_df['size_mb'].describe())

In [None]:
# Visualize image size distribution
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Resolution distribution
axes[0].scatter(stats_df['width'], stats_df['height'], alpha=0.5, color='steelblue')
axes[0].set_xlabel('Width (pixels)')
axes[0].set_ylabel('Height (pixels)')
axes[0].set_title('Image Resolution Distribution')
axes[0].grid(True, alpha=0.3)

# File size distribution
axes[1].hist(stats_df['size_mb'], bins=30, color='coral', alpha=0.7)
axes[1].set_xlabel('File Size (MB)')
axes[1].set_ylabel('Frequency')
axes[1].set_title('File Size Distribution')
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 4. Visualize Sample Images

In [None]:
# Display sample images from each class
samples_per_class = 3

fig, axes = plt.subplots(len(classes), samples_per_class, figsize=(15, len(classes) * 4))

if len(classes) == 1:
    axes = [axes]

for i, class_name in enumerate(classes):
    class_images = [img for img in all_images if img.parent.name == class_name]
    samples = np.random.choice(class_images, min(samples_per_class, len(class_images)), replace=False)
    
    for j, img_path in enumerate(samples):
        img = Image.open(img_path)
        axes[i][j].imshow(img, cmap='gray' if img.mode == 'L' else None)
        axes[i][j].axis('off')
        axes[i][j].set_title(f'{class_name}\n{img.size[0]}x{img.size[1]}', fontsize=10)

plt.suptitle('Sample Chest X-rays from Each Class', fontsize=16, fontweight='bold', y=1.0)
plt.tight_layout()
plt.show()

## 5. Pixel Intensity Analysis

In [None]:
# Analyze pixel intensity distributions
sample_for_intensity = np.random.choice(all_images, min(100, len(all_images)), replace=False)

intensity_distributions = {}

for img_path in sample_for_intensity:
    class_name = img_path.parent.name
    img = Image.open(img_path).convert('L')  # Convert to grayscale
    img_array = np.array(img)
    
    if class_name not in intensity_distributions:
        intensity_distributions[class_name] = []
    
    intensity_distributions[class_name].extend(img_array.flatten())

print("‚úÖ Pixel intensity analysis complete")

In [None]:
# Plot intensity distributions by class
plt.figure(figsize=(12, 6))

for class_name, intensities in intensity_distributions.items():
    plt.hist(intensities, bins=50, alpha=0.5, label=class_name, density=True)

plt.xlabel('Pixel Intensity (0-255)', fontsize=12)
plt.ylabel('Density', fontsize=12)
plt.title('Pixel Intensity Distribution by Class', fontsize=14, fontweight='bold')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 6. Create Train/Val/Test Splits

In [None]:
from sklearn.model_selection import train_test_split

# Create dataframe with all images
data = []
for img_path in all_images:
    data.append({
        'image_path': str(img_path),
        'class': img_path.parent.name
    })

df_all = pd.DataFrame(data)
print(f"Total images in dataset: {len(df_all)}")
print(f"\nClass distribution:")
print(df_all['class'].value_counts())

In [None]:
# Split: 70% train, 15% val, 15% test
# Stratified split to maintain class distribution

train_df, temp_df = train_test_split(
    df_all, 
    test_size=0.3, 
    stratify=df_all['class'], 
    random_state=42
)

val_df, test_df = train_test_split(
    temp_df, 
    test_size=0.5, 
    stratify=temp_df['class'], 
    random_state=42
)

print(f"Train set: {len(train_df)} images ({len(train_df)/len(df_all)*100:.1f}%)")
print(f"Validation set: {len(val_df)} images ({len(val_df)/len(df_all)*100:.1f}%)")
print(f"Test set: {len(test_df)} images ({len(test_df)/len(df_all)*100:.1f}%)")

In [None]:
# Verify stratification
print("\nClass distribution across splits:")
print("\nTrain:")
print(train_df['class'].value_counts())
print("\nValidation:")
print(val_df['class'].value_counts())
print("\nTest:")
print(test_df['class'].value_counts())

In [None]:
# Save splits to CSV
splits_dir = Path('../data/splits')
splits_dir.mkdir(parents=True, exist_ok=True)

train_df.to_csv(splits_dir / 'train.csv', index=False)
val_df.to_csv(splits_dir / 'val.csv', index=False)
test_df.to_csv(splits_dir / 'test.csv', index=False)

print(f"‚úÖ Splits saved to {splits_dir}")

## 7. Summary & Next Steps

In [None]:
print("="*60)
print("DATASET SUMMARY")
print("="*60)
print(f"\nTotal Images: {len(df_all):,}")
print(f"\nClasses: {len(classes)}")
for cls in classes:
    count = len(df_all[df_all['class'] == cls])
    print(f"  - {cls}: {count:,} ({count/len(df_all)*100:.1f}%)")

print(f"\nImage Properties:")
print(f"  - Average resolution: {stats_df['width'].mean():.0f}x{stats_df['height'].mean():.0f}")
print(f"  - Average file size: {stats_df['size_mb'].mean():.2f} MB")
print(f"  - Image mode: {stats_df['mode'].mode()[0]}")

print(f"\nData Splits:")
print(f"  - Train: {len(train_df):,} images")
print(f"  - Validation: {len(val_df):,} images")
print(f"  - Test: {len(test_df):,} images")

print(f"\n{'='*60}")
print("NEXT STEPS")
print("="*60)
print("\n1. ‚úÖ Data exploration complete")
print("2. üìù Run preprocessing notebook (02_preprocessing.ipynb)")
print("3. üß† Train baseline model (03_baseline_model.ipynb)")
print("4. ‚ö° Apply AST training (04_ast_training.ipynb)")
print("5. üöÄ Build Gradio demo")
print("="*60)