In [None]:
import os
import numpy as np
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
import json
from utils.model_utils import create_cnn_model, decode_predictions
from utils.data_loader import LibriSpeechDataLoader
from utils.evaluation import ModelEvaluator
import config

# Set random seeds
tf.random.set_seed(42)
np.random.seed(42)

# Load mappings and feature info
char_to_num = np.load('char_to_num.npy', allow_pickle=True).item()
num_to_char = np.load('num_to_char.npy', allow_pickle=True).item()
feature_info = np.load('feature_info.npy', allow_pickle=True).item()

print("CNN Model Training")
print("=" * 50)
print(f"Vocabulary size: {len(char_to_num)}")
print(f"Available feature types: {list(feature_info.keys())}")

# Initialize data loader
data_loader = LibriSpeechDataLoader(config.DATASET_CONFIG)
data_loader.char_to_num = char_to_num
data_loader.num_to_char = num_to_char

# Initialize evaluator
evaluator = ModelEvaluator(char_to_num, num_to_char)

# Train CNN models for each feature type
cnn_models = {}
cnn_histories = {}

for feature_type in config.FEATURE_CONFIG['feature_types']:
    print(f"\n{'='*50}")
    print(f"Training CNN with {feature_type} features")
    print(f"{'='*50}")
    
    # Load feature datasets
    feature_datasets = np.load(f'feature_datasets_{feature_type}.npy', allow_pickle=True).item()
    train_ds = feature_datasets['train']
    val_ds = feature_datasets['val']
    test_ds = feature_datasets['test']
    
    # Get model dimensions
    input_dim = feature_info[feature_type]['input_dim']
    output_dim = len(char_to_num)
    
    print(f"Input dimension: {input_dim}")
    print(f"Output dimension: {output_dim}")
    
    # Create CNN model
    model = create_cnn_model(
        input_dim=input_dim,
        output_dim=output_dim,
        filters=config.MODEL_CONFIG['cnn']['filters'],
        dropout_rate=config.MODEL_CONFIG['cnn']['dropout_rate'],
        learning_rate=config.MODEL_CONFIG['cnn']['learning_rate']
    )
    
    print("Model architecture:")
    model.summary()
    
    # Callbacks
    callbacks = [
        keras.callbacks.EarlyStopping(
            patience=config.TRAINING_CONFIG['patience'],
            restore_best_weights=True,
            verbose=1
        ),
        keras.callbacks.ReduceLROnPlateau(
            factor=config.TRAINING_CONFIG['reduce_lr_factor'],
            patience=config.TRAINING_CONFIG['reduce_lr_patience'],
            verbose=1
        ),
        keras.callbacks.ModelCheckpoint(
            f'models/cnn_{feature_type}_best.h5',
            save_best_only=True,
            monitor='val_loss',
            verbose=1
        ),
        keras.callbacks.TensorBoard(
            log_dir=f'logs/cnn_{feature_type}',
            histogram_freq=1
        )
    ]
    
    # Prepare datasets for CTC training
    def prepare_ctc_data(dataset):
        def add_ctc_inputs(features, labels):
            batch_size = tf.shape(features)[0]
            input_length = tf.ones((batch_size, 1)) * tf.shape(features)[1]
            label_length = tf.ones((batch_size, 1)) * tf.shape(labels)[1]
            dummy_output = tf.zeros(batch_size)
            
            return {
                'input': features,
                'y_true': labels,
                'input_length': input_length,
                'label_length': label_length
            }, dummy_output
        
        return dataset.map(add_ctc_inputs)
    
    train_ctc_ds = prepare_ctc_data(train_ds)
    val_ctc_ds = prepare_ctc_data(val_ds)
    
    # Train model
    print(f"\nTraining CNN with {feature_type} features...")
    history = model.fit(
        train_ctc_ds,
        epochs=config.TRAINING_CONFIG['epochs'],
        validation_data=val_ctc_ds,
        callbacks=callbacks,
        verbose=1
    )
    
    # Store model and history
    cnn_models[feature_type] = model
    cnn_histories[feature_type] = history.history
    
    # Plot training history
    plt.figure(figsize=(15, 5))
    
    plt.subplot(1, 2, 1)
    plt.plot(history.history['loss'], label='Training Loss')
    plt.plot(history.history['val_loss'], label='Validation Loss')
    plt.title(f'CNN with {feature_type} - Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    if 'accuracy' in history.history:
        plt.plot(history.history['accuracy'], label='Training Accuracy')
        plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
        plt.title(f'CNN with {feature_type} - Accuracy')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy')
        plt.legend()
    
    plt.tight_layout()
    plt.savefig(f'cnn_{feature_type}_training.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    # Evaluate on test set
    print(f"\nEvaluating CNN with {feature_type} on test set...")
    test_ctc_ds = prepare_ctc_data(test_ds)
    test_loss = model.evaluate(test_ctc_ds, verbose=0)
    print(f"Test Loss: {test_loss:.4f}")
    
    # Save training history
    with open(f'cnn_{feature_type}_history.json', 'w') as f:
        json.dump(history.history, f, indent=2)
    
    print(f"âœ“ CNN with {feature_type} training completed!")

# Compare all CNN models
print(f"\n{'='*50}")
print("CNN Models Comparison")
print(f"{'='*50}")

comparison_data = []
for feature_type in cnn_histories.keys():
    min_val_loss = min(cnn_histories[feature_type]['val_loss'])
    final_val_loss = cnn_histories[feature_type]['val_loss'][-1]
    
    comparison_data.append({
        'feature_type': feature_type,
        'min_val_loss': min_val_loss,
        'final_val_loss': final_val_loss,
        'epochs': len(cnn_histories[feature_type]['val_loss'])
    })

# Create comparison DataFrame
import pandas as pd
comparison_df = pd.DataFrame(comparison_data)
print("\nCNN Models Performance Comparison:")
print(comparison_df.to_string(index=False))

# Visualize comparison
plt.figure(figsize=(12, 6))

plt.subplot(1, 2, 1)
for feature_type, history in cnn_histories.items():
    plt.plot(history['val_loss'], label=feature_type, alpha=0.7)
plt.title('CNN Models - Validation Loss Comparison')
plt.xlabel('Epoch')
plt.ylabel('Validation Loss')
plt.legend()
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
feature_types = list(cnn_histories.keys())
final_losses = [cnn_histories[ft]['val_loss'][-1] for ft in feature_types]
plt.bar(feature_types, final_losses, alpha=0.7)
plt.title('CNN Models - Final Validation Loss')
plt.ylabel('Validation Loss')
plt.xticks(rotation=45)

plt.tight_layout()
plt.savefig('cnn_models_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

# Save models summary
cnn_summary = {
    'feature_types': list(cnn_models.keys()),
    'input_dims': {ft: feature_info[ft]['input_dim'] for ft in cnn_models.keys()},
    'best_val_loss': {ft: min(cnn_histories[ft]['val_loss']) for ft in cnn_histories.keys()}
}

with open('cnn_models_summary.json', 'w') as f:
    json.dump(cnn_summary, f, indent=2)

print("\nCNN training completed for all feature types!")
print("Best models saved in 'models/' directory")
print("Training histories saved as JSON files")
print("Visualizations saved as PNG files")