In [2]:
''' Training Frustum PointNets.

Author: Charles R. Qi
Date: September 2017
'''
from __future__ import print_function

import os
import sys
import argparse
import importlib
import numpy as np
import tensorflow as tf
from datetime import datetime
BASE_DIR = os.path.abspath('') # train/
ROOT_DIR = os.path.dirname(BASE_DIR) # frustum-pointnets/
sys.path.append(BASE_DIR)
sys.path.append(os.path.join(ROOT_DIR, 'models')) # Allows directly importing models
import provider
from train_util import get_batch


In [5]:
# Set training configurations
EPOCH_CNT = 0
BATCH_SIZE = 32
NUM_POINT = 128
MAX_EPOCH = 201
BASE_LEARNING_RATE = 0.001
GPU_INDEX = 0
MOMENTUM = 0.9
OPTIMIZER = 'adam'
DECAY_STEP = 800000
DECAY_RATE = 0.5
NUM_CHANNEL = 4 # point feature channel
NUM_CLASSES = 2 # segmentation has two classes

model_name = 'frustum_pointnets_lite'
MODEL = importlib.import_module(model_name) # import network module
MODEL_FILE = os.path.join(ROOT_DIR, 'models', model_name+'.py')
LOG_DIR = os.path.join(ROOT_DIR, 'train', 'log_lite')
if not os.path.exists(LOG_DIR): os.mkdir(LOG_DIR)
os.system('cp %s %s' % (MODEL_FILE, LOG_DIR)) # bkp of model def
os.system('cp %s %s' % (os.path.join(BASE_DIR, 'train.py'), LOG_DIR))
LOG_FOUT = open(os.path.join(LOG_DIR, 'log_train.txt'), 'w')
LOG_FOUT.write('BATCH_SIZE = {}\n'.format(BATCH_SIZE))
LOG_FOUT.write('NUM_POINT = {}\n'.format(NUM_POINT))
LOG_FOUT.write('MAX_EPOCH = {}\n'.format(MAX_EPOCH))
LOG_FOUT.write('BASE_LEARNING_RATE = {}\n'.format(BASE_LEARNING_RATE))
LOG_FOUT.write('GPU_INDEX = {}\n'.format(GPU_INDEX))
LOG_FOUT.write('MOMENTUM = {}\n'.format(MOMENTUM))
LOG_FOUT.write('OPTIMIZER = {}\n'.format(OPTIMIZER))
LOG_FOUT.write('DECAY_STEP = {}\n'.format(DECAY_STEP))
LOG_FOUT.write('DECAY_RATE = {}\n'.format(DECAY_RATE))
LOG_FOUT.write('NUM_CHANNEL = {}\n'.format(NUM_CHANNEL))

BN_INIT_DECAY = 0.5
BN_DECAY_DECAY_RATE = 0.5
BN_DECAY_DECAY_STEP = float(DECAY_STEP)
BN_DECAY_CLIP = 0.99

In [7]:
RUN_MODES = ['KITTI', 'NUSC']
RUN_SEL = 1 # select which dataset to run <=======
if RUN_MODES[RUN_SEL]=='KITTI':
    # Load Frustum Datasets. Use default data paths.
    TRAIN_DATASET = provider.FrustumDataset(npoints=NUM_POINT, split='train',
        rotate_to_center=True, random_flip=True, random_shift=True, one_hot=True)
    TEST_DATASET = provider.FrustumDataset(npoints=NUM_POINT, split='val',
        rotate_to_center=True, one_hot=True)
else:
    train_file = os.path.join(ROOT_DIR, 'nuscenes', 'nusc_carpedtruck_train.pickle')
    val_file = os.path.join(ROOT_DIR, 'nuscenes', 'nusc_carpedtruck_val.pickle')
    TRAIN_DATASET = provider.FrustumDataset(npoints=NUM_POINT, split='train',
        rotate_to_center=True, overwritten_data_path=train_file, random_flip=True, random_shift=True, one_hot=True)
    TEST_DATASET = provider.FrustumDataset(npoints=NUM_POINT, split='val',
        rotate_to_center=True, overwritten_data_path=val_file, one_hot=True)