Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add model resources for CPMAnt (new) #20906

Merged
merged 113 commits into from
Apr 12, 2023
Merged

Conversation

pioliverse
Copy link
Contributor

@pioliverse pioliverse commented Dec 27, 2022

What does this PR do?

Since the previous submission(#20711 ) had problems here and there, we have now resubmitted a new one.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@pioliverse
Copy link
Contributor Author

pioliverse commented Dec 27, 2022

Thanks very much @pioliverse for iterating! I left a couple of comments, I think that some refactoring needs to be considered, after that we should be close to merge this! My main comments are:

  • I think that you can wrap CPMAntEmbedding around a nn.Embedding layer even though scaling is needed. You can scale down after each call to the embedding module and make sure the input is scaled down before the projection call.
  • Make sure to inherit CPMAntForCausalLM from CPMAntPreTrainedModel, also make sure to follow the convention / good practices by checking what is done in OPT for instance:
    class OPTForCausalLM(OPTPreTrainedModel):
    • this includes defining correctly a lm_head module, functions such as get_input_embeddings, set_input_embeddings, etc.
  • A lot of arguments from module's init seems to be unused, e.g. init_std. Try also to take the config object as a single argument from the init whenever possible (e.g. CPMAntEncoder)
  • Please make sure to follow the correct styling for docstrings (check my comments about that below)
  • If you have to initialize some weights with a specific distribution, try to initialize all the submodules weights inside _init_weights function from CPMAntPreTrainedModel
  • It's unclear to me why forward function is not defined in CPMAntForCausalLM
  • The code can be optimized here and there, I left some comments below on how you can achieve that
  • Please do not raise RuntimeErrors outside if torch_is_available(), otherwise flax & tf tests will fail
    Again thanks a lot for your efforts!

@younesbelkada Thanks for your patience in reviewing, I followed OPT convention and made the following changes:

  • CPMAntEmbedding and CPMAntLinear has been replaced by nn.Embedding and nn.Linear respectively.
  • CPMAntForCausalLM has been inherited from CPMAntPreTrainedModel, and lm_head and some functions have been added.
  • Useless initial arguments have been removed.
  • forward has been defined in CPMAntForCausalLM

@pioliverse
Copy link
Contributor Author

@younesbelkada Thanks again for your patience in reviewing.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks so much for your patience ! Looks pretty clean thank you! We should be close merging this once most of the comments are addressed. My comments being:

Docstring and comments:

  • please harmonize the function docstrings to match the convention transformers model follow
  • please make sure to clean up some comments
  • Also would be nice to add a small explanation on the code on why generate needs to be overriden

dtype:

  • I don't think the argument dtype is needed. The dtype of the whole model is managed by the kwarg torch_dtype so you can load your model using model = xxxForCausalLM.from_pretrained(xxx, torch_dtype=torch.float16) or torch_dtype="auto" (if the weights are pushed in fp16) and the model will be loaded in the desired precision.

tests

  • I think that a test is failing, please double check that

general comments

Thanks!

Comment on lines 109 to 111
# assert (
# pointer.shape == array.shape
# ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to clean up!

Comment on lines 227 to 228
# assert x.size(-1) == self.dim_norm
if x.size(-1) != self.dim_norm:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to clean up

Comment on lines 341 to 346
dim_model (int): Main dimension of modules in transformer blocks.
num_heads (int): Number of attention heads in the Transformer encoder.
dim_head (int): Dimension of attention heads for each attention layer in the Transformer encoder.
dtype (optional): Defaults to torch.float.
eps (float, optional): The epsilon used by the layer normalization layers.
dropout_p (float, optional): Defaults to 0.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please respect the convention as above ;)

dim_model (`int`):
    Main dimension of ...
...

Comment on lines 499 to 503
dim_model (int): Main dimension of modules in transformer blocks.
dim_ff (int): Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
dtype (optional): Defaults to torch.float.
eps (float, optional): The epsilon used by the layer normalization layers.
dropout_p (float, optional): Defaults to 0.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above!

Comment on lines 648 to 655
num_layers (int): Number of layers.
dim_model (int): Main dimension of modules in transformer blocks.
dim_ff (int): Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
num_heads (int): Number of attention heads in the Transformer encoder.
dim_head (int): Dimension of attention heads for each attention layer in the Transformer encoder.
dtype (optional): Defaults to torch.float.
eps (float, optional): The epsilon used by the layer normalization layers.
dropout_p (float, optional): Defaults to 0.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

attention_mask,
position_bias,
past_key_value=past_key_values[i] if past_key_values else None,
use_cache=use_cache,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that use_cache is also passed here, I am slightly confused why you don't pass it above too

Comment on lines 795 to 796
# assert key_pos.size(0) == query_pos.size(0)
# assert keylen == key_segment.size(1) and querylen == query_segment.size(1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to clean up

Comment on lines 1041 to 1051
with torch.no_grad():
device = input.device
directional_mask_2d = torch.arange(seqlen, device=device) <= torch.arange(seqlen, device=device).view(
-1, 1
)
attention_mask = context[:, None, :] | (
context[:, :, None].logical_not() & directional_mask_2d.view(1, seqlen, seqlen)
)
attention_mask = attention_mask & (span[:, None, :] == span[:, :, None])
mask_1d = torch.arange(seqlen, device=device)[None, :].repeat(batch, 1) < length[:, None]
attention_mask = mask_1d.view(batch, seqlen, 1) & mask_1d.view(batch, 1, seqlen) & attention_mask
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: You can probably wrap that in a class method _prepare_attention_mask

Comment on lines 1111 to 1124
with torch.no_grad():
device = input.device
directional_mask_2d = torch.arange(seqlen, device=device) <= torch.arange(seqlen, device=device).view(
-1, 1
)
attention_mask = context[:, None, :] | (
context[:, :, None].logical_not() & directional_mask_2d.view(1, seqlen, seqlen)
)
attention_mask = attention_mask & (span[:, None, :] == span[:, :, None])
# mask for left paddding
mask_1d = (
torch.tensor(list(range(seqlen))[::-1], device=device)[None, :].repeat(batch, 1) < length[:, None]
)
attention_mask = mask_1d.view(batch, seqlen, 1) & mask_1d.view(batch, 1, seqlen) & attention_mask
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same, you can wrap that in a method

hidden_states, attention_mask, position_bias, True, past_key_values
)
logits = self.lm_head(hidden_states)
return logits, hidden_states, present_key_values
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you return a dataclass object instead? For example:

or you can also define your own class

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jan 9, 2023

The documentation is not available anymore as the PR was closed or merged.

@pioliverse pioliverse closed this Jan 12, 2023
@pioliverse pioliverse reopened this Jan 12, 2023
@pioliverse pioliverse marked this pull request as draft January 12, 2023 03:22
@pioliverse pioliverse changed the title add model resources for CPMAnt (new) 【WIP】add model resources for CPMAnt (new) Jan 12, 2023
@pioliverse pioliverse changed the title 【WIP】add model resources for CPMAnt (new) [WIP] add model resources for CPMAnt (new) Jan 12, 2023
@pioliverse pioliverse changed the title [WIP] add model resources for CPMAnt (new) WIP: add model resources for CPMAnt (new) Jan 12, 2023
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jan 12, 2023

The documentation is not available anymore as the PR was closed or merged.

@pioliverse pioliverse changed the title WIP: add model resources for CPMAnt (new) add model resources for CPMAnt (new) Feb 2, 2023
@pioliverse
Copy link
Contributor Author

Thanks so much for your patience ! Looks pretty clean thank you! We should be close merging this once most of the comments are addressed. My comments being:

Docstring and comments:

  • please harmonize the function docstrings to match the convention transformers model follow
  • please make sure to clean up some comments
  • Also would be nice to add a small explanation on the code on why generate needs to be overriden

dtype:

  • I don't think the argument dtype is needed. The dtype of the whole model is managed by the kwarg torch_dtype so you can load your model using model = xxxForCausalLM.from_pretrained(xxx, torch_dtype=torch.float16) or torch_dtype="auto" (if the weights are pushed in fp16) and the model will be loaded in the desired precision.

tests

  • I think that a test is failing, please double check that

general comments

Thanks!

Hi @younesbelkada, we have made some changes as follows:

  1. add some docstrings.
  2. modified forward following the style of transformers.
  3. rewrote some functions to adapt the generate function
  • in modeling_cpmant.py, we rewrote some functions like prepare_inputs_for_generation, _expand_inputs_for_generation
  • in tokenization_cpmant.py, rewrote some functions like prepare_for_model, _pad, _encode_plus, _batch_encode_plus
  1. cleaned some comments.

@pioliverse pioliverse marked this pull request as ready for review February 2, 2023 10:58
Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for addressing most of the comments of the previous review! And thank you for your huge work on refactoring the modeling script
I left some comments, mostly nits that can be solved easily. Note that for arguments such as use_cache etc, we prefer to pass them through the forward pass rather than setting them as a class attribute.
Also, please consider passing a CPMAntConfig for the classes that have several attributes such as CPMAntEncoder
Make sure also to correctly pass the required keyword arguments such as past_key_values, output_attentions etc, that are crucial for caching mechanism. You can check how this is done in OPT for example
Finally, the naming convention in transformers has changed a bit, we prefer to name models with a single capital letter (i.e. here CPMAnt -> Cpmant)
Again thanks for your efforts on this! Once the comments being solved, we should be very close merging this!

]


def load_tf_weights_in_cpmant(model, config, tf_checkpoint_path):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be a function that is adapted from:

def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
, can you add a # Adapted from statement on the top of the function?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be removed entirely, there are no TF weights to convert here, no?

Comment on lines 136 to 141

super().__init__()

self.eps = eps
self.dim_norm = dim_norm
self.weight = torch.nn.parameter.Parameter(torch.full((dim_norm,), init_var))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
super().__init__()
self.eps = eps
self.dim_norm = dim_norm
self.weight = torch.nn.parameter.Parameter(torch.full((dim_norm,), init_var))
super().__init__()
self.eps = eps
self.dim_norm = dim_norm
self.weight = torch.nn.parameter.Parameter(torch.full((dim_norm,), init_var))

Comment on lines 128 to 131
"""RMS LayerNorm"""

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add more description here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We added some comments.

Comment on lines 193 to 194
Indices of input sequence tokens of shape `(batch, len_q, dim_model)`. It will be embedded by model's
internal embedding lookup matrix.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This description seems to be wrong

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Revised.

Avoid invalid areas to participate in the calculation of self-attention.
position_bias (`torch.Tensor` of shape `(batch, len_seq, len_seq)`):
Provide positional information to self-attention block.
past_kv (`Tuple(torch.FloatTensor)`, *optional*): Cached past key and value projection states.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
past_kv (`Tuple(torch.FloatTensor)`, *optional*): Cached past key and value projection states.
past_kv (`Tuple(torch.FloatTensor)`, *optional*):
Cached past key and value projection states.

return hidden_states, current_key_values


class CPMAntIntermediate(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be copied from BertIntermediate:

class BertIntermediate(nn.Module):
can you add a # Copied from statement?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I meant:

# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->CPMAnt

Check for example here:

# Copied from transformers.models.t5.modeling_t5.T5LayerNorm with T5->SwitchTransformers

return relative_buckets


class CPMAntOutput(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be copied from BertOutput:

class BertOutput(nn.Module):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we added a statement.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above, what I meant is:

# Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->CPMAnt

Check

# Copied from transformers.models.t5.modeling_t5.T5LayerNorm with T5->SwitchTransformers

position_bias = self.position_bias(position, position, segment, segment)

hidden_states = self.encoder(hidden_states, attention_mask, position_bias)
logits = F.linear(hidden_states, self.input_embedding.weight)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't you call self.input_embedding directly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe not, self.input_embedding works for input_ids, but not for hidden_states.

if not return_dict:
return tuple(v for v in [logits, hidden_states] if v is not None)

return BaseModelOutput(hidden_states=hidden_states)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some attributes are missing such as the attention outputs etc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We added some attributes.

Comment on lines 883 to 759
span: Optional[torch.Tensor] = None,
return_dict: Optional[bool] = False,
**kwargs,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

keyword arguments that are needed for public models such as output_attentions seem to be missing here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We added the output_attentions .

@pioliverse
Copy link
Contributor Author

I am a bit surprised that when I use make style, some other files are also reformatted, which causes check_code_quality to fail.

@younesbelkada
Copy link
Contributor

Hi @pioliverse
You need to rebase with main branch as the styling has been updated for most of the files in transformers , and update your black version as follows:

pip install --upgrade -e .["quality"]

Then make style or make fixup

@pioliverse
Copy link
Contributor Author

Hi @pioliverse You need to rebase with main branch as the styling has been updated for most of the files in transformers , and update your black version as follows:

pip install --upgrade -e .["quality"]

Then make style or make fixup

Thanks @younesbelkada , this has been solved.

@pioliverse
Copy link
Contributor Author

Thanks a lot for addressing most of the comments of the previous review! And thank you for your huge work on refactoring the modeling script I left some comments, mostly nits that can be solved easily. Note that for arguments such as use_cache etc, we prefer to pass them through the forward pass rather than setting them as a class attribute. Also, please consider passing a CPMAntConfig for the classes that have several attributes such as CPMAntEncoder Make sure also to correctly pass the required keyword arguments such as past_key_values, output_attentions etc, that are crucial for caching mechanism. You can check how this is done in OPT for example Finally, the naming convention in transformers has changed a bit, we prefer to name models with a single capital letter (i.e. here CPMAnt -> Cpmant) Again thanks for your efforts on this! Once the comments being solved, we should be very close merging this!

Thanks for your review @younesbelkada , we have modified some code.

  • We pass the use_cache in forward function from a class attribute.
  • We simplify the code for the class attribute assignment and replace it with CPMAntconfig.
  • We added past_key_values and output_attentions in forward of CPMAntModel.
  • I kind of wonder if all files that contain the name CPMAnt should be changed to Cpmant?

@gongbaitao
Copy link
Contributor

gongbaitao commented Feb 21, 2023

Hi @younesbelkada , I am a member of OpenBMB, and I will help @pioliverse finish this PR.

All the issues mentioned above have been resolved. Please kindly have a look.

For the unit tests, I rebase pioliverse:cpmantmodel with huggingface:main, but it cannot pass the test. It seems some other models cause the failure?

for instance, in tests_onnx I met the error:
ERROR tests/models/altclip/test_modeling_altclip.py ============ 72 passed, 551 skipped, 29 warnings, 1 error in 28.26s ============
How can I avoid such error?

@younesbelkada
Copy link
Contributor

Hi @gongbaitao
Thanks for jumping in! And sorry for the delay
Rebasing with main should be probably solve this issue, will look into the PR asap, let me know once you think this is ready for review!

@gongbaitao
Copy link
Contributor

gongbaitao commented Mar 5, 2023

@younesbelkada @sgugger Thanks for the valued comments!
According to the new comments, I have dropped some redundant codes, and rename the model class in a camel-cased way
: )

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating! There are still a couple of issues here and there. Also the test added for the tokenizer will need to be decorated with a requires_jieba (that you will need to define in testing_utils, similar to the other requires_xxx functions). Lastly, you need to also add an import error for jieba in import_utils.py so that requires_backend(["jieba"]) works without error.

README.md Outdated
@@ -309,6 +309,7 @@ Current number of checkpoints: ![](https://img.shields.io/endpoint?url=https://h
1. **[ConvBERT](https://huggingface.co/docs/transformers/model_doc/convbert)** (from YituTech) released with the paper [ConvBERT: Improving BERT with Span-based Dynamic Convolution](https://arxiv.org/abs/2008.02496) by Zihang Jiang, Weihao Yu, Daquan Zhou, Yunpeng Chen, Jiashi Feng, Shuicheng Yan.
1. **[ConvNeXT](https://huggingface.co/docs/transformers/model_doc/convnext)** (from Facebook AI) released with the paper [A ConvNet for the 2020s](https://arxiv.org/abs/2201.03545) by Zhuang Liu, Hanzi Mao, Chao-Yuan Wu, Christoph Feichtenhofer, Trevor Darrell, Saining Xie.
1. **[CPM](https://huggingface.co/docs/transformers/model_doc/cpm)** (from Tsinghua University) released with the paper [CPM: A Large-scale Generative Chinese Pre-trained Language Model](https://arxiv.org/abs/2012.00413) by Zhengyan Zhang, Xu Han, Hao Zhou, Pei Ke, Yuxian Gu, Deming Ye, Yujia Qin, Yusheng Su, Haozhe Ji, Jian Guan, Fanchao Qi, Xiaozhi Wang, Yanan Zheng, Guoyang Zeng, Huanqi Cao, Shengqi Chen, Daixuan Li, Zhenbo Sun, Zhiyuan Liu, Minlie Huang, Wentao Han, Jie Tang, Juanzi Li, Xiaoyan Zhu, Maosong Sun.
1. **[CpmAnt](https://huggingface.co/docs/transformers/main/model_doc/cpmant)** (from OpenBMB) released by the [OpenBMB](https://www.openbmb.org/).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
1. **[CpmAnt](https://huggingface.co/docs/transformers/main/model_doc/cpmant)** (from OpenBMB) released by the [OpenBMB](https://www.openbmb.org/).
1. **[CPMAnt](https://huggingface.co/docs/transformers/main/model_doc/cpmant)** (from OpenBMB) released by the [OpenBMB](https://www.openbmb.org/).

This should use the model name casing. It's only the model/config/tokenizer classes that should be CpmAntXxx

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Solved.

@@ -402,6 +404,7 @@
("convbert", "ConvBERT"),
("convnext", "ConvNeXT"),
("cpm", "CPM"),
("cpmant", "CpmAnt"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
("cpmant", "CpmAnt"),
("cpmant", "CPM-Ant"),

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Solved.

logger = logging.get_logger(__name__)

CPMANT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"cpm-ant-10b": "https://huggingface.co/openbmb/cpm-ant-10b/blob/main/config.json"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"cpm-ant-10b": "https://huggingface.co/openbmb/cpm-ant-10b/blob/main/config.json"
"openbmb/cpm-ant-10b": "https://huggingface.co/openbmb/cpm-ant-10b/blob/main/config.json"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Solved.

This is the configuration class to store the configuration of a [`CpmAntModel`]. It is used to instantiate an
CPMAnt model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the CPMAnt
[cpm-ant-10b](https://huggingface.co/openbmb/cpm-ant-10b) architecture.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
[cpm-ant-10b](https://huggingface.co/openbmb/cpm-ant-10b) architecture.
[openbmb/cpm-ant-10b](https://huggingface.co/openbmb/cpm-ant-10b) architecture.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Solved.


PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {
"cpm-ant-10b": "https://huggingface.co/openbmb/cpm-ant-10b/blob/main/vocab.txt",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"cpm-ant-10b": "https://huggingface.co/openbmb/cpm-ant-10b/blob/main/vocab.txt",
"openbmb/cpm-ant-10b": "https://huggingface.co/openbmb/cpm-ant-10b/blob/main/vocab.txt",

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Solved.

}

PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
"cpm-ant-10b": 1024,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"cpm-ant-10b": 1024,
"openbmb/cpm-ant-10b": 1024,

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Solved.

Comment on lines 101 to 115
bod_token (`str`, *optional*, defaults to `<d>`):
The beginning of document token.
eod_token (`str`, *optional*, defaults to `</d>`):
The end of document token.
bos_token (`str`, *optional*, defaults to `<s>`):
The beginning of sequence token.
eos_token (`str`, *optional*, defaults to `</s>`):
The end of sequence token.
pad_token (`str`, *optional*, defaults to `<pad>`):
The token used for padding.
unk_token (`str`, *optional*, defaults to `<unk>`):
The unknown token.
line_token (`str`, *optional*, defaults to `</n>`):
The line token.
space_token (`str`, *optional*, defaults to `</_>`):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Default values are all missing double quotes here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Solved.

from transformers.models.cpmant import CpmAntTokenizer


@unittest.skip("CPMAntTokenizer process vocab in list format, so we skip the common test.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test below won't be executed because of this global skip here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I skip the test because it's load_vocab logic is different from the TokenizerTesterMixin. While refactor is not that convenient and necessary i think, so I just skip it. Is it OK? Need I make some changes?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No this is not ok, since the test you wrote below will never be executed. If all tests fail in the TokenizerTesterMixin, that means your tokenizer does not have an API consistent with the other tokenizers of Transformers, and thus we can't accept it. You should fix your tokenizer so that it passes most of the tests of the common tester.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK I will fix it, thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test below won't be executed because of this global skip here.

Solved.

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@gongbaitao
Copy link
Contributor

Hi @sgugger @younesbelkada , sorry for the delay!
In the last few weeks, I have fixed the problems mentioned above and refactored the CPMAnt tokenizer. Please kindly have a look again, thanks for your help!

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating! We're almost good to go, just a couple of comments left to address in the tests.

@require_torch
class CPMAntModelIntegrationTest(unittest.TestCase):
@slow
@unittest.skip("skip this test as the model is very large for our daily runner")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Corresponding changes here do not seem to have been pushed.

@require_torch
class CPMAntForCausalLMlIntegrationTest(unittest.TestCase):
@slow
@unittest.skip("skip this test as the model is very large for our daily runner")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.


@custom_tokenizers
class CPMAntTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
def test_pre_tokenization(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still needs to be addressed.

Comment on lines 50 to 53
def __init__(
self,
config: CpmAntConfig,
):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def __init__(
self,
config: CpmAntConfig,
):
def __init__(self, config: CpmAntConfig):

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to be result of make style. After I fits the code in one line, I cannot pass the
CI check in code quality.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's because you are not copy-pasting this suggestion as is (there is a button to accept it directly in GitHub FYI) but are leaving a trailing comma.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I misunderstand the problem as fits in one line. Now the trailing comma has beed fixed:)

Comment on lines 428 to 431
def __init__(
self,
config: CpmAntConfig,
):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def __init__(
self,
config: CpmAntConfig,
):
def __init__(self, config: CpmAntConfig):

@gongbaitao
Copy link
Contributor

Thanks for your quick review! @sgugger
It seems this problem #20906 (comment) is because the changed file didn't show all commits. Maybe check this page https://github.com/huggingface/transformers/pull/20906/files will be helpful:)
As the #20906 (comment), it cannot pass the code quality check, so shall I keep it unchanged?

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replied on the comment for the styling issue. The exact same line is present multiple times in the modeling file so I think you did not take the suggestion as it is written.

As for the tests, i'm sorry I was unclear: I meant that we should use the @tooslow decorator instead of skip.

@require_torch
class CPMAntModelIntegrationTest(unittest.TestCase):
@slow
@unittest.skip("skip this test as the model is very large for our daily runner")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I meant we should use the @tooslowdecorator here instead of skipping. You can import it from testing_utils

@gongbaitao
Copy link
Contributor

@sgugger Thanks for your meaningful comments!
Sorry I forget to drop the trailing comma in styling issue. Now I have fixed the trailing comma problem and add tooslow decorator. Please kindly have a review:)

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect, thanks for bearing with me :-)
Congrats on getting this new model merged into Transformers and thanks again for all your work!

@sgugger sgugger merged commit 523ca4e into huggingface:main Apr 12, 2023
novice03 pushed a commit to novice03/transformers that referenced this pull request Jun 23, 2023
* resolve conflicts

* rebase and make style

* test

* test

* test

* rebase and make style

* rebase and make style

* tests

* tests

* rewrite some functions

* rebase and make style

* fix load_tf_weights_in_cpmant

* reformat some unrelated files

* upgrade quality

* fix some bugs & docstring

* add models and tests

* solve conflicts

* resolve conflicts

* resolve conflicts

* resolve conflicts

* resolve conflicts

* tests

* resolve conflicts

* resolve conflicts

* fix load_tf_weights_in_cpmant

* reformat some unrelated files

* upgrade quality

* fix some bugs & docstring

* save resolution

* make style

* delete redefinition code

* reformat function

* reformat

* resolve conflicts

* resolve conflicts

* resolve conflicts

* resolve conflicts

* resolve conflicts

* tests

* resolve conflicts

* resolve conflicts

* fix load_tf_weights_in_cpmant

* reformat some unrelated files

* upgrade quality

* resolve conflicts

* resolve conflicts

* resolve conflicts

* resolve conflicts

* resolve conflicts

* fix load_tf_weights_in_cpmant

* reformat some unrelated files

* upgrade quality

* resolve conflicts

* make style

* fix bugs and refactor

* modify docstrings and make style

* unify import format in __init__.py

* fix import-altclp bug

* fix copies to update index.md

* fix unused config parameters

* fix unused config parameters

* fix unused config parameters

* update README_ja.md

* dummy commit for unit test

* fix attention mask

* add CPMAntTokenizer&-Fast to auto-mapping

* drop redundant changes in README_ko

* fix  defaults in docstring

* fix use_cache and some docstring

* add missing args in tokenizer

* modify tester inheritance

* add is_jieba_available

* fix some bugs

* make style and fix-copies

* add doctests

* skip integration tests

* add is_jieba_available

* fix bugs in common tests

* adjust docstrings and make style

* add argument docstring

* adjust code to some specifications

* make style and fix-copies

* add fast tokenization test

* dummy commit for unit test

* dummy commit for unit test

* dummy commit for unit test

* normalize some comments and names

* Bert->CPMAnt

* camel names and drop redundant codes

* make style and fix-coies

* add CpmTokenizerFast _import_structure

* drop cpmanttokenizerfast in model_doc

* fix some problems

* fix CPMAnt tokenization for common test

* make style and fixup

* fix copies and fixup

* fix bugs in tokenization test

* dummy commit for connection failure in unittest

* fix copies

* drop trailing comma

* fix decorator in tests

* dummy commit for connection failure in unittest

---------

Co-authored-by: Gong Baitao <gongbaitao11@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants