# SYDE 556/750 --- Assignment 2
**Student ID: 20823934**

*Note:* Please include your numerical student ID only, do *not* include your name.

*Note:* Refer to the [PDF](https://github.com/celiasmith/syde556-f22/raw/master/assignments/assignment_02/syde556_assignment_02.pdf) for the full instructions (including some hints), this notebook contains abbreviated instructions only. Cells you need to fill out are marked with a "writing hand" symbol. Of course, you can add new cells in between the instructions, but please leave the instructions intact to facilitate marking.

In [None]:
# Import numpy and matplotlib -- you shouldn't need any other libraries
import numpy as np
import matplotlib.pyplot as plt

# Fix the numpy random seed for reproducible results
np.random.seed(18945)

# Some formating options
%config InlineBackend.figure_formats = ['svg']

# 1. Generating a random input signal

## 1.1 Band-limited white noise

**a) Time-domain signals.** Plot $x(t)$ for three randomly generated signals with $\texttt{limit}$ at $5$, $10$, and $20\,\mathrm{Hz}$. For each of these, $\mathtt{T}=1\,\mathrm{s}$, $\mathtt{dt}=1\,\mathrm{ms}$ and $\mathtt{rms}=0.5$.

In [None]:
# Create a function that generates a randomly varying x(t) signal chosen from a band limited white noise distribution
# Return both x(t) and X(œâ) --> both its time- and Fourier-domain representation
def generate_signal(T, dt, rms, limit, seed):
    """
    T --> The length of the signal in seconds
    dt --> Time time step in seconds
    rms --> The root mean square power of the signal
    limit --> The maximum frequency of the signal
    seed --> Random seed
    """
    np.random.seed(seed)
    ts = np.arange(0, T, dt)                                # Time points
    fs_hz = np.fft.fftshift(np.fft.fftfreq(len(ts), dt))    # Frequency bins in Hz
    fs_rad = 2 * np.pi * fs_hz                              # Frequency bins in rad/s

    # Generate half of the random signal in the frequency domain
    num_samples = len(fs_rad) // 2
    real = np.random.normal(0, 1, num_samples)
    imag = np.random.normal(0, 1, num_samples)
    half_freq_signal = real + 1j * imag

    # Create the full frequency signal by mirroring the half signal
    # Note: The DC component is 0, so the mean of the time domain signal is 0
    freq_signal = np.concatenate((half_freq_signal, np.array([0]), np.conj(np.flip(half_freq_signal))))

    # Cut freq_signal so it has the same length as fs (in case len(fs) is even)
    freq_signal = freq_signal[:len(fs_rad)]

    # Limit the signal to the desired frequency range
    freq_signal[np.abs(fs_hz) > limit] = 0

    # Turn it back into the time domain
    time_signal = np.fft.ifft(np.fft.ifftshift(freq_signal))

    # Check if the time domain signal is real
    imag_threshold = 1e-10  # Error threshold
    if np.all(np.abs(time_signal.imag) < imag_threshold):
        time_signal = time_signal.real
    else:
        raise ValueError("The time domain signal is not real")

    # Scale the time and frequency signals to the RMS power
    old_rms = np.sqrt(np.mean(time_signal ** 2))
    scaling_factor = rms / old_rms
    time_signal *= scaling_factor
    freq_signal *= scaling_factor

    return ts, fs_rad, time_signal, freq_signal

# Adapting the plotting code from the lecture notes
def plot_signals(ts, fs_rad, time_signal, freq_signal):
    # Calculate bandwidth by finding the highest frequency above the energy threshold
    bandwidth = fs_rad[np.max(np.where(np.abs(freq_signal) > 1e-3))]

    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(6.5, 2.25))
    ax1.plot(ts, time_signal)
    ax1.set_xlabel('Time $t$ (s)')
    ax1.set_ylabel('Signal x')
    ax1.set_title('Time domain')

    ax2.plot(fs_rad, np.abs(freq_signal))
    ax2.set_xlabel('Frequency $f$ (rad/s)')
    ax2.set_ylabel('Power spectrum $|X|$')
    ax2.set_title('Frequency domain')

    ax3.plot(fs_rad, np.abs(freq_signal))
    ax3.set_xlabel('Frequency $f$ (rad/s)')
    ax3.set_ylabel('Power spectrum $|X|$')
    ax3.set_xlim(-min(20*2*np.pi, bandwidth), min(20*2*np.pi, bandwidth))
    ax3.set_title('Frequency domain (magnified)')

    fig.tight_layout(pad=0.5)

    plt.show()

# Generate and plot the signals
for limit in [5, 10, 20]:
    ts, fs_rad, time_signal, freq_signal = generate_signal(T=1, dt=0.001, rms=0.5, limit=limit, seed=0)
    print(f'Bandwidth: {limit} Hz = {np.round(limit * 2 * np.pi, 1)} rad/s')
    plot_signals(ts, fs_rad, time_signal, freq_signal)

**b) Average power spectrum.** Plot the average $|X(\omega)|$ (the norm of the Fourier coefficients, or ‚Äúpower spectrum‚Äù) over $100$ signals generated with $\mathtt{T}=1\,\mathrm{s}$, $\mathtt{dt}=1\,\mathrm{ms}$, $\mathtt{rms}=0.5$, and $\mathtt{limit}=10\,\mathrm{Hz}$ (of course, each of these 100 signals should have a different `seed`). The plot should have the $x$-axis labeled ‚Äú$\omega$ in radians‚Äù and the average $|X|$ value for that $\omega$ on the $y$-axis.

In [None]:
sum_norm_coefficients = 0
for i in range(100):
    ts, fs_rad, time_signal, freq_signal = generate_signal(T=1, dt=0.001, rms=0.5, limit=10, seed=i)
    sum_norm_coefficients += np.abs(freq_signal)
mean_norm_coefficients = sum_norm_coefficients / 100

print(f'Bandwidth: 10 Hz = {np.round(20 * np.pi, 1)} rad/s')
plt.plot(fs_rad, mean_norm_coefficients)
plt.xlabel('œâ in radians')
plt.ylabel('Average |X(œâ)|')
plt.xlim(-20*2*np.pi, 20*2*np.pi)
plt.title('Average Norm of Fourier Coefficients over 100 signals')
plt.show()

## 1.2 Gaussian power spectrum noise

**a) Time-domain signals.** Plot $x(t)$ for three randomly generated signals with `bandwidth` at $5$, $10$, and $20\,\mathrm{Hz}$. For each of these, $\mathtt{T}=1\,\mathrm{s}$, $\mathtt{dt}=1\,\mathrm{ms}$ and $\mathtt{rms}=0.5$.

In [None]:
# Modify the function so the power drops off as frequency increases
# Return both time and frequency domain signals
def generate_signal_bandwidth(T, dt, rms, bandwidth, seed):
    """
    T --> The length of the signal in seconds
    dt --> Time time step in seconds
    rms --> The root mean square power of the signal
    limit --> The maximum frequency of the signal
    seed --> Random seed
    """
    ts = np.arange(0, T, dt)                                # Time points
    fs_hz = np.fft.fftshift(np.fft.fftfreq(len(ts), dt))    # Frequency bins in Hz
    fs_rad = 2 * np.pi * fs_hz                              # Frequency bins in rad/s

    # Generate half of the random signal in the frequency domain
    num_samples = len(fs_rad) // 2
    half_freq_signal = np.zeros(num_samples, dtype=np.complex128)

    for i in range(num_samples):
        stdev = np.exp(-fs_rad[i]**2 / (2 * bandwidth**2))
        real = np.random.normal(0, stdev)
        imag = np.random.normal(0, stdev)
        half_freq_signal[i] = real + 1j * imag

    # Create the full frequency signal by mirroring the half signal
    # Note: The DC component is 0, so the mean of the time domain signal is 0
    freq_signal = np.concatenate((half_freq_signal, np.array([0]), np.conj(np.flip(half_freq_signal))))

    # Cut freq_signal so it has the same length as fs (in case len(fs) is even)
    freq_signal = freq_signal[:len(fs_rad)]

    # Turn it back into the time domain
    time_signal = np.fft.ifft(np.fft.ifftshift(freq_signal))

    # Check if the time domain signal is real
    imag_threshold = 1e-10  # Error threshold
    if np.all(np.abs(time_signal.imag) < imag_threshold):
        time_signal = time_signal.real
    else:
        raise ValueError("The time domain signal is not real")

    # Scale the time and frequency signals to the RMS power
    old_rms = np.sqrt(np.mean(time_signal ** 2))
    scaling_factor = rms / old_rms
    time_signal *= scaling_factor
    freq_signal *= scaling_factor

    return ts, fs_rad, time_signal, freq_signal

# Generate and plot the signals
for bandwidth in [5, 10, 20]:
    ts, fs_rad, time_signal, freq_signal = generate_signal_bandwidth(T=1, dt=0.001, rms=0.5, bandwidth=bandwidth, seed=0)
    print(f'Bandwidth: {limit} Hz = {np.round(limit * 2 * np.pi, 1)} rad/s')
    plot_signals(ts, fs_rad, time_signal, freq_signal)

**b) Average power spectrum.** Plot the average $|X(\omega)|$ (the norm of the Fourier coefficients, or ‚Äúpower spectrum‚Äù) over $100$ signals generated with $\mathtt{T}=1\,\mathrm{s}$, $\mathtt{dt}=1\,\mathrm{ms}$, $\mathtt{rms}=0.5$, and $\mathtt{bandwidth}=10$ (of course, each of these 100 signals should have a different `seed`). The plot should have the $x$-axis labeled ‚Äú$\omega$ in radians‚Äù and the average $|X|$ value for that $\omega$ on the $y$-axis.

In [None]:
sum_norm_coefficients = 0
for i in range(100):
    ts, fs_rad, time_signal, freq_signal = generate_signal_bandwidth(T=1, dt=0.001, rms=0.5, bandwidth=10, seed=i)
    sum_norm_coefficients += np.abs(freq_signal)
mean_norm_coefficients = sum_norm_coefficients / 100

print(f'Bandwidth: 10 Hz = {np.round(20 * np.pi, 1)} rad/s')
plt.plot(fs_rad, mean_norm_coefficients)
plt.xlabel('œâ in radians')
plt.ylabel('Average |X(œâ)|')
plt.xlim(-20*2*np.pi, 20*2*np.pi)
plt.title('Average Norm of Fourier Coefficients over 100 signals')
plt.show()

# 2. Simulating a spiking neuron

**a) Spike plots for constant inputs.** Plot the spike output for a constant input of $x=0$ over $1$ second. Report the number of spikes. Do the same thing for $x=1$. Use a time step of $\Delta t = 1\,\mathrm{ms}$ for the simulation.

In [None]:
tau_ref = 0.002 # 2 ms
tau_rc = 0.02   # 20 ms

def G_inverse(a):
    return 1 / (1 - np.exp((tau_ref - 1/a)/(tau_rc)))

def G(J):
    if J > 1:
        return 1 / (tau_ref - tau_rc * np.log(1 - 1/J))
    else:
        return 0

# Find alpha and J_bias such that the firing rate when x = 0 is 40 Hz and when x = 1 is 150 Hz
# 40 = G(J = J_bias) ==> J_bias = G_inverse(40)
# 150 = G(J = alpha + J_bias) ==> alpha + J_bias = G_inverse(150) ==> alpha = G_inverse(150) - J_bias
J_bias = G_inverse(40)
alpha = G_inverse(150) - J_bias

def spike_train(x, e, alpha, J_bias, tau_ref, tau_rc, v_threshold=1, T=1, delta_t=0.001):
    ts = np.arange(0, T, delta_t)
    voltage = 0
    refractory_countdown = 0
    J = alpha * e * x + J_bias

    spikes = []
    voltages = []
    for i in range(len(ts)):
        # Make sure the voltage is non-negative
        voltage = max(voltage, 0)
        voltages.append(voltage)

        # If not in refractory period, update the voltage
        if refractory_countdown <= 0:
            # If there's a spike, reset the voltage and start the countdown
            if voltage >= v_threshold:
                spikes.append(1)
                voltage = 0
                refractory_countdown = tau_ref

            # If there's no spike, update the voltage
            else:
                spikes.append(0)
                delta_v = (J[i] - voltage) / tau_rc # dv/dt = (J - v) / tau_rc
                voltage += delta_v * delta_t        # dv = dv/dt * dt

        # If in refractory period, only update the countdown
        else:
            spikes.append(0)
            refractory_countdown -= delta_t

    spikes = np.array(spikes)
    num_spikes = sum(spikes)

    return ts, spikes, num_spikes, voltages

T = 1
delta_t = 0.001

x_0 = np.zeros(int(T/delta_t))
ts, spikes, num_spikes, voltages = spike_train(x=x_0, e=1, alpha=alpha, J_bias=J_bias, tau_ref=tau_ref, tau_rc=tau_rc, T=T, delta_t=delta_t)
print(f"{num_spikes} spikes for x = 0, T = 1, and delta_t = 0.001")
plt.plot(ts, spikes)
plt.xlabel('Time (s)')
plt.ylabel('Spikes')
plt.title(f'Spike Train for x = 0 and e = 1')
plt.show()

x_1 = np.ones(int(T/delta_t))
ts, spikes, num_spikes, voltages = spike_train(x=x_1, e=1, alpha=alpha, J_bias=J_bias, tau_ref=tau_ref, tau_rc=tau_rc, T=T, delta_t=delta_t)
print(f"{num_spikes} spikes for x = 1, T = 1, and delta_t = 0.001")
plt.plot(ts, spikes)
plt.xlabel('Time (s)')
plt.ylabel('Spikes')
plt.title(f'Spike Train for x = 1 and e = 1')
plt.show()

**b) Discussion.** Does the observed number of spikes in the previous part match the expected number of spikes for $x=0$ and $x=1$? Why or why not? What aspects of the simulation would affect this accuracy?

In [None]:
for T in [1, 10, 100, 1000]:
    ts, spikes, num_spikes, voltages = spike_train(x=np.zeros(int(T/0.001)), e=1, alpha=alpha, J_bias=J_bias, tau_ref=tau_ref, tau_rc=tau_rc, T=T)
    print(f"{num_spikes} spikes for x = 0, T = {T}, and delta_t = 0.001")

for dt in [0.01, 0.001, 0.0001, 0.00001]:
    ts, spikes, num_spikes, voltages = spike_train(x=np.zeros(int(1/dt)), e=1, alpha=alpha, J_bias=J_bias, tau_ref=tau_ref, tau_rc=tau_rc, delta_t=dt)
    print(f"{num_spikes} spikes for x = 0, T = 1, and delta_t = {dt}")

No, the observed number of spikes are lower than the expected number, although they're similar. There are 4 main reasons why the simulation may not be accurate.

1. Simulation length T: If a spike were to occur right after the simulation ends, the total count could be short by 1. It's better to run the simulation for longer so the spike frequency can be averaged over a longer period of time. As shown above, there are only 38 spikes for `T=1` (38Hz) but 38461 spikes for `T=1000` (38.461Hz), which is closer to 40Hz.

2. Simulation interval delta_t: Because we're using a numerical method to approximate the integral, decreasing the simulation interval makes the estimate more accurate. As seen above, we only count 25 spikes per second for `delta_t = 0.01`, but 40 spikes per second for `delta_t = 0.00001`.

3. Euler's method for approximating the integral: Euler's method is a very simple numerical integration technique. Something more complex like the 4th order Runge-Kutta method might give a better approximation for the integral.

4. Neuron saturation: Although the neurons don't saturate in this example, the refractory period could reduce the number of spikes in other simulations.

**c) Spike plots for white noise inputs.** Plot the spike output for $x(t)$ generated using your function from part 1.1. Use $\mathtt{T}=1\,\mathrm{s}$, $\mathtt{dt}=1\,\mathrm{ms}$, $\mathtt{rms}=0.5$, and $\mathtt{limit}=30\,\mathrm{Hz}$. Overlay on this plot $x(t)$.

In [None]:
T = 1
delta_t = 0.001

ts, fs_rad, time_signal, freq_signal = generate_signal(T=T, dt=delta_t, rms=0.5, limit=30, seed=0)
ts, spikes, num_spikes, voltages = spike_train(x=time_signal, e=1, alpha=alpha, J_bias=J_bias, tau_ref=tau_ref, tau_rc=tau_rc, T=T, delta_t=delta_t)

plt.plot(ts, spikes, label='Spikes')
plt.plot(ts, time_signal, label='Signal')
plt.legend()
plt.xlabel('Time (s)')
plt.ylabel('Spikes & Signals')
plt.title(f'Spike Train for a Noisy Input')
plt.show()

**d) Voltage over time.** Using the same $x(t)$ signal as in part *c)*, plot the neuron's voltage over time for the first $0.2$ seconds, along with the spikes over the same time.

In [None]:
# Plot the first 200ms of the previous result instread of generating another noisy signal
plt.plot(ts[:200], voltages[:200], label='Voltage')
plt.plot(ts[:200], spikes[:200], label='Spikes')
plt.legend()
plt.xlabel('Time (s)')
plt.ylabel('Voltages & Spikes')
plt.title(f'Spike Train and Voltages')
plt.show()

**e) üåü Bonus question.** How could you improve this simulation (in terms of how closely the model matches actual equation) without significantly increasing the computation time? $0.5$ marks for having a good idea. Up to $1$ mark for actually implementing it and showing that it works.

In [None]:
# Use the 4th order Runge-Kutta method instead of Euler's method
def spike_train_runge_kutta(x, e, alpha, J_bias, tau_ref, tau_rc, v_threshold=1, T=1, delta_t=0.001):
    ts = np.arange(0, T, delta_t)
    voltage = 0
    refractory_countdown = 0
    J = alpha * e * x + J_bias

    spikes = []
    voltages = []

    def f(v, j):
        return (j - v) / tau_rc

    for i in range(len(ts)):
        # Make sure the voltage is non-negative
        voltage = max(voltage, 0)
        voltages.append(voltage)

        # If not in refractory period, update the voltage
        if refractory_countdown <= 0:
            # If there's a spike, reset the voltage and start the countdown
            if voltage >= v_threshold:
                spikes.append(1)
                voltage = 0
                refractory_countdown = tau_ref
            # If there's no spike, update the voltage using 4th order Runge-Kutta
            else:
                spikes.append(0)
                k1 = f(voltage, J[i]) * delta_t
                k2 = f(voltage + 0.5 * k1, J[i]) * delta_t
                k3 = f(voltage + 0.5 * k2, J[i]) * delta_t
                k4 = f(voltage + k3, J[i]) * delta_t
                voltage += (k1 + 2 * k2 + 2 * k3 + k4) / 6

        # If in refractory period, only update the countdown
        else:
            spikes.append(0)
            refractory_countdown -= delta_t

    spikes = np.array(spikes)
    num_spikes = sum(spikes)

    return ts, spikes, num_spikes, voltages

# Vectorize the variables in the original function
def spike_train_vectorized(x, e, alpha, J_bias, tau_ref, tau_rc, v_threshold=1, T=1, delta_t=0.001):
    ts = np.arange(0, T, delta_t)
    voltage = np.zeros_like(ts)
    spikes = np.zeros_like(ts, dtype=int)
    refractory_end = np.zeros_like(ts)
    J = alpha * e * x + J_bias
    
    for i in range(1, len(ts)):  # Starting from 1 as the 0th index is initial conditions
        if ts[i] < refractory_end[i-1]:
            voltage[i] = 0
            refractory_end[i] = refractory_end[i-1]
        else:
            delta_v = (J[i] - voltage[i-1]) / tau_rc
            voltage[i] = voltage[i-1] + delta_v * delta_t
            if voltage[i] >= v_threshold:
                spikes[i] = 1
                refractory_end[i] = ts[i] + tau_ref
                voltage[i] = 0

    num_spikes = np.sum(spikes)

    return ts, spikes, num_spikes, voltages

# Evaluate time efficiency
import time

def eval_runtime(func, func_name, repeats=5, delta_t=0.001):
    runtime = 0
    for i in range(repeats):
        start_time = time.time()
        ts, spikes, num_spikes, voltages = func(x=np.ones(int(1/delta_t)), e=1, alpha=alpha, J_bias=J_bias, tau_ref=tau_ref, tau_rc=tau_rc, delta_t=delta_t)
        elapsed_time = time.time() - start_time
        runtime += elapsed_time
    error = np.round(abs(num_spikes - 150) / 150, 3)
    avg_runtime = np.round(runtime / repeats * 1000, 3)
    print(f'{func_name} --> Spikes: {num_spikes}, Error: {error}, Average runtime: {avg_runtime} ms over {repeats} iterations')

eval_runtime(spike_train, "Original")
eval_runtime(spike_train_runge_kutta, "Runge-Kutta")
eval_runtime(spike_train_vectorized, "Vectorized")
eval_runtime(spike_train, "Sampling Interval", delta_t=0.0001)
eval_runtime(spike_train, "Sampling Interval", delta_t=0.00001)

There are 3 main things to try for improving the simulation accuracy. From 2b), we see that increasing the simulation period T doesn't improve accuracy very much, so it doesn't need to be tried.

1. Use a more complex numerical integration method like the 4th order Runge-Kutta: The results show that Runge-Kutta has the same error but a longer runtime. Although this specific method didn't work, other numerical methods might decrease the error.

2. Refactor the code to only update a vector when something happens, instead of appending the result to a vector after every iteration. This modification intends to take advantage of numpy's efficient handling of arrays instead of updating a list, which is thought to be a slower operation. The results show that the error decreased slightly, but the runtime also increased, almost by the same factor. Therefore, refactoring isn't the solution.

3. Decreasing the sampling interval delta_t: When `dt = 0.1ms`, the error decreases from 0.167 to 0.02. At `dt = 0.01ms`, the error disappears. Although this method increases the runtime by the same order of magnitude as the error decreases, the longest runtime is only 31ms, which is insignificant in most practical cases. Additionally, the user can choose the runtime according to how much error they're willing to tolerate, making this the most flexible option.

# 3. Simulating two spiking neurons

**a) Spike plots for constant inputs.** Plot $x(t)$ and the spiking output for $x(t)=0$ (both neurons should spike at about $40$ spikes per second), as well as (in a separate plot) $x(t)=1$ (one neuron should spike at $\approx 150$ spikes per second, and the other should not spike at all).

In [None]:
T = 1
delta_t = 0.001

# Plot x = 0
x_0 = np.zeros(int(T/delta_t))
ts, spikes_0p, num_spikes_0p, voltages = spike_train(x=x_0, e=1, alpha=alpha, J_bias=J_bias, tau_ref=tau_ref, tau_rc=tau_rc, T=T, delta_t=delta_t)
ts, spikes_0n, num_spikes_0n, voltages = spike_train(x=x_0, e=-1, alpha=alpha, J_bias=J_bias, tau_ref=tau_ref, tau_rc=tau_rc, T=T, delta_t=delta_t)

print(f"{num_spikes_0p} positive spikes and {num_spikes_0n} negative spikes for x = 0")
plt.plot(ts, spikes_0p, label='Positive Neuron')
plt.plot(ts, -spikes_0n, label='Negative Neuron')
plt.plot(ts, x_0, label='Input')
plt.legend()
plt.xlabel('Time (s)')
plt.ylabel('Spikes')
plt.ylim(-1.1, 1.1)
plt.title(f'Spike Trains for x = 0')
plt.show()

# Plot x = 1
x_1 = np.ones(int(T/delta_t))
ts, spikes_1p, num_spikes_1p, voltages = spike_train(x=x_1, e=1, alpha=alpha, J_bias=J_bias, tau_ref=tau_ref, tau_rc=tau_rc, T=T, delta_t=delta_t)
ts, spikes_1n, num_spikes_1n, voltages = spike_train(x=x_1, e=-1, alpha=alpha, J_bias=J_bias, tau_ref=tau_ref, tau_rc=tau_rc, T=T, delta_t=delta_t)

print(f"{num_spikes_1p} positive spikes and {num_spikes_1n} negative spikes for x = 1")
plt.plot(ts, spikes_1p, label='Positive Neuron')
plt.plot(ts, -spikes_1n, label='Negative Neuron')
plt.plot(ts, x_1, label='Input')
plt.legend()
plt.xlabel('Time (s)')
plt.ylabel('Spikes')
plt.ylim(-1.1, 1.1)
plt.title(f'Spike Trains for x = 1')
plt.show()

**b) Spike plots for a sinusodial input.** Plot $x(t)$ and the spiking output for $x(t)=\frac{1}2 \sin(10 \pi t)$.

In [None]:
T = 1
delta_t = 0.001

def x_func(t):
    return 0.5 * np.sin(10 * np.pi * t)

ts = np.arange(0, T, delta_t)
xs = x_func(ts)

ts, spikes_sp, num_spikes_sp, voltages = spike_train(x=xs, e=1, alpha=alpha, J_bias=J_bias, tau_ref=tau_ref, tau_rc=tau_rc, T=T, delta_t=delta_t)
ts, spikes_sn, num_spikes_sn, voltages = spike_train(x=xs, e=-1, alpha=alpha, J_bias=J_bias, tau_ref=tau_ref, tau_rc=tau_rc, T=T, delta_t=delta_t)

print(f"{num_spikes_sp} positive spikes and {num_spikes_sn} negative spikes for x = 0.5sin(10œÄt)")
plt.plot(ts, spikes_sp, label='Positive Neuron')
plt.plot(ts, -spikes_sn, label='Negative Neuron')
plt.plot(ts, xs, label='Input')
plt.legend()
plt.xlabel('Time (s)')
plt.ylabel('Spikes')
plt.title(f'Spike Trains for x = 0.5sin(10œÄt)')
plt.show()

**c) Spike plot for a white noise signal.** Plot $x(t)$ and the spiking output for a random signal generated with your function for question 1.1 with $\mathtt{T}=2\,\mathrm{s}$, $\mathtt{dt}=1\,\mathrm{ms}$, $\mathtt{rms}=0.5$, and $\mathtt{limit}=5\,\mathrm{Hz}$.

In [None]:
T = 2
delta_t = 0.001

ts, fs_rad, time_signal, freq_signal = generate_signal(T=T, dt=delta_t, rms=0.5, limit=5, seed=0)
ts, spikes_sp, num_spikes_sp, voltages = spike_train(x=time_signal, e=1, alpha=alpha, J_bias=J_bias, tau_ref=tau_ref, tau_rc=tau_rc, T=T, delta_t=delta_t)
ts, spikes_sn, num_spikes_sn, voltages = spike_train(x=time_signal, e=-1, alpha=alpha, J_bias=J_bias, tau_ref=tau_ref, tau_rc=tau_rc, T=T, delta_t=delta_t)

print(f"{num_spikes_sp} positive spikes and {num_spikes_sn} negative spikes for a noisy signal")
plt.figure(figsize=(10,5))
plt.plot(ts, spikes_sp, label='Positive Neuron')
plt.plot(ts, -spikes_sn, label='Negative Neuron')
plt.plot(ts, time_signal, label='Input')
plt.legend()
plt.xlabel('Time (s)')
plt.ylabel('Spikes')
plt.title(f'Spike Trains for a noisy signal')
plt.show()

# 4. Computing an optimal filter

**a) Document the code.** Fill in comments where there are `# !`-signs in the Python code. Make sure that your comments (where this makes sense) describe the semantics of the code and do not just repeat what is obvious from the code itself. Run the function with what you wrote for part 3 above, so that it uses the spike signal generated in 3c).

In [None]:
def compute_optimal_filter(
        # Signal generated from your white noise generator
        x,
        # Fourier coefficients from your white noise generator
        X,
        # Spike train from the previous part
        spikes,
        # Time step size
        dt=1e-3
    ):

    # x and X should (effectively) be 1D-arrays
    assert x.ndim == 1 and X.ndim == 1
    assert x.shape[0] == X.shape[0]

    # Number of time points t
    Nt = x.size

    # Make sure that "spikes" is a 2 x Nt array
    assert spikes.ndim == 2
    assert spikes.shape[0] == 2              
    assert spikes.shape[1] == Nt

    # Total simulation time = Number of time points * time step size
    T = Nt * dt

    # Nt number of time points that go from -T/2 to T/2 with step size dt
    ts = np.arange(Nt) * dt - T / 2.0

    # Nt number of frequency bins that go from -Nt/(2T) aka -Nyquist to Nt/(2T) aka +Nyquist with step size 1/T
    # Nt / T is the sampling rate, so the Nyquist frequency is Nt / (2T)
    fs = np.arange(Nt) / T - Nt / (2.0 * T)

    # Angular velocity in rad/s = 2œÄ * frequency in Hz
    omega = fs * 2.0 * np.pi

    # Response = sum of signed spikes (spikes[0] is positive, spikes[1] is negative)
    r = spikes[0] - spikes[1]

    # Transform the response into the frequency domain so convolution becomes multiplication
    R = np.fft.fftshift(np.fft.fft(r))

    # Set the stdev of the Gaussian filter to 40, so sigma_t = 1/sigma = 25e-3
    sigma_t = 25e-3

    # The unnormalized Gaussian filter = exp(-œâ^2 * (1/œÉ)^2)
    W2 = np.exp(-omega**2*sigma_t**2)

    # Normalize the Gaussian filter so the sum of the coefficients is 1
    W2 = W2 / sum(W2)

    # The numerator of the unfiltered H --> (X R_conj)
    CP = X*R.conjugate()

    # Improve the optimal filter by filtering the numerator with a Gaussian filter --> (X R_conj) * W
    WCP = np.convolve(CP, W2, 'same')

    # The denominator of the unfiltered H --> |R|^2
    RP = R*R.conjugate()

    # Improve the optimal filter by filtering the denominator with a Gaussian filter --> |R|^2 * W
    WRP = np.convolve(RP, W2, 'same')

    # Calculates the power spectral density of the noisy signal
    # Basically doing real^2 + imag^2 for each frequency or |X|^2
    XP = X*X.conjugate()

    # Smoothes the power spectral density by convolving it with a Gaussian filter
    WXP = np.convolve(XP, W2, 'same')

    # Compute the optimal windowed filter H = (X R_conj) * W / |R|^2 * W
    H = WCP / WRP

    # Convert the optimal filter back into the time domain and discard the imaginary components [H --> h]
    h = np.fft.fftshift(np.fft.ifft(np.fft.ifftshift(H))).real

    # Compute the decoded value x_hat by convolving the response R with the filter H in the frequency domain
    XHAT = H*R

    # Convert the decoded value back into the time domain and discard the imaginary components [XHAT --> x_hat]
    xhat = np.fft.ifft(np.fft.ifftshift(XHAT)).real

    return ts, fs, R, H, h, XHAT, xhat, XP, WXP

ts, fs, R, H, h, XHAT, xhat, XP, WXP = compute_optimal_filter(x=time_signal, X=freq_signal, spikes=np.array([spikes_sp, spikes_sn]))

**b) Optimal filter.** Plot the time and frequency plots of the optimal filter for the signal you generated in question 3c). Make sure to use appropriate limits for the $x$-axis.

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(12, 5))

axs[0].plot(ts, h)
axs[0].set_xlabel('Time (s)')
axs[0].set_ylabel('Filter h')
axs[0].set_title('Optimal Filter in the Time Domain')
axs[0].set_xlim([-0.25, 0.25])

axs[1].plot(fs, H.real)
axs[1].set_xlabel('Frequency (Hz)')
axs[1].set_ylabel('Filter H')
axs[1].set_title('Optimal Filter in the Frequency Domain')
axs[1].set_xlim([-50, 50])

plt.tight_layout()
plt.show()

**c) Decoded signal.** Plot the $x(t)$ signal, the spikes, and the decoded $\hat x(t)$ value for the signal from 3c).

In [None]:
# Plot the original and decoded signals
plt.figure(figsize=(12,5))
plt.plot(ts, time_signal, label='Original Signal')
plt.plot(ts, xhat, label='Decoded Signal')

# Plot the spikes as dots
pos_spike_idx = np.where(spikes_sp == 1)[0]
neg_spike_idx = np.where(spikes_sn == 1)[0]
pos_spike_times = ts[pos_spike_idx]
neg_spike_times = ts[neg_spike_idx]
plt.scatter(pos_spike_times, np.ones_like(pos_spike_times), color='blue', marker='.', label='Positive Spikes')
plt.scatter(neg_spike_times, -np.ones_like(neg_spike_times), color='red', marker='.', label='Negative Spikes')

# Add legend, labels, and title
plt.legend()
plt.xlabel('Time (s)')
plt.ylabel('Signal')
plt.title('Original and Decoded Signals with Spikes')
plt.show()

**d) Power spectra.** Plot the signal $|X(\omega)|$, spike response $|R(\omega)|$, and filtered signal $|\hat X(\omega)|$ power spectra for the signal from 3c).

In [None]:
# Calculate the power spectra
X_power = np.abs(freq_signal)   # |X(w)|
R_power = np.abs(R)             # |R(w)|
XHAT_power = np.abs(XHAT)       # |XHAT(w)|

# Plot the power spectra
plt.plot(fs, X_power, label='$|X(\\omega)|$')
plt.plot(fs, XHAT_power, label='$|\\hat{X}(\\omega)|$')
plt.plot(fs, R_power, label='$|R(\\omega)|$')
plt.legend()
plt.title('Power Spectra')
plt.xlabel('Frequency (Hz)')
plt.ylabel('Amplitude')
plt.xlim([-50, 50])
plt.show()

**e) Discussion.** How do these spectra relate to the optimal filter?

The figure in 4d) shows that $X(\omega)$ has the most concentrated power spectra with a 5Hz bandwidth since the random signal was generated with limit = 5Hz. $\hat{X}(\omega)$ has a slightly spread out power spectrum and some activation between 5-30Hz, but nothing beyond that. On the other hand, $R(\omega)$ is the most sparse near 0Hz and has a lot of power in high frequencies up to 500Hz.

The role of the optimal filter is to create a smooth decoded $\hat{x}(t)$ from a sharp/discrete $r(t)$. The power spectra show that this goal was accomplished, since the majority of the high frequency noise from $R(\omega)$ was eliminated after filtering. Also, the fact that $\hat{X}(\omega)$ has a similar spectrum to $X(\omega)$ within 5Hz is a good sign that the signal was decoded correctly.

**f) Filter for different signal bandwidths.** Plot the optmial filter $h(t)$ in the time domain when filtering spike trains for white noise signals with different `limit` values of $2\,\mathrm{Hz}$, $10\,\mathrm{Hz}$, and $30\,\mathrm{Hz}$.

In [None]:
# Put everything in a function so the variables don't get overwritten
def filter_limits():
    T = 1
    dt = 0.001
    hs = []
    for limit in [2, 10, 30]:
        # Generate the random signal according to the bandwidth limit
        ts, fs_rad, time_signal, freq_signal = generate_signal(T=T, dt=dt, rms=0.5, limit=limit, seed=0)

        # Define the spike trains
        ts, spikes_sp, num_spikes_sp, voltages = spike_train(x=time_signal, e=1, alpha=alpha, J_bias=J_bias, tau_ref=tau_ref, tau_rc=tau_rc, T=T, delta_t=delta_t)
        ts, spikes_sn, num_spikes_sn, voltages = spike_train(x=time_signal, e=-1, alpha=alpha, J_bias=J_bias, tau_ref=tau_ref, tau_rc=tau_rc, T=T, delta_t=delta_t)

        # Calculate the optimal filter h(t)
        ts, fs, R, H, h, XHAT, xhat, XP, WXP = compute_optimal_filter(x=time_signal, X=freq_signal, spikes=np.array([spikes_sp, spikes_sn]))
        hs.append(h)

    plt.plot(ts, hs[0], label='Limit = 2 Hz')
    plt.plot(ts, hs[1], label='Limit = 10 Hz')
    plt.plot(ts, hs[2], label='Limit = 30 Hz')
    plt.legend()
    plt.xlabel('Time (s)')
    plt.ylabel('Filter h')
    plt.title('Optimal Filter in the Time Domain')
    plt.xlim([-0.2, 0.2])
    plt.show()

filter_limits()

**g) Discussion.** Describe the effects on the time plot of the optimal filter as `limit` increases. Why does this happen?

As `limit` increases in `generate_signal()`, the input $x(t)$ takes on more high-frequency components, which means it has more rapid changes in the time domain. Since the optimal filter aims to minimize the decoding error between $x(t)$ and $\hat{x}(t)$, it must also take on a higher temporal resolution to accurately capture the fast variations in the input signal. As a result, an input with a high bandwidth leads the optimal filter to also have a high bandwidth, which makes it sharper and narrower.

# Using post-synaptic currents as a filter


**a) Plotting the filter for different $n$.** Plot the normalized $h(t)$ for $n=0$, $1$, and $2$, with $\tau=7\,\mathrm{ms}$.

In [None]:
def c_integrand(t, n, tau):
    return t**n * np.exp(-t/tau)

# Integrated c using the trapezoidal rule
def trapezoidal_integration(n, tau, a=0, b=1000, dt=0.001):
    t = np.arange(a, b, dt)
    y = c_integrand(t, n, tau)
    integral = (y[0]/2 + np.sum(y[1:-1]) + y[-1]/2) * dt
    return integral

# h(t)
def post_synaptic_current(n, tau, a=-0.01, b=0.075, dt=0.001):
    t = np.arange(a, b, dt)
    c = trapezoidal_integration(n, tau)
    h = t**n * np.exp(-t/tau) / c

    for i, time in enumerate(t):
        if time < 0:
            h[i] = 0
    return t, h

# Plot the post-synaptic current for n = 0, 1, and 2 with tau = 7ms
t, h0 = post_synaptic_current(n=0, tau=0.007)
t, h1 = post_synaptic_current(n=1, tau=0.007)
t, h2 = post_synaptic_current(n=2, tau=0.007)

plt.plot(t, h0, label=r'$\tau$ = 7ms, n = 0')
plt.plot(t, h1, label=r'$\tau$ = 7ms, n = 1')
plt.plot(t, h2, label=r'$\tau$ = 7ms, n = 2')
plt.legend()
plt.xlabel('Time (s)')
plt.ylabel('Filter Magnitude')
plt.title('Synaptic Filter')
plt.show()

**b) Discussion.** What two things do you expect increasing $n$ will do to $\hat{x}(t)$?

As $n$ increases, the post-synaptic current acts as more of a low-pass filter, which has a smoother but slower response curve (delayed peak magnitude). As a result, $\hat{x}(t)$ will become smoother but also more time-shifted.

**c) Plotting the filter for different $\tau$.** Plot the normalized $h(t)$ for $\tau=2\,\mathrm{ms}$, $\tau=5\,\mathrm{ms}$, $\tau=10\,\mathrm{ms}$, $\tau=20\,\mathrm{ms}$ with $n = 0$.

In [None]:
# Plot the post-synaptic current for tau = 2ms, 5ms, and 10ms with n = 0
t, h0 = post_synaptic_current(n=0, tau=0.002)
t, h1 = post_synaptic_current(n=0, tau=0.005)
t, h2 = post_synaptic_current(n=0, tau=0.01)

plt.plot(t, h0, label=r'$\tau$ = 2ms, n = 0')
plt.plot(t, h1, label=r'$\tau$ = 5ms, n = 0')
plt.plot(t, h2, label=r'$\tau$ = 10ms, n = 0')
plt.legend()
plt.xlabel('Time (s)')
plt.ylabel('Filter Magnitude')
plt.title('Synaptic Filter')
plt.show()

**d) Discussion.** What two things do you expect increasing $\tau$ will do to $\hat{x}(t)$?

Unlike changing $n$, increasing $\tau$ doesn't delay the peak but lowers its magnitude and stretches out the signal. It is also a low-pass filter, but with the peak magnitude still at t=0. As a result, $\hat{x}(t)$ will also become smoother and a bit time-shifted, but not as much as increasing $n$ makes it.

**e) Decoding a spike-train using the post-synaptic current filter.** Decode $\hat{x}(t)$ from the spikes generated in question 3c) using an $h(t)$ with $n=0$ and $\tau=7\,\mathrm{ms}$. Do this by generating the spikes, filtering them with $h(t)$, and using that as your activity matrix $A$ to compute your decoders. Plot the time and frequency plots for this $h(t)$. Plot the $x(t)$ signal, the spikes, and the decoded $\hat{x}(t)$ value.

In [None]:
# ‚úç <YOUR SOLUTION HERE>

**f) Deocding a spike-train representing a low-frequency signal.** Use the same decoder and $h(t)$ as in part e), but generate a new $x(t)$ with $\mathtt{limit}=2\,\mathrm{Hz}$. Plot the $x(t)$ signal, the spikes, and the decoded $\hat{x}(t)$ value.

In [None]:
# ‚úç <YOUR SOLUTION HERE>

**g) Discussion.** How do the decodings from e) and f) compare? Explain.

‚úç \<YOUR SOLUTION HERE\>