Skip to content

Commit

Permalink
Merge pull request #156 from melgor/mGPU
Browse files Browse the repository at this point in the history
Added Multi-GPU support
  • Loading branch information
Brandon Amos committed Jul 4, 2016
2 parents b0f9610 + 0995e86 commit e030689
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 66 deletions.
3 changes: 3 additions & 0 deletions training/OpenFaceOptim.lua
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ function OpenFaceOptim:__init(model, optState, checkpoint_data)
assert(pl.tablex.compare_no_order(modules, pl.tablex.keys(state)))
self.modulesToOptState = state
end
return self
end

local function get_device_for_module(mod)
Expand Down Expand Up @@ -162,3 +163,5 @@ function OpenFaceOptim:optimizeTriplet(optimMethod, inputs, output,

return err, output
end

return OpenFaceOptim
4 changes: 3 additions & 1 deletion training/main.lua
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ torch.manualSeed(opt.manualSeed)

paths.dofile('data.lua')
paths.dofile('util.lua')
paths.dofile('model.lua')
model = nil
criterion = nil
paths.dofile('train.lua')
paths.dofile('test.lua')

Expand All @@ -44,5 +45,6 @@ for _=1,opt.nEpochs do
if opt.testing then
test()
end
model = saveModel(model)
epoch = epoch + 1
end
71 changes: 44 additions & 27 deletions training/model.lua
Original file line number Diff line number Diff line change
Expand Up @@ -8,43 +8,60 @@ if opt.cuda then
require 'cunn'
if opt.cudnn then
require 'cudnn'
cudnn.benchmark = false
cudnn.benchmark = opt.cudnn_bench
cudnn.fastest = true
cudnn.verbose = false
end
end

paths.dofile('torch-TripletEmbedding/TripletEmbedding.lua')

if opt.retrain ~= 'none' then
assert(paths.filep(opt.retrain), 'File not found: ' .. opt.retrain)
print('Loading model from file: ' .. opt.retrain);
model = torch.load(opt.retrain)
print("Using imgDim = ", opt.imgDim)
else
paths.dofile(opt.modelDef)
assert(imgDim, "Model definition must set global variable 'imgDim'")
assert(imgDim == opt.imgDim, "Model definiton's imgDim must match imgDim option.")
model = createModel()
end

criterion = nn.TripletEmbeddingCriterion(opt.alpha)
local M = {}

if opt.cuda then
model = model:cuda()
if opt.cudnn then
cudnn.convert(model,cudnn)
end
criterion:cuda()
function M.modelSetup(continue)
if continue then
model = continue
elseif opt.retrain ~= 'none' then
assert(paths.filep(opt.retrain), 'File not found: ' .. opt.retrain)
print('Loading model from file: ' .. opt.retrain);
model = torch.load(opt.retrain)
print("Using imgDim = ", opt.imgDim)
else
paths.dofile(opt.modelDef)
assert(imgDim, "Model definition must set global variable 'imgDim'")
assert(imgDim == opt.imgDim, "Model definiton's imgDim must match imgDim option.")
model = createModel()
end

-- First remove any DataParallelTable
if torch.type(model) == 'nn.DataParallelTable' then
model = model:get(1)
end

criterion = nn.TripletEmbeddingCriterion(opt.alpha)

if opt.cuda then
model = model:cuda()
if opt.cudnn then
cudnn.convert(model,cudnn)
end
criterion:cuda()
else
model:float()
criterion:float()
end

optimizeNet(model, opt.imgDim)

if opt.cuda and opt.nGPU > 1 then
model = makeDataParallel(model, opt.nGPU)
end

collectgarbage()
return model, criterion
end

optimizeNet(model, opt.imgDim)

print('=> Model')
print(model)
print(('Number of Parameters: %d'):format(model:getParameters():size(1)))
return M

print('=> Criterion')
print(criterion)

collectgarbage()
2 changes: 2 additions & 0 deletions training/opts.lua
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ function M.parse(arg)
cmd:option('-manualSeed', 2, 'Manually set RNG seed')
cmd:option('-cuda', true, 'Use cuda.')
cmd:option('-device', 1, 'Cuda device to use.')
cmd:option('-nGPU', 1, 'Number of GPUs to use by default')
cmd:option('-cudnn', true, 'Convert the model to cudnn.')
cmd:option('-cudnn_bench', false, 'Run cudnn to choose fastest option. Increase memory usage')

------------- Data options ------------------------
cmd:option('-nDonkeys', 2, 'number of donkeys to initialize (data loading threads)')
Expand Down
88 changes: 50 additions & 38 deletions training/train.lua
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@
require 'optim'
require 'image'
require 'torchx' --for concetration the table of tensors
local optnet_loaded, optnet = pcall(require,'optnet')
local models = require 'model'
local openFaceOptim = require 'OpenFaceOptim'

paths.dofile("OpenFaceOptim.lua")

local optimMethod = optim.adam
local optimState = {} -- Use for other algorithms like SGD
local optimator = OpenFaceOptim(model, optimState)
local optimator = nil

trainLogger = optim.Logger(paths.concat(opt.save, 'train.log'))

Expand All @@ -35,17 +37,13 @@ function train()
print('==> doing epoch on training data:')
print("==> online epoch # " .. epoch)
batchNumber = 0
model,criterion = models.modelSetup(model)
optimator = openFaceOptim:__init(model, optimState)
if opt.cuda then
cutorch.synchronize()
cutorch.synchronize()
end

model:training()
if opt.cuda then
model:cuda()
if opt.cudnn then
cudnn.convert(model,cudnn)
end
end

local tm = torch.Timer()
triplet_loss = 0

Expand All @@ -72,7 +70,7 @@ function train()

donkeys:synchronize()
if opt.cuda then
cutorch.synchronize()
cutorch.synchronize()
end

triplet_loss = triplet_loss / batchNumber
Expand All @@ -86,7 +84,10 @@ function train()
print('\n')

collectgarbage()
end -- of train()


function saveModel(model)
-- Check for nans from https://github.com/cmusatyalab/openface/issues/127
local function checkNans(x, tag)
local I = torch.ne(x,x)
Expand All @@ -103,29 +104,41 @@ function train()
checkNans(mod.running_var, string.format("%d-%s-%s", j, mod, 'running_var'))
end
end

if opt.cudnn then
cudnn.convert(model, nn)
if opt.cuda then
if opt.cudnn then
cudnn.convert(model, nn)
end
end
model = model:float():clearState()

torch.save(paths.concat(opt.save, 'model_' .. epoch .. '.t7'), model)

local dpt
if torch.type(model) == 'nn.DataParallelTable' then
dpt = model
model = model:get(1)
end


if optnet_loaded then
optnet.removeOptimization(model)
end

torch.save(paths.concat(opt.save, 'model_' .. epoch .. '.t7'), model:float():clearState())
torch.save(paths.concat(opt.save, 'optimState_' .. epoch .. '.t7'), optimState)

if opt.cuda then
model = model:cuda()
if opt.cudnn then
cudnn.convert(model, cudnn)
end

if dpt then -- OOM without this
dpt:clearState()
end

collectgarbage()
end -- of train()

return model
end

local inputsCPU = torch.FloatTensor()
local numPerClass = torch.FloatTensor()

local timer = torch.Timer()
function trainBatch(inputsThread, numPerClassThread)
collectgarbage()
if batchNumber >= opt.epochSize then
return
end
Expand All @@ -134,6 +147,7 @@ function trainBatch(inputsThread, numPerClassThread)
cutorch.synchronize()
end
timer:reset()

receiveTensor(inputsThread, inputsCPU)
receiveTensor(numPerClassThread, numPerClass)

Expand Down Expand Up @@ -212,33 +226,31 @@ function trainBatch(inputsThread, numPerClassThread)
local as = torch.concat(as_table):view(table.getn(as_table), opt.embSize)
local ps = torch.concat(ps_table):view(table.getn(ps_table), opt.embSize)
local ns = torch.concat(ns_table):view(table.getn(ns_table), opt.embSize)

local apn
if opt.cuda then
local asCuda = torch.CudaTensor()
local psCuda = torch.CudaTensor()
local nsCuda = torch.CudaTensor()
local asCuda = torch.CudaTensor()
local psCuda = torch.CudaTensor()
local nsCuda = torch.CudaTensor()

local sz = as:size()
asCuda:resize(sz):copy(as)
psCuda:resize(sz):copy(ps)
nsCuda:resize(sz):copy(ns)
local sz = as:size()
asCuda:resize(sz):copy(as)
psCuda:resize(sz):copy(ps)
nsCuda:resize(sz):copy(ns)

apn = {asCuda, psCuda, nsCuda}
apn = {asCuda, psCuda, nsCuda}
else
apn = {as, ps, ns}
apn = {as, ps, ns}
end

local err, _ = optimator:optimizeTriplet(
optimMethod, inputs, apn, criterion,
triplet_idx -- , num_example_per_idx
)

-- DataParallelTable's syncParameters
model:apply(function(m) if m.syncParameters then m:syncParameters() end end)
if opt.cuda then
cutorch.synchronize()
cutorch.synchronize()
end

batchNumber = batchNumber + 1
print(('Epoch: [%d][%d/%d]\tTime %.3f\ttripErr %.2e'):format(
epoch, batchNumber, opt.epochSize, timer:time().real, err))
Expand Down
20 changes: 20 additions & 0 deletions training/util.lua
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,23 @@ function optimizeNet( model, inputSize )
print("Repo: https://github.com/fmassa/optimize-net")
end
end

function makeDataParallel(model, nGPU)
-- Wrap the model with DataParallelTable, if using more than one GPU
if nGPU > 1 then
local gpus = torch.range(1, nGPU):totable()
local fastest, benchmark = cudnn.fastest, cudnn.benchmark

local dpt = nn.DataParallelTable(1, true, true)
:add(model, gpus)
:threads(function()
require ("dpnn")
local cudnn = require 'cudnn'
cudnn.fastest, cudnn.benchmark = fastest, benchmark
end)
dpt.gradInput = nil

model = dpt:cuda()
end
return model
end

0 comments on commit e030689

Please sign in to comment.