In [None]:
import gym
import interaction_gym
import numpy as np
import event_inference as event
import random
import torch
from torch.utils.data import Dataset

## Load the data

In [None]:
def index_to_component_name(i):
    if i == 0:
        return 'start'
    if i == 1:
        return 'dynamics'
    return 'end'

def get_event_name(e_i):
    if e_i == 0:
        return 'still'
    if e_i == 1:
        return 'rand'
    if e_i == 2:
        return 'reach'
    return 'transport'


In [None]:
class EventComponentDataset(torch.utils.data.Dataset):
    """
    Dataset for one event component. 
    """
    
    def __init__(self, e_i, component_name):
        """
        Creates the dataset for one component (component_name in ['start', 'dynamics', 'end']) and one event e_i 
        """
        event_name = get_event_name(e_i)
        input_path = 'Data/input_' + event_name + '_' + component_name + '.npy'
        target_path = 'Data/target_' + event_name + '_' + component_name + '.npy'
        self.input_data = np.load(input_path)
        self.target_data = np.load(target_path)
        
    def __len__(self):
        return self.input_data.shape[0]

    def __getitem__(self, idx):
        return self.input_data[idx], self.target_data[idx]

Create the model

In [None]:
num_data_train = 1280
batch_size = 128
seed = 42

In [None]:
model = event.EventInferenceSystem(epsilon_start=0.01, epsilon_dynamics=0.001, epsilon_end=0.001,
                                   no_transition_prior=0.9, dim_observation=18, num_policies=3, 
                                   num_models=4, r_seed=seed, sampling_rate=2)

Load the data for every event $e_i$ and every subcomponent

In [None]:
dataloaders = [[], [], [], []]
for e_i in range(4):
    for c in range(3):
        dataset = EventComponentDataset(e_i, index_to_component_name(c))
        
        num_data = len(dataset)
        num_data_ignore = num_data - num_data_train
        
        train_dataset, _ = torch.utils.data.random_split(dataset, [num_data_train, num_data_ignore],
                                                         generator=torch.Generator().manual_seed(seed))
        
        dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        dataloaders[e_i].append(dataloader)

## Training the system using batches

We train our model for 500 epochs. In every epoch every component of every event is updated based on 1280 datapoints that are randomly assigned to batches of size 128. We print the mean negative log likelihood during training every 100 epochs.

In [None]:
for epoch in range(501):
    # Iterate over epochs
    if epoch % 100 == 0:
        print("--------------- EPOCH ", epoch, "---------------")
    for e_i in range(4):
        # Iterate over events
        
        for c in range(3):
            # Iterate over components (start, dynamics, end)
            
            dataloader_ei_c = dataloaders[e_i][c]
            
            nll_sum = 0.0
            
            for inps, targets in dataloader_ei_c:
                # Iterate through dataset
                nll = model.update_batch(inps, targets, e_i, index_to_component_name(c))
                nll_sum += nll
            if epoch % 100 == 0:
                print("Event ", get_event_name(e_i), " and ", index_to_component_name(c), "-component: NLL of ", nll_sum/10)        

## Testing the system:

Create the environment

In [None]:
env = interaction_gym.InteractionEventGym(sensory_noise_base=1.0, sensory_noise_focus=0.01, randomize_colors = True)

The code below runs the testing phase 3 times with event and policy inference of the system.
Here the agent is a hand. We print the inferred events and policy chosen

In [None]:
for episodes in range(3):
    print("--------------- EPISODE ", episodes, "---------------")
    # Reset environment to new event sequence
    observation = env.reset_to_grasping(claw=False) # claw=False for hand-agent
    
    # Sample one-hot-encoding of policy pi(0)
    policy_t = np.array([0.0, 0.0, 0.0])
    policy_t[2] = 1.0
    for t in range(3000):
        #Rendering if desired:
        env.render() #store_video=True, video_identifier=0)
        
        # Perform pi(t) and receive new observation o(t)
        observation, reward, done, info = env.step(policy_t)
        
        # Update the event probabilities, event schemata, and infer next policy
        policy_t, P_ei = model.step(o_t=observation, pi_t=policy_t, training=False, done=done, e_i=info)
        print("P_ei[t=", t, "] = ", P_ei, " with real event ", info)
        
        # Next sequence when event sequence is over
        if done:
            break
env.close()