In [1]:
import torch
from torch.nn import functional as F

# copy from https://github.com/LeeSinLiang/microGPT/blob/ed40cf9780dbeb180adfe94c227d4aa97e69250e/gpt.py
def top_k_top_p_filter(logits: torch.Tensor, top_k: int = 0, top_p: float = 0.0):
    """

    Args:
        logits (torch.Tensorpe_): 2D tensor with shape (batch, vocab)
        top_k (int, optional): top_k. Defaults to 0.
        top_p (float, optional): top_p. Defaults to 0.0.

    Returns:
        torch.Tensor: a renormalized logits
    """
    if top_k > 0:
        filter = torch.topk(logits, min(top_k, logits.size(-1)))[0]
        logits[logits < filter[:, [-1]]] = float('-inf')
    if top_p > 0.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(
            F.softmax(sorted_logits, dim=-1), dim=-1)
        filter = cumulative_probs > top_p
        filter[..., 1:] = filter[..., :-1].clone()
        filter[..., 0] = 0
        indices_to_remove = filter.scatter(1, sorted_indices, filter)
        logits[indices_to_remove] = float('-inf')
    return logits


def norm_logits(logits : torch.Tensor, temperature : float, top_k : float, top_p : float) -> torch.Tensor:
    """

    Args:
        logits (torch.Tensor): shape (1, vocab)
        temperature (float): temperature
        top_k (float): top_k
        top_p (float): top_p

    Returns:
        torch.Tensor: next token with shape as (batch,  1)
    """
    assert logits.dim() == 2
    logits = logits / temperature
    logits = top_k_top_p_filter(logits, top_k=top_k, top_p=top_p)
    probs = F.softmax(logits, dim=1)
    return probs


def sample(probs : torch.Tensor, num_samples: int = 1):
    idx_next = torch.multinomial(probs, num_samples=num_samples)
    if (idx_next.item() == 0):
        raise RuntimeError
    return idx_next


def max_fn(x):
    """
        norm(max (x, 0))
    """
    x_max = torch.where(x > 0, x, torch.zeros_like(x))
    x_max_sum = torch.sum(x_max, dim=1, keepdim=True) 
    return x_max / x_max_sum

In [18]:
from model import Mamba, ModelArgs
from transformers import AutoTokenizer

# One of:
#     'state-spaces/mamba-2.8b-slimpj'
#     'state-spaces/mamba-2.8b'
#     'state-spaces/mamba-1.4b'
#     'state-spaces/mamba-790m'
#     'state-spaces/mamba-370m'
#     'state-spaces/mamba-130m'
model = Mamba.from_pretrained('state-spaces/mamba-370m')
assistant_model = Mamba.from_pretrained('state-spaces/mamba-130m')
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [17]:
def autoregressive_sampling(x : torch.Tensor, model : torch.nn.Module, N : int, 
                            temperature : float = 1, top_k : int = 0, top_p : float = 0):
    n = len(x)
    T = len(x) + N

    while n < T:
        # outputs = model(x)
        outputs = model(x)
        last_p = norm_logits(outputs[::, -1, :], temperature, top_k, top_p)
        idx_next = sample(last_p)
        x = torch.cat((x, idx_next), dim=1)
        n += 1
    return x

In [19]:
prompt = 'The meaning of life is '
input_ids = tokenizer(prompt, return_tensors='pt').input_ids

In [41]:
top_k = 20
top_p = 0.9
num_tokens = 20

torch.manual_seed(123)
autoregressive_output = autoregressive_sampling(input_ids, model, num_tokens, top_k = top_k, top_p=top_p)

In [42]:
output_completions = [tokenizer.decode(output.tolist()) for output in autoregressive_output][0]
output_completions

'The meaning of life is \nthe ultimate goal of your life \nand the end of your life. \nAnd so'

In [34]:
import tqdm
@torch.no_grad()
def speculative_sampling_v2(prefix : torch.Tensor, approx_model : torch.nn.Module, target_model : torch.nn.Module, 
                         max_len : int , gamma : int = 4,
                         temperature : float = 1, top_k : int = 0, top_p : float = 0, random_seed : int = None) -> torch.Tensor:
    """
    DeepMind version Speculative Sampling.
    Accelerating Large Language Model Decoding with Speculative Sampling
    https://arxiv.org/abs/2302.01318
    No KV Cache Optimization
    
    Args:
        x (torch.Tensor): input sequence, (batch, prefix_seqlen), Note that the batch dim is always 1 now.
        approx_model (torch.nn.Module): approx model, the small one
        target_model (torch.nn.Module): target model, the large one
        max_len (int): the max overall generated tokens number.
        gamma (int): $\gamma$, the token number small model guesses.
        temperature (float, optional): Defaults to 1.
        top_k (int, optional): Defaults to 0.
        top_p (float, optional): Defaults to 0.

    Returns:
        torch.Tensor: generated tokens (batch, target_seqlen)
    """
    seq_len = prefix.shape[1]
    T = seq_len + max_len
    
    assert prefix.shape[0] == 1, "input batch size must be 1"

    while prefix.shape[1] < T:
        # q = M_q[prefix + x_0, x_1, .., x_(gamma-2)]
        x = prefix
        prefix_len = prefix.shape[1]
        for _ in range(gamma):
            # p.logits shape (batch, seq, vocab)
            q = approx_model(x)
            next_tok = sample(norm_logits(q[:, -1, :], 
                              temperature, top_k, top_p))
            x = torch.cat((x, next_tok), dim=1)

        # normalize the logits
        for i in range(q.shape[1]):
            q[:,i,:] = norm_logits(q[:,i,:],
                            temperature, top_k, top_p)
        # p  = M_p[prefix + x_0, x_0, .., x_(gamma-1)]
        p = target_model(x)
        for i in range(p.shape[1]):
            p[:,i,:] = norm_logits(p[:,i,:],
                            temperature, top_k, top_p)

        # n the end position of the valid prefix
        # x = x_[:prefix_len-1] + x_0, ... x_(gamma-1)

        is_all_accept = True
        n = prefix_len - 1
        for i in range(gamma):
            if random_seed:
                torch.manual_seed(random_seed)
            r = torch.rand(1, device = p.device)
            j = x[:, prefix_len + i]

            if r < torch.min(torch.tensor([1], device=q.device), p[:, prefix_len + i - 1, j] / q[:, prefix_len + i - 1, j]):
                # accept, and update n
                n += 1
            else:
                # reject
                t = sample(max_fn(p[:, n, :] - q[:, n, :]))
                is_all_accept = False
                break

        prefix = x[:, :n + 1]

        if is_all_accept:
            t = sample(p[:, -1, :])

        prefix = torch.cat((prefix, t), dim=1)

    return prefix

In [39]:
speculative_output = speculative_sampling_v2(input_ids, assistant_model, model, num_tokens)
output_completions = [tokenizer.decode(output.tolist()) for output in speculative_output][0]
output_completions

'The meaning of life is 不會 Any time the word “ someday “ is substituted by the word “ is … No need'

In [40]:
output

tensor([[  510,  4495,   273,  1495,   310,   209,   187, 12563,   281,  4264,
          1495,    15,   187,   187,   395,   187,   187,  5288,  1539,  2637,
          3481, 49367,   187, 50276,  4527,  8057,   272]])

In [38]:
input_ids

tensor([[ 510, 4495,  273, 1495,  310,  209]])