# Document Image Classification - Exploratory Data Analysis

This notebook explores the document image classification dataset, which consists of:
- Images (TIF format)
- OCR text data (TXT files)

The goal is to understand the characteristics of the data before building our models.

In [None]:
# Import necessary libraries
import os, glob, json, re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
import cv2
from nltk.corpus import stopwords

# Configure plots
plt.style.use('seaborn-v0_8')
%matplotlib inline
sns.set(font_scale=1.2)
plt.rcParams['figure.figsize'] = (12, 8)


## 1. Data Loading

First, let's locate our data directories and check the structure.

In [None]:
# Define paths to data directories
base_dir = os.path.join('..', 'data')  # Relative to this notebook
image_dir = os.path.join(base_dir, 'images')
ocr_dir = os.path.join(base_dir, 'ocr')

# Check that directories exist
print(f"Image directory exists: {os.path.exists(image_dir)}")
print(f"OCR directory exists: {os.path.exists(ocr_dir)}")

# List classes (subdirectories)
image_classes = sorted([d for d in os.listdir(image_dir) if os.path.isdir(os.path.join(image_dir, d))])
ocr_classes = sorted([d for d in os.listdir(ocr_dir) if os.path.isdir(os.path.join(ocr_dir, d))])

print(f"\nImage classes: {image_classes}")
print(f"OCR classes: {ocr_classes}")


## 2. Data Distribution

Let's check how many files we have for each class.

In [None]:
# Count files per class
image_counts = {cls: len(glob.glob(os.path.join(image_dir, cls, '*.TIF'))) for cls in image_classes}
ocr_counts = {cls: len(glob.glob(os.path.join(ocr_dir, cls, '*.TIF.txt'))) for cls in ocr_classes}

# Combine and display
counts_df = pd.DataFrame({
    'Class': list(image_counts.keys()),
    'Image Count': list(image_counts.values()),
    'OCR Count': [ocr_counts.get(cls, 0) for cls in image_counts.keys()]
})
display(counts_df)

# Check for mismatches
mismatch = counts_df[counts_df['Image Count'] != counts_df['OCR Count']]
if mismatch.empty:
    print("No mismatches found - all classes have equal numbers of image and OCR files.")
else:
    print("Mismatches detected:")
    display(mismatch)


## 3. Sample Loading

Let's load a few samples from each class to examine them.

In [None]:
def load_samples(num_per_class=5):
    samples = []
    for cls in image_classes:
        image_files = glob.glob(os.path.join(image_dir, cls, '*.TIF'))[:num_per_class]
        for img_path in image_files:
            basename = os.path.basename(img_path)
            ocr_path = os.path.join(ocr_dir, cls, basename + '.txt')
            ocr_text = None
            if os.path.exists(ocr_path):
                with open(ocr_path, 'r', encoding='utf-8', errors='ignore') as f:
                    ocr_text = f.read()
            samples.append({'class': cls, 'image_path': img_path, 'ocr_text': ocr_text})
    return pd.DataFrame(samples)

samples_df = load_samples()
print(f"Loaded {len(samples_df)} samples across {len(samples_df['class'].unique())} classes")
samples_df.head()


## 4. Image Analysis

Let's analyze the characteristics of our image data.

In [None]:
def analyze_image(image_path):
    img = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
    h, w = img.shape[:2]
    channels = 1 if len(img.shape) == 2 else img.shape[2]
    return h, w, channels

stats = samples_df['image_path'].apply(lambda p: pd.Series(analyze_image(p), index=['height', 'width', 'channels']))
samples_df = pd.concat([samples_df, stats], axis=1)

print(samples_df[['height', 'width', 'channels']].describe())

plt.scatter(samples_df['width'], samples_df['height'], alpha=0.6)
plt.xlabel('Width')
plt.ylabel('Height')
plt.title('Image Dimensions Scatter')
plt.show()


## 5. Text Analysis

Now let's examine the OCR text data.

In [None]:
samples_df['text_length'] = samples_df['ocr_text'].apply(lambda x: len(x) if x else 0)
samples_df['word_count'] = samples_df['ocr_text'].apply(lambda x: len(x.split()) if x else 0)

print(samples_df[['text_length', 'word_count']].describe())

sns.boxplot(x='word_count', data=samples_df)
plt.title('Distribution of Word Counts')
plt.show()


Let's look at the most common words in each class.

In [None]:
def get_common_words(texts, n=20):
    words = re.sub(r'[^a-zA-Z\s]', '', texts.lower()).split()
    stop_words = set(stopwords.words('english'))
    filtered = [w for w in words if w not in stop_words and len(w) > 1]
    return Counter(filtered).most_common(n)

for cls in samples_df['class'].unique():
    all_text = ' '.join(samples_df[samples_df['class'] == cls]['ocr_text'].dropna())
    print(f"\nMost common words for class {cls}:")
    for word, count in get_common_words(all_text):
        print(f"{word}: {count}")


## 6. Further Analysis and Next Steps

Fill in observations here after running exploratory steps.