In [2]:
%autosave 180
%load_ext autoreload
%autoreload 2
%matplotlib tk

#|
import numpy as np
import matplotlib.pyplot as plt



Autosaving every 180 seconds


In [8]:
##########################################################
########## COMPUTE VELOCITY FROM ROTARY ENCODER ##########
##########################################################
from tqdm import trange

#
def binarize(data):
    
    idx = np.where(data<2.5)[0]
    idx2 = np.where(data>=2.5)[0]
    
    data[idx] = 0
    data[idx2] = 1
    
    #
    return data

#
def get_velocity(data):
    
    # this contains the absolute time stamps of the rotary encoder readings
    try:
        abs_times_ttl_read = data['abs_times_ttl_read']
    except:
        print ("missing abs_times_ttl_read; using abs_times")
        abs_times_ttl_read = data['abs_times']
    print ("lick abs_times_ttl_read: ", abs_times_ttl_read.shape, abs_times_ttl_read)

    
    # lad the rotary encoder data
    rot1 = data['rotary_encoder1_abstime']
    rot2 = data['rotary_encoder2_abstime']
    print ("rot2: ", rot2.shape)
    #abs_times_ttl_read = data['abs_times_ttl_read']
    #print ("abs_times_ttl_read: ", abs_times_ttl_read.shape)

    # distance
    n_clicks_per_rotation = 500
    ball_diameter = 0.2  # distance in meters
    ball_circumference = ball_diameter*3.141592
    dist_per_click = ball_circumference/n_clicks_per_rotation
    
    # time
    sample_rate = 1000
    seconds_per_time_stamp = 1/sample_rate
        
    # detect the clicks
    bin1 = binarize(rot1)
    bin2 = binarize(rot2)
    
    #
    clicks = np.array((bin1, bin2)).T.squeeze()
    print ("Clicks: ", clicks.shape, clicks[:10])
    print ("abs_times_ttl_read: ", abs_times_ttl_read.shape, abs_times_ttl_read[:10], abs_times_ttl_read[-10:])
        
    #
    vel = []
    times = []
    time_since_last_click = 0
    rot1_last_state = clicks[0,0]
    for k in trange(0,clicks.shape[0],1):
        
        # check if the state of the rotary encoder changed
        if clicks[k,0]!=rot1_last_state:
            distance = dist_per_click  #only walked 1 click

            # check to see how long since last time point
            time = time_since_last_click*seconds_per_time_stamp
            
            # compute velocity
            v = distance/time 
            
            #
            vel.append(v)
            times.append(abs_times_ttl_read[k])
            #times.append(k/sample_rate)       
            
            #
            time_since_last_click=0
            last_click_location=k #.copy()
            rot1_last_state=clicks[k,0]

        # add a zero velocity if no click
        #else:
        #    vel.append(0)
        #    times.append(k/sample_rate)


        #    
        time_since_last_click+=1

    #
    times = np.array(times)
    
    # ok, now find the nearest value in times to the abs_times_ca_read
    try:
        abs_times_ca_read = data['abs_times_ca_read']
    except:
        print ("missing abs_times_ca_read; using abs_times")
        abs_times_ca_read = data['ttl_times']
    print ("lick abs_times_ca_read: ", abs_times_ca_read.shape, abs_times_ca_read)

    times_ttl_final = []
    velocity_final = []
    for k in trange(abs_times_ca_read.shape[0]):
        idx = np.argmin(np.abs(times-abs_times_ca_read[k]))
        #print (idx, times[idx], abs_times_ca_read[k])
        times_ttl_final.append(times[idx])
        velocity_final.append(vel[idx])
    
    times_ttl_final = np.array(times_ttl_final)
    velocity_final = np.array(velocity_final)

    #
    reward_times = data['reward_times'].squeeze().T[:,1]

    # remove all reward times ==-1
    idx = np.where(reward_times!=-1)[0]
    reward_times = reward_times[idx]

    #
    return velocity_final, reward_times
    

##################################################################
##################################################################
##################################################################
fname = '/media/cat/8TB/donato/bmi/DON-014266/20230201/results.npz'
data = np.load(fname,
               allow_pickle=True)

#
velocities, reward_times = get_velocity(data)

# 
print ("velocity array has # of entries", velocities.shape)

#
plt.figure()
plt.plot(velocities)
plt.xlabel("TTL pulse #")
plt.ylabel("Velocity (m/s)")

# add reward times 
for k in range(reward_times.shape[0]):
    # use vertical lines from 0 to 1
    plt.plot([reward_times[k], reward_times[k]], [0, 1], 'r', 
             # line style dashed
             linestyle='--',)

#
plt.show()

missing abs_times_ttl_read; using abs_times
lick abs_times_ttl_read:  (3006354,) [ 126.939616   126.9397987  126.9399566 ... 3131.2427038 3131.2444819
 3131.244684 ]
rot2:  (3006354, 1)
Clicks:  (3006354, 2) [[0. 0.]
 [0. 0.]
 [0. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]
 [1. 0.]]
abs_times_ttl_read:  (3006354,) [126.939616  126.9397987 126.9399566 126.9401115 126.9402626 126.9404159
 126.9405659 126.9407162 126.9408652 126.9410151] [3131.2364963 3131.2367064 3131.2384832 3131.2386802 3131.2404755
 3131.2407238 3131.2424825 3131.2427038 3131.2444819 3131.244684 ]


100%|██████████| 3006354/3006354 [00:00<00:00, 3645525.43it/s]


missing abs_times_ca_read; using abs_times
lick abs_times_ca_read:  (89952,) [ 131.2789029  131.3110578  131.3446579 ... 3131.1786582 3131.2109746
 3131.244684 ]


100%|██████████| 89952/89952 [00:01<00:00, 46006.96it/s]


velocity array has # of entries (89952,)
