# COMP 7310 Personal Project

### 1. Config Setup

In [20]:
# Make Dirs for Model Storage and Test Results
import os

try:
    # Model Storage Path
    model_prediction_path = '/kaggle/working/prediction_model'
    os.mkdir(model_prediction_path)
except:
    pass
try:
    # Test Results Path
    test_result_path = '/kaggle/working/test_results'
    os.mkdir(test_result_path)
except:
    pass
try:
    # Test Seen & Unseen Path
    os.mkdir(test_result_path + '/test_seen')
    os.mkdir(test_result_path + '/test_unseen')
except:
    pass

In [21]:
"""
config settings:
"""
### TFNET parameters
BATCH_SIZE = 72 # Training batch size
TEST_BATCH_SIZE = 1 # Test batch size
EPOCHS = 150 # Traning epoch
SAVE_INTERVAL = 20 # Save model every 20 epochs
STEP_SIZE = 200 # Step size for moving forward the window (For training)
TEST_STEP_SIZE = 400 # Step size for moving forward the window (For testing)
WINDOW_SIZE = 400 # Window size for training and testing
INPUT_CHANNEL = 6 # Input feature dimension (Gryo + Acce)
OUTPUT_CHANNEL = 2 # Output dimension (2D velocity vector)
SAMPLING_RATE = 200 # Sampling rate
LAYER_SIZE = 100 # The size of LSTM
LAYERS = 3 # The layer size of LSTM
DROPOUT = 0.1 # Dropout probability
LEARNING_RATE = 0.0003 # Learning rate
NUM_WORKERS = 8
### ------------------ ###

### Data preprocessing parameters
FEATURE_SIGMA = 2.0 # Sigma for feature gaussian smoothing
TARGET_SIGMA = 30.0 # Sigma for target gaussian smoothing
### ------------------ ###

### Device for training
DEVICE = "cuda:0" # You can choose GPU or CPU
### ------------------ ###

### Training and testing setting
DATA_DIR = '/kaggle/input/comp7310-project-1-imu-indoor-tracking/original_data/train_dataset' # Dataset directory for training
VAL_DATA_DIR = '/kaggle/input/comp7310-project-1-imu-indoor-tracking/original_data/val_dataset' # Dataset directory for validation
TEST_DIR = '/kaggle/input/comp7310-project-1-imu-indoor-tracking/original_data/test_seen' # Dataset directory for testing (unseen_subjects_test_set)
OUT_DIR = '/kaggle/working/prediction_model' # Output directory for both traning and testing
MODEL_PATH = '' # Model path for testing
### ------------------ ###

def load_config():
    kwargs = {}
    kwargs['batch_size'] = BATCH_SIZE
    kwargs['test_batch_size'] = TEST_BATCH_SIZE
    kwargs['epochs'] = EPOCHS
    kwargs['save_interval'] = SAVE_INTERVAL
    kwargs['step_size'] = STEP_SIZE
    kwargs['test_step_size'] = TEST_STEP_SIZE
    kwargs['window_size'] = WINDOW_SIZE
    kwargs['sampling_rate'] = SAMPLING_RATE
    kwargs['input_channel'] = INPUT_CHANNEL
    kwargs['output_channel'] = OUTPUT_CHANNEL
    kwargs['layer_size'] = LAYER_SIZE
    kwargs['layers'] = LAYERS
    kwargs['dropout'] = DROPOUT
    kwargs['learning_rate'] = LEARNING_RATE
    kwargs['num_workers'] = NUM_WORKERS

    kwargs['feature_sigma'] = FEATURE_SIGMA
    kwargs['target_sigma'] = TARGET_SIGMA

    kwargs['device'] = DEVICE

    kwargs['data_dir'] = DATA_DIR
    kwargs['val_data_dir'] = VAL_DATA_DIR
    kwargs['test_dir'] = TEST_DIR
    kwargs['out_dir'] = OUT_DIR
    kwargs['model_path'] = MODEL_PATH

    return kwargs


### 2. Model Design

In [22]:
import torch
from scipy import signal
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torchvision.models as models
from torch.nn.utils import weight_norm

### BiLSTM model
class BilinearLSTMSeqNetwork(torch.nn.Module):
    def __init__(self, input_size, out_size, batch_size, device,
                 lstm_size=100, lstm_layers=3, dropout=0):
        """
        LSTM network with Bilinear layer
        Input: torch array [batch x frames x input_size]
        Output: torch array [batch x frames x out_size]
        :param input_size: num. channels in input
        :param out_size: num. channels in output
        :param batch_size:
        :param device: torch device
        :param lstm_size: number of LSTM units per layer
        :param lstm_layers: number of LSTM layers
        :param dropout: dropout probability of LSTM (@ref https://pytorch.org/docs/stable/nn.html#lstm)
        """
        super(BilinearLSTMSeqNetwork, self).__init__()
        self.input_size = input_size
        self.lstm_size = lstm_size
        self.output_size = out_size
        self.num_layers = lstm_layers
        self.batch_size = batch_size
        self.device = device

        self.bilinear = torch.nn.Bilinear(self.input_size, self.input_size, self.input_size * 4)
        self.lstm = torch.nn.LSTM(self.input_size * 5, self.lstm_size, self.num_layers, batch_first=True, dropout=dropout)
        self.linear1 = torch.nn.Linear(self.lstm_size + self.input_size * 5, self.output_size * 5)
        self.linear2 = torch.nn.Linear(self.output_size * 5, self.output_size)
        self.hidden = self.init_weights()

    def forward(self, input):
        input_mix = self.bilinear(input, input)
        input_mix = torch.cat([input, input_mix], dim=2)
        output, self.hidden = self.lstm(input_mix, self.init_weights())
        output = torch.cat([input_mix, output], dim=2)
        output = self.linear1(output)
        output = self.linear2(output)
        return output

    def init_weights(self):
        h0 = torch.zeros(self.num_layers, self.batch_size, self.lstm_size)
        c0 = torch.zeros(self.num_layers, self.batch_size, self.lstm_size)
        h0 = h0.to(self.device)
        c0 = c0.to(self.device)
        return Variable(h0), Variable(c0)
### --------------------- End of the model --------------------- ###

### 3. Data Loader

In [23]:
!pip install pyquaternion==0.9.9
!pip install numpy-quaternion==2022.4.3



In [24]:
import json
import random
import os
from os import path as osp

import h5py
import torch
import numpy as np
import quaternion
import math
from scipy.ndimage import gaussian_filter1d
from torch.utils.data import Dataset

def convert_data(data_path):
    """
    Data Processing
    :param data_path
    This function is used to convert the raw data
    stored in hdf5 format to numpy array
    """
    # read hdf5 file
    file = h5py.File(os.path.join(data_path, 'data.hdf5'), 'r')
    synced = file['synced'] # synced data (gyro, acce, magn ...)
    pose = file['pose'] # pose data (ground truth)
    # timestamp (1D)
    timestamps = np.array(synced['time'])
    timestamps = timestamps.reshape((len(timestamps), 1)) / 10**9 # - timestamps[0] # start from 0
    timestamps = timestamps - timestamps[0]
    # gyroscope (3D), accelerometer (3D), magnetometer (3D), rotation vector (4D)
    gyro = np.array(synced['gyro'])
    acce = np.array(synced['acce'])
    # game rotation vector can convert acce and gyro
    # from body frame to navigation frame
    rotation_vector = synced['game_rv']

    # position (3D), orientation (4D)
    pose = file['pose']
    tango_pos = np.array(pose['tango_pos'])
    tango_ori = np.array(pose['tango_ori'])

    # Compute the IMU orientation in the Tango coordinate frame.
    init_tango_ori = quaternion.quaternion(*tango_ori[0])
    ori = rotation_vector
    ori_q = quaternion.from_float_array(ori)
    init_rotor = init_tango_ori * ori_q[0].conj()
    ori_q = init_rotor * ori_q

    gyro_q = quaternion.from_float_array(np.concatenate([np.zeros([gyro.shape[0], 1]), gyro], axis=1))
    acce_q = quaternion.from_float_array(np.concatenate([np.zeros([acce.shape[0], 1]), acce], axis=1))
    glob_gyro = quaternion.as_float_array(ori_q * gyro_q * ori_q.conj())[:, 1:]
    glob_acce = quaternion.as_float_array(ori_q * acce_q * ori_q.conj())[:, 1:]

    rawdata = np.concatenate((timestamps, glob_gyro, glob_acce), axis = 1)
    groundtruth = np.concatenate((timestamps, tango_pos, tango_ori), axis = 1)

    return rawdata, groundtruth

def convert_data_test(data_path):
    """
    Data Processing
    :param data_path
    This function is used to convert the raw data
    stored in hdf5 format to numpy array
    """
    # read hdf5 file
    file = h5py.File(os.path.join(data_path, 'data.hdf5'), 'r')
    synced = file['synced'] # synced data (gyro, acce, magn ...)
    pose = file['pose'] # pose data (ground truth)
    # timestamp (1D)
    timestamps = np.array(synced['time'])
    timestamps = timestamps.reshape((len(timestamps), 1)) / 10**9 # - timestamps[0] # start from 0
    timestamps = timestamps - timestamps[0]
    # gyroscope (3D), accelerometer (3D), magnetometer (3D), rotation vector (4D)
    gyro = np.array(synced['gyro'])
    acce = np.array(synced['acce'])
    # game rotation vector can convert acce and gyro
    # from body frame to navigation frame
    rotation_vector = synced['game_rv']

    # position (3D), orientation (4D)
    pose = file['pose']
    tango_ori = np.array(pose['tango_ori'])

    # Compute the IMU orientation in the Tango coordinate frame.
    init_tango_ori = quaternion.quaternion(*tango_ori[0])
    ori = rotation_vector
    ori_q = quaternion.from_float_array(ori)
    init_rotor = init_tango_ori * ori_q[0].conj()
    ori_q = init_rotor * ori_q

    gyro_q = quaternion.from_float_array(np.concatenate([np.zeros([gyro.shape[0], 1]), gyro], axis=1))
    acce_q = quaternion.from_float_array(np.concatenate([np.zeros([acce.shape[0], 1]), acce], axis=1))
    glob_gyro = quaternion.as_float_array(ori_q * gyro_q * ori_q.conj())[:, 1:]
    glob_acce = quaternion.as_float_array(ori_q * acce_q * ori_q.conj())[:, 1:]

    rawdata = np.concatenate((timestamps, glob_gyro, glob_acce), axis = 1)

    return rawdata

class GlobSequence():
    """
    Property: global coordinate frame
    """
    # add 3-axis magnetometer
    # feature_dim = 9
    feature_dim = 6
    target_dim = 2
    aux_dim = 8

    def __init__(self, data_path = None, **kwargs):
        super().__init__()
        self.ts, self.features, self.targets, self.gt_pos = None, None, None, None
        # self.info = {}
        self.w = kwargs.get('interval', 1)
        if data_path is not None:
            self.load(data_path)

    def load(self, data_path):
        # print("the data_path is:", data_path)
        data, ground_truth = convert_data(data_path)
        # already in global coordinate frame and start from start frame
        # timestamp (1D) gyroscope (3D), accelerometer (3D)
        gyro = data[:, 1:4]
        acce = data[:, 4:7]
        ts = data[:, 0]
        # tango position
        tango_pos = ground_truth[:, 1:4]
        # tango orientation
        tango_ori = ground_truth[:, 4:8]

        dt = (ts[self.w:] - ts[:-self.w])[:, None]
        # calculate the global velocity
        glob_v = (tango_pos[self.w:] - tango_pos[:-self.w]) / dt

        self.ts = ts
        self.features = np.concatenate([gyro, acce], axis = 1)
        # We only use the global velocity in the floor plane
        self.targets = glob_v[:, :2]
        self.orientations = tango_ori # quaternion.as_float_array(tango_ori)
        self.gt_pos = tango_pos

    def get_feature(self):
        return self.features

    def get_target(self):
        return self.targets

    def get_aux(self):
        return np.concatenate([self.ts[:, None], self.orientations, self.gt_pos], axis = 1)

class GlobSequenceTest():
    """
    Property: global coordinate frame
    """
    # add 3-axis magnetometer
    # feature_dim = 9
    feature_dim = 6
    aux_dim = 8

    def __init__(self, data_path = None, **kwargs):
        super().__init__()
        self.ts, self.features, self.targets, self.gt_pos = None, None, None, None
        # self.info = {}
        self.w = kwargs.get('interval', 1)
        if data_path is not None:
            self.load(data_path)

    def load(self, data_path):
        # print("the data_path is:", data_path)
        data = convert_data_test(data_path)
        # already in global coordinate frame and start from start frame
        # timestamp (1D) gyroscope (3D), accelerometer (3D)
        gyro = data[:, 1:4]
        acce = data[:, 4:7]
        ts = data[:, 0]

        dt = (ts[self.w:] - ts[:-self.w])[:, None]

        self.ts = ts
        self.features = np.concatenate([gyro, acce], axis = 1)

    def get_feature(self):
        return self.features

    def get_aux(self):
        return 0

def load_sequences(seq_type, root_dir, data_list, **kwargs):
    features_all, targets_all, aux_all = [], [], []

    for i in range(len(data_list)):
        seq = seq_type(osp.join(root_dir, data_list[i]), **kwargs)
        feat, targ, aux = seq.get_feature(), seq.get_target(), seq.get_aux()
        # add feat, targ, aux to list
        features_all.append(feat)
        targets_all.append(targ)
        aux_all.append(aux)
    return features_all, targets_all, aux_all

def load_sequences_test(seq_type, root_dir, data_list, **kwargs):
    features_all, aux_all = [], []

    for i in range(len(data_list)):
        seq = seq_type(osp.join(root_dir, data_list[i]), **kwargs)
        feat, aux = seq.get_feature(), seq.get_aux()
        # add feat, targ, aux to list
        features_all.append(feat)
        aux_all.append(aux)
    return features_all, aux_all

class SequenceToSequenceDataset(Dataset):
    def __init__(self, seq_type, root_dir, data_list, step_size = 100, window_size = 400,
                 random_shift = 0, transform = None, **kwargs):
        super(SequenceToSequenceDataset, self).__init__()
        self.seq_type = seq_type
        self.feature_dim = seq_type.feature_dim
        self.target_dim = seq_type.target_dim
        self.aux_dim = seq_type.aux_dim
        self.window_size = window_size
        self.step_size = step_size
        self.random_shift = random_shift
        self.transform = transform
        self.projection_width = kwargs.get('projection_width', 0)

        self.data_path = [osp.join(root_dir, data) for data in data_list]
        self.index_map = []

        self.features, self.targets, aux = load_sequences(
            seq_type, root_dir, data_list, **kwargs)

        # Optionally smooth the sequence
        feat_sigma = kwargs.get('feature_sigma,', -1)
        targ_sigma = kwargs.get('target_sigma,', -1)
        if feat_sigma > 0:
            self.features = [gaussian_filter1d(feat, sigma=feat_sigma, axis=0) for feat in self.features]
        if targ_sigma > 0:
            self.targets = [gaussian_filter1d(targ, sigma=targ_sigma, axis=0) for targ in self.targets]

        max_norm = 3.0 #
        self.ts, self.orientations, self.gt_pos, self.local_v = [], [], [], []
        for i in range(len(data_list)):
            self.features[i] = self.features[i][:-1]
            self.targets[i] = self.targets[i]
            self.ts.append(aux[i][:-1, :1])
            self.orientations.append(aux[i][:-1, 1:5])
            self.gt_pos.append(aux[i][:-1, 5:8])

            velocity = np.linalg.norm(self.targets[i], axis=1)  # Remove outlier ground truth data
            bad_data = velocity > max_norm
            for j in range(window_size + random_shift, self.targets[i].shape[0], step_size):
                if not bad_data[j - window_size - random_shift:j + random_shift].any():
                    self.index_map.append([i, j])

        # if shuffle is necessary here? As the training data should be in a logical sequence
        if kwargs.get('shuffle', True):
            random.shuffle(self.index_map)

    def __getitem__(self, item):
        # output format: input, target, seq_id, frame_id
        seq_id, frame_id = self.index_map[item][0], self.index_map[item][1]

        feat = np.copy(self.features[seq_id][frame_id - self.window_size:frame_id])
        targ = np.copy(self.targets[seq_id][frame_id - self.window_size:frame_id])
        # random rotate the sequence in the horizontal plane
        if self.transform is not None:
            feat, targ = self.transform(feat, targ)

            return feat.astype(np.float32), targ.astype(np.float32), seq_id, frame_id

    def __len__(self):
        return len(self.index_map)

    def get_lstm_test_seq(self):
        return np.array(self.features).astype(np.float32), np.array(self.targets).astype(np.float32)

class SequenceToSequenceDatasetTest(Dataset):
    def __init__(self, seq_type, root_dir, data_list, step_size = 100, window_size = 400,
                 random_shift = 0, transform = None, **kwargs):
        super(SequenceToSequenceDatasetTest, self).__init__()
        self.seq_type = seq_type
        self.feature_dim = seq_type.feature_dim
        self.aux_dim = seq_type.aux_dim
        self.window_size = window_size
        self.step_size = step_size
        self.random_shift = random_shift
        self.transform = transform
        self.projection_width = kwargs.get('projection_width', 0)

        self.data_path = [osp.join(root_dir, data) for data in data_list]
        self.index_map = []

        self.features, aux = load_sequences_test(
            seq_type, root_dir, data_list, **kwargs)

        # Optionally smooth the sequence
        feat_sigma = kwargs.get('feature_sigma,', -1)
        if feat_sigma > 0:
            self.features = [gaussian_filter1d(feat, sigma=feat_sigma, axis=0) for feat in self.features]

        max_norm = 3.0 #
        for i in range(len(data_list)):
            self.features[i] = self.features[i][:-1]

        # if shuffle is necessary here? As the training data should be in a logical sequence
        if kwargs.get('shuffle', True):
            random.shuffle(self.index_map)

    def __getitem__(self, item):
        # output format: input, target, seq_id, frame_id
        seq_id, frame_id = self.index_map[item][0], self.index_map[item][1]
        feat = np.copy(self.features[seq_id][frame_id - self.window_size:frame_id])

        return feat.astype(np.float32), seq_id, frame_id

    def __len__(self):
        return len(self.index_map)

    def get_lstm_test_seq(self):
        return np.array(self.features).astype(np.float32)

def change_cf(ori, vectors):
    """
    Euler-Rodrigous formula v'=v+2s(rxv)+2rx(rxv)
    :param ori: quaternion [n]x4
    :param vectors: vector nx3
    :return: rotated vector nx3
    """
    assert ori.shape[-1] == 4
    assert vectors.shape[-1] == 3

    if len(ori.shape) == 1:
        ori = ori.reshape(1, -1)

    q_s = ori[:, :1]
    q_r = ori[:, 1:]

    tmp = np.cross(q_r, vectors)
    vectors = np.add(np.add(vectors, np.multiply(2, np.multiply(q_s, tmp))), np.multiply(2, np.cross(q_r, tmp)))
    return vectors

class ComposeTransform:
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, feat, targ, **kwargs):
        for t in self.transforms:
            feat, targ = t(feat, targ)
        return feat, targ

class RandomHoriRotateSeq:
    def __init__(self, input_format, output_format=None):
        """
        Rotate global input, global output by a random angle
        @:param input format - input feature vector(x,3) boundaries as array (E.g [0,3,6])
        @:param output format - output feature vector(x,2/3) boundaries as array (E.g [0,2,5])
                                if 2, 0 is appended as z.
        """
        self.i_f = input_format
        self.o_f = output_format

    def __call__(self, feature, target):
        a = np.random.random() * 2 * np.math.pi
        # print("Rotating by {} degrees", a/np.math.pi * 180)
        t = np.array([np.cos(a), 0, 0, np.sin(a)])

        for i in range(len(self.i_f) - 1):
            feature[:, self.i_f[i]: self.i_f[i + 1]] = \
                change_cf(t, feature[:, self.i_f[i]: self.i_f[i + 1]])

        for i in range(len(self.o_f) - 1):
            if self.o_f[i + 1] - self.o_f[i] == 3:
                # vector = target[:, self.o_f[i]: self.o_f[i + 1]]
                # target[:, self.o_f[i]: self.o_f[i + 1]] = change_cf(t, vector)
                vector = target[self.o_f[i]: self.o_f[i + 1]]
                target[:, self.o_f[i]: self.o_f[i + 1]] = change_cf(t, vector)
            elif self.o_f[i + 1] - self.o_f[i] == 2:
                vector = np.concatenate([target[:, self.o_f[i]: self.o_f[i + 1]], np.zeros([target.shape[0], 1])], axis=1)
                target[:, self.o_f[i]: self.o_f[i + 1]] = change_cf(t, vector)[:, :2]

        return feature.astype(np.float32), target.astype(np.float32)

class RandomHoriRotateSeqTensor:
    def __init__(self):
        """
        Rotate global input, global output by a random angle
        @:param input format - input feature vector(x,3) boundaries as array (E.g [0,3,6])
        @:param output format - output feature vector(x,2/3) boundaries as array (E.g [0,2,5])
                                if 2, 0 is appended as z.
        """

    def __call__(self, feature, target):
        # Tensor random rotation matrix
        a = torch.rand(1) * 2 * np.math.pi
        rotation_matrix_feat = torch.tensor([[torch.cos(a), torch.sin(a), 0, 0, 0, 0],
                                            [-torch.sin(a), torch.cos(a), 0, 0, 0, 0],
                                            [0, 0, 1, 0, 0, 0],
                                            [0, 0, 0, torch.cos(a), torch.sin(a), 0],
                                            [0, 0, 0, -torch.sin(a), torch.cos(a), 0],
                                            [0, 0, 0, 0, 0, 1]], dtype=torch.float32)

        rotation_matrix_targ = torch.tensor([[torch.cos(a), torch.sin(a)],
                                            [-torch.sin(a), torch.cos(a)]], dtype=torch.float32)

        # Matrix multiplication
        feature = torch.matmul(feature, rotation_matrix_feat)
        target = torch.matmul(target, rotation_matrix_targ)

        return feature, target

def get_dataset(root_dir, data_list, mode, **kwargs):
    # load config
    global_step_size = 0
    # input data includes: accelemeters, gyroscopes
    input_format = [0, 3, 6]
    # output data is the moving distance and its direction
    output_format = [0, 2]

    random_shift, shuffle, transforms = 0, False, []

    if mode == 'train':
        random_shift = global_step_size // 2
        shuffle = True
        transforms.append(RandomHoriRotateSeq(input_format, output_format))
        global_step_size = kwargs.get('step_size')
        transforms = ComposeTransform(transforms)
        seq_type = GlobSequence
        global_window_size = kwargs.get('window_size')
        dataset = SequenceToSequenceDataset(seq_type, root_dir, data_list, global_step_size,
                                            global_window_size, random_shift = random_shift,
                                            transform = transforms, shuffle = shuffle)
    elif mode == 'val':
        shuffle = True
        global_step_size = kwargs.get('step_size')
        transforms = ComposeTransform(transforms)
        seq_type = GlobSequence
        global_window_size = kwargs.get('window_size')
        dataset = SequenceToSequenceDataset(seq_type, root_dir, data_list, global_step_size,
                                            global_window_size, random_shift = random_shift,
                                            transform = transforms, shuffle = shuffle)
    elif mode == 'val_test':
        shuffle = False
        global_step_size = kwargs.get('test_step_size')
        transforms = ComposeTransform(transforms)
        seq_type = GlobSequence
        global_window_size = kwargs.get('window_size')
        dataset = SequenceToSequenceDataset(seq_type, root_dir, data_list, global_step_size,
                                            global_window_size, random_shift = random_shift,
                                            transform = transforms, shuffle = shuffle)
    elif mode == 'test':
        shuffle = False
        global_step_size = kwargs.get('test_step_size')
        transforms = ComposeTransform(transforms)
        seq_type = GlobSequenceTest
        global_window_size = kwargs.get('window_size')
        dataset = SequenceToSequenceDatasetTest(seq_type, root_dir, data_list, global_step_size,
                                                global_window_size, random_shift = random_shift,
                                                transform = transforms, shuffle = shuffle)

    return dataset

def read_dir(dir_path):
    # read dirs from dir_path
    for _, dirs, _ in os.walk(dir_path):
        return dirs

def get_train_dataset(root_dir, **kwargs):
    trainlist = read_dir(root_dir)
    return get_dataset(root_dir, trainlist, mode = 'train', **kwargs)

def get_valid_dataset(root_dir, **kwargs):
    validlist = read_dir(root_dir)
    return get_dataset(root_dir, validlist, mode = 'val', **kwargs)

def get_valid_test_dataset(root_dir, dir, **kwargs):
    return get_dataset(root_dir, dir, mode = 'val_test', **kwargs)

def get_test_dataset(root_dir, dir, **kwargs):
    return get_dataset(root_dir, dir, mode = 'test', **kwargs)

### 4. Criterion

In [25]:
import json
import os
import sys
import time
import random
import argparse
from os import path as osp
from pathlib import Path

import numpy as np
import torch

class GlobalPosLoss(torch.nn.Module):
    def __init__(self):
        """
        Calculate position loss in global coordinate frame
        Target :- Global Velocity
        Prediction :- Global Velocity
        """
        super(GlobalPosLoss, self).__init__()
        self.mse_loss = torch.nn.MSELoss(reduction = 'none')

    def forward(self, pred, targ):
        # dts = 1 / 200
        dts = 1
        pred = pred * dts
        targ = targ * dts
        gt_pos = torch.cumsum(targ[:, 1:, ], 1)
        pred_pos = torch.cumsum(pred[:, 1:, ], 1)
        loss = self.mse_loss(pred_pos, gt_pos)
        # calculate the sum of absolute trajectory error
        return torch.mean(loss)

class MSEAverage():
    def __init__(self):
        self.count = 0
        self.targets = []
        self.predictions = []
        self.average = []

    def add(self, pred, targ):
        self.targets.append(targ)
        self.predictions.append(pred)
        self.average.append(np.average((pred - targ) ** 2, axis=(0, 1)))
        # print("The shape of average is: ", np.array(self.average).shape)
        # print("THe shape of np.average(np.array(self.average), axis=0) is: ", np.average(np.array(self.average), axis=0).shape)
        self.count += 1

    def get_channel_avg(self):
        average = np.average(np.array(self.average), axis=0)
        return average

    def get_total_avg(self):
        average = np.average(np.array(self.average), axis=0)
        return np.average(average)

    def get_elements(self, axis):
        return np.concatenate(self.predictions, axis=axis), np.concatenate(self.targets, axis=axis)

def reconstruct_traj(vector, **kwargs):
    global_sampling_rate = kwargs.get('sampling_rate', None)
    # reconstruct the vector to one sequence
    # velocity_sequence = vector.reshape(len(vector) * global_window_size, global_output_channel)

    velocity_sequence = vector * 1 / global_sampling_rate
    glob_pos = np.cumsum(velocity_sequence, axis = 0)

    return glob_pos

def compute_absolute_trajectory_error(pred, gt):
    """
    The Absolute Trajectory Error (ATE) defined in:
    A Benchmark for the evaluation of RGB-D SLAM Systems
    http://ais.informatik.uni-freiburg.de/publications/papers/sturm12iros.pdf

    Args:
        est: estimated trajectory
        gt: ground truth trajectory. It must have the same shape as est.

    Return:
        Absolution trajectory error, which is the Root Mean Squared Error between
        two trajectories.
    """
    return np.sqrt(np.mean((pred - gt) ** 2))


def compute_relative_trajectory_error(est, gt, delta, max_delta = -1):
    """
    The Relative Trajectory Error (RTE) defined in:
    A Benchmark for the evaluation of RGB-D SLAM Systems
    http://ais.informatik.uni-freiburg.de/publications/papers/sturm12iros.pdf

    Args:
        est: the estimated trajectory
        gt: the ground truth trajectory.
        delta: fixed window size. If set to -1, the average of all RTE up to max_delta will be computed.
        max_delta: maximum delta. If -1 is provided, it will be set to the length of trajectories.

    Returns:
        Relative trajectory error. This is the mean value under different delta.
    """
    if max_delta == -1:
        max_delta = est.shape[0]
    # print("delta: ", delta)
    deltas = np.array([min(delta, max_delta - 1)])
    # deltas = np.array([delta]) if delta > 0 else np.arange(1, min(est.shape[0], max_delta))
    rtes = np.zeros(deltas.shape[0])
    for i in range(deltas.shape[0]):
        # For each delta, the RTE is computed as the RMSE of endpoint drifts from fixed windows
        # slided through the trajectory.
        err = est[deltas[i]:] + gt[:-deltas[i]] - est[:-deltas[i]] - gt[deltas[i]:]
        rtes[i] = np.sqrt(np.mean(err ** 2))

    # The average of RTE of all window sized is returned.
    rtes = rtes[~np.isnan(rtes)]
    return np.mean(rtes)

def compute_position_drift_error(pos_pred, pos_gt):
    """
    Params:
        pos_pred: predicted position [seq_len, 2]
        pos_gt: ground truth position [seq_len, 2]
    """
    position_drift = np.linalg.norm((pos_gt[-1] - pos_pred[-1]))
    delta_position = pos_gt[1:] - pos_gt[:-1]
    delta_length = np.linalg.norm(delta_position, axis=1)
    moving_len = np.sum(delta_length)

    return position_drift / moving_len

def compute_distance_error(pos_pred, pos_gt):
    """
    Params:
        pos_pred: predicted position [seq_len, 2]
        pos_gt: ground truth position [seq_len, 2]
    """
    distance_error = np.linalg.norm((pos_gt - pos_pred), axis=1)

    return distance_error

def compute_heading_error(preds, targets):
    """
    Params:
        pos_pred: predicted position [seq_len, 2]
        pos_gt: ground truth position [seq_len, 2]
    """
    # Find the index of preds with zero norm
    zero_norm_index = np.where(np.linalg.norm(preds, axis=1) == 0)[0]
    # Remove the zero norm index
    preds = np.delete(preds, zero_norm_index, axis=0)
    targets = np.delete(targets, zero_norm_index, axis=0)
    # Find the index of targets with zero norm
    zero_norm_index = np.where(np.linalg.norm(targets, axis=1) == 0)[0]
    # Remove the zero norm index
    preds = np.delete(preds, zero_norm_index, axis=0)
    targets = np.delete(targets, zero_norm_index, axis=0)

    pred_v = np.linalg.norm(preds, axis=1)
    targ_v = np.linalg.norm(targets, axis=1)

    pred_o = preds / pred_v[:, np.newaxis]
    targ_o = targets / targ_v[:, np.newaxis]

    # calculate the heading angle of the predicted and target vectors
    pred_heading = np.arctan2(pred_o[:, 1], pred_o[:, 0])
    targ_heading = np.arctan2(targ_o[:, 1], targ_o[:, 0])

    # calculate the heading error
    heading_error = np.mean(np.abs(pred_heading - targ_heading))
    # convert to degrees
    heading_error = heading_error * 180 / np.pi

    return heading_error

### 5. Utils

In [47]:
import os
import numpy as np
import matplotlib.pyplot as plt
import os.path as osp

def format_string(*argv, sep=' '):
    result = ''
    for val in argv:
        if isinstance(val, (tuple, list, np.ndarray)):
            for v in val:
                result += format_string(v, sep=sep) + sep
        else:
            result += str(val) + sep
    return result[:-1]

def draw_trajectory(pos_pred, pos_gt, dir_name, ate, rte, **kwargs):
    """
    :param data:
    :pos_pred: (N, 2)
    :pos_gt: (N, 2)
    :dir_name: test directory
    :ate: average trajectory error
    :rte: relative trajectory error
    """
    global_out_dir = kwargs.get('out_dir', None)
    if global_out_dir is None:
        raise ValueError('out_dir is needed')

    plt.figure(figsize=(8, 5), dpi = 400)
    plt.plot(pos_pred[:, 0], pos_pred[:, 1], label = 'Predicted')
    plt.plot(pos_gt[:, 0], pos_gt[:, 1], label = 'Ground truth')
    plt.title(dir_name)
    print("make title success")
    # Show words in latex format
    plt.xlabel('$m$')
    plt.ylabel('$m$')
    plt.axis('equal')
    plt.legend()
    plt.title('ATE:{:.3f}, RTE:{:.3f}'.format(ate, rte), y = 0, loc = 'right')
    plt.show()
    plt.savefig(osp.join(global_out_dir, '{}.png'.format(dir_name)))

pos_pred= (N, 2)
pos_gt= (N, 2)
dir_name= '/kaggle/working/test_results'
draw_trajectory(pos_pred, pos_gt, dir_name, ate, rte, out_dir='output_directory')


NameError: name 'N' is not defined

### 6. Main Function

In [27]:
import os
import time
import argparse
from os import path as osp
from pathlib import Path
from tqdm import tqdm
import numpy as np
import torch
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader

def get_model(mode, **kwargs):
    global_input_channel = kwargs.get('input_channel')
    global_output_channel = kwargs.get('output_channel')
    global_dropout = kwargs.get('dropout')
    global_batch_size = kwargs.get('batch_size')
    global_test_batch_size = kwargs.get('test_batch_size')
    global_device = kwargs.get('device')
    global_layers = kwargs.get('layers')
    global_layer_size = kwargs.get('layer_size')

    if mode == 'train':
        print("LSTM model")
        network = BilinearLSTMSeqNetwork(global_input_channel, global_output_channel, global_batch_size, global_device,
                        lstm_size = global_layer_size, lstm_layers = global_layers, dropout = global_dropout).to(global_device)
    elif mode == 'test':
        print("LSTM model")
        network = BilinearLSTMSeqNetwork(global_input_channel, global_output_channel, global_test_batch_size, global_device,
                        lstm_size = global_layer_size, lstm_layers = global_layers, dropout = global_dropout).to(global_device)
    try:
        pytorch_total_params = sum(p.numel() for p in network.parameters() if p.requires_grad)
    except:
        pytorch_total_params = 0
    print('Network constructed. trainable parameters: {}'.format(pytorch_total_params))
    return network

def train(**kwargs):
    # load config
    global_data_dir = kwargs.get('data_dir')
    global_val_data_dir = kwargs.get('val_data_dir')
    global_batch_size = kwargs.get('batch_size')
    global_epochs = kwargs.get('epochs')
    global_num_workers = kwargs.get('num_workers')
    global_device = kwargs.get('device')
    global_out_dir = kwargs.get('out_dir', None)
    global_learning_rate = kwargs.get('learning_rate')
    global_save_interval = kwargs.get('save_interval')
    global_sampling_rate = kwargs.get('sampling_rate')
    # Loading data
    start_t = time.time()
    train_dataset = get_train_dataset(global_data_dir, **kwargs)
    val_dataset = get_valid_dataset(global_val_data_dir, **kwargs)
    train_loader = DataLoader(train_dataset, batch_size = global_batch_size, num_workers = global_num_workers, shuffle = True,
                              drop_last = True)
    val_loader = DataLoader(val_dataset, batch_size = global_batch_size, shuffle = True, drop_last = True)
    end_t = time.time()
    print('Training and validation set loaded. Time usage: {:.3f}s'.format(end_t - start_t))
    # read val for sequence test
    test_dirs = read_dir(global_val_data_dir)



    global device
    device = torch.device(global_device if torch.cuda.is_available() else 'cpu')
    print("Device: {}".format(device))

    if global_out_dir:
        if not osp.isdir(global_out_dir):
            os.makedirs(global_out_dir)
        if not osp.isdir(osp.join(global_out_dir, 'checkpoints')):
            os.makedirs(osp.join(global_out_dir, 'checkpoints'))

    print('\nNumber of train samples: {}'.format(len(train_dataset)))
    train_mini_batches = len(train_loader)
    if val_dataset:
        print('Number of val samples: {}'.format(len(val_dataset)))
        val_mini_batches = len(val_loader)

    network = get_model('train', **kwargs).to(device)
    testnetwork = get_model('test', **kwargs).to(device)
    criterion = GlobalPosLoss()

    optimizer = torch.optim.Adam(network.parameters(), global_learning_rate)
    scheduler = ReduceLROnPlateau(optimizer, 'min', patience = 10, factor = 0.75, verbose = True, eps = 1e-12)
    quiet_mode = kwargs.get('quiet', False)
    use_scheduler = kwargs.get('use_scheduler', True)

    start_epoch = 0
    step = 0
    best_val_loss = np.inf
    train_errs = np.zeros(global_epochs)

    print("Starting from epoch {}".format(start_epoch))
    try:
        for epoch in range(start_epoch, global_epochs):
            log_line = ''
            network.train()
            train_vel = MSEAverage()
            train_loss = 0
            start_t = time.time()

            for bid, batch in enumerate(tqdm(train_loader)):
                feat, targ, _, _ = batch
                feat, targ = feat.to(device), targ.to(device)
                optimizer.zero_grad()
                predicted = network(feat)
                train_vel.add(predicted.cpu().detach().numpy(), targ.cpu().detach().numpy())
                loss = criterion(predicted, targ)
                train_loss += loss.cpu().detach().numpy()
                loss.backward()
                optimizer.step()
                step += 1

            train_errs[epoch] = train_loss / train_mini_batches
            end_t = time.time()
            if not quiet_mode:
                print('-' * 25)
                print('Epoch {}, time usage: {:.3f}s, loss: {}, vec_loss {}/{:.6f}'.format(
                    epoch, end_t - start_t, train_errs[epoch], train_vel.get_channel_avg(), train_vel.get_total_avg()))

            saved_model = False
            if val_loader:
                network.eval()
                val_vel = MSEAverage()
                val_loss = 0
                for bid, batch in enumerate(val_loader):
                    feat, targ, _, _ = batch
                    feat, targ = feat.to(device), targ.to(device)
                    optimizer.zero_grad()
                    pred = network(feat)
                    val_vel.add(pred.cpu().detach().numpy(), targ.cpu().detach().numpy())
                    val_loss += criterion(pred, targ).cpu().detach().numpy()
                val_loss = val_loss / val_mini_batches

                if not quiet_mode:
                    print('Validation loss: {} vec_loss: {}/{:.6f}'.format(val_loss, val_vel.get_channel_avg(),
                                                                                val_vel.get_total_avg()))

                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    saved_model = True
                    if global_out_dir:
                        model_path = osp.join(global_out_dir, 'checkpoints', 'checkpoint_%d.pt' % epoch)
                        torch.save({'model_state_dict': network.state_dict(),
                                    'epoch': epoch,
                                    'loss': train_errs[epoch],
                                    'optimizer_state_dict': optimizer.state_dict()}, model_path)
                        print('Best Validation Model saved to ' + model_path)
                if use_scheduler:
                    scheduler.step(val_loss)

            print("Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE")
            testnetwork.load_state_dict(network.state_dict())
            network.eval().to(device)
            ate_all, rte_all, pde_all = [], [], []
            aye_all = []
            # Every minute
            pred_per_min = global_sampling_rate * 60 # 2 + 0.5*(x - 1) = 60
            # Test for every sequence
            for i in range(len(test_dirs)):
            # for i in range(1):
                seq_dir = [test_dirs[i]]
                seq_dataset = get_valid_test_dataset(global_val_data_dir, seq_dir, **kwargs)
                feat, targ = seq_dataset.get_lstm_test_seq()

                feat = torch.from_numpy(feat).to(device)
                pred = testnetwork(feat).cpu().detach().numpy()

                pred = np.squeeze(pred, axis = 0)
                targ = np.squeeze(targ, axis = 0)
                # Reconstruct the trajectory
                pos_pred = reconstruct_traj(pred, **kwargs)
                pos_gt = reconstruct_traj(targ, **kwargs)

                # Compute the ATE and RTE
                ate = compute_absolute_trajectory_error(pos_pred, pos_gt)
                rte = compute_relative_trajectory_error(pos_pred, pos_gt, delta = pred_per_min)
                pde = compute_position_drift_error(pos_pred, pos_gt)
                heading_error = compute_heading_error(pred, targ)
                ate_all.append(ate)
                if rte >= 0:
                    rte_all.append(rte)
                pde_all.append(pde)
                aye_all.append(heading_error)

            ate_all = np.array(ate_all)
            rte_all = np.array(rte_all)
            pde_all = np.array(pde_all)
            aye_all = np.array(aye_all)

            measure = format_string('ATE', 'RTE', 'PDE', 'AYE',  sep = '\t')
            values = format_string(np.mean(ate_all), np.mean(rte_all), np.mean(pde_all), np.mean(aye_all), sep = '\t')
            print(measure + '\n' + values)

            if global_out_dir and not saved_model and (epoch + 1) % global_save_interval == 0:  # save even with validation
                model_path = osp.join(global_out_dir, 'checkpoints', 'icheckpoint_%d.pt' % epoch)
                torch.save({'ate': np .mean(ate_all),
                            'rte': np.mean(rte_all),
                            'pde': np.mean(pde_all),
                            'aye': np.mean(aye_all),}, model_path)
                torch.save({'model_state_dict': network.state_dict(),
                            'epoch': epoch,
                            'loss': train_errs[epoch],
                            'optimizer_state_dict': optimizer.state_dict()}, model_path)
                print('Model saved to ' + model_path)

            if np.isnan(train_loss):
                print("Invalid value. Stopping training.")
                break
    except KeyboardInterrupt:
        print('-' * 60)
        print('Early terminate')

    print('Training completed')
    if global_out_dir:
        model_path = osp.join(global_out_dir, 'checkpoints', 'checkpoint_latest.pt')
        torch.save({'model_state_dict': network.state_dict(),
                    'epoch': epoch,
                    'optimizer_state_dict': optimizer.state_dict()}, model_path)

# Test and Save the Trajectory for Both Seen and Unseen Data
seen_unseen_dataset = {'id1': 'tracermini_hw101_test20230311112635T',
                'id2': 'tracermini_hw101_test20230311111507T',
                'id3': 'tracermini_hw101_test20230311112235T',
                'id4': 'tracermini_hw101_test20230313011357T',
                'id5': 'tracermini_hw101_test20230311111842T',
                'id6': 'tracermini_hw101_test20230313010954T',
                'id7': 'tracermini_hw101_test20230313010546T',
                'id8': 'tracermini_hw101_test20230313011731T',
                'id9': 'tracermini_hw101_test20230313013335T',
                'id10': 'tracermini_hw101_test20230311111027T',
                'id11': 'tracermini_hw101_test20230313012956T',
                'id12': 'tracermini_hw101_test20230313010204T',
                'id13': 'tracermini_unseen_hw520230314091844T',
                'id14': 'tracermini_unseen_cym20230314101001T',
                'id15': 'tracermini_unseen_hw520230314082319T',
                'id16': 'tracermini_unseen_cym20230314101636T',
                'id17': 'tracermini_unseen_hw520230314083031T',
                'id18': 'tracermini_unseen_hw520230314091110T',
                'id19': 'tracermini_unseen_cym20230314100816T',
                'id20': 'tracermini_unseen_cym20230314101230T',
                'id21': 'tracermini_unseen_cym20230314100559T',
                'id22': 'tracermini_unseen_hw520230314081212T',
                'id23': 'tracermini_unseen_hw520230314082000T',
                'id24': 'tracermini_unseen_hw520230314091603T',
                'id25': 'tracermini_unseen_cym20230314100325T',
                'id26': 'tracermini_unseen_cym20230314101434T',
                'id27': 'tracermini_unseen_cym20230314100103T',
                'id28': 'tracermini_unseen_hw520230314091338T',
                'id29': 'tracermini_unseen_cym20230314101010T',
                'id30': 'tracermini_unseen_cym20230314101850T',
                'id31': 'tracermini_unseen_hw520230314081542T',
                'id32': 'tracermini_unseen_hw520230314090739T'}
seen_unseen_pred = {'id1': [],
            'id2': [],
            'id3': [],
            'id4': [],
            'id5': [],
            'id6': [],
            'id7': [],
            'id8': [],
            'id9': [],
            'id10': [],
            'id11': [],
            'id12': [],
            'id13': [],
            'id14': [],
            'id15': [],
            'id16': [],
            'id17': [],
            'id18': [],
            'id19': [],
            'id20': [],
            'id21': [],
            'id22': [],
            'id23': [],
            'id24': [],
            'id25': [],
            'id26': [],
            'id27': [],
            'id28': [],
            'id29': [],
            'id30': [],
            'id31': [],
            'id32': []}
def test_lstm(**kwargs):
    # load config
    global_dataset = kwargs.get('dataset')
    global_model_type = kwargs.get('model_type')
    global_num_workers = kwargs.get('num_workers')
    global_sampling_rate = kwargs.get('sampling_rate')
    global_device = kwargs.get('device')
    global_out_dir = kwargs.get('out_dir', None)
    global_test_dir = kwargs.get('test_dir', None)
    global_out_dir = kwargs.get('out_dir', None)
    global_model_path = kwargs.get('model_path', None)

    global device
    device = torch.device(global_device if torch.cuda.is_available() else 'cpu')

    if global_test_dir is None:
        raise ValueError('Test_path is needed.')

    # read dirs
    test_dirs = read_dir(global_test_dir)

    # Make sure the test output dir exists
    if global_out_dir and not osp.exists(global_out_dir):
        os.makedirs(global_out_dir)
    # Load the model config
    if global_model_path is None:
        raise ValueError('Model path is needed.')

    checkpoint = torch.load(global_model_path, map_location=global_device)

    network = get_model('test', **kwargs)
    network.load_state_dict(checkpoint.get('model_state_dict'))
    # network.load_state_dict(checkpoint.get('model_state_dict'))
    print("The model is loaded.")
    network.eval().to(device)
    print('Model {} loaded to device {}.'.format(global_model_path, device))

    # Test for every sequence
    for i in range(len(test_dirs)):
    # for i in range(1):
        seq_dir = [test_dirs[i]]
        seq_dataset = get_test_dataset(global_test_dir, seq_dir, **kwargs)
        feat = seq_dataset.get_lstm_test_seq()

        feat = torch.from_numpy(feat).to(device)
        pred = network(feat).cpu().detach().numpy()

        pred = np.squeeze(pred, axis = 0)

        # Reconstruct the trajectory
        print("Reconstruct the {}".format(seq_dir[0]))
        pos_pred = reconstruct_traj(pred, **kwargs)

        # Find the index of the sequence
        for key, value in seen_unseen_dataset.items():
            if value == seq_dir[0]:
                index = key
                seen_unseen_pred[index] = pos_pred

        # # Make directory
        # if global_out_dir and not osp.exists(osp.join(global_out_dir, seq_dir[0])):
        #     os.makedirs(osp.join(global_out_dir, seq_dir[0]))
        # # Save the trajectory
        # np.save(osp.join(global_out_dir, seq_dir[0], 'pred.npy'), pos_pred)

def test(**kwargs):
    # Model type
    print("Testing LSTM model")
    test_lstm(**kwargs)

### 7. Training

In [28]:
# Load config settings
kwargs = load_config()

import warnings
# Suspend warnings
warnings.filterwarnings('ignore')
train(**kwargs)

Training and validation set loaded. Time usage: 12.639s
Device: cuda:0

Number of train samples: 18634
Number of val samples: 3270
LSTM model
Network constructed. trainable parameters: 216620
LSTM model
Network constructed. trainable parameters: 216620
Starting from epoch 0


100%|██████████| 258/258 [00:10<00:00, 24.81it/s]


-------------------------
Epoch 0, time usage: 10.402s, loss: 6231.685725722202, vec_loss [0.27687997 0.34201962]/0.309450
Validation loss: 4251.38990342882 vec_loss: [0.31316587 0.35118878]/0.332177
Best Validation Model saved to /kaggle/working/prediction_model/checkpoints/checkpoint_0.pt
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
17.489653	9.011737216602672	0.32183152	119.31425731404524


100%|██████████| 258/258 [00:10<00:00, 25.17it/s]


-------------------------
Epoch 1, time usage: 10.253s, loss: 3582.7976869095205, vec_loss [0.27719176 0.48369846]/0.380445
Validation loss: 3534.950244140625 vec_loss: [0.25240335 0.4703616 ]/0.361382
Best Validation Model saved to /kaggle/working/prediction_model/checkpoints/checkpoint_1.pt
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
22.017792	10.387772560119629	0.39470088	111.90974423655649


100%|██████████| 258/258 [00:10<00:00, 25.09it/s]


-------------------------
Epoch 2, time usage: 10.288s, loss: 3320.765686508297, vec_loss [0.23451698 0.52801895]/0.381268
Validation loss: 2861.336691623264 vec_loss: [0.19581926 0.45754877]/0.326684
Best Validation Model saved to /kaggle/working/prediction_model/checkpoints/checkpoint_2.pt
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
30.567223	13.011122876947576	0.56549996	117.34427428398895


100%|██████████| 258/258 [00:10<00:00, 25.45it/s]


-------------------------
Epoch 3, time usage: 10.142s, loss: 3113.277181935865, vec_loss [0.1975892  0.45087337]/0.324231
Validation loss: 2859.6264377170137 vec_loss: [0.19801208 0.3491679 ]/0.273590
Best Validation Model saved to /kaggle/working/prediction_model/checkpoints/checkpoint_3.pt
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
18.921253	9.464521278034557	0.3405883	110.75101568331591


100%|██████████| 258/258 [00:10<00:00, 24.82it/s]


-------------------------
Epoch 4, time usage: 10.399s, loss: 2933.0286023043845, vec_loss [0.1858126  0.37985796]/0.282835
Validation loss: 2687.579931640625 vec_loss: [0.16014013 0.3421206 ]/0.251130
Best Validation Model saved to /kaggle/working/prediction_model/checkpoints/checkpoint_4.pt
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
13.293316	7.9177279038862745	0.23081931	104.24184912985146


100%|██████████| 258/258 [00:10<00:00, 25.41it/s]


-------------------------
Epoch 5, time usage: 10.158s, loss: 2813.747224556383, vec_loss [0.17310004 0.3417965 ]/0.257448
Validation loss: 2549.1943739149306 vec_loss: [0.15817368 0.29140264]/0.224788
Best Validation Model saved to /kaggle/working/prediction_model/checkpoints/checkpoint_5.pt
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
26.192657	10.699111765081232	0.49024618	110.35610283847382


100%|██████████| 258/258 [00:10<00:00, 25.27it/s]


-------------------------
Epoch 6, time usage: 10.213s, loss: 2766.5345236608223, vec_loss [0.17444341 0.30622843]/0.240336
Validation loss: 3139.0931206597224 vec_loss: [0.17816557 0.2819229 ]/0.230044
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
12.99671	5.722079428759488	0.22365032	90.17810527542284


100%|██████████| 258/258 [00:10<00:00, 25.38it/s]


-------------------------
Epoch 7, time usage: 10.170s, loss: 2576.9627598015836, vec_loss [0.17031042 0.28298017]/0.226645
Validation loss: 2205.071511501736 vec_loss: [0.14933573 0.24721536]/0.198276
Best Validation Model saved to /kaggle/working/prediction_model/checkpoints/checkpoint_7.pt
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
14.3639345	6.188002933155406	0.25182846	97.46343155189487


100%|██████████| 258/258 [00:10<00:00, 25.35it/s]


-------------------------
Epoch 8, time usage: 10.180s, loss: 2545.132871406023, vec_loss [0.16281165 0.26769388]/0.215253
Validation loss: 2830.7053955078127 vec_loss: [0.14895198 0.26608032]/0.207516
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
27.857359	9.94462758844549	0.5013234	99.93217983219701


100%|██████████| 258/258 [00:10<00:00, 25.06it/s]


-------------------------
Epoch 9, time usage: 10.299s, loss: 2504.1118821728137, vec_loss [0.16067211 0.27177548]/0.216224
Validation loss: 2261.465131293403 vec_loss: [0.15206492 0.24580202]/0.198933
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
8.808211	4.827901406721636	0.14279361	84.76305492750329


100%|██████████| 258/258 [00:10<00:00, 25.38it/s]


-------------------------
Epoch 10, time usage: 10.169s, loss: 2342.183153729106, vec_loss [0.15412046 0.27070004]/0.212410
Validation loss: 2055.9482123480902 vec_loss: [0.13877141 0.242378  ]/0.190575
Best Validation Model saved to /kaggle/working/prediction_model/checkpoints/checkpoint_10.pt
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
7.130753	3.5152304606004194	0.11507973	87.20398446824228


100%|██████████| 258/258 [00:10<00:00, 25.38it/s]


-------------------------
Epoch 11, time usage: 10.166s, loss: 2347.016940567845, vec_loss [0.1594503  0.26998118]/0.214716
Validation loss: 1954.4348090277779 vec_loss: [0.15052205 0.24209578]/0.196309
Best Validation Model saved to /kaggle/working/prediction_model/checkpoints/checkpoint_11.pt
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
15.965393	6.067760944366455	0.29275963	87.80159186967255


100%|██████████| 258/258 [00:10<00:00, 25.19it/s]


-------------------------
Epoch 12, time usage: 10.246s, loss: 2333.817286055217, vec_loss [0.16526112 0.26277223]/0.214017
Validation loss: 2093.711764865451 vec_loss: [0.14294259 0.2539258 ]/0.198434
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
17.446156	6.488318161530928	0.32412165	87.8327319439512


100%|██████████| 258/258 [00:10<00:00, 25.43it/s]


-------------------------
Epoch 13, time usage: 10.150s, loss: 2267.5608572553297, vec_loss [0.15428352 0.2885424 ]/0.221413
Validation loss: 1862.6039984809029 vec_loss: [0.13464794 0.24830036]/0.191474
Best Validation Model saved to /kaggle/working/prediction_model/checkpoints/checkpoint_13.pt
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
6.9646077	3.63158934766596	0.11741492	84.43519839622722


100%|██████████| 258/258 [00:10<00:00, 25.28it/s]


-------------------------
Epoch 14, time usage: 10.208s, loss: 2214.1620835888293, vec_loss [0.15802737 0.28112257]/0.219575
Validation loss: 2108.45090874566 vec_loss: [0.14920686 0.2661271 ]/0.207667
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
12.734095	5.0034480528398	0.23645209	89.79966808871139


100%|██████████| 258/258 [00:10<00:00, 25.46it/s]


-------------------------
Epoch 15, time usage: 10.137s, loss: 2255.0547416746153, vec_loss [0.15922491 0.2770152 ]/0.218120
Validation loss: 1902.0355170355904 vec_loss: [0.16236146 0.25241357]/0.207388
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
10.052685	4.2975616021589795	0.18149933	84.75774662866561


100%|██████████| 258/258 [00:10<00:00, 25.38it/s]


-------------------------
Epoch 16, time usage: 10.171s, loss: 2120.6286453128787, vec_loss [0.16191278 0.273125  ]/0.217519
Validation loss: 1722.67291531033 vec_loss: [0.1344919  0.24901205]/0.191752
Best Validation Model saved to /kaggle/working/prediction_model/checkpoints/checkpoint_16.pt
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
7.3602366	3.58837628364563	0.12806179	84.18145624763496


100%|██████████| 258/258 [00:10<00:00, 25.44it/s]


-------------------------
Epoch 17, time usage: 10.146s, loss: 2143.1654996650163, vec_loss [0.15195525 0.271295  ]/0.211625
Validation loss: 2141.5477105034724 vec_loss: [0.1629073  0.25119045]/0.207049
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
12.975432	5.2561657645485615	0.24864879	78.85006702408003


100%|██████████| 258/258 [00:10<00:00, 25.32it/s]


-------------------------
Epoch 18, time usage: 10.192s, loss: 2125.1647215850594, vec_loss [0.15716992 0.2762029 ]/0.216686
Validation loss: 2116.2652777777776 vec_loss: [0.1604551 0.2776858]/0.219070
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
11.282751	4.75612245906483	0.21195358	84.02781300503437


100%|██████████| 258/258 [00:10<00:00, 24.93it/s]


-------------------------
Epoch 19, time usage: 10.351s, loss: 2102.9707177923633, vec_loss [0.15698385 0.27664474]/0.216814
Validation loss: 1894.3981119791667 vec_loss: [0.15078877 0.23709299]/0.193941
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
7.6665893	3.8909741314974697	0.14211892	85.65736077906861
Model saved to /kaggle/working/prediction_model/checkpoints/icheckpoint_19.pt


100%|██████████| 258/258 [00:10<00:00, 25.43it/s]


-------------------------
Epoch 20, time usage: 10.147s, loss: 2039.4086613618126, vec_loss [0.15268712 0.2724987 ]/0.212593
Validation loss: 1879.3773003472222 vec_loss: [0.1365245  0.27383557]/0.205180
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
14.13045	5.4432831460779365	0.26949927	82.80634481778317


100%|██████████| 258/258 [00:10<00:00, 25.34it/s]


-------------------------
Epoch 21, time usage: 10.185s, loss: 2040.564459985541, vec_loss [0.15552582 0.26762322]/0.211575
Validation loss: 1765.5643174913193 vec_loss: [0.1507226 0.2192913]/0.185007
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
7.6903896	3.962141838940707	0.14006573	80.16787006216394


100%|██████████| 258/258 [00:10<00:00, 25.37it/s]


-------------------------
Epoch 22, time usage: 10.171s, loss: 1950.0990288313044, vec_loss [0.15690252 0.24759007]/0.202246
Validation loss: 1724.9018798828124 vec_loss: [0.15081392 0.24070464]/0.195759
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
8.954512	4.100091197273948	0.16675091	82.90342726386689


100%|██████████| 258/258 [00:10<00:00, 25.35it/s]


-------------------------
Epoch 23, time usage: 10.182s, loss: 1972.5088780129602, vec_loss [0.1521423  0.26027805]/0.206210
Validation loss: 1603.8458713107639 vec_loss: [0.1389426  0.21287791]/0.175910
Best Validation Model saved to /kaggle/working/prediction_model/checkpoints/checkpoint_23.pt
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
7.315404	3.5889184149828823	0.14215566	74.6038477001098


100%|██████████| 258/258 [00:10<00:00, 25.36it/s]


-------------------------
Epoch 24, time usage: 10.177s, loss: 1977.21106430172, vec_loss [0.1460762 0.2583184]/0.202197
Validation loss: 1726.7187269422743 vec_loss: [0.12555619 0.24113734]/0.183347
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
11.320698	4.83999549258839	0.21094304	80.38897384134415


100%|██████████| 258/258 [00:10<00:00, 25.42it/s]


-------------------------
Epoch 25, time usage: 10.151s, loss: 1918.0672027824462, vec_loss [0.14619517 0.25619122]/0.201193
Validation loss: 1716.0901638454861 vec_loss: [0.12482962 0.22650515]/0.175667
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
8.150663	3.91960072517395	0.13917781	82.52777803618937


100%|██████████| 258/258 [00:10<00:00, 24.98it/s]


-------------------------
Epoch 26, time usage: 10.331s, loss: 1980.8824507838997, vec_loss [0.1430872  0.25139043]/0.197239
Validation loss: 1535.6253662109375 vec_loss: [0.12895124 0.23462638]/0.181789
Best Validation Model saved to /kaggle/working/prediction_model/checkpoints/checkpoint_26.pt
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
9.748684	4.366967504674738	0.15457183	83.19576681145888


100%|██████████| 258/258 [00:10<00:00, 25.42it/s]


-------------------------
Epoch 27, time usage: 10.152s, loss: 1857.0992043665212, vec_loss [0.14491546 0.23550135]/0.190208
Validation loss: 1591.946149359809 vec_loss: [0.14108743 0.22702283]/0.184055
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
5.6776967	3.1733072129162876	0.11237728	78.10074621895285


100%|██████████| 258/258 [00:10<00:00, 25.20it/s]


-------------------------
Epoch 28, time usage: 10.240s, loss: 1885.7590824097626, vec_loss [0.13991734 0.24214242]/0.191030
Validation loss: 1565.345145670573 vec_loss: [0.13461453 0.21964084]/0.177128
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
8.807873	3.9474116455424917	0.16741472	81.11560662082739


100%|██████████| 258/258 [00:10<00:00, 25.35it/s]


-------------------------
Epoch 29, time usage: 10.180s, loss: 1827.354367278343, vec_loss [0.13996984 0.23542102]/0.187695
Validation loss: 1641.2385226779513 vec_loss: [0.12283069 0.21702372]/0.169927
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
6.009633	3.7812120697715064	0.11509531	82.89676099923696


100%|██████████| 258/258 [00:10<00:00, 25.24it/s]


-------------------------
Epoch 30, time usage: 10.226s, loss: 1814.49172453178, vec_loss [0.13637702 0.24206413]/0.189221
Validation loss: 1367.8524441189236 vec_loss: [0.11693339 0.20876569]/0.162850
Best Validation Model saved to /kaggle/working/prediction_model/checkpoints/checkpoint_30.pt
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
6.336218	3.141607956452803	0.10395086	78.98297443328673


100%|██████████| 258/258 [00:10<00:00, 25.20it/s]


-------------------------
Epoch 31, time usage: 10.240s, loss: 1855.7240503592084, vec_loss [0.14249353 0.2376228 ]/0.190058
Validation loss: 1705.452303059896 vec_loss: [0.13065536 0.22965753]/0.180156
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
13.034088	5.271572589874268	0.23054376	86.84139828781555


100%|██████████| 258/258 [00:10<00:00, 25.38it/s]


-------------------------
Epoch 32, time usage: 10.168s, loss: 1819.055328605711, vec_loss [0.14144604 0.23856287]/0.190004
Validation loss: 1428.4985785590277 vec_loss: [0.11899353 0.2130021 ]/0.165998
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
4.6169662	2.8891997770829634	0.077934176	76.7986600670754


100%|██████████| 258/258 [00:10<00:00, 25.15it/s]


-------------------------
Epoch 33, time usage: 10.262s, loss: 1745.1436956834423, vec_loss [0.13368562 0.2326111 ]/0.183148
Validation loss: 1365.1490926106771 vec_loss: [0.11237953 0.20370427]/0.158042
Best Validation Model saved to /kaggle/working/prediction_model/checkpoints/checkpoint_33.pt
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
8.265734	3.830613136291504	0.13548703	81.21239723155564


100%|██████████| 258/258 [00:10<00:00, 25.52it/s]


-------------------------
Epoch 34, time usage: 10.113s, loss: 1734.4506536675979, vec_loss [0.14376865 0.23291843]/0.188344
Validation loss: 1497.184507921007 vec_loss: [0.13989143 0.21199428]/0.175943
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
6.605554	3.9643837972120806	0.13328968	79.5338273270694


100%|██████████| 258/258 [00:10<00:00, 24.84it/s]


-------------------------
Epoch 35, time usage: 10.390s, loss: 1715.8710180474807, vec_loss [0.14026648 0.21909021]/0.179678
Validation loss: 1681.5441338433159 vec_loss: [0.12816541 0.20240252]/0.165284
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
11.996081	4.9976444027640605	0.22829963	74.25836307151222


100%|██████████| 258/258 [00:10<00:00, 25.50it/s]


-------------------------
Epoch 36, time usage: 10.120s, loss: 1741.9710085373517, vec_loss [0.1343612  0.21212067]/0.173241
Validation loss: 1609.808322482639 vec_loss: [0.12587541 0.19060585]/0.158241
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
6.162169	3.3636933240023525	0.11002014	75.19765875821412


100%|██████████| 258/258 [00:10<00:00, 25.31it/s]


-------------------------
Epoch 37, time usage: 10.197s, loss: 1693.1103210449219, vec_loss [0.1300098  0.21198305]/0.170996
Validation loss: 1768.1545166015626 vec_loss: [0.11153096 0.19882414]/0.155178
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
6.556551	3.670787366953763	0.12165925	79.8237850014199


100%|██████████| 258/258 [00:10<00:00, 25.00it/s]


-------------------------
Epoch 38, time usage: 10.323s, loss: 1637.011229877324, vec_loss [0.13060762 0.21154934]/0.171078
Validation loss: 1301.4941379123263 vec_loss: [0.1191001  0.18890196]/0.154001
Best Validation Model saved to /kaggle/working/prediction_model/checkpoints/checkpoint_38.pt
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
3.9081519	2.6679722504182295	0.062985584	76.13382683920244


100%|██████████| 258/258 [00:10<00:00, 25.30it/s]


-------------------------
Epoch 39, time usage: 10.202s, loss: 1687.1120546326156, vec_loss [0.12697195 0.20195141]/0.164462
Validation loss: 1299.9919609917536 vec_loss: [0.11804325 0.18188666]/0.149965
Best Validation Model saved to /kaggle/working/prediction_model/checkpoints/checkpoint_39.pt
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
8.248328	3.7353681217540395	0.14346544	73.49691185041911


100%|██████████| 258/258 [00:10<00:00, 25.31it/s]


-------------------------
Epoch 40, time usage: 10.196s, loss: 1623.992351798124, vec_loss [0.12676474 0.19433002]/0.160547
Validation loss: 1355.1442843967013 vec_loss: [0.11943263 0.18452671]/0.151980
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
5.5030065	2.960660121657632	0.09264286	74.80608090846586


100%|██████████| 258/258 [00:10<00:00, 25.34it/s]


-------------------------
Epoch 41, time usage: 10.183s, loss: 1647.0694099840268, vec_loss [0.12808055 0.20829676]/0.168189
Validation loss: 1348.287178548177 vec_loss: [0.12345931 0.19498801]/0.159224
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
8.163444	3.6475915908813477	0.15867567	72.71087505294541


100%|██████████| 258/258 [00:10<00:00, 25.37it/s]


-------------------------
Epoch 42, time usage: 10.172s, loss: 1681.7583911511324, vec_loss [0.12919614 0.20584378]/0.167520
Validation loss: 1451.2353271484376 vec_loss: [0.11775406 0.17670825]/0.147231
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
10.514499	4.456276806918058	0.20275792	73.6963111511959


100%|██████████| 258/258 [00:10<00:00, 25.43it/s]


-------------------------
Epoch 43, time usage: 10.149s, loss: 1702.9040898759235, vec_loss [0.12993436 0.2067762 ]/0.168355
Validation loss: 2095.948074001736 vec_loss: [0.12295185 0.22819822]/0.175575
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
16.06448	6.331226739016446	0.30550095	89.72601691703858


100%|██████████| 258/258 [00:10<00:00, 25.18it/s]


-------------------------
Epoch 44, time usage: 10.250s, loss: 1677.4324191780977, vec_loss [0.12918627 0.1948179 ]/0.162002
Validation loss: 1229.1601996527777 vec_loss: [0.11356787 0.15520413]/0.134386
Best Validation Model saved to /kaggle/working/prediction_model/checkpoints/checkpoint_44.pt
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
4.1278	2.6412189548665825	0.07320287	72.6014715665253


100%|██████████| 258/258 [00:10<00:00, 25.26it/s]


-------------------------
Epoch 45, time usage: 10.216s, loss: 1571.0560056701188, vec_loss [0.12634988 0.18538827]/0.155869
Validation loss: 1461.31806640625 vec_loss: [0.11198892 0.1853581 ]/0.148674
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
7.4141784	3.5912295905026523	0.14317946	71.39536944212776


100%|██████████| 258/258 [00:10<00:00, 25.46it/s]


-------------------------
Epoch 46, time usage: 10.138s, loss: 1578.6378561803538, vec_loss [0.12125573 0.19180636]/0.156531
Validation loss: 1410.9000691731771 vec_loss: [0.11758117 0.16599427]/0.141788
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
9.572592	4.106801423159513	0.18183666	77.91672850418696


100%|██████████| 258/258 [00:10<00:00, 25.10it/s]


-------------------------
Epoch 47, time usage: 10.283s, loss: 1488.0312832381373, vec_loss [0.12382719 0.18059072]/0.152209
Validation loss: 1232.5315687391494 vec_loss: [0.10373309 0.16273929]/0.133236
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
5.159708	2.908699935132807	0.08714703	73.77447211049979


100%|██████████| 258/258 [00:10<00:00, 25.36it/s]


-------------------------
Epoch 48, time usage: 10.175s, loss: 1565.034707239432, vec_loss [0.1216304  0.18434922]/0.152990
Validation loss: 1560.884494357639 vec_loss: [0.12087287 0.1812168 ]/0.151045
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
10.634531	4.823387622833252	0.20490307	69.47481214461082


100%|██████████| 258/258 [00:10<00:00, 25.28it/s]


-------------------------
Epoch 49, time usage: 10.211s, loss: 1558.4084480936212, vec_loss [0.11951696 0.19300853]/0.156263
Validation loss: 1473.2375298394097 vec_loss: [0.11709653 0.17807347]/0.147585
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
5.026535	3.3600258176976983	0.07262395	80.7792496086661


100%|██████████| 258/258 [00:10<00:00, 25.12it/s]


-------------------------
Epoch 50, time usage: 10.276s, loss: 1557.537225531053, vec_loss [0.12508044 0.19526017]/0.160170
Validation loss: 1271.2038628472221 vec_loss: [0.09932417 0.17469387]/0.137009
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
5.430409	2.8952105045318604	0.10210344	77.51822608447284


100%|██████████| 258/258 [00:10<00:00, 25.17it/s]


-------------------------
Epoch 51, time usage: 10.254s, loss: 1514.4850265148075, vec_loss [0.11796291 0.18466128]/0.151312
Validation loss: 1366.6053860134548 vec_loss: [0.10521487 0.16321053]/0.134213
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
5.266081	3.062692869793285	0.1006746	73.24467172251796


100%|██████████| 258/258 [00:10<00:00, 24.87it/s]


-------------------------
Epoch 52, time usage: 10.378s, loss: 1493.1561064017835, vec_loss [0.11876869 0.1797399 ]/0.149254
Validation loss: 1405.973597547743 vec_loss: [0.12079586 0.17639968]/0.148598
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
7.4962144	3.6037153113972056	0.13811666	76.50024599760611


100%|██████████| 258/258 [00:10<00:00, 25.38it/s]


-------------------------
Epoch 53, time usage: 10.167s, loss: 1461.9862261631692, vec_loss [0.11586195 0.19000676]/0.152934
Validation loss: 1296.6485039605034 vec_loss: [0.11107709 0.18695948]/0.149018
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
6.5387287	3.3559610843658447	0.116074845	79.32984931584355


100%|██████████| 258/258 [00:10<00:00, 25.01it/s]


-------------------------
Epoch 54, time usage: 10.321s, loss: 1492.7280398819798, vec_loss [0.12472934 0.18971705]/0.157223
Validation loss: 1196.7599500868055 vec_loss: [0.10961169 0.16104956]/0.135331
Best Validation Model saved to /kaggle/working/prediction_model/checkpoints/checkpoint_54.pt
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
6.7405243	3.3063348423350942	0.12307267	77.45551811770297


100%|██████████| 258/258 [00:10<00:00, 25.48it/s]


-------------------------
Epoch 55, time usage: 10.128s, loss: 1482.2167952190073, vec_loss [0.12248344 0.18208778]/0.152286
Validation loss: 1250.3198825412326 vec_loss: [0.11216889 0.18822077]/0.150195
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
4.960091	3.0350716655904595	0.09370511	71.93830800944124


100%|██████████| 258/258 [00:10<00:00, 25.45it/s]


-------------------------
Epoch 56, time usage: 10.141s, loss: 1456.6246489295663, vec_loss [0.11766449 0.17530663]/0.146486
Validation loss: 1306.2347412109375 vec_loss: [0.10429568 0.16430317]/0.134299
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
13.325733	5.317630702799017	0.24779604	82.18372277990717


100%|██████████| 258/258 [00:10<00:00, 25.06it/s]


-------------------------
Epoch 57, time usage: 10.297s, loss: 1434.378552459007, vec_loss [0.11770118 0.17169589]/0.144699
Validation loss: 1314.3411607530381 vec_loss: [0.1099925  0.15074717]/0.130370
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
5.880093	3.2368331172249536	0.11068871	68.7564594171805


100%|██████████| 258/258 [00:10<00:00, 25.47it/s]


-------------------------
Epoch 58, time usage: 10.132s, loss: 1385.320237034051, vec_loss [0.11698619 0.16145562]/0.139221
Validation loss: 1320.7412109375 vec_loss: [0.1034654  0.17056492]/0.137015
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
7.985142	3.571103269403631	0.13189194	73.44074782580039


100%|██████████| 258/258 [00:10<00:00, 25.18it/s]


-------------------------
Epoch 59, time usage: 10.251s, loss: 1476.1588534569555, vec_loss [0.11758213 0.16949041]/0.143536
Validation loss: 1262.346240234375 vec_loss: [0.10199223 0.14593935]/0.123966
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
4.533579	2.8005909486250444	0.08844454	71.57369636746247
Model saved to /kaggle/working/prediction_model/checkpoints/icheckpoint_59.pt


100%|██████████| 258/258 [00:10<00:00, 25.48it/s]


-------------------------
Epoch 60, time usage: 10.131s, loss: 1411.1115922558215, vec_loss [0.11441929 0.16445245]/0.139436
Validation loss: 1447.6473958333333 vec_loss: [0.10447733 0.16425487]/0.134366
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
13.431113	5.29665114662864	0.25033072	74.45439881458593


100%|██████████| 258/258 [00:10<00:00, 25.41it/s]


-------------------------
Epoch 61, time usage: 10.157s, loss: 1403.616214012915, vec_loss [0.1159866  0.16035554]/0.138171
Validation loss: 1170.5129896375868 vec_loss: [0.09537084 0.13917363]/0.117272
Best Validation Model saved to /kaggle/working/prediction_model/checkpoints/checkpoint_61.pt
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
5.1398168	2.7839153679934414	0.08974592	71.54454387823228


100%|██████████| 258/258 [00:10<00:00, 25.35it/s]


-------------------------
Epoch 62, time usage: 10.181s, loss: 1387.351466097573, vec_loss [0.11851026 0.1538566 ]/0.136183
Validation loss: 1255.5691053602432 vec_loss: [0.1010697  0.13606773]/0.118569
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
5.6026177	2.8475897203792226	0.106763	70.82506168588374


100%|██████████| 258/258 [00:10<00:00, 25.41it/s]


-------------------------
Epoch 63, time usage: 10.159s, loss: 1364.8518980750741, vec_loss [0.11691918 0.1534808 ]/0.135200
Validation loss: 1286.1233737521702 vec_loss: [0.11292891 0.13926102]/0.126095
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
5.68018	2.9554462974721734	0.090577915	73.80403751571163


100%|██████████| 258/258 [00:10<00:00, 25.17it/s]


-------------------------
Epoch 64, time usage: 10.252s, loss: 1371.436234230219, vec_loss [0.11482825 0.1610451 ]/0.137937
Validation loss: 1094.0762329101562 vec_loss: [0.09915193 0.15198168]/0.125567
Best Validation Model saved to /kaggle/working/prediction_model/checkpoints/checkpoint_64.pt
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
9.840249	4.035350409421054	0.18254629	73.42620138564713


100%|██████████| 258/258 [00:10<00:00, 25.36it/s]


-------------------------
Epoch 65, time usage: 10.178s, loss: 1327.9748694841253, vec_loss [0.1127537  0.15377565]/0.133265
Validation loss: 1244.616593424479 vec_loss: [0.10015886 0.14838743]/0.124273
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
5.219821	2.98860671303489	0.10110101	68.20537926437636


100%|██████████| 258/258 [00:10<00:00, 24.93it/s]


-------------------------
Epoch 66, time usage: 10.354s, loss: 1430.8337746553643, vec_loss [0.11697467 0.15778574]/0.137380
Validation loss: 1283.1649956597223 vec_loss: [0.0975806  0.13523798]/0.116409
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
11.103024	4.500553586266258	0.19116291	80.5235024890893


100%|██████████| 258/258 [00:10<00:00, 25.30it/s]


-------------------------
Epoch 67, time usage: 10.201s, loss: 1357.0269623985587, vec_loss [0.11845217 0.15088864]/0.134670
Validation loss: 1326.869377983941 vec_loss: [0.11636387 0.15543734]/0.135901
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
7.5675244	3.629561597650701	0.14009356	69.78634297145486


100%|██████████| 258/258 [00:10<00:00, 25.52it/s]


-------------------------
Epoch 68, time usage: 10.113s, loss: 1356.8933201279751, vec_loss [0.11778054 0.1558449 ]/0.136813
Validation loss: 1101.7802585177951 vec_loss: [0.09281593 0.13259992]/0.112708
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
7.2124815	3.2146058190952647	0.120996974	69.84913072725277


100%|██████████| 258/258 [00:10<00:00, 25.27it/s]


-------------------------
Epoch 69, time usage: 10.213s, loss: 1313.5630319284837, vec_loss [0.11697339 0.15154962]/0.134262
Validation loss: 1073.1453640407985 vec_loss: [0.10130733 0.13607837]/0.118693
Best Validation Model saved to /kaggle/working/prediction_model/checkpoints/checkpoint_69.pt
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
4.9831476	2.9073788794604214	0.08922134	70.73351293488034


100%|██████████| 258/258 [00:10<00:00, 25.50it/s]


-------------------------
Epoch 70, time usage: 10.121s, loss: 1301.9457979424055, vec_loss [0.11639223 0.14871524]/0.132554
Validation loss: 1206.7618286132813 vec_loss: [0.10556453 0.14955075]/0.127558
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
5.7838726	3.0123237046328457	0.102236815	72.46791404749487


100%|██████████| 258/258 [00:10<00:00, 25.40it/s]


-------------------------
Epoch 71, time usage: 10.163s, loss: 1315.5229773706244, vec_loss [0.11392351 0.15121888]/0.132571
Validation loss: 1209.3913275824652 vec_loss: [0.09278546 0.1163175 ]/0.104551
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
3.6865194	2.598185506733981	0.058469273	67.68773847589601


100%|██████████| 258/258 [00:10<00:00, 25.48it/s]


-------------------------
Epoch 72, time usage: 10.129s, loss: 1310.9279457506284, vec_loss [0.1130161  0.14161436]/0.127315
Validation loss: 1365.5482367621528 vec_loss: [0.10512701 0.12752384]/0.116325
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
7.635783	3.446214654228904	0.13300273	75.24190099809972


100%|██████████| 258/258 [00:10<00:00, 25.43it/s]


-------------------------
Epoch 73, time usage: 10.148s, loss: 1327.0122448825098, vec_loss [0.11458032 0.15018883]/0.132385
Validation loss: 1201.2328667534723 vec_loss: [0.09658708 0.12585606]/0.111222
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
6.986351	3.298346909609708	0.11623206	70.96179338009938


100%|██████████| 258/258 [00:10<00:00, 25.37it/s]


-------------------------
Epoch 74, time usage: 10.172s, loss: 1315.9688604783641, vec_loss [0.11331339 0.15648292]/0.134898
Validation loss: 1449.7372463650174 vec_loss: [0.10482005 0.16049625]/0.132658
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
10.951741	4.672132535414263	0.20486592	74.54378487932759


100%|██████████| 258/258 [00:10<00:00, 25.43it/s]


-------------------------
Epoch 75, time usage: 10.149s, loss: 1317.8465699188469, vec_loss [0.11372273 0.15498956]/0.134356
Validation loss: 1272.5079535590278 vec_loss: [0.10584038 0.15542032]/0.130630
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
6.0399914	3.2886778874830767	0.10954382	76.56768542269197


100%|██████████| 258/258 [00:10<00:00, 25.23it/s]


-------------------------
Epoch 76, time usage: 10.227s, loss: 1258.8414585793664, vec_loss [0.10572159 0.15694277]/0.131332
Validation loss: 1344.1548638237848 vec_loss: [0.11150476 0.14942458]/0.130465
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
8.777426	4.3867575255307285	0.17118317	70.02119995745518


100%|██████████| 258/258 [00:10<00:00, 25.46it/s]


-------------------------
Epoch 77, time usage: 10.139s, loss: 1235.882593672405, vec_loss [0.10217323 0.15409428]/0.128134
Validation loss: 1363.9106757269965 vec_loss: [0.08993641 0.13955082]/0.114744
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
11.608835	4.721093871376731	0.21432947	73.23619794146315


100%|██████████| 258/258 [00:10<00:00, 25.24it/s]


-------------------------
Epoch 78, time usage: 10.225s, loss: 1275.1305429621261, vec_loss [0.11238693 0.14645144]/0.129419
Validation loss: 1181.5230007595487 vec_loss: [0.10385448 0.13613537]/0.119995
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
4.4572554	3.0652880560268057	0.07053637	71.76897892694859


100%|██████████| 258/258 [00:10<00:00, 25.58it/s]


-------------------------
Epoch 79, time usage: 10.089s, loss: 1246.9832253862721, vec_loss [0.11231297 0.14526501]/0.128789
Validation loss: 1804.9087212456598 vec_loss: [0.11647932 0.17725793]/0.146869
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
15.467093	6.383971842852506	0.27377063	78.84944982334957
Model saved to /kaggle/working/prediction_model/checkpoints/icheckpoint_79.pt


100%|██████████| 258/258 [00:10<00:00, 25.50it/s]


-------------------------
Epoch 80, time usage: 10.120s, loss: 1261.9831514580305, vec_loss [0.10693015 0.15088269]/0.128906
Validation loss: 1243.4833577473958 vec_loss: [0.10607036 0.13946745]/0.122769
Epoch 00081: reducing learning rate of group 0 to 2.2500e-04.
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
3.92829	2.7665779807350854	0.069457084	71.49188684528922


100%|██████████| 258/258 [00:10<00:00, 25.33it/s]


-------------------------
Epoch 81, time usage: 10.188s, loss: 1216.8987122764884, vec_loss [0.1125743  0.15192758]/0.132251
Validation loss: 1060.1796495225694 vec_loss: [0.10606138 0.1432061 ]/0.124634
Best Validation Model saved to /kaggle/working/prediction_model/checkpoints/checkpoint_81.pt
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
5.1535783	2.9201869639483364	0.10110235	69.77115977139306


100%|██████████| 258/258 [00:10<00:00, 25.43it/s]


-------------------------
Epoch 82, time usage: 10.150s, loss: 1081.7515386537063, vec_loss [0.10865024 0.15398547]/0.131318
Validation loss: 1033.5821899414063 vec_loss: [0.08199131 0.13555254]/0.108772
Best Validation Model saved to /kaggle/working/prediction_model/checkpoints/checkpoint_82.pt
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
4.031681	2.4828605326739224	0.07827576	68.31165824971434


100%|██████████| 258/258 [00:10<00:00, 25.19it/s]


-------------------------
Epoch 83, time usage: 10.246s, loss: 1129.2412271425706, vec_loss [0.10393747 0.15101877]/0.127478
Validation loss: 1189.483768717448 vec_loss: [0.09632653 0.14904799]/0.122687
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
5.259405	2.8962230032140557	0.09224885	67.82130530882078


100%|██████████| 258/258 [00:10<00:00, 25.30it/s]


-------------------------
Epoch 84, time usage: 10.203s, loss: 1158.2475834336392, vec_loss [0.1090683  0.14772527]/0.128397
Validation loss: 1304.065121799045 vec_loss: [0.10521103 0.15465128]/0.129931
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
11.387424	4.6861770369789815	0.20211278	76.81294820189281


100%|██████████| 258/258 [00:10<00:00, 25.19it/s]


-------------------------
Epoch 85, time usage: 10.244s, loss: 1201.8466582778813, vec_loss [0.1111699  0.14866884]/0.129919
Validation loss: 1294.2050645616318 vec_loss: [0.11074061 0.15705639]/0.133898
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
7.2808166	3.6451701792803677	0.1428551	72.1426896227544


100%|██████████| 258/258 [00:10<00:00, 25.45it/s]


-------------------------
Epoch 86, time usage: 10.142s, loss: 1245.2224886398908, vec_loss [0.10669862 0.15544835]/0.131073
Validation loss: 1140.4359822591146 vec_loss: [0.09191119 0.13520561]/0.113558
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
6.2485323	3.2093971967697144	0.1098027	74.18191524783742


100%|██████████| 258/258 [00:10<00:00, 25.51it/s]


-------------------------
Epoch 87, time usage: 10.115s, loss: 1214.4857293653859, vec_loss [0.10934372 0.1504723 ]/0.129908
Validation loss: 1251.5327324761286 vec_loss: [0.09768704 0.1435923 ]/0.120640
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
11.009862	4.48697380586104	0.19999246	75.14135564610685


100%|██████████| 258/258 [00:10<00:00, 25.11it/s]


-------------------------
Epoch 88, time usage: 10.279s, loss: 1095.2260496154313, vec_loss [0.10496976 0.15212081]/0.128545
Validation loss: 1678.5029120551214 vec_loss: [0.10813906 0.14341775]/0.125778
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
8.09355	4.285809560255571	0.14163834	72.57928835395366


100%|██████████| 258/258 [00:10<00:00, 25.50it/s]


-------------------------
Epoch 89, time usage: 10.119s, loss: 1141.9787948371827, vec_loss [0.10949236 0.15051414]/0.130003
Validation loss: 1288.100008138021 vec_loss: [0.09871569 0.14867222]/0.123694
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
6.857311	3.747442852367054	0.12204499	74.81725758165315


100%|██████████| 258/258 [00:10<00:00, 25.35it/s]


-------------------------
Epoch 90, time usage: 10.182s, loss: 1081.0140736897786, vec_loss [0.10433909 0.14468326]/0.124511
Validation loss: 1195.021233452691 vec_loss: [0.09851799 0.1556899 ]/0.127104
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
5.627212	2.8923952796242456	0.08916369	69.78868541587703


100%|██████████| 258/258 [00:10<00:00, 25.41it/s]


-------------------------
Epoch 91, time usage: 10.158s, loss: 1076.5633853646211, vec_loss [0.10674446 0.14786775]/0.127306
Validation loss: 1073.61162109375 vec_loss: [0.1008333  0.14299072]/0.121912
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
5.914068	2.895843755115162	0.11221876	66.24677843410875


100%|██████████| 258/258 [00:10<00:00, 25.17it/s]


-------------------------
Epoch 92, time usage: 10.255s, loss: 1099.7647690883903, vec_loss [0.10816431 0.14957282]/0.128869
Validation loss: 1121.452288140191 vec_loss: [0.10276274 0.13850866]/0.120636
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
4.6124706	2.8245975971221924	0.08059305	68.72522123654831


100%|██████████| 258/258 [00:10<00:00, 25.57it/s]


-------------------------
Epoch 93, time usage: 10.092s, loss: 1012.9287434659263, vec_loss [0.11065323 0.14198115]/0.126317
Validation loss: 1201.91533203125 vec_loss: [0.09585244 0.14058174]/0.118217
Epoch 00094: reducing learning rate of group 0 to 1.6875e-04.
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
10.188477	4.429918440905484	0.19062094	71.7064032174781


100%|██████████| 258/258 [00:10<00:00, 25.54it/s]


-------------------------
Epoch 94, time usage: 10.105s, loss: 1045.50470514815, vec_loss [0.10890415 0.14407589]/0.126490
Validation loss: 998.9842054578993 vec_loss: [0.09509049 0.13246794]/0.113779
Best Validation Model saved to /kaggle/working/prediction_model/checkpoints/checkpoint_94.pt
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
4.52443	2.6812127395109697	0.080419354	70.93660674707182


100%|██████████| 258/258 [00:10<00:00, 25.08it/s]


-------------------------
Epoch 95, time usage: 10.292s, loss: 1051.1404213129088, vec_loss [0.10962006 0.14445873]/0.127039
Validation loss: 1045.1400906032986 vec_loss: [0.09927723 0.1406811 ]/0.119979
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
4.6621575	2.545417991551486	0.0836023	66.62913242384606


100%|██████████| 258/258 [00:10<00:00, 25.35it/s]


-------------------------
Epoch 96, time usage: 10.182s, loss: 1053.3003579073174, vec_loss [0.11163716 0.13981634]/0.125727
Validation loss: 1027.9958570692274 vec_loss: [0.09736503 0.13797352]/0.117669
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
4.0366898	2.733622746034102	0.072945	67.812234817603


100%|██████████| 258/258 [00:10<00:00, 24.91it/s]


-------------------------
Epoch 97, time usage: 10.363s, loss: 998.0610189511794, vec_loss [0.104691   0.14524299]/0.124967
Validation loss: 1036.2691596137154 vec_loss: [0.10353487 0.14037733]/0.121956
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
4.3934236	2.6206627650694414	0.08172335	69.66739367474335


100%|██████████| 258/258 [00:10<00:00, 25.20it/s]


-------------------------
Epoch 98, time usage: 10.243s, loss: 993.7602636056353, vec_loss [0.10893331 0.1464104 ]/0.127672
Validation loss: 1032.6633246527779 vec_loss: [0.09472505 0.13785547]/0.116290
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
6.338257	3.0667378143830732	0.10828335	72.01554751413347


100%|██████████| 258/258 [00:10<00:00, 25.52it/s]


-------------------------
Epoch 99, time usage: 10.113s, loss: 1023.8642995671709, vec_loss [0.10947244 0.14393395]/0.126703
Validation loss: 998.4792317708333 vec_loss: [0.10445254 0.14434531]/0.124399
Best Validation Model saved to /kaggle/working/prediction_model/checkpoints/checkpoint_99.pt
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
6.868645	3.108125686645508	0.12498	69.74346770079268


100%|██████████| 258/258 [00:10<00:00, 25.31it/s]


-------------------------
Epoch 100, time usage: 10.198s, loss: 974.2549695155417, vec_loss [0.10846271 0.1460384 ]/0.127251
Validation loss: 1057.5443488226997 vec_loss: [0.1034916  0.12331767]/0.113405
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
4.7468944	2.7314558679407295	0.07672895	69.967801538524


100%|██████████| 258/258 [00:10<00:00, 25.44it/s]


-------------------------
Epoch 101, time usage: 10.145s, loss: 977.8987062439438, vec_loss [0.1072774  0.14444388]/0.125861
Validation loss: 1045.6090969509548 vec_loss: [0.0901682  0.13444002]/0.112304
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
7.8959584	3.536170092496005	0.14432849	72.02000886951409


100%|██████████| 258/258 [00:10<00:00, 25.01it/s]


-------------------------
Epoch 102, time usage: 10.318s, loss: 964.2162483865901, vec_loss [0.10464521 0.14709671]/0.125871
Validation loss: 1144.6827840169271 vec_loss: [0.10816511 0.14034347]/0.124254
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
5.8231626	3.3425618301738393	0.11283507	68.97682056852005


100%|██████████| 258/258 [00:10<00:00, 25.42it/s]


-------------------------
Epoch 103, time usage: 10.154s, loss: 993.3611582674721, vec_loss [0.11092471 0.14252926]/0.126727
Validation loss: 1003.898368326823 vec_loss: [0.09663843 0.14157979]/0.119109
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
3.8406768	2.5206289941614326	0.06965654	72.20787198078271


100%|██████████| 258/258 [00:10<00:00, 25.32it/s]


-------------------------
Epoch 104, time usage: 10.192s, loss: 986.1335212648377, vec_loss [0.10497593 0.14796893]/0.126472
Validation loss: 1076.5731472439236 vec_loss: [0.09675874 0.13625297]/0.116506
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
7.7431183	3.526787216013128	0.14219661	74.54730553138158


100%|██████████| 258/258 [00:10<00:00, 25.36it/s]


-------------------------
Epoch 105, time usage: 10.178s, loss: 929.8378144493399, vec_loss [0.10727102 0.14293024]/0.125101
Validation loss: 942.4593343098958 vec_loss: [0.09546874 0.13283232]/0.114151
Best Validation Model saved to /kaggle/working/prediction_model/checkpoints/checkpoint_105.pt
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
3.0875535	2.328872030431574	0.058212522	67.28105038492482


100%|██████████| 258/258 [00:10<00:00, 25.46it/s]


-------------------------
Epoch 106, time usage: 10.138s, loss: 920.9034880408944, vec_loss [0.10756557 0.1431319 ]/0.125349
Validation loss: 1067.9956637912326 vec_loss: [0.09996216 0.13617532]/0.118069
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
9.020362	4.015014290809631	0.16084121	73.19507771895451


100%|██████████| 258/258 [00:10<00:00, 25.20it/s]


-------------------------
Epoch 107, time usage: 10.241s, loss: 962.8726279088693, vec_loss [0.1071659  0.14202799]/0.124597
Validation loss: 1147.5411499023437 vec_loss: [0.10643517 0.1469957 ]/0.126715
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
9.676517	4.408969261429527	0.1602961	75.74761379258408


100%|██████████| 258/258 [00:10<00:00, 25.46it/s]


-------------------------
Epoch 108, time usage: 10.135s, loss: 946.2661628427431, vec_loss [0.1083988  0.14008662]/0.124243
Validation loss: 1348.3271308051214 vec_loss: [0.11776033 0.13011578]/0.123938
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
3.6831815	2.880445046858354	0.059832834	73.13900807513158


100%|██████████| 258/258 [00:10<00:00, 25.08it/s]


-------------------------
Epoch 109, time usage: 10.290s, loss: 983.932557571766, vec_loss [0.10723555 0.13895996]/0.123098
Validation loss: 1026.6098402235243 vec_loss: [0.10059704 0.13408281]/0.117340
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
4.8237834	2.8021909323605625	0.08429841	72.79581839009512


100%|██████████| 258/258 [00:10<00:00, 25.44it/s]


-------------------------
Epoch 110, time usage: 10.145s, loss: 1059.4964535735373, vec_loss [0.1054944  0.14681222]/0.126153
Validation loss: 1027.2909342447917 vec_loss: [0.10079519 0.12302928]/0.111912
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
4.62534	2.7426917769692163	0.08027355	70.47638860079634


100%|██████████| 258/258 [00:10<00:00, 25.29it/s]


-------------------------
Epoch 111, time usage: 10.206s, loss: 949.4556807880253, vec_loss [0.10781755 0.13814522]/0.122981
Validation loss: 1152.842842610677 vec_loss: [0.11031705 0.13091753]/0.120617
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
7.040874	3.623968341133811	0.13633636	67.33472418466481


100%|██████████| 258/258 [00:10<00:00, 25.27it/s]


-------------------------
Epoch 112, time usage: 10.212s, loss: 940.653033707493, vec_loss [0.10299122 0.13888752]/0.120939
Validation loss: 975.8831359863282 vec_loss: [0.09415624 0.14126898]/0.117713
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
3.9667792	2.423843882300637	0.07271708	65.83002264430094


100%|██████████| 258/258 [00:10<00:00, 25.55it/s]


-------------------------
Epoch 113, time usage: 10.101s, loss: 916.0566873476487, vec_loss [0.1036374  0.13885142]/0.121244
Validation loss: 1206.902859157986 vec_loss: [0.11720528 0.13889855]/0.128052
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
9.191547	4.169969981366938	0.17158101	68.46358711483323


100%|██████████| 258/258 [00:10<00:00, 25.25it/s]


-------------------------
Epoch 114, time usage: 10.220s, loss: 931.6354323985964, vec_loss [0.10751031 0.13761806]/0.122564
Validation loss: 1010.5498087565104 vec_loss: [0.0960563  0.12605442]/0.111055
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
4.420374	2.700398152524775	0.0770494	68.03999989400285


100%|██████████| 258/258 [00:10<00:00, 25.43it/s]


-------------------------
Epoch 115, time usage: 10.149s, loss: 925.1560664213905, vec_loss [0.10344063 0.13481402]/0.119127
Validation loss: 989.1267829047309 vec_loss: [0.08911356 0.13308771]/0.111101
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
7.4020963	3.590104753320867	0.1377648	70.34461562391814


100%|██████████| 258/258 [00:10<00:00, 24.99it/s]


-------------------------
Epoch 116, time usage: 10.328s, loss: 908.1855131637218, vec_loss [0.10403197 0.13924387]/0.121638
Validation loss: 1179.0957397460938 vec_loss: [0.10074497 0.13199137]/0.116368
Epoch 00117: reducing learning rate of group 0 to 1.2656e-04.
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
6.4162326	3.4076384847814385	0.12252838	70.69374571075367


100%|██████████| 258/258 [00:10<00:00, 25.40it/s]


-------------------------
Epoch 117, time usage: 10.160s, loss: 890.7400161388308, vec_loss [0.10233022 0.14347214]/0.122901
Validation loss: 1032.7568820529514 vec_loss: [0.10231198 0.14128953]/0.121801
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
8.730452	3.9961246143687856	0.14883772	70.7535669912098


100%|██████████| 258/258 [00:10<00:00, 25.31it/s]


-------------------------
Epoch 118, time usage: 10.195s, loss: 888.1349523973096, vec_loss [0.10410143 0.13895755]/0.121529
Validation loss: 1003.5248969184028 vec_loss: [0.09781466 0.13587792]/0.116846
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
5.2003193	3.113843885335055	0.09109226	70.99194618157983


100%|██████████| 258/258 [00:10<00:00, 25.33it/s]


-------------------------
Epoch 119, time usage: 10.187s, loss: 848.2764771927235, vec_loss [0.10142268 0.13886689]/0.120145
Validation loss: 1243.7030449761285 vec_loss: [0.11617385 0.1461283 ]/0.131151
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
11.3132	4.877926566384056	0.20504206	74.58094731767534
Model saved to /kaggle/working/prediction_model/checkpoints/icheckpoint_119.pt


100%|██████████| 258/258 [00:10<00:00, 25.36it/s]


-------------------------
Epoch 120, time usage: 10.176s, loss: 883.6029383932897, vec_loss [0.10268991 0.13725406]/0.119972
Validation loss: 1102.5541137695313 vec_loss: [0.10496408 0.13778193]/0.121373
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
6.369202	3.25661039352417	0.11977029	70.92945926456854


100%|██████████| 258/258 [00:10<00:00, 25.25it/s]


-------------------------
Epoch 121, time usage: 10.220s, loss: 847.1235185963238, vec_loss [0.10342455 0.13867459]/0.121050
Validation loss: 953.9444173177084 vec_loss: [0.09302819 0.12383813]/0.108433
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
3.807303	2.4398494850505483	0.06684912	68.82084885233881


100%|██████████| 258/258 [00:10<00:00, 25.32it/s]


-------------------------
Epoch 122, time usage: 10.194s, loss: 866.3497802970945, vec_loss [0.10384918 0.13270608]/0.118278
Validation loss: 1015.5862182617187 vec_loss: [0.09617706 0.13120903]/0.113693
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
9.151315	4.0259081667119805	0.1684536	71.23665323175275


100%|██████████| 258/258 [00:10<00:00, 25.41it/s]


-------------------------
Epoch 123, time usage: 10.162s, loss: 863.8642712970112, vec_loss [0.10052317 0.13300315]/0.116763
Validation loss: 993.7603000217014 vec_loss: [0.10192209 0.13774055]/0.119831
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
10.365692	4.448184923692183	0.18284798	72.70452359794358


100%|██████████| 258/258 [00:10<00:00, 25.36it/s]


-------------------------
Epoch 124, time usage: 10.178s, loss: 848.6438241411549, vec_loss [0.10316017 0.13450095]/0.118831
Validation loss: 948.1390238444011 vec_loss: [0.09398837 0.12947841]/0.111733
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
6.723946	3.211759242144498	0.115204625	70.9492860618765


100%|██████████| 258/258 [00:10<00:00, 25.46it/s]


-------------------------
Epoch 125, time usage: 10.136s, loss: 836.9409236464389, vec_loss [0.10105703 0.13099499]/0.116026
Validation loss: 1082.4875230577256 vec_loss: [0.09841512 0.1337627 ]/0.116089
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
8.344319	3.91842963478782	0.15058622	70.46586390021164


100%|██████████| 258/258 [00:10<00:00, 25.30it/s]


-------------------------
Epoch 126, time usage: 10.200s, loss: 822.7944655307504, vec_loss [0.10182911 0.13526206]/0.118546
Validation loss: 946.2277947319878 vec_loss: [0.09720011 0.13863246]/0.117916
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
4.991581	2.876376899805936	0.09488084	67.67384462927164


100%|██████████| 258/258 [00:10<00:00, 25.33it/s]


-------------------------
Epoch 127, time usage: 10.189s, loss: 842.6764166780221, vec_loss [0.09829897 0.1346043 ]/0.116452
Validation loss: 1114.2062913682726 vec_loss: [0.1051081 0.14281  ]/0.123959
Epoch 00128: reducing learning rate of group 0 to 9.4922e-05.
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
10.631193	4.820643533359874	0.18527803	72.60512571773126


100%|██████████| 258/258 [00:10<00:00, 25.14it/s]


-------------------------
Epoch 128, time usage: 10.273s, loss: 836.2269754335862, vec_loss [0.09868345 0.13398749]/0.116335
Validation loss: 1069.670060221354 vec_loss: [0.10320441 0.13823852]/0.120721
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
5.9097843	3.29885779727589	0.099748515	71.77306772655629


100%|██████████| 258/258 [00:10<00:00, 25.40it/s]


-------------------------
Epoch 129, time usage: 10.161s, loss: 795.7204088314559, vec_loss [0.09805309 0.132905  ]/0.115479
Validation loss: 971.1251085069445 vec_loss: [0.09992956 0.1256929 ]/0.112811
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
7.8765264	3.669079152020541	0.13604994	69.87754307234746


100%|██████████| 258/258 [00:10<00:00, 25.39it/s]


-------------------------
Epoch 130, time usage: 10.164s, loss: 795.0434098354606, vec_loss [0.10217358 0.13196519]/0.117069
Validation loss: 1024.045650906033 vec_loss: [0.09655978 0.12304993]/0.109805
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
7.082405	3.325688535516912	0.13956672	66.45561488307648


100%|██████████| 258/258 [00:10<00:00, 25.29it/s]


-------------------------
Epoch 131, time usage: 10.206s, loss: 806.1626323138097, vec_loss [0.10292105 0.13290697]/0.117914
Validation loss: 1031.3955234103732 vec_loss: [0.1035616  0.13826647]/0.120914
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
4.428975	2.579487908970226	0.08404786	68.10246445766775


100%|██████████| 258/258 [00:10<00:00, 25.51it/s]


-------------------------
Epoch 132, time usage: 10.115s, loss: 881.7901477665864, vec_loss [0.10571848 0.13805735]/0.121888
Validation loss: 1029.7871466742622 vec_loss: [0.09896695 0.12232914]/0.110648
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
3.5625718	2.6482598889957774	0.06478787	66.8577959805849


100%|██████████| 258/258 [00:10<00:00, 25.43it/s]


-------------------------
Epoch 133, time usage: 10.150s, loss: 824.3834550251332, vec_loss [0.09934701 0.13413462]/0.116741
Validation loss: 963.2663058810764 vec_loss: [0.10291195 0.1338376 ]/0.118375
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
9.289112	4.282109260559082	0.15812969	71.38025174978938


100%|██████████| 258/258 [00:10<00:00, 25.36it/s]


-------------------------
Epoch 134, time usage: 10.177s, loss: 798.4092689928158, vec_loss [0.10241355 0.13118239]/0.116798
Validation loss: 1012.7077473958333 vec_loss: [0.09781406 0.12493339]/0.111374
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
3.6416426	2.4475537105040117	0.06158956	68.84185354681546


100%|██████████| 258/258 [00:10<00:00, 25.49it/s]


-------------------------
Epoch 135, time usage: 10.127s, loss: 812.0246065124985, vec_loss [0.10014487 0.13304038]/0.116593
Validation loss: 1067.5325995551216 vec_loss: [0.09934398 0.12967588]/0.114510
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
9.594624	4.006695487282493	0.17632093	71.16372730198451


100%|██████████| 258/258 [00:10<00:00, 25.21it/s]


-------------------------
Epoch 136, time usage: 10.236s, loss: 800.4182707320812, vec_loss [0.09918154 0.12992474]/0.114553
Validation loss: 945.6022779676649 vec_loss: [0.0934989 0.1265404]/0.110020
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
8.770304	3.9636550816622647	0.15434644	71.71285339975974


100%|██████████| 258/258 [00:10<00:00, 25.43it/s]


-------------------------
Epoch 137, time usage: 10.148s, loss: 791.4742749827777, vec_loss [0.10030297 0.12717865]/0.113741
Validation loss: 972.5850802951389 vec_loss: [0.10267358 0.13196939]/0.117321
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
5.26074	3.083827257156372	0.09285534	69.33443499215728


100%|██████████| 258/258 [00:10<00:00, 25.30it/s]


-------------------------
Epoch 138, time usage: 10.202s, loss: 790.9724350567012, vec_loss [0.09769913 0.1273742 ]/0.112537
Validation loss: 908.66036851671 vec_loss: [0.08998753 0.12189384]/0.105941
Best Validation Model saved to /kaggle/working/prediction_model/checkpoints/checkpoint_138.pt
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
6.8769765	3.1440143476833	0.12494873	70.70831822981103


100%|██████████| 258/258 [00:10<00:00, 25.56it/s]


-------------------------
Epoch 139, time usage: 10.097s, loss: 847.5767165782839, vec_loss [0.10062405 0.13249023]/0.116557
Validation loss: 914.8404656304253 vec_loss: [0.08733544 0.12087022]/0.104103
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
4.202085	2.4363280859860508	0.07764935	66.75617518426087
Model saved to /kaggle/working/prediction_model/checkpoints/icheckpoint_139.pt


100%|██████████| 258/258 [00:10<00:00, 25.39it/s]


-------------------------
Epoch 140, time usage: 10.165s, loss: 818.7387750906538, vec_loss [0.09977883 0.12955587]/0.114667
Validation loss: 981.6172030978732 vec_loss: [0.10217346 0.12595333]/0.114063
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
5.828555	3.177441943775524	0.10121079	70.36719312468277


100%|██████████| 258/258 [00:10<00:00, 25.58it/s]


-------------------------
Epoch 141, time usage: 10.088s, loss: 769.0564671006314, vec_loss [0.10429828 0.12726836]/0.115783
Validation loss: 949.2754753960503 vec_loss: [0.1031535  0.12428628]/0.113720
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
7.368789	3.522965344515714	0.13144493	68.97649582407134


100%|██████████| 258/258 [00:10<00:00, 25.48it/s]


-------------------------
Epoch 142, time usage: 10.128s, loss: 790.6022622751635, vec_loss [0.10267761 0.12811962]/0.115399
Validation loss: 934.4181864420573 vec_loss: [0.09538932 0.1269878 ]/0.111189
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
6.1389756	3.1758647832003506	0.117678575	69.92918985821964


100%|██████████| 258/258 [00:10<00:00, 25.21it/s]


-------------------------
Epoch 143, time usage: 10.239s, loss: 764.8191078836604, vec_loss [0.10061154 0.12700194]/0.113807
Validation loss: 992.6402214898003 vec_loss: [0.10248788 0.13063699]/0.116562
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
7.3705463	3.5188650868155738	0.1331792	68.39479779599688


100%|██████████| 258/258 [00:10<00:00, 25.36it/s]


-------------------------
Epoch 144, time usage: 10.177s, loss: 795.113142383191, vec_loss [0.09886958 0.12930176]/0.114086
Validation loss: 951.2319620768229 vec_loss: [0.09810895 0.13234   ]/0.115224
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
6.309619	3.1510352221402256	0.12020686	66.49143332647346


100%|██████████| 258/258 [00:10<00:00, 25.43it/s]


-------------------------
Epoch 145, time usage: 10.148s, loss: 751.827265776405, vec_loss [0.09691747 0.12957656]/0.113247
Validation loss: 911.9638821072049 vec_loss: [0.0920293  0.12358377]/0.107807
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
3.8450887	2.5231389674273403	0.06562205	67.85708991173506


100%|██████████| 258/258 [00:10<00:00, 25.40it/s]


-------------------------
Epoch 146, time usage: 10.162s, loss: 753.6441090901693, vec_loss [0.09484363 0.12632698]/0.110585
Validation loss: 899.5246141221788 vec_loss: [0.09266023 0.12483371]/0.108747
Best Validation Model saved to /kaggle/working/prediction_model/checkpoints/checkpoint_146.pt
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
4.1195827	2.5389111908999356	0.06755109	66.85417784310366


100%|██████████| 258/258 [00:10<00:00, 25.49it/s]


-------------------------
Epoch 147, time usage: 10.125s, loss: 799.1540478846823, vec_loss [0.09876986 0.13082145]/0.114796
Validation loss: 927.8447258843316 vec_loss: [0.08890598 0.11321028]/0.101058
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
6.792725	3.1362285072153266	0.1317988	64.1180798146037


100%|██████████| 258/258 [00:10<00:00, 25.14it/s]


-------------------------
Epoch 148, time usage: 10.265s, loss: 813.160144184911, vec_loss [0.1025702 0.1257487]/0.114159
Validation loss: 910.3170545789931 vec_loss: [0.10174907 0.11627926]/0.109014
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
3.629222	2.330399578267878	0.06567859	68.27602142725665


100%|██████████| 258/258 [00:10<00:00, 25.41it/s]


-------------------------
Epoch 149, time usage: 10.158s, loss: 762.8375784703927, vec_loss [0.10284672 0.12330511]/0.113076
Validation loss: 1031.197664388021 vec_loss: [0.10543503 0.12984084]/0.117638
Reconstruct the validation trajectory and compute the ATE, RTE, PDE, and AYE
ATE	RTE	PDE	AYE
6.535742	3.350755648179488	0.11837521	70.05320187742183
Training completed


In [29]:
# Check Saved Checkpoints
# Define the directory path
dir_path = '/kaggle/working/prediction_model/checkpoints'

# Loop over all files in the directory
for filename in os.listdir(dir_path):
    # Check if the file is a regular file (not a directory)
    if os.path.isfile(os.path.join(dir_path, filename)):
        # Do something with the file
        print(filename)

checkpoint_3.pt
checkpoint_24.pt
checkpoint_1.pt
checkpoint_5.pt
checkpoint_44.pt
checkpoint_92.pt
icheckpoint_39.pt
icheckpoint_99.pt
checkpoint_138.pt
checkpoint_4.pt
checkpoint_54.pt
checkpoint_50.pt
icheckpoint_19.pt
checkpoint_38.pt
checkpoint_13.pt
checkpoint_146.pt
checkpoint_79.pt
checkpoint_31.pt
checkpoint_83.pt
checkpoint_11.pt
checkpoint_30.pt
checkpoint_16.pt
checkpoint_64.pt
checkpoint_59.pt
checkpoint_40.pt
checkpoint_53.pt
checkpoint_65.pt
checkpoint_61.pt
checkpoint_10.pt
checkpoint_111.pt
checkpoint_39.pt
checkpoint_2.pt
checkpoint_69.pt
checkpoint_82.pt
checkpoint_73.pt
checkpoint_26.pt
icheckpoint_119.pt
checkpoint_94.pt
checkpoint_22.pt
checkpoint_106.pt
checkpoint_99.pt
checkpoint_0.pt
icheckpoint_79.pt
checkpoint_81.pt
checkpoint_85.pt
checkpoint_76.pt
checkpoint_23.pt
checkpoint_42.pt
checkpoint_latest.pt
checkpoint_7.pt
checkpoint_33.pt
icheckpoint_59.pt
checkpoint_17.pt
icheckpoint_139.pt
checkpoint_105.pt


Download the checkpoints.

In [30]:
!zip -r checkpoints.zip /kaggle/working/prediction_model

  adding: kaggle/working/prediction_model/ (stored 0%)
  adding: kaggle/working/prediction_model/checkpoints/ (stored 0%)
  adding: kaggle/working/prediction_model/checkpoints/checkpoint_3.pt (deflated 8%)
  adding: kaggle/working/prediction_model/checkpoints/checkpoint_24.pt (deflated 7%)
  adding: kaggle/working/prediction_model/checkpoints/checkpoint_1.pt (deflated 8%)
  adding: kaggle/working/prediction_model/checkpoints/checkpoint_5.pt (deflated 8%)
  adding: kaggle/working/prediction_model/checkpoints/checkpoint_44.pt (deflated 8%)
  adding: kaggle/working/prediction_model/checkpoints/checkpoint_92.pt (deflated 7%)
  adding: kaggle/working/prediction_model/checkpoints/icheckpoint_39.pt (deflated 7%)
  adding: kaggle/working/prediction_model/checkpoints/icheckpoint_99.pt (deflated 7%)
  adding: kaggle/working/prediction_model/checkpoints/checkpoint_138.pt (deflated 8%)
  adding: kaggle/working/prediction_model/checkpoints/checkpoint_4.pt (deflated 7%)
  adding: kaggle/working/pred

In [31]:
from IPython.display import FileLink
FileLink(r'checkpoints.zip')

### 8. Testing

#### 8.1 Test for Seen Dataset

In [33]:
# Please change the output dir & model path
TEST_DIR = '/kaggle/input/comp7310-project-1-imu-indoor-tracking/original_data/test_seen' # Dataset directory for testing (unseen_subjects_test_set)
OUT_DIR = '/kaggle/working/test_results/test_seen' # Output directory for both traning and testing
MODEL_PATH = '/kaggle/working/prediction_model/checkpoints/checkpoint_105.pt' # Model path for testing
# Load config settings
kwargs = load_config()

import warnings
# Suspend warnings
warnings.filterwarnings('ignore')
test(**kwargs)

Testing LSTM model
LSTM model
Network constructed. trainable parameters: 216620
The model is loaded.
Model /kaggle/working/prediction_model/checkpoints/checkpoint_105.pt loaded to device cuda:0.
Reconstruct the tracermini_hw101_test20230313011731T
Reconstruct the tracermini_hw101_test20230311111842T
Reconstruct the tracermini_hw101_test20230311111507T
Reconstruct the tracermini_hw101_test20230311112635T
Reconstruct the tracermini_hw101_test20230313012956T
Reconstruct the tracermini_hw101_test20230313010954T
Reconstruct the tracermini_hw101_test20230313010204T
Reconstruct the tracermini_hw101_test20230311111027T
Reconstruct the tracermini_hw101_test20230313010546T
Reconstruct the tracermini_hw101_test20230311112235T
Reconstruct the tracermini_hw101_test20230313011357T
Reconstruct the tracermini_hw101_test20230313013335T


#### 8.2 Test for Unseen Dataset

In [35]:
# Please change the output dir & model path
TEST_DIR = '/kaggle/input/comp7310-project-1-imu-indoor-tracking/original_data/test_unseen' # Dataset directory for testing (unseen_subjects_test_set)
OUT_DIR = '/kaggle/working/test_results/test_unseen' # Output directory for both traning and testing
MODEL_PATH = '/kaggle/working/prediction_model/checkpoints/checkpoint_105.pt' # Model path for testing
# Load config settings
kwargs = load_config()
import warnings
# Suspend warnings
warnings.filterwarnings('ignore')
test(**kwargs)

Testing LSTM model
LSTM model
Network constructed. trainable parameters: 216620
The model is loaded.
Model /kaggle/working/prediction_model/checkpoints/checkpoint_105.pt loaded to device cuda:0.
Reconstruct the tracermini_unseen_cym20230314101010T
Reconstruct the tracermini_unseen_hw520230314081212T
Reconstruct the tracermini_unseen_cym20230314101001T
Reconstruct the tracermini_unseen_hw520230314091603T
Reconstruct the tracermini_unseen_hw520230314083031T
Reconstruct the tracermini_unseen_cym20230314101850T
Reconstruct the tracermini_unseen_cym20230314101230T
Reconstruct the tracermini_unseen_cym20230314100103T
Reconstruct the tracermini_unseen_hw520230314090739T
Reconstruct the tracermini_unseen_cym20230314100325T
Reconstruct the tracermini_unseen_hw520230314082000T
Reconstruct the tracermini_unseen_cym20230314100816T
Reconstruct the tracermini_unseen_hw520230314091844T
Reconstruct the tracermini_unseen_hw520230314091338T
Reconstruct the tracermini_unseen_hw520230314081542T
Reconstruc

In [38]:
# Load the checkpoint
checkpoint_path = '/kaggle/working/prediction_model/checkpoints/checkpoint_105.pt'
checkpoint = torch.load(checkpoint_path, map_location='cpu')
# Load ATE，RTE，PDE, and AYE values from the checkpoint
ate = checkpoint.get('ate',None)
rte = checkpoint.get('rte',None)
pde = checkpoint.get('pde',None)  
aye = checkpoint.get('aye',None)

#### 8.3 Check the Results

In [39]:
for key, value in seen_unseen_pred.items():
    # print the length of each sequence
    print(seen_unseen_dataset[key], len(value))

tracermini_hw101_test20230311112635T 42097
tracermini_hw101_test20230311111507T 37913
tracermini_hw101_test20230311112235T 43462
tracermini_hw101_test20230313011357T 38324
tracermini_hw101_test20230311111842T 42839
tracermini_hw101_test20230313010954T 39610
tracermini_hw101_test20230313010546T 38378
tracermini_hw101_test20230313011731T 37610
tracermini_hw101_test20230313013335T 38254
tracermini_hw101_test20230311111027T 44186
tracermini_hw101_test20230313012956T 38694
tracermini_hw101_test20230313010204T 39085
tracermini_unseen_hw520230314091844T 19545
tracermini_unseen_cym20230314101001T 283
tracermini_unseen_hw520230314082319T 31657
tracermini_unseen_cym20230314101636T 22690
tracermini_unseen_hw520230314083031T 21208
tracermini_unseen_hw520230314091110T 25327
tracermini_unseen_cym20230314100816T 18914
tracermini_unseen_cym20230314101230T 21364
tracermini_unseen_cym20230314100559T 24483
tracermini_unseen_hw520230314081212T 37662
tracermini_unseen_hw520230314082000T 34311
tracermini_un

#### 8.4 Save the Results

In [40]:
with open("submission.csv", "w") as f:
    # The first row must be "Id, Category"
    f.write("Id,Prediction\n")

    # For the rest of the rows, each image id corresponds to a predicted class.
    for key, value in seen_unseen_pred.items():
        # print the length of each sequence
        print(seen_unseen_dataset[key], len(value))
        # print the prediction
        for i in range(len(value)):
            f.write("{},{}\n".format(key+'_'+str(i), value[i]))

tracermini_hw101_test20230311112635T 42097
tracermini_hw101_test20230311111507T 37913
tracermini_hw101_test20230311112235T 43462
tracermini_hw101_test20230313011357T 38324
tracermini_hw101_test20230311111842T 42839
tracermini_hw101_test20230313010954T 39610
tracermini_hw101_test20230313010546T 38378
tracermini_hw101_test20230313011731T 37610
tracermini_hw101_test20230313013335T 38254
tracermini_hw101_test20230311111027T 44186
tracermini_hw101_test20230313012956T 38694
tracermini_hw101_test20230313010204T 39085
tracermini_unseen_hw520230314091844T 19545
tracermini_unseen_cym20230314101001T 283
tracermini_unseen_hw520230314082319T 31657
tracermini_unseen_cym20230314101636T 22690
tracermini_unseen_hw520230314083031T 21208
tracermini_unseen_hw520230314091110T 25327
tracermini_unseen_cym20230314100816T 18914
tracermini_unseen_cym20230314101230T 21364
tracermini_unseen_cym20230314100559T 24483
tracermini_unseen_hw520230314081212T 37662
tracermini_unseen_hw520230314082000T 34311
tracermini_un

In [48]:
#zip
import os
import zipfile
import datetime

def file2zip(packagePath, zipPath):
    '''
  :param packagePath: 文件夹路径
  :param zipPath: 压缩包路径
  :return:
  '''
    zip = zipfile.ZipFile(zipPath, 'w', zipfile.ZIP_DEFLATED)
    for path, dirNames, fileNames in os.walk(packagePath):
        fpath = path.replace(packagePath, '')
        for name in fileNames:
            fullName = os.path.join(path, name)
            name = fpath + '\\' + name
            zip.write(fullName, name)
    zip.close()


if __name__ == "__main__":
    # 文件夹路径
    packagePath = '/kaggle/working/'
    zipPath = '/kaggle/working/output.zip'
    if os.path.exists(zipPath):
        os.remove(zipPath)
    file2zip(packagePath, zipPath)
    print("打包完成")
    print(datetime.datetime.utcnow())


打包完成
2023-10-27 14:16:36.116966
