# Script to generate Figure 5

In [None]:
import numpy as np

The relevant data is the policy $\pi(t)$ and the inferred event $e(t)$

In [None]:
# Define indices of data where to find which information
event_t = 1
policy_t = 2
event_prob_t = 3

In [None]:
def get_relevant_data(data):
    policies = np.zeros(270)
    boundaries = np.zeros(2)
    p_event = np.zeros((4, 270))
    for t in range(270):
        policies[t] = data[t, policy_t] + 1
        if policies[t] == 3:
            policies[t] = 0
        if data[t, event_t] == 2 and data[t+1, event_t] == 3:
            boundaries[0] = t
        if data[t, event_t] == 3:
            if t < 299 and data[t+1, event_t] == 1:
                boundaries[1] = t
        for i in range(4):
            p_event[i, t] = data[t, event_prob_t + i]
    return policies, boundaries, p_event

# Hand plot

This file plots the event and policy inference over one testing run. First select which runs to plot through a local path:

In [None]:
# Interesting event inference or gaze behavior for hand agent
# Used in paper: Simulation 16, run 7
filename =  "PATH:/to/your/data/res_tau_2_sim16_epoch29_hand_run7.txt"

Load data from this run

In [None]:
data = np.loadtxt(filename, dtype='float64', skiprows = 1, delimiter= ', ')
policies, boundaries, p_event = get_relevant_data(data)

Plot the run

In [None]:
# Nice color definitions
colors = [(0.368, 0.507, 0.71), (0.881, 0.611, 0.142),
          (0.56, 0.692, 0.195), (0.923, 0.386, 0.209),
          (0.528, 0.471, 0.701), (0.772, 0.432, 0.102),
          (0.364, 0.619, 0.782), (0.572, 0.586, 0.) ]

In [None]:
import matplotlib.pyplot as plt
fig, axs = plt.subplots(2)
fig.set_figwidth(15)
fig.subplots_adjust(hspace=0)

#Policy plot:
axs[0].plot(range(270), policies, color='k', linewidth=1)
axs[0].set_yticks(np.array([0, 1, 2]))
axs[0].set_yticklabels(['\pi_{none}', '\pi_{agent}', '\pi_{patient}'])
axs[0].set_ylabel('Policy')
axs[0].set_xlim([0, 270])
axs[0].set_xticks([])

axs[1].fill_between(range(270), 0,  p_event[0, :], color=colors[1], linewidth=1)
axs[1].fill_between(range(270), p_event[0, :], p_event[0, :] + p_event[1, :], color=colors[5], linewidth=1)
axs[1].fill_between(range(270), p_event[0, :] + p_event[1, :], p_event[0, :] + p_event[1, :] + p_event[2, :], color=colors[2], linewidth=1)
axs[1].fill_between(range(270), p_event[0, :] + p_event[1, :] + p_event[2, :], 1.0, color=colors[4], linewidth=1)
axs[1].set_xlabel('t')
axs[1].set_yticks([0, 0.5, 1.0])
axs[1].set_ylim([0, 1.1])
axs[1].set_ylabel('P(e(t)| O(t), \Pi(t))')
axs[1].legend(['still', 'random', 'reach', 'transport'])
#Lines for boundaries
axs[0].plot([boundaries[0], boundaries[0]], [0.0, 2.0], 'k:')
axs[1].plot([boundaries[0], boundaries[0]], [0.0, 2.0], 'k:')
axs[0].plot([boundaries[1], boundaries[1]], [0.0, 2.0], 'k:')
axs[1].plot([boundaries[1], boundaries[1]], [0.0, 2.0], 'k:')
axs[1].set_xticks([boundaries[0], boundaries[1]])
axs[1].set_xticklabels(['reach -> transport', 'transport -> random'])
axs[1].set_xlim([0, 269])


#plt.show()
import tikzplotlib
tikzplotlib.save("resultPlots/tikz_one_run_hand.tex")

# Claw plot

In [None]:
# Interesting event inference or gaze behavior for claw agent
# Used in paper: Simulation 6 run 9
filename =  "PATH:/to/your/data/res_tau_2_sim6_epoch29_claw_run9.txt"


In [None]:
data = np.loadtxt(filename, dtype='float64', skiprows = 1, delimiter= ', ')
policies, boundaries, p_event = get_relevant_data(data)

In [None]:
import matplotlib.pyplot as plt
fig, axs = plt.subplots(2)
fig.set_figwidth(15)
fig.subplots_adjust(hspace=0)

#Policy plot:
axs[0].plot(range(270), policies, color='k', linewidth=1)
axs[0].set_yticks(np.array([0, 1, 2]))
axs[0].set_yticklabels(['\pi_{none}', '\pi_{agent}', '\pi_{patient}'])
axs[0].set_ylabel('Policy')
axs[0].set_xlim([0, 270])
axs[0].set_xticks([])

axs[1].fill_between(range(270), 0,  p_event[0, :], color=colors[1], linewidth=1)
axs[1].fill_between(range(270), p_event[0, :], p_event[0, :] + p_event[1, :], color=colors[5], linewidth=1)
axs[1].fill_between(range(270), p_event[0, :] + p_event[1, :], p_event[0, :] + p_event[1, :] + p_event[2, :], color=colors[2], linewidth=1)
axs[1].fill_between(range(270), p_event[0, :] + p_event[1, :] + p_event[2, :], 1.0, color=colors[4], linewidth=1)
axs[1].set_xlabel('t')
axs[1].set_yticks([0, 0.5, 1.0])
axs[1].set_ylim([0, 1.1])
axs[1].set_ylabel('P(e(t)| O(t), \Pi(t))')
axs[1].legend(['still', 'random', 'reach', 'transport'])
#Lines for boundaries
axs[0].plot([boundaries[0], boundaries[0]], [0.0, 2.0], 'k:')
axs[1].plot([boundaries[0], boundaries[0]], [0.0, 2.0], 'k:')
axs[0].plot([boundaries[1], boundaries[1]], [0.0, 2.0], 'k:')
axs[1].plot([boundaries[1], boundaries[1]], [0.0, 2.0], 'k:')
axs[1].set_xticks([boundaries[0], boundaries[1]])
axs[1].set_xticklabels(['reach -> transport', 'transport -> random'])
axs[1].set_xlim([0, 269])


#plt.show()
import tikzplotlib
tikzplotlib.save("resultPlots/tikz_one_run_claw.tex")