In [None]:
from llama_cpp import Llama
import torch
import numpy as np
import torch.nn as nn

class ResBlock(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.norm = nn.LayerNorm(hidden_size)
        self.mlp = nn.Sequential(
            nn.Linear(hidden_size, hidden_size * 4),
            nn.GELU(),
            nn.Linear(hidden_size * 4, hidden_size),
        )

    def forward(self, x):
        x = x.to(torch.float32)
        return x + self.mlp(self.norm(x))


class CustomMedusaHead(nn.Module):
    def __init__(self, hidden_size, vocab_size, num_pred_tokens=10, medusa_num_heads=1, medusa_num_layers=2):
        super().__init__()
        self.hidden_size = hidden_size
        self.vocab_size = vocab_size
        self.num_pred_tokens = num_pred_tokens  # ✅ fixed
        self.medusa_num_heads = medusa_num_heads
        self.medusa_num_layers = medusa_num_layers

        self.projections = nn.ModuleList([
            nn.Sequential(
                *[ResBlock(hidden_size) for _ in range(medusa_num_layers)],
                nn.Linear(hidden_size, vocab_size, bias=False)
            )
            for _ in range(medusa_num_heads)
        ])

    def forward(self, input_ids: np.ndarray) -> np.ndarray:
        if not isinstance(input_ids, np.ndarray):
            raise ValueError("CustomMedusa expects input_ids as a numpy.ndarray")
        if input_ids.dtype != np.intc:
            raise ValueError(f"CustomMedusa expects dtype np.intc (int32), got {input_ids.dtype}")

        seq_len = input_ids.shape[-1]
        hidden_states = torch.randn((seq_len, self.hidden_size), dtype=torch.float32)

        logits = [proj(hidden_states) for proj in self.projections]
        logits = torch.stack(logits, dim=0)
        logits = logits[:, -1, :]  # last token

        probs = torch.softmax(logits, dim=-1)
        pred_tokens = torch.multinomial(probs, num_samples=self.num_pred_tokens, replacement=True)
        return pred_tokens.flatten().cpu().numpy().astype(np.intc)

    def __call__(self, input_ids: np.ndarray, /, **kwargs) -> np.ndarray:
        return self.forward(input_ids)

In [58]:
model_gguf = 'vicuna-7b-v1.gguf'
medusa_path = 'medusa_lm_head.pt'
hidden_size = 2048
vocab_size = 32000
d_type = torch.float16

In [61]:
# laoding mesusa head 
print(f'Loading Medusa head ')
medusa_head = CustomMedusaHead(
    hidden_size=hidden_size,     # depends on your base model
    vocab_size=vocab_size,       # depends on your base model
    medusa_num_heads=2,          # check config
    medusa_num_layers=1          # check config
)
# Load the pretrained weights
state_dict = torch.load(medusa_path)
medusa_head.load_state_dict(state_dict,strict=False)
print('Medusa head is loaded successfully .....')

print(f"Loading model from ...")
model = Llama(
    model_path=model_gguf,
    n_ctx=hidden_size,
    use_mlock=True,      # lock into RAM to avoid swapping
    use_mmap=True,       # memory map to load faster
    logits_all=False,    # don't return logits unless necessary
    seed=42,
    verbose=False,
    # draft_model=medusa_head 
)
print(f'Base Model is loaded successfully ..........')

print("Both Model loaded successfully! ........")


Loading Medusa head 
Medusa head is loaded successfully .....
Loading model from ...
Base Model is loaded successfully ..........
Both Model loaded successfully! ........


In [18]:
print(model.context_params)

<llama_cpp.llama_cpp.llama_context_params object at 0x0000023568F86A50>


In [64]:
prompt = "what is speculative decoding ?"

In [65]:
output = model(
            prompt,
            max_tokens=50,
            temperature=0.7,
            top_p=0.9,
            echo=False,
            stop=["</s>"],
        )

In [66]:
output

{'id': 'cmpl-f4b79cef-c198-4f53-a008-6923f8c2b86d',
 'object': 'text_completion',
 'created': 1745904244,
 'model': 'vicuna-7b-v1.gguf',
 'choices': [{'text': '\n\nSpeculative decoding is a technique used by the JIT (Just-In-Time) compiler to optimize the performance of a program. It involves analyzing the control flow and data dependencies of a program, and making intelligent predictions about',
   'index': 0,
   'logprobs': None,
   'finish_reason': 'length'}],
 'usage': {'prompt_tokens': 8, 'completion_tokens': 50, 'total_tokens': 58}}