diff --git a/README.md b/README.md index faf18f2..dd434a9 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,123 @@ # Hierarchical Co-Attention for Visual Question Answering -Train a Hierarchical Co-Attention model for Visual Question Answering. This current code can get 62.1 on Open-Ended and 66.1 on Multiple-Choice on test-standard split. You can check Codalab leaderboard for more details. +Train a Hierarchical Co-Attention model for Visual Question Answering. This current code can get 62.1 on Open-Ended and 66.1 on Multiple-Choice on test-standard split. For COCO-QA, this code can get 65.4 on Accuracy. For more information, please refer the paper [https://arxiv.org/abs/1606.00061](https://arxiv.org/abs/1606.00061) -![teaser results](https://raw.github.com/jiasenlu/HieCoAttenVQA/master/vis/demo.png) +### Requirements +This code is written in Lua and requires [Torch](http://torch.ch/). The preprocssinng code is in Python, and you need to install [NLTK](http://www.nltk.org/) if you want to use NLTK to tokenize the question. + +You also need to install the following package in order to sucessfully run the code. + +- [cudnn.torch](https://github.com/soumith/cudnn.torch) +- [torch-hdf5](https://github.com/deepmind/torch-hdf5) +- [lua-cjson](http://www.kyne.com.au/~mark/software/lua-cjson.php) +- [loadcaffe](https://github.com/szagoruyko/loadcaffe) +- [iTorch](https://github.com/facebook/iTorch) + +### Training + +We have prepared everything for you ;) + +##### Download Dataset +The first thing you need to do is to download the data and do some preprocessing. Head over to the `data/` folder and run + +For **VQA**: + +``` +$ python vqa_preprocessing.py --download True --split 1 +``` +`--download Ture` means you choose to download the VQA data from the [VQA website](http://www.visualqa.org/) and `--split 1` means you use COCO train set to train and validation set to evaluation. `--split 2 ` means you use COCO train+val set to train and test set to evaluate. After this step, it will generate two files under the `data` folder. `vqa_raw_train.json` and `vqa_raw_test.json` + +For **COCO-QA** + +``` +$ python vqa_preprocessing.py --download True +``` +This will download the COCO-QA dataset from [here](http://www.cs.toronto.edu/~mren/imageqa/data/cocoqa/) and generate two files under the `data` folder. `cocoqa_raw_train.json` and `cocoqa_raw_test.json` + +##### Download Image Model +Here we use VGG_ILSVRC_19_layers [model](https://gist.github.com/ksimonyan/3785162f95cd2d5fee77) and Deep Residual network implement by Facebook [model](https://github.com/facebook/fb.resnet.torch). + +Head over to the `image_model` folder and run + +``` +$ python download_model --download 'VGG' +``` +This will download the VGG_ILSVRC_19_layers model under `image_model` folder. To download the Deep Residual Model, you need to change the `VGG` to `Residual`. + +##### Generate Image/Question Features + +Head over to the `prepro` folder and run + +For **VQA**: + +``` +$ python prepro_vqa.py --input_train_json ../data/vqa_raw_train.json --input_test_json ../data/vqa_raw_test.json --num_ans 1000 +``` +to get the question features. --num_ans specifiy how many top answers you want to use during training. You will also see some question and answer statistics in the terminal output. This will generate two files in `data/` folder, `vqa_data_prepro.h5` and `vqa_data_prepro.json`. + + +For **COCO-QA** + +``` +$ python prepro_cocoqa.py --input_train_json ../data/cocoqa_raw_train.json --input_test_json ../data/cocoqa_raw_test.json +``` +COCO-QA use all the answers in train, so there is no `--num_ans` option. This will generate two files in `data/` folder, `cocoqa_data_prepro.h5` and `cocoqa_data_prepro.json`. + +Then we are ready to extract the image features. -## Comming soon! +For **VGG** image feature: + +``` +$ th prepro_img_vgg.lua -input_json ../data/vqa_data_prepro.json -image_root /home/jiasenlu/data/ -cnn_proto ../image_model/VGG_ILSVRC_19_layers_deploy.prototxt -cnn_model ../image_model/VGG_ILSVRC_19_layers.caffemodel +``` +you can change the `-gpuid`, `-backend` and `-batch_size` based on your gpu. + +For **Deep Residual** image feature: + +##### Train the model + +We have everything ready to train the VQA and COCO-QA model. Back to the `main` folder + +``` +th train.lua -input_img_train_h5 data/vqa_data_img_vgg_train.h5 -input_img_test_h5 data/vqa_data_img_vgg_test.h5 -input_ques_h5 data/vqa_data_prepro.h5 -input_json data/vqa_data_prepro.json -co_atten_type Alternating -feature_type VGG +``` + +to train **Alternating co-attention** model on VQA using VGG image feature. You can train the **Parallel co-attention** by setting `-co_atten_type Parallel`. The prallel co-attention usually takes more time than alternating co-attention. + +##### Note +- Deep Residual Image Feature is 4 times larger than VGG feature, make sure you have enough RAM when you extract or load the features. +- If you didn't have large RAM, replace the `require 'misc.DataLoader'` (Line 11 in `train.lua`) with `require 'misc.DataLoaderDisk`. The model will read the data directly from the hard disk (SSD prefered) + +### Evaluation + +##### Evaluate using Pre-trained Model +The pre-trained model can be download [here](https://filebox.ece.vt.edu/~jiasenlu/codeRelease/co_atten/model/) + +##### Metric + +To Evaluate VQA, you need to download the VQA [evaluation tool](https://github.com/VT-vision-lab/VQA). To evaluate COCO-QA, you can use script `evaluate_cocoqa.py` under `metric/` folder. If you need to evaluate based on WUPS, download the evaluation script from [here](http://datasets.d2.mpi-inf.mpg.de/mateusz14visual-turing/calculate_wups.py) + +##### VQA on Single Image with Free Form Question + +We use iTorch to demo the visual question answering with pre-trained model. +In the root folder, open `itorch notebook`, then you can load any image and ask question using the itorch notebook. + +##### Attention Visualization + + +### Reference + +If you use this code as part of any published research, please acknowledge the following paper + +``` +@misc{Lu2016Hie, +author = {Lu, Jiasen and Yang, Jianwei and Batra, Dhruv and Parikh, Devi}, +title = {Hierarchical Question-Image Co-Attention for Visual Question Answering}, +journal = {arXiv preprint arXiv:1606.00061v2}, +year = {2016} +} +``` + +### Attention Demo + +![teaser results](https://raw.github.com/jiasenlu/HieCoAttenVQA/master/vis/demo.png) diff --git a/data/cocoqa_preprocess.py b/data/cocoqa_preprocess.py index 677827b..b6593d2 100755 --- a/data/cocoqa_preprocess.py +++ b/data/cocoqa_preprocess.py @@ -128,8 +128,8 @@ def main(params): tp = types[i] test.append({'ques_id': question_id, 'img_path': image_path, 'question': question, 'types': tp,'ans': ans}) - json.dump(train, open('coco_qa_raw_train.json', 'w')) - json.dump(test, open('coco_qa_raw_test.json', 'w')) + json.dump(train, open('cocoqa_raw_train.json', 'w')) + json.dump(test, open('cocoqa_raw_test.json', 'w')) if __name__ == "__main__": diff --git a/data/vqa_preprocess.py b/data/vqa_preprocess.py index 183f8c0..1c4b139 100755 --- a/data/vqa_preprocess.py +++ b/data/vqa_preprocess.py @@ -122,8 +122,8 @@ def main(params): parser = argparse.ArgumentParser() # input json - parser.add_argument('--download', default=0, type=int, help='Download and extract data from VQA server') - parser.add_argument('--split', default=2, type=int, help='1: train on Train and test on Val, 2: train on Train+Val and test on Test') + parser.add_argument('--download', default=1, type=int, help='Download and extract data from VQA server') + parser.add_argument('--split', default=1, type=int, help='1: train on Train and test on Val, 2: train on Train+Val and test on Test') args = parser.parse_args() params = vars(args) diff --git a/eval.lua b/eval.lua new file mode 100644 index 0000000..1e15792 --- /dev/null +++ b/eval.lua @@ -0,0 +1,193 @@ + +require 'nn' +require 'torch' +require 'optim' +require 'misc.DataLoader' +require 'misc.word_level' +require 'misc.phrase_level' +require 'misc.ques_level' +require 'misc.recursive_atten' +require 'misc.optim_updates' +local utils = require 'misc.utils' +require 'xlua' + + +cmd = torch.CmdLine() +cmd:text() +cmd:text('evaluate a Visual Question Answering model') +cmd:text() +cmd:text('Options') + +-- Data input settings +cmd:option('-input_img_train_h5','data/vqa_data_img_vgg_train.h5','path to the h5file containing the image feature') +cmd:option('-input_img_test_h5','data/vqa_data_img_vgg_test.h5','path to the h5file containing the image feature') +cmd:option('-input_ques_h5','data/vqa_data_prepro.h5','path to the h5file containing the preprocessed dataset') +cmd:option('-input_json','data/vqa_data_prepro.json','path to the json file containing additional info and vocab') + +cmd:option('-start_from', 'model/vqa_model/model_alternating_train_vgg.t7', 'path to a model checkpoint to initialize model weights from. Empty = don\'t') +cmd:option('-co_atten_type', 'Alternating', 'co_attention type. Parallel or Alternating, alternating trains more faster than parallel.') +cmd:option('-feature_type', 'VGG', 'VGG or Residual') + +-- misc +cmd:option('-backend', 'cudnn', 'nn|cudnn') +cmd:option('-gpuid', 2, 'which gpu to use. -1 = use CPU') +cmd:option('-seed', 123, 'random number generator seed to use') + +cmd:text() + +local batch_size = 256 + +------------------------------------------------------------------------------- +-- Basic Torch initializations +------------------------------------------------------------------------------- +local opt = cmd:parse(arg) +torch.manualSeed(opt.seed) +print(opt) +torch.setdefaulttensortype('torch.FloatTensor') -- for CPU + +if opt.gpuid >= 0 then + require 'cutorch' + require 'cunn' + if opt.backend == 'cudnn' then + require 'cudnn' + end + cutorch.manualSeed(opt.seed) + --cutorch.setDevice(opt.gpuid+1) -- note +1 because lua is 1-indexed +end + +opt = cmd:parse(arg) + +------------------------------------------------------------------------ +--Design Parameters and Network Definitions +------------------------------------------------------------------------ +local protos = {} +print('Building the model...') +-- intialize language model +local loaded_checkpoint +local lmOpt +if string.len(opt.start_from) > 0 then + + loaded_checkpoint = torch.load(opt.start_from) + lmOpt = loaded_checkpoint.lmOpt +else + lmOpt = {} + lmOpt.vocab_size = loader:getVocabSize() + lmOpt.input_encoding_size = opt.input_encoding_size + lmOpt.rnn_size = opt.rnn_size + lmOpt.num_layers = opt.rnn_layers + lmOpt.dropout = 0.5 + lmOpt.seq_length = loader:getSeqLength() + lmOpt.batch_size = opt.batch_size + lmOpt.output_size = opt.rnn_size + lmOpt.atten_type = opt.co_atten_type + lmOpt.feature_type = opt.feature_type +end +lmOpt.hidden_size = 512 +lmOpt.feature_type = 'VGG' +lmOpt.atten_type = opt.co_atten_type +print(lmOpt) + +protos.word = nn.word_level(lmOpt) +protos.phrase = nn.phrase_level(lmOpt) +protos.ques = nn.ques_level(lmOpt) + +protos.atten = nn.recursive_atten() +protos.crit = nn.CrossEntropyCriterion() + +if opt.gpuid >= 0 then + for k,v in pairs(protos) do v:cuda() end +end + +local wparams, grad_wparams = protos.word:getParameters() +local pparams, grad_pparams = protos.phrase:getParameters() +local qparams, grad_qparams = protos.ques:getParameters() +local aparams, grad_aparams = protos.atten:getParameters() + + +if string.len(opt.start_from) > 0 then + print('Load the weight...') + wparams:copy(loaded_checkpoint.wparams) + pparams:copy(loaded_checkpoint.pparams) + qparams:copy(loaded_checkpoint.qparams) + aparams:copy(loaded_checkpoint.aparams) +end + +print('total number of parameters in word_level: ', wparams:nElement()) +assert(wparams:nElement() == grad_wparams:nElement()) + +print('total number of parameters in phrase_level: ', pparams:nElement()) +assert(pparams:nElement() == grad_pparams:nElement()) + +print('total number of parameters in ques_level: ', qparams:nElement()) +assert(qparams:nElement() == grad_qparams:nElement()) +protos.ques:shareClones() + +print('total number of parameters in recursive_attention: ', aparams:nElement()) +assert(aparams:nElement() == grad_aparams:nElement()) + +------------------------------------------------------------------------------- +-- Create the Data Loader instance +------------------------------------------------------------------------------- + +local loader = DataLoader{h5_img_file_train = opt.input_img_train_h5, h5_img_file_test = opt.input_img_test_h5, h5_ques_file = opt.input_ques_h5, json_file = opt.input_json, feature_type = opt.feature_type} + +collectgarbage() + +function eval_split(split) + + protos.word:evaluate() + protos.phrase:evaluate() + protos.ques:evaluate() + protos.atten:evaluate() + loader:resetIterator(split) + + local n = 0 + local loss_evals = 0 + local predictions = {} + local total_num = loader:getDataNum(2) + print(total_num) + local logprob_all = torch.Tensor(total_num, 1000) + local ques_id = torch.Tensor(total_num) + + for i = 1, total_num, batch_size do + xlua.progress(i, total_num) + local r = math.min(i+batch_size-1, total_num) + + local data = loader:getBatch{batch_size = r-i+1, split = split} + -- ship the data to cuda + if opt.gpuid >= 0 then + data.images = data.images:cuda() + data.questions = data.questions:cuda() + data.ques_len = data.ques_len:cuda() + end + + local word_feat, img_feat, w_ques, w_img, mask = unpack(protos.word:forward({data.questions, data.images})) + + local conv_feat, p_ques, p_img = unpack(protos.phrase:forward({word_feat, data.ques_len, img_feat, mask})) + + local q_ques, q_img = unpack(protos.ques:forward({conv_feat, data.ques_len, img_feat, mask})) + + local feature_ensemble = {w_ques, w_img, p_ques, p_img, q_ques, q_img} + local out_feat = protos.atten:forward(feature_ensemble) + + logprob_all:sub(i, r):copy(out_feat:float()) + ques_id:sub(i, r):copy(data.ques_id) + + end + + + tmp,pred=torch.max(logprob_all,2); + + for i=1,total_num do + local ans = loader.ix_to_ans[tostring(pred[{i,1}])] + table.insert(predictions,{question_id=ques_id[i],answer=ans}) + end + + return {predictions} +end + +predictions = eval_split(2) + +utils.write_json('OpenEnded_mscoco_co-atten_results.json', predictions[1]) + +--utils.write_json('MultipleChoice_mscoco_co-atten_results.json', predictions[2]) diff --git a/misc/DataLoader.lua b/misc/DataLoader.lua new file mode 100644 index 0000000..f4390f7 --- /dev/null +++ b/misc/DataLoader.lua @@ -0,0 +1,161 @@ +require 'hdf5' +local utils = require 'misc.utils' +local DataLoader = torch.class('DataLoader') + +function DataLoader:__init(opt) + + if opt.h5_img_file_train ~= nil then + print('DataLoader loading h5 image file: ', opt.h5_img_file_train) + local h5_file = hdf5.open(opt.h5_img_file_train, 'r') + self.fv_im_train = h5_file:read('/images_train'):all() + h5_file:close() + end + + if opt.h5_img_file_train ~= nil then + print('DataLoader loading h5 image file: ', opt.h5_img_file_test) + local h5_file = hdf5.open(opt.h5_img_file_test, 'r') + self.fv_im_test = h5_file:read('/images_test'):all() + h5_file:close() + end + + print('DataLoader loading h5 question file: ', opt.h5_ques_file) + local h5_file = hdf5.open(opt.h5_ques_file, 'r') + self.ques_train = h5_file:read('/ques_train'):all() + self.ques_len_train = h5_file:read('/ques_len_train'):all() + self.img_pos_train = h5_file:read('/img_pos_train'):all() + self.ques_id_train = h5_file:read('/ques_id_train'):all() + self.answer = h5_file:read('/answers'):all() + self.split_train = h5_file:read('/split_train'):all() + + self.ques_test = h5_file:read('/ques_test'):all() + self.ques_len_test = h5_file:read('/ques_len_test'):all() + self.img_pos_test = h5_file:read('/img_pos_test'):all() + self.ques_id_test = h5_file:read('/ques_id_test'):all() + self.split_test = h5_file:read('/split_test'):all() + self.ans_test = h5_file:read('/ans_test'):all() + + h5_file:close() + print('Transform the image feature...') + + if opt.h5_img_file_train ~= nil then + if opt.feature_type == 'VGG' then + self.fv_im_train = self.fv_im_train:view(-1, 196, 512):contiguous() + elseif opt.feature_type == 'Residual' then + self.fv_im_train = self.fv_im_train:view(-1, 196, 2048):contiguous() + else + error('feature type error') + end + end + + if opt.h5_img_file_test ~= nil then + if opt.feature_type == 'VGG' then + self.fv_im_test = self.fv_im_test:view(-1, 196, 512):contiguous() + elseif opt.feature_type == 'Residual' then + self.fv_im_test = self.fv_im_test:view(-1, 196, 2048):contiguous() + else + error('feature type error') + end + end + + print('DataLoader loading json file: ', opt.json_file) + local json_file = utils.read_json(opt.json_file) + self.ix_to_word = json_file.ix_to_word + self.ix_to_ans = json_file.ix_to_ans + + self.seq_length = self.ques_train:size(2) + + -- count the vocabulary key! + self.vocab_size = utils.count_key(self.ix_to_word) + + -- Let's get the split for train and val and test. + self.split_ix = {} + self.iterators = {} + + for i = 1,self.split_train:size(1) do + local idx = self.split_train[i] + if not self.split_ix[idx] then + self.split_ix[idx] = {} + self.iterators[idx] = 1 + end + table.insert(self.split_ix[idx], i) + end + + for i = 1,self.split_test:size(1) do + local idx = self.split_test[i] + if not self.split_ix[idx] then + self.split_ix[idx] = {} + self.iterators[idx] = 1 + end + table.insert(self.split_ix[idx], i) + end + + for k,v in pairs(self.split_ix) do + print(string.format('assigned %d images to split %s', #v, k)) + end + collectgarbage() -- do it often and there is no harm ;) +end + +function DataLoader:resetIterator(split) + self.iterators[split] = 1 +end + +function DataLoader:getVocabSize() + return self.vocab_size +end + +function DataLoader:getSeqLength() + return self.seq_length +end + +function DataLoader:getDataNum(split) + return #self.split_ix[split] +end + +function DataLoader:getBatch(opt) + local split = utils.getopt(opt, 'split') -- lets require that user passes this in, for safety + local batch_size = utils.getopt(opt, 'batch_size', 128) + + local split_ix = self.split_ix[split] + assert(split_ix, 'split ' .. split .. ' not found.') + + local max_index = #split_ix + local infos = {} + local ques_idx = torch.LongTensor(batch_size):fill(0) + local img_idx = torch.LongTensor(batch_size):fill(0) + + for i=1,batch_size do + local ri = self.iterators[split] -- get next index from iterator + local ri_next = ri + 1 -- increment iterator + if ri_next > max_index then ri_next = 1 end + self.iterators[split] = ri_next + if split == 0 then + ix = split_ix[torch.random(max_index)] + else + ix = split_ix[ri] + end + assert(ix ~= nil, 'bug: split ' .. split .. ' was accessed out of bounds with ' .. ri) + ques_idx[i] = ix + if split == 0 or split == 1 then + img_idx[i] = self.img_pos_train[ix] + else + img_idx[i] = self.img_pos_test[ix] + end + end + + local data = {} + -- fetch the question and image features. + if split == 0 or split == 1 then + data.images = self.fv_im_train:index(1,img_idx) + data.questions = self.ques_train:index(1, ques_idx) + data.ques_id = self.ques_id_train:index(1, ques_idx) + data.ques_len = self.ques_len_train:index(1, ques_idx) + data.answer = self.answer:index(1, ques_idx) + else + data.images = self.fv_im_test:index(1,img_idx) + data.questions = self.ques_test:index(1, ques_idx) + data.ques_id = self.ques_id_test:index(1, ques_idx) + data.ques_len = self.ques_len_test:index(1, ques_idx) + data.answer = self.ans_test:index(1, ques_idx) + end + return data +end diff --git a/misc/DataLoaderDisk.lua b/misc/DataLoaderDisk.lua new file mode 100644 index 0000000..46b1ec1 --- /dev/null +++ b/misc/DataLoaderDisk.lua @@ -0,0 +1,168 @@ +require 'hdf5' +local utils = require 'misc.utils' +local DataLoader = torch.class('DataLoader') + +function DataLoader:__init(opt) + + if opt.h5_img_file_train ~= nil then + print('DataLoader loading h5 image file: ', opt.h5_img_file_train) + self.h5_img_file_train = hdf5.open(opt.h5_img_file_train, 'r') + end + + if opt.h5_img_file_train ~= nil then + print('DataLoader loading h5 image file: ', opt.h5_img_file_test) + self.h5_img_file_test = hdf5.open(opt.h5_img_file_test, 'r') + end + + print('DataLoader loading h5 question file: ', opt.h5_ques_file) + local h5_file = hdf5.open(opt.h5_ques_file, 'r') + self.ques_train = h5_file:read('/ques_train'):all() + self.ques_len_train = h5_file:read('/ques_len_train'):all() + self.img_pos_train = h5_file:read('/img_pos_train'):all() + self.ques_id_train = h5_file:read('/ques_id_train'):all() + self.answer = h5_file:read('/answers'):all() + self.split_train = h5_file:read('/split_train'):all() + + self.ques_test = h5_file:read('/ques_test'):all() + self.ques_len_test = h5_file:read('/ques_len_test'):all() + self.img_pos_test = h5_file:read('/img_pos_test'):all() + self.ques_id_test = h5_file:read('/ques_id_test'):all() + self.split_test = h5_file:read('/split_test'):all() + self.ans_test = h5_file:read('/ans_test'):all() + + h5_file:close() + + print('DataLoader loading json file: ', opt.json_file) + local json_file = utils.read_json(opt.json_file) + self.ix_to_word = json_file.ix_to_word + self.ix_to_ans = json_file.ix_to_ans + self.feature_type = opt.feature_type + self.seq_length = self.ques_train:size(2) + + -- count the vocabulary key! + self.vocab_size = utils.count_key(self.ix_to_word) + + -- Let's get the split for train and val and test. + self.split_ix = {} + self.iterators = {} + + for i = 1,self.split_train:size(1) do + local idx = self.split_train[i] + if not self.split_ix[idx] then + self.split_ix[idx] = {} + self.iterators[idx] = 1 + end + table.insert(self.split_ix[idx], i) + end + + for i = 1,self.split_test:size(1) do + local idx = self.split_test[i] + if not self.split_ix[idx] then + self.split_ix[idx] = {} + self.iterators[idx] = 1 + end + table.insert(self.split_ix[idx], i) + end + + for k,v in pairs(self.split_ix) do + print(string.format('assigned %d images to split %s', #v, k)) + end + collectgarbage() -- do it often and there is no harm ;) +end + +function DataLoader:resetIterator(split) + self.iterators[split] = 1 +end + +function DataLoader:getVocabSize() + return self.vocab_size +end + +function DataLoader:getSeqLength() + return self.seq_length +end + +function DataLoader:getDataNum(split) + return #self.split_ix[split] +end + +function DataLoader:getBatch(opt) + local split = utils.getopt(opt, 'split') -- lets require that user passes this in, for safety + local batch_size = utils.getopt(opt, 'batch_size', 128) + + local split_ix = self.split_ix[split] + assert(split_ix, 'split ' .. split .. ' not found.') + + local max_index = #split_ix + local infos = {} + local ques_idx = torch.LongTensor(batch_size):fill(0) + local img_idx = torch.LongTensor(batch_size):fill(0) + + if self.feature_type == 'VGG' then + self.img_batch = torch.Tensor(batch_size, 14, 14, 512) + elseif self.feature_type == 'Residual' then + self.img_batch = torch.Tensor(batch_size, 14, 14, 2048) + end + + for i=1,batch_size do + local ri = self.iterators[split] -- get next index from iterator + local ri_next = ri + 1 -- increment iterator + if ri_next > max_index then ri_next = 1 end + self.iterators[split] = ri_next + if split == 0 then + ix = split_ix[torch.random(max_index)] + else + ix = split_ix[ri] + end + assert(ix ~= nil, 'bug: split ' .. split .. ' was accessed out of bounds with ' .. ri) + ques_idx[i] = ix + if split == 0 or split == 1 then + img_idx[i] = self.img_pos_train[ix] + if self.h5_img_file_train ~= nil then + if self.feature_type == 'VGG' then + local img = self.h5_img_file_train:read('/images_train'):partial({img_idx[i],img_idx[i]},{1,14}, + {1,14},{1,512}) + self.img_batch[i] = img + elseif self.feature_type == 'Residual' then + local img = self.h5_img_file_train:read('/images_train'):partial({img_idx[i],img_idx[i]},{1,14}, + {1,14},{1,2048}) + self.img_batch[i] = img + else + error('feature type error') + end + end + else + img_idx[i] = self.img_pos_test[ix] + if self.h5_img_file_test ~= nil then + if self.feature_type == 'VGG' then + local img = self.h5_img_file_test:read('/images_test'):partial({img_idx[i],img_idx[i]},{1,14}, + {1,14},{1,512}) + self.img_batch[i] = img + elseif self.feature_type == 'Residual' then + local img = self.h5_img_file_test:read('/images_test'):partial({img_idx[i],img_idx[i]},{1,14}, + {1,14},{1,2048}) + self.img_batch[i] = img + else + error('feature type error') + end + end + end + end + + local data = {} + -- fetch the question and image features. + if split == 0 or split == 1 then + data.images = self.img_batch:view(batch_size, 196, -1):contiguous() + data.questions = self.ques_train:index(1, ques_idx) + data.ques_id = self.ques_id_train:index(1, ques_idx) + data.ques_len = self.ques_len_train:index(1, ques_idx) + data.answer = self.answer:index(1, ques_idx) + else + data.images = self.img_batch:view(batch_size, 196, -1):contiguous() + data.questions = self.ques_test:index(1, ques_idx) + data.ques_id = self.ques_id_test:index(1, ques_idx) + data.ques_len = self.ques_len_test:index(1, ques_idx) + data.answer = self.ans_test:index(1, ques_idx) + end + return data +end diff --git a/misc/DataLoaderRaw.lua b/misc/DataLoaderRaw.lua new file mode 100644 index 0000000..4f7f857 --- /dev/null +++ b/misc/DataLoaderRaw.lua @@ -0,0 +1,163 @@ +-- Load raw image and question. +-- image and question is saved in predict/ folder + +local utils = require 'misc.utils' +local DataLoaderRaw = torch.class('DataLoaderRaw') + +function DataLoader:__init(opt) + + if opt.h5_img_file_train ~= nil then + print('DataLoader loading h5 image file: ', opt.h5_img_file_train) + local h5_file = hdf5.open(opt.h5_img_file_train, 'r') + self.fv_im_train = h5_file:read('/images_train'):all() + h5_file:close() + end + + if opt.h5_img_file_train ~= nil then + print('DataLoader loading h5 image file: ', opt.h5_img_file_test) + local h5_file = hdf5.open(opt.h5_img_file_test, 'r') + self.fv_im_test = h5_file:read('/images_test'):all() + h5_file:close() + end + + print('DataLoader loading h5 question file: ', opt.h5_ques_file) + local h5_file = hdf5.open(opt.h5_ques_file, 'r') + self.ques_train = h5_file:read('/ques_train'):all() + self.ques_len_train = h5_file:read('/ques_len_train'):all() + self.img_pos_train = h5_file:read('/img_pos_train'):all() + self.ques_id_train = h5_file:read('/ques_id_train'):all() + self.answer = h5_file:read('/answers'):all() + self.split_train = h5_file:read('/split_train'):all() + + self.ques_test = h5_file:read('/ques_test'):all() + self.ques_len_test = h5_file:read('/ques_len_test'):all() + self.img_pos_test = h5_file:read('/img_pos_test'):all() + self.ques_id_test = h5_file:read('/ques_id_test'):all() + self.split_test = h5_file:read('/split_test'):all() + self.ans_test = h5_file:read('/ans_test'):all() + + h5_file:close() + print('Transform the image feature...') + + if opt.h5_img_file_train ~= nil then + if opt.feature_type == 'VGG' then + self.fv_im_train = self.fv_im_train:view(-1, 196, 512):contiguous() + elseif opt.feature_type == 'Residual' then + self.fv_im_train = self.fv_im_train:view(-1, 196, 2048):contiguous() + else + error('feature type error') + end + end + + if opt.h5_img_file_test ~= nil then + if opt.feature_type == 'VGG' then + self.fv_im_test = self.fv_im_test:view(-1, 196, 512):contiguous() + elseif opt.feature_type == 'Residual' then + self.fv_im_test = self.fv_im_test:view(-1, 196, 2048):contiguous() + else + error('feature type error') + end + end + + print('DataLoader loading json file: ', opt.json_file) + local json_file = utils.read_json(opt.json_file) + self.ix_to_word = json_file.ix_to_word + self.ix_to_ans = json_file.ix_to_ans + + self.seq_length = self.ques_train:size(2) + + -- count the vocabulary key! + self.vocab_size = utils.count_key(self.ix_to_word) + + -- Let's get the split for train and val and test. + self.split_ix = {} + self.iterators = {} + + for i = 1,self.split_train:size(1) do + local idx = self.split_train[i] + if not self.split_ix[idx] then + self.split_ix[idx] = {} + self.iterators[idx] = 1 + end + table.insert(self.split_ix[idx], i) + end + + for i = 1,self.split_test:size(1) do + local idx = self.split_test[i] + if not self.split_ix[idx] then + self.split_ix[idx] = {} + self.iterators[idx] = 1 + end + table.insert(self.split_ix[idx], i) + end + + for k,v in pairs(self.split_ix) do + print(string.format('assigned %d images to split %s', #v, k)) + end + collectgarbage() -- do it often and there is no harm ;) +end + +function DataLoader:resetIterator(split) + self.iterators[split] = 1 +end + +function DataLoader:getVocabSize() + return self.vocab_size +end + +function DataLoader:getSeqLength() + return self.seq_length +end + +function DataLoader:getDataNum(split) + return #self.split_ix[split] +end + +function DataLoader:getBatch(opt) + local split = utils.getopt(opt, 'split') -- lets require that user passes this in, for safety + local batch_size = utils.getopt(opt, 'batch_size', 128) + + local split_ix = self.split_ix[split] + assert(split_ix, 'split ' .. split .. ' not found.') + + local max_index = #split_ix + local infos = {} + local ques_idx = torch.LongTensor(batch_size):fill(0) + local img_idx = torch.LongTensor(batch_size):fill(0) + + for i=1,batch_size do + local ri = self.iterators[split] -- get next index from iterator + local ri_next = ri + 1 -- increment iterator + if ri_next > max_index then ri_next = 1 end + self.iterators[split] = ri_next + if split == 0 then + ix = split_ix[torch.random(max_index)] + else + ix = split_ix[ri] + end + assert(ix ~= nil, 'bug: split ' .. split .. ' was accessed out of bounds with ' .. ri) + ques_idx[i] = ix + if split == 0 or split == 1 then + img_idx[i] = self.img_pos_train[ix] + else + img_idx[i] = self.img_pos_test[ix] + end + end + + local data = {} + -- fetch the question and image features. + if split == 0 or split == 1 then + data.images = self.fv_im_train:index(1,img_idx) + data.questions = self.ques_train:index(1, ques_idx) + data.ques_id = self.ques_id_train:index(1, ques_idx) + data.ques_len = self.ques_len_train:index(1, ques_idx) + data.answer = self.answer:index(1, ques_idx) + else + data.images = self.fv_im_test:index(1,img_idx) + data.questions = self.ques_test:index(1, ques_idx) + data.ques_id = self.ques_id_test:index(1, ques_idx) + data.ques_len = self.ques_len_test:index(1, ques_idx) + data.answer = self.ans_test:index(1, ques_idx) + end + return data +end diff --git a/misc/LSTM.lua b/misc/LSTM.lua new file mode 100644 index 0000000..bb10040 --- /dev/null +++ b/misc/LSTM.lua @@ -0,0 +1,64 @@ +require 'nn' +require 'nngraph' + +local LSTM = {} +function LSTM.lstm(input_size, rnn_size, n, dropout) + dropout = dropout or 0.5 + + -- there will be 2*n+1 inputs + local inputs = {} + table.insert(inputs, nn.Identity()()) -- indices giving the sequence of symbols + for L = 1,n do + table.insert(inputs, nn.Identity()()) -- prev_c[L] + table.insert(inputs, nn.Identity()()) -- prev_h[L] + end + + local x, input_size_L + local outputs = {} + for L = 1,n do + -- c,h from previos timesteps + local prev_h = inputs[L*2+1] + local prev_c = inputs[L*2] + -- the input to this layer + if L == 1 then + x = inputs[1] + --x = nn.BatchNormalization(input_size)(x) + input_size_L = input_size + else + x = outputs[(L-1)*2] + if dropout > 0 then x = nn.Dropout(dropout)(x):annotate{name='drop_' .. L} end -- apply dropout, if any + input_size_L = rnn_size + end + -- evaluate the input sums at once for efficiency + local i2h = nn.Linear(input_size_L, 4 * rnn_size)(x):annotate{name='i2h_'..L} + local h2h = nn.Linear(rnn_size, 4 * rnn_size)(prev_h):annotate{name='h2h_'..L} + local all_input_sums = nn.CAddTable()({i2h, h2h}) + + local reshaped = nn.Reshape(4, rnn_size)(all_input_sums) + local n1, n2, n3, n4 = nn.SplitTable(2)(reshaped):split(4) + -- decode the gates + local in_gate = nn.Sigmoid()(n1) + local forget_gate = nn.Sigmoid()(n2) + local out_gate = nn.Sigmoid()(n3) + -- decode the write inputs + local in_transform = nn.Tanh()(n4) + -- perform the LSTM update + local next_c = nn.CAddTable()({ + nn.CMulTable()({forget_gate, prev_c}), + nn.CMulTable()({in_gate, in_transform}) + }) + -- gated cells form the output + local next_h = nn.CMulTable()({out_gate, nn.Tanh()(next_c)}) + + table.insert(outputs, next_c) + table.insert(outputs, next_h) + end + -- set up the decoder + local top_h = nn.Identity()(outputs[#outputs]) + if dropout > 0 then top_h = nn.Dropout(dropout)(top_h):annotate{name='drop_final'} end + table.insert(outputs, top_h) + + return nn.gModule(inputs, outputs) +end + +return LSTM \ No newline at end of file diff --git a/misc/LanguageEmbedding.lua b/misc/LanguageEmbedding.lua new file mode 100644 index 0000000..9e382dc --- /dev/null +++ b/misc/LanguageEmbedding.lua @@ -0,0 +1,79 @@ +-- +require 'nn' +require 'nngraph' +require 'rnn' +require 'cudnn' + +local LanguageEmbedding = {} + +function LanguageEmbedding.LE(vocab_size, embedding_size, conv_size, seq_length) + local inputs = {} + local outputs = {} + + table.insert(inputs, nn.Identity()()) + + local seq = inputs[1] + + local embed = nn.Dropout(0.5)(nn.Tanh()(nn.LookupTableMaskZero(vocab_size, embedding_size)(seq))) + + table.insert(outputs, embed) + + return nn.gModule(inputs, outputs) +end + + +function LanguageEmbedding.conv(conv_size,embedding_size, seq_length) + local inputs = {} + local outputs = {} + + table.insert(inputs, nn.Identity()()) + + local embed = inputs[1] + + local unigram = cudnn.TemporalConvolution(embedding_size, conv_size, 1, 1, 0)(embed) + local bigram = cudnn.TemporalConvolution(embedding_size, conv_size, 2, 1, 1)(embed) + local trigram = cudnn.TemporalConvolution(embedding_size,conv_size,3, 1, 1)(embed) + + local bigram = nn.Narrow(2,1,seq_length)(bigram) + + local unigram_dim = nn.View(-1, seq_length, conv_size, 1)(unigram) + local bigram_dim = nn.View(-1, seq_length, conv_size, 1)(bigram) + local trigram_dim = nn.View(-1, seq_length, conv_size, 1)(trigram) + + local feat = nn.JoinTable(4)({unigram_dim, bigram_dim, trigram_dim}) + local max_feat = nn.Dropout(0.5)(nn.Tanh()(nn.Max(3, 3)(feat))) + + table.insert(outputs, max_feat) + + return nn.gModule(inputs, outputs) +end +--[[ +function LanguageEmbedding.conv(conv_size,embedding_size, seq_length) + local inputs = {} + local outputs = {} + + table.insert(inputs, nn.Identity()()) + + local embed = inputs[1] + + local unigram = nn.TemporalConvolution(embedding_size, conv_size, 1)(embed) + local bigram = nn.TemporalConvolution(embedding_size, conv_size, 2)(embed) + local trigram = nn.TemporalConvolution(embedding_size,conv_size,3)(embed) + + local bigram_pad = nn.Padding(1,-1,2,0)(bigram) + local trigram_pad = nn.Padding(1,1,2,0)(trigram) + local trigram_pad = nn.Padding(1,-1,2,0)(trigram_pad) + + local unigram_dim = nn.View(seq_length, conv_size, 1):setNumInputDims(3)(unigram) + local bigram_dim = nn.View(seq_length, conv_size, 1):setNumInputDims(3)(bigram_pad) + local trigram_dim = nn.View(seq_length, conv_size, 1):setNumInputDims(3)(trigram_pads) + + local feat = nn.JoinTable(4)({unigram_dim, bigram_dim, trigram_dim}) + local max_feat = nn.Dropout(0.5)(nn.Max(3, 3)(feat)) + + table.insert(outputs, max_feat) + + return nn.gModule(inputs, outputs) +end +]]-- +return LanguageEmbedding diff --git a/misc/attention.lua b/misc/attention.lua new file mode 100644 index 0000000..4a63758 --- /dev/null +++ b/misc/attention.lua @@ -0,0 +1,170 @@ +require 'nngraph' +require 'nn' +require 'misc.maskSoftmax' +local attention = {} +function attention.parallel_atten(input_size_ques, input_size_img, embedding_size, ques_seq_size, img_seq_size) + local inputs = {} + local outputs = {} + + table.insert(inputs, nn.Identity()()) + table.insert(inputs, nn.Identity()()) + table.insert(inputs, nn.Identity()()) + + local ques_feat = inputs[1] + local img_feat = inputs[2] + local mask = inputs[3] + + + local img_corr_dim = nn.Linear(input_size_img, input_size_ques)(nn.View(input_size_img):setNumInputDims(2)(img_feat)) + local img_corr = nn.View(img_seq_size, embedding_size):setNumInputDims(2)(img_corr_dim) + + local weight_matrix = nn.Tanh()(nn.MM(false, true)({ques_feat, img_corr})) + + local ques_embed_dim = nn.Linear(input_size_ques, embedding_size)(nn.View(input_size_ques):setNumInputDims(2)(ques_feat)) + local ques_embed = nn.View(ques_seq_size, embedding_size):setNumInputDims(2)(ques_embed_dim) + + local img_embed_dim = nn.Linear(input_size_img, input_size_ques)(nn.View(input_size_img):setNumInputDims(2)(img_feat)) + local img_embed = nn.View(img_seq_size, embedding_size):setNumInputDims(2)(img_embed_dim) + + local transform_img = nn.MM(false, false)({weight_matrix, img_embed}) + local ques_atten_sum = nn.Dropout(0.5)(nn.Tanh()(nn.CAddTable()({transform_img, ques_embed}))) + local ques_atten_embedding = nn.Linear(embedding_size, 1)(nn.View(embedding_size):setNumInputDims(2)(ques_atten_sum)) + local ques_atten = nn.maskSoftMax()({nn.View(ques_seq_size):setNumInputDims(2)(ques_atten_embedding),mask}) + + local transform_ques = nn.MM(true, false)({weight_matrix, ques_embed}) + local img_atten_sum = nn.Dropout(0.5)(nn.Tanh()(nn.CAddTable()({transform_ques, img_embed}))) + local img_atten_embedding = nn.Linear(embedding_size, 1)(nn.View(embedding_size):setNumInputDims(2)(img_atten_sum)) + local img_atten = nn.SoftMax()(nn.View(img_seq_size):setNumInputDims(2)(img_atten_embedding)) + + local ques_atten_dim = nn.View(1,-1):setNumInputDims(1)(ques_atten) + local img_atten_dim = nn.View(1,-1):setNumInputDims(1)(img_atten) + + local ques_atten_feat = nn.MM(false, false)({ques_atten_dim, ques_feat}) + local ques_atten_feat = nn.View(input_size_ques):setNumInputDims(2)(ques_atten_feat) + + local img_atten_feat = nn.MM(false, false)({img_atten_dim, img_feat}) + local img_atten_feat = nn.View(input_size_img):setNumInputDims(2)(img_atten_feat) + + table.insert(outputs, ques_atten_feat) + table.insert(outputs, img_atten_feat) + + return nn.gModule(inputs, outputs) +end + + +function attention.alternating_atten(input_size_ques, input_size_img, embedding_size, ques_seq_size, img_seq_size) + local inputs = {} + local outputs = {} + + table.insert(inputs, nn.Identity()()) + table.insert(inputs, nn.Identity()()) + table.insert(inputs, nn.Identity()()) + + local ques_feat = inputs[1] + local img_feat = inputs[2] + local mask = inputs[3] + + local ques_embed_dim = nn.Linear(input_size_ques, embedding_size)(nn.View(-1, input_size_ques)(ques_feat)) + local ques_embed = nn.View(-1, ques_seq_size, embedding_size)(ques_embed_dim) + + local feat = nn.Dropout(0.5)(nn.Tanh()(ques_embed)) + local h1 = nn.Linear(embedding_size, 1)(nn.View(-1, embedding_size)(feat)) + local P1 = nn.maskSoftMax()({nn.View(-1, ques_seq_size)(h1),mask}) + local ques_atten = nn.View(1,-1):setNumInputDims(1)(P1) + local quesAtt1 = nn.MM(false, false)({ques_atten, ques_feat}) + local ques_atten_feat_1 = nn.View(-1, input_size_ques)(quesAtt1) + + + -- img attention + local ques_embed_img = nn.Linear(input_size_ques, embedding_size)(ques_atten_feat_1) + + local img_embed_dim = nn.Linear(input_size_img, embedding_size)(nn.View(-1, input_size_img)(img_feat)) + + local img_embed = nn.View(-1, img_seq_size, embedding_size)(img_embed_dim) + + local ques_replicate = nn.Replicate(img_seq_size,2)(ques_embed_img) + + local feat = nn.Dropout(0.5)(nn.Tanh()(nn.CAddTable()({img_embed, ques_replicate}))) + local h2 = nn.Linear(embedding_size, 1)(nn.View(-1, embedding_size)(feat)) + local P2 = nn.SoftMax()(nn.View(-1, img_seq_size)(h2)) + local img_atten = nn.View(1,-1):setNumInputDims(1)(P2) + local visAtt = nn.MM(false, false)({img_atten, img_feat}) + local img_atten_feat = nn.View(-1, input_size_img)(visAtt) + + -- question attention + + local img_embed = nn.Linear(input_size_img, embedding_size)(img_atten_feat) + local img_replicate = nn.Replicate(ques_seq_size,2)(img_embed) + + local ques_embed_dim = nn.Linear(input_size_ques, embedding_size)(nn.View(-1, input_size_ques)(ques_feat)) + local ques_embed = nn.View(-1, ques_seq_size, embedding_size)(ques_embed_dim) + + local feat = nn.Dropout(0.5)(nn.Tanh()(nn.CAddTable()({ques_embed, img_replicate}))) + + local h3 = nn.Linear(embedding_size, 1)(nn.View(-1, embedding_size)(feat)) + local P3 = nn.maskSoftMax()({nn.View(-1, ques_seq_size)(h3),mask}) + local probs3dim = nn.View(1,-1):setNumInputDims(1)(P3) + local quesAtt = nn.MM(false, false)({probs3dim, ques_feat}) + local ques_atten_feat = nn.View(-1, 512)(quesAtt) + + -- combine image attention feature and language attention feature + + table.insert(outputs, ques_atten_feat) + table.insert(outputs, img_atten_feat) + + --table.insert(outputs, probs3dim) + ---table.insert(outputs, img_atten) + + return nn.gModule(inputs, outputs) +end + +function attention.fuse(input_size, ques_seq_size) + local inputs = {} + local outputs = {} + table.insert(inputs, nn.Identity()()) + table.insert(inputs, nn.Identity()()) + + local fore_lstm = inputs[1] + local back_lstm = inputs[2] + + local concat_lstm = nn.JoinTable(2)({fore_lstm, back_lstm}) + + local ques_feat_dim = nn.Linear(input_size*2, input_size)(nn.View(input_size*2):setNumInputDims(2)(concat_lstm)) + local ques_feat = nn.Dropout(0.5)(nn.View(ques_seq_size, input_size):setNumInputDims(2)(ques_feat_dim)) + + table.insert(outputs, ques_feat) + return nn.gModule(inputs, outputs) +end + + +function attention.recursive_atten(input_size, embedding_size, last_embed_size, output_size) + local inputs = {} + local outputs = {} + table.insert(inputs, nn.Identity()()) + table.insert(inputs, nn.Identity()()) + table.insert(inputs, nn.Identity()()) + table.insert(inputs, nn.Identity()()) + table.insert(inputs, nn.Identity()()) + table.insert(inputs, nn.Identity()()) + + local embed_ques = inputs[1] + local embed_img = inputs[2] + local conv_ques = inputs[3] + local conv_img = inputs[4] + local lstm_ques = inputs[5] + local lstm_img = inputs[6] + + local feat1 = nn.Dropout(0.5)(nn.CAddTable()({embed_ques, embed_img})) + local hidden1 = nn.Tanh()(nn.Linear(input_size, embedding_size)(feat1)) + local feat2 = nn.Dropout(0.5)(nn.JoinTable(2)({nn.CAddTable()({conv_ques, conv_img}), hidden1})) + local hidden2 = nn.Tanh()(nn.Linear(embedding_size+input_size, embedding_size)(feat2)) + local feat3 = nn.Dropout(0.5)(nn.JoinTable(2)({nn.CAddTable()({lstm_ques, lstm_img}), hidden2})) + local hidden3 = nn.Tanh()(nn.Linear(embedding_size+input_size, last_embed_size)(feat3)) + local outfeat = nn.Linear(last_embed_size, output_size)(nn.Dropout(0.5)(hidden3)) + + table.insert(outputs, outfeat) + + return nn.gModule(inputs, outputs) +end + +return attention diff --git a/misc/attention_visu.lua b/misc/attention_visu.lua new file mode 100644 index 0000000..b52cc3a --- /dev/null +++ b/misc/attention_visu.lua @@ -0,0 +1,174 @@ +require 'nngraph' +require 'nn' +require 'misc.maskSoftmax' +local attention = {} + +function attention.parallel_atten(input_size_ques, input_size_img, embedding_size, ques_seq_size, img_seq_size) + local inputs = {} + local outputs = {} + + table.insert(inputs, nn.Identity()()) + table.insert(inputs, nn.Identity()()) + table.insert(inputs, nn.Identity()()) + + local ques_feat = inputs[1] + local img_feat = inputs[2] + local mask = inputs[3] + + + local img_corr_dim = nn.Linear(input_size_img, input_size_ques)(nn.View(input_size_img):setNumInputDims(2)(img_feat)) + local img_corr = nn.View(img_seq_size, embedding_size):setNumInputDims(2)(img_corr_dim) + + local weight_matrix = nn.Tanh()(nn.MM(false, true)({ques_feat, img_corr})) + + local ques_embed_dim = nn.Linear(input_size_ques, embedding_size)(nn.View(input_size_ques):setNumInputDims(2)(ques_feat)) + local ques_embed = nn.View(ques_seq_size, embedding_size):setNumInputDims(2)(ques_embed_dim) + + local img_embed_dim = nn.Linear(input_size_img, input_size_ques)(nn.View(input_size_img):setNumInputDims(2)(img_feat)) + local img_embed = nn.View(img_seq_size, embedding_size):setNumInputDims(2)(img_embed_dim) + + local transform_img = nn.MM(false, false)({weight_matrix, img_embed}) + local ques_atten_sum = nn.Dropout(0.5)(nn.Tanh()(nn.CAddTable()({transform_img, ques_embed}))) + local ques_atten_embedding = nn.Linear(embedding_size, 1)(nn.View(embedding_size):setNumInputDims(2)(ques_atten_sum)) + local ques_atten = nn.maskSoftMax()({nn.View(ques_seq_size):setNumInputDims(2)(ques_atten_embedding),mask}) + + local transform_ques = nn.MM(true, false)({weight_matrix, ques_embed}) + local img_atten_sum = nn.Dropout(0.5)(nn.Tanh()(nn.CAddTable()({transform_ques, img_embed}))) + local img_atten_embedding = nn.Linear(embedding_size, 1)(nn.View(embedding_size):setNumInputDims(2)(img_atten_sum)) + local img_atten = nn.SoftMax()(nn.View(img_seq_size):setNumInputDims(2)(img_atten_embedding)) + + local ques_atten_dim = nn.View(1,-1):setNumInputDims(1)(ques_atten) + local img_atten_dim = nn.View(1,-1):setNumInputDims(1)(img_atten) + + local ques_atten_feat = nn.MM(false, false)({ques_atten_dim, ques_feat}) + local ques_atten_feat = nn.View(input_size_ques):setNumInputDims(2)(ques_atten_feat) + + local img_atten_feat = nn.MM(false, false)({img_atten_dim, img_feat}) + local img_atten_feat = nn.View(input_size_img):setNumInputDims(2)(img_atten_feat) + + table.insert(outputs, ques_atten_feat) + table.insert(outputs, img_atten_feat) + + table.insert(outputs, ques_atten) + table.insert(outputs, img_atten) + + return nn.gModule(inputs, outputs) +end + + +function attention.alternating_atten(input_size_ques, input_size_img, embedding_size, ques_seq_size, img_seq_size) + local inputs = {} + local outputs = {} + + table.insert(inputs, nn.Identity()()) + table.insert(inputs, nn.Identity()()) + table.insert(inputs, nn.Identity()()) + + local ques_feat = inputs[1] + local img_feat = inputs[2] + local mask = inputs[3] + + local ques_embed_dim = nn.Linear(input_size_ques, embedding_size)(nn.View(input_size_ques):setNumInputDims(2)(ques_feat)) + local ques_embed = nn.View(ques_seq_size, embedding_size):setNumInputDims(2)(ques_embed_dim) + + local feat = nn.Dropout(0.5)(nn.Tanh()(ques_embed)) + local h1 = nn.Linear(embedding_size, 1)(nn.View(embedding_size):setNumInputDims(2)(feat)) + local P1 = nn.maskSoftMax()({nn.View(ques_seq_size):setNumInputDims(2)(h1),mask}) + local ques_atten = nn.View(1,-1):setNumInputDims(1)(P1) + local quesAtt1 = nn.MM(false, false)({ques_atten, ques_feat}) + local ques_atten_feat_1 = nn.View(input_size_ques):setNumInputDims(2)(quesAtt1) + + + -- img attention + local ques_embed_img = nn.Linear(input_size_ques, embedding_size)(ques_atten_feat_1) + + local img_embed_dim = nn.Linear(input_size_img, embedding_size)(nn.View(input_size_img):setNumInputDims(2)(img_feat)) + + local img_embed = nn.View(img_seq_size, embedding_size):setNumInputDims(2)(img_embed_dim) + + local ques_replicate = nn.Replicate(img_seq_size,2)(ques_embed_img) + + local feat = nn.Dropout(0.5)(nn.Tanh()(nn.CAddTable()({img_embed, ques_replicate}))) + local h2 = nn.Linear(embedding_size, 1)(nn.View(embedding_size):setNumInputDims(2)(feat)) + local P2 = nn.SoftMax()(nn.View(img_seq_size):setNumInputDims(2)(h2)) + local img_atten = nn.View(1,-1):setNumInputDims(1)(P2) + local visAtt = nn.MM(false, false)({img_atten, img_feat}) + local img_atten_feat = nn.View(input_size_img):setNumInputDims(2)(visAtt) + + -- question attention + + local img_embed = nn.Linear(input_size_img, embedding_size)(img_atten_feat) + local img_replicate = nn.Replicate(ques_seq_size,2)(img_embed) + + local ques_embed_dim = nn.Linear(input_size_ques, embedding_size)(nn.View(input_size_ques):setNumInputDims(2)(ques_feat)) + local ques_embed = nn.View(ques_seq_size, embedding_size):setNumInputDims(2)(ques_embed_dim) + + local feat = nn.Dropout(0.5)(nn.Tanh()(nn.CAddTable()({ques_embed, img_replicate}))) + + local h3 = nn.Linear(embedding_size, 1)(nn.View(embedding_size):setNumInputDims(2)(feat)) + local P3 = nn.maskSoftMax()({nn.View(ques_seq_size):setNumInputDims(2)(h3),mask}) + local probs3dim = nn.View(1,-1):setNumInputDims(1)(P3) + local quesAtt = nn.MM(false, false)({probs3dim, ques_feat}) + local ques_atten_feat = nn.View(512):setNumInputDims(2)(quesAtt) + + -- combine image attention feature and language attention feature + + table.insert(outputs, ques_atten_feat) + table.insert(outputs, img_atten_feat) + + table.insert(outputs, probs3dim) + table.insert(outputs, img_atten) + + return nn.gModule(inputs, outputs) +end + +function attention.fuse(input_size, ques_seq_size) + local inputs = {} + local outputs = {} + table.insert(inputs, nn.Identity()()) + table.insert(inputs, nn.Identity()()) + + local fore_lstm = inputs[1] + local back_lstm = inputs[2] + + local concat_lstm = nn.JoinTable(2)({fore_lstm, back_lstm}) + + local ques_feat_dim = nn.Linear(input_size*2, input_size)(nn.View(input_size*2):setNumInputDims(2)(concat_lstm)) + local ques_feat = nn.Dropout(0.5)(nn.View(ques_seq_size, input_size):setNumInputDims(2)(ques_feat_dim)) + + table.insert(outputs, ques_feat) + return nn.gModule(inputs, outputs) +end + + +function attention.recursive_atten(input_size, embedding_size, last_embed_size, output_size) + local inputs = {} + local outputs = {} + table.insert(inputs, nn.Identity()()) + table.insert(inputs, nn.Identity()()) + table.insert(inputs, nn.Identity()()) + table.insert(inputs, nn.Identity()()) + table.insert(inputs, nn.Identity()()) + table.insert(inputs, nn.Identity()()) + + local embed_ques = inputs[1] + local embed_img = inputs[2] + local conv_ques = inputs[3] + local conv_img = inputs[4] + local lstm_ques = inputs[5] + local lstm_img = inputs[6] + + local feat1 = nn.Dropout(0.5)(nn.CAddTable()({embed_ques, embed_img})) + local hidden1 = nn.Tanh()(nn.Linear(input_size, embedding_size)(feat1)) + local feat2 = nn.Dropout(0.5)(nn.JoinTable(2)({nn.CAddTable()({conv_ques, conv_img}), hidden1})) + local hidden2 = nn.Tanh()(nn.Linear(embedding_size+input_size, embedding_size)(feat2)) + local feat3 = nn.Dropout(0.5)(nn.JoinTable(2)({nn.CAddTable()({lstm_ques, lstm_img}), hidden2})) + local hidden3 = nn.Tanh()(nn.Linear(embedding_size+input_size, last_embed_size)(feat3)) + local outfeat = nn.Linear(last_embed_size, output_size)(nn.Dropout(0.5)(hidden3)) + + table.insert(outputs, outfeat) + + return nn.gModule(inputs, outputs) +end + +return attention diff --git a/misc/cnnModel.lua b/misc/cnnModel.lua new file mode 100644 index 0000000..4f39292 --- /dev/null +++ b/misc/cnnModel.lua @@ -0,0 +1,75 @@ +require 'nn' +local utils = require 'misc.utils' +require 'loadcaffe' + +--local LSTM_img = require 'misc.LSTM_img' + +------------------------------------------------------------------------------- +-- Image Model core +------------------------------------------------------------------------------- + +local layer, parent = torch.class('nn.cnnModel', 'nn.Module') +function layer:__init(opt) + parent.__init(self) + + local layer_num = utils.getopt(opt, 'layer_num', 37) + self.input_size = utils.getopt(opt, 'input_size_image') + local dropout = utils.getopt(opt, 'dropout', 0) + self.output_size = utils.getopt(opt, 'output_size') + self.cnn_proto = utils.getopt(opt, 'cnn_proto') + self.cnn_model = utils.getopt(opt, 'cnn_model') + self.backend = utils.getopt(opt, 'backend') + -- option for Image Model + self.h = utils.getopt(opt, 'h') + self.w = utils.getopt(opt, 'w') + assert(self.h==self.w) -- h and w should be same here + self.seq_length = self.h * self.w + print(self.cnn_proto, self.cnn_model, self.backend) + local cnn_raw = loadcaffe.load(self.cnn_proto, self.cnn_model, self.backend) + self.cnn_part = nn.Sequential() + for i = 1, layer_num do + local layer = cnn_raw:get(i) + self.cnn_part:add(layer) + end + self.cnn_part:add(nn.View(-1, 512, 196)) + self.cnn_part:add(nn.Transpose({2,3})) +end + + +function layer:parameters() + local params = {} + local grad_params = {} + + local p2,g2 = self.cnn_part:parameters() + for k,v in pairs(p2) do table.insert(params, v) end + for k,v in pairs(g2) do table.insert(grad_params, v) end + + return params, grad_params +end + +function layer:training() + self.cnn_part:training() + +end + +function layer:evaluate() + self.cnn_part:evaluate() +end + +function layer:updateOutput(input) + local imgs = input + self.output = self.cnn_part:forward(imgs) + + return self.output + +end + +function layer:updateGradInput(input, gradOutput) + local imgs = input + + local dummy = self.cnn_part:backward(imgs, gradOutput) + self.gradInput = {} + return self.gradInput +end + + diff --git a/misc/maskSoftmax.lua b/misc/maskSoftmax.lua new file mode 100644 index 0000000..63da95a --- /dev/null +++ b/misc/maskSoftmax.lua @@ -0,0 +1,30 @@ +local maskSoftMax, _ = torch.class('nn.maskSoftMax', 'nn.Module') + +function maskSoftMax:updateOutput(input) + local data = input[1] + local mask = input[2] + data:maskedFill(mask, -9999999) + + data.THNN.SoftMax_updateOutput( + data:cdata(), + self.output:cdata() + ) + return self.output +end + +function maskSoftMax:updateGradInput(input, gradOutput) + local data = input[1] + local mask = input[2] + data:maskedFill(mask, -9999999) + data.THNN.SoftMax_updateGradInput( + data:cdata(), + gradOutput[1]:cdata(), + self.gradInput:cdata(), + self.output:cdata() + ) + if not self.dummy_out then + self.dummy_out = mask:clone() + end + self.dummy_out:resizeAs(mask):zero() + return {self.gradInput, self.dummy_out} +end \ No newline at end of file diff --git a/misc/optim_updates.lua b/misc/optim_updates.lua new file mode 100644 index 0000000..a116190 --- /dev/null +++ b/misc/optim_updates.lua @@ -0,0 +1,84 @@ + +-- optim, simple as it should be, written from scratch. That's how I roll + +function sgd(x, dx, lr) + x:add(-lr, dx) +end + +function sgdm(x, dx, lr, alpha, state) + -- sgd with momentum, standard update + if not state.v then + state.v = x.new(#x):zero() + end + state.v:mul(alpha) + state.v:add(lr, dx) + x:add(-1, state.v) +end + +function sgdmom(x, dx, lr, alpha, state) + -- sgd momentum, uses nesterov update (reference: http://cs231n.github.io/neural-networks-3/#sgd) + if not state.m then + state.m = x.new(#x):zero() + state.tmp = x.new(#x) + end + state.tmp:copy(state.m) + state.m:mul(alpha):add(-lr, dx) + x:add(-alpha, state.tmp) + x:add(1+alpha, state.m) +end + +function adagrad(x, dx, lr, epsilon, state) + if not state.m then + state.m = x.new(#x):zero() + state.tmp = x.new(#x) + end + -- calculate new mean squared values + state.m:addcmul(1.0, dx, dx) + -- perform update + state.tmp:sqrt(state.m):add(epsilon) + x:addcdiv(-lr, dx, state.tmp) +end + +-- rmsprop implementation, simple as it should be +function rmsprop(x, dx, lr, alpha, epsilon, state) + if not state.m then + state.m = x.new(#x):zero() + state.tmp = x.new(#x) + end + -- calculate new (leaky) mean squared values + state.m:mul(alpha) + state.m:addcmul(1.0-alpha, dx, dx) + -- perform update + state.tmp:sqrt(state.m):add(epsilon) + x:addcdiv(-lr, dx, state.tmp) +end + +function adam(x, dx, lr, beta1, beta2, epsilon, state) + local beta1 = beta1 or 0.9 + local beta2 = beta2 or 0.999 + local epsilon = epsilon or 1e-8 + + if not state.m then + -- Initialization + state.t = 0 + -- Exponential moving average of gradient values + state.m = x.new(#dx):zero() + -- Exponential moving average of squared gradient values + state.v = x.new(#dx):zero() + -- A tmp tensor to hold the sqrt(v) + epsilon + state.tmp = x.new(#dx):zero() + end + + -- Decay the first and second moment running average coefficient + state.m:mul(beta1):add(1-beta1, dx) + state.v:mul(beta2):addcmul(1-beta2, dx, dx) + state.tmp:copy(state.v):sqrt():add(epsilon) + + state.t = state.t + 1 + local biasCorrection1 = 1 - beta1^state.t + local biasCorrection2 = 1 - beta2^state.t + local stepSize = lr * math.sqrt(biasCorrection2)/biasCorrection1 + + -- perform update + x:addcdiv(-stepSize, state.m, state.tmp) +end diff --git a/misc/phrase_level.lua b/misc/phrase_level.lua new file mode 100644 index 0000000..e9634b5 --- /dev/null +++ b/misc/phrase_level.lua @@ -0,0 +1,91 @@ +require 'nn' +local utils = require 'misc.utils' +local LanguageEmbedding = require 'misc.LanguageEmbedding' +local attention = require 'misc.attention' + + +local layer, parent = torch.class('nn.phrase_level', 'nn.Module') +function layer:__init(opt) + parent.__init(self) + self.vocab_size = utils.getopt(opt, 'vocab_size') -- required + self.hidden_size = utils.getopt(opt, 'hidden_size') + local dropout = utils.getopt(opt, 'dropout', 0) + + self.seq_length = utils.getopt(opt, 'seq_length') + self.atten_type = utils.getopt(opt, 'atten_type') + self.feature_type = utils.getopt(opt, 'feature_type') + + self.conv = LanguageEmbedding.conv(self.hidden_size, self.hidden_size, self.seq_length) + + if self.atten_type == 'Alternating' then + self.atten = attention.alternating_atten(self.hidden_size, self.hidden_size, self.hidden_size, self.seq_length, 196) + elseif self.atten_type == 'Parallel' then + self.atten = attention.parallel_atten(self.hidden_size, self.hidden_size, self.hidden_size, self.seq_length, 196) + else + error('Must provide an valid attention type.') + end +end + +function layer:getModulesList() + return {self.conv, self.atten} +end + +function layer:parameters() + -- we only have two internal modules, return their params + local p1,g1 = self.conv:parameters() + local p3,g3 = self.atten:parameters() + + local params = {} + for k,v in pairs(p1) do table.insert(params, v) end + for k,v in pairs(p3) do table.insert(params, v) end + + local grad_params = {} + for k,v in pairs(g1) do table.insert(grad_params, v) end + for k,v in pairs(g3) do table.insert(grad_params, v) end + + return params, grad_params +end + + +function layer:training() + self.atten:training() + self.conv:training() + +end + +function layer:evaluate() + self.atten:evaluate() + self.conv:evaluate() +end + +function layer:updateOutput(input) + local ques = input[1] + local seq_len = input[2] + local img = input[3] + self.mask = input[4] + + self.conv_out = self.conv:forward(ques) + local w_conv_ques, w_conv_img, ques_atten, img_atten = unpack(self.atten:forward({self.conv_out, img, self.mask})) + + return {self.conv_out, w_conv_ques, w_conv_img, ques_atten, img_atten} + +end + +function layer:updateGradInput(input, gradOutput) + local ques = input[1] + local seq_len = input[2] + local img = input[3] + + local batch_size = ques:size(1) + + local d_core_output, d_imgfeat, dummy = unpack(self.atten:backward({self.conv_out, img, self.mask}, {gradOutput[2], gradOutput[3]})) + + d_core_output:add(gradOutput[1]) + + local d_embedding = self.conv:backward(ques, d_core_output) + + self.gradInput = {d_embedding,d_imgfeat} + + + return self.gradInput +end diff --git a/misc/ques_level.lua b/misc/ques_level.lua new file mode 100644 index 0000000..3db9b7f --- /dev/null +++ b/misc/ques_level.lua @@ -0,0 +1,168 @@ +require 'nn' +local utils = require 'misc.utils' +local LSTM = require 'misc.LSTM' +local LanguageEmbedding = require 'misc.LanguageEmbedding' +local attention = require 'misc.attention' +local LanguageEmbedding = require 'misc.LanguageEmbedding' + +local layer, parent = torch.class('nn.ques_level', 'nn.Module') +function layer:__init(opt) + parent.__init(self) + self.rnn_size = utils.getopt(opt, 'rnn_size') + self.num_layers = utils.getopt(opt, 'num_layers', 1) + local dropout = utils.getopt(opt, 'dropout', 0) + self.hidden_size = utils.getopt(opt, 'hidden_size') + self.seq_length = utils.getopt(opt, 'seq_length') + self.atten_type = utils.getopt(opt, 'atten_type') + + self.core = LSTM.lstm(self.rnn_size, self.rnn_size, self.num_layers, dropout) + + if self.atten_type == 'Alternating' then + self.atten = attention.alternating_atten(self.hidden_size, self.hidden_size, self.hidden_size, self.seq_length, 196) + elseif self.atten_type == 'Parallel' then + self.atten = attention.parallel_atten(self.hidden_size, self.hidden_size, self.hidden_size, self.seq_length, 196) + else + error('Must provide an valid attention type.') + end + self:_createInitState(1) + self.mask = torch.Tensor() + self.core_output = torch.Tensor() +end + +function layer:_createInitState(batch_size) + assert(batch_size ~= nil, 'batch size must be provided') + -- construct the initial state for the LSTM + if not self.init_state then self.init_state = {} end -- lazy init + for h=1,self.num_layers*2 do + -- note, the init state Must be zeros because we are using init_state to init grads in backward call too + if self.init_state[h] then + if self.init_state[h]:size(1) ~= batch_size then + self.init_state[h]:resize(batch_size, self.rnn_size):zero() -- expand the memory + end + else + self.init_state[h] = torch.zeros(batch_size, self.rnn_size) + end + end + self.num_state = #self.init_state + end + +function layer:createClones() + print('constructing clones inside the ques_level') + self.cores = {self.core} + for t=1,self.seq_length do + self.cores[t] = self.core:clone('weight', 'bias', 'gradWeight', 'gradBias') + end +end + +function layer:shareClones() + if self.cores == nil then self:createClones(); return; end + print('resharing clones inside the ques_level') + self.cores[1] = self.core + for t=1,self.seq_length do + self.cores[t]:share(self.core, 'weight', 'bias', 'gradWeight', 'gradBias') + end +end + +function layer:getModulesList() + return {self.core, self.atten} +end + +function layer:parameters() + -- we only have two internal modules, return their params + local p1,g1 = self.core:parameters() + local p3,g3 = self.atten:parameters() + + local params = {} + for k,v in pairs(p1) do table.insert(params, v) end + for k,v in pairs(p3) do table.insert(params, v) end + + local grad_params = {} + for k,v in pairs(g1) do table.insert(grad_params, v) end + for k,v in pairs(g3) do table.insert(grad_params, v) end + + return params, grad_params +end + + +function layer:training() + if self.cores == nil then self:createClones() end -- create these lazily if needed + for k,v in pairs(self.cores) do v:training() end + self.atten:training() +end + +function layer:evaluate() + if self.cores == nil then self:createClones() end -- create these lazily if needed + for k,v in pairs(self.cores) do v:evaluate() end + self.atten:evaluate() +end + +function layer:updateOutput(input) + local ques = input[1] + local seq_len = input[2] + local img = input[3] + self.mask = input[4] + + + if self.cores == nil then self:createClones() end -- lazily create clones on first forward pass + local batch_size = ques:size(1) + self.tmax = torch.max(seq_len) + self.tmin = torch.min(seq_len) + + self:_createInitState(batch_size) + self.fore_state = {[0] = self.init_state} + self.fore_inputs = {} + self.core_output:resize(batch_size, self.seq_length, self.rnn_size):zero() + + + for t=1,self.tmax do + self.fore_inputs[t] = {ques:narrow(2,t,1):contiguous():view(-1, self.rnn_size), unpack(self.fore_state[t-1])} + local out = self.cores[t]:forward(self.fore_inputs[t]) + if t > self.tmin then + for i=1,self.num_state+1 do + out[i]:maskedFill(self.mask:narrow(2,t,1):contiguous():view(batch_size,1):expandAs(out[i]), 0) + end + end + self.fore_state[t] = {} -- the rest is state + for i=1,self.num_state do table.insert(self.fore_state[t], out[i]) end + + self.core_output:narrow(2,t,1):copy(out[self.num_state+1]) + end + + local w_lstm_ques, w_lstm_img, ques_atten, img_atten = unpack(self.atten:forward({self.core_output, img, self.mask})) + + return {w_lstm_ques, w_lstm_img, ques_atten, img_atten} +end + +function layer:updateGradInput(input, gradOutput) + local ques = input[1] + local seq_len = input[2] + local img = input[3] + + local batch_size = ques:size(1) + + local d_core_output, d_imgfeat, dummy = unpack(self.atten:backward({self.core_output, img, self.mask}, gradOutput)) + + -- go backwards and lets compute gradients + local d_core_state = {[self.tmax] = self.init_state} -- initial dstates + local d_embed_core = d_embed_core or self.core_output:new() + d_embed_core:resize(batch_size, self.seq_length, self.rnn_size):zero() + + for t=self.tmax,1,-1 do + -- concat state gradients and output vector gradients at time step t + local dout = {} + for k=1,#d_core_state[t] do table.insert(dout, d_core_state[t][k]) end + table.insert(dout, d_core_output:narrow(2,t,1):contiguous():view(-1, self.hidden_size)) + local dinputs = self.cores[t]:backward(self.fore_inputs[t], dout) + + if t > self.tmin then + for k=1,self.num_state+1 do + dinputs[k]:maskedFill(self.mask:narrow(2,t,1):contiguous():view(batch_size,1):expandAs(dinputs[k]), 0) + end + end + d_core_state[t-1] = {} -- copy over rest to state grad + for k=2,self.num_state+1 do table.insert(d_core_state[t-1], dinputs[k]) end + d_embed_core:narrow(2,t,1):copy(dinputs[1]) + end + self.gradInput = {d_embed_core, d_imgfeat} + return self.gradInput +end diff --git a/misc/recursive_atten.lua b/misc/recursive_atten.lua new file mode 100644 index 0000000..42c6cb0 --- /dev/null +++ b/misc/recursive_atten.lua @@ -0,0 +1,44 @@ +require 'nn' +local utils = require 'misc.utils' +local attention = require 'misc.attention' + +local layer, parent = torch.class('nn.recursive_atten', 'nn.Module') +function layer:__init(opt) + parent.__init(self) + + self.atten_encode = attention.recursive_atten(512,512,1024,1000) + + -- self.atten_encode = attention.recursive_atten(512,512,512,1000) -- coco_qa +end + +function layer:getModulesList() + return {self.atten_encode} +end + +function layer:parameters() + local p1,g1 = self.atten_encode:parameters() + local params = {} + for k,v in pairs(p1) do table.insert(params, v) end + local grad_params = {} + for k,v in pairs(g1) do table.insert(grad_params, v) end + + return params, grad_params +end + +function layer:training() + self.atten_encode:training() +end + +function layer:evaluate() + self.atten_encode:evaluate() +end + +function layer:updateOutput(input) + local out_feat = self.atten_encode:forward(input) + return out_feat +end + +function layer:updateGradInput(input, gradOutput) + self.gradInput = self.atten_encode:backward(input, gradOutput) + return self.gradInput +end diff --git a/misc/utils.lua b/misc/utils.lua new file mode 100644 index 0000000..0c8e55d --- /dev/null +++ b/misc/utils.lua @@ -0,0 +1,71 @@ +local cjson = require 'cjson' +local utils = {} +require 'nn' +-- Assume required if default_value is nil +function utils.getopt(opt, key, default_value) + if default_value == nil and (opt == nil or opt[key] == nil) then + error('error: required key ' .. key .. ' was not provided in an opt.') + end + if opt == nil then return default_value end + + local v = opt[key] + if v == nil then v = default_value end + return v +end + +function utils.read_json(path) + local file = io.open(path, 'r') + local text = file:read() + file:close() + local info = cjson.decode(text) + return info +end + +function utils.write_json(path, j) + -- API reference http://www.kyne.com.au/~mark/software/lua-cjson-manual.html#encode + --cjson.encode_sparse_array(true, 2, 10) + local text = cjson.encode(j) + local file = io.open(path, 'w') + file:write(text) + file:close() +end + +function utils.right_align(seq, lengths) + -- right align the questions. + local v=seq:clone():fill(0) + local N=seq:size(2) + for i=1,seq:size(1) do + v[i][{{N-lengths[i]+1,N}}]=seq[i][{{1,lengths[i]}}] + end + return v +end + +function utils.normlize_image(imgFeat) + local length = imgFeat:size(2) + local nm=torch.sqrt(torch.sum(torch.cmul(imgFeat,imgFeat),2)) + return torch.cdiv(imgFeat,torch.repeatTensor(nm,1,length)):float() +end + +function utils.count_key(t) + local count = 1 + for i, w in pairs(t) do + count = count + 1 + end + return count +end + + +function utils.prepro(im, on_gpu) + assert(on_gpu ~= nil, 'pass this in. careful here.') + + im=im*255 + local im2=im:clone() + im2[{{},{3},{},{}}]=im[{{},{1},{},{}}]-123.68 + im2[{{},{2},{},{}}]=im[{{},{2},{},{}}]-116.779 + im2[{{},{1},{},{}}]=im[{{},{3},{},{}}]-103.939 + + return im2 +end + + +return utils \ No newline at end of file diff --git a/misc/word_level.lua b/misc/word_level.lua new file mode 100644 index 0000000..e2f275e --- /dev/null +++ b/misc/word_level.lua @@ -0,0 +1,115 @@ +require 'nn' +local utils = require 'misc.utils' +local LanguageEmbedding = require 'misc.LanguageEmbedding' +local attention = require 'misc.attention' + + +local layer, parent = torch.class('nn.word_level', 'nn.Module') +function layer:__init(opt) + parent.__init(self) + self.vocab_size = utils.getopt(opt, 'vocab_size') -- required + self.hidden_size = utils.getopt(opt, 'hidden_size') + local dropout = utils.getopt(opt, 'dropout', 0) + self.seq_length = utils.getopt(opt, 'seq_length') + self.atten_type = utils.getopt(opt, 'atten_type') + self.feature_type = utils.getopt(opt, 'feature_type') + self.LE = LanguageEmbedding.LE(self.vocab_size, self.hidden_size, self.hidden_size, self.seq_length) + + if self.atten_type == 'Alternating' then + self.atten = attention.alternating_atten(self.hidden_size, self.hidden_size, self.hidden_size, self.seq_length, 196) + elseif self.atten_type == 'Parallel' then + self.atten = attention.parallel_atten(self.hidden_size, self.hidden_size, self.hidden_size, self.seq_length, 196) + else + error('Must provide an valid attention type.') + end + + if self.feature_type == 'VGG' then + self.cnn = nn.Sequential() + :add(nn.View(512):setNumInputDims(2)) + :add(nn.Linear(512, self.hidden_size)) + :add(nn.View(-1, 196, self.hidden_size)) + :add(nn.Tanh()) + :add(nn.Dropout(0.5)) + elseif self.feature_type == 'Residual' then + self.cnn = nn.Sequential() + :add(nn.View(2048):setNumInputDims(2)) + :add(nn.Linear(2048, self.hidden_size)) + :add(nn.View(-1, 196, self.hidden_size)) + :add(nn.Tanh()) + :add(nn.Dropout(0.5)) + end + + self.mask = torch.Tensor() +end + + +function layer:getModulesList() + return {self.LE, self.atten, self.cnn} +end + +function layer:parameters() + local p1,g1 = self.cnn:parameters() + + local p2,g2 = self.LE:parameters() + local p3,g3 = self.atten:parameters() + + local params = {} + for k,v in pairs(p1) do table.insert(params, v) end + for k,v in pairs(p2) do table.insert(params, v) end + for k,v in pairs(p3) do table.insert(params, v) end + + local grad_params = {} + for k,v in pairs(g1) do table.insert(grad_params, v) end + for k,v in pairs(g2) do table.insert(grad_params, v) end + for k,v in pairs(g3) do table.insert(grad_params, v) end + + return params, grad_params +end + +function layer:training() + self.LE:training() + self.atten:training() + self.cnn:training() +end + +function layer:evaluate() + self.atten:evaluate() + self.LE:evaluate() + self.cnn:evaluate() +end + +function layer:updateOutput(input) + local seq = input[1] + local img = input[2] + + local batch_size = seq:size(1) + self.mask:resizeAs(seq):zero() + self.mask[torch.eq(seq, 0)] = 1 + + self.img_feat = self.cnn:forward(img) + + self.embed_output = self.LE:forward(seq) + local w_embed_ques, w_embed_img, ques_atten, img_atten = unpack(self.atten:forward({self.embed_output, self.img_feat, self.mask})) + + return {self.embed_output, self.img_feat, w_embed_ques, w_embed_img, self.mask, ques_atten, img_atten} +end + +function layer:updateGradInput(input, gradOutput) + local seq = input[1] + local img = input[2] + + local batch_size = seq:size(1) + + local d_embed_ques, d_embed_img, dummy = unpack(self.atten:backward({self.embed_output, self.img_feat, self.mask}, {gradOutput[2], gradOutput[3]})) + + d_embed_ques:add(gradOutput[1]) + + local dummy = self.LE:backward(seq, d_embed_ques) + + d_embed_img:add(gradOutput[4]) + d_embed_img:add(gradOutput[5]) + local d_imgfeat = self.cnn:backward(img, d_embed_img) + self.gradInput = d_imgfeat + + return self.gradInput +end diff --git a/predict.ipynb b/predict.ipynb new file mode 100644 index 0000000..6da564d --- /dev/null +++ b/predict.ipynb @@ -0,0 +1,423 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 55, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "data": { + "text/plain": [ + "image_model/VGG_ILSVRC_19_layers_deploy.prototxt\timage_model/VGG_ILSVRC_19_layers.caffemodel\tcudnn\t\n" + ] + }, + "execution_count": 55, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "text/plain": [ + "Successfully loaded image_model/VGG_ILSVRC_19_layers.caffemodel\n" + ] + }, + "execution_count": 55, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "text/plain": [ + "conv1_1: 64 3 3 3\n", + "conv1_2: 64 64 3 3\n" + ] + }, + "execution_count": 55, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "text/plain": [ + "conv2_1: 128 64 3 3\n", + "conv2_2: 128 128 3 3\n", + "conv3_1: 256 128 3 3\n" + ] + }, + "execution_count": 55, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "text/plain": [ + "conv3_2: 256 256 3 3\n" + ] + }, + "execution_count": 55, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "text/plain": [ + "conv3_3: 256 256 3 3\n" + ] + }, + "execution_count": 55, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "text/plain": [ + "conv3_4: 256 256 3 3\n", + "conv4_1: 512 256 3 3\n" + ] + }, + "execution_count": 55, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "text/plain": [ + "conv4_2: 512 512 3 3\n" + ] + }, + "execution_count": 55, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "text/plain": [ + "conv4_3: 512 512 3 3\n" + ] + }, + "execution_count": 55, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "text/plain": [ + "conv4_4: 512 512 3 3\n" + ] + }, + "execution_count": 55, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "text/plain": [ + "conv5_1: 512 512 3 3\n" + ] + }, + "execution_count": 55, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "text/plain": [ + "conv5_2: 512 512 3 3\n" + ] + }, + "execution_count": 55, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "text/plain": [ + "conv5_3: 512 512 3 3\n" + ] + }, + "execution_count": 55, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "text/plain": [ + "conv5_4: 512 512 3 3\n" + ] + }, + "execution_count": 55, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "text/plain": [ + "fc6: 1 1 25088 4096\n" + ] + }, + "execution_count": 55, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "text/plain": [ + "fc7: 1 1 4096 4096\n" + ] + }, + "execution_count": 55, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "text/plain": [ + "fc8: 1 1 4096 1000\n" + ] + }, + "execution_count": 55, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "text/plain": [ + "0.01 *\n", + "-4.1915\n", + "[torch.CudaTensor of size 1]\n", + "\n", + "Load the weight...\t\n" + ] + }, + "execution_count": 55, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "text/plain": [ + " 0.1113\n", + "[torch.CudaTensor of size 1]\n", + "\n", + "total number of parameters in cnn_model: \t20024384\t\n", + "total number of parameters in word_level: \t8031747\t\n", + "total number of parameters in phrase_level: \t2889219\t\n", + "total number of parameters in ques_level: \t5517315\t\n", + "constructing clones inside the ques_level\t\n" + ] + }, + "execution_count": 55, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "text/plain": [ + "total number of parameters in recursive_attention: \t2862056\t\n" + ] + }, + "execution_count": 55, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "require 'nn'\n", + "require 'torch'\n", + "require 'optim'\n", + "require 'misc.DataLoader'\n", + "require 'misc.word_level'\n", + "require 'misc.phrase_level'\n", + "require 'misc.ques_level'\n", + "require 'misc.recursive_atten'\n", + "require 'misc.cnnModel'\n", + "require 'misc.optim_updates'\n", + "utils = require 'misc.utils'\n", + "require 'xlua'\n", + "require 'image'\n", + "\n", + "\n", + "opt = {}\n", + "\n", + "opt.vqa_model = 'model/vqa_model/model_alternating_train_vgg.t7'\n", + "opt.cnn_proto = 'image_model/VGG_ILSVRC_19_layers_deploy.prototxt'\n", + "opt.cnn_model = 'image_model/VGG_ILSVRC_19_layers.caffemodel'\n", + "opt.json_file = 'data/vqa_data_prepro.json'\n", + "opt.backend = 'cudnn'\n", + "opt.gpuid = 1\n", + "if opt.gpuid >= 0 then\n", + " require 'cutorch'\n", + " require 'cunn'\n", + " if opt.backend == 'cudnn' then \n", + " require 'cudnn' \n", + " end\n", + " --cutorch.setDevice(opt.gpuid+1) -- note +1 because lua is 1-indexed\n", + "end\n", + "\n", + "local loaded_checkpoint = torch.load(opt.vqa_model)\n", + "local lmOpt = loaded_checkpoint.lmOpt\n", + "\n", + "lmOpt.hidden_size = 512\n", + "lmOpt.feature_type = 'VGG'\n", + "lmOpt.atten_type = 'Alternating'\n", + "cnnOpt = {}\n", + "cnnOpt.cnn_proto = opt.cnn_proto\n", + "cnnOpt.cnn_model = opt.cnn_model\n", + "cnnOpt.backend = opt.backend\n", + "cnnOpt.input_size_image = 512\n", + "cnnOpt.output_size = 512\n", + "cnnOpt.h = 14\n", + "cnnOpt.w = 14\n", + "cnnOpt.layer_num = 37\n", + "\n", + "-- load the vocabulary and answers.\n", + "\n", + "local json_file = utils.read_json(opt.json_file)\n", + "ix_to_word = json_file.ix_to_word\n", + "ix_to_ans = json_file.ix_to_ans\n", + "\n", + "word_to_ix = {}\n", + "for ix, word in pairs(ix_to_word) do\n", + " word_to_ix[word]=ix\n", + "end\n", + "\n", + "-- load the model\n", + "protos = {}\n", + "protos.word = nn.word_level(lmOpt)\n", + "protos.phrase = nn.phrase_level(lmOpt)\n", + "protos.ques = nn.ques_level(lmOpt)\n", + "\n", + "protos.atten = nn.recursive_atten()\n", + "protos.crit = nn.CrossEntropyCriterion()\n", + "protos.cnn = nn.cnnModel(cnnOpt)\n", + "\n", + "if opt.gpuid >= 0 then\n", + " for k,v in pairs(protos) do v:cuda() end\n", + "end\n", + "\n", + "cparams, grad_cparams = protos.cnn:getParameters()\n", + "wparams, grad_wparams = protos.word:getParameters()\n", + "pparams, grad_pparams = protos.phrase:getParameters()\n", + "qparams, grad_qparams = protos.ques:getParameters()\n", + "aparams, grad_aparams = protos.atten:getParameters()\n", + "\n", + "print(wparams:sub(1,1))\n", + "print('Load the weight...')\n", + "wparams:copy(loaded_checkpoint.wparams)\n", + "pparams:copy(loaded_checkpoint.pparams)\n", + "qparams:copy(loaded_checkpoint.qparams)\n", + "aparams:copy(loaded_checkpoint.aparams)\n", + "print(pparams:sub(1,10))\n", + "\n", + "print('total number of parameters in cnn_model: ', cparams:nElement())\n", + "assert(cparams:nElement() == grad_cparams:nElement())\n", + "\n", + "print('total number of parameters in word_level: ', wparams:nElement())\n", + "assert(wparams:nElement() == grad_wparams:nElement())\n", + "\n", + "print('total number of parameters in phrase_level: ', pparams:nElement())\n", + "assert(pparams:nElement() == grad_pparams:nElement())\n", + "\n", + "print('total number of parameters in ques_level: ', qparams:nElement())\n", + "assert(qparams:nElement() == grad_qparams:nElement())\n", + "protos.ques:shareClones()\n", + "\n", + "print('total number of parameters in recursive_attention: ', aparams:nElement())\n", + "assert(aparams:nElement() == grad_aparams:nElement())" + ] + }, + { + "cell_type": "code", + "execution_count": 63, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "ename": "[string \"-- specify the image and the question....\"]:46: ')' expected near 'ans'", + "evalue": "", + "output_type": "error", + "traceback": [ + "[string \"-- specify the image and the question....\"]:46: ')' expected near 'ans'" + ] + } + ], + "source": [ + "-- specify the image and the question.\n", + "local img_path = 'visu/demo_img1.jpg'\n", + "local question = 'what is the color of the hat ?'\n", + "\n", + "-- load the image\n", + "local img = image.load(img_path)\n", + "-- scale the image\n", + "img = image.scale(img,448,448)\n", + "itorch.image(img)\n", + "img = img:view(1,img:size(1),img:size(2),img:size(3))\n", + "-- parse and encode the question (in a simple way).\n", + "local ques_encode = torch.IntTensor(26):zero()\n", + "\n", + "local count = 1\n", + "for word in string.gmatch(question, \"%S+\") do\n", + " ques_encode[count] = word_to_ix[word] or word_to_ix['UNK']\n", + " count = count + 1\n", + "end\n", + "ques_encode = ques_encode:view(1,ques_encode:size(1))\n", + "-- doing the prediction\n", + "\n", + "protos.word:evaluate()\n", + "protos.phrase:evaluate()\n", + "protos.ques:evaluate()\n", + "protos.atten:evaluate()\n", + "protos.cnn:evaluate()\n", + "\n", + "local image_raw = utils.prepro(img, false)\n", + "image_raw = image_raw:cuda()\n", + "ques_encode = ques_encode:cuda()\n", + "\n", + "local image_feat = protos.cnn:forward(image_raw)\n", + "local ques_len = torch.Tensor(1,1):cuda()\n", + "ques_len[1] = count-1\n", + "\n", + "local word_feat, img_feat, w_ques, w_img, mask = unpack(protos.word:forward({ques_encode, image_feat}))\n", + "local conv_feat, p_ques, p_img = unpack(protos.phrase:forward({word_feat, ques_len, img_feat, mask}))\n", + "local q_ques, q_img = unpack(protos.ques:forward({conv_feat, ques_len, img_feat, mask}))\n", + "\n", + "local feature_ensemble = {w_ques, w_img, p_ques, p_img, q_ques, q_img}\n", + "local out_feat = protos.atten:forward(feature_ensemble)\n", + "\n", + "local tmp,pred=torch.max(out_feat,2)\n", + "local ans = ix_to_ans[tostring(pred[1][1])]\n", + "\n", + "print('The answer is:' ans)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "-- Attention Visualization\n", + "\n", + "\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "iTorch", + "language": "lua", + "name": "itorch" + }, + "language_info": { + "name": "lua", + "version": "5.1" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/prepro/prepro_cocoqa.py b/prepro/prepro_cocoqa.py index 26b8d46..b727b7c 100755 --- a/prepro/prepro_cocoqa.py +++ b/prepro/prepro_cocoqa.py @@ -240,8 +240,8 @@ def main(params): parser = argparse.ArgumentParser() # input json - parser.add_argument('--input_train_json', default='../data/coco_qa_raw_train.json', help='input json file to process into hdf5') - parser.add_argument('--input_test_json', default='../data/coco_qa_raw_test.json', help='input json file to process into hdf5') + parser.add_argument('--input_train_json', default='../data/cocoqa_raw_train.json', help='input json file to process into hdf5') + parser.add_argument('--input_test_json', default='../data/cocoqa_raw_test.json', help='input json file to process into hdf5') parser.add_argument('--output_json', default='../data/cocoqa_data_prepro.json', help='output json file') parser.add_argument('--output_h5', default='../data/cocoqa_data_prepro.h5', help='output h5 file') diff --git a/prepro/prepro_img_residule.lua b/prepro/prepro_img_residule.lua index 2469932..58e05e6 100755 --- a/prepro/prepro_img_residule.lua +++ b/prepro/prepro_img_residule.lua @@ -10,7 +10,7 @@ require 'hdf5' cjson=require('cjson') require 'xlua' require 'cudnn' -local t = require 'misc/transforms' +local t = require 'image_model.transforms' ------------------------------------------------------------------------------- -- Input arguments and options @@ -18,8 +18,8 @@ local t = require 'misc/transforms' cmd = torch.CmdLine() cmd:text() cmd:text('Options') -cmd:option('-input_json','','path to the json file containing vocab and answers') -cmd:option('-image_root','','path to the image root') +cmd:option('-input_json','../data/vqa_data_prepro.json','path to the json file containing vocab and answers') +cmd:option('-image_root','/home/jiasenlu/data/','path to the image root') cmd:option('-residule_path', '') cmd:option('-batch_size', 10, 'batch_size') diff --git a/prepro/prepro_img_vgg.lua b/prepro/prepro_img_vgg.lua index 96c8fe8..e0d9b27 100755 --- a/prepro/prepro_img_vgg.lua +++ b/prepro/prepro_img_vgg.lua @@ -16,17 +16,15 @@ require 'xlua' cmd = torch.CmdLine() cmd:text() cmd:text('Options') -cmd:option('-input_json','data/cocoqa_data_prepro.json','path to the json file containing vocab and answers') +cmd:option('-input_json','../data/vqa_data_prepro.json','path to the json file containing vocab and answers') cmd:option('-image_root','/home/jiasenlu/data/','path to the image root') -cmd:option('-cnn_proto', '/home/jiasenlu/code/vqa_code/vqaSoA/CNN_model/VGG_ILSVRC_19_layers_deploy.prototxt', 'path to the cnn prototxt') -cmd:option('-cnn_model', '/home/jiasenlu/code/vqa_code/vqaSoA/CNN_model/VGG_ILSVRC_19_layers.caffemodel', 'path to the cnn model') ---cmd:option('-cnn_proto', '/home/jiasenlu/code/vqa_code/vqaSoA/CNN_model/ResNet-101-deploy.prototxt', 'path to the cnn prototxt') ---cmd:option('-cnn_model', '/home/jiasenlu/code/vqa_code/vqaSoA/CNN_model/ResNet-101-model.caffemodel', 'path to the cnn model') +cmd:option('-cnn_proto', '../image_model/VGG_ILSVRC_19_layers_deploy.prototxt', 'path to the cnn prototxt') +cmd:option('-cnn_model', '../image_model/VGG_ILSVRC_19_layers.caffemodel', 'path to the cnn model') cmd:option('-batch_size', 20, 'batch_size') -cmd:option('-out_name_train', 'data/cocoqa_data_img_vgg_train.h5', 'output name') -cmd:option('-out_name_test', 'data/cocoqa_data_img_vgg_test.h5', 'output name') +cmd:option('-out_name_train', '../data/vqa_data_img_vgg_train.h5', 'output name train') +cmd:option('-out_name_test', '../data/vqa_data_img_vgg_test.h5', 'output name test') cmd:option('-gpuid', 6, 'which gpu to use. -1 = use CPU') cmd:option('-backend', 'cudnn', 'nn|cudnn') diff --git a/prepro/prepro_vqa.py b/prepro/prepro_vqa.py index 6aeeaae..fec5859 100755 --- a/prepro/prepro_vqa.py +++ b/prepro/prepro_vqa.py @@ -182,8 +182,8 @@ def main(params): imgs_train = json.load(open(params['input_train_json'], 'r')) imgs_test = json.load(open(params['input_test_json'], 'r')) - #imgs_train = imgs_train[:5000] - #imgs_test = imgs_test[:5000] + imgs_train = imgs_train[:5000] + imgs_test = imgs_test[:5000] # get top answers top_ans = get_top_answers(imgs_train, params) atoi = {w:i+1 for i,w in enumerate(top_ans)} @@ -214,7 +214,7 @@ def main(params): # get the answer encoding. ans_train = encode_answer(imgs_train, atoi) - #ans_test = encode_answer(imgs_test, atoi) + ans_test = encode_answer(imgs_test, atoi) MC_ans_test = encode_mc_answer(imgs_test, atoi) # get the split @@ -234,7 +234,7 @@ def main(params): f.create_dataset("ques_test", dtype='uint32', data=ques_test) f.create_dataset("answers", dtype='uint32', data=ans_train) - #f.create_dataset("ans_test", dtype='uint32', data=ans_test) + f.create_dataset("ans_test", dtype='uint32', data=ans_test) f.create_dataset("ques_id_train", dtype='uint32', data=question_id_train) f.create_dataset("ques_id_test", dtype='uint32', data=question_id_test) @@ -277,8 +277,8 @@ def main(params): parser.add_argument('--input_test_json', default='../data/vqa_raw_test.json', help='input json file to process into hdf5') parser.add_argument('--num_ans', default=1000, type=int, help='number of top answers for the final classifications.') - parser.add_argument('--output_json', default='../data_prepro_all.json', help='output json file') - parser.add_argument('--output_h5', default='../data_prepro_all.h5', help='output h5 file') + parser.add_argument('--output_json', default='../data/vqa_data_prepro.json', help='output json file') + parser.add_argument('--output_h5', default='../data/vqa_data_prepro.h5', help='output h5 file') # options parser.add_argument('--max_length', default=26, type=int, help='max length of a caption, in number of words. captions longer than this get clipped.') diff --git a/train.lua b/train.lua new file mode 100644 index 0000000..281de26 --- /dev/null +++ b/train.lua @@ -0,0 +1,362 @@ +------------------------------------------------------------------------------ +-- Hierarchical Question-Image Co-Attention for Visual Question Answering +-- J. Lu, J. Yang, D. Batra, and D. Parikh +-- https://arxiv.org/abs/1606.00061, 2016 +-- if you have any question about the code, please contact jiasenlu@vt.edu +----------------------------------------------------------------------------- + +require 'nn' +require 'torch' +require 'optim' +require 'misc.DataLoaderDisk' +require 'misc.word_level' +require 'misc.phrase_level' +require 'misc.ques_level' +require 'misc.recursive_atten' +require 'misc.optim_updates' +local utils = require 'misc.utils' +require 'xlua' + +------------------------------------------------------------------------------- +-- Input arguments and options +------------------------------------------------------------------------------- + +cmd = torch.CmdLine() +cmd:text() +cmd:text('Train a Visual Question Answering model') +cmd:text() +cmd:text('Options') + +-- Data input settings +cmd:option('-input_img_train_h5','data/vqa_data_img_vgg_train.h5','path to the h5file containing the image feature') +cmd:option('-input_img_test_h5','data/vqa_data_img_vgg_test.h5','path to the h5file containing the image feature') +cmd:option('-input_ques_h5','data/vqa_data_prepro.h5','path to the h5file containing the preprocessed dataset') +cmd:option('-input_json','data/vqa_data_prepro.json','path to the json file containing additional info and vocab') + +cmd:option('-start_from', '', 'path to a model checkpoint to initialize model weights from. Empty = don\'t') +cmd:option('-co_atten_type', 'Alternating', 'co_attention type. Parallel or Alternating, alternating trains more faster than parallel.') +cmd:option('-feature_type', 'VGG', 'VGG or Residual') + + +cmd:option('-hidden_size',512,'the hidden layer size of the model.') +cmd:option('-rnn_size',512,'size of the rnn in number of hidden nodes in each layer') +cmd:option('-batch_size',200,'what is theutils batch size in number of images per batch? (there will be x seq_per_img sentences)') +cmd:option('-output_size', 1000, 'number of output answers') +cmd:option('-rnn_layers',2,'number of the rnn layer') + + +-- Optimization +cmd:option('-optim','rmsprop','what update to use? rmsprop|sgd|sgdmom|adagrad|adam') +cmd:option('-learning_rate',4e-4,'learning rate') +cmd:option('-learning_rate_decay_start', 0, 'at what iteration to start decaying learning rate? (-1 = dont)') +cmd:option('-learning_rate_decay_every', 300, 'every how many epoch thereafter to drop LR by 0.1?') +cmd:option('-optim_alpha',0.99,'alpha for adagrad/rmsprop/momentum/adam') +cmd:option('-optim_beta',0.995,'beta used for adam') +cmd:option('-optim_epsilon',1e-8,'epsilon that goes into denominator in rmsprop') +cmd:option('-max_iters', -1, 'max number of iterations to run for (-1 = run forever)') +cmd:option('-iterPerEpoch', 1200) + +-- Evaluation/Checkpointing +cmd:option('-save_checkpoint_every', 6000, 'how often to save a model checkpoint?') +cmd:option('-checkpoint_path', 'save/train_vgg', 'folder to save checkpoints into (empty = this folder)') + +-- Visualization +cmd:option('-losses_log_every', 600, 'How often do we save losses, for inclusion in the progress dump? (0 = disable)') + +-- misc +cmd:option('-id', '0', 'an id identifying this run/job. used in cross-val and appended when writing progress files') +cmd:option('-backend', 'cudnn', 'nn|cudnn') +cmd:option('-gpuid', 6, 'which gpu to use. -1 = use CPU') +cmd:option('-seed', 123, 'random number generator seed to use') + +cmd:text() + +------------------------------------------------------------------------------- +-- Basic Torch initializations +------------------------------------------------------------------------------- +local opt = cmd:parse(arg) +torch.manualSeed(opt.seed) +print(opt) +torch.setdefaulttensortype('torch.FloatTensor') -- for CPU + +if opt.gpuid >= 0 then + require 'cutorch' + require 'cunn' + if opt.backend == 'cudnn' then + require 'cudnn' + end + --cutorch.manualSeed(opt.seed) + --cutorch.setDevice(opt.gpuid+1) -- note +1 because lua is 1-indexed +end + +opt = cmd:parse(arg) + +------------------------------------------------------------------------------- +-- Create the Data Loader instance +------------------------------------------------------------------------------- +local loader = DataLoader{h5_img_file_train = opt.input_img_train_h5, h5_img_file_test = opt.input_img_test_h5, h5_ques_file = opt.input_ques_h5, json_file = opt.input_json, feature_type = opt.feature_type} +------------------------------------------------------------------------ +--Design Parameters and Network Definitions +------------------------------------------------------------------------ +local protos = {} +print('Building the model...') +-- intialize language model +local loaded_checkpoint +local lmOpt +if string.len(opt.start_from) > 0 then + local start_path = path.join(opt.checkpoint_path .. '_' .. opt.co_atten_type , opt.start_from) + loaded_checkpoint = torch.load(start_path) + lmOpt = loaded_checkpoint.lmOpt +else + lmOpt = {} + lmOpt.vocab_size = loader:getVocabSize() + lmOpt.hidden_size = opt.hidden_size + lmOpt.rnn_size = opt.rnn_size + lmOpt.num_layers = opt.rnn_layers + lmOpt.dropout = 0.5 + lmOpt.seq_length = loader:getSeqLength() + lmOpt.batch_size = opt.batch_size + lmOpt.output_size = opt.rnn_size + lmOpt.atten_type = opt.co_atten_type + lmOpt.feature_type = opt.feature_type +end + +protos.word = nn.word_level(lmOpt) +protos.phrase = nn.phrase_level(lmOpt) +protos.ques = nn.ques_level(lmOpt) + +protos.atten = nn.recursive_atten() +protos.crit = nn.CrossEntropyCriterion() +-- ship everything to GPU, maybe + +if opt.gpuid >= 0 then + for k,v in pairs(protos) do v:cuda() end +end + +local wparams, grad_wparams = protos.word:getParameters() +local pparams, grad_pparams = protos.phrase:getParameters() +local qparams, grad_qparams = protos.ques:getParameters() +local aparams, grad_aparams = protos.atten:getParameters() + + +if string.len(opt.start_from) > 0 then + print('Load the weight...') + wparams:copy(loaded_checkpoint.wparams) + pparams:copy(loaded_checkpoint.pparams) + qparams:copy(loaded_checkpoint.qparams) + aparams:copy(loaded_checkpoint.aparams) +end + +print('total number of parameters in word_level: ', wparams:nElement()) +assert(wparams:nElement() == grad_wparams:nElement()) + +print('total number of parameters in phrase_level: ', pparams:nElement()) +assert(pparams:nElement() == grad_pparams:nElement()) + +print('total number of parameters in ques_level: ', qparams:nElement()) +assert(qparams:nElement() == grad_qparams:nElement()) +protos.ques:shareClones() + +print('total number of parameters in recursive_attention: ', aparams:nElement()) +assert(aparams:nElement() == grad_aparams:nElement()) + +collectgarbage() + +------------------------------------------------------------------------------- +-- Validation evaluation +------------------------------------------------------------------------------- +local function eval_split(split) + + protos.word:evaluate() + protos.phrase:evaluate() + protos.ques:evaluate() + protos.atten:evaluate() + loader:resetIterator(split) + + local n = 0 + local loss_sum = 0 + local loss_evals = 0 + local right_sum = 0 + local predictions = {} + local total_num = loader:getDataNum(split) + while true do + local data = loader:getBatch{batch_size = opt.batch_size, split = split} + -- ship the data to cuda + if opt.gpuid >= 0 then + data.answer = data.answer:cuda() + data.images = data.images:cuda() + data.questions = data.questions:cuda() + data.ques_len = data.ques_len:cuda() + end + n = n + data.images:size(1) + xlua.progress(n, total_num) + + local word_feat, img_feat, w_ques, w_img, mask = unpack(protos.word:forward({data.questions, data.images})) + + local conv_feat, p_ques, p_img = unpack(protos.phrase:forward({word_feat, data.ques_len, img_feat, mask})) + + local q_ques, q_img = unpack(protos.ques:forward({conv_feat, data.ques_len, img_feat, mask})) + + local feature_ensemble = {w_ques, w_img, p_ques, p_img, q_ques, q_img} + local out_feat = protos.atten:forward(feature_ensemble) + + -- forward the language model criterion + local loss = protos.crit:forward(out_feat, data.answer) + + local tmp,pred=torch.max(out_feat,2) + + for i = 1, pred:size()[1] do + + if pred[i][1] == data.answer[i] then + right_sum = right_sum + 1 + end + end + + loss_sum = loss_sum + loss + loss_evals = loss_evals + 1 + if n >= total_num then break end + end + + return loss_sum/loss_evals, right_sum / total_num +end + + +------------------------------------------------------------------------------- +-- Loss function +------------------------------------------------------------------------------- +local iter = 0 +local function lossFun() + protos.word:training() + grad_wparams:zero() + + protos.phrase:training() + grad_pparams:zero() + + protos.ques:training() + grad_qparams:zero() + + protos.atten:training() + grad_aparams:zero() + + ---------------------------------------------------------------------------- + -- Forward pass + ----------------------------------------------------------------------------- + -- get batch of data + local data = loader:getBatch{batch_size = opt.batch_size, split = 0} + if opt.gpuid >= 0 then + data.answer = data.answer:cuda() + data.questions = data.questions:cuda() + data.ques_len = data.ques_len:cuda() + data.images = data.images:cuda() + end + + local word_feat, img_feat, w_ques, w_img, mask = unpack(protos.word:forward({data.questions, data.images})) + + local conv_feat, p_ques, p_img = unpack(protos.phrase:forward({word_feat, data.ques_len, img_feat, mask})) + + local q_ques, q_img = unpack(protos.ques:forward({conv_feat, data.ques_len, img_feat, mask})) + + local feature_ensemble = {w_ques, w_img, p_ques, p_img, q_ques, q_img} + local out_feat = protos.atten:forward(feature_ensemble) + + -- forward the language model criterion + local loss = protos.crit:forward(out_feat, data.answer) + ----------------------------------------------------------------------------- + -- Backward pass + ----------------------------------------------------------------------------- + -- backprop criterion + local dlogprobs = protos.crit:backward(out_feat, data.answer) + + local d_w_ques, d_w_img, d_p_ques, d_p_img, d_q_ques, d_q_img = unpack(protos.atten:backward(feature_ensemble, dlogprobs)) + + local d_ques_feat, d_ques_img = unpack(protos.ques:backward({conv_feat, data.ques_len, img_feat}, {d_q_ques, d_q_img})) + + --local d_ques1 = protos.bl1:backward({ques_feat_0, data.ques_len}, d_ques2) + local d_conv_feat, d_conv_img = unpack(protos.phrase:backward({word_feat, data.ques_len, img_feat}, {d_ques_feat, d_p_ques, d_p_img})) + + local dummy = protos.word:backward({data.questions, data.images}, {d_conv_feat, d_w_ques, d_w_img, d_conv_img, d_ques_img}) + + ----------------------------------------------------------------------------- + -- and lets get out! + local stats = {} + stats.dt = dt + local losses = {} + losses.total_loss = loss + return losses, stats +end + +------------------------------------------------------------------------------- +-- Main loop +------------------------------------------------------------------------------- + +local loss0 +local w_optim_state = {} +local p_optim_state = {} +local q_optim_state = {} +local a_optim_state = {} +local loss_history = {} +local accuracy_history = {} +local learning_rate_history = {} +local best_val_loss = 10000 +local ave_loss = 0 +local timer = torch.Timer() +local decay_factor = math.exp(math.log(0.1)/opt.learning_rate_decay_every/opt.iterPerEpoch) +local learning_rate = opt.learning_rate +-- create the path to save the model. +paths.mkdir(opt.checkpoint_path .. '_' .. opt.co_atten_type) + +while true do + -- eval loss/gradient + local losses, stats = lossFun() + ave_loss = ave_loss + losses.total_loss + -- decay the learning rate + if iter > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0 then + learning_rate = learning_rate * decay_factor -- set the decayed rate + end + + if iter % opt.losses_log_every == 0 then + ave_loss = ave_loss / opt.losses_log_every + loss_history[iter] = losses.total_loss + accuracy_history[iter] = ave_loss + learning_rate_history[iter] = learning_rate + + print(string.format('iter %d: %f, %f, %f, %f', iter, losses.total_loss, ave_loss, learning_rate, timer:time().real)) + + ave_loss = 0 + end + + -- save checkpoint once in a while (or on final iteration) + if (iter % opt.save_checkpoint_every == 0 or iter == opt.max_iters) then + local val_loss, val_accu = eval_split(2) + print('validation loss: ', val_loss, 'accuracy ', val_accu) + + local checkpoint_path = path.join(opt.checkpoint_path .. '_' .. opt.co_atten_type, 'model_id' .. opt.id .. '_iter'.. iter) + torch.save(checkpoint_path..'.t7', {wparams=wparams, pparams = pparams, qparams=qparams, aparams=aparams, lmOpt=lmOpt}) + + local checkpoint = {} + checkpoint.opt = opt + checkpoint.iter = iter + checkpoint.loss_history = loss_history + checkpoint.accuracy_history = accuracy_history + checkpoint.learning_rate_history = learning_rate_history + + local checkpoint_path = path.join(opt.checkpoint_path .. '_' .. opt.co_atten_type, 'checkpoint' .. '.json') + + utils.write_json(checkpoint_path, checkpoint) + print('wrote json checkpoint to ' .. checkpoint_path .. '.json') + + end + + -- perform a parameter update + if opt.optim == 'rmsprop' then + rmsprop(wparams, grad_wparams, learning_rate, opt.optim_alpha, opt.optim_epsilon, w_optim_state) + rmsprop(pparams, grad_pparams, learning_rate, opt.optim_alpha, opt.optim_epsilon, p_optim_state) + rmsprop(qparams, grad_qparams, learning_rate, opt.optim_alpha, opt.optim_epsilon, q_optim_state) + rmsprop(aparams, grad_aparams, learning_rate, opt.optim_alpha, opt.optim_epsilon, a_optim_state) + else + error('bad option opt.optim') + end + + iter = iter + 1 + if opt.max_iters > 0 and iter >= opt.max_iters then break end -- stopping criterion +end diff --git a/vis/demo_img1.jpg b/vis/demo_img1.jpg new file mode 100644 index 0000000..0ce087e Binary files /dev/null and b/vis/demo_img1.jpg differ