In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torchrl
from flipper_training.environment.env import Env
from flipper_training.configs.terrain_config import TerrainConfig
from flipper_training.configs.robot_config import RobotModelConfig
from flipper_training.configs.engine_config import PhysicsEngineConfig
from flipper_training.utils.torch_utils import set_device

In [3]:
num_robots = 2

In [4]:
from flipper_training.heightmaps.flat import FlatHeightmapGenerator

# Heightmap setup - use torch's XY indexing !!!!!
grid_res = 0.05  # 5cm per grid cell
max_coord = 3.2  # meters
heighmap_gen = FlatHeightmapGenerator()
x_grid, y_grid, z_grid, extras = heighmap_gen(grid_res, max_coord, num_robots, None)

In [6]:
from flipper_training.vis.static_vis import plot_heightmap_3d

plot_heightmap_3d(x_grid[0], y_grid[0], z_grid[0])

In [7]:
device = "cpu"
device = set_device(device)

In [8]:
# Instatiate the physics config
robot_model = RobotModelConfig(kind="marv", points_per_driving_part=256, points_per_body=512)
terrain_config = TerrainConfig(
    x_grid=x_grid,
    y_grid=y_grid,
    z_grid=z_grid,
    grid_res=grid_res,
    max_coord=max_coord,
    k_stiffness=30000,
    grid_extras=extras,
    k_friction_lat=0.5,
    k_friction_lon=1.0,
)
physics_config = PhysicsEngineConfig(num_robots=num_robots, damping_alpha=5.0)

[0m[33m2025-05-04 22:37:31.643 (  77.390s) [         1CBEA1D]            vtkMath.cxx:570   WARN| Unable to factor linear system[0m
[0m[33m2025-05-04 22:37:31.787 (  77.534s) [         1CBEA1D]      vtkDelaunay3D.cxx:514   WARN| vtkDelaunay3D (0x17ef7c2f0): 38 degenerate triangles encountered, mesh quality suspect[0m
[0m[33m2025-05-04 22:37:32.414 (  78.161s) [         1CBEA1D]            vtkMath.cxx:570   WARN| Unable to factor linear system[0m
[0m[33m2025-05-04 22:37:32.542 (  78.289s) [         1CBEA1D]      vtkDelaunay3D.cxx:514   WARN| vtkDelaunay3D (0x168fc75d0): 31 degenerate triangles encountered, mesh quality suspect[0m
[0m[33m2025-05-04 22:37:33.120 (  78.867s) [         1CBEA1D]            vtkMath.cxx:570   WARN| Unable to factor linear system[0m
[0m[33m2025-05-04 22:37:33.254 (  79.001s) [         1CBEA1D]      vtkDelaunay3D.cxx:514   WARN| vtkDelaunay3D (0x17e92afd0): 30 degenerate triangles encountered, mesh quality suspect[0m
[0m[33m2025-05-04 22:37:33

In [23]:
from flipper_training.observations.heightmap import Heightmap
from flipper_training.observations.robot_state import LocalStateVector

obs = {
    Heightmap.make_factory(percep_shape=(128, 128), percep_extent=(1.0, 1.0, -1.0, -1.0), interval=[-1, 1], encoder_opts={}),
    LocalStateVector.make_factory(encoder_opts={}),
}
obs

{<function flipper_training.observations.heightmap.Heightmap(percep_shape: tuple[int, int], percep_extent: tuple[float, float, float, float], interval: tuple[float, float], shift: float | None = None, normalize_to_interval: bool = False, *, env: 'Env', encoder_opts: dict, apply_noise: bool = False, noise_scale: float | torch.Tensor | None = None) -> None>,
 <function flipper_training.observations.robot_state.LocalStateVector(*, env: 'Env', encoder_opts: dict, apply_noise: bool = False, noise_scale: float | torch.Tensor | None = None) -> None>}

In [19]:
from flipper_training.rl_rewards.rewards import PotentialGoal

reward_factory = PotentialGoal.make_factory(
    goal_reached_reward=1.0,
    failed_reward=-1.0,
    potential_coef=20,
    step_penalty=-0.1,
    gamma=0.99,
)

In [29]:
from flipper_training.rl_objectives.fixed_goal import FixedStartGoalNavigation

objective_factory = FixedStartGoalNavigation.make_factory(
    **{
        "start_x_y_z": torch.tensor([-1.0, 0.0, 0.2]),
        "goal_x_y_z": torch.tensor([1.0, 0.0, 0.1]),
        "goal_reached_threshold": 0.05,
        "max_feasible_pitch": torch.deg2rad(torch.tensor(70.0)),
        "max_feasible_roll": torch.deg2rad(torch.tensor(70.0)),
        "iteration_limit": 500,
        "init_joint_angles": torch.tensor([0.0, 0.0, 0.0, 0.0]),
        "rng": None,
    },
)

In [30]:
torch_env = Env(
    objective_factory=objective_factory,
    reward_factory=reward_factory,
    observation_factories=obs,
    terrain_config=terrain_config,
    physics_config=physics_config,
    robot_model_config=robot_model,
    device=device,
    differentiable=False,
    batch_size=[num_robots],
    engine_compile_opts=None,
)

In [31]:
torchrl.envs.utils.check_env_specs(torch_env)

2025-05-04 22:45:54,844 [torchrl][INFO] check_env_specs succeeded!


In [32]:
torch_env.reset(reset_all=True)
torch_env.visualize()