Skip to content

Commit

Permalink
output image and compute score while train
Browse files Browse the repository at this point in the history
tb_ploteval had been removed.
Every <tb_savempi> it will call evaluation() which also write into tensorboard, and it will output the rendered image at runs/evaluation.
  • Loading branch information
pureexe committed May 7, 2021
1 parent e7637f3 commit eeff38c
Showing 1 changed file with 4 additions and 38 deletions.
42 changes: 4 additions & 38 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,8 @@
parser.add_argument('-epochs', type=int, default=4000, help='total epochs to train')
parser.add_argument('-steps', type=int, default=-1, help='total steps to train. In our paper, we proposed to use epoch instead.')
parser.add_argument('-tb_saveimage', type=int, default=50, help='write an output image to tensorboard for every <tb_saveimage> epochs')
parser.add_argument('-tb_savempi', type=int, default=200, help='generate MPI (WebGL) for every <tb_savempi> epochs')
parser.add_argument('-tb_savempi', type=int, default=200, help='generate MPI (WebGL) and measure PSNR/SSIM of validation image for every <tb_savempi> epochs')
parser.add_argument('-checkpoint', type=int, default=100, help='save checkpoint for every <checkpoint> epochs. Be aware that! It will replace the previous checkpoint.')
parser.add_argument('-tb_ploteval',type=int, default=0, help='measure PSNR/SSIM of validation images for every <tb_ploteval> epochs.')
parser.add_argument('-tb_toc',type=int, default=500, help="print output to terminal for every tb_toc epochs")

#lr schedule
Expand Down Expand Up @@ -488,9 +487,6 @@ def generateAlpha(model, dataset, dataloader, writer, runpath, suffix="", datalo
model. --> trained model
dataset. --> valiade dataset
writer. --> tensorboard
Returns:
Webgl
validation score
'''
suffix_str = "/%06d" % suffix if isinstance(suffix, int) else "/"+str(suffix)
# create webgl only when using -predict or finish training
Expand All @@ -507,9 +503,8 @@ def generateAlpha(model, dataset, dataloader, writer, runpath, suffix="", datalo
args.invz,
webpath=args.web_path,
web_width= args.web_width)
return info

if not args.no_eval:
if not args.no_eval and len(dataloader) > 0:
out = evaluation(model,
dataset,
dataloader,
Expand Down Expand Up @@ -693,38 +688,9 @@ def train():
if np.isnan(loss_total.item()):
exit()
checkpoint(ckpt, model, optimizer, epoch+1)
if not args.no_eval and args.tb_ploteval > 0 and epoch % args.tb_ploteval == 0 and len(sampler_val) > 0:
#print("Evaluating on valid set...")
epoch_mse, epoch_psnr, epoch_ssim = 0, 0, 0
model.eval()
with pt.no_grad():
for i, feature in enumerate(dataloader_val):

output_shape = feature['image'].shape[-2:]

#randomly select patch (size: ray^(0.5))
size = math.ceil(math.sqrt(args.ray))
size = min(size, output_shape[0], output_shape[1])

gt, sel = patchtify(feature['image'], size)

output = model(dataset.sfm, feature, output_shape, sel)
mse = pt.mean(((output - gt))**2)
epoch_mse += mse

output_cpu = output.view(1, 3, size, size).permute(0, 2, 3, 1).cpu().detach().numpy()[0]
gt_cpu = gt.view(1, 3, size, size).permute(0, 2, 3, 1).cpu().numpy()[0]

epoch_ssim += structural_similarity(output_cpu, gt_cpu, win_size=11,
multichannel=True, gaussian_weights=True)
epoch_psnr += peak_signal_noise_ratio(output_cpu, gt_cpu, data_range=1.0)
pt.cuda.empty_cache()
writer.add_scalar('loss/mse_valid', epoch_mse / len(sampler_val), epoch)
writer.add_scalar('loss/PSNR_valid', epoch_psnr / len(sampler_val), epoch)
writer.add_scalar('loss/SSIM_valid', epoch_ssim / len(sampler_val), epoch)

print('Finished Training')
mpi = generateAlpha(model, dataset, dataloader_val, None, runpath, dataloader_train = dataloader_train)
generateAlpha(model, dataset, dataloader_val, None, runpath, dataloader_train = dataloader_train)
if not args.no_video:
render_video(model, dataset, args.ray, os.path.join(runpath, 'video_output', args.model_dir))
if args.http:
Expand Down Expand Up @@ -782,4 +748,4 @@ def loadDataset(dpath):

if __name__ == "__main__":
sys.excepthook = colored_hook(os.path.dirname(os.path.realpath(__file__)))
train()
train()

0 comments on commit eeff38c

Please sign in to comment.