In [10]:
## Import functions

import numpy as np  
import pandas as pd
np.set_printoptions(suppress=True) # Supress scientific notation when printing
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = [10, 5]
import seaborn as sns
from scipy.ndimage import gaussian_filter1d
import re # Regular expressions
import networkx as nx # Package for graph represenations 
from datetime import datetime, time 
import pygraphviz as gv

from networkx.drawing.nx_agraph import graphviz_layout, to_agraph, write_dot
from scipy.ndimage.filters import uniform_filter1d


from sklearn.mixture import GaussianMixture
from scipy.stats import norm
from scipy.signal import argrelextrema
from scipy import signal

In [11]:
edgelist_filename = './data/graph_edgelist.dat'

### load graph
print("load graph")
G = nx.read_edgelist(edgelist_filename)  
print("graph: number of nodes = ",G.number_of_nodes(),", edges = ",G.number_of_edges())

### pre-compute all possible shortest path lengths and save in dictionary (does not save shortest paths itself)
DD = nx.shortest_path_length(G) 
DD = dict(DD)
print("dictionary: len =", len(DD), ", total items =", sum([len(dv) for dv in DD.values()]))

# get node positions from recosntructed graph 
pos = graphviz_layout(G)

load graph
graph: number of nodes =  96 , edges =  125
dictionary: len = 96 , total items = 9216


In [12]:
## Set file names and paths

# Read full processed csv file and select rat and session

# SET CORRECT FILENAME!!!!!!!, since import date may differ!!!
filename = './results/Rat_HM_Ephys_AggProc_20220202.csv'

data_full = pd.read_csv(filename)

In [13]:
from __future__ import division 

def interpolated_intercept(x, y1, y2):
    """Find the intercept of two curves, given by the same x data"""

    def intercept(point1, point2, point3, point4):
        """find the intersection between two lines
        the first line is defined by the line between point1 and point2
        the second line is defined by the line between point3 and point4
        each point is an (x,y) tuple.

        So, for example, you can find the intersection between
        intercept((0,0), (1,1), (0,1), (1,0)) = (0.5, 0.5)

        Returns: the intercept, in (x,y) format
        """    

        def line(p1, p2):
            A = (p1[1] - p2[1])
            B = (p2[0] - p1[0])
            C = (p1[0]*p2[1] - p2[0]*p1[1])
            return A, B, -C

        def intersection(L1, L2):
            D  = L1[0] * L2[1] - L1[1] * L2[0]
            Dx = L1[2] * L2[1] - L1[1] * L2[2]
            Dy = L1[0] * L2[2] - L1[2] * L2[0]

            x = Dx / D
            y = Dy / D
            return x,y
        L1 = line([point1[0],point1[1]], [point2[0],point2[1]])
        L2 = line([point3[0],point3[1]], [point4[0],point4[1]])

        R = intersection(L1, L2)

        return R

    idx = np.argwhere(np.diff(np.sign(y1 - y2)) != 0)
    xc, yc = intercept((x[idx], y1[idx]),((x[idx+1], y1[idx+1])), ((x[idx], y2[idx])), ((x[idx+1], y2[idx+1])))
    return xc.flatten(),yc.flatten()

   
# #For the model with 3 Gaussians
# x_ax  = f_axis1
# y1 = c[0]
# y2 = c[2]

# plt.plot(x_ax, y1, marker='o', mec='none', ms=4, lw=1, label='y1')
# plt.plot(x_ax, y2, marker='o', mec='none', ms=4, lw=1, label='y2')

# idx = np.argwhere(np.diff(np.sign(y1 - y2)) != 0)

# plt.plot(x_ax[idx], y1[idx], 'ms', ms=7, label='Nearest data-point method')

# # new method!
# xc, yc = interpolated_intercept(x_ax,y1,y2)
# plt.plot(xc, yc, 'co', ms=5, label='Nearest data-point, with linear interpolation')
# print(xc,yc)


# plt.legend(frameon=False, fontsize=10, numpoints=1, loc='lower left')

# plt.show()

In [18]:
def intersectionPoint(ratNr):
    
    dataRatNr = data_full.loc[data_full['rat_no']==ratNr]
    dates = dataRatNr['date'].unique()
    meanIntersection = list()
    mean1 = list()
    mean2 = list()
    intersectionPoint = list()
    
    for i in dates:
        dataRatSession = data_full.loc[(data_full['rat_no']==ratNr) & (data_full['date']==i)]
        sp = dataRatSession['speed_ff']*-1

        sptotal = pd.concat([dataRatSession['speed_ff'],(sp)])

        gmm2 = GaussianMixture(n_components = 3).fit(np.asarray(sptotal).reshape(-1, 1))
        
        #plt.figure()
        #plt.hist(sptotal, bins=np.linspace(-1,1,50), histtype='stepfilled', density = True, alpha=0.5)
        #plt.xlim(-1, 1)
        
        f_axis1 = sptotal.copy().ravel()
        f_axis1.sort()
        a = []
        c = []
        for weight, mean, covar in zip(gmm2.weights_, gmm2.means_, gmm2.covariances_):
            a.append(mean)
            c.append(weight*norm.pdf(f_axis1, mean, np.sqrt(covar)).ravel())
            #plt.plot(f_axis1, c[-1])
            
        #plt.plot(f_axis1, np.array(c).sum(axis =0), 'k-')
        #plt.xlabel('Variable')
        #plt.ylabel('PDF')
        #plt.tight_layout()
        #plt.show()
        a=np.asarray(a)
        
        g1=np.argmin(abs(a))
        g2=np.argmax(a)

        a = (a[g1],a[g2])
        c = (c[g1],c[g2])
        
        x_axis  = f_axis1
        y1 = c[0]
        y2 = c[1]
        #plt.plot(x_axis, y1, marker='o', mec='none', ms=4, lw=1, label='y1')
        #plt.plot(x_axis, y2, marker='o', mec='none', ms=4, lw=1, label='y2')
        idx = np.argwhere(np.diff(np.sign(y1 - y2)) != 0)
       

        #plt.plot(x_axis[idx], y1[idx], 'ms', ms=7, label='Nearest data-point method')
        y1 = y1[x_axis>0]
        y2 = y2[x_axis>0]
        x_axis = x_axis[x_axis>0]
        xc, yc = interpolated_intercept(x_axis,y1,y2)
        #plt.plot(xc, yc, 'co', ms=5, label='Nearest data-point, with linear interpolation')
        mean1.append(a[0])
        mean2.append(a[1])
        intersectionPoint.append((xc,yc))
        meanIntersection.append((a[0].flatten(),(xc.flatten(),yc.flatten()), a[1].flatten()))
    return mean1 , mean2, intersectionPoint, dates, dataRatNr
#mean1,mean2,intersectionP, dateD, dataRatNr = intersectionPoint('Rat5')
#print(intersectionPoint('Rat5'))

# Plotting in time the averages of the intersection point and means per session.
# Calculating the intersection point of average point within trials (min aantal puntnen?)

In [55]:
def getnewDataframe (dataR, sessionNumber, ratNr):
    listAtWhatSpeedCrossesLine = list()
    listSpeed = list()
    listTimes = list()
    indexList = list()
    goalNodeList = list()
    trialNumberList = list()
    intersectionNodeList = list()
    listSessionNumber = list()
    dates = dataRatNr['date'].unique()
    if sessionNumber==0:
        for n_sess,z in enumerate(dates):
            print(intersection_point[n_sess][0])
            dataRatSession = data_full.loc[(data_full['rat_no']==ratNr) & (data_full['date']==z)]
            dataRatSession.set_index('trial_no', inplace=True)
            tr_list = dataRatSession.index.unique()
            for tr in tr_list:
                q = 0
                p = 0
                b = dataRatSession.loc[tr, 'speed_ff']
                i = -1
                for j in b:
                        
                        if(j<=i and j>=intersection_point[n_sess][0]):
                            listAtWhatSpeedCrossesLine.append((tr, j))
                            if q==0:
                                trialNumberList.append(tr)
                                listSpeed.append((tr,j))
                                listTimes.append(dataRatSession.loc[tr,'cum_seconds'].iloc[p])
                                indexList.append(p)
                                goalNodeList.append(dataRatSession.loc[tr,'node'].iloc[-1])
                                intersectionNodeList.append(dataRatSession.loc[tr,'node'].iloc[p]) 
                                listSessionNumber.append(z)
                            q+=1   
                        i=j
                        p+=1

    else:
        dataRatSession = data_full.loc[(data_full['rat_no']==ratNr) & (data_full['date']==sessionNumber)]
        dataRatSession.set_index('trial_no', inplace=True)
        tr_list = dataRatSession.index.unique()
        
        n_sess = np.argwhere(session_number==sessionNumber)[0][0]
        print(intersection_point[n_sess][0])
        for tr in tr_list:
            q = 0
            p = 0
            b = dataRatSession.loc[tr,'speed_ff']
            i = -1
            for j in b:
                    
                    if(j<=i and j>=intersection_point[n_sess][0]):
                        listAtWhatSpeedCrossesLine.append((tr, j))
                        if q==0:
                            trialNumberList.append(tr)
                            listSpeed.append((tr,j))
                            listTimes.append(dataRatSession.loc[tr,'cum_seconds'].iloc[p])
                            indexList.append(p)
                            goalNodeList.append(dataRatSession.loc[tr,'node'].iloc[-1])
                            intersectionNodeList.append(dataRatSession.loc[tr,'node'].iloc[p]) 
                            listSessionNumber.append(sessionNumber)
                        q+=1   
                    i=j
                    p+=1
                    
    res = pd.DataFrame(list(zip(trialNumberList, listTimes, indexList, intersectionNodeList, goalNodeList,listSessionNumber)),
               columns =['trial_number', 'time', 'index', 'intersection_node', 'goal_node','session_number'])
    res.sort_values(by = ['session_number', 'trial_number'], ascending=[True,True], inplace = True)
    res.drop_duplicates(subset=['session_number','trial_number'],keep=False, inplace = True)
    return res

In [56]:
def getInfoRat(ratNr, sessionNr):
    mean1,mean2,intersection_point, session_number, dataRatNr = intersectionPoint(ratNr)
    df = getnewDataframe(dataRatNr, sessionNr, ratNr)
    cf = df.copy()
    cf = cf.groupby(['intersection_node'])['intersection_node'].count().reset_index(
  name='Count').sort_values(['Count'], ascending=False)
    #print(cf)
    return df, cf

In [57]:
dataFram, dictF = getInfoRat('Rat5', '2021-06-29')
#print(dataFram)
print(dictF)

[0.17783625]
    intersection_node  Count
9                 301      4
1                 118      2
2                 120      2
3                 201      2
5                 207      2
11                318      2
12                403      2
0                 115      1
4                 202      1
6                 209      1
7                 210      1
8                 220      1
10                308      1
13                405      1
14                410      1
15                417      1


In [19]:
mean1,mean2,intersection_point, session_number, dataRatNr = intersectionPoint('Rat5')

In [58]:
dataFram, dictF = getInfoRat('Rat5', 0)
#print(dataFram)
print(dictF)

[0.16087066]
[0.17783625]
[0.16338227]
[0.16134963]
[0.18747182]
[0.14360431]
    intersection_node  Count
38                401     13
23                301      7
5                 118      7
11                202      7
8                 122      6
13                207      6
6                 120      5
10                201      4
28                308      4
29                310      4
36                323      4
41                407      4
25                303      4
12                203      3
26                305      3
37                324      3
15                210      3
39                403      3
34                318      3
19                220      3
20                222      3
1                 109      3
4                 117      3
43                409      2
35                322      2
33                317      2
0                 108      2
2                 115      2
7                 121      2
27                307      1
14                209  

In [50]:
n_sess = np.argwhere(session_number=='2021-06-28')[0][0]

In [51]:
n_sess

0

In [36]:
session_number

array(['2021-06-28', '2021-06-29', '2021-07-10', '2021-07-11',
       '2021-07-13', '2021-07-15'], dtype=object)

In [68]:
dictF.loc[0]['intersection_node']

108

In [67]:
dictF.shape[0]

46