# Animation and Dynamics

This comprehensive tutorial demonstrates how to create dynamic visualizations and animations of neural data using BrainTools. We'll explore techniques for visualizing temporal evolution, creating movies of network dynamics, and optimizing performance for large datasets.

## Learning Objectives

By the end of this tutorial, you will be able to:
- Create animated visualizations of neural activity over time
- Apply 1D and 2D animation techniques to neural data
- Generate movies of evolving network dynamics
- Visualize learning processes through time-lapse animations
- Display dynamic changes in connectivity patterns
- Export animations in various video formats
- Optimize performance for large-scale temporal datasets

## 1. Setup and Imports <a id='setup'></a>

In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
from IPython.display import HTML, display
from matplotlib.animation import FuncAnimation, PillowWriter, FFMpegWriter

import braintools

# Set up output directory
output_dir = Path('animations')
output_dir.mkdir(exist_ok=True)

# Enable interactive matplotlib in Jupyter
%matplotlib inline

# For animations in notebook
from matplotlib import rc

rc('animation', html='jshtml')

# Set random seed
np.random.seed(42)

## 2. Generate Dynamic Neural Data

In [None]:
# Generate synthetic dynamic neural data

def generate_dynamic_data(n_frames=100, n_neurons=50, duration=10.0):
    """Generate time-varying neural data for animation demonstrations."""

    data = {}

    # 1. Time-varying spike trains (propagating wave)
    spike_data = []
    for frame in range(n_frames):
        frame_spikes = []
        wave_center = frame / n_frames * duration

        for neuron in range(n_neurons):
            # Create wave of activity
            phase = 2 * np.pi * neuron / n_neurons
            rate = 10 * (1 + np.sin(2 * np.pi * frame / n_frames + phase))

            n_spikes = np.random.poisson(rate * duration / n_frames)
            spike_times = np.random.uniform(
                wave_center - 0.5, wave_center + 0.5, n_spikes
            )
            spike_times = spike_times[(spike_times >= 0) & (spike_times <= duration)]
            frame_spikes.append(spike_times)

        spike_data.append(frame_spikes)
    data['spike_trains'] = spike_data

    # 2. Oscillating membrane potential
    time_points = np.linspace(0, duration, 1000)
    membrane_data = np.zeros((n_frames, len(time_points)))

    for frame in range(n_frames):
        phase_shift = 2 * np.pi * frame / n_frames
        base = -70 + 5 * np.sin(phase_shift)
        oscillation = 10 * np.sin(2 * np.pi * time_points + phase_shift)
        noise = np.random.normal(0, 2, len(time_points))
        membrane_data[frame] = base + oscillation + noise

    data['membrane_potential'] = membrane_data
    data['time'] = time_points

    # 3. 2D activity patterns (traveling waves)
    grid_size = 50
    activity_frames = np.zeros((n_frames, grid_size, grid_size))

    for frame in range(n_frames):
        x = np.linspace(-2, 2, grid_size)
        y = np.linspace(-2, 2, grid_size)
        X, Y = np.meshgrid(x, y)

        # Traveling wave
        wave_pos = -2 + 4 * frame / n_frames
        activity = np.exp(-((X - wave_pos) ** 2 + Y ** 2) / 0.5)

        # Rotating spiral
        theta = np.arctan2(Y, X) + 2 * np.pi * frame / n_frames
        spiral = 0.5 * (1 + np.sin(3 * theta - np.sqrt(X ** 2 + Y ** 2)))

        activity_frames[frame] = activity + 0.3 * spiral
        activity_frames[frame] += np.random.normal(0, 0.1, (grid_size, grid_size))

    data['activity_2d'] = activity_frames

    # 4. Evolving connectivity
    connectivity_frames = np.zeros((n_frames, 20, 20))
    base_connectivity = np.random.randn(20, 20)

    for frame in range(n_frames):
        # Gradually strengthen/weaken connections
        modulation = np.sin(2 * np.pi * frame / n_frames)
        connectivity_frames[frame] = base_connectivity * (1 + 0.5 * modulation)

        # Add plastic changes
        if frame > 0:
            plastic_change = np.random.normal(0, 0.01, (20, 20))
            connectivity_frames[frame] += plastic_change

    data['connectivity'] = connectivity_frames

    # 5. Learning curve data
    learning_steps = 100
    performance = np.zeros(learning_steps)
    error = np.zeros(learning_steps)

    for step in range(learning_steps):
        performance[step] = 1 - np.exp(-step / 20) + np.random.normal(0, 0.05)
        error[step] = np.exp(-step / 15) + np.random.normal(0, 0.02)

    data['learning'] = {'performance': performance, 'error': error}

    return data


# Generate all dynamic data
print("Generating dynamic neural data...")
dynamic_data = generate_dynamic_data(n_frames=100)

print("\nGenerated data:")
print(f"  Spike trains: {len(dynamic_data['spike_trains'])} frames")
print(f"  Membrane potential: shape {dynamic_data['membrane_potential'].shape}")
print(f"  2D activity: shape {dynamic_data['activity_2d'].shape}")
print(f"  Connectivity: shape {dynamic_data['connectivity'].shape}")
print(f"  Learning data: {len(dynamic_data['learning']['performance'])} steps")

## 3. Basic Animation Techniques <a id='basic'></a>

Understanding the fundamentals of matplotlib animations before applying them to neural data.

In [None]:
# Basic animation using FuncAnimation

# Example 1: Simple oscillating sine wave
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

# Setup axes
ax1.set_xlim(0, 2 * np.pi)
ax1.set_ylim(-1.5, 1.5)
ax1.set_xlabel('x')
ax1.set_ylabel('y')
ax1.set_title('Animated Sine Wave')
ax1.grid(True, alpha=0.3)

# Create line object
x = np.linspace(0, 2 * np.pi, 100)
line, = ax1.plot([], [], 'b-', linewidth=2)


# Animation function
def animate_sine(frame):
    y = np.sin(x + 0.1 * frame)
    line.set_data(x, y)
    return line,


# Example 2: Growing scatter plot
ax2.set_xlim(-2, 2)
ax2.set_ylim(-2, 2)
ax2.set_xlabel('x')
ax2.set_ylabel('y')
ax2.set_title('Animated Scatter Plot')
ax2.grid(True, alpha=0.3)

# Create scatter plot
scat = ax2.scatter([], [], c=[], s=50, cmap='viridis', vmin=0, vmax=100)


def animate_scatter(frame):
    # Generate random points
    n_points = min(frame + 1, 100)
    x_data = np.random.randn(n_points) * (1 + frame / 100)
    y_data = np.random.randn(n_points) * (1 + frame / 100)
    colors = np.arange(n_points)

    # Update scatter plot
    data = np.c_[x_data, y_data]
    scat.set_offsets(data)
    scat.set_array(colors)
    return scat,


# Combine animations
def animate_both(frame):
    animate_sine(frame)
    animate_scatter(frame)
    return line, scat


# Create animation
anim = FuncAnimation(fig, animate_both, frames=100, interval=50, blit=True)

plt.tight_layout()
plt.show()

# Display animation in notebook
display(HTML(anim.to_jshtml()))

print("\nBasic Animation Techniques:")
print("- FuncAnimation: Updates plot elements frame by frame")
print("- Blit=True: Optimizes rendering by only updating changed elements")
print("- Interval: Controls frame delay in milliseconds")
print("- Multiple subplots can be animated simultaneously")

## 4. Neural Activity Animation <a id='neural-activity'></a>

Animating spike trains and neural population activity over time.

In [None]:
# Animate spike raster plot with sliding window

fig, axes = plt.subplots(2, 1, figsize=(12, 8))
fig.suptitle('Neural Activity Animation', fontsize=14, fontweight='bold')

# Parameters
window_size = 2.0  # seconds
n_neurons = 30
total_duration = 10.0

# Generate continuous spike trains
all_spike_trains = []
for neuron in range(n_neurons):
    rate = np.random.uniform(5, 20)  # Hz
    n_spikes = np.random.poisson(rate * total_duration)
    spike_times = np.sort(np.random.uniform(0, total_duration, n_spikes))
    all_spike_trains.append(spike_times)

# Setup spike raster axis
ax1 = axes[0]
ax1.set_xlim(0, window_size)
ax1.set_ylim(0, n_neurons)
ax1.set_xlabel('Time (s)')
ax1.set_ylabel('Neuron ID')
ax1.set_title('Sliding Window Spike Raster')

# Create scatter plot for spikes
spike_scatter = ax1.scatter([], [], c='black', s=2, marker='|')

# Setup population rate axis
ax2 = axes[1]
ax2.set_xlim(0, window_size)
ax2.set_ylim(0, 30)
ax2.set_xlabel('Time (s)')
ax2.set_ylabel('Population Rate (Hz)')
ax2.set_title('Population Firing Rate')
ax2.grid(True, alpha=0.3)

# Create line for population rate
time_bins = np.linspace(0, window_size, 100)
rate_line, = ax2.plot([], [], 'b-', linewidth=2)
rate_fill = None


def animate_neural_activity(frame):
    global rate_fill

    # Calculate window position
    window_start = frame * 0.05  # Slide by 50ms per frame
    window_end = window_start + window_size

    # Get spikes in current window
    spike_x = []
    spike_y = []

    for neuron_id, spike_train in enumerate(all_spike_trains):
        window_spikes = spike_train[(spike_train >= window_start) &
                                    (spike_train < window_end)]
        # Shift to window coordinates
        window_spikes_shifted = window_spikes - window_start
        spike_x.extend(window_spikes_shifted)
        spike_y.extend([neuron_id] * len(window_spikes_shifted))

    # Update spike scatter
    if spike_x:
        spike_scatter.set_offsets(np.c_[spike_x, spike_y])

    # Calculate population rate
    counts, _ = np.histogram([s for s in spike_x], bins=time_bins)
    rate = counts / (time_bins[1] - time_bins[0]) / n_neurons

    # Smooth the rate
    from scipy.ndimage import gaussian_filter1d
    rate_smooth = gaussian_filter1d(rate, sigma=2)

    # Update rate plot
    rate_line.set_data(time_bins[:-1], rate_smooth)

    # Update fill area (remove old and add new)
    if rate_fill is not None:
        rate_fill.remove()
    rate_fill = ax2.fill_between(time_bins[:-1], 0, rate_smooth, alpha=0.3)

    # Update time indicator
    ax1.set_title(f'Sliding Window Spike Raster (t = {window_start:.1f}s)')

    return spike_scatter, rate_line, rate_fill


# Create animation
n_frames = int((total_duration - window_size) / 0.05)
anim = FuncAnimation(fig, animate_neural_activity, frames=n_frames,
                     interval=50, blit=False)

plt.tight_layout()

# Save as GIF
print("Saving neural activity animation...")
anim.save(output_dir / 'neural_activity.gif', writer='pillow', fps=20)

# Display in notebook
display(HTML(anim.to_jshtml()))

print("\nNeural Activity Animation Features:")
print("- Sliding window through spike trains")
print("- Real-time population rate calculation")
print("- Synchronized multi-panel updates")
print(f"- Animation saved to: {output_dir / 'neural_activity.gif'}")

## 5. 1D Signal Animation <a id='1d-animation'></a>

Animating time series data such as membrane potentials and LFP signals.

In [None]:
# Animate 1D signals with multiple channels

# Generate multi-channel signals
n_channels = 5
signal_length = 1000
time = np.linspace(0, 10, signal_length)

# Create signals with different frequencies
signals = np.zeros((n_channels, signal_length))
for i in range(n_channels):
    freq = 0.5 + i * 0.5  # Different frequencies
    phase = np.random.uniform(0, 2 * np.pi)
    signals[i] = np.sin(2 * np.pi * freq * time + phase) * (1 + 0.2 * i)
    signals[i] += np.random.normal(0, 0.1, signal_length)

# Setup figure
fig, axes = plt.subplots(3, 1, figsize=(12, 10))
fig.suptitle('1D Signal Animation', fontsize=14, fontweight='bold')

# 1. Oscilloscope-style display
ax1 = axes[0]
ax1.set_xlim(0, 2)
ax1.set_ylim(-3, 3)
ax1.set_xlabel('Time (s)')
ax1.set_ylabel('Amplitude')
ax1.set_title('Oscilloscope View')
ax1.grid(True, alpha=0.3)

lines_osc = []
colors = plt.cm.viridis(np.linspace(0, 1, n_channels))
for i in range(n_channels):
    line, = ax1.plot([], [], color=colors[i], linewidth=1.5,
                     label=f'Ch {i + 1}')
    lines_osc.append(line)
ax1.legend(loc='upper right')

# 2. Waterfall plot
ax2 = axes[1]
ax2.set_xlim(0, 2)
ax2.set_ylim(-1, n_channels * 2)
ax2.set_xlabel('Time (s)')
ax2.set_ylabel('Channel')
ax2.set_title('Waterfall Display')

lines_waterfall = []
for i in range(n_channels):
    line, = ax2.plot([], [], color=colors[i], linewidth=1.5)
    lines_waterfall.append(line)

# 3. Spectrogram-style heatmap
ax3 = axes[2]
ax3.set_xlabel('Time (s)')
ax3.set_ylabel('Channel')
ax3.set_title('Signal Intensity Heatmap')

# Initialize heatmap data
heatmap_data = np.zeros((n_channels, 100))
im = ax3.imshow(heatmap_data, aspect='auto', cmap='hot',
                extent=[0, 2, 0, n_channels], vmin=-2, vmax=2)
plt.colorbar(im, ax=ax3, label='Amplitude')


def animate_1d_signals(frame):
    # Window parameters
    window_size = 200  # samples
    start_idx = frame * 5
    end_idx = start_idx + window_size

    if end_idx > signal_length:
        return lines_osc + lines_waterfall + [im]

    # Time window
    time_window = time[start_idx:end_idx] - time[start_idx]

    # Update oscilloscope
    for i, line in enumerate(lines_osc):
        line.set_data(time_window, signals[i, start_idx:end_idx])

    # Update waterfall (with offset)
    for i, line in enumerate(lines_waterfall):
        offset_signal = signals[i, start_idx:end_idx] + i * 2
        line.set_data(time_window, offset_signal)

    # Update heatmap
    heatmap_slice = signals[:, start_idx:end_idx:2]  # Downsample
    im.set_array(heatmap_slice)

    # Update titles with time
    current_time = time[start_idx]
    ax1.set_title(f'Oscilloscope View (t = {current_time:.1f}s)')

    return lines_osc + lines_waterfall + [im]


# Create animation
n_frames = (signal_length - 200) // 5
anim = FuncAnimation(fig, animate_1d_signals, frames=n_frames,
                     interval=50, blit=True)

plt.tight_layout()

# Display in notebook
display(HTML(anim.to_jshtml()))

print("\n1D Signal Animation Techniques:")
print("- Oscilloscope: Real-time signal display")
print("- Waterfall: Multiple channels with vertical offset")
print("- Heatmap: Intensity representation over time")
print("- Synchronized visualization across different representations")

## 6. 2D Spatial Animation <a id='2d-animation'></a>

Animating spatial patterns of neural activity, such as traveling waves and spreading activation.

In [None]:
# Use braintools animator for 2D activity

# Create figure for 2D animation
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
fig.suptitle('2D Spatial Activity Animation', fontsize=14, fontweight='bold')

# Generate different types of 2D patterns
n_frames = 100
grid_size = 50

# 1. Traveling wave
wave_data = np.zeros((n_frames, grid_size, grid_size))
for frame in range(n_frames):
    x = np.linspace(-3, 3, grid_size)
    y = np.linspace(-3, 3, grid_size)
    X, Y = np.meshgrid(x, y)

    # Wave position
    wave_pos = -3 + 6 * frame / n_frames
    wave_data[frame] = np.exp(-((X - wave_pos) ** 2 + Y ** 2) / 0.8)

# 2. Rotating spiral
spiral_data = np.zeros((n_frames, grid_size, grid_size))
for frame in range(n_frames):
    x = np.linspace(-2, 2, grid_size)
    y = np.linspace(-2, 2, grid_size)
    X, Y = np.meshgrid(x, y)

    theta = np.arctan2(Y, X)
    r = np.sqrt(X ** 2 + Y ** 2)
    spiral_data[frame] = 0.5 * (1 + np.sin(3 * theta - r + 0.2 * frame))

# 3. Spreading activation
spread_data = np.zeros((n_frames, grid_size, grid_size))
center = grid_size // 2
for frame in range(n_frames):
    x, y = np.ogrid[:grid_size, :grid_size]
    radius = frame * 0.5
    mask = (x - center) ** 2 + (y - center) ** 2 <= radius ** 2
    spread_data[frame][mask] = np.exp(-frame / 20)

# 4. Random hotspots
hotspot_data = np.zeros((n_frames, grid_size, grid_size))
n_hotspots = 5
hotspot_centers = [(np.random.randint(10, 40), np.random.randint(10, 40))
                   for _ in range(n_hotspots)]

for frame in range(n_frames):
    for cx, cy in hotspot_centers:
        x, y = np.ogrid[:grid_size, :grid_size]
        dist = np.sqrt((x - cx) ** 2 + (y - cy) ** 2)
        intensity = np.sin(2 * np.pi * frame / 20 + cx / 10) * 0.5 + 0.5
        hotspot_data[frame] += intensity * np.exp(-dist ** 2 / 50)

# Create animations using braintools animator
ax1 = axes[0, 0]
ax1.set_title('Traveling Wave')
anim1 = braintools.visualize.animator(wave_data, fig, ax1, interval=50, cmap='viridis')

ax2 = axes[0, 1]
ax2.set_title('Rotating Spiral')
anim2 = braintools.visualize.animator(spiral_data, fig, ax2, interval=50, cmap='plasma')

ax3 = axes[1, 0]
ax3.set_title('Spreading Activation')
anim3 = braintools.visualize.animator(spread_data, fig, ax3, interval=50, cmap='hot')

ax4 = axes[1, 1]
ax4.set_title('Dynamic Hotspots')
anim4 = braintools.visualize.animator(hotspot_data, fig, ax4, interval=50, cmap='jet')

plt.tight_layout()

# Display first animation
display(HTML(anim1.to_jshtml()))

print("\n2D Spatial Animation Patterns:")
print("- Traveling waves: Common in cortical dynamics")
print("- Spiral waves: Found in cardiac and neural tissue")
print("- Spreading activation: Models seizure propagation")
print("- Dynamic hotspots: Represents changing activity centers")
print("\nUsing braintools.visualize.animator for efficient 2D animation")

## 7. Network Dynamics Movies <a id='network-dynamics'></a>

Creating movies that show evolving network structure and activity.

In [None]:
# Animate network dynamics with changing connectivity and node activity

# Network parameters
n_nodes = 20
n_frames = 100

# Generate node positions (fixed)
np.random.seed(42)
theta = np.linspace(0, 2 * np.pi, n_nodes, endpoint=False)
node_pos = np.column_stack([np.cos(theta), np.sin(theta)]) * 2

# Add some noise to positions
node_pos += np.random.normal(0, 0.1, node_pos.shape)

# Generate time-varying connectivity
base_connectivity = np.random.rand(n_nodes, n_nodes)
base_connectivity = (base_connectivity + base_connectivity.T) / 2
base_connectivity[base_connectivity < 0.7] = 0
np.fill_diagonal(base_connectivity, 0)

# Generate time-varying node activity
node_activity = np.zeros((n_frames, n_nodes))
for frame in range(n_frames):
    # Oscillating activity
    phase = 2 * np.pi * frame / 30
    for i in range(n_nodes):
        node_activity[frame, i] = 0.5 + 0.5 * np.sin(phase + i * np.pi / 5)

# Setup figure
fig, axes = plt.subplots(1, 2, figsize=(14, 6))
fig.suptitle('Network Dynamics Animation', fontsize=14, fontweight='bold')

# Network visualization axis
ax1 = axes[0]
ax1.set_xlim(-3, 3)
ax1.set_ylim(-3, 3)
ax1.set_aspect('equal')
ax1.set_title('Network Structure')
ax1.axis('off')

# Connectivity matrix axis
ax2 = axes[1]
ax2.set_title('Connectivity Matrix')
im = ax2.imshow(base_connectivity, cmap='RdBu_r', vmin=-1, vmax=1)
plt.colorbar(im, ax=ax2, label='Weight')

# Initialize network plot elements
nodes = ax1.scatter(node_pos[:, 0], node_pos[:, 1],
                    s=200, c=node_activity[0], cmap='hot',
                    vmin=0, vmax=1, edgecolors='black', linewidths=1)

# Create edge lines
edges = []
for i in range(n_nodes):
    for j in range(i + 1, n_nodes):
        if base_connectivity[i, j] > 0:
            line, = ax1.plot([node_pos[i, 0], node_pos[j, 0]],
                             [node_pos[i, 1], node_pos[j, 1]],
                             'k-', alpha=0.3, linewidth=1)
            edges.append((line, i, j))


def animate_network(frame):
    # Update node colors based on activity
    nodes.set_array(node_activity[frame])

    # Update edge properties based on connectivity strength
    connectivity_frame = base_connectivity * (1 + 0.5 * np.sin(2 * np.pi * frame / 50))

    for line, i, j in edges:
        strength = abs(connectivity_frame[i, j])
        line.set_alpha(min(strength, 1.0))
        line.set_linewidth(strength * 3)

        # Color based on correlation between nodes
        if node_activity[frame, i] * node_activity[frame, j] > 0.5:
            line.set_color('red')
        else:
            line.set_color('blue')

    # Update connectivity matrix
    im.set_array(connectivity_frame)

    # Update title with time
    ax1.set_title(f'Network Structure (t = {frame:.0f})')

    return [nodes, im] + [e[0] for e in edges]


# Create animation
anim = FuncAnimation(fig, animate_network, frames=n_frames,
                     interval=100, blit=False)

plt.tight_layout()

# Save as MP4 if ffmpeg is available
try:
    Writer = FFMpegWriter(fps=10, bitrate=1800)
    anim.save(output_dir / 'network_dynamics.mp4', writer=Writer)
    print(f"Network dynamics movie saved to: {output_dir / 'network_dynamics.mp4'}")
except:
    print("FFMpeg not available, saving as GIF instead")
    anim.save(output_dir / 'network_dynamics.gif', writer='pillow', fps=10)

# Display in notebook
display(HTML(anim.to_jshtml()))

print("\nNetwork Dynamics Features:")
print("- Node activity represented by color intensity")
print("- Edge strength shown by line width and opacity")
print("- Edge color indicates correlation between connected nodes")
print("- Synchronized matrix view shows connectivity changes")

## 8. Learning Process Visualization <a id='learning'></a>

Time-lapse visualization of learning dynamics, weight evolution, and performance metrics.

In [None]:
# Animate learning process with multiple metrics

# Simulate learning data
n_epochs = 100
n_neurons = 10

# Performance metrics
accuracy = np.zeros(n_epochs)
loss = np.zeros(n_epochs)

# Weight evolution
weights = np.random.randn(n_epochs, n_neurons, n_neurons) * 0.1

# Simulate learning
for epoch in range(n_epochs):
    # Learning curves
    accuracy[epoch] = 1 - np.exp(-epoch / 20) + np.random.normal(0, 0.02)
    loss[epoch] = 2 * np.exp(-epoch / 15) + np.random.normal(0, 0.05)

    # Weight updates (Hebbian-like)
    if epoch > 0:
        weights[epoch] = weights[epoch - 1] + np.random.randn(n_neurons, n_neurons) * 0.01
        weights[epoch] *= 0.99  # Decay

# Setup figure
fig = plt.figure(figsize=(15, 10))
gs = fig.add_gridspec(3, 3, hspace=0.3, wspace=0.3)

# Axes
ax1 = fig.add_subplot(gs[0, :])
ax2 = fig.add_subplot(gs[1, 0])
ax3 = fig.add_subplot(gs[1, 1])
ax4 = fig.add_subplot(gs[1, 2])
ax5 = fig.add_subplot(gs[2, :])

fig.suptitle('Learning Process Animation', fontsize=14, fontweight='bold')

# 1. Learning curves
ax1.set_xlim(0, n_epochs)
ax1.set_ylim(0, 1.2)
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Performance')
ax1.set_title('Learning Curves')
ax1.grid(True, alpha=0.3)

acc_line, = ax1.plot([], [], 'g-', linewidth=2, label='Accuracy')
loss_line, = ax1.plot([], [], 'r-', linewidth=2, label='Loss (scaled)')
ax1.legend(loc='right')

# 2. Weight matrix
im_weights = ax2.imshow(weights[0], cmap='RdBu_r', vmin=-1, vmax=1)
ax2.set_title('Weight Matrix')
ax2.set_xlabel('Post')
ax2.set_ylabel('Pre')

# 3. Weight histogram
ax3.set_xlim(-2, 2)
ax3.set_ylim(0, 30)
ax3.set_xlabel('Weight Value')
ax3.set_ylabel('Count')
ax3.set_title('Weight Distribution')
n_bins = 30
counts, bins = np.histogram(weights[0].flatten(), bins=n_bins)
hist_bars = ax3.bar(bins[:-1], counts, width=bins[1] - bins[0],
                    color='blue', alpha=0.7)

# 4. Network activity
activity = np.random.rand(n_neurons)
activity_bars = ax4.bar(range(n_neurons), activity, color='orange')
ax4.set_xlim(-0.5, n_neurons - 0.5)
ax4.set_ylim(0, 1)
ax4.set_xlabel('Neuron')
ax4.set_ylabel('Activity')
ax4.set_title('Neural Activity')

# 5. Feature importance evolution
feature_importance = np.random.rand(n_epochs, 5)
for i in range(1, n_epochs):
    feature_importance[i] = feature_importance[i - 1] * 0.9 + np.random.rand(5) * 0.1

ax5.set_xlim(0, n_epochs)
ax5.set_ylim(0, 1)
ax5.set_xlabel('Epoch')
ax5.set_ylabel('Importance')
ax5.set_title('Feature Importance Evolution')

feature_lines = []
for i in range(5):
    line, = ax5.plot([], [], linewidth=2, label=f'Feature {i + 1}')
    feature_lines.append(line)
ax5.legend(loc='right')


def animate_learning(frame):
    # Update learning curves
    epochs_so_far = np.arange(frame + 1)
    acc_line.set_data(epochs_so_far, accuracy[:frame + 1])
    loss_line.set_data(epochs_so_far, loss[:frame + 1] / 2)  # Scale for display

    # Update weight matrix
    im_weights.set_array(weights[frame])

    # Update weight histogram
    counts, _ = np.histogram(weights[frame].flatten(), bins=bins)
    for bar, count in zip(hist_bars, counts):
        bar.set_height(count)

    # Update activity (simulate changing patterns)
    new_activity = np.abs(np.sum(weights[frame], axis=0))
    new_activity = new_activity / np.max(new_activity) if np.max(new_activity) > 0 else new_activity
    for bar, act in zip(activity_bars, new_activity):
        bar.set_height(act)

    # Update feature importance
    for i, line in enumerate(feature_lines):
        line.set_data(epochs_so_far, feature_importance[:frame + 1, i])

    # Update main title
    fig.suptitle(f'Learning Process Animation (Epoch {frame})',
                 fontsize=14, fontweight='bold')

    return [acc_line, loss_line, im_weights] + list(hist_bars) + \
        list(activity_bars) + feature_lines


# Create animation
anim = FuncAnimation(fig, animate_learning, frames=n_epochs,
                     interval=100, blit=False)

plt.tight_layout()

# Save animation
anim.save(output_dir / 'learning_process.gif', writer='pillow', fps=10)

# Display in notebook
display(HTML(anim.to_jshtml()))

print("\nLearning Process Visualization:")
print("- Performance metrics tracked over epochs")
print("- Weight matrix evolution shows plasticity")
print("- Weight distribution changes during learning")
print("- Neural activity patterns evolve with training")
print("- Feature importance shows what the network learns")
print(f"- Animation saved to: {output_dir / 'learning_process.gif'}")

## 9. Dynamic Connectivity <a id='connectivity'></a>

Visualizing how connectivity patterns change over time, including synaptic plasticity and network reorganization.

In [None]:
# Animate dynamic connectivity with plasticity rules

# Network parameters
n_neurons = 15
n_frames = 150

# Initialize connectivity
connectivity = np.random.randn(n_frames, n_neurons, n_neurons) * 0.1
connectivity[0] = (connectivity[0] + connectivity[0].T) / 2

# Simulate STDP-like plasticity
for t in range(1, n_frames):
    # Copy previous connectivity
    connectivity[t] = connectivity[t - 1].copy()

    # Random spike times for STDP
    pre_spikes = np.random.rand(n_neurons) < 0.1
    post_spikes = np.random.rand(n_neurons) < 0.1

    # STDP update
    for i in range(n_neurons):
        for j in range(n_neurons):
            if i != j:
                if pre_spikes[i] and post_spikes[j]:
                    # Potentiation
                    connectivity[t, i, j] += 0.01
                elif post_spikes[i] and pre_spikes[j]:
                    # Depression
                    connectivity[t, i, j] -= 0.005

    # Add noise
    connectivity[t] += np.random.normal(0, 0.001, (n_neurons, n_neurons))

    # Bounds
    connectivity[t] = np.clip(connectivity[t], -1, 1)

    # Maintain symmetry for display
    connectivity[t] = (connectivity[t] + connectivity[t].T) / 2

# Setup figure
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
fig.suptitle('Dynamic Connectivity Patterns', fontsize=14, fontweight='bold')

# 1. Connectivity matrix
ax1 = axes[0, 0]
im1 = ax1.imshow(connectivity[0], cmap='RdBu_r', vmin=-0.5, vmax=0.5)
ax1.set_title('Connectivity Matrix')
ax1.set_xlabel('Post-synaptic')
ax1.set_ylabel('Pre-synaptic')
plt.colorbar(im1, ax=ax1, label='Weight')

# 2. Connection strength distribution
ax2 = axes[0, 1]
ax2.set_xlim(-0.5, 0.5)
ax2.set_ylim(0, 50)
ax2.set_xlabel('Connection Strength')
ax2.set_ylabel('Count')
ax2.set_title('Weight Distribution')
bins = np.linspace(-0.5, 0.5, 30)
counts, _ = np.histogram(connectivity[0].flatten(), bins=bins)
bars = ax2.bar(bins[:-1], counts, width=bins[1] - bins[0], alpha=0.7)

# 3. Graph visualization
ax3 = axes[1, 0]
ax3.set_xlim(-2, 2)
ax3.set_ylim(-2, 2)
ax3.set_aspect('equal')
ax3.set_title('Network Graph')
ax3.axis('off')

# Node positions (circular layout)
theta = np.linspace(0, 2 * np.pi, n_neurons, endpoint=False)
node_x = 1.5 * np.cos(theta)
node_y = 1.5 * np.sin(theta)

# Draw nodes
nodes = ax3.scatter(node_x, node_y, s=200, c='lightblue',
                    edgecolors='black', linewidths=1, zorder=3)

# Initialize edges
edge_lines = []
for i in range(n_neurons):
    for j in range(i + 1, n_neurons):
        line, = ax3.plot([node_x[i], node_x[j]],
                         [node_y[i], node_y[j]],
                         'k-', alpha=0, linewidth=1, zorder=1)
        edge_lines.append((line, i, j))

# 4. Connectivity metrics
ax4 = axes[1, 1]
ax4.set_xlim(0, n_frames)
ax4.set_ylim(-0.1, 0.5)
ax4.set_xlabel('Time')
ax4.set_ylabel('Metric Value')
ax4.set_title('Network Metrics')
ax4.grid(True, alpha=0.3)

# Calculate metrics
mean_strength = np.mean(np.abs(connectivity), axis=(1, 2))
sparsity = np.mean(np.abs(connectivity) > 0.1, axis=(1, 2))

mean_line, = ax4.plot([], [], 'b-', label='Mean Strength')
sparsity_line, = ax4.plot([], [], 'r-', label='Sparsity')
ax4.legend()


def animate_connectivity(frame):
    # Update connectivity matrix
    im1.set_array(connectivity[frame])

    # Update histogram
    counts, _ = np.histogram(connectivity[frame].flatten(), bins=bins)
    for bar, count in zip(bars, counts):
        bar.set_height(count)

    # Update graph edges
    for line, i, j in edge_lines:
        weight = connectivity[frame, i, j]
        if abs(weight) > 0.1:  # Threshold for visibility
            line.set_alpha(min(abs(weight) * 2, 1.0))
            line.set_linewidth(abs(weight) * 5)
            if weight > 0:
                line.set_color('red')
            else:
                line.set_color('blue')
        else:
            line.set_alpha(0)

    # Update metrics
    frames_so_far = np.arange(frame + 1)
    mean_line.set_data(frames_so_far, mean_strength[:frame + 1])
    sparsity_line.set_data(frames_so_far, sparsity[:frame + 1])

    # Update title
    fig.suptitle(f'Dynamic Connectivity Patterns (t = {frame})',
                 fontsize=14, fontweight='bold')

    return [im1] + list(bars) + [e[0] for e in edge_lines] + \
        [mean_line, sparsity_line]


# Create animation
anim = FuncAnimation(fig, animate_connectivity, frames=n_frames,
                     interval=50, blit=False)

plt.tight_layout()

# Display in notebook
display(HTML(anim.to_jshtml()))

print("\nDynamic Connectivity Features:")
print("- STDP-like plasticity rules applied")
print("- Weight distribution evolves over time")
print("- Graph visualization shows strong connections")
print("- Network metrics track global changes")
print("- Red edges: excitatory, Blue edges: inhibitory")

## 10. Export and Optimization <a id='export'></a>

Best practices for exporting animations and optimizing performance for large datasets.

In [None]:
import time

# Demonstration of export options and performance optimization

print("Animation Export Options and Optimization")
print("=" * 50)

# Create a simple test animation
fig, ax = plt.subplots(figsize=(6, 4))
ax.set_xlim(0, 10)
ax.set_ylim(-1, 1)
line, = ax.plot([], [], 'b-')


def simple_animate(frame):
    x = np.linspace(0, 10, 100)
    y = np.sin(x + 0.1 * frame)
    line.set_data(x, y)
    return line,


anim = FuncAnimation(fig, simple_animate, frames=50, interval=50, blit=True)

# 1. Export formats
print("\n1. Export Formats:")
print("-" * 30)

# GIF export
try:
    writer = PillowWriter(fps=20)
    anim.save(output_dir / 'test.gif', writer=writer)
    print("✓ GIF: Saved using PillowWriter")
    print("  - Good for: Small animations, web display")
    print("  - File size: Medium")
    print("  - Quality: Good for simple graphics")
except Exception as e:
    print(f"✗ GIF export failed: {e}")

# MP4 export
try:
    writer = FFMpegWriter(fps=30, bitrate=1800)
    anim.save(output_dir / 'test.mp4', writer=writer)
    print("\n✓ MP4: Saved using FFMpegWriter")
    print("  - Good for: Large animations, presentations")
    print("  - File size: Small (compressed)")
    print("  - Quality: Excellent")
except:
    print("\n✗ MP4: FFMpeg not installed")
    print("  Install with: conda install ffmpeg")

# HTML export
html_str = anim.to_jshtml()
with open(output_dir / 'test.html', 'w') as f:
    f.write(html_str)
print("\n✓ HTML: Saved with JavaScript controls")
print("  - Good for: Interactive viewing, notebooks")
print("  - File size: Large (embedded data)")
print("  - Quality: Excellent")

plt.close(fig)

# 2. Performance optimization techniques
print("\n2. Performance Optimization:")
print("-" * 30)

# Technique 1: Blitting
print("\na) Blitting:")
print("   - Use blit=True in FuncAnimation")
print("   - Only redraws changed elements")
print("   - 2-5x speed improvement")

# Technique 2: Data decimation
print("\nb) Data Decimation:")
print("   Example: Downsample large datasets")

# Generate large dataset
large_data = np.random.randn(1000, 1000)
print(f"   Original size: {large_data.shape}")

# Decimate
decimated_data = large_data[::5, ::5]  # Take every 5th point
print(f"   Decimated size: {decimated_data.shape}")
print(f"   Memory reduction: {(1 - decimated_data.size / large_data.size) * 100:.1f}%")

# Technique 3: Frame caching
print("\nc) Frame Caching:")
print("   - Pre-compute expensive operations")
print("   - Store results in memory")
print("   - Trade memory for speed")

# Example: Pre-compute frames
n_frames = 100
frame_cache = []
print("   Pre-computing frames...", end="")
start_time = time.time()

for i in range(n_frames):
    # Expensive computation
    frame_data = np.sin(np.linspace(0, 10, 1000) + i * 0.1)
    frame_cache.append(frame_data)

cache_time = time.time() - start_time
print(f" Done in {cache_time:.2f}s")

# Technique 4: Reduce artists
print("\nd) Reduce Number of Artists:")
print("   - Combine multiple lines into LineCollection")
print("   - Use single scatter plot instead of multiple")
print("   - Batch updates when possible")

# 3. Memory management
print("\n3. Memory Management:")
print("-" * 30)
print("- Clear figure: plt.close(fig)")
print("- Limit frame buffer: cache_frame_data=False")
print("- Use generators for data streaming")
print("- Delete large arrays after use")

# 4. Format comparison
print("\n4. Format Comparison Table:")
print("-" * 30)
print("Format | Quality | Size | Speed | Use Case")
print("-------|---------|------|-------|----------")
print("GIF    | Medium  | Med  | Fast  | Web, Email")
print("MP4    | High    | Small| Medium| Presentation")
print("AVI    | High    | Large| Slow  | Editing")
print("HTML   | High    | Large| Fast  | Interactive")
print("Frames | Highest | Huge | Slow  | Post-process")

print("\n✓ Export examples saved to:", output_dir.absolute())

## 11. Advanced Techniques <a id='advanced'></a>

Advanced animation techniques including interactive controls and custom animations.

In [None]:
# Advanced animation with custom controls

from matplotlib.widgets import Slider, Button

# Create figure with controls
fig = plt.figure(figsize=(12, 8))

# Main plot
ax_main = plt.subplot2grid((4, 3), (0, 0), colspan=3, rowspan=3)

# Control axes
ax_speed = plt.subplot2grid((4, 3), (3, 0), colspan=2)
ax_button = plt.subplot2grid((4, 3), (3, 2))

# Generate complex data
t = np.linspace(0, 10, 1000)
frequencies = [0.5, 1.0, 2.0]
amplitudes = [1.0, 0.5, 0.3]

# Setup main plot
ax_main.set_xlim(0, 10)
ax_main.set_ylim(-2, 2)
ax_main.set_xlabel('Time')
ax_main.set_ylabel('Signal')
ax_main.set_title('Interactive Animation Control')
ax_main.grid(True, alpha=0.3)

# Create lines
lines = []
for i, (freq, amp) in enumerate(zip(frequencies, amplitudes)):
    line, = ax_main.plot([], [], linewidth=2,
                         label=f'f={freq}Hz')
    lines.append(line)
ax_main.legend()

# Speed slider
speed_slider = Slider(ax_speed, 'Speed', 0.1, 5.0,
                      valinit=1.0, valstep=0.1)

# Play/Pause button
button = Button(ax_button, 'Pause')
is_paused = False

# Animation state
phase = 0
speed_multiplier = 1.0


def update_speed(val):
    global speed_multiplier
    speed_multiplier = speed_slider.val


def toggle_pause(event):
    global is_paused
    is_paused = not is_paused
    button.label.set_text('Play' if is_paused else 'Pause')


speed_slider.on_changed(update_speed)
button.on_clicked(toggle_pause)


def animate_advanced(frame):
    global phase

    if not is_paused:
        phase += 0.1 * speed_multiplier

    for i, (line, freq, amp) in enumerate(zip(lines, frequencies, amplitudes)):
        y = amp * np.sin(2 * np.pi * freq * t + phase + i * np.pi / 3)
        line.set_data(t, y)

    ax_main.set_title(f'Interactive Animation (Phase: {phase:.1f})')

    return lines


# Create animation
anim = FuncAnimation(fig, animate_advanced, interval=50, blit=True)

plt.show()

print("Advanced Animation Features:")
print("- Interactive speed control with slider")
print("- Play/Pause functionality")
print("- Real-time parameter updates")
print("- Custom widget integration")
print("\nNote: Interactive controls work best in standalone window")

## Summary and Best Practices

This tutorial has covered comprehensive animation techniques for neural data visualization:

1. **Basic Animation**
   - FuncAnimation for frame-by-frame updates
   - ArtistAnimation for pre-computed frames
   - Blitting for performance optimization

2. **Neural-Specific Animations**
   - Spike raster sliding windows
   - Population activity dynamics
   - Membrane potential oscillations

3. **Spatial Animations**
   - 2D activity patterns
   - Traveling waves and spirals
   - Spreading activation

4. **Network Dynamics**
   - Evolving connectivity
   - Node activity changes
   - Graph visualization

5. **Learning Visualization**
   - Performance metrics over time
   - Weight evolution
   - Feature importance dynamics


Export Options:

| Format | Writer | Best For | Pros | Cons |
|--------|--------|----------|------|------|
| GIF | Pillow | Web, Email | Universal support | Large files |
| MP4 | FFMpeg | Presentations | Small, high quality | Requires FFMpeg |
| HTML | JavaScript | Notebooks | Interactive | Large, needs browser |
| Frames | ImageMagick | Post-processing | Full control | Very large |