Skip to content

Commit

Permalink
collect losses in train. updated README
Browse files Browse the repository at this point in the history
  • Loading branch information
charlesq34 committed Feb 23, 2018
1 parent 8a8883e commit cce809f
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 24 deletions.
21 changes: 5 additions & 16 deletions README.md
Expand Up @@ -28,21 +28,6 @@ The TF operators are included under `tf_ops`, you need to compile them (check `t

There is also a handy point cloud visualization tool under `utils`, run `sh compile_render_balls_so.sh` to compile it and you can try the demo with `python show3d_balls.py` The original code is from <a href="://github.com/fanhqme/PointSetGeneration">here</a>.

### Point Cloud Data

Note: You can skip this step if you simply want to try the basic classification example (based on XYZ coordinates of points).

To use normal features for classification: You can get our sampled point clouds of ModelNet40 (XYZ and normal from mesh, 10k points per shape) at this <a href="https://1drv.ms/u/s!ApbTjxa06z9CgQfKl99yUDHL_wHs">OneDrive link</a>.

For object part segmetnation: You can get processed ShapeNetPart dataset (XYZ, normal and part labels) can be found <a href="https://1drv.ms/u/s!ApbTjxa06z9CgQnl-Qm6KI3Ywbe1">here</a>.

After successful downloads, uncompress zip files to the data folder:

data/modelnet40_normal_resampled
data/shapenetcore_partanno_segmentation_benchmark_v0_normal

so that training and testing scripts can successfully locate them.

### Usage

#### Shape Classification
Expand All @@ -61,14 +46,18 @@ If you have multiple GPUs on your machine, you can also run the multi-gpu versio

<i>Side Note:</i> For the XYZ+normal experiment reported in our paper: (1) 5000 points are used and (2) a further random data dropout augmentation is used during training (see commented line after `augment_batch_data` in `train.py` and (3) the model architecture is updated such that the `nsample=128` in the first two set abstraction levels, which is suited for the larger point density in 5000-point samplings.

To use normal features for classification: You can get our sampled point clouds of ModelNet40 (XYZ and normal from mesh, 10k points per shape) at this <a href="https://1drv.ms/u/s!ApbTjxa06z9CgQfKl99yUDHL_wHs">OneDrive link</a>. Move the uncompressed data folder to `data/modelnet40_normal_resampled`

#### Object Part Segmentation

To train a model to segment object parts for ShapeNet models:

cd part_seg
python train.py

#### Scene Parsing
You can get processed ShapeNetPart dataset (XYZ, normal and part labels) can be found <a href="https://1drv.ms/u/s!ApbTjxa06z9CgQnl-Qm6KI3Ywbe1">here</a>. Move the uncompressed data folder to `data/shapenetcore_partanno_segmentation_benchmark_v0_normal`

#### Semantic Scene Parsing

See README files and `scannet/train.py` for details.

Expand Down
10 changes: 8 additions & 2 deletions evaluate.py
@@ -1,3 +1,7 @@
'''
Evaluate classification performance with optional voting.
Will use H5 dataset in default. If using normal, will shift to the normal dataset.
'''
import tensorflow as tf
import numpy as np
import argparse
Expand Down Expand Up @@ -68,7 +72,9 @@ def evaluate(num_votes):

# simple model
pred, end_points = MODEL.get_model(pointclouds_pl, is_training_pl)
loss = MODEL.get_loss(pred, labels_pl, end_points)
MODEL.get_loss(pred, labels_pl, end_points)
losses = tf.get_collection('losses', scope)
total_loss = tf.add_n(losses, name='total_loss')

# Add ops to save and restore all the variables.
saver = tf.train.Saver()
Expand All @@ -88,7 +94,7 @@ def evaluate(num_votes):
'labels_pl': labels_pl,
'is_training_pl': is_training_pl,
'pred': pred,
'loss': loss}
'loss': total_loss}

eval_one_epoch(sess, ops, num_votes)

Expand Down
15 changes: 10 additions & 5 deletions train.py
@@ -1,5 +1,6 @@
'''
Single-GPU training.
Will use H5 dataset in default. If using normal, will shift to the normal dataset.
'''
import argparse
import math
Expand Down Expand Up @@ -118,8 +119,12 @@ def train():

# Get model and loss
pred, end_points = MODEL.get_model(pointclouds_pl, is_training_pl, bn_decay=bn_decay)
loss = MODEL.get_loss(pred, labels_pl, end_points)
tf.summary.scalar('loss', loss)
MODEL.get_loss(pred, labels_pl, end_points)
losses = tf.get_collection('losses', scope)
total_loss = tf.add_n(losses, name='total_loss')
tf.summary.scalar('total_loss', total_loss)
for l in losses + [total_loss]:
tf.summary.scalar(l.op.name, l)

correct = tf.equal(tf.argmax(pred, 1), tf.to_int64(labels_pl))
accuracy = tf.reduce_sum(tf.cast(correct, tf.float32)) / float(BATCH_SIZE)
Expand All @@ -133,7 +138,7 @@ def train():
optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=MOMENTUM)
elif OPTIMIZER == 'adam':
optimizer = tf.train.AdamOptimizer(learning_rate)
train_op = optimizer.minimize(loss, global_step=batch)
train_op = optimizer.minimize(total_loss, global_step=batch)

# Add ops to save and restore all the variables.
saver = tf.train.Saver()
Expand All @@ -158,7 +163,7 @@ def train():
'labels_pl': labels_pl,
'is_training_pl': is_training_pl,
'pred': pred,
'loss': loss,
'loss': total_loss,
'train_op': train_op,
'merged': merged,
'step': batch,
Expand Down Expand Up @@ -211,7 +216,7 @@ def train_one_epoch(sess, ops, train_writer):
total_seen += bsize
loss_sum += loss_val
if (batch_idx+1)%50 == 0:
log_string(' ---- %03d ----' % (batch_idx+1))
log_string(' ---- batch: %03d ----' % (batch_idx+1))
log_string('mean loss: %f' % (loss_sum / 50))
log_string('accuracy: %f' % (total_correct / float(total_seen)))
total_correct = 0
Expand Down
3 changes: 2 additions & 1 deletion train_multi_gpu.py
@@ -1,6 +1,7 @@
'''
Multi-GPU training.
Nearly linear scale acceleration for multi-gpus on a single machine.
Will use H5 dataset in default. If using normal, will shift to the normal dataset.
'''

import argparse
Expand Down Expand Up @@ -289,7 +290,7 @@ def train_one_epoch(sess, ops, train_writer):
total_seen += bsize
loss_sum += loss_val
if (batch_idx+1)%50 == 0:
log_string(' ---- %03d ----' % (batch_idx+1))
log_string(' ---- batch: %03d ----' % (batch_idx+1))
log_string('mean loss: %f' % (loss_sum / 50))
log_string('accuracy: %f' % (total_correct / float(total_seen)))
total_correct = 0
Expand Down

0 comments on commit cce809f

Please sign in to comment.