Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #16 from farrell236/master
se3 pose estimation example
- Loading branch information
Showing
4 changed files
with
575 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
# Geomstats Pose Estimation Example | ||
|
||
This example trains a pose estimation network using the SE3 Geodesic Loss function. | ||
|
||
## Requirements: | ||
|
||
* geomstats | ||
* imageio | ||
* numpy | ||
* scikit-image | ||
* tensorflow | ||
* tqdm | ||
|
||
## Generating Dataset | ||
|
||
This example uses the original King's College [dataset](http://mi.eng.cam.ac.uk/projects/relocalisation/#dataset) by Kendall et al. To create TFRecords of the dataset, run: | ||
|
||
``` | ||
python make_dataset_kingscollege.py \ | ||
--root_dir path/to/KingsCollege \ | ||
--dataset dataset_train.txt \ | ||
--out_file dataset_train.tfrecord | ||
``` | ||
|
||
This command needs to be run twice; once for ```dataset_train.txt``` and once for ```dataset_test.txt``` | ||
|
||
## Train Network | ||
|
||
To train the network run: | ||
|
||
``` | ||
python train_se3_kingscollege.py --dataset path/to/dataset_train.tfrecord | ||
``` | ||
|
||
Optional Parameters: | ||
|
||
``` | ||
Train SE3 PoseNet Inception v1 Model. | ||
optional arguments: | ||
-h, --help show this help message and exit | ||
--batch_size BATCH_SIZE | ||
Batch size to train. | ||
--init_lr INIT_LR Initial Learning rate. | ||
--max_iter MAX_ITER The number of iteration to train. | ||
--epsilon EPSILON Gradient Epsilon | ||
--snapshot SNAPSHOT Save model weights every X iterations | ||
--dataset DATASET Training dataset | ||
--model_dir MODEL_DIR | ||
The path to the model directory. | ||
--logs_path LOGS_PATH | ||
The path to the logs directory. | ||
--resume Resume training from previous saved checkpoint. | ||
--cuda CUDA Specify default GPU to use. | ||
--debug Enables debugging mode. | ||
``` | ||
|
||
|
||
## Evaluating | ||
|
||
To evaluate the trained model: | ||
|
||
``` | ||
python test_se3_kingscollege.py --dataset path/to/dataset_test.tfrecord --model_dir /path/to/model_saved_weights | ||
``` | ||
|
||
Optional Parameters: | ||
|
||
``` | ||
Test SE3 PoseNet Inception v1 Model. | ||
optional arguments: | ||
-h, --help show this help message and exit | ||
--model_dir MODEL_DIR | ||
The path to the model directory. | ||
--dataset DATASET The path to the TFRecords dataset. | ||
--cuda CUDA Specify default GPU to use. | ||
--debug Enables debugging mode. | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
''' | ||
Original | ||
File: create_posenet_lmdb_dataset.py | ||
Link: https://git.io/fpPuw | ||
Author: Alex Kendall <https://alexgkendall.com> | ||
Modified | ||
File: make_dataset_kingscollege.py | ||
Author: Benjamin Hou <bh1511@imperial.ac.ukm> | ||
Download KingsCollege Dataset: | ||
http://mi.eng.cam.ac.uk/projects/relocalisation/#dataset | ||
''' | ||
|
||
import argparse | ||
import imageio | ||
import logging | ||
import random | ||
import sys | ||
|
||
import numpy as np | ||
import tensorflow as tf | ||
|
||
from skimage import exposure | ||
from tqdm import tqdm | ||
from geomstats.special_orthogonal_group import SpecialOrthogonalGroup | ||
|
||
|
||
# command line argument parser | ||
ARGPARSER = argparse.ArgumentParser( | ||
description='Create KingsCollege TFRecords Dataset') | ||
ARGPARSER.add_argument( | ||
'--root_dir', required=True, type=str, | ||
help='Path to KingsCollege Dataset root directory.') | ||
ARGPARSER.add_argument( | ||
'--dataset', required=True, type=str, | ||
help='Dataset text file') | ||
ARGPARSER.add_argument( | ||
'--out_file', required=True, type=str, | ||
help='Path to save TFRecords') | ||
ARGPARSER.add_argument( | ||
'--hist_norm', default=False, action='store_true', | ||
help='Histogram normalise image') | ||
ARGPARSER.add_argument( | ||
'--verbose', default=False, action='store_true', | ||
help='Verbose mode') | ||
|
||
|
||
# Tensorflow feature wrapper | ||
def _bytes_feature(value): | ||
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) | ||
|
||
|
||
def _int64_feature(value): | ||
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) | ||
|
||
|
||
def _float_feature(value): | ||
return tf.train.Feature(float_list=tf.train.FloatList(value=[value])) | ||
|
||
|
||
def main(args): | ||
|
||
poses = [] | ||
images = [] | ||
|
||
# Processing Image Lables | ||
logger.info('Processing Image Lables') | ||
with open(FLAGS.root_dir + '/' + FLAGS.dataset) as f: | ||
next(f) # skip the 3 header lines | ||
next(f) | ||
next(f) | ||
for line in f: | ||
fname, p0, p1, p2, p3, p4, p5, p6 = line.split() | ||
p0 = float(p0) | ||
p1 = float(p1) | ||
p2 = float(p2) | ||
p3 = float(p3) | ||
p4 = float(p4) | ||
p5 = float(p5) | ||
p6 = float(p6) | ||
poses.append((p0, p1, p2, p3, p4, p5, p6)) | ||
images.append(FLAGS.root_dir + '/' + fname) | ||
|
||
r = list(range(len(images))) | ||
random.shuffle(r) | ||
random.shuffle(r) | ||
random.shuffle(r) | ||
|
||
# Writing TFRecords | ||
logger.info('Writing TFRecords') | ||
|
||
SO3_GROUP = SpecialOrthogonalGroup(3) | ||
writer = tf.python_io.TFRecordWriter(FLAGS.out_file) | ||
|
||
for i in tqdm(r): | ||
|
||
pose_q = np.array(poses[i][3:7]) | ||
pose_x = np.array(poses[i][0:3]) | ||
|
||
rot_vec = SO3_GROUP.rotation_vector_from_quaternion(pose_q)[0] | ||
pose = np.concatenate((rot_vec, pose_x), axis=0) | ||
|
||
logger.info('Processing Image: ' + images[i]) | ||
X = imageio.imread(images[i]) | ||
X = X[::4, ::4, :] | ||
if FLAGS.hist_norm: | ||
X = exposure.equalize_hist(X) | ||
|
||
img_raw = X.tostring() | ||
pose_raw = pose.astype('float32').tostring() | ||
pose_q_raw = pose_q.astype('float32').tostring() | ||
pose_x_raw = pose_x.astype('float32').tostring() | ||
|
||
example = tf.train.Example(features=tf.train.Features(feature={ | ||
'height': _int64_feature(X.shape[0]), | ||
'width': _int64_feature(X.shape[1]), | ||
'channel': _int64_feature(X.shape[2]), | ||
'image': _bytes_feature(img_raw), | ||
'pose': _bytes_feature(pose_raw), | ||
'pose_q': _bytes_feature(pose_q_raw), | ||
'pose_x': _bytes_feature(pose_x_raw)})) | ||
|
||
writer.write(example.SerializeToString()) | ||
|
||
writer.close() | ||
logger.info('\n', 'Creating Dataset Success.') | ||
|
||
|
||
if __name__ == '__main__': | ||
|
||
print('Generating KingsCollege Dataset.') | ||
FLAGS, UNPARSED_ARGV = ARGPARSER.parse_known_args() | ||
print('Dataset:', FLAGS.dataset) | ||
|
||
if FLAGS.verbose: | ||
logging.basicConfig(level=logging.INFO) | ||
else: | ||
logging.basicConfig(level=logging.WARNING) | ||
logger = logging.getLogger(__name__) | ||
|
||
main([sys.argv[0]] + UNPARSED_ARGV) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
''' | ||
Example Pose Estimation Network with SE3 loss function (Inference Script) | ||
Dataset: KingsCollege | ||
Network: Inception v1 | ||
Loss Function: Geomstats SE(3) Loss | ||
''' | ||
|
||
import argparse | ||
import sys | ||
import os | ||
os.environ['GEOMSTATS_BACKEND'] = 'tensorflow' # NOQA | ||
|
||
import geomstats.lie_group as lie_group | ||
import tensorflow as tf | ||
|
||
from geomstats.special_euclidean_group import SpecialEuclideanGroup | ||
from tensorflow.contrib.slim.python.slim.nets import inception | ||
|
||
|
||
# command line argument parser | ||
ARGPARSER = argparse.ArgumentParser( | ||
description='Test SE3 PoseNet Inception v1 Model.') | ||
ARGPARSER.add_argument( | ||
'--model_dir', type=str, default='./model', | ||
help='The path to the model directory.') | ||
ARGPARSER.add_argument( | ||
'--dataset', type=str, default='dataset_test.tfrecords', | ||
help='The path to the TFRecords dataset.') | ||
ARGPARSER.add_argument( | ||
'--cuda', type=str, default='0', | ||
help='Specify default GPU to use.') | ||
ARGPARSER.add_argument( | ||
'--debug', default=False, action='store_true', | ||
help="Enables debugging mode.") | ||
|
||
|
||
class PoseNetReader: | ||
|
||
def __init__(self, tfrecord_list): | ||
|
||
self.file_q = tf.train.string_input_producer( | ||
tfrecord_list, num_epochs=1) | ||
|
||
def read_and_decode(self): | ||
reader = tf.TFRecordReader() | ||
|
||
_, serialized_example = reader.read(self.file_q) | ||
|
||
features = tf.parse_single_example( | ||
serialized_example, | ||
features={ | ||
'image': tf.FixedLenFeature([], tf.string), | ||
'pose': tf.FixedLenFeature([], tf.string) | ||
}) | ||
|
||
image = tf.decode_raw(features['image'], tf.uint8) | ||
pose = tf.decode_raw(features['pose'], tf.float32) | ||
|
||
image = tf.reshape(image, (1, 480, 270, 3)) | ||
pose.set_shape((6)) | ||
|
||
# Random transformations can be put here: right before you crop images | ||
# to predefined size. To get more information look at the stackoverflow | ||
# question linked above. | ||
|
||
# image = tf.image.resize_images(image, size=[224, 224]) | ||
image = tf.image.resize_image_with_crop_or_pad(image=image, | ||
target_height=224, | ||
target_width=224) | ||
|
||
return image, pose | ||
|
||
|
||
def main(args): | ||
|
||
SE3_GROUP = SpecialEuclideanGroup(3) | ||
metric = SE3_GROUP.left_canonical_metric | ||
|
||
reader_train = PoseNetReader([FLAGS.dataset]) | ||
|
||
# Get Input Tensors | ||
image, y_true = reader_train.read_and_decode() | ||
|
||
# Construct model and encapsulating all ops into scopes, making | ||
# Tensorboard's Graph visualization more convenient | ||
print('Making Model') | ||
with tf.name_scope('Model'): | ||
py_x, _ = inception.inception_v1(tf.cast(image, tf.float32), | ||
num_classes=6, | ||
is_training=False) | ||
# tanh(pred_angle) required to prevent infinite spins on rotation axis | ||
y_pred = tf.concat((tf.nn.tanh(py_x[:, :3]), py_x[:, 3:]), axis=1) | ||
loss = tf.reduce_mean( | ||
lie_group.loss(y_pred, y_true, SE3_GROUP, metric)) | ||
|
||
print('Initializing Variables...') | ||
init_op = tf.group(tf.global_variables_initializer(), | ||
tf.local_variables_initializer()) | ||
|
||
# Main Testing Routine | ||
with tf.Session() as sess: | ||
# Run the initializer | ||
sess.run(init_op) | ||
|
||
# Start Queue Threads | ||
coord = tf.train.Coordinator() | ||
threads = tf.train.start_queue_runners(coord=coord) | ||
|
||
# Load saved weights | ||
print('Loading Trained Weights') | ||
saver = tf.train.Saver() | ||
latest_checkpoint = tf.train.latest_checkpoint(FLAGS.model_dir) | ||
saver.restore(sess, latest_checkpoint) | ||
|
||
i = 0 | ||
|
||
# Inference cycle | ||
try: | ||
while True: | ||
_y_pred, _y_true, _loss = sess.run([y_pred, y_true, loss]) | ||
print('Iteration:', i, 'loss:', _loss) | ||
print('_y_pred:', _y_pred) | ||
print('_y_true:', _y_true) | ||
print('\n') | ||
i = i + 1 | ||
|
||
except tf.errors.OutOfRangeError: | ||
print('End of Testing Data') | ||
|
||
except KeyboardInterrupt: | ||
print('KeyboardInterrupt!') | ||
|
||
finally: | ||
print('Stopping Threads') | ||
coord.request_stop() | ||
coord.join(threads) | ||
|
||
|
||
if __name__ == '__main__': | ||
|
||
print('Testing SE3 PoseNet Inception v1 Model.') | ||
FLAGS, UNPARSED_ARGV = ARGPARSER.parse_known_args() | ||
print('FLAGS:', FLAGS) | ||
|
||
# Set verbosity | ||
if FLAGS.debug: | ||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' | ||
tf.logging.set_verbosity(tf.logging.INFO) | ||
else: | ||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' | ||
tf.logging.set_verbosity(tf.logging.ERROR) | ||
|
||
# GPU allocation options | ||
os.environ["CUDA_VISIBLE_DEVICES"] = FLAGS.cuda | ||
|
||
tf.app.run(main=main, argv=[sys.argv[0]] + UNPARSED_ARGV) |
Oops, something went wrong.