In [None]:
import numpy as np
from scipy.signal import lfilter, butter 
from scipy.stats import poisson
import matplotlib.pyplot as plt
from scipy.interpolate import make_interp_spline
import random
import pandas as pd


In [None]:
# Parameters
sample_rate = 1017.25 # (Hz) - based on normal data recording rate
t = 1800 # (s) - based on std experiment
cutoff = 0.1 # based on OG simulation paper
n_dtpts = int(sample_rate*t)  # Number of data points
movement_attenuation = 50  # Example attenuation percentage as per OG sim paper
noise_factor = 2 # as per OG sim paper
time_pts = np.linspace(0,t,n_dtpts)

In [None]:
def calculate_movement_component(cutoff = 0.1, sample_rate = sample_rate, movement_attenuation = 50):
    '''
    Calculate the movement component of the signal, 
    based on a lowpass filtered random data and movement attenuation parameter
    '''

    b, a = butter(N=4, Wn=cutoff / (sample_rate / 2), btype='low') # check cutofffffffff

    # Apply the filter
    lowpass_values = lfilter(b, a, np.random.rand(n_dtpts))

    movement_component = 1 - (lowpass_values * (movement_attenuation / 100))
    return movement_component

def calculate_decay_component(time_pts, decay_rate1 = 0.02, decay_rate2 = 0.002, decay_base = 40):
    '''
    Make a double exponatial decaying curve, sampled at every time_pts
    '''
    decay_rate = ((1 - decay_rate1) ** time_pts + (1 - decay_rate2) ** time_pts) / 2
    print(np.shape(decay_rate))
    decay = decay_rate*(decay_base/100)+(1-decay_base/100)
    
    return decay
    

def calculate_ERT(lambda_val = 2, peak = 1, scale = 5, vis = False):
    '''
    Makes a Poisson distribution, 
    with mean = lambda_val, range = t, max value = peak
    '''
    # evaluate lambda over a duration 5 times longer to capture the whole distribution
    t = lambda_val*5

    # Generate discrete values of the theoretical Poisson probability mass
    # function (pmf) from 0 to t
    x = np.arange(0, t)
    pmf = poisson.pmf(x, lambda_val)
    # Rescale x axis. The lowest reasonable value of lambda is 2,
    # corresponding to t = 10, our response timescale is >50ms
    x = x * scale
    # Rescale y axis
    pmf = pmf/max(pmf)*peak
    # print(max(pmf))

    # Interpolate pmf
    b = make_interp_spline(x, pmf, k=2)             # b spline interpolation
    x = np.arange(0, t * scale)
    pmf = b(x)
    # print(max(pmf))

    # Reindex where pmf values are >= 0.01
    indices = np.where(pmf >= 0.01)[0]
    pmf = pmf[indices]
    # x = x[indices]
    # x = np.arange(len(x))
    # print(max(pmf))
    
    if vis:
        plt.plot(x, pmf)
        plt.xlabel('time (ms; 1017.25Hz)')
        plt.show()
        
    return pmf


def calculate_noise_component(n_dtpts, sample_rate, noise_factor=8):
    '''
    Make a vector of length n_dtpts with random noised scaled by noise_factor
    '''
    noise_component = np.random.randn(n_dtpts) * noise_factor

    # b = sig.firwin(noise_component, cutoff=[1], fs=data.attrs['fs'],
    #                    pass_zero=False)
    # noise_component = detrend.filter_b = b
    # b, a = butter(N=6, Wn=0.99, btype='low')

    # # Apply the filter
    # noise_component = lfilter(b, a, np.random.rand(n_dtpts))
    
    return noise_component




   

    

In [None]:
# %%
def make_event(n_dtpts = 1000, n_events = False, lambda_val = False, peak_m = False, vis = False, delay_og = 0):

    true_signal = np.zeros(n_dtpts)
    if not n_events: n_events = random.randint(2,3)
    events = np.zeros(n_dtpts)
    if not lambda_val: lambda_val = random.randint(2, 5)
    if not peak_m: peak_m = random.uniform(5,15)

    for i in range(n_events):
        delay = delay_og + random.randint(0,5)
        peak = peak_m + random.uniform(-2, 2)
        print(peak)
        ert = calculate_ERT(lambda_val, peak, scale=10)
        event_duration = len(ert)

        initial_response = random.randint(delay, len(true_signal)-event_duration)
        events[initial_response] = 1

        true_signal[initial_response:initial_response+event_duration] += ert

    if vis: plt.plot(true_signal)
    
    return events, true_signal

In [None]:
ert = calculate_ERT(20, 7, scale=10)
max(ert)

In [None]:
# Put it all together to make the simmulated signal made of noise, underlying true signal, photobleaching decay and movement 

events1, true_signal1 = make_event(n_dtpts=n_dtpts, n_events=20, delay_og=1, peak_m = 8, vis = False, lambda_val=2)
events2, true_signal2 = make_event(n_dtpts=n_dtpts, n_events=15, delay_og=2, peak_m = 10, vis = False, lambda_val=20)
events3, true_signal3 = make_event(n_dtpts=n_dtpts, n_events=21, delay_og=0, peak_m = 12, vis = False, lambda_val=50)

true_signal = true_signal1 + true_signal2 + true_signal3

movement_component = calculate_movement_component(cutoff, sample_rate, movement_attenuation)

noise_component = calculate_noise_component(n_dtpts, sample_rate)
noise_component_iso = calculate_noise_component(n_dtpts, sample_rate)

decay_component = calculate_decay_component(time_pts)

data = (true_signal + 200) * movement_component * decay_component + noise_component

isob = 100 * movement_component * decay_component + noise_component_iso

# np.save('C:\Users\levip\Desktop\NSB\BrainHack\behapy\SIM\rawdata\sub-test1\ses-TEST1\sub-test1_ses-TEST.2_task-TEST_run-1_label-LNAc_channel-iso.npy', isob)

p, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2,2)
ax1.plot(true_signal)
ax2 = plt.subplot(2,2, 2)
ax2.plot(decay_component)
ax3 = plt.subplot(2,2, 3)
ax3.plot(movement_component)
ax4 = plt.subplot(2,2, 4)
ax4.plot(noise_component)
plt.show()

# plt.plot(decay_component)
# plt.show()


In [None]:
plt.plot(true_signal1)

In [None]:
# Rescale onset of each event from index to seconds
aa = np.where(events1 == 1)[0] / sample_rate
bb = np.where(events2 == 1)[0] / sample_rate
cc = np.where(events3 == 1)[0] / sample_rate

# Combine all event times and labels
onsets = np.concatenate([aa, bb, cc])
duration = [0.1] * len(onsets)
event_ids = ['event1'] * len(aa) + ['event2'] * len(bb) + ['event3'] * len(cc)

# Create the DataFrame and sort by time
df = pd.DataFrame({'onset': onsets, 'duration': duration, 'event_id': event_ids}).sort_values(by='onset').reset_index(drop=True)
df = df.set_index('onset')


df.to_csv(r'\Users\levip\Desktop\NSB\BrainHack\behapy\SIM\rawdata\sub-test1\ses-TEST1\sub-test1_ses-TEST1_task-TEST_run-1_events.csv')

# print(df)



In [None]:
np.save('/Users/levip/Desktop/NSB/BrainHack/behapy/SIM/rawdata/sub-test1/ses-TEST1/fp/sub-test1_ses-TEST1_task-TEST_run-1_label-LNAc_channel-ACh.npy', data)
np.save('/Users/levip/Desktop/NSB/BrainHack/behapy/SIM/rawdata/sub-test1/ses-TEST1/fp/sub-test1_ses-TEST1_task-TEST_run-1_label-LNAc_channel-iso.npy', isob)


In [None]:
plt.plot(data)

In [None]:
#%%
import numpy as np
import holoviews as hv
import datashader as ds
from holoviews.operation.datashader import datashade
from bokeh.plotting import output_notebook

# Enable Bokeh and Holoviews support in the notebook
hv.extension('bokeh')
# output_notebook()

# Convert data to a Holoviews Curve
curve = hv.Curve((np.arange(len(true_signal3)), true_signal3))
shaded_curve = datashade(curve).opts(width=800)

shaded_curve