Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Add embedding related methods in numpy version #1263

Merged
merged 12 commits into from
Jul 28, 2020
Merged

Conversation

acphile
Copy link
Contributor

@acphile acphile commented Jul 15, 2020

Description

Create embedding related methods in 'gluonnlp.embedding':

embed_loader.list_sources: Get valid token embedding names and their pre-trained file names.
embed_loader.load_embeddings: Load pretrained embedding file to build an embedding matrix for a given Vocab.
evaluation.CosineSimilarity: a function to compute the cosine similarity.
evaluation.HyperbolicCosineSimilarity: a function to compute the cosine similarity in the Hyperbolic space.
evaluation.ThreeCosAdd: a Class for 3CosAdd analogy.
evaluation.ThreeCosMul: a Class for 3CosMul analogy.

About evaluation

Currently the implementations of embedding.evaluation are not very satisfactory. Suggestions are welcome.

Checklist

Essentials

  • PR's title starts with a category (e.g. [BUGFIX], [MODEL], [TUTORIAL], [FEATURE], [DOC], etc)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage
  • Code is well-documented

Changes

  • Feature1, tests, (and when applicable, API doc)
  • Feature2, tests, (and when applicable, API doc)

Comments

  • If this change is a backward incompatible change, why must this change be made.
  • Interesting edge cases to note here

cc @dmlc/gluon-nlp-team

@codecov
Copy link

codecov bot commented Jul 15, 2020

Codecov Report

Merging #1263 into numpy will increase coverage by 0.23%.
The diff coverage is 83.04%.

Impacted file tree graph

@@            Coverage Diff             @@
##            numpy    #1263      +/-   ##
==========================================
+ Coverage   82.44%   82.67%   +0.23%     
==========================================
  Files          38       41       +3     
  Lines        5450     5702     +252     
==========================================
+ Hits         4493     4714     +221     
- Misses        957      988      +31     
Impacted Files Coverage Δ
src/gluonnlp/embedding/embed_loader.py 81.52% <81.52%> (ø)
src/gluonnlp/__init__.py 100.00% <100.00%> (ø)
src/gluonnlp/attention_cell.py 79.74% <100.00%> (-0.26%) ⬇️
src/gluonnlp/embedding/__init__.py 100.00% <100.00%> (ø)
src/gluonnlp/embedding/_constants.py 100.00% <100.00%> (ø)
src/gluonnlp/op.py 60.00% <100.00%> (+2.10%) ⬆️
src/gluonnlp/models/roberta.py 88.78% <0.00%> (-4.48%) ⬇️
src/gluonnlp/models/xlmr.py 86.88% <0.00%> (-1.12%) ⬇️
src/gluonnlp/layers.py 86.78% <0.00%> (-0.45%) ⬇️
src/gluonnlp/models/transformer_xl.py 82.71% <0.00%> (-0.22%) ⬇️
... and 16 more

x = x.reshape(-1, dim)
y = y.reshape(-1, dim)
x = mx.nd.L2Normalization(x, eps=eps).asnumpy()
y = mx.nd.L2Normalization(y, eps=eps).asnumpy()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about to move this function to op.py and then reuse this function?

def l2_normalize(F, data, axis=-1, eps=1E-6):
"""Normalize the data by L2 normalization.
Parameters
----------
F : mx.sym or mx.nd
data : symbol or ndarray
axis : int, default -1
eps : float, default 1E-6
Returns
-------
ret : mx.sym or mx.nd
"""
ret = data / (F.np.linalg.norm(data, axis=axis, keepdims=True) + eps)
return ret


return None

def load_embeddings(vocab, pretrained_name_or_dir='glove.6B.50d', unknown='<unk>',
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about to use separate classes for handling different embedding vectors. You may just implement some basic functions. For example, implement the model.load() functionality to load the parameter. And model.similarity(a, b, method=None) function to evaluate the similarity. Thus, it will be more structured and close to how we will evaluate it in the paper.

class KeyedVector:
   def __init__(self):

   def load(cls, path):
      ...

   def similarity(self, a, b):
      ...
class FastText
   def __init__(self):
      ...
   def load(cls, path):
      ...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we just load the embedding vectors from a small set of vocabulary, we can also add some flags to the load function, e.g., def load(cls, path, vocab=None, num_tokens=None):, in which vocab means a potential vocabulary object to help reduce the number of tokens to load. For example, the data augmentation algorithm in TinyBERT only loads the first 100000 words for fast lookup: https://github.com/huawei-noah/Pretrained-Language-Model/blob/e670706c041246b975a3646bc6a27c48786f6c15/TinyBERT/data_augmentation.py#L75

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can set a base class Class EmbeddingModel(HybridBlock) to serve as the base class of embedding models. And we can attach some evaluation functions to this class. For just loading a embedding matrix, we can simply use the current load_embeddings for users to manually set_data, or have Class WordEmbedding(EmbeddingModel) and move the functionality of load_embeddings to this class. And complex embedding models like FastText, character-level CNN can be further implemented based on EmbeddingModel. This embedding models may be implemented in models/

@sxjscience
Copy link
Member

sxjscience commented Jul 18, 2020 via email

@acphile acphile marked this pull request as draft July 20, 2020 10:28
@acphile
Copy link
Contributor Author

acphile commented Jul 21, 2020

Is it possible to get the embedding of words in raw text if it’s a HybridBlock? We may want to calculate the embedding from raw text or out-of-vocabulary words. which is the purpose of FastText. Get Outlook for iOShttps://aka.ms/o0ukef
________________________________ From: acphile notifications@github.com Sent: Saturday, July 18, 2020 2:13:41 AM To: dmlc/gluon-nlp gluon-nlp@noreply.github.com Cc: Xingjian SHI xshiab@connect.ust.hk; Comment comment@noreply.github.com Subject: Re: [dmlc/gluon-nlp] Add embedding related methods in numpy version (#1263) @acphile commented on this pull request.
________________________________ In src/gluonnlp/embedding/embed_loader.py<#1263 (comment)>:

  • for cls_name, embedding_cls in text_embedding_reg.items():
  • if pretrained_name_or_dir in embedding_cls: + source = pretrained_name_or_dir + embedding_dir = os.path.join(root_path, cls_name) + file_name, file_hash = embedding_cls[source] + url = _get_file_url(cls_name, file_name) + file_path = os.path.join(embedding_dir, file_name) + if not os.path.exists(file_path) or not check_sha1(file_path, file_hash): + logging.info('Embedding file {} is not found. Downloading from Gluon Repository. ' + 'This may take some time.'.format(file_name)) + download(url, file_path, sha1_hash=file_hash) + return file_path + + return None + +def load_embeddings(vocab, pretrained_name_or_dir='glove.6B.50d', unknown='', I think we can set a base class Class EmbeddingModel(HybridBlock) to serve as the base class of embedding models. And we can attach some evaluation functions to this class. For just loading a embedding matrix, we can simply use the current load_embeddings for users to manually set_data, or have Class WordEmbedding(EmbeddingModel) and move the functionality of load_embeddings to this class. And complex embedding models like FastText, character-level CNN can be further implemented based on EmbeddingModel. This embedding models may be implemented in models/ — You are receiving this because you commented. Reply to this email directly, view it on GitHub<#1263 (comment)>, or unsubscribehttps://github.com/notifications/unsubscribe-auth/ABHQH3QCI7C27JYYSU5VEKLR4FRULANCNFSM4O2EOW6Q.

For getting the embeddings for unknown words, there are the following situations:

1. words in the vocabulary but not in the embedding file

The default method is to sample from normal distribution. And users can use unk_method to define your own way. I made some simple tests that fasttext.cc actually computes faster than the original gluon approach in v0.9.x, so you simply can do as follows:

fast = fasttext.load_model('model.bin')
def ngram(words):
     return np.array([fast[word] for word in words])
embedding_matrix = load_embeddings(vocab, source, unk_method=ngram)

In this case, we get an embedding matrix for a given vocabulary.
Now I have added the feature that use the embedding file as the vocabulary:

embedding_matrix, vocab = load_embeddings(vocab=None, pretrained_name_or_dir=source)

2. words not in the vocabulary

I think generally we just use the embedding of <unk> for these OOV words. Of course we can still generate some initial embedding vectors by FastText, but since they are not updated during training, I don't think it is very useful.

To further maintain the information from these words, in practice we may use some character-level NN. For example, we may use a character-level CNN to compute the embedding of a word and the parameters are learnable during training. That's why I think we should create a base class EmbeddingModel(HybridBlock) and have some embedding models (in the other words, part of the Neural Network) so that it would be easier for users to build their NLP models.
EmbeddingModel (and its children class) can serve as a black box: Inside there can be only a simple embedding lookup or some NN layers. What we want is to input List[word index] or List[List[character index] (maybe both or others) and get the word representations which can even be contextualized representations. So we can set the embedding evaluation functions as the class methods in a more general way.

@sxjscience
Copy link
Member

The advantage of fasttext is that there is no need to care about OOV words. Thus, you may need to still offer this functionality.

@sxjscience
Copy link
Member

sxjscience commented Jul 21, 2020

@acphile This is the advantage of using subwords. Basically, there will be no/less number of OOV words if you are using a subword repsentation. For example, GPT-2/GPT-3 chose to use the byte-based BPE encoding because there will never be OOV words. Also, you may check Section 2.1 of https://arxiv.org/pdf/1911.03688.pdf to see how different models may adopt different strategies for dealing with the OOV problem.

@acphile
Copy link
Contributor Author

acphile commented Jul 21, 2020

@acphile This is the advantage of using subwords. Basically, there will be no/less number of OOV words if you are using a subword repsentation. For example, GPT-2/GPT-3 chose to use the byte-based BPE encoding because there will never be OOV words. Also, you may check Section 2.1 of https://arxiv.org/pdf/1911.03688.pdf to see how different models may adopt different strategies for dealing with the OOV problem.

I understand that and in my context, vocabulary not only refers to vocabulary of a word ('vocabulary') but also refers to a lookup dict which records the different tokens for a certain input. For example, a vocabulary of trigram records the most trigrams which occurs in the dataset. Since raw text can be transformed to several different types of inputs for models (like List[word], List[List[ngram]],List[BPE]) by tokenizer and further transformed to List[int] or List[List[int]] by Vocab. For embedding part, we only need to handle with integers and that's why previous I suggest using EmbeddingModel(HybridBlock). Of course FastText is very useful and I think it is better to implement FastText as Class FastText(EmbeddingModel) instead of Class FastText:. It is somewhat like https://github.com/dmlc/gluon-nlp/blob/v0.9.x/src/gluonnlp/model/train/embedding.py#L175 but I think we could improve its implementation.

@sxjscience
Copy link
Member

@acphile The problem is that it will be inefficient to have the tokenizer output all the ngram combinations. Instead, you ask the tokenizer to output a list of tokens and each token will be converted to the embedding.

@sxjscience
Copy link
Member

@acphile Is it possible to also refer to the implementation in gensim https://radimrehurek.com/gensim/models/fasttext.html#module-gensim.models.fasttext?

@acphile
Copy link
Contributor Author

acphile commented Jul 21, 2020

@acphile The problem is that it will be inefficient to have the tokenizer output all the ngram combinations. Instead, you ask the tokenizer to output a list of tokens and each token will be converted to the embedding.

For each token, gensim still output all ngrams to compute the corresponding embeddings: https://github.com/RaRe-Technologies/gensim/blob/c0e0169565116854993b22efef29e3c402ec6c69/gensim/models/fasttext_inner.pyx#L672
And they use hash buckets for converting ngrams to indexes: https://github.com/RaRe-Technologies/gensim/blob/c0e0169565116854993b22efef29e3c402ec6c69/gensim/models/fasttext.py#L1289
I think maybe we can make a supplement for hash lookup in vocab.

@sxjscience
Copy link
Member

sxjscience commented Jul 21, 2020 via email

from .evaluation import CosineSimilarity, HyperbolicCosineSimilarity
from ..data import Vocab

class StaticEmbedding:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the AutoML use-cases of data augmentation, we may need more efficient similarity implementations (eg https://github.com/facebookresearch/faiss). So associating a similarity function to an embedding class may not be a good way forward. For this PR we can focus on adding the load_embeddings and an equivalent for the fasttext package and it may not be necessary to add the embed_container.py

import numpy as np
from ..op import l2_normalize

__all__ = ['CosineSimilarity', 'ThreeCosMul', 'ThreeCosAdd', 'HyperbolicCosineSimilarity']
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can add these in a scripts/ folder as part of the evaluation scripts for the similarity and analogy datasets. We don't need to add it to the main API as part of this PR. (The scripts/ for evaluation are not required for automl and you don't need to add it if you're not interested in trying it)

@acphile acphile marked this pull request as ready for review July 22, 2020 06:50
@sxjscience
Copy link
Member

Can you add some tests in https://github.com/dmlc/gluon-nlp/tree/numpy/tests?

with pytest.raises(ValueError):
get_fasttext_model('wiki.multi.ar')


Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: newline.

Copy link
Member

@sxjscience sxjscience left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

raise ValueError('Cannot recognize {} for the bin file'.format(source))
file_name, file_hash = C.FAST_TEXT_BIN_SHA1[source]
file_path = _get_file_path('fasttext', file_name, file_hash)
return fasttext.load_model(file_path)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just noticed one issue. Can we support multiprocessing with FastText embedding?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you give a concrete example about multiprocessing with FastText embedding? I'm not very clear about it. Thanks!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically, for example:

fasttext_model = load_fasttext(...)
with multiprocessing.Pool(4) as pool:
   out = pool.map(..., fasttext_model.encode(...))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean that we need a function which support multiprocessing inside using fasttext.cc's APIs like your example or are you worried about if fasttext.cc's APIs could work in the multiprocessing setting?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After offline discussion with Leo, this looks good and let me merge this in.

@sxjscience sxjscience merged commit d76897b into dmlc:numpy Jul 28, 2020
zheyuye added a commit to zheyuye/gluon-nlp that referenced this pull request Jul 29, 2020
commit 232e0b6
Author: ZheyuYe <zheyu.ye1995@gmail.com>
Date:   Thu Jul 30 01:05:17 2020 +0800

    update

commit 995e5d7
Author: ZheyuYe <zheyu.ye1995@gmail.com>
Date:   Thu Jul 30 01:01:56 2020 +0800

    fix

commit 9623240
Author: ZheyuYe <zheyu.ye1995@gmail.com>
Date:   Thu Jul 30 00:52:17 2020 +0800

    fix

commit d9c4140
Author: ZheyuYe <zheyu.ye1995@gmail.com>
Date:   Wed Jul 29 23:07:10 2020 +0800

    fix transformer

commit e49fbe1
Author: ZheyuYe <zheyu.ye1995@gmail.com>
Date:   Wed Jul 29 22:18:12 2020 +0800

    update

commit 1f75b26
Author: ZheyuYe <zheyu.ye1995@gmail.com>
Date:   Wed Jul 29 22:04:08 2020 +0800

    test bart

commit 5bab516
Author: ZheyuYe <zheyu.ye1995@gmail.com>
Date:   Wed Jul 29 21:34:47 2020 +0800

    fix cfg

commit 6c62a29
Merge: 3366cf3 033214e
Author: ZheyuYe <zheyu.ye1995@gmail.com>
Date:   Wed Jul 29 21:33:10 2020 +0800

    Merge remote-tracking branch 'upstream/numpy' into bart

commit 033214e
Author: Xingjian Shi <xshiab@connect.ust.hk>
Date:   Wed Jul 29 00:36:57 2020 -0700

    [Numpy] Fix SQuAD + Fix GLUE downloading (dmlc#1280)

    * Update run_squad.py

    * Update run_squad.py

    * Update prepare_glue.py

commit 3c87457
Author: Xingjian Shi <xshiab@connect.ust.hk>
Date:   Tue Jul 28 18:03:21 2020 -0700

    Add layout + compute_layout support: TransformerNMT, BERT, ALBERT, ELECTRA, MobileBERT, RoBERTA, XLMR (dmlc#1258)

    * Add layout support

    * fix test

    * Update transformer.py

    * Update transformer.py

    * Update README.md

    * try to add set_layout

    * update test case

    * fix

    * update

    * update

    * update

    * Update bert.py

    * fix bug

    * update

    * Update test_models_bert.py

    * Update tokenizers.py

    * add compute layout

    * Update xlmr.py

    * Update test_models_bert.py

    * revise test cases

    * Update layers.py

    * move jieba to try import

    * fix

    * Update transformer.py

    * fix

    * Update bert.py

    * Update setup.py

    * Update test_models_bert.py

    * Update test_models_bert.py

    * fix

    * update

    * Revise

    * Update electra.py

    * Update electra.py

    * Update test_models_electra.py

    * fix

    * fix bug

    * Update test_models_albert.py

    * add more testcases

    * fix

    * Update albert.py

    * Update albert.py

    * fix bug

    * fix testcase

    * Update test_models_electra.py

    * Update bert.py

    * update

    * Update test_models_electra.py

    * Update mobilebert.py

    * Update mobilebert.py

    * update mobilebert

    * Update test_models_mobilebert.py

    * Update mobilebert.py

    * fix bug

    * Update roberta.py

    * fix roberta

    * update

    * update

    * fix import

    * fix bug

    * update

    * reduce test workloads

    * address comment

    * address comment

commit 4d43f82
Author: Sheng Zha <szha@users.noreply.github.com>
Date:   Mon Jul 27 20:21:00 2020 -0700

    add subversion/wget to docker, add readme (dmlc#1279)

commit d76897b
Author: phile <phile_999@126.com>
Date:   Tue Jul 28 10:10:13 2020 +0800

    Add embedding related methods in numpy version (dmlc#1263)

    * A draft for embedding

    * fix embed_loader

    * add hyperbolic space and some updates

    * revise evaluation

    * fix

    * simple fixes

    * move l2norm to op.py

    * new features

    * fix

    * update

    * add tests, update

    * newline
zheyuye added a commit to zheyuye/gluon-nlp that referenced this pull request Jul 29, 2020
commit 510d991
Author: ZheyuYe <zheyu.ye1995@gmail.com>
Date:   Thu Jul 30 02:33:22 2020 +0800

    test

commit 1b5fa7b
Author: ZheyuYe <zheyu.ye1995@gmail.com>
Date:   Thu Jul 30 01:48:01 2020 +0800

    fix comment1

commit 6533601
Author: ZheyuYe <zheyu.ye1995@gmail.com>
Date:   Thu Jul 30 01:27:44 2020 +0800

    fix comment

commit a8853f9
Author: ZheyuYe <zheyu.ye1995@gmail.com>
Date:   Thu Jul 30 01:10:06 2020 +0800

    Squashed commit of the following:

    commit 232e0b6
    Author: ZheyuYe <zheyu.ye1995@gmail.com>
    Date:   Thu Jul 30 01:05:17 2020 +0800

        update

    commit 995e5d7
    Author: ZheyuYe <zheyu.ye1995@gmail.com>
    Date:   Thu Jul 30 01:01:56 2020 +0800

        fix

    commit 9623240
    Author: ZheyuYe <zheyu.ye1995@gmail.com>
    Date:   Thu Jul 30 00:52:17 2020 +0800

        fix

    commit d9c4140
    Author: ZheyuYe <zheyu.ye1995@gmail.com>
    Date:   Wed Jul 29 23:07:10 2020 +0800

        fix transformer

    commit e49fbe1
    Author: ZheyuYe <zheyu.ye1995@gmail.com>
    Date:   Wed Jul 29 22:18:12 2020 +0800

        update

    commit 1f75b26
    Author: ZheyuYe <zheyu.ye1995@gmail.com>
    Date:   Wed Jul 29 22:04:08 2020 +0800

        test bart

    commit 5bab516
    Author: ZheyuYe <zheyu.ye1995@gmail.com>
    Date:   Wed Jul 29 21:34:47 2020 +0800

        fix cfg

    commit 6c62a29
    Merge: 3366cf3 033214e
    Author: ZheyuYe <zheyu.ye1995@gmail.com>
    Date:   Wed Jul 29 21:33:10 2020 +0800

        Merge remote-tracking branch 'upstream/numpy' into bart

    commit 033214e
    Author: Xingjian Shi <xshiab@connect.ust.hk>
    Date:   Wed Jul 29 00:36:57 2020 -0700

        [Numpy] Fix SQuAD + Fix GLUE downloading (dmlc#1280)

        * Update run_squad.py

        * Update run_squad.py

        * Update prepare_glue.py

    commit 3c87457
    Author: Xingjian Shi <xshiab@connect.ust.hk>
    Date:   Tue Jul 28 18:03:21 2020 -0700

        Add layout + compute_layout support: TransformerNMT, BERT, ALBERT, ELECTRA, MobileBERT, RoBERTA, XLMR (dmlc#1258)

        * Add layout support

        * fix test

        * Update transformer.py

        * Update transformer.py

        * Update README.md

        * try to add set_layout

        * update test case

        * fix

        * update

        * update

        * update

        * Update bert.py

        * fix bug

        * update

        * Update test_models_bert.py

        * Update tokenizers.py

        * add compute layout

        * Update xlmr.py

        * Update test_models_bert.py

        * revise test cases

        * Update layers.py

        * move jieba to try import

        * fix

        * Update transformer.py

        * fix

        * Update bert.py

        * Update setup.py

        * Update test_models_bert.py

        * Update test_models_bert.py

        * fix

        * update

        * Revise

        * Update electra.py

        * Update electra.py

        * Update test_models_electra.py

        * fix

        * fix bug

        * Update test_models_albert.py

        * add more testcases

        * fix

        * Update albert.py

        * Update albert.py

        * fix bug

        * fix testcase

        * Update test_models_electra.py

        * Update bert.py

        * update

        * Update test_models_electra.py

        * Update mobilebert.py

        * Update mobilebert.py

        * update mobilebert

        * Update test_models_mobilebert.py

        * Update mobilebert.py

        * fix bug

        * Update roberta.py

        * fix roberta

        * update

        * update

        * fix import

        * fix bug

        * update

        * reduce test workloads

        * address comment

        * address comment

    commit 4d43f82
    Author: Sheng Zha <szha@users.noreply.github.com>
    Date:   Mon Jul 27 20:21:00 2020 -0700

        add subversion/wget to docker, add readme (dmlc#1279)

    commit d76897b
    Author: phile <phile_999@126.com>
    Date:   Tue Jul 28 10:10:13 2020 +0800

        Add embedding related methods in numpy version (dmlc#1263)

        * A draft for embedding

        * fix embed_loader

        * add hyperbolic space and some updates

        * revise evaluation

        * fix

        * simple fixes

        * move l2norm to op.py

        * new features

        * fix

        * update

        * add tests, update

        * newline
zheyuye added a commit to zheyuye/gluon-nlp that referenced this pull request Jul 30, 2020
commit 9e1ffde
Author: ZheyuYe <zheyu.ye1995@gmail.com>
Date:   Thu Jul 30 11:42:01 2020 +0800

    todo

commit 9a7c343
Author: ZheyuYe <zheyu.ye1995@gmail.com>
Date:   Thu Jul 30 10:53:15 2020 +0800

    revert gelu

commit 0425346
Author: ZheyuYe <zheyu.ye1995@gmail.com>
Date:   Thu Jul 30 10:49:52 2020 +0800

    re-upload bart

commit 516ae84
Author: ZheyuYe <zheyu.ye1995@gmail.com>
Date:   Thu Jul 30 03:32:35 2020 +0800

    use_qkv_bias for transformer

commit 9d60cda
Author: ZheyuYe <zheyu.ye1995@gmail.com>
Date:   Thu Jul 30 03:17:28 2020 +0800

    classifier_activation

commit 510d991
Author: ZheyuYe <zheyu.ye1995@gmail.com>
Date:   Thu Jul 30 02:33:22 2020 +0800

    test

commit 1b5fa7b
Author: ZheyuYe <zheyu.ye1995@gmail.com>
Date:   Thu Jul 30 01:48:01 2020 +0800

    fix comment1

commit 6533601
Author: ZheyuYe <zheyu.ye1995@gmail.com>
Date:   Thu Jul 30 01:27:44 2020 +0800

    fix comment

commit a8853f9
Author: ZheyuYe <zheyu.ye1995@gmail.com>
Date:   Thu Jul 30 01:10:06 2020 +0800

    Squashed commit of the following:

    commit 232e0b6
    Author: ZheyuYe <zheyu.ye1995@gmail.com>
    Date:   Thu Jul 30 01:05:17 2020 +0800

        update

    commit 995e5d7
    Author: ZheyuYe <zheyu.ye1995@gmail.com>
    Date:   Thu Jul 30 01:01:56 2020 +0800

        fix

    commit 9623240
    Author: ZheyuYe <zheyu.ye1995@gmail.com>
    Date:   Thu Jul 30 00:52:17 2020 +0800

        fix

    commit d9c4140
    Author: ZheyuYe <zheyu.ye1995@gmail.com>
    Date:   Wed Jul 29 23:07:10 2020 +0800

        fix transformer

    commit e49fbe1
    Author: ZheyuYe <zheyu.ye1995@gmail.com>
    Date:   Wed Jul 29 22:18:12 2020 +0800

        update

    commit 1f75b26
    Author: ZheyuYe <zheyu.ye1995@gmail.com>
    Date:   Wed Jul 29 22:04:08 2020 +0800

        test bart

    commit 5bab516
    Author: ZheyuYe <zheyu.ye1995@gmail.com>
    Date:   Wed Jul 29 21:34:47 2020 +0800

        fix cfg

    commit 6c62a29
    Merge: 3366cf3 033214e
    Author: ZheyuYe <zheyu.ye1995@gmail.com>
    Date:   Wed Jul 29 21:33:10 2020 +0800

        Merge remote-tracking branch 'upstream/numpy' into bart

    commit 033214e
    Author: Xingjian Shi <xshiab@connect.ust.hk>
    Date:   Wed Jul 29 00:36:57 2020 -0700

        [Numpy] Fix SQuAD + Fix GLUE downloading (dmlc#1280)

        * Update run_squad.py

        * Update run_squad.py

        * Update prepare_glue.py

    commit 3c87457
    Author: Xingjian Shi <xshiab@connect.ust.hk>
    Date:   Tue Jul 28 18:03:21 2020 -0700

        Add layout + compute_layout support: TransformerNMT, BERT, ALBERT, ELECTRA, MobileBERT, RoBERTA, XLMR (dmlc#1258)

        * Add layout support

        * fix test

        * Update transformer.py

        * Update transformer.py

        * Update README.md

        * try to add set_layout

        * update test case

        * fix

        * update

        * update

        * update

        * Update bert.py

        * fix bug

        * update

        * Update test_models_bert.py

        * Update tokenizers.py

        * add compute layout

        * Update xlmr.py

        * Update test_models_bert.py

        * revise test cases

        * Update layers.py

        * move jieba to try import

        * fix

        * Update transformer.py

        * fix

        * Update bert.py

        * Update setup.py

        * Update test_models_bert.py

        * Update test_models_bert.py

        * fix

        * update

        * Revise

        * Update electra.py

        * Update electra.py

        * Update test_models_electra.py

        * fix

        * fix bug

        * Update test_models_albert.py

        * add more testcases

        * fix

        * Update albert.py

        * Update albert.py

        * fix bug

        * fix testcase

        * Update test_models_electra.py

        * Update bert.py

        * update

        * Update test_models_electra.py

        * Update mobilebert.py

        * Update mobilebert.py

        * update mobilebert

        * Update test_models_mobilebert.py

        * Update mobilebert.py

        * fix bug

        * Update roberta.py

        * fix roberta

        * update

        * update

        * fix import

        * fix bug

        * update

        * reduce test workloads

        * address comment

        * address comment

    commit 4d43f82
    Author: Sheng Zha <szha@users.noreply.github.com>
    Date:   Mon Jul 27 20:21:00 2020 -0700

        add subversion/wget to docker, add readme (dmlc#1279)

    commit d76897b
    Author: phile <phile_999@126.com>
    Date:   Tue Jul 28 10:10:13 2020 +0800

        Add embedding related methods in numpy version (dmlc#1263)

        * A draft for embedding

        * fix embed_loader

        * add hyperbolic space and some updates

        * revise evaluation

        * fix

        * simple fixes

        * move l2norm to op.py

        * new features

        * fix

        * update

        * add tests, update

        * newline
sxjscience pushed a commit that referenced this pull request Jul 30, 2020
* init

* fix convert roberta

* rename TransformerNMTModel as TransformerModel

* update bart

* fix

* fix

* update init

* add layernorm_embedding for transformer

* convert script

* encoder

* fix

* fix vocab

* fix roberta

* fix

* fix electra

* add conversion bash for roberta and xlmr

* ELECTRA SETUP

* convert bart decoder

* fix

* update

* testing output

* remove arange_like for embeddings

* fix

* update

* use_pooler for bart

* fix

* upload params for bart

* add test_models_bart

* fix cfg

* test bart

* update

* fix transformer

* Squashed commit of the following:

commit 510d991
Author: ZheyuYe <zheyu.ye1995@gmail.com>
Date:   Thu Jul 30 02:33:22 2020 +0800

    test

commit 1b5fa7b
Author: ZheyuYe <zheyu.ye1995@gmail.com>
Date:   Thu Jul 30 01:48:01 2020 +0800

    fix comment1

commit 6533601
Author: ZheyuYe <zheyu.ye1995@gmail.com>
Date:   Thu Jul 30 01:27:44 2020 +0800

    fix comment

commit a8853f9
Author: ZheyuYe <zheyu.ye1995@gmail.com>
Date:   Thu Jul 30 01:10:06 2020 +0800

    Squashed commit of the following:

    commit 232e0b6
    Author: ZheyuYe <zheyu.ye1995@gmail.com>
    Date:   Thu Jul 30 01:05:17 2020 +0800

        update

    commit 995e5d7
    Author: ZheyuYe <zheyu.ye1995@gmail.com>
    Date:   Thu Jul 30 01:01:56 2020 +0800

        fix

    commit 9623240
    Author: ZheyuYe <zheyu.ye1995@gmail.com>
    Date:   Thu Jul 30 00:52:17 2020 +0800

        fix

    commit d9c4140
    Author: ZheyuYe <zheyu.ye1995@gmail.com>
    Date:   Wed Jul 29 23:07:10 2020 +0800

        fix transformer

    commit e49fbe1
    Author: ZheyuYe <zheyu.ye1995@gmail.com>
    Date:   Wed Jul 29 22:18:12 2020 +0800

        update

    commit 1f75b26
    Author: ZheyuYe <zheyu.ye1995@gmail.com>
    Date:   Wed Jul 29 22:04:08 2020 +0800

        test bart

    commit 5bab516
    Author: ZheyuYe <zheyu.ye1995@gmail.com>
    Date:   Wed Jul 29 21:34:47 2020 +0800

        fix cfg

    commit 6c62a29
    Merge: 3366cf3 033214e
    Author: ZheyuYe <zheyu.ye1995@gmail.com>
    Date:   Wed Jul 29 21:33:10 2020 +0800

        Merge remote-tracking branch 'upstream/numpy' into bart

    commit 033214e
    Author: Xingjian Shi <xshiab@connect.ust.hk>
    Date:   Wed Jul 29 00:36:57 2020 -0700

        [Numpy] Fix SQuAD + Fix GLUE downloading (#1280)

        * Update run_squad.py

        * Update run_squad.py

        * Update prepare_glue.py

    commit 3c87457
    Author: Xingjian Shi <xshiab@connect.ust.hk>
    Date:   Tue Jul 28 18:03:21 2020 -0700

        Add layout + compute_layout support: TransformerNMT, BERT, ALBERT, ELECTRA, MobileBERT, RoBERTA, XLMR (#1258)

        * Add layout support

        * fix test

        * Update transformer.py

        * Update transformer.py

        * Update README.md

        * try to add set_layout

        * update test case

        * fix

        * update

        * update

        * update

        * Update bert.py

        * fix bug

        * update

        * Update test_models_bert.py

        * Update tokenizers.py

        * add compute layout

        * Update xlmr.py

        * Update test_models_bert.py

        * revise test cases

        * Update layers.py

        * move jieba to try import

        * fix

        * Update transformer.py

        * fix

        * Update bert.py

        * Update setup.py

        * Update test_models_bert.py

        * Update test_models_bert.py

        * fix

        * update

        * Revise

        * Update electra.py

        * Update electra.py

        * Update test_models_electra.py

        * fix

        * fix bug

        * Update test_models_albert.py

        * add more testcases

        * fix

        * Update albert.py

        * Update albert.py

        * fix bug

        * fix testcase

        * Update test_models_electra.py

        * Update bert.py

        * update

        * Update test_models_electra.py

        * Update mobilebert.py

        * Update mobilebert.py

        * update mobilebert

        * Update test_models_mobilebert.py

        * Update mobilebert.py

        * fix bug

        * Update roberta.py

        * fix roberta

        * update

        * update

        * fix import

        * fix bug

        * update

        * reduce test workloads

        * address comment

        * address comment

    commit 4d43f82
    Author: Sheng Zha <szha@users.noreply.github.com>
    Date:   Mon Jul 27 20:21:00 2020 -0700

        add subversion/wget to docker, add readme (#1279)

    commit d76897b
    Author: phile <phile_999@126.com>
    Date:   Tue Jul 28 10:10:13 2020 +0800

        Add embedding related methods in numpy version (#1263)

        * A draft for embedding

        * fix embed_loader

        * add hyperbolic space and some updates

        * revise evaluation

        * fix

        * simple fixes

        * move l2norm to op.py

        * new features

        * fix

        * update

        * add tests, update

        * newline

* Squashed commit of the following:

commit 9e1ffde
Author: ZheyuYe <zheyu.ye1995@gmail.com>
Date:   Thu Jul 30 11:42:01 2020 +0800

    todo

commit 9a7c343
Author: ZheyuYe <zheyu.ye1995@gmail.com>
Date:   Thu Jul 30 10:53:15 2020 +0800

    revert gelu

commit 0425346
Author: ZheyuYe <zheyu.ye1995@gmail.com>
Date:   Thu Jul 30 10:49:52 2020 +0800

    re-upload bart

commit 516ae84
Author: ZheyuYe <zheyu.ye1995@gmail.com>
Date:   Thu Jul 30 03:32:35 2020 +0800

    use_qkv_bias for transformer

commit 9d60cda
Author: ZheyuYe <zheyu.ye1995@gmail.com>
Date:   Thu Jul 30 03:17:28 2020 +0800

    classifier_activation

commit 510d991
Author: ZheyuYe <zheyu.ye1995@gmail.com>
Date:   Thu Jul 30 02:33:22 2020 +0800

    test

commit 1b5fa7b
Author: ZheyuYe <zheyu.ye1995@gmail.com>
Date:   Thu Jul 30 01:48:01 2020 +0800

    fix comment1

commit 6533601
Author: ZheyuYe <zheyu.ye1995@gmail.com>
Date:   Thu Jul 30 01:27:44 2020 +0800

    fix comment

commit a8853f9
Author: ZheyuYe <zheyu.ye1995@gmail.com>
Date:   Thu Jul 30 01:10:06 2020 +0800

    Squashed commit of the following:

    commit 232e0b6
    Author: ZheyuYe <zheyu.ye1995@gmail.com>
    Date:   Thu Jul 30 01:05:17 2020 +0800

        update

    commit 995e5d7
    Author: ZheyuYe <zheyu.ye1995@gmail.com>
    Date:   Thu Jul 30 01:01:56 2020 +0800

        fix

    commit 9623240
    Author: ZheyuYe <zheyu.ye1995@gmail.com>
    Date:   Thu Jul 30 00:52:17 2020 +0800

        fix

    commit d9c4140
    Author: ZheyuYe <zheyu.ye1995@gmail.com>
    Date:   Wed Jul 29 23:07:10 2020 +0800

        fix transformer

    commit e49fbe1
    Author: ZheyuYe <zheyu.ye1995@gmail.com>
    Date:   Wed Jul 29 22:18:12 2020 +0800

        update

    commit 1f75b26
    Author: ZheyuYe <zheyu.ye1995@gmail.com>
    Date:   Wed Jul 29 22:04:08 2020 +0800

        test bart

    commit 5bab516
    Author: ZheyuYe <zheyu.ye1995@gmail.com>
    Date:   Wed Jul 29 21:34:47 2020 +0800

        fix cfg

    commit 6c62a29
    Merge: 3366cf3 033214e
    Author: ZheyuYe <zheyu.ye1995@gmail.com>
    Date:   Wed Jul 29 21:33:10 2020 +0800

        Merge remote-tracking branch 'upstream/numpy' into bart

    commit 033214e
    Author: Xingjian Shi <xshiab@connect.ust.hk>
    Date:   Wed Jul 29 00:36:57 2020 -0700

        [Numpy] Fix SQuAD + Fix GLUE downloading (#1280)

        * Update run_squad.py

        * Update run_squad.py

        * Update prepare_glue.py

    commit 3c87457
    Author: Xingjian Shi <xshiab@connect.ust.hk>
    Date:   Tue Jul 28 18:03:21 2020 -0700

        Add layout + compute_layout support: TransformerNMT, BERT, ALBERT, ELECTRA, MobileBERT, RoBERTA, XLMR (#1258)

        * Add layout support

        * fix test

        * Update transformer.py

        * Update transformer.py

        * Update README.md

        * try to add set_layout

        * update test case

        * fix

        * update

        * update

        * update

        * Update bert.py

        * fix bug

        * update

        * Update test_models_bert.py

        * Update tokenizers.py

        * add compute layout

        * Update xlmr.py

        * Update test_models_bert.py

        * revise test cases

        * Update layers.py

        * move jieba to try import

        * fix

        * Update transformer.py

        * fix

        * Update bert.py

        * Update setup.py

        * Update test_models_bert.py

        * Update test_models_bert.py

        * fix

        * update

        * Revise

        * Update electra.py

        * Update electra.py

        * Update test_models_electra.py

        * fix

        * fix bug

        * Update test_models_albert.py

        * add more testcases

        * fix

        * Update albert.py

        * Update albert.py

        * fix bug

        * fix testcase

        * Update test_models_electra.py

        * Update bert.py

        * update

        * Update test_models_electra.py

        * Update mobilebert.py

        * Update mobilebert.py

        * update mobilebert

        * Update test_models_mobilebert.py

        * Update mobilebert.py

        * fix bug

        * Update roberta.py

        * fix roberta

        * update

        * update

        * fix import

        * fix bug

        * update

        * reduce test workloads

        * address comment

        * address comment

    commit 4d43f82
    Author: Sheng Zha <szha@users.noreply.github.com>
    Date:   Mon Jul 27 20:21:00 2020 -0700

        add subversion/wget to docker, add readme (#1279)

    commit d76897b
    Author: phile <phile_999@126.com>
    Date:   Tue Jul 28 10:10:13 2020 +0800

        Add embedding related methods in numpy version (#1263)

        * A draft for embedding

        * fix embed_loader

        * add hyperbolic space and some updates

        * revise evaluation

        * fix

        * simple fixes

        * move l2norm to op.py

        * new features

        * fix

        * update

        * add tests, update

        * newline

* fix comment

* use xavier for embedding initializer
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants