In [6]:
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import (
    Conv2D, BatchNormalization, MaxPool2D, SpatialDropout2D,
    Flatten, Dense, Dropout
)
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
import os
import numpy as np
from PIL import Image
import kagglehub
from kagglehub import KaggleDatasetAdapter
import kaggle
from kaggle.api.kaggle_api_extended import KaggleApi

# Create consistent directory structure
BASE_DATA_DIR = 'data'
DATASETS: dict[str, str] = {
    'mnist': os.path.join(BASE_DATA_DIR, 'mnist'),
    'emnist': os.path.join(BASE_DATA_DIR, 'emnist'),
    'handwritten_digits': os.path.join(BASE_DATA_DIR, 'handwritten-digits-not-mnist'),
    'usps': os.path.join(BASE_DATA_DIR, 'usps'),
    'models': os.path.join(BASE_DATA_DIR, 'models')
}

# Create all necessary directories
for directory in DATASETS.values():
    os.makedirs(directory, exist_ok=True)





def print_dataset_summary(dataset_name, count):
    print(f"{dataset_name:<30} : {count:>8,d} samples")
    

def print_class_distribution(labels, dataset_name):
    """Print distribution of classes in a dataset"""
    unique, counts = np.unique(labels, return_counts=True)
    print(f"\n{dataset_name} class distribution:")
    for digit, count in zip(unique, counts):
        print(f"Digit {digit}: {count} samples")
    print(f"Total samples: {len(labels)}")


def print_total_samples(labels, dataset_name):
    """Print only total samples in a dataset"""
    print(f"{dataset_name}: {len(labels)} samples")





def load_handwritten_digits():
    """Load and process Handwritten Digits Dataset (not in MNIST)."""
    api = KaggleApi()
    api.authenticate()
    
    api.dataset_download_files(
        'jcprogjava/handwritten-digits-dataset-not-in-mnist',
        path=DATASETS['handwritten_digits'],
        unzip=True,
        force=True
    )
    
    images = []
    labels = []
    dataset_path = os.path.join(DATASETS['handwritten_digits'], 'dataset')
    
    print(f"\nProcessing Handwritten Digits Dataset (not in MNIST) from: {dataset_path}")
    
    for label in range(10):
        # Single folder_path assignment with correct path structure
        folder_path = os.path.join(dataset_path, str(label), str(label))  # Path to digit/digit folder
        if not os.path.exists(folder_path):
            print(f"Warning: Missing directory for label {label} - {folder_path}")
            continue
            
        file_count = len([name for name in os.listdir(folder_path) if name.endswith('.png')])
        print(f"Processing {file_count} samples for digit {label}")
        
        for filename in os.listdir(folder_path):
            if filename.endswith('.png'):
                img_path = os.path.join(folder_path, filename)
                try:
                    img = Image.open(img_path).convert('L')
                    img = img.resize((28, 28))
                    img_array = np.array(img)
                    img_array = img_array.astype('float32') / 255.0
                    img_array = 1.0 - img_array  # Invert to white-on-black
                    images.append(img_array)
                    labels.append(label)
                except Exception as e:
                    print(f"Error processing {img_path}: {str(e)}")
                    continue
                
    images = np.array(images)
    labels = np.array(labels)
    
    print("\nHandwritten Digits Dataset (not in MNIST) Summary:")
    print("-" * 40)
    print_class_distribution(labels, "Handwritten Digits Dataset (not in MNIST)")
    return images, labels


def load_emnist_data():
    """Load and process EMNIST digits dataset from local directory."""
    # Construct file paths using the DATASETS dictionary
    train_file = os.path.join(DATASETS['emnist'], 'emnist-digits-train.csv')
    
    try:
        # First check if file exists
        if not os.path.exists(train_file):
            print(f"Error: EMNIST data file not found at {train_file}")
            print("\nAvailable files in EMNIST directory:")
            for file in os.listdir(DATASETS['emnist']):
                print(f"- {file}")
            raise FileNotFoundError(f"EMNIST data file not found at {train_file}")
        
        print(f"Loading EMNIST data from: {train_file}")
        
        # Load data using pandas
        import pandas as pd  # Adding import here in case it was missing
        data = pd.read_csv(train_file)
        
        if data.empty:
            raise ValueError("Loaded CSV file is empty")
            
        print(f"Loaded data shape: {data.shape}")
        
        # Extract labels and pixels
        labels = data.iloc[:, 0].values
        pixels = data.iloc[:, 1:].values
        
        print(f"Labels shape: {labels.shape}")
        print(f"Pixels shape: {pixels.shape}")
        
        # Reshape and reorient images
        images = pixels.reshape(-1, 28, 28)
        images = images.transpose(0, 2, 1)  # Correct orientation
        images = np.flip(images, axis=1)  # Vertical flip
        
        # Normalize pixel values to [0, 1]
        images = images.astype('float32') / 255.0
        
        print_total_samples(labels, "EMNIST Dataset")
        print(f"Final images shape: {images.shape}")
        
        # Verify no NaN values
        if np.isnan(images).any():
            raise ValueError("NaN values found in processed images")
            
        return images, labels
        
    except FileNotFoundError as e:
        print(str(e))
        return None, None
    except Exception as e:
        print(f"Error loading EMNIST data: {str(e)}")
        print(f"Error type: {type(e)}")
        import traceback
        print(traceback.format_exc())
        return None, None


def load_mnist_data():
    """Load and process MNIST dataset."""
    (x_train, labels_train), (x_test, labels_test) = mnist.load_data()
    print_total_samples(labels_train, "MNIST Training Set")
    print_total_samples(labels_test, "MNIST Test Set")
    
    x_train = x_train.astype('float32') / 255.0
    x_test = x_test.astype('float32') / 255.0
    return x_train, labels_train, x_test, labels_test


def load_usps_data():
    """Load and process USPS dataset from local file."""
    import h5py
    usps_path = os.path.join(DATASETS['usps'], 'usps.h5')
    
    if not os.path.exists(usps_path):
        raise FileNotFoundError(
            f"USPS dataset file not found at {usps_path}. "
            "Please download usps.h5 from Kaggle and place it in the {DATASETS['usps']} directory"
        )
        
    print(f"Loading USPS data from {usps_path}")
    
    with h5py.File(usps_path, 'r') as hf:
        train = hf.get('train')
        X_train = train.get('data')[:]
        y_train = train.get('target')[:]
    
    # Reshape and resize images from 16x16 to 28x28
    from skimage.transform import resize
    print("Resizing USPS images from 16x16 to 28x28...")
    resized_images = []
    for img in X_train:
        img = img.reshape(16, 16)
        img_resized = resize(img, (28, 28), anti_aliasing=True)
        resized_images.append(img_resized)
    
    images = np.array(resized_images)
    
    # Ensure values are in [0, 1] range
    images = images.astype('float32')
    if images.max() > 1.0:
        images /= 255.0
    
    print_total_samples(y_train, "USPS Dataset")
    return images, y_train



def prepare_data():
    """Prepare and combine all datasets."""
    # Load MNIST
    x_train_mnist, labels_train_mnist, x_test, labels_test = load_mnist_data()
    x_train_mnist = x_train_mnist.reshape(-1, 28, 28, 1)
    y_train_mnist = tf.keras.utils.to_categorical(labels_train_mnist, 10)
    
    # Load handwritten digits data
    handwritten_images, handwritten_labels = load_handwritten_digits()
    handwritten_images = handwritten_images.reshape(-1, 28, 28, 1)
    handwritten_labels_cat = tf.keras.utils.to_categorical(handwritten_labels, 10)
    
    # Load EMNIST
    emnist_images, emnist_labels = load_emnist_data()
    emnist_images = emnist_images.reshape(-1, 28, 28, 1)
    emnist_labels_cat = tf.keras.utils.to_categorical(emnist_labels, 10)
    
    # Load USPS
    usps_images, usps_labels = load_usps_data()
    usps_images = usps_images.reshape(-1, 28, 28, 1)
    usps_labels_cat = tf.keras.utils.to_categorical(usps_labels, 10)
    
    # Combine datasets
    x_train_combined = np.concatenate([x_train_mnist, handwritten_images, emnist_images, usps_images])
    y_train_combined = np.concatenate([y_train_mnist, handwritten_labels_cat, emnist_labels_cat, usps_labels_cat])
    
    print(f"\nTotal combined samples: {len(x_train_combined)}")
    
    # Shuffle and split
    x_train_combined, y_train_combined = shuffle(x_train_combined, y_train_combined, random_state=42)
    x_train, x_val, y_train, y_val = train_test_split(
        x_train_combined, y_train_combined, test_size=0.1, random_state=42
    )
    
    print(f"\nAfter splitting:")
    print(f"Training samples: {len(x_train)}")
    print(f"Validation samples: {len(x_val)}")
    
    print("\nOriginal Datasets:")
    print("-" * 40)
    print_dataset_summary("MNIST Train", len(labels_train_mnist))
    print_dataset_summary("MNIST Test", len(labels_test))
    print_dataset_summary("Handwritten Digits (not MNIST)", len(handwritten_labels))
    print_dataset_summary("EMNIST", len(emnist_labels))
    print_dataset_summary("USPS", len(usps_labels))
    
    # After combining
    total_samples = len(x_train_combined)
    print("\nCombined Dataset:")
    print("-" * 40)
    print_dataset_summary("Total Combined", total_samples)
    
    # After splitting
    print("\nAfter Train/Val Split:")
    print("-" * 40)
    print_dataset_summary("Training Set", len(x_train))
    print_dataset_summary("Validation Set", len(x_val))
    
    return x_train, y_train, x_val, y_val






# Prepare data
x_train, y_train, x_val, y_val = prepare_data()





# Data augmentation
datagen = ImageDataGenerator(
    rotation_range=10,
    width_shift_range=0.1,
    height_shift_range=0.1,
    zoom_range=0.1
)
datagen.fit(x_train)

steps_per_epoch = len(x_train) // 256  # batch_size = 256
print(f"\nAfter augmentation (per epoch):")
print(f"Original training samples: {len(x_train)}")
print(f"Augmented samples per epoch: {steps_per_epoch * 256}")

# Model architecture
model = Sequential([
    Conv2D(32, (3,3), activation='relu', input_shape=(28,28,1)),
    BatchNormalization(),
    Conv2D(32, (3,3), activation='relu'),
    BatchNormalization(),
    MaxPool2D((2,2)),
    SpatialDropout2D(0.2),
    
    Conv2D(64, (3,3), activation='relu'),
    BatchNormalization(),
    Conv2D(64, (3,3), activation='relu'),
    BatchNormalization(),
    MaxPool2D((2,2)),
    SpatialDropout2D(0.2),
    
    Flatten(),
    Dense(256, activation='relu', kernel_regularizer='l2'),
    BatchNormalization(),
    Dropout(0.5),
    Dense(10, activation='softmax')
])

model.compile(
    optimizer='adam',
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

# Training callbacks
callbacks = [
    EarlyStopping(
        monitor='val_loss',
        patience=5,
        restore_best_weights=True
    ),
    ModelCheckpoint(
        os.path.join(DATASETS['models'], 'best_model.h5'),
        save_best_only=True,
        monitor='val_accuracy'
    ),
    ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.2,
        patience=3,
        min_lr=1e-6
    )
]

# Train model
history = model.fit(
    datagen.flow(x_train, y_train, batch_size=256),
    validation_data=(x_val, y_val),
    epochs=1,
    callbacks=callbacks
)

# Save final model in HDF5 format
model.save(os.path.join(DATASETS['models'], 'final_model.h5'))

MNIST Training Set: 60000 samples
MNIST Test Set: 10000 samples
Dataset URL: https://www.kaggle.com/datasets/jcprogjava/handwritten-digits-dataset-not-in-mnist

Processing Handwritten Digits Dataset (not in MNIST) from: data/handwritten-digits-not-mnist/dataset
Processing 10773 samples for digit 0
Processing 10773 samples for digit 1
Processing 10773 samples for digit 2
Processing 10773 samples for digit 3
Processing 10773 samples for digit 4
Processing 10773 samples for digit 5
Processing 10773 samples for digit 6
Processing 10773 samples for digit 7
Processing 10773 samples for digit 8
Processing 10773 samples for digit 9

Handwritten Digits Dataset (not in MNIST) Summary:
----------------------------------------

Handwritten Digits Dataset (not in MNIST) class distribution:
Digit 0: 10773 samples
Digit 1: 10773 samples
Digit 2: 10773 samples
Digit 3: 10773 samples
Digit 4: 10773 samples
Digit 5: 10773 samples
Digit 6: 10773 samples
Digit 7: 10773 samples
Digit 8: 10773 samples
Digit

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)
  self._warn_if_super_not_called()


[1m1460/1460[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 501ms/step - accuracy: 0.6492 - loss: 2.2125



[1m1460/1460[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m756s[0m 513ms/step - accuracy: 0.6493 - loss: 2.2119 - val_accuracy: 0.7557 - val_loss: 0.7698 - learning_rate: 0.0010


