Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
rodrigonogueira4 committed Jan 12, 2019
1 parent 224dc74 commit e6fdc7f
Show file tree
Hide file tree
Showing 14 changed files with 2,890 additions and 1 deletion.
29 changes: 29 additions & 0 deletions LICENSE
@@ -0,0 +1,29 @@
BSD 3-Clause License

Copyright (c) 2019, New York University (Kyunghyun Cho and Rodrigo Nogueira)
All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:

* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.

* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.

* Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
91 changes: 90 additions & 1 deletion README.md
@@ -1 +1,90 @@
# dl4marco-bert
# BERT as Passage-Reranker

## Introduction
**\*\*\*\*\* Most of the code in this repository was copied from the original
[BERT repository](https://github.com/google-research/bert).**\*\*\*\*\*

This repository contains the code to reproduce our entry to the [MSMARCO passage
ranking task](http://www.msmarco.org/leaders.aspx)

The paper describing our implementation is [here]().


MSMARCO Passage Re-Ranking Leaderboard (Jan 8th 2019) | Eval MRR@10 | Eval MRR@10
------------------------------------- | :------: | :------:
1st Place - BERT (this code) | **35.87** | **36.53**
2nd Place - IRNet | 28.06 | 27.80
3rd Place - Conv-KNRM | 27.12 | 29.02

## Download and extract the data
First, we need to download and extract MS MARCO and BERT files:
```
DATA_DIR=./data
mkdir ${DATA_DIR}
wget https://msmarco.blob.core.windows.net/msmarcoranking/triples.train.small.tar.gz -P ${DATA_DIR}
wget https://msmarco.blob.core.windows.net/msmarcoranking/top1000.dev.tar.gz -P ${DATA_DIR}
wget https://msmarco.blob.core.windows.net/msmarcoranking/top1000.eval.tar.gz -P ${DATA_DIR}
wget https://msmarco.blob.core.windows.net/msmarcoranking/qrels.dev.tsv -P ${DATA_DIR}
wget https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-24_H-1024_A-16.zip -P ${DATA_DIR}
tar -xvf ${DATA_DIR}/triples.train.small.tar.gz -C ${DATA_DIR}
tar -xvf ${DATA_DIR}/top1000.dev.tar.gz -C ${DATA_DIR}
tar -xvf ${DATA_DIR}/top1000.eval.tar.gz -C ${DATA_DIR}
unzip ${DATA_DIR}/uncased_L-24_H-1024_A-16.zip -d ${DATA_DIR}
```

## Convert MS MARCO to tfrecord
Next, we need to convert MS MARCO train, dev, and eval file to tfrecord files,
which will be later consumed by BERT.

```
mkdir ${DATA_DIR}/tfrecord
python convert_msmarco_to_tfrecord.py \
--tfrecord_folder=${DATA_DIR}/tfrecord \
--vocab_file=${DATA_DIR}/uncased_L-24_H-1024_A-16/vocab.txt \
--train_dataset_path=${DATA_DIR}/triples.train.small.tsv \
--dev_dataset_path=${DATA_DIR}/top1000.dev.tsv \
--eval_dataset_path=${DATA_DIR}/top1000.eval.tsv \
--dev_qrels_path=${DATA_DIR}/qrels.dev.tsv \
--max_query_length=64\
--max_seq_length=512 \
--num_eval_docs=1000
```

This conversion takes 30-40 hours. Alternatively, you can download the
[tfrecords file here]() (~23GB):

## Training
We can now start training. We highly recommend to use a TPU, which are free in
[Google's colab](https://colab.research.google.com/github/tensorflow/tpu/blob/master/tools/colab/bert_finetuning_with_cloud_tpus.ipynb). Otherwise, a modern V100 GPU with 16GB
cannot fit even a small batch size of 2 when training a BERT Large model.

```
TFRECORD_FOLDER=${DATA_DIR}/tfrecord
mkdir ${TFRECORD_FOLDER}
python run.py \
--data_dir=${TFRECORD_FOLDER} \
--vocab_file=${DATA_DIR}/uncased_L-24_H-1024_A-16/vocab.txt \
--bert_config_file=${DATA_DIR}/uncased_L-24_H-1024_A-16/bert_config.json \
--init_checkpoint=${DATA_DIR}/uncased_L-24_H-1024_A-16/bert_model.ckpt \
--output_dir=${DATA_DIR}/output \
--msmarco_output=True \
--do_train=True\
--do_eval=True\
--num_train_steps=400000\
--
```

Training for 400k iterations takes approximately 70 hours on a TPU v2.
Alternatively, you can [download the trained model used in our submission here](https://storage.googleapis.com/bert_msmarco_data/pretrained_models/trained_bert_large.zip) (~3.4GB).

#### How do I cite this work?
```
@article{nogueira2019passage,
title={Passage Re-ranking with BERT},
author={Nogueira, Rodrigo and Cho, Kyunghyun},
journal={arXiv preprint},
year={2019}
}
```
15 changes: 15 additions & 0 deletions __init__.py
@@ -0,0 +1,15 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

217 changes: 217 additions & 0 deletions convert_msmarco_to_tfrecord.py
@@ -0,0 +1,217 @@
"""
This code converts MS MARCO train, dev and eval tsv data into the tfrecord files
that will be consumed by BERT.
"""
import collections
import csv
import os
import re
import tensorflow as tf
import time
import tokenization


flags = tf.flags

FLAGS = flags.FLAGS


flags.DEFINE_string(
"tfrecord_folder", None,
"Folder where the tfrecord files will be writen.")

flags.DEFINE_string(
"vocab_file",
"./data/bert/uncased_L-24_H-1024_A-16/vocab.txt",
"The vocabulary file that the BERT model was trained on.")

flags.DEFINE_string(
"train_dataset_path",
"./data/triples.train.small.tsv",
"Path to the MSMARCO training dataset containing the tab separated "
"<query, positive_paragraph, negative_paragraph> tuples.")

flags.DEFINE_string(
"dev_dataset_path",
"./data/top1000.dev.tsv",
"Path to the MSMARCO training dataset containing the tab separated "
"<query, positive_paragraph, negative_paragraph> tuples.")

flags.DEFINE_string(
"eval_dataset_path",
"./data/top1000.eval.tsv",
"Path to the MSMARCO eval dataset containing the tab separated "
"<query, positive_paragraph, negative_paragraph> tuples.")

flags.DEFINE_string(
"dev_qrels_path",
"./data/qrels.dev.tsv",
"Path to the query_id relevant doc ids mapping.")

flags.DEFINE_integer(
"max_seq_length", 512,
"The maximum total input sequence length after WordPiece tokenization. "
"Sequences longer than this will be truncated, and sequences shorter "
"than this will be padded.")

flags.DEFINE_integer(
"max_query_length", 64,
"The maximum query sequence length after WordPiece tokenization. "
"Sequences longer than this will be truncated.")

flags.DEFINE_integer(
"num_eval_docs", 1000,
"The maximum number of docs per query for dev and eval sets.")


def write_to_tf_record(writer, tokenizer, query, docs, labels,
ids_file=None, query_id=None, doc_ids=None):
query = tokenization.convert_to_unicode(query)
query_token_ids = tokenization.convert_to_bert_input(
text=query, max_seq_length=FLAGS.max_query_length, tokenizer=tokenizer,
add_cls=True)

query_token_ids_tf = tf.train.Feature(
int64_list=tf.train.Int64List(value=query_token_ids))

for i, (doc_text, label) in enumerate(zip(docs, labels)):

doc_token_id = tokenization.convert_to_bert_input(
text=tokenization.convert_to_unicode(doc_text),
max_seq_length=FLAGS.max_seq_length - len(query_token_ids),
tokenizer=tokenizer,
add_cls=False)

doc_ids_tf = tf.train.Feature(
int64_list=tf.train.Int64List(value=doc_token_id))

labels_tf = tf.train.Feature(
int64_list=tf.train.Int64List(value=[label]))

features = tf.train.Features(feature={
'query_ids': query_token_ids_tf,
'doc_ids': doc_ids_tf,
'label': labels_tf,
})
example = tf.train.Example(features=features)
writer.write(example.SerializeToString())

if ids_file:
ids_file.write('\t'.join([query_id, doc_ids[i]]) + '\n')

def convert_eval_dataset(set_name, tokenizer):
print('Converting {} set to tfrecord...'.format(set_name))
start_time = time.time()

if set_name == 'dev':
dataset_path = FLAGS.dev_dataset_path
relevant_pairs = set()
with open(FLAGS.dev_qrels_path) as f:
for line in f:
query_id, _, doc_id, _ = line.strip().split('\t')
relevant_pairs.add('\t'.join([query_id, doc_id]))
else:
dataset_path = FLAGS.eval_dataset_path

queries_docs = collections.defaultdict(list)
query_ids = {}
with open(dataset_path, 'r') as f:
for i, line in enumerate(f):
query_id, doc_id, query, doc = line.strip().split('\t')
label = 0
if set_name == 'dev':
if '\t'.join([query_id, doc_id]) in relevant_pairs:
label = 1
queries_docs[query].append((doc_id, doc, label))
query_ids[query] = query_id

# Add fake paragraphs to the queries that have less than FLAGS.num_eval_docs.
queries = list(queries_docs.keys()) # Need to copy keys before iterating.
for query in queries:
docs = queries_docs[query]
docs += max(
0, FLAGS.num_eval_docs - len(docs)) * [('00000000', 'FAKE DOCUMENT', 0)]
queries_docs[query] = docs

assert len(
set(len(docs) == FLAGS.num_eval_docs for docs in queries_docs.values())) == 1, (
'Not all queries have {} docs'.format(FLAGS.num_eval_docs))

writer = tf.python_io.TFRecordWriter(
FLAGS.tfrecord_folder + '/dataset_' + set_name + '.tf')

query_doc_ids_path = (
FLAGS.tfrecord_folder + '/query_doc_ids_' + set_name + '.txt')
with open(query_doc_ids_path, 'w') as ids_file:
for i, (query, doc_ids_docs) in enumerate(queries_docs.items()):
doc_ids, docs, labels = zip(*doc_ids_docs)
query_id = query_ids[query]

write_to_tf_record(writer=writer,
tokenizer=tokenizer,
query=query,
docs=docs,
labels=labels,
ids_file=ids_file,
query_id=query_id,
doc_ids=doc_ids)

if i % 100 == 0:
print('Writing {} set, query {} of {}'.format(
set_name, i, len(queries_docs)))
time_passed = time.time() - start_time
hours_remaining = (
len(queries_docs) - i) * time_passed / (max(1.0, i) * 3600)
print('Estimated hours remaining to write the {} set: {}'.format(
set_name, hours_remaining))
writer.close()


def convert_train_dataset(tokenizer):
print('Converting to Train to tfrecord...')

start_time = time.time()

print('Counting number of examples...')
num_lines = sum(1 for line in open(FLAGS.train_dataset_path, 'r'))
print('{} examples found.'.format(num_lines))
writer = tf.python_io.TFRecordWriter(
FLAGS.tfrecord_folder + '/dataset_train.tf')

with open(FLAGS.train_dataset_path, 'r') as f:
for i, line in enumerate(f):
if i % 1000 == 0:
time_passed = int(time.time() - start_time)
print('Processed training set, line {} of {} in {} sec'.format(
i, num_lines, time_passed))
hours_remaining = (num_lines - i) * time_passed / (max(1.0, i) * 3600)
print('Estimated hours remaining to write the training set: {}'.format(
hours_remaining))

query, positive_doc, negative_doc = line.rstrip().split('\t')

write_to_tf_record(writer=writer,
tokenizer=tokenizer,
query=query,
docs=[positive_doc, negative_doc],
labels=[1, 0])

writer.close()


def main():

print('Loading Tokenizer...')
tokenizer = tokenization.FullTokenizer(
vocab_file=FLAGS.vocab_file, do_lower_case=True)

if not os.path.exists(FLAGS.tfrecord_folder):
os.mkdir(FLAGS.tfrecord_folder)

convert_train_dataset(tokenizer=tokenizer)
convert_eval_dataset(set_name='dev', tokenizer=tokenizer)
convert_eval_dataset(set_name='eval', tokenizer=tokenizer)
print('Done!')

if __name__ == '__main__':
main()

0 comments on commit e6fdc7f

Please sign in to comment.