In [1]:
import json
import pandas as pd
from isaac.constants import BASIC_TRAINING_COLS, MASS_CLASS_COLS, FORCE_CLASS_COLS
import numpy as np

In [2]:
from isaac.dataset import read_dataset

In [None]:
from rpy2.robjects import r, pandas2ri
pandas2ri.activate()

In [3]:
import numpy as np
from tqdm import tqdm

In [None]:
hdf_path = "data/r_passive_trials.h5"
trial_id = 0

for world_i in tqdm(range(1, 2188)):
    rdata_path = "data/for_hector/passive_simulations/w_%d.rdata" % world_i
    r['load'](rdata_path)
    key = r["key"].iloc[world_i -1]
    trials = r["sim_trials"]
    
    
    for world_trial in trials:
        world_trial = pandas2ri.ri2py_dataframe(world_trial)
        
        world_trial["A"] = np.full(world_trial.shape[0], key.target_heavier == "A")
        world_trial["B"] = np.full(world_trial.shape[0], key.target_heavier == "B")
        world_trial["same"] = np.full(world_trial.shape[0], key.target_heavier == "same")
        
        world_trial["attract"] = np.full(world_trial.shape[0], key.target_fAB == 3)
        world_trial["none"] = np.full(world_trial.shape[0], key.target_fAB == 0)
        world_trial["repel"] = np.full(world_trial.shape[0], key.target_fAB == -3)

        world_trial["target_fAB"] = np.full(world_trial.shape[0], key.target_fAB == 3)
        world_trial["fAC"] = np.full(world_trial.shape[0], key.fAC == 3)
        world_trial["fAD"] = np.full(world_trial.shape[0], key.fAD == 0)
        world_trial["fBC"] = np.full(world_trial.shape[0], key.fBC == -3)
        world_trial["fBD"] = np.full(world_trial.shape[0], key.fBD == 0)
        world_trial["fCD"] = np.full(world_trial.shape[0], key.fCD == -3)

        world_trial["world_id"] = np.full(world_trial.shape[0], world_i)
        world_trial.to_hdf(hdf_path, key="trial_%d" % trial_id)
        
        trial_id += 1

## Dividing into train, validation and test trials

In [4]:
from sklearn.model_selection import train_test_split

In [5]:
all_trials = read_dataset("data/r_passive_trials.h5")

100%|██████████| 10935/10935 [01:03<00:00, 173.05it/s]


In [6]:
w_to_s = {}

for trial in all_trials:
    mass_sol = MASS_CLASS_COLS[np.argmax(trial[list(MASS_CLASS_COLS)].iloc[0].values)]
    rel_sol = FORCE_CLASS_COLS[np.argmax(trial[list(FORCE_CLASS_COLS)].iloc[0].values)]
                               
    w_to_s[trial.world_id.iloc[0]] = mass_sol + "_" + rel_sol

In [7]:
np.random.seed(37)
train_wids, test_wids, train_sols, test_sols = train_test_split(list(w_to_s.keys()), list(w_to_s.values()), stratify=list(w_to_s.values()), test_size=0.5)
val_wids, test_wids, val_sols, test_sols = train_test_split(test_wids, test_sols, stratify=test_sols, test_size=0.5)

In [8]:
print(pd.Series(train_sols).value_counts())
print(pd.Series(val_sols).value_counts())
print(pd.Series(test_sols).value_counts())

same_none       122
A_attract       122
B_none          122
A_none          122
same_attract    121
same_repel      121
B_attract       121
B_repel         121
A_repel         121
dtype: int64
same_attract    61
B_none          61
A_attract       61
same_repel      61
B_attract       61
B_repel         61
A_repel         61
same_none       60
A_none          60
dtype: int64
same_attract    61
same_none       61
same_repel      61
B_attract       61
A_none          61
B_repel         61
A_repel         61
B_none          60
A_attract       60
dtype: int64


In [9]:
train_trials = []
val_trials = []
test_trials = []

for trial in all_trials:
    world_id = trial.world_id.unique()[0]
    
    if world_id in train_wids:
        train_trials.append(trial)
    elif world_id in val_wids:
        val_trials.append(trial)
    else:
        test_trials.append(trial)

In [10]:
def save_dataset(hdf_path, trials):
    for trial_i, trial in tqdm(enumerate(trials), total=len(trials)):
        trial.to_hdf(hdf_path, key="trial_%d" % trial_i)
        
train_hdf = "data/r_train_trials.h5"
val_hdf = "data/r_val_trials.h5"
test_hdf = "data/r_test_trials.h5"

save_dataset(train_hdf, train_trials)
save_dataset(val_hdf, val_trials)
save_dataset(test_hdf, test_trials)

100%|██████████| 5465/5465 [08:24<00:00,  5.66it/s]
100%|██████████| 2735/2735 [02:12<00:00, 20.70it/s]
100%|██████████| 2735/2735 [02:13<00:00, 20.54it/s]
