Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: medusa v2 #1734

Merged
merged 3 commits into from
Apr 12, 2024
Merged

feat: medusa v2 #1734

merged 3 commits into from
Apr 12, 2024

Conversation

OlivierDehaene
Copy link
Member

No description provided.

Copy link
Collaborator

@Narsil Narsil left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice.

Few nits.

super().__init__()
self.blocks = torch.nn.ModuleList(
[
ResBlock(config, prefix=f"{prefix}.{i}", weights=weights)
for i in range(config["medusa_num_layers"])
for i in range(medusa_config["medusa_num_layers"])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which should probably take that information from speculate instead (to avoid loading layers which we're not going to use)

We can do in another PR


self.act = torch.nn.SiLU()

self.lm_head = TensorParallelHead.load(config, prefix, weights)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought what I had done was cleverer by passing the lm_head directly (which avoids a reload).

Since we're using sharding here Weights will actually load the LM head n times with current code. (iirc)

medusa = MedusaModel(config, weights)
lm_head = None
try:
medusa = MedusaHeadV2(config, prefix, weights)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't it be the reverse ? MedusaV2 will always load properly (since it's only a subset of V1) ?

@@ -467,46 +467,159 @@ def forward(self, x):
return x


class SpeculativeHead(nn.Module):
class MedusaHeadV1(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it's time we put them into their own file (both in the same file is ok I think)?

Narsil
Narsil previously approved these changes Apr 12, 2024
Copy link
Collaborator

@Narsil Narsil left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

Let's move the nits as chore.

@OlivierDehaene OlivierDehaene merged commit eefea5e into main Apr 12, 2024
1 of 3 checks passed
@OlivierDehaene OlivierDehaene deleted the feat/medusa_v2 branch April 12, 2024 14:24
kdamaszk pushed a commit to kdamaszk/tgi-gaudi that referenced this pull request Apr 29, 2024
Nilabhra pushed a commit to TII-AI-Research-Center/text-generation-inference that referenced this pull request May 14, 2024
@fxmarty
Copy link
Collaborator

fxmarty commented May 19, 2024

It would be helpful to document it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants