# Tutorial 4: Input and Output

In this tutorial, you'll learn:

- Generating input patterns (Poisson, periodic, custom)
- Input encoding strategies
- Using readout layers
- Population coding and decoding
- Recording and analyzing network outputs

In [None]:
import brainpy as bp
import brainstate
import brainunit as u
import braintools
import matplotlib.pyplot as plt
import jax.numpy as jnp
import numpy as np

## Part 1: Understanding Inputs and Outputs

Neural networks need:

**Inputs** → Convert external signals to neural activity
- Current injection
- Spike trains (Poisson, regular)
- Temporal patterns

**Outputs** → Extract information from network
- Spike counts
- Population vectors
- Readout layers

## Part 2: Constant Current Input

The simplest input: constant current.

In [None]:
brainstate.environ.set(dt=0.1 * u.ms)

# Create neuron
neuron = bp.LIF(10, V_rest=-65.*u.mV, V_th=-50.*u.mV, tau=10.*u.ms)
brainstate.nn.init_all_states(neuron)

# Simulate with constant input
duration = 200. * u.ms
times = u.math.arange(0.*u.ms, duration, brainstate.environ.get_dt())

I_constant = 2.0 * u.nA
spikes = brainstate.transform.for_loop(
    lambda t: neuron(I_constant),
    times
)

# Plot
t_idx, n_idx = u.math.where(spikes != 0)
plt.figure(figsize=(10, 4))
plt.scatter(times[t_idx].to_decimal(u.ms), n_idx, s=5, c='black')
plt.xlabel('Time (ms)')
plt.ylabel('Neuron Index')
plt.title('Response to Constant Current Input')
plt.grid(True, alpha=0.3)
plt.show()

print(f"Total spikes: {len(t_idx)}")
print(f"Average rate: {len(t_idx) / (10 * duration.to_decimal(u.second)):.2f} Hz")

## Part 3: Poisson Spike Trains

Realistic input: random Poisson spike trains.

In [None]:
def poisson_input(size, rate, dt):
    """Generate Poisson spike train.
    
    Args:
        size: Number of neurons
        rate: Firing rate (Hz)
        dt: Time step
    
    Returns:
        Binary spike array
    """
    prob = rate * dt.to_decimal(u.second)
    return (brainstate.random.rand(size) < prob).astype(float)

# Test Poisson input
brainstate.nn.init_all_states(neuron)
rate = 50 * u.Hz
dt = brainstate.environ.get_dt()

input_spikes_hist = []
output_spikes_hist = []

for t in times:
    # Generate Poisson input
    input_spikes = poisson_input(10, rate, dt)
    input_spikes_hist.append(input_spikes)
    
    # Convert spikes to current (simple model)
    I_poisson = input_spikes * 5.0 * u.nA
    neuron(I_poisson)
    output_spikes_hist.append(neuron.get_spike())

input_spikes_hist = jnp.array(input_spikes_hist)
output_spikes_hist = u.math.asarray(output_spikes_hist)

In [None]:
# Visualize input and output
fig, axes = plt.subplots(2, 1, figsize=(12, 8), sharex=True)

# Input spikes
t_in, n_in = jnp.where(input_spikes_hist > 0)
axes[0].scatter(times[t_in].to_decimal(u.ms), n_in, s=2, c='blue', alpha=0.5)
axes[0].set_ylabel('Neuron Index')
axes[0].set_title(f'Input: Poisson Spike Train ({rate.to_decimal(u.Hz):.0f} Hz)', fontweight='bold')
axes[0].grid(True, alpha=0.3)

# Output spikes
t_out, n_out = u.math.where(output_spikes_hist != 0)
axes[1].scatter(times[t_out].to_decimal(u.ms), n_out, s=2, c='red', alpha=0.5)
axes[1].set_xlabel('Time (ms)')
axes[1].set_ylabel('Neuron Index')
axes[1].set_title('Output: Neuron Response', fontweight='bold')
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Input spikes: {len(t_in)}")
print(f"Output spikes: {len(t_out)}")
print(f"Gain: {len(t_out) / len(t_in):.2f}x")

## Part 4: Periodic Input Patterns

Regular, rhythmic inputs.

In [None]:
def periodic_input(t, frequency, amplitude, phase=0):
    """Generate sinusoidal input current.
    
    Args:
        t: Time
        frequency: Oscillation frequency
        amplitude: Current amplitude
        phase: Phase offset
    """
    omega = 2 * jnp.pi * frequency.to_decimal(u.Hz)
    t_sec = t.to_decimal(u.second)
    return amplitude * (0.5 + 0.5 * jnp.sin(omega * t_sec + phase))

# Test periodic input
brainstate.nn.init_all_states(neuron)
freq = 10 * u.Hz
amp = 3.0 * u.nA

currents_hist = []
spikes_hist = []

for t in times:
    I_periodic = periodic_input(t, freq, amp)
    currents_hist.append(I_periodic)
    neuron(I_periodic)
    spikes_hist.append(neuron.get_spike())

currents_hist = u.math.asarray(currents_hist)
spikes_hist = u.math.asarray(spikes_hist)

In [None]:
# Plot periodic input and response
fig, axes = plt.subplots(2, 1, figsize=(12, 8), sharex=True)

# Input current
axes[0].plot(times.to_decimal(u.ms), currents_hist.to_decimal(u.nA), 
            linewidth=2, color='blue')
axes[0].set_ylabel('Current (nA)')
axes[0].set_title(f'Input: Periodic Current ({freq})', fontweight='bold')
axes[0].grid(True, alpha=0.3)

# Output spikes
t_idx, n_idx = u.math.where(spikes_hist != 0)
axes[1].scatter(times[t_idx].to_decimal(u.ms), n_idx, s=5, c='red', alpha=0.7)
axes[1].set_xlabel('Time (ms)')
axes[1].set_ylabel('Neuron Index')
axes[1].set_title('Output: Phase-Locked Spiking', fontweight='bold')
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("Observation: Neurons fire preferentially during high-current phases")

## Part 5: Rate Coding

Encode information in firing rates.

In [None]:
def rate_encode(values, max_rate, dt):
    """Encode values as Poisson spike trains.
    
    Args:
        values: Array of values to encode (0 to 1)
        max_rate: Maximum firing rate
        dt: Time step
    
    Returns:
        Binary spike array
    """
    rates = values * max_rate.to_decimal(u.Hz)
    probs = rates * dt.to_decimal(u.second)
    return (brainstate.random.rand(len(values)) < probs).astype(float)

# Example: encode a sine wave
n_neurons = 10
max_rate = 100 * u.Hz
duration = 500. * u.ms
times = u.math.arange(0.*u.ms, duration, brainstate.environ.get_dt())

encoded_spikes = []
signal_values = []

for i, t in enumerate(times):
    # Signal to encode (sine wave)
    signal = 0.5 + 0.5 * jnp.sin(2 * jnp.pi * 5 * t.to_decimal(u.second))
    signal_values.append(signal)
    
    # Encode as spikes for each neuron
    values = jnp.ones(n_neurons) * signal  # Same value for all neurons
    spikes = rate_encode(values, max_rate, brainstate.environ.get_dt())
    encoded_spikes.append(spikes)

encoded_spikes = jnp.array(encoded_spikes)
signal_values = jnp.array(signal_values)

In [None]:
# Visualize rate coding
fig, axes = plt.subplots(2, 1, figsize=(12, 8), sharex=True)

# Original signal
axes[0].plot(times.to_decimal(u.ms), signal_values, linewidth=2, color='blue')
axes[0].set_ylabel('Signal Value')
axes[0].set_title('Original Signal (to be encoded)', fontweight='bold')
axes[0].grid(True, alpha=0.3)

# Encoded spikes
t_idx, n_idx = jnp.where(encoded_spikes > 0)
axes[1].scatter(times[t_idx].to_decimal(u.ms), n_idx, s=1, c='red', alpha=0.5)
axes[1].set_xlabel('Time (ms)')
axes[1].set_ylabel('Neuron Index')
axes[1].set_title('Rate-Coded Spike Train', fontweight='bold')
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("Higher signal → higher spike density")

## Part 6: Population Coding

Multiple neurons encode a single value.

In [None]:
def population_encode(value, n_neurons, pref_values, sigma, max_rate, dt):
    """Encode value using population code with tuning curves.
    
    Args:
        value: Value to encode (0 to 1)
        n_neurons: Number of neurons
        pref_values: Preferred values for each neuron
        sigma: Tuning width
        max_rate: Maximum firing rate
        dt: Time step
    """
    # Tuning curves: Gaussian around preferred value
    responses = jnp.exp(-0.5 * ((value - pref_values) / sigma)**2)
    rates = responses * max_rate.to_decimal(u.Hz)
    probs = rates * dt.to_decimal(u.second)
    return (brainstate.random.rand(n_neurons) < probs).astype(float)

# Setup population
n_pop = 20
pref_values = jnp.linspace(0, 1, n_pop)  # Evenly spaced preferences
sigma = 0.2
max_rate = 100 * u.Hz

# Encode a slowly changing value
duration = 500. * u.ms
times = u.math.arange(0.*u.ms, duration, brainstate.environ.get_dt())

pop_spikes = []
true_values = []

for i, t in enumerate(times):
    # Value changes over time
    value = 0.5 + 0.3 * jnp.sin(2 * jnp.pi * 2 * t.to_decimal(u.second))
    true_values.append(value)
    
    # Population encoding
    spikes = population_encode(value, n_pop, pref_values, sigma, max_rate, 
                               brainstate.environ.get_dt())
    pop_spikes.append(spikes)

pop_spikes = jnp.array(pop_spikes)
true_values = jnp.array(true_values)

In [None]:
# Visualize population coding
fig, axes = plt.subplots(2, 1, figsize=(12, 8), sharex=True)

# True value
axes[0].plot(times.to_decimal(u.ms), true_values, linewidth=2, color='blue')
axes[0].set_ylabel('Encoded Value')
axes[0].set_title('True Value (to be encoded)', fontweight='bold')
axes[0].grid(True, alpha=0.3)

# Population spikes
t_idx, n_idx = jnp.where(pop_spikes > 0)
axes[1].scatter(times[t_idx].to_decimal(u.ms), n_idx, s=2, c='red', alpha=0.5)
axes[1].set_xlabel('Time (ms)')
axes[1].set_ylabel('Neuron Index (Preference)')
axes[1].set_title('Population Code: Activity Follows Value', fontweight='bold')
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("Peak activity shifts with encoded value")

## Part 7: Population Decoding

Extract the encoded value from population activity.

In [None]:
def population_decode(spike_counts, pref_values):
    """Decode value from population activity.
    
    Args:
        spike_counts: Number of spikes per neuron
        pref_values: Preferred values of neurons
    
    Returns:
        Decoded value (population vector)
    """
    # Population vector: weighted average
    total_activity = jnp.sum(spike_counts)
    if total_activity > 0:
        decoded = jnp.sum(spike_counts * pref_values) / total_activity
        return decoded
    else:
        return 0.5  # Default

# Decode the population activity
window_size = 50  # ms
window_steps = int(window_size / brainstate.environ.get_dt().to_decimal(u.ms))

decoded_values = []
decode_times = []

for i in range(0, len(times) - window_steps, window_steps // 2):
    # Count spikes in window
    window_spikes = pop_spikes[i:i+window_steps]
    spike_counts = jnp.sum(window_spikes, axis=0)
    
    # Decode
    decoded = population_decode(spike_counts, pref_values)
    decoded_values.append(decoded)
    decode_times.append(times[i + window_steps//2])

decoded_values = jnp.array(decoded_values)
decode_times = u.math.asarray(decode_times)

In [None]:
# Compare true and decoded values
plt.figure(figsize=(12, 5))
plt.plot(times.to_decimal(u.ms), true_values, linewidth=2, 
        label='True Value', color='blue', alpha=0.7)
plt.plot(decode_times.to_decimal(u.ms), decoded_values, linewidth=2, 
        label='Decoded Value', color='red', linestyle='--', alpha=0.7)
plt.xlabel('Time (ms)', fontsize=12)
plt.ylabel('Value', fontsize=12)
plt.title('Population Decoding: True vs Decoded Values', fontsize=14, fontweight='bold')
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

# Calculate decoding error
# Interpolate true values at decode times
true_at_decode = jnp.interp(
    decode_times.to_decimal(u.ms),
    times.to_decimal(u.ms),
    true_values
)
error = jnp.abs(decoded_values - true_at_decode)
print(f"Mean decoding error: {jnp.mean(error):.4f}")
print(f"Max decoding error: {jnp.max(error):.4f}")

## Part 8: Readout Layers

Use a readout layer to extract output.

In [None]:
# Create network with readout
class NetworkWithReadout(brainstate.nn.Module):
    def __init__(self, n_input, n_hidden, n_output):
        super().__init__()
        
        # Hidden layer (recurrent LIF neurons)
        self.hidden = bp.LIF(n_hidden, V_rest=-65.*u.mV, V_th=-50.*u.mV, tau=10.*u.ms)
        
        # Readout layer
        self.readout = bp.Readout(
            n_hidden, n_output,
            weight_initializer=braintools.init.KaimingNormal()
        )
    
    def update(self, spike_input):
        # Convert input spikes to current
        I_input = spike_input * 5.0 * u.nA
        
        # Update hidden neurons
        self.hidden(I_input)
        spikes = self.hidden.get_spike()
        
        # Readout
        output = self.readout(spikes)
        
        return output, spikes

# Create network
net = NetworkWithReadout(n_input=10, n_hidden=50, n_output=2)
brainstate.nn.init_all_states(net)

print("Network with readout layer created")
print(f"Hidden neurons: {50}")
print(f"Output dimensions: {2}")

In [None]:
# Test readout
duration = 200. * u.ms
times = u.math.arange(0.*u.ms, duration, brainstate.environ.get_dt())

outputs_hist = []
spikes_hist = []

for t in times:
    # Generate Poisson input
    input_spikes = poisson_input(10, 50*u.Hz, brainstate.environ.get_dt())
    
    # Network update
    output, spikes = net.update(input_spikes)
    outputs_hist.append(output)
    spikes_hist.append(spikes)

outputs_hist = jnp.array(outputs_hist)
spikes_hist = u.math.asarray(spikes_hist)

In [None]:
# Visualize readout
fig, axes = plt.subplots(2, 1, figsize=(12, 8), sharex=True)

# Hidden layer activity
t_idx, n_idx = u.math.where(spikes_hist != 0)
axes[0].scatter(times[t_idx].to_decimal(u.ms), n_idx, s=1, c='blue', alpha=0.5)
axes[0].set_ylabel('Neuron Index')
axes[0].set_title('Hidden Layer Spikes', fontweight='bold')
axes[0].grid(True, alpha=0.3)

# Readout outputs
axes[1].plot(times.to_decimal(u.ms), outputs_hist[:, 0], 
            linewidth=2, label='Output 1', alpha=0.7)
axes[1].plot(times.to_decimal(u.ms), outputs_hist[:, 1], 
            linewidth=2, label='Output 2', alpha=0.7)
axes[1].set_xlabel('Time (ms)')
axes[1].set_ylabel('Readout Value')
axes[1].set_title('Readout Layer Outputs', fontweight='bold')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("Readout layer converts spikes to continuous values")

## Part 9: Recording Network States

Record variables during simulation.

In [None]:
# Manual recording example
neuron = bp.LIF(5, V_rest=-65.*u.mV, V_th=-50.*u.mV, tau=10.*u.ms)
brainstate.nn.init_all_states(neuron)

duration = 100. * u.ms
times = u.math.arange(0.*u.ms, duration, brainstate.environ.get_dt())

# Preallocate recording arrays
n_steps = len(times)
V_hist = []
spike_hist = []

for t in times:
    neuron(2.0 * u.nA)
    
    # Record states
    V_hist.append(neuron.V.value.copy())
    spike_hist.append(neuron.get_spike().copy())

V_hist = u.math.asarray(V_hist)  # Shape: (time, neurons)
spike_hist = u.math.asarray(spike_hist)

print(f"Recorded {n_steps} time steps")
print(f"Voltage history shape: {V_hist.shape}")
print(f"Spike history shape: {spike_hist.shape}")

In [None]:
# Plot recorded states
plt.figure(figsize=(12, 6))

# Plot voltage traces
for i in range(5):
    V_trace = V_hist[:, i]
    # Mark spikes
    spike_times = times[spike_hist[:, i] > 0]
    V_with_spikes = V_trace.copy()
    V_with_spikes = V_with_spikes.to_decimal(u.mV)
    
    plt.plot(times.to_decimal(u.ms), V_with_spikes, 
            linewidth=1.5, alpha=0.7, label=f'Neuron {i}')

plt.axhline(y=-50, color='r', linestyle='--', alpha=0.5, label='Threshold')
plt.xlabel('Time (ms)', fontsize=12)
plt.ylabel('Membrane Potential (mV)', fontsize=12)
plt.title('Recorded Voltage Traces', fontsize=14, fontweight='bold')
plt.legend(loc='upper right', fontsize=9)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## Summary

In this tutorial, you learned:

✅ **Input Generation**: Constant, Poisson, periodic patterns

✅ **Rate Coding**: Encoding values as firing rates

✅ **Population Coding**: Multiple neurons encode single values

✅ **Population Decoding**: Extract values from spike trains

✅ **Readout Layers**: Convert spikes to continuous outputs

✅ **Recording States**: Track network variables over time

## Key Concepts

1. **Input Encoding**: Convert signals → spike patterns
2. **Population Codes**: Distributed representation across neurons
3. **Decoding**: Extract information from population activity
4. **Readout**: Linear combination of spike counts

## Next Steps

- **Tutorial 5**: Learn [SNN training](../advanced/05-snn-training.ipynb)
- **Examples**: See [trained networks](../../examples/gallery.rst#snn-training)
- **Advanced**: Explore [reservoir computing](../../examples/gallery.rst)

## Exercises

1. Implement temporal coding (first-spike latency)
2. Create a 2D population code (e.g., for position)
3. Build a classifier using readout layer
4. Compare different decoding methods (vector, maximum likelihood)
5. Implement sparse coding with inhibition