Permalink
Browse files

first commit

  • Loading branch information...
0 parents commit a0d35c3bf57a2fd578831af406afd2c6bce69bdf @karpathy committed Nov 20, 2015
Showing with 2,568 additions and 0 deletions.
  1. +8 −0 .gitignore
  2. +141 −0 DataLoader.lua
  3. +90 −0 DataLoaderRaw.lua
  4. +329 −0 LanguageModel.lua
  5. +105 −0 README.md
  6. +5 −0 call_python_caption_eval.sh
  7. +40 −0 coco-caption/myeval.py
  8. +188 −0 coco/coco_preprocess.ipynb
  9. +171 −0 eval.lua
  10. +67 −0 misc/LSTM.lua
  11. +102 −0 misc/gradcheck.lua
  12. +84 −0 misc/optim_updates.lua
  13. +186 −0 net_utils.lua
  14. +236 −0 prepro.py
  15. +268 −0 test_language_model.lua
  16. +405 −0 train.lua
  17. +71 −0 utils.lua
  18. +5 −0 vis/d3.min.js
  19. 0 vis/imgs/dummy
  20. +65 −0 vis/index.html
  21. +2 −0 vis/jquery-1.8.3.min.js
@@ -0,0 +1,8 @@
+coco/
+coco-caption/
+model/
+.ipynb_checkpoints/
+vis/imgs/
+vis/vis.json
+testimages/
+checkpoints/
@@ -0,0 +1,141 @@
+require 'hdf5'
+local utils = require 'utils'
+
+local DataLoader = torch.class('DataLoader')
+
+function DataLoader:__init(opt)
+
+ -- load the json file which contains additional information about the dataset
+ print('DataLoader loading json file: ', opt.json_file)
+ self.info = utils.read_json(opt.json_file)
+ self.ix_to_word = self.info.ix_to_word
+ self.vocab_size = utils.count_keys(self.ix_to_word)
+ print('vocab size is ' .. self.vocab_size)
+
+ -- open the hdf5 file
+ print('DataLoader loading h5 file: ', opt.h5_file)
+ self.h5_file = hdf5.open(opt.h5_file, 'r')
+
+ -- extract image size from dataset
+ local images_size = self.h5_file:read('/images'):dataspaceSize()
+ assert(#images_size == 4, '/images should be a 4D tensor')
+ assert(images_size[3] == images_size[4], 'width and height must match')
+ self.num_images = images_size[1]
+ self.num_channels = images_size[2]
+ self.max_image_size = images_size[3]
+ print(string.format('read %d images of size %dx%dx%d', self.num_images,
+ self.num_channels, self.max_image_size, self.max_image_size))
+
+ -- load in the sequence data
+ local seq_size = self.h5_file:read('/labels'):dataspaceSize()
+ self.seq_length = seq_size[2]
+ print('max sequence length in data is ' .. self.seq_length)
+ -- load the pointers in full to RAM (should be small enough)
+ self.label_start_ix = self.h5_file:read('/label_start_ix'):all()
+ self.label_end_ix = self.h5_file:read('/label_end_ix'):all()
+
+ -- separate out indexes for each of the provided splits
+ self.split_ix = {}
+ self.iterators = {}
+ for i,img in pairs(self.info.images) do
+ local split = img.split
+ if not self.split_ix[split] then
+ -- initialize new split
+ self.split_ix[split] = {}
+ self.iterators[split] = 1
+ end
+ table.insert(self.split_ix[split], i)
+ end
+ for k,v in pairs(self.split_ix) do
+ print(string.format('assigned %d images to split %s', #v, k))
+ end
+end
+
+function DataLoader:resetIterator(split)
+ self.iterators[split] = 1
+end
+
+function DataLoader:getVocabSize()
+ return self.vocab_size
+end
+
+function DataLoader:getVocab()
+ return self.ix_to_word
+end
+
+function DataLoader:getSeqLength()
+ return self.seq_length
+end
+
+--[[
+ Split is a string identifier (e.g. train|val|test)
+ Returns a batch of data:
+ - X (N,3,H,W) containing the images
+ - y (L,M) containing the captions as columns (which is better for contiguous memory during training)
+ - info table of length N, containing additional information
+ The data is iterated linearly in order. Iterators for any split can be reset manually with resetIterator()
+--]]
+function DataLoader:getBatch(opt)
+ local split = utils.getopt(opt, 'split') -- lets require that user passes this in, for safety
+ local batch_size = utils.getopt(opt, 'batch_size', 5) -- how many images get returned at one time (to go through CNN)
+ local seq_per_img = utils.getopt(opt, 'seq_per_img', 5) -- number of sequences to return per image
+
+ local split_ix = self.split_ix[split]
+ assert(split_ix, 'split ' .. split .. ' not found.')
+
+ -- pick an index of the datapoint to load next
+ local img_batch_raw = torch.ByteTensor(batch_size, 3, 256, 256)
+ local label_batch = torch.LongTensor(batch_size * seq_per_img, self.seq_length)
+ local max_index = #split_ix
+ local wrapped = false
+ local infos = {}
+ for i=1,batch_size do
+
+ local ri = self.iterators[split] -- get next index from iterator
+ local ri_next = ri + 1 -- increment iterator
+ if ri_next > max_index then ri_next = 1; wrapped = true end -- wrap back around
+ self.iterators[split] = ri_next
+ ix = split_ix[ri]
+ assert(ix ~= nil, 'bug: split ' .. split .. ' was accessed out of bounds with ' .. ri)
+
+ -- fetch the image from h5
+ local img = self.h5_file:read('/images'):partial({ix,ix},{1,self.num_channels},
+ {1,self.max_image_size},{1,self.max_image_size})
+ img_batch_raw[i] = img
+
+ -- fetch the sequence labels
+ local ix1 = self.label_start_ix[ix]
+ local ix2 = self.label_end_ix[ix]
+ local ncap = ix2 - ix1 + 1 -- number of captions available for this image
+ assert(ncap > 0, 'an image does not have any label. this can be handled but right now isn\'t')
+ local seq
+ if ncap < seq_per_img then
+ -- we need to subsample (with replacement)
+ seq = torch.LongTensor(seq_per_img, self.seq_length)
+ for q=1,5 do
+ local ixl = torch.random(ix1,ix2)
+ seq[{ {q,q} }] = self.h5_file:read('/labels'):partial({ixl, ixl}, {1,self.seq_length})
+ end
+ else
+ -- there is enough data to read a contiguous chunk, but subsample the chunk position
+ local ixl = torch.random(ix1, ix2 - seq_per_img + 1) -- generates integer in the range
+ seq = self.h5_file:read('/labels'):partial({ixl, ixl+seq_per_img-1}, {1,self.seq_length})
+ end
+ local il = (i-1)*seq_per_img+1
+ label_batch[{ {il,il+seq_per_img-1} }] = seq
+
+ -- and record associated info as well
+ local info_struct = {}
+ info_struct.id = self.info.images[ix].id
+ info_struct.file_path = self.info.images[ix].file_path
+ table.insert(infos, info_struct)
+ end
+
+ local data = {}
+ data.images = img_batch_raw
+ data.labels = label_batch:transpose(1,2):contiguous() -- note: make label sequences go down as columns
+ data.bounds = {it_pos_now = self.iterators[split], it_max = #split_ix, wrapped = wrapped}
+ data.infos = infos
+ return data
+end
+
@@ -0,0 +1,90 @@
+--[[
+Same as DataLoader but only requires a folder of images.
+Does not have an h5 dependency.
+Only used at test time.
+]]--
+
+local utils = require 'utils'
+require 'lfs'
+require 'image'
+
+local DataLoaderRaw = torch.class('DataLoaderRaw')
+
+function DataLoaderRaw:__init(opt)
+ local coco_json = utils.getopt(opt, 'coco_json', '')
+
+ -- load the json file which contains additional information about the dataset
+ print('DataLoaderRaw loading images from folder: ', opt.folder_path)
+
+ self.files = {}
+ self.ids = {}
+ if string.len(opt.coco_json) > 0 then
+ print('reading from ' .. opt.coco_json)
+ -- read in filenames from the coco-style json file
+ self.coco_annotation = utils.read_json(opt.coco_json)
+ for k,v in pairs(self.coco_annotation.images) do
+ local fullpath = path.join(opt.folder_path, v.file_name)
+ table.insert(self.files, fullpath)
+ table.insert(self.ids, v.id)
+ end
+ else
+ -- read in all the filenames from the folder
+ print('listing all images in directory ' .. opt.folder_path)
+ local n = 1
+ for file in lfs.dir(opt.folder_path) do
+ local fullpath = path.join(opt.folder_path, file)
+ if lfs.attributes(fullpath,"mode") == "file" then
+ table.insert(self.files, fullpath)
+ table.insert(self.ids, tostring(n)) -- just order them sequentially
+ n=n+1
+ end
+ end
+ end
+
+ self.N = #self.files
+ print('DataLoaderRaw found ' .. self.N .. ' images')
+
+ self.iterator = 1
+end
+
+function DataLoaderRaw:resetIterator()
+ self.iterator = 1
+end
+
+--[[
+ Returns a batch of data:
+ - X (N,3,256,256) containing the images as uint8 ByteTensor
+ - info table of length N, containing additional information
+ The data is iterated linearly in order
+--]]
+function DataLoaderRaw:getBatch(opt)
+ local batch_size = utils.getopt(opt, 'batch_size', 5) -- how many images get returned at one time (to go through CNN)
+ -- pick an index of the datapoint to load next
+ local img_batch_raw = torch.ByteTensor(batch_size, 3, 256, 256)
+ local max_index = self.N
+ local wrapped = false
+ local infos = {}
+ for i=1,batch_size do
+ local ri = self.iterator
+ local ri_next = ri + 1 -- increment iterator
+ if ri_next > max_index then ri_next = 1; wrapped = true end -- wrap back around
+ self.iterator = ri_next
+
+ -- load the image
+ local img = image.load(self.files[ri], 3, 'byte')
+ img_batch_raw[i] = image.scale(img, 256, 256)
+
+ -- and record associated info as well
+ local info_struct = {}
+ info_struct.id = self.ids[ri]
+ info_struct.file_path = self.files[ri]
+ table.insert(infos, info_struct)
+ end
+
+ local data = {}
+ data.images = img_batch_raw
+ data.bounds = {it_pos_now = self.iterator, it_max = self.N, wrapped = wrapped}
+ data.infos = infos
+ return data
+end
+
Oops, something went wrong.

0 comments on commit a0d35c3

Please sign in to comment.