Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
26 changed files
with
2,774 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,123 @@ | ||
# Hierarchical Co-Attention for Visual Question Answering | ||
|
||
Train a Hierarchical Co-Attention model for Visual Question Answering. This current code can get 62.1 on Open-Ended and 66.1 on Multiple-Choice on test-standard split. You can check Codalab leaderboard for more details. | ||
Train a Hierarchical Co-Attention model for Visual Question Answering. This current code can get 62.1 on Open-Ended and 66.1 on Multiple-Choice on test-standard split. For COCO-QA, this code can get 65.4 on Accuracy. For more information, please refer the paper [https://arxiv.org/abs/1606.00061](https://arxiv.org/abs/1606.00061) | ||
|
||
![teaser results](https://raw.github.com/jiasenlu/HieCoAttenVQA/master/vis/demo.png) | ||
### Requirements | ||
This code is written in Lua and requires [Torch](http://torch.ch/). The preprocssinng code is in Python, and you need to install [NLTK](http://www.nltk.org/) if you want to use NLTK to tokenize the question. | ||
|
||
You also need to install the following package in order to sucessfully run the code. | ||
|
||
- [cudnn.torch](https://github.com/soumith/cudnn.torch) | ||
- [torch-hdf5](https://github.com/deepmind/torch-hdf5) | ||
- [lua-cjson](http://www.kyne.com.au/~mark/software/lua-cjson.php) | ||
- [loadcaffe](https://github.com/szagoruyko/loadcaffe) | ||
- [iTorch](https://github.com/facebook/iTorch) | ||
|
||
### Training | ||
|
||
We have prepared everything for you ;) | ||
|
||
##### Download Dataset | ||
The first thing you need to do is to download the data and do some preprocessing. Head over to the `data/` folder and run | ||
|
||
For **VQA**: | ||
|
||
``` | ||
$ python vqa_preprocessing.py --download True --split 1 | ||
``` | ||
`--download Ture` means you choose to download the VQA data from the [VQA website](http://www.visualqa.org/) and `--split 1` means you use COCO train set to train and validation set to evaluation. `--split 2 ` means you use COCO train+val set to train and test set to evaluate. After this step, it will generate two files under the `data` folder. `vqa_raw_train.json` and `vqa_raw_test.json` | ||
|
||
For **COCO-QA** | ||
|
||
``` | ||
$ python vqa_preprocessing.py --download True | ||
``` | ||
This will download the COCO-QA dataset from [here](http://www.cs.toronto.edu/~mren/imageqa/data/cocoqa/) and generate two files under the `data` folder. `cocoqa_raw_train.json` and `cocoqa_raw_test.json` | ||
|
||
##### Download Image Model | ||
Here we use VGG_ILSVRC_19_layers [model](https://gist.github.com/ksimonyan/3785162f95cd2d5fee77) and Deep Residual network implement by Facebook [model](https://github.com/facebook/fb.resnet.torch). | ||
|
||
Head over to the `image_model` folder and run | ||
|
||
``` | ||
$ python download_model --download 'VGG' | ||
``` | ||
This will download the VGG_ILSVRC_19_layers model under `image_model` folder. To download the Deep Residual Model, you need to change the `VGG` to `Residual`. | ||
|
||
##### Generate Image/Question Features | ||
|
||
Head over to the `prepro` folder and run | ||
|
||
For **VQA**: | ||
|
||
``` | ||
$ python prepro_vqa.py --input_train_json ../data/vqa_raw_train.json --input_test_json ../data/vqa_raw_test.json --num_ans 1000 | ||
``` | ||
to get the question features. --num_ans specifiy how many top answers you want to use during training. You will also see some question and answer statistics in the terminal output. This will generate two files in `data/` folder, `vqa_data_prepro.h5` and `vqa_data_prepro.json`. | ||
|
||
|
||
For **COCO-QA** | ||
|
||
``` | ||
$ python prepro_cocoqa.py --input_train_json ../data/cocoqa_raw_train.json --input_test_json ../data/cocoqa_raw_test.json | ||
``` | ||
COCO-QA use all the answers in train, so there is no `--num_ans` option. This will generate two files in `data/` folder, `cocoqa_data_prepro.h5` and `cocoqa_data_prepro.json`. | ||
|
||
Then we are ready to extract the image features. | ||
|
||
## Comming soon! | ||
For **VGG** image feature: | ||
|
||
``` | ||
$ th prepro_img_vgg.lua -input_json ../data/vqa_data_prepro.json -image_root /home/jiasenlu/data/ -cnn_proto ../image_model/VGG_ILSVRC_19_layers_deploy.prototxt -cnn_model ../image_model/VGG_ILSVRC_19_layers.caffemodel | ||
``` | ||
you can change the `-gpuid`, `-backend` and `-batch_size` based on your gpu. | ||
|
||
For **Deep Residual** image feature: | ||
|
||
##### Train the model | ||
|
||
We have everything ready to train the VQA and COCO-QA model. Back to the `main` folder | ||
|
||
``` | ||
th train.lua -input_img_train_h5 data/vqa_data_img_vgg_train.h5 -input_img_test_h5 data/vqa_data_img_vgg_test.h5 -input_ques_h5 data/vqa_data_prepro.h5 -input_json data/vqa_data_prepro.json -co_atten_type Alternating -feature_type VGG | ||
``` | ||
|
||
to train **Alternating co-attention** model on VQA using VGG image feature. You can train the **Parallel co-attention** by setting `-co_atten_type Parallel`. The prallel co-attention usually takes more time than alternating co-attention. | ||
|
||
##### Note | ||
- Deep Residual Image Feature is 4 times larger than VGG feature, make sure you have enough RAM when you extract or load the features. | ||
- If you didn't have large RAM, replace the `require 'misc.DataLoader'` (Line 11 in `train.lua`) with `require 'misc.DataLoaderDisk`. The model will read the data directly from the hard disk (SSD prefered) | ||
|
||
### Evaluation | ||
|
||
##### Evaluate using Pre-trained Model | ||
The pre-trained model can be download [here](https://filebox.ece.vt.edu/~jiasenlu/codeRelease/co_atten/model/) | ||
|
||
##### Metric | ||
|
||
To Evaluate VQA, you need to download the VQA [evaluation tool](https://github.com/VT-vision-lab/VQA). To evaluate COCO-QA, you can use script `evaluate_cocoqa.py` under `metric/` folder. If you need to evaluate based on WUPS, download the evaluation script from [here](http://datasets.d2.mpi-inf.mpg.de/mateusz14visual-turing/calculate_wups.py) | ||
|
||
##### VQA on Single Image with Free Form Question | ||
|
||
We use iTorch to demo the visual question answering with pre-trained model. | ||
In the root folder, open `itorch notebook`, then you can load any image and ask question using the itorch notebook. | ||
|
||
##### Attention Visualization | ||
|
||
|
||
### Reference | ||
|
||
If you use this code as part of any published research, please acknowledge the following paper | ||
|
||
``` | ||
@misc{Lu2016Hie, | ||
author = {Lu, Jiasen and Yang, Jianwei and Batra, Dhruv and Parikh, Devi}, | ||
title = {Hierarchical Question-Image Co-Attention for Visual Question Answering}, | ||
journal = {arXiv preprint arXiv:1606.00061v2}, | ||
year = {2016} | ||
} | ||
``` | ||
|
||
### Attention Demo | ||
|
||
![teaser results](https://raw.github.com/jiasenlu/HieCoAttenVQA/master/vis/demo.png) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,193 @@ | ||
|
||
require 'nn' | ||
require 'torch' | ||
require 'optim' | ||
require 'misc.DataLoader' | ||
require 'misc.word_level' | ||
require 'misc.phrase_level' | ||
require 'misc.ques_level' | ||
require 'misc.recursive_atten' | ||
require 'misc.optim_updates' | ||
local utils = require 'misc.utils' | ||
require 'xlua' | ||
|
||
|
||
cmd = torch.CmdLine() | ||
cmd:text() | ||
cmd:text('evaluate a Visual Question Answering model') | ||
cmd:text() | ||
cmd:text('Options') | ||
|
||
-- Data input settings | ||
cmd:option('-input_img_train_h5','data/vqa_data_img_vgg_train.h5','path to the h5file containing the image feature') | ||
cmd:option('-input_img_test_h5','data/vqa_data_img_vgg_test.h5','path to the h5file containing the image feature') | ||
cmd:option('-input_ques_h5','data/vqa_data_prepro.h5','path to the h5file containing the preprocessed dataset') | ||
cmd:option('-input_json','data/vqa_data_prepro.json','path to the json file containing additional info and vocab') | ||
|
||
cmd:option('-start_from', 'model/vqa_model/model_alternating_train_vgg.t7', 'path to a model checkpoint to initialize model weights from. Empty = don\'t') | ||
cmd:option('-co_atten_type', 'Alternating', 'co_attention type. Parallel or Alternating, alternating trains more faster than parallel.') | ||
cmd:option('-feature_type', 'VGG', 'VGG or Residual') | ||
|
||
-- misc | ||
cmd:option('-backend', 'cudnn', 'nn|cudnn') | ||
cmd:option('-gpuid', 2, 'which gpu to use. -1 = use CPU') | ||
cmd:option('-seed', 123, 'random number generator seed to use') | ||
|
||
cmd:text() | ||
|
||
local batch_size = 256 | ||
|
||
------------------------------------------------------------------------------- | ||
-- Basic Torch initializations | ||
------------------------------------------------------------------------------- | ||
local opt = cmd:parse(arg) | ||
torch.manualSeed(opt.seed) | ||
print(opt) | ||
torch.setdefaulttensortype('torch.FloatTensor') -- for CPU | ||
|
||
if opt.gpuid >= 0 then | ||
require 'cutorch' | ||
require 'cunn' | ||
if opt.backend == 'cudnn' then | ||
require 'cudnn' | ||
end | ||
cutorch.manualSeed(opt.seed) | ||
--cutorch.setDevice(opt.gpuid+1) -- note +1 because lua is 1-indexed | ||
end | ||
|
||
opt = cmd:parse(arg) | ||
|
||
------------------------------------------------------------------------ | ||
--Design Parameters and Network Definitions | ||
------------------------------------------------------------------------ | ||
local protos = {} | ||
print('Building the model...') | ||
-- intialize language model | ||
local loaded_checkpoint | ||
local lmOpt | ||
if string.len(opt.start_from) > 0 then | ||
|
||
loaded_checkpoint = torch.load(opt.start_from) | ||
lmOpt = loaded_checkpoint.lmOpt | ||
else | ||
lmOpt = {} | ||
lmOpt.vocab_size = loader:getVocabSize() | ||
lmOpt.input_encoding_size = opt.input_encoding_size | ||
lmOpt.rnn_size = opt.rnn_size | ||
lmOpt.num_layers = opt.rnn_layers | ||
lmOpt.dropout = 0.5 | ||
lmOpt.seq_length = loader:getSeqLength() | ||
lmOpt.batch_size = opt.batch_size | ||
lmOpt.output_size = opt.rnn_size | ||
lmOpt.atten_type = opt.co_atten_type | ||
lmOpt.feature_type = opt.feature_type | ||
end | ||
lmOpt.hidden_size = 512 | ||
lmOpt.feature_type = 'VGG' | ||
lmOpt.atten_type = opt.co_atten_type | ||
print(lmOpt) | ||
|
||
protos.word = nn.word_level(lmOpt) | ||
protos.phrase = nn.phrase_level(lmOpt) | ||
protos.ques = nn.ques_level(lmOpt) | ||
|
||
protos.atten = nn.recursive_atten() | ||
protos.crit = nn.CrossEntropyCriterion() | ||
|
||
if opt.gpuid >= 0 then | ||
for k,v in pairs(protos) do v:cuda() end | ||
end | ||
|
||
local wparams, grad_wparams = protos.word:getParameters() | ||
local pparams, grad_pparams = protos.phrase:getParameters() | ||
local qparams, grad_qparams = protos.ques:getParameters() | ||
local aparams, grad_aparams = protos.atten:getParameters() | ||
|
||
|
||
if string.len(opt.start_from) > 0 then | ||
print('Load the weight...') | ||
wparams:copy(loaded_checkpoint.wparams) | ||
pparams:copy(loaded_checkpoint.pparams) | ||
qparams:copy(loaded_checkpoint.qparams) | ||
aparams:copy(loaded_checkpoint.aparams) | ||
end | ||
|
||
print('total number of parameters in word_level: ', wparams:nElement()) | ||
assert(wparams:nElement() == grad_wparams:nElement()) | ||
|
||
print('total number of parameters in phrase_level: ', pparams:nElement()) | ||
assert(pparams:nElement() == grad_pparams:nElement()) | ||
|
||
print('total number of parameters in ques_level: ', qparams:nElement()) | ||
assert(qparams:nElement() == grad_qparams:nElement()) | ||
protos.ques:shareClones() | ||
|
||
print('total number of parameters in recursive_attention: ', aparams:nElement()) | ||
assert(aparams:nElement() == grad_aparams:nElement()) | ||
|
||
------------------------------------------------------------------------------- | ||
-- Create the Data Loader instance | ||
------------------------------------------------------------------------------- | ||
|
||
local loader = DataLoader{h5_img_file_train = opt.input_img_train_h5, h5_img_file_test = opt.input_img_test_h5, h5_ques_file = opt.input_ques_h5, json_file = opt.input_json, feature_type = opt.feature_type} | ||
|
||
collectgarbage() | ||
|
||
function eval_split(split) | ||
|
||
protos.word:evaluate() | ||
protos.phrase:evaluate() | ||
protos.ques:evaluate() | ||
protos.atten:evaluate() | ||
loader:resetIterator(split) | ||
|
||
local n = 0 | ||
local loss_evals = 0 | ||
local predictions = {} | ||
local total_num = loader:getDataNum(2) | ||
print(total_num) | ||
local logprob_all = torch.Tensor(total_num, 1000) | ||
local ques_id = torch.Tensor(total_num) | ||
|
||
for i = 1, total_num, batch_size do | ||
xlua.progress(i, total_num) | ||
local r = math.min(i+batch_size-1, total_num) | ||
|
||
local data = loader:getBatch{batch_size = r-i+1, split = split} | ||
-- ship the data to cuda | ||
if opt.gpuid >= 0 then | ||
data.images = data.images:cuda() | ||
data.questions = data.questions:cuda() | ||
data.ques_len = data.ques_len:cuda() | ||
end | ||
|
||
local word_feat, img_feat, w_ques, w_img, mask = unpack(protos.word:forward({data.questions, data.images})) | ||
|
||
local conv_feat, p_ques, p_img = unpack(protos.phrase:forward({word_feat, data.ques_len, img_feat, mask})) | ||
|
||
local q_ques, q_img = unpack(protos.ques:forward({conv_feat, data.ques_len, img_feat, mask})) | ||
|
||
local feature_ensemble = {w_ques, w_img, p_ques, p_img, q_ques, q_img} | ||
local out_feat = protos.atten:forward(feature_ensemble) | ||
|
||
logprob_all:sub(i, r):copy(out_feat:float()) | ||
ques_id:sub(i, r):copy(data.ques_id) | ||
|
||
end | ||
|
||
|
||
tmp,pred=torch.max(logprob_all,2); | ||
|
||
for i=1,total_num do | ||
local ans = loader.ix_to_ans[tostring(pred[{i,1}])] | ||
table.insert(predictions,{question_id=ques_id[i],answer=ans}) | ||
end | ||
|
||
return {predictions} | ||
end | ||
|
||
predictions = eval_split(2) | ||
|
||
utils.write_json('OpenEnded_mscoco_co-atten_results.json', predictions[1]) | ||
|
||
--utils.write_json('MultipleChoice_mscoco_co-atten_results.json', predictions[2]) |
Oops, something went wrong.