# Q1 — Image Data Exploration

This notebook explores an image dataset: dataset shape, class balance, class prototypes (mean/median images), 
2D visualisation with PCA, and intra-class similarity via MSE histograms.

> Expected data files in `datasets/`:
> - `image_data.npz` with arrays: `train_X`, `train_Y`, `test_X`, `test_Y`.

In [None]:
# Imports
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import math
from sklearn.decomposition import PCA
from pathlib import Path

# Set a larger default figure size
plt.rcParams['figure.figsize'] = (8, 5)

# Load dataset
data_path = Path('datasets/image_data.npz')
if not data_path.exists():
    raise FileNotFoundError("Expected datasets/image_data.npz. Please place the file in datasets/.")

data = np.load(data_path)
train_X = data['train_X']
train_Y = data['train_Y'].ravel()
test_X  = data['test_X']
test_Y  = data['test_Y'].ravel()

train_X.shape, train_Y.shape, test_X.shape, test_Y.shape

## 1a) Dataset dimensions & class distribution

In [None]:
# Basic counts
train_items, n_features = train_X.shape
test_items  = test_X.shape[0]
print(f"Training items: {train_items}, Features: {n_features}")
print(f"Testing  items: {test_items}, Features: {test_X.shape[1]}")

# Class distribution
classes, counts = np.unique(train_Y, return_counts=True)
fig, ax = plt.subplots()
ax.bar(classes, counts)
ax.set_title('Population size of classes in training data')
ax.set_xlabel('Class')
ax.set_ylabel('Count')
plt.show()

## 1b) Class prototypes (Mean/Median images)

In [None]:
# Pick 4 classes at random
rng = np.random.default_rng(42)
selected_classes = rng.choice(np.unique(train_Y), 4, replace=False)

# Guess image side length (square images assumed)
image_side = int(math.sqrt(n_features))

mean_images, median_images = [], []
for cls in selected_classes:
    idx = np.where(train_Y == cls)[0]
    imgs = train_X[idx].reshape(-1, image_side, image_side)
    mean_images.append(imgs.mean(axis=0))
    median_images.append(np.median(imgs, axis=0))

for i, cls in enumerate(selected_classes):
    fig, axes = plt.subplots(1, 2, figsize=(6, 3))
    axes[0].imshow(mean_images[i], cmap='gray')
    axes[0].set_title(f'Class {cls} — Mean')
    axes[0].axis('off')
    axes[1].imshow(median_images[i], cmap='gray')
    axes[1].set_title(f'Class {cls} — Median')
    axes[1].axis('off')
    plt.tight_layout()
    plt.show()

## 1c) PCA scatter for two classes

In [None]:
# Choose two classes (override here if you want specific IDs)
class_1, class_2 = 15, 42
idx1 = np.where(train_Y == class_1)[0]
idx2 = np.where(train_Y == class_2)[0]

X = np.vstack([train_X[idx1], train_X[idx2]])
y = np.hstack([np.full(len(idx1), class_1), np.full(len(idx2), class_2)])

pca = PCA(n_components=2, random_state=0)
X2 = pca.fit_transform(X)

fig, ax = plt.subplots()
ax.scatter(X2[y == class_1, 0], X2[y == class_1, 1], label=f"Class {class_1}", s=12)
ax.scatter(X2[y == class_2, 0], X2[y == class_2, 1], label=f"Class {class_2}", s=12)
ax.set_xlabel('PC1')
ax.set_ylabel('PC2')
ax.set_title('PCA (2D) of two classes')
ax.legend()
plt.show()

## 1d) Intra-class similarity with MSE

In [None]:
def pairwise_mse(A):
    n = A.shape[0]
    out = np.zeros((n, n))
    for i in range(n):
        for j in range(i+1, n):
            mse = np.mean((A[i] - A[j])**2)
            out[i, j] = out[j, i] = mse
    # Return upper triangle values (excluding diagonal) as a flat array
    tri = out[np.triu_indices(n, k=1)]
    return tri

# Subset 1: all images from a chosen class
subset1_class = 6
idx = np.where(train_Y == subset1_class)[0]
subset1 = train_X[idx]

# Subset 2: same number of random images from the full training set
rng = np.random.default_rng(0)
rand_idx = rng.choice(train_items, size=len(idx), replace=False)
subset2 = train_X[rand_idx]

mse1 = pairwise_mse(subset1)
mse2 = pairwise_mse(subset2)

fig, axes = plt.subplots(1, 2, figsize=(10,4))
axes[0].hist(mse1, bins=20)
axes[0].set_title(f'MSE — Subset 1 (class {subset1_class})')
axes[0].set_xlabel('MSE')
axes[0].set_ylabel('Frequency')
axes[1].hist(mse2, bins=20)
axes[1].set_title('MSE — Subset 2 (random classes)')
axes[1].set_xlabel('MSE')
plt.tight_layout()
plt.show()