In [1]:
import numpy as np
import matplotlib.pyplot as plt

In [10]:
class Sampler(object):
    
    # Generate samples based on 
    # fs: Sampling frequency per second
    # T: Period of interest (in seconds)
    def __init__(self, fs=20000, T=2, n_spikes=8):
        self.fs = fs
        self.T = T
        self.t = np.arange(0, T, 1 / fs) 
        self.n_spikes = n_spikes
    
    def sample(self, epsilon = 0.00005):
        self.x = sample_signal(self.t, epsilon)
        sample_spikes(self.x, self.n_spikes, epsilon)
        plot(self.t, self.x)
        return self.x
    
        
# Converts time t (in seconds) to index in data
def time_to_index(t):
    return int(t * fs)

# Set very small random noise in the range [-epsilon, epsilon]
def small_random_noise(t, epsilon = 0.00005):
    return epsilon * np.random.randn(len(t))

# Generate the samples 
# epsilon: noise coefficient
def sample_signal(t, epsilon = 0.00005):
    f1 = 3000
    a1 = 0.1
    f2 = 4000
    a2 = 0.2
    return small_random_noise(t, epsilon=epsilon) + a1 * np.sin(2 * np.pi * f1 * t) + a2 * np.sin(2 * np.pi * f2 * t) 

# Sample n_spikes spikes from signal
def sample_spikes(x, n_spikes, epsilon = 0.00005):
    ampl = np.max(x)
    
    offset = 4
    
    sampled_spike_indices = np.random.choice(np.arange(offset, len(t)), size=n_spikes, replace=False)
    sampled_spike_indices.sort()
    
    for index in sampled_spike_indices:
        assert(index >= offset)
        x[index - 2 : index - 1] = 3 * ampl
        x[index - 1 : index] = -4 * ampl
        x[index : index + 1] = 8 * ampl
    
    return sampled_spike_indices

def plot(t, data):
    # Plot the received (sampled) signal
    fig, ax = plt.subplots(1, 1, figsize=(6, 3),dpi = 600)

    plt.title("Received Neural Signal")
    plt.ylabel("Voltage (mV)")
    plt.xlabel("Time (s)")

    plt.plot(t, data)
    plt.show()