-
Notifications
You must be signed in to change notification settings - Fork 35
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
97 changed files
with
7,534 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
*.pyc | ||
test.ipynb | ||
.ipynb_checkpoints | ||
.idea | ||
*history | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
[submodule "cider"] | ||
path = cider | ||
url = https://github.com/littlekobe/cider |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,79 @@ | ||
# AREL-for-Visual-Storytelling | ||
Code coming soon. | ||
# No Metrics Are Perfect: Adversarial REward Learning for Visual Storytelling | ||
|
||
In this paper, we not only introduce a novel adversarial reward learning algorithm to generate more human-like stories given image sequences, but also empirically analyze the limitations of the automatic metrics for story evaluation. | ||
|
||
For more details, please check the latest version of the paper: [https://arxiv.org/abs/1804.09160](https://arxiv.org/abs/1804.09160). | ||
|
||
## Prerequisites | ||
- Python 2.7 | ||
- PyTorch 0.3 | ||
- TensorFlow (optional, only using the fantastic tensorboard) | ||
- cuda & cudnn | ||
|
||
## Usage | ||
### 1. Setup | ||
Clone this github repository recursively: | ||
|
||
``` | ||
git clone --recursive https://github.com/littlekobe/Visual-Storytelling.git ./ | ||
``` | ||
|
||
Download the preprocessed ResNet-152 features [here](http://nlp.cs.ucsb.edu/data/VIST_resnet_features.zip) and unzip it into `DATADIR/resnet_features`. | ||
|
||
### 2. Supervised Learning | ||
We use cross entropy loss to warm start the model first: | ||
|
||
``` | ||
python train.py --id XE --data_dir DATADIR --start_rl -1 | ||
``` | ||
|
||
Check the file `opt.py` for more options, where you can play with some other settings. | ||
|
||
### 3. AREL Learning | ||
To train an AREL model, run | ||
|
||
``` | ||
python train_AREL.py --id AREL --start_from_model PRETRAINED_MODEL | ||
``` | ||
|
||
Note that `PRETRAINED_MODEL` can be `data/save/XE/model.pth` or some other saved models. | ||
Check `opt.py` for more information. | ||
|
||
### 4. Monitor your training | ||
TensorBoard is used to monitor the training process. Suppose you set the option `checkpoint_path` as `data/save`, then run | ||
|
||
``` | ||
tensorboard --logdir data/save/tensorboard | ||
``` | ||
|
||
And then open your browser and go to `[IP address]:6006` (the default port for tensorboard is `6006`). | ||
|
||
### 5. Testing | ||
To test the model's performance, run | ||
|
||
``` | ||
python train.py --option test --start_from_model data/save/XE/model.pth | ||
``` | ||
|
||
or | ||
|
||
``` | ||
python train_AREL.py --option test --start_from_model data/save/AREL/model.pth | ||
``` | ||
|
||
## If you find this code useful, please cite the paper | ||
``` | ||
@InProceedings{xinwang-wenhuchen-ACL-2018, | ||
author = "Wang, Xin and Chen, Wenhu and Wang, Yuan-Fang and Wang, William Yang", | ||
title = "No Metrics Are Perfect: Adversarial Reward Learning for Visual Storytelling", | ||
booktitle = "Proceedings of the 56th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)", | ||
year = "2018", | ||
publisher = "Association for Computational Linguistics", | ||
pages = "899--909", | ||
location = "Melbourne, Australia", | ||
url = "http://aclweb.org/anthology/P18-1083" | ||
} | ||
``` | ||
|
||
## Acknowledgement | ||
* [VIST evaluation code by Licheng Yu](https://github.com/lichengunc/vist_eval) |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
import collections | ||
import time | ||
import sys | ||
import torch | ||
import torch.nn as nn | ||
from torch.autograd import Variable | ||
import numpy as np | ||
import logging | ||
from vist_eval.meteor.meteor import Meteor | ||
import misc.utils as utils | ||
|
||
|
||
def to_contiguous(tensor): | ||
if tensor.is_contiguous(): | ||
return tensor | ||
else: | ||
return tensor.contiguous() | ||
|
||
|
||
class ReinforceCriterion(nn.Module): | ||
def __init__(self, opt, dataset): | ||
super(ReinforceCriterion, self).__init__() | ||
self.dataset = dataset | ||
self.reward_type = opt.reward_type | ||
self.bleu = None | ||
|
||
if self.reward_type == 'METEOR': | ||
from vist_eval.meteor.meteor import Meteor | ||
self.reward_scorer = Meteor() | ||
elif self.reward_type == 'CIDEr': | ||
sys.path.append("cider") | ||
from pyciderevalcap.ciderD.ciderD import CiderD | ||
self.reward_scorer = CiderD(df=opt.cached_tokens) | ||
elif self.reward_type == 'Bleu_4' or self.reward_type == 'Bleu_3': | ||
from vist_eval.bleu.bleu import Bleu | ||
self.reward_scorer = Bleu(4) | ||
self.bleu = int(self.reward_type[-1]) - 1 | ||
elif self.reward_type == 'ROUGE_L': | ||
from vist_eval.rouge.rouge import Rouge | ||
self.reward_scorer = Rouge() | ||
else: | ||
err_msg = "{} scorer hasn't been implemented".format(self.reward_type) | ||
logging.error(err_msg) | ||
raise Exception(err_msg) | ||
|
||
def _cal_action_loss(self, log_probs, reward, mask): | ||
output = - log_probs * reward * mask | ||
output = torch.sum(output) / torch.sum(mask) | ||
return output | ||
|
||
def _cal_value_loss(self, reward, baseline, mask): | ||
output = (reward - baseline).pow(2) * mask | ||
output = torch.sum(output) / torch.sum(mask) | ||
return output | ||
|
||
def forward(self, seq, seq_log_probs, baseline, index, rewards=None): | ||
''' | ||
:param seq: (batch_size, 5, seq_length) | ||
:param seq_log_probs: (batch_size, 5, seq_length) | ||
:param baseline: (batch_size, 5, seq_length) | ||
:param indexes: (batch_size,) | ||
:param rewards: (batch_size, 5, seq_length) | ||
:return: | ||
''' | ||
if rewards is None: | ||
# compute the reward | ||
sents = utils.decode_story(self.dataset.get_vocab(), seq) | ||
|
||
rewards = [] | ||
batch_size = seq.size(0) | ||
for i, story in enumerate(sents): | ||
vid, _ = self.dataset.get_id(index[i]) | ||
GT_story = self.dataset.get_GT(index[i]) | ||
result = {vid: [story]} | ||
gt = {vid: [GT_story]} | ||
score, _ = self.reward_scorer.compute_score(gt, result) | ||
if self.bleu is not None: | ||
rewards.append(score[self.bleu]) | ||
else: | ||
rewards.append(score) | ||
rewards = torch.FloatTensor(rewards) # (batch_size,) | ||
avg_reward = rewards.mean() | ||
rewards = Variable(rewards.view(batch_size, 1, 1).expand_as(seq)).cuda() | ||
else: | ||
avg_reward = rewards.mean() | ||
rewards = rewards.view(-1, 5, 1) | ||
|
||
# get the mask | ||
mask = (seq > 0).float() # its size is supposed to be (batch_size, 5, seq_length) | ||
if mask.size(2) > 1: | ||
mask = torch.cat([mask.new(mask.size(0), mask.size(1), 1).fill_(1), mask[:, :, :-1]], 2).contiguous() | ||
else: | ||
mask.fill_(1) | ||
mask = Variable(mask) | ||
|
||
# compute the loss | ||
advantage = Variable(rewards.data - baseline.data) | ||
value_loss = self._cal_value_loss(rewards, baseline, mask) | ||
action_loss = self._cal_action_loss(seq_log_probs, advantage, mask) | ||
|
||
return action_loss + value_loss, avg_reward | ||
|
||
|
||
class LanguageModelCriterion(nn.Module): | ||
def __init__(self, weight=0.0): | ||
self.weight = weight | ||
super(LanguageModelCriterion, self).__init__() | ||
|
||
def forward(self, input, target, weights=None, compute_prob=False): | ||
if len(target.size()) == 3: # separate story | ||
input = input.view(-1, input.size(2), input.size(3)) | ||
target = target.view(-1, target.size(2)) | ||
|
||
seq_length = input.size(1) | ||
# truncate to the same size | ||
target = target[:, :input.size(1)] | ||
mask = (target > 0).float() | ||
mask = to_contiguous(torch.cat([Variable(mask.data.new(mask.size(0), 1).fill_(1)), mask[:, :-1]], 1)) | ||
|
||
# reshape the variables | ||
input = to_contiguous(input).view(-1, input.size(2)) | ||
target = to_contiguous(target).view(-1, 1) | ||
mask = mask.view(-1, 1) | ||
|
||
if weights is None: | ||
output = - input.gather(1, target) * mask | ||
else: | ||
output = - input.gather(1, target) * mask * to_contiguous(weights).view(-1, 1) | ||
|
||
if compute_prob: | ||
output = output.view(-1, seq_length) | ||
mask = mask.view(-1, seq_length) | ||
return output.sum(-1) / mask.sum(-1) | ||
|
||
output = torch.sum(output) / torch.sum(mask) | ||
|
||
entropy = -(torch.exp(input) * input).sum(-1) * mask | ||
entropy = torch.sum(entropy) / torch.sum(mask) | ||
|
||
return output + self.weight * entropy |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
* | ||
!.gitignore |
Oops, something went wrong.