Skip to content

Commit

Permalink
minor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
holyseven committed Apr 22, 2018
1 parent b68945f commit 7400e90
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 10 deletions.
5 changes: 4 additions & 1 deletion run_pspmg/predict.py
Expand Up @@ -10,6 +10,7 @@
import cv2
from database.helper_cityscapes import trainid_to_labelid, coloring
from database.helper_segmentation import *
from experiment_manager.utils import sorted_str_dict

import argparse
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -134,6 +135,8 @@ def predict(i_ckpt):
else:
max_iter = FLAGS.test_max_iter

# IMG_MEAN = [123.680000305, 116.778999329, 103.939002991] # RGB mean from official PSPNet

step = 0
while step < max_iter:
image, label = cv2.imread(images_filenames[step], 1), cv2.imread(labels_filenames[step], 0)
Expand Down Expand Up @@ -245,7 +248,7 @@ def predict(i_ckpt):


def main(_):
print FLAGS.__dict__
print(sorted_str_dict(FLAGS.__dict__))

# ============================================================================
# ===================== Prediction =========================
Expand Down
18 changes: 9 additions & 9 deletions run_pspmg/train.py
Expand Up @@ -5,7 +5,7 @@
import os

import tensorflow as tf
from experiment_manager.utils import LogDir
from experiment_manager.utils import LogDir, sorted_str_dict
from model import pspnet_mg
import math
import numpy as np
Expand Down Expand Up @@ -211,20 +211,20 @@ def train(resume_step=None):
# This can transform .npy weights with variables names being the same to the tf ckpt model.
fine_tune_variables = []
npy_dict = np.load(FLAGS.fine_tune_filename).item()
new_layers_names = model.new_layers_names
new_layers_names.append('Momentum')
new_layers_names = ['Momentum']
for v in tf.global_variables():
print '=====Saving initial snapshot process:',
if any(elem in v.name for elem in new_layers_names):
print '=====Saving initial snapshot process: not import', v.name
print 'not import', v.name
continue

name = v.name.split(':0')[0]
if name not in npy_dict:
print '=====Saving initial snapshot process: not find ', v.name
print 'not find', v.name
continue

v.load(npy_dict[name], sess)
print '=====Saving initial snapshot process: saving %s' % v.name
print 'saving', v.name
fine_tune_variables.append(v)

saver = tf.train.Saver(var_list=fine_tune_variables)
Expand Down Expand Up @@ -260,7 +260,7 @@ def train(resume_step=None):
print '=========================== training process begins ================================='
f_log = open(logdir.exp_dir + '/' + str(datetime.datetime.now()) + '.txt', 'w')
f_log.write('step,loss,precision,wd\n')
f_log.write(str(FLAGS.__dict__) + '\n')
f_log.write(sorted_str_dict(FLAGS.__dict__) + '\n')

average_loss = 0.0
show_period = 20
Expand Down Expand Up @@ -518,7 +518,7 @@ def main(_):
# ============================================================================
# ============================= TRAIN ========================================
# ============================================================================
print FLAGS.__dict__
print(sorted_str_dict(FLAGS.__dict__))
if FLAGS.resume_step is not None:
print 'Ready to resume from step %d.' % FLAGS.resume_step

Expand All @@ -529,7 +529,7 @@ def main(_):
logdir.print_all_info()
f_log = open(logdir.exp_dir + '/' + str(datetime.datetime.now()) + '.txt', 'w')
f_log.write('step,loss,precision,wd\n')
f_log.write(str(FLAGS.__dict__) + '\n')
f_log.write(sorted_str_dict(FLAGS.__dict__) + '\n')
else:
f_log, logdir, has_nan = train(FLAGS.resume_step)

Expand Down
4 changes: 4 additions & 0 deletions z_pretrained_weights/download_resnet_v1_101.sh
@@ -0,0 +1,4 @@
wget http://download.tensorflow.org/models/resnet_v1_101_2016_08_28.tar.gz
tar -xvf ./resnet_v1_101_2016_08_28.tar.gz
rm ./resnet_v1_101_2016_08_28.tar.gz

0 comments on commit 7400e90

Please sign in to comment.