Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 4198859
Showing
17 changed files
with
1,912 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
*.pyc | ||
.vscode | ||
out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
FROM pytorch/pytorch:1.0.1-cuda10.0-cudnn7-devel | ||
|
||
RUN apt update && apt install -y rsync | ||
|
||
ADD requirements.txt /tmp/requirements.txt | ||
RUN pip install -r /tmp/requirements.txt | ||
RUN git clone https://github.com/nvidia/apex && \ | ||
cd apex && \ | ||
python setup.py install --cuda_ext --cpp_ext && \ | ||
rm -rf /apex | ||
|
||
CMD "/bin/bash" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# Improving Neural Story Generation by Targeted Common Sense Grounding | ||
This repository contains the code to replicate the paper "Improving Neural Story Generation by Targeted Common Sense Grounding". | ||
|
||
## Environment Setup | ||
We use Docker to ensure a consistent development environment. | ||
First, ensure Docker and NVIDIA-Docker is installed. | ||
|
||
Build Docker image: | ||
``` | ||
docker build -t storygen . | ||
``` | ||
|
||
Run bash shell in image: | ||
``` | ||
docker run --rm -w /src -v $(pwd):/src storygen /bin/bash | ||
``` | ||
Now you can run scripts within the shell. | ||
|
||
For all scripts you will need to download the corresponding datasets before running. | ||
|
||
## Training | ||
To train a model, run the following. See `--help` for CLI argument options. | ||
``` | ||
python train.py [experiment_name] | ||
``` | ||
|
||
## Evaluation | ||
Generate text from model | ||
``` | ||
python -m analysis.generate.py | ||
``` | ||
|
||
Compute perplexity from model | ||
``` | ||
python -m analysis.eval_ppl.py | ||
``` | ||
|
||
Compute prompt ranking accuracy from model | ||
``` | ||
python -m analysis.eval_prompt_rank.py | ||
``` | ||
|
||
Compute common sense reasoning accuracy from model | ||
``` | ||
python -m analysis.eval_csr.py | ||
``` | ||
|
||
## Attribution | ||
If you use this code in your research, cite our paper via the following BibTeX. | ||
|
||
``` | ||
@inproceedings{mao2019emnlp, | ||
title={Improving Neural Story Generation by Targeted Common Sense Grounding}, | ||
author={Mao, Huanru Henry and Majumder, Bodhisattwa Prasad and McAuley, Julian and Cottrell, Garrison W.}, | ||
booktitle={EMNLP}, | ||
year={2019} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
""" | ||
Calculates the SWAG/Story Cloze accuracy given a model. | ||
""" | ||
import pickle | ||
import os, re | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import numpy as np | ||
import argparse | ||
from tokenizer import GPT2Tokenizer | ||
from random import randint | ||
from datetime import datetime | ||
from pytorch_transformers import GPT2LMHeadModel, GPT2Config | ||
from data.util import prepare_dataset | ||
from train import compute_ranking_lp | ||
|
||
def main(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--model-path', type=str, help='pretrained model path to local checkpoint') | ||
parser.add_argument('--dataset', type=str, default='swag') | ||
parser.add_argument('--data-dir', type=str, default='../data') | ||
parser.add_argument("--batch-size", type=int, default=32) | ||
parser.add_argument("--seq-len", type=int, default=128) | ||
args = parser.parse_args() | ||
print(args) | ||
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
model = GPT2LMHeadModel.from_pretrained('gpt2', cache_dir='out/cache') | ||
|
||
if args.model_path: | ||
if args.model_path == 'random': | ||
model.apply(model.init_weights) | ||
else: | ||
state = torch.load(args.model_path, map_location='cpu') | ||
model.load_state_dict(state) | ||
|
||
tokenizer = GPT2Tokenizer(os.path.join(args.data_dir, 'gpt2-vocab.json'), os.path.join(args.data_dir, 'gpt2-merges.txt')) | ||
|
||
model.half().to(device) | ||
model.eval() | ||
print('Model loaded.') | ||
|
||
loader = prepare_dataset(args.data_dir, args.dataset, tokenizer, args.batch_size, args.seq_len, args.batch_size, args.seq_len, | ||
distributed=False, make_train=False, make_val=not args.test, make_test=args.test)[0] | ||
print('Data loaded.') | ||
|
||
correct = 0 | ||
total = 0 | ||
|
||
outputs = [] | ||
|
||
with torch.no_grad(): | ||
for tokens, mask in loader: | ||
lprobs = compute_ranking_lp(device, model, tokens, mask, random_shift=False) | ||
chosen = lprobs.argmax(dim=-1) | ||
|
||
total += int(chosen.size(0)) | ||
|
||
if args.test: | ||
print('Collecting results...', total) | ||
outputs += chosen.tolist() | ||
else: | ||
correct += (chosen == 0).sum().item() | ||
print('Accuracy', correct / total) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,181 @@ | ||
""" | ||
Calculates the perplexity given a model. | ||
""" | ||
import pickle | ||
import os, re | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import numpy as np | ||
import argparse | ||
from tokenizer import GPT2Tokenizer | ||
from random import randint | ||
from pytorch_transformers import GPT2LMHeadModel, GPT2Config | ||
from data import PromptDataset, TextDataset | ||
from data.util import wp_preprocess, compose | ||
|
||
def compute_logprobs(token_tensor, model): | ||
input_tokens = token_tensor[:, :-1] | ||
target_tokens = token_tensor[:, 1:] | ||
|
||
logits, _ = model(input_tokens) | ||
lprobs = torch.log_softmax(logits, dim=-1) | ||
# Extract the probability of the target token at each position | ||
lprobs = lprobs.gather(-1, target_tokens.unsqueeze(-1)).squeeze(-1) | ||
return lprobs | ||
|
||
def word_level_ppl(target_tokens, lprobs, tokenizer, raw_token=None): | ||
assert len(target_tokens) == len(lprobs), (len(target_tokens), len(lprobs)) | ||
|
||
# Convert BPE lprobs to word lprobs | ||
word_lprobs = [] | ||
cur_lp = [] | ||
new_add = '' | ||
i = 0 | ||
start = False | ||
|
||
for token, lp in zip(target_tokens, lprobs): | ||
# Follow how it's detokenized. | ||
chars = tokenizer.decoder[token] | ||
new_add += bytearray([tokenizer.byte_decoder[c] for c in chars]).decode('utf-8', errors=tokenizer.errors) | ||
cur_lp.append(lp) | ||
|
||
if not start: | ||
# Wait for end of prompt | ||
start = '---\n' in new_add | ||
if start: | ||
cur_lp = [] | ||
new_add = '' | ||
continue | ||
|
||
# Reverse preprocessing | ||
text = new_add | ||
text = re.sub('"', ' " ', text) | ||
text = re.sub('(\'|\.|\,|\:|\?|\!|;)', ' \g<1>', text) | ||
# Fix contraction | ||
text = text.replace("n 't", " n't") | ||
text = text.replace('\n', ' <newline> ') | ||
text = re.sub(' +', ' ', text) | ||
text = text.replace('. . .', '...') | ||
# Edge cases | ||
text = text.replace("ca n't-", "can't-") | ||
text = text.replace("St .", "St.") | ||
text = re.sub(r"//www \.(.*) \.(.*)/", r"//www\.\g<1>\.\g<1>\/", text) | ||
|
||
tokens = text.strip().split(' ') | ||
|
||
# Once a new word is starting to be formed, remove the previous one | ||
if len(tokens) > i + 1: | ||
# Token length changed, which means new word has been added. | ||
# Grab all but the last prob (excluding the unformed next word) | ||
word_lprobs.append(sum(cur_lp[:-1])) | ||
cur_lp = cur_lp[-1:] | ||
i += 1 | ||
|
||
# Add final token | ||
word_lprobs.append(sum(cur_lp)) | ||
|
||
token_diff = None | ||
if raw_token is not None: | ||
token_diff = abs(len(word_lprobs) - len(raw_token)) | ||
|
||
word_lprobs = torch.tensor(word_lprobs) | ||
ppl = torch.exp(-word_lprobs.mean()).item() | ||
|
||
if ppl == float('inf'): | ||
raise Exception('Infinite PPL', raw_token) | ||
|
||
if ppl > 1000: | ||
print(ppl) | ||
print(word_lprobs) | ||
print(len(word_lprobs), len(raw_token)) | ||
|
||
raise Exception('Large PPL', tokens, raw_token) | ||
return ppl, token_diff | ||
|
||
def run_model(): | ||
parser = argparse.ArgumentParser(description="") | ||
parser.add_argument('--model-path', type=str, help='pretrained model path to local checkpoint') | ||
parser.add_argument("--batch-size", type=int, default=40) | ||
parser.add_argument('--data-dir', type=str, default='../data') | ||
parser.add_argument('--dataset', type=str, default='../data') | ||
parser.add_argument("--test", action='store_true', default=False) | ||
args = parser.parse_args() | ||
print(args) | ||
|
||
if args.batch_size == -1: | ||
args.batch_size = 1 | ||
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
model = GPT2LMHeadModel.from_pretrained('gpt2', cache_dir='out/cache') | ||
|
||
if args.model_path: | ||
state = torch.load(args.model_path, map_location='cpu') | ||
model.load_state_dict(state) | ||
|
||
tokenizer = GPT2Tokenizer(os.path.join(args.data_dir, 'gpt2-vocab.json'), os.path.join(args.data_dir, 'gpt2-merges.txt')) | ||
# Hack to allow tokenizing longer sequences. | ||
tokenizer.max_len = int(1e12) | ||
|
||
model.half().to(device) | ||
model.eval() | ||
print('Model loaded.') | ||
|
||
d_val = PromptDataset( | ||
os.path.join(args.data_dir, 'writingPrompts/{}.wp_source'.format('test' if args.test else 'valid')), | ||
os.path.join(args.data_dir, 'writingPrompts/{}.wp_target'.format('test' if args.test else 'valid')), | ||
wp_preprocess | ||
) | ||
d_val_raw = PromptDataset( | ||
os.path.join(args.data_dir, 'writingPrompts/{}.wp_source'.format('test' if args.test else 'valid')), | ||
os.path.join(args.data_dir, 'writingPrompts/{}.wp_target'.format('test' if args.test else 'valid')) | ||
) | ||
|
||
print('Data loaded.') | ||
|
||
print('Running evaluation...') | ||
with torch.no_grad(): | ||
ppls = [] | ||
word_ppls = [] | ||
token_diffs = [] | ||
num_errs = 0 | ||
|
||
batch = [] | ||
for sample_id, (text, check_text) in enumerate(zip(d_val, d_val_raw)): | ||
bpe_tokens = [tokenizer.encoder['<|endoftext|>']] + tokenizer.encode(text) | ||
# (This limit applies to GPT2) | ||
bpe_tokens = bpe_tokens[:1025] | ||
# Pad | ||
batch.append((bpe_tokens + [0] * (1025 - len(bpe_tokens)), len(bpe_tokens), check_text.split('---\n')[1].split(' '))) | ||
|
||
if len(batch) == args.batch_size or len(word_ppls) == len(d_val) - 1: | ||
x, x_lens, raw_tokens = zip(*batch) | ||
token_tensor = torch.tensor(x, dtype=torch.long, device=device) | ||
|
||
# Compute log probs | ||
lps = compute_logprobs(token_tensor, model) | ||
token_tensor = token_tensor.cpu().numpy() | ||
|
||
# Compute individually | ||
for i in range(lps.shape[0]): | ||
try: | ||
# Mask out some tokens | ||
target_tokens = token_tensor[i, 1:x_lens[i]] | ||
log_probs = lps[i, :x_lens[i] - 1] | ||
ppl, token_diff = word_level_ppl(target_tokens, log_probs.cpu().float().numpy(), tokenizer, raw_tokens[i]) | ||
token_diffs.append(token_diff) | ||
word_ppls.append(ppl) | ||
ppls.append(torch.exp(-log_probs.mean()).item()) | ||
except Exception as e: | ||
print('Skipping anomaly.') | ||
print(e) | ||
num_errs += 1 | ||
print('World Level PPL {:.2f} BPE PPL {:.2f} Diff {:.2f} Done: {:.2f}% Skip {}'.format( | ||
np.mean(word_ppls), np.mean(ppls), np.mean(token_diffs), | ||
sample_id / len(d_val) * 100, num_errs | ||
)) | ||
batch = [] | ||
|
||
if __name__ == '__main__': | ||
run_model() |
Oops, something went wrong.