Skip to content

Commit

Permalink
In the process of adding loss plotting to train.lua
Browse files Browse the repository at this point in the history
  • Loading branch information
brannondorsey committed Nov 29, 2016
1 parent 1f3e792 commit 3e654b9
Showing 1 changed file with 15 additions and 1 deletion.
16 changes: 15 additions & 1 deletion train.lua 100755 → 100644
Expand Up @@ -314,6 +314,16 @@ file = torch.DiskFile(paths.concat(opt.checkpoints_dir, opt.name, 'opt.txt'), 'w
file:writeObject(opt)
file:close()

-- display plot config
local plot_config = {
title = "Loss over time",
labels = {"epoch", "errG", "errD", "errL1"},
ylabel = "loss",
}
-- display plot vars
local plot_data = {}
local plot_win

local counter = 0
for epoch = 1, opt.niter do
epoch_tm:reset()
Expand All @@ -328,7 +338,7 @@ for epoch = 1, opt.niter do

-- (2) Update G network: maximize log(D(x,G(x))) + L1(y,G(x))
optim.adam(fGx, parametersG, optimStateG)

-- display
counter = counter + 1
if counter % opt.display_freq == 0 and opt.display then
Expand Down Expand Up @@ -385,6 +395,10 @@ for epoch = 1, opt.niter do
math.floor(math.min(data:size(), opt.ntrain) / opt.batchSize),
tm:time().real / opt.batchSize, data_tm:time().real / opt.batchSize,
errG and errG or -1, errD and errD or -1, errL1 and errL1 or -1))
-- update display plot
table.insert(plot_data, {epoch, errG, errD, errL1})
plot_config.win = plot_win
plot_win = disp.plot(plot_data, plot_config)
end

-- save latest model
Expand Down

0 comments on commit 3e654b9

Please sign in to comment.