## set arguments

In [1]:
cmd = torch.CmdLine()
cmd:option('-iteration', 30,'how many iteration')
cmd:option('-gradclip',5,'magnitude of clip on the RNN gradient')
cmd:option('-modelname','GEINet','wuzifeng model name you want to load')
cmd:option('-dropout',0.5,'fraction of dropout to use between layers')
cmd:option('-learningrate',1e-3)
cmd:option('-datapath', '/home/chenqiang/data/gait-rnn', 'base data path')
cmd:option('-geipath', '/home/chenqiang/data/OULP_C1V1_Pack_GEI/')
cmd:option('-calprecision', 200, 'calculate loss on validation every batch')
cmd:option('-calval', 2, 'calculate loss on validation every batch')
cmd:option('-batchsize', 64, 'how many intance in a traning batch')
cmd:option('-loadmodel', '', 'load fullmodel, rnn model, cnn model')
cmd:option('-gpu', false, 'use GPU')
cmd:option('-gpudevice', 1, 'set gpu device')
arg = arg or ""
opt = cmd:parse(arg)
opt.gpu = true
print(opt)

{
  gpu : true
  gpudevice : 1
  modelname : GEINet
  calval : 2
  loadmodel : 
  gradclip : 5
  datapath : /home/chenqiang/data/gait-rnn
  learningrate : 0.001
  geipath : /home/chenqiang/data/OULP_C1V1_Pack_GEI/
  dropout : 0.5
  batchsize : 64
  calprecision : 200
  iteration : 30
}


## set environment variable firstly

In [2]:
-- set the GPU
if opt.gpu then
    require 'cunn'
    require 'cutorch'
    cutorch.setDevice(opt.gpudevice)
end

## load models

In [3]:
require 'model.getModel';
model, crit = GEINet(opt.gpu, opt.dropout)

## load dataset

In [7]:
local prepDataset = require 'prepareDataset'
dataset = prepDataset.prepareDatasetOULP(opt.datapath, opt.modelname, opt.geipath)
for i, item in ipairs({'train', 'val', 'test'}) do
    local item_count = #dataset[item].index2hid
    local uniq_count = #dataset[item].item_data
    info('%s data instances %05d, uniq  %04d', item, item_count, uniq_count)
end

2017-12-25 22:30:30[INFO] load data from /home/chenqiang/data/gait-rnn/oulp_train_data.txt, /home/chenqiang/data/gait-rnn/oulp_val_data.txt, /home/chenqiang/data/gait-rnn/oulp_test_data.txt	
2017-12-25 22:30:30[INFO] train data instances 00856, uniq  6848	
2017-12-25 22:30:30[INFO] val data instances 00100, uniq  0800	
2017-12-25 22:30:30[INFO] test data instances 00956, uniq  7648	


## start to train

In [5]:
Plot = require 'itorch.Plot';
counter_tb = {}
loss_tb = {}

In [6]:
local parameters, gradParameters = model:getParameters()
info('Number of parameters:%d', parameters:size(1))
local max_val_precision = 0

local timer = torch.Timer()
local i = 0
while i < opt.iteration do
    i = i + 1    
    
    -- calculate loss on validation dataset, and test dataset for testing purpose
--     if i % opt.calval == 0 then
--         local val_in, val_tar = dataset['val']:next_gei_batch(opt.batchsize)
--         val_in = convertToCuda(val_in)
--         local val_loss = cal_loss(model, crit, val_in, val_tar)
--         info('%05dth/%05d Val Error %0.6f', i, opt.iteration, val_loss)
        
--         local tes_in, tes_tar = dataset['test']:next_gei_batch(opt.batchsize)
--         tes_in = convertToCuda(tes_in)
--         local loss = cal_loss(model, crit, tes_in, tes_tar)            
--         info('%05dth/%05d Tes Error %0.6f', i, opt.iteration, loss)
--     end
    

    local inputs, targets = dataset['train']:next_gei_batch(opt.batchsize)
    gradParametersAdd = gradParameters:clone()
    gradParametersAdd:zero()
    total_loss = 0            
    inputs = convertToCuda(inputs)
    for j = 1, opt.batchsize do
        inputPair = inputs[j]
        output = model:forward(inputPair)
        loss = crit:forward(output, targets[j])
        total_loss = total_loss + loss
        local grad = crit:backward(output, targets[j])        
        gradParameters:zero()
        model:backward(inputPair, grad)
        gradParameters:clamp(-opt.gradclip, opt.gradclip)
        gradParametersAdd:add(gradParameters)
    end
    total_loss = total_loss / opt.batchsize
    gradParametersAdd:div(opt.batchsize)
    parameters:add(gradParametersAdd*opt.learningrate*-1)   
--     if i % opt.calprecision == 0 then
--         local same, diff, prec = precisionCASIADatasetBNormal(dataset['val'], model)
--         if prec > max_val_precision then
--             info('change max precision from %0.2f to %0.2f'
--                                         , max_val_precision, prec)
--             max_val_precision = prec
--             local name = string.format('%s_valpre_%0.04f_i%04d', opt.modelname, max_val_precision, i)
--             save_model(model, name)
--         else
--             info('do not change max_precision from %0.2f to %0.2f', max_val_precision, prec)
--         end
--     end
    
    
    local time = timer:time().real
    timer:reset()
    info('%05dth/%05d Tra Error %0.6f, %d', i, opt.iteration, total_loss, time)
end

2017-12-25 22:25:51[INFO] Number of parameters:24226439	


2017-12-25 22:25:52[INFO] 00001th/00030 Tra Error 6.862573, 1	


2017-12-25 22:25:53[INFO] 00002th/00030 Val Error 6.863902	


2017-12-25 22:25:53[INFO] 00002th/00030 Tes Error 6.865671	


2017-12-25 22:25:54[INFO] 00002th/00030 Tra Error 6.865077, 1	


2017-12-25 22:25:55[INFO] 00003th/00030 Tra Error 6.862040, 0	


2017-12-25 22:25:55[INFO] 00004th/00030 Val Error 6.859328	


2017-12-25 22:25:55[INFO] 00004th/00030 Tes Error 6.866451	


2017-12-25 22:25:55[INFO] 00004th/00030 Tra Error 6.861688, 0	


2017-12-25 22:25:56[INFO] 00005th/00030 Tra Error 6.863520, 0	


2017-12-25 22:25:56[INFO] 00006th/00030 Val Error 6.857063	


2017-12-25 22:25:56[INFO] 00006th/00030 Tes Error 6.865277	


2017-12-25 22:25:57[INFO] 00006th/00030 Tra Error 6.864102, 1	


2017-12-25 22:25:58[INFO] 00007th/00030 Tra Error 6.862687, 0	


2017-12-25 22:25:58[INFO] 00008th/00030 Val Error 6.854218	


2017-12-25 22:25:58[INFO] 00008th/00030 Tes Error 6.862837	


2017-12-25 22:25:59[INFO] 00008th/00030 Tra Error 6.860694, 1	


2017-12-25 22:26:00[INFO] 00009th/00030 Tra Error 6.862172, 0	


2017-12-25 22:26:00[INFO] 00010th/00030 Val Error 6.863279	


2017-12-25 22:26:01[INFO] 00010th/00030 Tes Error 6.866897	


2017-12-25 22:26:01[INFO] 00010th/00030 Tra Error 6.861208, 0	


2017-12-25 22:26:02[INFO] 00011th/00030 Tra Error 6.862516, 0	


2017-12-25 22:26:02[INFO] 00012th/00030 Val Error 6.858454	


2017-12-25 22:26:02[INFO] 00012th/00030 Tes Error 6.862455	


2017-12-25 22:26:03[INFO] 00012th/00030 Tra Error 6.868934, 1	


2017-12-25 22:26:04[INFO] 00013th/00030 Tra Error 6.860870, 0	


2017-12-25 22:26:04[INFO] 00014th/00030 Val Error 6.862724	


2017-12-25 22:26:04[INFO] 00014th/00030 Tes Error 6.860374	


2017-12-25 22:26:05[INFO] 00014th/00030 Tra Error 6.865350, 1	


2017-12-25 22:26:06[INFO] 00015th/00030 Tra Error 6.857179, 0	


2017-12-25 22:26:06[INFO] 00016th/00030 Val Error 6.857994	


2017-12-25 22:26:06[INFO] 00016th/00030 Tes Error 6.862984	


2017-12-25 22:26:07[INFO] 00016th/00030 Tra Error 6.863724, 1	


2017-12-25 22:26:07[INFO] 00017th/00030 Tra Error 6.862614, 0	


2017-12-25 22:26:08[INFO] 00018th/00030 Val Error 6.860873	


2017-12-25 22:26:08[INFO] 00018th/00030 Tes Error 6.863042	


2017-12-25 22:26:09[INFO] 00018th/00030 Tra Error 6.862233, 1	


2017-12-25 22:26:09[INFO] 00019th/00030 Tra Error 6.865769, 0	


2017-12-25 22:26:10[INFO] 00020th/00030 Val Error 6.863768	


2017-12-25 22:26:10[INFO] 00020th/00030 Tes Error 6.865135	


2017-12-25 22:26:11[INFO] 00020th/00030 Tra Error 6.860199, 1	


2017-12-25 22:26:11[INFO] 00021th/00030 Tra Error 6.864934, 0	


2017-12-25 22:26:12[INFO] 00022th/00030 Val Error 6.861171	


2017-12-25 22:26:12[INFO] 00022th/00030 Tes Error 6.862191	


2017-12-25 22:26:13[INFO] 00022th/00030 Tra Error 6.862385, 1	


2017-12-25 22:26:13[INFO] 00023th/00030 Tra Error 6.866692, 0	


2017-12-25 22:26:13[INFO] 00024th/00030 Val Error 6.866182	


2017-12-25 22:26:13[INFO] 00024th/00030 Tes Error 6.864545	


2017-12-25 22:26:14[INFO] 00024th/00030 Tra Error 6.864995, 1	


2017-12-25 22:26:15[INFO] 00025th/00030 Tra Error 6.865097, 0	


2017-12-25 22:26:15[INFO] reset self.gei_index to 1, 01 epoch ends	


2017-12-25 22:26:15[INFO] 00026th/00030 Val Error 6.859276	


2017-12-25 22:26:15[INFO] 00026th/00030 Tes Error 6.856068	


2017-12-25 22:26:16[INFO] 00026th/00030 Tra Error 6.861513, 0	


2017-12-25 22:26:17[INFO] 00027th/00030 Tra Error 6.863482, 0	


2017-12-25 22:26:17[INFO] 00028th/00030 Val Error 6.858210	


2017-12-25 22:26:17[INFO] 00028th/00030 Tes Error 6.864064	


2017-12-25 22:26:18[INFO] 00028th/00030 Tra Error 6.866108, 1	


2017-12-25 22:26:18[INFO] 00029th/00030 Tra Error 6.862733, 0	


2017-12-25 22:26:19[INFO] 00030th/00030 Val Error 6.859651	


2017-12-25 22:26:19[INFO] 00030th/00030 Tes Error 6.866508	


2017-12-25 22:26:19[INFO] 00030th/00030 Tra Error 6.862065, 1	
