# Multi-Signal Analysis with ML4T Diagnostic

This notebook demonstrates how to analyze 50-200 trading signals simultaneously using the `MultiSignalAnalysis` module. Key features:

1. **Batch Analysis**: Efficiently analyze many signals with parallel computation
2. **Multiple Testing Corrections**: FDR (Benjamini-Hochberg) and FWER (Holm-Bonferroni)
3. **Signal Selection**: Intelligent algorithms to identify best signals
4. **Interactive Visualization**: Focus+Context dashboards for exploration

## References
- Benjamini & Hochberg (1995). "Controlling the False Discovery Rate"
- Holm (1979). "A Simple Sequentially Rejective Multiple Test Procedure"
- Lopez de Prado (2018). "Advances in Financial Machine Learning"

In [None]:
# Standard imports
from datetime import datetime, timedelta

import numpy as np
import polars as pl

from ml4t.diagnostic.config import MultiSignalAnalysisConfig

# ML4T Diagnostic imports
from ml4t.diagnostic.evaluation import (
    MultiSignalAnalysis,
    SignalSelector,
)
from ml4t.diagnostic.visualization import (
    MultiSignalDashboard,
    plot_ic_ridge,
    plot_pareto_frontier,
    plot_signal_correlation_heatmap,
    plot_signal_ranking_bar,
)

## 1. Generate Synthetic Data

We'll create a synthetic dataset with:
- 30 assets over 2 years (500 trading days)
- 20 signals with varying predictive power
- Some signals correlated with each other (cluster structure)

In [None]:
def generate_trading_dates(start: str, n_days: int) -> list[datetime]:
    """Generate trading dates (skip weekends)."""
    dates = []
    current = datetime.strptime(start, "%Y-%m-%d")
    while len(dates) < n_days:
        if current.weekday() < 5:  # Mon-Fri
            dates.append(current)
        current += timedelta(days=1)
    return dates


def generate_price_data(
    dates: list[datetime],
    assets: list[str],
    seed: int = 42,
) -> pl.DataFrame:
    """Generate synthetic price data with random walk."""
    np.random.seed(seed)
    records = []
    for asset in assets:
        price = 100.0 * (1 + np.random.randn() * 0.3)  # Random starting price
        for date in dates:
            price *= 1 + np.random.randn() * 0.02  # Daily volatility ~2%
            records.append({"date": date, "asset": asset, "price": max(price, 1.0)})
    return pl.DataFrame(records)


def generate_signals(
    dates: list[datetime],
    assets: list[str],
    n_signals: int,
    n_clusters: int = 4,
    seed: int = 42,
) -> dict[str, pl.DataFrame]:
    """Generate signals with cluster structure and varying IC."""
    np.random.seed(seed)
    signals = {}

    # Assign signals to clusters
    cluster_assignments = np.arange(n_signals) % n_clusters

    # Generate cluster base signals
    cluster_bases = {}
    for c in range(n_clusters):
        cluster_bases[c] = np.random.randn(len(dates), len(assets))

    for sig_idx in range(n_signals):
        cluster = cluster_assignments[sig_idx]

        # Signal is cluster base + noise
        noise_level = 0.3 + np.random.rand() * 0.5
        factor_values = (
            cluster_bases[cluster] * (1 - noise_level)
            + np.random.randn(len(dates), len(assets)) * noise_level
        )

        # Create DataFrame
        records = []
        for i, date in enumerate(dates):
            for j, asset in enumerate(assets):
                records.append(
                    {
                        "date": date,
                        "asset": asset,
                        "factor": factor_values[i, j],
                    }
                )

        # Name with cluster prefix for interpretability
        signal_name = f"cluster{cluster}_signal_{sig_idx:02d}"
        signals[signal_name] = pl.DataFrame(records)

    return signals

In [None]:
# Configuration
N_DAYS = 500
N_ASSETS = 30
N_SIGNALS = 20
N_CLUSTERS = 4

# Generate data
dates = generate_trading_dates("2022-01-03", N_DAYS)
assets = [f"STOCK_{i:02d}" for i in range(N_ASSETS)]

prices = generate_price_data(dates, assets)
signals = generate_signals(dates, assets, N_SIGNALS, N_CLUSTERS)

print(f"Price data: {prices.shape}")
print(f"Number of signals: {len(signals)}")
print(f"Signal names: {list(signals.keys())[:5]} ...")

## 2. Basic Multi-Signal Analysis

The `MultiSignalAnalysis` class provides batch analysis with:
- Parallel computation via joblib
- Smart caching with Polars fingerprinting
- Multiple testing corrections (FDR, FWER)

In [None]:
# Create analyzer with default configuration
analyzer = MultiSignalAnalysis(
    signals=signals,
    prices=prices,
)

# Compute summary metrics for all signals
summary = analyzer.compute_summary(progress=True)

print(f"\nTotal signals: {summary.n_signals}")
print(f"FDR significant (alpha=0.05): {summary.n_fdr_significant}")
print(f"FWER significant (alpha=0.05): {summary.n_fwer_significant}")

In [None]:
# View summary as DataFrame
df = summary.get_dataframe()
print("Summary DataFrame columns:")
print(df.columns)

# Show top signals by IC IR
print("\nTop 5 signals by IC IR:")
df.sort("ic_ir", descending=True).head(5).select(
    ["signal_name", "ic_mean", "ic_ir", "ic_t_stat", "fdr_significant"]
)

## 3. Custom Configuration

Configure the analysis for your specific needs:

In [None]:
from ml4t.diagnostic.config import SignalAnalysisConfig

# Strict configuration for production
strict_config = MultiSignalAnalysisConfig(
    # Tighter significance thresholds
    fdr_alpha=0.01,
    fwer_alpha=0.01,
    # Minimum IC threshold
    min_ic_threshold=0.02,
    # Performance tuning
    n_jobs=-1,  # Use all CPU cores
    cache_enabled=True,
    # Analysis periods
    signal_config=SignalAnalysisConfig(
        periods=(1, 5, 10),  # Forward return periods
    ),
)

# Run with strict config
strict_analyzer = MultiSignalAnalysis(signals, prices, config=strict_config)
strict_summary = strict_analyzer.compute_summary(progress=True)

print("\nWith strict thresholds (alpha=0.01):")
print(f"FDR significant: {strict_summary.n_fdr_significant}")
print(f"FWER significant: {strict_summary.n_fwer_significant}")

## 4. Signal Selection Algorithms

Use `SignalSelector` to identify the most promising signals from a large universe.

In [None]:
# Get summary DataFrame for selection
df = summary.get_dataframe()

# Method 1: Top N by IC IR
top_by_ir = SignalSelector.select_top_n(
    summary_df=df,
    n=5,
    metric="ic_ir",
)
print("Top 5 by IC IR:")
print(top_by_ir)

# Method 2: Top N with low turnover
low_turnover = SignalSelector.select_top_n(
    summary_df=df,
    n=5,
    metric="turnover_mean",
    ascending=True,  # Lowest turnover
)
print("\nTop 5 lowest turnover:")
print(low_turnover)

# Method 3: Only FDR-significant signals
significant_only = SignalSelector.select_top_n(
    summary_df=df,
    n=5,
    metric="ic_ir",
    filter_significant=True,
)
print("\nTop 5 FDR-significant:")
print(significant_only)

In [None]:
# Method 4: Select uncorrelated signals
# First get correlation matrix
correlation_matrix = summary.get_correlation_matrix()

if correlation_matrix is not None:
    uncorrelated = SignalSelector.select_uncorrelated(
        summary_df=df,
        correlation_df=correlation_matrix,
        n=5,
        max_correlation=0.5,  # Exclude pairs with |corr| > 0.5
    )
    print("Top 5 uncorrelated signals:")
    print(uncorrelated)
else:
    print("No correlation matrix available - run with compute_correlation=True")

In [None]:
# Method 5: Pareto frontier (IC IR vs Turnover)
pareto_signals = SignalSelector.select_pareto_frontier(
    summary_df=df,
    x_metric="turnover_mean",  # Minimize
    y_metric="ic_ir",  # Maximize
)
print("Pareto-optimal signals (IC IR vs Turnover):")
print(pareto_signals)

In [None]:
# Method 6: Cluster-based selection (one representative per cluster)
if correlation_matrix is not None:
    cluster_reps = SignalSelector.select_by_cluster(
        correlation_df=correlation_matrix,
        summary_df=df,
        n_clusters=4,
        signals_per_cluster=1,
    )
    print("Cluster representatives:")
    print(cluster_reps)

## 5. Visualization

Create interactive visualizations using the Focus+Context pattern.

In [None]:
# IC Ridge Plot - shows IC distribution per signal
fig = plot_ic_ridge(
    summary,
    max_signals=15,
    sort_by="ic_mean",
    show_significance=True,
)
fig.show()

In [None]:
# Signal Ranking Bar Chart
fig = plot_signal_ranking_bar(
    summary,
    metric="ic_ir",
    max_signals=15,
    show_significance=True,
)
fig.show()

In [None]:
# Signal Correlation Heatmap (with hierarchical clustering)
fig = plot_signal_correlation_heatmap(
    summary,
    max_signals=20,
)
fig.show()

In [None]:
# Pareto Frontier (IC IR vs Turnover)
fig = plot_pareto_frontier(
    summary,
    x_metric="turnover_mean",
    y_metric="ic_ir",
)
fig.show()

## 6. Multi-Signal Dashboard

The `MultiSignalDashboard` combines all visualizations into an interactive HTML dashboard.

In [None]:
# Create multi-tab dashboard
dashboard = MultiSignalDashboard(summary)

# Save as HTML file
dashboard.save_html("multi_signal_dashboard.html")
print("Dashboard saved to multi_signal_dashboard.html")

In [None]:
# Or get as HTML string for embedding
html_content = dashboard.to_html()
print(f"Dashboard HTML size: {len(html_content):,} bytes")

## 7. Signal Comparison

Compare selected signals in detail.

In [None]:
# Select top signals for comparison
top_signals = SignalSelector.select_top_n(df, n=3, metric="ic_ir")

# Run comparison
comparison = analyzer.compare(signal_names=top_signals)

print(f"Comparing: {comparison.signal_names}")
print("\nComparison metrics:")
print(comparison.metrics_df)

In [None]:
# Save comparison as HTML
comparison.save_html("signal_comparison.html")
print("Comparison saved to signal_comparison.html")

## 8. Working with Large Signal Sets

Tips for handling 100+ signals efficiently:

In [None]:
# Configuration for large signal sets
large_config = MultiSignalAnalysisConfig(
    # Enable caching - huge speedup for re-analysis
    cache_enabled=True,
    cache_max_items=500,
    # Use all cores
    n_jobs=-1,
    # Faster single-period analysis for initial screening
    signal_config=SignalAnalysisConfig(
        periods=(5,),  # Just 5-day forward returns
    ),
    # Display limits
    max_display_signals=50,
)

print("Configuration for large signal sets:")
print(f"  Cache enabled: {large_config.cache_enabled}")
print(f"  N jobs: {large_config.n_jobs}")
print(f"  Periods: {large_config.signal_config.periods}")

## Summary

Key takeaways for multi-signal analysis:

1. **Always apply multiple testing corrections** - FDR or FWER to avoid false discoveries
2. **Use signal selection algorithms** - Don't just take top N by raw IC
3. **Consider correlations** - Uncorrelated signals provide more diversification
4. **Trade off IC vs Turnover** - Pareto frontier shows efficient signals
5. **Enable caching** - Speeds up re-analysis significantly
6. **Use dashboards** - Interactive exploration is key for 50-200 signals