# Example 2: The Importance of Time

In the first example, we assumed each sample $(x_i, y_i)$ was independent. However, in most biological and neural data, relationships are spread out over time. A stimulus now might affect a neural response hundreds of milliseconds later.

This notebook demonstrates how to handle such temporal dependencies.

**Goal:**
1.  Introduce the `data.processors` for windowing continuous data.
2.  Showcase `run(mode='sweep')` to find the optimal `window_size`.
3.  Demonstrate why choosing the correct timescale is critical for MI estimation.

## 1. Imports

We'll need our standard imports, plus `matplotlib` for plotting the results of our sweep.

In [None]:
import torch
import numpy as np
import neural_mi as nmi
import matplotlib.pyplot as plt
import seaborn as sns

# Use a nice style for plotting
sns.set_context("talk")

## 2. Generating Temporally Dependent Data

We'll use `generate_temporally_convolved_data`. This function creates a latent signal `Z`, and then generates `X` and `Y` by smearing `Z` with different time kernels. This means the relationship between a single point $x_t$ and $y_t$ is weak, but the relationship between a *window* of X and a *window* of Y is strong.

For this type of data, the ground truth MI isn't easily known, which is realistic. Our goal is not to match a known value, but to find the timescale that *maximizes* the MI.

In [None]:
# --- Dataset Parameters ---
n_samples = 10000

# --- Generate Data ---
# The output shape is [1, n_samples] which represents [n_channels, n_timepoints]
x_raw, y_raw = nmi.datasets.generate_temporally_convolved_data(n_samples=n_samples, use_torch=False)

print(f"Generated raw X data shape: {x_raw.shape}")
print(f"Generated raw Y data shape: {y_raw.shape}")

# Let's visualize the raw data to see the temporal relationship
plt.figure(figsize=(12, 4))
plt.plot(x_raw[0, :200], label='X', alpha=0.8)
plt.plot(y_raw[0, :200], label='Y', alpha=0.8)
plt.xlabel("Timepoints")
plt.ylabel("Signal")
plt.title("Raw Temporal Data (first 200 points)")
plt.legend()
plt.show()

## 3. The Problem: A Naive Estimate Fails

If we treat each timepoint as an independent sample (a window size of 1), we fail to capture the smeared relationship. Let's prove this by processing the data with a `window_size=1` and running a quick estimate.

In [None]:
# Process the data with a window size of 1
processor = nmi.data.ContinuousProcessor(window_size=1, step_size=1)
x_naive = processor.process(x_raw)
y_naive = processor.process(y_raw)

print(f"Shape after processing with window_size=1: {x_naive.shape}")

# Use the same base parameters as before
base_params = {
    'n_epochs': 50, 'learning_rate': 1e-3, 'batch_size': 128,
    'patience': 5, 'embedding_dim': 16, 'hidden_dim': 64, 'n_layers': 2
}

naive_mi = nmi.run(x_data=x_naive, y_data=y_naive, mode='estimate', base_params=base_params)
print(f"\nNaive MI estimate (window_size=1): {naive_mi:.3f} bits")

As expected, the MI estimate is very low. We've missed the real relationship.

## 4. The Solution: Sweeping Over Window Size

To find the correct timescale, we need to test many different window sizes. This is a hyperparameter search, which is exactly what `mode='sweep'` is for.

The process is:
1.  Define a `sweep_grid`. This dictionary tells the `run` function which parameters to vary. Our key will be `window_size`.
2.  The `run` function will iterate through each value in the grid.
3.  **Crucially, we do not pre-process the data.** The sweep engine handles the processing internally for each `window_size`.

*Note: The parameter `window_size` is not a model parameter, it's a data processing parameter. The library is designed to handle this by looking for specific keys in the sweep grid and applying them to the processor before training.*

In [None]:
# The sweep will create a new ContinuousProcessor for each value.
sweep_grid = {
    'window_size': [1, 5, 10, 15, 20, 25, 30, 40, 50, 100, 200, 500, 1000]
}

# Notice we pass the RAW data to the run function
sweep_results_df = nmi.run(
    x_data=torch.from_numpy(x_raw).float(), # The runner needs tensors
    y_data=torch.from_numpy(y_raw).float(),
    mode='sweep',
    base_params=base_params,
    sweep_grid=sweep_grid,
    # This tells the run function to perform a processing sweep.
    processor_type='continuous',
    # Speed up the sweep with parallel workers
    n_workers=4
)

display(sweep_results_df)

## 5. Analyzing the Results

The output is a pandas DataFrame containing the results for each hyperparameter combination. Now we can simply plot the `test_mi` against the `window_size` to find the peak.

In [None]:
best_run = sweep_results_df.loc[sweep_results_df['test_mi'].idxmax()]

# --- Now create the plot ---
plt.figure(figsize=(10, 6))
sns.lineplot(data=sweep_results_df, x='window_size', y='test_mi', marker='o')
plt.axvline(x=best_run['window_size'], c='r', ls=':', label=f"Optimal ({best_run['window_size']})")
plt.xlabel("Window Size (timepoints)")
plt.ylabel("Estimated MI (bits)")
plt.title("MI vs. Window Size")
plt.grid(True, linestyle=':')
plt.xscale('log')
plt.legend()
plt.show()

print("--- Best Result ---")
print(f"Optimal Window Size: {best_run['window_size']}")
print(f"Maximum MI Estimated: {best_run['test_mi']:.3f} bits")

## 6. Conclusion

The result is clear! The estimated MI peaks at a window size around $100\sim 200$ timepoints. This is few times the characteristic timescale of the relationship in our generated data. By using the right window, we recovered a strong MI ($\sim 5$ bits) that was completely invisible to the naive, point-by-point estimate ($\sim 2$ bits) --numbers can differ slightly--.

This example demonstrates a core workflow for analyzing real experimental data:
1.  Start with raw time-series data.
2.  Use a `sweep` over `window_size` to find the timescale that maximizes information.
3.  This optimal window size is itself a valuable scientific finding.

In the next example, we'll explore the internal structure of a single, high-dimensional dataset to estimate its 'latent dimensionality'.