Skip to content

Commit

Permalink
Add @aaronnech's dataset.lua improvements for #132.
Browse files Browse the repository at this point in the history
  • Loading branch information
Brandon Amos committed May 8, 2016
1 parent 4c0522f commit 1d5491a
Showing 1 changed file with 53 additions and 108 deletions.
161 changes: 53 additions & 108 deletions training/dataset.lua
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ local argcheck = require 'argcheck'
require 'sys'
require 'xlua'
require 'image'
tds = require 'tds'

local dataset = torch.class('dataLoader')

Expand Down Expand Up @@ -92,118 +93,78 @@ function dataset:__init(...)
if not self.sampleHookTest then self.sampleHookTest = self.defaultSampleHook end

-- find class names
print('finding class names')
self.classes = {}
local classPaths = {}
local classPaths = tds.Hash()
local classToIdx = tds.Hash()
if self.forceClasses then
print('Adding forceClasses class names')
for k,v in pairs(self.forceClasses) do
self.classes[k] = v
classPaths[k] = {}
classPaths[k] = tds.Hash()
end
end
local function tableFind(t, o) for k,v in pairs(t) do if v == o then return k end end end
-- loop over each paths folder, get list of unique class names,
-- also store the directory paths per class
-- for each class,
print('Adding all path folders')
for _,path in ipairs(self.paths) do
local dirs = dir.getdirectories(path);
for _,dirpath in ipairs(dirs) do
for dirpath in paths.iterdirs(path) do
dirpath = path .. '/' .. dirpath
local class = paths.basename(dirpath)
local idx = tableFind(self.classes, class)
if not idx then
table.insert(self.classes, class)
idx = #self.classes
classPaths[idx] = {}
end
if not tableFind(classPaths[idx], dirpath) then
table.insert(classPaths[idx], dirpath);
end
self.classes[#self.classes + 1] = class
classPaths[#classPaths + 1] = dirpath
end
end

print(#self.classes .. ' class names found')
self.classIndices = {}
for k,v in ipairs(self.classes) do
self.classIndices[v] = k
end

-- define command-line tools, try your best to maintain OSX compatibility
local wc = 'wc'
local cut = 'cut'
local find = 'find'
if jit.os == 'OSX' then
wc = 'gwc'
cut = 'gcut'
find = 'gfind'
end
----------------------------------------------------------------------
-- Options for the GNU find command
local extensionList = {'jpg', 'png','JPG','PNG','JPEG', 'ppm', 'PPM', 'bmp', 'BMP'}
local findOptions = ' -iname "*.' .. extensionList[1] .. '"'
for i=2,#extensionList do
findOptions = findOptions .. ' -o -iname "*.' .. extensionList[i] .. '"'
end

-- find the image path names
print('Finding path for each image')
self.imagePath = torch.CharTensor() -- path to each image in dataset
self.imageClass = torch.LongTensor() -- class index of each image (class index in self.classes)
self.classList = {} -- index of imageList to each image of a particular class
self.classListSample = self.classList -- the main list used when sampling data

print('running "find" on each class directory, and concatenate all'
.. ' those filenames into a single file containing all image paths for a given class')
-- so, generates one file per class
local classFindFiles = {}
for i=1,#self.classes do
classFindFiles[i] = os.tmpname()
end
local combinedFindList = os.tmpname();

local tmpfile = os.tmpname()
local tmphandle = assert(io.open(tmpfile, 'w'))
-- iterate over classes
for i, _ in ipairs(self.classes) do
-- iterate over classPaths
for _,path in ipairs(classPaths[i]) do
local command = find .. ' "' .. path .. '" ' .. findOptions
.. ' >>"' .. classFindFiles[i] .. '" \n'
tmphandle:write(command)
local counts = tds.Hash()
local maxPathLength = 0

print('Calculating maximum class name length and counting files')
local length = 0

local fullPaths = tds.Hash()
-- iterate over classPaths
for _,path in pairs(classPaths) do
local count = 0
-- iterate over files in the class path
for f in paths.iterfiles(path) do
local fullPath = path .. '/' .. f
maxPathLength = math.max(fullPath:len(), maxPathLength)
count = count + 1
length = length + 1
fullPaths[#fullPaths + 1] = fullPath
end
counts[path] = count
end
io.close(tmphandle)
os.execute('bash ' .. tmpfile)
os.execute('rm -f ' .. tmpfile)

print('now combine all the files to a single large file')
tmpfile = os.tmpname()
tmphandle = assert(io.open(tmpfile, 'w'))
-- concat all finds to a single large file in the order of self.classes
for i=1,#self.classes do
local command = 'cat "' .. classFindFiles[i] .. '" >>' .. combinedFindList .. ' \n'
tmphandle:write(command)
end
io.close(tmphandle)
os.execute('bash ' .. tmpfile)
os.execute('rm -f ' .. tmpfile)

--==========================================================================
print('load the large concatenated list of sample paths to self.imagePath')
local maxPathLength = tonumber(sys.fexecute(wc .. " -L '"
.. combinedFindList .. "' |"
.. cut .. " -f1 -d' '")) + 1
local length = tonumber(sys.fexecute(wc .. " -l '"
.. combinedFindList .. "' |"
.. cut .. " -f1 -d' '"))
assert(length > 0, "Could not find any image file in the given input paths")
assert(maxPathLength > 0, "paths of files are length 0?")

self.imagePath:resize(length, maxPathLength):fill(0)
local s_data = self.imagePath:data()
local count = 0
for line in io.lines(combinedFindList) do
ffi.copy(s_data, line)
s_data = s_data + maxPathLength
if self.verbose and count % 10000 == 0 then
xlua.progress(count, length)
end;
count = count + 1
for _,line in pairs(fullPaths) do
ffi.copy(s_data, line)
s_data = s_data + maxPathLength
if self.verbose and count % 10000 == 0 then
xlua.progress(count, length)
end;
count = count + 1
end

self.numSamples = self.imagePath:size(1)
Expand All @@ -214,9 +175,7 @@ function dataset:__init(...)
local runningIndex = 0
for i=1,#self.classes do
if self.verbose then xlua.progress(i, #(self.classes)) end
local clsLength = tonumber(sys.fexecute(wc .. " -l '"
.. classFindFiles[i] .. "' |"
.. cut .. " -f1 -d' '"))
local clsLength = counts[classPaths[i]]
if clsLength == 0 then
error('Class has zero samples: ' .. self.classes[i])
else
Expand All @@ -226,19 +185,6 @@ function dataset:__init(...)
runningIndex = runningIndex + clsLength
end

--==========================================================================
-- clean up temporary files
print('Cleaning up temporary files')
local tmpfilelistall = ''
for i=1,#(classFindFiles) do
tmpfilelistall = tmpfilelistall .. ' "' .. classFindFiles[i] .. '"'
if i % 1000 == 0 then
os.execute('rm -f ' .. tmpfilelistall)
tmpfilelistall = ''
end
end
os.execute('rm -f ' .. tmpfilelistall)
os.execute('rm -f "' .. combinedFindList .. '"')
--==========================================================================

if self.split == 100 then
Expand Down Expand Up @@ -373,13 +319,13 @@ function dataset:sample(quantity)
error('No training mode when split is set to 0')
end
quantity = quantity or 1
local dataTable = {}
local scalarTable = {}
local dataTable = tds.Hash()
local scalarTable = tds.Hash()
for _=1,quantity do
local class = torch.random(1, #self.classes)
local out = self:getByClass(class)
table.insert(dataTable, out)
table.insert(scalarTable, class)
dataTable[#dataTable + 1] = out
scalarTable[#scalarTable + 1] = class
end
local data, scalarLabels, labels = tableToOutput(self, dataTable, scalarTable)
return data, scalarLabels, labels
Expand Down Expand Up @@ -429,31 +375,30 @@ function dataset:samplePeople(peoplePerBatch, imagesPerPerson)
end

local classes = torch.randperm(#trainLoader.classes)[{{1,peoplePerBatch}}]:int()
local nSamplesPerClass = torch.Tensor(peoplePerBatch)
local numPerClass = torch.Tensor(peoplePerBatch)
for i=1,peoplePerBatch do
local nSample = math.min(self.classListSample[classes[i]]:nElement(), imagesPerPerson)
nSamplesPerClass[i] = nSample
local n = math.min(self.classListSample[classes[i]]:nElement(), imagesPerPerson)
numPerClass[i] = n
end

local data = torch.Tensor(nSamplesPerClass:sum(),
local data = torch.Tensor(numPerClass:sum(),
self.sampleSize[1], self.sampleSize[2], self.sampleSize[3])

local dataIdx = 1
for i=1,peoplePerBatch do
local cls = classes[i]
local nSamples = nSamplesPerClass[i]
local nTotal = self.classListSample[classes[i]]:nElement()
local shuffle = torch.randperm(nTotal)
for j = 1, nSamples do
local n = numPerClass[i]
local shuffle = torch.randperm(n)
for j=1,n do
imgNum = self.classListSample[cls][shuffle[j]]
imgPath = ffi.string(torch.data(self.imagePath[imgNum]))
data[dataIdx] = self:sampleHookTrain(imgPath)
dataIdx = dataIdx + 1
end
end
assert(dataIdx - 1 == nSamplesPerClass:sum())
assert(dataIdx - 1 == numPerClass:sum())

return data, nSamplesPerClass
return data, numPerClass
end

function dataset:get(i1, i2)
Expand Down

1 comment on commit 1d5491a

@aaronnech
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, one change: add 1 to maxPathLength after line 156. The string extraction stuff won't parse out the maximum string correctly otherwise (due to the length being 1 character too long).

Thanks for integrating this!

Please sign in to comment.