Skip to content

Commit

Permalink
Fixed progress bar display issue and no error printing issue
Browse files Browse the repository at this point in the history
  • Loading branch information
rahuldey91 committed May 11, 2018
1 parent 4769a8a commit 138211e
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 14 deletions.
19 changes: 8 additions & 11 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import config
import os
import sys
import traceback
import time
import datetime
import tensorflow as tf
Expand All @@ -17,7 +18,7 @@
from checkpoints import Checkpoints

def main():
# parse the arguments
# Parse the Arguments
args = config.parse_args()
random.seed(args.manual_seed)
tf.set_random_seed(args.manual_seed)
Expand All @@ -27,36 +28,32 @@ def main():
if args.save_results:
utils.saveargs(args)

# initialize the checkpoint class
# Initialize the Checkpoints Class
checkpoints = Checkpoints(args)

# Create Model
models = Model(args)
model, criterion, evaluation = models.setup(checkpoints)

# initialize a sample input to build the model for the first time and print its summary
batch = tf.zeros((1, args.nchannels, args.resolution_high, args.resolution_wide))
model(batch)
# Print Model Summary
print('Model summary: {}'.format(model.name))
print(model.summary())

# Data Loading
dataloader_obj = Dataloader(args)
dataloader = dataloader_obj.create()

# initialize trainer and tester
# Initialize Trainer and Tester
trainer = Trainer(args, model, criterion, evaluation)
tester = Tester(args, model, criterion, evaluation)

# start training !!!
# Start Training !!!
loss_best = 1e10
for epoch in range(args.nepochs):
print('\nEpoch %d/%d\n' % (epoch + 1, args.nepochs))
print('\nEpoch %d/%d' % (epoch + 1, args.nepochs))

# train for a single epoch
print("Training...")
# Train and Test for a Single Epoch
loss_train = trainer.train(epoch, dataloader["train"])
print("Testing...")
loss_test = tester.test(epoch, dataloader["test"])

if loss_best > loss_test:
Expand Down
5 changes: 5 additions & 0 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import models
import losses
import evaluate
import tensorflow as tf

class Model:
def __init__(self, args):
Expand All @@ -26,4 +27,8 @@ def setup(self, checkpoints):
else:
model = checkpoints.load(model, ckpt)

# initialize a sample input to build the model for the first time
batch = tf.zeros((1, self.args.nchannels, self.args.resolution_high, self.args.resolution_wide))
model(batch)

return model, criterion, evaluation
2 changes: 1 addition & 1 deletion test.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def test(self, epoch, dataloader):
if self.log_type == 'progressbar':
# Progress bar
processed_data_len = 0
bar = plugins.Bar('{:<10}'.format('Train'), max=data_len)
bar = plugins.Bar('{:<10}'.format('Test'), max=data_len//self.batch_size)
end = time.time()

with self.summary_writer.as_default():
Expand Down
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def train(self, epoch, dataloader):
if self.log_type == 'progressbar':
# Progress bar
processed_data_len = 0
bar = plugins.Bar('{:<10}'.format('Train'), max=data_len)
bar = plugins.Bar('{:<10}'.format('Train'), max=data_len//self.batch_size)
end = time.time()

with self.summary_writer.as_default():
Expand Down
2 changes: 1 addition & 1 deletion utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def cleanup():
except OSError as ex:
raise Exception("wasn't able to kill the child process (pid:{}).".format(child.pid))
# # os.waitpid(child.pid, os.P_ALL)
print('\x1b[?25h', end='', flush=True) # show cursor
print('\n\n\x1b[?25h', end='', flush=True) # show cursor
sys.exit(0)


Expand Down

0 comments on commit 138211e

Please sign in to comment.