Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Commit

Permalink
repeat
Browse files Browse the repository at this point in the history
  • Loading branch information
zheyuye committed Jun 29, 2020
1 parent aea936f commit 8ee381b
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 86 deletions.
13 changes: 4 additions & 9 deletions scripts/pretraining/run_electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,13 +382,13 @@ def train(args):
if args.num_accumulated != 1:
# set grad to zero for gradient accumulation
model.collect_params().zero_grad()
while not finish_flag:
for update_count, batch_data in enumerate(grouper(repeat(train_dataloader), len(ctx_l) * args.num_accumulated)):
tic = time.time()
batch_id = 0
is_last_batch = False
train_dataloader = grouper(data_train, len(ctx_l))
sample_l = next(train_dataloader)
while not is_last_batch:

for sample_l in grouper(batch_data, len(ctx_l))
loss_l = []
mlm_loss_l = []
rtd_loss_l = []
Expand Down Expand Up @@ -438,14 +438,9 @@ def train(args):
for ele in rtd_loss_l]).asnumpy()
log_total_loss += sum([ele.as_in_ctx(ctx_l[0])
for ele in loss_l]).asnumpy() * loss_denom
# pre fetch next batch
try:
sample_l = next(train_dataloader)
except StopIteration:
is_last_batch = True

# update
if (batch_id + 1) % args.num_accumulated == 0 or is_last_batch:
if (batch_id + 1) % args.num_accumulated == 0:
trainer.allreduce_grads()
# Here, the accumulated gradients are
# \sum_{n=1}^N g_n / loss_denom
Expand Down
146 changes: 69 additions & 77 deletions scripts/question_answering/run_squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from eval_utils import squad_eval
from squad_utils import SquadFeature, get_squad_examples, convert_squad_example_to_feature
from gluonnlp.models import get_backbone
from gluonnlp.utils.misc import grouper, set_seed, parse_ctx, logging_config, count_parameters
from gluonnlp.utils.misc import grouper, repeat, set_seed, parse_ctx, logging_config, count_parameters
from gluonnlp.initializer import TruncNorm
from gluonnlp.utils.parameter import clip_grad_global_norm

Expand Down Expand Up @@ -531,7 +531,6 @@ def train(args):
update_on_kvstore=False)
step_num = 0
finish_flag = False
epoch_id = 0
num_samples_per_update = 0
loss_denom = float(len(ctx_l) * args.num_accumulated)

Expand All @@ -542,23 +541,21 @@ def train(args):
if args.num_accumulated != 1:
# set grad to zero for gradient accumulation
qa_net.collect_params().zero_grad()

# start training
global_tic = time.time()
while not finish_flag:
epoch_tic = time.time()
tic = time.time()
epoch_sample_num = 0
for batch_id, sample_l in enumerate(grouper(train_dataloader, len(ctx_l))):
loss_l = []
span_loss_l = []
answerable_loss_l = []
is_last_batch = (batch_id == epoch_size - 1)
tic = time.time()
for update_count, batch_data in enumerate(grouper(repeat(train_dataloader), len(ctx_l) * args.num_accumulated)):
loss_l = []
span_loss_l = []
answerable_loss_l = []
for sample_l in grouper(batch_data, len(ctx_l)):
for sample, ctx in zip(sample_l, ctx_l):
if sample is None:
continue
# Copy the data to device
tokens = sample.data.as_in_ctx(ctx)
log_sample_num += len(tokens)
epoch_sample_num += len(tokens)
num_samples_per_update += len(tokens)
segment_ids = sample.segment_ids.as_in_ctx(ctx) if use_segmentation else None
valid_length = sample.valid_length.as_in_ctx(ctx)
Expand Down Expand Up @@ -589,70 +586,65 @@ def train(args):
for ele in loss_l]).asnumpy() * loss_denom
log_answerable_loss += sum([ele.as_in_ctx(ctx_l[0])
for ele in answerable_loss_l]).asnumpy()
# update
if (batch_id + 1) % args.num_accumulated == 0 or is_last_batch:
trainer.allreduce_grads()
# Here, the accumulated gradients are
# \sum_{n=1}^N g_n / loss_denom
# Thus, in order to clip the average gradient
# \frac{1}{N} \sum_{n=1}^N --> clip to args.max_grad_norm
# We need to change the ratio to be
# \sum_{n=1}^N g_n / loss_denom --> clip to args.max_grad_norm * N / loss_denom
total_norm, ratio, is_finite = clip_grad_global_norm(
params, args.max_grad_norm * num_samples_per_update / loss_denom)
total_norm = total_norm / (num_samples_per_update / loss_denom)

trainer.update(num_samples_per_update / loss_denom, ignore_stale_grad=True)
step_num += 1
if args.num_accumulated != 1:
# set grad to zero for gradient accumulation
qa_net.collect_params().zero_grad()

# saving
if step_num % save_interval == 0 or step_num >= num_train_steps:
version_prefix = 'squad' + args.version
ckpt_name = '{}_{}_{}.params'.format(args.model_name,
version_prefix,
step_num)
params_saved = os.path.join(args.output_dir, ckpt_name)
qa_net.save_parameters(params_saved)
ckpt_candidates = [
f for f in os.listdir(
args.output_dir) if f.endswith('.params')]
# keep last 10 checkpoints
if len(ckpt_candidates) > args.max_saved_ckpt:
ckpt_candidates.sort(key=lambda ele: (len(ele), ele))
os.remove(os.path.join(args.output_dir, ckpt_candidates[0]))
logging.info('Params saved in: {}'.format(params_saved))

# logging
if step_num % log_interval == 0:
log_span_loss /= log_sample_num
log_answerable_loss /= log_sample_num
log_total_loss /= log_sample_num
toc = time.time()
logging.info(
'Epoch: {}, Batch: {}/{}, Loss span/answer/total={:.4f}/{:.4f}/{:.4f},'
' LR={:.8f}, grad_norm={:.4f}. Time cost={:.2f}, Throughput={:.2f} samples/s'
' ETA={:.2f}h'.format(epoch_id + 1, batch_id + 1, epoch_size, log_span_loss,
log_answerable_loss, log_total_loss, trainer.learning_rate, total_norm,
toc - tic, log_sample_num / (toc - tic),
(num_train_steps - step_num) / (step_num / (toc - global_tic)) / 3600))
tic = time.time()
log_span_loss = 0
log_answerable_loss = 0
log_total_loss = 0
log_sample_num = 0
num_samples_per_update = 0

if step_num >= num_train_steps:
logging.info('Finish training step: %d', step_num)
finish_flag = True
break
logging.info('Epoch: {}, #Samples: {}, Throughput={:.2f} samples/s'
.format(epoch_id + 1, epoch_sample_num,
epoch_sample_num / (time.time() - epoch_tic)))
epoch_id += 1
# update
trainer.allreduce_grads()
# Here, the accumulated gradients are
# \sum_{n=1}^N g_n / loss_denom
# Thus, in order to clip the average gradient
# \frac{1}{N} \sum_{n=1}^N --> clip to args.max_grad_norm
# We need to change the ratio to be
# \sum_{n=1}^N g_n / loss_denom --> clip to args.max_grad_norm * N / loss_denom
total_norm, ratio, is_finite = clip_grad_global_norm(
params, args.max_grad_norm * num_samples_per_update / loss_denom)
total_norm = total_norm / (num_samples_per_update / loss_denom)

trainer.update(num_samples_per_update / loss_denom, ignore_stale_grad=True)
step_num += 1
if args.num_accumulated != 1:
# set grad to zero for gradient accumulation
qa_net.collect_params().zero_grad()

# saving
if step_num % save_interval == 0 or step_num >= num_train_steps:
version_prefix = 'squad' + args.version
ckpt_name = '{}_{}_{}.params'.format(args.model_name,
version_prefix,
step_num)
params_saved = os.path.join(args.output_dir, ckpt_name)
qa_net.save_parameters(params_saved)
ckpt_candidates = [
f for f in os.listdir(
args.output_dir) if f.endswith('.params')]
# keep last 10 checkpoints
if len(ckpt_candidates) > args.max_saved_ckpt:
ckpt_candidates.sort(key=lambda ele: (len(ele), ele))
os.remove(os.path.join(args.output_dir, ckpt_candidates[0]))
logging.info('Params saved in: {}'.format(params_saved))

# logging
if step_num % log_interval == 0:
log_span_loss /= log_sample_num
log_answerable_loss /= log_sample_num
log_total_loss /= log_sample_num
toc = time.time()
logging.info(
'Batch: {}/{}, Loss span/answer/total={:.4f}/{:.4f}/{:.4f},'
' LR={:.8f}, grad_norm={:.4f}. Time cost={:.2f}, Throughput={:.2f} samples/s'
' ETA={:.2f}h'.format(update_count + 1, epoch_size, log_span_loss,
log_answerable_loss, log_total_loss, trainer.learning_rate, total_norm,
toc - tic, log_sample_num / (toc - tic),
(num_train_steps - step_num) / (step_num / (toc - global_tic)) / 3600))
tic = time.time()
log_span_loss = 0
log_answerable_loss = 0
log_total_loss = 0
log_sample_num = 0
num_samples_per_update = 0

if step_num >= num_train_steps:
logging.info('Finish training step: %d', step_num)
break

return params_saved


Expand Down Expand Up @@ -859,7 +851,7 @@ def eval_validation(ckpt_name, best_eval):
p_mask = sample.masks.as_in_ctx(ctx)
p_mask = 1 - p_mask # In the network, we use 1 --> no_mask, 0 --> mask
start_top_logits, start_top_index, end_top_logits, end_top_index, answerable_logits \
= qa_net.inference(tokens, segment_ids, valid_length, p_mask,
= qa_net.inference(tokens, segment_ids, valid_length, p_mask,
args.start_top_n, args.end_top_n)
for i, qas_id in enumerate(sample.qas_id):
result = RawResultExtended(qas_id=qas_id,
Expand Down
9 changes: 9 additions & 0 deletions src/gluonnlp/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,15 @@ def grouper(iterable, n, fillvalue=None):
args = [iter(iterable)] * n
return itertools.zip_longest(*args, fillvalue=fillvalue)

def repeat(iterable, count=None):
if count is None:
while True:
for sample in iterable:
yield sample
else:
for i in range(count):
for sample in iterable:
yield sample

def parse_ctx(data_str):
import mxnet as mx
Expand Down

0 comments on commit 8ee381b

Please sign in to comment.