In [1]:
import math
import pickle
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from model import HDDriveDQN, HDMapSensorDQN, hd_net_args
from srunner.tools import dotdict

In [2]:
with open('state.pickle', 'rb') as fp:
    state = pickle.load(fp)
state.keys()

dict_keys(['frame', 'accelerometer', 'gyroscope', 'compass', 'gnss', 'velocity', 'hd_map', 'front_rgb'])

In [3]:
h_size = 128
hd_sensor_model = HDMapSensorDQN((3,96,96), 7, p=0.3)
npImg_to_tensor = lambda x: torch.tensor(np.expand_dims(x.transpose(2,0,1), axis=0), dtype=torch.float32)

In [4]:
acc_X = torch.tensor(state['accelerometer'], dtype=torch.float32).unsqueeze(0)
comp_X = torch.tensor(state['compass'], dtype=torch.float32).unsqueeze(0)
gyro_X = torch.tensor(state['gyroscope'], dtype=torch.float32).unsqueeze(0)
bev_X = npImg_to_tensor(state['hd_map'])
front_X = npImg_to_tensor(state['front_rgb'])
vel_X = torch.tensor(state['velocity']).unsqueeze(0)
bev_X.shape, front_X.shape, acc_X.shape, gyro_X.shape, comp_X.shape, vel_X.shape

(torch.Size([1, 3, 96, 96]),
 torch.Size([1, 3, 96, 96]),
 torch.Size([1, 3]),
 torch.Size([1, 3]),
 torch.Size([1, 1]),
 torch.Size([1]))

In [5]:
bev_X_h, front_X_h, ego_X_h, vel_X_h = hd_sensor_model(bev_X.repeat(32,1,1,1), front_X.repeat(32,1,1,1), acc_X.repeat(32,1), comp_X.repeat(32,1), gyro_X.repeat(32,1), vel_X.repeat(32))
bev_X_h.shape, front_X_h.shape, ego_X_h.shape, vel_X_h.shape

(torch.Size([1, 32, 128]),
 torch.Size([1, 32, 128]),
 torch.Size([1, 32, 128]),
 torch.Size([1, 32, 128]))

In [6]:
args = hd_net_args
hd_drive_net = HDDriveDQN(args)

In [7]:
bev_X.repeat(4,32,1,1,1).shape

torch.Size([4, 32, 3, 96, 96])

In [8]:
steering_q_vals, throttle_q_vals, brake_q_vals, t_or_b_vals = hd_drive_net(bev_X.repeat(args.n_frames,32,1,1,1), 
             front_X.repeat(args.n_frames, 32,1,1,1), 
             acc_X.repeat(args.n_frames, 32,1), 
             comp_X.repeat(args.n_frames, 32,1), 
             gyro_X.repeat(args.n_frames, 32,1), 
             vel_X.repeat(args.n_frames, 32))

torch.Size([4, 32, 128]) torch.Size([4, 32, 128])
torch.Size([4, 32, 128]) torch.Size([4, 32, 128])
torch.Size([4, 32, 128]) torch.Size([4, 32, 128])
torch.Size([4, 32, 128]) torch.Size([4, 32, 128])


In [9]:
steering_q_vals.shape, throttle_q_vals.shape, brake_q_vals.shape

(torch.Size([32, 10]), torch.Size([32, 10]), torch.Size([32, 10]))

In [10]:
sensor_params = sum(p.numel() for p in hd_drive_net.sensor_net.parameters())
fusion_params = sum(p.numel() for p in hd_drive_net.fusion_net.parameters())
temporal_params = sum(p.numel() for p in hd_drive_net.temporal_net.parameters())
throttle_params = sum(p.numel() for p in hd_drive_net.throttle_net.parameters())
steering_params = sum(p.numel() for p in hd_drive_net.steering_net.parameters())
brake_params = sum(p.numel() for p in hd_drive_net.brake_net.parameters())
t_or_b_params = sum(p.numel() for p in hd_drive_net.t_or_b_net.parameters())
sensor_params, fusion_params, temporal_params, throttle_params, steering_params, brake_params, t_or_b_params

(750784, 661632, 661632, 1290, 1290, 1290, 258)

In [11]:
sum(p.numel() for p in hd_drive_net.parameters())

2078176