In [1]:
%load_ext autoreload
%autoreload 3

In [2]:
%env TORCH_LOGS=dynamo 
%env TORCHDYNAMO_VERBOSE=1
%env TORCHINDUCTOR_FORCE_DISABLE_CACHES=1

env: TORCH_LOGS=dynamo
env: TORCHDYNAMO_VERBOSE=1
env: TORCHINDUCTOR_FORCE_DISABLE_CACHES=1


In [3]:
import torch
import matplotlib.pyplot as plt
from flipper_training.configs import (
    TerrainConfig,
    RobotModelConfig,
    PhysicsEngineConfig,
)
import time
from flipper_training.engine.engine import DPhysicsEngine, PhysicsState
from flipper_training.utils.dynamics import *
from flipper_training.utils.geometry import *
from flipper_training.utils.environment import *
from flipper_training.utils.numerical import *
from collections import deque
from copy import deepcopy
from tqdm import tqdm, trange

In [4]:
from flipper_training.utils.torch_utils import set_device

In [5]:
torch.random.manual_seed(420)

<torch._C.Generator at 0x7f39fbf69d30>

In [6]:
num_robots = 128

In [7]:
# Heightmap setup - use torch's XY indexing !!!!!
grid_res = 0.05  # 5cm per grid cell
max_coord = 3.2  # meters
DIM = int(2 * max_coord / grid_res)
xint = torch.linspace(-max_coord, max_coord, DIM)
yint = torch.linspace(-max_coord, max_coord, DIM)
x, y = torch.meshgrid(xint, yint, indexing="xy")
z = torch.zeros_like(x)
for thresh in [1.0, 0, -1.0, -2]:
    z[torch.logical_and(x > -thresh, y < thresh)] += 0.2
x_grid = x.repeat(num_robots, 1, 1)
y_grid = y.repeat(num_robots, 1, 1)
z_grid = z.repeat(num_robots, 1, 1)

In [8]:
# Instatiate the physics config
robot_model = RobotModelConfig(kind="marv", points_per_driving_part=384, points_per_body=512)
world_config = TerrainConfig(
    x_grid=x_grid,
    y_grid=y_grid,
    z_grid=z_grid,
    grid_res=grid_res,
    max_coord=max_coord,
    k_stiffness=40000,
    k_friction_lat=0.5,
    k_friction_lon=0.8,
)
physics_config = PhysicsEngineConfig(num_robots=num_robots, damping_alpha=5.0, dt=0.005)

2025-04-30 12:41:03,648 [RobotModelConfig][[92mINFO[00m]: Loading robot model from cache: /mnt/personal/korcadav/flipper_training/.robot_cache/marv_vx0.010_dp384_b512_whl0.02_trck0.05_eaecc2d5466de1eb8911703837d75c759b5c075158ced88ea318e932700dabb2 (robot_config.py:155)


In [None]:
compile_opts = {
    "fullgraph": True,
    "options": {
        "triton.cudagraphs": True,
        "coordinate_descent_tuning": True,
        "max-autotune": True,
    },
}

In [10]:
from flipper_training.utils.geometry import euler_to_quaternion

num_steps = 1000
x0 = torch.tensor([[-2.5, -2.5, 0.400]]).repeat(num_robots, 1)
xd0 = torch.zeros_like(x0)
q0 = euler_to_quaternion(*torch.tensor([0, 0, 1.0 * torch.pi])).repeat(num_robots, 1)
omega0 = torch.zeros_like(x0)
thetas0 = torch.zeros((num_robots, robot_model.num_driving_parts))
controls_all = torch.rand((num_robots, 2 * robot_model.num_driving_parts))

In [11]:
init_state = PhysicsState(x0, xd0, q0, omega0, thetas0)

In [12]:
trajectories = {}

for i in range(3):
    device_str = f"cuda:{i}"
    print(f"Running on {device_str}")
    device = set_device(device_str)
    world_config.to(device)
    robot_model.to(device)
    engine = DPhysicsEngine(physics_config, robot_model, device)
    engine.compile(**compile_opts)
    comp_state = deepcopy(init_state).to(device)
    comp_controls = torch.zeros((num_robots, 2 * robot_model.num_driving_parts)).to(device)
    engine(comp_state, comp_controls, world_config)
    state = deepcopy(init_state).to(device)
    traj = []
    time_sum = 0
    for _ in trange(num_steps, desc=f"Trajectory on {device_str}"):
        s = time.time()
        next_state, state_der = engine(state, comp_controls, world_config)
        e = time.time()
        time_sum += e - s
        next_state = deepcopy(next_state)
        state_der = deepcopy(state_der)
        traj.append((state.cpu(), state_der.cpu()))
        state = next_state
    trajectories[device_str] = traj
    print(f"Average time per step on {device_str}: {time_sum / num_steps:.4f} seconds")

Running on cuda:0


I0430 12:41:05.944000 371008 torch/_dynamo/utils.py:1162] [0/0] ChromiumEventLogger initialized with id ee7e3791-30fc-4304-855a-fd5e8c1e500f
I0430 12:41:05.948000 371008 torch/_dynamo/symbolic_convert.py:2706] [0/0] Step 1: torchdynamo start tracing forward /mnt/personal/korcadav/flipper_training/flipper_training/engine/engine.py:27
I0430 12:41:05.949000 371008 torch/fx/experimental/symbolic_shapes.py:3192] [0/0] create_env
  warn_once(
I0430 12:41:07.302000 371008 torch/_dynamo/symbolic_convert.py:3028] [0/0] Step 1: torchdynamo done tracing forward (RETURN_VALUE)
I0430 12:41:07.329000 371008 torch/_dynamo/output_graph.py:1458] [0/0] Step 2: calling compiler function inductor
AUTOTUNE bmm(128x3x1, 128x1x3)
  triton_bmm_17 0.0072 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=16, BLOCK_N=16, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=1
  triton_bmm_18 0.0072 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=1

Average time per step on cuda:0: 0.0006 seconds
Running on cuda:1


I0430 12:42:21.809000 371008 torch/_dynamo/symbolic_convert.py:2706] [0/1] Step 1: torchdynamo start tracing forward /mnt/personal/korcadav/flipper_training/flipper_training/engine/engine.py:27
I0430 12:42:21.812000 371008 torch/fx/experimental/symbolic_shapes.py:3192] [0/1] create_env
I0430 12:42:22.721000 371008 torch/_dynamo/symbolic_convert.py:3028] [0/1] Step 1: torchdynamo done tracing forward (RETURN_VALUE)
I0430 12:42:22.742000 371008 torch/_dynamo/output_graph.py:1458] [0/1] Step 2: calling compiler function inductor
AUTOTUNE bmm(128x3x1, 128x1x3)
  bmm 0.0051 ms 100.0% 
  triton_bmm_50 0.0072 ms 71.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=16, BLOCK_N=16, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=1, num_warps=1
  triton_bmm_51 0.0072 ms 71.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=16, BLOCK_N=16, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=2, num_warps=1
  triton_bmm_52 0.0072 ms 71.4% ACC_TYP

Average time per step on cuda:1: 0.0010 seconds
Running on cuda:2


I0430 12:43:20.969000 371008 torch/_dynamo/symbolic_convert.py:2706] [0/2] Step 1: torchdynamo start tracing forward /mnt/personal/korcadav/flipper_training/flipper_training/engine/engine.py:27
I0430 12:43:20.971000 371008 torch/fx/experimental/symbolic_shapes.py:3192] [0/2] create_env
I0430 12:43:21.937000 371008 torch/_dynamo/symbolic_convert.py:3028] [0/2] Step 1: torchdynamo done tracing forward (RETURN_VALUE)
I0430 12:43:21.972000 371008 torch/_dynamo/output_graph.py:1458] [0/2] Step 2: calling compiler function inductor
AUTOTUNE bmm(128x3x1, 128x1x3)
  bmm 0.0051 ms 100.0% 
  triton_bmm_87 0.0072 ms 71.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=16, BLOCK_N=16, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=1
  triton_bmm_88 0.0072 ms 71.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=16, BLOCK_N=16, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=1
  triton_bmm_89 0.0072 ms 71.4% ACC_TYP

Average time per step on cuda:2: 0.0007 seconds





In [13]:
atol = 1e-5

In [14]:
for i in range(num_steps):
    for t_to_diff in [
        ("cuda:0", "cuda:1"),
        ("cuda:0", "cuda:2"),
        ("cuda:1", "cuda:2"),
    ]:
        t1 = trajectories[t_to_diff[0]][i]
        t2 = trajectories[t_to_diff[1]][i]
        state1, state_der1 = t1
        state2, state_der2 = t2
        state_abs_diff = (state1.cpu() - state2.cpu()).abs()
        state_der_abs_diff = (state_der1.cpu() - state_der2.cpu()).abs()
        for k, v in state_abs_diff.items():
            if v.max() > atol:
                print(f"State {k} diff {v.max()} between {t_to_diff[0]} and {t_to_diff[1]} at step {i}")
                print(f"Device {t_to_diff[0]}: {k}={getattr(state1, k)}")
                print(f"Device {t_to_diff[1]}: {k}={getattr(state2, k)}")
        for k, v in state_der_abs_diff.items():
            if v.max() > atol:
                print(f"State Derivative {k} diff {v.max()} between {t_to_diff[0]} and {t_to_diff[1]} at step {i}")
                print(f"Device {t_to_diff[0]}: {k}={getattr(state_der1, k)}")
                print(f"Device {t_to_diff[1]}: {k}={getattr(state_der2, k)}")
                print(f"Torque for device {t_to_diff[0]}: {state_der1.torque}")
                print(f"Torque for device {t_to_diff[1]}: {state_der2.torque}")

State Derivative omega_d diff 0.00013399124145507812 between cuda:0 and cuda:1 at step 47
Device cuda:0: omega_d=tensor[128, 3] n=384 (1.5Kb) x∈[-0.210, 39.614] μ=14.904 σ=17.640
Device cuda:1: omega_d=tensor[128, 3] n=384 (1.5Kb) x∈[-0.210, 39.614] μ=14.904 σ=17.640
Torque for device cuda:0: tensor[128, 3] n=384 (1.5Kb) x∈[0.001, 0.997] μ=0.356 σ=0.455
Torque for device cuda:1: tensor[128, 3] n=384 (1.5Kb) x∈[0.001, 0.997] μ=0.356 σ=0.455
State Derivative f_spring diff 0.001129150390625 between cuda:0 and cuda:1 at step 47
Device cuda:0: f_spring=tensor[128, 2048, 3] n=786432 (3Mb) x∈[0., 1.229e+03] μ=4.713 σ=55.569
Device cuda:1: f_spring=tensor[128, 2048, 3] n=786432 (3Mb) x∈[0., 1.229e+03] μ=4.713 σ=55.569
Torque for device cuda:0: tensor[128, 3] n=384 (1.5Kb) x∈[0.001, 0.997] μ=0.356 σ=0.455
Torque for device cuda:1: tensor[128, 3] n=384 (1.5Kb) x∈[0.001, 0.997] μ=0.356 σ=0.455
State Derivative omega_d diff 0.00013399124145507812 between cuda:1 and cuda:2 at step 47
Device cuda:1: