# 3D-LOTUS policy

This notebook shows an example to run the trained 3D-LOTUS policy on the RLBench simulator. 

Before starting, make sure that you have followed the instructions in [our Github repository](https://github.com/vlc-robot/robot-3dlotus?tab=readme-ov-file) to setup the environment and download our pretrained models.

In [1]:
import os
import ctypes

# Set LD_LIBRARY_PATH (optional but good practice)
os.environ['LD_LIBRARY_PATH'] = "/home/uhcc/Desktop/robot-3dlotus/CoppeliaSim_Edu_V4_1_0_Ubuntu20_04:" + os.environ.get('LD_LIBRARY_PATH', '')

# Force-load the shared library using full path
ctypes.CDLL("/home/uhcc/Desktop/robot-3dlotus/CoppeliaSim_Edu_V4_1_0_Ubuntu20_04/libcoppeliaSim.so.1")


<CDLL '/home/uhcc/Desktop/robot-3dlotus/CoppeliaSim_Edu_V4_1_0_Ubuntu20_04/libcoppeliaSim.so.1', handle 33f22ee0 at 0x7f706c77bf10>

In [2]:
import os
import numpy as np
from easydict import EasyDict
import matplotlib.pyplot as plt

from genrobo3d.rlbench.environments import RLBenchEnv, Mover
from rlbench.backend.utils import task_file_to_task_class
from pyrep.errors import IKError, ConfigurationPathError
from rlbench.backend.exceptions import InvalidActionError

from genrobo3d.train.utils.misc import set_random_seed
from genrobo3d.evaluation.eval_simple_policy import Actioner

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [3]:
os.chdir('..') # locate in the robot-3dlotus directory

## Build the model

In [4]:
seed = 100

model_args = EasyDict(
    exp_config='data/experiments/gembench/3dlotus/v1/logs/training_config.yaml',
    checkpoint='data/experiments/gembench/3dlotus/v1/ckpts/model_step_40000.pt',
    device='cuda',
    real_robot=False,
    save_obs_outs_dir=None,
    best_disc_pos='max',
    num_ensembles=1,
    remained_args={},
)

In [5]:
set_random_seed(seed)
actioner = Actioner(model_args)

## Build RLBench environment

In [6]:
taskvar = 'push_button+0'
task_str, variation_id = taskvar.split('+')
variation_id = int(variation_id)

In [7]:
image_size = [256, 256]
mover_max_tries = 10
max_steps = 25

In [8]:
env = RLBenchEnv(
    data_path='',
    apply_rgb=True,
    apply_pc=True,
    apply_mask=True,
    headless=True, # Changed 'True'
    image_size=image_size,
    cam_rand_factor=0,
)

In [9]:
env.env.launch()
task_type = task_file_to_task_class(task_str)
task = env.env.get_task(task_type)
task.set_variation(variation_id)

move = Mover(task, max_tries=mover_max_tries)


: 

## Run policy

In [None]:
demo_id = 0

instructions, obs = task.reset()

print('Instructions:', instructions)

print('Initial observation')
plt.imshow(np.concatenate([obs.left_shoulder_rgb, obs.right_shoulder_rgb, obs.wrist_rgb, obs.front_rgb], 1))
plt.show()

In [None]:
obs_state_dict = env.get_observation(obs)
move.reset(obs_state_dict['gripper'])

In [None]:
for step_id in range(max_steps):
    # fetch the current observation, and predict one action
    batch = {
        'task_str': task_str,
        'variation': variation_id,
        'step_id': step_id,
        'obs_state_dict': obs_state_dict,
        'episode_id': demo_id,
        'instructions': instructions,
    }

    output = actioner.predict(**batch)
    action = output["action"]

    if action is None:
        break

    # update the observation based on the predicted action
    try:
        obs, reward, terminate, _ = move(action, verbose=False)
        print('Step id:', step_id+1)
        plt.imshow(np.concatenate([obs.left_shoulder_rgb, obs.right_shoulder_rgb, obs.wrist_rgb, obs.front_rgb], 1))
        plt.show()
        
        obs_state_dict = env.get_observation(obs)  # type: ignore

        if reward == 1:
            break
        if terminate:
            print("The episode has terminated!")
    except (IKError, ConfigurationPathError, InvalidActionError) as e:
        print(taskvar, demo_id, step_id, e)
        reward = 0
        break

print('Reward:', reward)

In [None]:
env.env.shutdown()