Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-xw committed Jul 15, 2018
1 parent 8bf4fb2 commit 4930f91
Show file tree
Hide file tree
Showing 97 changed files with 7,534 additions and 2 deletions.
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
*.pyc
test.ipynb
.ipynb_checkpoints
.idea
*history

3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "cider"]
path = cider
url = https://github.com/littlekobe/cider
81 changes: 79 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,79 @@
# AREL-for-Visual-Storytelling
Code coming soon.
# No Metrics Are Perfect: Adversarial REward Learning for Visual Storytelling

In this paper, we not only introduce a novel adversarial reward learning algorithm to generate more human-like stories given image sequences, but also empirically analyze the limitations of the automatic metrics for story evaluation.

For more details, please check the latest version of the paper: [https://arxiv.org/abs/1804.09160](https://arxiv.org/abs/1804.09160).

## Prerequisites
- Python 2.7
- PyTorch 0.3
- TensorFlow (optional, only using the fantastic tensorboard)
- cuda & cudnn

## Usage
### 1. Setup
Clone this github repository recursively:

```
git clone --recursive https://github.com/littlekobe/Visual-Storytelling.git ./
```

Download the preprocessed ResNet-152 features [here](http://nlp.cs.ucsb.edu/data/VIST_resnet_features.zip) and unzip it into `DATADIR/resnet_features`.

### 2. Supervised Learning
We use cross entropy loss to warm start the model first:

```
python train.py --id XE --data_dir DATADIR --start_rl -1
```

Check the file `opt.py` for more options, where you can play with some other settings.

### 3. AREL Learning
To train an AREL model, run

```
python train_AREL.py --id AREL --start_from_model PRETRAINED_MODEL
```

Note that `PRETRAINED_MODEL` can be `data/save/XE/model.pth` or some other saved models.
Check `opt.py` for more information.

### 4. Monitor your training
TensorBoard is used to monitor the training process. Suppose you set the option `checkpoint_path` as `data/save`, then run

```
tensorboard --logdir data/save/tensorboard
```

And then open your browser and go to `[IP address]:6006` (the default port for tensorboard is `6006`).

### 5. Testing
To test the model's performance, run

```
python train.py --option test --start_from_model data/save/XE/model.pth
```

or

```
python train_AREL.py --option test --start_from_model data/save/AREL/model.pth
```

## If you find this code useful, please cite the paper
```
@InProceedings{xinwang-wenhuchen-ACL-2018,
author = "Wang, Xin and Chen, Wenhu and Wang, Yuan-Fang and Wang, William Yang",
title = "No Metrics Are Perfect: Adversarial Reward Learning for Visual Storytelling",
booktitle = "Proceedings of the 56th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)",
year = "2018",
publisher = "Association for Computational Linguistics",
pages = "899--909",
location = "Melbourne, Australia",
url = "http://aclweb.org/anthology/P18-1083"
}
```

## Acknowledgement
* [VIST evaluation code by Licheng Yu](https://github.com/lichengunc/vist_eval)
Binary file added VIST/VIST-train-words.p
Binary file not shown.
Binary file added VIST/description.h5
Binary file not shown.
Binary file added VIST/embedding.npy
Binary file not shown.
Binary file added VIST/full_story.h5
Binary file not shown.
Binary file added VIST/story.h5
Binary file not shown.
1 change: 1 addition & 0 deletions VIST/story_line.json

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions cider
Submodule cider added at c25ac5
140 changes: 140 additions & 0 deletions criterion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import collections
import time
import sys
import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import logging
from vist_eval.meteor.meteor import Meteor
import misc.utils as utils


def to_contiguous(tensor):
if tensor.is_contiguous():
return tensor
else:
return tensor.contiguous()


class ReinforceCriterion(nn.Module):
def __init__(self, opt, dataset):
super(ReinforceCriterion, self).__init__()
self.dataset = dataset
self.reward_type = opt.reward_type
self.bleu = None

if self.reward_type == 'METEOR':
from vist_eval.meteor.meteor import Meteor
self.reward_scorer = Meteor()
elif self.reward_type == 'CIDEr':
sys.path.append("cider")
from pyciderevalcap.ciderD.ciderD import CiderD
self.reward_scorer = CiderD(df=opt.cached_tokens)
elif self.reward_type == 'Bleu_4' or self.reward_type == 'Bleu_3':
from vist_eval.bleu.bleu import Bleu
self.reward_scorer = Bleu(4)
self.bleu = int(self.reward_type[-1]) - 1
elif self.reward_type == 'ROUGE_L':
from vist_eval.rouge.rouge import Rouge
self.reward_scorer = Rouge()
else:
err_msg = "{} scorer hasn't been implemented".format(self.reward_type)
logging.error(err_msg)
raise Exception(err_msg)

def _cal_action_loss(self, log_probs, reward, mask):
output = - log_probs * reward * mask
output = torch.sum(output) / torch.sum(mask)
return output

def _cal_value_loss(self, reward, baseline, mask):
output = (reward - baseline).pow(2) * mask
output = torch.sum(output) / torch.sum(mask)
return output

def forward(self, seq, seq_log_probs, baseline, index, rewards=None):
'''
:param seq: (batch_size, 5, seq_length)
:param seq_log_probs: (batch_size, 5, seq_length)
:param baseline: (batch_size, 5, seq_length)
:param indexes: (batch_size,)
:param rewards: (batch_size, 5, seq_length)
:return:
'''
if rewards is None:
# compute the reward
sents = utils.decode_story(self.dataset.get_vocab(), seq)

rewards = []
batch_size = seq.size(0)
for i, story in enumerate(sents):
vid, _ = self.dataset.get_id(index[i])
GT_story = self.dataset.get_GT(index[i])
result = {vid: [story]}
gt = {vid: [GT_story]}
score, _ = self.reward_scorer.compute_score(gt, result)
if self.bleu is not None:
rewards.append(score[self.bleu])
else:
rewards.append(score)
rewards = torch.FloatTensor(rewards) # (batch_size,)
avg_reward = rewards.mean()
rewards = Variable(rewards.view(batch_size, 1, 1).expand_as(seq)).cuda()
else:
avg_reward = rewards.mean()
rewards = rewards.view(-1, 5, 1)

# get the mask
mask = (seq > 0).float() # its size is supposed to be (batch_size, 5, seq_length)
if mask.size(2) > 1:
mask = torch.cat([mask.new(mask.size(0), mask.size(1), 1).fill_(1), mask[:, :, :-1]], 2).contiguous()
else:
mask.fill_(1)
mask = Variable(mask)

# compute the loss
advantage = Variable(rewards.data - baseline.data)
value_loss = self._cal_value_loss(rewards, baseline, mask)
action_loss = self._cal_action_loss(seq_log_probs, advantage, mask)

return action_loss + value_loss, avg_reward


class LanguageModelCriterion(nn.Module):
def __init__(self, weight=0.0):
self.weight = weight
super(LanguageModelCriterion, self).__init__()

def forward(self, input, target, weights=None, compute_prob=False):
if len(target.size()) == 3: # separate story
input = input.view(-1, input.size(2), input.size(3))
target = target.view(-1, target.size(2))

seq_length = input.size(1)
# truncate to the same size
target = target[:, :input.size(1)]
mask = (target > 0).float()
mask = to_contiguous(torch.cat([Variable(mask.data.new(mask.size(0), 1).fill_(1)), mask[:, :-1]], 1))

# reshape the variables
input = to_contiguous(input).view(-1, input.size(2))
target = to_contiguous(target).view(-1, 1)
mask = mask.view(-1, 1)

if weights is None:
output = - input.gather(1, target) * mask
else:
output = - input.gather(1, target) * mask * to_contiguous(weights).view(-1, 1)

if compute_prob:
output = output.view(-1, seq_length)
mask = mask.view(-1, seq_length)
return output.sum(-1) / mask.sum(-1)

output = torch.sum(output) / torch.sum(mask)

entropy = -(torch.exp(input) * input).sum(-1) * mask
entropy = torch.sum(entropy) / torch.sum(mask)

return output + self.weight * entropy
2 changes: 2 additions & 0 deletions data/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*
!.gitignore
Loading

0 comments on commit 4930f91

Please sign in to comment.