# Dental Cavity Detection - Model Training

This notebook demonstrates the process of training and evaluating different models for dental cavity detection.

In [None]:
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from sklearn.metrics import classification_report, confusion_matrix

# Add the src directory to the path
sys.path.append('../')
from src.data_preprocessing import DataPreprocessor
from src.model import DentalCavityModel
from src.evaluation import ModelEvaluator
from src.utils import save_class_names, plot_learning_curves

# Set random seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)

# Set plot style
plt.style.use('ggplot')
sns.set(style="whitegrid")

## 1. Data Preparation

First, let's load and preprocess the dataset.

In [None]:
# Set data directory
data_dir = '../data/cavity_dataset'

# Initialize data preprocessor
preprocessor = DataPreprocessor(
    data_dir=data_dir,
    img_size=(224, 224),
    test_size=0.2,
    val_size=0.1
)

# Load and preprocess data
X_train, y_train, X_val, y_val, X_test, y_test, class_names = preprocessor.load_and_preprocess_data()

# Save class names for later use
save_class_names(class_names, '../models/class_names.txt')

# Visualize sample images
preprocessor.visualize_samples(X_train, y_train, class_names)

## 2. Data Augmentation

Let's apply data augmentation to increase the diversity of our training data.

In [None]:
# Apply data augmentation
X_train_aug, y_train_aug = preprocessor.apply_augmentation(X_train, y_train)

# Visualize augmented samples
preprocessor.visualize_samples(X_train_aug[-10:], y_train_aug[-10:], class_names, num_samples=2)

## 3. Model Training

Let's train different models and compare their performance.

### 3.1 Custom CNN Model

In [None]:
# Initialize custom CNN model
custom_model = DentalCavityModel(
    input_shape=X_train.shape[1:],
    num_classes=len(class_names),
    model_type='custom'
)

# Print model summary
custom_model.model.summary()

# Train the model
custom_history = custom_model.train(
    X_train_aug, y_train_aug,
    X_val, y_val,
    batch_size=32,
    epochs=30,
    save_dir='../models'
)

# Plot training history
custom_model.plot_training_history(custom_history, '../results/custom_model_history.png')

### 3.2 Transfer Learning with MobileNetV2

In [None]:
# Initialize MobileNetV2 model
mobilenet_model = DentalCavityModel(
    input_shape=X_train.shape[1:],
    num_classes=len(class_names),
    model_type='mobilenetv2'
)

# Print model summary
mobilenet_model.model.summary()

# Train the model
mobilenet_history = mobilenet_model.train(
    X_train_aug, y_train_aug,
    X_val, y_val,
    batch_size=32,
    epochs=30,
    save_dir='../models'
)

# Plot training history
mobilenet_model.plot_training_history(mobilenet_history, '../results/mobilenet_model_history.png')

### 3.3 Fine-tuning MobileNetV2

Let's fine-tune the MobileNetV2 model by unfreezing some of the top layers.

In [None]:
# Load the best MobileNetV2 model
best_mobilenet = tf.keras.models.load_model('../models/dental_cavity_mobilenetv2_best.h5')

# Unfreeze the top layers of the base model
for layer in best_mobilenet.layers:
    if 'mobilenetv2' in layer.name:
        base_model = layer
        break

# Unfreeze the top 20 layers
for layer in base_model.layers[-20:]:
    layer.trainable = True

# Recompile the model with a lower learning rate
best_mobilenet.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

# Define callbacks
checkpoint = ModelCheckpoint(
    '../models/dental_cavity_mobilenetv2_finetuned_best.h5',
    monitor='val_accuracy',
    save_best_only=True,
    mode='max',
    verbose=1
)

early_stopping = EarlyStopping(
    monitor='val_loss',
    patience=10,
    restore_best_weights=True,
    verbose=1
)

reduce_lr = ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.2,
    patience=5,
    min_lr=1e-6,
    verbose=1
)

# Fine-tune the model
finetuned_history = best_mobilenet.fit(
    X_train_aug, y_train_aug,
    batch_size=32,
    epochs=20,
    validation_data=(X_val, y_val),
    callbacks=[checkpoint, early_stopping, reduce_lr]
)

# Plot training history
plot_learning_curves(
    finetuned_history.history,
    save_path='../results/mobilenet_finetuned_history.png'
)

## 4. Model Evaluation

Let's evaluate the performance of our models on the test set.

### 4.1 Custom CNN Model Evaluation

In [None]:
# Load the best custom model
best_custom_model = tf.keras.models.load_model('../models/dental_cavity_custom_best.h5')

# Evaluate on test set
custom_loss, custom_accuracy = best_custom_model.evaluate(X_test, y_test)
print(f"Custom CNN - Test Loss: {custom_loss:.4f}")
print(f"Custom CNN - Test Accuracy: {custom_accuracy:.4f}")

# Detailed evaluation
custom_evaluator = ModelEvaluator(best_custom_model, X_test, y_test, class_names)
custom_evaluator.plot_confusion_matrix(normalize=True, save_path='../results/custom_confusion_matrix.png')
custom_evaluator.print_classification_report()
custom_evaluator.plot_roc_curve(save_path='../results/custom_roc_curve.png')
custom_evaluator.visualize_predictions(save_path='../results/custom_predictions.png')

### 4.2 MobileNetV2 Model Evaluation

In [None]:
# Load the best MobileNetV2 model
best_mobilenet_model = tf.keras.models.load_model('../models/dental_cavity_mobilenetv2_best.h5')

# Evaluate on test set
mobilenet_loss, mobilenet_accuracy = best_mobilenet_model.evaluate(X_test, y_test)
print(f"MobileNetV2 - Test Loss: {mobilenet_loss:.4f}")
print(f"MobileNetV2 - Test Accuracy: {mobilenet_accuracy:.4f}")

# Detailed evaluation
mobilenet_evaluator = ModelEvaluator(best_mobilenet_model, X_test, y_test, class_names)
mobilenet_evaluator.plot_confusion_matrix(normalize=True, save_path='../results/mobilenet_confusion_matrix.png')
mobilenet_evaluator.print_classification_report()
mobilenet_evaluator.plot_roc_curve(save_path='../results/mobilenet_roc_curve.png')
mobilenet_evaluator.visualize_predictions(save_path='../results/mobilenet_predictions.png')

### 4.3 Fine-tuned MobileNetV2 Model Evaluation

In [None]:
# Load the best fine-tuned MobileNetV2 model
best_finetuned_model = tf.keras.models.load_model('../models/dental_cavity_mobilenetv2_finetuned_best.h5')

# Evaluate on test set
finetuned_loss, finetuned_accuracy = best_finetuned_model.evaluate(X_test, y_test)
print(f"Fine-tuned MobileNetV2 - Test Loss: {finetuned_loss:.4f}")
print(f"Fine-tuned MobileNetV2 - Test Accuracy: {finetuned_accuracy:.4f}")

# Detailed evaluation
finetuned_evaluator = ModelEvaluator(best_finetuned_model, X_test, y_test, class_names)
finetuned_evaluator.plot_confusion_matrix(normalize=True, save_path='../results/finetuned_confusion_matrix.png')
finetuned_evaluator.print_classification_report()
finetuned_evaluator.plot_roc_curve(save_path='../results/finetuned_roc_curve.png')
finetuned_evaluator.visualize_predictions(save_path='../results/finetuned_predictions.png')

## 5. Model Comparison

Let's compare the performance of all models.

In [None]:
# Compare model performance
models = ['Custom CNN', 'MobileNetV2', 'Fine-tuned MobileNetV2']
accuracies = [custom_accuracy, mobilenet_accuracy, finetuned_accuracy]
losses = [custom_loss, mobilenet_loss, finetuned_loss]

# Create a DataFrame for comparison
comparison_df = pd.DataFrame({
    'Model': models,
    'Accuracy': accuracies,
    'Loss': losses
})

print("Model Performance Comparison:")
print(comparison_df)

# Plot comparison
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
sns.barplot(x='Model', y='Accuracy', data=comparison_df)
plt.title('Model Accuracy Comparison')
plt.ylim(0, 1)

plt.subplot(1, 2, 2)
sns.barplot(x='Model', y='Loss', data=comparison_df)
plt.title('Model Loss Comparison')

plt.tight_layout()
plt.savefig('../results/model_comparison.png')
plt.show()

## 6. Conclusion

In this notebook, we've trained and evaluated different models for dental cavity detection. We've compared a custom CNN model with transfer learning approaches using MobileNetV2, including fine-tuning.

Key findings:
- Transfer learning with MobileNetV2 outperforms the custom CNN model
- Fine-tuning further improves the performance of the MobileNetV2 model
- Data augmentation helps to improve model generalization

The fine-tuned MobileNetV2 model achieves the best performance and is recommended for dental cavity detection.