In [1]:
# test_phase_4_final.ipynb

import torch
import numpy as np
import pandas as pd
from neural_mi.data.processors import ContinuousProcessor
from neural_mi.analysis.workflow import AnalysisWorkflow

if __name__ == '__main__':
    print(f"--- Testing Phase 4: The Final Modular Workflow ---")
    device = 'cpu'

    # --- 1. Generate Data ---
    x_raw = np.random.randn(5, 8000)
    y_raw = x_raw + np.random.randn(5, 8000) * 0.5
    cont_proc = ContinuousProcessor(window_size=100, step_size=50)
    x_data = cont_proc.process(x_raw).to(device)
    y_data = cont_proc.process(y_raw).to(device)
    print(f"Data shape: {x_data.shape}")
    print("-" * 50)

    # --- 2. Define Base Parameters ---
    base_params = {
        'hidden_dim': 128, 'n_layers': 2, 'learning_rate': 1e-4,
        'n_epochs': 5, 'batch_size': 32, 'patience': 3
    }

    # --- 3. Run Workflow with SeparableCritic (Standard) ---
    print("\n--- Running workflow with SeparableCritic (Standard) ---")
    separable_param_grid = {'embedding_dim': [8]} # Keep test fast
    workflow_sep = AnalysisWorkflow(x_data, y_data, base_params, critic_type='separable', use_variational=False)
    results_sep = workflow_sep.run(param_grid=separable_param_grid, gamma_range=[1], n_workers=2)
    df_sep = pd.DataFrame(results_sep)
    print("\nSeparable (Standard) Results Summary:")
    print(df_sep[['critic_type', 'use_variational', 'embedding_dim', 'gamma', 'test_mi']].round(4))
    assert not df_sep['use_variational'].iloc[0]
    print("-" * 50)

    # --- 4. Run Workflow with ConcatCritic (Standard) ---
    print("\n--- Running workflow with ConcatCritic (Standard) ---")
    workflow_concat = AnalysisWorkflow(x_data, y_data, base_params, critic_type='concat', use_variational=False)
    results_concat = workflow_concat.run(param_grid={}, gamma_range=[1], n_workers=2)
    df_concat = pd.DataFrame(results_concat)
    print("\nConcat (Standard) Results Summary:")
    print(df_concat[['critic_type', 'use_variational', 'gamma', 'test_mi']].round(4))
    assert not df_concat['use_variational'].iloc[0]
    print("-" * 50)
    
    # --- 5. Run Workflow with SeparableCritic (VARIATIONAL) ---
    print("\n--- Running workflow with SeparableCritic (VARIATIONAL) ---")
    workflow_var = AnalysisWorkflow(x_data, y_data, base_params, critic_type='separable', use_variational=True)
    results_var = workflow_var.run(param_grid=separable_param_grid, gamma_range=[1], n_workers=2)
    df_var = pd.DataFrame(results_var)
    print("\nSeparable (Variational) Results Summary:")
    print(df_var[['critic_type', 'use_variational', 'embedding_dim', 'gamma', 'test_mi']].round(4))
    assert df_var['use_variational'].iloc[0]
    print("-" * 50)

    print("\n✅ Phase 4 final test completed successfully!")

--- Testing Phase 4: The Final Modular Workflow ---
Data shape: torch.Size([159, 5, 100])
--------------------------------------------------

--- Running workflow with SeparableCritic (Standard) ---
Starting analysis with 2 workers for critic: 'separable'...
Created 1 tasks to run...
Epoch 1/5 | Test MI: -0.1223
Epoch 2/5 | Test MI: -0.1215
Epoch 3/5 | Test MI: -0.1234
Epoch 4/5 | Test MI: -0.1278
Epoch 5/5 | Test MI: -0.1340
Early stopping triggered after 3 epochs.
Best epoch identified (via smoothed curve): 1 (Smoothed MI: -0.1224)
Analysis workflow finished.

Separable (Standard) Results Summary:
  critic_type  use_variational  embedding_dim  gamma  test_mi
0   separable            False              8      1  -0.1215
--------------------------------------------------

--- Running workflow with ConcatCritic (Standard) ---
Starting analysis with 2 workers for critic: 'concat'...
Created 1 tasks to run...
Epoch 1/5 | Test MI: -0.1014
Epoch 2/5 | Test MI: -0.0919
Epoch 3/5 | Test MI: -