Skip to content

Commit

Permalink
Fix TunableOp bug (#1920)
Browse files Browse the repository at this point in the history
  • Loading branch information
fxmarty committed May 17, 2024
1 parent 422bf1f commit b5f1c9d
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 1 deletion.
2 changes: 1 addition & 1 deletion docs/source/basic_tutorials/monitoring.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,4 @@ Once Prometheus data source is configured, we can finally create our dashboard!

Community contributed dashboard templates are also available, for example [here](https://grafana.com/grafana/dashboards/19831-text-generation-inference-dashboard/) or [here](https://grafana.com/grafana/dashboards/20246-text-generation-inference/).

Load your dashboard configuration, and your TGI dashboard should be ready to go!
Load your dashboard configuration, and your TGI dashboard should be ready to go!
22 changes: 22 additions & 0 deletions server/text_generation_server/models/flash_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,28 @@ def max_past(self) -> int:
def batch_type(self) -> Type[FlashMistralBatch]:
return FlashMistralBatch

def tunableop_warmup(self, seqlen: int):
input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device)
position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
kv_cache = get_cache_manager().kv_cache

# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=torch.tensor(
[0, seqlen], device=self.device, dtype=torch.int32
),
kv_cache=get_cache_manager().kv_cache,
block_tables=None,
input_lengths=None,
slots=slots,
max_s=seqlen,
lm_head_indices=None,
prefill_cache_indices=None,
)

def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
Expand Down
24 changes: 24 additions & 0 deletions server/text_generation_server/models/mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,30 @@ def cuda_graph_warmup(self, batch_size: int):
}
self.cuda_graphs[batch_size] = graph_dict

def tunableop_warmup(self, seqlen: int):
input_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=self.device)
n_blocks = len(self.model.blocks)

d_state = self.model.config.d_state
d_conv = self.model.config.d_conv
# Inner takes the expand multiplication
d_inner = self.model.config.d_inner

# Important seqlen_offset to go through the update mecanism with the state
seqlen_offset = 1
inference_params = new_inference_params(
n_blocks=n_blocks,
batch_size=seqlen,
d_state=d_state,
d_conv=d_conv,
d_inner=d_inner,
seqlen_offset=seqlen_offset,
device=self.device,
dtype=self.dtype,
)

self.model.forward(input_ids=input_ids, inference_params=inference_params)

def forward(
self, input_ids: torch.Tensor, inference_params: Any
) -> Tuple[torch.Tensor, torch.Tensor]:
Expand Down

0 comments on commit b5f1c9d

Please sign in to comment.