Skip to content

Commit

Permalink
Attempt to use torch.compile
Browse files Browse the repository at this point in the history
Error messages: https://gist.github.com/dzhulgakov/188b997ce56e3540218d8991b49b85aa

Tried:
1. mode='reduce-overhead': fails with tangents mismatch error
2. mode='reduce-overhead' with commented out inference_mode: takes 3+ minutes to compile and performance degrades compared to no compilation (13t/s -> 7t/s). Warns about complex performance
3. mode='cudagraphs': cryptic error about FakeTensor (maybe also related to complex too?)

Worked: mode='reduce-overhead', disable distributed, no_grad instead of inference_mode
  • Loading branch information
dzhulgakov committed Aug 25, 2023
1 parent a41de26 commit ddf596e
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 31 deletions.
2 changes: 1 addition & 1 deletion llama/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def __init__(self, model: Transformer, tokenizer: Tokenizer):
self.model = model
self.tokenizer = tokenizer

@torch.inference_mode()
@torch.no_grad()
def generate(
self,
prompt_tokens: List[List[int]],
Expand Down
46 changes: 16 additions & 30 deletions llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,8 @@
from dataclasses import dataclass
from typing import Any, Optional, Tuple

import fairscale.nn.model_parallel.initialize as fs_init
import torch
import torch.nn.functional as F
from fairscale.nn.model_parallel.layers import (
ColumnParallelLinear,
ParallelEmbedding,
RowParallelLinear,
)
from torch import nn


Expand Down Expand Up @@ -90,39 +84,31 @@ class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
model_parallel_size = fs_init.get_model_parallel_world_size()
model_parallel_size = 1
self.n_local_heads = args.n_heads // model_parallel_size
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads

self.wq = ColumnParallelLinear(
self.wq = nn.Linear(
args.dim,
args.n_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wk = ColumnParallelLinear(
self.wk = nn.Linear(
args.dim,
self.n_kv_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wv = ColumnParallelLinear(
self.wv = nn.Linear(
args.dim,
self.n_kv_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=lambda x: x,
)
self.wo = RowParallelLinear(
self.wo = nn.Linear(
args.n_heads * self.head_dim,
args.dim,
bias=False,
input_is_parallel=True,
init_method=lambda x: x,
)

self.cache_k = torch.zeros(
Expand Down Expand Up @@ -198,14 +184,14 @@ def __init__(
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

self.w1 = ColumnParallelLinear(
dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
self.w1 = nn.Linear(
dim, hidden_dim, bias=False
)
self.w2 = RowParallelLinear(
hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x
self.w2 = nn.Linear(
hidden_dim, dim, bias=False
)
self.w3 = ColumnParallelLinear(
dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
self.w3 = nn.Linear(
dim, hidden_dim, bias=False
)

def forward(self, x):
Expand Down Expand Up @@ -250,24 +236,24 @@ def __init__(self, params: ModelArgs):
self.vocab_size = params.vocab_size
self.n_layers = params.n_layers

self.tok_embeddings = ParallelEmbedding(
params.vocab_size, params.dim, init_method=lambda x: x
self.tok_embeddings = nn.Embedding(
params.vocab_size, params.dim
)

self.layers = torch.nn.ModuleList()
for layer_id in range(params.n_layers):
self.layers.append(TransformerBlock(layer_id, params))

self.norm = RMSNorm(params.dim, eps=params.norm_eps)
self.output = ColumnParallelLinear(
params.dim, params.vocab_size, bias=False, init_method=lambda x: x
self.output = nn.Linear(
params.dim, params.vocab_size, bias=False
)

self.freqs_cis = precompute_freqs_cis(
self.params.dim // self.params.n_heads, self.params.max_seq_len * 2
)

@torch.inference_mode()
@torch.no_grad()
def forward(self, tokens: torch.Tensor, start_pos: int):
_bsz, seqlen = tokens.shape
h = self.tok_embeddings(tokens)
Expand Down
78 changes: 78 additions & 0 deletions text_completion_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

from contextlib import nullcontext
import fire
import time
import torch
import torch.profiler

from llama import Llama


def benchmark(
ckpt_dir: str,
tokenizer_path: str,
max_seq_len: int = 128,
warmup_iterations: int = 2,
test_iterations: int = 5,
use_torch_compile: bool = False,
use_inductor: bool = True,
profile : bool = False,
):
# Build the Llama generator
generator = Llama.build(
ckpt_dir=ckpt_dir,
tokenizer_path=tokenizer_path,
max_seq_len=max_seq_len,
max_batch_size=1,
)
if use_torch_compile:
if use_inductor:
generator.model = torch.compile(generator.model, mode='reduce-overhead')
else:
generator.model = torch.compile(generator.model, backend='cudagraphs')


# Sample prompt for warmup and benchmarking
prompt = "The theory of everything is"

# Warmup Iterations
for i in range(warmup_iterations):
print(f"Warmup iteration {i}")
_ = generator.text_completion([prompt])

# Ensure GPU operations have completed
torch.cuda.synchronize()

# Benchmark Iterations
start_time = time.perf_counter()
total_tokens = 0
benchmark_schedule = torch.profiler.schedule(wait=0, warmup=2, active=1, repeat=1)
with torch.profiler.profile(
schedule=benchmark_schedule,
on_trace_ready=torch.profiler.tensorboard_trace_handler(f'./log_torch_compile_{use_torch_compile}_{use_inductor}'),
record_shapes=True,
) if profile else nullcontext() as prof:
for i in range(test_iterations):
print(f'Benchmark iteration {i}')
result = generator.text_completion([prompt])
total_tokens += len(result[0]['generation'].split())
if profile:
prof.step()

# Ensure GPU operations have completed
torch.cuda.synchronize()

end_time = time.perf_counter()
elapsed_time = end_time - start_time
seconds_per_example = elapsed_time / test_iterations
tokens_per_second = total_tokens / elapsed_time

print(f"Results after {test_iterations} iterations:")
print(f"Seconds per example: {seconds_per_example:.4f} sec")
print(f"Tokens per second: {tokens_per_second:.2f} tokens/sec")


if __name__ == "__main__":
fire.Fire(benchmark)

0 comments on commit ddf596e

Please sign in to comment.