Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Add nt-asgd for language model #170

Merged
merged 2 commits into from
Sep 19, 2018
Merged

Add nt-asgd for language model #170

merged 2 commits into from
Sep 19, 2018

Conversation

cgraywang
Copy link
Contributor

Description

  1. Add nt-asgd for language model
  2. Online update of nt-asgd

Checklist

Essentials

  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage
  • Code is well-documented

Changes

  • Feature1, tests, (and when applicable, API doc)
  • Feature2, tests, (and when applicable, API doc)

Comments

  • If this change is a backward incompatible change, why must this change be made.
  • Interesting edge cases to note here

@cgraywang cgraywang requested a review from szha as a code owner June 26, 2018 06:55
@cgraywang cgraywang requested a review from szhengac June 26, 2018 06:56
@cgraywang
Copy link
Contributor Author

@szhengac Thanks for the very helpful discussion!

@@ -354,30 +357,51 @@ def train():
data_list = gluon.utils.split_and_load(data, context, batch_axis=1, even_split=True)
target_list = gluon.utils.split_and_load(target, context, batch_axis=1, even_split=True)
hiddens = detach(hiddens)
model.collect_params().zero_grad()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you zero out the gradient? I think you are not using accumulated gradient, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on the original implementation of the paper, https://github.com/salesforce/awd-lstm-lm/blob/32fcb42562aeb5c7e6c9dec3f2a3baaaf68a5cb5/main.py#L194 , it is used before updating parameters.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They use torch, in which the gradient is accumulated across the iterations in default. In gluon, the gradient is replaced in default.

l = joint_loss(output, y, encoder_hs, dropped_encoder_hs)
L = L + l.as_in_context(context[0]) / X.size
Ls.append(l/X.size)
Ls.append(l / X.size)
hiddens[j] = h
L.backward()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can simply perform backward by doing for loop in Ls. No need to have Ls and L at the same time

total_L += sum([mx.nd.sum(L).asscalar() for L in Ls]) / len(context)
trainer.set_learning_rate(lr_batch_start)
if batch_i % args.log_interval == 0 and batch_i > 0:
if batch_i % args.log_interval == 0 and avg_trigger == 0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, it is 'and'? if so, T will be changed only once.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it will change only once each epoch.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I check the paper again. The T is changed only once "throughout the training".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the log_interaval = number of batches

alpha = 1.0 / max(1, batch_i - avg_trigger + 1)
if param_dict_avg is None:
param_dict_avg = {k: v.data(context[0]).copy()
for k, v in model.collect_params().items()}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move the initialization before the trainer.step(1), as the average may include the initial point.

param_dict_avg = {k: v.data(context[0]).copy()
for k, v in model.collect_params().items()}
for name, param_avg in param_dict_avg.items():
param_avg[:] += alpha * (param_dict_batch_i[name] - param_avg)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should move this averaging step after the validation, as avg_trigger is determined after that

trainer.step(1)

alpha = 1.0 / max(1, batch_i - avg_trigger + 1)
Copy link
Member

@szhengac szhengac Jun 26, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With my other comments, we should add 2 here. You can think about the case where we only perform 1 iterations, and final avg_trigger is 0, and the output is the average of these two iterates.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the 2 iterations case, the alpha = 1 / max(1, 1-0+1) = 1/2, the output will equal to the average of the first two iterations, it seems to be correct to me.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ur batch_i starts from 0. It was 1 iteration. I have updated my comment.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the two iterates refer to w_0 and w_1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in iteration 0, param_dict_avg = w0, in iteration 1, param_dict_avg = (w0+w1)/2 , seems to be correct?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Running the model for only one iteration means only one backward is performed. The routine you mention indicates 2 iterations.

Copy link
Member

@szha szha left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Be sure to update the model performances and pre-trained model too.

@codecov
Copy link

codecov bot commented Jun 27, 2018

Codecov Report

❗ No coverage uploaded for pull request base (master@03b0e70). Click here to learn what that means.
The diff coverage is 0%.

Impacted file tree graph

@@            Coverage Diff            @@
##             master     #170   +/-   ##
=========================================
  Coverage          ?   76.74%           
=========================================
  Files             ?       79           
  Lines             ?     6462           
  Branches          ?     1019           
=========================================
  Hits              ?     4959           
  Misses            ?     1263           
  Partials          ?      240
Impacted Files Coverage Δ
gluonnlp/model/language_model.py 94.28% <ø> (ø)
scripts/language_model/cache_language_model.py 0% <0%> (ø)
scripts/language_model/word_language_model.py 0% <0%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 03b0e70...bc7d9d1. Read the comment docs.

trainer.step(1)

if args.ntasgd:
gamma = 1.0 / max(1, batch_i - avg_trigger + 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed, it should be 2.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In current implementation, iteration 0 generates w1, iteration 2 generates (w1+w2)/2. What do you mean by 2? If you want iteration 0 generates w0, and iteration 2 generates (w0+w1)/2, we can change but I don't think it makes a big difference.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current implementation is ok. But If you want to implement the exact algorithm in the paper, the iteration 0 should generate (w_0+w_1)/2.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, I did't really follow the paper, iteration -1 should return w_0. Will change, thanks

grads = [p.grad(d.context) for p in parameters for d in data_list]
gluon.utils.clip_global_norm(grads, args.clip)

if args.ntasgd:
if param_dict_avg is None:
param_dict_avg = {k.split(model._prefix)[1]: v.data(context[0]).copy()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why split the prefix?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The model.load_params only load models without prefix.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

load_params is deprecated. Use load_parameters instead.

hiddens[j] = h
L.backward()
for L in Ls:
L.backward()
grads = [p.grad(d.context) for p in parameters for d in data_list]
gluon.utils.clip_global_norm(grads, args.clip)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to test it in multi-gpu mode.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tested.

>>> [a,b]
[
[[[1. 1. 1.]
  [1. 1. 1.]]]
<NDArray 1x2x3 @gpu(0)>,
[[[1. 1. 1.]
  [1. 1. 1.]]]
<NDArray 1x2x3 @gpu(1)>]
>>> grads = [a,b]
>>> mx.gluon.utils.clip_global_norm(grads, 0.25)
3.4641016
>>> grads
[
[[[0.07216878 0.07216878 0.07216878]
  [0.07216878 0.07216878 0.07216878]]]
<NDArray 1x2x3 @gpu(0)>,
[[[0.07216878 0.07216878 0.07216878]
  [0.07216878 0.07216878 0.07216878]]]
<NDArray 1x2x3 @gpu(1)>]```

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks correct to me

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I forget to mention that clip_global_norm does not perform reduction. So we need to do it manually before feeding the gradient into the clip function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The grads have been copied to every context, why I need reduction?

Copy link
Member

@szhengac szhengac Jun 28, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, if a parameter has grad_1 and grad_2 w.r.t. two samples, respectively, we need to do (grad_1+grad_2)/2 first and then use the average to compute the norm. Without manual reduction, grad_1^2 + grad_2^2 is used in clip_global_norm.

hiddens[j] = h
L.backward()
for L in Ls:
L.backward()
grads = [p.grad(d.context) for p in parameters for d in data_list]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

p.grad only accepts a single context instead of a list. So as data_list is the returned result from the split_and_load, the gradient would not be copied to all the contexts.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The p.grad does only accepts single gpu context while using the data.context:

for d in data_list:
... print(d.context)
...
gpu(0)
gpu(1)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the goal is to copy the gradients to all the contexts, you can use following:
grads = [p.grad(ctx) for p in parameters for ctx in context]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, they are the almost same.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The one I provided is cheaper, as its complexity is reduced by a factor of batch_size/len(context)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well. sometimes the context can be reserved as the computation resource, not everytime all the context are leveraged. This data.context is much safer. I still prefer it.

@mli
Copy link
Member

mli commented Jun 27, 2018

Job PR-170/3 is complete.
Docs are uploaded to http://gluon-nlp-staging.s3-accelerate.dualstack.amazonaws.com/PR-170/3/index.html

@mli
Copy link
Member

mli commented Jun 28, 2018

Job PR-170/4 is complete.
Docs are uploaded to http://gluon-nlp-staging.s3-accelerate.dualstack.amazonaws.com/PR-170/4/index.html

@mli
Copy link
Member

mli commented Jul 3, 2018

Job PR-170/9 is complete.
Docs are uploaded to http://gluon-nlp-staging.s3-accelerate.dualstack.amazonaws.com/PR-170/9/index.html

@mli
Copy link
Member

mli commented Jul 5, 2018

Job PR-170/10 is complete.
Docs are uploaded to http://gluon-nlp-staging.s3-accelerate.dualstack.amazonaws.com/PR-170/10/index.html

@mli
Copy link
Member

mli commented Jul 8, 2018

Job PR-170/11 is complete.
Docs are uploaded to http://gluon-nlp-staging.s3-accelerate.dualstack.amazonaws.com/PR-170/11/index.html

@szha szha mentioned this pull request Jul 9, 2018
@mli
Copy link
Member

mli commented Jul 11, 2018

Job PR-170/12 is complete.
Docs are uploaded to http://gluon-nlp-staging.s3-accelerate.dualstack.amazonaws.com/PR-170/12/index.html

@mli
Copy link
Member

mli commented Jul 12, 2018

Job PR-170/13 is complete.
Docs are uploaded to http://gluon-nlp-staging.s3-accelerate.dualstack.amazonaws.com/PR-170/13/index.html

@mli
Copy link
Member

mli commented Jul 13, 2018

Job PR-170/14 is complete.
Docs are uploaded to http://gluon-nlp-staging.s3-accelerate.dualstack.amazonaws.com/PR-170/14/index.html

@mli
Copy link
Member

mli commented Jul 16, 2018

Job PR-170/15 is complete.
Docs are uploaded to http://gluon-nlp-staging.s3-accelerate.dualstack.amazonaws.com/PR-170/15/index.html

@mli
Copy link
Member

mli commented Jul 17, 2018

Job PR-170/16 is complete.
Docs are uploaded to http://gluon-nlp-staging.s3-accelerate.dualstack.amazonaws.com/PR-170/16/index.html

@szha
Copy link
Member

szha commented Jul 17, 2018

@cgraywang @szhengac ping.

@mli
Copy link
Member

mli commented Jul 18, 2018

Job PR-170/17 is complete.
Docs are uploaded to http://gluon-nlp-staging.s3-accelerate.dualstack.amazonaws.com/PR-170/17/index.html

@mli
Copy link
Member

mli commented Jul 25, 2018

Job PR-170/20 is complete.
Docs are uploaded to http://gluon-nlp-staging.s3-accelerate.dualstack.amazonaws.com/PR-170/20/index.html

@mli
Copy link
Member

mli commented Jul 27, 2018

Job PR-170/21 is complete.
Docs are uploaded to http://gluon-nlp-staging.s3-accelerate.dualstack.amazonaws.com/PR-170/21/index.html

@mli
Copy link
Member

mli commented Jul 29, 2018

Job PR-170/22 is complete.
Docs are uploaded to http://gluon-nlp-staging.s3-accelerate.dualstack.amazonaws.com/PR-170/22/index.html

@mli
Copy link
Member

mli commented Aug 3, 2018

Job PR-170/25 is complete.
Docs are uploaded to http://gluon-nlp-staging.s3-accelerate.dualstack.amazonaws.com/PR-170/25/index.html

@mli
Copy link
Member

mli commented Aug 4, 2018

Job PR-170/26 is complete.
Docs are uploaded to http://gluon-nlp-staging.s3-accelerate.dualstack.amazonaws.com/PR-170/26/index.html

@mli
Copy link
Member

mli commented Aug 5, 2018

Job PR-170/27 is complete.
Docs are uploaded to http://gluon-nlp-staging.s3-accelerate.dualstack.amazonaws.com/PR-170/27/index.html

@mli
Copy link
Member

mli commented Aug 11, 2018

Job PR-170/29 is complete.
Docs are uploaded to http://gluon-nlp-staging.s3-accelerate.dualstack.amazonaws.com/PR-170/29/index.html

@szha szha mentioned this pull request Aug 12, 2018
20 tasks
@szha szha added the release focus Progress focus for release label Aug 12, 2018
@mli
Copy link
Member

mli commented Aug 20, 2018

Job PR-170/38 is complete.
Docs are uploaded to http://gluon-nlp-staging.s3-accelerate.dualstack.amazonaws.com/PR-170/38/index.html

@mli
Copy link
Member

mli commented Aug 21, 2018

Job PR-170/39 is complete.
Docs are uploaded to http://gluon-nlp-staging.s3-accelerate.dualstack.amazonaws.com/PR-170/39/index.html

if len(list(parameters)[0].list_ctx()) > 1:
for ctx in list(parameters)[0].list_ctx()[1:]:
for p in parameters:
p.as_in_context(ctx)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what are you doing here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sync parameters

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean?

trainer.allreduce_grads()
ctx = list(parameters)[0].list_ctx()[0]
global_grads = [p.grad(ctx) for p in parameters]
total_norm, scale = _multi_gpu_clip_global_norm_scale(global_grads, max_norm)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trainer.allreduce_grads() does not perform averaging, so this should be taken into account when computing the scale.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since now the experiment is running with single gpu, it should be fine. I can further test it after KDD.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the script is not used in KDD, let's get it right and not rush this.

l = joint_loss(output, y, encoder_hs, dropped_encoder_hs)
L = L + l.as_in_context(context[0]) / X.size
Ls.append(l/X.size)
Ls.append(l.as_in_context(context[0]) / (len(context) * X.size))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

l is already in the corresponding context, why do u copy it to the first gpu?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The result should be the same, probably I can address later after KDD

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make it right since this script is not used in KDD tutorials.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The copy is not necessary in either single GPU or multi-GPU case.

total_L += sum([mx.nd.sum(L).asscalar() for L in Ls]) / len(context)
if args.ntasgd and ntasgd:
if param_dict_avg is None:
param_dict_avg = {k.split(model._prefix)[1]: v.data(context[0]).copy()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is no need to remove the model prefix. Otherwise, you cannot use load_parameters directly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should, I have verified.

('700b532dc96a29e39f45cb7dd632ce44e377a752', 'standard_lstm_lm_200_wikitext-2'),
('a416351377d837ef12d17aae27739393f59f0b82', 'standard_lstm_lm_1500_wikitext-2'),
('631f39040cd65b49f5c8828a0aba65606d73a9cb', 'standard_lstm_lm_650_wikitext-2'),
('b233c700e80fb0846c17fe14846cb7e08db3fd51', 'standard_lstm_lm_200_wikitext-2'),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to have these better models soon. Given that the script looks like it requires more time, @cgraywang could you create a separate PR for updating the model?

@mli
Copy link
Member

mli commented Aug 23, 2018

Job PR-170/42 is complete.
Docs are uploaded to http://gluon-nlp-staging.s3-accelerate.dualstack.amazonaws.com/PR-170/42/index.html


The dataset used for training the models is wikitext-2.

+--------------------+-----------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
| Model | cache_awd_lstm_lm_1150_wikitext-2 | cache_awd_lstm_lm_600_wikitext-2 | cache_standard_lstm_lm_1500_wikitext-2 | cache_standard_lstm_lm_650_wikitext-2 | cache_standard_lstm_lm_200_wikitext-2 |
+====================+===================================================================================================================================+==================================================================================================================================+========================================================================================================================================+=======================================================================================================================================+=======================================================================================================================================+
| Pre-trained setting | Refer to: awd_lstm_lm_1150_wikitext-2 | Refer to: awd_lstm_lm_600_wikitext-2 | Refer to: standard_lstm_lm_1500_wikitext-2 | Refer to: standard_lstm_lm_650_wikitext-2 | Refer to: standard_lstm_lm_200_wikitext-2 |
| Pretrained setting | Refer to: awd_lstm_lm_1150_wikitext-2 | Refer to: awd_lstm_lm_600_wikitext-2 | Refer to: standard_lstm_lm_1500_wikitext-2 | Refer to: standard_lstm_lm_650_wikitext-2 | Refer to: standard_lstm_lm_200_wikitext-2 |
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Revert to pre-trained

args = parser.parse_args()

###############################################################################
# Load data
###############################################################################

context = [mx.cpu()] if args.gpus is None or args.gpus == '' else \
[mx.gpu(int(x)) for x in args.gpus.split(',')]
context = [mx.cpu()] if args.gpu is None or args.gpu == '' else [mx.gpu(int(args.gpu))]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if not arts.gpu covers both the none case and empty string case.

l = joint_loss(output, y, encoder_hs, dropped_encoder_hs)
L = L + l.as_in_context(context[0]) / X.size
Ls.append(l/X.size)
Ls.append(l.as_in_context(context[0]) / (len(context) * X.size))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The copy is not necessary in either single GPU or multi-GPU case.

@mli
Copy link
Member

mli commented Aug 25, 2018

Job PR-170/44 is complete.
Docs are uploaded to http://gluon-nlp-staging.s3-accelerate.dualstack.amazonaws.com/PR-170/44/index.html

@mli
Copy link
Member

mli commented Aug 27, 2018

Job PR-170/45 is complete.
Docs are uploaded to http://gluon-nlp-staging.s3-accelerate.dualstack.amazonaws.com/PR-170/45/index.html

@mli
Copy link
Member

mli commented Aug 31, 2018

Job PR-170/46 is complete.
Docs are uploaded to http://gluon-nlp-staging.s3-accelerate.dualstack.amazonaws.com/PR-170/46/index.html

@mli
Copy link
Member

mli commented Sep 4, 2018

Job PR-170/47 is complete.
Docs are uploaded to http://gluon-nlp-staging.s3-accelerate.dualstack.amazonaws.com/PR-170/47/index.html

@codecov
Copy link

codecov bot commented Sep 18, 2018

Codecov Report

❗ No coverage uploaded for pull request base (master@fb27033). Click here to learn what that means.
The diff coverage is 0%.

Impacted file tree graph

@@            Coverage Diff            @@
##             master     #170   +/-   ##
=========================================
  Coverage          ?   76.74%           
=========================================
  Files             ?       79           
  Lines             ?     6462           
  Branches          ?     1019           
=========================================
  Hits              ?     4959           
  Misses            ?     1263           
  Partials          ?      240
Impacted Files Coverage Δ
gluonnlp/model/language_model.py 94.28% <ø> (ø)
scripts/language_model/cache_language_model.py 0% <0%> (ø)
scripts/language_model/word_language_model.py 0% <0%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update fb27033...541e5b0. Read the comment docs.

@@ -89,6 +87,8 @@ def test_text_models(wikitext2_val_and_counter):
model.collect_params().initialize()
output, state = model(mx.nd.arange(330).reshape(33, 10))
output.wait_to_read()
del model
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @szha for suggesting these changes!

@mli
Copy link
Member

mli commented Sep 18, 2018

Job PR-170/56 is complete.
Docs are uploaded to http://gluon-nlp-staging.s3-accelerate.dualstack.amazonaws.com/PR-170/56/index.html

help='lr udpate interval')
parser.add_argument('--lr_update_factor', type=float, default=0.1,
help='lr udpate factor')
parser.add_argument('--gpu', type=str, help='single gpu id')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can use int as type for a single gpu

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The multigpu will eventually get supported, to make aligned, I think use str is fine.

@codecov
Copy link

codecov bot commented Sep 18, 2018

Codecov Report

❗ No coverage uploaded for pull request base (master@fb27033). Click here to learn what that means.
The diff coverage is 0%.

Impacted file tree graph

@@            Coverage Diff            @@
##             master     #170   +/-   ##
=========================================
  Coverage          ?   76.74%           
=========================================
  Files             ?       79           
  Lines             ?     6462           
  Branches          ?     1019           
=========================================
  Hits              ?     4959           
  Misses            ?     1263           
  Partials          ?      240
Impacted Files Coverage Δ
gluonnlp/model/language_model.py 94.28% <ø> (ø)
scripts/language_model/cache_language_model.py 0% <0%> (ø)
scripts/language_model/word_language_model.py 0% <0%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update fb27033...3a1de71. Read the comment docs.

@mli
Copy link
Member

mli commented Sep 18, 2018

Job PR-170/57 is complete.
Docs are uploaded to http://gluon-nlp-staging.s3-accelerate.dualstack.amazonaws.com/PR-170/57/index.html

@cgraywang cgraywang merged commit a08eaf2 into dmlc:master Sep 19, 2018
paperplanet pushed a commit to paperplanet/gluon-nlp that referenced this pull request Jun 9, 2019
* Update language model script with NTSGD and new model results

* Polishing scripts and results presentation
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
release focus Progress focus for release
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants