In [1]:
!pip install bitsandbytes transformers accelerate vllm



In [2]:
import pprint
import torch
from torch.distributions import Categorical
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, set_seed, BitsAndBytesConfig

# Default device is CPU
device = torch.device('cpu')

# Check if CUDA GPU is available
if torch.cuda.is_available():
    device = torch.device('cuda')
    print('CUDA GPU is available')
    print(f"Device name: {torch.cuda.get_device_name(0)}")


target_model_name = "unsloth/Llama-3.2-3B-Instruct-bnb-4bit"
auxilary_model_name = "unsloth/Llama-3.2-1B-Instruct-bnb-4bit"

# Load models and tokenizer
compute_dtype = torch.bfloat16
quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        llm_int8_threshold=6.0,
        llm_int8_has_fp16_weight=False,
        bnb_4bit_compute_dtype=compute_dtype,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4"
)

model_args = {"torch_dtype": compute_dtype, "device_map": "auto"}
tgt_m = AutoModelForCausalLM.from_pretrained(target_model_name, **model_args)
drf_m = AutoModelForCausalLM.from_pretrained(auxilary_model_name, **model_args)
tok = AutoTokenizer.from_pretrained(target_model_name)
tok.pad_token = tok.eos_token
tok.padding_side = "left"

CUDA GPU is available
Device name: NVIDIA GeForce RTX 4090


Unused kwargs: ['_load_in_4bit', '_load_in_8bit', 'quant_method']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.
Unused kwargs: ['_load_in_4bit', '_load_in_8bit', 'quant_method']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.


In [3]:
@torch.no_grad()
def generate(model, tokenizer, input_ids: torch.Tensor, max_new_tokens: int, temperature: float = 1.0) -> torch.Tensor:
    log_zero = -1e4

    # Initialize generated tokens with the input prompt
    generated_ids = input_ids
    finished_sequences = torch.zeros(input_ids.shape[0], dtype=torch.bool, device=model.device)
    log_probs = []

    # Iteratively generate tokens using greedy decoding
    for token_idx in range(max_new_tokens):
        # Filter out finished sequences
        active_indices = torch.nonzero(~finished_sequences).squeeze(-1)
        if len(active_indices) == 0:
            break

        # Get model outputs for active sequences
        active_input_ids = generated_ids[active_indices]
        outputs = model(input_ids=active_input_ids)
        logits = outputs.logits

        # Get the last token logits and apply argmax to select the next token
        next_token_logits = logits[:, -1, :] / temperature
        next_token_log_probs = torch.nn.functional.log_softmax(next_token_logits, dim=-1)
        next_token_id = Categorical(logits=next_token_log_probs).sample()
        # next_token_log_prob, next_token_id = next_token_log_probs.max(dim=-1)

        # Save log next-token distribution for each sequence in batch; inactivate sequences produce <pad> token with probability 1
        curr_log_probs = torch.full((input_ids.shape[0], len(tokenizer)), log_zero, dtype=next_token_log_probs.dtype, device=model.device)
        curr_log_probs[:, tokenizer.pad_token_id] = 0.0
        curr_log_probs[active_indices] = next_token_log_probs
        log_probs.append(curr_log_probs)

        # Update finished sequences and add padding if necessary
        finished_sequences[active_indices] |= (next_token_id == tokenizer.eos_token_id)

        # Create a tensor for the next tokens to append to all sequences
        new_tokens = torch.full((generated_ids.shape[0], 1), tokenizer.pad_token_id, dtype=torch.long, device=model.device)
        new_tokens[active_indices] = next_token_id.unsqueeze(-1)

        # Append the next token to the generated sequence
        generated_ids = torch.cat([generated_ids, new_tokens], dim=-1)

    return generated_ids, log_probs

In [4]:
@torch.no_grad()
def compute_ll_rejs(tgt_lprob: torch.Tensor, spec_lprob: torch.Tensor, spec_tok_id: torch.Tensor) -> torch.Tensor:
    # Compare log-likelihood ratios of target and speculative tokens; use unifrorm (0, 1) distribution to decide acceptance
    llrs = tgt_lprob[:-1].gather(1, spec_tok_id.view(-1,1)) - spec_lprob.gather(1, spec_tok_id.view(-1,1))
    uniform_lprobs = torch.log(torch.rand_like(llrs))
    rej_idx = torch.nonzero((llrs <= uniform_lprobs).squeeze(-1))
    return rej_idx

@torch.no_grad()
def compute_adjusted_dist(tgt_lprob: torch.Tensor, spec_lprob: torch.Tensor, rej_idx: torch.Tensor) -> torch.Tensor:
    adj_dist = torch.clamp(torch.exp(tgt_lprob[rej_idx]) - torch.exp(spec_lprob[rej_idx]), min=0)
    adj_dist = torch.div(adj_dist, adj_dist.sum())
    return adj_dist

@torch.no_grad()
def specualative_decode(tgt_m, drf_m, tok, inp: torch.Tensor,
                        max_tok: int, n_spec: int = 5, t: float = 1.0) -> torch.Tensor:
    # Initialize generated tokens with the input prompt
    gen = inp
    max_len = inp.shape[1] + max_tok
    
    while gen.shape[1] < max_len:
        tok_left = max_len - gen.shape[1]
        spec_size = min(n_spec, tok_left - 1)

        if spec_size > 0:
            # Generate speculative tokens
            spec_id, spec_lprob = generate(drf_m, tok, gen, spec_size, t)
            spec_size = spec_id.shape[1] - gen.shape[1]
            spec_tok_id = spec_id[:, -spec_size:]
            spec_lprob = torch.stack(spec_lprob, dim=1).squeeze(0)

            # Forwarding tgt model
            outputs = tgt_m(input_ids=spec_id)
            tgt_logit = outputs.logits[:, -(spec_size + 1):, :].squeeze(0) / t
            tgt_lprob = torch.nn.functional.log_softmax(tgt_logit, dim=-1)

            # Compute rejected indices
            rejs = compute_ll_rejs(tgt_lprob, spec_lprob, spec_tok_id)

            if len(rejs) > 0:
                # Some speculative tokens are rejected, truncate the accepted tokens
                rej_idx = rejs[0]
                accepted = spec_tok_id[:, :rej_idx]

                # Sample the next token from the adjusted distribution
                adj_dist = compute_adjusted_dist(tgt_lprob, spec_lprob, rejs[0])
                next_tok = Categorical(probs=adj_dist).sample()

            else:
                # All speculative tokens are accepted, sample the next token from target model
                accepted = spec_tok_id
                if accepted[0, -1].item() != tok.eos_token_id:
                    next_tok = Categorical(logits=tgt_logit[[-1]]).sample()

            # Append the accepted tokens to the generated sequence
            if accepted.numel() == 0 or (accepted.numel() > 0  and accepted[0, -1].item() != tok.eos_token_id):
                new_tok = torch.cat([accepted, next_tok.unsqueeze(-1)], dim=-1)
            else:
                new_tok = accepted

            gen = torch.cat([gen, new_tok], dim=-1)

        else:
            # If no speculation is performed, use the target model for generation
            outputs = tgt_m(input_ids=gen)
            tgt_logit = outputs.logits[:, -1, :].squeeze(0) / t
            tgt_lprob = torch.nn.functional.log_softmax(tgt_logit, dim=-1).unsqueeze(0)
            next_tok = Categorical(logits=tgt_lprob).sample()
            gen = torch.cat([gen, next_tok.unsqueeze(0)], dim=-1)

        if gen[0, -1] == tok.eos_token_id:
            break

    return gen

In [5]:
messages = [
    [
        {'role': 'system', 'content': 'You are an algebra assistant. The user will ask you math questions and you will solve them.'},
        {'role': 'user', 'content': "Peter purchased 20 popsicles at $0.25 each. He also purchased 2730244 ice cream bars at $0.50 each. How much did he pay in total in dollars?"},
    ],
]
max_tok = 120
t = 0.001

inputs = tok.apply_chat_template(messages)
for inp, message in zip(inputs, messages):
    inp = torch.tensor(inp, device=tgt_m.device).unsqueeze(0)

    set_seed(42)
    speculative_ids = specualative_decode(tgt_m, drf_m, tok, inp, max_tok, t=t)

    set_seed(42)
    sampled_ids, log_probs = generate(tgt_m, tok, inp, max_tok, temperature=t)

    if torch.equal(speculative_ids, sampled_ids):
        print("The outputs match!")
    else:
        print("The outputs do not match.")

    speculative_text = tok.batch_decode(speculative_ids, skip_special_tokens=True)
    sampled_text = tok.batch_decode(sampled_ids, skip_special_tokens=True)

    pprint.pprint({"Prompt": message, "Speculative": speculative_text, "Sampled": sampled_text})


The outputs match!
{'Prompt': [{'content': 'You are an algebra assistant. The user will ask you '
                        'math questions and you will solve them.',
             'role': 'system'},
            {'content': 'Peter purchased 20 popsicles at $0.25 each. He also '
                        'purchased 2730244 ice cream bars at $0.50 each. How '
                        'much did he pay in total in dollars?',
             'role': 'user'}],
 'Sampled': ['system\n'
             '\n'
             'Cutting Knowledge Date: December 2023\n'
             'Today Date: 21 Nov 2024\n'
             '\n'
             'You are an algebra assistant. The user will ask you math '
             'questions and you will solve them.user\n'
             '\n'
             'Peter purchased 20 popsicles at $0.25 each. He also purchased '
             '2730244 ice cream bars at $0.50 each. How much did he pay in '
             'total in dollars?assistant\n'
             '\n'
             'To find the tota