In [None]:
# add the path to the source code of the MonoForce package
import sys
sys.path.append('../src')

In [None]:
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np

In [None]:
# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

In [None]:
# Training parameters
num_epochs = 100
lr = 1e-4
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
robot = 'marv'
traj_sim_time = 1.0
small_data = True

In [None]:
from monoforce.models.dphysics import DPhysics
from monoforce.dphys_config import DPhysConfig
from monoforce.utils import read_yaml

# load configs: Differentiable Physics
dphys_cfg = DPhysConfig(robot=robot)
dphys_cfg.traj_sim_time = traj_sim_time

# load configs: LSS (Terrain Encoder)
lss_config_path = '../config/lss_cfg.yaml'
lss_cfg = read_yaml(lss_config_path)
pretrained_model_path = f'../config/weights/lss/val.pth'

In [None]:
# Load Differentiable Physics
dphysics = DPhysics(dphys_cfg, device=device)

In [None]:
from monoforce.models.terrain_encoder.lss import LiftSplatShoot

# Load LSS (Terrain Encoder)
lss = LiftSplatShoot(lss_cfg['grid_conf'], lss_cfg['data_aug_conf'], outC=1)
lss.from_pretrained(pretrained_model_path)
lss.to(device);

In [None]:
# Load dataset
from monoforce.utils import compile_data

train_ds, val_ds = compile_data(small_data=small_data, dphys_cfg=dphys_cfg, lss_cfg=lss_cfg)
print('Train dataset:', len(train_ds))
print('Validation dataset:', len(val_ds))

In [None]:
# Create dataloaders
from torch.utils.data import DataLoader

train_dl = DataLoader(train_ds, batch_size=1, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=1, shuffle=False)

In [None]:
def monoforce_forward(inputs, lss, dphysics):    
    # terrain encoder forward pass
    out = lss(*inputs)
    height_terrain_pred, friction_pred = out['terrain'], out['friction']
    
    # predict states with differentiable physics
    states_pred, _ = dphysics(z_grid=height_terrain_pred.squeeze(1),
                              controls=controls,
                              friction=friction_pred.squeeze(1))
    
    return height_terrain_pred, friction_pred, states_pred

In [None]:
# Training: Friction Head
# https://discuss.pytorch.org/t/how-to-train-a-part-of-a-network/8923/2
lss.eval()
for p in lss.parameters():
    p.requires_grad = False
for p in lss.bevencode.up_friction.parameters():
    p.requires_grad = True
lss.bevencode.up_friction.train()
optimizer = torch.optim.Adam(lss.bevencode.up_friction.parameters(), lr=lr)

In [None]:
%matplotlib inline
from monoforce.models.terrain_encoder.utils import denormalize_img

# train loop
loss_history = []
for epoch in tqdm(range(num_epochs)):
    # epoch loop
    loss_epoch = 0
    for batch in train_dl:
        batch = [b.to(device) for b in batch]
        
        # unpack batch
        (imgs, rots, trans, intrins, post_rots, post_trans,
         hm_geom, hm_terrain,
         control_ts, controls,
         pose0,
         traj_ts, Xs, Xds, Rs, Omegas) = batch
        # monoforce inputs
        inputs = [imgs, rots, trans, intrins, post_rots, post_trans]
        
        # forward pass
        height_terrain_pred, friction_pred, states_pred = monoforce_forward(inputs, lss, dphysics)
        
        # unpack states
        Xs_pred, Xds_pred, Rs_pred, Omegas_pred = states_pred

        # find the closest timesteps in the trajectory to the ground truth timesteps
        ts_ids = torch.argmin(torch.abs(control_ts.unsqueeze(1) - traj_ts.unsqueeze(2)), dim=2)

        # compute the loss as the mean squared error between the predicted and ground truth poses
        batch_size = Xs.shape[0]
        loss = torch.nn.functional.mse_loss(Xs_pred[torch.arange(batch_size).unsqueeze(1), ts_ids], Xs)
        
        # accumulate loss
        loss_epoch += loss.item()

        # backpropagate
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    # print epoch loss
    loss_epoch /= len(train_dl)
    loss_history.append(loss_epoch)
    if epoch % 10 == 0:
        print('Train epoch:', epoch, 'Mean loss:', loss_epoch)

In [None]:
# plot loss history
plt.plot(loss_history)
plt.xlabel('Epoch')
plt.ylabel('Mean loss')
plt.grid()
plt.show()

In [None]:
# evaluate model and visualize predictions
with torch.no_grad():
    batch = next(iter(val_dl))
    batch = [b.to(device) for b in batch]
    
    # unpack batch
    (imgs, rots, trans, intrins, post_rots, post_trans,
     hm_geom, hm_terrain,
     control_ts, controls,
     pose0,
     traj_ts, Xs, Xds, Rs, Omegas) = batch
    # monoforce inputs
    inputs = [imgs, rots, trans, intrins, post_rots, post_trans]
    
    # forward pass
    height_terrain_pred, friction_pred, states_pred = monoforce_forward(inputs, lss, dphysics)
    
    # unpack states
    Xs_pred, Xds_pred, Rs_pred, Omegas_pred = states_pred

    # find the closest timesteps in the trajectory to the ground truth timesteps
    ts_ids = torch.argmin(torch.abs(control_ts.unsqueeze(1) - traj_ts.unsqueeze(2)), dim=2)
    
    # visualize
    plt.figure(figsize=(20, 10))
    plt.subplot(2, 3, 1)
    plt.imshow(denormalize_img(imgs[0, 0]))
    plt.title('Input Image')
    plt.axis('off')
    
    plt.subplot(2, 3, 2)
    plt.imshow(height_terrain_pred[0, 0].cpu().numpy().T,  origin='lower', vmin=-1, vmax=1, cmap='jet')
    plt.colorbar()
    plt.title('Predicted Heightmap')

    plt.subplot(2, 3, 3)
    plt.imshow(hm_terrain[0, 0].cpu().numpy().T, origin='lower', vmin=-1, vmax=1, cmap='jet')
    plt.colorbar()
    plt.title('Ground Truth Heightmap')

    plt.subplot(2, 3, 4)
    plt.imshow(friction_pred[0, 0].cpu().numpy().T, origin='lower', vmin=0, vmax=1, cmap='jet')
    plt.colorbar()
    plt.title('Predicted Friction')

    plt.subplot(2, 3, 5)
    plt.plot(Xs[0, :, 0].cpu().numpy(), Xs[0, :, 1].cpu().numpy(), 'xr', label='GT poses')
    plt.plot(Xs_pred[0, :, 0].cpu().numpy(), Xs_pred[0, :, 1].cpu().numpy(), '.b', label='Pred poses')
    plt.grid()
    plt.axis('equal')
    plt.legend()
    plt.title('Trajectories XY')
    # plot lines between corresponding points from the ground truth and predicted trajectories (use ts_ids)
    for j in range(Xs.shape[1]):
        plt.plot([Xs[0, j, 0].cpu().numpy(), Xs_pred[0, ts_ids[0, j], 0].cpu().numpy()],
                 [Xs[0, j, 1].cpu().numpy(), Xs_pred[0, ts_ids[0, j], 1].cpu().numpy()], 'g')
    
    plt.subplot(2, 3, 6)
    plt.plot(traj_ts[0].cpu().numpy(), Xs[0, :, 2].cpu().numpy(), 'xr', label='GT poses')
    plt.plot(control_ts[0].cpu().numpy(), Xs_pred[0, :, 2].cpu().numpy(), '.b', label='Pred poses')
    plt.grid()
    plt.ylim(-1, 1)
    plt.legend()
    plt.title('Trajectories Z')

    plt.show()