Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
calclavia committed Aug 24, 2019
0 parents commit 4198859
Show file tree
Hide file tree
Showing 17 changed files with 1,912 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .gitignore
@@ -0,0 +1,3 @@
*.pyc
.vscode
out
12 changes: 12 additions & 0 deletions Dockerfile
@@ -0,0 +1,12 @@
FROM pytorch/pytorch:1.0.1-cuda10.0-cudnn7-devel

RUN apt update && apt install -y rsync

ADD requirements.txt /tmp/requirements.txt
RUN pip install -r /tmp/requirements.txt
RUN git clone https://github.com/nvidia/apex && \
cd apex && \
python setup.py install --cuda_ext --cpp_ext && \
rm -rf /apex

CMD "/bin/bash"
58 changes: 58 additions & 0 deletions README.md
@@ -0,0 +1,58 @@
# Improving Neural Story Generation by Targeted Common Sense Grounding
This repository contains the code to replicate the paper "Improving Neural Story Generation by Targeted Common Sense Grounding".

## Environment Setup
We use Docker to ensure a consistent development environment.
First, ensure Docker and NVIDIA-Docker is installed.

Build Docker image:
```
docker build -t storygen .
```

Run bash shell in image:
```
docker run --rm -w /src -v $(pwd):/src storygen /bin/bash
```
Now you can run scripts within the shell.

For all scripts you will need to download the corresponding datasets before running.

## Training
To train a model, run the following. See `--help` for CLI argument options.
```
python train.py [experiment_name]
```

## Evaluation
Generate text from model
```
python -m analysis.generate.py
```

Compute perplexity from model
```
python -m analysis.eval_ppl.py
```

Compute prompt ranking accuracy from model
```
python -m analysis.eval_prompt_rank.py
```

Compute common sense reasoning accuracy from model
```
python -m analysis.eval_csr.py
```

## Attribution
If you use this code in your research, cite our paper via the following BibTeX.

```
@inproceedings{mao2019emnlp,
title={Improving Neural Story Generation by Targeted Common Sense Grounding},
author={Mao, Huanru Henry and Majumder, Bodhisattwa Prasad and McAuley, Julian and Cottrell, Garrison W.},
booktitle={EMNLP},
year={2019}
}
```
70 changes: 70 additions & 0 deletions analysis/eval_csr.py
@@ -0,0 +1,70 @@
"""
Calculates the SWAG/Story Cloze accuracy given a model.
"""
import pickle
import os, re
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import argparse
from tokenizer import GPT2Tokenizer
from random import randint
from datetime import datetime
from pytorch_transformers import GPT2LMHeadModel, GPT2Config
from data.util import prepare_dataset
from train import compute_ranking_lp

def main():
parser = argparse.ArgumentParser()
parser.add_argument('--model-path', type=str, help='pretrained model path to local checkpoint')
parser.add_argument('--dataset', type=str, default='swag')
parser.add_argument('--data-dir', type=str, default='../data')
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument("--seq-len", type=int, default=128)
args = parser.parse_args()
print(args)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = GPT2LMHeadModel.from_pretrained('gpt2', cache_dir='out/cache')

if args.model_path:
if args.model_path == 'random':
model.apply(model.init_weights)
else:
state = torch.load(args.model_path, map_location='cpu')
model.load_state_dict(state)

tokenizer = GPT2Tokenizer(os.path.join(args.data_dir, 'gpt2-vocab.json'), os.path.join(args.data_dir, 'gpt2-merges.txt'))

model.half().to(device)
model.eval()
print('Model loaded.')

loader = prepare_dataset(args.data_dir, args.dataset, tokenizer, args.batch_size, args.seq_len, args.batch_size, args.seq_len,
distributed=False, make_train=False, make_val=not args.test, make_test=args.test)[0]
print('Data loaded.')

correct = 0
total = 0

outputs = []

with torch.no_grad():
for tokens, mask in loader:
lprobs = compute_ranking_lp(device, model, tokens, mask, random_shift=False)
chosen = lprobs.argmax(dim=-1)

total += int(chosen.size(0))

if args.test:
print('Collecting results...', total)
outputs += chosen.tolist()
else:
correct += (chosen == 0).sum().item()
print('Accuracy', correct / total)


if __name__ == '__main__':
main()
181 changes: 181 additions & 0 deletions analysis/eval_ppl.py
@@ -0,0 +1,181 @@
"""
Calculates the perplexity given a model.
"""
import pickle
import os, re
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import argparse
from tokenizer import GPT2Tokenizer
from random import randint
from pytorch_transformers import GPT2LMHeadModel, GPT2Config
from data import PromptDataset, TextDataset
from data.util import wp_preprocess, compose

def compute_logprobs(token_tensor, model):
input_tokens = token_tensor[:, :-1]
target_tokens = token_tensor[:, 1:]

logits, _ = model(input_tokens)
lprobs = torch.log_softmax(logits, dim=-1)
# Extract the probability of the target token at each position
lprobs = lprobs.gather(-1, target_tokens.unsqueeze(-1)).squeeze(-1)
return lprobs

def word_level_ppl(target_tokens, lprobs, tokenizer, raw_token=None):
assert len(target_tokens) == len(lprobs), (len(target_tokens), len(lprobs))

# Convert BPE lprobs to word lprobs
word_lprobs = []
cur_lp = []
new_add = ''
i = 0
start = False

for token, lp in zip(target_tokens, lprobs):
# Follow how it's detokenized.
chars = tokenizer.decoder[token]
new_add += bytearray([tokenizer.byte_decoder[c] for c in chars]).decode('utf-8', errors=tokenizer.errors)
cur_lp.append(lp)

if not start:
# Wait for end of prompt
start = '---\n' in new_add
if start:
cur_lp = []
new_add = ''
continue

# Reverse preprocessing
text = new_add
text = re.sub('"', ' " ', text)
text = re.sub('(\'|\.|\,|\:|\?|\!|;)', ' \g<1>', text)
# Fix contraction
text = text.replace("n 't", " n't")
text = text.replace('\n', ' <newline> ')
text = re.sub(' +', ' ', text)
text = text.replace('. . .', '...')
# Edge cases
text = text.replace("ca n't-", "can't-")
text = text.replace("St .", "St.")
text = re.sub(r"//www \.(.*) \.(.*)/", r"//www\.\g<1>\.\g<1>\/", text)

tokens = text.strip().split(' ')

# Once a new word is starting to be formed, remove the previous one
if len(tokens) > i + 1:
# Token length changed, which means new word has been added.
# Grab all but the last prob (excluding the unformed next word)
word_lprobs.append(sum(cur_lp[:-1]))
cur_lp = cur_lp[-1:]
i += 1

# Add final token
word_lprobs.append(sum(cur_lp))

token_diff = None
if raw_token is not None:
token_diff = abs(len(word_lprobs) - len(raw_token))

word_lprobs = torch.tensor(word_lprobs)
ppl = torch.exp(-word_lprobs.mean()).item()

if ppl == float('inf'):
raise Exception('Infinite PPL', raw_token)

if ppl > 1000:
print(ppl)
print(word_lprobs)
print(len(word_lprobs), len(raw_token))

raise Exception('Large PPL', tokens, raw_token)
return ppl, token_diff

def run_model():
parser = argparse.ArgumentParser(description="")
parser.add_argument('--model-path', type=str, help='pretrained model path to local checkpoint')
parser.add_argument("--batch-size", type=int, default=40)
parser.add_argument('--data-dir', type=str, default='../data')
parser.add_argument('--dataset', type=str, default='../data')
parser.add_argument("--test", action='store_true', default=False)
args = parser.parse_args()
print(args)

if args.batch_size == -1:
args.batch_size = 1

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = GPT2LMHeadModel.from_pretrained('gpt2', cache_dir='out/cache')

if args.model_path:
state = torch.load(args.model_path, map_location='cpu')
model.load_state_dict(state)

tokenizer = GPT2Tokenizer(os.path.join(args.data_dir, 'gpt2-vocab.json'), os.path.join(args.data_dir, 'gpt2-merges.txt'))
# Hack to allow tokenizing longer sequences.
tokenizer.max_len = int(1e12)

model.half().to(device)
model.eval()
print('Model loaded.')

d_val = PromptDataset(
os.path.join(args.data_dir, 'writingPrompts/{}.wp_source'.format('test' if args.test else 'valid')),
os.path.join(args.data_dir, 'writingPrompts/{}.wp_target'.format('test' if args.test else 'valid')),
wp_preprocess
)
d_val_raw = PromptDataset(
os.path.join(args.data_dir, 'writingPrompts/{}.wp_source'.format('test' if args.test else 'valid')),
os.path.join(args.data_dir, 'writingPrompts/{}.wp_target'.format('test' if args.test else 'valid'))
)

print('Data loaded.')

print('Running evaluation...')
with torch.no_grad():
ppls = []
word_ppls = []
token_diffs = []
num_errs = 0

batch = []
for sample_id, (text, check_text) in enumerate(zip(d_val, d_val_raw)):
bpe_tokens = [tokenizer.encoder['<|endoftext|>']] + tokenizer.encode(text)
# (This limit applies to GPT2)
bpe_tokens = bpe_tokens[:1025]
# Pad
batch.append((bpe_tokens + [0] * (1025 - len(bpe_tokens)), len(bpe_tokens), check_text.split('---\n')[1].split(' ')))

if len(batch) == args.batch_size or len(word_ppls) == len(d_val) - 1:
x, x_lens, raw_tokens = zip(*batch)
token_tensor = torch.tensor(x, dtype=torch.long, device=device)

# Compute log probs
lps = compute_logprobs(token_tensor, model)
token_tensor = token_tensor.cpu().numpy()

# Compute individually
for i in range(lps.shape[0]):
try:
# Mask out some tokens
target_tokens = token_tensor[i, 1:x_lens[i]]
log_probs = lps[i, :x_lens[i] - 1]
ppl, token_diff = word_level_ppl(target_tokens, log_probs.cpu().float().numpy(), tokenizer, raw_tokens[i])
token_diffs.append(token_diff)
word_ppls.append(ppl)
ppls.append(torch.exp(-log_probs.mean()).item())
except Exception as e:
print('Skipping anomaly.')
print(e)
num_errs += 1
print('World Level PPL {:.2f} BPE PPL {:.2f} Diff {:.2f} Done: {:.2f}% Skip {}'.format(
np.mean(word_ppls), np.mean(ppls), np.mean(token_diffs),
sample_id / len(d_val) * 100, num_errs
))
batch = []

if __name__ == '__main__':
run_model()

0 comments on commit 4198859

Please sign in to comment.