Skip to content

Commit

Permalink
adding tests to examples - updating summary module - coverage update
Browse files Browse the repository at this point in the history
  • Loading branch information
thomwolf committed Jul 9, 2019
1 parent c079d7d commit d5481cb
Show file tree
Hide file tree
Showing 17 changed files with 139 additions and 116 deletions.
3 changes: 3 additions & 0 deletions .coveragerc
@@ -1,5 +1,8 @@
[run]
source=pytorch_transformers
omit =
# skip convertion scripts from testing for now
*/convert_*
[report]
exclude_lines =
pragma: no cover
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Expand Up @@ -126,4 +126,5 @@ models
proc_data

# examples
runs
examples/runs
124 changes: 64 additions & 60 deletions examples/run_glue.py
Expand Up @@ -60,25 +60,14 @@
'xlm': XLMTokenizer,
}

def train(args, train_features, model):
def train(args, train_dataset, model):
""" Train the model """
if args.local_rank in [-1, 0]:
tb_writer = SummaryWriter()

# Convert in tensors and build dataloader
all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
if args.output_mode == "classification":
all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long)
elif args.output_mode == "regression":
all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.float)

args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
train_sampler = RandomSampler(train_data) if args.local_rank == -1 else DistributedSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)

num_train_optimization_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs

Expand Down Expand Up @@ -109,19 +98,24 @@ def train(args, train_features, model):

# Train!
logger.info("***** Running training *****")
logger.info(" Num examples = %d", len(train_features))
logger.info(" Num examples = %d", len(train_dataset))
logger.info(" Num Epochs = %d", args.num_train_epochs)
logger.info(" Batch size = %d", args.train_batch_size)
logger.info(" Num steps = %d", num_train_optimization_steps)
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
logger.info(" Total optimization steps = %d", num_train_optimization_steps)

global_step = 0
tr_loss = 0
model.train()
optimizer.zero_grad()
for _ in trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]):
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])):
batch = tuple(t.to(args.device) for t in batch)
input_ids, input_mask, segment_ids, label_ids = batch

ouputs = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask, labels=label_ids)
inputs = {'input_ids': batch[0],
'attention_mask': batch[1],
'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet'] else None,
'labels': batch[3]}
ouputs = model(**inputs)
loss = ouputs[0]

if args.n_gpu > 1:
Expand Down Expand Up @@ -150,30 +144,20 @@ def train(args, train_features, model):
return global_step, tr_loss / global_step


def evalutate(args, eval_task, eval_output_dir, eval_features, model):
def evalutate(args, eval_task, eval_output_dir, dataset, model):
""" Evaluate the model """
if os.path.exists(eval_output_dir) and os.listdir(eval_output_dir) and args.do_train and not args.overwrite_output_dir:
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(eval_output_dir))
if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
os.makedirs(eval_output_dir)

# Convert in tensors and build dataloader
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
if args.output_mode == "classification":
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
elif args.output_mode == "regression":
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.float)

eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
# Note that DistributedSampler samples randomly
eval_sampler = SequentialSampler(eval_data) if args.local_rank == -1 else DistributedSampler(eval_data)
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)
eval_sampler = SequentialSampler(dataset) if args.local_rank == -1 else DistributedSampler(dataset)
eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)

# Eval!
logger.info("***** Running evaluation *****")
logger.info(" Num examples = %d", len(eval_features))
logger.info(" Num examples = %d", len(dataset))
logger.info(" Batch size = %d", args.eval_batch_size)
model.eval()
eval_loss = 0
Expand Down Expand Up @@ -214,36 +198,47 @@ def evalutate(args, eval_task, eval_output_dir, eval_features, model):
logger.info(" %s = %s", key, str(result[key]))
writer.write("%s = %s\n" % (key, str(result[key])))

return result

def load_and_cache_examples(args, task, tokenizer, eval=False):
processor = processors[task]()
output_mode = output_modes[task]
label_list = processor.get_labels()

# Load and cache data
def load_and_cache_examples(args, task, tokenizer, evaluate=False):
processor = processors[task]()
examples = processor.get_dev_examples(args.data_dir)
cached_features_file = os.path.join(args.data_dir, '{}_{}_{}_{}'.format(
'dev' if eval else 'train',
output_mode = output_modes[task]
# Load data features from cache or dataset file
cached_features_file = os.path.join(args.data_dir, 'cached_{}_{}_{}_{}'.format(
'dev' if evaluate else 'train',
list(filter(None, args.model_name.split('/'))).pop(),
str(args.max_seq_length),
str(task)))

if os.path.exists(cached_features_file):
logger.info("Loading features from cached file %s", cached_features_file)
features = torch.load(cached_features_file)
else:
features = convert_examples_to_features(examples, label_list, args.max_seq_length, tokenizer, output_mode)
logger.info("Creating features from dataset file at %s", args.data_dir)
label_list = processor.get_labels()
examples = processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir)
features = convert_examples_to_features(examples, label_list, args.max_seq_length, tokenizer, output_mode,
cls_token_at_end=bool(args.model_type not in ['bert', 'xlm']),
cls_token_at_end=bool(args.model_type in ['xlnet']), # xlnet has a cls token at the end
cls_token=tokenizer.cls_token,
sep_token=tokenizer.sep_token, cls_token_segment_id=2,
pad_on_left=True, pad_token_segment_id=4)
if args.local_rank == -1 or torch.distributed.get_rank() == 0:
sep_token=tokenizer.sep_token,
cls_token_segment_id=2 if args.model_type in ['xlnet'] else 1,
pad_on_left=bool(args.model_type in ['xlnet']), # pad on the left for xlnet
pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0)
if args.local_rank in [-1, 0]:
logger.info("Saving features into cached file %s", cached_features_file)
torch.save(features, cached_features_file)

return features
# Convert to Tensors and build dataset
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
if output_mode == "classification":
all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long)
elif output_mode == "regression":
all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.float)

dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
return dataset


def main():
Expand Down Expand Up @@ -350,10 +345,10 @@ def main():
torch.distributed.barrier()

args.model_type = args.model_name.lower().split('-')[0]
args.tokenizer_class = TOKENIZER_CLASSES[args.model_type]
args.model_class = MODEL_CLASSES[args.model_type]
tokenizer = args.tokenizer_class.from_pretrained(args.model_name, do_lower_case=args.do_lower_case)
model = args.model_class.from_pretrained(args.model_name, num_labels=num_labels)
tokenizer_class = TOKENIZER_CLASSES[args.model_type]
model_class = MODEL_CLASSES[args.model_type]
tokenizer = tokenizer_class.from_pretrained(args.model_name, do_lower_case=args.do_lower_case)
model = model_class.from_pretrained(args.model_name, num_labels=num_labels)

if args.local_rank == 0:
torch.distributed.barrier()
Expand All @@ -372,23 +367,30 @@ def main():

# Training
if args.do_train:
train_features = load_and_cache_examples(args, args.task_name, tokenizer, eval=False)
global_step, tr_loss = train(args, train_features, model)
train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False)
global_step, tr_loss = train(args, train_dataset, model)
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)


# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
# Save a trained model, configuration and tokenizer
# Create output directory if needed
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
os.makedirs(args.output_dir)

# Save a trained model, configuration and tokenizer using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
model.save_pretrained(args.output_dir)
tokenizer.save_vocabulary(args.output_dir)
tokenizer.save_pretrained(args.output_dir)

# Good practice: save your training arguments together with the trained model
torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))

# Load a trained model and vocabulary that you have fine-tuned
model = args.model_class.from_pretrained(args.output_dir)
tokenizer = args.tokenizer_class.from_pretrained(args.output_dir)
model = model_class.from_pretrained(args.output_dir)
tokenizer = tokenizer_class.from_pretrained(args.output_dir)
model.to(args.device)

# Evaluation
Expand All @@ -398,9 +400,11 @@ def main():
eval_outputs_dirs = (args.output_dir, args.output_dir + '-MM') if args.task_name == "mnli" else (args.output_dir,)

for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs):
eval_features = load_and_cache_examples(args, eval_task, tokenizer, eval=True)
eval_dataset = load_and_cache_examples(args, eval_task, tokenizer, evaluate=True)

result = evalutate(args, eval_task, eval_output_dir, eval_dataset, model)

evalutate(args, eval_task, eval_output_dir, eval_features, model)
return result


if __name__ == "__main__":
Expand Down
25 changes: 18 additions & 7 deletions examples/test_examples.py
Expand Up @@ -19,14 +19,19 @@
import sys
import unittest
import argparse
import logging

try:
# python 3.4+ can use builtin unittest.mock instead of mock package
from unittest.mock import patch
except ImportError:
from mock import patch

import run_bert_squad as rbs
import run_glue

logging.basicConfig(level=logging.DEBUG)

logger = logging.getLogger()

def get_setup_file():
parser = argparse.ArgumentParser()
Expand All @@ -36,12 +41,18 @@ def get_setup_file():

class ExamplesTests(unittest.TestCase):

def test_run_squad(self):
testargs = ["prog", "-f", "/home/test/setup.py"]
with patch.object(sys, 'argv', testargs):
setup = get_setup_file()
assert setup == "/home/test/setup.py"
# rbs.main()
def test_run_glue(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)

testargs = ["run_glue.py", "--data_dir=./examples/tests_samples/MRPC/",
"--task_name=mrpc", "--do_train", "--do_eval", "--output_dir=./examples/tests_samples/temp_dir",
"--train_batch_size=4", "--eval_batch_size=2", "--num_train_epochs=2.0", "--overwrite_output_dir"]
model_name = "--model_name=xlnet-large-cased"
with patch.object(sys, 'argv', testargs + [model_name]):
result = run_glue.main()
for value in result.values():
self.assertGreaterEqual(value, 0.75)


if __name__ == "__main__":
Expand Down
5 changes: 5 additions & 0 deletions examples/tests_samples/.gitignore
@@ -0,0 +1,5 @@
*.*
cache*
temp*
!*.tsv
!.gitignore
7 changes: 7 additions & 0 deletions examples/tests_samples/MRPC/dev.tsv
@@ -0,0 +1,7 @@
Quality #1 ID #2 ID #1 String #2 String
1 1355540 1355592 He said the foodservice pie business doesn 't fit the company 's long-term growth strategy . " The foodservice pie business does not fit our long-term growth strategy .
0 2029631 2029565 Magnarelli said Racicot hated the Iraqi regime and looked forward to using his long years of training in the war . His wife said he was " 100 percent behind George Bush " and looked forward to using his years of training in the war .
0 487993 487952 The dollar was at 116.92 yen against the yen , flat on the session , and at 1.2891 against the Swiss franc , also flat . The dollar was at 116.78 yen JPY = , virtually flat on the session , and at 1.2871 against the Swiss franc CHF = , down 0.1 percent .
1 1989515 1989458 The AFL-CIO is waiting until October to decide if it will endorse a candidate . The AFL-CIO announced Wednesday that it will decide in October whether to endorse a candidate before the primaries .
0 1783137 1782659 No dates have been set for the civil or the criminal trial . No dates have been set for the criminal or civil cases , but Shanley has pleaded not guilty .
1 3039165 3039036 Wal-Mart said it would check all of its million-plus domestic workers to ensure they were legally employed . It has also said it would review all of its domestic employees more than 1 million to ensure they have legal status .
7 changes: 7 additions & 0 deletions examples/tests_samples/MRPC/train.tsv
@@ -0,0 +1,7 @@
Quality #1 ID #2 ID #1 String #2 String
1 1355540 1355592 He said the foodservice pie business doesn 't fit the company 's long-term growth strategy . " The foodservice pie business does not fit our long-term growth strategy .
0 2029631 2029565 Magnarelli said Racicot hated the Iraqi regime and looked forward to using his long years of training in the war . His wife said he was " 100 percent behind George Bush " and looked forward to using his years of training in the war .
0 487993 487952 The dollar was at 116.92 yen against the yen , flat on the session , and at 1.2891 against the Swiss franc , also flat . The dollar was at 116.78 yen JPY = , virtually flat on the session , and at 1.2871 against the Swiss franc CHF = , down 0.1 percent .
1 1989515 1989458 The AFL-CIO is waiting until October to decide if it will endorse a candidate . The AFL-CIO announced Wednesday that it will decide in October whether to endorse a candidate before the primaries .
0 1783137 1782659 No dates have been set for the civil or the criminal trial . No dates have been set for the criminal or civil cases , but Shanley has pleaded not guilty .
1 3039165 3039036 Wal-Mart said it would check all of its million-plus domestic workers to ensure they were legally employed . It has also said it would review all of its domestic employees more than 1 million to ensure they have legal status .
1 change: 0 additions & 1 deletion pytorch_transformers/modeling_bert.py
Expand Up @@ -28,7 +28,6 @@
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss

from .file_utils import cached_path
from .modeling_utils import WEIGHTS_NAME, CONFIG_NAME, PretrainedConfig, PreTrainedModel, prune_linear_layer

logger = logging.getLogger(__name__)
Expand Down
7 changes: 2 additions & 5 deletions pytorch_transformers/modeling_gpt2.py
Expand Up @@ -30,7 +30,6 @@
from torch.nn import CrossEntropyLoss
from torch.nn.parameter import Parameter

from .file_utils import cached_path
from .modeling_utils import (Conv1D, CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig,
PreTrainedModel, prune_conv1d_layer, SequenceSummary)
from .modeling_bert import BertLayerNorm as LayerNorm
Expand Down Expand Up @@ -122,9 +121,8 @@ def __init__(
predict_special_tokens=True,
summary_type='token_ids',
summary_use_proj=True,
summary_num_classes=1,
summary_activation=None,
summary_dropout=0.1,
summary_first_dropout=0.1,
**kwargs
):
"""Constructs GPT2Config.
Expand Down Expand Up @@ -172,9 +170,8 @@ def __init__(
self.predict_special_tokens = predict_special_tokens
self.summary_type = summary_type
self.summary_use_proj = summary_use_proj
self.summary_num_classes = summary_num_classes
self.summary_activation = summary_activation
self.summary_dropout = summary_dropout
self.summary_first_dropout = summary_first_dropout
else:
raise ValueError(
"First argument must be either a vocabulary size (int)"
Expand Down
9 changes: 3 additions & 6 deletions pytorch_transformers/modeling_openai.py
Expand Up @@ -30,9 +30,8 @@
from torch.nn import CrossEntropyLoss
from torch.nn.parameter import Parameter

from .file_utils import cached_path
from .modeling_utils import (Conv1D, CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig,
PreTrainedModel, prune_conv1d_layer, SequenceSummary)
PreTrainedModel, prune_conv1d_layer, SequenceSummary)
from .modeling_bert import BertLayerNorm as LayerNorm

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -150,9 +149,8 @@ def __init__(
predict_special_tokens=True,
summary_type='token_ids',
summary_use_proj=True,
summary_num_classes=1,
summary_activation=None,
summary_dropout=0.1,
summary_first_dropout=0.1,
**kwargs
):
"""Constructs OpenAIGPTConfig.
Expand Down Expand Up @@ -203,9 +201,8 @@ def __init__(
self.predict_special_tokens = predict_special_tokens
self.summary_type = summary_type
self.summary_use_proj = summary_use_proj
self.summary_num_classes = summary_num_classes
self.summary_activation = summary_activation
self.summary_dropout = summary_dropout
self.summary_first_dropout = summary_first_dropout
else:
raise ValueError(
"First argument must be either a vocabulary size (int)"
Expand Down
1 change: 0 additions & 1 deletion pytorch_transformers/modeling_transfo_xl.py
Expand Up @@ -36,7 +36,6 @@

from .modeling_bert import BertLayerNorm as LayerNorm
from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax, sample_logits
from .file_utils import cached_path
from .modeling_utils import CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel

logger = logging.getLogger(__name__)
Expand Down

0 comments on commit d5481cb

Please sign in to comment.