In [14]:
import sys
sys.path.append('/home/dyan/stanford/su_cs234/')

import os
import pickle
import torch
import numpy as np
import matplotlib.pyplot as plt
from final_project.code.src.wrapper import CustomRewardEnv
from final_project.code.src.actions import meaningful_actions

import warnings

warnings.filterwarnings('ignore')

def action_to_string(action_arr):
    return "".join([str(x) for x in action_arr])

ACTIONS = {k: v for v, k in enumerate(action_to_string(a) for a in meaningful_actions)}

In [15]:
class HumanDemonstrationEnv:
    def __init__(self, traj_file):
        self.trajectory = self.get_traj(traj_file)
        self.max_step = len(self.trajectory) - 1
    
    def get_traj(self, traj_file):
        with open(os.path.join(traj_file), "rb") as f:
            _t = pickle.load(f)
        return _t
    
    def step(self, action_step):
        if action_step <= self.max_step:
            current_traj = self.trajectory[action_step]
            obs = current_traj["observation"]
            reward = current_traj["reward"]
            info = current_traj["info"]
            if action_step == self.max_step:
                return obs, reward, True, True, info
            else:
                return obs, reward, False, False, info
        else:
            raise ValueError(f"{action_step} step not found.")

    def fetch_action(self, action_step):
        # only keep allowed actions, otherwise marked as -1
        a = self.trajectory[action_step]["action"]
        a_str = action_to_string(a)
        if a_str not in ACTIONS:
            return -1
        else:
            return ACTIONS[a_str]

In [16]:
def get_traj_info(traj_file, skip_frame=4):

    env = CustomRewardEnv(
        HumanDemonstrationEnv(traj_file)
    )

    rewards = []
    states = []
    actions = []
    infos = []
    done = False 

    i = 0
    while not done or i <= 2500:  # truncate automatic frames
        obs, reward, terminated, truncated, info = env.step(i)
        done = terminated or truncated
        action = env.fetch_action(i)
        i += 1
        
        # skip is action is not allowed
        if action == -1:
            continue
        else:
            rewards.append(reward if not done else 0)
            states.append(obs)
            actions.append(action)
            infos.append(info)

    print(np.sum(rewards))
    # projected_rewards = np.cumsum(rewards[::-1], dtype=np.float32)[::-1]

    # desired_dict = {}
    # # for (state, proj_reward) pair
    # traj_output_s_pr = []
    # for s, proj_r, info in zip(
    #     states[::skip_frame], projected_rewards[::skip_frame], infos[::skip_frame]
    # ):
    #     traj_output_s_pr.append((s, proj_r, info))
    # desired_dict["s_pr_pair"] = traj_output_s_pr
    
    # for (state, action) pair
    traj_output_s_a = []
    for s, a in zip(states[::skip_frame], actions[::skip_frame]):
        traj_output_s_a.append((s, a))
    # desired_dict["s_a_pair"] = traj_output_s_a
    
    return traj_output_s_a

In [17]:
trajs = os.listdir("human_demon")
holder = []
for idx, file_name in enumerate(trajs):
    print(idx)
    tf = os.path.join("human_demon", file_name)
    traj_output = get_traj_info(tf)
    holder.extend(traj_output)

with open(os.path.join("human_demon_processed", f"traj_all.pkl"), "wb") as f:
    pickle.dump(holder, f)

0
5201.500000000001
1
4513.700000000001
2
5679.400000000001
3
5582.500000000001
4
5960.6
