Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Initially taken from Github's Python gitignore file

# TB logs
# Data created by runs (not to be tracked)
tb_logs/
results/

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
20 changes: 19 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ You can profile a short run with `--profile`, with the TB logs being stored in `
python run_llama.py --model huggingface/llama-7b --preallocate --profile
```

gives, with `batch_size=1`, `prompt_length=1000`, `new_tokens=200`, `cache_length=1200`, `dtype=fp16`:
## Results

Running the command above with `batch_size=1`, `prompt_length=1000`, `new_tokens=200`, `cache_length=1200`, `dtype=fp16`:

| changes | compile | tok_per_s | max_mem_mb | hash | commit |
|-------------------------------------------------------------|---------|-----------|------------|----------|------------------------------------------|
Expand Down Expand Up @@ -56,3 +58,19 @@ BATCH_SIZES = [1, 2, 4, 8]
PROMPT_LENGTHS = [500, 1000, 4000]
NEW_TOKENS = [1000]
```

## Predefined sweeps

You can sweep over predefined configurations of batch sizes (for a fixed prompt length) and prompt lengths (for a
fixed batch size) with the `--sweep` flag, e.g.

```
python scripts/run_llama.py --model huggingface/llama-7b --sweep batch
```

If you run the sweep for the multiple generation alternatives (original code, with preallocated tensors, and
preallocated + compiled), you can easily compare the results with

```
python scripts/plot_results.py --sweep batch
```
79 changes: 79 additions & 0 deletions scripts/plot_results.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
"""
Plots the results of a sweep for the current git hash.
"""
import argparse

import git
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd


DEFAULT_BATCH_SIZE = 1
DEFAULT_PROMPT_LENGTH = 1000


parser = argparse.ArgumentParser()
parser.add_argument(
"--sweep",
type=str,
choices=["batch", "length"],
required=True,
help="Select which type of sweep to plot"
)
args = parser.parse_args()

# 1. Read file and retrieve relevant data
results_file = "./results/results_llama.csv"
df = pd.read_csv(results_file)

repo = git.Repo(search_parent_directories=True)
current_git_hash = repo.git.rev_parse(repo.head, short=True)
df = df[df["Git hash"] == current_git_hash]
if df.empty:
raise ValueError(f"No results found for current git hash ({current_git_hash})")

if args.sweep == "batch":
df = df[df["Prompt length"] == DEFAULT_PROMPT_LENGTH]
else:
df = df[df["Batch size"] == DEFAULT_BATCH_SIZE]
df = df[df["Prompt length"] != DEFAULT_PROMPT_LENGTH]

if df.empty:
raise ValueError("Something went wrong -- no results in the filtered dataframe")

# 2. Plot -- we expect 3 series: original model, preallocated, and preallocated + compiled
if args.sweep == "batch":
x_col_name = "Batch size"
else:
x_col_name = "Prompt length"

df["Type"] = df["Preallocate"].astype("str") + df["Compile"]
df["Type"] = df["Type"].replace({"Falseno": "original", "Trueno": "Preallocate", "Truestatic": "Pre + comp."})

g = sns.catplot(
data=df,
kind="bar",
x=x_col_name,
y="Tokens per second",
hue="Type",
palette={"original": "blue", "Preallocate": "orange", "Pre + comp.": "red"},
alpha=.9,
)
g.despine(left=True)
g.set_axis_labels("Batch size" if args.sweep == "batch" else "Prompt length", "Tokens per second")
g.legend.set_title("LLaMA code version")
plt.setp(g._legend.get_texts(), fontsize='7') # for legend text

title_constant = f'{"Batch size = " + str(DEFAULT_BATCH_SIZE) if args.sweep == "length" else "Prompt length = " + str(DEFAULT_PROMPT_LENGTH)}'
g.set(title=f'LLaMA sweep ({title_constant})')

# Add the number to the top of each bar
ax = g.facet_axis(0, 0)
for i in ax.containers:
ax.bar_label(i, fontsize=7)

g.tight_layout()
plt_path = f"./results/llama_sweep_{current_git_hash}_{args.sweep}.png"
plt.savefig(plt_path, dpi=300)
print(f"Plot stored at {plt_path}")
115 changes: 76 additions & 39 deletions scripts/run_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,37 @@
import copy
import contextlib
import hashlib
import os
from typing import Dict

from tqdm import tqdm
import pandas as pd
import torch
from torch.profiler import ProfilerActivity, profile, schedule, tensorboard_trace_handler
import git
from torch.profiler import ProfilerActivity, profile, tensorboard_trace_handler

from transformers import AutoModelForCausalLM, AutoTokenizer

from trfs_fast.llama import LlamaForCausalLM
from trfs_fast.utils import recurse_getattr, recurse_hasattr, recurse_setattr, recurse_delattr


# Default case
BATCH_SIZES = [1]
PROMPT_LENGTHS = [1000]
NEW_TOKENS = [200]
WARMUP_RUNS = 2
NUM_RUNS = 5

# Modifiers for profiling (we want a short run)
PROFILE_NEW_TOKENS = 10
PROFILE_NUM_RUNS = 1

# Modifiers for parameter sweeps
SWEEP_BATCH_SIZES = [1, 2, 4, 8, 16, 32, 64, 128]
SWEEP_PROMPT_LENGTHS = [100, 200, 400, 800, 1600]
SWEEP_NUM_RUNS = 10

parser = argparse.ArgumentParser()

# TODO: support other archs than llama
Expand Down Expand Up @@ -55,6 +65,14 @@
default="no",
help="If (and how) to compile the model forward pass with torch.compile",
)
parser.add_argument(
"--sweep",
type=str,
choices=["", "batch", "length"],
required=False,
default="",
help="Select which type of sweep to gather data for",
)


def timing_cuda(
Expand All @@ -71,15 +89,11 @@ def timing_cuda(
warmup_start_event = torch.cuda.Event(enable_timing=True)
warmup_end_event = torch.cuda.Event(enable_timing=True)

if do_profile:
num_runs = PROFILE_NUM_RUNS
max_new_tokens = PROFILE_NEW_TOKENS

if preallocate:
inputs["cache_length"] = cache_length

with torch.no_grad():
print("Warming up...")
print(f"Warming up ({WARMUP_RUNS} runs)...")
warmup_start_event.record()
for _ in range(WARMUP_RUNS):
res = generate_method(
Expand Down Expand Up @@ -149,10 +163,6 @@ def timing_cuda(
tokenizer = AutoTokenizer.from_pretrained(args.model)
tokenizer.pad_token = tokenizer.eos_token

header = "batch_size,compile,prompt_length,new_tokens,cache_length,dtype,tok_per_s,max_mem_mb,hash"
stats = {}


if args.preallocate:
with device:
original_model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=dtype)
Expand Down Expand Up @@ -200,14 +210,27 @@ def timing_cuda(
if args.compile != "no":
dynamic = args.compile == "dynamic"
fullgraph = args.compile == "fullgraph"
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=fullgraph, dynamic=dynamic)
mode = "reduce-overhead" if not args.sweep else "max-autotune"
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To see how fast we can get with compile :) Happy to revert if you feel like it won't be the planned use case!

model.forward = torch.compile(model.forward, mode=mode, fullgraph=fullgraph, dynamic=dynamic)

if model.config.model_type != "llama":
raise ValueError("This script currently only supports LLAMA")

for batch_size in tqdm(BATCH_SIZES):
for prompt_length in tqdm(PROMPT_LENGTHS):
for max_new_tokens in tqdm(NEW_TOKENS):
if args.profile and args.sweep:
raise ValueError("Cannot profile and sweep at the same time")
batch_sizes = BATCH_SIZES if args.sweep != "batch" else SWEEP_BATCH_SIZES
prompt_lengths = PROMPT_LENGTHS if args.sweep != "length" else SWEEP_PROMPT_LENGTHS
new_tokens = NEW_TOKENS if not args.profile else PROFILE_NEW_TOKENS
num_runs = NUM_RUNS
if args.profile:
num_runs = PROFILE_NUM_RUNS
elif args.sweep:
num_runs = SWEEP_NUM_RUNS

stats = {}
for batch_size in tqdm(batch_sizes):
for prompt_length in tqdm(prompt_lengths):
for max_new_tokens in tqdm(new_tokens):
cache_length = 1 * (prompt_length + max_new_tokens)

inp = {
Expand All @@ -225,19 +248,22 @@ def timing_cuda(
print("Cache preallocation:", args.preallocate)

generate_method = model.generate if not args.preallocate else model.generate_minimal
time_per_generation, max_memory, sha_hash = timing_cuda(
tokenizer=tokenizer,
num_runs=NUM_RUNS,
inputs=inp,
device=device,
max_new_tokens=max_new_tokens,
cache_length=cache_length,
generate_method=generate_method,
preallocate=args.preallocate,
do_profile=args.profile,
)
try:
time_per_generation, max_memory, sha_hash = timing_cuda(
tokenizer=tokenizer,
num_runs=num_runs,
inputs=inp,
device=device,
max_new_tokens=max_new_tokens,
cache_length=cache_length,
generate_method=generate_method,
preallocate=args.preallocate,
do_profile=args.profile,
)
except:
break # in a sweep, might get OOM

tok_per_s = max_new_tokens / time_per_generation
tok_per_s = (max_new_tokens * batch_size) / time_per_generation

stats[(batch_size, prompt_length, max_new_tokens)] = {
"cache_length": cache_length,
Expand All @@ -246,18 +272,29 @@ def timing_cuda(
"max_mem": max_memory
}

# print csv
print(header)
# build dataframe with the results and store it
rows = []
repo = git.Repo(search_parent_directories=True)
current_git_hash = repo.git.rev_parse(repo.head, short=True)
for key, value in stats.items():
batch_size, prompt_length, new_tokens = key
print(",".join([
str(batch_size),
args.compile,
str(prompt_length),
str(new_tokens),
str(value["cache_length"]),
args.dtype,
f"{value['tok_per_s']:.3f}",
f"{value['max_mem']:.2f}",
value["hash"]])
)
rows.append({
'Preallocate': args.preallocate,
'Compile': args.compile,
'Batch size': str(batch_size),
'Prompt length': str(prompt_length),
'New tokens': str(new_tokens),
'Cache length': str(value["cache_length"]),
'Weights dtype': args.dtype,
'Tokens per second': f"{value['tok_per_s']:.3f}",
'Max GPU memory (MB)': f"{value['max_mem']:.2f}",
'Results hash': value["hash"],
'Git hash': current_git_hash,
})
df = pd.DataFrame(rows)
print(df)

os.makedirs("./results", exist_ok=True)
output_path = "./results/results_llama.csv"
df.to_csv(output_path, mode='a', header=not os.path.exists(output_path))
print(f"Results also appended to {output_path}")
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from setuptools import find_packages, setup

REQUIRED_PKGS = ["torch", "transformers"]
REQUIRED_PKGS = ["torch", "transformers", "gitpython", "seaborn"]

QUALITY_REQUIRE = ["black~=22.0", "flake8>=3.8.3", "isort>=5.0.0", "pyyaml>=5.3.1"]

Expand Down