In [1]:
import os
import sys
import json
import random
from itertools import product
from collections import Counter, defaultdict
from scipy.integrate import solve_ivp
from scipy.interpolate import interp1d
from scipy.stats import ttest_ind
import numpy as np
from numpy.linalg import norm
from numpy import sqrt
from sklearn.metrics import accuracy_score
from sklearn.manifold import MDS
from sklearn.decomposition import PCA
from math import pi
import pickle
import matplotlib.pyplot as plt
from packages import data_container
from packages.data_container import Data
from packages.helper import play_trajs, rotate, sp2a, v2sp, psi, beta, d_theta, d_psi, sp2v, dist, min_dist, \
    vector_angle, signed_angle, side, inner, theta, min_sep, traj_speed
from packages.ode_simulator import ODESimulator
from parameters import approaches, avoidances
# For pickle to load the Data object, which is defined in packages.data_container
sys.modules['data_container'] = data_container
data = None
def load_data(filename):
    global data
    filename = filename + '.pickle'
    file = os.path.abspath(os.path.join(os.getcwd(), os.pardir, 'Raw_Data', filename))
    with open(file, 'rb') as f:
        data = pickle.load(f)
        
# file = os.path.abspath(os.path.join(os.getcwd(), os.pardir, 'Raw_Data', 'Fajen_steer1a_data.pickle'))
# with open(file, 'rb') as f:
#     data = pickle.load(f)

# file = os.path.abspath(os.path.join(os.getcwd(), os.pardir, 'Raw_Data', 'Bai_movObst2_data.pickle'))
# with open(file, 'rb') as f:
#     data = pickle.load(f)

# file = os.path.abspath(os.path.join(os.getcwd(), os.pardir, 'Raw_Data', 'Bai_movObst1_data_30Hz.pickle'))
# with open(file, 'rb') as f:
#     data = pickle.load(f)

# file = os.path.abspath(os.path.join(os.getcwd(), os.pardir, 'Raw_Data', 'Cohen_movObst1_data.pickle'))
# with open(file, 'rb') as f:
#     data = pickle.load(f)

    
%matplotlib qt

In [None]:
'''Multidimensional Scaling'''
model = 'cohen_avoid'
subjects = [0, 1, 2, 3, 4, 6, 7, 9, 10, 11, 12, 13]
X = []
for s in subjects:
    X.append(list(avoidances['Bai_movObst1'][model]['differential_evolution'][s].values())[1:])
embedding = MDS(n_components=2)
X_transformed = embedding.fit_transform(X)
plt.scatter(X_transformed[:,0], X_transformed[:,1], label=model)
for i, txt in enumerate(subjects):
    plt.annotate(str(txt), X_transformed[i])
plt.title("Parameters of 4 models fitted to individuals")
plt.xlabel('embedding dimension 1')
plt.ylabel('embedding dimension 2')
plt.gca().set_aspect('equal')
plt.legend()

In [None]:
'''PCA'''
model = 'cohen_avoid'
subjects = [0, 1, 2, 3, 4, 6, 7, 9, 10, 11, 12, 13]
X = []
for s in subjects:
    X.append(list(avoidances['Bai_movObst1'][model]['differential_evolution'][s].values())[1:])
pca = PCA(n_components=2)
pca.fit(X)
X_transformed = pca.transform(X)
print(avoidances['Bai_movObst1'][model]['differential_evolution'][s].keys())
print(pca.components_)
print(pca.explained_variance_ratio_)
plt.scatter(X_transformed[:,0], X_transformed[:,1], label=model)
for i, txt in enumerate(subjects):
    plt.annotate(str(txt), X_transformed[i])

In [None]:
cnt = 0
total = 0
headon = 0
for i in range(len(data.trajs)):
    if data.info['obst_speed'][i] == 0:
        continue
    total += 1
    if abs(data.info['obst_angle'][i]) == 180:
        headon += 1
    if i not in data.dump:
        cnt += 1
    else:
        print(data.dump[i])
print(total, cnt, headon, cnt - headon)

In [None]:
set(data.info['obst_speed'])

### Cohen_movObst1

In [None]:
'''Animate data'''
load_data('Cohen_movObst1_data')
############
subject = 3
trial = 111
############
%matplotlib qt
i = subject * 160 + trial - 1
i = 616
# p_obst = np.array(data.info['p_obst'][i])
t0 = data.info['obst_onset'][i]
t1 = data.info['obst_out'][i]
p_goal = data.info['p_goal'][i]
p_subj = data.info['p_subj'][i]
p_obst = data.info['p_obst'][i]
trajs = [p_goal, p_obst, p_subj]
ws = [data.info['w_goal'], data.info['w_obst'], 0.4]
title = 'subj ' + str(data.info['subj_id'][i]) + ' trial ' + str(data.info['trial_id'][i]) + ' obst_angle: ' + str(data.info['obst_angle'][i]) + ' obst_speed: ' + str(data.info['obst_speed'][i])
play_trajs(trajs, ws, data.Hz, title=title, save=False)


In [None]:
'''Time averaged traj by condition'''
'''
In within subject design, the standard error is within subject sum of squares (s_w) divided by root(2n(m-1))
n is the number of subjects m is number of conditions.
Confidence interval is +-z*SE, for 95% CI z = 1.96
'''
%matplotlib qt
for angle in set(np.abs(data.info['obst_angle'])):
    for speed in set(data.info['obst_speed']):
        fig = plt.figure()
        ax = fig.add_subplot(1,1,1)
        ax.set_title(f'obst_angle {angle}, obst_speed {speed}')
        ax.set_xlabel('x (m)')
        ax.set_ylabel('y (m)')
        ax.set_aspect('equal')
        ax.set_xlim(-1.5, 1.5)
        ax.set_ylim(-0.5, 8.5)
        front = []
        behind = []
        len_front = float('inf')
        len_behind = float('inf')
        # Plot trials
        for i in range(len(data.trajs)):
            if angle != abs(data.info['obst_angle'][i]) or speed != data.info['obst_speed'][i] or i in data.dump:
                continue
            traj = np.array(data.info['p_subj'][i])
            order = data.info['pass_order'][i]
            if data.info['obst_angle'][i] < 0:
                traj[:, 0] *= -1
                
            ax.plot(traj[:, 0], traj[:, 1], color='k', linewidth=0.1, alpha=0.5)
            m = len(traj) // 2
            if order == 1:
                front.append(traj)
                len_front = min(len(traj), len_front)
            elif order == -1:
                behind.append(traj)
                len_behind = min(len(traj), len_behind)
        # Plot average trials
        front = [traj[:len_front] for traj in front]
        behind = [traj[:len_behind] for traj in behind]
        front = np.mean(front, axis=0)
        behind = np.mean(behind, axis=0)
        ax.plot(front[:, 0], front[:, 1], color='k')
        ax.plot(behind[:, 0], behind[:, 1], color='k')

In [None]:
traj = np.array(data.info['p_obst'][0])
np.array(data.info['obst_angle'][0])

In [None]:
'''Space averaged traj by condition'''
%matplotlib qt
y_max = float('inf')
y_min = -float('inf')
for i in range(len(data.trajs)):
    if i in data.dump:
        continue
    traj = np.array(data.info['p_subj'][i])
    y_max = min(y_max, max(traj[:, 1]))
    y_min = max(y_min, min(traj[:, 1]))
print(f'common y range among all trials is [{y_min}, {y_max}]')

# Compute y positions
y_step = 0.01
y = np.linspace(y_min, y_max, int((y_max - y_min) / y_step))
# Interpolate x by y
iplot = 1
angles = sorted(set(np.abs(data.info['obst_angle'])))
speeds = sorted(set(data.info['obst_speed']))
fig = plt.figure()
for angle in angles:
    for speed in speeds:
        ax = fig.add_subplot(len(angles), len(speeds), iplot)
        iplot += 1
        ax.set_title(f'angle {angle}°, speed {speed} m/s')
        ax.set_xlabel('x (m)')
        ax.set_ylabel('y (m)')
        ax.set_aspect('equal')
        ax.set_xlim(-1.5, 1.5)
        ax.set_ylim(-0.2, 8)
        front = []
        behind = []
        # Plot trials
        for i in range(len(data.trajs)):
            if angle != abs(data.info['obst_angle'][i]) or speed != data.info['obst_speed'][i] or i in data.dump:
                continue
            traj = np.array(data.info['p_subj'][i])
            if data.info['obst_angle'][i] < 0:
                traj[:,0] = -traj[:,0]
            x = np.interp(y, traj[:,1], traj[:,0])
            order = data.info['pass_order'][i]
            if order == 1:
                front.append(x)
                ax.plot(traj[:, 0], traj[:, 1], color='r', linewidth=0.1, alpha=0.5)
            elif order == -1:
                behind.append(x)
                ax.plot(traj[:, 0], traj[:, 1], color='b', linewidth=0.1, alpha=0.5)
        # Plot average trials
        front = np.mean(front, axis=0)
        behind = np.mean(behind, axis=0)
        ax.plot(front, y, color='r')
        ax.plot(behind, y, color='b')

In [None]:
data.info.keys()

In [None]:
'''Speed control vs heading control all trials'''
%matplotlib qt
sc = []
hc = []
for i in range(len(data.trajs)):
    if i in data.dump: continue
    t0 = data.info['obst_onset'][i]
    t1 = data.info['obst_out'][i]
    v = data.info['v_subj'][i][t0:t1]
    a = data.info['a_subj'][i][t0:t1]
    speed_ctrl = np.array([np.inner(x, y)/norm(y) for x, y in zip(a, v)]) # a along direction of v
    heading_ctrl = np.sqrt(norm(a, axis=1) ** 2 - speed_ctrl ** 2) # a perpendicular to v
    sc.extend(speed_ctrl)
    hc.extend(heading_ctrl)
    plt.plot(speed_ctrl, color='r', linewidth=0.1, alpha=0.5)
    plt.plot(heading_ctrl, color='b', linewidth=0.1, alpha=0.5)
plt.title(f'mean speed control {np.mean(sc):.2f} mean heading control {np.mean(hc):.2f} ratio {np.mean(sc) / np.mean(hc): .2f}')

In [None]:
'''Signed speed vs heading control by conditions (abs obst_angle)'''
%matplotlib qt
for angle in set(np.abs(data.info['obst_angle'])):
    for speed in set(data.info['obst_speed']):
        fig = plt.figure()
        ax = fig.add_subplot(1,1,1)
        ax.set_title(f'obst_angle {angle}, obst_speed {speed}')
        ax.set_ylabel('acceleration (m/s**2)')
        ax.set_xlabel('time (s)')
        ax.set_xlim(0, 4.5)
        ax.set_ylim(-1.5, 1.5)
        sc = []
        hc = []
        for i in range(len(data.trajs)):
            if angle != abs(data.info['obst_angle'][i]) or speed != data.info['obst_speed'][i] or i in data.dump:
                continue
            t0 = data.info['obst_onset'][i]
            t1 = data.info['obst_out'][i]
            v = data.info['v_subj'][i][t0:t1]
            a = data.info['a_subj'][i][t0:t1]
            speed_ctrl = np.array([np.inner(x, y)/norm(y) for x, y in zip(a, v)]) # a along direction of v
            heading_ctrl = np.sqrt(norm(a, axis=1) ** 2 - speed_ctrl ** 2) # a perpendicular to v
            # Check the side of heading ctrl: front (+) or behind (-) driving
            pass_side = np.sign(inner(rotate(a, pi/2), v)) * np.sign(data.info['obst_angle'][i]) * -1
            heading_ctrl *= pass_side
            sc.extend(speed_ctrl)
            hc.extend(heading_ctrl)
            plt.plot(np.linspace(0, len(speed_ctrl)/data.Hz, len(speed_ctrl)), speed_ctrl, color='r', linewidth=0.1, alpha=0.5)
            plt.plot(np.linspace(0, len(speed_ctrl)/data.Hz, len(speed_ctrl)), heading_ctrl, color='b', linewidth=0.1, alpha=0.5)
        ax.set_title(f'obst_angle {angle}, obst_speed {speed}, \n mean speed control {np.mean(np.abs(sc)):.2f} mean heading control {np.mean(np.abs(hc)):.2f} ratio {np.mean(np.abs(sc)) / np.mean(np.abs(hc)): .2f}')

In [None]:
'''Signed speed vs heading control by subject'''
for s in set(data.info['subj_id']):
    fig = plt.figure()
    ax = fig.add_subplot(1,1,1)
    ax.set_title(f'subject {s}')
    ax.set_ylabel('acceleration (m/s**2)')
    ax.set_xlabel('time (s)')
    ax.set_xlim(0, 4.5)
    ax.set_ylim(-1.5, 1.5)
    sc = []
    hc = []
    for i in range(len(data.trajs)):
        if s != data.info['subj_id'][i] or i in data.dump:
            continue
        t0 = data.info['obst_onset'][i]
        t1 = data.info['obst_out'][i]
        v = data.info['v_subj'][i][t0:t1]
        a = data.info['a_subj'][i][t0:t1]
        speed_ctrl = np.array([np.inner(x, y)/norm(y) for x, y in zip(a, v)]) # a along direction of v
        heading_ctrl = np.sqrt(norm(a, axis=1) ** 2 - speed_ctrl ** 2) # a perpendicular to v
        # Check the side of heading ctrl: front (+) or behind (-) driving
        pass_side = np.sign(inner(rotate(a, pi/2), v)) * np.sign(data.info['obst_angle'][i]) * -1
        heading_ctrl *= pass_side
        sc.extend(speed_ctrl)
        hc.extend(heading_ctrl)
        plt.plot(np.linspace(0, len(speed_ctrl)/data.Hz, len(speed_ctrl)), speed_ctrl, color='r', linewidth=0.1, alpha=0.5)
        plt.plot(np.linspace(0, len(speed_ctrl)/data.Hz, len(speed_ctrl)), heading_ctrl, color='b', linewidth=0.1, alpha=0.5)
    ax.set_title(f'subject {s}, \n mean speed control {np.mean(np.abs(sc)):.2f} mean heading control {np.mean(np.abs(hc)):.2f} ratio {np.mean(np.abs(sc)) / np.mean(np.abs(hc)): .2f}')

In [None]:
'''Speed/heading control ratio by condition (signed angle)'''
%matplotlib qt
conditions = np.zeros((3, 6)) # speed by angle
angle_ins = {-110: 0, -90: 1, -70: 2, 70: 3, 90: 4, 110: 5}
speed_ins = {0.4: 0, 0.6: 1, 0.8: 2}
for angle in set(data.info['obst_angle']):
    for speed in set(data.info['obst_speed']):
        sc = []
        hc = []
        for i in range(len(data.trajs)):
            if angle != data.info['obst_angle'][i] or speed != data.info['obst_speed'][i] or i in data.dump:
                continue
            t0 = data.info['obst_onset'][i]
            t1 = data.info['obst_out'][i]
            v = data.info['v_subj'][i][t0:t1]
            a = data.info['a_subj'][i][t0:t1]
            speed_ctrl = np.abs([np.inner(x, y)/norm(y) for x, y in zip(a, v)]) # a along direction of v
            heading_ctrl = np.sqrt(norm(a, axis=1) ** 2 - speed_ctrl ** 2) # a perpendicular to v
            sc.extend(speed_ctrl)
            hc.extend(heading_ctrl)
        conditions[speed_ins[speed], angle_ins[angle]] = round(np.mean(sc) / np.mean(hc), 2)

# Heatmap
plt.imshow(conditions, cmap='plasma')
for i in range(conditions.shape[0]):
    for j in range(conditions.shape[1]):
        text = plt.text(j, i, str(conditions[i, j]),
                       ha="center", va="center", color="w")
plt.xticks(range(0, 6), list(angle_ins.keys()))
plt.yticks(range(0, 3), list(speed_ins.keys()))
plt.xlabel("angle")
plt.ylabel("speed")

In [None]:
'''Speed/heading control ratio by condition (abs angle)'''
%matplotlib qt
conditions = np.zeros((3, 3)) # speed by angle
angle_ins = {70: 0, 90: 1, 110: 2}
speed_ins = {0.4: 0, 0.6: 1, 0.8: 2}
for angle in set(np.abs(data.info['obst_angle'])):
    for speed in set(data.info['obst_speed']):
        sc = []
        hc = []
        for i in range(len(data.trajs)):
            if angle != abs(data.info['obst_angle'][i]) or speed != data.info['obst_speed'][i] or i in data.dump:
                continue
            t0 = data.info['obst_onset'][i]
            t1 = data.info['obst_out'][i]
            v = data.info['v_subj'][i][t0:t1]
            a = data.info['a_subj'][i][t0:t1]
            speed_ctrl = np.abs([np.inner(x, y)/norm(y) for x, y in zip(a, v)]) # a along direction of v
            heading_ctrl = np.sqrt(norm(a, axis=1) ** 2 - speed_ctrl ** 2) # a perpendicular to v
            sc.extend(speed_ctrl)
            hc.extend(heading_ctrl)
        conditions[speed_ins[speed], angle_ins[angle]] = round(np.mean(sc) / np.mean(hc), 2)

# Heatmap
plt.imshow(conditions, cmap='plasma')
for i in range(conditions.shape[0]):
    for j in range(conditions.shape[1]):
        text = plt.text(j, i, str(conditions[i, j]),
                       ha="center", va="center", color="w")
plt.xticks(range(0, 3), list(angle_ins.keys()))
plt.yticks(range(0, 3), list(speed_ins.keys()))
plt.xlabel("angle")
plt.ylabel("speed")

In [None]:
'''effective data length'''
%matplotlib qt
l = []
sim = ODESimulator(data=data, ref=[0,1])
for i in range(len(data.trajs)):
    if i not in data.dump and data.info['obst_speed'][i] > 0 and abs(data.info['obst_angle'][i]) != 180:
        t0 = data.info['threshold_onset'][i]
        t1 = data.info['obst_out'][i]
        if t0 and t1:
            l.append((t1 - t0) / data.Hz)
plt.hist(l, bins=40)
# plt.plot(l)

In [None]:
'''% Correct pass order by dtheta threshold'''
load_data('Cohen_movObst1_data')
threshold = 0.02
thress = np.linspace(0.001, 0.03, 20)
threshold = thress[np.argmin(np.abs(thress - threshold))]
accuracies = []
lengths_thres = []
lengths = []
for thres in thress:
    correct = 0
    total = 0
    for i in range(len(data.trajs)):
        if i not in data.dump:
            t0 = data.info['obst_onset'][i]
            t1 = data.info['obst_out'][i]
            rel_dpsi = data.info['rel_dpsi'][i]
            dtheta = data.info['dtheta'][i]
            for t in range(t0, t1):
                if dtheta[t] > thres:
                    dpsi = rel_dpsi[t]
                    break
            else:
                dpsi = rel_dpsi[t]
            if thres == threshold:
                lengths.append((t1 - t0) / data.Hz)
                lengths_thres.append((t1 - t) / data.Hz)
            if np.sign(dpsi) == data.info['pass_order'][i]:
                correct += 1
            total += 1
    accuracies.append(correct / total)
plt.figure()
plt.plot(thress, accuracies)
plt.title('Cohen_movObst1')
plt.ylabel('pass order accuracy')
plt.xlabel('dthetha threshold (rad/s)')
plt.figure()
plt.hist(lengths_thres)
plt.title(f'Cohen_movObst1 thres={threshold}')
plt.xlabel('effective trial length (s)')
plt.figure()
plt.hist(lengths)
plt.title('Cohen_movObst1')
plt.xlabel('effective trial length (s)')

In [None]:
'''% Correct pass order by dtheta + dpsi threshold'''
load_data('Cohen_movObst1_data')
threshold = 0.05
thress = np.linspace(0.001, 0.1, 20)
threshold = thress[np.argmin(np.abs(thress - threshold))]
accuracies = []
lengths_thres = []
lengths = []
for thres in thress:
    correct = 0
    total = 0
    for i in range(len(data.trajs)):
        if i not in data.dump:
            t0 = data.info['obst_onset'][i]
            t1 = data.info['obst_out'][i]
            rel_dpsi = data.info['rel_dpsi'][i]
            dtheta = data.info['dtheta'][i]
            for t in range(t0, t1):
                if dtheta[t] + abs(rel_dpsi[t]) > thres:
                    dpsi = rel_dpsi[t]
                    break
            else:
                dpsi = rel_dpsi[t]
            if thres == threshold:
                lengths.append((t1 - t0) / data.Hz)
                lengths_thres.append((t1 - t) / data.Hz)
            if np.sign(dpsi) == data.info['pass_order'][i]:
                correct += 1
            total += 1
    accuracies.append(correct / total)
plt.figure()
plt.plot(thress, accuracies)
plt.title('Cohen_movObst1')
plt.ylabel('pass order accuracy')
plt.xlabel('dthetha + dpsi threshold (rad/s)')
plt.figure()
plt.hist(lengths_thres)
plt.title(f'Cohen_movObst1 thres={threshold}')
plt.xlabel('effective trial length (s)')
plt.figure()
plt.hist(lengths)
plt.title('Cohen_movObst1')
plt.xlabel('effective trial length (s)')

In [None]:
'''% Correct pass order by dtheta * dpsi threshold'''
load_data('Cohen_movObst1_data')
threshold = 0.0005
thress = np.linspace(1E-4, 2E-3, 20)
threshold = thress[np.argmin(np.abs(thress - threshold))]
accuracies = []
lengths_thres = []
lengths = []
for thres in thress:
    correct = 0
    total = 0
    for i in range(len(data.trajs)):
        if i not in data.dump:
            t0 = data.info['obst_onset'][i]
            t1 = data.info['obst_out'][i]
            rel_dpsi = data.info['rel_dpsi'][i]
            dtheta = data.info['dtheta'][i]
            for t in range(t0, t1):
                if dtheta[t] * abs(rel_dpsi[t]) > thres:
                    dpsi = rel_dpsi[t]
                    break
            else:
                dpsi = rel_dpsi[t]
            if thres == threshold:
                lengths.append((t1 - t0) / data.Hz)
                lengths_thres.append((t1 - t) / data.Hz)
            if np.sign(dpsi) == data.info['pass_order'][i]:
                correct += 1
            total += 1
    accuracies.append(correct / total)
plt.figure()
plt.plot(thress, accuracies)
plt.title('Cohen_movObst1')
plt.ylabel('pass order accuracy')
plt.xlabel('dthetha * dpsi threshold (rad/s)')
plt.figure()
plt.hist(lengths_thres)
plt.title(f'Cohen_movObst1 thres={threshold}')
plt.xlabel('effective trial length (s)')
plt.figure()
plt.hist(lengths)
plt.title('Cohen_movObst1')
plt.xlabel('effective trial length (s)')

In [None]:
'''% Correct pass order by dtheta and dpsi'''
%matplotlib qt
load_data('Cohen_movObst1_data')
dthetas = np.linspace(0, 0.02, 10)
dpsis = np.linspace(0, 0.04, 10)
accuracies = np.zeros((len(dthetas), len(dpsis)))
tpps = np.zeros((len(dthetas), len(dpsis)))
for a, dtheta_thres in enumerate(dthetas):
    for b, dpsi_thres in enumerate(dpsis):
        correct = 0
        total = 0
        tpp = []
        for i in range(len(data.trajs)):
            if i not in data.dump:
                t0 = data.info['obst_onset'][i]
                t1 = data.info['obst_out'][i]
                rel_dpsi = data.info['rel_dpsi'][i]
                dtheta = data.info['dtheta'][i]
                for t in range(t0, t1):
                    if dtheta[t] > dtheta_thres and abs(rel_dpsi[t]) > dpsi_thres:
                        dpsi = rel_dpsi[t]
                        break
                else:
                    dpsi = rel_dpsi[t]
                tpp.append((t1 - t) / data.Hz)
                if np.sign(dpsi) == data.info['pass_order'][i]:
                    correct += 1
                total += 1
        accuracies[a, b] = round(correct / total, 2)
        tpps[a, b] = round(np.mean(tpp), 2)

# Heatmap of pass accuracy
plt.figure()
plt.imshow(accuracies, cmap='plasma')
for i in range(accuracies.shape[0]):
    for j in range(accuracies.shape[1]):
        text = plt.text(j, i, str(accuracies[i, j]),
                       ha="center", va="center", color="w")
plt.xticks(range(0, accuracies.shape[0]), np.round(dthetas, 3))
plt.yticks(range(0, accuracies.shape[1]), np.round(dpsis, 3))
plt.xlabel("dtheta (rad/s)")
plt.ylabel("dpsi (rad/s)")
plt.title('Pass order accuracy of AND threshold')

# Heatmap of trial length loss
plt.figure()
plt.imshow(tpps, cmap='plasma')
for i in range(tpps.shape[0]):
    for j in range(tpps.shape[1]):
        text = plt.text(j, i, str(tpps[i, j]),
                       ha="center", va="center", color="w")
plt.xticks(range(0, tpps.shape[0]), np.round(dthetas, 3))
plt.yticks(range(0, tpps.shape[1]), np.round(dpsis, 3))
plt.xlabel("dtheta (rad/s)")
plt.ylabel("dpsi (rad/s)")
plt.title('Time before passing using AND threshold')
        

### Cohen_movObst2

In [None]:
'''Plot speed'''
for i in range(len(data.trajs)):
    if data.info['obst_out'][i] and data.info['obst_onset'][i] and i not in data.dump:
        t0, t1 = data.info['obst_onset'][i], data.info['obst_out'][i]
        v_subj = data.info['v_subj'][i]
        s_subj = norm(v_subj, axis=1)
        plt.plot(s_subj[t0:t1], linewidth=0.1, alpha=0.5)

In [None]:
'''Animate data'''
############
subject = 3
trial = 111
############
%matplotlib qt
i = subject * 160 + trial - 1
i = 252
# p_obst = np.array(data.info['p_obst'][i])
t0 = data.info['obst_onset'][i]
t1 = data.info['obst_out'][i]
p_goal = data.info['p_goal'][i]
p_subj = data.info['p_subj'][i]
p_obst = data.info['p_obst'][i]
trajs = [p_goal, p_obst, p_subj]
ws = [data.info['w_goal'], data.info['w_obst'], 0.4]
title = 'subj ' + str(data.info['subj_id'][i]) + ' trial ' + str(data.info['trial_id'][i]) + ' obst_angle: ' + str(data.info['obst_angle'][i]) + ' obst_dist: ' + str(data.info['obst_dist'][i])
play_trajs(trajs, ws, data.Hz, title=title, save=False)


In [None]:
data.info.keys()

In [None]:
'''Space averaged traj by condition'''
%matplotlib qt
y_max = float('inf')
y_min = -float('inf')
for i in range(len(data.trajs)):
    if i in data.dump:
        continue
    traj = np.array(data.info['p_subj'][i])
    y_max = min(y_max, max(traj[:, 1]))
    y_min = max(y_min, min(traj[:, 1]))
print(f'common y range among all trials is [{y_min}, {y_max}]')

# Compute y positions
y_step = 0.01
y = np.linspace(y_min, y_max, int((y_max - y_min) / y_step))
# Interpolate x by y
iplot = 1
angles = sorted(set([abs(x) for x in data.info['obst_angle'] if x]))
dists = sorted(set([x for x in data.info['obst_dist'] if x]))
fig = plt.figure()
for dist in dists:
    for angle in angles:
        ax = fig.add_subplot(len(dists), len(angles), iplot)
        iplot += 1
        ax.set_title(f'angle {angle}°, distance {dist} m')
        ax.set_xlabel('x (m)')
        ax.set_ylabel('y (m)')
        ax.set_aspect('equal')
        ax.set_xlim(-2.2, 2.2)
        ax.set_ylim(-0.2, 8.2)
        front = []
        behind = []
        # Plot trials
        for i in range(len(data.trajs)):
            if not data.info['obst_angle'][i] or angle != abs(data.info['obst_angle'][i]) or dist != data.info['obst_dist'][i] or i in data.dump:
                continue
            traj = np.array(data.info['p_subj'][i])
            if data.info['obst_angle'][i] < 0:
                traj[:,0] = -traj[:,0]
            x = np.interp(y, traj[:,1], traj[:,0])
            order = data.info['pass_order'][i]
            if order == 1:
                front.append(x)
                ax.plot(traj[:, 0], traj[:, 1], color='r', linewidth=0.1, alpha=0.5)
            elif order == -1:
                behind.append(x)
                ax.plot(traj[:, 0], traj[:, 1], color='b', linewidth=0.1, alpha=0.5)
        # Plot average trials
        front = np.mean(front, axis=0)
        behind = np.mean(behind, axis=0)
        ax.plot(front, y, color='r')
        ax.plot(behind, y, color='b')

In [None]:
'''% Correct pass order by dtheta threshold'''
load_data('Cohen_movObst2_data')
threshold = 0.02
thress = np.linspace(0.001, 0.03, 20)
threshold = thress[np.argmin(np.abs(thress - threshold))]
accuracies = []
lengths_thres = []
lengths = []
for thres in thress:
    correct = 0
    total = 0
    for i in range(len(data.trajs)):
        if i not in data.dump and data.info['obst_speed'][i]:
            t0 = data.info['obst_onset'][i]
            t1 = data.info['obst_out'][i]
            rel_dpsi = data.info['rel_dpsi'][i]
            dtheta = data.info['dtheta'][i]
            for t in range(t0, t1):
                if dtheta[t] > thres:
                    dpsi = rel_dpsi[t]
                    break
            else:
                dpsi = rel_dpsi[t]
            if thres == threshold:
                lengths.append((t1 - t0) / data.Hz)
                lengths_thres.append((t1 - t) / data.Hz)
            if np.sign(dpsi) == data.info['pass_order'][i]:
                correct += 1
            total += 1
    accuracies.append(correct / total)
plt.figure()
plt.plot(thress, accuracies)
plt.title('Cohen_movObst2')
plt.ylabel('pass order accuracy')
plt.xlabel('dthetha threshold (rad/s)')
plt.figure()
plt.hist(lengths_thres)
plt.title(f'Cohen_movObst2 thres={threshold}')
plt.xlabel('effective trial length (s)')
plt.figure()
plt.hist(lengths)
plt.title('Cohen_movObst2')
plt.xlabel('effective trial length (s)')

In [None]:
'''% Correct pass order by dtheta + dpsi threshold'''
load_data('Cohen_movObst2_data')
threshold = 0.05
thress = np.linspace(0.001, 0.1, 20)
threshold = thress[np.argmin(np.abs(thress - threshold))]
accuracies = []
lengths_thres = []
lengths = []
for thres in thress:
    correct = 0
    total = 0
    for i in range(len(data.trajs)):
        if i not in data.dump and data.info['obst_speed'][i]:
            t0 = data.info['obst_onset'][i]
            t1 = data.info['obst_out'][i]
            rel_dpsi = data.info['rel_dpsi'][i]
            dtheta = data.info['dtheta'][i]
            for t in range(t0, t1):
                if dtheta[t] + abs(rel_dpsi[t]) > thres:
                    dpsi = rel_dpsi[t]
                    break
            else:
                dpsi = rel_dpsi[t]
            if thres == threshold:
                lengths.append((t1 - t0) / data.Hz)
                lengths_thres.append((t1 - t) / data.Hz)
            if np.sign(dpsi) == data.info['pass_order'][i]:
                correct += 1
            total += 1
    accuracies.append(correct / total)
plt.figure()
plt.plot(thress, accuracies)
plt.title('Cohen_movObst2')
plt.ylabel('pass order accuracy')
plt.xlabel('dthetha threshold (rad/s)')
plt.figure()
plt.hist(lengths_thres)
plt.title(f'Cohen_movObst2 thres={threshold}')
plt.xlabel('effective trial length (s)')
plt.figure()
plt.hist(lengths)
plt.title('Cohen_movObst2')
plt.xlabel('effective trial length (s)')

In [None]:
'''% Correct pass order by dtheta * dpsi threshold'''
load_data('Cohen_movObst2_data')
threshold = 0.0005
thress = np.linspace(1E-4, 2E-3, 20)
threshold = thress[np.argmin(np.abs(thress - threshold))]
accuracies = []
lengths_thres = []
lengths = []
for thres in thress:
    correct = 0
    total = 0
    for i in range(len(data.trajs)):
        if i not in data.dump and data.info['obst_speed'][i]:
            t0 = data.info['obst_onset'][i]
            t1 = data.info['obst_out'][i]
            rel_dpsi = data.info['rel_dpsi'][i]
            dtheta = data.info['dtheta'][i]
            for t in range(t0, t1):
                if dtheta[t] * abs(rel_dpsi[t]) > thres:
                    dpsi = rel_dpsi[t]
                    break
            else:
                dpsi = rel_dpsi[t]
            if thres == threshold:
                lengths.append((t1 - t0) / data.Hz)
                lengths_thres.append((t1 - t) / data.Hz)
            if np.sign(dpsi) == data.info['pass_order'][i]:
                correct += 1
            total += 1
    accuracies.append(correct / total)
plt.figure()
plt.plot(thress, accuracies)
plt.title('Cohen_movObst2')
plt.ylabel('pass order accuracy')
plt.xlabel('dthetha * dpsi threshold (rad/s)')
plt.figure()
plt.hist(lengths_thres)
plt.title(f'Cohen_movObst2 thres={threshold}')
plt.xlabel('effective trial length (s)')
plt.figure()
plt.hist(lengths)
plt.title('Cohen_movObst2')
plt.xlabel('effective trial length (s)')

In [None]:
'''% Correct pass order by dtheta and dpsi'''
%matplotlib qt
load_data('Cohen_movObst2_data')
dthetas = np.linspace(0, 0.02, 10)
dpsis = np.linspace(0, 0.04, 10)
accuracies = np.zeros((len(dthetas), len(dpsis)))
tpps = np.zeros((len(dthetas), len(dpsis)))
for a, dtheta_thres in enumerate(dthetas):
    for b, dpsi_thres in enumerate(dpsis):
        correct = 0
        total = 0
        tpp = []
        for i in range(len(data.trajs)):
            if i not in data.dump and data.info['obst_speed'][i]:
                t0 = data.info['obst_onset'][i]
                t1 = data.info['obst_out'][i]
                rel_dpsi = data.info['rel_dpsi'][i]
                dtheta = data.info['dtheta'][i]
                for t in range(t0, t1):
                    if dtheta[t] > dtheta_thres and abs(rel_dpsi[t]) > dpsi_thres:
                        dpsi = rel_dpsi[t]
                        break
                else:
                    dpsi = rel_dpsi[t]
                tpp.append((t1 - t) / data.Hz)
                if np.sign(dpsi) == data.info['pass_order'][i]:
                    correct += 1
                total += 1
        accuracies[a, b] = round(correct / total, 2)
        tpps[a, b] = round(np.mean(tpp), 2)

# Heatmap of pass accuracy
plt.figure()
plt.imshow(accuracies, cmap='plasma')
for i in range(accuracies.shape[0]):
    for j in range(accuracies.shape[1]):
        text = plt.text(j, i, str(accuracies[i, j]),
                       ha="center", va="center", color="w")
plt.xticks(range(0, accuracies.shape[0]), np.round(dthetas, 3))
plt.yticks(range(0, accuracies.shape[1]), np.round(dpsis, 3))
plt.xlabel("dtheta (rad/s)")
plt.ylabel("dpsi (rad/s)")
plt.title('Pass order accuracy of AND threshold')

# Heatmap of trial length loss
plt.figure()
plt.imshow(tpps, cmap='plasma')
for i in range(tpps.shape[0]):
    for j in range(tpps.shape[1]):
        text = plt.text(j, i, str(tpps[i, j]),
                       ha="center", va="center", color="w")
plt.xticks(range(0, tpps.shape[0]), np.round(dthetas, 3))
plt.yticks(range(0, tpps.shape[1]), np.round(dpsis, 3))
plt.xlabel("dtheta (rad/s)")
plt.ylabel("dpsi (rad/s)")
plt.title('Time before passing using AND threshold')
        

### Bai_movObst1

In [None]:
'''trial lookup'''
i = 167
t0, t1 = data.info['obst_onset'][i], data.info['obst_out'][i]
data.info['rel_dpsi'][167][t1+1]

In [None]:
'''Counting trials'''
cnt = 0
for i in range(len(data.trajs)):
    if i not in data.dump and data.info['obst_speed'][i] > 0 and abs(data.info['obst_angle'][i]) != 180:
        cnt += 1
cnt

In [None]:
'''effective data length'''
%matplotlib qt
l = []
sim = ODESimulator(data=data, ref=[0,1])
for i in range(len(data.trajs)):
    if i not in data.dump and data.info['obst_speed'][i] > 0 and abs(data.info['obst_angle'][i]) != 180:
#         t0 = data.info['obst_onset'][i]
#         t0 = data.info['decision_onset'][i]
        t0 = data.info['threshold_onset'][i]
        t1 = data.info['obst_out'][i]
        if t0 and t1:
            l.append((t1 - t0) / data.Hz)
#             if t1 - t0 < 2:
#                 print(i, t0, t1)
#                 order = []
#                 for t in range(t0, t1):
#                     xg, yg, xo, yo, vxo, vyo, x, y, vx, vy, a, phi, s, dphi, ds, w0 = sim.compute_var0(i, t)
#                     # When beta and dpsi has the same sign it means pass in front, otherwise it means pass from behind
#                     dpsi = d_psi([x, y], [xo, yo], [vx, vy], [vxo, vyo])
#                     angle = sim.data.info['obst_angle'][i]
#                     order.append(-dpsi * np.sign(angle))
#                 plt.plot(order)
#                 break
plt.hist(l, bins=40)
# plt.plot(l)

In [None]:
'''Animate a trial'''
############
subject = 3
trial = 111
############
%matplotlib qt
i = subject * 160 + trial - 1
i = 1
# p_obst = np.array(data.info['p_obst'][i])
td = data.info['decision_onset'][i]
t0 = data.info['obst_onset'][i]
t1 = data.info['obst_out'][i]
p_goal = data.info['p_goal'][i]
p_subj = data.info['p_subj'][i]
p_obst = data.info['p_obst'][i]
trajs = [p_goal, p_obst, p_subj]
ws = [data.info['w_goal'], data.info['w_obst'], 0.4]
title = 'subj ' + str(data.info['subj_id'][i]) + ' trial ' + str(data.info['trial_id'][i]) + ' obst_angle: ' + str(data.info['obst_angle'][i]) + ' obst_speed: ' + str(data.info['obst_speed'][i])
print(data.info['subj_id'][i], data.info['trial_id'][i], td, t1)
for t in range(t0, t1):
    if data.info['dtheta'][i][t] > 0.02:
        print(f'dtheta = {data.info["dtheta"][i][t]} at {t/data.Hz} seconds')
        break
play_trajs(trajs, ws, data.Hz, title=title, save=False)


### Bai_movObst1: Plot trajectories

In [None]:
'''Plot data by condition'''
#####################
subject = 3
con_ang = [112.5]
con_spd = [1.1]
con_ipd = [0, 0.07]
# con_ang = set(data.info['obst_angle'])
# con_spd = set(data.info['obst_speed'])
#####################
%matplotlib qt
plt.figure()
n = 0
for i in range(len(data.trajs)):
    obst_speed = data.info['obst_speed'][i]
    obst_angle = data.info['obst_angle'][i]
    ipd = data.info['ipd'][i]
    subj_id = data.info['subj_id'][i]
    if subject != -1 and subj_id != subject:
        continue
    if not (obst_speed in con_spd and abs(obst_angle) in con_ang and ipd in con_ipd):
        continue
    subj = data.info['p_subj'][i]
    obst = np.array(data.info['p_obst'][i])
    n += 1
    if i % 2 == 0:
        subj = rotate(subj, np.arctan(11 / 9) - pi / 2)
        obst = rotate(obst, np.arctan(11 / 9) - pi / 2)
    else:
        subj = rotate(subj, np.arctan(11 / 9) + pi / 2)
        obst = rotate(obst, np.arctan(11 / 9) + pi / 2)
    if obst_angle < 0:
        subj[:, 0] *= -1
        obst[:, 0] *= -1
    if ipd == 0:
        plt.plot(subj[:, 0], subj[:, 1], color='r')
    else:
        plt.plot(subj[:, 0], subj[:, 1], color='b')
    plt.plot(obst[:, 0], obst[:, 1])
ax = plt.gca()
# ax.set_aspect('equal')
ax.set_title('subj ' + str(subject) + ' angle: ' + str(con_ang[0]) + ' speed: ' + str(con_spd[0]))
print(n)

In [None]:
'''Plot data by subject'''
#####################
subject = 13
#####################
%matplotlib qt
fig = plt.figure()
fig.suptitle('Subject ' + str(subject))
axes = {}
obst_angle = [90, 112.5, 135, 157.5, 180]
obst_speed = [0.9, 1.0, 1.1, 1.2, 1.3]
i_plot = 1
for angle in obst_angle:
    for speed in obst_speed:
        axes[(angle, speed)] = fig.add_subplot(5, 5, i_plot)
        axes[(angle, speed)].set_xlim(-3, 3)
        axes[(angle, speed)].set_ylim(-7, 5)
        axes[(angle, speed)].set_title(str(angle) + '° ' + str(speed) + 'm/s')
        axes[(angle, speed)].set_aspect('equal')
        i_plot += 1
for i in range(len(data.trajs)):
    speed = data.info['obst_speed'][i]
    angle = data.info['obst_angle'][i]
    subj_id = data.info['subj_id'][i]
    if subj_id != subject or speed == 0:
        continue
    subj = np.array(data.info['p_subj'][i])
    obst = np.array(data.info['p_obst'][i])
    if i % 2 == 0:
        subj = rotate(subj, np.arctan(11 / 9) - pi / 2)
        obst = rotate(obst, np.arctan(11 / 9) - pi / 2)
    else:
        subj = rotate(subj, np.arctan(11 / 9) + pi / 2)
        obst = rotate(obst, np.arctan(11 / 9) + pi / 2)
    if angle < 0:
        subj[:, 0] *= -1
        obst[:, 0] *= -1
    axes[(abs(angle), speed)].plot(subj[:, 0], subj[:, 1])
    axes[(abs(angle), speed)].plot(obst[:, 0], obst[:, 1])

In [14]:
'''Plot obst trajectories'''
load_data('Bai_movObst1_data_30Hz')
visited = set()
for i in range(len(data.trajs)):
    angle = data.info['obst_angle'][i]
    speed = data.info['obst_speed'][i]
    if speed != 0 and (angle, speed) not in visited:
        visited.add((angle, speed))
        obst = data.info['p_obst'][i]
        if data.trajs[i][0][0] > 0:
            obst = rotate(obst, np.arctan(11 / 9) + pi / 2)
        else:
            obst = rotate(obst, np.arctan(11 / 9) - pi / 2)
        plt.plot(obst[:data.Hz*10, 0], obst[:data.Hz*10, 1], 'k')
        plt.scatter(obst[0, 0], obst[0, 1], c='k', marker='x')
plt.scatter(0, -7.1, c='k', marker='o')
plt.annotate("Participant", (0, -12.1))
plt.gca().set_aspect('equal') 
plt.xlabel("x (m)")
plt.ylabel("y (m)")
visited

{(-180.0, 0.9),
 (-180.0, 1.0),
 (-180.0, 1.1),
 (-180.0, 1.2),
 (-180.0, 1.3),
 (-157.5, 0.9),
 (-157.5, 1.0),
 (-157.5, 1.1),
 (-157.5, 1.2),
 (-157.5, 1.3),
 (-135.0, 0.9),
 (-135.0, 1.0),
 (-135.0, 1.1),
 (-135.0, 1.2),
 (-135.0, 1.3),
 (-112.5, 0.9),
 (-112.5, 1.0),
 (-112.5, 1.1),
 (-112.5, 1.2),
 (-112.5, 1.3),
 (-90.0, 0.9),
 (-90.0, 1.0),
 (-90.0, 1.1),
 (-90.0, 1.2),
 (-90.0, 1.3),
 (90.0, 0.9),
 (90.0, 1.0),
 (90.0, 1.1),
 (90.0, 1.2),
 (90.0, 1.3),
 (112.5, 0.9),
 (112.5, 1.0),
 (112.5, 1.1),
 (112.5, 1.2),
 (112.5, 1.3),
 (135.0, 0.9),
 (135.0, 1.0),
 (135.0, 1.1),
 (135.0, 1.2),
 (135.0, 1.3),
 (157.5, 0.9),
 (157.5, 1.0),
 (157.5, 1.1),
 (157.5, 1.2),
 (157.5, 1.3),
 (180.0, 0.9),
 (180.0, 1.0),
 (180.0, 1.1),
 (180.0, 1.2),
 (180.0, 1.3)}

In [None]:
'''Space averaged traj by condition'''
%matplotlib qt
y_max = float('inf')
for i in range(len(data.trajs)):
    if angle != abs(data.info['obst_angle'][i]) or speed != data.info['obst_speed'][i] or i in data.dump:
        continue
    traj = np.array(data.info['p_subj'][i])
    traj = -traj if i % 2 == 1 else traj
    traj = rotate(traj, np.arctan(11 / 9) - pi / 2) - traj[0, 0]
    y_max = min(y_max, max(traj[:, 1]))
    print(i, y_max)
print(f'minimum y value among all trials is {y_max}')

# Compute y positions
y_step = 0.01
y = np.linspace(0, y_max, int(y_max / y_step))
# Interpolate x by y
for angle in set(np.abs(data.info['obst_angle'])):
    for speed in set(data.info['obst_speed']):
        fig = plt.figure()
        ax = fig.add_subplot(1,1,1)
        ax.set_title(f'obst_angle {angle}, obst_speed {speed}')
        ax.set_xlabel('x (m)')
        ax.set_ylabel('y (m)')
        ax.set_aspect('equal')
        ax.set_xlim(-1.5, 1.5)
        ax.set_ylim(-0.5, 8.5)
        front = []
        behind = []
        # Plot trials
        for i in range(len(data.trajs)):
            if angle != abs(data.info['obst_angle'][i]) or speed != data.info['obst_speed'][i] or i in data.dump:
                continue
            traj = np.array(data.info['p_subj'][i])
            if traj[0, 0] < 0:
                traj = rotate(traj, np.arctan(11 / 9) - pi / 2) - traj[0, 0]
            else:
                traj = rotate(traj, np.arctan(11 / 9) + pi / 2) - traj[0, 0]
            x = np.interp(y, traj[:,1], traj[:,0])
            if order == 1:
                front.append(x)
            elif order == -1:
                behind.append(x)
        # Plot average trials
        front = np.mean(front, axis=0)
        behind = np.mean(behind, axis=0)
        ax.plot(front, y, color='k')
        ax.plot(behind, y, color='k')

In [None]:
'''Space averaged traj by condition'''
%matplotlib qt
load_data('Bai_movObst1_data_30Hz')
y_max = float('inf')
y_min = -float('inf')
for i in range(len(data.trajs)):
    if i in data.dump or data.info['obst_speed'][i] == 0:
        continue
    traj = np.array(data.info['p_subj'][i])
    traj = -traj if i % 2 == 1 else traj
    traj = traj - traj[0]
    traj = rotate(traj, np.arctan(11 / 9) - pi / 2)
    y_max = min(y_max, max(traj[:, 1]))
    y_min = max(y_min, min(traj[:, 1]))
print(f'common y range among all trials is [{y_min}, {y_max}]')

# Compute y positions
y_step = 0.01
y = np.linspace(y_min, y_max, int((y_max - y_min) / y_step))
# Interpolate x by y
iplot = 1
angles = sorted(set([abs(x) for x in data.info['obst_angle'] if x]))
speeds = sorted(set([abs(x) for x in data.info['obst_speed'] if x]))
print(angles, speeds)
fig = plt.figure()
for angle in angles:
    for speed in speeds:
        ax = fig.add_subplot(len(angles), len(speeds), iplot)
        iplot += 1
        ax.set_title(f'{angle}°, {speed} m/s')
        if angle == 180:
            ax.set_xlabel('x (m)')
        if speed == 0.9:
            ax.set_ylabel('y (m)')
        ax.set_aspect('equal')
        ax.set_xlim(-2, 2)
        ax.set_ylim(-0.2, 12)
        front = []
        behind = []
        # Plot trials
        for i in range(len(data.trajs)):
            if angle != abs(data.info['obst_angle'][i]) or speed != data.info['obst_speed'][i] or i in data.dump:
                continue
            traj = np.array(data.info['p_subj'][i])
            traj = -traj if i % 2 == 1 else traj
            traj = traj - traj[0]
            traj = rotate(traj, np.arctan(11 / 9) - pi / 2)
            if data.info['obst_angle'][i] < 0:
                traj[:,0] = -traj[:,0]
            x = np.interp(y, traj[:,1], traj[:,0])
            order = data.info['pass_order'][i]
            if order == 1:
                front.append(x)
                ax.plot(traj[:, 0], traj[:, 1], color='r', linewidth=0.1, alpha=0.5)
            elif order == -1:
                behind.append(x)
                ax.plot(traj[:, 0], traj[:, 1], color='b', linewidth=0.1, alpha=0.5)
        # Plot average trials
        front = np.mean(front, axis=0)
        behind = np.mean(behind, axis=0)
        ax.plot(front, y, color='r')
        ax.plot(behind, y, color='b')

### Bai_movObst1: dpsi and order of passing

In [None]:
'''Compute initial passing order (-dpsi * sign(obst_angle)) time series'''
%matplotlib qt
orders = {}
conditions = np.zeros((5, 8))
angle_ins = {-157.5: 0, -135: 1, -112.5: 2, -90: 3, 90: 4, 112.5: 5, 135: 6, 157.5: 7}
speed_ins = {0.9: 0, 1: 1, 1.1: 2, 1.2: 3, 1.3: 4}
for i in range(len(data.trajs)):
    if i not in data.dump and data.info['obst_speed'][i] != 0 and abs(data.info['obst_angle'][i]) != 180:
        rel_dpsi = data.info['rel_dpsi'][i]
        subj = data.info['subj_id'][i]
        angle = data.info['obst_angle'][i]
        speed = data.info['obst_speed'][i]
        t0, t1 = data.info['obst_onset'][i], data.info['obst_out'][i]
        if subj not in orders:
            orders[subj] = {'correct': [], 'incorrect': []}
        t0, t1 = data.info['obst_onset'][i], data.info['obst_out'][i]
        if np.sign(rel_dpsi[t0]) == data.info['pass_order'][i]:
            orders[subj]['correct'].append(rel_dpsi)
        else:
            orders[subj]['incorrect'].append(rel_dpsi)
            conditions[speed_ins[speed], angle_ins[angle]] += 1
# Correct order ratio
correct = incorrect = 0
for order in orders.values():
    correct += len(order['correct'])
    incorrect += len(order['incorrect'])
print(correct, incorrect, correct / (correct + incorrect))

In [None]:
'''Plot conditions of incorrect trial'''
#Scatter plot
# x = []
# y = []
# z = []
# for angle in angle_ins.values():
#     for speed in speed_ins.values():
#         x.append(angle)
#         y.append(speed)
#         z.append(conditions[angle, speed])
#         for _ in range(int(conditions[angle, speed])):
#             plt.scatter(angle + random.uniform(-0.25, 0.25), speed + random.uniform(-0.25, 0.25))

# plt.xticks(range(-1, 9), [""] + list(angle_ins.keys()) + [""])
# plt.yticks(range(-1, 5), [""] + list(speed_ins.keys()) + [""])

# Heatmap
plt.imshow(conditions, cmap='plasma')
for i in range(conditions.shape[0]):
    for j in range(conditions.shape[1]):
        text = plt.text(j, i, str(conditions[i, j]),
                       ha="center", va="center", color="w")
plt.xticks(range(-1, 9), [""] + list(angle_ins.keys()) + [""])
plt.yticks(range(-1, 5), [""] + list(speed_ins.keys()) + [""])
plt.xlabel("angle")
plt.ylabel("speed")

# 3D bar graph
# fig = plt.figure()
# ax = fig.add_subplot(projection='3d')
# ax.bar3d(x,y,0,0.5,0.5,z)
# ax.set_xticks(angle_ins.keys())
# ax.set_yticks(speed_ins.keys())

In [None]:
'''Check initial dpsi and subject passing choice, plot dpsi of matching and non-matching trials'''
%matplotlib qt
subjects = set(data.info['subj_id'])
print(subjects)
side_pred = []
side_true = []
dpsi_match = []
dpsi_not = []
for i in range(len(sim.data.trajs)):
    if i not in data.dump and data.info['obst_speed'][i] != 0 and abs(data.info['obst_angle'][i]) != 180:
        # When angle and dpsi has the opposite sign it means pass in front, otherwise it means pass from behind
        t0 = data.info['obst_onset'][i]
        rel_dpsi = data.info['rel_dpsi'][i][t0]
        side_pred.append(1 if rel_dpsi > 0 else -1)
        side_true.append(sim.data.info['pass_order'][i])
        if side_pred[-1] == side_true[-1]:
            dpsi_match.append(rel_dpsi)
        else:
            dpsi_not.append(rel_dpsi)
print('passing order matching rate ', accuracy_score(side_true, side_pred))
thres = 0.02
print(f'{np.mean(np.absolute(np.array(dpsi_not)) < thres) * 100}% mis-matching trials have an intial dpsi < {thres}')
plt.scatter(np.random.uniform(size=len(dpsi_match)), dpsi_match, label='matching trials')
plt.scatter(np.random.uniform(size=len(dpsi_not)), dpsi_not, label='mis-matching trials')
plt.ylabel('dpsi')
plt.xlabel('arbitrary')
plt.legend()


In [None]:
'''Plot relative dpsi * sign(obst_angle) time series with avoidance onset points'''

dtheta_thres = []
product_thres = []
for i in range(len(data.trajs)):
    if i not in data.dump and data.info['obst_speed'][i] > 0 and abs(data.info['obst_angle'][i]) != 180:
        t0 = data.info['obst_onset'][i]
        t1 = data.info['obst_out'][i]
        rel_dpsi = data.info['rel_dpsi'][i][t0:t1]
        dtheta = data.info['dtheta'][i][t0:t1]
        # Plot minimum |dpsi|
        mi = np.argmin(np.absolute(rel_dpsi))
        tt = data.info['threshold_onset'][i] - t0
        plt.scatter(mi, rel_dpsi[mi], c='k')
        # Plot maximum |dpsi| from time 0 to t_minimum|dpsi|
        ma = np.argmax(np.absolute(rel_dpsi[:mi + 1]))
        plt.scatter(ma, rel_dpsi[ma], c='r')
        dtheta_thres.append(dtheta[ma])
        product_thres.append(abs(rel_dpsi[ma] * dtheta[ma]))
        if np.sign(rel_dpsi[0]) == data.info['pass_order'][i]:
            plt.plot(range(len(rel_dpsi)), rel_dpsi, linewidth=0.1, alpha=0.5, color='g')
        else:
            plt.plot(range(len(rel_dpsi)), rel_dpsi, linewidth=0.1, alpha=0.5, color='r')
plt.xlabel("Time since obst appear", fontsize=20)
plt.ylabel("dpsi (rad/s), positive means front", fontsize=20)
plt.figure()
plt.hist(dtheta_thres, bins=40)
plt.figure()
plt.hist(product_thres, bins=40)

In [None]:
'''Plot dtheta time series'''
%matplotlib qt
load_data('Bai_movObst1_data_30Hz')
for i in range(len(data.trajs)):
    if i not in data.dump and data.info['obst_speed'][i] > 0 and abs(data.info['obst_angle'][i]) != 180:
        t0, t1 = data.info['obst_onset'][i], data.info['obst_out'][i]
        dtheta = data.info['dtheta'][i][t0:t1]
        if data.info['pass_order'][i] > 0:
            plt.plot(range(len(dtheta)), dtheta, linewidth=0.1, alpha=0.5, color='g')
        else:
            plt.plot(range(len(dtheta)), dtheta, linewidth=0.1, alpha=0.5, color='r')
        for t in range(t1 - t0):
            if dtheta[t] > 0.02:
                plt.scatter(t, dtheta[t])
                break
plt.xlabel("Time since obst appear", fontsize=20)
plt.ylabel("dtheta (rad/s)", fontsize=20)

In [None]:
'''Plot dpsi * dthetha * sign(obst_angle)) time series'''
%matplotlib qt
for i in range(len(data.trajs)):
    if i not in data.dump and data.info['obst_speed'][i] > 0 and abs(data.info['obst_angle'][i]) != 180:
        t0, t1 = data.info['obst_onset'][i], data.info['obst_out'][i]
        rel_dpsi = data.info['rel_dpsi'][i][t0:t1]
        dtheta = data.info['dtheta'][i][t0:t1]
        if data.info['pass_order'][i] > 0:
            plt.plot(range(len(rel_dpsi)), rel_dpsi * dtheta, linewidth=0.1, alpha=0.5, color='g')
        else:
            plt.plot(range(len(rel_dpsi)), rel_dpsi * dtheta, linewidth=0.1, alpha=0.5, color='r')
plt.xlabel("Time since obst appear", fontsize=20)
plt.ylabel("dpsi*dtheta), positive means front", fontsize=20)

In [None]:
'''% Correct pass order by dpsi threshold'''
thress = []
for i in range(len(data.trajs)):
    if i not in data.dump and data.info['obst_speed'][i] > 0 and abs(data.info['obst_angle'][i]) != 180:
        t0 = data.info['obst_onset'][i]
        t1 = data.info['obst_out'][i]
        rel_dpsi = data.info['rel_dpsi'][i]
        # Find minimum |dpsi|
        mi = np.argmin(np.absolute(rel_dpsi[t0:t1]))
        # Find maximum |dpsi| from time 0 to t_minimum|dpsi|
        thres = max(np.absolute(rel_dpsi[t0:t0 + mi + 1]))
        thress.append(thres)
# evaluate the histogram
values, base = np.histogram(thress, bins=1000)
#evaluate the cumulative
cumulative = np.cumsum(values)
# plot the cumulative function
plt.plot(base[:-1], cumulative / len(thress), c='blue')
plt.xlabel("threshold (rad/s)", fontsize=14)
plt.ylabel("trials with smaller threshold (100%)", fontsize=14)


# Mark point of interest
for v in [0.02, 0.03, 0.04]:
    x = np.argmin(np.absolute(base - v))
    y = round(cumulative[x] / len(thress), 2)
    x = round(base[x], 2)
    plt.scatter(x, y, c='r')
    plt.annotate(f"{x}, {round(y * 100, 2)}%", (x+0.0025, y-0.05), fontsize=14)

In [None]:
'''% Correct pass order by dtheta threshold'''
%matplotlib qt
load_data('Bai_movObst1_data_30Hz')
threshold = 0.02
thress = np.linspace(0.001, 0.03, 20)
threshold = thress[np.argmin(np.abs(thress - threshold))]
accuracies = []
lengths_thres = []
lengths = []
for thres in thress:
    correct = 0
    total = 0
    for i in range(len(data.trajs)):
        if i not in data.dump and data.info['obst_speed'][i] > 0 and abs(data.info['obst_angle'][i]) != 180:
            t0 = data.info['obst_onset'][i]
            t1 = data.info['obst_out'][i]
            rel_dpsi = data.info['rel_dpsi'][i]
            dtheta = data.info['dtheta'][i]
            for t in range(t0, t1):
                if dtheta[t] > thres:
                    dpsi = rel_dpsi[t]
                    break
            else:
                dpsi = rel_dpsi[t]
            if thres == threshold:
                lengths.append((t1 - t0) / data.Hz)
                lengths_thres.append((t1 - t) / data.Hz)
            if np.sign(dpsi) == data.info['pass_order'][i]:
                correct += 1
            total += 1
    accuracies.append(correct / total)
plt.figure()
plt.plot(thress, accuracies)
plt.title('Bai_movObst1')
plt.ylabel('pass order accuracy')
plt.xlabel('dthetha threshold (rad/s)')
plt.figure()
plt.hist(lengths_thres)
plt.title(f'Bai_movObst1 thres={threshold}')
plt.xlabel('effective trial length (s)')
plt.figure()
plt.hist(lengths)
plt.title('Bai_movObst1')
plt.xlabel('effective trial length (s)')

In [18]:
'''% Correct pass order by dtheta + dpsi threshold'''
%matplotlib qt
load_data('Bai_movObst1_data_30Hz')
threshold = 0.05
thress = np.linspace(0.001, 0.1, 20)
threshold = thress[np.argmin(np.abs(thress - threshold))]
accuracies = []
lengths_thres = []
lengths = []
for thres in thress:
    correct = 0
    total = 0
    for i in range(len(data.trajs)):
        if i not in data.dump and data.info['obst_speed'][i] > 0 and abs(data.info['obst_angle'][i]) != 180:
            t0 = data.info['obst_onset'][i]
            t1 = data.info['obst_out'][i]
            rel_dpsi = data.info['rel_dpsi'][i]
            dtheta = data.info['dtheta'][i]
            for t in range(t0, t1):
                if dtheta[t] + abs(rel_dpsi[t]) > thres:
                    dpsi = rel_dpsi[t]
                    break
            else:
                dpsi = rel_dpsi[t]
            if thres == threshold:
                lengths.append((t1 - t0) / data.Hz)
                lengths_thres.append((t1 - t) / data.Hz)
            if np.sign(dpsi) == data.info['pass_order'][i]:
                correct += 1
            total += 1
    accuracies.append(correct / total)
plt.figure()
plt.plot(thress, accuracies)
plt.title('Bai_movObst1')
plt.ylabel('pass order accuracy')
plt.xlabel('dthetha + dpsi threshold (rad/s)')
plt.figure()
plt.hist(lengths_thres)
plt.title(f'Bai_movObst1 thres={threshold}')
plt.xlabel('effective trial length (s)')
plt.figure()
plt.hist(lengths)
plt.title('Bai_movObst1')
plt.xlabel('effective trial length (s)')


Text(0.5, 0, 'effective trial length (s)')

In [None]:
'''% Correct pass order by dtheta * dpsi threshold'''
%matplotlib qt
load_data('Bai_movObst1_data_30Hz')
threshold = 0.0005
thress = np.linspace(1E-4, 2E-3, 20)
threshold = thress[np.argmin(np.abs(thress - threshold))]
accuracies = []
lengths_thres = []
lengths = []
for thres in thress:
    correct = 0
    total = 0
    for i in range(len(data.trajs)):
        if i not in data.dump and data.info['obst_speed'][i] > 0 and abs(data.info['obst_angle'][i]) != 180:
            t0 = data.info['obst_onset'][i]
            t1 = data.info['obst_out'][i]
            rel_dpsi = data.info['rel_dpsi'][i]
            dtheta = data.info['dtheta'][i]
            for t in range(t0, t1):
                if dtheta[t] * abs(rel_dpsi[t]) > thres:
                    dpsi = rel_dpsi[t]
                    break
            else:
                dpsi = rel_dpsi[t]
            if thres == threshold:
                lengths.append((t1 - t0) / data.Hz)
                lengths_thres.append((t1 - t) / data.Hz)
            if np.sign(dpsi) == data.info['pass_order'][i]:
                correct += 1
            total += 1
    accuracies.append(correct / total)
plt.figure()
plt.plot(thress, accuracies)
plt.title('Bai_movObst1')
plt.ylabel('pass order accuracy')
plt.xlabel('dthetha * dpsi threshold (rad/s)')
plt.figure()
plt.hist(lengths_thres)
plt.title(f'Bai_movObst1 thres={threshold}')
plt.xlabel('effective trial length (s)')
plt.figure()
plt.hist(lengths)
plt.title('Bai_movObst1')
plt.xlabel('effective trial length (s)')

In [None]:
'''% Correct pass order by dtheta and dpsi'''
%matplotlib qt
load_data('Bai_movObst1_data_30Hz')
dthetas = np.linspace(0, 0.02, 10)
dpsis = np.linspace(0, 0.04, 10)
accuracies = np.zeros((len(dthetas), len(dpsis)))
tpps = np.zeros((len(dthetas), len(dpsis)))
for a, dtheta_thres in enumerate(dthetas):
    for b, dpsi_thres in enumerate(dpsis):
        correct = 0
        total = 0
        tpp = []
        for i in range(len(data.trajs)):
            if i not in data.dump and data.info['obst_speed'][i] > 0 and abs(data.info['obst_angle'][i]) != 180:
                t0 = data.info['obst_onset'][i]
                t1 = data.info['obst_out'][i]
                rel_dpsi = data.info['rel_dpsi'][i]
                dtheta = data.info['dtheta'][i]
                for t in range(t0, t1):
                    if dtheta[t] > dtheta_thres and abs(rel_dpsi[t]) > dpsi_thres:
                        dpsi = rel_dpsi[t]
                        break
                else:
                    dpsi = rel_dpsi[t]
                tpp.append((t1 - t) / data.Hz)
                if np.sign(dpsi) == data.info['pass_order'][i]:
                    correct += 1
                total += 1
        accuracies[a, b] = round(correct / total, 2)
        tpps[a, b] = round(np.mean(tpp), 2)

# Heatmap of pass accuracy
plt.figure()
plt.imshow(accuracies, cmap='plasma')
for i in range(accuracies.shape[0]):
    for j in range(accuracies.shape[1]):
        text = plt.text(j, i, str(accuracies[i, j]),
                       ha="center", va="center", color="w")
plt.xticks(range(0, accuracies.shape[0]), np.round(dthetas, 3))
plt.yticks(range(0, accuracies.shape[1]), np.round(dpsis, 3))
plt.xlabel("dtheta (rad/s)")
plt.ylabel("dpsi (rad/s)")
plt.title('Pass order accuracy of AND threshold')

# Heatmap of trial length loss
plt.figure()
plt.imshow(tpps, cmap='plasma')
for i in range(tpps.shape[0]):
    for j in range(tpps.shape[1]):
        text = plt.text(j, i, str(tpps[i, j]),
                       ha="center", va="center", color="w")
plt.xticks(range(0, tpps.shape[0]), np.round(dthetas, 3))
plt.yticks(range(0, tpps.shape[1]), np.round(dpsis, 3))
plt.xlabel("dtheta (rad/s)")
plt.ylabel("dpsi (rad/s)")
plt.title('Time before passing using AND threshold')
        

In [None]:
'''Count pass order by condition'''
#########
angle = 157.5
speed = 1.1
#########
fpass = 0
total = 0
for i in range(len(data.trajs)):
    if abs(data.info['obst_angle'][i]) != angle or data.info['obst_speed'][i] != speed:
        continue
    if data.info['pass_order'][i] == 1:
        fpass += 1
    total += 1
print(fpass, total)
print(fpass / total)

In [None]:
'''Pass order by subject'''
load_data('Bai_movObst1_data_30Hz')
subject = 3
pass_orders = []
for i in range(len(data.trajs)):
    if subject != data.info['subj_id'][i] or i in data.dump or data.info['obst_speed'][i] == 0 or data.info['obst_angle'][i] == 180:
        continue
    pass_orders.append(data.info['pass_order'][i])
print(f'Subject {subject} front pass ratio is {np.mean(np.array(pass_orders) > 0)}')

In [None]:
'''Plot dpsi / dtheta ratio by time'''
%matplotlib qt
subjects = range(16)
n = len(data.trajs)
fig0 = plt.figure()
ax0 = fig0.add_subplot()
fig1 = plt.figure()
ax1 = fig1.add_subplot()
fig2 = plt.figure()
ax2 = fig2.add_subplot()
w = data.info['w_obst']
for i in range(2,3):
    if (data.info['subj_id'][i] in subjects and
        i not in data.dump and
        data.info['obst_speed'][i] != 0 and
        abs(data.info['obst_angle'][i]) != 180):
        t0, t1 = data.info['obst_onset'][i], data.info['obst_out'][i]
        p0, p1 = np.array(data.info['p_subj'][i][t0:t1]), np.array(data.info['p_obst'][i][t0:t1])
        v0, v1 = np.array(data.info['v_subj'][i][t0:t1]), np.array(data.info['v_obst'][i][t0:t1])
        a0 = np.array(data.info['a_subj'][i][t0:t1])
        a0 = norm(a0, axis=-1)
        dpsis = np.absolute(d_psi(p0, p1, v0, v1))
        dthetas = d_theta(p0, p1, v0, v1, w=w)
        thetas = theta(p0, p1, w=w)
        ratio = dthetas/thetas
        ax0.plot(a0)
        c1, c2 = 0, 0.2
        ax1.plot((ratio+c1)/(dpsis+c2))
        ax2.scatter(ratio[60:80], a0[60:80], s=1)
#         ax2.scatter(range(len(dpsis)), dthetas/dpsis, s=1)

In [None]:
'''Heatmap and line plot of front passing ratio'''
%matplotlib qt
load_data('Bai_movObst1_data_30Hz')
name = data.info['experiment_name']
conditions = np.zeros((5, 4)) # speed by angle
angle_ins = {90: 0, 112.5: 1, 135: 2, 157.5: 3}
speed_ins = {0.9: 0, 1: 1, 1.1: 2, 1.2: 3, 1.3: 4}
for angle in angle_ins.keys():
    for speed in speed_ins.keys():
        orders = []
        for i in range(len(data.trajs)):
            if angle != abs(data.info['obst_angle'][i]) or speed != data.info['obst_speed'][i] or i in data.dump:
                continue
            orders.append(data.info['pass_order'][i])
        conditions[speed_ins[speed], angle_ins[angle]] = round(np.mean(np.array(orders) > 0), 2)

# Heatmap
plt.figure()
plt.imshow(conditions, cmap='plasma')
for i in range(conditions.shape[0]):
    for j in range(conditions.shape[1]):
        text = plt.text(j, i, str(conditions[i, j]),
                       ha="center", va="center", color="w")
plt.xticks(range(0, 4), list(angle_ins.keys()))
plt.yticks(range(0, 5), list(speed_ins.keys()))
plt.xlabel("angle (degree)")
plt.ylabel("speed (m/s)")
plt.title('Front pass ratio')
# Line plot
plt.figure()
x = list(speed_ins.keys())
for i, angle in zip(range(conditions.shape[1]), angle_ins.keys()):
    plt.plot(x, conditions[:, i], label=f'{angle}°', color='r', alpha=(angle-67.5)/90)
plt.ylabel("Front pass ratio")
plt.xlabel("speed (m/s)")
plt.legend()
plt.title(name)

### Bai_movObst1: Minimum Passing Distance (MPD)

In [None]:
'''Minimum Passing Distance'''
%matplotlib qt
fig = plt.figure()
ax = fig.add_subplot()
ax.set_title('Signed predicted minimum passing distance (SMPD)')
ax.set_ylabel('SMPD (m)')
ax.set_xlabel('normalized time (%)')
ax.set_ylim((-2, 2))
for i in range(len(data.trajs)):
    t0 = data.info['stimuli_onset'][i]
    t1 = data.info['stimuli_out'][i]
    p0 = data.info['p_subj'][i][t0:t1]
    p1 = data.info['p_obst'][i][t0:t1]
    v0 = data.info['v_subj'][i][t0:t1]
    v1 = data.info['v_obst'][i][t0:t1]
    t = np.linspace(0, 100, len(p0))
    smpd = []
    for _p0, _p1, _v0, _v1 in zip(p0, p1, v0, v1):        
        smpd.append(min_sep(_p0, _p1, _v0, _v1)[0])
    ax.plot(t, smpd, 'k', linewidth=0.1, alpha=0.5)
    
    

### Bai_movObst1: Speed vs Heading control

In [None]:
'''Plot acceleration angle and magnitude'''
%matplotlib qt
trials = range(400,401)
subject = 1
con_angle = [90, -90]
con_speed = []
fig0 = plt.figure()
ax0 = fig0.add_subplot()
fig1 = plt.figure()
ax1 = fig1.add_subplot()
for i in trials:
    angle = data.info['obst_angle'][i]
    speed = data.info['obst_speed'][i]
    subj_id = data.info['subj_id'][i]
    if i in data.dump:
        continue
#     if angle not in con_angle or subj_id != subject:
#         continue
    t0, t1 = data.info['stimuli_onset'][i], data.info['stimuli_out'][i]
    p0, p1, a0 = np.array(data.info['p_subj'][i][t0:t1]), np.array(data.info['p_obst'][i][t0:t1]), np.array(data.info['a_subj'][i][t0:t1])
    angles = signed_angle(p1 - p0, a0)
    ax0.scatter(range(len(angles)), angles, s=1)
    ax1.plot(norm(a0, axis=-1))
    print(data.info['subj_id'][i], data.info['trial_id'][i])

In [None]:
'''Signed speed vs heading control by condition (abs angle)'''
%matplotlib qt
angle_ins = {90: 0, 112.5: 1, 135: 2, 157.5: 3}
speed_ins = {0.9: 0, 1: 1, 1.1: 2, 1.2: 3, 1.3: 4}
for angle in angle_ins.keys():
    for speed in speed_ins.keys():
        fig = plt.figure()
        ax = fig.add_subplot(1,1,1)
        ax.set_title(f'obst_angle {angle}, obst_speed {speed}')
        ax.set_ylabel('acceleration (m/s**2)')
        ax.set_xlabel('time (s)')
        ax.set_xlim(0, 4.5)
        ax.set_ylim(-1.5, 1.5)
        sc = []
        hc = []
        for i in range(len(data.trajs)):
            if angle != abs(data.info['obst_angle'][i]) or speed != data.info['obst_speed'][i] or i in data.dump:
                continue
            if data.info['pass_order'][i] > 0:
                continue
            t0 = data.info['obst_onset'][i]
            t1 = data.info['obst_out'][i]
            v = data.info['v_subj'][i][t0:t1]
            a = data.info['a_subj'][i][t0:t1]
            speed_ctrl = np.array([np.inner(x, y)/norm(y) for x, y in zip(a, v)]) # a along direction of v
            heading_ctrl = np.sqrt(norm(a, axis=1) ** 2 - speed_ctrl ** 2) # a perpendicular to v
            # Check the side of heading ctrl: front (+) or behind (-) driving
            pass_side = np.sign(inner(rotate(a, pi/2), v)) * np.sign(data.info['obst_angle'][i]) * -1
            heading_ctrl *= pass_side
            sc.extend(speed_ctrl)
            hc.extend(heading_ctrl)
            plt.plot(np.linspace(0, len(speed_ctrl)/data.Hz, len(speed_ctrl)), speed_ctrl, color='r', linewidth=0.1, alpha=0.5)
            plt.plot(np.linspace(0, len(speed_ctrl)/data.Hz, len(speed_ctrl)), heading_ctrl, color='b', linewidth=0.1, alpha=0.5)
        ax.set_title(f'obst_angle {angle}, obst_speed {speed}, \n mean speed control {np.mean(np.abs(sc)):.2f} mean heading control {np.mean(np.abs(hc)):.2f} ratio {np.mean(np.abs(sc)) / np.mean(np.abs(hc)): .2f}')

In [None]:
'''Signed speed vs heading control by subject'''
for s in set(data.info['subj_id']):
    fig = plt.figure()
    ax = fig.add_subplot(1,1,1)
    ax.set_title(f'subject {s}')
    ax.set_ylabel('acceleration (m/s**2)')
    ax.set_xlabel('time (s)')
    ax.set_xlim(0, 4.5)
    ax.set_ylim(-1.5, 1.5)
    sc = []
    hc = []
    for i in range(len(data.trajs)):
        if s != data.info['subj_id'][i] or i in data.dump or data.info['obst_speed'][i] == 0 or data.info['obst_angle'][i] == 180:
            continue
        t0 = data.info['obst_onset'][i]
        t1 = data.info['obst_out'][i]
        v = data.info['v_subj'][i][t0:t1]
        a = data.info['a_subj'][i][t0:t1]
        speed_ctrl = np.array([np.inner(x, y)/norm(y) for x, y in zip(a, v)]) # a along direction of v
        heading_ctrl = np.sqrt(norm(a, axis=1) ** 2 - speed_ctrl ** 2) # a perpendicular to v
        # Check the side of heading ctrl: front (+) or behind (-) driving
        pass_side = np.sign(inner(rotate(a, pi/2), v)) * np.sign(data.info['obst_angle'][i]) * -1
        heading_ctrl *= pass_side
        sc.extend(speed_ctrl)
        hc.extend(heading_ctrl)
        plt.plot(np.linspace(0, len(speed_ctrl)/data.Hz, len(speed_ctrl)), speed_ctrl, color='r', linewidth=0.1, alpha=0.5)
        plt.plot(np.linspace(0, len(speed_ctrl)/data.Hz, len(speed_ctrl)), heading_ctrl, color='b', linewidth=0.1, alpha=0.5)
    ax.set_title(f'subject {s}, \n mean speed control {np.mean(np.abs(sc)):.2f} mean heading control {np.mean(np.abs(hc)):.2f} ratio {np.mean(np.abs(sc)) / np.mean(np.abs(hc)): .2f}')

In [None]:
'''Heatmap of speed/heading control ratio by condition (signed angle)'''
%matplotlib qt
conditions = np.zeros((5, 8)) # speed by angle
angle_ins = {-157.5: 0, -135: 1, -112.5: 2, -90: 3, 90: 4, 112.5: 5, 135: 6, 157.5: 7}
speed_ins = {0.9: 0, 1: 1, 1.1: 2, 1.2: 3, 1.3: 4}
for angle in angle_ins.keys():
    for speed in speed_ins.keys():
        sc = []
        hc = []
        for i in range(len(data.trajs)):
            if angle != data.info['obst_angle'][i] or speed != data.info['obst_speed'][i] or i in data.dump or \
            data.info['obst_speed'][i] == 0 or abs(data.info['obst_angle'][i]) == 180:
                continue
            t0 = data.info['obst_onset'][i]
            t1 = data.info['obst_out'][i]
            v = data.info['v_subj'][i][t0:t1]
            a = data.info['a_subj'][i][t0:t1]
            speed_ctrl = np.abs([np.inner(x, y)/norm(y) for x, y in zip(a, v)]) # a along direction of v
            heading_ctrl = np.sqrt(norm(a, axis=1) ** 2 - speed_ctrl ** 2) # a perpendicular to v
            sc.extend(speed_ctrl)
            hc.extend(heading_ctrl)
        conditions[speed_ins[speed], angle_ins[angle]] = round(np.mean(sc) / np.mean(hc), 2)

# Heatmap
plt.imshow(conditions, cmap='plasma')
for i in range(conditions.shape[0]):
    for j in range(conditions.shape[1]):
        text = plt.text(j, i, str(conditions[i, j]),
                       ha="center", va="center", color="w")
plt.xticks(range(0, 8), list(angle_ins.keys()))
plt.yticks(range(0, 5), list(speed_ins.keys()))
plt.xlabel("angle")
plt.ylabel("speed")

In [None]:
'''Heatmap of speed and heading control by condition (abs angle)'''
%matplotlib qt
conditions = np.zeros((5, 4)) # speed by angle
angle_ins = {90: 0, 112.5: 1, 135: 2, 157.5: 3}
speed_ins = {0.9: 0, 1: 1, 1.1: 2, 1.2: 3, 1.3: 4}
for angle in angle_ins.keys():
    for speed in speed_ins.keys():
        sc = []
        hc = []
        for i in range(len(data.trajs)):
            if angle != abs(data.info['obst_angle'][i]) or speed != data.info['obst_speed'][i] or i in data.dump:
                continue
            t0 = data.info['obst_onset'][i]
            t1 = data.info['obst_out'][i]
            v = data.info['v_subj'][i][t0:t1]
            a = data.info['a_subj'][i][t0:t1]
            speed_ctrl = np.abs([np.inner(x, y)/norm(y) for x, y in zip(a, v)]) # a along direction of v
            heading_ctrl = np.sqrt(norm(a, axis=1) ** 2 - speed_ctrl ** 2) # a perpendicular to v
            sc.extend(speed_ctrl)
            hc.extend(heading_ctrl)
        conditions[speed_ins[speed], angle_ins[angle]] = round(np.mean(sc) / np.mean(hc), 2)

# Heatmap
plt.imshow(conditions, cmap='plasma')
for i in range(conditions.shape[0]):
    for j in range(conditions.shape[1]):
        text = plt.text(j, i, str(conditions[i, j]),
                       ha="center", va="center", color="w")
plt.xticks(range(0, 4), list(angle_ins.keys()))
plt.yticks(range(0, 5), list(speed_ins.keys()))
plt.xlabel("angle")
plt.ylabel("speed")

In [None]:
'''Peak speed/heading control by conditions (speed on x abs angle separate lines)'''
%matplotlib qt
speed_data = np.zeros((5, 4)) # speed by angle
angle_data = np.zeros((5, 4)) # speed by angle
angle_ins = {90: 0, 112.5: 1, 135: 2, 157.5: 3}
speed_ins = {0.9: 0, 1: 1, 1.1: 2, 1.2: 3, 1.3: 4}
for angle in angle_ins.keys():
    for speed in speed_ins.keys():
        for i in range(len(data.trajs)):
            if angle != abs(data.info['obst_angle'][i]) or speed != data.info['obst_speed'][i] or i in data.dump:
                continue
            order = data.info['pass_order'][i]
            t0 = data.info['obst_onset'][i]
            t1 = data.info['obst_out'][i]
            v = data.info['v_subj'][i][t0:t1]
            a = data.info['a_subj'][i][t0:t1]
            speed_ctrl = np.array([np.inner(x, y)/norm(y) for x, y in zip(a, v)]) # a along direction of v
            heading_ctrl = np.sqrt(norm(a, axis=1) ** 2 - speed_ctrl ** 2) # a perpendicular to v
            # Check the side of heading ctrl: front (+) or behind (-) driving
            pass_side = np.sign(inner(rotate(a, pi/2), v)) * np.sign(data.info['obst_angle'][i]) * -1
            heading_ctrl *= pass_side
            if order > 0:
                sc_peak = np.max(speed_ctrl)
                hc_peak = np.max(heading_ctrl)
            else:
                sc_peak = -np.min(speed_ctrl)
                hc_peak = -np.min(heading_ctrl)
            speed_data[speed_ins[speed], angle_ins[angle]] = sc_peak
            angle_data[speed_ins[speed], angle_ins[angle]] = hc_peak
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
ax.set_ylabel('peak acceleration (m/s**2)')
ax.set_xlabel('speed (m/s)')
for angle, i in angle_ins.items():
    ax.plot(list(speed_ins.keys()), speed_data[:,i], label=f'speed_ctrl in {angle}°', color='r', alpha=(angle-67.5)/90)
    ax.plot(list(speed_ins.keys()), angle_data[:,i], label=f'angle_ctrl in {angle}°', color='b', alpha=(angle-67.5)/90)
ax.legend()

In [None]:
'''Peak speed/heading control by conditions (angle on x abs speed separate lines)'''
%matplotlib qt
speed_data = np.zeros((5, 4)) # speed by angle
angle_data = np.zeros((5, 4)) # speed by angle
angle_ins = {90: 0, 112.5: 1, 135: 2, 157.5: 3}
speed_ins = {0.9: 0, 1: 1, 1.1: 2, 1.2: 3, 1.3: 4}
for angle in angle_ins.keys():
    for speed in speed_ins.keys():
        for i in range(len(data.trajs)):
            if angle != abs(data.info['obst_angle'][i]) or speed != data.info['obst_speed'][i] or i in data.dump:
                continue
            order = data.info['pass_order'][i]
            t0 = data.info['obst_onset'][i]
            t1 = data.info['obst_out'][i]
            v = data.info['v_subj'][i][t0:t1]
            a = data.info['a_subj'][i][t0:t1]
            speed_ctrl = np.array([np.inner(x, y)/norm(y) for x, y in zip(a, v)]) # a along direction of v
            heading_ctrl = np.sqrt(norm(a, axis=1) ** 2 - speed_ctrl ** 2) # a perpendicular to v
            # Check the side of heading ctrl: front (+) or behind (-) driving
            pass_side = np.sign(inner(rotate(a, pi/2), v)) * np.sign(data.info['obst_angle'][i]) * -1
            heading_ctrl *= pass_side
            if order > 0:
                sc_peak = np.max(speed_ctrl)
                hc_peak = np.max(heading_ctrl)
            else:
                sc_peak = -np.min(speed_ctrl)
                hc_peak = -np.min(heading_ctrl)
            speed_data[speed_ins[speed], angle_ins[angle]] = sc_peak
            angle_data[speed_ins[speed], angle_ins[angle]] = hc_peak
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
ax.set_ylabel('peak acceleration (m/s**2)')
ax.set_xlabel('angle (°)')
for speed, i in speed_ins.items():
    ax.plot(list(angle_ins.keys()), speed_data[i,:], label=f'speed_ctrl in {speed}m/s', color='r', alpha=(speed-0.8)/0.5)
    ax.plot(list(angle_ins.keys()), angle_data[i,:], label=f'angle_ctrl in {speed}m/s', color='b', alpha=(speed-0.8)/0.5)
ax.legend()

### Bai_movObst1b (approach model)

In [None]:
print(set(data.info['goal_angle']))
print(set(data.info['leader_s0']))
print(set(data.info['goal_d0']))
cnt = 0
for i in range(len(data.trajs)):
    if data.info['leader_s0'][i] and i not in data.dump:
        cnt += 1
    if i in data.dump:
        print(data.dump[i])
print(cnt, len(data.dump))

In [None]:
'''Animate approach data'''
i = 1
t0 = data.info['stimuli_onset'][i]
t1 = data.info['stimuli_out'][i]
print(t0, t1)
p_goal = data.info['p_goal'][i][t0:t1]
p_subj = data.info['p_subj'][i][t0:t1]
trajs = [p_goal, p_subj]
ws = [data.info['w_goal'], 0.4]
title = 'subj ' + str(data.info['subj_id'][i]) + ' trial ' + str(data.info['trial_id'][i]) + ' goal_d0: ' + str(data.info['goal_d0'][i]) + ' goal_angle: ' + str(data.info['goal_angle'][i]) + ' leader_s0: ' + str(data.info['leader_s0'][i])
print(data.info['subj_id'][i], data.info['trial_id'][i])
play_trajs(trajs, ws, data.Hz, title=title, save=False)

In [None]:
'''Plot speed'''
for i in range(len(data.trajs)):
    if data.info['stimuli_out'][i] and data.info['stimuli_onset'][i] and data.info['leader_s0'] and i not in data.dump:
        t0, t1 = data.info['stimuli_onset'][i], data.info['stimuli_out'][i]
        v_subj = data.info['v_subj'][i]
        s_subj = norm(v_subj, axis=1)
        plt.plot(s_subj[t0:t1])

In [None]:
'''Check trial length'''
%matplotlib qt
a = []
for i in range(len(data.trajs)):
    if data.info['stimuli_out'][i] and data.info['stimuli_onset'][i] and data.info['leader_s0'] and i not in data.dump:
        l = data.info['stimuli_out'][i] - data.info['stimuli_onset'][i]
        a.append(l)
        if (l < 100 or l > 1000) and i not in data.dump:
            print(i, l)
plt.plot(a)

In [None]:
'''Check trajectories'''
for i in range(len(data.trajs)):
    if i not in data.dump and data.info['leader_s0']:
        if data.info['goal_d0'][i] == 8:
            plt.plot(data.info['p_subj'][i][:, 0], data.info['p_subj'][i][:, 1])

In [None]:
'''Plot sample traj in each condition'''
visited = set()
for i in range(len(data.trajs)):
    if i not in data.dump and data.info['leader_s0']:
        angle = data.info['goal_angle'][i]
        distance = data.info['goal_d0'][i]
        speed = data.info['leader_s0'][i]
        if (angle, distance, speed) not in visited and speed == 1.4:
            goal = data.info['p_goal'][i][0]
            subj = data.info['p_subj'][i]
            if data.trajs[i][0][0] > 0:
                subj = rotate(subj, np.arctan(11 / 9) + pi / 2)
                goal = rotate(goal, np.arctan(11 / 9) + pi / 2)
            else:
                subj = rotate(subj, np.arctan(11 / 9) - pi / 2)
                goal = rotate(goal, np.arctan(11 / 9) - pi / 2)
            plt.plot(subj[:, 0], subj[:, 1], 'k')
            plt.scatter(goal[0], goal[1], c='k', marker='x')
            visited.add(((angle, distance, speed)))
# plt.gca().set_aspect('equal') 
plt.xlabel("x (m)")
plt.ylabel("y (m)")

In [None]:
visited

In [None]:
'''Plot data by condition'''
#####################
subject = 0
s0 = [1.4]
d0 = [8]
angle = [15]
s0 = set(data.info['leader_s0'])
d0 = set(data.info['goal_d0'])
angle = set(data.info['goal_angle'])
#####################
%matplotlib qt
plt.figure()
for i in range(len(data.trajs)):
    goal_s0 = data.info['leader_s0'][i]
    goal_d0 = data.info['goal_d0'][i]
    goal_angle = data.info['goal_angle'][i]
    subj_id = data.info['subj_id'][i]
    if subj_id != subject:
        continue
    if not (goal_s0 in s0 and goal_d0 in d0 and abs(goal_angle) in angle):
        continue
    subj = np.array(data.info['p_subj'][i])
    goal = np.array(data.info['p_goal'][i])
    if i % 2 == 0:
        subj = rotate(subj, np.arctan(11 / 9) - pi / 2)
        goal = rotate(goal, np.arctan(11 / 9) - pi / 2)
    else:
        subj = rotate(subj, np.arctan(11 / 9) + pi / 2)
        goal = rotate(goal, np.arctan(11 / 9) + pi / 2)
#     if goal_angle < 0:
#         subj[:, 0] *= -1
#         goal[:, 0] *= -1
    plt.plot(subj[:, 0], subj[:, 1])
    plt.plot(goal[:, 0], goal[:, 1])
ax = plt.gca()
ax.set_aspect('equal')
ax.set_title(f'subj {str(subject)} angle: {str(angle)} s0: {str(s0)} d0: {str(d0)}')

### Bai_movObst2

In [12]:
'''Animate data'''
load_data('Bai_movObst2_data_30Hz')
%matplotlib qt
############
i = 603
############
t0 = data.info['obst_onset'][i]
t1 = data.info['obst_out'][i]
p_goal = data.info['p_goal'][i]
p_subj = data.info['p_subj'][i]
p_obst = data.info['p_obst'][i]
trajs = [p_goal, p_obst, p_subj]
ws = [data.info['w_goal'], data.info['w_obst'], 0.4]
title = 'subj ' + str(data.info['subj_id'][i]) + ' trial ' + str(data.info['trial_id'][i]) +\
        ' obst_angle: ' + str(data.info['obst_angle'][i]) + ' obst_speed: ' + str(data.info['obst_speed'][i]) +\
        ' ground: ' + str(data.info['ground'][i]) + ' ipd: ' + str(data.info['ipd'][i]) +\
        ' dsize: ' + str(data.info['dsize'][i])
play_trajs(trajs, ws, data.Hz, title=title, save=False)


<matplotlib.animation.FuncAnimation at 0x20991784dd8>

In [13]:
'''Plot obst trajectories'''
load_data('Bai_movObst2_data_30Hz')
visited = set()
for i in range(len(data.trajs)):
    angle = data.info['obst_angle'][i]
    speed = data.info['obst_speed'][i]
    if speed != 0 and (angle, speed) not in visited:
        visited.add((angle, speed))
        obst = data.info['p_obst'][i]
        if data.trajs[i][0][0] > 0:
            obst = rotate(obst, np.arctan(11 / 9) + pi / 2)
        else:
            obst = rotate(obst, np.arctan(11 / 9) - pi / 2)
        plt.plot(obst[:data.Hz*10, 0], obst[:data.Hz*10, 1], 'k')
        plt.scatter(obst[0, 0], obst[0, 1], c='k', marker='x')
plt.scatter(0, -7.1, c='k', marker='o')
plt.annotate("Participant", (0, -12.1))
plt.gca().set_aspect('equal') 
plt.xlabel("x (m)")
plt.ylabel("y (m)")
visited

{(-157.5, 0.9),
 (-157.5, 1.1),
 (-112.5, 0.9),
 (-112.5, 1.1),
 (112.5, 0.9),
 (112.5, 1.1),
 (157.5, 0.9),
 (157.5, 1.1)}

In [None]:
'''Average time series of dpsi by subj and condition, Bai_movObst2'''
load_data('Bai_movObst2_data_30Hz')
dpsi = {}
length = 267
t = np.linspace(0, 266/90, 267)
subjs = set(data.info['subj_id'])
subjs.add(0)
for subj in subjs:
    dpsi[subj] = {}
    for ground in [1, 0]:
        dpsi[subj][ground] = {}
        for ipd in [0.07, 0]:
            dpsi[subj][ground][ipd] = {}
            for dsize in [-0.1, 0, 0.1]:
                dpsi[subj][ground][ipd][dsize] = {}
                dpsi[subj][ground][ipd][dsize]['vals'] = np.zeros(length)
                dpsi[subj][ground][ipd][dsize]['n'] = 0

for i in range(len(data.trajs)):
    if i in data.dump or data.info['obst_speed'][i] == 0:
        continue
    subj = data.info['subj_id'][i]
    ground = data.info['ground'][i]
    ipd = data.info['ipd'][i]
    dsize = data.info['dsize'][i]
    vals = np.array(abs(data.info['dpsi'][i][:length]))
    dpsi[subj][ground][ipd][dsize]['vals'] += vals
    dpsi[subj][ground][ipd][dsize]['n'] += 1
    dpsi[0][ground][ipd][dsize]['vals'] += vals
    dpsi[0][ground][ipd][dsize]['n'] += 1

plt.plot()
ground_con = {1:'ground', 0:'no ground'}
ipd_con = {0.07:'disparity', 0:'no disparity'}
dsize_con = {0.1:'grow', -0.1:'shrink', 0:'constant'}
for ground in [1, 0]:
    for ipd in [0.07, 0]:
        for dsize in [-0.1, 0, 0.1]:
            plt.plot(t, dpsi[0][ground][ipd][dsize]['vals'] / dpsi[0][ground][ipd][dsize]['n'], 
                     label=f'{ground_con[ground]}, {ipd_con[ipd]}, {dsize_con[dsize]}')
plt.legend()


In [None]:
'''Plot average time series by one condition'''
load_data('Bai_movObst2_data_30Hz')
length = 267
n = 0
vals = np.zeros(length)
t = np.linspace(0, 266/90, 267)
for i in range(len(data.trajs)):
    if i in data.dump or data.info['obst_speed'][i] == 0:
        continue
    subj_id = data.info['subj_id'][i]
    ground = data.info['ground'][i]
    ipd = data.info['ipd'][i]
    dsize = data.info['dsize'][i]
    dpsi = abs(data.info['dpsi'][i][:length])
    if dsize == 0:
        n += 1
        vals += np.array(dpsi)
vals /= n
plt.plot(t, vals, label='constant size')
plt.xlabel('time (s)')
plt.ylabel('dpsi (rad/s)')
plt.legend()

In [None]:
'''Plot average time series by subject'''
load_data('Bai_movObst2_data_30Hz')
length = 267
subjs = {}
t = np.linspace(0, 266/90, 267)
for subj in set(data.info['subj_id']):
    n = 0
    vals = np.zeros(length)
    for i in range(len(data.trajs)):
        if i in data.dump or data.info['obst_speed'][i] == 0:
            continue
        subj_id = data.info['subj_id'][i]
        ground = data.info['ground'][i]
        ipd = data.info['ipd'][i]
        dsize = data.info['dsize'][i]
        dpsi = abs(data.info['dpsi'][i][:length])
        if subj_id == subj:
            n += 1
            vals += np.array(dpsi)
    vals /= n
    plt.plot(t, vals, label='subj ' + str(subj))
plt.xlabel('time (s)')
plt.ylabel('dpsi (rad/s)')
plt.legend()

In [None]:
'''Check trial length'''
load_data('Bai_movObst2_data_30Hz')
full = []
effective = []
for i in range(len(data.trajs)):
    if i not in data.dump and data.info['obst_speed'][i] != 0:
        t0, t1 = data.info['obst_onset'][i], data.info['obst_out'][i]
        tt = data.info['dtheta+dpsi_onset'][i]
        full.append(t1 - t0)
        effective.append(t1 - tt)
        if t1 - tt < 30:
            print(i, t1 -tt, data.info['dsize'][i])
plt.hist(full)
plt.hist(effective)

In [None]:
'''Plot all trials'''
load_data('Bai_movObst2_data_30Hz')
for i in range(len(data.trajs)):
    if i not in data.dump:
        p0 = data.info['p_subj'][i]
        plt.plot(p0[-10:,0], p0[-10:,1])

### Bai_movObst2: Pass order, pass distance

In [None]:
'''Passing distance, pass order by condition Bai_movObst2'''
load_data('Bai_movObst2_data_30Hz')
%matplotlib qt
#########
grounds = [1, 0]
ipds = [0.07, 0]
dsizes = [-0.1, 0, 0.1]
#########
pd1 = []
pd2 = []
po1 = []
po2 = []
dp1 = []
dp2 = []
ps1 = []
ps2 = []
for i in range(len(data.trajs)):
    if i in data.dump or data.info['obst_speed'][i] == 0:
        continue
    ground = data.info['ground'][i]
    ipd = data.info['ipd'][i]
    dsize = data.info['dsize'][i]
    t0 = data.info['obst_onset'][i]
    t1 = data.info['obst_out'][i]
    p0 = data.info['p_subj'][i][t1]
    p1 = data.info['p_obst'][i][t1]
    ps = norm(data.info['v_subj'][i][t0])
    pass_order = data.info['pass_order'][i]
    j = 0
#     for j in range(t0, t1):
#         x, y = data.info['p_subj'][i][j]
#         xo, yo = data.info['p_obst'][i][j]
#         dpsi = d_psi([x, y], [xo, yo], [vx, vy], [vxo, vyo])
#         if dpsi > 0.05:
#             break
    if pass_order == -1:
        pass_order = 0
    d = dist(p0, p1)
    if dsize == 0.1:
        pd1.append(d)
        po1.append(pass_order)
        dp1.append(j)
        ps1.append(ps)
    elif dsize == -0.1:
        pd2.append(d)
        po2.append(pass_order)
        dp2.append(j)
        ps2.append(ps)

var1, var2 = pd1, pd2
print(np.mean(var1), np.mean(var2))
print(ttest_ind(var1, var2))
plt.boxplot([var1,var2])


### Bai_movObst2: Threshold

In [None]:
'''Match rate by dpsi threshold by condition, Bai_movObst2'''
%matplotlib qt
load_data('Bai_movObst2_data_30Hz')
sim = ODESimulator(data=data, ref=[0,1])
thress = np.linspace(0, 0.1, 11)
match_rates = []
match_rates_ipd1 = []
match_rates_ipd0 = []
match_rates_ground1 = []
match_rates_ground0 = []
match_rates_dsize1 = []
match_rates_dsize0 = []
i_ipd1 = []
i_ipd0 = []
i_ground1 = []
i_ground0 = []
i_dsize1 = []
i_dsize0 = []
t0s = {}
for ii, thres in enumerate(thress):
    print(f'thres = {thres}')
    side_pred = []
    side_true = []
    for i in range(len(sim.data.trajs)):
        if i in sim.data.dump or sim.data.info['obst_speed'][i] == 0:
            continue
        t0 = t0s.get(i, sim.data.info['obst_onset'][i])
        t1 = sim.data.info['obst_out'][i]
        ipd = sim.data.info['ipd'][i]
        ground = sim.data.info['ground'][i]
        dsize = sim.data.info['dsize'][i]
        for j in range(t0, t1):
            xg, yg, xo, yo, vxo, vyo, x, y, vx, vy, a, phi, s, dphi, ds, w0 = sim.compute_var0(i, j)
            dpsi = d_psi([x, y], [xo, yo], [vx, vy], [vxo, vyo])
            if j == t1 - 1:
                print('reached t1 before meeting the threshold')
            if abs(dpsi) > thres:
                t0s[i] = j
                break
        angle = sim.data.info['obst_angle'][i]
        side_pred.append(1 if angle * dpsi < 0 else -1)
        side_true.append(sim.data.info['pass_order'][i])
        if ii == 0:
            i_ipd1.append(True if ipd==0.07 else False)
            i_ipd0.append(True if ipd==0 else False)
            i_ground1.append(True if ground==1 else False)
            i_ground0.append(True if ground==0 else False)
            i_dsize1.append(True if dsize==0.1 else False)
            i_dsize0.append(True if dsize==-0.1 else False)
    if ii == 0:
        i_ipd1 = np.array(i_ipd1)
        i_ipd0 = np.array(i_ipd0)
        i_ground1 = np.array(i_ground1)
        i_ground0 = np.array(i_ground0)
        i_dsize1 = np.array(i_dsize1)
        i_dsize0 = np.array(i_dsize0)
    side_pred = np.array(side_pred)
    side_true = np.array(side_true)
    match_rates.append(accuracy_score(side_true, side_pred))
    match_rates_ipd1.append(accuracy_score(side_true[i_ipd1], side_pred[i_ipd1]))
    match_rates_ipd0.append(accuracy_score(side_true[i_ipd0], side_pred[i_ipd0]))
    match_rates_ground1.append(accuracy_score(side_true[i_ground1], side_pred[i_ground1]))
    match_rates_ground0.append(accuracy_score(side_true[i_ground0], side_pred[i_ground0]))
    match_rates_dsize1.append(accuracy_score(side_true[i_dsize1], side_pred[i_dsize1]))
    match_rates_dsize0.append(accuracy_score(side_true[i_dsize0], side_pred[i_dsize0]))
    
plt.plot(thress, match_rates, label='All trials')
plt.plot(thress, match_rates_ipd1, label='disparity')
plt.plot(thress, match_rates_ipd0, label='no disparity')
plt.plot(thress, match_rates_ground1, label='ground')
plt.plot(thress, match_rates_ground0, label='no ground')
plt.plot(thress, match_rates_dsize1, label='grow')
plt.plot(thress, match_rates_dsize0, label='shrink')
plt.ylabel('Percentage')
plt.xlabel('dpsi threshold')
plt.legend()


In [4]:
'''% Correct pass order by dtheta threshold'''
%matplotlib qt
load_data('Bai_movObst2_data_30Hz')
threshold = 0.02
thress = np.linspace(0.001, 0.03, 20)
threshold = thress[np.argmin(np.abs(thress - threshold))]
accuracies = []
lengths_thres = []
lengths = []
for thres in thress:
    correct = 0
    total = 0
    for i in range(len(data.trajs)):
        if i not in data.dump and data.info['obst_speed'][i] > 0:
            t0 = data.info['obst_onset'][i]
            t1 = data.info['obst_out'][i]
            rel_dpsi = data.info['rel_dpsi'][i]
            dtheta = data.info['dtheta'][i]
            for t in range(t0, t1):
                if dtheta[t] > thres:
                    dpsi = rel_dpsi[t]
                    break
            else:
                dpsi = rel_dpsi[t]
            if thres == threshold:
                lengths.append((t1 - t0) / data.Hz)
                lengths_thres.append((t1 - t) / data.Hz)
            if np.sign(dpsi) == data.info['pass_order'][i]:
                correct += 1
            total += 1
    accuracies.append(correct / total)
plt.figure()
plt.plot(thress, accuracies)
plt.title('Bai_movObst2')
plt.ylabel('pass order accuracy')
plt.xlabel('dthetha threshold (rad/s)')
plt.figure()
plt.hist(lengths_thres)
plt.title(f'Bai_movObst2 thres={threshold}')
plt.xlabel('Time to pass (s)')
plt.figure()
plt.hist(lengths)
plt.title('Bai_movObst2')
plt.xlabel('Time to pass (s)')

Text(0.5, 0, 'Time to pass (s)')

In [None]:
'''% Correct pass order by dtheta+dpsi threshold'''
%matplotlib qt
load_data('Bai_movObst2_data_30Hz')
threshold = 0.05
thress = np.linspace(0.001, 0.1, 20)
threshold = thress[np.argmin(np.abs(thress - threshold))]
accuracies = []
lengths_thres = []
lengths = []
for thres in thress:
    correct = 0
    total = 0
    for i in range(len(data.trajs)):
        if i not in data.dump and data.info['obst_speed'][i] > 0:
            t0 = data.info['obst_onset'][i]
            t1 = data.info['obst_out'][i]
            rel_dpsi = data.info['rel_dpsi'][i]
            dtheta = data.info['dtheta'][i]
            for t in range(t0, t1):
                if dtheta[t] + abs(rel_dpsi[t]) > thres:
                    dpsi = rel_dpsi[t]
                    break
            else:
                dpsi = rel_dpsi[t]
            if thres == threshold:
                lengths.append((t1 - t0) / data.Hz)
                lengths_thres.append((t1 - t) / data.Hz)
                if abs(t - data.info['dtheta+dpsi_onset'][i]) > 10:
                    print(i, t, data.info['dtheta+dpsi_onset'][i])
            if np.sign(dpsi) == data.info['pass_order'][i]:
                correct += 1
            total += 1
    accuracies.append(correct / total)
plt.figure()
plt.plot(thress, accuracies)
plt.title('Bai_movObst2')
plt.ylabel('pass order accuracy')
plt.xlabel('dthetha+dpsi threshold (rad/s)')
plt.figure()
plt.hist(lengths_thres)
plt.title(f'Bai_movObst2 thres={threshold}')
plt.xlabel('Time to pass (s)')
plt.figure()
plt.hist(lengths)
plt.title('Bai_movObst2')
plt.xlabel('Time to pass (s)')

In [25]:
'''% Correct pass order by dtheta*dpsi threshold'''
%matplotlib qt
load_data('Bai_movObst2_data_30Hz')
threshold = 0.0005
thress = np.linspace(1E-4, 2E-3, 20)
threshold = thress[np.argmin(np.abs(thress - threshold))]
accuracies = []
lengths_thres = []
lengths = []
for thres in thress:
    correct = 0
    total = 0
    for i in range(len(data.trajs)):
        if i not in data.dump and data.info['obst_speed'][i] > 0:
            t0 = data.info['obst_onset'][i]
            t1 = data.info['obst_out'][i]
            rel_dpsi = data.info['rel_dpsi'][i]
            dtheta = data.info['dtheta'][i]
            for t in range(t0, t1):
                if dtheta[t] * abs(rel_dpsi[t]) > thres:
                    dpsi = rel_dpsi[t]
                    break
            else:
                dpsi = rel_dpsi[t]
            if thres == threshold:
                lengths.append((t1 - t0) / data.Hz)
                lengths_thres.append((t1 - t) / data.Hz)
                if abs(t - data.info['dtheta*dpsi_onset'][i]) > 10:
                    print(i, t, data.info['dtheta*dpsi_onset'][i])
            if np.sign(dpsi) == data.info['pass_order'][i]:
                correct += 1
            total += 1
    accuracies.append(correct / total)
plt.figure()
plt.plot(thress, accuracies)
plt.title('Bai_movObst2')
plt.ylabel('pass order accuracy')
plt.xlabel('dthetha*dpsi threshold (rad/s)')
plt.figure()
plt.hist(lengths_thres)
plt.title(f'Bai_movObst2 thres={threshold}')
plt.xlabel('Time to pass (s)')
plt.figure()
plt.hist(lengths)
plt.title('Bai_movObst2')
plt.xlabel('Time to pass (s)')

Text(0.5, 0, 'Time to pass (s)')

### Bai_movObst2: Testing speed

In [3]:
'''Plot average speed in last two seconds ground vs no ground'''
load_data('Bai_movObst2_data_30Hz')
%matplotlib qt
shortest = float("inf")
for i in range(len(data.trajs)):
    if i in data.dump or not data.info['obst_speed'][i]:
        continue
    t0, t1 = data.info['obst_onset'][i], data.info['obst_out'][i]
    shortest = min(shortest, t1 - t0)
t = np.linspace(0, shortest/data.Hz, int(shortest))
angles = sorted(set([abs(x) for x in data.info['obst_angle'] if x]))
speeds = sorted(set([abs(x) for x in data.info['obst_speed'] if x]))
print(angles, speeds)
for ground in [0, 1]:
    iplot = 1
    fig = plt.figure()
    for angle in angles:
        for speed in speeds:
            ax = fig.add_subplot(len(angles), len(speeds), iplot)
            iplot += 1
            ax.set_title(f'{angle}°, {speed} m/s, ground={ground}')
            if angle == 180:
                ax.set_xlabel('time (s)')
            if speed == 0.9:
                ax.set_ylabel('speed (m/s)')
            ax.set_ylim(0.4, 2)
            front = []
            behind = []
            # Plot trials
            for i in range(len(data.trajs)):
                if (angle != abs(data.info['obst_angle'][i]) or speed != data.info['obst_speed'][i] or i in data.dump or
                data.info['ground'][i] != ground):
                    continue
                t1 = data.info['obst_out'][i]
                v = data.info['v_subj'][i][t1-shortest:t1]
                s = norm(v, axis=1)
                order = data.info['pass_order'][i]
                if order == 1:
                    front.append(s)
                    ax.plot(t, s, color='r', linewidth=0.1, alpha=0.3)
                elif order == -1:
                    behind.append(s)
                    ax.plot(t, s, color='b', linewidth=0.1, alpha=0.3)
            # Plot average trials
            front = np.mean(front, axis=0)
            behind = np.mean(behind, axis=0)
            ax.plot(t, front, color='r')
            ax.plot(t, behind, color='b')

[112.5, 157.5] [0.9, 1.1]


In [4]:
'''Plot average speed in last two seconds disparity vs no disparity'''
load_data('Bai_movObst2_data_30Hz')
%matplotlib qt
shortest = float("inf")
for i in range(len(data.trajs)):
    if i in data.dump or not data.info['obst_speed'][i]:
        continue
    t0, t1 = data.info['obst_onset'][i], data.info['obst_out'][i]
    shortest = min(shortest, t1 - t0)
t = np.linspace(0, shortest/data.Hz, int(shortest))
angles = sorted(set([abs(x) for x in data.info['obst_angle'] if x]))
speeds = sorted(set([abs(x) for x in data.info['obst_speed'] if x]))
print(angles, speeds)
for ipd in [0, 0.07]:
    iplot = 1
    fig = plt.figure()
    for angle in angles:
        for speed in speeds:
            ax = fig.add_subplot(len(angles), len(speeds), iplot)
            iplot += 1
            ax.set_title(f'{angle}°, {speed} m/s, ipd={ipd}')
            if angle == 180:
                ax.set_xlabel('time (s)')
            if speed == 0.9:
                ax.set_ylabel('speed (m/s)')
            ax.set_ylim(0.4, 2)
            front = []
            behind = []
            # Plot trials
            for i in range(len(data.trajs)):
                if (angle != abs(data.info['obst_angle'][i]) or speed != data.info['obst_speed'][i] or i in data.dump or
                data.info['ipd'][i] != ipd):
                    continue
                t1 = data.info['obst_out'][i]
                v = data.info['v_subj'][i][t1 - shortest:t1]
                s = norm(v, axis=1)
                order = data.info['pass_order'][i]
                if order == 1:
                    front.append(s)
                    ax.plot(t, s, color='r', linewidth=0.1, alpha=0.3)
                elif order == -1:
                    behind.append(s)
                    ax.plot(t, s, color='b', linewidth=0.1, alpha=0.3)
            # Plot average trials
            front = np.mean(front, axis=0)
            behind = np.mean(behind, axis=0)
            ax.plot(t, front, color='r')
            ax.plot(t, behind, color='b')

[112.5, 157.5] [0.9, 1.1]


In [5]:
'''Plot average speed in last two seconds shrink vs normal vs grow'''
load_data('Bai_movObst2_data_30Hz')
%matplotlib qt
shortest = float("inf")
for i in range(len(data.trajs)):
    if i in data.dump or not data.info['obst_speed'][i]:
        continue
    t0, t1 = data.info['obst_onset'][i], data.info['obst_out'][i]
    shortest = min(shortest, t1 - t0)
t = np.linspace(0, shortest/data.Hz, int(shortest))
angles = sorted(set([abs(x) for x in data.info['obst_angle'] if x]))
speeds = sorted(set([abs(x) for x in data.info['obst_speed'] if x]))
print(angles, speeds)
for dsize in [-0.1, 0, 0.1]:
    iplot = 1
    fig = plt.figure()
    for angle in angles:
        for speed in speeds:
            ax = fig.add_subplot(len(angles), len(speeds), iplot)
            iplot += 1
            ax.set_title(f'{angle}°, {speed} m/s, dsize={dsize}')
            if angle == 180:
                ax.set_xlabel('time (s)')
            if speed == 0.9:
                ax.set_ylabel('speed (m/s)')
            ax.set_ylim(0.4, 2)
            front = []
            behind = []
            # Plot trials
            for i in range(len(data.trajs)):
                if (angle != abs(data.info['obst_angle'][i]) or speed != data.info['obst_speed'][i] or i in data.dump or
                data.info['dsize'][i] != dsize):
                    continue
                t1 = data.info['obst_out'][i]
                v = data.info['v_subj'][i][t1-shortest:t1]
                s = norm(v, axis=1)
                order = data.info['pass_order'][i]
                if order == 1:
                    front.append(s)
                    ax.plot(t, s, color='r', linewidth=0.1, alpha=0.3)
                elif order == -1:
                    behind.append(s)
                    ax.plot(t, s, color='b', linewidth=0.1, alpha=0.3)
            # Plot average trials
            front = np.mean(front, axis=0)
            behind = np.mean(behind, axis=0)
            ax.plot(t, front, color='r')
            ax.plot(t, behind, color='b')

[112.5, 157.5] [0.9, 1.1]


### Bai_movObst2: Testing heading

In [6]:
'''Plot average heading in last two seconds ground vs no ground'''
load_data('Bai_movObst2_data_30Hz')
%matplotlib qt
shortest = float("inf")
for i in range(len(data.trajs)):
    if i in data.dump or not data.info['obst_speed'][i]:
        continue
    t0, t1 = data.info['obst_onset'][i], data.info['obst_out'][i]
    shortest = min(shortest, t1 - t0)
t = np.linspace(0, shortest/data.Hz, int(shortest))
angles = sorted(set([abs(x) for x in data.info['obst_angle'] if x]))
speeds = sorted(set([abs(x) for x in data.info['obst_speed'] if x]))
print(angles, speeds)
for ground in [0, 1]:
    iplot = 1
    fig = plt.figure()
    for angle in angles:
        for speed in speeds:
            ax = fig.add_subplot(len(angles), len(speeds), iplot)
            iplot += 1
            ax.set_title(f'{angle}°, {speed} m/s, ground={ground}')
            if angle == 180:
                ax.set_xlabel('time (s)')
            if speed == 0.9:
                ax.set_ylabel('heading (°)')
            ax.set_ylim(-90, 90)
            front = []
            behind = []
            # Plot trials
            for i in range(len(data.trajs)):
                if (angle != abs(data.info['obst_angle'][i]) or speed != data.info['obst_speed'][i] or i in data.dump or
                data.info['ground'][i] != ground):
                    continue
                t1 = data.info['obst_out'][i]
                v = data.info['v_subj'][i][t1-shortest:t1]
                ref = [9, 11] if i % 2 == 0 else [-9, -11]
                s, h = v2sp(v, ref=ref)
                # Convert to degree
                h *= 180/pi
                h = -h if data.info['obst_angle'][i] < 0 else h
                order = data.info['pass_order'][i]
                if order == 1:
                    front.append(h)
                    ax.plot(t, h, color='r', linewidth=0.1, alpha=0.3)
                elif order == -1:
                    behind.append(h)
                    ax.plot(t, h, color='b', linewidth=0.1, alpha=0.3)
            # Plot average trials
            front = np.mean(front, axis=0)
            behind = np.mean(behind, axis=0)
            ax.plot(t, front, color='r')
            ax.plot(t, behind, color='b')

[112.5, 157.5] [0.9, 1.1]


In [7]:
'''Plot average heading in last two seconds disparity vs no disparity'''
load_data('Bai_movObst2_data_30Hz')
%matplotlib qt
shortest = float("inf")
for i in range(len(data.trajs)):
    if i in data.dump or not data.info['obst_speed'][i]:
        continue
    t0, t1 = data.info['obst_onset'][i], data.info['obst_out'][i]
    shortest = min(shortest, t1 - t0)
t = np.linspace(0, shortest/data.Hz, int(shortest))
angles = sorted(set([abs(x) for x in data.info['obst_angle'] if x]))
speeds = sorted(set([abs(x) for x in data.info['obst_speed'] if x]))
print(angles, speeds)
for ipd in [0, 0.07]:
    iplot = 1
    fig = plt.figure()
    for angle in angles:
        for speed in speeds:
            ax = fig.add_subplot(len(angles), len(speeds), iplot)
            iplot += 1
            ax.set_title(f'{angle}°, {speed} m/s, ipd={ipd}')
            if angle == 180:
                ax.set_xlabel('time (s)')
            if speed == 0.9:
                ax.set_ylabel('heading (°)')
            ax.set_ylim(-90, 90)
            front = []
            behind = []
            # Plot trials
            for i in range(len(data.trajs)):
                if (angle != abs(data.info['obst_angle'][i]) or speed != data.info['obst_speed'][i] or i in data.dump or
                data.info['ipd'][i] != ipd):
                    continue
                t1 = data.info['obst_out'][i]
                v = data.info['v_subj'][i][t1-shortest:t1]
                ref = [9, 11] if i % 2 == 0 else [-9, -11]
                s, h = v2sp(v, ref=ref)
                # Convert to degree
                h *= 180/pi
                h = -h if data.info['obst_angle'][i] < 0 else h
                order = data.info['pass_order'][i]
                if order == 1:
                    front.append(h)
                    ax.plot(t, h, color='r', linewidth=0.1, alpha=0.3)
                elif order == -1:
                    behind.append(h)
                    ax.plot(t, h, color='b', linewidth=0.1, alpha=0.3)
            # Plot average trials
            front = np.mean(front, axis=0)
            behind = np.mean(behind, axis=0)
            ax.plot(t, front, color='r')
            ax.plot(t, behind, color='b')

[112.5, 157.5] [0.9, 1.1]


In [8]:
'''Plot average heading in last two seconds disparity vs no disparity'''
load_data('Bai_movObst2_data_30Hz')
%matplotlib qt
shortest = float("inf")
for i in range(len(data.trajs)):
    if i in data.dump or not data.info['obst_speed'][i]:
        continue
    t0, t1 = data.info['obst_onset'][i], data.info['obst_out'][i]
    shortest = min(shortest, t1 - t0)
t = np.linspace(0, shortest/data.Hz, int(shortest))
angles = sorted(set([abs(x) for x in data.info['obst_angle'] if x]))
speeds = sorted(set([abs(x) for x in data.info['obst_speed'] if x]))
print(angles, speeds)
for dsize in [-0.1, 0, 0.1]:
    iplot = 1
    fig = plt.figure()
    for angle in angles:
        for speed in speeds:
            ax = fig.add_subplot(len(angles), len(speeds), iplot)
            iplot += 1
            ax.set_title(f'{angle}°, {speed} m/s, dsize={dsize}')
            if angle == 180:
                ax.set_xlabel('time (s)')
            if speed == 0.9:
                ax.set_ylabel('heading (°)')
            ax.set_ylim(-90, 90)
            front = []
            behind = []
            # Plot trials
            for i in range(len(data.trajs)):
                if (angle != abs(data.info['obst_angle'][i]) or speed != data.info['obst_speed'][i] or i in data.dump or
                data.info['dsize'][i] != dsize):
                    continue
                t1 = data.info['obst_out'][i]
                v = data.info['v_subj'][i][t1-shortest:t1]
                ref = [9, 11] if i % 2 == 0 else [-9, -11]
                s, h = v2sp(v, ref=ref)
                # Convert to degree
                h *= 180/pi
                h = -h if data.info['obst_angle'][i] < 0 else h
                order = data.info['pass_order'][i]
                if order == 1:
                    front.append(h)
                    ax.plot(t, h, color='r', linewidth=0.1, alpha=0.3)
                elif order == -1:
                    behind.append(h)
                    ax.plot(t, h, color='b', linewidth=0.1, alpha=0.3)
            # Plot average trials
            front = np.mean(front, axis=0)
            behind = np.mean(behind, axis=0)
            ax.plot(t, front, color='r')
            ax.plot(t, behind, color='b')

[112.5, 157.5] [0.9, 1.1]


### Bai_movObst2: Ouput speed and heading data for LME analysis

In [10]:
'''Ouput speed data for LME analysis by trials'''
# DV: averaged speed in the last two seconds. Fixed effect: ground, disparity, expansion. Random effect: subject
# Do LME to 8 conditions separately

load_data('Bai_movObst2_data_30Hz')
with open('Bai_movObst2_final_states.csv', 'w') as f: 
    for i in range(len(data.trajs)):
        if i in data.dump or not data.info['obst_speed'][i]:
            continue
        t1 = data.info['obst_out'][i]
        v = data.info['v_subj'][i][t1-2*data.Hz:t1]
        ref = [9, 11] if i % 2 == 0 else [-9, -11]
        s, h = v2sp(v, ref=ref)
        # Convert to degree
        h *= 180/pi
        h = -h if data.info['obst_angle'][i] < 0 else h
        log = [data.info['subj_id'][i], data.info['trial_id'][i],
               abs(data.info['obst_angle'][i]), data.info['obst_speed'][i], data.info['pass_order'][i],
               data.info['ground'][i], data.info['ipd'][i], data.info['dsize'][i],
               np.mean(s), np.mean(h)]
        log = [str(x) for x in log]
        f.write(','.join(log) + '\n')


In [11]:
'''Ouput speed data for LME analysis by timestamps'''
load_data('Bai_movObst2_data_30Hz')
shortest = float("inf")
for i in range(len(data.trajs)):
    if i in data.dump or not data.info['obst_speed'][i]:
        continue
    t0, t1 = data.info['obst_onset'][i], data.info['obst_out'][i]
    shortest = min(shortest, t1 - t0)
ts = np.linspace(0, shortest/data.Hz, int(shortest))
with open('Bai_movObst2_speed_heading_time_series.csv', 'w') as f:
    f.write('subject,trial,angle,speed,order,ground,ipd,dsize,time,s,h\n')
    for i in range(len(data.trajs)):
        if i in data.dump or not data.info['obst_speed'][i]:
            continue
        t1 = data.info['obst_out'][i]
        v = data.info['v_subj'][i][t1-shortest:t1]
        ref = [9, 11] if i % 2 == 0 else [-9, -11]
        s, h = v2sp(v, ref=ref)
        # Convert to degree
        h *= 180/pi
        h = -h if data.info['obst_angle'][i] < 0 else h
        for j, t in enumerate(ts):
            log = [data.info['subj_id'][i], data.info['trial_id'][i],
                   abs(data.info['obst_angle'][i]), data.info['obst_speed'][i], data.info['pass_order'][i],
                   data.info['ground'][i], data.info['ipd'][i], data.info['dsize'][i],
                   t, s[j], h[j]]
            log = [str(x) for x in log]
            f.write(','.join(log) + '\n')

In [None]:
data.info.keys()

### Bai_movObst2: testing features

In [6]:
# Time of maximum absolute magnitude
load_data('Bai_movObst2_data_30Hz')
h_t_max = {}
h_t_max_mean = {}
# Find the shortest trial
shortest = float("inf")
for i in range(len(data.trajs)):
    if i in data.dump or not data.info['obst_speed'][i]:
        continue
    t0, t1 = data.info['obst_onset'][i], data.info['obst_out'][i]
    shortest = min(shortest, t1 - t0)
# t = np.linspace(0, shortest/data.Hz, int(shortest))
for i in range(len(data.trajs)):
    if i not in data.dump and data.info['obst_speed'][i]:
        subj = data.info['subj_id'][i]
        t1 = data.info['obst_out'][i]
        v = data.info['v_subj'][i][t1-shortest:t1]
        ref = [9, 11] if i % 2 == 0 else [-9, -11]
        s, h = v2sp(v, ref=ref)
        h = np.abs(h)
        if subj not in h_t_max:
            h_t_max[subj] = []
            h_t_max_mean[subj] = []
        h_t_max[subj].append(np.argmax(h))
        h_t_max_mean[subj].append(h)
for subj, val in h_t_max.items():
    h_t_max[subj] = np.mean(val)
for subj, val in h_t_max_mean.items():
    mean_traj = np.mean(val, axis=0)
    h_t_max_mean[subj] = np.argmax(mean_traj)
print(h_t_max, h_t_max_mean)

{1: 35.46689895470383, 2: 32.572916666666664, 3: 38.92446043165467, 4: 27.319444444444443, 5: 33.88811188811189, 6: 33.21875, 7: 46.1184668989547, 8: 50.21478873239437, 9: 41.02439024390244, 10: 34.498245614035085, 11: 52.36619718309859, 12: 37.90625, 13: 57.666666666666664} {1: 15, 2: 35, 3: 42, 4: 25, 5: 36, 6: 36, 7: 53, 8: 53, 9: 43, 10: 33, 11: 60, 12: 40, 13: 61}


In [13]:
# Output time and value of max heading and speed by trial
load_data('Bai_movObst2_data_30Hz')
# Find the shortest trial
shortest = float("inf")
for i in range(len(data.trajs)):
    if i in data.dump or not data.info['obst_speed'][i]:
        continue
    t0, t1 = data.info['obst_onset'][i], data.info['obst_out'][i]
    shortest = min(shortest, t1 - t0)
with open('Bai_movObst2_feature_trial.csv', 'w') as f:
    f.write('subject,trial,angle,speed,order,ground,ipd,dsize,hmax,hmaxtime,smax,smaxtime\n')
    for i in range(len(data.trajs)):
        if i not in data.dump and data.info['obst_speed'][i]:
            t1 = data.info['obst_out'][i]
            v = data.info['v_subj'][i][t1-shortest:t1]
            ref = [9, 11] if i % 2 == 0 else [-9, -11]
            s, h = v2sp(v, ref=ref)
            # Convert to degree
            h *= 180/pi
            h = np.abs(h)
            log = [data.info['subj_id'][i], data.info['trial_id'][i],
                   abs(data.info['obst_angle'][i]), data.info['obst_speed'][i], data.info['pass_order'][i],
                   data.info['ground'][i], data.info['ipd'][i], data.info['dsize'][i],
                   np.max(h), np.argmax(h), np.max(s), np.argmax(s)]
            log = [str(x) for x in log]
            f.write(','.join(log) + '\n')
                

In [21]:
# Output time of maximum heading by subject
load_data('Bai_movObst2_data_30Hz')
logs = {}
# Find the shortest trial
shortest = float("inf")
for i in range(len(data.trajs)):
    if i in data.dump or not data.info['obst_speed'][i]:
        continue
    t0, t1 = data.info['obst_onset'][i], data.info['obst_out'][i]
    shortest = min(shortest, t1 - t0)
for i in range(len(data.trajs)):
    if i not in data.dump and data.info['obst_speed'][i]:
        t1 = data.info['obst_out'][i]
        v = data.info['v_subj'][i][t1-shortest:t1]
        ref = [9, 11] if i % 2 == 0 else [-9, -11]
        s, h = v2sp(v, ref=ref)
        # Convert to degree
        h *= 180/pi
        h = np.abs(h)
        condition = (data.info['subj_id'][i], abs(data.info['obst_angle'][i]), data.info['obst_speed'][i],
                     data.info['pass_order'][i], data.info['ground'][i], data.info['ipd'][i],
                     data.info['dsize'][i])
        if condition not in logs:
            logs[condition] = [[], []]
        logs[condition][0].append(h)
        logs[condition][1].append(s)
with open('Bai_movObst2_feature_subject.csv', 'w') as f:
    f.write('subject,angle,speed,order,ground,ipd,dsize,hmax,hmaxtime,smax,smaxtime\n')
    for condition, val in logs.items():
        h = np.mean(val[0], axis=0)
        s = np.mean(val[1], axis=0)
        log = list(condition) + [np.max(h), np.argmax(h), np.max(s), np.argmax(s)]
        log = [str(x) for x in log]
        f.write(','.join(log) + '\n')

624

In [20]:
set(data.info['obst_speed'])

{0.0, 0.9, 1.1}

### Bai_movObst2: Testing average trajectories

In [None]:
'''Testing the effect of an variable'''
'''
# Divide trials randomly into two groups
# Record mean distance between the averaged trajectory of two groups
# See whether the mean distance between two level of the variable is significantly 
different from the mean of the distribution of mean distance between to random groups
'''


In [None]:
'''Plot average trajectories ground vs no ground'''
load_data('Bai_movObst2_data_30Hz')
%matplotlib qt
y_max = float('inf')
y_min = -float('inf')
for i in range(len(data.trajs)):
    if i in data.dump or data.info['obst_speed'][i] == 0:
        continue
    traj = np.array(data.info['p_subj'][i])
    traj = -traj if i % 2 == 1 else traj
    traj = traj - traj[0]
    traj = rotate(traj, np.arctan(11 / 9) - pi / 2)
    y_max = min(y_max, max(traj[:, 1]))
    y_min = max(y_min, min(traj[:, 1]))
    if y_max == 0:
        print(i)
print(f'common y range among all trials is [{y_min}, {y_max}]')

# Compute y positions
y_step = 0.01
y = np.linspace(y_min, y_max, int((y_max - y_min) / y_step))
# Interpolate x by y
angles = sorted(set([abs(x) for x in data.info['obst_angle'] if x]))
speeds = sorted(set([abs(x) for x in data.info['obst_speed'] if x]))
print(angles, speeds)
trajs = {}
for ground in [0, 1]:
    iplot = 1
    fig = plt.figure()
    for angle in angles:
        for speed in speeds:
            ax = fig.add_subplot(len(angles), len(speeds), iplot)
            iplot += 1
            ax.set_title(f'{angle}°, {speed} m/s, ground={ground}')
            if angle == 180:
                ax.set_xlabel('x (m)')
            if speed == 0.9:
                ax.set_ylabel('y (m)')
            ax.set_aspect('equal')
            ax.set_xlim(-2, 2)
            ax.set_ylim(-0.2, 12)
            front = []
            behind = []
            # Plot trials
            for i in range(len(data.trajs)):
                if (angle != abs(data.info['obst_angle'][i]) or speed != data.info['obst_speed'][i] or i in data.dump or
                data.info['ground'][i] != ground):
                    continue
                traj = np.array(data.info['p_subj'][i])
                traj = -traj if i % 2 == 1 else traj
                traj = traj - traj[0]
                traj = rotate(traj, np.arctan(11 / 9) - pi / 2)
                if data.info['obst_angle'][i] < 0:
                    traj[:,0] = -traj[:,0]
                x = np.interp(y, traj[:,1], traj[:,0])
                order = data.info['pass_order'][i]
                if order == 1:
                    front.append(x)
                    ax.plot(traj[:, 0], traj[:, 1], color='r', linewidth=0.1, alpha=0.3)
                elif order == -1:
                    behind.append(x)
                    ax.plot(traj[:, 0], traj[:, 1], color='b', linewidth=0.1, alpha=0.3)
            trajs[(ground, angle, speed, 1)] = front
            trajs[(ground, angle, speed, -1)] = behind
            # Plot average trials
            front = np.mean(front, axis=0)
            behind = np.mean(behind, axis=0)
            ax.plot(front, y, color='r')
            ax.plot(behind, y, color='b')

In [None]:
'''Test average trajectories ground vs no ground'''
zs = {}
for angle in angles:
    for speed in speeds:
        for order in [-1, 1]:
            traj1 = np.mean(trajs[(0, angle, speed, order)], axis=0)
            traj2 = np.mean(trajs[(1, angle, speed, order)], axis=0)
            x = np.mean(np.abs(traj1 - traj2))
            sample = []
            g = trajs[(0, angle, speed, order)] + trajs[(1, angle, speed, order)]
            n = len(g) // 2
            for _ in range(1000):
                random.shuffle(g)
                g1 = np.mean(g[:n], axis=0)
                g2 = np.mean(g[n:], axis=0)
                sample.append(np.mean(np.abs(g1 - g2)))
            zs[(angle, speed, order)] = (x - np.mean(sample)) / np.std(sample)
zs

In [None]:
'''Plot average trajectories disparity vs no disparity'''
load_data('Bai_movObst2_data_30Hz')
%matplotlib qt
y_max = float('inf')
y_min = -float('inf')
for i in range(len(data.trajs)):
    if i in data.dump or data.info['obst_speed'][i] == 0:
        continue
    traj = np.array(data.info['p_subj'][i])
    traj = -traj if i % 2 == 1 else traj
    traj = traj - traj[0]
    traj = rotate(traj, np.arctan(11 / 9) - pi / 2)
    y_max = min(y_max, max(traj[:, 1]))
    y_min = max(y_min, min(traj[:, 1]))
    if y_max == 0:
        print(i)
print(f'common y range among all trials is [{y_min}, {y_max}]')

# Compute y positions
y_step = 0.01
y = np.linspace(y_min, y_max, int((y_max - y_min) / y_step))
# Interpolate x by y
angles = sorted(set([abs(x) for x in data.info['obst_angle'] if x]))
speeds = sorted(set([abs(x) for x in data.info['obst_speed'] if x]))
print(angles, speeds)
trajs = {}
for ipd in [0, 0.07]:
    iplot = 1
    fig = plt.figure()
    for angle in angles:
        for speed in speeds:
            ax = fig.add_subplot(len(angles), len(speeds), iplot)
            iplot += 1
            ax.set_title(f'{angle}°, {speed} m/s, ipd={ipd}')
            if angle == 180:
                ax.set_xlabel('x (m)')
            if speed == 0.9:
                ax.set_ylabel('y (m)')
            ax.set_aspect('equal')
            ax.set_xlim(-2, 2)
            ax.set_ylim(-0.2, 12)
            front = []
            behind = []
            # Plot trials
            for i in range(len(data.trajs)):
                if (angle != abs(data.info['obst_angle'][i]) or speed != data.info['obst_speed'][i] or i in data.dump or
                data.info['ipd'][i] != ipd):
                    continue
                traj = np.array(data.info['p_subj'][i])
                traj = -traj if i % 2 == 1 else traj
                traj = traj - traj[0]
                traj = rotate(traj, np.arctan(11 / 9) - pi / 2)
                if data.info['obst_angle'][i] < 0:
                    traj[:,0] = -traj[:,0]
                x = np.interp(y, traj[:,1], traj[:,0])
                order = data.info['pass_order'][i]
                if order == 1:
                    front.append(x)
                    ax.plot(traj[:, 0], traj[:, 1], color='r', linewidth=0.1, alpha=0.3)
                elif order == -1:
                    behind.append(x)
                    ax.plot(traj[:, 0], traj[:, 1], color='b', linewidth=0.1, alpha=0.3)
            trajs[(ipd, angle, speed, 1)] = front
            trajs[(ipd, angle, speed, -1)] = behind
            # Plot average trials
            front = np.mean(front, axis=0)
            behind = np.mean(behind, axis=0)
            ax.plot(front, y, color='r')
            ax.plot(behind, y, color='b')

In [None]:
'''Test average trajectories disparity vs no disparity'''
zs = {}
for angle in angles:
    for speed in speeds:
        for order in [-1, 1]:
            traj1 = np.mean(trajs[(0, angle, speed, order)], axis=0)
            traj2 = np.mean(trajs[(0.07, angle, speed, order)], axis=0)
            x = np.mean(np.abs(traj1 - traj2))
            sample = []
            g = trajs[(0, angle, speed, order)] + trajs[(0.07, angle, speed, order)]
            n = len(g) // 2
            for _ in range(1000):
                random.shuffle(g)
                g1 = np.mean(g[:n], axis=0)
                g2 = np.mean(g[n:], axis=0)
                sample.append(np.mean(np.abs(g1 - g2)))
            zs[(angle, speed, order)] = (x - np.mean(sample)) / np.std(sample)
zs

In [None]:
# Plot average trajectories expansion vs contraction
load_data('Bai_movObst2_data_30Hz')
%matplotlib qt
y_max = float('inf')
y_min = -float('inf')
for i in range(len(data.trajs)):
    if i in data.dump or data.info['obst_speed'][i] == 0:
        continue
    traj = np.array(data.info['p_subj'][i])
    traj = -traj if i % 2 == 1 else traj
    traj = traj - traj[0]
    traj = rotate(traj, np.arctan(11 / 9) - pi / 2)
    y_max = min(y_max, max(traj[:, 1]))
    y_min = max(y_min, min(traj[:, 1]))
    if y_max == 0:
        print(i)
print(f'common y range among all trials is [{y_min}, {y_max}]')

# Compute y positions
y_step = 0.01
y = np.linspace(y_min, y_max, int((y_max - y_min) / y_step))
# Interpolate x by y
angles = sorted(set([abs(x) for x in data.info['obst_angle'] if x]))
speeds = sorted(set([abs(x) for x in data.info['obst_speed'] if x]))
print(angles, speeds)
trajs = {}
for dsize in [-0.1, 0, 0.1]:
    iplot = 1
    fig = plt.figure()
    for angle in angles:
        for speed in speeds:
            ax = fig.add_subplot(len(angles), len(speeds), iplot)
            iplot += 1
            ax.set_title(f'{angle}°, {speed} m/s, dsize={dsize}')
            if angle == 180:
                ax.set_xlabel('x (m)')
            if speed == 0.9:
                ax.set_ylabel('y (m)')
            ax.set_aspect('equal')
            ax.set_xlim(-2, 2)
            ax.set_ylim(-0.2, 12)
            front = []
            behind = []
            # Plot trials
            for i in range(len(data.trajs)):
                if (angle != abs(data.info['obst_angle'][i]) or speed != data.info['obst_speed'][i] or i in data.dump or
                data.info['dsize'][i] != dsize):
                    continue
                traj = np.array(data.info['p_subj'][i])
                traj = -traj if i % 2 == 1 else traj
                traj = traj - traj[0]
                traj = rotate(traj, np.arctan(11 / 9) - pi / 2)
                if data.info['obst_angle'][i] < 0:
                    traj[:,0] = -traj[:,0]
                x = np.interp(y, traj[:,1], traj[:,0])
                order = data.info['pass_order'][i]
                if order == 1:
                    front.append(x)
                    ax.plot(traj[:, 0], traj[:, 1], color='r', linewidth=0.1, alpha=0.3)
                elif order == -1:
                    behind.append(x)
                    ax.plot(traj[:, 0], traj[:, 1], color='b', linewidth=0.1, alpha=0.3)
            trajs[(dsize, angle, speed, 1)] = front
            trajs[(dsize, angle, speed, -1)] = behind
            # Plot average trials
            front = np.mean(front, axis=0)
            behind = np.mean(behind, axis=0)
            ax.plot(front, y, color='r')
            ax.plot(behind, y, color='b')

In [None]:
'''Test average trajectories disparity vs no disparity'''
zs = {}
for angle in angles:
    for speed in speeds:
        for order in [-1, 1]:
            traj1 = np.mean(trajs[(-0.1, angle, speed, order)], axis=0)
            traj2 = np.mean(trajs[(0.1, angle, speed, order)], axis=0)
            x = np.mean(np.abs(traj1 - traj2))
            sample = []
            g = trajs[(-0.1, angle, speed, order)] + trajs[(0.1, angle, speed, order)]
            n = len(g) // 2
            for _ in range(1000):
                random.shuffle(g)
                g1 = np.mean(g[:n], axis=0)
                g2 = np.mean(g[n:], axis=0)
                sample.append(np.mean(np.abs(g1 - g2)))
            zs[(angle, speed, order)] = (x - np.mean(sample)) / np.std(sample)
zs

In [None]:
# Plot average trajectories trial set 1 vs trial set 2
load_data('Bai_movObst2_data_30Hz')
%matplotlib qt
y_max = float('inf')
y_min = -float('inf')
for i in range(len(data.trajs)):
    if i in data.dump or data.info['obst_speed'][i] == 0:
        continue
    traj = np.array(data.info['p_subj'][i])
    traj = -traj if i % 2 == 1 else traj
    traj = traj - traj[0]
    traj = rotate(traj, np.arctan(11 / 9) - pi / 2)
    y_max = min(y_max, max(traj[:, 1]))
    y_min = max(y_min, min(traj[:, 1]))
    if y_max == 0:
        print(i)
print(f'common y range among all trials is [{y_min}, {y_max}]')

# Compute y positions
y_step = 0.01
y = np.linspace(y_min, y_max, int((y_max - y_min) / y_step))
# Interpolate x by y
angles = sorted(set([abs(x) for x in data.info['obst_angle'] if x]))
speeds = sorted(set([abs(x) for x in data.info['obst_speed'] if x]))
print(angles, speeds)
trajs = {}
for n in [0, 1]:
    iplot = 1
    fig = plt.figure()
    for angle in angles:
        for speed in speeds:
            ax = fig.add_subplot(len(angles), len(speeds), iplot)
            iplot += 1
            ax.set_title(f'{angle}°, {speed} m/s, trial set {n}')
            if angle == 180:
                ax.set_xlabel('x (m)')
            if speed == 0.9:
                ax.set_ylabel('y (m)')
            ax.set_aspect('equal')
            ax.set_xlim(-2, 2)
            ax.set_ylim(-0.2, 12)
            front = []
            behind = []
            # Plot trials
            for i in range(len(data.trajs)):
                if (angle != abs(data.info['obst_angle'][i]) or speed != data.info['obst_speed'][i] or i in data.dump or
                i % 2 != n):
                    continue
                traj = np.array(data.info['p_subj'][i])
                traj = -traj if i % 2 == 1 else traj
                traj = traj - traj[0]
                traj = rotate(traj, np.arctan(11 / 9) - pi / 2)
                if data.info['obst_angle'][i] < 0:
                    traj[:,0] = -traj[:,0]
                x = np.interp(y, traj[:,1], traj[:,0])
                order = data.info['pass_order'][i]
                if order == 1:
                    front.append(x)
                    ax.plot(traj[:, 0], traj[:, 1], color='r', linewidth=0.1, alpha=0.3)
                elif order == -1:
                    behind.append(x)
                    ax.plot(traj[:, 0], traj[:, 1], color='b', linewidth=0.1, alpha=0.3)
            trajs[(n, angle, speed, 1)] = front
            trajs[(n, angle, speed, -1)] = behind
            # Plot average trials
            front = np.mean(front, axis=0)
            behind = np.mean(behind, axis=0)
            ax.plot(front, y, color='r')
            ax.plot(behind, y, color='b')

In [None]:
'''Test average trajectories disparity vs no disparity'''
zs = {}
for angle in angles:
    for speed in speeds:
        for order in [-1, 1]:
            traj1 = np.mean(trajs[(0, angle, speed, order)], axis=0)
            traj2 = np.mean(trajs[(1, angle, speed, order)], axis=0)
            x = np.mean(np.abs(traj1 - traj2))
            sample = []
            g = trajs[(0, angle, speed, order)] + trajs[(1, angle, speed, order)]
            n = len(g) // 2
            for _ in range(1000):
                random.shuffle(g)
                g1 = np.mean(g[:n], axis=0)
                g2 = np.mean(g[n:], axis=0)
                sample.append(np.mean(np.abs(g1 - g2)))
            zs[(angle, speed, order)] = (x - np.mean(sample)) / np.std(sample)
zs

In [None]:
'''% Correct pass order by dtheta threshold'''
%matplotlib qt
load_data('Bai_movObst1_data_30Hz')
threshold = 0.02
thress = np.linspace(0.001, 0.03, 20)
threshold = thress[np.argmin(np.abs(thress - threshold))]
accuracies = []
lengths_thres = []
lengths = []
for thres in thress:
    correct = 0
    total = 0
    for i in range(len(data.trajs)):
        if i not in data.dump and data.info['obst_speed'][i] > 0 and abs(data.info['obst_angle'][i]) != 180:
            t0 = data.info['obst_onset'][i]
            t1 = data.info['obst_out'][i]
            rel_dpsi = data.info['rel_dpsi'][i]
            dtheta = data.info['dtheta'][i]
            for t in range(t0, t1):
                if dtheta[t] > thres:
                    dpsi = rel_dpsi[t]
                    break
            else:
                dpsi = rel_dpsi[t]
            if thres == threshold:
                lengths.append((t1 - t0) / data.Hz)
                lengths_thres.append((t1 - t) / data.Hz)
            if np.sign(dpsi) == data.info['pass_order'][i]:
                correct += 1
            total += 1
    accuracies.append(correct / total)
plt.figure()
plt.plot(thress, accuracies)
plt.title('Bai_movObst1')
plt.ylabel('pass order accuracy')
plt.xlabel('dthetha threshold (rad/s)')
plt.figure()
plt.hist(lengths_thres)
plt.title(f'Bai_movObst1 thres={threshold}')
plt.xlabel('effective trial length (s)')
plt.figure()
plt.hist(lengths)
plt.title('Bai_movObst1')
plt.xlabel('effective trial length (s)')

### Check training results

In [None]:
'''Check optimal parameters for individuals from training results'''
bests = {} 
path = os.path.abspath(os.path.join(os.getcwd(), os.pardir, 'Results', 'Bai_movObst1_cross_validation_cohen_avoid2_with_approach_trained_on_Bai_movObst1b_all_threshold_onset'))
filenames = [os.path.join(path, name) for name in os.listdir(path) if name[-3:] == 'txt']
# filenames = [filenames[1]]
errs = []
for filename in filenames:
    with open(filename, 'rb') as f:
        best = 0
        e_min = float('inf')
        for i, line in enumerate(f):
            if i == 1:
                subj_id = line[-5:-2]
            if i >= 11:
                try:
                    err = str(line).split("\\t")[2][:10]
                    if err[0] == '0':
                        err = float(err)
                        if err < e_min:
                            e_min = err
                            best = str(line).replace("\\", "")
                except Exception as e:
                    print(e)
    bests[subj_id] = best
for i, best in bests.items():
    print('\n')
    print(i, best)

In [None]:
'''Check optimal parameters for all subjects from training results'''
path = os.path.abspath(os.path.join(os.getcwd(), os.pardir, 'Results', 'Bai_movObst1_all_all_models_with_approach_trained_on_Bai_movObst1b_all_threshold_onset'))
filenames = [os.path.join(path, name) for name in os.listdir(path) if name[-3:] == 'txt']
errs = []
for filename in filenames:
    with open(filename, 'rb') as f:
        best = 0
        e_min = float('inf')
        for i, line in enumerate(f):
            if i > 10:
                try:
                    err = str(line).split("\\t")[2]
                    if err[0] == '0':
                        err = float(err)
                        if err < e_min:
                            e_min = err
                            best = str(line).replace("\\", "")
                except Exception as e:
                    print(e)
                    pass

    print('\n')
    print(best)
    

In [None]:
bests