# COMP 7310 Personal Project

### 1. Config Setup

In [None]:
# Make Dirs for Model Storage and Test Results
import os

try:
    # Model Storage Path
    model_prediction_path = r'/content/drive/MyDrive/ColabNotebooks/out/working/prediction_model'
    os.makedirs(model_prediction_path)
except:
    pass
try:
    # Test Results Path
    test_result_path = r'/content/drive/MyDrive/ColabNotebooks/out/working/test_results'
    os.makedirs(test_result_path)
except:
    pass
try:
    # Test Seen & Unseen Path
    os.makedirs(test_result_path + r'/test_seen')
    os.makedirs(test_result_path + r'/test_unseen')
except:
    pass

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
"""
config settings:
"""
### TFNET parameters
BATCH_SIZE = 64 # Training batch size
TEST_BATCH_SIZE = 1 # Test batch size
EPOCHS = 150 # Traning epoch
SAVE_INTERVAL = 20 # Save model every 20 epochs
STEP_SIZE = 120 # Step size for moving forward the window (For training)
TEST_STEP_SIZE = 400 # Step size for moving forward the window (For testing)
WINDOW_SIZE = 480 # Window size for training and testing
INPUT_CHANNEL = 12 # Input feature dimension (Gryo + Acce +lin + gra)
OUTPUT_CHANNEL = 2 # Output dimension (2D velocity vector)
SAMPLING_RATE = 200 # Sampling rate
LAYER_SIZE = 120 # The size of LSTM
LAYERS = 4 # The layer size of LSTM
DROPOUT = 0.2 # Dropout probability
LEARNING_RATE = 0.003 # Learning rate
NUM_WORKERS = 8
### ------------------ ###

### Data preprocessing parameters
FEATURE_SIGMA = 2.0 # Sigma for feature gaussian smoothing
TARGET_SIGMA = 30.0 # Sigma for target gaussian smoothing
### ------------------ ###

### Device for training
DEVICE = "cuda:0" # You can choose GPU or CPU
### ------------------ ###

### Training and testing setting
DATA_DIR = r'/content/drive/MyDrive/ColabNotebooks/original_data/train_dataset/' # Dataset directory for training
VAL_DATA_DIR = r'/content/drive/MyDrive/ColabNotebooks/original_data/val_dataset/' # Dataset directory for validation
TEST_DIR = r'/content/drive/MyDrive/ColabNotebooks/original_data/test_seen/' # Dataset directory for testing (unseen_subjects_test_set)
OUT_DIR = r'/content/drive/MyDrive/ColabNotebooks/out/working/prediction_model' # Output directory for both traning and testing
MODEL_PATH = '' # Model path for testing
### ------------------ ###

def load_config():
    kwargs = {}
    kwargs['batch_size'] = BATCH_SIZE
    kwargs['test_batch_size'] = TEST_BATCH_SIZE
    kwargs['epochs'] = EPOCHS
    kwargs['save_interval'] = SAVE_INTERVAL
    kwargs['step_size'] = STEP_SIZE
    kwargs['test_step_size'] = TEST_STEP_SIZE
    kwargs['window_size'] = WINDOW_SIZE
    kwargs['sampling_rate'] = SAMPLING_RATE
    kwargs['input_channel'] = INPUT_CHANNEL
    kwargs['output_channel'] = OUTPUT_CHANNEL
    kwargs['layer_size'] = LAYER_SIZE
    kwargs['layers'] = LAYERS
    kwargs['dropout'] = DROPOUT
    kwargs['learning_rate'] = LEARNING_RATE
    kwargs['num_workers'] = NUM_WORKERS

    kwargs['feature_sigma'] = FEATURE_SIGMA
    kwargs['target_sigma'] = TARGET_SIGMA

    kwargs['device'] = DEVICE

    kwargs['data_dir'] = DATA_DIR
    kwargs['val_data_dir'] = VAL_DATA_DIR
    kwargs['test_dir'] = TEST_DIR
    kwargs['out_dir'] = OUT_DIR
    kwargs['model_path'] = MODEL_PATH

    return kwargs


### 2. Model Design

In [None]:
import torch
from scipy import signal
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torchvision.models as models
from torch.nn.utils import weight_norm

### BiLSTM model
class BilinearLSTMSeqNetwork(torch.nn.Module):
    def __init__(self, input_size, out_size, batch_size, device,
                 lstm_size=100, lstm_layers=3, dropout=0):
        """
        LSTM network with Bilinear layer
        Input: torch array [batch x frames x input_size]
        Output: torch array [batch x frames x out_size]
        :param input_size: num. channels in input
        :param out_size: num. channels in output
        :param batch_size:
        :param device: torch device
        :param lstm_size: number of LSTM units per layer
        :param lstm_layers: number of LSTM layers
        :param dropout: dropout probability of LSTM (@ref https://pytorch.org/docs/stable/nn.html#lstm)
        """
        super(BilinearLSTMSeqNetwork, self).__init__()
        self.input_size = input_size
        self.lstm_size = lstm_size
        self.output_size = out_size
        self.num_layers = lstm_layers
        self.batch_size = batch_size
        self.device = device

        self.bilinear = torch.nn.Bilinear(self.input_size, self.input_size, self.input_size * 4)
        self.lstm = torch.nn.LSTM(self.input_size * 5, self.lstm_size, self.num_layers, batch_first=True, dropout=dropout)
        self.linear1 = torch.nn.Linear(self.lstm_size + self.input_size * 5, self.output_size * 5)
        self.linear2 = torch.nn.Linear(self.output_size * 5, self.output_size)
        self.hidden = self.init_weights()

    def forward(self, input):
        input_mix = self.bilinear(input, input)
        input_mix = torch.cat([input, input_mix], dim=2)
        output, self.hidden = self.lstm(input_mix, self.init_weights())
        output = torch.cat([input_mix, output], dim=2)
        output = self.linear1(output)
        output = self.linear2(output)
        return output

    def init_weights(self):
        h0 = torch.zeros(self.num_layers, self.batch_size, self.lstm_size)
        c0 = torch.zeros(self.num_layers, self.batch_size, self.lstm_size)
        h0 = h0.to(self.device)
        c0 = c0.to(self.device)
        return Variable(h0), Variable(c0)
### --------------------- End of the model --------------------- ###

### 3. Data Loader

In [None]:
!pip install pyquaternion==0.9.9
!pip install numpy-quaternion==2022.4.3

Collecting pyquaternion==0.9.9
  Downloading pyquaternion-0.9.9-py3-none-any.whl (14 kB)
Installing collected packages: pyquaternion
Successfully installed pyquaternion-0.9.9
Collecting numpy-quaternion==2022.4.3
  Downloading numpy_quaternion-2022.4.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (205 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m205.9/205.9 kB[0m [31m4.9 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: numpy-quaternion
Successfully installed numpy-quaternion-2022.4.3


In [None]:
import json
import random
import os
from os import path as osp

import h5py
import torch
import numpy as np
import quaternion
import math
from scipy.ndimage import gaussian_filter1d
from torch.utils.data import Dataset

def convert_data(data_path):
    """
    Data Processing
    :param data_path
    This function is used to convert the raw data
    stored in hdf5 format to numpy array
    """
    # read hdf5 file
    file = h5py.File(os.path.join(data_path, 'data.hdf5'), 'r')
    synced = file['synced'] # synced data (gyro, acce, magn ...)
    pose = file['pose'] # pose data (ground truth)
    # timestamp (1D)
    timestamps = np.array(synced['time'])
    timestamps = timestamps.reshape((len(timestamps), 1)) / 10**9 # - timestamps[0] # start from 0
    timestamps = timestamps - timestamps[0]
    # gyroscope (3D), accelerometer (3D), magnetometer (3D), rotation vector (4D)
    gyro = np.array(synced['gyro'])
    acce = np.array(synced['acce'])
    lina = np.array(synced["linacce"])
    grav = np.array(synced["gravity"])
    # game rotation vector can convert acce and gyro
    # from body frame to navigation frame
    rotation_vector = synced['game_rv']

    # position (3D), orientation (4D)
    pose = file['pose']
    tango_pos = np.array(pose['tango_pos'])
    tango_ori = np.array(pose['tango_ori'])

    # Compute the IMU orientation in the Tango coordinate frame.
    init_tango_ori = quaternion.quaternion(*tango_ori[0])
    ori = rotation_vector
    ori_q = quaternion.from_float_array(ori)
    init_rotor = init_tango_ori * ori_q[0].conj()
    ori_q = init_rotor * ori_q

    gyro_q = quaternion.from_float_array(np.concatenate([np.zeros([gyro.shape[0], 1]), gyro], axis=1))
    acce_q = quaternion.from_float_array(np.concatenate([np.zeros([acce.shape[0], 1]), acce], axis=1))
    glob_gyro = quaternion.as_float_array(ori_q * gyro_q * ori_q.conj())[:, 1:]
    glob_acce = quaternion.as_float_array(ori_q * acce_q * ori_q.conj())[:, 1:]

    rawdata = np.concatenate((timestamps, glob_gyro, glob_acce, lina, grav), axis = 1)
    # print(rawdata.shape)
    # import sys
    # sys.exit(1)
    groundtruth = np.concatenate((timestamps, tango_pos, tango_ori), axis = 1)

    return rawdata, groundtruth

def convert_data_test(data_path):
    """
    Data Processing
    :param data_path
    This function is used to convert the raw data
    stored in hdf5 format to numpy array
    """
    # read hdf5 file
    file = h5py.File(os.path.join(data_path, 'data.hdf5'), 'r')
    synced = file['synced'] # synced data (gyro, acce, magn ...)
    pose = file['pose'] # pose data (ground truth)
    # timestamp (1D)
    timestamps = np.array(synced['time'])
    timestamps = timestamps.reshape((len(timestamps), 1)) / 10**9 # - timestamps[0] # start from 0
    timestamps = timestamps - timestamps[0]
    # gyroscope (3D), accelerometer (3D), magnetometer (3D), rotation vector (4D)
    gyro = np.array(synced['gyro'])
    acce = np.array(synced['acce'])
    lina = np.array(synced["linacce"])
    grav = np.array(synced["gravity"])
    # game rotation vector can convert acce and gyro
    # from body frame to navigation frame
    rotation_vector = synced['game_rv']

    # position (3D), orientation (4D)
    pose = file['pose']
    tango_ori = np.array(pose['tango_ori'])

    # Compute the IMU orientation in the Tango coordinate frame.
    init_tango_ori = quaternion.quaternion(*tango_ori[0])
    ori = rotation_vector
    ori_q = quaternion.from_float_array(ori)
    init_rotor = init_tango_ori * ori_q[0].conj()
    ori_q = init_rotor * ori_q

    gyro_q = quaternion.from_float_array(np.concatenate([np.zeros([gyro.shape[0], 1]), gyro], axis=1))
    acce_q = quaternion.from_float_array(np.concatenate([np.zeros([acce.shape[0], 1]), acce], axis=1))
    glob_gyro = quaternion.as_float_array(ori_q * gyro_q * ori_q.conj())[:, 1:]
    glob_acce = quaternion.as_float_array(ori_q * acce_q * ori_q.conj())[:, 1:]

    rawdata = np.concatenate((timestamps, glob_gyro, glob_acce, lina, grav), axis = 1)

    return rawdata

class GlobSequence():
    """
    Property: global coordinate frame
    """
    # add 3-axis magnetometer
    # feature_dim = 9
    feature_dim = 6
    target_dim = 2
    aux_dim = 8

    def __init__(self, data_path = None, **kwargs):
        super().__init__()
        self.ts, self.features, self.targets, self.gt_pos = None, None, None, None
        # self.info = {}
        self.w = kwargs.get('interval', 1)
        if data_path is not None:
            self.load(data_path)

    def load(self, data_path):
        # print("the data_path is:", data_path)
        data, ground_truth = convert_data(data_path)
        # already in global coordinate frame and start from start frame
        # timestamp (1D) gyroscope (3D), accelerometer (3D)
        gyro = data[:, 1:4]
        acce = data[:, 4:7]
        lina = data[:, 7:10]
        grav = data[:, 10:13]
        ts = data[:, 0]
        # tango position
        tango_pos = ground_truth[:, 1:4]
        # tango orientation
        tango_ori = ground_truth[:, 4:8]

        dt = (ts[self.w:] - ts[:-self.w])[:, None]
        # calculate the global velocity
        glob_v = (tango_pos[self.w:] - tango_pos[:-self.w]) / dt

        self.ts = ts
        self.features = np.concatenate([gyro, acce, lina, grav], axis = 1)
        # We only use the global velocity in the floor plane
        self.targets = glob_v[:, :2]
        self.orientations = tango_ori # quaternion.as_float_array(tango_ori)
        self.gt_pos = tango_pos

    def get_feature(self):
        return self.features

    def get_target(self):
        return self.targets

    def get_aux(self):
        return np.concatenate([self.ts[:, None], self.orientations, self.gt_pos], axis = 1)

class GlobSequenceTest():
    """
    Property: global coordinate frame
    """
    # add 3-axis magnetometer
    # feature_dim = 9
    feature_dim = 6
    aux_dim = 8

    def __init__(self, data_path = None, **kwargs):
        super().__init__()
        self.ts, self.features, self.targets, self.gt_pos = None, None, None, None
        # self.info = {}
        self.w = kwargs.get('interval', 1)
        if data_path is not None:
            self.load(data_path)

    def load(self, data_path):
        # print("the data_path is:", data_path)
        data = convert_data_test(data_path)
        # already in global coordinate frame and start from start frame
        # timestamp (1D) gyroscope (3D), accelerometer (3D)
        gyro = data[:, 1:4]
        acce = data[:, 4:7]
        lina = data[:, 7:10]
        grav = data[:, 10:13]
        ts = data[:, 0]

        dt = (ts[self.w:] - ts[:-self.w])[:, None]

        self.ts = ts
        self.features = np.concatenate([gyro, acce, lina, grav], axis = 1)

    def get_feature(self):
        return self.features

    def get_aux(self):
        return 0

def load_sequences(seq_type, root_dir, data_list, **kwargs):
    features_all, targets_all, aux_all = [], [], []

    for i in range(len(data_list)):
        seq = seq_type(osp.join(root_dir, data_list[i]), **kwargs)
        feat, targ, aux = seq.get_feature(), seq.get_target(), seq.get_aux()
        # add feat, targ, aux to list
        features_all.append(feat)
        targets_all.append(targ)
        aux_all.append(aux)
    return features_all, targets_all, aux_all

def load_sequences_test(seq_type, root_dir, data_list, **kwargs):
    features_all, aux_all = [], []

    for i in range(len(data_list)):
        seq = seq_type(osp.join(root_dir, data_list[i]), **kwargs)
        feat, aux = seq.get_feature(), seq.get_aux()
        # add feat, targ, aux to list
        features_all.append(feat)
        aux_all.append(aux)
    return features_all, aux_all

class SequenceToSequenceDataset(Dataset):
    def __init__(self, seq_type, root_dir, data_list, step_size = 100, window_size = 400,
                 random_shift = 0, transform = None, **kwargs):
        super(SequenceToSequenceDataset, self).__init__()
        self.seq_type = seq_type
        self.feature_dim = seq_type.feature_dim
        self.target_dim = seq_type.target_dim
        self.aux_dim = seq_type.aux_dim
        self.window_size = window_size
        self.step_size = step_size
        self.random_shift = random_shift
        self.transform = transform
        self.projection_width = kwargs.get('projection_width', 0)
        self.data_path = [osp.join(root_dir, data) for data in data_list]
        self.index_map = []

        self.features, self.targets, aux = load_sequences(
            seq_type, root_dir, data_list, **kwargs)

        # Optionally smooth the sequence
        feat_sigma = kwargs.get('feature_sigma,', -1)
        targ_sigma = kwargs.get('target_sigma,', -1)
        if feat_sigma > 0:
            self.features = [gaussian_filter1d(feat, sigma=feat_sigma, axis=0) for feat in self.features]
        if targ_sigma > 0:
            self.targets = [gaussian_filter1d(targ, sigma=targ_sigma, axis=0) for targ in self.targets]

        max_norm = 3.0 #
        self.ts, self.orientations, self.gt_pos, self.local_v = [], [], [], []
        for i in range(len(data_list)):
            self.features[i] = self.features[i][:-1]
            self.targets[i] = self.targets[i]
            self.ts.append(aux[i][:-1, :1])
            self.orientations.append(aux[i][:-1, 1:5])
            self.gt_pos.append(aux[i][:-1, 5:8])

            velocity = np.linalg.norm(self.targets[i], axis=1)  # Remove outlier ground truth data
            bad_data = velocity > max_norm
            for j in range(window_size + random_shift, self.targets[i].shape[0], step_size):
                if not bad_data[j - window_size - random_shift:j + random_shift].any():
                    self.index_map.append([i, j])

        # if shuffle is necessary here? As the training data should be in a logical sequence
        if kwargs.get('shuffle', True):
            random.shuffle(self.index_map)

    def __getitem__(self, item):
        # output format: input, target, seq_id, frame_id
        seq_id, frame_id = self.index_map[item][0], self.index_map[item][1]

        feat = np.copy(self.features[seq_id][frame_id - self.window_size:frame_id])
        targ = np.copy(self.targets[seq_id][frame_id - self.window_size:frame_id])
        # random rotate the sequence in the horizontal plane
        if self.transform is not None:
            feat, targ = self.transform(feat, targ)

            return feat.astype(np.float32), targ.astype(np.float32), seq_id, frame_id

    def __len__(self):
        return len(self.index_map)

    def get_lstm_test_seq(self):
        return np.array(self.features).astype(np.float32), np.array(self.targets).astype(np.float32)

class SequenceToSequenceDatasetTest(Dataset):
    def __init__(self, seq_type, root_dir, data_list, step_size = 100, window_size = 400,
                 random_shift = 0, transform = None, **kwargs):
        super(SequenceToSequenceDatasetTest, self).__init__()
        self.seq_type = seq_type
        self.feature_dim = seq_type.feature_dim
        self.aux_dim = seq_type.aux_dim
        self.window_size = window_size
        self.step_size = step_size
        self.random_shift = random_shift
        self.transform = transform
        self.projection_width = kwargs.get('projection_width', 0)

        self.data_path = [osp.join(root_dir, data) for data in data_list]
        self.index_map = []

        self.features, aux = load_sequences_test(
            seq_type, root_dir, data_list, **kwargs)

        # Optionally smooth the sequence
        feat_sigma = kwargs.get('feature_sigma,', -1)
        if feat_sigma > 0:
            self.features = [gaussian_filter1d(feat, sigma=feat_sigma, axis=0) for feat in self.features]

        max_norm = 3.0 #
        for i in range(len(data_list)):
            self.features[i] = self.features[i][:-1]

        # if shuffle is necessary here? As the training data should be in a logical sequence
        if kwargs.get('shuffle', True):
            random.shuffle(self.index_map)

    def __getitem__(self, item):
        # output format: input, target, seq_id, frame_id
        seq_id, frame_id = self.index_map[item][0], self.index_map[item][1]
        feat = np.copy(self.features[seq_id][frame_id - self.window_size:frame_id])

        return feat.astype(np.float32), seq_id, frame_id

    def __len__(self):
        return len(self.index_map)

    def get_lstm_test_seq(self):
        return np.array(self.features).astype(np.float32)

def change_cf(ori, vectors):
    """
    Euler-Rodrigous formula v'=v+2s(rxv)+2rx(rxv)
    :param ori: quaternion [n]x4
    :param vectors: vector nx3
    :return: rotated vector nx3
    """
    assert ori.shape[-1] == 4
    assert vectors.shape[-1] == 3

    if len(ori.shape) == 1:
        ori = ori.reshape(1, -1)

    q_s = ori[:, :1]
    q_r = ori[:, 1:]

    tmp = np.cross(q_r, vectors)
    vectors = np.add(np.add(vectors, np.multiply(2, np.multiply(q_s, tmp))), np.multiply(2, np.cross(q_r, tmp)))
    return vectors

class ComposeTransform:
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, feat, targ, **kwargs):
        for t in self.transforms:
            feat, targ = t(feat, targ)
        return feat, targ

class RandomHoriRotateSeq:
    def __init__(self, input_format, output_format=None):
        """
        Rotate global input, global output by a random angle
        @:param input format - input feature vector(x,3) boundaries as array (E.g [0,3,6])
        @:param output format - output feature vector(x,2/3) boundaries as array (E.g [0,2,5])
                                if 2, 0 is appended as z.
        """
        self.i_f = input_format
        self.o_f = output_format

    def __call__(self, feature, target):
        a = np.random.random() * 2 * np.math.pi
        # print("Rotating by {} degrees", a/np.math.pi * 180)
        t = np.array([np.cos(a), 0, 0, np.sin(a)])

        for i in range(len(self.i_f) - 1):
            feature[:, self.i_f[i]: self.i_f[i + 1]] = \
                change_cf(t, feature[:, self.i_f[i]: self.i_f[i + 1]])

        for i in range(len(self.o_f) - 1):
            if self.o_f[i + 1] - self.o_f[i] == 3:
                # vector = target[:, self.o_f[i]: self.o_f[i + 1]]
                # target[:, self.o_f[i]: self.o_f[i + 1]] = change_cf(t, vector)
                vector = target[self.o_f[i]: self.o_f[i + 1]]
                target[:, self.o_f[i]: self.o_f[i + 1]] = change_cf(t, vector)
            elif self.o_f[i + 1] - self.o_f[i] == 2:
                vector = np.concatenate([target[:, self.o_f[i]: self.o_f[i + 1]], np.zeros([target.shape[0], 1])], axis=1)
                target[:, self.o_f[i]: self.o_f[i + 1]] = change_cf(t, vector)[:, :2]

        return feature.astype(np.float32), target.astype(np.float32)

class RandomHoriRotateSeqTensor:
    def __init__(self):
        """
        Rotate global input, global output by a random angle
        @:param input format - input feature vector(x,3) boundaries as array (E.g [0,3,6])
        @:param output format - output feature vector(x,2/3) boundaries as array (E.g [0,2,5])
                                if 2, 0 is appended as z.
        """

    def __call__(self, feature, target):
        # Tensor random rotation matrix
        a = torch.rand(1) * 2 * np.math.pi
        rotation_matrix_feat = torch.tensor([[torch.cos(a), torch.sin(a), 0, 0, 0, 0],
                                            [-torch.sin(a), torch.cos(a), 0, 0, 0, 0],
                                            [0, 0, 1, 0, 0, 0],
                                            [0, 0, 0, torch.cos(a), torch.sin(a), 0],
                                            [0, 0, 0, -torch.sin(a), torch.cos(a), 0],
                                            [0, 0, 0, 0, 0, 1]], dtype=torch.float32)

        rotation_matrix_targ = torch.tensor([[torch.cos(a), torch.sin(a)],
                                            [-torch.sin(a), torch.cos(a)]], dtype=torch.float32)

        # Matrix multiplication
        feature = torch.matmul(feature, rotation_matrix_feat)
        target = torch.matmul(target, rotation_matrix_targ)

        return feature, target

def get_dataset(root_dir, data_list, mode, **kwargs):
    # load config
    global_step_size = 0
    # input data includes: accelemeters, gyroscopes
    input_format = [0, 3, 6]
    # output data is the moving distance and its direction
    output_format = [0, 2]

    random_shift, shuffle, transforms = 0, False, []

    if mode == 'train':
        random_shift = global_step_size // 2
        shuffle = True
        transforms.append(RandomHoriRotateSeq(input_format, output_format))
        global_step_size = kwargs.get('step_size')
        transforms = ComposeTransform(transforms)
        seq_type = GlobSequence
        global_window_size = kwargs.get('window_size')
        dataset = SequenceToSequenceDataset(seq_type, root_dir, data_list, global_step_size,
                                            global_window_size, random_shift = random_shift,
                                            transform = transforms, shuffle = shuffle)
    elif mode == 'val':
        shuffle = True
        global_step_size = kwargs.get('step_size')
        transforms = ComposeTransform(transforms)
        seq_type = GlobSequence
        global_window_size = kwargs.get('window_size')
        dataset = SequenceToSequenceDataset(seq_type, root_dir, data_list, global_step_size,
                                            global_window_size, random_shift = random_shift,
                                            transform = transforms, shuffle = shuffle)
    elif mode == 'val_test':
        shuffle = False
        global_step_size = kwargs.get('test_step_size')
        transforms = ComposeTransform(transforms)
        seq_type = GlobSequence
        global_window_size = kwargs.get('window_size')
        dataset = SequenceToSequenceDataset(seq_type, root_dir, data_list, global_step_size,
                                            global_window_size, random_shift = random_shift,
                                            transform = transforms, shuffle = shuffle)
    elif mode == 'test':
        shuffle = False
        global_step_size = kwargs.get('test_step_size')
        transforms = ComposeTransform(transforms)
        seq_type = GlobSequenceTest
        global_window_size = kwargs.get('window_size')
        dataset = SequenceToSequenceDatasetTest(seq_type, root_dir, data_list, global_step_size,
                                                global_window_size, random_shift = random_shift,
                                                transform = transforms, shuffle = shuffle)

    return dataset

def read_dir(dir_path):
    # read dirs from dir_path
    for _, dirs, _ in os.walk(dir_path):
        return dirs

def get_train_dataset(root_dir, **kwargs):
    trainlist = read_dir(root_dir)
    return get_dataset(root_dir, trainlist, mode = 'train', **kwargs)

def get_valid_dataset(root_dir, **kwargs):
    validlist = read_dir(root_dir)
    return get_dataset(root_dir, validlist, mode = 'val', **kwargs)

def get_valid_test_dataset(root_dir, dir, **kwargs):
    return get_dataset(root_dir, dir, mode = 'val_test', **kwargs)

def get_test_dataset(root_dir, dir, **kwargs):
    return get_dataset(root_dir, dir, mode = 'test', **kwargs)

### 4. Criterion

In [None]:
import json
import os
import sys
import time
import random
import argparse
from os import path as osp
from pathlib import Path

import numpy as np
import torch

class GlobalPosLoss(torch.nn.Module):
    def __init__(self):
        """
        Calculate position loss in global coordinate frame
        Target :- Global Velocity
        Prediction :- Global Velocity
        """
        super(GlobalPosLoss, self).__init__()
        self.mse_loss = torch.nn.MSELoss(reduction = 'none')

    def forward(self, pred, targ):
        # dts = 1 / 200
        dts = 1
        pred = pred * dts
        targ = targ * dts
        gt_pos = torch.cumsum(targ[:, 1:, ], 1)
        pred_pos = torch.cumsum(pred[:, 1:, ], 1)
        loss = self.mse_loss(pred_pos, gt_pos)
        # loss = 10*PDE + 0.1*AYE + RTE + ATE
        # calculate the sum of absolute trajectory error
        return torch.mean(loss)

class MSEAverage():
    def __init__(self):
        self.count = 0
        self.targets = []
        self.predictions = []
        self.average = []

    def add(self, pred, targ):
        self.targets.append(targ)
        self.predictions.append(pred)
        self.average.append(np.average((pred - targ) ** 2, axis=(0, 1)))
        # print("The shape of average is: ", np.array(self.average).shape)
        # print("THe shape of np.average(np.array(self.average), axis=0) is: ", np.average(np.array(self.average), axis=0).shape)
        self.count += 1

    def get_channel_avg(self):
        average = np.average(np.array(self.average), axis=0)
        return average

    def get_total_avg(self):
        average = np.average(np.array(self.average), axis=0)
        return np.average(average)

    def get_elements(self, axis):
        return np.concatenate(self.predictions, axis=axis), np.concatenate(self.targets, axis=axis)

def reconstruct_traj(vector, **kwargs):
    global_sampling_rate = kwargs.get('sampling_rate', None)
    # reconstruct the vector to one sequence
    # velocity_sequence = vector.reshape(len(vector) * global_window_size, global_output_channel)

    velocity_sequence = vector * 1 / global_sampling_rate
    glob_pos = np.cumsum(velocity_sequence, axis = 0)

    return glob_pos

def compute_absolute_trajectory_error(pred, gt):
    """
    The Absolute Trajectory Error (ATE) defined in:
    A Benchmark for the evaluation of RGB-D SLAM Systems
    http://ais.informatik.uni-freiburg.de/publications/papers/sturm12iros.pdf

    Args:
        est: estimated trajectory
        gt: ground truth trajectory. It must have the same shape as est.

    Return:
        Absolution trajectory error, which is the Root Mean Squared Error between
        two trajectories.
    """
    return np.sqrt(np.mean((pred - gt) ** 2))


def compute_relative_trajectory_error(est, gt, delta, max_delta = -1):
    """
    The Relative Trajectory Error (RTE) defined in:
    A Benchmark for the evaluation of RGB-D SLAM Systems
    http://ais.informatik.uni-freiburg.de/publications/papers/sturm12iros.pdf

    Args:
        est: the estimated trajectory
        gt: the ground truth trajectory.
        delta: fixed window size. If set to -1, the average of all RTE up to max_delta will be computed.
        max_delta: maximum delta. If -1 is provided, it will be set to the length of trajectories.

    Returns:
        Relative trajectory error. This is the mean value under different delta.
    """
    if max_delta == -1:
        max_delta = est.shape[0]
    # print("delta: ", delta)
    deltas = np.array([min(delta, max_delta - 1)])
    # deltas = np.array([delta]) if delta > 0 else np.arange(1, min(est.shape[0], max_delta))
    rtes = np.zeros(deltas.shape[0])
    for i in range(deltas.shape[0]):
        # For each delta, the RTE is computed as the RMSE of endpoint drifts from fixed windows
        # slided through the trajectory.
        err = est[deltas[i]:] + gt[:-deltas[i]] - est[:-deltas[i]] - gt[deltas[i]:]
        rtes[i] = np.sqrt(np.mean(err ** 2))

    # The average of RTE of all window sized is returned.
    rtes = rtes[~np.isnan(rtes)]
    return np.mean(rtes)

def compute_position_drift_error(pos_pred, pos_gt):
    """
    Params:
        pos_pred: predicted position [seq_len, 2]
        pos_gt: ground truth position [seq_len, 2]
    """
    position_drift = np.linalg.norm((pos_gt[-1] - pos_pred[-1]))
    delta_position = pos_gt[1:] - pos_gt[:-1]
    delta_length = np.linalg.norm(delta_position, axis=1)
    moving_len = np.sum(delta_length)

    return position_drift / moving_len

def compute_distance_error(pos_pred, pos_gt):
    """
    Params:
        pos_pred: predicted position [seq_len, 2]
        pos_gt: ground truth position [seq_len, 2]
    """
    distance_error = np.linalg.norm((pos_gt - pos_pred), axis=1)

    return distance_error

def compute_heading_error(preds, targets):
    """
    Params:
        pos_pred: predicted position [seq_len, 2]
        pos_gt: ground truth position [seq_len, 2]
    """
    # Find the index of preds with zero norm
    zero_norm_index = np.where(np.linalg.norm(preds, axis=1) == 0)[0]
    # Remove the zero norm index
    preds = np.delete(preds, zero_norm_index, axis=0)
    targets = np.delete(targets, zero_norm_index, axis=0)
    # Find the index of targets with zero norm
    zero_norm_index = np.where(np.linalg.norm(targets, axis=1) == 0)[0]
    # Remove the zero norm index
    preds = np.delete(preds, zero_norm_index, axis=0)
    targets = np.delete(targets, zero_norm_index, axis=0)

    pred_v = np.linalg.norm(preds, axis=1)
    targ_v = np.linalg.norm(targets, axis=1)

    pred_o = preds / pred_v[:, np.newaxis]
    targ_o = targets / targ_v[:, np.newaxis]

    # calculate the heading angle of the predicted and target vectors
    pred_heading = np.arctan2(pred_o[:, 1], pred_o[:, 0])
    targ_heading = np.arctan2(targ_o[:, 1], targ_o[:, 0])

    # calculate the heading error
    heading_error = np.mean(np.abs(pred_heading - targ_heading))
    # convert to degrees
    heading_error = heading_error * 180 / np.pi

    return heading_error

### 5. Utils

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import os.path as osp

def format_string(*argv, sep=' '):
    result = ''
    for val in argv:
        if isinstance(val, (tuple, list, np.ndarray)):
            for v in val:
                result += format_string(v, sep=sep) + sep
        else:
            result += str(val) + sep
    return result[:-1]

def draw_trajectory(pos_pred, pos_gt, dir_name, ate, rte, **kwargs):
    """
    :param data:
    :pos_pred: (N, 2)
    :pos_gt: (N, 2)
    :dir_name: test directory
    :ate: average trajectory error
    :rte: relative trajectory error
    """
    global_out_dir = kwargs.get('out_dir', None)
    if global_out_dir is None:
        raise ValueError('out_dir is needed')

    plt.figure(figsize=(8, 5), dpi = 400)
    plt.plot(pos_pred[:, 0], pos_pred[:, 1], label = 'Predicted')
    plt.plot(pos_gt[:, 0], pos_gt[:, 1], label = 'Ground truth')
    plt.title(dir_name)
    print("make title success")
    # Show words in latex format
    plt.xlabel('$m$')
    plt.ylabel('$m$')
    plt.axis('equal')
    plt.legend()
    plt.title('ATE:{:.3f}, RTE:{:.3f}'.format(ate, rte), y = 0, loc = 'right')

    plt.savefig(osp.join(global_out_dir, '{}.png'.format(dir_name)))


### 6. Main Function

In [None]:
import os
import time
import argparse
from os import path as osp
from pathlib import Path
from tqdm import tqdm
import numpy as np
import torch
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader

def get_model(mode, **kwargs):
    global_input_channel = kwargs.get('input_channel')
    global_output_channel = kwargs.get('output_channel')
    global_dropout = kwargs.get('dropout')
    global_batch_size = kwargs.get('batch_size')
    global_test_batch_size = kwargs.get('test_batch_size')
    global_device = kwargs.get('device')
    global_layers = kwargs.get('layers')
    global_layer_size = kwargs.get('layer_size')

    if mode == 'train':
        print("LSTM model")
        network = BilinearLSTMSeqNetwork(global_input_channel, global_output_channel, global_batch_size, global_device,
                        lstm_size = global_layer_size, lstm_layers = global_layers, dropout = global_dropout).to(global_device)
    elif mode == 'test':
        print("LSTM model")
        network = BilinearLSTMSeqNetwork(global_input_channel, global_output_channel, global_test_batch_size, global_device,
                        lstm_size = global_layer_size, lstm_layers = global_layers, dropout = global_dropout).to(global_device)
    try:
        pytorch_total_params = sum(p.numel() for p in network.parameters() if p.requires_grad)
    except:
        pytorch_total_params = 0
    print('Network constructed. trainable parameters: {}'.format(pytorch_total_params))
    return network

def train(**kwargs):
    # load config
    global_data_dir = kwargs.get('data_dir')
    global_val_data_dir = kwargs.get('val_data_dir')
    global_batch_size = kwargs.get('batch_size')
    global_epochs = kwargs.get('epochs')
    global_num_workers = kwargs.get('num_workers')
    global_device = kwargs.get('device')
    global_out_dir = kwargs.get('out_dir', None)
    global_learning_rate = kwargs.get('learning_rate')
    global_save_interval = kwargs.get('save_interval')
    global_sampling_rate = kwargs.get('sampling_rate')
    # Loading data
    start_t = time.time()
    train_dataset = get_train_dataset(global_data_dir, **kwargs)
    val_dataset = get_valid_dataset(global_val_data_dir, **kwargs)
    train_loader = DataLoader(train_dataset, batch_size = global_batch_size, num_workers = global_num_workers, shuffle = True,
                              drop_last = True)
    val_loader = DataLoader(val_dataset, batch_size = global_batch_size, shuffle = True, drop_last = True)
    end_t = time.time()
    print('Training and validation set loaded. Time usage: {:.3f}s'.format(end_t - start_t))
    # read val for sequence test
    test_dirs = read_dir(global_val_data_dir)



    global device
    device = torch.device(global_device if torch.cuda.is_available() else 'cpu')
    print("Device: {}".format(device))

    if global_out_dir:
        if not osp.isdir(global_out_dir):
            os.makedirs(global_out_dir)
        if not osp.isdir(osp.join(global_out_dir, 'checkpoints')):
            os.makedirs(osp.join(global_out_dir, 'checkpoints'))

    print('\nNumber of train samples: {}'.format(len(train_dataset)))
    train_mini_batches = len(train_loader)
    if val_dataset:
        print('Number of val samples: {}'.format(len(val_dataset)))
        val_mini_batches = len(val_loader)

    network = get_model('train', **kwargs).to(device)
    testnetwork = get_model('test', **kwargs).to(device)
    criterion = GlobalPosLoss()

    optimizer = torch.optim.Adam(network.parameters(), global_learning_rate)
    scheduler = ReduceLROnPlateau(optimizer, 'min', patience = 10, factor = 0.75, verbose = True, eps = 1e-12)
    quiet_mode = kwargs.get('quiet', False)
    use_scheduler = kwargs.get('use_scheduler', True)

    start_epoch = 0
    step = 0
    best_val_loss = np.inf
    train_errs = np.zeros(global_epochs)

    print("Starting from epoch {}".format(start_epoch))
    try:
        for epoch in range(start_epoch, global_epochs):
            log_line = ''
            network.train()
            train_vel = MSEAverage()
            train_loss = 0
            start_t = time.time()

            for bid, batch in enumerate(tqdm(train_loader)):
                feat, targ, _, _ = batch
                feat, targ = feat.to(device), targ.to(device)
                optimizer.zero_grad()
                predicted = network(feat)
                train_vel.add(predicted.cpu().detach().numpy(), targ.cpu().detach().numpy())
                loss = criterion(predicted, targ)
                train_loss += loss.cpu().detach().numpy()
                loss.backward()
                optimizer.step()
                step += 1

            train_errs[epoch] = train_loss / train_mini_batches
            end_t = time.time()
            if not quiet_mode:
                print('-' * 25)
                print('Epoch {}, time usage: {:.3f}s, loss: {}, vec_loss {}/{:.6f}'.format(
                    epoch, end_t - start_t, train_errs[epoch], train_vel.get_channel_avg(), train_vel.get_total_avg()))

            saved_model = False
            if val_loader:
                network.eval()
                val_vel = MSEAverage()
                val_loss = 0
                for bid, batch in enumerate(val_loader):
                    feat, targ, _, _ = batch
                    feat, targ = feat.to(device), targ.to(device)
                    optimizer.zero_grad()
                    pred = network(feat)
                    val_vel.add(pred.cpu().detach().numpy(), targ.cpu().detach().numpy())
                    val_loss += criterion(pred, targ).cpu().detach().numpy()
                val_loss = val_loss / val_mini_batches

                if not quiet_mode:
                    print('Validation loss: {} vec_loss: {}/{:.6f}'.format(val_loss, val_vel.get_channel_avg(),
                                                                                val_vel.get_total_avg()))

                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    saved_model = True
                    if global_out_dir:
                        model_path = osp.join(global_out_dir, 'checkpoints', 'checkpoint_%d.pt' % epoch)
                        torch.save({'model_state_dict': network.state_dict(),
                                    'epoch': epoch,
                                    'loss': train_errs[epoch],
                                    'optimizer_state_dict': optimizer.state_dict()}, model_path)
                        print('Best Validation Model saved to ' + model_path)
                if use_scheduler:
                    scheduler.step(val_loss)

            print("Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE")
            testnetwork.load_state_dict(network.state_dict())
            network.eval().to(device)
            ate_all, rte_all, pde_all = [], [], []
            aye_all = []
            # Every minute
            pred_per_min = global_sampling_rate * 60 # 2 + 0.5*(x - 1) = 60
            # Test for every sequence
            for i in range(len(test_dirs)):
            # for i in range(1):
                seq_dir = [test_dirs[i]]
                seq_dataset = get_valid_test_dataset(global_val_data_dir, seq_dir, **kwargs)
                feat, targ = seq_dataset.get_lstm_test_seq()

                feat = torch.from_numpy(feat).to(device)
                pred = testnetwork(feat).cpu().detach().numpy()

                pred = np.squeeze(pred, axis = 0)
                targ = np.squeeze(targ, axis = 0)
                # Reconstruct the trajectory
                pos_pred = reconstruct_traj(pred, **kwargs)
                pos_gt = reconstruct_traj(targ, **kwargs)

                # Compute the ATE and RTE
                ate = compute_absolute_trajectory_error(pos_pred, pos_gt)
                rte = compute_relative_trajectory_error(pos_pred, pos_gt, delta = pred_per_min)
                pde = compute_position_drift_error(pos_pred, pos_gt)
                heading_error = compute_heading_error(pred, targ)
                ate_all.append(ate)
                if rte >= 0:
                    rte_all.append(rte)
                pde_all.append(pde)
                aye_all.append(heading_error)

            ate_all = np.array(ate_all)
            rte_all = np.array(rte_all)
            pde_all = np.array(pde_all)
            aye_all = np.array(aye_all)

            measure = format_string('ATE', 'RTE', 'PDE', 'AYE',  sep = '\t')
            values = format_string(np.mean(ate_all), np.mean(rte_all), np.mean(pde_all), np.mean(aye_all), sep = '\t')
            print(measure + '\n' + values)

            if global_out_dir and not saved_model and (epoch + 1) % global_save_interval == 0:  # save even with validation
                model_path = osp.join(global_out_dir, 'checkpoints', 'icheckpoint_%d.pt' % epoch)
                torch.save({'model_state_dict': network.state_dict(),
                            'epoch': epoch,
                            'loss': train_errs[epoch],
                            'optimizer_state_dict': optimizer.state_dict()}, model_path)
                print('Model saved to ' + model_path)

            if np.isnan(train_loss):
                print("Invalid value. Stopping training.")
                break
    except KeyboardInterrupt:
        print('-' * 60)
        print('Early terminate')

    print('Training completed')
    if global_out_dir:
        model_path = osp.join(global_out_dir, 'checkpoints', 'checkpoint_latest.pt')
        torch.save({'model_state_dict': network.state_dict(),
                    'epoch': epoch,
                    'optimizer_state_dict': optimizer.state_dict()}, model_path)

# Test and Save the Trajectory for Both Seen and Unseen Data
seen_unseen_dataset = {'id1': 'tracermini_hw101_test20230311112635T',
                'id2': 'tracermini_hw101_test20230311111507T',
                'id3': 'tracermini_hw101_test20230311112235T',
                'id4': 'tracermini_hw101_test20230313011357T',
                'id5': 'tracermini_hw101_test20230311111842T',
                'id6': 'tracermini_hw101_test20230313010954T',
                'id7': 'tracermini_hw101_test20230313010546T',
                'id8': 'tracermini_hw101_test20230313011731T',
                'id9': 'tracermini_hw101_test20230313013335T',
                'id10': 'tracermini_hw101_test20230311111027T',
                'id11': 'tracermini_hw101_test20230313012956T',
                'id12': 'tracermini_hw101_test20230313010204T',
                'id13': 'tracermini_unseen_hw520230314091844T',
                'id14': 'tracermini_unseen_cym20230314101001T',
                'id15': 'tracermini_unseen_hw520230314082319T',
                'id16': 'tracermini_unseen_cym20230314101636T',
                'id17': 'tracermini_unseen_hw520230314083031T',
                'id18': 'tracermini_unseen_hw520230314091110T',
                'id19': 'tracermini_unseen_cym20230314100816T',
                'id20': 'tracermini_unseen_cym20230314101230T',
                'id21': 'tracermini_unseen_cym20230314100559T',
                'id22': 'tracermini_unseen_hw520230314081212T',
                'id23': 'tracermini_unseen_hw520230314082000T',
                'id24': 'tracermini_unseen_hw520230314091603T',
                'id25': 'tracermini_unseen_cym20230314100325T',
                'id26': 'tracermini_unseen_cym20230314101434T',
                'id27': 'tracermini_unseen_cym20230314100103T',
                'id28': 'tracermini_unseen_hw520230314091338T',
                'id29': 'tracermini_unseen_cym20230314101010T',
                'id30': 'tracermini_unseen_cym20230314101850T',
                'id31': 'tracermini_unseen_hw520230314081542T',
                'id32': 'tracermini_unseen_hw520230314090739T'}
seen_unseen_pred = {'id1': [],
            'id2': [],
            'id3': [],
            'id4': [],
            'id5': [],
            'id6': [],
            'id7': [],
            'id8': [],
            'id9': [],
            'id10': [],
            'id11': [],
            'id12': [],
            'id13': [],
            'id14': [],
            'id15': [],
            'id16': [],
            'id17': [],
            'id18': [],
            'id19': [],
            'id20': [],
            'id21': [],
            'id22': [],
            'id23': [],
            'id24': [],
            'id25': [],
            'id26': [],
            'id27': [],
            'id28': [],
            'id29': [],
            'id30': [],
            'id31': [],
            'id32': []}
def test_lstm(**kwargs):
    # load config
    global_dataset = kwargs.get('dataset')
    global_model_type = kwargs.get('model_type')
    global_num_workers = kwargs.get('num_workers')
    global_sampling_rate = kwargs.get('sampling_rate')
    global_device = kwargs.get('device')
    global_out_dir = kwargs.get('out_dir', None)
    global_test_dir = kwargs.get('test_dir', None)
    global_out_dir = kwargs.get('out_dir', None)
    global_model_path = kwargs.get('model_path', None)

    global device
    device = torch.device(global_device if torch.cuda.is_available() else 'cpu')

    if global_test_dir is None:
        raise ValueError('Test_path is needed.')

    # read dirs
    test_dirs = read_dir(global_test_dir)

    # Make sure the test output dir exists
    if global_out_dir and not osp.exists(global_out_dir):
        os.makedirs(global_out_dir)
    # Load the model config
    if global_model_path is None:
        raise ValueError('Model path is needed.')

    checkpoint = torch.load(global_model_path, map_location=global_device)

    network = get_model('test', **kwargs)
    network.load_state_dict(checkpoint.get('model_state_dict'))
    # network.load_state_dict(checkpoint.get('model_state_dict'))
    print("The model is loaded.")
    network.eval().to(device)
    print('Model {} loaded to device {}.'.format(global_model_path, device))

    # Test for every sequence
    for i in range(len(test_dirs)):
    # for i in range(1):
        seq_dir = [test_dirs[i]]
        seq_dataset = get_test_dataset(global_test_dir, seq_dir, **kwargs)
        feat = seq_dataset.get_lstm_test_seq()

        feat = torch.from_numpy(feat).to(device)
        pred = network(feat).cpu().detach().numpy()

        pred = np.squeeze(pred, axis = 0)

        # Reconstruct the trajectory
        print("Reconstruct the {}".format(seq_dir[0]))
        pos_pred = reconstruct_traj(pred, **kwargs)

        # Find the index of the sequence
        for key, value in seen_unseen_dataset.items():
            if value == seq_dir[0]:
                index = key
                seen_unseen_pred[index] = pos_pred

        # # Make directory
        # if global_out_dir and not osp.exists(osp.join(global_out_dir, seq_dir[0])):
        #     os.makedirs(osp.join(global_out_dir, seq_dir[0]))
        # # Save the trajectory
        # np.save(osp.join(global_out_dir, seq_dir[0], 'pred.npy'), pos_pred)

def test(**kwargs):
    # Model type
    print("Testing LSTM model")
    test_lstm(**kwargs)

### 7. Training

In [None]:
# Load config settings
kwargs = load_config()
import warnings
# Suspend warnings
warnings.filterwarnings('ignore')
train(**kwargs)

Training and validation set loaded. Time usage: 2.954s
Device: cuda:0

Number of train samples: 30993
Number of val samples: 5436
LSTM model
Network constructed. trainable parameters: 444632
LSTM model
Network constructed. trainable parameters: 444632
Starting from epoch 0


100%|██████████| 484/484 [00:14<00:00, 32.98it/s]


-------------------------
Epoch 0, time usage: 14.677s, loss: 50007.518136521016, vec_loss [1.0504694 4.186765 ]/2.618617
Validation loss: 4325.428405761719 vec_loss: [0.5009264 1.2236528]/0.862290
Best Validation Model saved to /content/drive/MyDrive/ColabNotebooks/out/working/prediction_model/checkpoints/checkpoint_0.pt
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
10.199997	5.02634258703752	0.19289719	106.88733290141172


100%|██████████| 484/484 [00:14<00:00, 33.03it/s]


-------------------------
Epoch 1, time usage: 14.659s, loss: 4165.889948821265, vec_loss [0.47737652 0.98442745]/0.730902
Validation loss: 3111.76511492048 vec_loss: [0.44671822 0.53194916]/0.489334
Best Validation Model saved to /content/drive/MyDrive/ColabNotebooks/out/working/prediction_model/checkpoints/checkpoint_1.pt
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
6.2596254	3.8892373821952124	0.11400742	102.75126031262859


100%|██████████| 484/484 [00:14<00:00, 33.06it/s]


-------------------------
Epoch 2, time usage: 14.644s, loss: 3718.674800147695, vec_loss [0.44045594 0.57340455]/0.506930
Validation loss: 3053.2745535714284 vec_loss: [0.34750816 0.49213514]/0.419822
Best Validation Model saved to /content/drive/MyDrive/ColabNotebooks/out/working/prediction_model/checkpoints/checkpoint_2.pt
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
17.703836	7.36112455888228	0.315259	98.55265796292028


100%|██████████| 484/484 [00:14<00:00, 33.03it/s]


-------------------------
Epoch 3, time usage: 14.658s, loss: 3601.778373529103, vec_loss [0.45243388 0.5278861 ]/0.490160
Validation loss: 2798.352106003534 vec_loss: [0.37701976 0.4689254 ]/0.422973
Best Validation Model saved to /content/drive/MyDrive/ColabNotebooks/out/working/prediction_model/checkpoints/checkpoint_3.pt
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
10.246911	4.729341680353338	0.18592349	99.98263568144847


100%|██████████| 484/484 [00:14<00:00, 32.68it/s]


-------------------------
Epoch 4, time usage: 14.813s, loss: 3402.512205770193, vec_loss [0.40864328 0.49733943]/0.452991
Validation loss: 3328.907004220145 vec_loss: [0.38202313 0.46371815]/0.422871
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
18.215122	6.942073475230824	0.33114558	94.6786926416838


100%|██████████| 484/484 [00:14<00:00, 32.70it/s]


-------------------------
Epoch 5, time usage: 14.805s, loss: 3328.446477464408, vec_loss [0.40590852 0.4630153 ]/0.434462
Validation loss: 2664.5476582845054 vec_loss: [0.35810322 0.4284167 ]/0.393260
Best Validation Model saved to /content/drive/MyDrive/ColabNotebooks/out/working/prediction_model/checkpoints/checkpoint_5.pt
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
4.2574153	3.339162739840421	0.061159	102.02886721784108


100%|██████████| 484/484 [00:14<00:00, 32.78it/s]


-------------------------
Epoch 6, time usage: 14.769s, loss: 3166.2102658610697, vec_loss [0.39188373 0.44534335]/0.418614
Validation loss: 2792.5699448358446 vec_loss: [0.37598124 0.4545321 ]/0.415257
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
14.117291	6.157823649319735	0.24764414	95.42502927506995


100%|██████████| 484/484 [00:14<00:00, 32.55it/s]


-------------------------
Epoch 7, time usage: 14.872s, loss: 3094.791019408171, vec_loss [0.37084752 0.4425053 ]/0.406676
Validation loss: 2812.23525710333 vec_loss: [0.3322365  0.43572557]/0.383981
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
11.372114	4.520477793433449	0.20611839	103.44321249613796


100%|██████████| 484/484 [00:14<00:00, 32.49it/s]


-------------------------
Epoch 8, time usage: 14.902s, loss: 2890.923400122272, vec_loss [0.3692386  0.41193536]/0.390587
Validation loss: 2186.540975661505 vec_loss: [0.32236218 0.38713327]/0.354748
Best Validation Model saved to /content/drive/MyDrive/ColabNotebooks/out/working/prediction_model/checkpoints/checkpoint_8.pt
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
7.0455403	3.367032592946833	0.12782049	95.16783415538912


100%|██████████| 484/484 [00:14<00:00, 32.85it/s]


-------------------------
Epoch 9, time usage: 14.738s, loss: 2705.365660738354, vec_loss [0.35791963 0.40431055]/0.381115
Validation loss: 2380.085970924014 vec_loss: [0.31686988 0.41344076]/0.365155
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
9.544002	4.116758823394775	0.1724907	88.56585899148256


100%|██████████| 484/484 [00:14<00:00, 32.85it/s]


-------------------------
Epoch 10, time usage: 14.738s, loss: 2714.1371689473303, vec_loss [0.34389842 0.3821882 ]/0.363043
Validation loss: 2130.641584123884 vec_loss: [0.311569  0.3721021]/0.341836
Best Validation Model saved to /content/drive/MyDrive/ColabNotebooks/out/working/prediction_model/checkpoints/checkpoint_10.pt
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
6.44536	3.4321188709952613	0.10285056	91.40836515798843


100%|██████████| 484/484 [00:14<00:00, 32.47it/s]


-------------------------
Epoch 11, time usage: 14.911s, loss: 2424.7394706789128, vec_loss [0.3312936 0.3519505]/0.341622
Validation loss: 2496.5476960681735 vec_loss: [0.26812723 0.33297688]/0.300552
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
14.225525	5.640175255862149	0.26123577	86.3739755338183


100%|██████████| 484/484 [00:14<00:00, 32.67it/s]


-------------------------
Epoch 12, time usage: 14.818s, loss: 2282.4271927510413, vec_loss [0.30190745 0.33255723]/0.317232
Validation loss: 1919.3773629324776 vec_loss: [0.26623243 0.31747016]/0.291851
Best Validation Model saved to /content/drive/MyDrive/ColabNotebooks/out/working/prediction_model/checkpoints/checkpoint_12.pt
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
7.899905	3.4592760042710737	0.1436395	86.29732528817357


100%|██████████| 484/484 [00:14<00:00, 32.89it/s]


-------------------------
Epoch 13, time usage: 14.721s, loss: 2169.878234484964, vec_loss [0.2805564  0.30504915]/0.292803
Validation loss: 1731.252960931687 vec_loss: [0.23770866 0.2751673 ]/0.256438
Best Validation Model saved to /content/drive/MyDrive/ColabNotebooks/out/working/prediction_model/checkpoints/checkpoint_13.pt
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
4.2460475	2.6048381328582764	0.065472804	84.18159781882865


100%|██████████| 484/484 [00:14<00:00, 32.93it/s]


-------------------------
Epoch 14, time usage: 14.702s, loss: 1923.5876588427332, vec_loss [0.2577121  0.28777447]/0.272743
Validation loss: 1857.4058452787854 vec_loss: [0.19697574 0.24427216]/0.220624
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
5.3173733	2.7516028881073	0.08348629	80.36892661520388


100%|██████████| 484/484 [00:14<00:00, 32.95it/s]


-------------------------
Epoch 15, time usage: 14.692s, loss: 1950.5270644258862, vec_loss [0.24974652 0.25582653]/0.252787
Validation loss: 1967.4829980759393 vec_loss: [0.21686558 0.24441001]/0.230638
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
12.43128	4.5214786963029345	0.2245284	83.61846921232863


100%|██████████| 484/484 [00:14<00:00, 32.82it/s]


-------------------------
Epoch 16, time usage: 14.753s, loss: 1749.9599937249807, vec_loss [0.2197718  0.25108427]/0.235428
Validation loss: 2121.834729875837 vec_loss: [0.1753638  0.23254576]/0.203955
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
9.257331	3.7047910473563452	0.16101354	80.5558434348093


100%|██████████| 484/484 [00:14<00:00, 33.13it/s]


-------------------------
Epoch 17, time usage: 14.613s, loss: 1711.0261844603483, vec_loss [0.1977189  0.22395916]/0.210839
Validation loss: 1435.7273130871001 vec_loss: [0.16765893 0.2029797 ]/0.185319
Best Validation Model saved to /content/drive/MyDrive/ColabNotebooks/out/working/prediction_model/checkpoints/checkpoint_17.pt
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
4.0962915	2.202175422148271	0.06375589	74.2813745997523


100%|██████████| 484/484 [00:14<00:00, 33.44it/s]


-------------------------
Epoch 18, time usage: 14.479s, loss: 1588.4119081103113, vec_loss [0.1959983  0.21064423]/0.203321
Validation loss: 1683.0390959240142 vec_loss: [0.16454962 0.20636004]/0.185455
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
7.219496	2.867186264558272	0.136255	78.20055887792569


100%|██████████| 484/484 [00:14<00:00, 33.07it/s]


-------------------------
Epoch 19, time usage: 14.642s, loss: 1567.4630570214642, vec_loss [0.18457296 0.21225865]/0.198416
Validation loss: 1541.4254397437685 vec_loss: [0.16187195 0.20508528]/0.183479
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
6.868414	2.868905717676336	0.12763992	76.01896913626695
Model saved to /content/drive/MyDrive/ColabNotebooks/out/working/prediction_model/checkpoints/icheckpoint_19.pt


100%|██████████| 484/484 [00:14<00:00, 33.12it/s]


-------------------------
Epoch 20, time usage: 14.618s, loss: 1470.6995126393215, vec_loss [0.18435708 0.20652632]/0.195442
Validation loss: 1464.3265497116815 vec_loss: [0.15838711 0.21768697]/0.188037
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
5.017776	2.2073594981973823	0.09134147	77.93455778162583


100%|██████████| 484/484 [00:14<00:00, 33.40it/s]


-------------------------
Epoch 21, time usage: 14.493s, loss: 1425.8609173356995, vec_loss [0.1827559  0.19573136]/0.189244
Validation loss: 1499.1587742396764 vec_loss: [0.15589727 0.17604563]/0.165971
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
7.7375903	3.168663740158081	0.14744371	69.43460778837726


100%|██████████| 484/484 [00:14<00:00, 33.05it/s]


-------------------------
Epoch 22, time usage: 14.650s, loss: 1425.1663250883748, vec_loss [0.17207876 0.18254268]/0.177311
Validation loss: 1565.7552751813616 vec_loss: [0.13363841 0.15525167]/0.144445
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
7.522224	3.3214738260615957	0.14224829	68.51581446678489


100%|██████████| 484/484 [00:14<00:00, 33.06it/s]


-------------------------
Epoch 23, time usage: 14.645s, loss: 1258.5143903621956, vec_loss [0.15600796 0.16566354]/0.160836
Validation loss: 1610.1364382789247 vec_loss: [0.12598187 0.16842358]/0.147203
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
11.1630335	4.060256567868319	0.20692936	72.36962079514922


100%|██████████| 484/484 [00:14<00:00, 33.20it/s]


-------------------------
Epoch 24, time usage: 14.586s, loss: 1152.6412146702285, vec_loss [0.14750345 0.16351733]/0.155510
Validation loss: 1391.0062284923736 vec_loss: [0.12594184 0.15267295]/0.139307
Best Validation Model saved to /content/drive/MyDrive/ColabNotebooks/out/working/prediction_model/checkpoints/checkpoint_24.pt
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
3.7849982	2.052733703093095	0.07026095	69.36273805422526


100%|██████████| 484/484 [00:14<00:00, 32.93it/s]


-------------------------
Epoch 25, time usage: 14.703s, loss: 1207.574598769511, vec_loss [0.14386727 0.15553886]/0.149703
Validation loss: 1467.995767139253 vec_loss: [0.11367573 0.15730931]/0.135493
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
6.7380934	2.9158553210171787	0.11951792	68.60346814718253


100%|██████████| 484/484 [00:14<00:00, 32.97it/s]


-------------------------
Epoch 26, time usage: 14.685s, loss: 1124.1861667475425, vec_loss [0.1384139  0.14462939]/0.141522
Validation loss: 1328.8613070533388 vec_loss: [0.11926911 0.15342876]/0.136349
Best Validation Model saved to /content/drive/MyDrive/ColabNotebooks/out/working/prediction_model/checkpoints/checkpoint_26.pt
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
2.747391	1.727595481005582	0.03726918	66.86954390595733


100%|██████████| 484/484 [00:14<00:00, 33.17it/s]


-------------------------
Epoch 27, time usage: 14.594s, loss: 1090.5798303903628, vec_loss [0.13017133 0.13055797]/0.130365
Validation loss: 1674.4327559698195 vec_loss: [0.1410348  0.14494017]/0.142987
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
7.3207073	3.2130624814466997	0.13407947	66.63540781537758


100%|██████████| 484/484 [00:14<00:00, 33.05it/s]


-------------------------
Epoch 28, time usage: 14.648s, loss: 1109.7798724213908, vec_loss [0.13195957 0.13262948]/0.132295
Validation loss: 1233.5104094005767 vec_loss: [0.10524491 0.11504763]/0.110146
Best Validation Model saved to /content/drive/MyDrive/ColabNotebooks/out/working/prediction_model/checkpoints/checkpoint_28.pt
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
2.3332632	1.5470689210024746	0.04145598	62.790045463606255


100%|██████████| 484/484 [00:14<00:00, 33.07it/s]


-------------------------
Epoch 29, time usage: 14.638s, loss: 1070.3141483275358, vec_loss [0.12206616 0.12395356]/0.123010
Validation loss: 1444.6978018624443 vec_loss: [0.11946786 0.15305454]/0.136261
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
3.577706	1.8562686118212612	0.064805515	67.0046077921629


100%|██████████| 484/484 [00:14<00:00, 33.07it/s]


-------------------------
Epoch 30, time usage: 14.642s, loss: 1223.7466156384176, vec_loss [0.12469247 0.13393274]/0.129313
Validation loss: 1342.5996148245674 vec_loss: [0.09530225 0.12589301]/0.110598
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
3.2548647	1.7270260778340427	0.054565456	64.81392644779255


100%|██████████| 484/484 [00:14<00:00, 33.33it/s]


-------------------------
Epoch 31, time usage: 14.524s, loss: 1069.7067500343007, vec_loss [0.11501898 0.12139634]/0.118208
Validation loss: 1219.2584649948847 vec_loss: [0.09737814 0.10717736]/0.102278
Best Validation Model saved to /content/drive/MyDrive/ColabNotebooks/out/working/prediction_model/checkpoints/checkpoint_31.pt
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
2.882226	1.749811215834184	0.049535338	62.95766389417495


100%|██████████| 484/484 [00:14<00:00, 32.99it/s]


-------------------------
Epoch 32, time usage: 14.676s, loss: 978.8273934608649, vec_loss [0.11307865 0.11206884]/0.112574
Validation loss: 1282.8181413922991 vec_loss: [0.09516993 0.12191753]/0.108544
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
3.762709	1.7896878719329834	0.06247097	64.83893114941769


100%|██████████| 484/484 [00:14<00:00, 33.01it/s]


-------------------------
Epoch 33, time usage: 14.670s, loss: 919.0602791526101, vec_loss [0.10392107 0.10287021]/0.103396
Validation loss: 1253.123758951823 vec_loss: [0.08730622 0.1118313 ]/0.099569
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
5.6548233	2.4242129759355024	0.10689595	59.33388095112059


100%|██████████| 484/484 [00:14<00:00, 32.82it/s]


-------------------------
Epoch 34, time usage: 14.753s, loss: 864.1886973735715, vec_loss [0.09691923 0.09796732]/0.097443
Validation loss: 1356.843970889137 vec_loss: [0.09039958 0.11653575]/0.103468
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
3.3285384	1.7162868597290732	0.055488203	63.24832228384172


100%|██████████| 484/484 [00:14<00:00, 32.97it/s]


-------------------------
Epoch 35, time usage: 14.684s, loss: 875.0576783487619, vec_loss [0.09774575 0.09647536]/0.097111
Validation loss: 1526.9017355782646 vec_loss: [0.11764383 0.12620725]/0.121926
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
3.3472962	1.8369903781197288	0.05959365	62.41019132310517


100%|██████████| 484/484 [00:14<00:00, 33.03it/s]


-------------------------
Epoch 36, time usage: 14.658s, loss: 930.1517826427113, vec_loss [0.10252184 0.09429499]/0.098408
Validation loss: 1287.7389631725493 vec_loss: [0.08211844 0.10100148]/0.091560
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
3.7314575	1.796267032623291	0.06644969	55.748534276694166


100%|██████████| 484/484 [00:14<00:00, 33.20it/s]


-------------------------
Epoch 37, time usage: 14.583s, loss: 948.1245249598479, vec_loss [0.10638902 0.09574123]/0.101065
Validation loss: 1146.0249949863978 vec_loss: [0.09965082 0.11366703]/0.106659
Best Validation Model saved to /content/drive/MyDrive/ColabNotebooks/out/working/prediction_model/checkpoints/checkpoint_37.pt
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
6.2287307	2.472135836427862	0.115358956	58.148189294741925


100%|██████████| 484/484 [00:14<00:00, 33.12it/s]


-------------------------
Epoch 38, time usage: 14.619s, loss: 835.6633967880375, vec_loss [0.09533805 0.08698195]/0.091160
Validation loss: 1116.494719369071 vec_loss: [0.07612471 0.08648939]/0.081307
Best Validation Model saved to /content/drive/MyDrive/ColabNotebooks/out/working/prediction_model/checkpoints/checkpoint_38.pt
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
4.0556827	1.9047261801632969	0.07563442	55.177403172192044


100%|██████████| 484/484 [00:14<00:00, 33.11it/s]


-------------------------
Epoch 39, time usage: 14.621s, loss: 856.6394477402869, vec_loss [0.09940103 0.09020042]/0.094801
Validation loss: 1120.1270494006928 vec_loss: [0.07727743 0.08468534]/0.080981
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
5.2054715	2.143124894662337	0.097801924	54.49275742831572
Model saved to /content/drive/MyDrive/ColabNotebooks/out/working/prediction_model/checkpoints/icheckpoint_39.pt


100%|██████████| 484/484 [00:14<00:00, 33.09it/s]


-------------------------
Epoch 40, time usage: 14.632s, loss: 789.6546639056245, vec_loss [0.09327026 0.09129913]/0.092285
Validation loss: 1355.8513394310362 vec_loss: [0.08270863 0.11535034]/0.099029
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
3.6361268	1.7362412214279175	0.065682426	59.37383010664971


100%|██████████| 484/484 [00:14<00:00, 32.96it/s]


-------------------------
Epoch 41, time usage: 14.690s, loss: 777.0437309958718, vec_loss [0.09579613 0.0942764 ]/0.095036
Validation loss: 1251.8736397879463 vec_loss: [0.09074189 0.10550857]/0.098125
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
3.898394	1.838898859240792	0.07606738	54.874501668438036


100%|██████████| 484/484 [00:14<00:00, 32.66it/s]


-------------------------
Epoch 42, time usage: 14.824s, loss: 753.8204094122264, vec_loss [0.09188994 0.09262647]/0.092258
Validation loss: 1146.3309100923084 vec_loss: [0.09609292 0.11452159]/0.105307
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
5.134627	2.2390385541048916	0.09041072	57.623770164436635


100%|██████████| 484/484 [00:14<00:00, 32.91it/s]


-------------------------
Epoch 43, time usage: 14.711s, loss: 991.4384858312685, vec_loss [0.10571153 0.0995933 ]/0.102652
Validation loss: 1146.7824260166713 vec_loss: [0.08375131 0.11004554]/0.096898
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
2.5759747	1.5499261942776767	0.047047183	58.10598183030268


100%|██████████| 484/484 [00:14<00:00, 33.20it/s]


-------------------------
Epoch 44, time usage: 14.582s, loss: 764.8711236370497, vec_loss [0.09877631 0.09298496]/0.095881
Validation loss: 1136.8098126366026 vec_loss: [0.08662387 0.11604623]/0.101335
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
2.3837478	1.4677666208960793	0.044710115	55.77938655163826


100%|██████████| 484/484 [00:14<00:00, 33.26it/s]


-------------------------
Epoch 45, time usage: 14.554s, loss: 672.0072668406589, vec_loss [0.09114104 0.0879081 ]/0.089525
Validation loss: 1068.4422985258557 vec_loss: [0.09650024 0.11596538]/0.106233
Best Validation Model saved to /content/drive/MyDrive/ColabNotebooks/out/working/prediction_model/checkpoints/checkpoint_45.pt
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
2.730744	1.6120697476647117	0.052433822	59.97708359081842


100%|██████████| 484/484 [00:14<00:00, 33.08it/s]


-------------------------
Epoch 46, time usage: 14.634s, loss: 783.14887086222, vec_loss [0.09614489 0.09394471]/0.095045
Validation loss: 1470.855977376302 vec_loss: [0.10450085 0.1445813 ]/0.124541
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
5.894196	2.333814577622847	0.10511134	65.44894403715956


100%|██████████| 484/484 [00:14<00:00, 33.21it/s]


-------------------------
Epoch 47, time usage: 14.577s, loss: 773.7660352848778, vec_loss [0.10057487 0.09369539]/0.097135
Validation loss: 1081.272357395717 vec_loss: [0.08363632 0.0845608 ]/0.084099
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
4.615335	2.141470432281494	0.08356163	56.935665550672


100%|██████████| 484/484 [00:14<00:00, 33.34it/s]


-------------------------
Epoch 48, time usage: 14.523s, loss: 640.5970153178065, vec_loss [0.09476463 0.08955441]/0.092160
Validation loss: 1123.162081037249 vec_loss: [0.08133116 0.10080853]/0.091070
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
3.252395	1.6316818324002353	0.06031884	52.79116419798993


100%|██████████| 484/484 [00:14<00:00, 33.35it/s]


-------------------------
Epoch 49, time usage: 14.518s, loss: 617.6341147304566, vec_loss [0.0963918  0.08605756]/0.091225
Validation loss: 959.9718678792318 vec_loss: [0.06667899 0.0824271 ]/0.074553
Best Validation Model saved to /content/drive/MyDrive/ColabNotebooks/out/working/prediction_model/checkpoints/checkpoint_49.pt
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
1.8365573	1.324116517197002	0.02935211	55.28376970813665


100%|██████████| 484/484 [00:14<00:00, 33.37it/s]


-------------------------
Epoch 50, time usage: 14.508s, loss: 769.8420651018128, vec_loss [0.09604388 0.09458595]/0.095315
Validation loss: 1008.4652619134812 vec_loss: [0.08234868 0.09559306]/0.088971
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
4.154057	1.8586015809666028	0.07789269	51.610621262392876


100%|██████████| 484/484 [00:14<00:00, 33.53it/s]


-------------------------
Epoch 51, time usage: 14.440s, loss: 794.0698493768361, vec_loss [0.10011626 0.09239206]/0.096254
Validation loss: 1048.064095996675 vec_loss: [0.08075721 0.09329097]/0.087024
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
3.7959511	1.8445288918235085	0.06667229	55.05207665203777


100%|██████████| 484/484 [00:14<00:00, 33.47it/s]


-------------------------
Epoch 52, time usage: 14.466s, loss: 690.9889165074372, vec_loss [0.09665404 0.08666145]/0.091658
Validation loss: 878.5716076805478 vec_loss: [0.08937596 0.10103437]/0.095205
Best Validation Model saved to /content/drive/MyDrive/ColabNotebooks/out/working/prediction_model/checkpoints/checkpoint_52.pt
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
1.668552	1.218928868120367	0.028872974	56.19249905645017


100%|██████████| 484/484 [00:14<00:00, 33.48it/s]


-------------------------
Epoch 53, time usage: 14.462s, loss: 768.2139948064631, vec_loss [0.10352416 0.09409147]/0.098808
Validation loss: 950.9859746297201 vec_loss: [0.07537478 0.09289715]/0.084136
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
2.9680238	1.5368428772146052	0.050752833	54.53234489428198


100%|██████████| 484/484 [00:14<00:00, 33.85it/s]


-------------------------
Epoch 54, time usage: 14.305s, loss: 635.8460454389084, vec_loss [0.09327967 0.09096827]/0.092124
Validation loss: 1008.336443219866 vec_loss: [0.0779356  0.09842506]/0.088180
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
2.553301	1.4637381705370816	0.03994315	56.466039750204274


100%|██████████| 484/484 [00:14<00:00, 33.70it/s]


-------------------------
Epoch 55, time usage: 14.368s, loss: 582.8626321209364, vec_loss [0.08981339 0.08651939]/0.088166
Validation loss: 1057.4489313761394 vec_loss: [0.08263043 0.10445601]/0.093543
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
2.46754	1.3510958606546575	0.042084314	54.71719558166347


100%|██████████| 484/484 [00:14<00:00, 33.93it/s]


-------------------------
Epoch 56, time usage: 14.267s, loss: 672.3066853294688, vec_loss [0.09452987 0.09425672]/0.094393
Validation loss: 1029.013517470587 vec_loss: [0.09667262 0.11763167]/0.107152
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
5.1635633	2.185150048949502	0.092027195	57.99528588581372


100%|██████████| 484/484 [00:14<00:00, 33.38it/s]


-------------------------
Epoch 57, time usage: 14.504s, loss: 676.5558534732535, vec_loss [0.09721269 0.09322724]/0.095220
Validation loss: 897.5125499906994 vec_loss: [0.07564646 0.10943931]/0.092543
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
1.8777539	1.387297267263586	0.034722093	52.78744143442911


100%|██████████| 484/484 [00:14<00:00, 33.80it/s]


-------------------------
Epoch 58, time usage: 14.325s, loss: 677.8473399453912, vec_loss [0.09675185 0.09519877]/0.095975
Validation loss: 956.1883377801804 vec_loss: [0.07956548 0.09590721]/0.087736
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
1.9817272	1.3466508117589084	0.038291603	52.035632577665204


100%|██████████| 484/484 [00:14<00:00, 33.68it/s]


-------------------------
Epoch 59, time usage: 14.373s, loss: 620.1324629665406, vec_loss [0.09546962 0.09133627]/0.093403
Validation loss: 859.2911260695685 vec_loss: [0.08383409 0.09822646]/0.091030
Best Validation Model saved to /content/drive/MyDrive/ColabNotebooks/out/working/prediction_model/checkpoints/checkpoint_59.pt
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
1.8097069	1.2295178229158574	0.03542594	52.66863898678479


100%|██████████| 484/484 [00:14<00:00, 32.62it/s]


-------------------------
Epoch 60, time usage: 14.843s, loss: 581.695896621578, vec_loss [0.09686039 0.0944723 ]/0.095666
Validation loss: 897.1683033534458 vec_loss: [0.07539072 0.09293663]/0.084164
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
2.2292712	1.33002230795947	0.037874963	53.492379049403254


100%|██████████| 484/484 [00:15<00:00, 31.95it/s]


-------------------------
Epoch 61, time usage: 15.156s, loss: 576.123583454731, vec_loss [0.09444346 0.09143674]/0.092940
Validation loss: 1128.7357435680572 vec_loss: [0.08598834 0.09998007]/0.092984
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
4.5861974	2.0156525048342617	0.07965534	59.2561509146607


100%|██████████| 484/484 [00:15<00:00, 32.07it/s]


-------------------------
Epoch 62, time usage: 15.099s, loss: 656.7810293465607, vec_loss [0.09918606 0.09807722]/0.098632
Validation loss: 810.961556207566 vec_loss: [0.08391928 0.10093106]/0.092425
Best Validation Model saved to /content/drive/MyDrive/ColabNotebooks/out/working/prediction_model/checkpoints/checkpoint_62.pt
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
3.4782593	1.6762801517139783	0.06654573	55.2852698660489


100%|██████████| 484/484 [00:14<00:00, 32.42it/s]


-------------------------
Epoch 63, time usage: 14.934s, loss: 666.4221877736494, vec_loss [0.09985014 0.09710702]/0.098479
Validation loss: 1034.4306600661505 vec_loss: [0.08989436 0.10445274]/0.097174
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
3.8606112	1.9902881492267956	0.06989643	58.12895330879715


100%|██████████| 484/484 [00:14<00:00, 32.90it/s]


-------------------------
Epoch 64, time usage: 14.714s, loss: 523.8924338916116, vec_loss [0.09400418 0.08992458]/0.091964
Validation loss: 1027.3120894659132 vec_loss: [0.09112452 0.10836049]/0.099743
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
2.106493	1.4222208044745706	0.03905559	58.91065604077935


100%|██████████| 484/484 [00:14<00:00, 32.51it/s]


-------------------------
Epoch 65, time usage: 14.891s, loss: 520.4290207476655, vec_loss [0.09471379 0.09350142]/0.094108
Validation loss: 893.1733725411551 vec_loss: [0.07584421 0.08914256]/0.082493
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
2.6512752	1.486221963709051	0.048639074	52.780696312160586


100%|██████████| 484/484 [00:14<00:00, 32.55it/s]


-------------------------
Epoch 66, time usage: 14.876s, loss: 618.9396693016872, vec_loss [0.10081641 0.09662568]/0.098721
Validation loss: 1132.5193822951544 vec_loss: [0.07559684 0.08290564]/0.079251
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
4.340326	2.0001637285405938	0.08186159	58.52622629446064


100%|██████████| 484/484 [00:14<00:00, 32.58it/s]


-------------------------
Epoch 67, time usage: 14.862s, loss: 616.6757146661931, vec_loss [0.09772219 0.09286371]/0.095293
Validation loss: 921.9034819830032 vec_loss: [0.08664418 0.11006919]/0.098357
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
2.9809015	1.5631343668157405	0.05795003	50.455081476105214


100%|██████████| 484/484 [00:14<00:00, 32.85it/s]


-------------------------
Epoch 68, time usage: 14.740s, loss: 662.6964644755213, vec_loss [0.0982824  0.09318709]/0.095735
Validation loss: 864.4150521414621 vec_loss: [0.07932377 0.09276006]/0.086042
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
3.4147565	1.7133962349458174	0.06486552	51.758255801302695


100%|██████████| 484/484 [00:14<00:00, 32.98it/s]


-------------------------
Epoch 69, time usage: 14.678s, loss: 713.6717889643897, vec_loss [0.10393263 0.09842791]/0.101180
Validation loss: 871.0614278884161 vec_loss: [0.07582647 0.09753121]/0.086679
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
1.9691029	1.2892259467731824	0.03585374	53.274969781637736


100%|██████████| 484/484 [00:14<00:00, 33.39it/s]


-------------------------
Epoch 70, time usage: 14.500s, loss: 557.8811450043986, vec_loss [0.09410524 0.09242754]/0.093266
Validation loss: 860.5471714564732 vec_loss: [0.07611234 0.09489737]/0.085505
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
2.5524416	1.37649174711921	0.04586413	53.00122169850336


100%|██████████| 484/484 [00:14<00:00, 33.53it/s]


-------------------------
Epoch 71, time usage: 14.437s, loss: 519.3587662562851, vec_loss [0.09503324 0.09227614]/0.093655
Validation loss: 946.3053686959403 vec_loss: [0.07805292 0.09501123]/0.086532
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
4.1495013	1.8374200842597268	0.077516794	49.51990532486227


100%|██████████| 484/484 [00:14<00:00, 32.99it/s]


-------------------------
Epoch 72, time usage: 14.673s, loss: 510.92592374943507, vec_loss [0.09351459 0.09172717]/0.092621
Validation loss: 941.8491697765533 vec_loss: [0.08121398 0.10202079]/0.091617
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
3.6928442	1.7143431143327192	0.066500306	53.5250279745622


100%|██████████| 484/484 [00:14<00:00, 32.88it/s]


-------------------------
Epoch 73, time usage: 14.725s, loss: 763.4039794984927, vec_loss [0.10698909 0.09946015]/0.103225
Validation loss: 995.3193944295248 vec_loss: [0.08491828 0.09999541]/0.092457
Epoch 00074: reducing learning rate of group 0 to 7.5000e-04.
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
2.4291313	1.4793610518628901	0.043744456	59.242058992550376


100%|██████████| 484/484 [00:14<00:00, 33.24it/s]


-------------------------
Epoch 74, time usage: 14.566s, loss: 506.9864565005972, vec_loss [0.09923913 0.09560147]/0.097420
Validation loss: 852.9664513724191 vec_loss: [0.08890036 0.10662238]/0.097761
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
2.2926939	1.4621072194793008	0.04112863	55.016528311576145


100%|██████████| 484/484 [00:14<00:00, 32.88it/s]


-------------------------
Epoch 75, time usage: 14.727s, loss: 484.76162669284287, vec_loss [0.0967707  0.09460825]/0.095689
Validation loss: 847.962906610398 vec_loss: [0.07975356 0.0933729 ]/0.086563
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
1.9584941	1.2848061594096096	0.03519942	51.63794915972523


100%|██████████| 484/484 [00:14<00:00, 32.76it/s]


-------------------------
Epoch 76, time usage: 14.779s, loss: 438.65198094391627, vec_loss [0.09113164 0.09160952]/0.091371
Validation loss: 771.3825211297898 vec_loss: [0.07283755 0.09309705]/0.082967
Best Validation Model saved to /content/drive/MyDrive/ColabNotebooks/out/working/prediction_model/checkpoints/checkpoint_76.pt
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
2.1432805	1.3185345259579746	0.04151009	50.27274119370684


100%|██████████| 484/484 [00:14<00:00, 33.03it/s]


-------------------------
Epoch 77, time usage: 14.657s, loss: 455.5104406372575, vec_loss [0.09046222 0.09265283]/0.091558
Validation loss: 782.220078604562 vec_loss: [0.07398756 0.10052107]/0.087254
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
3.14721	1.5132400881160388	0.058634397	51.49058224004561


100%|██████████| 484/484 [00:14<00:00, 32.55it/s]


-------------------------
Epoch 78, time usage: 14.873s, loss: 421.42763109443604, vec_loss [0.09000149 0.08911182]/0.089557
Validation loss: 796.7571269444057 vec_loss: [0.07709197 0.0930046 ]/0.085048
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
2.0581946	1.3195221424102783	0.03960711	51.84187935962688


100%|██████████| 484/484 [00:15<00:00, 32.24it/s]


-------------------------
Epoch 79, time usage: 15.017s, loss: 420.53109829484924, vec_loss [0.09019462 0.08772039]/0.088958
Validation loss: 830.840586344401 vec_loss: [0.07701126 0.08912772]/0.083069
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
2.2481508	1.4181756431406194	0.0401671	50.04436823047087
Model saved to /content/drive/MyDrive/ColabNotebooks/out/working/prediction_model/checkpoints/icheckpoint_79.pt


100%|██████████| 484/484 [00:14<00:00, 32.51it/s]


-------------------------
Epoch 80, time usage: 14.892s, loss: 483.5751690509891, vec_loss [0.09616448 0.09325686]/0.094711
Validation loss: 766.2412316458566 vec_loss: [0.08639573 0.0886949 ]/0.087545
Best Validation Model saved to /content/drive/MyDrive/ColabNotebooks/out/working/prediction_model/checkpoints/checkpoint_80.pt
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
2.0318139	1.2253305045041172	0.03855699	51.36372917261183


100%|██████████| 484/484 [00:14<00:00, 32.52it/s]


-------------------------
Epoch 81, time usage: 14.889s, loss: 416.73207193169713, vec_loss [0.0921398  0.08894175]/0.090541
Validation loss: 801.75937579927 vec_loss: [0.08999982 0.09792415]/0.093962
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
2.6431327	1.3528769666498357	0.05205883	49.61523800066586


100%|██████████| 484/484 [00:14<00:00, 32.83it/s]


-------------------------
Epoch 82, time usage: 14.748s, loss: 405.8000125727378, vec_loss [0.09046015 0.08677723]/0.088619
Validation loss: 742.4229685465494 vec_loss: [0.0802053  0.09480162]/0.087503
Best Validation Model saved to /content/drive/MyDrive/ColabNotebooks/out/working/prediction_model/checkpoints/checkpoint_82.pt
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
1.5231465	1.1137211214412341	0.028035842	50.90879790321713


100%|██████████| 484/484 [00:14<00:00, 32.70it/s]


-------------------------
Epoch 83, time usage: 14.805s, loss: 425.455091366098, vec_loss [0.09071087 0.08507855]/0.087895
Validation loss: 867.3877240135556 vec_loss: [0.0831065  0.09181122]/0.087459
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
2.591534	1.3963476094332608	0.048036624	50.805665461844505


100%|██████████| 484/484 [00:14<00:00, 32.77it/s]


-------------------------
Epoch 84, time usage: 14.776s, loss: 434.82576909341105, vec_loss [0.09291731 0.09069864]/0.091808
Validation loss: 778.5480626424154 vec_loss: [0.09182102 0.11635431]/0.104088
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
1.4779233	1.0976056781682102	0.027879644	52.028146690334886


100%|██████████| 484/484 [00:14<00:00, 32.75it/s]


-------------------------
Epoch 85, time usage: 14.781s, loss: 398.98194068719533, vec_loss [0.08903681 0.08878141]/0.088909
Validation loss: 761.695067632766 vec_loss: [0.08242601 0.09137284]/0.086899
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
1.771199	1.1955205256288701	0.029675066	50.98630720038156


100%|██████████| 484/484 [00:14<00:00, 32.76it/s]


-------------------------
Epoch 86, time usage: 14.779s, loss: 376.86687393819005, vec_loss [0.08998279 0.08506039]/0.087522
Validation loss: 791.9277677990142 vec_loss: [0.08246532 0.09964608]/0.091056
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
1.4721745	1.1067520054903897	0.03011446	50.922554835594944


100%|██████████| 484/484 [00:14<00:00, 32.90it/s]


-------------------------
Epoch 87, time usage: 14.714s, loss: 429.57813947062846, vec_loss [0.09168944 0.08825286]/0.089971
Validation loss: 885.688111441476 vec_loss: [0.09045344 0.10362235]/0.097038
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
2.8262618	1.3940802541646091	0.053756654	51.641738051935725


100%|██████████| 484/484 [00:15<00:00, 32.21it/s]


-------------------------
Epoch 88, time usage: 15.032s, loss: 396.224292534442, vec_loss [0.09036278 0.08657003]/0.088466
Validation loss: 867.7607062203543 vec_loss: [0.07674628 0.09092652]/0.083836
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
3.5994027	1.6090124953876843	0.06542612	52.49311430178864


100%|██████████| 484/484 [00:14<00:00, 32.46it/s]


-------------------------
Epoch 89, time usage: 14.914s, loss: 407.42916094567164, vec_loss [0.09390448 0.09016014]/0.092032
Validation loss: 849.5362745012555 vec_loss: [0.08601437 0.10260591]/0.094310
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
2.219128	1.3115712675181301	0.043997087	52.85833383060355


100%|██████████| 484/484 [00:14<00:00, 32.29it/s]


-------------------------
Epoch 90, time usage: 14.994s, loss: 422.0844244523482, vec_loss [0.09337655 0.09182817]/0.092602
Validation loss: 839.0206291562035 vec_loss: [0.0805684  0.09125131]/0.085910
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
4.601171	1.8995981108058582	0.08083352	56.54203522907813


100%|██████████| 484/484 [00:15<00:00, 32.05it/s]


-------------------------
Epoch 91, time usage: 15.104s, loss: 381.8227655395003, vec_loss [0.0911976 0.0879911]/0.089594
Validation loss: 909.2659690493629 vec_loss: [0.08420168 0.10162772]/0.092915
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
3.2156663	1.4767603820020503	0.05691998	54.10333077769287


100%|██████████| 484/484 [00:14<00:00, 32.30it/s]


-------------------------
Epoch 92, time usage: 14.987s, loss: 409.08443368959036, vec_loss [0.09133524 0.09098863]/0.091162
Validation loss: 777.6931101481119 vec_loss: [0.08149169 0.09537329]/0.088432
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
2.4793968	1.3604382439093157	0.042043954	52.05215542619376


100%|██████████| 484/484 [00:14<00:00, 32.68it/s]


-------------------------
Epoch 93, time usage: 14.817s, loss: 370.92675604702026, vec_loss [0.08783223 0.08643074]/0.087131
Validation loss: 980.3308861142114 vec_loss: [0.09993063 0.11269268]/0.106312
Epoch 00094: reducing learning rate of group 0 to 5.6250e-04.
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
2.7098086	1.4178504727103494	0.04990673	53.56623698953004


100%|██████████| 484/484 [00:14<00:00, 32.42it/s]


-------------------------
Epoch 94, time usage: 14.933s, loss: 352.59395917782115, vec_loss [0.09020816 0.08633704]/0.088273
Validation loss: 778.526249113537 vec_loss: [0.09118131 0.09907419]/0.095128
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
1.3462441	1.088817379691384	0.024529189	51.39676462438518


100%|██████████| 484/484 [00:14<00:00, 32.48it/s]


-------------------------
Epoch 95, time usage: 14.906s, loss: 341.85482107115183, vec_loss [0.08647697 0.0839607 ]/0.085219
Validation loss: 775.8887111118862 vec_loss: [0.08502195 0.09089495]/0.087958
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
2.215584	1.2804843187332153	0.041766997	52.10767461886219


100%|██████████| 484/484 [00:14<00:00, 32.81it/s]


-------------------------
Epoch 96, time usage: 14.757s, loss: 333.0385408953202, vec_loss [0.08631409 0.08451588]/0.085415
Validation loss: 831.247192746117 vec_loss: [0.0774755 0.0902297]/0.083853
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
3.0523438	1.4973749789324673	0.052784026	54.46987350042789


100%|██████████| 484/484 [00:14<00:00, 32.63it/s]


-------------------------
Epoch 97, time usage: 14.839s, loss: 344.61394185468185, vec_loss [0.08913282 0.08698384]/0.088058
Validation loss: 796.8549150739398 vec_loss: [0.08484727 0.09917172]/0.092009
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
2.431616	1.3300555077466099	0.041867133	51.24485469751807


100%|██████████| 484/484 [00:14<00:00, 33.01it/s]


-------------------------
Epoch 98, time usage: 14.669s, loss: 318.60580321382884, vec_loss [0.08586691 0.08438727]/0.085127
Validation loss: 817.7066977364676 vec_loss: [0.08419748 0.10212853]/0.093163
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
1.8052758	1.2065142664042385	0.031600177	53.58381572321164


100%|██████████| 484/484 [00:14<00:00, 32.64it/s]


-------------------------
Epoch 99, time usage: 14.835s, loss: 319.2365829215562, vec_loss [0.08478322 0.08293644]/0.083860
Validation loss: 714.0249056134905 vec_loss: [0.07922374 0.09271059]/0.085967
Best Validation Model saved to /content/drive/MyDrive/ColabNotebooks/out/working/prediction_model/checkpoints/checkpoint_99.pt
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
2.3554049	1.3252866322344	0.045825113	50.27926247204848


100%|██████████| 484/484 [00:14<00:00, 32.90it/s]


-------------------------
Epoch 100, time usage: 14.715s, loss: 316.4984054880694, vec_loss [0.08503153 0.08152901]/0.083280
Validation loss: 798.1942796253022 vec_loss: [0.07546332 0.08595261]/0.080708
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
1.4825543	1.0847711779854514	0.02792529	50.50845443705857


100%|██████████| 484/484 [00:14<00:00, 32.96it/s]


-------------------------
Epoch 101, time usage: 14.686s, loss: 329.97029116922175, vec_loss [0.08554759 0.0826331 ]/0.084090
Validation loss: 812.7952753702799 vec_loss: [0.07467933 0.08428839]/0.079484
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
2.0982494	1.2628262855789878	0.03753533	53.06256424836553


100%|██████████| 484/484 [00:14<00:00, 32.59it/s]


-------------------------
Epoch 102, time usage: 14.855s, loss: 324.01359526579046, vec_loss [0.0856318  0.08129945]/0.083466
Validation loss: 830.2880714053199 vec_loss: [0.08162402 0.09854742]/0.090086
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
2.203788	1.2746576124971563	0.040358935	50.73807825630208


100%|██████████| 484/484 [00:14<00:00, 32.83it/s]


-------------------------
Epoch 103, time usage: 14.748s, loss: 341.7701330894281, vec_loss [0.08697613 0.08344465]/0.085210
Validation loss: 859.6872834705171 vec_loss: [0.07525746 0.09134661]/0.083302
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
3.5781724	1.568762324073098	0.06750498	50.23704724603064


100%|██████████| 484/484 [00:14<00:00, 32.58it/s]


-------------------------
Epoch 104, time usage: 14.860s, loss: 303.9202074098193, vec_loss [0.08532438 0.08094129]/0.083133
Validation loss: 819.6076285952613 vec_loss: [0.09483742 0.10249331]/0.098665
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
1.3307191	1.1228314692323857	0.02689869	53.84166467827395


100%|██████████| 484/484 [00:14<00:00, 32.85it/s]


-------------------------
Epoch 105, time usage: 14.740s, loss: 294.49690224513535, vec_loss [0.08231695 0.08057184]/0.081444
Validation loss: 832.2365341186523 vec_loss: [0.07509641 0.09313494]/0.084116
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
1.8343929	1.1504436893896623	0.037428096	47.896631700506994


100%|██████████| 484/484 [00:15<00:00, 32.18it/s]


-------------------------
Epoch 106, time usage: 15.044s, loss: 306.6103989467148, vec_loss [0.08456425 0.08130892]/0.082937
Validation loss: 751.0468245006743 vec_loss: [0.08061121 0.09699943]/0.088805
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
3.483396	1.5225817290219394	0.06344897	52.77753548689059


100%|██████████| 484/484 [00:14<00:00, 32.83it/s]


-------------------------
Epoch 107, time usage: 14.749s, loss: 307.24649624785116, vec_loss [0.08308426 0.08177225]/0.082428
Validation loss: 1213.7563291277204 vec_loss: [0.07525002 0.09411641]/0.084683
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
3.246546	1.612485576759685	0.057284046	50.976933076208866


100%|██████████| 484/484 [00:14<00:00, 32.74it/s]


-------------------------
Epoch 108, time usage: 14.789s, loss: 343.4305232339654, vec_loss [0.08650581 0.0862335 ]/0.086370
Validation loss: 867.8847848801386 vec_loss: [0.07955196 0.08833407]/0.083943
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
1.9845394	1.2780941291288896	0.03633607	49.41344223487985


100%|██████████| 484/484 [00:14<00:00, 32.68it/s]


-------------------------
Epoch 109, time usage: 14.815s, loss: 294.1828332696079, vec_loss [0.08101564 0.08097845]/0.080997
Validation loss: 853.1505047026135 vec_loss: [0.07925645 0.09140597]/0.085331
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
1.5178548	1.1066437851298938	0.025195606	49.55063062038043


100%|██████████| 484/484 [00:14<00:00, 32.75it/s]


-------------------------
Epoch 110, time usage: 14.784s, loss: 306.5003579825409, vec_loss [0.08301008 0.08302429]/0.083017
Validation loss: 908.4810391380673 vec_loss: [0.07725551 0.09483357]/0.086045
Epoch 00111: reducing learning rate of group 0 to 4.2188e-04.
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
1.6187538	1.114108221097426	0.029036343	49.92794858745743


100%|██████████| 484/484 [00:14<00:00, 32.54it/s]


-------------------------
Epoch 111, time usage: 14.878s, loss: 278.60915570219686, vec_loss [0.08165616 0.08063567]/0.081146
Validation loss: 858.9378531319754 vec_loss: [0.08424   0.1027624]/0.093501
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
2.5065494	1.3933989622376182	0.042394068	50.697848752553845


100%|██████████| 484/484 [00:14<00:00, 33.09it/s]


-------------------------
Epoch 112, time usage: 14.632s, loss: 274.2006293683013, vec_loss [0.08318701 0.08026264]/0.081725
Validation loss: 738.8330248878116 vec_loss: [0.08923912 0.09576292]/0.092501
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
1.9182137	1.259758933023973	0.036514476	48.14415899788083


100%|██████████| 484/484 [00:14<00:00, 32.85it/s]


-------------------------
Epoch 113, time usage: 14.736s, loss: 272.35678453681885, vec_loss [0.08174139 0.08037879]/0.081060
Validation loss: 869.3335117158435 vec_loss: [0.07012421 0.08704428]/0.078584
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
1.6888787	1.1496864774010398	0.032018043	49.405647126258


100%|██████████| 484/484 [00:14<00:00, 33.08it/s]


-------------------------
Epoch 114, time usage: 14.636s, loss: 268.00855226753174, vec_loss [0.08152421 0.07978748]/0.080656
Validation loss: 836.6995504470099 vec_loss: [0.07523851 0.09071201]/0.082975
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
1.5741467	1.1432366100224582	0.030969197	48.69934937224959


100%|██████████| 484/484 [00:14<00:00, 33.05it/s]


-------------------------
Epoch 115, time usage: 14.649s, loss: 269.89467690207744, vec_loss [0.08144275 0.07976294]/0.080603
Validation loss: 847.6831512451172 vec_loss: [0.08546302 0.09342045]/0.089442
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
1.6026511	1.119465860453519	0.03225498	48.720479804299615


100%|██████████| 484/484 [00:14<00:00, 32.55it/s]


-------------------------
Epoch 116, time usage: 14.874s, loss: 273.8576233919002, vec_loss [0.08174723 0.08039891]/0.081073
Validation loss: 876.8937465122768 vec_loss: [0.07574378 0.09026424]/0.083004
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
1.8192489	1.2371976104649631	0.033382796	50.47045871563207


100%|██████████| 484/484 [00:14<00:00, 32.93it/s]


-------------------------
Epoch 117, time usage: 14.700s, loss: 265.31591355505066, vec_loss [0.08146265 0.08041365]/0.080938
Validation loss: 843.4497575305757 vec_loss: [0.07901903 0.09808226]/0.088551
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
1.486615	1.0730341564525256	0.028493112	48.813260783521926


100%|██████████| 484/484 [00:14<00:00, 32.98it/s]


-------------------------
Epoch 118, time usage: 14.682s, loss: 263.94849900174734, vec_loss [0.07913303 0.07951888]/0.079326
Validation loss: 858.75475529262 vec_loss: [0.0751958  0.09326988]/0.084233
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
1.1913803	1.0488141666759143	0.019591438	48.93002907664617


100%|██████████| 484/484 [00:14<00:00, 32.71it/s]


-------------------------
Epoch 119, time usage: 14.803s, loss: 264.3535986971264, vec_loss [0.08145273 0.07989953]/0.080676
Validation loss: 851.4305841355097 vec_loss: [0.08999592 0.11149824]/0.100747
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
1.9399395	1.1576356291770935	0.032196313	51.032934418140094
Model saved to /content/drive/MyDrive/ColabNotebooks/out/working/prediction_model/checkpoints/icheckpoint_119.pt


100%|██████████| 484/484 [00:14<00:00, 32.70it/s]


-------------------------
Epoch 120, time usage: 14.806s, loss: 260.27491733456446, vec_loss [0.0812915  0.08018091]/0.080736
Validation loss: 833.0599615914481 vec_loss: [0.08007258 0.09883902]/0.089456
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
1.647678	1.1535519849170337	0.03213882	49.297496358624294


100%|██████████| 484/484 [00:14<00:00, 32.57it/s]


-------------------------
Epoch 121, time usage: 14.866s, loss: 253.1084834957911, vec_loss [0.07976954 0.07832705]/0.079048
Validation loss: 767.147809346517 vec_loss: [0.07955579 0.09478283]/0.087169
Epoch 00122: reducing learning rate of group 0 to 3.1641e-04.
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
1.4825019	1.1137400513345546	0.027342744	48.141871815898305


100%|██████████| 484/484 [00:14<00:00, 32.39it/s]


-------------------------
Epoch 122, time usage: 14.948s, loss: 243.34154882509847, vec_loss [0.07981651 0.07689434]/0.078355
Validation loss: 839.1111610049293 vec_loss: [0.07907198 0.09771091]/0.088391
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
1.3507333	1.0883743979714133	0.025589257	48.951027251396795


100%|██████████| 484/484 [00:14<00:00, 32.77it/s]


-------------------------
Epoch 123, time usage: 14.777s, loss: 242.17456637926338, vec_loss [0.07876801 0.07745969]/0.078114
Validation loss: 804.9075669788178 vec_loss: [0.07738647 0.09427897]/0.085833
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
1.4721993	1.1141613071615046	0.028443987	48.865283540261615


100%|██████████| 484/484 [00:14<00:00, 32.65it/s]


-------------------------
Epoch 124, time usage: 14.826s, loss: 239.62023088754702, vec_loss [0.0777176  0.07682893]/0.077273
Validation loss: 846.8177755446661 vec_loss: [0.0750268  0.09377953]/0.084403
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
1.4757257	1.1163831407373601	0.024996297	48.519156563418896


100%|██████████| 484/484 [00:14<00:00, 33.13it/s]


-------------------------
Epoch 125, time usage: 14.611s, loss: 238.6152524711672, vec_loss [0.07905816 0.07742082]/0.078239
Validation loss: 850.7602960495722 vec_loss: [0.08455954 0.09967704]/0.092118
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
2.2626145	1.241433783011003	0.045833945	48.35241550169539


100%|██████████| 484/484 [00:14<00:00, 33.09it/s]


-------------------------
Epoch 126, time usage: 14.633s, loss: 237.987485727988, vec_loss [0.07992022 0.07775176]/0.078836
Validation loss: 827.7026650565011 vec_loss: [0.08431546 0.10226107]/0.093288
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
1.4341916	1.0677617896686902	0.02646294	48.08805489264892


100%|██████████| 484/484 [00:14<00:00, 32.76it/s]


-------------------------
Epoch 127, time usage: 14.780s, loss: 234.80153760043058, vec_loss [0.07867604 0.07777701]/0.078227
Validation loss: 795.4303999401275 vec_loss: [0.07622503 0.09995504]/0.088090
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
2.043402	1.2222556580196728	0.035094373	50.080435354846266


100%|██████████| 484/484 [00:14<00:00, 32.54it/s]


-------------------------
Epoch 128, time usage: 14.879s, loss: 232.5237015968512, vec_loss [0.07745362 0.07684547]/0.077150
Validation loss: 826.9491199311756 vec_loss: [0.08340368 0.1004613 ]/0.091932
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
1.7192711	1.1654029867865823	0.033838477	48.51839654964217


100%|██████████| 484/484 [00:14<00:00, 32.56it/s]


-------------------------
Epoch 129, time usage: 14.868s, loss: 238.96055041857002, vec_loss [0.07818139 0.07712171]/0.077652
Validation loss: 782.9398331415085 vec_loss: [0.07674483 0.09185039]/0.084298
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
1.1423436	1.0175343914465471	0.018242696	47.63033845838856


100%|██████████| 484/484 [00:14<00:00, 32.62it/s]


-------------------------
Epoch 130, time usage: 14.843s, loss: 230.34646237586156, vec_loss [0.07821099 0.07671598]/0.077463
Validation loss: 840.4531124659946 vec_loss: [0.08126992 0.097693  ]/0.089481
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
2.051772	1.212151977148923	0.03952648	46.828798683814625


100%|██████████| 484/484 [00:14<00:00, 32.54it/s]


-------------------------
Epoch 131, time usage: 14.879s, loss: 227.08124044118833, vec_loss [0.07741294 0.07525404]/0.076333
Validation loss: 802.5998389834449 vec_loss: [0.08135932 0.09815122]/0.089755
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
1.9538932	1.1950996789065274	0.034037467	48.903178050703126


100%|██████████| 484/484 [00:14<00:00, 32.27it/s]


-------------------------
Epoch 132, time usage: 15.001s, loss: 230.0139212332481, vec_loss [0.07879385 0.07626925]/0.077532
Validation loss: 840.6306929815383 vec_loss: [0.08038362 0.0967924 ]/0.088588
Epoch 00133: reducing learning rate of group 0 to 2.3730e-04.
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
1.7646502	1.1314500895413486	0.031638987	49.012020840685004


100%|██████████| 484/484 [00:14<00:00, 32.96it/s]


-------------------------
Epoch 133, time usage: 14.690s, loss: 219.77623802374217, vec_loss [0.07687637 0.07530403]/0.076090
Validation loss: 841.4003463018508 vec_loss: [0.08413013 0.09518774]/0.089659
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
1.2635201	1.0446890592575073	0.021581436	48.42930915044511


100%|██████████| 484/484 [00:14<00:00, 32.88it/s]


-------------------------
Epoch 134, time usage: 14.725s, loss: 218.11783336607877, vec_loss [0.07686035 0.07541678]/0.076139
Validation loss: 796.2480813889276 vec_loss: [0.07517715 0.08889145]/0.082034
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
1.2532902	1.0228353467854587	0.023004351	47.93530423158761


100%|██████████| 484/484 [00:14<00:00, 32.72it/s]


-------------------------
Epoch 135, time usage: 14.797s, loss: 215.05513910025604, vec_loss [0.07604612 0.07496928]/0.075508
Validation loss: 780.4802118937174 vec_loss: [0.07594054 0.09079789]/0.083369
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
1.1887355	1.0646525485949083	0.022194311	48.7303671240495


100%|██████████| 484/484 [00:14<00:00, 32.74it/s]


-------------------------
Epoch 136, time usage: 14.788s, loss: 215.5276743045523, vec_loss [0.07675872 0.07547612]/0.076117
Validation loss: 799.1545415605817 vec_loss: [0.07323328 0.08609325]/0.079663
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
1.5734767	1.023123719475486	0.030463321	47.68158164217343


100%|██████████| 484/484 [00:14<00:00, 32.73it/s]


-------------------------
Epoch 137, time usage: 14.795s, loss: 216.0241760380012, vec_loss [0.07631106 0.07436921]/0.075340
Validation loss: 808.9047597249349 vec_loss: [0.07476705 0.0894371 ]/0.082102
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
1.26797	1.0144291520118713	0.02374959	48.423905849885344


100%|██████████| 484/484 [00:14<00:00, 32.82it/s]


-------------------------
Epoch 138, time usage: 14.754s, loss: 211.5967551932847, vec_loss [0.07581245 0.073557  ]/0.074685
Validation loss: 812.736989520845 vec_loss: [0.07620158 0.09061115]/0.083406
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
1.9933769	1.1774923205375671	0.038061567	47.19319113881366


100%|██████████| 484/484 [00:14<00:00, 32.59it/s]


-------------------------
Epoch 139, time usage: 14.854s, loss: 210.61435585967766, vec_loss [0.07498959 0.07313094]/0.074060
Validation loss: 825.8350292387463 vec_loss: [0.08773214 0.09750249]/0.092617
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
1.6592113	1.1290253373709591	0.03361793	47.79173427616022
Model saved to /content/drive/MyDrive/ColabNotebooks/out/working/prediction_model/checkpoints/icheckpoint_139.pt


100%|██████████| 484/484 [00:14<00:00, 33.07it/s]


-------------------------
Epoch 140, time usage: 14.640s, loss: 209.6457461995527, vec_loss [0.07381058 0.07360949]/0.073710
Validation loss: 886.3613117762974 vec_loss: [0.07090273 0.08594877]/0.078426
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
1.1658231	0.9949472085996107	0.016691016	48.94952305746367


100%|██████████| 484/484 [00:14<00:00, 32.59it/s]


-------------------------
Epoch 141, time usage: 14.854s, loss: 214.59142443759382, vec_loss [0.07703058 0.07512885]/0.076080
Validation loss: 809.4507885887509 vec_loss: [0.07805875 0.09365787]/0.085858
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
1.4546665	1.0302851714871146	0.027895903	47.65501282722837


100%|██████████| 484/484 [00:14<00:00, 32.39it/s]


-------------------------
Epoch 142, time usage: 14.946s, loss: 205.8827073278506, vec_loss [0.07514529 0.07409911]/0.074622
Validation loss: 795.2832347324917 vec_loss: [0.07195335 0.08951274]/0.080733
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
1.536584	1.0896422808820552	0.029751211	46.69803253891273


100%|██████████| 484/484 [00:14<00:00, 32.56it/s]


-------------------------
Epoch 143, time usage: 14.870s, loss: 205.32235355219566, vec_loss [0.07485264 0.07474585]/0.074799
Validation loss: 779.5668398539225 vec_loss: [0.07943976 0.09324015]/0.086340
Epoch 00144: reducing learning rate of group 0 to 1.7798e-04.
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
1.0189713	0.9832066785205494	0.01693774	47.57365533939289


100%|██████████| 484/484 [00:14<00:00, 32.52it/s]


-------------------------
Epoch 144, time usage: 14.887s, loss: 205.72540508617055, vec_loss [0.07640126 0.07454728]/0.075474
Validation loss: 794.517072405134 vec_loss: [0.07621621 0.09248379]/0.084350
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
1.3338554	1.0565293675119227	0.02582956	47.409519684374615


100%|██████████| 484/484 [00:14<00:00, 32.83it/s]


-------------------------
Epoch 145, time usage: 14.746s, loss: 200.6781418697893, vec_loss [0.07466937 0.07429902]/0.074484
Validation loss: 787.897784096854 vec_loss: [0.07459617 0.09050177]/0.082549
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
1.0412403	0.9732892513275146	0.018652149	47.473998843381544


100%|██████████| 484/484 [00:14<00:00, 32.81it/s]


-------------------------
Epoch 146, time usage: 14.756s, loss: 199.33203749223188, vec_loss [0.07618358 0.07412705]/0.075155
Validation loss: 793.3563657488141 vec_loss: [0.07723652 0.09139782]/0.084317
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
1.1652275	0.9963513179258867	0.021210859	47.05413500703824


100%|██████████| 484/484 [00:14<00:00, 32.64it/s]


-------------------------
Epoch 147, time usage: 14.836s, loss: 197.35420900139925, vec_loss [0.07437666 0.07266647]/0.073522
Validation loss: 762.0697966076079 vec_loss: [0.07778112 0.09524151]/0.086511
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
1.3673062	1.0468673272566362	0.02468226	48.08557336073819


100%|██████████| 484/484 [00:14<00:00, 32.62it/s]


-------------------------
Epoch 148, time usage: 14.842s, loss: 197.33533465172633, vec_loss [0.0747766  0.07292187]/0.073849
Validation loss: 749.4926950363886 vec_loss: [0.07261117 0.09147649]/0.082044
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
1.1407638	1.0203043655915693	0.020994721	47.270383763571154


100%|██████████| 484/484 [00:14<00:00, 32.50it/s]


-------------------------
Epoch 149, time usage: 14.895s, loss: 199.32048893763013, vec_loss [0.07348505 0.07305236]/0.073269
Validation loss: 807.974243527367 vec_loss: [0.08078028 0.09391274]/0.087347
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
1.9695786	1.1752062385732478	0.03904834	46.66415914711347
Training completed


In [None]:
import torch
torch.cuda.is_available()

True

In [None]:
# Check Saved Checkpoints
# Define the directory path
dir_path = r'/content/drive/MyDrive/ColabNotebooks/out/working/prediction_model/checkpoints'

# Loop over all files in the directory
for filename in os.listdir(dir_path):
    # Check if the file is a regular file (not a directory)
    if os.path.isfile(os.path.join(dir_path, filename)):
        # Do something with the file
        print(filename)

checkpoint_0.pt
checkpoint_1.pt
checkpoint_2.pt
checkpoint_3.pt
checkpoint_5.pt
checkpoint_8.pt
checkpoint_10.pt
checkpoint_12.pt
checkpoint_13.pt
checkpoint_17.pt
icheckpoint_19.pt
checkpoint_24.pt
checkpoint_26.pt
checkpoint_28.pt
checkpoint_31.pt
checkpoint_37.pt
checkpoint_38.pt
icheckpoint_39.pt
checkpoint_45.pt
checkpoint_49.pt
checkpoint_52.pt
checkpoint_59.pt
checkpoint_62.pt
checkpoint_76.pt
icheckpoint_79.pt
checkpoint_80.pt
checkpoint_82.pt
checkpoint_99.pt
icheckpoint_119.pt
icheckpoint_139.pt
checkpoint_latest.pt


Download the checkpoints.

In [None]:
!zip -r /content/drive/MyDrive/ColabNotebooks/out/checkpoints.zip /content/drive/MyDrive/ColabNotebooks/out/working/prediction_model

  adding: content/drive/MyDrive/ColabNotebooks/out/working/prediction_model/ (stored 0%)
  adding: content/drive/MyDrive/ColabNotebooks/out/working/prediction_model/checkpoints/ (stored 0%)
  adding: content/drive/MyDrive/ColabNotebooks/out/working/prediction_model/checkpoints/checkpoint_0.pt (deflated 7%)
  adding: content/drive/MyDrive/ColabNotebooks/out/working/prediction_model/checkpoints/checkpoint_1.pt (deflated 7%)
  adding: content/drive/MyDrive/ColabNotebooks/out/working/prediction_model/checkpoints/checkpoint_2.pt (deflated 7%)
  adding: content/drive/MyDrive/ColabNotebooks/out/working/prediction_model/checkpoints/checkpoint_3.pt (deflated 7%)
  adding: content/drive/MyDrive/ColabNotebooks/out/working/prediction_model/checkpoints/checkpoint_5.pt (deflated 7%)
  adding: content/drive/MyDrive/ColabNotebooks/out/working/prediction_model/checkpoints/checkpoint_8.pt (deflated 7%)
  adding: content/drive/MyDrive/ColabNotebooks/out/working/prediction_model/checkpoints/checkpoint_10.

In [None]:
from IPython.display import FileLink
FileLink(r'/content/drive/MyDrive/ColabNotebooks/out/checkpoints.zip')

### 8. Testing

#### 8.1 Test for Seen Dataset

In [None]:
# Please change the output dir & model path
TEST_DIR = r'/content/drive/MyDrive/ColabNotebooks/original_data/test_seen' # Dataset directory for testing (unseen_subjects_test_set)
OUT_DIR = r'/content/drive/MyDrive/ColabNotebooks/out/working/test_results/test_seen' # Output directory for both traning and testing
MODEL_PATH = r'/content/drive/MyDrive/ColabNotebooks/out/working/prediction_model/checkpoints/checkpoint_latest.pt' # Model path for testing
# Load config settings
kwargs = load_config()

import warnings
# Suspend warnings
warnings.filterwarnings('ignore')
test(**kwargs)

Testing LSTM model
LSTM model
Network constructed. trainable parameters: 444632
The model is loaded.
Model /content/drive/MyDrive/ColabNotebooks/out/working/prediction_model/checkpoints/checkpoint_latest.pt loaded to device cuda:0.
Reconstruct the tracermini_hw101_test20230311112635T
Reconstruct the tracermini_hw101_test20230313011731T
Reconstruct the tracermini_hw101_test20230313010204T
Reconstruct the tracermini_hw101_test20230311112235T
Reconstruct the tracermini_hw101_test20230313010546T
Reconstruct the tracermini_hw101_test20230313011357T
Reconstruct the tracermini_hw101_test20230311111842T
Reconstruct the tracermini_hw101_test20230311111507T
Reconstruct the tracermini_hw101_test20230313010954T
Reconstruct the tracermini_hw101_test20230311111027T
Reconstruct the tracermini_hw101_test20230313013335T
Reconstruct the tracermini_hw101_test20230313012956T


#### 8.2 Test for Unseen Dataset

In [None]:
# Please change the output dir & model path
TEST_DIR = '/content/drive/MyDrive/ColabNotebooks/original_data/test_unseen' # Dataset directory for testing (unseen_subjects_test_set)
OUT_DIR = '/content/drive/MyDrive/ColabNotebooks/out/working/test_results/test_unseen' # Output directory for both traning and testing
MODEL_PATH = '/content/drive/MyDrive/ColabNotebooks/out/working/prediction_model/checkpoints/checkpoint_latest.pt' # Model path for testing
# Load config settings
kwargs = load_config()
import warnings
# Suspend warnings
warnings.filterwarnings('ignore')
test(**kwargs)

Testing LSTM model
LSTM model
Network constructed. trainable parameters: 444632
The model is loaded.
Model /content/drive/MyDrive/ColabNotebooks/out/working/prediction_model/checkpoints/checkpoint_latest.pt loaded to device cuda:0.
Reconstruct the tracermini_unseen_cym20230314100816T
Reconstruct the tracermini_unseen_cym20230314101001T
Reconstruct the tracermini_unseen_cym20230314101434T
Reconstruct the tracermini_unseen_cym20230314100559T
Reconstruct the tracermini_unseen_cym20230314101010T
Reconstruct the tracermini_unseen_cym20230314101230T
Reconstruct the tracermini_unseen_cym20230314100325T
Reconstruct the tracermini_unseen_cym20230314100103T
Reconstruct the tracermini_unseen_cym20230314101850T
Reconstruct the tracermini_unseen_hw520230314082319T
Reconstruct the tracermini_unseen_hw520230314082000T
Reconstruct the tracermini_unseen_hw520230314091110T
Reconstruct the tracermini_unseen_hw520230314091338T
Reconstruct the tracermini_unseen_hw520230314081542T
Reconstruct the tracermini

#### 8.3 Check the Results

In [None]:
for key, value in seen_unseen_pred.items():
    # print the length of each sequence
    print(seen_unseen_dataset[key], len(value))

tracermini_hw101_test20230311112635T 42097
tracermini_hw101_test20230311111507T 37913
tracermini_hw101_test20230311112235T 43462
tracermini_hw101_test20230313011357T 38324
tracermini_hw101_test20230311111842T 42839
tracermini_hw101_test20230313010954T 39610
tracermini_hw101_test20230313010546T 38378
tracermini_hw101_test20230313011731T 37610
tracermini_hw101_test20230313013335T 38254
tracermini_hw101_test20230311111027T 44186
tracermini_hw101_test20230313012956T 38694
tracermini_hw101_test20230313010204T 39085
tracermini_unseen_hw520230314091844T 19545
tracermini_unseen_cym20230314101001T 283
tracermini_unseen_hw520230314082319T 31657
tracermini_unseen_cym20230314101636T 22690
tracermini_unseen_hw520230314083031T 21208
tracermini_unseen_hw520230314091110T 25327
tracermini_unseen_cym20230314100816T 18914
tracermini_unseen_cym20230314101230T 21364
tracermini_unseen_cym20230314100559T 24483
tracermini_unseen_hw520230314081212T 37662
tracermini_unseen_hw520230314082000T 34311
tracermini_un

#### 8.4 Save the Results

In [None]:
with open("/content/drive/MyDrive/ColabNotebooks/out/submission.csv", "w") as f:
    # The first row must be "Id, Category"
    f.write("Id,Prediction\n")

    # For the rest of the rows, each image id corresponds to a predicted class.
    for key, value in seen_unseen_pred.items():
        # print the length of each sequence
        print(seen_unseen_dataset[key], len(value))
        # print the prediction
        for i in range(len(value)):
            f.write("{},{}\n".format(key+'_'+str(i), value[i]))

tracermini_hw101_test20230311112635T 42097
tracermini_hw101_test20230311111507T 37913
tracermini_hw101_test20230311112235T 43462
tracermini_hw101_test20230313011357T 38324
tracermini_hw101_test20230311111842T 42839
tracermini_hw101_test20230313010954T 39610
tracermini_hw101_test20230313010546T 38378
tracermini_hw101_test20230313011731T 37610
tracermini_hw101_test20230313013335T 38254
tracermini_hw101_test20230311111027T 44186
tracermini_hw101_test20230313012956T 38694
tracermini_hw101_test20230313010204T 39085
tracermini_unseen_hw520230314091844T 19545
tracermini_unseen_cym20230314101001T 283
tracermini_unseen_hw520230314082319T 31657
tracermini_unseen_cym20230314101636T 22690
tracermini_unseen_hw520230314083031T 21208
tracermini_unseen_hw520230314091110T 25327
tracermini_unseen_cym20230314100816T 18914
tracermini_unseen_cym20230314101230T 21364
tracermini_unseen_cym20230314100559T 24483
tracermini_unseen_hw520230314081212T 37662
tracermini_unseen_hw520230314082000T 34311
tracermini_un