Skip to content

Commit

Permalink
CUDA graph compilation for bs=1
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesr66a committed Aug 16, 2023
1 parent ea9f33d commit d8003f5
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 26 deletions.
16 changes: 9 additions & 7 deletions example_text_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def main(
top_p: float = 0.9,
max_seq_len: int = 128,
max_gen_len: int = 64,
max_batch_size: int = 4,
max_batch_size: int = 1,
):
generator = Llama.build(
ckpt_dir=ckpt_dir,
Expand All @@ -39,12 +39,14 @@ def main(
plush girafe => girafe peluche
cheese =>""",
]
results = generator.text_completion(
prompts,
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
)
results = []
for prompt in prompts:
results.append(generator.text_completion(
[prompt],
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
)[0])
for prompt, result in zip(prompts, results):
print(prompt)
print(f"> {result['generation']}")
Expand Down
63 changes: 62 additions & 1 deletion llama/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import torch
import torch.nn.functional as F
from torch.profiler import record_function
from fairscale.nn.model_parallel.initialize import (
get_model_parallel_rank,
initialize_model_parallel,
Expand Down Expand Up @@ -102,6 +103,62 @@ def __init__(self, model: Transformer, tokenizer: Tokenizer):
self.model = model
self.tokenizer = tokenizer

self.compiled_model = None
self._cuda_graph = None
self._compiled_inputs = None
self._compiled_logits = None

def _compile_model(self, tokens_sliced : torch.Tensor, mask : torch.Tensor, valid_seq_pos : torch.Tensor):
assert self._cuda_graph is None and self._compiled_inputs is None and self._compiled_logits is None, "Already compiled the model"

self._compiled_inputs = tuple(v.clone() for v in (tokens_sliced, mask, valid_seq_pos))

s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
_ = self.model.forward(*self._compiled_inputs)
torch.cuda.current_stream().wait_stream(s)

self._cuda_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self._cuda_graph):
self._compiled_logits = self.model.forward(*self._compiled_inputs)

def replay(tokens, mask, valid_seq_pos):
self._compiled_inputs[0].copy_(tokens)
self._compiled_inputs[1].copy_(mask)
self._compiled_inputs[2].copy_(valid_seq_pos)

self._cuda_graph.replay()

return self._compiled_logits

return replay


def compile_and_call_model(self, tokens : torch.Tensor, prev_pos : int, cur_pos : int, use_cuda_graph : bool):
if prev_pos == 0:
with record_function("prefill"):
tokens_sliced, mask, valid_seq_pos = self.model.params_for_prefill(
tokens, prev_pos, cur_pos, tokens.device)

logits = self.model.forward(tokens=tokens_sliced, mask=mask, valid_seq_pos=valid_seq_pos)
else:
with record_function("incremental_gen"):
tokens_sliced, mask, valid_seq_pos = self.model.params_for_incremental_gen(
tokens, prev_pos, cur_pos, tokens.device)

bsz = tokens.shape[0]
if self.compiled_model is None:
if use_cuda_graph:
assert bsz == 1, "Only support bs=1 for now"
self.compiled_model = self._compile_model(tokens_sliced, mask, valid_seq_pos)
else:
self.compiled_model = self.model.forward

logits = self.compiled_model(tokens=tokens_sliced, mask=mask, valid_seq_pos=valid_seq_pos)

return logits

@torch.inference_mode()
def generate(
self,
Expand All @@ -111,6 +168,7 @@ def generate(
top_p: float = 0.9,
logprobs: bool = False,
echo: bool = False,
use_cuda_graph : bool = True,
) -> Tuple[List[List[int]], Optional[List[List[float]]]]:
params = self.model.params
bsz = len(prompt_tokens)
Expand All @@ -132,7 +190,8 @@ def generate(
eos_reached = torch.tensor([False] * bsz, device="cuda")
input_text_mask = tokens != pad_id
for cur_pos in range(min_prompt_len, total_len):
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
logits = self.compile_and_call_model(tokens, prev_pos, cur_pos, use_cuda_graph=use_cuda_graph)

if logprobs:
token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
input=logits.transpose(1, 2),
Expand Down Expand Up @@ -186,6 +245,7 @@ def text_completion(
max_gen_len: Optional[int] = None,
logprobs: bool = False,
echo: bool = False,
use_cuda_graph : bool = True,
) -> List[CompletionPrediction]:
if max_gen_len is None:
max_gen_len = self.model.params.max_seq_len - 1
Expand All @@ -197,6 +257,7 @@ def text_completion(
top_p=top_p,
logprobs=logprobs,
echo=echo,
use_cuda_graph=use_cuda_graph,
)
if logprobs:
return [
Expand Down
54 changes: 36 additions & 18 deletions llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def __init__(self, args: ModelArgs):
def forward(
self,
x: torch.Tensor,
start_pos: int,
valid_seq_pos: torch.tensor,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor],
):
Expand All @@ -161,11 +161,15 @@ def forward(
self.cache_k = self.cache_k.to(xq)
self.cache_v = self.cache_v.to(xq)

self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
self.cache_k[:bsz, valid_seq_pos] = xk
self.cache_v[:bsz, valid_seq_pos] = xv

keys = self.cache_k[:bsz, : start_pos + seqlen]
values = self.cache_v[:bsz, : start_pos + seqlen]
if seqlen == 1:
keys = self.cache_k[:bsz]
values = self.cache_v[:bsz]
else:
keys = self.cache_k[:bsz, valid_seq_pos]
values = self.cache_v[:bsz, valid_seq_pos]

# repeat k/v heads if n_kv_heads < n_heads
keys = repeat_kv(keys, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
Expand Down Expand Up @@ -232,12 +236,12 @@ def __init__(self, layer_id: int, args: ModelArgs):
def forward(
self,
x: torch.Tensor,
start_pos: int,
valid_seq_pos : torch.Tensor,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor],
):
h = x + self.attention.forward(
self.attention_norm(x), start_pos, freqs_cis, mask
self.attention_norm(x), valid_seq_pos, freqs_cis, mask
)
out = h + self.feed_forward.forward(self.ffn_norm(h))
return out
Expand Down Expand Up @@ -267,22 +271,36 @@ def __init__(self, params: ModelArgs):
self.params.dim // self.params.n_heads, self.params.max_seq_len * 2
)

def params_for_prefill(self, tokens : torch.Tensor, prev_pos : int, cur_pos : int, device : torch.device):
tokens_sliced = tokens[:, prev_pos:cur_pos].to(device=device)
valid_seq_pos = torch.arange(prev_pos, cur_pos, device=device)
seqlen = cur_pos - prev_pos
mask = torch.full(
(1, 1, seqlen, seqlen), float("-inf"), device=device
)
mask = torch.triu(mask, diagonal=valid_seq_pos[0] + 1)

return tokens_sliced, mask, valid_seq_pos

def params_for_incremental_gen(self, tokens : torch.Tensor, prev_pos : int, cur_pos : int, device : torch.device):
tokens_sliced = tokens[:, prev_pos:cur_pos].to(device=device)
valid_seq_pos = torch.arange(prev_pos, cur_pos, device=device)

mask = torch.full(
(1, 1, 1, self.params.max_seq_len), float("-inf"), device=device
)
mask[:, :, :, :valid_seq_pos.item() + 1] = 0.0

return tokens_sliced, mask, valid_seq_pos

@torch.inference_mode()
def forward(self, tokens: torch.Tensor, start_pos: int):
_bsz, seqlen = tokens.shape
def forward(self, tokens: torch.Tensor, mask : torch.Tensor, valid_seq_pos : torch.Tensor):
h = self.tok_embeddings(tokens)
self.freqs_cis = self.freqs_cis.to(h.device)
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]

mask = None
if seqlen > 1:
mask = torch.full(
(1, 1, seqlen, seqlen), float("-inf"), device=tokens.device
)
mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)
freqs_cis = self.freqs_cis[valid_seq_pos]

for layer in self.layers:
h = layer(h, start_pos, freqs_cis, mask)
h = layer(h, valid_seq_pos, freqs_cis, mask)
h = self.norm(h)
output = self.output(h).float()
return output
71 changes: 71 additions & 0 deletions text_completion_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

from contextlib import nullcontext
import fire
import time
import torch
import torch.profiler

from llama import Llama


def benchmark(
ckpt_dir: str,
tokenizer_path: str,
max_seq_len: int = 128,
warmup_iterations: int = 2,
test_iterations: int = 5,
use_cuda_graph : bool = True,
profile : bool = False,
):
# Build the Llama generator
generator = Llama.build(
ckpt_dir=ckpt_dir,
tokenizer_path=tokenizer_path,
max_seq_len=max_seq_len,
max_batch_size=1,
)

# Sample prompt for warmup and benchmarking
prompt = "The theory of everything is"

# Warmup Iterations
for i in range(warmup_iterations):
print(f"Warmup iteration {i}")
_ = generator.text_completion([prompt], use_cuda_graph=use_cuda_graph)

# Ensure GPU operations have completed
torch.cuda.synchronize()

# Benchmark Iterations
start_time = time.perf_counter()
total_tokens = 0
benchmark_schedule = torch.profiler.schedule(wait=0, warmup=2, active=1, repeat=1)
with torch.profiler.profile(
schedule=benchmark_schedule,
on_trace_ready=torch.profiler.tensorboard_trace_handler(f'./log_cudagraph_{use_cuda_graph}'),
record_shapes=True,
) if profile else nullcontext as prof:
for i in range(test_iterations):
print(f'Benchmark iteration {i}')
result = generator.text_completion([prompt], use_cuda_graph=use_cuda_graph)
total_tokens += len(result[0]['generation'].split())
if profile:
prof.step()

# Ensure GPU operations have completed
torch.cuda.synchronize()

end_time = time.perf_counter()
elapsed_time = end_time - start_time
seconds_per_example = elapsed_time / test_iterations
tokens_per_second = total_tokens / elapsed_time

print(f"Results after {test_iterations} iterations:")
print(f"Seconds per example: {seconds_per_example:.4f} sec")
print(f"Tokens per second: {tokens_per_second:.2f} tokens/sec")


if __name__ == "__main__":
fire.Fire(benchmark)

0 comments on commit d8003f5

Please sign in to comment.