In [None]:
import sys
sys.path.append('../src/')
import torch
from torch.utils.data import DataLoader
from monoforce.models.terrain_encoder.bevfusion import BEVFusion
from monoforce.utils import read_yaml, position
from monoforce.datasets.rough import ROUGH, rough_seq_paths
from monoforce.models.traj_predictor.dphys_config import DPhysConfig

In [None]:
class Data(ROUGH):
    def __init__(self, path, lss_cfg, dphys_cfg=DPhysConfig(), is_train=True):
        super(Data, self).__init__(path, lss_cfg, dphys_cfg=dphys_cfg, is_train=is_train)

    def get_sample(self, i):
        imgs, rots, trans, intrins, post_rots, post_trans = self.get_images_data(i)
        points = torch.as_tensor(position(self.get_cloud(i))).T
        return (imgs, rots, trans, intrins, post_rots, post_trans,
                points)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
lss_config = read_yaml('../config/lss_cfg.yaml')
ds = Data(path=rough_seq_paths[1], lss_cfg=lss_config)
loader = DataLoader(ds, batch_size=4, shuffle=True)
print(f'Dataset length: {len(loader.dataset)}')

In [None]:
bevfusion = BEVFusion(grid_conf=lss_config['grid_conf'], data_aug_conf=lss_config['data_aug_conf'])
bevfusion.to(device)
print(f'Number of BEVFusion model parameters: {sum(p.numel() for p in bevfusion.parameters())}')

In [None]:
batch = next(iter(loader))
(imgs, rots, trans, intrins, post_rots, post_trans,
 points) = batch

img_inputs = [imgs, rots, trans, intrins, post_rots, post_trans]
img_inputs = [i.to(device) for i in img_inputs]
points_input = points.to(device)

In [None]:
with torch.inference_mode():
    with torch.no_grad():
        out = bevfusion(img_inputs, points_input)
        for k, v in out.items():
            print(f'{k}: {v.shape}')