Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inference/SpecDec] Add Basic Drafter Model Container #5405

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions colossalai/inference/spec/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .drafter import Drafter
from .struct import DrafterOutput

__all__ = ["Drafter", "DrafterOutput"]
142 changes: 142 additions & 0 deletions colossalai/inference/spec/drafter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
from typing import Optional, Tuple

import torch
import torch.nn as nn
from transformers import PreTrainedTokenizer

from colossalai.utils import get_current_device

from .struct import DrafterOutput


class Drafter:
"""Container for the Drafter Model (Assistant Model) used in Speculative Decoding.

Args:
model (nn.Module): The drafter model.
tokenizer (transformers.PreTrainedTokenizer): The tokenizer for the drafter model.
max_spec_num (int): The maximum number of tokens to speculate.
device (torch.device): The device for the drafter model.
"""

def __init__(
self, model: nn.Module, tokenizer: PreTrainedTokenizer, max_spec_num: int, device: torch.device = None
):
self._drafter_model = model
self._tokenizer = tokenizer
self.max_spec_num = max_spec_num
self.do_sample = False
self.sample_fn = None
self._device = device or get_current_device()
self._past_key_values = None

@property
def past_key_values(self) -> Optional[Tuple[Tuple[torch.FloatTensor]]]:
return self._past_key_values

# Debug usage for now
@property
def past_key_values_shape(self):
if self._past_key_values is None:
return []
return self._past_key_values[0][0].shape

def get_model(self) -> nn.Module:
return self._drafter_model

def reset_sample_method(self, sample_fn: callable) -> None:
self.do_sample = True
self.sample_fn = sample_fn

def clear_sample_method(self) -> None:
self.do_sample = False
self.sample_fn = None

def reset_max_spec_num(self, n: int) -> None:
assert isinstance(n, int) and n > 1
self.max_spec_num = n

def reset_past_key_values(self, past_key_values: Tuple[Tuple[torch.FloatTensor]] = None) -> None:
self._past_key_values = past_key_values

def trim_kv_cache(self, invalid_token_num) -> Tuple[Tuple[torch.FloatTensor]]:
# Tuple of kv cache tensors: num_layers x 2 x (bsz x num_heads x seq_len x head_dim)
# Trim the last `invalid_token_num` kv caches
# The verifier (main model) might reject `invalid_token_num` tokens,
# and so that we have to trim the invalid tokens for the kv cache of the drafter model.
assert self._past_key_values is not None
trimmed_past_key_values = []
for layer_idx in range(len(self._past_key_values)):
past_key_value = self._past_key_values[layer_idx]
trimmed_past_key_values.append(
(
past_key_value[0][:, :, :-invalid_token_num, :],
past_key_value[1][:, :, :-invalid_token_num, :],
)
)
self._past_key_values = tuple(trimmed_past_key_values)
return self._past_key_values

@torch.inference_mode()
def speculate(
self, input_ids: torch.Tensor, n: int, past_key_values: Tuple[Tuple[torch.FloatTensor]] = None
) -> DrafterOutput:
"""Generate n tokens using the drafter model.

Args:
input_ids (torch.Tensor): Input token ids.
n (int): Number of tokens to speculate.
past_key_values (Tuple[Tuple[torch.FloatTensor]]): The past key values of the input sequence.
"""

assert 0 <= n <= self.max_spec_num, f"Invalid number {n} to speculate"

# FIXME For compatibility with transformers 4.36.2 (versions before 4.38.0)
if input_ids.dim() == 1:
input_ids = input_ids.unsqueeze(0)

if past_key_values is None:
past_key_values = self._past_key_values

logits = []
token_ids = []

for _ in range(n):
outputs = self._drafter_model(
input_ids,
return_dict=True,
use_cache=True,
past_key_values=past_key_values,
)
next_token_logits = outputs.logits[:, -1, :]

# Skip logits_processor for drafter model

# Sample
if self.do_sample:
if self.sample_fn is not None:
probs = self.sample_fn(next_token_logits)
else:
probs = nn.functional.softmax(next_token_logits, dim=-1)
next_token_ids = torch.multinomial(probs, num_samples=1).squeeze(1)
else:
next_token_ids = torch.argmax(next_token_logits, dim=-1)

logits.append(next_token_logits)
token_ids.append(next_token_ids)
if next_token_ids.item() == self._tokenizer.eos_token_id:
# TODO support bsz > 1
break
input_ids = next_token_ids[:, None]
past_key_values = outputs.past_key_values

speculated_length = len(token_ids) # TODO For now, only support bsz 1
logits = torch.concat(logits, dim=0)
token_ids = torch.concat(token_ids, dim=-1)
# update past_key_values
self._past_key_values = past_key_values

out = DrafterOutput(
speculated_length=speculated_length, logits=logits, next_tokens=token_ids, past_key_values=past_key_values
)
return out
29 changes: 29 additions & 0 deletions colossalai/inference/spec/struct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from dataclasses import dataclass
from typing import Optional, Tuple

import torch


@dataclass
class DrafterOutput:
"""
Dataclass for drafter model outputs.

Args:
speculated_length (int): Speculated length of the output sequence
It is always less than or equal to spec_num during drafter's speculation process
logits (torch.FloatTensor): Logits of the output sequence
next_tokens (torch.Tensor): Next token ids
past_key_values (Optional[Tuple[Tuple[torch.FloatTensor]]]): Past key values of the output sequence
"""

speculated_length: int = None
logits: torch.FloatTensor = None
next_tokens: torch.Tensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None

def __post_init__(self):
assert self.speculated_length is not None and self.speculated_length >= 0
if self.past_key_values is not None:
assert isinstance(self.past_key_values, tuple), "Past key values should be a tuple"
assert all([isinstance(past_key_value, tuple) for past_key_value in self.past_key_values])
41 changes: 41 additions & 0 deletions tests/test_infer/test_drafter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import pytest
import torch
from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM

from colossalai.inference.spec.drafter import Drafter
from colossalai.utils import get_current_device

NUM_LAYERS = 2


@pytest.mark.parametrize("spec_num", [5])
def test_drafter(spec_num: int):
torch.manual_seed(123)

device = get_current_device()

toy_config = LlamaConfig(num_hidden_layers=NUM_LAYERS)
toy_config.pad_token_id = toy_config.eos_token_id
drafter_model = LlamaForCausalLM(toy_config)
drafter_model = drafter_model.eval().cuda()
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")

drafter = Drafter(drafter_model, tokenizer, spec_num, device=device)

input_ids = torch.randint(low=5, high=1000, size=(1, 6)).to(device)
out = drafter.speculate(input_ids, spec_num)
past_kv_length = input_ids.size(1) + spec_num - 1

assert out.speculated_length == spec_num
assert out.next_tokens.shape == (spec_num,)
assert out.logits.shape == (spec_num, len(tokenizer))
assert drafter._past_key_values[0][0].size(2) == out.past_key_values[0][0].size(2) == past_kv_length

reject_num = 3
assert reject_num <= spec_num
drafter.trim_kv_cache(reject_num)
assert drafter._past_key_values[0][0].size(2) == past_kv_length - reject_num


if __name__ == "__main__":
test_drafter(spec_num=5)
8 changes: 6 additions & 2 deletions tests/test_infer/test_ops/triton/test_rmsnorm_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@
import triton
from packaging import version
from transformers.models.llama.modeling_llama import LlamaRMSNorm
from vllm.model_executor.layers.layernorm import RMSNorm

from colossalai.kernel.triton import rms_layernorm
from colossalai.testing.utils import parameterize

try:
pass
import triton # noqa

HAS_TRITON = True
except ImportError:
Expand Down Expand Up @@ -85,6 +84,11 @@ def benchmark_rms_layernorm(
SEQUENCE_TOTAL: int,
HIDDEN_SIZE: int,
):
try:
from vllm.model_executor.layers.layernorm import RMSNorm
except ImportError:
raise ImportError("Please install vllm from https://github.com/vllm-project/vllm")

warmup = 10
rep = 1000

Expand Down
Loading