In [1]:
%load_ext autoreload
%autoreload 3

In [None]:
%env TORCH_LOGS=dynamo 
%env TORCHDYNAMO_VERBOSE=1

In [3]:
import torch
import matplotlib.pyplot as plt
from flipper_training.configs import *
from flipper_training.engine.engine import DPhysicsEngine, PhysicsState, PhysicsStateDer
from flipper_training.utils.dynamics import *
from flipper_training.utils.geometry import *
from flipper_training.utils.environment import *
from flipper_training.utils.numerical import *
from copy import deepcopy
from collections import deque

In [4]:
HIGH_PERFORMANCE = False

In [5]:
from flipper_training.utils.torch_utils import autodevice

device = autodevice(HIGH_PERFORMANCE)

In [6]:
from flipper_training.vis.static_vis import *

In [None]:
torch.random.manual_seed(420)

In [8]:
num_robots = 1024

In [9]:
# Heightmap setup - use torch's XY indexing !!!!!
grid_res = 0.05  # 5cm per grid cell
max_coord = 6.4  # meters
DIM = int(2 * max_coord / grid_res)
xint = torch.linspace(-max_coord, max_coord, DIM)
yint = torch.linspace(-max_coord, max_coord, DIM)
x, y = torch.meshgrid(xint, yint, indexing='xy')

In [None]:
# gaussian hm
z = (1.0 * torch.exp(-0.5 * ((x - 0)**2 + (y - 4)**2)) +
     0.0 * torch.exp(-0.3 * ((x - 1)**2 + (y + 2)**2)) +
     2.0 * torch.exp(-0.1 * ((x + max_coord)**2 + (y + max_coord)**2))
     ) + 0.01 * torch.randn_like(x) + torch.exp(-0.03 * ((x + 5)**2 + (y + 5)**2))
x_grid = x.repeat(num_robots, 1, 1)
y_grid = y.repeat(num_robots, 1, 1)
z_grid = z.repeat(num_robots, 1, 1)
x_grid.shape

In [None]:
plot_grids_xyz(x_grid[0], y_grid[0], z_grid[0])

In [None]:
# Instatiate the physics config
robot_model = RobotModelConfig(robot_type="marv", voxel_size=0.08, points_per_driving_part=150)
world_config = WorldConfig(x_grid=x_grid, y_grid=y_grid, z_grid=z_grid, grid_res=grid_res, max_coord=max_coord, k_stiffness=30000)
physics_config = PhysicsEngineConfig(num_robots=num_robots)

In [13]:
# Controls
traj_length = 5.0  # seconds
n_iters = int(traj_length / physics_config.dt)
speed = 1.0  # m/s forward
omega = -0.5  # rad/s yaw
controls = robot_model.get_controls(torch.tensor([speed, omega]))
flipper_controls = torch.zeros_like(controls)

In [14]:
for cfg in [robot_model, world_config, physics_config]:
    cfg.move_all_tensors_to_device(device)

In [15]:
engine = DPhysicsEngine(physics_config, robot_model, device)

In [16]:
x0 = torch.tensor([-6, -6, 3.]).to(device).repeat(num_robots, 1)
xd0 = torch.zeros_like(x0)
R0 = torch.eye(3).to(device).repeat(num_robots, 1, 1)
omega0 = torch.zeros_like(x0)
thetas0 = torch.zeros(num_robots, 4).to(device)
local_robot_points0 = robot_model.robot_points.to(device).repeat(num_robots, 1, 1)
controls_all = torch.cat((controls, flipper_controls)).unsqueeze(0).repeat(n_iters, num_robots, 1).to(device)

In [17]:
# Set joint rotational velocities, we want to follow a sine wave, so we set the joint velocities to the derivative of the sine wave
# We want to go +- pi/6 5 times in 10 seconds
amplitude = torch.pi
periods = traj_length / 10.
rot_vels = torch.cos(torch.linspace(0, periods * 2 * np.pi, n_iters)) * amplitude
rot_vels = rot_vels.unsqueeze(-1).repeat(1, num_robots)
controls_all[:, :, len(controls)] = rot_vels
controls_all[:, :, len(controls) + 1] = rot_vels
controls_all[:, :, len(controls) + 2] = -rot_vels
controls_all[:, :, len(controls) + 3] = -rot_vels

In [18]:
init_state = PhysicsState(x0, xd0, R0, local_robot_points0, omega0, thetas0)

In [19]:
bench_state = PhysicsState.dummy_like(init_state)

In [20]:
compile = True
compile_opts = {"options":{"trace.enabled":True, "trace.graph_diagram":True,"max-autotune":True,"triton.cudagraphs":True}}

In [None]:
print(torch._inductor.list_mode_options())
print(torch._inductor.list_options())

In [None]:
if compile:
    # run all torch compilation
    engine = torch.compile(engine, **compile_opts)
    _ = engine(bench_state, controls_all[0], world_config)

In [None]:
%%timeit -o
_ = engine(bench_state, controls_all[0],world_config)

In [24]:
states = deque(maxlen=n_iters)
dstates = deque(maxlen=n_iters)
auxs = deque(maxlen=n_iters)

In [None]:
%%time 
state = deepcopy(init_state)
i = 0
for ctrl in controls_all:
    state, der, aux = engine(state, ctrl, world_config)
    states.append(deepcopy(state))
    dstates.append(der)
    auxs.append(aux)
    i += 1

In [26]:
from flipper_training.engine.engine_state import vectorize_iter_of_tensor_tuples

In [27]:
states_vec = vectorize_iter_of_tensor_tuples(states)
dstates_vec = vectorize_iter_of_tensor_tuples(dstates)
aux_vec = vectorize_iter_of_tensor_tuples(auxs)

In [None]:
plot_birdview_trajectory(world_config, states, iter_step=40)

In [29]:
ROBOT_IDX = 0

In [None]:
# Analyze the rotation matrices to check for numerical stability
R = states_vec.R[:, ROBOT_IDX].cpu()
RTRs = torch.bmm(R, R.permute(0, 2, 1))
diffs = torch.norm(RTRs - torch.eye(3), dim=(1, 2))
dets = torch.linalg.det(R)
plt.plot(diffs, label="RTR fro from I")
plt.plot(dets, label="det")
plt.legend()

In [None]:
# Plot components of acceleration
ay, ax, az = dstates_vec.xdd[:, ROBOT_IDX].T.cpu().numpy()
plt.figure(figsize=(10, 5))
plt.plot(ay, label='ay')
plt.plot(ax, label='ax')
plt.plot(az, label='az')
plt.legend()

In [None]:
# Plot components of torque
t1, t2, t3 = aux_vec.torque[:, ROBOT_IDX].T.cpu().numpy()
plt.figure(figsize=(10, 5), dpi=200)
plt.plot(t1, label='t1')
plt.plot(t2, label='t2')
plt.plot(t3, label='t3')
plt.grid()
plt.xlabel('Time step')
plt.ylabel('Torque (Nm) at CoG')
plt.legend()

In [None]:
# Joint angles
plt.figure(figsize=(10, 5), dpi=200)
theta1, theta2, theta3, theta4 = states_vec.thetas[:, ROBOT_IDX].T.cpu().numpy()
plt.plot(theta1, label='theta1')
plt.plot(theta2, label='theta2')
plt.plot(theta3, label='theta3')
plt.plot(theta4, label='theta4')
plt.grid()

In [None]:
# Height of the robot
plt.figure(figsize=(10, 5), dpi=200)
z = states_vec.x[:, ROBOT_IDX, 2].cpu().numpy()
plt.plot(z)

In [None]:
plot_3d_trajectory(world_config, states, auxs)

In [None]:
from flipper_training.vis.animator import animate_trajectory

In [None]:
animate_trajectory(world_config, physics_config, states, auxs, 0)