Skip to content

Commit

Permalink
add evaluation code
Browse files Browse the repository at this point in the history
  • Loading branch information
ericyi committed Sep 3, 2018
1 parent 6c15049 commit 0d13e15
Showing 1 changed file with 17 additions and 13 deletions.
30 changes: 17 additions & 13 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
import data_prep

parser = argparse.ArgumentParser()
parser.add_argument('--gpu', type=int, default=3, help='GPU to use [default: GPU 0]')
parser.add_argument('--gpu', type=int, default=0, help='GPU to use [default: GPU 0]')
parser.add_argument('--model', default='model', help='Model name [default: model]')
parser.add_argument('--stage', type=int, default=1, help='Training stage [default: 1]')
parser.add_argument('--log_dir', default='log', help='Log dir [default: log]')
parser.add_argument('--max_epoch', type=int, default=100, help='Epoch to run [default: 100]')
parser.add_argument('--batch_size', type=int, default=12, help='Batch Size during training [default: 16]')
parser.add_argument('--learning_rate', type=float, default=0.0005, help='Initial learning rate [default: 0.0005]')
parser.add_argument('--batch_size', type=int, default=32, help='Batch Size during training [default: 16]')
parser.add_argument('--learning_rate', type=float, default=0.0001, help='Initial learning rate [default: 0.0005]')
parser.add_argument('--momentum', type=float, default=0.9, help='Initial learning rate [default: 0.9]')
parser.add_argument('--optimizer', default='adam', help='adam or momentum [default: adam]')
parser.add_argument('--decay_step', type=int, default=200000, help='Decay step for lr decay [default: 200000]')
Expand Down Expand Up @@ -59,8 +59,12 @@
TRAIN_DATASET = data_prep.FlowDataset('data/flow_train.mat', npoint=NPOINT)
VAD_DATASET = data_prep.FlowDataset('data/flow_validation.mat', npoint=NPOINT)
else:
TRAIN_DATASET = data_prep.SegDataset('data/seg_train.mat', npoint=NPOINT)
VAD_DATASET = data_prep.SegDataset('data/seg_validation.mat', npoint=NPOINT)
if STAGE==2:
RELROT = True
else:
RELROT = False
TRAIN_DATASET = data_prep.SegDataset('data/seg_train.mat', npoint=NPOINT, relrot=RELROT)
VAD_DATASET = data_prep.SegDataset('data/seg_validation.mat', npoint=NPOINT, relrot=RELROT)

def log_string(out_str):
LOG_FOUT.write(out_str+'\n')
Expand All @@ -74,7 +78,7 @@ def get_learning_rate(batch):
DECAY_STEP, # Decay step.
DECAY_RATE, # Decay rate.
staircase=True)
learing_rate = tf.maximum(learning_rate, 0.00001) # CLIP THE LEARNING RATE!
learning_rate = tf.maximum(learning_rate, 0.00001) # CLIP THE LEARNING RATE!
return learning_rate

def get_bn_decay(batch):
Expand Down Expand Up @@ -107,8 +111,8 @@ def train():
pred_trans, pred_grouping_sub, loss1, loss2, loss = MODEL.get_model_loss_stage2(pcpair_pl, vismask_pl, flow_pl, momasks_pl, is_training_pl, bn_decay)
loss3 = 0*loss
else:
pred_seg_sub, pred_conf, loss = MODEL.get_model_loss_stage3(pcpair_pl, vismask_pl, flow_pl, momasks_pl, is_training_pl, bn_decay)
loss1 = loss2 = loss3 = 0*loss
pred_seg_sub, pred_conf, loss, loss1, loss2 = MODEL.get_model_loss_stage3(pcpair_pl, vismask_pl, flow_pl, momasks_pl, is_training_pl, bn_decay)
loss3 = 0*loss

tf.summary.scalar('loss', loss)
tf.summary.scalar('loss1', loss1)
Expand Down Expand Up @@ -149,8 +153,8 @@ def train():
init = tf.global_variables_initializer()
sess.run(init)
if STAGE>2:
saver_corrsflow.restore(sess, FLAGS.log_dir+'1/best_model_epoch.ckpt')
saver_transgrouping.restore(sess, FLAGS.log_dir+'2/best_model_epoch.ckpt')
saver_corrsflow.restore(sess, FLAGS.log_dir+'1/best_model.ckpt')
saver_transgrouping.restore(sess, FLAGS.log_dir+'2/best_model.ckpt')

ops = {'pcpair_pl': pcpair_pl,
'flow_pl': flow_pl,
Expand All @@ -175,7 +179,7 @@ def train():
if loss < best_loss:
best_loss = loss
save_path = saver.save(sess, os.path.join(LOG_DIR, "best_model_epoch_%03d.ckpt"%(epoch)))
save_path = saver.save(sess, os.path.join(LOG_DIR, "best_model_epoch.ckpt"%(epoch)))
save_path = saver.save(sess, os.path.join(LOG_DIR, "best_model.ckpt"))
log_string("Model saved in file: %s" % save_path)

# Save the variables to disk.
Expand Down Expand Up @@ -239,7 +243,7 @@ def train_one_epoch(sess, ops, train_writer):
elif STAGE==2:
log_string('mean loss: %f, mean loss_trans: %f, mean loss_grouping: %f' % (loss_sum / 10, loss1_sum / 10, loss2_sum / 10))
else:
log_string('mean loss: %f' % (loss_sum / 10))
log_string('mean loss: %f, mean negative iou: %f, mean loss_seg: %f' % (loss_sum / 10, loss2_sum / 10, loss1_sum / 10))
loss_sum = 0
loss1_sum = 0
loss2_sum = 0
Expand Down Expand Up @@ -288,7 +292,7 @@ def eval_one_epoch(sess, ops, test_writer):
log_string('eval mean loss: %f, eval mean loss_trans: %f, eval mean loss_grouping: %f' % (loss_sum / float(num_batches), loss1_sum / float(num_batches), loss2_sum / float(num_batches)))
return loss_sum / float(num_batches)
else:
log_string('eval mean loss: %f' % (loss_sum / float(num_batches)))
log_string('eval mean loss: %f, eval mean negative iou: %f, eval mean loss_seg: %f' % (loss_sum / float(num_batches), loss2_sum / float(num_batches), loss1_sum / float(num_batches)))
return loss_sum / float(num_batches)


Expand Down

0 comments on commit 0d13e15

Please sign in to comment.