In [None]:
import os
import sys
import numpy as np
# Add VLA_DIR to PYTHONPATH
sys.path.append(os.path.abspath(os.path.join(os.path.dirname('__file__'), '../')))
# Add LIBERO to PYTHONPATH
sys.path.append(os.path.abspath(os.path.join(os.path.dirname('__file__'), '../external/LIBERO')))
from libero.libero import benchmark, get_libero_path
from utils.LIBERO_utils import get_task_names, extract_task_info

## User specific configurations
# TODO: change this into argparse for user input in python file
DATASET_NAME = "libero_spatial" # "libero_object", "libero_spatial", "libero_goal", "libero_10", "libero_90"
# currently no need to change FILTER_KEY and VERBOSE
FILTER_KEY = None  # Set filter key if needed, e.g., "valid" for validation
VERBOSE = True

## Check libero dataset path
BENCHMARK_PATH = get_libero_path("benchmark_root")
DATASET_BASE_PATH = get_libero_path("datasets")
DATASET_PATH_DEMO = os.path.join(DATASET_BASE_PATH, DATASET_NAME)
print("=====================================")
print("LIBERO benchmark root path: ", BENCHMARK_PATH)
print("LIBERO dataset root path: ", DATASET_BASE_PATH)
print(f"LIBERO demonstration dataset for {DATASET_NAME} path: {DATASET_PATH_DEMO}")
print("=====================================")

## Load demonstration dataset
# get all task names in the dataset
task_names_demo = get_task_names(DATASET_PATH_DEMO)
# print(f"Tasks in the demonstration dataset: {task_names_demo}")
# load demonstration data for each task
dataset_demo = {}
print("Start loading demonstration data for each task...")
print("-------------------------------------")
for task_name_demo in task_names_demo:
    print(f"Loading demonstration data for task:\n {task_name_demo}")
    [language_instruction, actions_batch, images_batch] = extract_task_info(DATASET_PATH_DEMO, task_name_demo, filter_key=FILTER_KEY, verbose=VERBOSE)
    dataset_demo[task_name_demo] = [language_instruction, actions_batch, images_batch]
    # check if actions_batch and images_batch have the same length
    assert actions_batch.shape[0] == images_batch.shape[0], "Dataset problem: the number of actions and images should be the same!"
    # print dataset information
    print("Loaded successfully!")
    print(f"Total demonstrations: {actions_batch.shape[0]}")
    ave_len = np.mean([len(x) for x in actions_batch]) # average length of demonstrations
    print(f"Average demonstration length: {ave_len}")
    action_shape = actions_batch[0][0].shape # action shape
    print(f"Action shape: {action_shape}")
    img_shape = images_batch[0][0].shape # image shape
    print(f"Image shape: {img_shape}")
    print("-------------------------------------")