-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.lua
68 lines (49 loc) · 2.28 KB
/
train.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
--[[
Train a Fast-RCNN detector network using the Pascal VOC 2007/coco dataset.
]]
require 'paths'
require 'torch'
local fastrcnn = require 'fastrcnn'
torch.setdefaulttensortype('torch.FloatTensor')
paths.dofile('projectdir.lua')
--------------------------------------------------------------------------------
-- Load options
--------------------------------------------------------------------------------
print('==> (1/5) Load options')
local opts = paths.dofile('options.lua')
local opt = opts.parse(arg)
--------------------------------------------------------------------------------
-- Load dataset data loader
--------------------------------------------------------------------------------
-- The fastrcnn.train() function receives a table with loading functions to fetch
-- the necessary data from a data structure. This way it is easy to use other
-- datasets with the fastrcnn package.
print('==> (2/5) Load dataset data loader')
local data_loader = paths.dofile('data.lua')
local data_gen = data_loader(opt.dataset, 'train')
--------------------------------------------------------------------------------
-- Load regions-of-interest (RoIs)
--------------------------------------------------------------------------------
print('==> (3/5) Load roi proposals data')
local rois_loader = paths.dofile('rois.lua')
local rois = rois_loader(opt.dataset, 'train')
--------------------------------------------------------------------------------
-- Setup model
--------------------------------------------------------------------------------
local model, model_parameters
if opt.loadModel == '' then
print('==> (4/5) Setup model:')
local nclasses = (opt.dataset=='coco' and 80) or 20
local load_model = paths.dofile('models/init.lua')
model, model_parameters = load_model(opt.netType, opt.nGPU, nclasses)
else
print('==> (4/5) Load model from file: ')
local model_data = torch.load(opt.load)
model, model_parameters = model_data.model, model_data.params
end
--------------------------------------------------------------------------------
-- Train a Fast R-CNN detector
--------------------------------------------------------------------------------
print('==> (5/5) Train Fast-RCNN model')
fastrcnn.train(data_gen, rois, model, model_parameters, opt)
print('Script complete.')