In [1]:
from nuscenes.nuscenes import NuScenes
from nuscenes.can_bus.can_bus_api import NuScenesCanBus
from nuscenes.utils import splits
import mmcv
import numpy as np
import pprint
import argparse
import os
import torch
import logging
from path import Path
from utils import custom_transform
from dataset.KITTI_dataset import KITTI
from dataset.NuScenes_dataset import NuScenes_Dataset
from model import DeepVIO
from collections import defaultdict
from utils.kitti_eval import KITTI_tester, data_partition
import numpy as np
import math
import os
import glob
import numpy as np
import time
import scipy.io as sio
import torch
from PIL import Image
import torchvision.transforms.functional as TF
import matplotlib.pyplot as plt
import math
from utils.utils import *

from utils.utils import rotationError, read_pose_from_text
from collections import Counter
from scipy.ndimage import gaussian_filter1d
from scipy.signal.windows import triang
from scipy.ndimage import convolve1d
from torch.utils.data import Dataset
from utils import custom_transform

In [2]:
def quaternion_rotation_matrix(Q):
    """
    Covert a quaternion into a full three-dimensional rotation matrix.
 
    Input
    :param Q: A 4 element array representing the quaternion (q0,q1,q2,q3) 
 
    Output
    :return: A 3x3 element matrix representing the full 3D rotation matrix. 
             This rotation matrix converts a point in the local reference 
             frame to a point in the global reference frame.
    """
    # Extract the values from Q
    q0 = Q[0]
    q1 = Q[1]
    q2 = Q[2]
    q3 = Q[3]
     
    # First row of the rotation matrix
    r00 = 2 * (q0 * q0 + q1 * q1) - 1
    r01 = 2 * (q1 * q2 - q0 * q3)
    r02 = 2 * (q1 * q3 + q0 * q2)
     
    # Second row of the rotation matrix
    r10 = 2 * (q1 * q2 + q0 * q3)
    r11 = 2 * (q0 * q0 + q2 * q2) - 1
    r12 = 2 * (q2 * q3 - q0 * q1)
     
    # Third row of the rotation matrix
    r20 = 2 * (q1 * q3 - q0 * q2)
    r21 = 2 * (q2 * q3 + q0 * q1)
    r22 = 2 * (q0 * q0 + q3 * q3) - 1
     
    # 3x3 rotation matrix
    rot_matrix = np.array([[r00, r01, r02],
                           [r10, r11, r12],
                           [r20, r21, r22]])
                            
    return rot_matrix

def euler_from_matrix(matrix):
    '''
    Extract the eular angle from a rotation matrix
    '''
    _EPS = np.finfo(float).eps * 4.0
    
    M = np.array(matrix, dtype=np.float64, copy=False)[:3, :3]
    cy = math.sqrt(M[0, 0] * M[0, 0] + M[1, 0] * M[1, 0])
    ay = math.atan2(-M[2, 0], cy)
    if ay < -math.pi / 2 + _EPS and ay > -math.pi / 2 - _EPS:  # pitch = -90 deg
        ax = 0
        az = math.atan2(-M[1, 2], -M[0, 2])
    elif ay < math.pi / 2 + _EPS and ay > math.pi / 2 - _EPS:
        ax = 0
        az = math.atan2(M[1, 2], M[0, 2])
    else:
        ax = math.atan2(M[2, 1], M[2, 2])
        az = math.atan2(M[1, 0], M[0, 0])
    return np.array([ax, ay, az])

def get_lds_kernel_window(kernel, ks, sigma):
    assert kernel in ['gaussian', 'triang', 'laplace']
    half_ks = (ks - 1) // 2
    if kernel == 'gaussian':
        base_kernel = [0.] * half_ks + [1.] + [0.] * half_ks
        kernel_window = gaussian_filter1d(base_kernel, sigma=sigma) / max(gaussian_filter1d(base_kernel, sigma=sigma))
    elif kernel == 'triang':
        kernel_window = triang(ks)
    else:
        laplace = lambda x: np.exp(-abs(x) / sigma) / (2. * sigma)
        kernel_window = list(map(laplace, np.arange(-half_ks, half_ks + 1))) / max(map(laplace, np.arange(-half_ks, half_ks + 1)))

    return kernel_window

In [119]:
class NuScenes_Val_Dataset(Dataset):
    def __init__(self, img_path_list, pose_rel_list, imu_list, args):
        super(NuScenes_Val_Dataset, self).__init__()
        self.img_path_list = img_path_list
        self.pose_rel_list = pose_rel_list
        self.imu_list = imu_list
        self.args = args
        
    def __getitem__(self, index):
        image_path_sequence = self.img_path_list[index]
        image_sequence = []
        for img_path in image_path_sequence:
            img_as_img = Image.open(img_path)
            img_as_img = TF.resize(img_as_img, size=(self.args.img_h, self.args.img_w))
            img_as_tensor = TF.to_tensor(img_as_img) - 0.5
            img_as_tensor = img_as_tensor.unsqueeze(0)
            image_sequence.append(img_as_tensor)
        image_sequence = torch.cat(image_sequence, 0)
        gt_sequence = self.pose_rel_list[index][:, :6]
        imu_sequence = torch.FloatTensor(self.imu_list[index])
        return image_sequence, imu_sequence, gt_sequence
    
    def __len__(self):
        return len(self.img_path_list)

class NuScenes_Dataset(Dataset):
    def __init__(self, 
                 data_root,
                 mode='train', # or 'val'
                 sequence_length=11,
                 max_imu_length=10,
                 cam_names = ["CAM_FRONT", "CAM_FRONT_RIGHT", "CAM_BACK_RIGHT", "CAM_BACK", "CAM_BACK_LEFT", "CAM_FRONT_LEFT"],
                 transform=None,
                 nusc=None,
                 nusc_can=None,
                 args=None):
        self.data_root = data_root
        if nusc is None:
            self.nusc = NuScenes(version='v1.0-trainval', dataroot=self.data_root, verbose=False)
        else:
            self.nusc = nusc
        if nusc_can is None:
            self.nusc_can = NuScenesCanBus(dataroot=self.data_root)
        else:
            self.nusc_can = nusc_can
        self.cam_names = cam_names
        self.sequence_length = sequence_length
        self.max_imu_length = max_imu_length
        self.transform = transform
        self.mode = mode
        if self.mode == 'train':
            self.make_train_dataset()
        self.args = args
    
    def get_available_scene_tokens(self):
        """Code from bevdet codebase - tools/data_converter/nuscenes_converter.py"""
        train_scenes = splits.train
        val_scenes = splits.val

        available_scenes = []
        for scene in self.nusc.scene:
            scene_token = scene['token']
            scene_rec = self.nusc.get('scene', scene_token)
            sample_rec = self.nusc.get('sample', scene_rec['first_sample_token'])
            sd_rec = self.nusc.get('sample_data', sample_rec['data']['LIDAR_TOP'])
            has_more_frames = True
            scene_not_exist = False
            while has_more_frames:
                lidar_path, boxes, _ = self.nusc.get_sample_data(sd_rec['token'])
                lidar_path = str(lidar_path)
                if os.getcwd() in lidar_path:
                    # path from lyftdataset is absolute path
                    lidar_path = lidar_path.split(f'{os.getcwd()}/')[-1]
                    # relative path
                if not mmcv.is_filepath(lidar_path):
                    scene_not_exist = True
                    break
                else:
                    break
            if scene_not_exist:
                continue
            available_scenes.append(scene)

        available_scene_names = [s['name'] for s in available_scenes]
        train_scenes = list(filter(lambda x: x in available_scene_names, train_scenes))
        val_scenes = list(filter(lambda x: x in available_scene_names, val_scenes))
        
        train_scenes = set([
            available_scenes[available_scene_names.index(s)]['token']
            for s in train_scenes
        ])
        val_scenes = set([
            available_scenes[available_scene_names.index(s)]['token']
            for s in val_scenes
        ])
        
        train_scenes = [self.nusc.get('scene', token) for token in train_scenes]
        val_scenes = [self.nusc.get('scene', token) for token in val_scenes]
        
        return train_scenes, val_scenes
    
    def get_scene_data(self, scene_record, cam_name):
        scene_name = scene_record['name']

        # Get images and poses of target scene
        first_sample_token = scene_record['first_sample_token']
        cur_sample = self.nusc.get('sample', first_sample_token)
        cur_sample_data = self.nusc.get('sample_data', cur_sample['data'][cam_name])

        scene_sample_data = []
        while True:
            try:
                scene_sample_data.append(cur_sample_data)
                cur_sample_data = self.nusc.get('sample_data', cur_sample_data['next'])
            except:
                break
        
        scene_imu_data = self.nusc_can.get_messages(scene_name, 'ms_imu')
        
        return scene_sample_data, scene_imu_data
    
    def format_scene_inputs(self, scene_sample_data, scene_imu_data):
        """ Collect image (12hz), pose (12hz), imu data (96hz) of target scene - single training input contains 2 images,  """
        # 1. 일단 각 scene input 모으기 - 2 images, 2 pose, 1 relative pose, 8 imu data
        scene_inputs = []
        for data_idx, cur_sample_data in enumerate(scene_sample_data):
            
            # 1. get image 
            cur_img_path = os.path.join(self.data_root, cur_sample_data['filename'])
            if cur_sample_data['next'] != "":
                next_sample_data = self.nusc.get('sample_data', cur_sample_data['next'])
                next_img_path = os.path.join(self.data_root, next_sample_data['filename'])
            else:
                break
            
            # 2. get ego pose
            # read_pose in utils.py
            cur_ego_pose = self.nusc.get('ego_pose', cur_sample_data['ego_pose_token'])
            trans = np.array(cur_ego_pose['translation'])
            trans = trans.reshape(3, -1)
            rot_mat = quaternion_rotation_matrix(cur_ego_pose['rotation']) # (w, x, y, z)
            cur_ego_pose_mat = np.concatenate((rot_mat, trans), axis=1)
            cur_ego_pose_mat = np.array(cur_ego_pose_mat).reshape(3, 4)
            cur_ego_pose_mat = np.concatenate((cur_ego_pose_mat, np.array([[0, 0, 0, 1]])), 0)
            
            next_ego_pose = self.nusc.get('ego_pose', next_sample_data['ego_pose_token'])
            trans = np.array(next_ego_pose['translation'])
            trans = trans.reshape(3, -1)
            rot_mat = quaternion_rotation_matrix(next_ego_pose['rotation']) # (w, x, y, z)
            next_ego_pose_mat = np.concatenate((rot_mat, trans), axis=1)
            next_ego_pose_mat = np.array(next_ego_pose_mat).reshape(3, 4)
            next_ego_pose_mat = np.concatenate((next_ego_pose_mat, np.array([[0, 0, 0, 1]])), 0)    

            # 3. get relative pose
            relative_pose = np.dot(np.linalg.inv(cur_ego_pose_mat), next_ego_pose_mat)
            R_rel = relative_pose[:3, :3]
            t_rel = relative_pose[:3, 3]

                # Extract the Eular angle from the relative rotation matrix
            x, y, z = euler_from_matrix(R_rel)
            theta = [x, y, z]

            pose_rel = np.concatenate((theta, t_rel))
            
            # 4. get imu data
            cur_timestamp = cur_sample_data['timestamp']
            next_timestamp = next_sample_data['timestamp']
            
            # get imu data between cur and next timestamp
            imu_data = []
            for imu in scene_imu_data:
                imu_timestamp = imu['utime']
                if imu_timestamp > cur_timestamp and imu_timestamp < next_timestamp:
                    data = imu['linear_accel'] + imu['rotation_rate']
                    imu_data.append(data)
            
            # if no matched imu data, skip
            if len(imu_data) <= 2:
                # continue
                return None
                
            # if imu data length is less than max_imu_length, pad with zeros
            if len(imu_data) < self.max_imu_length:
                imu_data = np.pad(imu_data, ((0, self.max_imu_length - len(imu_data)), (0, 0)), 'constant', constant_values=0)
            else:
                imu_data = imu_data[:self.max_imu_length]
            
            # 5. make training input
            training_input = {
                'cur_img_path': cur_img_path,
                'next_img_path': next_img_path,
                'cur_ego_pose': cur_ego_pose_mat,
                'next_ego_pose': next_ego_pose_mat,
                'pose_rel': pose_rel,
                'imu_data': imu_data
            }
            scene_inputs.append(training_input)
        return scene_inputs
    
    def segment_training_inputs(self, training_inputs):
        samples = []

        input_idx = 0
        while True:
            # get training input chunk of sequence_length
            training_input_chunk = training_inputs[input_idx : input_idx + (self.sequence_length-1)]
            input_idx += 1 # training sequence간 겹치는 images 존재함
            if len(training_input_chunk) < (self.sequence_length-1):
                break
            
            img_samples = []
            pose_samples = []
            for training_input in training_input_chunk:
                img_samples.append(training_input['cur_img_path'])
                pose_samples.append(training_input['cur_ego_pose'])
            img_samples.append(training_input_chunk[-1]['next_img_path'])
            pose_samples.append(training_input_chunk[-1]['next_ego_pose'])
            
            pose_rel_samples = []
            imu_samples = np.empty((0, 6))
            for training_input in training_input_chunk:
                pose_rel_samples.append(training_input['pose_rel'])
                imu_samples = np.vstack((imu_samples, np.array(training_input['imu_data'])))
            
            pose_samples = np.array(pose_samples)
            pose_rel_samples = np.array(pose_rel_samples)
            imu_samples = np.array(imu_samples)
    
            segment_rot = rotationError(pose_samples[0], pose_samples[-1])
            sample = {'imgs':img_samples, 'imus':imu_samples, 'gts': pose_rel_samples, 'rot': segment_rot}
            
            samples.append(sample)
            
        # Generate weights based on the rotation of the training segments
        # Weights are calculated based on the histogram of rotations according to the method in https://github.com/YyzHarry/imbalanced-regression
        rot_list = np.array([np.cbrt(item['rot']*180/np.pi) for item in samples])
        rot_range = np.linspace(np.min(rot_list), np.max(rot_list), num=10)
        indexes = np.digitize(rot_list, rot_range, right=False)
        num_samples_of_bins = dict(Counter(indexes))
        emp_label_dist = [num_samples_of_bins.get(i, 0) for i in range(1, len(rot_range)+1)]

        # Apply 1d convolution to get the smoothed effective label distribution
        lds_kernel_window = get_lds_kernel_window(kernel='gaussian', ks=7, sigma=5)
        eff_label_dist = convolve1d(np.array(emp_label_dist), weights=lds_kernel_window, mode='constant')

        weights = [np.float32(1/eff_label_dist[bin_idx-1]) for bin_idx in indexes]
        
        assert len(samples) == len(weights)
        
        return samples, weights
    
    def segment_val_inputs(self, scene_inputs):
        img_samples, pose_rel_samples, imu_samples = [], [], []
        input_idx = 0
        while True:
            val_input_chunk = scene_inputs[input_idx : input_idx + (self.sequence_length - 1)]
            input_idx = input_idx + (self.sequence_length - 1)
            if len(val_input_chunk) < (self.sequence_length-1):
                break
            
            imgs = []
            for val_input in val_input_chunk:
                imgs.append(val_input['cur_img_path'])
            imgs.append(val_input_chunk[-1]['next_img_path'])
            
            pose_rels = []
            imus = np.empty((0, 6))
            for val_input in val_input_chunk:
                pose_rels.append(val_input['pose_rel'])
                imus = np.vstack((imus, np.array(val_input['imu_data'])))

            img_samples.append(imgs)
            pose_rel_samples.append(np.array(pose_rels))
            imu_samples.append(np.array(imus))

        return img_samples, pose_rel_samples, imu_samples
                
    def filter_dataset(self, scenes):
        skipped_scene = []
        imuavail_scenes = []
        for idx, train_scene in enumerate(scenes):
            scene_name = train_scene['name']
            scene_idx = int(scene_name.split('-')[-1])
            if scene_idx in self.nusc_can.route_blacklist or scene_idx in self.nusc_can.can_blacklist: # skip if scene has no can_bus data
                skipped_scene.append(scene_name)
                continue
            imuavail_scenes.append(train_scene)
        
        target_scenes = []
        for idx, train_scene in enumerate(imuavail_scenes):
            avail_cam_num = 0
            for cam_name in self.cam_names:
                scene_sample_data, scene_imu_data = self.get_scene_data(train_scene, cam_name)
                scene_inputs = self.format_scene_inputs(scene_sample_data, scene_imu_data)
                if scene_inputs is None: # skip if there are any scene samples that have no associated imu data
                    break
                avail_cam_num += 1
            if avail_cam_num == len(self.cam_names):
                target_scenes.append(train_scene)
            else:
                skipped_scene.append(train_scene['name'])
        print('skipped scenes: {}'.format(len(skipped_scene)))
        return target_scenes
    
    def make_train_dataset(self):
        train_scenes, val_scenes = self.get_available_scene_tokens()
        target_train_scenes = self.filter_dataset(train_scenes)
        
        self.samples, self.weights = [], []
        for idx, train_scene in enumerate(target_train_scenes):
            
            # select camera one by one
            cam_name = self.cam_names[idx % len(self.cam_names)]
            
            # collect samples and weights                
            scene_sample_data, scene_imu_data = self.get_scene_data(train_scene, cam_name)
            scene_training_inputs = self.format_scene_inputs(scene_sample_data, scene_imu_data)
            scene_samples, scene_weights = self.segment_training_inputs(scene_training_inputs)
            self.samples.extend(scene_samples)
            self.weights.extend(scene_weights)
        
        print('total samples: {}'.format(len(self.samples)))
        assert len(self.samples) == len(self.weights)
    
    def get_val_dataset(self):
        _, val_scenes = self.get_available_scene_tokens()
        target_val_scenes = self.filter_dataset(val_scenes)
        
        total_samples_num = 0
        val_scene_datasets = []
        for idx, val_scene in enumerate(target_val_scenes):
            img_path_list, pose_rel_list, imu_list = [], [], []
            
            """
            TODO
            camera to ego transformation을 고려해야 하는지?
            """
            # select camera one by one
            cam_name = self.cam_names[idx % len(self.cam_names)]
            # cam_name = "CAM_FRONT"
            
            scene_sample_data, scene_imu_data = self.get_scene_data(val_scene, cam_name)
            scene_val_inputs = self.format_scene_inputs(scene_sample_data, scene_imu_data)
            img_samples, pose_rel_samples, imu_samples = self.segment_val_inputs(scene_val_inputs)
            
            img_path_list.extend(img_samples)
            pose_rel_list.extend(pose_rel_samples)
            imu_list.extend(imu_samples)
            
            total_samples_num += len(img_path_list)
            
            val_scene_datasets.append(NuScenes_Val_Dataset(img_path_list, pose_rel_list, imu_list, self.args))
            
            # TEMP
            # if idx == 2:
            #     break

        print('total samples: {}'.format(total_samples_num))
        
        return val_scene_datasets
    
    # the Dataset class implementation only works for training set
    def __getitem__(self, index):
        sample = self.samples[index]
        imgs = [np.asarray(Image.open(img)) for img in sample['imgs']]
        
        if self.transform is not None:
            # imgs, imus, gts = self.transform(imgs, np.copy(sample['imus']), np.copy(sample['gts']))
            imgs, imus, gts = self.transform(imgs, np.copy(sample['imus']).astype(np.float32), np.copy(sample['gts']).astype(np.float32))
        else:
            imus = np.copy(sample['imus'])
            gts = np.copy(sample['gts']).astype(np.float32)
        
        rot = sample['rot'].astype(np.float32)
        weight = self.weights[index]

        return imgs, imus, gts, rot, weight

    def __len__(self):
        return len(self.samples)

---

In [110]:
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"  # Arrange GPU devices starting from 0
os.environ["CUDA_VISIBLE_DEVICES"]= "7"  # Set the GPUs 2 and 3 to use
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = '7'
# device = 'cuda:2'

In [4]:
#########################################################################################
dataroot = '/data/public/360_3D_OD_Dataset/nuscenes'
cam_names = ["CAM_FRONT", "CAM_FRONT_RIGHT", "CAM_BACK_RIGHT", "CAM_BACK", "CAM_BACK_LEFT", "CAM_FRONT_LEFT"]
#########################################################################################

nusc_can = NuScenesCanBus(dataroot=dataroot)
nusc = NuScenes(version='v1.0-trainval', dataroot=dataroot, verbose=False)

In [120]:
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--save_dir', type=str, default='./results', help='path to save the result')
parser.add_argument('--seed', type=int, default=0, help='random seed')

# jeho
# parser.add_argument('--img_w', type=int, default=512, help='image width')
# parser.add_argument('--img_h', type=int, default=256, help='image height')
parser.add_argument('--img_w', type=int, default=448, help='image width')
parser.add_argument('--img_h', type=int, default=256, help='image height')

parser.add_argument('--v_f_len', type=int, default=512, help='visual feature length')
parser.add_argument('--i_f_len', type=int, default=256, help='imu feature length')
parser.add_argument('--fuse_method', type=str, default='cat', help='fusion method [cat, soft, hard]')
parser.add_argument('--imu_dropout', type=float, default=0, help='dropout for the IMU encoder')
parser.add_argument('--rnn_hidden_size', type=int, default=1024, help='size of the LSTM latent')
parser.add_argument('--rnn_dropout_out', type=float, default=0.2, help='dropout for the LSTM output layer')
parser.add_argument('--rnn_dropout_between', type=float, default=0.2, help='dropout within LSTM')
parser.add_argument('--weight_decay', type=float, default=5e-6, help='weight decay for the optimizer')

parser.add_argument('--seq_len', type=int, default=11, help='sequence length for LSTM')
parser.add_argument('--workers', type=int, default=4, help='number of workers')

# jeho
# NuScenes - 68,000 training samples, total 25 epochs -> 1,700,000 iterations assuming batch size 1
parser.add_argument('--batch_size', type=int, default=64, help='batch size')
parser.add_argument('--epochs_warmup', type=int, default=5, help='number of epochs for warmup')
parser.add_argument('--epochs_joint', type=int, default=15, help='number of epochs for joint training')
parser.add_argument('--epochs_fine', type=int, default=5, help='number of epochs for finetuning')

# KITTI - 17,000 training samples, total 100 epochs -> 1,700,000 iterations assuming batch size 1
# parser.add_argument('--epochs_warmup', type=int, default=40, help='number of epochs for warmup')
# parser.add_argument('--epochs_joint', type=int, default=40, help='number of epochs for joint training')
# parser.add_argument('--epochs_fine', type=int, default=20, help='number of epochs for finetuning')


parser.add_argument('--lr_warmup', type=float, default=5e-4, help='learning rate for warming up stage')
parser.add_argument('--lr_joint', type=float, default=5e-5, help='learning rate for joint training stage')
parser.add_argument('--lr_fine', type=float, default=1e-6, help='learning rate for finetuning stage')
parser.add_argument('--eta', type=float, default=0.05, help='exponential decay factor for temperature')
parser.add_argument('--temp_init', type=float, default=5, help='initial temperature for gumbel-softmax')
parser.add_argument('--Lambda', type=float, default=3e-5, help='penalty factor for the visual encoder usage')

parser.add_argument('--experiment_name', type=str, default='experiment', help='experiment name')
parser.add_argument('--optimizer', type=str, default='Adam', help='type of optimizer [Adam, SGD]')

parser.add_argument('--pretrain_flownet',type=str, default='./pretrained_models/flownets_bn_EPE2.459.pth.tar', help='wehther to use the pre-trained flownet')
parser.add_argument('--pretrain', type=str, default=None, help='path to the pretrained model')
parser.add_argument('--hflip', default=False, action='store_true', help='whether to use horizonal flipping as augmentation')
parser.add_argument('--color', default=False, action='store_true', help='whether to use color augmentations')

parser.add_argument('--print_frequency', type=int, default=10, help='print frequency for loss values')
parser.add_argument('--weighted', default=False, action='store_true', help='whether to use weighted sum')

args = parser.parse_args(args=[])

In [161]:
from tqdm import tqdm 

def plotPath_2D(seq, poses_gt_mat, poses_est_mat, plot_path_dir, decision, speed, window_size):
    
    # Apply smoothing to the decision
    decision = np.insert(decision, 0, 1)
    decision = moving_average(decision, window_size)

    fontsize_ = 10
    plot_keys = ["Ground Truth", "Ours"]
    start_point = [0, 0]
    style_pred = 'b-'
    style_gt = 'r-'
    style_O = 'ko'

    # get the value
    x_gt = np.asarray([pose[0, 3] for pose in poses_gt_mat])
    y_gt = np.asarray([pose[1, 3] for pose in poses_gt_mat])
    z_gt = np.asarray([pose[2, 3] for pose in poses_gt_mat])

    x_pred = np.asarray([pose[0, 3] for pose in poses_est_mat])
    y_pred = np.asarray([pose[1, 3] for pose in poses_est_mat])
    z_pred = np.asarray([pose[2, 3] for pose in poses_est_mat])

    # Plot 2d trajectory estimation map
    fig = plt.figure(figsize=(6, 6), dpi=100)
    ax = plt.gca()
    plt.plot(x_gt, z_gt, style_gt, label=plot_keys[0])
    plt.plot(x_pred, z_pred, style_pred, label=plot_keys[1])
    plt.plot(start_point[0], start_point[1], style_O, label='Start Point')
    plt.legend(loc="upper right", prop={'size': fontsize_})
    plt.xlabel('x (m)', fontsize=fontsize_)
    plt.ylabel('z (m)', fontsize=fontsize_)
    # set the range of x and y
    xlim = ax.get_xlim()
    ylim = ax.get_ylim()
    xmean = np.mean(xlim)
    ymean = np.mean(ylim)
    plot_radius = max([abs(lim - mean_)
                       for lims, mean_ in ((xlim, xmean),
                                           (ylim, ymean))
                       for lim in lims])
    ax.set_xlim([xmean - plot_radius, xmean + plot_radius])
    ax.set_ylim([ymean - plot_radius, ymean + plot_radius])

    plt.title('2D path')
    png_title = "{}_path_2d".format(seq)
    plt.savefig(plot_path_dir + "/" + png_title + ".png", bbox_inches='tight', pad_inches=0.1)
    plt.close()


    # Plot 2d xy trajectory estimation map
    fig = plt.figure(figsize=(6, 6), dpi=100)
    ax = plt.gca()
    plt.plot(x_gt, y_gt, style_gt, label=plot_keys[0])
    plt.plot(x_pred, y_pred, style_pred, label=plot_keys[1])
    plt.plot(start_point[0], start_point[1], style_O, label='Start Point')
    plt.legend(loc="upper right", prop={'size': fontsize_})
    plt.xlabel('x (m)', fontsize=fontsize_)
    plt.ylabel('y (m)', fontsize=fontsize_)
    # set the range of x and y
    xlim = ax.get_xlim()
    ylim = ax.get_ylim()
    xmean = np.mean(xlim)
    ymean = np.mean(ylim)
    plot_radius = max([abs(lim - mean_)
                       for lims, mean_ in ((xlim, xmean),
                                           (ylim, ymean))
                       for lim in lims])
    ax.set_xlim([xmean - plot_radius, xmean + plot_radius])
    ax.set_ylim([ymean - plot_radius, ymean + plot_radius])

    plt.title('2D path')
    png_title = "{}_path_2d_xy".format(seq)
    plt.savefig(plot_path_dir + "/" + png_title + ".png", bbox_inches='tight', pad_inches=0.1)
    plt.close()
    
    # 3D trajectory map 
    fig = plt.figure(figsize=(6, 6), dpi=100)
    ax = fig.add_subplot(111, projection='3d')
    ax.plot(x_gt, y_gt, z_gt, style_gt, label=plot_keys[0])
    ax.plot(x_pred, y_pred, z_pred, style_pred, label=plot_keys[1])
    ax.plot(0, 0, 0, style_O, label='Start Point')
    plt.legend(loc="upper right", prop={'size': fontsize_})
    plt.xlabel('x (m)', fontsize=fontsize_)
    plt.ylabel('y (m)', fontsize=fontsize_)
    ax.set_zlabel('z (m)')
    
    plt.title('3D path')
    png_title = "{}_path_3d".format(seq)
    plt.savefig(plot_path_dir + "/" + png_title + ".png", bbox_inches='tight', pad_inches=0.1)
    plt.close()
    


    # Plot decision hearmap
    fig = plt.figure(figsize=(8, 6), dpi=100)
    ax = plt.gca()
    cout = np.insert(decision, 0, 0) * 100
    cax = plt.scatter(x_pred, z_pred, marker='o', c=cout)
    plt.xlabel('x (m)', fontsize=fontsize_)
    plt.ylabel('z (m)', fontsize=fontsize_)
    xlim = ax.get_xlim()
    ylim = ax.get_ylim()
    xmean = np.mean(xlim)
    ymean = np.mean(ylim)
    ax.set_xlim([xmean - plot_radius, xmean + plot_radius])
    ax.set_ylim([ymean - plot_radius, ymean + plot_radius])
    max_usage = max(cout)
    min_usage = min(cout)
    ticks = np.floor(np.linspace(min_usage, max_usage, num=5))
    cbar = fig.colorbar(cax, ticks=ticks)
    cbar.ax.set_yticklabels([str(i) + '%' for i in ticks])

    plt.title('decision heatmap with window size {}'.format(window_size))
    png_title = "{}_decision_smoothed".format(seq)
    plt.savefig(plot_path_dir + "/" + png_title + ".png", bbox_inches='tight', pad_inches=0.1)
    plt.close()

    # Plot the speed map
    fig = plt.figure(figsize=(8, 6), dpi=100)
    ax = plt.gca()
    cout = speed
    cax = plt.scatter(x_pred, z_pred, marker='o', c=cout)
    plt.xlabel('x (m)', fontsize=fontsize_)
    plt.ylabel('z (m)', fontsize=fontsize_)
    xlim = ax.get_xlim()
    ylim = ax.get_ylim()
    xmean = np.mean(xlim)
    ymean = np.mean(ylim)
    ax.set_xlim([xmean - plot_radius, xmean + plot_radius])
    ax.set_ylim([ymean - plot_radius, ymean + plot_radius])
    max_speed = max(cout)
    min_speed = min(cout)
    ticks = np.floor(np.linspace(min_speed, max_speed, num=5))
    cbar = fig.colorbar(cax, ticks=ticks)
    cbar.ax.set_yticklabels([str(i) + 'm/s' for i in ticks])

    plt.title('speed heatmap')
    png_title = "{}_speed".format(seq)
    plt.savefig(plot_path_dir + "/" + png_title + ".png", bbox_inches='tight', pad_inches=0.1)
    plt.close()


def kitti_err_cal(pose_est_mat, pose_gt_mat):

    # metric lengths in meters
    
    lengths = [100, 200, 300, 400, 500, 600, 700, 800]
    # lengths = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100]
    
    num_lengths = len(lengths)

    err = []
    dist, speed = trajectoryDistances(pose_gt_mat)
    step_size = 10  # 10Hz

    for first_frame in range(0, len(pose_gt_mat), step_size):

        calculated_metric_length = 0
        
        for i in range(num_lengths):
            metric_length = lengths[i]
            last_frame = lastFrameFromSegmentLength(dist, first_frame, metric_length)
            # Continue if sequence not long enough
            if last_frame == -1 or last_frame >= len(pose_est_mat) or first_frame >= len(pose_est_mat):
                continue
            
            calculated_metric_length += 1

            pose_delta_gt = np.dot(np.linalg.inv(pose_gt_mat[first_frame]), pose_gt_mat[last_frame])
            pose_delta_result = np.dot(np.linalg.inv(pose_est_mat[first_frame]), pose_est_mat[last_frame])
            
            r_err = rotationError(pose_delta_result, pose_delta_gt)
            t_err = translationError(pose_delta_result, pose_delta_gt)

            err.append([first_frame, r_err / metric_length, t_err / metric_length, metric_length])

        # print("calculated_metric_length: ", calculated_metric_length)
        
    t_rel, r_rel = computeOverallErr(err)
    
    # print("t_rel: ", t_rel)
    # print("r_rel: ", r_rel)
    
    return err, t_rel, r_rel, np.asarray(speed)

def kitti_eval(pose_est, dec_est, pose_gt):
    
    # First decision is always true
    dec_est = np.insert(dec_est, 0, 1)
    
    # Calculate the translational and rotational RMSE
    t_rmse, r_rmse = rmse_err_cal(pose_est, pose_gt)

    # Transfer to 3x4 pose matrix
    pose_est_mat = path_accu(pose_est)
    pose_gt_mat = path_accu(pose_gt)

    # Using KITTI metric
    err_list, t_rel, r_rel, speed = kitti_err_cal(pose_est_mat, pose_gt_mat)
    
    """
    여기 아닌것 같음 - Kitti로도 찍어보자
    찍어보니 맞음 - kitti와 nuscenes와의 coord. system 차이에 의해 발생?
    """
    t_rel = t_rel * 100
    r_rel = r_rel / np.pi * 180 * 100
    r_rmse = r_rmse / np.pi * 180
    usage = np.mean(dec_est) * 100

    return pose_est_mat, pose_gt_mat, t_rel, r_rel, t_rmse, r_rmse, usage, speed

class NuScenes_Tester():
    def __init__(self, val_scene_datasets):
        super(NuScenes_Tester, self).__init__()
        self.val_scene_datasets = val_scene_datasets
    
    def test_one_scene(self, model, scene_dataset, selection, num_gpu=1, p=0.5):
        hc = None
        pose_list, decision_list, probs_list, pose_rel_gt_list = [], [], [], []
        for i, (image_seq, imu_seq, gt_seq) in tqdm(enumerate(scene_dataset), total=len(scene_dataset), smoothing=0.9):
            x_in = image_seq.unsqueeze(0).repeat(num_gpu,1,1,1,1).cuda()
            i_in = imu_seq.unsqueeze(0).repeat(num_gpu,1,1).cuda()
            with torch.no_grad():
                pose, decision, probs, hc = model(x_in, i_in, is_first=(i==0), hc=hc, selection=selection, p=p)
            pose_list.append(pose[0,:,:].detach().cpu().numpy())
            decision_list.append(decision[0,:,:].detach().cpu().numpy()[:, 0])
            probs_list.append(probs[0,:,:].detach().cpu().numpy())
            pose_rel_gt_list.append(np.array(gt_seq))
        pose_est = np.vstack(pose_list)
        dec_est = np.hstack(decision_list)
        prob_est = np.vstack(probs_list)
        pose_rel_gt_list = np.vstack(pose_rel_gt_list)
        return pose_est, dec_est, prob_est, pose_rel_gt_list
    
    def eval(self, model, selection, num_gpu=1, p=0.5):
        self.errors = []
        self.est = []

        for i, scene_dataset in enumerate(self.val_scene_datasets):
            print(f'testing sequence {i}')
            
            pose_est, dec_est, prob_est, pose_rel_gt_list = self.test_one_scene(model, scene_dataset, selection, num_gpu=num_gpu, p=p)  
            pose_est_global, pose_gt_global, t_rel, r_rel, t_rmse, r_rmse, usage, speed = kitti_eval(pose_est, dec_est, pose_rel_gt_list)
            
            self.est.append({'pose_est_global':pose_est_global, 'pose_gt_global':pose_gt_global, 'decs':dec_est, 'probs':prob_est, 'speed':speed})
            self.errors.append({'t_rel':t_rel, 'r_rel':r_rel, 't_rmse':t_rmse, 'r_rmse':r_rmse, 'usage':usage})
        
        return self.errors
    
    def generate_plots(self, save_dir, window_size):
        for i, scene_dataset in enumerate(self.val_scene_datasets):
            plotPath_2D(scene_dataset, 
                        self.est[i]['pose_gt_global'], 
                        self.est[i]['pose_est_global'], 
                        save_dir, 
                        self.est[i]['decs'], 
                        self.est[i]['speed'], 
                        window_size)
            
    def save_text(self, save_dir):
        for i, scene_dataset in enumerate(self.val_scene_datasets):
            path = save_dir/'{}_pred.txt'.format(scene_dataset)
            saveSequence(self.est[i]['pose_est_global'], path)
            print('scene_dataset {} saved'.format(scene_dataset))

In [164]:
class data_partition():
    def __init__(self, opt, folder):
        super(data_partition, self).__init__()
        self.opt = opt
        self.data_dir = opt.data_dir
        self.seq_len = opt.seq_len
        self.folder = folder
        self.load_data()

    def load_data(self):
        image_dir = self.data_dir + '/sequences/'
        imu_dir = self.data_dir + '/imus/'
        pose_dir = self.data_dir + '/poses/'

        self.img_paths = glob.glob('{}{}/image_2/*.png'.format(image_dir, self.folder))
        self.imus = sio.loadmat('{}{}.mat'.format(imu_dir, self.folder))['imu_data_interp']
        self.poses, self.poses_rel = read_pose_from_text('{}{}.txt'.format(pose_dir, self.folder))
        self.img_paths.sort()

        self.img_paths_list, self.poses_list, self.imus_list = [], [], []
        start = 0
        n_frames = len(self.img_paths)
        while start + self.seq_len < n_frames:
            self.img_paths_list.append(self.img_paths[start:start + self.seq_len])
            self.poses_list.append(self.poses_rel[start:start + self.seq_len - 1])
            self.imus_list.append(self.imus[start * 10:(start + self.seq_len - 1) * 10 + 1])
            start += self.seq_len - 1
        self.img_paths_list.append(self.img_paths[start:])
        self.poses_list.append(self.poses_rel[start:])
        self.imus_list.append(self.imus[start * 10:])

    def __len__(self):
        return len(self.img_paths_list)

    def __getitem__(self, i):
        image_path_sequence = self.img_paths_list[i]
        image_sequence = []
        for img_path in image_path_sequence:
            img_as_img = Image.open(img_path)
            img_as_img = TF.resize(img_as_img, size=(self.opt.img_h, self.opt.img_w))
            img_as_tensor = TF.to_tensor(img_as_img) - 0.5
            img_as_tensor = img_as_tensor.unsqueeze(0)
            image_sequence.append(img_as_tensor)
        image_sequence = torch.cat(image_sequence, 0)
        imu_sequence = torch.FloatTensor(self.imus_list[i])
        gt_sequence = self.poses_list[i][:, :6]
        return image_sequence, imu_sequence, gt_sequence


class KITTI_tester():
    def __init__(self, args):
        super(KITTI_tester, self).__init__()
        
        # generate data loader for each path
        self.dataloader = []
        for seq in args.val_seq:
            self.dataloader.append(data_partition(args, seq))

        self.args = args
    
    def test_one_path(self, net, df, selection, num_gpu=1, p=0.5):
        hc = None
        pose_list, decision_list, probs_list= [], [], []
        for i, (image_seq, imu_seq, gt_seq) in tqdm(enumerate(df), total=len(df), smoothing=0.9):  
            x_in = image_seq.unsqueeze(0).repeat(num_gpu,1,1,1,1).cuda()
            i_in = imu_seq.unsqueeze(0).repeat(num_gpu,1,1).cuda()
            with torch.no_grad():
                pose, decision, probs, hc = net(x_in, i_in, is_first=(i==0), hc=hc, selection=selection, p=p)
            pose_list.append(pose[0,:,:].detach().cpu().numpy())
            decision_list.append(decision[0,:,:].detach().cpu().numpy()[:, 0])
            probs_list.append(probs[0,:,:].detach().cpu().numpy())
        pose_est = np.vstack(pose_list)
        dec_est = np.hstack(decision_list)
        prob_est = np.vstack(probs_list)        
        return pose_est, dec_est, prob_est

    def eval(self, net, selection, num_gpu=1, p=0.5):
        self.errors = []
        self.est = []
        for i, seq in enumerate(self.args.val_seq):
            print(f'testing sequence {seq}')
            pose_est, dec_est, prob_est = self.test_one_path(net, self.dataloader[i], selection, num_gpu=num_gpu, p=p)            
            pose_est_global, pose_gt_global, t_rel, r_rel, t_rmse, r_rmse, usage, speed = kitti_eval(pose_est, dec_est, self.dataloader[i].poses_rel)
            
            self.est.append({'pose_est_global':pose_est_global, 'pose_gt_global':pose_gt_global, 'decs':dec_est, 'probs':prob_est, 'speed':speed})
            self.errors.append({'t_rel':t_rel, 'r_rel':r_rel, 't_rmse':t_rmse, 'r_rmse':r_rmse, 'usage':usage})
            
        return self.errors

    def generate_plots(self, save_dir, window_size):
        for i, seq in enumerate(self.args.val_seq):
            plotPath_2D(seq, 
                        self.est[i]['pose_gt_global'], 
                        self.est[i]['pose_est_global'], 
                        save_dir, 
                        self.est[i]['decs'], 
                        self.est[i]['speed'], 
                        window_size)
    
    def save_text(self, save_dir):
        for i, seq in enumerate(self.args.val_seq):
            path = save_dir/'{}_pred.txt'.format(seq)
            saveSequence(self.est[i]['pose_est_global'], path)
            print('Seq {} saved'.format(seq))

In [163]:
mmcv.mkdir_or_exist(args.save_dir)
checkpoints_dir = os.path.join(args.save_dir, "experiment_1")
mmcv.mkdir_or_exist(checkpoints_dir)

# Create logs
logger = logging.getLogger(args.experiment_name)
logger.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger.info('----------------------------------------TRAINING----------------------------------')
logger.info('PARAMETER ...')
logger.info(args)

# Load the dataset
transform_train = [custom_transform.ToTensor(), custom_transform.Resize((args.img_h, args.img_w))]
if args.hflip:
    transform_train += [custom_transform.RandomHorizontalFlip()]
if args.color:
    transform_train += [custom_transform.RandomColorAug()]
transform_train = custom_transform.Compose(transform_train)

##############################################################
max_imu_length = 11 # KITTI
##############################################################

val_dataset = NuScenes_Dataset(dataroot,
                                 mode='val',
                             sequence_length=args.seq_len,
                             max_imu_length=max_imu_length,
                             cam_names=cam_names,
                             transform=transform_train,
                             nusc=nusc,
                             nusc_can=nusc_can,
                             args=args)

val_scene_datasets = val_dataset.get_val_dataset()

AttributeError: 'Namespace' object has no attribute 'hflip'

In [165]:
# GPU selections
str_ids = device.split(',')
gpu_ids = []
for str_id in str_ids:
    id = int(str_id)
    if id >= 0:
        gpu_ids.append(id)
if len(gpu_ids) > 0:
    torch.cuda.set_device(gpu_ids[0])

# Initialize the tester
tester = NuScenes_Tester(val_scene_datasets)

# Model initialization
model = DeepVIO(args)

ckpt_path = './pretrained_models/vf_512_if_256_3e-05.model'
model.load_state_dict(torch.load(ckpt_path))
print('load model %s'%ckpt_path)

# Feed model to GPU
model.cuda(gpu_ids[0])
model = torch.nn.DataParallel(model, device_ids = gpu_ids)

model.eval()
errors = tester.eval(model, 'gumbel-softmax', num_gpu=len(gpu_ids))

load model ./pretrained_models/vf_512_if_256_3e-05.model
testing sequence 0


100%|██████████| 23/23 [00:10<00:00,  2.15it/s]


t_rel:  1.4327209127773595
r_rel:  0.004775360509008873
testing sequence 1


100%|██████████| 22/22 [00:08<00:00,  2.59it/s]


t_rel:  1.7356491427690386
r_rel:  0.005148786236864521
testing sequence 2


100%|██████████| 22/22 [00:09<00:00,  2.35it/s]

t_rel:  1.4624186933996497
r_rel:  0.004441598131854362





In [166]:
tester.generate_plots('./results/experiment_1/nuscenes/', 30)

In [152]:
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--data_dir', type=str, default='./data/kitti/', help='path to the dataset')
parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0  0,1,2, 0,2. use -1 for CPU')
parser.add_argument('--save_dir', type=str, default='./results', help='path to save the result')
parser.add_argument('--seq_len', type=int, default=11, help='sequence length for LSTM')

parser.add_argument('--train_seq', type=list, default=['00', '01', '02', '04', '06', '08', '09'], help='sequences for training')
parser.add_argument('--val_seq', type=list, default=['07'], help='sequences for validation')
parser.add_argument('--seed', type=int, default=0, help='random seed')

parser.add_argument('--img_w', type=int, default=512, help='image width')
parser.add_argument('--img_h', type=int, default=256, help='image height')
parser.add_argument('--v_f_len', type=int, default=512, help='visual feature length')
parser.add_argument('--i_f_len', type=int, default=256, help='imu feature length')
parser.add_argument('--fuse_method', type=str, default='cat', help='fusion method [cat, soft, hard]')
parser.add_argument('--imu_dropout', type=float, default=0, help='dropout for the IMU encoder')

parser.add_argument('--rnn_hidden_size', type=int, default=1024, help='size of the LSTM latent')
parser.add_argument('--rnn_dropout_out', type=float, default=0.2, help='dropout for the LSTM output layer')
parser.add_argument('--rnn_dropout_between', type=float, default=0.2, help='dropout within LSTM')

parser.add_argument('--workers', type=int, default=4, help='number of workers')
parser.add_argument('--experiment_name', type=str, default='test', help='experiment name')
parser.add_argument('--model', type=str, default='./pretrain_models/vf_512_if_256_3e-05.model', help='path to the pretrained model')
args = parser.parse_args(args=[])


In [153]:
 # Initialize the tester
tester = KITTI_tester(args)

In [154]:
errors = tester.eval(model, 'gumbel-softmax', num_gpu=len(gpu_ids))

testing sequence 07


100%|██████████| 110/110 [00:24<00:00,  4.51it/s]

t_rel:  0.017259536462358335
r_rel:  0.00012587293560097494





In [167]:
errors

[{'t_rel': 143.27209127773594,
  'r_rel': 27.360800281965293,
  't_rmse': 1.1589466929184404,
  'r_rmse': 0.7952650806784035,
  'usage': 18.26086938381195},
 {'t_rel': 173.56491427690386,
  'r_rel': 29.500372098738243,
  't_rmse': 1.1845246747014186,
  'r_rmse': 0.7671784817745677,
  'usage': 32.72727131843567},
 {'t_rel': 146.24186933996498,
  'r_rel': 25.448482724844588,
  't_rmse': 1.1813577173057124,
  'r_rmse': 0.8037528325557416,
  'usage': 17.72727221250534}]

In [157]:
tester.generate_plots("./results/experiment_1/", 30)

In [None]:
# GPU selections
str_ids = device.split(',')
gpu_ids = []
for str_id in str_ids:
    id = int(str_id)
    if id >= 0:
        gpu_ids.append(id)
if len(gpu_ids) > 0:
    torch.cuda.set_device(gpu_ids[0])

# Initialize the tester
tester = KITTI_tester(args)
    
# Model initialization
model = DeepVIO(args)

# Continual training or not
if args.pretrain is not None:
    model.load_state_dict(torch.load(args.pretrain))
    print('load model %s'%args.pretrain)
    logger.info('load model %s'%args.pretrain)
else:
    print('Training from scratch')
    logger.info('Training from scratch')

# Use the pre-trained flownet or not
if args.pretrain_flownet and args.pretrain is None:
    pretrained_w = torch.load(args.pretrain_flownet, map_location='cpu')
    model_dict = model.Feature_net.state_dict()
    update_dict = {k: v for k, v in pretrained_w['state_dict'].items() if k in model_dict}
    model_dict.update(update_dict)
    model.Feature_net.load_state_dict(model_dict)

# Feed model to GPU
# model.to(device)
# model = torch.nn.DataParallel(model, device_ids = [device])

# model = model.cuda()
model.cuda(gpu_ids[0])
model = torch.nn.DataParallel(model, device_ids = gpu_ids)

pretrain = args.pretrain 
if args.pretrain is None or pretrain[-5:] == 'model':
    init_epoch = 0
else:
    init_epoch = int(pretrain[-7:-4])+1

# Initialize the optimizer
if args.optimizer == 'SGD':
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
elif args.optimizer == 'Adam':
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, betas=(0.9, 0.999), 
                                    eps=1e-08, weight_decay=args.weight_decay)

best = 10000

for ep in range(init_epoch, args.epochs_warmup+args.epochs_joint+args.epochs_fine):
    lr, selection, temp = update_status(ep, args, model)
    optimizer.param_groups[0]['lr'] = lr
    message = f'Epoch: {ep}, lr: {lr}, selection: {selection}, temperaure: {temp:.5f}'
    print(message)
    logger.info(message)
    
    model.train()
    avg_pose_loss, avg_penalty_loss = train(model, optimizer, train_loader, selection, temp, logger, ep, p=0.5)
    
    if ep > args.epochs_warmup+args.epochs_joint:
        # Save the model after training
        torch.save(model.module.state_dict(), f'{checkpoints_dir}/{ep:003}.pth')
        message = f'Epoch {ep} training finished, pose loss: {avg_pose_loss:.6f}, penalty_loss: {avg_penalty_loss:.6f}, model saved'
        print(message)
        logger.info(message)
    
        # Evaluate the model
        # print('Evaluating the model')
        # logger.info('Evaluating the model')
        # with torch.no_grad(): 
        #     model.eval()
        #     errors = tester.eval(model, selection='gumbel-softmax', num_gpu=len(gpu_ids))
    
        # t_rel = np.mean([errors[i]['t_rel'] for i in range(len(errors))])
        # r_rel = np.mean([errors[i]['r_rel'] for i in range(len(errors))])
        # t_rmse = np.mean([errors[i]['t_rmse'] for i in range(len(errors))])
        # r_rmse = np.mean([errors[i]['r_rmse'] for i in range(len(errors))])
        # usage = np.mean([errors[i]['usage'] for i in range(len(errors))])

        # if t_rel < best:
        #     best = t_rel 
        #     torch.save(model.module.state_dict(), f'{checkpoints_dir}/best_{best:.2f}.pth')
    
        # message = f'Epoch {ep} evaluation finished , t_rel: {t_rel:.4f}, r_rel: {r_rel:.4f}, t_rmse: {t_rmse:.4f}, r_rmse: {r_rmse:.4f}, usage: {usage:.4f}, best t_rel: {best:.4f}'
        # logger.info(message)
        # print(message)

message = f'Training finished, best t_rel: {best:.4f}'
logger.info(message)
print(message)