# T-Maze Connectivity Analysis

This notebook demonstrates connectivity analysis for T-maze data:
- Functional connectivity matrices
- Dynamic connectivity
- Graph theory metrics
- EEG phase-based connectivity

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import sys
sys.path.append('..')

from connectivity import (
    compute_fc_matrix,
    partial_correlation,
    fc_condition_contrast
)
from connectivity.dynamic import (
    sliding_window_fc,
    detect_fc_states,
    fc_variability
)
from connectivity.graph import (
    compute_graph_metrics,
    modularity_detection,
    small_world_index,
    hub_identification
)
from connectivity.eeg_connectivity import (
    phase_lag_index,
    weighted_phase_lag_index,
    coherence,
    compute_connectivity_bands
)

plt.style.use('seaborn-v0_8-whitegrid')
%matplotlib inline

## 1. Simulate fMRI Timeseries

In [None]:
# Simulate fMRI timeseries with known connectivity structure
np.random.seed(42)

n_timepoints = 200  # TRs
n_rois = 50

# Create ground truth connectivity matrix with community structure
true_fc = np.zeros((n_rois, n_rois))

# Define 4 networks
networks = {
    'DMN': list(range(0, 12)),
    'FPN': list(range(12, 24)),
    'SAL': list(range(24, 36)),
    'VIS': list(range(36, 50))
}

# Strong within-network connectivity
for net_name, roi_idx in networks.items():
    for i in roi_idx:
        for j in roi_idx:
            if i != j:
                true_fc[i, j] = 0.6 + np.random.rand() * 0.2

# Weak between-network connectivity
for i in range(n_rois):
    for j in range(i + 1, n_rois):
        if true_fc[i, j] == 0:
            true_fc[i, j] = 0.1 + np.random.rand() * 0.1
            true_fc[j, i] = true_fc[i, j]

np.fill_diagonal(true_fc, 1)

# Generate timeseries from this structure
from scipy.linalg import cholesky
L = cholesky(true_fc, lower=True)
noise = np.random.randn(n_timepoints, n_rois)
timeseries = noise @ L.T

print(f"Timeseries shape: {timeseries.shape}")
print(f"Networks: {list(networks.keys())}")

## 2. Functional Connectivity Matrix

In [None]:
# Compute FC matrix
roi_names = [f'ROI_{i:02d}' for i in range(n_rois)]

fc_result = compute_fc_matrix(
    timeseries,
    method='pearson',
    roi_names=roi_names
)

print(f"FC matrix shape: {fc_result.matrix.shape}")
print(f"Method: {fc_result.method}")

# Plot FC matrix
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# True FC
im1 = axes[0].imshow(true_fc, cmap='RdBu_r', vmin=-1, vmax=1)
axes[0].set_title('Ground Truth FC')
plt.colorbar(im1, ax=axes[0], label='Correlation')

# Estimated FC
im2 = axes[1].imshow(fc_result.matrix, cmap='RdBu_r', vmin=-1, vmax=1)
axes[1].set_title('Estimated FC (Pearson)')
plt.colorbar(im2, ax=axes[1], label='Correlation')

# Add network boundaries
for ax in axes:
    for net_name, roi_idx in networks.items():
        start, end = roi_idx[0], roi_idx[-1]
        ax.axhline(start - 0.5, color='black', linewidth=0.5)
        ax.axvline(start - 0.5, color='black', linewidth=0.5)

plt.tight_layout()
plt.show()

## 3. Partial Correlation (Direct Connectivity)

In [None]:
# Compute partial correlation
partial_fc = partial_correlation(
    timeseries,
    roi_names=roi_names,
    regularize=True,
    shrinkage=0.1
)

# Compare Pearson vs Partial
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Pearson
im1 = axes[0].imshow(fc_result.matrix, cmap='RdBu_r', vmin=-1, vmax=1)
axes[0].set_title('Pearson Correlation')
plt.colorbar(im1, ax=axes[0])

# Partial
im2 = axes[1].imshow(partial_fc.matrix, cmap='RdBu_r', vmin=-0.5, vmax=0.5)
axes[1].set_title('Partial Correlation')
plt.colorbar(im2, ax=axes[1])

plt.tight_layout()
plt.show()

print(f"Pearson mean (off-diagonal): {np.mean(np.abs(fc_result.matrix[np.triu_indices(n_rois, k=1)])):.3f}")
print(f"Partial mean (off-diagonal): {np.mean(np.abs(partial_fc.matrix[np.triu_indices(n_rois, k=1)])):.3f}")

## 4. Dynamic Connectivity

In [None]:
# Sliding window dynamic FC
dfc = sliding_window_fc(
    timeseries,
    window_size=30,
    step_size=5,
    method='pearson',
    window_type='hamming'
)

print(f"Number of windows: {dfc.n_windows}")
print(f"Window size: {dfc.window_size} TRs")

# Plot FC variability
fc_var = fc_variability(dfc, metric='std')

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Mean FC
im1 = axes[0].imshow(dfc.mean_fc(), cmap='RdBu_r', vmin=-1, vmax=1)
axes[0].set_title('Mean Dynamic FC')
plt.colorbar(im1, ax=axes[0])

# FC Variability
im2 = axes[1].imshow(fc_var, cmap='hot')
axes[1].set_title('FC Variability (SD across windows)')
plt.colorbar(im2, ax=axes[1])

plt.tight_layout()
plt.show()

In [None]:
# Detect FC states
states = detect_fc_states(dfc, n_states=4, method='kmeans')

print(f"Detected {states['n_states']} states")
print(f"State fractions: {states['state_fractions']}")

# Plot state centroids
fig, axes = plt.subplots(1, 4, figsize=(16, 4))
for i, ax in enumerate(axes):
    im = ax.imshow(states['centroids'][i], cmap='RdBu_r', vmin=-1, vmax=1)
    ax.set_title(f"State {i+1} ({states['state_fractions'][i]:.1%})")
    plt.colorbar(im, ax=ax)
plt.tight_layout()
plt.show()

# Plot state time course
fig, ax = plt.subplots(figsize=(14, 3))
ax.plot(states['state_labels'], 'k-', linewidth=0.5)
ax.scatter(range(len(states['state_labels'])), states['state_labels'], 
           c=states['state_labels'], cmap='tab10', s=20)
ax.set_xlabel('Window')
ax.set_ylabel('State')
ax.set_title('FC State Time Course')
plt.tight_layout()
plt.show()

## 5. Graph Theory Metrics

In [None]:
# Compute graph metrics
from connectivity.functional import fc_to_adjacency

# Threshold FC to create adjacency
adj = fc_to_adjacency(fc_result.matrix, density=0.2, binarize=False)

# Compute metrics
graph = compute_graph_metrics(adj, weighted=True, node_names=roi_names)

print("\n" + "="*50)
print("GRAPH METRICS")
print("="*50)
print(f"Density: {graph.density:.3f}")
print(f"Global efficiency: {graph.global_efficiency:.3f}")
print(f"Clustering coefficient: {graph.clustering_coefficient:.3f}")
print(f"Characteristic path length: {graph.characteristic_path_length:.3f}")
print(f"Modularity: {graph.modularity:.3f}")
print(f"Number of modules: {graph.n_communities}")
print(f"Small-worldness (sigma): {graph.small_worldness:.3f}" if graph.small_worldness else "")

In [None]:
# Identify hub nodes
hubs = hub_identification(adj, metrics=graph, threshold=1.5, method='multi')

print(f"\nIdentified {hubs['n_hubs']} hub nodes:")
for idx in hubs['hub_indices']:
    print(f"  {roi_names[idx]}: degree={graph.degree[idx]:.2f}, betweenness={graph.betweenness[idx]:.3f}")

# Plot node metrics
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Degree
axes[0].bar(range(n_rois), graph.degree)
axes[0].set_xlabel('ROI')
axes[0].set_ylabel('Degree')
axes[0].set_title('Node Degree')

# Betweenness
axes[1].bar(range(n_rois), graph.betweenness)
axes[1].set_xlabel('ROI')
axes[1].set_ylabel('Betweenness')
axes[1].set_title('Betweenness Centrality')

# Community assignment
colors = plt.cm.tab10(graph.community_assignments)
axes[2].bar(range(n_rois), np.ones(n_rois), color=colors)
axes[2].set_xlabel('ROI')
axes[2].set_title('Community Assignment')
axes[2].set_yticks([])

plt.tight_layout()
plt.show()

## 6. EEG Phase-Based Connectivity

In [None]:
# Simulate EEG data
np.random.seed(42)

n_epochs = 100
n_channels = 32
n_times = 500  # 2s at 250 Hz
sfreq = 250

# Generate EEG with alpha oscillations
eeg_data = np.random.randn(n_epochs, n_channels, n_times) * 5
t = np.linspace(0, 2, n_times)

# Add 10 Hz alpha to occipital channels
occipital = [28, 29, 30, 31]
for ch in occipital:
    eeg_data[:, ch, :] += 10 * np.sin(2 * np.pi * 10 * t)

# Add phase-locked alpha between occipital channels
for i in range(len(occipital) - 1):
    for j in range(i + 1, len(occipital)):
        # Consistent phase difference
        eeg_data[:, occipital[j], :] += 3 * np.sin(2 * np.pi * 10 * t + np.pi/4)

print(f"EEG data shape: {eeg_data.shape}")

In [None]:
# Compute wPLI in alpha band
channel_names = [f'Ch{i:02d}' for i in range(n_channels)]

wpli = weighted_phase_lag_index(
    eeg_data,
    sfreq=sfreq,
    fmin=8,
    fmax=13,
    channel_names=channel_names
)

print(f"wPLI matrix shape: {wpli.matrix.shape}")

# Plot
fig, ax = plt.subplots(figsize=(10, 8))
im = ax.imshow(wpli.matrix, cmap='hot', vmin=0, vmax=0.5)
plt.colorbar(im, label='wPLI')
ax.set_title('Weighted Phase Lag Index (Alpha Band: 8-13 Hz)')
ax.set_xlabel('Channel')
ax.set_ylabel('Channel')
plt.tight_layout()
plt.show()

In [None]:
# Compute connectivity in multiple bands
bands = {
    'delta': (1, 4),
    'theta': (4, 8),
    'alpha': (8, 13),
    'beta': (13, 30)
}

band_conn = compute_connectivity_bands(
    eeg_data,
    sfreq=sfreq,
    method='wpli',
    bands=bands
)

# Plot band-specific connectivity
fig, axes = plt.subplots(1, 4, figsize=(16, 4))

for ax, (band_name, matrix) in zip(axes, band_conn.items()):
    im = ax.imshow(matrix, cmap='hot', vmin=0, vmax=0.3)
    ax.set_title(f'{band_name.capitalize()} ({bands[band_name][0]}-{bands[band_name][1]} Hz)')
    plt.colorbar(im, ax=ax)

plt.tight_layout()
plt.show()

# Mean connectivity per band
print("\nMean connectivity per band:")
for band_name, matrix in band_conn.items():
    mean_conn = np.mean(matrix[np.triu_indices(n_channels, k=1)])
    print(f"  {band_name}: {mean_conn:.3f}")

## Summary

In [None]:
print("\n" + "="*60)
print("CONNECTIVITY ANALYSIS SUMMARY")
print("="*60)
print(f"\nfMRI Connectivity:")
print(f"  - Pearson correlation FC computed")
print(f"  - Partial correlation (regularized) computed")
print(f"  - Dynamic FC: {dfc.n_windows} windows")
print(f"  - FC states: {states['n_states']} states detected")
print(f"\nGraph Metrics:")
print(f"  - Modularity: {graph.modularity:.3f}")
print(f"  - Small-world: {graph.small_worldness:.3f}" if graph.small_worldness else "")
print(f"  - Hub nodes: {hubs['n_hubs']}")
print(f"\nEEG Connectivity:")
print(f"  - wPLI computed for alpha band")
print(f"  - Multi-band connectivity computed")
print("="*60)