Skip to content

Commit

Permalink
Prefix TPU specific comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
jysohn23 committed Apr 9, 2020
1 parent 306851c commit 1eb47c5
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions examples/run_glue_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def train(args, train_dataset, model, tokenizer, disable_logging=False):
train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=disable_logging)
set_seed(args.seed) # Added here for reproductibility (even between python 2 and 3)
for epoch in train_iterator:
# Get TPU parallel loader which sends data to TPU in background.
# tpu-comment: Get TPU parallel loader which sends data to TPU in background.
train_dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device)
epoch_iterator = tqdm(train_dataloader, desc="Iteration", total=len(dataloader), disable=disable_logging)
for step, batch in enumerate(epoch_iterator):
Expand Down Expand Up @@ -197,7 +197,7 @@ def train(args, train_dataset, model, tokenizer, disable_logging=False):
)
)
if xm.is_master_ordinal():
# All values must be in CPU and not on TPU device
# tpu-comment: All values must be in CPU and not on TPU device
for key, value in results.items():
tb_writer.add_scalar("eval_{}".format(key), value, global_step)
tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
Expand All @@ -207,6 +207,7 @@ def train(args, train_dataset, model, tokenizer, disable_logging=False):
epoch_iterator.close()
break
if args.metrics_debug:
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
xm.master_print(met.metrics_report())
if args.max_steps > 0 and global_step > args.max_steps:
train_iterator.close()
Expand Down Expand Up @@ -235,7 +236,6 @@ def evaluate(args, model, tokenizer, prefix="", disable_logging=False):
if not os.path.exists(eval_output_dir):
os.makedirs(eval_output_dir)

# Note that we don't shard for TPU Multiprocess as we don't reduce loss among client processes.
dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, shuffle=False)
eval_dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device)

Expand Down Expand Up @@ -267,7 +267,7 @@ def evaluate(args, model, tokenizer, prefix="", disable_logging=False):
preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
out_label_ids = np.append(out_label_ids, inputs["labels"].detach().cpu().numpy(), axis=0)

# Get all predictions and labels from all workers
# tpu-comment: Get all predictions and labels from all worker shards of eval dataset
preds = xm.mesh_reduce("eval_preds", preds, np.concatenate)
out_label_ids = xm.mesh_reduce("eval_out_label_ids", out_label_ids, np.concatenate)

Expand All @@ -290,6 +290,7 @@ def evaluate(args, model, tokenizer, prefix="", disable_logging=False):
tb_writer.add_scalar(key, results[key])

if args.metrics_debug:
# tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
xm.master_print(met.metrics_report())

if xm.is_master_ordinal():
Expand Down Expand Up @@ -369,7 +370,7 @@ def main(args):
).format(args.output_dir)
)

# Get TPU/XLA Device
# tpu-comment: Get TPU/XLA Device
args.device = xm.xla_device()

# Setup logging
Expand Down

0 comments on commit 1eb47c5

Please sign in to comment.