In [1]:
import os
import sys
from scipy.integrate import solve_ivp
from scipy.stats import ttest_rel
import numpy as np
from numpy.linalg import norm
from numpy import sqrt
from math import pi
import pickle
import matplotlib.pyplot as plt
from packages import data_container
from packages.data_container import Data
from packages.helper import play_trajs, rotate, sp2a, v2sp, dist, psi, beta, d_theta, d_psi, sp2v, dist, \
    traj_speed, min_sep
from packages.ode_simulator import ODESimulator
# For pickle to load the Data object, which is defined in packages.data_container
sys.modules['data_container'] = data_container

# file = os.path.abspath(os.path.join(os.getcwd(), os.pardir, 'Raw_Data', 'Bai_movObst1_data.pickle'))
# with open(file, 'rb') as f:
#     data = pickle.load(f)
file = os.path.abspath(os.path.join(os.getcwd(), os.pardir, 'Raw_Data', 'Cohen_movObst1_data.pickle'))
with open(file, 'rb') as f:
    data = pickle.load(f)

'''Models'''
fajen_approach = {'name': 'fajen_approach', 'ps': 1, 'b1': 3.25, 'k1': 7.5, 'c1': 0.4, 'c2': 0.4, 'k2': 1.4}
fajen_approach2 = {'name': 'fajen_approach2', 
                         'ps': 1.3, 'b1': 3.25, 'k1':7.5, 'c1': 0.4, 'c2': 0.4, 'b2': 4.8, 'k2': 6}
cohen_avoid = {'name': 'cohen_avoid', 'ps': 1.3, 'b1': 3.25, 'k1': 530, 'c5': 6, 'c6': 1.3, 
               'b2': 3.25, 'k2': 530, 'c7': 6, 'c8': 1.3}
cohen_avoid2 = {'name': 'cohen_avoid2', 'ps': 1.3, 'b1': 3.25, 'k1': 530, 'c5': 6, 'c6': 1.3, 
               'b2': 3.25, 'k2': 530, 'c7': 6, 'c8': 1.3}
cohen_avoid3 = {'name': 'cohen_avoid3', 'k1': 530, 'c5': 6, 'c6': 1.3, 'k2': 50, 'c7': 6, 'c8': 1.3}
cohen_avoid4 = {'name': 'cohen_avoid4', 'k1': 530, 'c5': 6, 'c6': 1.3, 'k2': 50, 'c7': 6, 'c8': 1.3}

In [6]:
'''Simulation with var0'''
%matplotlib qt
Hz = 100
xg0, yg0, xo0, yo0, x0, y0 = 0, 10, 5, 5, 0, 0
vxo0, vyo0, vx0, vy0 = -1, 0, 0, 1
s0, phi0 = v2sp([vx0, vy0])
dphi0 = dds = 0
var0 = [xg0, yg0, xo0, yo0, vxo0, vyo0, x0, y0, vx0, vy0, phi0, s0, dphi0, dds]
models = [fajen_approach, cohen_avoid]
args = {'w_goal': 0.1, 'w_obst': 0.1}
sim = ODESimulator(Hz=90, models=models, args=args, ref=[0, 1])
sim.simulate(var0, total_time=10)
sim.play()

Simulation finished in 0:00:00 t_total 9.988889


In [3]:
'''Simulate one trial'''
############
i_trial = 996
approach = {'name': 'fajen_approach', 'b1': 2.01938384, 'k1': 4.90527274, 'c1': 2.96094879, 'c2': 0.50896525, 'k2': 1.61216734}
avoid = {'name': 'cohen_avoid', 'b1': 9.765711751753193, 'k1': 267.9131480127584, 'c5': 4.177541546041094, 'c6': 5.658734281737619, 'b2': 41.98739870194379, 'k2': 323.75304694723644, 'c7': 0.40326520167129853, 'c8': 0.49703900483204855}
models = [approach, avoid]
############
%matplotlib qt
sim = ODESimulator(data=data, ref=[0, 1])
sim.models = models
sim.simulate_all(trials=[i_trial], t_start='obst_onset', t_end='obst_out', ps='trial')
title = 'subj ' + str(sim.data.info['subj_id'][i_trial]) + \
        ' trial ' + str(sim.data.info['trial_id'][i_trial]) + \
        ' obst_angle: ' + str(sim.data.info['obst_angle'][i_trial]) + \
        ' obst_speed: ' + str(sim.data.info['obst_speed'][i_trial])
sim.test('p_dist')
sim.play(title=title)

Loading finished
simulation ended early on trial 996, switch to Euler method
0.00028726910622838524 0.0011682181661080086
Simulated 1 trials in 0:00:00


<matplotlib.animation.FuncAnimation at 0x238b8e60f08>

In [5]:
sim.play()

In [2]:
'''Simuate all trials model 1'''
############
subject = 1
approach = {'name': 'fajen_approach', 'b1': 2.01938384, 'k1': 4.90527274, 'c1': 2.96094879, 'c2': 0.50896525, 'k2': 1.61216734}
avoid = {'name': 'cohen_avoid', 'b1': 9.765711751753193, 'k1': 267.9131480127584, 'c5': 4.177541546041094, 'c6': 5.658734281737619, 'b2': 41.98739870194379, 'k2': 323.75304694723644, 'c7': 0.40326520167129853, 'c8': 0.49703900483204855}
############
%matplotlib qt
models = [approach, avoid]
sim1 = ODESimulator(data=data, ref=[0, 1])
sim1.reset()
sim1.models = models
trials = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524, 525, 526, 527, 528, 529, 530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, 548, 549, 550, 551, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587, 588, 589, 590, 591, 592, 593, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 610, 611, 612, 613, 614, 615, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, 632, 633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643, 644, 645, 646, 647, 648, 649, 650, 651, 652, 653, 654, 655, 656, 657, 658, 659, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669, 670, 671, 672, 673, 674, 675, 676, 677, 678, 679, 680, 681, 682, 683, 684, 685, 686, 687, 688, 689, 690, 691, 692, 693, 694, 695, 696, 697, 698, 699, 700, 701, 702, 703, 704, 705, 706, 707, 708, 709, 710, 711, 712, 713, 714, 715, 716, 717, 718, 719, 720, 721, 722, 723, 724, 725, 726, 727, 728, 729, 730, 731, 732, 733, 734, 735, 736, 737, 738, 739, 740, 741, 742, 743, 744, 745, 746, 747, 748, 749, 750, 751, 752, 753, 754, 755, 756, 757, 758, 759, 760, 761, 762, 763, 764, 765, 766, 767, 768, 769, 770, 771, 772, 773, 774, 775, 776, 777, 778, 779, 780, 781, 782, 783, 784, 785, 786, 787, 788, 789, 790, 791, 792, 793, 794, 795, 796, 797, 798, 799, 800, 801, 802, 803, 804, 805, 806, 807, 808, 809, 810, 811, 812, 813, 814, 815, 816, 817, 818, 819, 820, 821, 822, 823, 824, 825, 826, 827, 828, 829, 830, 831, 832, 833, 834, 835, 836, 837, 838, 839, 840, 841, 843, 844, 845, 846, 847, 848, 849, 850, 851, 852, 853, 854, 855, 856, 857, 858, 859, 860, 861, 862, 863, 864, 865, 866, 867, 868, 869, 870, 871, 872, 873, 874, 875, 876, 877, 878, 879, 880, 881, 882, 883, 884, 885, 886, 887, 888, 889, 890, 891, 892, 893, 894, 895, 896, 897, 898, 899, 900, 901, 902, 903, 904, 905, 906, 907, 908, 909, 910, 911, 912, 913, 914, 915, 916, 917, 918, 919, 920, 921, 922, 923, 924, 925, 926, 927, 928, 929, 930, 931, 932, 933, 934, 935, 936, 937, 938, 939, 940, 941, 942, 943, 944, 945, 946, 947, 948, 949, 950, 951, 952, 953, 954, 955, 956, 957, 958, 959, 960, 961, 962, 963, 964, 965, 966, 967, 968, 969, 970, 971, 972, 973, 974, 975, 976, 977, 978, 979, 980, 981, 982, 983, 984, 985, 986, 987, 988, 989, 990, 991, 992, 993, 994, 995, 996, 997, 998, 999, 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007]

# for i in range(len(data.trajs)):
#     if i in data.dump or data.info['subj_id'][i] != subject or data.info['obst_speed'][i] == 0:
#         continue
#     trials.append(i)
sim1.simulate_all(trials, t_start='stimuli_onset', t_end='obst_out', ps='trial')
sim1.test('p_dist')
sim1.test('accuracy')

Loading finished
Simulated 1005 trials in 0:00:44


0.19508337878808443

In [3]:
'''Simuate all trials model 2'''
############
subject = 1
approach = {'name': 'fajen_approach2', 'b1': 2.04992354, 'k1': 2.85641543, 'c1': 0.54294928, 'c2': 0.73857217, 'b2': 3.89580222, 'k2': 5.04511601, 'ps': 1.3075058885951525}
avoid =  {'name': 'cohen_avoid4_thres', 'k1': 2.88211554599569, 'c5': 1.8402080467654376, 'c6': 10.786835005989973, 'k2': 3.380457830359567, 'c7': 5.4360445618996565, 'c8': 12.67287871326611, 'thres': 0.006145124061037794, 'ps': 1.3075058885951525}
############
%matplotlib qt
models = [approach, avoid]
sim2 = ODESimulator(data=data, ref=[0, 1])
sim2.reset()
sim2.models = models
trials = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524, 525, 526, 527, 528, 529, 530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, 548, 549, 550, 551, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587, 588, 589, 590, 591, 592, 593, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 610, 611, 612, 613, 614, 615, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, 632, 633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643, 644, 645, 646, 647, 648, 649, 650, 651, 652, 653, 654, 655, 656, 657, 658, 659, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669, 670, 671, 672, 673, 674, 675, 676, 677, 678, 679, 680, 681, 682, 683, 684, 685, 686, 687, 688, 689, 690, 691, 692, 693, 694, 695, 696, 697, 698, 699, 700, 701, 702, 703, 704, 705, 706, 707, 708, 709, 710, 711, 712, 713, 714, 715, 716, 717, 718, 719, 720, 721, 722, 723, 724, 725, 726, 727, 728, 729, 730, 731, 732, 733, 734, 735, 736, 737, 738, 739, 740, 741, 742, 743, 744, 745, 746, 747, 748, 749, 750, 751, 752, 753, 754, 755, 756, 757, 758, 759, 760, 761, 762, 763, 764, 765, 766, 767, 768, 769, 770, 771, 772, 773, 774, 775, 776, 777, 778, 779, 780, 781, 782, 783, 784, 785, 786, 787, 788, 789, 790, 791, 792, 793, 794, 795, 796, 797, 798, 799, 800, 801, 802, 803, 804, 805, 806, 807, 808, 809, 810, 811, 812, 813, 814, 815, 816, 817, 818, 819, 820, 821, 822, 823, 824, 825, 826, 827, 828, 829, 830, 831, 832, 833, 834, 835, 836, 837, 838, 839, 840, 841, 843, 844, 845, 846, 847, 848, 849, 850, 851, 852, 853, 854, 855, 856, 857, 858, 859, 860, 861, 862, 863, 864, 865, 866, 867, 868, 869, 870, 871, 872, 873, 874, 875, 876, 877, 878, 879, 880, 881, 882, 883, 884, 885, 886, 887, 888, 889, 890, 891, 892, 893, 894, 895, 896, 897, 898, 899, 900, 901, 902, 903, 904, 905, 906, 907, 908, 909, 910, 911, 912, 913, 914, 915, 916, 917, 918, 919, 920, 921, 922, 923, 924, 925, 926, 927, 928, 929, 930, 931, 932, 933, 934, 935, 936, 937, 938, 939, 940, 941, 942, 943, 944, 945, 946, 947, 948, 949, 950, 951, 952, 953, 954, 955, 956, 957, 958, 959, 960, 961, 962, 963, 964, 965, 966, 967, 968, 969, 970, 971, 972, 973, 974, 975, 976, 977, 978, 979, 980, 981, 982, 983, 984, 985, 986, 987, 988, 989, 990, 991, 992, 993, 994, 995, 996, 997, 998, 999, 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007]

# for i in range(len(data.trajs)):
#     if i in data.dump or data.info['subj_id'][i] != subject or data.info['obst_speed'][i] == 0:
#         continue
#     trials.append(i)
sim2.simulate_all(trials, t_start='stimuli_onset', t_end='obst_out', ps='trial')
sim2.test('p_dist')
sim2.test('accuracy')

Loading finished
Simulated 1005 trials in 0:01:00


0.18389660270629649

In [19]:
'''Minimum Passing Distance'''
%matplotlib qt
sim = sim2
fig = plt.figure()
ax = fig.add_subplot()
ax.set_title('Signed predicted minimum passing distance (SMPD)')
ax.set_ylabel('SMPD (m)')
ax.set_xlabel('normalized time (%)')
ax.set_ylim((-2, 2))
for j, i in enumerate(sim.i_trials):
    t0 = data.info['stimuli_onset'][i]
    t1 = data.info['stimuli_out'][i]
    p0 = sim.p_pred[j]
    p1 = data.info['p_obst'][i][t0:t1]
    v0 = sim.v_pred[j]
    v1 = data.info['v_obst'][i][t0:t1]
    t = np.linspace(0, 100, len(p0))
    smpd = []
    for _p0, _p1, _v0, _v1 in zip(p0, p1, v0, v1):        
        smpd.append(min_sep(_p0, _p1, _v0, _v1)[0])
    ax.plot(t, smpd, 'k', linewidth=0.1, alpha=0.5)

In [6]:
print(len(p0), len(p1), len(sim.i_trials))

123 123 1005


In [16]:
'''Compute pred_order'''
s = sim
orders = []
for i in range(len(s.p_pred)):
    j = s.i_trials[i]
    true_order = s.data.info['pass_order'][j]
    p0 = s.p_pred[i][-1]
    v0 = s.v_pred[i][-1]
    p1 = s.p_obst[i][-1]
    _beta = beta(p0, p1, v0)
    angle = s.data.info['obst_angle'][j]
    pred_order = np.sign(_beta * -angle)
    orders.append(true_order == pred_order)
print(np.mean(orders))

0.8398009950248756


In [13]:
sim.test('order_accuracy')

0.6716417910447762

In [9]:
'''t-test between two models'''
err1 = sim.test('p_dist', all_errors=True)
err2 = sim2.test('p_dist', all_errors=True)
print(ttest_rel(err1, err2))
print(f'sd1 = {np.std(err1)}, sd2 = {np.std(err2)}')
print(f'df = {len(err1) - 1}')

Ttest_relResult(statistic=6.88882649896061, pvalue=9.913386664917606e-12)
sd1 = 0.17176037737296396, sd2 = 0.16896790728323252
df = 1004


In [72]:
'''Show one trial from batch simulations'''
######## index in simulated trials only
i = 83
########
i_trial = sim.i_trials[i]
title = 'subj ' + str(sim.data.info['subj_id'][i_trial]) + \
        ' trial ' + str(sim.data.info['trial_id'][i_trial]) + \
        ' obst_angle: ' + str(sim.data.info['obst_angle'][i_trial]) + \
        ' obst_speed: ' + str(sim.data.info['obst_speed'][i_trial])
sim.play(i, title=title, save=False)

# When beta and dpsi has the same sign it means pass in front, otherwise it means pass from behind
pass_order = sim.data.info['pass_order'][i]
pred = sim.pass_order_pred[i]
print('pass order ', pass_order, 'prediction ', pred)
print(f"err is {sim.test('p_dist', i_trial=i)}")

pass order  -1.0 prediction  1.0
err is 0.5442221570927824


In [65]:
'''true vs predicted speed by condition'''
#####################
subject = 1
#####################
%matplotlib qt
Hz = sim.Hz
fig = plt.figure()
fig.suptitle('Subject ' + str(subject))
axes = {}
obst_angle = sorted(set([abs(x) for x in data.info['obst_angle'] if x != 0]))
obst_speed = sorted(set([abs(x) for x in data.info['obst_speed'] if x != 0]))
i_plot = 1
for angle in obst_angle:
    for speed in obst_speed:
        axes[(angle, speed)] = fig.add_subplot(5, 5, i_plot)
        axes[(angle, speed)].set_xlim(0, 5)
        axes[(angle, speed)].set_ylim(0.5, 2)
        axes[(angle, speed)].set_title(str(angle) + '° ' + str(speed) + 'm/s')
#         axes[(angle, speed)].set_aspect('equal')
        i_plot += 1
for i, i_trial in enumerate(sim.i_trials):
    speed = sim.data.info['obst_speed'][i_trial]
    angle = sim.data.info['obst_angle'][i_trial]
    subj_id = sim.data.info['subj_id'][i_trial]
    if subj_id != subject or speed == 0:
        continue
    
    subj, pred = sim.p_subj[i], sim.p_pred[i]
    s_subj, s_pred = traj_speed(subj, Hz), traj_speed(pred, Hz)
    t = np.linspace(0, len(pred)-1, len(pred)) / Hz
    # Speed
    axes[(abs(angle), speed)].plot(t, s_subj, 'r')
    axes[(abs(angle), speed)].plot(t, s_pred, 'b')


In [67]:
'''true vs predicted heading by condition'''
#####################
subject = 1
#####################
%matplotlib qt
Hz = sim.Hz
fig = plt.figure()
fig.suptitle('Subject ' + str(subject))
axes = {}
obst_angle = sorted(set([abs(x) for x in data.info['obst_angle'] if x != 0]))
obst_speed = sorted(set([abs(x) for x in data.info['obst_speed'] if x != 0]))
i_plot = 1
for angle in obst_angle:
    for speed in obst_speed:
        axes[(angle, speed)] = fig.add_subplot(5, 5, i_plot)
        axes[(angle, speed)].set_xlim(0, 5)
        axes[(angle, speed)].set_ylim(0, 1.5)
        axes[(angle, speed)].set_title(str(angle) + '° ' + str(speed) + 'm/s')
#         axes[(angle, speed)].set_aspect('equal')
        i_plot += 1
for i, i_trial in enumerate(sim.i_trials):
    speed = sim.data.info['obst_speed'][i_trial]
    angle = sim.data.info['obst_angle'][i_trial]
    subj_id = sim.data.info['subj_id'][i_trial]
    if subj_id != subject or speed == 0:
        continue
    
    subj, pred = sim.p_subj[i], sim.p_pred[i]
    h_subj, h_pred = v2sp(np.gradient(subj, axis=0) * Hz)[1], v2sp(np.gradient(pred, axis=0) * Hz)[1]
    if h_subj[0] < 0:
        h_subj += pi
        h_pred += pi
    t = np.linspace(0, len(pred)-1, len(pred)) / Hz
    # Heading
    axes[(abs(angle), speed)].plot(t, h_subj, 'r')
    axes[(abs(angle), speed)].plot(t, h_pred, 'b')
    if abs(angle) == 112.5 and speed == 1.2:
        print(i, i_trial)


18 184
69 253
77 263
83 273
111 311
118 318


In [8]:
xg, yg, xo, yo, vxo, vyo, x, y, vx, vy, phi, s, dphi, ds = sim.var0_match[3]
p0, p1, v0, v1 = [x, y], [xo, yo], [vx, vy], [vxo, vyo]
beta(p0, p1, v0) * d_psi(p0, p1, v0, v1)

-8.916084053041905e-06

In [29]:
sim.test('order_accuracy')

0.65