Skip to content

Commit

Permalink
Merge pull request #29 from huggingface/first-release
Browse files Browse the repository at this point in the history
First release
  • Loading branch information
thomwolf committed Nov 17, 2018
2 parents 02173a1 + 47a7d4e commit 4132a02
Show file tree
Hide file tree
Showing 23 changed files with 6,941 additions and 1,128 deletions.
31 changes: 0 additions & 31 deletions CONTRIBUTING.md

This file was deleted.

401 changes: 326 additions & 75 deletions README.md

Large diffs are not rendered by default.

15 changes: 0 additions & 15 deletions __init__.py

This file was deleted.

2 changes: 2 additions & 0 deletions bin/pytorch_pretrained_bert
@@ -0,0 +1,2 @@
#!/bin/sh
python -m pytorch_pretrained_bert "$@"
33 changes: 10 additions & 23 deletions extract_features.py → examples/extract_features.py
Expand Up @@ -19,18 +19,17 @@
from __future__ import print_function

import argparse
import codecs
import collections
import logging
import json
import re

import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from torch.utils.data import TensorDataset, DataLoader, SequentialSampler
from torch.utils.data.distributed import DistributedSampler

import tokenization
from modeling import BertConfig, BertModel
from pytorch_pretrained_bert.tokenization import convert_to_unicode, BertTokenizer
from pytorch_pretrained_bert.modeling import BertModel

logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S',
Expand Down Expand Up @@ -171,7 +170,7 @@ def read_examples(input_file):
unique_id = 0
with open(input_file, "r") as reader:
while True:
line = tokenization.convert_to_unicode(reader.readline())
line = convert_to_unicode(reader.readline())
if not line:
break
line = line.strip()
Expand All @@ -194,23 +193,16 @@ def main():

## Required parameters
parser.add_argument("--input_file", default=None, type=str, required=True)
parser.add_argument("--vocab_file", default=None, type=str, required=True,
help="The vocabulary file that the BERT model was trained on.")
parser.add_argument("--output_file", default=None, type=str, required=True)
parser.add_argument("--bert_config_file", default=None, type=str, required=True,
help="The config json file corresponding to the pre-trained BERT model. "
"This specifies the model architecture.")
parser.add_argument("--init_checkpoint", default=None, type=str, required=True,
help="Initial checkpoint (usually from a pre-trained BERT model).")
parser.add_argument("--bert_model", default=None, type=str, required=True,
help="Bert pre-trained model selected in the list: bert-base-uncased, "
"bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")

## Other parameters
parser.add_argument("--layers", default="-1,-2,-3,-4", type=str)
parser.add_argument("--max_seq_length", default=128, type=int,
help="The maximum total input sequence length after WordPiece tokenization. Sequences longer "
"than this will be truncated, and sequences shorter than this will be padded.")
parser.add_argument("--do_lower_case", default=True, action='store_true',
help="Whether to lower case the input text. Should be True for uncased "
"models and False for cased models.")
parser.add_argument("--batch_size", default=32, type=int, help="Batch size for predictions.")
parser.add_argument("--local_rank",
type=int,
Expand All @@ -227,14 +219,11 @@ def main():
n_gpu = 1
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
torch.distributed.init_process_group(backend='nccl')
logger.info("device", device, "n_gpu", n_gpu, "distributed training", bool(args.local_rank != -1))
logger.info("device: {} n_gpu: {} distributed training: {}".format(device, n_gpu, bool(args.local_rank != -1)))

layer_indexes = [int(x) for x in args.layers.split(",")]

bert_config = BertConfig.from_json_file(args.bert_config_file)

tokenizer = tokenization.FullTokenizer(
vocab_file=args.vocab_file, do_lower_case=args.do_lower_case)
tokenizer = BertTokenizer.from_pretrained(args.bert_model)

examples = read_examples(args.input_file)

Expand All @@ -245,9 +234,7 @@ def main():
for feature in features:
unique_id_to_feature[feature.unique_id] = feature

model = BertModel(bert_config)
if args.init_checkpoint is not None:
model.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
model = BertModel.from_pretrained(args.bert_model)
model.to(device)

if args.local_rank != -1:
Expand Down
90 changes: 33 additions & 57 deletions run_classifier.py → examples/run_classifier.py
Expand Up @@ -30,9 +30,9 @@
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler

import tokenization
from modeling import BertConfig, BertForSequenceClassification
from optimization import BERTAdam
from pytorch_pretrained_bert.tokenization import printable_text, convert_to_unicode, BertTokenizer
from pytorch_pretrained_bert.modeling import BertForSequenceClassification
from pytorch_pretrained_bert.optimization import BertAdam

logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S',
Expand Down Expand Up @@ -122,9 +122,9 @@ def _create_examples(self, lines, set_type):
if i == 0:
continue
guid = "%s-%s" % (set_type, i)
text_a = tokenization.convert_to_unicode(line[3])
text_b = tokenization.convert_to_unicode(line[4])
label = tokenization.convert_to_unicode(line[0])
text_a = convert_to_unicode(line[3])
text_b = convert_to_unicode(line[4])
label = convert_to_unicode(line[0])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
Expand Down Expand Up @@ -154,14 +154,14 @@ def _create_examples(self, lines, set_type):
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, tokenization.convert_to_unicode(line[0]))
text_a = tokenization.convert_to_unicode(line[8])
text_b = tokenization.convert_to_unicode(line[9])
label = tokenization.convert_to_unicode(line[-1])
guid = "%s-%s" % (set_type, convert_to_unicode(line[0]))
text_a = convert_to_unicode(line[8])
text_b = convert_to_unicode(line[9])
label = convert_to_unicode(line[-1])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples


class ColaProcessor(DataProcessor):
"""Processor for the CoLA data set (GLUE version)."""
Expand All @@ -185,8 +185,8 @@ def _create_examples(self, lines, set_type):
examples = []
for (i, line) in enumerate(lines):
guid = "%s-%s" % (set_type, i)
text_a = tokenization.convert_to_unicode(line[3])
label = tokenization.convert_to_unicode(line[1])
text_a = convert_to_unicode(line[3])
label = convert_to_unicode(line[1])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples
Expand Down Expand Up @@ -273,19 +273,18 @@ def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer
logger.info("*** Example ***")
logger.info("guid: %s" % (example.guid))
logger.info("tokens: %s" % " ".join(
[tokenization.printable_text(x) for x in tokens]))
[printable_text(x) for x in tokens]))
logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
logger.info(
"segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
logger.info("label: %s (id = %d)" % (example.label, label_id))

features.append(
InputFeatures(
input_ids=input_ids,
input_mask=input_mask,
segment_ids=segment_ids,
label_id=label_id))
InputFeatures(input_ids=input_ids,
input_mask=input_mask,
segment_ids=segment_ids,
label_id=label_id))
return features


Expand All @@ -307,7 +306,7 @@ def _truncate_seq_pair(tokens_a, tokens_b, max_length):

def accuracy(out, labels):
outputs = np.argmax(out, axis=1)
return np.sum(outputs==labels)
return np.sum(outputs == labels)

def copy_optimizer_params_to_model(named_params_model, named_params_optimizer):
""" Utility function for optimize_on_cpu and 16-bits training.
Expand All @@ -328,11 +327,14 @@ def set_optimizer_params_grad(named_params_optimizer, named_params_model, test_n
if name_opti != name_model:
logger.error("name_opti != name_model: {} {}".format(name_opti, name_model))
raise ValueError
if test_nan and torch.isnan(param_model.grad).sum() > 0:
is_nan = True
if param_opti.grad is None:
param_opti.grad = torch.nn.Parameter(param_opti.data.new().resize_(*param_opti.data.size()))
param_opti.grad.data.copy_(param_model.grad.data)
if param_model.grad is not None:
if test_nan and torch.isnan(param_model.grad).sum() > 0:
is_nan = True
if param_opti.grad is None:
param_opti.grad = torch.nn.Parameter(param_opti.data.new().resize_(*param_opti.data.size()))
param_opti.grad.data.copy_(param_model.grad.data)
else:
param_opti.grad = None
return is_nan

def main():
Expand All @@ -344,37 +346,21 @@ def main():
type=str,
required=True,
help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
parser.add_argument("--bert_config_file",
default=None,
type=str,
required=True,
help="The config json file corresponding to the pre-trained BERT model. \n"
"This specifies the model architecture.")
parser.add_argument("--bert_model", default=None, type=str, required=True,
help="Bert pre-trained model selected in the list: bert-base-uncased, "
"bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")
parser.add_argument("--task_name",
default=None,
type=str,
required=True,
help="The name of the task to train.")
parser.add_argument("--vocab_file",
default=None,
type=str,
required=True,
help="The vocabulary file that the BERT model was trained on.")
parser.add_argument("--output_dir",
default=None,
type=str,
required=True,
help="The output directory where the model checkpoints will be written.")

## Other parameters
parser.add_argument("--init_checkpoint",
default=None,
type=str,
help="Initial checkpoint (usually from a pre-trained BERT model).")
parser.add_argument("--do_lower_case",
default=False,
action='store_true',
help="Whether to lower case the input text. True for uncased models, False for cased models.")
parser.add_argument("--max_seq_length",
default=128,
type=int,
Expand Down Expand Up @@ -478,13 +464,6 @@ def main():
if not args.do_train and not args.do_eval:
raise ValueError("At least one of `do_train` or `do_eval` must be True.")

bert_config = BertConfig.from_json_file(args.bert_config_file)

if args.max_seq_length > bert_config.max_position_embeddings:
raise ValueError(
"Cannot use sequence length {} because the BERT model was only trained up to sequence length {}".format(
args.max_seq_length, bert_config.max_position_embeddings))

if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
os.makedirs(args.output_dir, exist_ok=True)
Expand All @@ -497,8 +476,7 @@ def main():
processor = processors[task_name]()
label_list = processor.get_labels()

tokenizer = tokenization.FullTokenizer(
vocab_file=args.vocab_file, do_lower_case=args.do_lower_case)
tokenizer = BertTokenizer.from_pretrained(args.bert_model)

train_examples = None
num_train_steps = None
Expand All @@ -508,9 +486,7 @@ def main():
len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs)

# Prepare model
model = BertForSequenceClassification(bert_config, len(label_list))
if args.init_checkpoint is not None:
model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
model = BertForSequenceClassification.from_pretrained(args.bert_model, len(label_list))
if args.fp16:
model.half()
model.to(device)
Expand All @@ -534,7 +510,7 @@ def main():
{'params': [p for n, p in param_optimizer if n not in no_decay], 'weight_decay_rate': 0.01},
{'params': [p for n, p in param_optimizer if n in no_decay], 'weight_decay_rate': 0.0}
]
optimizer = BERTAdam(optimizer_grouped_parameters,
optimizer = BertAdam(optimizer_grouped_parameters,
lr=args.learning_rate,
warmup=args.warmup_proportion,
t_total=num_train_steps)
Expand Down

0 comments on commit 4132a02

Please sign in to comment.