# Step 1: Understanding the MNIST Dataset

This notebook explores the MNIST handwritten digits dataset.

In [None]:
import numpy as np
import struct
import matplotlib.pyplot as plt
from collections import Counter

%matplotlib inline

## Load the Dataset

MNIST data is stored in IDX format. We need custom functions to read it.

In [None]:
def read_idx_images(filename):
    with open(filename, 'rb') as f:
        magic, num_images, rows, cols = struct.unpack('>IIII', f.read(16))
        images = np.fromfile(f, dtype=np.uint8).reshape(num_images, rows, cols)
    return images

def read_idx_labels(filename):
    with open(filename, 'rb') as f:
        magic, num_labels = struct.unpack('>II', f.read(8))
        labels = np.fromfile(f, dtype=np.uint8)
    return labels

In [None]:
train_images = read_idx_images('../data/train-images.idx3-ubyte')
train_labels = read_idx_labels('../data/train-labels.idx1-ubyte')
test_images = read_idx_images('../data/t10k-images.idx3-ubyte')
test_labels = read_idx_labels('../data/t10k-labels.idx1-ubyte')

print(f"Training images: {train_images.shape}")
print(f"Training labels: {train_labels.shape}")
print(f"Test images: {test_images.shape}")
print(f"Test labels: {test_labels.shape}")

## Dataset Structure

In [None]:
print(f"Image dimensions: {train_images.shape[1]}x{train_images.shape[2]} pixels")
print(f"Pixel value range: [{train_images.min()}, {train_images.max()}]")
print(f"Data type: {train_images.dtype}")
print(f"Total pixels per image: {train_images.shape[1] * train_images.shape[2]}")

## Visualize Sample Images

In [None]:
fig, axes = plt.subplots(2, 10, figsize=(15, 3))
fig.suptitle('MNIST Sample Images (20 random samples)', fontsize=14)

for i in range(20):
    ax = axes[i // 10, i % 10]
    idx = np.random.randint(0, len(train_images))
    ax.imshow(train_images[idx], cmap='gray')
    ax.set_title(f'{train_labels[idx]}', fontsize=10)
    ax.axis('off')

plt.tight_layout()
plt.show()

## Class Distribution

In [None]:
train_counter = Counter(train_labels)
test_counter = Counter(test_labels)

print("Training set distribution:")
for digit in sorted(train_counter.keys()):
    count = train_counter[digit]
    percentage = (count / len(train_labels)) * 100
    print(f"  Digit {digit}: {count:5d} samples ({percentage:.2f}%)")

print("\nTest set distribution:")
for digit in sorted(test_counter.keys()):
    count = test_counter[digit]
    percentage = (count / len(test_labels)) * 100
    print(f"  Digit {digit}: {count:5d} samples ({percentage:.2f}%)")

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10, 6))
train_counts = [train_counter[i] for i in range(10)]
test_counts = [test_counter[i] for i in range(10)]

x = np.arange(10)
width = 0.35

ax.bar(x - width/2, train_counts, width, label='Training', alpha=0.8)
ax.bar(x + width/2, test_counts, width, label='Test', alpha=0.8)

ax.set_xlabel('Digit Class', fontsize=12)
ax.set_ylabel('Number of Samples', fontsize=12)
ax.set_title('MNIST Class Distribution', fontsize=14)
ax.set_xticks(x)
ax.set_xticklabels(range(10))
ax.legend()
ax.grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.show()

## Pixel Statistics

In [None]:
print(f"Training set mean pixel value: {train_images.mean():.2f}")
print(f"Training set std pixel value: {train_images.std():.2f}")
print(f"Test set mean pixel value: {test_images.mean():.2f}")
print(f"Test set std pixel value: {test_images.std():.2f}")

non_zero_pixels_train = (train_images > 0).sum(axis=(1, 2)).mean()
non_zero_pixels_test = (test_images > 0).sum(axis=(1, 2)).mean()
print(f"\nAvg non-zero pixels per image (train): {non_zero_pixels_train:.2f}")
print(f"Avg non-zero pixels per image (test): {non_zero_pixels_test:.2f}")
print(f"Image sparsity: {(1 - non_zero_pixels_train / 784) * 100:.1f}% pixels are zero")

## Average Digit Images

What does the "average" digit look like for each class?

In [None]:
fig, axes = plt.subplots(2, 5, figsize=(12, 5))
fig.suptitle('Average Digit Images per Class', fontsize=14)

for digit in range(10):
    ax = axes[digit // 5, digit % 5]
    digit_images = train_images[train_labels == digit]
    avg_image = digit_images.mean(axis=0)
    ax.imshow(avg_image, cmap='gray')
    ax.set_title(f'Digit {digit}', fontsize=11)
    ax.axis('off')

plt.tight_layout()
plt.show()

## Inspect Individual Samples

In [None]:
idx = 0
plt.figure(figsize=(4, 4))
plt.imshow(train_images[idx], cmap='gray')
plt.title(f'Label: {train_labels[idx]}', fontsize=14)
plt.axis('off')
plt.show()

print(f"Image shape: {train_images[idx].shape}")
print(f"Label: {train_labels[idx]}")
print(f"Min pixel value: {train_images[idx].min()}")
print(f"Max pixel value: {train_images[idx].max()}")

## Key Takeaways

1. **Dataset size:** 60,000 training + 10,000 test images
2. **Image format:** 28Ã—28 grayscale (784 pixels)
3. **Balanced classes:** Each digit appears ~10% of the time
4. **Sparse images:** Only ~19% of pixels are non-zero
5. **Pixel range:** 0 (black) to 255 (white)