-
Notifications
You must be signed in to change notification settings - Fork 31.5k
Description
Environment info
transformersversion: 3.5.0- Platform: Linux-5.4.0-53-generic-x86_64-with-glibc2.10
- Python version: 3.8.3
- PyTorch version (GPU?): 1.4.0 (True)
- Tensorflow version (GPU?): 2.3.0 (False)
- Using GPU in script?: yes
- Using distributed or parallel set-up in script?: no
Who can help
Information
Model I am using (Bert, XLNet ...): Encoder (RoBERTa) Decoder (GPT2) model
The problem arises when using:
- the official example scripts: (give details below)
- my own modified scripts: (give details below)
from transformers import (
AutoConfig,
AutoModel,
AutoModelForCausalLM,
EncoderDecoderModel,
EncoderDecoderConfig,
GPT2Config,
)
encoder_config = AutoConfig.from_pretrained('microsoft/codebert-base')
encoder = AutoModel.from_pretrained('microsoft/codebert-base')
decoder_config = GPT2Config(
n_layer = 6,
n_head = encoder_config.num_attention_heads,
add_cross_attention= True,
)
decoder = AutoModelForCausalLM.from_config(decoder_config)
encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder_config, decoder_config)
encoder_decoder_config.tie_encoder_decoder = True
shared_codebert2gpt = EncoderDecoderModel(encoder = encoder, decoder = decoder, config = encoder_decoder_config)
The tasks I am working on is: N/A
To reproduce
Steps to reproduce the behavior:
Running the above code produces the following message:
The following encoder weights were not tied to the decoder ['transformer/pooler', 'transformer/embeddings', 'transformer/encoder']
When checking the number of parameters of the model produces a model with shared_codebert2gpt: 220,741,632 parameters, which is the same number of parameters if I were to not attempt to tie the encoder and decoder parameters :(.
Expected behavior
The above snippet should produce a model with roughly 172,503,552 parameters.
My big question is, am I doing this correctly? I can correctly tie the model parameters if I use the EncoderDecoderModel.from_encoder_decoder_pretrained constructor and pass tie_encoder_decoder=True. However, for my task, I don't want to use a pretrained decoder and so am unable to use this constructor.
Any help with this would be greatly appreciated!