In [None]:
import os
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt

In [None]:
# Training parameters
batch_size = 2
num_epochs = 10
lr = 1e-4
device = torch.device('cpu')
robot = 'tradr2'

In [None]:
from monoforce.models.dphysics import DPhysics
from monoforce.config import DPhysConfig
from monoforce.utils import read_yaml

# load configs: Differentiable Physics
dphys_cfg = DPhysConfig()
dphys_config_path = '../config/dphys_cfg.yaml'
assert os.path.isfile(dphys_config_path), 'Config file %s does not exist' % dphys_config_path
dphys_cfg.from_yaml(dphys_config_path)

# load configs: LSS (Terrain Encoder)
lss_config_path = f'../config/lss_cfg_{robot}.yaml'
assert os.path.isfile(lss_config_path), 'LSS config file %s does not exist' % lss_config_path
lss_cfg = read_yaml(lss_config_path)
pretrained_model_path = f'../config/weights/lss/lss_robingas_{robot}.pt'

In [None]:
# Load Differentiable Physics
dphysics = DPhysics(dphys_cfg, device=device)

In [None]:
from monoforce.models.terrain_encoder.lss import compile_model

# Load LSS (Terrain Encoder)
lss = compile_model(lss_cfg['grid_conf'], lss_cfg['data_aug_conf'], inpC=3, outC=1)
if os.path.exists(pretrained_model_path):
    # load pretrained model / update model with pretrained weights
    print('Loading pretrained LSS model from', pretrained_model_path)
    # https://discuss.pytorch.org/t/how-to-load-part-of-pre-trained-model/1113/3
    model_dict = lss.state_dict()
    pretrained_model = torch.load(pretrained_model_path)
    model_dict.update(pretrained_model)
    lss.load_state_dict(model_dict)
lss.to(device)
lss.train();

In [None]:
# Load dataset
from monoforce.utils import compile_data

train_ds, val_ds = compile_data(dataset='robingas',
                                robot=robot,
                                lss_cfg=lss_cfg,
                                dphys_cfg=dphys_cfg,
                                small_data=True)
print('Train dataset:', len(train_ds))
print('Validation dataset:', len(val_ds))

In [None]:
# Create dataloaders
from torch.utils.data import DataLoader

train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False)

In [None]:
%matplotlib inline

# Training: Friction Head
# https://discuss.pytorch.org/t/how-to-train-a-part-of-a-network/8923/2
optimizer = torch.optim.Adam(lss.bevencode.up_friction.parameters(), lr=lr)

# train loop
for epoch in range(num_epochs):
    # epoch loop
    for batch in tqdm(train_dl, total=len(train_dl)):
        batch = [torch.as_tensor(b, dtype=torch.float32, device=device) for b in batch]
        (imgs, rots, trans, intrins, post_rots, post_trans,
         hm_geom, hm_terrain,
         control_ts, controls,
         traj_ts, Xs, Xds, Rs, Omegas) = batch
        
        optimizer.zero_grad()
        
        # terrain encoder forward pass
        inputs = [imgs, rots, trans, intrins, post_rots, post_trans]
        voxel_feats = lss.get_voxels(*inputs)
        height_pred_geom, height_pred_diff, friction_pred = lss.bevencode(voxel_feats)
        
        # predict states
        states_pred, _ = dphysics(z_grid=height_pred_geom.squeeze(1),
                                  controls=controls,
                                  friction=friction_pred.squeeze(1))
        # unpack states
        Xs_pred, Xds_pred, Rs_pred, Omegas_pred, _ = states_pred
    
        # find the closest timesteps in the trajectory to the ground truth timesteps
        ts_ids = torch.argmin(torch.abs(control_ts.unsqueeze(1) - traj_ts.unsqueeze(2)), dim=2)
    
        # compute the loss as the mean squared error between the predicted and ground truth poses
        batch_size = Xs.shape[0]
        loss = torch.nn.functional.mse_loss(Xs_pred[torch.arange(batch_size).unsqueeze(1), ts_ids], Xs)
        
        # backpropagate
        loss.backward()
        optimizer.step()
   
    # visualize predictions
    with torch.no_grad(): 
        print('Epoch:', epoch, 'Loss:', loss.item())
        
        plt.figure(figsize=(20, 5))
        plt.subplot(1, 4, 1)
        plt.imshow(height_pred_geom[0, 0].cpu().numpy().T,  origin='lower', vmin=-1, vmax=1, cmap='terrain')
        plt.colorbar()
        plt.title('Predicted Heightmap')
        
        plt.subplot(1, 4, 2)
        plt.imshow(hm_geom[0, 0].cpu().numpy().T, origin='lower', vmin=-1, vmax=1, cmap='terrain')
        plt.colorbar()
        plt.title('Ground Truth Heightmap')
        
        plt.subplot(1, 4, 3)
        plt.imshow(friction_pred[0, 0].cpu().numpy(), origin='lower', vmin=0, vmax=1, cmap='terrain')
        plt.colorbar()
        plt.title('Predicted Friction')
        
        plt.subplot(1, 4, 4)
        plt.plot(Xs[0, :, 0].cpu().numpy(), Xs[0, :, 1].cpu().numpy(), 'r', label='GT poses')
        plt.plot(Xs_pred[0, :, 0].cpu().numpy(), Xs_pred[0, :, 1].cpu().numpy(), 'b', label='Pred poses')
        plt.grid()
        plt.axis('equal')
        plt.legend()
        plt.title('Trajectory')
        
        plt.show()