Skip to content

Commit

Permalink
add embed and residual dropout (#30)
Browse files Browse the repository at this point in the history
  • Loading branch information
RaymondLi0 committed Feb 19, 2024
1 parent e0ec999 commit 4983a75
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ def __init__(
rope_theta=10000.0,
sliding_window=None,
attention_dropout=0.0,
residual_dropout=0.0,
embedding_dropout=0.0,
# TODO: Implement
use_bias: bool = True,
# TODO: Other dropouts?
Expand Down Expand Up @@ -153,6 +155,8 @@ def __init__(
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_dropout = attention_dropout
self.residual_dropout = residual_dropout
self.embedding_dropout = embedding_dropout

super().__init__(
bos_token_id=bos_token_id,
Expand Down
13 changes: 12 additions & 1 deletion src/transformers/models/starcoder2/modeling_starcoder2.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,11 +169,13 @@ def __init__(self, config: Starcoder2Config):
self.c_fc = nn.Linear(embed_dim, config.intermediate_size, bias=config.use_bias)
self.c_proj = nn.Linear(config.intermediate_size, embed_dim, bias=config.use_bias)
self.act = ACT2FN[config.hidden_act]
self.residual_dropout = config.residual_dropout

def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
hidden_states = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.c_proj(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.residual_dropout, training=self.training)
return hidden_states


Expand All @@ -188,9 +190,12 @@ def __init__(self, config: Starcoder2Config):
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.use_bias)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
self.act_fn = ACT2FN[config.hidden_act]
self.residual_dropout = config.residual_dropout

def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
hidden_states = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
hidden_states = nn.functional.dropout(hidden_states, p=self.residual_dropout, training=self.training)
return hidden_states


# Copied from transformers.models.llama.modeling_llama.repeat_kv
Expand Down Expand Up @@ -233,6 +238,7 @@ def __init__(self, config: Starcoder2Config, layer_idx: Optional[int] = None):
self.use_bias = config.use_bias
self.is_causal = True
self.attention_dropout = config.attention_dropout
self.residual_dropout = config.residual_dropout

if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
Expand Down Expand Up @@ -329,6 +335,7 @@ def forward(
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

attn_output = self.o_proj(attn_output)
attn_output = nn.functional.dropout(attn_output, p=self.residual_dropout, training=self.training)

if not output_attentions:
attn_weights = None
Expand Down Expand Up @@ -481,6 +488,7 @@ def forward(

attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
attn_output = self.o_proj(attn_output)
attn_output = nn.functional.dropout(attn_output, p=self.residual_dropout, training=self.training)

if not output_attentions:
attn_weights = None
Expand Down Expand Up @@ -714,6 +722,7 @@ def forward(
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

attn_output = self.o_proj(attn_output)
attn_output = nn.functional.dropout(attn_output, p=self.residual_dropout, training=self.training)

return attn_output, None, past_key_value

Expand Down Expand Up @@ -946,6 +955,7 @@ def __init__(self, config: Starcoder2Config):
self.vocab_size = config.vocab_size

self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.embedding_dropout = config.embedding_dropout
self.layers = nn.ModuleList(
[Starcoder2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
Expand Down Expand Up @@ -1052,6 +1062,7 @@ def forward(
)

hidden_states = inputs_embeds
hidden_states = nn.functional.dropout(hidden_states, p=self.embedding_dropout, training=self.training)

# decoder layers
all_hidden_states = () if output_hidden_states else None
Expand Down

0 comments on commit 4983a75

Please sign in to comment.