Skip to content

Commit

Permalink
switched to hdf5 and added a choice for prediction method to main.lua…
Browse files Browse the repository at this point in the history
… (e.g. the classification model has two different prediction methods for the same model, averaging and taking the max.)
  • Loading branch information
paulguerrero committed Jul 7, 2017
1 parent 11fdd70 commit f55154a
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 46 deletions.
38 changes: 23 additions & 15 deletions hough.lua
Original file line number Diff line number Diff line change
Expand Up @@ -358,23 +358,31 @@ end

-- Translate the results of the deep net to a 3D normal by
-- computing the third coordinate and applying pca to each row
function M.postprocess_normals2(normals, pcas, hist_size)
function M.postprocess_normals2(normals, pcas, hist_size, pred_method)
local n = normals:size(1)

-- -- weighted mean of the histogram votes
-- row_ind = torch.range(0,normals:size(2)-1):div(hist_size):floor() -- zero-based
-- col_ind = torch.range(0,normals:size(2)-1) - torch.mul(row_ind,hist_size) -- zero-based
-- row_ind = torch.repeatTensor(row_ind,n,1):float()
-- col_ind = torch.repeatTensor(col_ind,n,1):float()
-- hist_sum = normals:sum(2)
-- row_ind = row_ind:cmul(normals):sum(2):cdiv(hist_sum)
-- col_ind = col_ind:cmul(normals):sum(2):cdiv(hist_sum)

-- maximum of the histogram votes
_,ind = normals:max(2)
ind = ind:float()
row_ind = (ind-1):div(hist_size):floor() -- zero-based
col_ind = (ind-1) - torch.mul(row_ind,hist_size) -- zero-based
if pred_method == '' or pred_method == 'max' then

-- maximum of the histogram votes
_,ind = normals:max(2)
ind = ind:float()
row_ind = (ind-1):div(hist_size):floor() -- zero-based
col_ind = (ind-1) - torch.mul(row_ind,hist_size) -- zero-based

elseif pred_method == 'avg' then

-- weighted mean of the histogram votes
row_ind = torch.range(0,normals:size(2)-1):div(hist_size):floor() -- zero-based
col_ind = torch.range(0,normals:size(2)-1) - torch.mul(row_ind,hist_size) -- zero-based
row_ind = torch.repeatTensor(row_ind,n,1):float()
col_ind = torch.repeatTensor(col_ind,n,1):float()
hist_sum = normals:sum(2)
row_ind = row_ind:cmul(normals):sum(2):cdiv(hist_sum)
col_ind = col_ind:cmul(normals):sum(2):cdiv(hist_sum)

else
error('Unknown prediction method.')
end

-- get normals at bin centers
normals = torch.cat(
Expand Down
53 changes: 37 additions & 16 deletions main.lua
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ require 'cunn'
require 'cudnn'
require 'image' -- is it necessary?
require 'sys'
--require 'hdf5'
require 'hdf5'
require 'os'
local utils = require('utils')

Expand All @@ -21,8 +21,8 @@ local num_of_samples = 1000
local hist_size = 33
local batch_size = 256

local base_path = '/home/yanir/Documents/Projects/DeepCloud/'
-- local base_path = '../'
-- local base_path = '/home/yanir/Documents/Projects/DeepCloud/'
local base_path = '../'

local shape_path = 'data/shapes/'

Expand All @@ -40,18 +40,24 @@ for k,v in pairs(shapes) do
end

------------------------------------------------------------------
---- For day to day testing change only the following two lines:
---- For day to day testing change only the following 2-3 lines:
------------------------------------------------------------------
-- Shapes to run - change this to evaluate specific shapes:
local run_shapes = {'c','f','b','a','cn','fn','bn','an'}
local run_shapes = {'c','f','b','a','d','h','cn','fn','bn','an','dn','hn'}
-- Model to run - change this to evaluate on different models:
local model = model_list['re3']
local model = model_list['pca']
-- Optional: Prediction method with a given model (leave empty if there is only one prediction method)
local pred_method = '';
------------------------------------------------------------------

local out_path = 'data/out/' .. model['id'] .. '/'
local model_path = 'data/out/' .. model['id'] .. '/'
local out_path = model_path
if string.len(pred_method) > 0 then
out_path = out_path .. pred_method .. '/'
end

local model_filename = base_path .. out_path .. 'model.t7'
local mean_filename = base_path .. out_path .. 'mean.t7'
local model_filename = base_path .. model_path .. 'model.t7'
local mean_filename = base_path .. model_path .. 'mean.t7'


for i,sid in ipairs(run_shapes) do
Expand All @@ -71,20 +77,35 @@ for i,sid in ipairs(run_shapes) do

--------------------------------------------------------------------------
---- Load or compute Hough transform and PCA for each point on the shape:
local hough_save_name = string.format('%s%s%s_hough_%d_%d.txt', base_path, shape_path, sn, hist_size, num_of_samples)
local pca_save_name = string.format('%s%s%s_pca_%d_%d.txt', base_path, shape_path, sn, hist_size, num_of_samples)
local hough_save_name = string.format('%s%s%s_hough_%d_%d.h5', base_path, shape_path, sn, hist_size, num_of_samples)
local pca_save_name = string.format('%s%s%s_pca_%d_%d.h5', base_path, shape_path, sn, hist_size, num_of_samples)

local hough, pcas
if not utils.exists(hough_save_name) or not utils.exists(pca_save_name) then
hough, pcas = Hough.hough(v, k, num_of_samples, hist_size)

torch.save(hough_save_name, hough, 'ascii')
torch.save(pca_save_name, pcas, 'ascii')
-- torch.save(hough_save_name, hough, 'ascii')
-- torch.save(pca_save_name, pcas, 'ascii')

local h5file = hdf5.open(hough_save_name, 'w')
h5file:write('hough', hough)
h5file:close()
local h5file = hdf5.open(pca_save_name, 'w')
h5file:write('pcas', pcas)
h5file:close()

else
sys.tic()

hough = torch.load(hough_save_name, 'ascii')
pcas = torch.load(pca_save_name, 'ascii')
-- hough = torch.load(hough_save_name, 'ascii')
-- pcas = torch.load(pca_save_name, 'ascii')

local h5file = hdf5.open(hough_save_name, 'r')
hough = h5file:read('hough'):all()
h5file:close()
local h5file = hdf5.open(pca_save_name, 'r')
pcas = h5file:read('pcas'):all()
h5file:close()

print('Loaded Hough transform and PCA from file in ' .. sys.toc() .. ' seconds.')
end
Expand Down Expand Up @@ -146,7 +167,7 @@ for i,sid in ipairs(run_shapes) do
-- local hough_file = hdf5.open(base_path .. out_path .. sn .. '_hough_opt.h5', 'w')
-- hough_file:write('hough_opt', normals)
-- hough_file:close()
normals = Hough.postprocess_normals2(normals, pcas, hist_size)
normals = Hough.postprocess_normals2(normals, pcas, hist_size, pred_method)
else
normals = Hough.postprocess_normals(normals, pcas)
end
Expand Down
38 changes: 23 additions & 15 deletions main_train.lua
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ require 'cunn'
require 'cudnn'
require 'image' -- is it necessary?
require 'sys'
-- require 'hdf5'
require 'hdf5'
require 'lfs'
local utils = require('utils')

Expand All @@ -17,8 +17,8 @@ local hnet = require('hough_net')
local model_list = require('model_list')


local base_path = '/home/yanir/Documents/Projects/DeepCloud/'
-- local base_path = '../'
-- local base_path = '/home/yanir/Documents/Projects/DeepCloud/'
local base_path = '../'

local shape_path = 'data/shapes/'

Expand Down Expand Up @@ -77,26 +77,34 @@ for i,model_id in ipairs(model_ids) do

--------------------------------------------------------------------------
---- Load or compute Hough transform and PCA for each point on the shape:
local hough_save_name = string.format('%s%s%s_hough_%d_%d.txt', base_path, shape_path, sn, hist_size, num_of_samples)
local pca_save_name = string.format('%s%s%s_pca_%d_%d.txt', base_path, shape_path, sn, hist_size, num_of_samples)
local hough_save_name = string.format('%s%s%s_hough_%d_%d.h5', base_path, shape_path, sn, hist_size, num_of_samples)
local pca_save_name = string.format('%s%s%s_pca_%d_%d.h5', base_path, shape_path, sn, hist_size, num_of_samples)

local h, p
if not utils.exists(hough_save_name) or not utils.exists(pca_save_name) then
h, p = Hough.hough(v, k, num_of_samples, hist_size)

torch.save(hough_save_name, h, 'ascii')
torch.save(pca_save_name, p, 'ascii')
-- local pca_file = hdf5.open(base_path .. out_path .. sn .. '_hough_100.h5', 'w')
-- pca_file:write('hough', h)
-- pca_file:close()
-- local hough_file = hdf5.open(base_path .. out_path .. sn .. '_pca_100.h5', 'w')
-- hough_file:write('pcas', p)
-- hough_file:close()
-- torch.save(hough_save_name, h, 'ascii')
-- torch.save(pca_save_name, p, 'ascii')

local h5file = hdf5.open(hough_save_name, 'w')
h5file:write('hough', h)
h5file:close()
local h5file = hdf5.open(pca_save_name, 'w')
h5file:write('pcas', p)
h5file:close()
else
sys.tic()

h = torch.load(hough_save_name, 'ascii')
p = torch.load(pca_save_name, 'ascii')
-- h = torch.load(hough_save_name, 'ascii')
-- p = torch.load(pca_save_name, 'ascii')

local h5file = hdf5.open(hough_save_name, 'r')
h = h5file:read('hough'):all()
h5file:close()
local h5file = hdf5.open(pca_save_name, 'r')
p = h5file:read('pcas'):all()
h5file:close()

print('Loaded Hough transform and PCA from file in ' .. sys.toc() .. ' seconds.')
end
Expand Down
7 changes: 7 additions & 0 deletions model_list.lua
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@ model["name"] = "PCA only"
model["shapes"] = {}
models[model["id"]] = model

model = {}
model["id"] = "reb"
model["method"] = "re"
model["name"] = "Regression (Boulch)"
model["shapes"] = {}
models[model["id"]] = model

model = {}
model["id"] = "re1"
model["method"] = "re"
Expand Down

0 comments on commit f55154a

Please sign in to comment.