In [2]:
# -*- coding: utf-8 -*-
"""
Created on ... 

@author: ...
"""

import os
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
import matplotlib.pyplot as plt 
from sklearn.metrics import mean_absolute_error

# --- Model Definition ---
def create_model(input_shape=(64, 64, 3)):
    """Creates a CNN model for age prediction."""
    model = tf.keras.Sequential([
        layers.Input(shape=input_shape),
        layers.Conv2D(32, (3, 3), activation='relu', padding='same'),
        layers.MaxPooling2D((2, 2)),
        layers.Conv2D(64, (3, 3), activation='relu', padding='same'),
        layers.MaxPooling2D((2, 2)),
        layers.Conv2D(128, (3, 3), activation='relu', padding='same'),
        layers.MaxPooling2D((2, 2)),
        layers.Flatten(),
        layers.Dense(128, activation='relu'),
        layers.Dense(1)  
    ])
    return model

# --- Data Loading ---
data_dir = "../data/processed"  # Adjust if necessary

X_train = np.load(os.path.join(data_dir, 'X_train.npy'))
y_train = np.load(os.path.join(data_dir, 'y_train.npy'))
X_test = np.load(os.path.join(data_dir, 'X_test.npy'))
y_test = np.load(os.path.join(data_dir, 'y_test.npy'))

# --- Model Training ---

model = create_model()
model.compile(optimizer='adam', loss='mse', metrics=['mae'])

# Train
history = model.fit(X_train, y_train, 
                    epochs=20,  # Adjust as needed 
                    batch_size=32, 
                    validation_data=(X_test, y_test)) 

# --- Model Evaluation ---
loss, mae = model.evaluate(X_test, y_test, verbose=0)
print(f"Mean Absolute Error: {mae:.2f}")

# --- Visualization of Training (Optional) ---

plt.plot(history.history['mae'], label='Training MAE')
plt.plot(history.history['val_mae'], label='Validation MAE')
plt.title('Model MAE')
plt.ylabel('MAE')
plt.xlabel('Epoch')
plt.legend()
plt.show()

# --- Save the Trained Model (Optional) ---

model_save_path = "../saved_models/age_prediction_model.h5"
model.save(model_save_path)
print(f"Model saved to: {model_save_path}")

ModuleNotFoundError: No module named 'distutils'