diff --git a/.gitignore b/.gitignore index 0798d2a..7344cbf 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,8 @@ # Initially taken from Github's Python gitignore file -# TB logs +# Data created by runs (not to be tracked) tb_logs/ +results/ # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/README.md b/README.md index 4aafe31..a15fa37 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,9 @@ You can profile a short run with `--profile`, with the TB logs being stored in ` python run_llama.py --model huggingface/llama-7b --preallocate --profile ``` -gives, with `batch_size=1`, `prompt_length=1000`, `new_tokens=200`, `cache_length=1200`, `dtype=fp16`: +## Results + +Running the command above with `batch_size=1`, `prompt_length=1000`, `new_tokens=200`, `cache_length=1200`, `dtype=fp16`: | changes | compile | tok_per_s | max_mem_mb | hash | commit | |-------------------------------------------------------------|---------|-----------|------------|----------|------------------------------------------| @@ -56,3 +58,19 @@ BATCH_SIZES = [1, 2, 4, 8] PROMPT_LENGTHS = [500, 1000, 4000] NEW_TOKENS = [1000] ``` + +## Predefined sweeps + +You can sweep over predefined configurations of batch sizes (for a fixed prompt length) and prompt lengths (for a +fixed batch size) with the `--sweep` flag, e.g. + +``` +python scripts/run_llama.py --model huggingface/llama-7b --sweep batch +``` + +If you run the sweep for the multiple generation alternatives (original code, with preallocated tensors, and +preallocated + compiled), you can easily compare the results with + +``` +python scripts/plot_results.py --sweep batch +``` diff --git a/scripts/plot_results.py b/scripts/plot_results.py new file mode 100644 index 0000000..854c680 --- /dev/null +++ b/scripts/plot_results.py @@ -0,0 +1,79 @@ +""" +Plots the results of a sweep for the current git hash. +""" +import argparse + +import git +import matplotlib.pyplot as plt +import seaborn as sns +import pandas as pd + + +DEFAULT_BATCH_SIZE = 1 +DEFAULT_PROMPT_LENGTH = 1000 + + +parser = argparse.ArgumentParser() +parser.add_argument( + "--sweep", + type=str, + choices=["batch", "length"], + required=True, + help="Select which type of sweep to plot" +) +args = parser.parse_args() + +# 1. Read file and retrieve relevant data +results_file = "./results/results_llama.csv" +df = pd.read_csv(results_file) + +repo = git.Repo(search_parent_directories=True) +current_git_hash = repo.git.rev_parse(repo.head, short=True) +df = df[df["Git hash"] == current_git_hash] +if df.empty: + raise ValueError(f"No results found for current git hash ({current_git_hash})") + +if args.sweep == "batch": + df = df[df["Prompt length"] == DEFAULT_PROMPT_LENGTH] +else: + df = df[df["Batch size"] == DEFAULT_BATCH_SIZE] + df = df[df["Prompt length"] != DEFAULT_PROMPT_LENGTH] + +if df.empty: + raise ValueError("Something went wrong -- no results in the filtered dataframe") + +# 2. Plot -- we expect 3 series: original model, preallocated, and preallocated + compiled +if args.sweep == "batch": + x_col_name = "Batch size" +else: + x_col_name = "Prompt length" + +df["Type"] = df["Preallocate"].astype("str") + df["Compile"] +df["Type"] = df["Type"].replace({"Falseno": "original", "Trueno": "Preallocate", "Truestatic": "Pre + comp."}) + +g = sns.catplot( + data=df, + kind="bar", + x=x_col_name, + y="Tokens per second", + hue="Type", + palette={"original": "blue", "Preallocate": "orange", "Pre + comp.": "red"}, + alpha=.9, +) +g.despine(left=True) +g.set_axis_labels("Batch size" if args.sweep == "batch" else "Prompt length", "Tokens per second") +g.legend.set_title("LLaMA code version") +plt.setp(g._legend.get_texts(), fontsize='7') # for legend text + +title_constant = f'{"Batch size = " + str(DEFAULT_BATCH_SIZE) if args.sweep == "length" else "Prompt length = " + str(DEFAULT_PROMPT_LENGTH)}' +g.set(title=f'LLaMA sweep ({title_constant})') + +# Add the number to the top of each bar +ax = g.facet_axis(0, 0) +for i in ax.containers: + ax.bar_label(i, fontsize=7) + +g.tight_layout() +plt_path = f"./results/llama_sweep_{current_git_hash}_{args.sweep}.png" +plt.savefig(plt_path, dpi=300) +print(f"Plot stored at {plt_path}") diff --git a/scripts/run_llama.py b/scripts/run_llama.py index 8427df5..5664c97 100644 --- a/scripts/run_llama.py +++ b/scripts/run_llama.py @@ -2,11 +2,14 @@ import copy import contextlib import hashlib +import os from typing import Dict from tqdm import tqdm +import pandas as pd import torch -from torch.profiler import ProfilerActivity, profile, schedule, tensorboard_trace_handler +import git +from torch.profiler import ProfilerActivity, profile, tensorboard_trace_handler from transformers import AutoModelForCausalLM, AutoTokenizer @@ -14,15 +17,22 @@ from trfs_fast.utils import recurse_getattr, recurse_hasattr, recurse_setattr, recurse_delattr +# Default case BATCH_SIZES = [1] PROMPT_LENGTHS = [1000] NEW_TOKENS = [200] WARMUP_RUNS = 2 NUM_RUNS = 5 +# Modifiers for profiling (we want a short run) PROFILE_NEW_TOKENS = 10 PROFILE_NUM_RUNS = 1 +# Modifiers for parameter sweeps +SWEEP_BATCH_SIZES = [1, 2, 4, 8, 16, 32, 64, 128] +SWEEP_PROMPT_LENGTHS = [100, 200, 400, 800, 1600] +SWEEP_NUM_RUNS = 10 + parser = argparse.ArgumentParser() # TODO: support other archs than llama @@ -55,6 +65,14 @@ default="no", help="If (and how) to compile the model forward pass with torch.compile", ) +parser.add_argument( + "--sweep", + type=str, + choices=["", "batch", "length"], + required=False, + default="", + help="Select which type of sweep to gather data for", +) def timing_cuda( @@ -71,15 +89,11 @@ def timing_cuda( warmup_start_event = torch.cuda.Event(enable_timing=True) warmup_end_event = torch.cuda.Event(enable_timing=True) - if do_profile: - num_runs = PROFILE_NUM_RUNS - max_new_tokens = PROFILE_NEW_TOKENS - if preallocate: inputs["cache_length"] = cache_length with torch.no_grad(): - print("Warming up...") + print(f"Warming up ({WARMUP_RUNS} runs)...") warmup_start_event.record() for _ in range(WARMUP_RUNS): res = generate_method( @@ -149,10 +163,6 @@ def timing_cuda( tokenizer = AutoTokenizer.from_pretrained(args.model) tokenizer.pad_token = tokenizer.eos_token -header = "batch_size,compile,prompt_length,new_tokens,cache_length,dtype,tok_per_s,max_mem_mb,hash" -stats = {} - - if args.preallocate: with device: original_model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=dtype) @@ -200,14 +210,27 @@ def timing_cuda( if args.compile != "no": dynamic = args.compile == "dynamic" fullgraph = args.compile == "fullgraph" - model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=fullgraph, dynamic=dynamic) + mode = "reduce-overhead" if not args.sweep else "max-autotune" + model.forward = torch.compile(model.forward, mode=mode, fullgraph=fullgraph, dynamic=dynamic) if model.config.model_type != "llama": raise ValueError("This script currently only supports LLAMA") -for batch_size in tqdm(BATCH_SIZES): - for prompt_length in tqdm(PROMPT_LENGTHS): - for max_new_tokens in tqdm(NEW_TOKENS): +if args.profile and args.sweep: + raise ValueError("Cannot profile and sweep at the same time") +batch_sizes = BATCH_SIZES if args.sweep != "batch" else SWEEP_BATCH_SIZES +prompt_lengths = PROMPT_LENGTHS if args.sweep != "length" else SWEEP_PROMPT_LENGTHS +new_tokens = NEW_TOKENS if not args.profile else PROFILE_NEW_TOKENS +num_runs = NUM_RUNS +if args.profile: + num_runs = PROFILE_NUM_RUNS +elif args.sweep: + num_runs = SWEEP_NUM_RUNS + +stats = {} +for batch_size in tqdm(batch_sizes): + for prompt_length in tqdm(prompt_lengths): + for max_new_tokens in tqdm(new_tokens): cache_length = 1 * (prompt_length + max_new_tokens) inp = { @@ -225,19 +248,22 @@ def timing_cuda( print("Cache preallocation:", args.preallocate) generate_method = model.generate if not args.preallocate else model.generate_minimal - time_per_generation, max_memory, sha_hash = timing_cuda( - tokenizer=tokenizer, - num_runs=NUM_RUNS, - inputs=inp, - device=device, - max_new_tokens=max_new_tokens, - cache_length=cache_length, - generate_method=generate_method, - preallocate=args.preallocate, - do_profile=args.profile, - ) + try: + time_per_generation, max_memory, sha_hash = timing_cuda( + tokenizer=tokenizer, + num_runs=num_runs, + inputs=inp, + device=device, + max_new_tokens=max_new_tokens, + cache_length=cache_length, + generate_method=generate_method, + preallocate=args.preallocate, + do_profile=args.profile, + ) + except: + break # in a sweep, might get OOM - tok_per_s = max_new_tokens / time_per_generation + tok_per_s = (max_new_tokens * batch_size) / time_per_generation stats[(batch_size, prompt_length, max_new_tokens)] = { "cache_length": cache_length, @@ -246,18 +272,29 @@ def timing_cuda( "max_mem": max_memory } -# print csv -print(header) +# build dataframe with the results and store it +rows = [] +repo = git.Repo(search_parent_directories=True) +current_git_hash = repo.git.rev_parse(repo.head, short=True) for key, value in stats.items(): batch_size, prompt_length, new_tokens = key - print(",".join([ - str(batch_size), - args.compile, - str(prompt_length), - str(new_tokens), - str(value["cache_length"]), - args.dtype, - f"{value['tok_per_s']:.3f}", - f"{value['max_mem']:.2f}", - value["hash"]]) - ) + rows.append({ + 'Preallocate': args.preallocate, + 'Compile': args.compile, + 'Batch size': str(batch_size), + 'Prompt length': str(prompt_length), + 'New tokens': str(new_tokens), + 'Cache length': str(value["cache_length"]), + 'Weights dtype': args.dtype, + 'Tokens per second': f"{value['tok_per_s']:.3f}", + 'Max GPU memory (MB)': f"{value['max_mem']:.2f}", + 'Results hash': value["hash"], + 'Git hash': current_git_hash, + }) +df = pd.DataFrame(rows) +print(df) + +os.makedirs("./results", exist_ok=True) +output_path = "./results/results_llama.csv" +df.to_csv(output_path, mode='a', header=not os.path.exists(output_path)) +print(f"Results also appended to {output_path}") diff --git a/setup.py b/setup.py index 90198da..7abd91d 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ from setuptools import find_packages, setup -REQUIRED_PKGS = ["torch", "transformers"] +REQUIRED_PKGS = ["torch", "transformers", "gitpython", "seaborn"] QUALITY_REQUIRE = ["black~=22.0", "flake8>=3.8.3", "isort>=5.0.0", "pyyaml>=5.3.1"]