# Tiny-Cats-Model â€” Data Exploration

This notebook guides you through:
1. Exploring the dataset (class distribution, sample images)
2. Checking image statistics (mean/std per channel)
3. Visualising data augmentation transforms
4. Inspecting model architecture and parameter counts
5. Training a quick sanity-check run
6. Plotting training curves

In [None]:
# Standard imports
import os
import sys
import random
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import seaborn as sns
import torch
import torchvision.transforms as T
from PIL import Image

# Ensure project root is on path
sys.path.insert(0, str(Path('.').resolve()))

sns.set_theme(style='whitegrid')
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
print('PyTorch version:', torch.__version__)
print('CUDA available:', torch.cuda.is_available())

## 1. Dataset Overview

In [None]:
from dataset import cats_dataloader, get_transforms

DATA_DIR = 'data/cats'  # adjust if needed

# Count images per class
data_path = Path(DATA_DIR)
class_counts = {}
if data_path.exists():
    for cls in sorted(data_path.iterdir()):
        if cls.is_dir():
            images = list(cls.glob('*.jpg')) + list(cls.glob('*.jpeg')) + list(cls.glob('*.png'))
            class_counts[cls.name] = len(images)

    print('Classes found:', list(class_counts.keys()))
    for name, count in class_counts.items():
        print(f'  {name}: {count} images')
else:
    print(f'Data directory {DATA_DIR} not found. Run data/download.sh first.')

In [None]:
# Plot class distribution
if class_counts:
    fig, ax = plt.subplots(figsize=(8, 4))
    ax.bar(class_counts.keys(), class_counts.values(), color='steelblue', edgecolor='white')
    ax.set_title('Class Distribution')
    ax.set_xlabel('Class')
    ax.set_ylabel('Number of images')
    plt.tight_layout()
    plt.show()

## 2. Sample Images

In [None]:
def show_samples(data_dir, n_per_class=4):
    path = Path(data_dir)
    classes = sorted([d for d in path.iterdir() if d.is_dir()])
    fig, axes = plt.subplots(len(classes), n_per_class, figsize=(n_per_class * 3, len(classes) * 3))
    if len(classes) == 1:
        axes = [axes]
    for row, cls in enumerate(classes):
        images = list(cls.glob('*.jpg')) + list(cls.glob('*.png'))
        samples = random.sample(images, min(n_per_class, len(images)))
        for col, img_path in enumerate(samples):
            img = Image.open(img_path).convert('RGB')
            axes[row][col].imshow(img)
            axes[row][col].axis('off')
            if col == 0:
                axes[row][col].set_title(cls.name, fontsize=12, fontweight='bold')
    plt.suptitle('Sample Images per Class', fontsize=14)
    plt.tight_layout()
    plt.show()

if data_path.exists():
    show_samples(DATA_DIR)

## 3. Augmentation Visualisation

In [None]:
train_tf = get_transforms('train')
val_tf   = get_transforms('val')

if data_path.exists():
    sample_img_path = next(iter(data_path.rglob('*.jpg')))
    orig = Image.open(sample_img_path).convert('RGB')

    fig, axes = plt.subplots(1, 6, figsize=(18, 3))
    axes[0].imshow(orig.resize((224, 224)))
    axes[0].set_title('Original')
    axes[0].axis('off')
    for i in range(1, 6):
        aug = train_tf(orig).permute(1, 2, 0).numpy()
        aug = np.clip(aug, 0, 1)
        axes[i].imshow(aug)
        axes[i].set_title(f'Augmented {i}')
        axes[i].axis('off')
    plt.suptitle('Train Augmentations')
    plt.tight_layout()
    plt.show()

## 4. Model Inspection

In [None]:
from model import cats_model, count_parameters, SUPPORTED_BACKBONES

print('Supported backbones:', SUPPORTED_BACKBONES)

for backbone in SUPPORTED_BACKBONES:
    m = cats_model(num_classes=2, backbone=backbone, pretrained=False)
    params = count_parameters(m)
    print(f'  {backbone}: {params:,} parameters')

## 5. Quick Training Sanity Check

In [None]:
import torch
from torch.utils.data import DataLoader, TensorDataset
from model import cats_model
from train import get_optimizer, train_one_epoch, validate

# Synthetic data sanity check
x = torch.randn(16, 3, 224, 224)
y = torch.randint(0, 2, (16,))
ds = TensorDataset(x, y)
loader = DataLoader(ds, batch_size=4)

model = cats_model(num_classes=2, backbone='resnet18', pretrained=False)
opt   = get_optimizer(model, 'adamw', lr=1e-3)

train_losses, val_accs = [], []
for epoch in range(5):
    loss = train_one_epoch(model, loader, opt, device='cpu')
    metrics = validate(model, loader, device='cpu')
    train_losses.append(loss)
    val_accs.append(metrics['accuracy'])
    print(f'Epoch {epoch+1}/5 | loss={loss:.4f} | acc={metrics["accuracy"]:.3f}')

print('Sanity check passed!')

## 6. Training Curves

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

ax1.plot(range(1, len(train_losses)+1), train_losses, marker='o', color='steelblue')
ax1.set_title('Training Loss')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')

ax2.plot(range(1, len(val_accs)+1), val_accs, marker='o', color='darkorange')
ax2.set_title('Validation Accuracy')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy')

plt.tight_layout()
plt.show()