In [3]:
# test_phase_4c_final_workflow.ipynb

import numpy as np
import pandas as pd
from neural_mi.data.processors import ContinuousProcessor
from neural_mi.analysis.workflow import AnalysisWorkflow

if __name__ == '__main__':
    print(f"--- Testing Phase 4c: The Final Rigorous Workflow ---")
    device = 'cpu'

    # 1. Generate Data
    x_raw = np.random.randn(5, 4000)
    y_raw = x_raw + np.random.randn(5, 4000) * 0.5
    cont_proc = ContinuousProcessor(window_size=100, step_size=50)
    x_data = cont_proc.process(x_raw).to(device)
    y_data = cont_proc.process(y_raw).to(device)
    print(f"Data shape: {x_data.shape}")
    print("-" * 40)

    # 2. Define Parameters for the Rigorous Workflow
    base_params = {
        'hidden_dim': 128, 'n_layers': 2, 'learning_rate': 5e-4,
        'n_epochs': 50, 'batch_size': 128, 'patience': 10
    }
    param_grid = {'embedding_dim': [16]}
    gamma_range = range(1, 11) # Use a decent range for a good fit

    # 3. Instantiate and Run the Workflow
    workflow = AnalysisWorkflow(x_data, y_data, base_params, critic_type='separable')
    results = workflow.run(param_grid=param_grid, gamma_range=gamma_range, n_workers=4, verbose=True)

    # 4. Verification
    print("\n--- Verification ---")
    assert len(results) == 1
    final_result = results[0]
    
    assert 'mi_corrected' in final_result
    assert 'is_reliable' in final_result
    assert 'gammas_used' in final_result
    
    print("\nFinal Bias-Corrected MI Estimate:")
    print(fr"  MI = {final_result['mi_corrected']:.4f} ± {final_result['mi_error']}")
    print(f"  Reliable Fit: {final_result['is_reliable']}")
    print(f"  Gammas Used in Final Fit: {final_result['gammas_used']}")
    assert final_result['mi_corrected'] > 0

    print("\n✅ Phase 4c (Iterative Bias Correction) test completed successfully!")

--- Testing Phase 4c: The Final Rigorous Workflow ---
Data shape: torch.Size([79, 5, 100])
----------------------------------------
Starting rigorous analysis with 4 workers...
Created 55 tasks to run...




Epoch 1/50 | Test MI: -0.0137
Epoch 2/50 | Test MI: -0.0169
Epoch 3/50 | Test MI: -0.0188
Epoch 4/50 | Test MI: -0.0188
Epoch 5/50 | Test MI: -0.0170
Epoch 6/50 | Test MI: -0.0135
Epoch 7/50 | Test MI: -0.0074
Epoch 8/50 | Test MI: 0.0009
Epoch 9/50 | Test MI: 0.0111
Epoch 10/50 | Test MI: 0.0223
Epoch 11/50 | Test MI: 0.0341
Epoch 12/50 | Test MI: 0.0450
Epoch 13/50 | Test MI: 0.0534
Epoch 14/50 | Test MI: 0.0587
Epoch 15/50 | Test MI: 0.0599
Epoch 16/50 | Test MI: 0.0598
Epoch 17/50 | Test MI: 0.0590
Epoch 18/50 | Test MI: 0.0573
Epoch 19/50 | Test MI: 0.0556
Epoch 20/50 | Test MI: 0.0541
Epoch 21/50 | Test MI: 0.0526
Epoch 22/50 | Test MI: 0.0513
Epoch 23/50 | Test MI: 0.0503
Epoch 24/50 | Test MI: 0.0496
Epoch 25/50 | Test MI: 0.0490
Epoch 26/50 | Test MI: 0.0484
Early stopping triggered after 10 epochs.
Best epoch identified (via smoothed curve): 16 (Smoothed MI: 0.0594)
Epoch 1/50 | Test MI: -0.0050
Epoch 2/50 | Test MI: -0.0041
Epoch 3/50 | Test MI: -0.0036
Epoch 4/50 | Test MI:

