In [69]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [70]:
from pathlib import Path
import gymnasium as gym
import sys, os
import cv2
import numpy as np
from lerobot.common.envs.factory import make_env
from hydra import compose, initialize
from omegaconf import OmegaConf
import time
from collections import defaultdict

# NOTE: replace with path to peripheral controller
sys.path.append(os.path.expanduser("~/workspace/fastrl/nov20"))
from second_wind.peripheral import get_memorymaze_action_from_joystick, get_pusht_action_from_joystick, get_calvin_action_from_joystick, get_pinpad_action_from_joystick, get_ev3_action_from_joystick


In [71]:
env_name = "robosuite" #"xarm"

# create an output directory for the demonstration
output_directory = Path(f"local/demonstration/{env_name}")
output_directory.mkdir(parents=True, exist_ok=True)

# context initialization
with initialize(version_base=None, config_path="../lerobot/configs", job_name="test_app"):
    cfg = compose(config_name="default", overrides=[f"env={env_name}",
                                                    # f"policy.training.batch_size=1",
                                                    ])
    print(OmegaConf.to_yaml(cfg))

# create an environment
env = make_env(cfg, n_envs=1)

resume: false
device: cuda
use_amp: false
seed: 100000
dataset_repo_id: lerobot/pusht
video_backend: pyav
training:
  offline_steps: 200000
  online_steps: 0
  online_steps_between_rollouts: 1
  online_sampling_ratio: 0.5
  online_env_seed: ???
  eval_freq: 5000
  log_freq: 250
  save_checkpoint: true
  save_freq: 5000
  num_workers: 4
  batch_size: 64
  image_transforms:
    enable: false
    max_num_transforms: 3
    random_order: false
    brightness:
      weight: 1
      min_max:
      - 0.8
      - 1.2
    contrast:
      weight: 1
      min_max:
      - 0.8
      - 1.2
    saturation:
      weight: 1
      min_max:
      - 0.5
      - 1.5
    hue:
      weight: 1
      min_max:
      - -0.05
      - 0.05
    sharpness:
      weight: 1
      min_max:
      - 0.8
      - 1.2
  grad_clip_norm: 10
  lr: 0.0001
  lr_scheduler: cosine
  lr_warmup_steps: 500
  adam_betas:
  - 0.95
  - 0.999
  adam_eps: 1.0e-08
  adam_weight_decay: 1.0e-06
  delta_timestamps:
    observation.image: '[i 

In [64]:
obs, info = env.reset()
for k,v in obs.items():
    print(k, v.shape if hasattr(v, "shape") else v.__len__())

gripper_image (1, 128, 128, 3)
image (1, 128, 128, 3)
is_first (1,)
is_last (1,)
is_terminal (1,)
reward (1,)
vector_state (1, 21)


In [67]:
import pygame
pygame.joystick.init()
try:
    joysticks = [pygame.joystick.Joystick(x) for x in range(pygame.joystick.get_count())]
    os.environ['SDL_JOYSTICK_ALLOW_BACKGROUND_EVENTS'] = "1"
    print(f"Found {len(joysticks)} joysticks.")
except:
    joysticks = []
    print("No joysticks found.")

Found 1 joysticks.


In [45]:
cv2.destroyAllWindows()
obs, info = env.reset()
n_eps = 1
step_time_s = 1 / cfg.env.fps
stp = 0
saved_episodes = [defaultdict(list) for _ in range(n_eps)]
for ep in range(n_eps):
    done, truncated = False, False
    while not done and not truncated:
        t0 = time.time()
        pygame.event.pump()
        action = [get_ev3_action_from_joystick(joysticks)]#; print(f"{i}: {action}")
        # action = env.action_space.sample()
        obs, reward, done, truncated, info = env.step(action)
        img = np.hstack([obs["image"], obs["gripper_image"]]).squeeze()
        cv2.imshow(f"img {img.shape}", img); cv2.waitKey(1)

        leftover_time = step_time_s - (time.time() - t0)
        if leftover_time > 0: time.sleep(leftover_time)

        saved_episodes[ep]["obs"].append(obs)

        if reward > 0 or done or truncated:
            print(stp, reward, done, truncated)
            if done or truncated: obs, info = env.reset()
        stp += 1

499 [0.] [False] [ True]


In [72]:
cv2.destroyAllWindows()
import random
obs, info = env.reset(seed=0)
n_eps = 3
step_time_s = 1 / cfg.env.fps
stp = 0
saved_episodes = [defaultdict(list) for _ in range(n_eps)]
for ep in range(n_eps):
    done, truncated = False, False
    img = obs["image"].squeeze()
    cv2.imshow(f"{ep} {img.shape}", img); cv2.waitKey(500)
    obs, info = env.reset()
    stp += 1
cv2.destroyAllWindows()