In [10]:
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint

from utils.data_loader import load_and_prepare_data, get_input_shape
from utils.models_util import get_model
from models.training.trainer import train_model
from models.training.evaluate import evaluate_model, plot_history


In [11]:

# Constants
DATA_PATH = "dataset"
MODEL_SAVE_DIR = "saved_models"
BATCH_SIZE = 32
EPOCHS = 20

# Create model save directory if it doesn't exist
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)

# Replace the directory checking code with:
print("Checking data directories...")

# Use actual case for directories
real_dir = os.path.join(DATA_PATH, "Real")
altered_dir = os.path.join(DATA_PATH, "Altered")

# Check if directories exist
if not os.path.exists(real_dir):
    print(f"ERROR: Real fingerprints directory '{real_dir}' not found!")
    # Try alternative cases
    real_dir = os.path.join(DATA_PATH, "real")
    if not os.path.exists(real_dir):
        print(f"Also tried '{real_dir}' but not found")
        exit()

if not os.path.exists(altered_dir):
    print(f"ERROR: Altered fingerprints directory '{altered_dir}' not found!")
    # Try alternative cases
    altered_dir = os.path.join(DATA_PATH, "altered")
    if not os.path.exists(altered_dir):
        print(f"Also tried '{altered_dir}' but not found")
        exit()

print_directory_contents(DATA_PATH)
print_directory_contents(real_dir)
print_directory_contents(altered_dir)

# Count files with more detailed output
print("\nCounting files...")
real_count = sum(1 for root, dirs, files in os.walk(real_dir) 
               for file in files if file.lower().endswith('.bmp'))
altered_count = sum(1 for root, dirs, files in os.walk(altered_dir) 
                for file in files if file.lower().endswith('.bmp'))

print(f"Found {real_count} BMP files in real directory")
print(f"Found {altered_count} BMP files in altered directory")

# Print counts of BMP files
real_count = sum(1 for root, dirs, files in os.walk(real_dir) 
                for file in files if file.lower().endswith('.bmp'))
altered_count = sum(1 for root, dirs, files in os.walk(altered_dir) 
                   for file in files if file.lower().endswith('.bmp'))

print(f"Found {real_count} BMP files in real directory")
print(f"Found {altered_count} BMP files in altered directory")

# List of models to train (you can modify this list)
models_to_train = ["lenet", "alexnet", "vgg", "googlenet", "resnet"]

for model_name in models_to_train:
    print(f"\n{'='*50}")
    print(f"Training {model_name.upper()} model")
    print(f"{'='*50}")
    
    try:
        # Load and prepare data
        X_train, X_test, y_train, y_test = load_and_prepare_data(model_name, DATA_PATH)
        
        # Get model
        input_shape = get_input_shape(model_name)
        model = get_model(model_name, input_shape)
        
        # Train model
        model, history = train_model(
            model, 
            X_train, 
            y_train, 
            X_test, 
            y_test, 
            model_name,
            batch_size=BATCH_SIZE,
            epochs=EPOCHS,
            model_save_dir=MODEL_SAVE_DIR
        )
        
        # Evaluate model
        evaluate_model(model, X_test, y_test)
        
        # Plot training history
        plot_history(history)
    except Exception as e:
        print(f"Error training {model_name}: {e}")
        import traceback
        traceback.print_exc()

Checking data directories...
Contents of dataset:
  [FILE] .DS_Store
  [DIR] Real
  [DIR] Altered
Contents of dataset/Real:
  [FILE] 364__M_Right_little_finger.BMP
  [FILE] 292__M_Left_little_finger.BMP
  [FILE] 53__M_Left_thumb_finger.BMP
  [FILE] 225__M_Left_little_finger.BMP
  [FILE] 277__M_Right_middle_finger.BMP
  [FILE] 97__M_Right_little_finger.BMP
  [FILE] 503__M_Right_ring_finger.BMP
  [FILE] 426__M_Right_thumb_finger.BMP
  [FILE] 569__M_Left_ring_finger.BMP
  [FILE] 495__M_Right_index_finger.BMP
  [FILE] 405__M_Right_little_finger.BMP
  [FILE] 33__M_Left_little_finger.BMP
  [FILE] 562__F_Right_middle_finger.BMP
  [FILE] 350__M_Left_middle_finger.BMP
  [FILE] 84__M_Left_little_finger.BMP
  [FILE] 299__M_Right_ring_finger.BMP
  [FILE] 516__M_Right_middle_finger.BMP
  [FILE] 422__M_Right_index_finger.BMP
  [FILE] 491__M_Right_thumb_finger.BMP
  [FILE] 64__M_Left_thumb_finger.BMP
  [FILE] 199__M_Left_thumb_finger.BMP
  [FILE] 341__M_Right_ring_finger.BMP
  [FILE] 88__F_Left_ring_

Traceback (most recent call last):
  File "/var/folders/tf/06n5fz4j6gdb4zj5j5tkrqgc0000gn/T/ipykernel_69647/273351051.py", line 67, in <module>
    X_train, X_test, y_train, y_test = load_and_prepare_data(model_name, DATA_PATH)
                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/sc/Desktop/AAI/project/Fingerprint_Recognition/utils/data_loader.py", line 109, in load_and_prepare_data
    target_shape = get_input_shape(model_type)
                                       ^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/sklearn/utils/_param_validation.py", line 216, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/sklearn/model_selection/_split.py", line 2851, in train_test_split
    n_train, n_test = _validate_shuffle_split(
                      ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Libr