# Train KYC/AML Document Classifier

This notebook trains an **EfficientNetB0** model to classify 5 types of ID documents:
- Aadhaar Card
- Driving License
- PAN Card
- Voter ID
- Passport

The trained model will be saved to `training/model/` for use by the inference microservice.

## Step 0: Install Required Dependencies

Before running this notebook, make sure you have activated the conda environment and installed all required packages.

In [5]:
!python --version   
!hostname
!uname -a

In [14]:
# TensorFlow 2.16.1 fully supports Python 3.9‚Äì3.12 and TPUs
%pip install tensorflow==2.16.1 pillow numpy==1.26.4 matplotlib scikit-learn seaborn

## Step 1: Import Required Libraries

In [13]:
import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau, Callback

print(f"TensorFlow version: {tf.__version__}")
print(f"Keras version: {keras.__version__}")

# Avoid serialization issues with TensorBoard
import warnings
warnings.filterwarnings('ignore')

# Custom callback to convert EagerTensor/log arrays into JSON-safe floats
class ConvertHistoryCallback(Callback):
    def on_epoch_end(self, epoch, logs=None):
        if not logs:
            return

        for key, value in logs.items():
            if hasattr(value, "numpy"):
                value = value.numpy()

            if isinstance(value, (np.ndarray, np.generic)):
                if value.shape == ():
                    logs[key] = float(value)
                else:
                    logs[key] = [float(v) for v in value.flatten().tolist()]
            elif isinstance(value, (list, tuple)):
                logs[key] = [float(v) for v in np.asarray(value).flatten().tolist()]
            else:
                try:
                    logs[key] = float(value)
                except (TypeError, ValueError):
                    logs[key] = value

## Step 2: Configure Training Parameters

In [None]:
# Dataset paths
TRAIN_DIR = "../dataset_generator/dataset/train"
VALID_DIR = "../dataset_generator/dataset/valid"

# Training parameters
IMG_SIZE = (224, 224)
BATCH_SIZE = 32
EPOCHS = 20
LEARNING_RATE = 0.001

# Model output
MODEL_DIR = "model"
MODEL_PATH = os.path.join(MODEL_DIR, "efficientnet_model.h5")

# Create model directory
os.makedirs(MODEL_DIR, exist_ok=True)

print(f"Train directory: {TRAIN_DIR}")
print(f"Valid directory: {VALID_DIR}")
print(f"Image size: {IMG_SIZE}")
print(f"Batch size: {BATCH_SIZE}")
print(f"Epochs: {EPOCHS}")
print(f"Model will be saved to: {MODEL_PATH}")

## Step 3: Verify Dataset Structure

In [None]:
def count_images_in_dataset(base_dir):
    """Count images per class"""
    stats = {}
    
    if not os.path.exists(base_dir):
        print(f"‚ùå Directory not found: {base_dir}")
        return stats
    
    for class_name in os.listdir(base_dir):
        class_dir = os.path.join(base_dir, class_name)
        if os.path.isdir(class_dir):
            image_files = [f for f in os.listdir(class_dir) 
                         if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
            stats[class_name] = len(image_files)
    
    return stats

# Count images
train_stats = count_images_in_dataset(TRAIN_DIR)
valid_stats = count_images_in_dataset(VALID_DIR)

print("\n" + "="*60)
print("üìä DATASET STATISTICS")
print("="*60)

print("\nTRAIN SET:")
train_total = 0
for class_name, count in sorted(train_stats.items()):
    print(f"  - {class_name}: {count} images")
    train_total += count
print(f"  TOTAL: {train_total} images")

print("\nVALIDATION SET:")
valid_total = 0
for class_name, count in sorted(valid_stats.items()):
    print(f"  - {class_name}: {count} images")
    valid_total += count
print(f"  TOTAL: {valid_total} images")

print(f"\nüéØ GRAND TOTAL: {train_total + valid_total} images")
print(f"üìã Number of classes: {len(train_stats)}")

if len(train_stats) == 5 and len(valid_stats) == 5:
    print("\n‚úÖ All 5 classes present in both splits!")
else:
    print(f"\n‚ö†Ô∏è Expected 5 classes, found {len(train_stats)} in train, {len(valid_stats)} in valid")

## Step 4: Create Data Generators with Augmentation

In [None]:
# Training data generator with augmentation
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=10,
    width_shift_range=0.1,
    height_shift_range=0.1,
    shear_range=0.1,
    zoom_range=0.1,
    horizontal_flip=True,
    brightness_range=[0.8, 1.2],
    fill_mode='nearest'
)

# Validation data generator (only rescaling, no augmentation)
valid_datagen = ImageDataGenerator(rescale=1./255)

# Create generators
train_generator = train_datagen.flow_from_directory(
    TRAIN_DIR,
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    shuffle=True
)

valid_generator = valid_datagen.flow_from_directory(
    VALID_DIR,
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    shuffle=False
)

# Print class indices
print("\nüìã Class Indices:")
for class_name, idx in sorted(train_generator.class_indices.items(), key=lambda x: x[1]):
    print(f"  {idx}: {class_name}")

NUM_CLASSES = len(train_generator.class_indices)
print(f"\n‚úì Data generators created")
print(f"  Train samples: {train_generator.samples}")
print(f"  Valid samples: {valid_generator.samples}")
print(f"  Number of classes: {NUM_CLASSES}")

## Step 5: Display Sample Images

In [None]:
# Get a batch of images
sample_batch, sample_labels = next(train_generator)

# Display 9 sample images
fig, axes = plt.subplots(3, 3, figsize=(12, 12))
axes = axes.flatten()

# Reverse class indices for display
class_names = {v: k for k, v in train_generator.class_indices.items()}

for i in range(9):
    img = sample_batch[i]
    label_idx = np.argmax(sample_labels[i])
    class_name = class_names[label_idx]
    
    axes[i].imshow(img)
    axes[i].axis('off')
    axes[i].set_title(f'{class_name}', fontsize=12, fontweight='bold')

plt.suptitle('Sample Training Images (with augmentation)', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

# Reset generator
train_generator.reset()

## Step 6: Build EfficientNetB0 Model

In [None]:
def build_model(num_classes, img_size=(224, 224)):
    """Build EfficientNetB0 model with custom top layers"""
    
    # Load pre-trained EfficientNetB0 (without top classification layer)
    base_model = tf.keras.applications.EfficientNetB0(
        include_top=False,
        weights='imagenet',
        input_shape=(*img_size, 3),
        pooling='avg'
    )
    
    # Freeze base model initially
    base_model.trainable = False
    
    # Build model
    inputs = keras.Input(shape=(*img_size, 3))
    x = base_model(inputs, training=False)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.3)(x)
    x = layers.Dense(256, activation='relu')(x)
    x = layers.Dropout(0.2)(x)
    outputs = layers.Dense(num_classes, activation='softmax')(x)
    
    model = keras.Model(inputs, outputs)
    
    return model, base_model

# Build model
model, base_model = build_model(NUM_CLASSES, IMG_SIZE)

# Compile model
model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=LEARNING_RATE),
    loss='categorical_crossentropy',
    metrics=['accuracy', keras.metrics.TopKCategoricalAccuracy(k=2, name='top2_accuracy')]
)

print("\n‚úì Model built and compiled")
print(f"\nModel Summary:")
model.summary()

## Step 7: Configure Training Callbacks

In [None]:
# Custom callback to convert metrics
convert_callback = ConvertHistoryCallback()

# Model checkpoint - save best model
checkpoint_callback = ModelCheckpoint(
    MODEL_PATH,
    monitor='val_accuracy',
    save_best_only=True,
    mode='max',
    verbose=1,
    save_weights_only=False
)

# Early stopping - stop if no improvement
early_stopping_callback = EarlyStopping(
    monitor='val_accuracy',
    patience=5,
    restore_best_weights=True,
    verbose=1
)

# Reduce learning rate when metric plateaus
reduce_lr_callback = ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.5,
    patience=3,
    min_lr=1e-7,
    verbose=1
)

callbacks = [
    convert_callback,
    checkpoint_callback,
    early_stopping_callback,
    reduce_lr_callback
]

print("‚úì Callbacks configured:")
print("  - ConvertHistoryCallback: Convert EagerTensor to float")
print("  - ModelCheckpoint: Save best model based on val_accuracy")
print("  - EarlyStopping: Stop if no improvement for 5 epochs")
print("  - ReduceLROnPlateau: Reduce LR if val_loss plateaus")

## Step 8: Train Model (Phase 1 - Frozen Base)

In [None]:
print("\n" + "="*60)
print("üöÄ PHASE 1: Training with frozen base model")
print("="*60)
print(f"Training for {EPOCHS // 2} epochs...\n")

# Train with frozen base (metrics will be auto-converted by callback)
history_phase1 = model.fit(
    train_generator,
    validation_data=valid_generator,
    epochs=EPOCHS // 2,
    callbacks=callbacks,
    verbose=1
)

print("\n‚úÖ Phase 1 training complete!")

## Step 9: Fine-tune Model (Phase 2 - Unfrozen Base)

In [None]:
print("\n" + "="*60)
print("üöÄ PHASE 2: Fine-tuning with unfrozen base model")
print("="*60)

# Unfreeze base model for fine-tuning
base_model.trainable = True

# Freeze early layers, unfreeze later layers
for layer in base_model.layers[:100]:
    layer.trainable = False

# Recompile with lower learning rate
model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=LEARNING_RATE / 10),
    loss='categorical_crossentropy',
    metrics=['accuracy', keras.metrics.TopKCategoricalAccuracy(k=2, name='top2_accuracy')]
)

print(f"‚úì Base model unfrozen (layers 100+ trainable)")
print(f"‚úì Learning rate reduced to {LEARNING_RATE / 10}")
print(f"\nTraining for {EPOCHS - (EPOCHS // 2)} more epochs...\n")

# Continue training with unfrozen base
history_phase2 = model.fit(
    train_generator,
    validation_data=valid_generator,
    epochs=EPOCHS - (EPOCHS // 2),
    initial_epoch=len(history_phase1.history['loss']),
    callbacks=callbacks,
    verbose=1
)

print("\n‚úÖ Phase 2 fine-tuning complete!")

## Step 10: Plot Training History

In [None]:
# Combine histories and convert to standard Python types
history = {}

def _to_float_list(values):
    """Convert tensors/arrays coming from TF history into plain floats."""
    converted = []
    for value in values:
        if hasattr(value, 'numpy'):
            value = value.numpy()
        if isinstance(value, (list, tuple)):
            value = np.asarray(value)
        if isinstance(value, np.ndarray):
            if value.size == 1:
                value = value.item()
            else:
                # Collapse unexpected vector outputs by averaging to a scalar
                value = float(np.mean(value))
        converted.append(float(value))
    return converted

for key in history_phase1.history.keys():
    phase1_values = _to_float_list(history_phase1.history[key])
    phase2_values = _to_float_list(history_phase2.history[key])
    history[key] = phase1_values + phase2_values

# Plot training history
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Plot accuracy
axes[0, 0].plot(history['accuracy'], label='Train Accuracy')
axes[0, 0].plot(history['val_accuracy'], label='Val Accuracy')
axes[0, 0].axvline(x=len(history_phase1.history['loss']), color='r', linestyle='--', label='Fine-tuning start')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Accuracy')
axes[0, 0].set_title('Model Accuracy')
axes[0, 0].legend()
axes[0, 0].grid(True)

# Plot loss
axes[0, 1].plot(history['loss'], label='Train Loss')
axes[0, 1].plot(history['val_loss'], label='Val Loss')
axes[0, 1].axvline(x=len(history_phase1.history['loss']), color='r', linestyle='--', label='Fine-tuning start')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Loss')
axes[0, 1].set_title('Model Loss')
axes[0, 1].legend()
axes[0, 1].grid(True)

# Plot top-2 accuracy
axes[1, 0].plot(history['top2_accuracy'], label='Train Top-2 Acc')
axes[1, 0].plot(history['val_top2_accuracy'], label='Val Top-2 Acc')
axes[1, 0].axvline(x=len(history_phase1.history['loss']), color='r', linestyle='--', label='Fine-tuning start')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Top-2 Accuracy')
axes[1, 0].set_title('Top-2 Accuracy')
axes[1, 0].legend()
axes[1, 0].grid(True)

# Summary text
final_train_acc = history['accuracy'][-1]
final_val_acc = history['val_accuracy'][-1]
best_val_acc = max(history['val_accuracy'])
summary_text = f"""Final Metrics:

Train Accuracy: {final_train_acc:.4f}
Val Accuracy: {final_val_acc:.4f}
Best Val Accuracy: {best_val_acc:.4f}

Total Epochs: {len(history['loss'])}
Phase 1: {len(history_phase1.history['loss'])} epochs
Phase 2: {len(history_phase2.history['loss'])} epochs
"""
axes[1, 1].text(0.1, 0.5, summary_text, fontsize=12, verticalalignment='center')
axes[1, 1].axis('off')

plt.suptitle('Training History', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

print(f"\nüìä Final Training Accuracy: {final_train_acc:.4f}")
print(f"üìä Final Validation Accuracy: {final_val_acc:.4f}")
print(f"üìä Best Validation Accuracy: {best_val_acc:.4f}")

## Step 11: Evaluate Model

In [None]:
print("\n" + "="*60)
print("üìä EVALUATING MODEL ON VALIDATION SET")
print("="*60)

# Evaluate on validation set
results = model.evaluate(valid_generator, verbose=1)

print(f"\nValidation Metrics:")
print(f"  Loss: {results[0]:.4f}")
print(f"  Accuracy: {results[1]:.4f}")
print(f"  Top-2 Accuracy: {results[2]:.4f}")

## Step 12: Generate Predictions and Confusion Matrix

In [None]:
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns

# Generate predictions
print("Generating predictions...")
valid_generator.reset()
predictions = model.predict(valid_generator, verbose=1)
predicted_classes = np.argmax(predictions, axis=1)

# Get true labels
true_classes = valid_generator.classes
class_labels = list(valid_generator.class_indices.keys())

# Confusion matrix
cm = confusion_matrix(true_classes, predicted_classes)

# Plot confusion matrix
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=class_labels, yticklabels=class_labels)
plt.title('Confusion Matrix', fontsize=16, fontweight='bold')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.tight_layout()
plt.show()

# Classification report
print("\n" + "="*60)
print("üìã CLASSIFICATION REPORT")
print("="*60)
print(classification_report(true_classes, predicted_classes, target_names=class_labels))

## Step 13: Test Model on Sample Images

In [None]:
# Get a batch of validation images
valid_generator.reset()
sample_batch, sample_labels = next(valid_generator)

# Make predictions
sample_predictions = model.predict(sample_batch)

# Display 9 sample predictions
fig, axes = plt.subplots(3, 3, figsize=(15, 15))
axes = axes.flatten()

class_names = {v: k for k, v in valid_generator.class_indices.items()}

for i in range(9):
    img = sample_batch[i]
    true_label_idx = np.argmax(sample_labels[i])
    pred_label_idx = np.argmax(sample_predictions[i])
    confidence = sample_predictions[i][pred_label_idx]
    
    true_label = class_names[true_label_idx]
    pred_label = class_names[pred_label_idx]
    
    # Color: green if correct, red if wrong
    color = 'green' if true_label == pred_label else 'red'
    
    axes[i].imshow(img)
    axes[i].axis('off')
    axes[i].set_title(f'True: {true_label}\nPred: {pred_label}\nConf: {confidence:.2%}',
                     fontsize=10, color=color, fontweight='bold')

plt.suptitle('Sample Predictions on Validation Set', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

## Step 14: Save Model and Class Indices

In [None]:
import json

# Save model (already saved by checkpoint, but save final version)
final_model_path = os.path.join(MODEL_DIR, "efficientnet_model_final.h5")
model.save(final_model_path)
print(f"‚úÖ Final model saved to: {final_model_path}")

# Save class indices for inference
class_indices_path = os.path.join(MODEL_DIR, "class_indices.json")
with open(class_indices_path, 'w') as f:
    json.dump(train_generator.class_indices, f, indent=2)
print(f"‚úÖ Class indices saved to: {class_indices_path}")

# Save model architecture as JSON
model_json_path = os.path.join(MODEL_DIR, "model_architecture.json")
with open(model_json_path, 'w') as f:
    f.write(model.to_json())
print(f"‚úÖ Model architecture saved to: {model_json_path}")

print("\n" + "="*60)
print("‚úÖ TRAINING COMPLETE!")
print("="*60)
print(f"\nModel files saved in: {os.path.abspath(MODEL_DIR)}")
print(f"  - {os.path.basename(MODEL_PATH)} (best model)")
print(f"  - {os.path.basename(final_model_path)} (final model)")
print(f"  - {os.path.basename(class_indices_path)} (class mapping)")
print(f"  - {os.path.basename(model_json_path)} (architecture)")

## Summary

‚úÖ Model trained on 5 document classes  
‚úÖ Two-phase training: frozen base ‚Üí fine-tuned base  
‚úÖ Best model saved based on validation accuracy  
‚úÖ Model ready for inference microservice  

**Model Files:**
```
training/model/
‚îú‚îÄ‚îÄ efficientnet_model.h5          ‚Üê Best model (use this for inference)
‚îú‚îÄ‚îÄ efficientnet_model_final.h5    ‚Üê Final epoch model
‚îú‚îÄ‚îÄ class_indices.json              ‚Üê Class name to index mapping
‚îî‚îÄ‚îÄ model_architecture.json         ‚Üê Model architecture (optional)
```

**Next Steps:**
1. Load the model in the inference microservice
2. Test the API: `uvicorn api.main:app --reload --port 8000`
3. Deploy the microservice using Docker