Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Starcoder2 model #29120

Closed
wants to merge 7 commits into from
Closed

Conversation

jlamypoirier
Copy link
Contributor

@jlamypoirier jlamypoirier commented Feb 19, 2024

The Starcoder2 model, adapted from Mistral. All changes are done through options, so Mistral itself is still supported. Main changes:

  • Use layer norm (RMS still available as option)
  • Use standard MLP (gated still available as option)
  • Add back biases (optional)
  • Change (default?) tokenizer class
    *Embedding and residual dropout

It does not support absolute embeddings, so can't support Santacoder or Starcoder

Todo:

@younesbelkada

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔥 looks very good!

return self.weight * hidden_states.to(input_dtype)


# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Starcoder2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix copies will not let this pass, should be copied from Mistral as we changed llama for compiled static cache.
I would also rather we support static cache as the API got quite a lot cleaner

return torch.cat((-x2, x1), dim=-1)


# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here llama is different make fix-copies will help you fix this !

return hidden_states


class Starcoder2GatedMLP(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably missing copied from mention here (mistral)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It has small changes (bias + dropout I think)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we remove the copied mention from all the classes/methods where we added dropout?

# Copied from transformers.models.mistral.modeling_mistral.MistralAttention.forward

# Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Starcoder2

# Copied from transformers.models.mistral.modeling_mistral.MistralModel.forward with MISTRAL->STARCODER2,Mistral->Starcoder2

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes otherwise the check-copies will fail 😉

Comment on lines +259 to +261
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

this is not used in Mistral anyways

return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


class Starcoder2Attention(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would make sense to follow the llama implementation IMO for static cache (with the additional cache positions) but this can go in another PR no worries 🤗

Comment on lines +754 to +763
self.self_attn = STARCODER2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)

self.mlp = STARCODER2_MLP_CLASSES[config.mlp_type](config)

self.input_layernorm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type](
config.hidden_size, eps=config.norm_epsilon
)
self.post_attention_layernorm = STARCODER2_NORMALIZATION_CLASSES[config.norm_type](
config.hidden_size, eps=config.norm_epsilon
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not what we usually do in transformers. The attention is a specific case 😅

  • are all of these used in the default starcoder?
  • if not then let's not support mistral. Mistral is a different architecture
    The reason why attention is allowed is because it uses the same parameters -> same "Attention" with different forward vs here it's really a different architecture = against transformers philosophy

Comment on lines +1042 to +1062
if self._attn_implementation == "flash_attention_2":
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._attn_implementation == "sdpa" and not output_attentions:
# output_attentions=True can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
sliding_window=self.config.sliding_window,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see the new Llama code for this which was simpliefied. I'd rather we take it directly for the attention 😉

Comment on lines +289 to +291
@unittest.skip("Starcoder2 buffers include complex numbers, which breaks this test")
def test_save_load_fast_init_from_base(self):
pass
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might have missed this but have not seen where these complex number buffers are?

@RaymondLi0
Copy link
Contributor

RaymondLi0 commented Feb 22, 2024

I re-created a PR here since Joel is on vacation: #29215
Sorry for the inconvenience

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@ArthurZucker
Copy link
Collaborator

Closing as #29215 was merged and starcoder 2 is officially supported

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants