diff --git a/src/transformers/models/t5gemma/modeling_t5gemma.py b/src/transformers/models/t5gemma/modeling_t5gemma.py index 81ba072a2c72..ebb5f6a5e874 100644 --- a/src/transformers/models/t5gemma/modeling_t5gemma.py +++ b/src/transformers/models/t5gemma/modeling_t5gemma.py @@ -448,11 +448,28 @@ def forward( return hidden_states -class T5GemmaDecoderLayer(T5GemmaEncoderLayer): +class T5GemmaDecoderLayer(GradientCheckpointingLayer): """Decoder sub-layer: an extra cross-attention layer.""" def __init__(self, config, layer_idx: int): - super().__init__(config, layer_idx) + super().__init__() + self.hidden_size = config.hidden_size + self.config = config + self.layer_idx = layer_idx + self.attention_type = config.layer_types[layer_idx] + + self.self_attn = T5GemmaSelfAttention( + config=config, + layer_idx=layer_idx, + ) + self.pre_self_attn_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_self_attn_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.mlp = T5GemmaMLP(config) + self.pre_feedforward_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_feedforward_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.dropout = nn.Dropout(config.dropout_rate) self.cross_attn = T5GemmaCrossAttention(config=config, layer_idx=layer_idx) self.pre_cross_attn_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_cross_attn_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -732,7 +749,7 @@ def forward( ) -class T5GemmaDecoder(T5GemmaEncoder): +class T5GemmaDecoder(T5GemmaPreTrainedModel): _can_record_outputs = { "attentions": OutputRecorder(T5GemmaSelfAttention, index=1), "cross_attentions": OutputRecorder(T5GemmaCrossAttention, index=1), @@ -741,11 +758,20 @@ class T5GemmaDecoder(T5GemmaEncoder): def __init__(self, config): super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.norm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = False + self.layers = nn.ModuleList( [T5GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) + self.dropout = nn.Dropout(config.dropout_rate) self.rotary_emb = T5GemmaRotaryEmbedding(config=config) + # Initialize weights and apply final processing self.post_init() @check_model_inputs() diff --git a/src/transformers/models/t5gemma/modular_t5gemma.py b/src/transformers/models/t5gemma/modular_t5gemma.py index 86ecf53ae6e4..a4c1d5083e74 100644 --- a/src/transformers/models/t5gemma/modular_t5gemma.py +++ b/src/transformers/models/t5gemma/modular_t5gemma.py @@ -517,11 +517,28 @@ def forward( return hidden_states -class T5GemmaDecoderLayer(T5GemmaEncoderLayer): +class T5GemmaDecoderLayer(GradientCheckpointingLayer): """Decoder sub-layer: an extra cross-attention layer.""" def __init__(self, config, layer_idx: int): - super().__init__(config, layer_idx) + super().__init__() + self.hidden_size = config.hidden_size + self.config = config + self.layer_idx = layer_idx + self.attention_type = config.layer_types[layer_idx] + + self.self_attn = T5GemmaSelfAttention( + config=config, + layer_idx=layer_idx, + ) + self.pre_self_attn_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_self_attn_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.mlp = T5GemmaMLP(config) + self.pre_feedforward_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_feedforward_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.dropout = nn.Dropout(config.dropout_rate) self.cross_attn = T5GemmaCrossAttention(config=config, layer_idx=layer_idx) self.pre_cross_attn_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_cross_attn_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -770,7 +787,7 @@ def forward( ) -class T5GemmaDecoder(T5GemmaEncoder): +class T5GemmaDecoder(T5GemmaPreTrainedModel): _can_record_outputs = { "attentions": OutputRecorder(T5GemmaSelfAttention, index=1), "cross_attentions": OutputRecorder(T5GemmaCrossAttention, index=1), @@ -779,11 +796,20 @@ class T5GemmaDecoder(T5GemmaEncoder): def __init__(self, config): super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.norm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = False + self.layers = nn.ModuleList( [T5GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) + self.dropout = nn.Dropout(config.dropout_rate) self.rotary_emb = T5GemmaRotaryEmbedding(config=config) + # Initialize weights and apply final processing self.post_init() @check_model_inputs()