# Dataset Analysis and Statistics for PAI Classification

**Author:** Gerald Torgersen  
**Date:** 2025  
**GitHub:** [github.com/geraldOslo/PAI-meets-AI](https://github.com/geraldOslo/PAI-meets-AI)  

**License**  
SPDX-License-Identifier: MIT  
Copyright (c) 2025 Gerald Torgersen


## Overview

This notebook serves as a preliminary analysis tool for the Periapical Index (PAI) image dataset. Its primary purpose is to compute and display essential statistics required for configuring the main training pipeline. It performs two key tasks:

1.  **Categorical Data Analysis**: It reads the dataset's metadata from CSV files and calculates the distribution (counts and percentages) of key features such as `quadrant`, `tooth`, `root`, and the target `PAI` score. This is crucial for understanding class imbalance and data characteristics.
2.  **Image Pixel Statistics**: It processes the entire image set to compute the channel-wise `mean` and `standard deviation`. These values are essential for normalizing the input data in the training script, which helps stabilize training and improve model convergence.

This notebook is designed to be run **before** model training. The outputs are intended to be copied into the `config.py` file used by the main training notebook.

## Workflow Summary

1.  **Configuration**: The user specifies paths to the metadata CSV files and the root directories containing the images in the "USER CONFIGURATION" cell. Image processing settings like target `IMAGE_SIZE` are also defined here.
2.  **Metadata Loading**: The script loads and combines all specified CSV files into a single pandas DataFrame.
3.  **Categorical Statistics**: It iterates through predefined columns of interest (`quadrant`, `tooth`, `root`, `PAI`) and prints detailed tables showing the frequency and percentage of each unique value.
4.  **Image Path Aggregation**: The script constructs a complete list of absolute paths to all image files by combining the root directories and the filenames from the CSV. It also reports any missing files.
5.  **Image Statistics Calculation**:
    *   A custom PyTorch `Dataset` is used to load images efficiently.
    *   Each image is resized to the specified `IMAGE_SIZE` and converted to a tensor.
    *   The script iterates through all images in batches, calculating the sum and sum-of-squares of pixel values for each color channel (R, G, B) in a single pass. This is a numerically stable and efficient method.
    *   Finally, it computes the overall `mean` and `standard deviation` from these accumulated values.
6.  **Output**: All statistics are printed directly to the console. The user can then manually transfer the calculated `mean` and `std` values to their training configuration.

## Key Analysis Techniques

*   **Categorical Data Analysis**: Vital for identifying potential data entry issues (e.g., duplicate categories like `'B'` and `' B'`) and quantifying class imbalance. Understanding the PAI score distribution is the first step in deciding on strategies like oversampling, class weighting, or using specialized loss functions in the training script.
*   **Image Normalization Statistics (Mean/Std)**: Neural networks train more effectively when input data is normalized (typically to a zero mean and unit variance). This notebook calculates these statistics based on the images as they will be seen by the model (i.e., after resizing). The calculated `mean` and `std` are then used in the `transforms.Normalize` step of the main training pipeline.

## Requirements

*   `pandas`
*   `torch` & `torchvision`
*   `Pillow` (PIL)
*   `tqdm` (for progress bars)

## Usage

1.  Modify the `CSV_FILES` and `IMAGE_ROOT_DIRS` variables in the "USER CONFIGURATION" cell to point to your data.
2.  Ensure the `IMAGE_SIZE` variable matches the input size you plan to use for your neural network.
3.  Run all cells of the notebook sequentially.
4.  Copy the final `Dataset Mean` and `Dataset Std` values and paste them into the `NORMALIZATION` dictionary in your `config.py` file for the training script.

In [None]:
# --- Standard Library and Core ML Imports ---
import os
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from tqdm.notebook import tqdm

In [None]:
# ==============================================================================
#                              USER CONFIGURATION
# ==============================================================================

# 1. List all your CSV files here
CSV_FILES = [
    '/path/to/your/metadata.csv',
    # Add other CSV files if you have more, e.g.,
    # '/path/to/your/second_metadata.csv'
]

# 2. List all the root directories where your images might be stored
IMAGE_ROOT_DIRS = [
    '/path/to/your/images',
    # Add other root directories if you have them
]

# 3. Define image processing settings
BATCH_SIZE = 1  # How many images to process at once (adjust based on RAM)
NUM_WORKERS = 4  # Number of parallel processes for data loading
# ==============================================================================

In [None]:
# --- Part 1: Load and Analyze Categorical Data from CSVs ---

print("--- Step 1: Loading and Analyzing CSV Metadata ---")

try:
    all_dfs = [pd.read_csv(f) for f in CSV_FILES]
    df = pd.concat(all_dfs, ignore_index=True)
    print(f"Successfully loaded a total of {len(df)} records from {len(CSV_FILES)} CSV file(s).\n")
except FileNotFoundError as e:
    print(f"ERROR: Could not find a CSV file. Please check the paths in CSV_FILES. Details: {e}")
    raise

def print_categorical_stats(dataframe, column_name):
    print(f"--- Statistics for '{column_name}' ---")
    counts = dataframe[column_name].value_counts().sort_index()
    percentages = dataframe[column_name].value_counts(normalize=True).sort_index() * 100
    stats_df = pd.DataFrame({'Count': counts, 'Percentage (%)': percentages.round(2)})
    print(stats_df)
    print("-" * 40)

columns_to_analyze = ['quadrant', 'tooth', 'root', 'PAI']
for col in columns_to_analyze:
    if col in df.columns:
        print_categorical_stats(df, col)
    else:
        print(f"Warning: Column '{col}' not found in the CSV file(s).")

In [None]:
# --- Part 2: Calculate Raw Image Mean and Standard Deviation ---

print("\n--- Step 2: Preparing for Raw Image Mean and Standard Deviation Calculation ---")

image_paths = []
missing_files = 0
for filename in df['filename']:
    found_path = None
    for root_dir in IMAGE_ROOT_DIRS:
        potential_path = os.path.join(root_dir, filename)
        if os.path.exists(potential_path):
            found_path = potential_path
            break
    if found_path:
        image_paths.append(found_path)
    else:
        missing_files += 1

print(f"Found {len(image_paths)} image files.")
if missing_files > 0:
    print(f"Warning: {missing_files} image files listed in the CSV were not found.")

class RawImageDataset(Dataset):
    def __init__(self, image_paths, transform=None):
        self.image_paths = image_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image

stats_transform = transforms.Compose([
    transforms.ToTensor(),  # Scales pixels to [0.0, 1.0] and changes to [C, H, W]
])

# Create the Dataset and DataLoader
stats_dataset = RawImageDataset(image_paths, transform=stats_transform)
stats_loader = DataLoader(
    stats_dataset,
    batch_size=BATCH_SIZE, # Must be 1 to handle potentially different image sizes
    num_workers=NUM_WORKERS,
    shuffle=False
)

psum = torch.tensor([0.0, 0.0, 0.0])
psum_sq = torch.tensor([0.0, 0.0, 0.0])
pixel_count = 0

print("\nIterating over original-sized images to compute stats...")
for inputs in tqdm(stats_loader):
    # Sum over height and width dimensions (dim 2 and 3)
    # Since batch_size is 1, dim 0 has size 1.
    psum += inputs.sum(dim=[0, 2, 3])
    psum_sq += (inputs**2).sum(dim=[0, 2, 3])
    
    # KEY CHANGE: Accumulate the pixel count dynamically for each image
    # C, H, W = inputs.shape[1], inputs.shape[2], inputs.shape[3]
    pixel_count += inputs.shape[2] * inputs.shape[3]


# Calculate the final mean and std
total_mean = psum / pixel_count
total_var = (psum_sq / pixel_count) - (total_mean**2)
total_std = torch.sqrt(total_var)

print("\n--- Raw Image Statistics Calculation Complete ---")
print(f"Calculated over {len(stats_dataset)} original-sized images.")
print(f"Dataset Mean (R, G, B): {total_mean.tolist()}")
print(f"Dataset Std (R, G, B):  {total_std.tolist()}")
print("---------------------------------------------")