In [None]:
import gym
import interaction_gym
import numpy as np
import event_inference as event
import random

## Create model and environment

In [None]:
seed = 42

In [None]:
model = event.CAPRI(epsilon_start=0.01, epsilon_dynamics=0.001, epsilon_end=0.001,
                    no_transition_prior=0.9, dim_observation=18, num_policies=3, 
                    num_models=4, r_seed=seed, sampling_rate=2)

In [None]:
env = interaction_gym.InteractionEventGym(sensory_noise_base=1.0, sensory_noise_focus=0.01, randomize_colors = True)

## Collect data

In [None]:
event_input_data_list = [[], [], [], []]
event_target_data_list = [[], [], [], []]
for i in range(4):
    event_input_data_list[i] = [[], [], []]
    event_target_data_list[i] = [[], [], []]
event_input_data_list

In [None]:
def component_name_to_index(name):
    if name == 'start':
        return 0
    if name== 'dynamics':
        return 1
    return 2

In [None]:
for episodes in range(10000):
    
    # Reset environment to new event sequence
    observation = env.reset()
    
    # Sample one-hot-encoding of policy pi(0)
    policy_t = np.array([0.0, 0.0, 0.0])
    policy_t[random.randint(0, 2)] = 1
    repeat = False
    t = 0
    for _ in range(3000):
        
        # Perform pi(t) and receive new observation o(t)
        if not repeat:
            observation, reward, done, info = env.step(policy_t)
        
        component, e_t, inputs, targets = model.get_offline_data(o_t=observation, pi_t=policy_t, done=done, e_i=info)
        
        #print("Comp =", component, " for ", e_t, " at t =", t)
        
        # If we reach the end of an event sequence we run get_offline_data twice:
        # Once for the end data and once for the start data
        if component == 'end':
            repeat = True
        else:
            repeat = False
            t += 1
            
        component_index = component_name_to_index(component)
        event_input_data_list[e_t][component_index] += inputs
        event_target_data_list[e_t][component_index] += targets
        
        # Next sequence when event sequence is over
        if done:
            print("Episode ", episodes, " done after ", t , " time steps")
            break
env.close()

## Store the data

In [None]:
def index_to_component_name(i):
    if i == 0:
        return 'start'
    if i == 1:
        return 'dynamics'
    return 'end'

In [None]:
def get_event_name(e_i):
    if e_i == 0:
        return 'still'
    if e_i == 1:
        return 'rand'
    if e_i == 2:
        return 'reach'
    return 'transport'

In [None]:
for e in range(4):
    for c in range(3):
        e_name = get_event_name(e)
        comp_name = index_to_component_name(c)
        np.save("Data/input_" + e_name + "_" + comp_name, np.stack(event_input_data_list[e][c]))
        np.save("Data/target_" + e_name + "_" + comp_name, np.stack(event_target_data_list[e][c]))