-
Notifications
You must be signed in to change notification settings - Fork 903
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: medusa v2 #1734
feat: medusa v2 #1734
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice.
Few nits.
super().__init__() | ||
self.blocks = torch.nn.ModuleList( | ||
[ | ||
ResBlock(config, prefix=f"{prefix}.{i}", weights=weights) | ||
for i in range(config["medusa_num_layers"]) | ||
for i in range(medusa_config["medusa_num_layers"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Which should probably take that information from speculate
instead (to avoid loading layers which we're not going to use)
We can do in another PR
|
||
self.act = torch.nn.SiLU() | ||
|
||
self.lm_head = TensorParallelHead.load(config, prefix, weights) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought what I had done was cleverer by passing the lm_head
directly (which avoids a reload).
Since we're using sharding here Weights will actually load the LM head n times with current code. (iirc)
medusa = MedusaModel(config, weights) | ||
lm_head = None | ||
try: | ||
medusa = MedusaHeadV2(config, prefix, weights) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't it be the reverse ? MedusaV2 will always load properly (since it's only a subset of V1) ?
@@ -467,46 +467,159 @@ def forward(self, x): | |||
return x | |||
|
|||
|
|||
class SpeculativeHead(nn.Module): | |||
class MedusaHeadV1(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe it's time we put them into their own file (both in the same file is ok I think)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
Let's move the nits as chore.
It would be helpful to document it. |
No description provided.