In [15]:
%load_ext autoreload
%autoreload 3

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [16]:
%env TORCH_LOGS=dynamo 
%env TORCHDYNAMO_VERBOSE=1

env: TORCH_LOGS=dynamo
env: TORCHDYNAMO_VERBOSE=1


In [17]:
import torch
import matplotlib.pyplot as plt
from flipper_training.configs import (
    TerrainConfig,
    RobotModelConfig,
    PhysicsEngineConfig,
)
from flipper_training.engine.engine import DPhysicsEngine, PhysicsState
from flipper_training.utils.dynamics import *
from flipper_training.utils.geometry import *
from flipper_training.utils.environment import *
from flipper_training.utils.numerical import *
from collections import deque
from copy import deepcopy
from tqdm import tqdm, trange

In [18]:
from flipper_training.utils.torch_utils import set_device

In [19]:
torch.random.manual_seed(420)

<torch._C.Generator at 0x7fe3ddd6dd30>

In [20]:
num_robots = 128

In [21]:
# Heightmap setup - use torch's XY indexing !!!!!
grid_res = 0.05  # 5cm per grid cell
max_coord = 3.2  # meters
DIM = int(2 * max_coord / grid_res)
xint = torch.linspace(-max_coord, max_coord, DIM)
yint = torch.linspace(-max_coord, max_coord, DIM)
x, y = torch.meshgrid(xint, yint, indexing="xy")
z = torch.zeros_like(x)
for thresh in [1.0, 0, -1.0, -2]:
    z[torch.logical_and(x > -thresh, y < thresh)] += 0.2
x_grid = x.repeat(num_robots, 1, 1)
y_grid = y.repeat(num_robots, 1, 1)
z_grid = z.repeat(num_robots, 1, 1)

In [22]:
# Instatiate the physics config
robot_model = RobotModelConfig(kind="marv", points_per_driving_part=384, points_per_body=512)
world_config = TerrainConfig(
    x_grid=x_grid,
    y_grid=y_grid,
    z_grid=z_grid,
    grid_res=grid_res,
    max_coord=max_coord,
    k_stiffness=40000,
    k_friction_lat=0.5,
    k_friction_lon=0.8,
)
physics_config = PhysicsEngineConfig(num_robots=num_robots, damping_alpha=5.0, dt=0.005)

2025-04-29 22:20:36,352 [RobotModelConfig][[92mINFO[00m]: Loading robot model from cache: /mnt/personal/korcadav/flipper_training/.robot_cache/marv_vx0.010_dp384_b512_whl0.02_trck0.05_eaecc2d5466de1eb8911703837d75c759b5c075158ced88ea318e932700dabb2 (robot_config.py:155)
2025-04-29 22:20:36,352 [RobotModelConfig][[92mINFO[00m]: Loading robot model from cache: /mnt/personal/korcadav/flipper_training/.robot_cache/marv_vx0.010_dp384_b512_whl0.02_trck0.05_eaecc2d5466de1eb8911703837d75c759b5c075158ced88ea318e932700dabb2 (robot_config.py:155)


In [23]:
compile_opts = {
    "fullgraph": True,
    "options": {
        "triton.cudagraphs": True,
        "coordinate_descent_tuning": True,
        # "max-autotune": True,
    },
}

In [24]:
from flipper_training.utils.geometry import euler_to_quaternion

num_steps = 1000
x0 = torch.tensor([[-2.5, -2.5, 0.400]]).repeat(num_robots, 1)
xd0 = torch.zeros_like(x0)
q0 = euler_to_quaternion(*torch.tensor([0, 0, 1.0 * torch.pi])).repeat(num_robots, 1)
omega0 = torch.zeros_like(x0)
thetas0 = torch.zeros((num_robots, robot_model.num_driving_parts))
controls_all = torch.rand((num_robots, 2 * robot_model.num_driving_parts))

In [25]:
init_state = PhysicsState(x0, xd0, q0, omega0, thetas0)

In [26]:
trajectories = {}

for i in range(3):
    device_str = f"cuda:{i}"
    print(f"Running on {device_str}")
    device = set_device(device_str)
    world_config.to(device)
    robot_model.to(device)
    engine = DPhysicsEngine(physics_config, robot_model, device)
    engine.compile(**compile_opts)
    comp_state = deepcopy(init_state).to(device)
    comp_controls = torch.zeros((num_robots, 2 * robot_model.num_driving_parts)).to(device)
    engine(comp_state, comp_controls, world_config)
    state = deepcopy(init_state).to(device)
    traj = []
    for _ in trange(num_steps, desc=f"Trajectory on {device_str}"):
        next_state, state_der = engine(state, comp_controls, world_config)
        next_state = deepcopy(next_state)
        state_der = deepcopy(state_der)
        traj.append((state.cpu(), state_der.cpu()))
        state = next_state
    trajectories[device_str] = traj

I0429 22:20:36.824000 83085 torch/_dynamo/symbolic_convert.py:2706] [0/3] Step 1: torchdynamo start tracing forward /mnt/personal/korcadav/flipper_training/flipper_training/engine/engine.py:27
I0429 22:20:36.826000 83085 torch/fx/experimental/symbolic_shapes.py:3192] [0/3] create_env


Running on cuda:0


I0429 22:20:37.555000 83085 torch/_dynamo/symbolic_convert.py:3028] [0/3] Step 1: torchdynamo done tracing forward (RETURN_VALUE)
I0429 22:20:37.576000 83085 torch/_dynamo/output_graph.py:1458] [0/3] Step 2: calling compiler function inductor
I0429 22:20:44.339000 83085 torch/fx/experimental/symbolic_shapes.py:4547] [0/3] produce_guards
I0429 22:20:44.377000 83085 torch/_dynamo/output_graph.py:1463] [0/3] Step 2: done compiler function inductor
I0429 22:20:44.423000 83085 torch/fx/experimental/symbolic_shapes.py:4547] [0/3] produce_guards
I0429 22:20:44.447000 83085 torch/_dynamo/pgo.py:636] [0/3] put_code_state: no cache key, skipping
Trajectory on cuda:0: 100%|██████████| 1000/1000 [00:02<00:00, 481.18it/s]
I0429 22:21:01.338000 83085 torch/_dynamo/symbolic_convert.py:2706] [0/4] Step 1: torchdynamo start tracing forward /mnt/personal/korcadav/flipper_training/flipper_training/engine/engine.py:27
I0429 22:21:01.340000 83085 torch/fx/experimental/symbolic_shapes.py:3192] [0/4] create_

Running on cuda:1


I0429 22:21:02.054000 83085 torch/_dynamo/symbolic_convert.py:3028] [0/4] Step 1: torchdynamo done tracing forward (RETURN_VALUE)
I0429 22:21:02.076000 83085 torch/_dynamo/output_graph.py:1458] [0/4] Step 2: calling compiler function inductor
I0429 22:21:07.666000 83085 torch/fx/experimental/symbolic_shapes.py:4547] [0/4] produce_guards
I0429 22:21:07.701000 83085 torch/_dynamo/output_graph.py:1463] [0/4] Step 2: done compiler function inductor
I0429 22:21:07.748000 83085 torch/fx/experimental/symbolic_shapes.py:4547] [0/4] produce_guards
I0429 22:21:07.771000 83085 torch/_dynamo/pgo.py:636] [0/4] put_code_state: no cache key, skipping
Trajectory on cuda:1: 100%|██████████| 1000/1000 [00:01<00:00, 566.40it/s]
I0429 22:21:20.644000 83085 torch/_dynamo/symbolic_convert.py:2706] [0/5] Step 1: torchdynamo start tracing forward /mnt/personal/korcadav/flipper_training/flipper_training/engine/engine.py:27
I0429 22:21:20.645000 83085 torch/fx/experimental/symbolic_shapes.py:3192] [0/5] create_

Running on cuda:2


I0429 22:21:21.355000 83085 torch/_dynamo/symbolic_convert.py:3028] [0/5] Step 1: torchdynamo done tracing forward (RETURN_VALUE)
I0429 22:21:21.376000 83085 torch/_dynamo/output_graph.py:1458] [0/5] Step 2: calling compiler function inductor
I0429 22:21:26.965000 83085 torch/fx/experimental/symbolic_shapes.py:4547] [0/5] produce_guards
I0429 22:21:26.999000 83085 torch/_dynamo/output_graph.py:1463] [0/5] Step 2: done compiler function inductor
I0429 22:21:27.048000 83085 torch/fx/experimental/symbolic_shapes.py:4547] [0/5] produce_guards
I0429 22:21:27.072000 83085 torch/_dynamo/pgo.py:636] [0/5] put_code_state: no cache key, skipping
Trajectory on cuda:2: 100%|██████████| 1000/1000 [00:01<00:00, 551.09it/s]


In [27]:
atol = 1e-5

In [28]:
for i in range(num_steps):
    for t_to_diff in [
        ("cuda:0", "cuda:1"),
        ("cuda:0", "cuda:2"),
        ("cuda:1", "cuda:2"),
    ]:
        t1 = trajectories[t_to_diff[0]][i]
        t2 = trajectories[t_to_diff[1]][i]
        state1, state_der1 = t1
        state2, state_der2 = t2
        state_abs_diff = (state1.cpu() - state2.cpu()).abs()
        state_der_abs_diff = (state_der1.cpu() - state_der2.cpu()).abs()
        for k, v in state_abs_diff.items():
            if v.max() > atol:
                print(f"State {k} diff {v.max()} between {t_to_diff[0]} and {t_to_diff[1]} at step {i}")
        for k, v in state_der_abs_diff.items():
            if v.max() > atol:
                print(f"State Derivative {k} diff {v.max()} between {t_to_diff[0]} and {t_to_diff[1]} at step {i}")

State Derivative torque diff 1.1444091796875e-05 between cuda:0 and cuda:1 at step 51
State Derivative torque diff 1.1444091796875e-05 between cuda:1 and cuda:2 at step 51
State Derivative torque diff 1.1444091796875e-05 between cuda:0 and cuda:1 at step 52
State Derivative torque diff 1.1444091796875e-05 between cuda:1 and cuda:2 at step 52
State Derivative f_spring diff 1.2874603271484375e-05 between cuda:0 and cuda:1 at step 55
State Derivative f_spring diff 1.2874603271484375e-05 between cuda:1 and cuda:2 at step 55
State Derivative f_spring diff 1.52587890625e-05 between cuda:0 and cuda:1 at step 58
State Derivative torque diff 1.1444091796875e-05 between cuda:0 and cuda:2 at step 58
State Derivative f_spring diff 1.52587890625e-05 between cuda:1 and cuda:2 at step 58
State Derivative torque diff 1.33514404296875e-05 between cuda:1 and cuda:2 at step 58
State Derivative torque diff 1.1444091796875e-05 between cuda:0 and cuda:1 at step 59
State Derivative torque diff 1.52587890625e