# Multivariate Decoding Analysis for Neural Replay

Based on **Liu et al. (2019) Cell**: *Human Replay Spontaneously Reorganizes Experience*

This notebook demonstrates the multivariate decoding algorithm for detecting sequential replay of neural representations during rest periods.

## Key Concepts

1. **Stimulus Decoders**: Train classifiers to recognize neural patterns associated with each stimulus
2. **Reactivation Detection**: Apply decoders to rest periods to detect spontaneous reactivations
3. **Sequenceness Measure**: Quantify the degree to which reactivations follow a sequential structure
4. **Statistical Testing**: Use permutation tests to establish significance

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import cross_val_score
import seaborn as sns
from scipy import stats

# Set style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

print("Libraries imported successfully!")

## Part 1: Implementation of the Decoding Algorithm

In [None]:
class MultivariateReplayDecoder:
    """
    Multivariate decoding analysis for detecting neural replay sequences
    based on Liu et al. (2019) Cell paper.
    """
    
    def __init__(self, n_states=8, max_lag_ms=600, sampling_rate=100):
        """
        Parameters:
        -----------
        n_states : int
            Number of distinct states/stimuli
        max_lag_ms : int
            Maximum time lag to test in milliseconds
        sampling_rate : int
            Sampling rate in Hz
        """
        self.n_states = n_states
        self.max_lag_ms = max_lag_ms
        self.sampling_rate = sampling_rate
        self.max_lag_samples = int(max_lag_ms * sampling_rate / 1000)
        self.classifiers = []
        self.scalers = []
        
    def train_classifiers(self, X_train, y_train, time_point_ms=200, C=1.0):
        """
        Train binary classifiers for each state using lasso logistic regression.
        
        Parameters:
        -----------
        X_train : ndarray, shape (n_trials, n_timepoints, n_sensors)
            Training data from functional localizer
        y_train : ndarray, shape (n_trials,)
            Labels for each trial (0 to n_states-1)
        time_point_ms : int
            Time point relative to stimulus onset to use for training
        C : float
            Inverse regularization strength
        """
        time_idx = int(time_point_ms * self.sampling_rate / 1000)
        X_at_time = X_train[:, time_idx, :]  # Extract specific time point
        
        self.classifiers = []
        self.scalers = []
        
        # Train one binary classifier per state
        for state in range(self.n_states):
            # Create binary labels
            y_binary = (y_train == state).astype(int)
            
            # Standardize features
            scaler = StandardScaler()
            X_scaled = scaler.fit_transform(X_at_time)
            
            # Train L1-regularized logistic regression
            clf = LogisticRegression(penalty='l1', C=C, solver='liblinear',
                                    max_iter=1000, random_state=42)
            clf.fit(X_scaled, y_binary)
            
            self.classifiers.append(clf)
            self.scalers.append(scaler)
            
        print(f"Trained {self.n_states} classifiers")
        
    def decode_states(self, X_rest):
        """
        Apply trained classifiers to resting-state data.
        
        Parameters:
        -----------
        X_rest : ndarray, shape (n_timepoints, n_sensors)
            Resting state MEG/EEG data
            
        Returns:
        --------
        probabilities : ndarray, shape (n_timepoints, n_states)
            Reactivation probabilities for each state at each time point
        """
        n_timepoints = X_rest.shape[0]
        probabilities = np.zeros((n_timepoints, self.n_states))
        
        for state in range(self.n_states):
            X_scaled = self.scalers[state].transform(X_rest)
            probabilities[:, state] = self.classifiers[state].predict_proba(X_scaled)[:, 1]
            
        return probabilities
    
    def compute_sequenceness(self, reactivation_probs, transition_matrix, 
                            time_lags_ms=None, alpha_control=True):
        """
        Compute sequenceness measure using time-lagged regression.
        
        Parameters:
        -----------
        reactivation_probs : ndarray, shape (n_timepoints, n_states)
            State reactivation probabilities from decode_states()
        transition_matrix : ndarray, shape (n_states, n_states)
            Hypothesized transition structure (1 for transitions, 0 otherwise)
        time_lags_ms : array-like, optional
            Specific time lags to test in milliseconds
        alpha_control : bool
            Whether to include nuisance regressors for 10Hz oscillations
            
        Returns:
        --------
        sequenceness : ndarray
            Sequenceness measure at each time lag (forward - backward)
        time_lags : ndarray
            Time lags tested in milliseconds
        """
        if time_lags_ms is None:
            time_lags_ms = np.arange(10, self.max_lag_ms + 1, 10)
            
        time_lags_samples = (time_lags_ms * self.sampling_rate / 1000).astype(int)
        sequenceness_forward = np.zeros(len(time_lags_ms))
        sequenceness_backward = np.zeros(len(time_lags_ms))
        
        Y = reactivation_probs  # Current activations
        n_timepoints, n_states = Y.shape
        
        # Normalize transition matrix
        P = transition_matrix / (np.sum(transition_matrix) + 1e-10)
        
        for lag_idx, lag in enumerate(time_lags_samples):
            # Create time-lagged design matrix
            X_lag = self._create_lagged_matrix(reactivation_probs, lag, alpha_control)
            
            # Fit regression for each state
            beta_matrix = np.zeros((n_states, n_states))
            
            for state_i in range(n_states):
                # Regression: Y_i = X(Δt) * β
                if lag > 0:
                    y_i = Y[lag:, state_i]
                    X_reg = X_lag[:len(y_i)]
                else:
                    y_i = Y[:lag, state_i]
                    X_reg = X_lag[:len(y_i)]
                
                # Ordinary least squares
                try:
                    beta = np.linalg.lstsq(X_reg, y_i, rcond=None)[0]
                    beta_matrix[state_i, :] = beta[:n_states]
                except:
                    continue
            
            # Project onto transition matrix (Frobenius inner product)
            sequenceness_forward[lag_idx] = np.sum(beta_matrix * P)
            
            # Backward direction (transpose)
            sequenceness_backward[lag_idx] = np.sum(beta_matrix * P.T)
        
        # Sequenceness = forward - backward
        sequenceness = sequenceness_forward - sequenceness_backward
        
        return sequenceness, time_lags_ms
    
    def _create_lagged_matrix(self, reactivation_probs, lag, alpha_control=True):
        """
        Create time-lagged predictor matrix with optional alpha confounds.
        """
        n_timepoints, n_states = reactivation_probs.shape
        
        # Basic lagged matrix
        if lag > 0:
            X_lag = reactivation_probs[:-lag, :]
        else:
            X_lag = reactivation_probs[-lag:, :]
        
        if alpha_control:
            # Add confound regressors at Δt+100ms, Δt+200ms, ... up to Δt+600ms
            confounds = []
            for extra_lag_ms in range(100, 700, 100):
                extra_lag_samples = int(extra_lag_ms * self.sampling_rate / 1000)
                total_lag = lag + extra_lag_samples
                
                if total_lag < n_timepoints:
                    if total_lag > 0:
                        confound = reactivation_probs[:-total_lag, :]
                    else:
                        confound = reactivation_probs[-total_lag:, :]
                    
                    # Match length
                    min_len = min(X_lag.shape[0], confound.shape[0])
                    confounds.append(confound[:min_len, :])
            
            if confounds:
                # Concatenate confounds
                min_len = min([X_lag.shape[0]] + [c.shape[0] for c in confounds])
                X_lag = X_lag[:min_len, :]
                confounds = [c[:min_len, :] for c in confounds]
                X_lag = np.hstack([X_lag] + confounds)
        
        # Add constant term
        X_lag = np.hstack([X_lag, np.ones((X_lag.shape[0], 1))])
        
        return X_lag
    
    def permutation_test(self, reactivation_probs, transition_matrix, 
                        n_permutations=1000, time_lags_ms=None):
        """
        Statistical testing via permutation of stimulus labels.
        
        Parameters:
        -----------
        reactivation_probs : ndarray
            State reactivation probabilities
        transition_matrix : ndarray
            Hypothesized transition structure
        n_permutations : int
            Number of permutations
        time_lags_ms : array-like, optional
            Time lags to test
            
        Returns:
        --------
        p_values : ndarray
            P-values at each time lag
        threshold : float
            Significance threshold (corrected for multiple comparisons)
        true_seq : ndarray
            True sequenceness values
        """
        # Compute true sequenceness
        true_seq, time_lags = self.compute_sequenceness(
            reactivation_probs, transition_matrix, time_lags_ms
        )
        
        # Permutation distribution
        perm_max_abs = np.zeros(n_permutations)
        
        for perm in range(n_permutations):
            # Permute transition matrix (shuffle rows and columns together)
            perm_indices = np.random.permutation(self.n_states)
            P_perm = transition_matrix[perm_indices, :][:, perm_indices]
            
            # Compute sequenceness with permuted matrix
            perm_seq, _ = self.compute_sequenceness(
                reactivation_probs, P_perm, time_lags_ms
            )
            
            # Store maximum absolute value across lags
            perm_max_abs[perm] = np.max(np.abs(perm_seq))
        
        # Threshold at 95th percentile
        threshold = np.percentile(perm_max_abs, 95)
        
        # Compute p-values
        p_values = np.array([
            np.mean(perm_max_abs >= np.abs(seq_val))
            for seq_val in true_seq
        ])
        
        return p_values, threshold, true_seq

print("MultivariateReplayDecoder class defined!")

## Part 2: Example 1 - Basic Simulation with Forward Replay

We'll simulate a simple scenario:
- 4 states: A, B, C, D
- Sequence: A → B → C → D
- Replay lag: 40ms (as in the paper)

In [None]:
# Set random seed for reproducibility
np.random.seed(42)

# Parameters
n_trials = 200
n_timepoints_per_trial = 50  # 500ms at 100Hz
n_sensors = 100
n_states = 4

print("Creating simulated training data...")
print(f"  - {n_trials} trials")
print(f"  - {n_states} states (A, B, C, D)")
print(f"  - {n_sensors} sensors")
print(f"  - {n_timepoints_per_trial} timepoints per trial (500ms)")

# Simulate training data (functional localizer)
X_train = np.random.randn(n_trials, n_timepoints_per_trial, n_sensors) * 0.5
y_train = np.random.randint(0, n_states, n_trials)

# Add signal at 200ms for each state
time_idx = 20  # 200ms
for trial in range(n_trials):
    state = y_train[trial]
    # Add state-specific pattern to specific sensors
    X_train[trial, time_idx, state*10:(state+1)*10] += 2.0

print("\nTraining data created!")
print(f"Shape: {X_train.shape}")

In [None]:
# Visualize training data
fig, axes = plt.subplots(2, 2, figsize=(12, 8))
axes = axes.ravel()

state_names = ['A', 'B', 'C', 'D']
for state in range(n_states):
    # Get trials for this state
    state_trials = X_train[y_train == state]
    
    # Average across trials and plot
    avg_activity = state_trials.mean(axis=0)  # (timepoints, sensors)
    
    im = axes[state].imshow(avg_activity.T, aspect='auto', cmap='RdBu_r', 
                            vmin=-0.5, vmax=0.5, origin='lower')
    axes[state].set_xlabel('Time (ms)', fontsize=10)
    axes[state].set_ylabel('Sensor', fontsize=10)
    axes[state].set_title(f'State {state_names[state]} - Average Activity', fontsize=12, fontweight='bold')
    axes[state].axvline(x=20, color='yellow', linestyle='--', linewidth=2, label='Decoding timepoint')
    axes[state].legend(loc='upper right', fontsize=8)
    
    # Set x-axis labels in milliseconds
    xticks = axes[state].get_xticks()
    axes[state].set_xticklabels([int(x*10) for x in xticks])

plt.colorbar(im, ax=axes, label='Activity (a.u.)', fraction=0.046, pad=0.04)
plt.tight_layout()
plt.savefig('training_data_heatmap.png', dpi=150, bbox_inches='tight')
plt.show()

print("Training data visualization saved!")

In [None]:
# Initialize and train decoder
print("Initializing decoder...")
decoder = MultivariateReplayDecoder(n_states=n_states, max_lag_ms=600, sampling_rate=100)

print("\nTraining classifiers...")
decoder.train_classifiers(X_train, y_train, time_point_ms=200, C=1.0)

# Evaluate decoder performance with cross-validation
print("\nEvaluating decoder performance (cross-validation)...")
time_idx = 20
X_at_time = X_train[:, time_idx, :]

for state in range(n_states):
    y_binary = (y_train == state).astype(int)
    scores = cross_val_score(decoder.classifiers[state], 
                            decoder.scalers[state].transform(X_at_time), 
                            y_binary, cv=5, scoring='roc_auc')
    print(f"  State {state_names[state]}: AUC = {scores.mean():.3f} ± {scores.std():.3f}")

In [None]:
# Simulate resting state data with embedded replay
print("Simulating resting state with replay sequence A→B→C→D...")
n_rest_timepoints = 3000  # 30 seconds at 100Hz
X_rest = np.random.randn(n_rest_timepoints, n_sensors) * 0.3

# Embed sequence A→B→C→D at 40ms lag, at multiple time points
lag_samples = 4  # 40ms at 100Hz
sequence = [0, 1, 2, 3]  # A→B→C→D
replay_times = []

for start_time in range(100, n_rest_timepoints - 100, 200):
    replay_times.append(start_time)
    for step, state in enumerate(sequence):
        time_point = start_time + step * lag_samples
        if time_point < n_rest_timepoints:
            # Add signal to state-specific sensors
            X_rest[time_point, state*10:(state+1)*10] += 1.5

print(f"  Embedded {len(replay_times)} replay events")
print(f"  Replay lag: 40ms between states")
print(f"  Replay interval: ~200ms between events")

In [None]:
# Decode states from resting data
print("\nDecoding states from resting data...")
reactivation_probs = decoder.decode_states(X_rest)
print(f"  Reactivation matrix shape: {reactivation_probs.shape}")
print(f"  Time resolution: 10ms")

# Visualize reactivations
fig, ax = plt.subplots(figsize=(14, 5))

# Plot first 5 seconds
time_window = slice(0, 500)
time_ms = np.arange(500) * 10

for state in range(n_states):
    ax.plot(time_ms, reactivation_probs[time_window, state], 
           label=f'State {state_names[state]}', linewidth=1.5, alpha=0.8)

# Mark embedded replay events
for rt in replay_times:
    if rt < 500:
        ax.axvline(x=rt*10, color='red', linestyle='--', alpha=0.3, linewidth=1)

ax.set_xlabel('Time (ms)', fontsize=12)
ax.set_ylabel('Reactivation Probability', fontsize=12)
ax.set_title('State Reactivations During Rest (First 5 seconds)', fontsize=14, fontweight='bold')
ax.legend(loc='upper right', ncol=4)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('reactivation_timeseries.png', dpi=150, bbox_inches='tight')
plt.show()

print("Reactivation visualization saved!")

In [None]:
# Define transition matrix for A→B→C→D
transition_matrix = np.array([
    [0, 1, 0, 0],  # A → B
    [0, 0, 1, 0],  # B → C
    [0, 0, 0, 1],  # C → D
    [0, 0, 0, 0]   # D → nothing
])

print("Transition matrix (A→B→C→D):")
print(transition_matrix)

# Visualize transition matrix
fig, ax = plt.subplots(figsize=(6, 5))
im = ax.imshow(transition_matrix, cmap='Blues', vmin=0, vmax=1)
ax.set_xticks(range(n_states))
ax.set_yticks(range(n_states))
ax.set_xticklabels(state_names)
ax.set_yticklabels(state_names)
ax.set_xlabel('To State', fontsize=12)
ax.set_ylabel('From State', fontsize=12)
ax.set_title('Hypothesized Transition Matrix', fontsize=14, fontweight='bold')

# Add text annotations
for i in range(n_states):
    for j in range(n_states):
        text = ax.text(j, i, int(transition_matrix[i, j]),
                      ha="center", va="center", color="black", fontsize=16)

plt.colorbar(im, ax=ax, label='Transition')
plt.tight_layout()
plt.savefig('transition_matrix.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# Compute sequenceness
print("Computing sequenceness...")
sequenceness, time_lags = decoder.compute_sequenceness(
    reactivation_probs, transition_matrix, alpha_control=True
)

# Find peak
peak_idx = np.argmax(np.abs(sequenceness))
peak_lag = time_lags[peak_idx]
peak_value = sequenceness[peak_idx]

print(f"\nResults:")
print(f"  Peak sequenceness: {peak_value:.4f}")
print(f"  Peak lag: {peak_lag}ms")
print(f"  Expected lag: 40ms (embedded in data)")
print(f"  Direction: {'Forward' if peak_value > 0 else 'Backward'}")

In [None]:
# Run permutation test
print("\nRunning permutation test (500 permutations)...")
print("This may take a minute...")

p_values, threshold, true_seq = decoder.permutation_test(
    reactivation_probs, transition_matrix, n_permutations=500
)

significant_lags = time_lags[p_values < 0.05]
print(f"\nStatistical Results:")
print(f"  Significance threshold: {threshold:.4f}")
print(f"  Number of significant lags: {len(significant_lags)}")
if len(significant_lags) > 0:
    print(f"  Significant time lags: {significant_lags}ms")
    print(f"  Peak p-value: {p_values[peak_idx]:.4f}")

In [None]:
# Visualize sequenceness results
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8))

# Plot 1: Sequenceness
ax1.plot(time_lags, sequenceness, 'b-', linewidth=2.5, label='Sequenceness')
ax1.axhline(y=0, color='k', linestyle='--', alpha=0.3)
ax1.axhline(y=threshold, color='r', linestyle='--', linewidth=2, label=f'Threshold (p<0.05)')
ax1.axhline(y=-threshold, color='r', linestyle='--', linewidth=2)
ax1.axvline(x=40, color='green', linestyle=':', linewidth=2, alpha=0.7, label='True lag (40ms)')
ax1.fill_between(time_lags, 0, sequenceness, where=(sequenceness > threshold), 
                 alpha=0.3, color='blue', label='Significant')
ax1.set_xlabel('Time Lag (ms)', fontsize=12)
ax1.set_ylabel('Sequenceness\n(Forward - Backward)', fontsize=12)
ax1.set_title('Example 1: Forward Replay Detection (A→B→C→D)', fontsize=14, fontweight='bold')
ax1.legend(loc='upper right')
ax1.grid(True, alpha=0.3)

# Plot 2: P-values
ax2.plot(time_lags, p_values, 'purple', linewidth=2)
ax2.axhline(y=0.05, color='r', linestyle='--', linewidth=2, label='p=0.05')
ax2.fill_between(time_lags, 0, p_values, where=(p_values < 0.05), 
                alpha=0.3, color='purple')
ax2.set_xlabel('Time Lag (ms)', fontsize=12)
ax2.set_ylabel('P-value', fontsize=12)
ax2.set_title('Statistical Significance', fontsize=12)
ax2.set_yscale('log')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('sequenceness_example1.png', dpi=150, bbox_inches='tight')
plt.show()

print("\nSequenceness analysis complete!")

## Part 3: Example 2 - Reverse Replay After Reward

The paper showed that replay reverses direction after reward. Let's simulate this phenomenon.

In [None]:
# Simulate resting state with REVERSE replay (D→C→B→A)
print("Simulating resting state with REVERSE replay (D→C→B→A)...")
X_rest_reverse = np.random.randn(n_rest_timepoints, n_sensors) * 0.3

# Embed REVERSE sequence D→C→B→A at 40ms lag
reverse_sequence = [3, 2, 1, 0]  # D→C→B→A
replay_times_reverse = []

for start_time in range(100, n_rest_timepoints - 100, 200):
    replay_times_reverse.append(start_time)
    for step, state in enumerate(reverse_sequence):
        time_point = start_time + step * lag_samples
        if time_point < n_rest_timepoints:
            X_rest_reverse[time_point, state*10:(state+1)*10] += 1.5

print(f"  Embedded {len(replay_times_reverse)} REVERSE replay events")
print(f"  Replay lag: 40ms between states")

In [None]:
# Decode reverse replay
print("\nDecoding reverse replay...")
reactivation_probs_reverse = decoder.decode_states(X_rest_reverse)

# Compute sequenceness for reverse replay
print("Computing sequenceness for reverse replay...")
sequenceness_reverse, _ = decoder.compute_sequenceness(
    reactivation_probs_reverse, transition_matrix, alpha_control=True
)

# Find peak
peak_idx_reverse = np.argmax(np.abs(sequenceness_reverse))
peak_lag_reverse = time_lags[peak_idx_reverse]
peak_value_reverse = sequenceness_reverse[peak_idx_reverse]

print(f"\nResults:")
print(f"  Peak sequenceness: {peak_value_reverse:.4f}")
print(f"  Peak lag: {peak_lag_reverse}ms")
print(f"  Direction: {'Forward' if peak_value_reverse > 0 else 'Backward (Reverse)'}")

In [None]:
# Compare forward and reverse replay
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Forward replay
axes[0].plot(time_lags, sequenceness, 'b-', linewidth=2.5)
axes[0].axhline(y=0, color='k', linestyle='--', alpha=0.3)
axes[0].axhline(y=threshold, color='r', linestyle='--', linewidth=2)
axes[0].axhline(y=-threshold, color='r', linestyle='--', linewidth=2)
axes[0].axvline(x=40, color='green', linestyle=':', linewidth=2, alpha=0.7)
axes[0].fill_between(time_lags, 0, sequenceness, where=(sequenceness > 0), 
                     alpha=0.3, color='blue')
axes[0].set_xlabel('Time Lag (ms)', fontsize=12)
axes[0].set_ylabel('Sequenceness', fontsize=12)
axes[0].set_title('Forward Replay\n(Before Reward: A→B→C→D)', fontsize=12, fontweight='bold')
axes[0].grid(True, alpha=0.3)
axes[0].text(0.05, 0.95, 'Forward\n(positive)', transform=axes[0].transAxes,
            fontsize=11, verticalalignment='top', bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.5))

# Reverse replay
axes[1].plot(time_lags, sequenceness_reverse, 'r-', linewidth=2.5)
axes[1].axhline(y=0, color='k', linestyle='--', alpha=0.3)
axes[1].axhline(y=threshold, color='r', linestyle='--', linewidth=2)
axes[1].axhline(y=-threshold, color='r', linestyle='--', linewidth=2)
axes[1].axvline(x=40, color='green', linestyle=':', linewidth=2, alpha=0.7)
axes[1].fill_between(time_lags, 0, sequenceness_reverse, where=(sequenceness_reverse < 0), 
                     alpha=0.3, color='red')
axes[1].set_xlabel('Time Lag (ms)', fontsize=12)
axes[1].set_ylabel('Sequenceness', fontsize=12)
axes[1].set_title('Reverse Replay\n(After Reward: D→C→B→A)', fontsize=12, fontweight='bold')
axes[1].grid(True, alpha=0.3)
axes[1].text(0.05, 0.05, 'Backward\n(negative)', transform=axes[1].transAxes,
            fontsize=11, verticalalignment='bottom', bbox=dict(boxstyle='round', facecolor='lightcoral', alpha=0.5))

plt.suptitle('Example 2: Direction Reversal After Reward', fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
plt.savefig('sequenceness_example2_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

print("\nComparison complete!")

## Part 4: Example 3 - Multiple Sequences

Real experiments often have multiple possible sequences. Let's test with two interleaved sequences.

In [None]:
# Use 8 states for two sequences
n_states_multi = 8
decoder_multi = MultivariateReplayDecoder(n_states=n_states_multi, max_lag_ms=600)

# Create training data for 8 states
print("Creating training data for 8 states (2 sequences)...")
X_train_multi = np.random.randn(400, n_timepoints_per_trial, n_sensors) * 0.5
y_train_multi = np.random.randint(0, n_states_multi, 400)

# Add signal for each state
for trial in range(400):
    state = y_train_multi[trial]
    X_train_multi[trial, time_idx, state*12:(state+1)*12] += 2.0

# Train decoder
print("Training decoder for 8 states...")
decoder_multi.train_classifiers(X_train_multi, y_train_multi, time_point_ms=200)

# Define two sequences: Seq1 (0→1→2→3) and Seq2 (4→5→6→7)
transition_matrix_seq1 = np.zeros((8, 8))
transition_matrix_seq1[0, 1] = 1  # 0→1
transition_matrix_seq1[1, 2] = 1  # 1→2
transition_matrix_seq1[2, 3] = 1  # 2→3

transition_matrix_seq2 = np.zeros((8, 8))
transition_matrix_seq2[4, 5] = 1  # 4→5
transition_matrix_seq2[5, 6] = 1  # 5→6
transition_matrix_seq2[6, 7] = 1  # 6→7

print("\nSequence 1: States 0→1→2→3")
print("Sequence 2: States 4→5→6→7")

In [None]:
# Simulate rest with ONLY Sequence 1 replay
print("\nSimulating rest with Sequence 1 replay only...")
X_rest_multi = np.random.randn(n_rest_timepoints, n_sensors) * 0.3

# Embed Sequence 1: 0→1→2→3
seq1 = [0, 1, 2, 3]
for start_time in range(100, n_rest_timepoints - 100, 200):
    for step, state in enumerate(seq1):
        time_point = start_time + step * lag_samples
        if time_point < n_rest_timepoints:
            X_rest_multi[time_point, state*12:(state+1)*12] += 1.5

# Decode
print("Decoding...")
reactivation_probs_multi = decoder_multi.decode_states(X_rest_multi)

# Compute sequenceness for both sequences
print("\nComputing sequenceness for both sequences...")
seq1_sequenceness, _ = decoder_multi.compute_sequenceness(
    reactivation_probs_multi, transition_matrix_seq1
)
seq2_sequenceness, _ = decoder_multi.compute_sequenceness(
    reactivation_probs_multi, transition_matrix_seq2
)

# Compare peaks
peak1 = np.max(np.abs(seq1_sequenceness))
peak2 = np.max(np.abs(seq2_sequenceness))

print(f"\nResults:")
print(f"  Sequence 1 peak: {peak1:.4f}")
print(f"  Sequence 2 peak: {peak2:.4f}")
print(f"  Ratio (Seq1/Seq2): {peak1/peak2:.2f}")

In [None]:
# Visualize both sequences
fig, ax = plt.subplots(figsize=(12, 6))

ax.plot(time_lags, seq1_sequenceness, 'b-', linewidth=2.5, label='Sequence 1 (0→1→2→3)', alpha=0.8)
ax.plot(time_lags, seq2_sequenceness, 'orange', linewidth=2.5, label='Sequence 2 (4→5→6→7)', alpha=0.8)
ax.axhline(y=0, color='k', linestyle='--', alpha=0.3)
ax.axvline(x=40, color='green', linestyle=':', linewidth=2, alpha=0.7, label='True lag (40ms)')

ax.set_xlabel('Time Lag (ms)', fontsize=12)
ax.set_ylabel('Sequenceness', fontsize=12)
ax.set_title('Example 3: Sequence-Specific Replay\n(Only Sequence 1 was embedded in data)', 
            fontsize=14, fontweight='bold')
ax.legend(loc='upper right', fontsize=11)
ax.grid(True, alpha=0.3)

# Add annotation
ax.text(0.98, 0.95, f'Sequence 1 replays at 40ms\nSequence 2 shows no replay',
       transform=ax.transAxes, fontsize=11, verticalalignment='top', horizontalalignment='right',
       bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

plt.tight_layout()
plt.savefig('sequenceness_example3_multi.png', dpi=150, bbox_inches='tight')
plt.show()

print("\nMulti-sequence analysis complete!")

## Part 5: Summary and Key Findings

This notebook demonstrated the multivariate decoding algorithm from Liu et al. (2019):

In [None]:
# Create summary visualization
fig = plt.figure(figsize=(14, 10))
gs = fig.add_gridspec(3, 2, hspace=0.3, wspace=0.3)

# 1. Training phase
ax1 = fig.add_subplot(gs[0, :])
ax1.text(0.5, 0.5, 'Step 1: Train Classifiers\n\n' +
        '• Record neural activity during stimulus presentations\n' +
        '• Train binary classifiers (one per stimulus)\n' +
        '• Use L1-regularized logistic regression\n' +
        '• Focus on specific time point (200ms post-stimulus)',
        transform=ax1.transAxes, fontsize=12, verticalalignment='center',
        horizontalalignment='center', bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.7))
ax1.axis('off')
ax1.set_title('Algorithm Overview', fontsize=16, fontweight='bold', pad=20)

# 2. Decoding phase
ax2 = fig.add_subplot(gs[1, 0])
ax2.text(0.5, 0.5, 'Step 2: Decode Rest Activity\n\n' +
        '• Apply classifiers to rest data\n' +
        '• Generate reactivation probabilities\n' +
        '• Create time series for each state',
        transform=ax2.transAxes, fontsize=11, verticalalignment='center',
        horizontalalignment='center', bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.7))
ax2.axis('off')

# 3. Sequenceness computation
ax3 = fig.add_subplot(gs[1, 1])
ax3.text(0.5, 0.5, 'Step 3: Compute Sequenceness\n\n' +
        '• Time-lagged regression\n' +
        '• Test multiple time lags\n' +
        '• Control for 10Hz oscillations\n' +
        '• Forward - Backward measure',
        transform=ax3.transAxes, fontsize=11, verticalalignment='center',
        horizontalalignment='center', bbox=dict(boxstyle='round', facecolor='lightgreen', alpha=0.7))
ax3.axis('off')

# 4. Key findings visualization
ax4 = fig.add_subplot(gs[2, :])
findings_text = (
    'KEY FINDINGS:\n\n'
    '1. FORWARD REPLAY: Sequences play forward during learning/exploration (A→B→C→D)\n'
    f'   Example 1 peak: {peak_value:.3f} at {peak_lag}ms\n\n'
    '2. REVERSE REPLAY: Sequences reverse after reward (D→C→B→A)\n'
    f'   Example 2 peak: {peak_value_reverse:.3f} at {peak_lag_reverse}ms\n\n'
    '3. SEQUENCE-SPECIFIC: Only experienced sequences show replay\n'
    f'   Example 3: Seq1/Seq2 ratio = {peak1/peak2:.2f}\n\n'
    '4. TIME-COMPRESSED: Replay occurs at ~40-50ms intervals (faster than experience)\n\n'
    '5. STATISTICAL VALIDATION: Permutation tests confirm significance'
)
ax4.text(0.05, 0.95, findings_text, transform=ax4.transAxes, fontsize=11,
        verticalalignment='top', fontfamily='monospace',
        bbox=dict(boxstyle='round', facecolor='lavender', alpha=0.8))
ax4.axis('off')

plt.suptitle('Multivariate Decoding Analysis - Summary', fontsize=18, fontweight='bold', y=0.98)
plt.savefig('analysis_summary.png', dpi=150, bbox_inches='tight')
plt.show()

print("\n" + "="*70)
print("ANALYSIS COMPLETE!")
print("="*70)
print("\nGenerated files:")
print("  • training_data_heatmap.png")
print("  • reactivation_timeseries.png")
print("  • transition_matrix.png")
print("  • sequenceness_example1.png")
print("  • sequenceness_example2_comparison.png")
print("  • sequenceness_example3_multi.png")
print("  • analysis_summary.png")
print("\nThis implementation demonstrates the core algorithm from:")
print("Liu et al. (2019) Cell: Human Replay Spontaneously Reorganizes Experience")
print("="*70)

## Next Steps

To apply this to real data:

1. **Load your MEG/EEG data**: Replace simulated data with real recordings
2. **Adjust parameters**: 
   - Sampling rate
   - Number of sensors
   - Regularization strength (C parameter)
   - Time windows
3. **Preprocess data**:
   - Filter (high-pass at 0.5 Hz)
   - Artifact rejection
   - Baseline correction
4. **Define your sequences**: Based on experimental design
5. **Interpret results**: Consider:
   - Behavioral relevance
   - Task structure
   - Individual differences

---

**Citation:**
Liu, Y., Dolan, R. J., Kurth-Nelson, Z., & Behrens, T. E. (2019). 
Human replay spontaneously reorganizes experience. *Cell*, 178(3), 640-652.