Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Commit

Permalink
revise train_transformer
Browse files Browse the repository at this point in the history
  • Loading branch information
zheyuye committed Jun 29, 2020
1 parent 1450f5c commit ca83fac
Showing 1 changed file with 42 additions and 47 deletions.
89 changes: 42 additions & 47 deletions scripts/machine_translation/train_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,15 @@
import random
import os
import logging
import itertools
import math
import numpy as np
import mxnet as mx
from mxnet import gluon
from gluonnlp.models.transformer import TransformerNMTModel
from gluonnlp.utils.misc import logging_config, AverageSGDTracker, count_parameters,\
md5sum, grouper
from gluonnlp.data.sampler import *
md5sum, grouper, repeat
from gluonnlp.data.sampler import ConstWidthBucket, LinearWidthBucket, ExpWidthBucket,\
FixedBucketSampler
import gluonnlp.data.batchify as bf
from gluonnlp.data import Vocab
from gluonnlp.data import tokenizers
Expand Down Expand Up @@ -355,19 +355,15 @@ def train(args):
log_start_time = time.time()
num_params, num_fixed_params = None, None
# TODO(sxjscience) Add a log metric class
accum_count = 0
loss_denom = 0
n_train_iters = 0
log_wc = 0
log_avg_loss = 0.0
log_loss_denom = 0
for epoch_id in range(args.epochs):
n_epoch_train_iters = 0
processed_batch_num = 0
train_multi_data_loader = grouper(train_data_loader, len(ctx_l))
is_last_batch = False
sample_data_l = next(train_multi_data_loader)
while not is_last_batch:
epoch_id = 0
processed_batch_num = 0
for n_train_iters, batch_data in enumerate(
grouper(repeat(train_data_loader, count=args.epochs), len(ctx_l) * args.num_accumulated)):
for sample_data_l in grouper(batch_data, len(ctx_l)):
processed_batch_num += len(sample_data_l)
loss_l = []
for sample_data, ctx in zip(sample_data_l, ctx_l):
Expand All @@ -394,48 +390,47 @@ def train(args):
loss_l.append(loss.sum() / rescale_loss)
for l in loss_l:
l.backward()
accum_count += 1
try:
sample_data_l = next(train_multi_data_loader)
except StopIteration:
is_last_batch = True

if num_params is None:
num_params, num_fixed_params = count_parameters(model.collect_params())
logging.info('Total Number of Parameters (not-fixed/fixed): {}/{}'
.format(num_params, num_fixed_params))
sum_loss = sum([l.as_in_ctx(mx.cpu()) for l in loss_l]) * rescale_loss
log_avg_loss += sum_loss
mx.npx.waitall()
if accum_count == args.num_accumulated or is_last_batch:
# Update the parameters
n_train_iters += 1
n_epoch_train_iters += 1
trainer.step(loss_denom.asnumpy() / rescale_loss)
accum_count = 0
loss_denom = 0
model.collect_params().zero_grad()
if epoch_id >= (args.epochs - args.num_averages):
model_averager.step()
if n_epoch_train_iters % args.log_interval == 0:
log_end_time = time.time()
log_wc = log_wc.asnumpy()
wps = log_wc / (log_end_time - log_start_time)
log_avg_loss = (log_avg_loss / log_loss_denom).asnumpy()
logging.info('[Epoch {} Batch {}/{}] loss={:.4f}, ppl={:.4f}, '
'throughput={:.2f}K wps, wc={:.2f}K, LR={}'
.format(epoch_id, processed_batch_num, num_batches,
log_avg_loss, np.exp(log_avg_loss),
wps / 1000, log_wc / 1000, trainer.learning_rate))
log_start_time = time.time()
log_avg_loss = 0
log_loss_denom = 0
log_wc = 0
model.save_parameters(os.path.join(args.save_dir,
'epoch{:d}.params'.format(epoch_id)),
deduplicate=True)
avg_valid_loss = validation(model, val_data_loader, ctx_l)
logging.info('[Epoch {}] validation loss/ppl={:.4f}/{:.4f}'
.format(epoch_id, avg_valid_loss, np.exp(avg_valid_loss)))

# Update the parameters
trainer.step(loss_denom.asnumpy() / rescale_loss)
loss_denom = 0
model.collect_params().zero_grad()

if epoch_id >= (args.epochs - args.num_averages):
model_averager.step()
if (n_train_iters + 1) % args.log_interval == 0:
log_end_time = time.time()
log_wc = log_wc.asnumpy()
wps = log_wc / (log_end_time - log_start_time)
log_avg_loss = (log_avg_loss / log_loss_denom).asnumpy()
logging.info('[Epoch {} Batch {}/{}] loss={:.4f}, ppl={:.4f}, '
'throughput={:.2f}K wps, wc={:.2f}K, LR={}'
.format(epoch_id, processed_batch_num, num_batches,
log_avg_loss, np.exp(log_avg_loss),
wps / 1000, log_wc / 1000, trainer.learning_rate))
log_start_time = time.time()
log_avg_loss = 0
log_loss_denom = 0
log_wc = 0

# save parameters every epochs
if processed_batch_num >= num_batches:
epoch_id += 1
model.save_parameters(os.path.join(args.save_dir,
'epoch{:d}.params'.format(epoch_id)),
deduplicate=True)
avg_valid_loss = validation(model, val_data_loader, ctx_l)
logging.info('[Epoch {}] validation loss/ppl={:.4f}/{:.4f}'
.format(epoch_id, avg_valid_loss, np.exp(avg_valid_loss)))

if args.num_averages > 0:
model_averager.copy_back(model.collect_params()) # TODO(sxjscience) Rewrite using update
model.save_parameters(os.path.join(args.save_dir, 'average.params'),
Expand Down

0 comments on commit ca83fac

Please sign in to comment.