In [None]:
# EMNIST Character Classifier - Ensemble of EfficientNetB4 and ResNet152 with Augmentation, Regularization, and Visualization
#  Human-level EMNIST Character Classifier
"""
Required library versions (pip install):
tensorflow>=2.12
tensorflow-datasets>=4.8
scikit-learn>=1.2
matplotlib>=3.7
seaborn>=0.13
"""
import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import numpy as np
from tensorflow.keras.applications import EfficientNetB4, ResNet152
from tensorflow.keras.layers import Dense, Dropout, GlobalAveragePooling2D, Input, Concatenate
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
import pickle


In [None]:
# ==================== CONFIG ====================
IMG_SIZE = 380
BATCH_SIZE = 64
NUM_CLASSES = 62
AUTOTUNE = tf.data.AUTOTUNE
EPOCHS = 20

# ==================== 1. Load Dataset ====================
(ds_train, ds_test), ds_info = tfds.load(
    'emnist/byclass',
    split=['train', 'test'],
    as_supervised=True,
    with_info=True
)

# ==================== 2. Preprocessing ====================
def preprocess(image, label):
    image = tf.image.rot90(image, k=1)
    image = tf.image.flip_left_right(image)
    image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
    image = tf.cast(image, tf.float32) / 255.0
    image = tf.image.grayscale_to_rgb(image)
    label = tf.one_hot(label, NUM_CLASSES)
    return image, label

def augment(image, label):
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_brightness(image, max_delta=0.1)
    image = tf.image.random_contrast(image, 0.8, 1.2)
    return image, label

train_ds = ds_train.map(preprocess, num_parallel_calls=AUTOTUNE).map(augment, num_parallel_calls=AUTOTUNE)
train_ds = train_ds.shuffle(2048).batch(BATCH_SIZE).prefetch(AUTOTUNE)
test_ds = ds_test.map(preprocess, num_parallel_calls=AUTOTUNE).batch(BATCH_SIZE).prefetch(AUTOTUNE)

# ==================== 3. Build Model ====================
input_layer = Input(shape=(IMG_SIZE, IMG_SIZE, 3))

base1 = EfficientNetB4(include_top=False, input_tensor=input_layer, weights='imagenet')
base2 = ResNet152(include_top=False, input_tensor=input_layer, weights='imagenet')

for layer in base1.layers:
    layer.trainable = False
for layer in base2.layers:
    layer.trainable = False

gap1 = GlobalAveragePooling2D()(base1.output)
gap2 = GlobalAveragePooling2D()(base2.output)

merged = Concatenate()([gap1, gap2])
dense1 = Dense(512, activation='relu')(merged)
drop1 = Dropout(0.5)(dense1)
out = Dense(NUM_CLASSES, activation='softmax')(drop1)

model = Model(inputs=input_layer, outputs=out)

model.compile(
    optimizer=Adam(learning_rate=5e-4),
    loss=tf.keras.losses.CategoricalCrossentropy(label_smoothing=0.1),
    metrics=['accuracy']
)

model.summary()

# ==================== 4. Callbacks ====================
callbacks = [
    EarlyStopping(patience=5, restore_best_weights=True),
    ReduceLROnPlateau(factor=0.5, patience=3, min_lr=1e-6),
    ModelCheckpoint('EMNIST_V4_best_model.h5', save_best_only=True)
]

# ==================== 5. Training ====================
history = model.fit(
    train_ds,
    validation_data=test_ds,
    epochs=EPOCHS,
    callbacks=callbacks
)

# ==================== 6. Plot ====================
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Train Accuracy')
plt.plot(history.history['val_accuracy'], label='Val Accuracy')
plt.legend()
plt.title('Accuracy')

plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Val Loss')
plt.legend()
plt.title('Loss')

plt.tight_layout()
plt.show()

# ============ Save training history ================
with open('emnist_V4_byclass_history.pkl', 'wb') as f:
    pickle.dump(history.history, f)
print("History saved as emnist_V4_byclass_history.pkl")

Epoch 1/20
[1m    2/10906[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m81:32:03[0m 27s/step - accuracy: 0.0352 - loss: 4.7025 

KeyboardInterrupt: 