Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit d6203d6

Browse files
authoredFeb 28, 2024
[Inference/SpecDec] Add Basic Drafter Model Container (#5405)
* [Infer/Fix] Fix Dependency in test - RMSNorm kernel (#5399) fix dependency in pytest * add drafter model container (basic ver)
1 parent 2d62aca commit d6203d6

File tree

4 files changed

+216
-0
lines changed

4 files changed

+216
-0
lines changed
 

‎colossalai/inference/spec/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .drafter import Drafter
2+
from .struct import DrafterOutput
3+
4+
__all__ = ["Drafter", "DrafterOutput"]

‎colossalai/inference/spec/drafter.py

+142
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
from typing import Optional, Tuple
2+
3+
import torch
4+
import torch.nn as nn
5+
from transformers import PreTrainedTokenizer
6+
7+
from colossalai.utils import get_current_device
8+
9+
from .struct import DrafterOutput
10+
11+
12+
class Drafter:
13+
"""Container for the Drafter Model (Assistant Model) used in Speculative Decoding.
14+
15+
Args:
16+
model (nn.Module): The drafter model.
17+
tokenizer (transformers.PreTrainedTokenizer): The tokenizer for the drafter model.
18+
max_spec_num (int): The maximum number of tokens to speculate.
19+
device (torch.device): The device for the drafter model.
20+
"""
21+
22+
def __init__(
23+
self, model: nn.Module, tokenizer: PreTrainedTokenizer, max_spec_num: int, device: torch.device = None
24+
):
25+
self._drafter_model = model
26+
self._tokenizer = tokenizer
27+
self.max_spec_num = max_spec_num
28+
self.do_sample = False
29+
self.sample_fn = None
30+
self._device = device or get_current_device()
31+
self._past_key_values = None
32+
33+
@property
34+
def past_key_values(self) -> Optional[Tuple[Tuple[torch.FloatTensor]]]:
35+
return self._past_key_values
36+
37+
# Debug usage for now
38+
@property
39+
def past_key_values_shape(self):
40+
if self._past_key_values is None:
41+
return []
42+
return self._past_key_values[0][0].shape
43+
44+
def get_model(self) -> nn.Module:
45+
return self._drafter_model
46+
47+
def reset_sample_method(self, sample_fn: callable) -> None:
48+
self.do_sample = True
49+
self.sample_fn = sample_fn
50+
51+
def clear_sample_method(self) -> None:
52+
self.do_sample = False
53+
self.sample_fn = None
54+
55+
def reset_max_spec_num(self, n: int) -> None:
56+
assert isinstance(n, int) and n > 1
57+
self.max_spec_num = n
58+
59+
def reset_past_key_values(self, past_key_values: Tuple[Tuple[torch.FloatTensor]] = None) -> None:
60+
self._past_key_values = past_key_values
61+
62+
def trim_kv_cache(self, invalid_token_num) -> Tuple[Tuple[torch.FloatTensor]]:
63+
# Tuple of kv cache tensors: num_layers x 2 x (bsz x num_heads x seq_len x head_dim)
64+
# Trim the last `invalid_token_num` kv caches
65+
# The verifier (main model) might reject `invalid_token_num` tokens,
66+
# and so that we have to trim the invalid tokens for the kv cache of the drafter model.
67+
assert self._past_key_values is not None
68+
trimmed_past_key_values = []
69+
for layer_idx in range(len(self._past_key_values)):
70+
past_key_value = self._past_key_values[layer_idx]
71+
trimmed_past_key_values.append(
72+
(
73+
past_key_value[0][:, :, :-invalid_token_num, :],
74+
past_key_value[1][:, :, :-invalid_token_num, :],
75+
)
76+
)
77+
self._past_key_values = tuple(trimmed_past_key_values)
78+
return self._past_key_values
79+
80+
@torch.inference_mode()
81+
def speculate(
82+
self, input_ids: torch.Tensor, n: int, past_key_values: Tuple[Tuple[torch.FloatTensor]] = None
83+
) -> DrafterOutput:
84+
"""Generate n tokens using the drafter model.
85+
86+
Args:
87+
input_ids (torch.Tensor): Input token ids.
88+
n (int): Number of tokens to speculate.
89+
past_key_values (Tuple[Tuple[torch.FloatTensor]]): The past key values of the input sequence.
90+
"""
91+
92+
assert 0 <= n <= self.max_spec_num, f"Invalid number {n} to speculate"
93+
94+
# FIXME For compatibility with transformers 4.36.2 (versions before 4.38.0)
95+
if input_ids.dim() == 1:
96+
input_ids = input_ids.unsqueeze(0)
97+
98+
if past_key_values is None:
99+
past_key_values = self._past_key_values
100+
101+
logits = []
102+
token_ids = []
103+
104+
for _ in range(n):
105+
outputs = self._drafter_model(
106+
input_ids,
107+
return_dict=True,
108+
use_cache=True,
109+
past_key_values=past_key_values,
110+
)
111+
next_token_logits = outputs.logits[:, -1, :]
112+
113+
# Skip logits_processor for drafter model
114+
115+
# Sample
116+
if self.do_sample:
117+
if self.sample_fn is not None:
118+
probs = self.sample_fn(next_token_logits)
119+
else:
120+
probs = nn.functional.softmax(next_token_logits, dim=-1)
121+
next_token_ids = torch.multinomial(probs, num_samples=1).squeeze(1)
122+
else:
123+
next_token_ids = torch.argmax(next_token_logits, dim=-1)
124+
125+
logits.append(next_token_logits)
126+
token_ids.append(next_token_ids)
127+
if next_token_ids.item() == self._tokenizer.eos_token_id:
128+
# TODO support bsz > 1
129+
break
130+
input_ids = next_token_ids[:, None]
131+
past_key_values = outputs.past_key_values
132+
133+
speculated_length = len(token_ids) # TODO For now, only support bsz 1
134+
logits = torch.concat(logits, dim=0)
135+
token_ids = torch.concat(token_ids, dim=-1)
136+
# update past_key_values
137+
self._past_key_values = past_key_values
138+
139+
out = DrafterOutput(
140+
speculated_length=speculated_length, logits=logits, next_tokens=token_ids, past_key_values=past_key_values
141+
)
142+
return out

‎colossalai/inference/spec/struct.py

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from dataclasses import dataclass
2+
from typing import Optional, Tuple
3+
4+
import torch
5+
6+
7+
@dataclass
8+
class DrafterOutput:
9+
"""
10+
Dataclass for drafter model outputs.
11+
12+
Args:
13+
speculated_length (int): Speculated length of the output sequence
14+
It is always less than or equal to spec_num during drafter's speculation process
15+
logits (torch.FloatTensor): Logits of the output sequence
16+
next_tokens (torch.Tensor): Next token ids
17+
past_key_values (Optional[Tuple[Tuple[torch.FloatTensor]]]): Past key values of the output sequence
18+
"""
19+
20+
speculated_length: int = None
21+
logits: torch.FloatTensor = None
22+
next_tokens: torch.Tensor = None
23+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
24+
25+
def __post_init__(self):
26+
assert self.speculated_length is not None and self.speculated_length >= 0
27+
if self.past_key_values is not None:
28+
assert isinstance(self.past_key_values, tuple), "Past key values should be a tuple"
29+
assert all([isinstance(past_key_value, tuple) for past_key_value in self.past_key_values])

‎tests/test_infer/test_drafter.py

+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import pytest
2+
import torch
3+
from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM
4+
5+
from colossalai.inference.spec.drafter import Drafter
6+
from colossalai.utils import get_current_device
7+
8+
NUM_LAYERS = 2
9+
10+
11+
@pytest.mark.parametrize("spec_num", [5])
12+
def test_drafter(spec_num: int):
13+
torch.manual_seed(123)
14+
15+
device = get_current_device()
16+
17+
toy_config = LlamaConfig(num_hidden_layers=NUM_LAYERS)
18+
toy_config.pad_token_id = toy_config.eos_token_id
19+
drafter_model = LlamaForCausalLM(toy_config)
20+
drafter_model = drafter_model.eval().cuda()
21+
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
22+
23+
drafter = Drafter(drafter_model, tokenizer, spec_num, device=device)
24+
25+
input_ids = torch.randint(low=5, high=1000, size=(1, 6)).to(device)
26+
out = drafter.speculate(input_ids, spec_num)
27+
past_kv_length = input_ids.size(1) + spec_num - 1
28+
29+
assert out.speculated_length == spec_num
30+
assert out.next_tokens.shape == (spec_num,)
31+
assert out.logits.shape == (spec_num, len(tokenizer))
32+
assert drafter._past_key_values[0][0].size(2) == out.past_key_values[0][0].size(2) == past_kv_length
33+
34+
reject_num = 3
35+
assert reject_num <= spec_num
36+
drafter.trim_kv_cache(reject_num)
37+
assert drafter._past_key_values[0][0].size(2) == past_kv_length - reject_num
38+
39+
40+
if __name__ == "__main__":
41+
test_drafter(spec_num=5)

0 commit comments

Comments
 (0)
Please sign in to comment.