# Kalman filter of position data

We want to remove as much noise from the position tracking data. We will apply the filter to reported x and y position of the mouse.

A few preprocessing steps are applied to the position data before applying the Kalman filter to remove values that were clearly wrong.


In [1]:
%load_ext autoreload
%autoreload 2
%run setup_project.py
prepareSessionsForSpatialAnalysisProject(sSesList,myProject.sessionList)


Project name: autopi_ca1
dataPath: /adata/projects/autopi_ca1
dlcModelPath: /adata/models
Reading /adata/projects/autopi_ca1/sessionList
We have 40 testing sessions in the list
See myProject and sSesList objects
Loading Animal_pose and Spike_train, sSes.ap and sSes.cg


100%|██████████████████████████████████████████████████████████████████████████████████████████| 40/40 [00:33<00:00,  1.20it/s]

Loading ses.trial_table_simple as ses.trials
Create condition intervals in ses.intervalDict





## Preprocessing steps


These functions apply some corrections to the raw data (remove clear jumps, correct for 180 flip of head direction, and use offline tracking if available to fill gaps)

In [2]:
from pathlib import Path
from spikeA.Session import Session

def remove_speed_jumps(df, time, px_per_cm , max_speed=200):
    """
    set x,y,hd to np.nan when speed is above a threshod
    
    Argument
    df: DataFrame with x, y , hd data
    time: time of each frame in DataFrame, in seconds
    px_per_cm: pixels per cm
    """

    dt = np.diff(time,append = np.nan)
    dx = np.diff(df.x,append = np.nan)
    dy = np.diff(df.y,append = np.nan)
    speed = (np.sqrt(dx**2+dy**2)/dt)/px_per_cm
    df["speed"] = speed
    
    speed_issues = speed > max_speed
    print("Number of speed issues: {}".format(np.sum(speed_issues)))
    
    df.x[speed_issues] = np.nan
    df.y[speed_issues] = np.nan
    df.hd[speed_issues] = np.nan
    df.speed[speed_issues] = np.nan
    
    return df

def correct_head_direction_180_flip(df, time, px_per_cm , min_speed=10, max_angle=np.pi/1.2):
    """
    Rotate by pi the head-direction if the difference between heading and head-direction is more than max_angle when the animal is moving above a speed threshold. 
    
    Mice do not move forward with their head pointing backward. If we get this in our data, it is likely due to a swap in the 2 LEDs 
    
    Argument
    df: DataFrame with x, y , hd data
    time: time of each frame in DataFrame
    px_per_cm: pixels per cm
    min_speed: minimal speed at which heading and head-direction should be compared
    max_angle: max angle between heading and head-direction that is considered valid.
    """
        
    dt = np.diff(time,append = np.nan)
    dx = np.diff(df.x,append = np.nan)
    dy = np.diff(df.y,append = np.nan)
    speed = (np.sqrt(dx**2+dy**2)/dt)/px_per_cm
    df["speed"] = speed
    df["heading"] = np.arctan2(dy,dx)
    
    # angle between 2 vector is the acos of the dot product
    delta_angle = np.arccos( np.cos(df.hd)*np.cos(df.heading)+np.sin(df.hd)*np.sin(df.heading) )
    
    swap_indices = np.logical_and(df.speed>min_speed,delta_angle>max_angle)
    
    df.hd[swap_indices] = np.arctan2(np.sin(df.hd[swap_indices]+np.pi),np.cos(df.hd[swap_indices]+np.pi))
    
    return df

def merge_online_offline_tracking(df,df1,prob=0.995):
    """
    Function to fill missing x,y,hd from online tracking with dlc model output.
    
    This only works if we have done offline tracking. 
    
    We fill only when dlc was pretty sure about detecting the mouse (data above prob threshold)
    
    Arguments:
    df : online tracking data (positrack2 file)
    df1: offline tracking data (offline dlc model output)
    """
    
    
    i1 = np.isnan(df.x) # df is nan
    print("Proportion of nan in online tracking: {:.3f}".format(np.sum(np.isnan(df.x)/df.x.shape[0])))
    
    if df1 is None:
        print("No valid df1")
        return
    
    i2 = np.logical_and(df1.LedLeft_p>prob, df1.LedRight_p>prob) # we have good dlc confidence
    i = np.logical_and(i1,i2)
    
    print("Proportion of nan in offline tracking: {:.3f}".format(np.sum(np.isnan(df1.x)/df1.x.shape[0])))
                                                                 
    print("Number of np.nan filling using dlc: {}".format(np.sum(i)))
    df.x[i] = df1.x[i]
    df.y[i] = df1.y[i]
    df.hd[i] = df1.hd[i]
                                                                 
    print("Proportion of nan after merging: {:.3f}".format(np.sum(np.isnan(df.x)/df.x.shape[0])))
                                                                 
    return df

def post_processing_positrack2_file(sSes, trialName, extension="positrack2",alt_extension="offline.positrack2", output_extension="positrack2_post"):
    """
    Function to apply post processing to the position tracking data
    
    Work on a sinble positrack2 file
    
    sSes: spikeA session
    trial_name: name of a single trial
    extension: extension of the positrack2 file
    alt_extension: name of the extension for alternative position data file (see dlc_on_positrack2_videos)
    
    """
    if sSes is None:
        raise TypeError("Please provide a session object with the ses argument")
    
    if not (issubclass(type(sSes),Session) or isinstance(sSes,Session)): 
        raise TypeError("ses should be a subclass of the Session class")
    
    print("***************")
    print(trialName)
    posiFn = sSes.path+"/"+trialName+"."+extension
        
    positrack_file = Path(posiFn)
    if not positrack_file.exists() :
        raise OSError("positrack file {} missing".format(positrack_file_name))
    print("Loading ",positrack_file)
    df = pd.read_csv(positrack_file)
    print("positrack2 file with {} lines".format(df.shape[0]))
    print("Proportion of nan in original positrack2 file: {:.3f}".format(np.sum(np.isnan(df.x)/df.x.shape[0])))
    
    
    
    posiAltFn = sSes.path+"/"+trialName+"."+alt_extension
    positrack_alt_file = Path(posiAltFn)
    if positrack_alt_file.exists() :
        print("Loading ",positrack_alt_file)
        df1 = pd.read_csv(positrack_alt_file)
        print(positrack_alt_file,"has {} lines".format(df1.shape[0]))
        if df1.shape[0] ==0:
            df1= None
    else:
        print("No ",positrack_alt_file)
        df1 = None
    
    if df1 is None:
        print("No offline tracking to improve on online tracking")
    
    if df1 is not None:
        if (df.shape[0] != df1.shape[0]):
            raise ValueError("the two position data file do not have the same length")
    
    

    # if there are subsequent frames with the same acquisition time, set to np.nan as it will confuse the gh filter
    dt=np.diff(df.acq_time_source_2,append=np.nan)
    indices = dt==0
    print("Number of delta time == 0: {}".format(np.sum(indices)))
    df.x[indices] = np.nan
    df.y[indices] = np.nan
    df.hd[indices] = np.nan


    if df1 is not None:
        df1.x[indices] = np.nan
        df1.y[indices] = np.nan
        df1.hd[indices] = np.nan

    # we want the head-direction data in radians
    if (np.nanmax(df.hd)>np.pi): # if in degrees
        df.hd = df.hd/360*np.pi*2
        df.hd = np.arctan2(np.sin(df.hd),np.cos(df.hd))
    if df1 is not None:
        if (np.nanmax(df1.hd)>np.pi): # if in degrees
            df1.hd = df1.hd/360*np.pi*2
            df1.hd = np.arctan2(np.sin(df1.hd),np.cos(df1.hd))
    
    
    # remove speed jumps
    df = remove_speed_jumps(df = df,
                        time = df.acq_time_source_2,
                        px_per_cm = sSes.px_per_cm)
    if df1 is not None:
        df1 = remove_speed_jumps(df = df1,
                            time = df.acq_time_source_2,
                            px_per_cm = sSes.px_per_cm)

    # remove hd flips
    df = correct_head_direction_180_flip(df,df.acq_time_source_0,sSes.px_per_cm)
    if df1 is not None:
        df1 = correct_head_direction_180_flip(df1,df.acq_time_source_0,sSes.px_per_cm)

        
    # use the dlc model output to fill in the np.nan from online tracking
    if df1 is not None:                        
        df = merge_online_offline_tracking(df,df1)
    
    
    ## saving "positrack2_post" file
    newDf = df.copy()
    
    outPosiFn = sSes.path+"/"+trialName+"."+ output_extension    
    newDf.to_csv(outPosiFn,index=False)
    print("Saving output to ",outPosiFn)
    
    return

def post_processing_positrack_file(sSes, trialName, extension="positrack", output_extension="positrack_post"):
    """
    Function to apply post processing to the position tracking data
    
    Work on a single positrack file
    
    sSes: spikeA session
    trial_name: name of a single trial
    extension: extension of the positrack file
    output_extension: name of the extension for the processed data
    
    """
    if sSes is None:
        raise TypeError("Please provide a session object with the ses argument")
    
    if not (issubclass(type(sSes),Session) or isinstance(sSes,Session)): 
        raise TypeError("ses should be a subclass of the Session class")
    
    print("***************")
    print(trialName)
    posiFn = sSes.path+"/"+trialName+"."+extension
        
    positrack_file = Path(posiFn)
    if not positrack_file.exists() :
        raise OSError("positrack file {} missing".format(positrack_file_name))
    print("Loading ",positrack_file)
    df = pd.read_csv(positrack_file,sep=" ")
    print("positrack file with {} lines".format(df.shape[0]))
   
    
    # set invalid to np.nan
    df.x[df.x==-1.0] = np.nan
    df.y[df.y==-1.0] = np.nan
    df.hd[np.isnan(df.x)] = np.nan
    
    if (np.nanmax(df.hd)>np.pi): # if in degrees
        df.hd = df.hd/360*np.pi*2
        df.hd = np.arctan2(np.sin(df.hd),np.cos(df.hd))
        
        
    print("Proportion of nan in original positrack2 file: {:.3f}".format(np.sum(np.isnan(df.x)/df.x.shape[0])))
    
    # remove speed jumps
    df = remove_speed_jumps(df = df,
                        time = df.capTime/1000, # in seconds
                        px_per_cm = sSes.px_per_cm)
    
    outPosiFn = sSes.path+"/"+trialName+"."+ output_extension    
    df.to_csv(outPosiFn,index=False)
    print("Saving output to ",outPosiFn)

Start with sessions from positrack2. We have the video that we can use to improve tracking.

In [680]:
positrack2_sessions = [ ses for ses in sSesList if ses.ap.positrack_type()=="positrack2"]
print("Number of positrack2 sessions: ", len(positrack2_sessions))

positrack_sessions = [ ses for ses in sSesList if ses.ap.positrack_type()=="positrack"]
print("Number of positrack sessions: ", len(positrack_sessions))

Number of positrack2 sessions:  16
Number of positrack sessions:  28


In [None]:
for i, sSes in enumerate(positrack_sessions):
    print(i, sSes.name)
    for t in sSes.trial_names:
        post_processing_positrack_file(sSes,t)

Run the pre-processing steps on the positrack2 sessions

In [None]:
for i, sSes in enumerate(positrack2_sessions):
    print(i, sSes.name)
    for t in sSes.trial_names:
        post_processing_positrack2_file(sSes,t)

## Kalman filter

In [3]:
def pos_vel_filter(x, P, R, Q=0., dt=1.0):
    """
    Returns a KalmanFilter which implements a constant velocity model for a state [x,y,dx,dy].T
    
    """
    
    kf = KalmanFilter(dim_x=4, dim_z=2)
    
    # initialization of state
    kf.x = np.array([x[0], x[1], x[2], x[3]]) # location x, location y, velocity x and velocity y
    
    # state transition matrix
    kf.F = np.array([[1., 0 , dt, 0], # location x * 1 + velocity x * dt
                     [0., 1 , 0, dt], # location y * 1 + velocity y * dt
                     [0., 0 , 1., 0], # velocity x
                     [0., 0 , 0., 1.]]) # velocity y 
    
    

    kf.H = np.array([[1., 0, 0, 0],
                     [0., 1., 0, 0]])    # Measurement function
    kf.R *= R                     # measurement uncertainty
    if np.isscalar(P):
        kf.P *= P                 # covariance matrix 
    else:
        kf.P[:] = P               # [:] makes deep copy
    if np.isscalar(Q):
        kf.Q = Q_discrete_white_noise(dim=2, dt=dt, var=Q, block_size = 2, order_by_dim=False)
    else:
        kf.Q[:] = Q
    return kf

def run(x0=(400.,400.,0.,0.), P=200, R=0, Q=0, dt=1.0, zs=None,maxConsicutiveInvalids=15):
    """
    Process one data point with the Kalman filter
    
    This filter assumes that dt is constant.
    
    Arguments
    x0: starting state of the filter
    P: 
    
    """
    
    # create the Kalman filter
    kf = pos_vel_filter(x0, R=R, P=P, Q=Q, dt=dt)  
    
    consecutiveInvalids = 0
    # run the kalman filter and store the results
    xs, cov = [], []
    for z in zs:
        if ~np.isnan(z[0]):            
            consecutiveInvalids=0
            kf.predict()
            kf.update(z)
            xs.append(kf.x)
            cov.append(kf.P)
        else:
            #print("dealing with nan")
            consecutiveInvalids += 1
            if consecutiveInvalids < maxConsicutiveInvalids:
                kf.predict()
                xs.append(kf.x)
                cov.append(kf.P)
            else: # don't predict, set to invalid
                xs.append(np.array([np.nan,np.nan,np.nan,np.nan]))
                cov.append(kf.P)
            
    xs, cov = np.array(xs), np.array(cov)
    return xs, cov
def filter_positrack_file(sSes, trialName,input_extension="positrack2_post", output_extension="positrack2_kf",with_plot=True):
    
    posiFn = sSes.path+"/"+trialName+"."+ input_extension
    print(posiFn)
    df = pd.read_csv(posiFn)
    print(input_extension, "file with {} lines".format(df.shape[0]))
    print("Proportion of nan in original positrack2 file: {:.3f}".format(np.sum(np.isnan(df.x)/df.x.shape[0])))
    
    # prepare input data for filter
    zsx = df.x.to_numpy()
    zsy = df.y.to_numpy()
    zs = np.stack([zsx,zsy]).T
    
    
    # plot input data
    if with_plot:
        fig, axs = plt.subplots(1,3,figsize=(17,4))
        axs[0].plot(zs[:,0])
        axs[1].plot(zs[:,1])
        axs[2].plot(zs[:,0],zs[:,1])
    
    # get valid position to use as starting position
    zsValid=zs[~np.isnan(zs[:,0]),:]
    
    if input_extension=="positrack2_post":
        dt=df.acq_time_source_2.diff().median()
    elif input_extension=="positrack_post":
        dt = df.capTime.diff().median()/1000
    # apply the filter
    Ms, Ps = run(x0=(zsValid[0,0],zsValid[0,1],0,0),R=0.5, Q=1500, P=200,zs=zs,dt=dt)

    print("Proportion of nan after Kalmin filter: {:.3f}".format(np.sum(np.isnan(Ms[:,0])/Ms.shape[0])))

    
    
    # remove any position that would be out of range of the original data
    toInvalidx = np.logical_or(Ms[:,0]<df.x.min(),Ms[:,0]>df.x.max())
    toInvalidy = np.logical_or(Ms[:,1]<df.y.min(),Ms[:,1]>df.y.max())
    toInvalid = np.logical_or(toInvalidx,toInvalidy)
    print("out of range: {}".format(np.sum(toInvalid)))
    Ms[toInvalid,:] = np.nan
    
    
    print("Proportion of nan: {:.3f}".format(np.sum(np.isnan(Ms[:,0])/Ms.shape[0])))

    
    # replace original position data with filtered data
    df.x = Ms[:,0]
    df.y = Ms[:,1]
    
    # save the new file
    posiFn = sSes.path+"/"+trialName+"."+ output_extension
    df.to_csv(posiFn,index=False)
    print("Saving output to",posiFn)
    
    
    if with_plot:
        maxX = 20000
        fig,axs = plt.subplots(2,1,figsize=(30,20))
        width = 1
        for i in range(2):
            axs[i].scatter(range(len(zs[:,i])),zs[:,i],zorder=1,c="red")
            axs[i].plot(Ms[:,i],zorder=0,c="black",lw=width)
            axs[i].set_xlim(0,maxX)

        plt.show()

        plt.figure(figsize=(10,10))
        plt.plot(zs[:,0],zs[:,1],zorder=1,c="red",lw=width)
        plt.plot(Ms[:,0],Ms[:,1],zorder=0,c="black",lw=width)
        
        plt.show()
        plt.figure(figsize=(10,10))
        plt.plot(Ms[:,0],Ms[:,1],zorder=0,c="black",lw=width)
        

In [None]:
for sSes in positrack_sessions:
#sSes = positrack_sessions[-1]
    for trialName in sSes.trial_names:
#trialName= sSes.trial_names[4]
        filter_positrack_file(sSes,trialName,with_plot=False,input_extension="positrack_post", output_extension="positrack_kf")

In [None]:
output_extension="positrack2_post"
for sSes in positrack2_sessions:
#sSes = positrack2_sessions[-1]
    for trialName in sSes.trial_names:
        #trialName= sSes.trial_names[4]
        filter_positrack_file(sSes,trialName,with_plot=False,input_extension="positrack2_post", output_extension="positrack2_kf")

In each session directory, there are files with extension `positrack_kf` or `positrack2_kf` with the filtered data.