In [9]:
import os
import sys
import json
from collections import Counter
from scipy.integrate import solve_ivp
import numpy as np
from numpy.linalg import norm
from numpy import sqrt
from sklearn.metrics import accuracy_score
from math import pi
import pickle
import matplotlib.pyplot as plt
from packages import data_container
from packages.data_container import Data
from packages.helper import play_trajs, rotate, sp2a, v2sp, dist, psi, beta, d_theta, d_psi, sp2v, dist, min_dist, \
    vector_angle, signed_angle, side, inner, theta, min_sep, traj_speed
from packages.ode_simulator import ODESimulator
# For pickle to load the Data object, which is defined in packages.data_container
sys.modules['data_container'] = data_container

# file = os.path.abspath(os.path.join(os.getcwd(), os.pardir, 'Raw_Data', 'Fajen_steer1a_data.pickle'))
# with open(file, 'rb') as f:
#     data = pickle.load(f)

# file = os.path.abspath(os.path.join(os.getcwd(), os.pardir, 'Raw_Data', 'Bai_movObst1_data.pickle'))
# with open(file, 'rb') as f:
#     data = pickle.load(f)

# file = os.path.abspath(os.path.join(os.getcwd(), os.pardir, 'Raw_Data', 'Bai_movObst1b_data.pickle'))
# with open(file, 'rb') as f:
#     data2 = pickle.load(f)

file = os.path.abspath(os.path.join(os.getcwd(), os.pardir, 'Raw_Data', 'Cohen_movObst1_data.pickle'))
with open(file, 'rb') as f:
    data = pickle.load(f)


In [12]:
len(data.trajs)

1008

In [46]:
cnt = 0
for i in range(len(data.trajs)):
    if i not in data.dump:
        cnt += 1
cnt

995

In [45]:
len(data.trajs)

1008

In [55]:
vs = []
for i in range(len(data.trajs)):
    if i in data.dump or data.info['obst_speed'][i] == 0:
        continue
    t0 = data.info['obst_onset'][i]
    t1 = data.info['obst_out'][i]
    s = min(traj_speed(data.info['p_subj'][i][t0:t1], Hz=data.Hz))
    vs.append((s, i))
vs.sort()
n = int(len(vs) / 100)
print([x[1] for x in vs[:n]])


[129, 163, 469, 319, 292, 689, 470, 269]


In [33]:
'''Animate data'''
############
subject = 0
trial = 31

############
%matplotlib qt
i = subject * 160 + trial - 1
i = 306
# p_obst = np.array(data.info['p_obst'][i])
t0 = data.info['obst_onset'][i]
p_goal = data.info['p_goal'][i]
p_subj = data.info['p_subj'][i]
p_obst = data.info['p_obst'][i]
trajs = [p_goal, p_obst, p_subj]
ws = [data.info['w_goal'], data.info['w_obst'], 0.4]
title = 'subj ' + str(data.info['subj_id'][i]) + ' trial ' + str(data.info['trial_id'][i]) + ' obst_angle: ' + str(data.info['obst_angle'][i]) + ' obst_speed: ' + str(data.info['obst_speed'][i])
print(data.info['subj_id'][i], data.info['trial_id'][i])
play_trajs(trajs, ws, data.Hz, title=title, save=False)


1 147


<matplotlib.animation.FuncAnimation at 0x2b886ea10c8>

In [None]:
'''Minimum Passing Distance'''
%matplotlib qt
fig = plt.figure()
ax = fig.add_subplot()
ax.set_title('Signed predicted minimum passing distance (SMPD)')
ax.set_ylabel('SMPD (m)')
ax.set_xlabel('normalized time (%)')
ax.set_ylim((-2, 2))
for i in range(len(data.trajs)):
    t0 = data.info['stimuli_onset'][i]
    t1 = data.info['stimuli_out'][i]
    p0 = data.info['p_subj'][i][t0:t1]
    p1 = data.info['p_obst'][i][t0:t1]
    v0 = data.info['v_subj'][i][t0:t1]
    v1 = data.info['v_obst'][i][t0:t1]
    t = np.linspace(0, 100, len(p0))
    smpd = []
    for _p0, _p1, _v0, _v1 in zip(p0, p1, v0, v1):        
        smpd.append(min_sep(_p0, _p1, _v0, _v1)[0])
    ax.plot(t, smpd, 'k', linewidth=0.1, alpha=0.5)
    
    

In [None]:
'''Plot acceleration angle and magnitude'''
%matplotlib qt
trials = range(400,401)
subject = 1
con_angle = [90, -90]
con_speed = []
fig0 = plt.figure()
ax0 = fig0.add_subplot()
fig1 = plt.figure()
ax1 = fig1.add_subplot()
for i in trials:
    angle = data.info['obst_angle'][i]
    speed = data.info['obst_speed'][i]
    subj_id = data.info['subj_id'][i]
    if i in data.dump:
        continue
#     if angle not in con_angle or subj_id != subject:
#         continue
    t0, t1 = data.info['stimuli_onset'][i], data.info['stimuli_out'][i]
    p0, p1, a0 = np.array(data.info['p_subj'][i][t0:t1]), np.array(data.info['p_obst'][i][t0:t1]), np.array(data.info['a_subj'][i][t0:t1])
    angles = signed_angle(p1 - p0, a0)
    ax0.scatter(range(len(angles)), angles, s=1)
    ax1.plot(norm(a0, axis=-1))
    print(data.info['subj_id'][i], data.info['trial_id'][i])

In [None]:
'''Plot dpsi dtheta by time'''
%matplotlib qt
subjects = range(16)
n = len(data.trajs)
fig0 = plt.figure()
ax0 = fig0.add_subplot()
fig1 = plt.figure()
ax1 = fig1.add_subplot()
fig2 = plt.figure()
ax2 = fig2.add_subplot()
for i in range(2,3):
    if (data.info['subj_id'][i] in subjects and
        i not in data.dump and
        data.info['obst_speed'][i] != 0 and
        abs(data.info['obst_angle'][i]) != 180):
        t0, t1 = data.info['stimuli_onset'][i], data.info['stimuli_out'][i]
        p0, p1 = np.array(data.info['p_subj'][i][t0:t1]), np.array(data.info['p_obst'][i][t0:t1])
        v0, v1 = np.array(data.info['v_subj'][i][t0:t1]), np.array(data.info['v_obst'][i])
        a0 = np.array(data.info['a_subj'][i][t0:t1])
        a0 = norm(a0, axis=-1)
        v1 = np.tile(v1, (len(v0), 1))
        dpsis = np.absolute(d_psi(p0, p1, v0, v1))
        dthetas = d_theta(p0, p1, v0, v1, w=0.1)
        thetas = theta(p0, p1, w=0.1)
        ratio = dthetas/thetas
        ax0.plot(a0)
        c1, c2 = 0, 0.2
        ax1.plot((ratio+c1)/(dpsis+c2))
        ax2.scatter(ratio[60:80], a0[60:80], s=1)
#         ax2.scatter(range(len(dpsis)), dthetas/dpsis, s=1)

In [4]:
'''Plot data by condition'''
#####################
subject = -1
con_ang = [90]
con_spd = [1.2]
# con_ang = set(data.info['obst_angle'])
# con_spd = set(data.info['obst_speed'])
#####################
%matplotlib qt
plt.figure()
for i in range(len(data.trajs)):
    obst_speed = data.info['obst_speed'][i]
    obst_angle = data.info['obst_angle'][i]
    subj_id = data.info['subj_id'][i]
    if subject != -1 and subj_id != subject:
        continue
    if not (obst_speed in con_spd and abs(obst_angle) in con_ang):
        continue
    subj = data.info['p_subj'][i]
    obst = np.array(data.info['p_obst'][i])
    if i % 2 == 0:
        subj = rotate(subj, np.arctan(11 / 9) - pi / 2)
        obst = rotate(obst, np.arctan(11 / 9) - pi / 2)
    else:
        subj = rotate(subj, np.arctan(11 / 9) + pi / 2)
        obst = rotate(obst, np.arctan(11 / 9) + pi / 2)
    if obst_angle < 0:
        subj[:, 0] *= -1
        obst[:, 0] *= -1
    plt.plot(subj[:, 0], subj[:, 1])
    plt.plot(obst[:, 0], obst[:, 1])
ax = plt.gca()
ax.set_aspect('equal')
ax.set_title('subj ' + str(subject) + ' angle: ' + str(con_ang[0]) + ' speed: ' + str(con_spd[0]))

Text(0.5, 1.0, 'subj -1 angle: 90 speed: 1.2')

In [32]:
'''Plot data by subject'''
#####################
subject = 3
#####################
%matplotlib qt
fig = plt.figure()
fig.suptitle('Subject ' + str(subject))
axes = {}
obst_angle = [90, 112.5, 135, 157.5, 180]
obst_speed = [0.9, 1.0, 1.1, 1.2, 1.3]
i_plot = 1
for angle in obst_angle:
    for speed in obst_speed:
        axes[(angle, speed)] = fig.add_subplot(5, 5, i_plot)
        axes[(angle, speed)].set_xlim(-3, 3)
        axes[(angle, speed)].set_ylim(-7, 5)
        axes[(angle, speed)].set_title(str(angle) + '° ' + str(speed) + 'm/s')
        axes[(angle, speed)].set_aspect('equal')
        i_plot += 1
for i in range(len(data.trajs)):
    speed = data.info['obst_speed'][i]
    angle = data.info['obst_angle'][i]
    subj_id = data.info['subj_id'][i]
    if subj_id != subject or speed == 0:
        continue
    subj = np.array(data.info['p_subj'][i])
    obst = np.array(data.info['p_obst'][i])
    if i % 2 == 0:
        subj = rotate(subj, np.arctan(11 / 9) - pi / 2)
        obst = rotate(obst, np.arctan(11 / 9) - pi / 2)
    else:
        subj = rotate(subj, np.arctan(11 / 9) + pi / 2)
        obst = rotate(obst, np.arctan(11 / 9) + pi / 2)
    if angle < 0:
        subj[:, 0] *= -1
        obst[:, 0] *= -1
    axes[(abs(angle), speed)].plot(subj[:, 0], subj[:, 1])
    axes[(abs(angle), speed)].plot(obst[:, 0], obst[:, 1])

In [None]:
'''Check initial dpsi and subject passing choice, plot dpsi of matching and non-matching trials'''
subjects = set(data.info['subj_id'])
print(subjects)
sim = ODESimulator(data=data, ref=[0,1])
side_pred = []
side_true = []
dpsi_match = []
dpsi_not = []
for i in range(len(sim.data.trajs)):
    if (sim.data.info['subj_id'][i] in subjects and
        i not in sim.data.dump and
        sim.data.info['obst_speed'][i] != 0 and
        abs(sim.data.info['obst_angle'][i]) != 180):
        xg, yg, xo, yo, vxo, vyo, x, y, vx, vy, a, phi, s, dphi, ds = sim.compute_var0(i, sim.data.info['stimuli_onset'][i])
        # When beta and dpsi has the same sign it means pass in front, otherwise it means pass from behind
        dpsi = d_psi([x, y], [xo, yo], [vx, vy], [vxo, vyo])
        b = beta([x, y], [xo, yo], [vx, vy])
        side_pred.append(1 if b * dpsi > 0 else -1)
        side_true.append(sim.data.info['pass_order'][i])
        if side_pred[-1] == side_true[-1]:
            dpsi_match.append(dpsi)
        else:
            dpsi_not.append(dpsi)
print('passing order matching rate ', accuracy_score(side_true, side_pred))
plt.scatter(np.random.uniform(size=len(dpsi_match)), dpsi_match, label='matching trials')
plt.scatter(np.random.uniform(size=len(dpsi_not)), dpsi_not, label='non-matching trials')
plt.ylabel('dpsi')
plt.xlabel('arbitrary')
plt.legend()
thress = np.linspace(0, 0.1, 11)
match_rates = []
trial_ratios = []
for thres in thress:
    a, b = sum([abs(x) > thres for x in dpsi_match]), sum([abs(x) > thres for x in dpsi_not])
    match_rates.append(a/(a+b))
    trial_ratios.append((a+b)/len(side_pred))
plt.figure()
plt.scatter(thress, match_rates, label='Match Rate')
plt.scatter(thress, trial_ratios, label='Trial Ratio')
plt.legend()
plt.ylabel('Percentage')
plt.xlabel('dpsi')
# print(a/(a+b), ' matching rate for trials with initial dpsi smaller than ', thres)


In [None]:
'''Check trial length from match_onset to stimuli_out'''
match_onsets = []
lens = []
for i in range(len(data.trajs)):
    if i not in data.dump:
        p0 = data.info['p_subj'][i]
        p1 = data.info['p_obst'][i]
        v0 = data.info['v_subj'][i]
        v1 = data.info['v_obst'][i]
        stimuli_out = data.info['stimuli_out'][i]
        stimuli_onset = data.info['stimuli_onset'][i]
        match_onset = data.info['match_onset'][i]
        dpsis = d_psi(p0, p1, v0, v1)
        if stimuli_out - match_onset <= 30:
            plt.plot(dpsis)
            plt.plot(dpsis[match_onset:stimuli_out])
            print(i)
        match_onsets.append(data.info['match_onset'][i])
        lens.append(data.info['stimuli_out'][i] - data.info['match_onset'][i])
# plt.hist(match_onsets)
# plt.figure()
# plt.hist(lens)

In [31]:
'''Count pass order by condition'''
#########
angle = 157.5
speed = 1.1
#########
fpass = 0
total = 0
for i in range(len(data.trajs)):
    if abs(data.info['obst_angle'][i]) != angle or data.info['obst_speed'][i] != speed:
        continue
    if data.info['pass_order'][i] == 1:
        fpass += 1
    total += 1
print(fpass, total)
print(fpass / total)

27 66
0.4090909090909091


### Approach data

In [None]:

'''Plot data by condition'''
#####################
subject = 0
s0 = [1.4]
d0 = [8]
angle = [15]
s0 = set(data2.info['goal_s0'])
d0 = set(data2.info['goal_d0'])
angle = set(data2.info['goal_angle'])
#####################
%matplotlib qt
plt.figure()
for i in range(len(data2.trajs)):
    goal_s0 = data2.info['goal_s0'][i]
    goal_d0 = data2.info['goal_d0'][i]
    goal_angle = data2.info['goal_angle'][i]
    subj_id = data2.info['subj_id'][i]
    if subj_id != subject:
        continue
    if not (goal_s0 in s0 and goal_d0 in d0 and abs(goal_angle) in angle):
        continue
    subj = np.array(data2.info['p_subj'][i])
    goal = np.array(data2.info['p_goal'][i])
    if i % 2 == 0:
        subj = rotate(subj, np.arctan(11 / 9) - pi / 2)
        goal = rotate(goal, np.arctan(11 / 9) - pi / 2)
    else:
        subj = rotate(subj, np.arctan(11 / 9) + pi / 2)
        goal = rotate(goal, np.arctan(11 / 9) + pi / 2)
#     if goal_angle < 0:
#         subj[:, 0] *= -1
#         goal[:, 0] *= -1
    plt.plot(subj[:, 0], subj[:, 1])
    print(subj)
    plt.plot(goal[:, 0], goal[:, 1])
ax = plt.gca()
ax.set_aspect('equal')
ax.set_title(f'subj {str(subject)} angle: {str(angle)} s0: {str(s0)} d0: {str(d0)}')

In [None]:
'''Cohen experiments'''
file = os.path.abspath(os.path.join(os.getcwd(), os.pardir, 'Raw_Data', 'Cohen_movObst_exp1_data.pickle'))
with open(file, 'rb') as f:
    cohen1 = pickle.load(f)

In [None]:
cohen1.info['pass_order']
# data.info['pass_order']

In [10]:
''' Check optimal parameters from training results '''
bests = {} 
path = os.path.abspath(os.path.join(os.getcwd(), os.pardir, 'Results', 'Cohen_movObst2_all_cohen_avoid1&2&3&4'))
filenames = [os.path.join(path, name) for name in os.listdir(path) if name[-3:] == 'txt']
filenames = [filenames[3]]
# b1s = []
# k1s = []
errs = []
for filename in filenames:
    with open(filename, 'rb') as f:
        best = 0
        e_min = float('inf')
        for i, line in enumerate(f):
            if i == 1:
                subj_id = line[-5:-2]
            if i >= 11:
                try:
                    err = str(line).split("\\t")[-4][:10]
                    if err[0] == '0':
                        err = float(err)
                        if err < e_min:
                            e_min = err
                            best = str(line).replace("\\", "")
                except:
                    pass
                # Find all parameters that have small error
#                 if subj_id == 0 and float(err) < 0.16:
#                     model = str(line).split("\\t")[1].split("}")[1][1:] + '}'
#                     model = model.replace('\'', '\"')
#                     d = json.loads(model)
#                     b1s.append(d['b1'])
#                     k1s.append(d['k1'])
#                     errs.append(float(err))
    bests[subj_id] = best
for i, best in bests.items():
    print('\n')
    print(i, best)



b'all' b"2593t[{'name': 'fajen_approach2', 'b1': 2.04992354, 'k1': 2.85641543, 'c1': 0.54294928, 'c2': 0.73857217, 'b2': 3.89580222, 'k2': 5.04511601, 'ps': 1.1242147210583378}, {'name': 'cohen_avoid4', 'k1': 2.0736814583750123, 'c5': 0.277869777057664, 'c6': 11.55562093696106, 'k2': 0.695174380649082, 'c7': 5.707392197706118, 'c8': 7.709406732831272, 'ps': 1.1242147210583378}]t0.24684128959749946torder_accuracyt0.7885714285714286t0:00:20rn"


In [11]:
''' Check optimal parameters from cross validation'''
bests = {} 
path = os.path.abspath(os.path.join(os.getcwd(), os.pardir, 'Results', 'cohen_movObst2_cross_validation_cohen_avoid'))
filenames = [os.path.join(path, name) for name in os.listdir(path) if name[-3:] == 'txt']
errs = []
for filename in filenames:
    with open(filename, 'rb') as f:
        best = 0
        e_min = float('inf')
        if filename[-6] == '_':
            subj_id = int(filename[-5])
        else:
            subj_id = int(filename[-6:-4])
        for i, line in enumerate(f):
            if i >= 5:
                try:
                    err = str(line).split("\\t")[-3][:10]
                    if err[0] == '0':
                        err = float(err)
                        if err < e_min:
                            e_min = err
                            best = str(line).replace("\\", "")
                except:
                    pass

    bests[subj_id] = e_min
for i, err in bests.items():
    print('\n')
    print(i, err)



1 inf


2 inf


3 inf


4 inf


5 inf


6 inf


7 inf


10 inf


11 inf


12 inf


13 inf


14 inf


15 inf


8 inf


9 inf


In [None]:
bests