In [1]:
import torch
from torch import nn

In [2]:
class LTVDynamicsModel(nn.Module):  # Linear Time-Varying Dynamics Model
    def __init__(self, layers=[64, 64]):
        super(LTVDynamicsModel, self).__init__()
        layers = [nn.Linear(3, layers[0]), nn.ReLU()] + [
            nn.Linear(layers[i], layers[i + 1]) for i in range(len(layers) - 1)
        ] + [nn.Linear(layers[-1], 3)]
        self.time_varying_F = nn.Sequential(*layers)

    def forward(self, xu):
        """
        xu: [roll_lataccel, v_ego, a_ego, lataccel, steer_command]
        """
        x = self.time_varying_F(xu[:, :3]) # Ft(roll_lataccel, v_ego, a_ego)
        Ft = x[:, 0:2].reshape(-1, 1, 2)
        ft = x[:, 2:].reshape(-1, 1, 1)

        xt = xu[:, 3:].reshape(-1, 2, 1) # [x, u]
        xt1 = (torch.bmm(Ft, xt) + ft).reshape(-1) # [x_t+1]
        return xt1

model = LTVDynamicsModel([128, 128, 128, 128])

In [4]:
model.load_state_dict(torch.load('best_model_train.pth', weights_only=True))
model.cpu()
model.eval()

LTVDynamicsModel(
  (time_varying_F): Sequential(
    (0): Linear(in_features=3, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=128, bias=True)
    (3): Linear(in_features=128, out_features=128, bias=True)
    (4): Linear(in_features=128, out_features=128, bias=True)
    (5): Linear(in_features=128, out_features=3, bias=True)
  )
)

In [6]:
mean, std = torch.load('mean_std.pt', weights_only=True)