# üéì Experimentos com Learned Wavelet (LearnedWaveletDWT1D_QMF)

## Objetivo
Avaliar o impacto de usar wavelets aprendidas (end-to-end) vs wavelets fixas:
- **LearnedWavelet + CNN**
- **LearnedWavelet + LSTM**
- **LearnedWavelet + Transformer**

## Hip√≥tese
Wavelets aprendidas podem adaptar-se √†s caracter√≠sticas espec√≠ficas do sinal,
potencialmente superando wavelets fixas como db2.

## Arquitetura
```
Input (raw signal) -> LearnedWaveletDWT1D_QMF -> [CNN/LSTM/Transformer] -> Output
```

A camada LearnedWaveletDWT1D_QMF aprende os filtros low/high pass durante o treinamento.

In [None]:
# Imports
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import time
import warnings
warnings.filterwarnings('ignore')

# TensorFlow
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
    Input, Dense, Dropout, Conv1D, MaxPooling1D, LSTM,
    Flatten, BatchNormalization, GlobalAveragePooling1D
)
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.regularizers import l2
print(f"TensorFlow version: {tf.__version__}")

# Imports locais - modelos LWT
import sys
sys.path.append('.')
sys.path.append('../../models')
from LWT import LearnedWaveletDWT1D_QMF, LearnedWaveletPair1D_QMF

from src.models import get_callbacks, TransformerBlock
from src.evaluation import RegressionEvaluator, ResultsManager
from src.visualization import ExperimentVisualizer
from config.experiment_config import (
    DATA_DIR, RESULTS_DIR, MODELS_DIR,
    DL_TRAINING_CONFIG, LEARNED_WAVELET_CONFIG
)

# Configura√ß√£o
plt.style.use('seaborn-v0_8-whitegrid')
(RESULTS_DIR / "learned_wavelet_experiments").mkdir(parents=True, exist_ok=True)

print("\n‚úÖ Imports realizados com sucesso!")
print(f"\nüì¶ LearnedWaveletDWT1D_QMF carregado")

## 1. Carregar Dados

In [None]:
# Carregar datasets (raw)
X_train = np.load(DATA_DIR / "X_train.npy")
y_train = np.load(DATA_DIR / "y_train.npy")
X_val = np.load(DATA_DIR / "X_val.npy")
y_val = np.load(DATA_DIR / "y_val.npy")
X_test = np.load(DATA_DIR / "X_test.npy")
y_test = np.load(DATA_DIR / "y_test.npy")

# Adicionar dimens√£o de canal
X_train = X_train[..., np.newaxis]
X_val = X_val[..., np.newaxis]
X_test = X_test[..., np.newaxis]

print(f"üì¶ Dados Carregados (Raw + Canal):")
print(f"  Train: {X_train.shape}")
print(f"  Val:   {X_val.shape}")
print(f"  Test:  {X_test.shape}")

input_shape = X_train.shape[1:]
print(f"\nInput shape: {input_shape}")

## 2. Configura√ß√£o das Learned Wavelets

In [None]:
# Configura√ß√£o da wavelet aprendida
wavelet_config = LEARNED_WAVELET_CONFIG.copy()

print("Configura√ß√£o LearnedWaveletDWT1D_QMF:")
for k, v in wavelet_config.items():
    print(f"  {k}: {v}")

# Gerenciadores
results_manager = ResultsManager(RESULTS_DIR / "learned_wavelet_experiments")
evaluator = RegressionEvaluator()
visualizer = ExperimentVisualizer()

training_config = DL_TRAINING_CONFIG.copy()

# Armazenar resultados
all_results = {}
all_histories = {}

## 3. Fun√ß√µes para Criar Modelos com Learned Wavelet

In [None]:
def create_learned_wavelet_cnn(input_shape, wavelet_config, learning_rate=0.001):
    """
    LearnedWaveletDWT1D_QMF + CNN
    """
    inputs = Input(shape=input_shape)
    
    # Learned Wavelet Layer
    wavelet = LearnedWaveletDWT1D_QMF(
        levels=wavelet_config.get('levels', 3),
        kernel_size=wavelet_config.get('kernel_size', 32),
        wavelet_net_units=wavelet_config.get('wavelet_net_units', 32),
        mode="concat",
        reg_energy=wavelet_config.get('reg_energy', 1e-2),
        reg_high_dc=wavelet_config.get('reg_high_dc', 1e-2),
        reg_smooth=wavelet_config.get('reg_smooth', 1e-3),
    )
    x = wavelet(inputs)
    
    # CNN layers
    x = Conv1D(64, 7, activation='relu', padding='same', kernel_regularizer=l2(0.001))(x)
    x = BatchNormalization()(x)
    x = MaxPooling1D(2)(x)
    x = Dropout(0.3)(x)
    
    x = Conv1D(128, 5, activation='relu', padding='same', kernel_regularizer=l2(0.001))(x)
    x = BatchNormalization()(x)
    x = MaxPooling1D(2)(x)
    x = Dropout(0.3)(x)
    
    x = Conv1D(256, 3, activation='relu', padding='same', kernel_regularizer=l2(0.001))(x)
    x = BatchNormalization()(x)
    x = GlobalAveragePooling1D()(x)
    
    x = Dense(128, activation='relu', kernel_regularizer=l2(0.001))(x)
    x = Dropout(0.3)(x)
    x = Dense(64, activation='relu', kernel_regularizer=l2(0.001))(x)
    x = Dropout(0.3)(x)
    
    outputs = Dense(1)(x)
    
    model = Model(inputs=inputs, outputs=outputs, name='LearnedWavelet_CNN')
    model.compile(
        optimizer=Adam(learning_rate=learning_rate),
        loss='mse',
        metrics=['mae']
    )
    return model


def create_learned_wavelet_lstm(input_shape, wavelet_config, learning_rate=0.001):
    """
    LearnedWaveletDWT1D_QMF + LSTM
    """
    inputs = Input(shape=input_shape)
    
    # Learned Wavelet Layer
    wavelet = LearnedWaveletDWT1D_QMF(
        levels=wavelet_config.get('levels', 3),
        kernel_size=wavelet_config.get('kernel_size', 32),
        wavelet_net_units=wavelet_config.get('wavelet_net_units', 32),
        mode="concat",
        reg_energy=wavelet_config.get('reg_energy', 1e-2),
        reg_high_dc=wavelet_config.get('reg_high_dc', 1e-2),
        reg_smooth=wavelet_config.get('reg_smooth', 1e-3),
    )
    x = wavelet(inputs)
    
    # LSTM layers
    x = LSTM(128, return_sequences=True, dropout=0.3, kernel_regularizer=l2(0.001))(x)
    x = LSTM(64, return_sequences=False, dropout=0.3, kernel_regularizer=l2(0.001))(x)
    
    x = Dense(64, activation='relu', kernel_regularizer=l2(0.001))(x)
    x = Dropout(0.3)(x)
    
    outputs = Dense(1)(x)
    
    model = Model(inputs=inputs, outputs=outputs, name='LearnedWavelet_LSTM')
    model.compile(
        optimizer=Adam(learning_rate=learning_rate),
        loss='mse',
        metrics=['mae']
    )
    return model


def create_learned_wavelet_transformer(input_shape, wavelet_config, learning_rate=0.001):
    """
    LearnedWaveletDWT1D_QMF + Transformer
    """
    inputs = Input(shape=input_shape)
    
    # Learned Wavelet Layer
    wavelet = LearnedWaveletDWT1D_QMF(
        levels=wavelet_config.get('levels', 3),
        kernel_size=wavelet_config.get('kernel_size', 32),
        wavelet_net_units=wavelet_config.get('wavelet_net_units', 32),
        mode="concat",
        reg_energy=wavelet_config.get('reg_energy', 1e-2),
        reg_high_dc=wavelet_config.get('reg_high_dc', 1e-2),
        reg_smooth=wavelet_config.get('reg_smooth', 1e-3),
    )
    x = wavelet(inputs)
    
    # Proje√ß√£o para dimens√£o do transformer
    x = Dense(64 * 4)(x)
    
    # Transformer blocks
    x = TransformerBlock(head_size=64, num_heads=4, ff_dim=128, dropout=0.2)(x)
    x = TransformerBlock(head_size=64, num_heads=4, ff_dim=128, dropout=0.2)(x)
    
    x = GlobalAveragePooling1D()(x)
    x = Dense(128, activation='relu')(x)
    x = Dropout(0.2)(x)
    x = Dense(64, activation='relu')(x)
    x = Dropout(0.2)(x)
    
    outputs = Dense(1)(x)
    
    model = Model(inputs=inputs, outputs=outputs, name='LearnedWavelet_Transformer')
    model.compile(
        optimizer=Adam(learning_rate=learning_rate),
        loss='mse',
        metrics=['mae']
    )
    return model

print("‚úÖ Fun√ß√µes de cria√ß√£o de modelos definidas")

## 4. Experimento 1: LearnedWavelet + CNN

In [None]:
print("="*70)
print("üéì Experimento: LearnedWaveletDWT1D_QMF + CNN")
print("="*70)

tf.keras.backend.clear_session()

# Criar modelo
model_lwt_cnn = create_learned_wavelet_cnn(input_shape, wavelet_config)
model_lwt_cnn.summary()

# Callbacks
model_path = str(MODELS_DIR / "learned_wavelet_cnn_best.keras")
callbacks = get_callbacks(model_path, patience_early=15, patience_lr=7)

# Treinar
t0 = time.time()
history_lwt_cnn = model_lwt_cnn.fit(
    X_train, y_train,
    validation_data=(X_val, y_val),
    epochs=training_config['epochs'],
    batch_size=training_config['batch_size'],
    callbacks=callbacks,
    verbose=1
)
elapsed = time.time() - t0

# Predi√ß√µes
y_pred_lwt_cnn = model_lwt_cnn.predict(X_test, verbose=0).flatten()

# M√©tricas
lwt_cnn_metrics = evaluator.evaluate(y_test, y_pred_lwt_cnn)

print(f"\nüìä Resultados LearnedWavelet + CNN:")
print(f"  RMSE: {lwt_cnn_metrics['rmse']:.6f}")
print(f"  MAE:  {lwt_cnn_metrics['mae']:.6f}")
print(f"  R¬≤:   {lwt_cnn_metrics['r2']:.6f}")
print(f"  Tempo: {elapsed:.2f}s")

all_results['LearnedWavelet_CNN'] = {
    'metrics': lwt_cnn_metrics,
    'time': elapsed,
    'epochs': len(history_lwt_cnn.history['loss']),
    'y_pred': y_pred_lwt_cnn,
    'model': model_lwt_cnn,
    'params': model_lwt_cnn.count_params()
}
all_histories['LearnedWavelet_CNN'] = history_lwt_cnn.history

results_manager.log_experiment(
    'DL_LearnedWavelet', 'CNN', lwt_cnn_metrics,
    {'wavelet_config': wavelet_config}
)

## 5. Experimento 2: LearnedWavelet + LSTM

In [None]:
print("="*70)
print("üéì Experimento: LearnedWaveletDWT1D_QMF + LSTM")
print("="*70)

tf.keras.backend.clear_session()

# Criar modelo
model_lwt_lstm = create_learned_wavelet_lstm(input_shape, wavelet_config)
model_lwt_lstm.summary()

# Callbacks
model_path = str(MODELS_DIR / "learned_wavelet_lstm_best.keras")
callbacks = get_callbacks(model_path, patience_early=15, patience_lr=7)

# Treinar
t0 = time.time()
history_lwt_lstm = model_lwt_lstm.fit(
    X_train, y_train,
    validation_data=(X_val, y_val),
    epochs=training_config['epochs'],
    batch_size=training_config['batch_size'],
    callbacks=callbacks,
    verbose=1
)
elapsed = time.time() - t0

# Predi√ß√µes
y_pred_lwt_lstm = model_lwt_lstm.predict(X_test, verbose=0).flatten()

# M√©tricas
lwt_lstm_metrics = evaluator.evaluate(y_test, y_pred_lwt_lstm)

print(f"\nüìä Resultados LearnedWavelet + LSTM:")
print(f"  RMSE: {lwt_lstm_metrics['rmse']:.6f}")
print(f"  MAE:  {lwt_lstm_metrics['mae']:.6f}")
print(f"  R¬≤:   {lwt_lstm_metrics['r2']:.6f}")

all_results['LearnedWavelet_LSTM'] = {
    'metrics': lwt_lstm_metrics,
    'time': elapsed,
    'epochs': len(history_lwt_lstm.history['loss']),
    'y_pred': y_pred_lwt_lstm,
    'model': model_lwt_lstm,
    'params': model_lwt_lstm.count_params()
}
all_histories['LearnedWavelet_LSTM'] = history_lwt_lstm.history

results_manager.log_experiment(
    'DL_LearnedWavelet', 'LSTM', lwt_lstm_metrics,
    {'wavelet_config': wavelet_config}
)

## 6. Experimento 3: LearnedWavelet + Transformer

In [None]:
print("="*70)
print("üéì Experimento: LearnedWaveletDWT1D_QMF + Transformer")
print("="*70)

tf.keras.backend.clear_session()

# Criar modelo
model_lwt_transformer = create_learned_wavelet_transformer(input_shape, wavelet_config)
model_lwt_transformer.summary()

# Callbacks
model_path = str(MODELS_DIR / "learned_wavelet_transformer_best.keras")
callbacks = get_callbacks(model_path, patience_early=15, patience_lr=7)

# Treinar
t0 = time.time()
history_lwt_transformer = model_lwt_transformer.fit(
    X_train, y_train,
    validation_data=(X_val, y_val),
    epochs=training_config['epochs'],
    batch_size=training_config['batch_size'],
    callbacks=callbacks,
    verbose=1
)
elapsed = time.time() - t0

# Predi√ß√µes
y_pred_lwt_transformer = model_lwt_transformer.predict(X_test, verbose=0).flatten()

# M√©tricas
lwt_transformer_metrics = evaluator.evaluate(y_test, y_pred_lwt_transformer)

print(f"\nüìä Resultados LearnedWavelet + Transformer:")
print(f"  RMSE: {lwt_transformer_metrics['rmse']:.6f}")
print(f"  MAE:  {lwt_transformer_metrics['mae']:.6f}")
print(f"  R¬≤:   {lwt_transformer_metrics['r2']:.6f}")

all_results['LearnedWavelet_Transformer'] = {
    'metrics': lwt_transformer_metrics,
    'time': elapsed,
    'epochs': len(history_lwt_transformer.history['loss']),
    'y_pred': y_pred_lwt_transformer,
    'model': model_lwt_transformer,
    'params': model_lwt_transformer.count_params()
}
all_histories['LearnedWavelet_Transformer'] = history_lwt_transformer.history

results_manager.log_experiment(
    'DL_LearnedWavelet', 'Transformer', lwt_transformer_metrics,
    {'wavelet_config': wavelet_config}
)

## 7. Visualiza√ß√£o dos Filtros Aprendidos

In [None]:
# Extrair e visualizar filtros aprendidos do melhor modelo
def extract_learned_filters(model):
    """Extrai os filtros aprendidos da camada wavelet."""
    for layer in model.layers:
        if 'learned_wavelet' in layer.name.lower():
            # Obter os pares wavelet
            pairs = layer.pairs
            filters_info = []
            for i, pair in enumerate(pairs):
                # Gerar filtros
                t = pair._make_t()
                scale = tf.nn.softplus(pair.raw_scale) + 1e-3
                t_adj = (t - pair.translation) / scale
                
                z = pair.base_net(t_adj)
                h = pair.low_head(z)
                h = pair._normalize_h(h)
                g = pair._qmf_from_h(h)
                
                filters_info.append({
                    'level': i + 1,
                    'low_pass': h.numpy().flatten(),
                    'high_pass': g.numpy().flatten(),
                    'scale': scale.numpy(),
                    'translation': pair.translation.numpy()
                })
            return filters_info
    return None

# Usar o modelo CNN para visualiza√ß√£o
filters = extract_learned_filters(all_results['LearnedWavelet_CNN']['model'])

if filters:
    n_levels = len(filters)
    fig, axes = plt.subplots(n_levels, 2, figsize=(14, 4*n_levels))
    
    for i, filt in enumerate(filters):
        # Low-pass filter
        axes[i, 0].plot(filt['low_pass'], 'b-', linewidth=2)
        axes[i, 0].set_title(f'N√≠vel {filt["level"]} - Filtro Low-Pass (h)')
        axes[i, 0].set_xlabel('Coeficiente')
        axes[i, 0].set_ylabel('Amplitude')
        axes[i, 0].grid(True, alpha=0.3)
        axes[i, 0].axhline(y=0, color='r', linestyle='--', alpha=0.5)
        
        # High-pass filter
        axes[i, 1].plot(filt['high_pass'], 'r-', linewidth=2)
        axes[i, 1].set_title(f'N√≠vel {filt["level"]} - Filtro High-Pass (g) [QMF]')
        axes[i, 1].set_xlabel('Coeficiente')
        axes[i, 1].set_ylabel('Amplitude')
        axes[i, 1].grid(True, alpha=0.3)
        axes[i, 1].axhline(y=0, color='b', linestyle='--', alpha=0.5)
    
    plt.suptitle('Filtros Wavelet Aprendidos (LearnedWaveletDWT1D_QMF)', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig(RESULTS_DIR / "learned_wavelet_experiments" / "learned_filters.png", dpi=150, bbox_inches='tight')
    plt.show()
    
    print("\nüìä Par√¢metros dos Filtros Aprendidos:")
    for filt in filters:
        print(f"  N√≠vel {filt['level']}: scale={filt['scale']:.4f}, translation={filt['translation']:.4f}")
else:
    print("‚ö†Ô∏è N√£o foi poss√≠vel extrair os filtros")

## 8. Compara√ß√£o dos Resultados

In [None]:
# Criar DataFrame comparativo
comparison_data = []
for model_name, result in all_results.items():
    row = {
        'Model': model_name,
        'RMSE': result['metrics']['rmse'],
        'MAE': result['metrics']['mae'],
        'R¬≤': result['metrics']['r2'],
        'Params': result['params'],
        'Time (s)': result['time'],
        'Epochs': result['epochs']
    }
    comparison_data.append(row)

comparison_df = pd.DataFrame(comparison_data)
comparison_df = comparison_df.sort_values('RMSE')

print("\n" + "="*70)
print("üìä COMPARA√á√ÉO - Modelos com LearnedWaveletDWT1D_QMF")
print("="*70)
print(comparison_df.to_string(index=False))

# Salvar
comparison_df.to_csv(RESULTS_DIR / "learned_wavelet_experiments" / "comparison_learned_wavelet.csv", index=False)

In [None]:
# Visualiza√ß√£o comparativa
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

metrics_to_plot = ['RMSE', 'MAE', 'R¬≤']
colors = plt.cm.Purples(np.linspace(0.4, 0.9, len(comparison_df)))

for idx, metric in enumerate(metrics_to_plot):
    data = comparison_df.set_index('Model')[metric].sort_values(
        ascending=(metric != 'R¬≤')
    )
    bars = axes[idx].barh(data.index, data.values, color=colors)
    axes[idx].set_xlabel(metric)
    axes[idx].set_title(f'Compara√ß√£o: {metric}')
    axes[idx].grid(True, alpha=0.3, axis='x')
    
    for bar, val in zip(bars, data.values):
        axes[idx].text(val, bar.get_y() + bar.get_height()/2,
                      f'{val:.4f}', va='center', ha='left', fontsize=9)

plt.suptitle('Learned Wavelet (LearnedWaveletDWT1D_QMF)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(RESULTS_DIR / "learned_wavelet_experiments" / "comparison_learned_wavelet.png", dpi=150, bbox_inches='tight')
plt.show()

## 9. An√°lise de Predi√ß√µes

In [None]:
# An√°lise do melhor modelo
best_model_name = comparison_df.iloc[0]['Model']
best_result = all_results[best_model_name]

print(f"\nüèÜ Melhor Modelo: {best_model_name}")

# Plot de predi√ß√µes
fig = visualizer.plot_prediction_comparison(
    y_test, best_result['y_pred'],
    model_name=best_model_name,
    n_samples=500,
    save_path=RESULTS_DIR / "learned_wavelet_experiments" / f"predictions_{best_model_name}.png"
)
plt.show()

## 10. Resumo

In [None]:
print("\n" + "="*70)
print("üìã RESUMO - Experimentos com Learned Wavelets")
print("="*70)
print(f"\n‚úÖ Modelos avaliados: {len(all_results)}")
print(f"‚úÖ Melhor modelo: {best_model_name}")
print(f"‚úÖ Melhor RMSE: {comparison_df.iloc[0]['RMSE']:.6f}")
print(f"‚úÖ Melhor R¬≤: {comparison_df.iloc[0]['R¬≤']:.6f}")
print(f"\nüìÅ Resultados salvos em: {RESULTS_DIR / 'learned_wavelet_experiments'}")
print("\nüéâ Notebook conclu√≠do com sucesso!")