# Document classifier training

This notebook builds a TensorFlow EfficientNet-based classifier for KYC/AML document images. It expects the dataset to be organized as one folder per class (e.g., `dataset_generator/output_dataset/<class_name>/*.png`).


## Setup
Update the `dataset_root` path if your dataset lives elsewhere. The notebook uses TensorFlow's `image_dataset_from_directory` to split the data into training and validation subsets.


In [None]:
from pathlib import Path
import tensorflow as tf

# Paths and hyperparameters
# Point this to dataset_generator/output_dataset or your own dataset directory
# where files are organized as dataset_root/<class_name>/*.png

dataset_root = Path('dataset_generator/output_dataset')
img_size = (224, 224)
batch_size = 16
epochs = 10
validation_split = 0.2
seed = 1337

print(f"Dataset root: {dataset_root.resolve()}")


## Load the dataset
This uses TensorFlow's `image_dataset_from_directory` to create batched datasets. Caching and prefetching keep the input pipeline fast.


In [None]:
# Ensure the dataset exists
data_dir = dataset_root
if not data_dir.exists():
    raise FileNotFoundError(f"Dataset directory {data_dir} does not exist. Generate it with dataset_generator or update dataset_root.")

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=validation_split,
    subset="training",
    seed=seed,
    image_size=img_size,
    batch_size=batch_size,
)

val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=validation_split,
    subset="validation",
    seed=seed,
    image_size=img_size,
    batch_size=batch_size,
)

class_names = train_ds.class_names
num_classes = len(class_names)
print(f"Classes ({num_classes}): {class_names}")

autotune = tf.data.AUTOTUNE
train_ds = train_ds.shuffle(1000).prefetch(autotune)
val_ds = val_ds.cache().prefetch(autotune)


## Build the model
We start from ImageNet-pretrained EfficientNetB0, add light augmentation, and train only the classification head. You can unfreeze the base later for fine-tuning if needed.


In [None]:
data_augmentation = tf.keras.Sequential([
    tf.keras.layers.RandomFlip('horizontal'),
    tf.keras.layers.RandomRotation(0.05),
    tf.keras.layers.RandomZoom(0.1),
], name='augmentation')

preprocess = tf.keras.applications.efficientnet.preprocess_input

inputs = tf.keras.Input(shape=img_size + (3,))
x = data_augmentation(inputs)
x = preprocess(x)
base_model = tf.keras.applications.EfficientNetB0(
    include_top=False,
    weights='imagenet',
    input_tensor=x,
    pooling='avg'
)
base_model.trainable = False  # start with frozen backbone

x = tf.keras.layers.Dropout(0.3)(base_model.output)
outputs = tf.keras.layers.Dense(num_classes, activation='softmax')(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

model.summary()


## Train
Early stopping and model checkpointing prevent overfitting and keep the best weights.


In [None]:
callbacks = [
    tf.keras.callbacks.EarlyStopping(patience=3, restore_best_weights=True, monitor='val_loss'),
    tf.keras.callbacks.ModelCheckpoint('training/model/efficientnet_document_classifier.keras',
                                      monitor='val_accuracy',
                                      save_best_only=True)
]

history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=epochs,
    callbacks=callbacks
)


## (Optional) Fine-tune the base network
Unfreeze the top layers of the backbone for a few more epochs if you need extra accuracy.


In [None]:
fine_tune_at = 200  # unfreeze the last blocks

base_model.trainable = True
for layer in base_model.layers[:fine_tune_at]:
    layer.trainable = False

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

fine_tune_history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=epochs,
    callbacks=callbacks
)


## Evaluate and export
Save the best-performing model and a label map for inference.


In [None]:
eval_results = model.evaluate(val_ds)
print(f"Validation loss: {eval_results[0]:.4f} - acc: {eval_results[1]:.4f}")

export_dir = Path('training/model')
export_dir.mkdir(parents=True, exist_ok=True)

model_path = export_dir / 'efficientnet_document_classifier.keras'
labels_path = export_dir / 'labels.txt'

model.save(model_path)
labels_path.write_text('
'.join(class_names))

print(f"Saved model to {model_path}")
print(f"Saved label map to {labels_path}")
