In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torchrl
import matplotlib.pyplot as plt
from tensordict import TensorDict
from flipper_training.environment.env import Env
from flipper_training.configs import *
from flipper_training.rl_objectives import *
from flipper_training.utils.heightmap_generators import *

In [3]:
num_robots = 1

In [4]:
from flipper_training.utils.environment import make_x_y_grids, generate_heightmaps

In [5]:
# Heightmap setup - use torch's XY indexing !!!!!
grid_res = 0.05  # 5cm per grid cell
max_coord = 6.4  # meters
heighmap_gen = MultiGaussianHeightmapGenerator(
    min_gaussians=400,
    max_gaussians=600,
    min_height_fraction=0.03,
    max_height_fraction=0.12,
    min_std_fraction=0.03,
    max_std_fraction=0.08,
    min_sigma_ratio=0.6)
x_grid, y_grid = make_x_y_grids(max_coord, grid_res, num_robots)
z_grid, suit_mask = generate_heightmaps(x_grid, y_grid, heighmap_gen)

In [6]:
from flipper_training.vis.static_vis import plot_heightmap_3d

plot_heightmap_3d(x_grid[0], y_grid[0], z_grid[0])

In [7]:
device = "cpu"

In [8]:
# Instatiate the physics config
robot_model = RobotModelConfig(robot_type="marv")
world_config = WorldConfig(x_grid=x_grid, y_grid=y_grid, z_grid=z_grid, grid_res=grid_res, max_coord=max_coord, k_stiffness=30000, suitable_mask=suit_mask)
physics_config = PhysicsEngineConfig(num_robots=num_robots)

Loading robot model from cache: /Users/davidkorcak/Documents/ctu/bachelors/flipper_training/.robot_cache/marv_0.080_192.pt
Robot has 1023 points


In [9]:
from flipper_training.environment.env import EnvConfig
from flipper_training.rl_objectives import *

env_config = EnvConfig(control_type="per-track", differentiable=False)

In [10]:
from flipper_training.observations import *
from functools import partial

obs = {
    "perception": partial(Heightmap, percep_shape=(128, 128), percep_extent=(1.0, 1.0, -1.0, -1.0)),
    "observation": partial(RobotStateVector)
}

obs["perception"]

functools.partial(<class 'flipper_training.observations.heightmap.Heightmap'>, percep_shape=(128, 128), percep_extent=(1.0, 1.0, -1.0, -1.0))

In [11]:
from flipper_training.rl_rewards.rewards import *

reward = RollPitchGoal(1000,-1000, 1.,1.)

In [18]:
torch_env = Env(SimpleStabilizationObjective(),
                reward,
                obs,
                env_config,
                world_config,
                physics_config,
                robot_model,
                device,
                batch_size=[num_robots])

In [19]:
torchrl.envs.utils.check_env_specs(torch_env)

2025-02-19 19:02:09,017 [torchrl][INFO] check_env_specs succeeded!


In [20]:
# Controls
speed = 0.5  # m/s forward
flipper_controls = torch.zeros(robot_model.num_joints)
control_vec = torch.cat([torch.full((4,), speed), flipper_controls]).unsqueeze(0).repeat(num_robots, 1)
control_vec.shape

torch.Size([1, 8])

In [21]:
control_td = TensorDict({"action": control_vec}, batch_size=num_robots)
control_td.shape

torch.Size([1])

In [22]:
torch_env.reset(reset_all=True)
torch_env.visualize_curr_state()

In [23]:
for i in range(100):
    o = torch_env.step(control_td)
    # control_td.pop("next")
torch_env.visualize_curr_state()