diff --git a/benchmarks/benchmark_mqar.py b/benchmarks/benchmark_mqar.py deleted file mode 100644 index 80fcc89..0000000 --- a/benchmarks/benchmark_mqar.py +++ /dev/null @@ -1,62 +0,0 @@ -from zoology.config import TrainConfig, ModelConfig, DataConfig, ModuleConfig -from zoology.data.associative_recall import MQARConfig - - -seq_len = 2048 -d_model = 256 # 32, 64, 128, 256 -vocab_size = seq_len + 1 -num_kv_pairs = 512 - -if d_model == 32: - learning_rate = 4e-4 -elif d_model == 64: - learning_rate = 3e-4 -elif d_model == 128: - learning_rate = 2e-4 -elif d_model == 256: - learning_rate = 1e-4 - -if seq_len == 1024: - batch_size = 64 -elif seq_len == 2048: - batch_size = 32 -elif seq_len == 4096: - batch_size = 16 -elif seq_len == 8192: - batch_size = 8 - -config = TrainConfig( - learning_rate=learning_rate, - data=DataConfig( - cache_dir=".cache", - batch_size=batch_size, - train_configs=[ - MQARConfig( - num_examples=250_000, - vocab_size=vocab_size, - input_seq_len=seq_len, - num_kv_pairs=num_kv_pairs, - ) - ], - test_configs=[ - MQARConfig( - num_examples=1_000, - vocab_size=vocab_size, - input_seq_len=seq_len, - num_kv_pairs=num_kv_pairs, - ) - ] - ), - model=ModelConfig( - vocab_size=vocab_size, - d_model=d_model, - max_position_embeddings=seq_len, - sequence_mixer=ModuleConfig( - name="zoology.mixers.dma.DynamicMaskAttention", - kwargs={"keep_window_size": num_kv_pairs, "num_heads": 1}, - ) - ), - -) - -configs = [config] \ No newline at end of file diff --git a/benchmarks/benchmark_forward_equivalence.py b/benchmarks/forward_equivalence.py similarity index 97% rename from benchmarks/benchmark_forward_equivalence.py rename to benchmarks/forward_equivalence.py index fac85d2..f55c33b 100644 --- a/benchmarks/benchmark_forward_equivalence.py +++ b/benchmarks/forward_equivalence.py @@ -518,27 +518,27 @@ def test_cuda_forward_equivalence(accuracy_threshold=0.95): # If you encounter NAN issues when running multiple configurations, try running a single configuration test_configs = [ # (batch_size, num_heads, num_kv_heads, query_len, key_len, head_dim, is_causal) - (1, 1, 1, 64, 64, 32, True), - (1, 1, 1, 64, 64, 32, False), - (1, 1, 1, 128, 128, 32, True), - (1, 1, 1, 128, 128, 32, False), - (1, 1, 1, 256, 256, 32, True), - (1, 1, 1, 256, 256, 32, False), - (1, 1, 1, 512, 512, 32, True), - (1, 1, 1, 512, 512, 32, False), - (1, 1, 1, 1024, 1024, 32, True), - (1, 1, 1, 1024, 1024, 32, False), - (1, 1, 1, 2048, 2048, 32, True), - (1, 1, 1, 2048, 2048, 32, False), + # (1, 1, 1, 64, 64, 32, True), + # (1, 1, 1, 64, 64, 32, False), + # (1, 1, 1, 128, 128, 32, True), + # (1, 1, 1, 128, 128, 32, False), + # (1, 1, 1, 256, 256, 32, True), + # (1, 1, 1, 256, 256, 32, False), + # (1, 1, 1, 512, 512, 32, True), + # (1, 1, 1, 512, 512, 32, False), + # (1, 1, 1, 1024, 1024, 32, True), + # (1, 1, 1, 1024, 1024, 32, False), + # (1, 1, 1, 2048, 2048, 32, True), + # (1, 1, 1, 2048, 2048, 32, False), (1, 1, 1, 4096, 4096, 32, True), - (1, 1, 1, 4096, 4096, 32, False), - (1, 2, 1, 64, 64, 32, True), - (2, 1, 1, 128, 128, 32, True), - (2, 2, 1, 128, 128, 32, True), - (1, 2, 1, 64, 64, 128, True), - (1, 2, 1, 128, 128, 128, True), - (1, 2, 1, 256, 256, 128, True), - (1, 2, 1, 512, 512, 128, True), + # (1, 1, 1, 4096, 4096, 32, False), + # (1, 2, 1, 64, 64, 32, True), + # (2, 1, 1, 128, 128, 32, True), + # (2, 2, 1, 128, 128, 32, True), + # (1, 2, 1, 64, 64, 128, True), + # (1, 2, 1, 128, 128, 128, True), + # (1, 2, 1, 256, 256, 128, True), + # (1, 2, 1, 512, 512, 128, True), ] device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -1050,13 +1050,13 @@ def main(): print("\n" + "šŸ“" + " Starting Standard Forward Pass Tests " + "šŸ“") test_results['cuda'] = test_cuda_forward_equivalence(args.accuracy_threshold) - if args.test_type in ['all', 'triton']: - print("\n" + "šŸ”„" + " Starting Python vs Triton Tests " + "šŸ”„") - test_results['triton'] = test_triton_forward_equivalence(args.accuracy_threshold) + # if args.test_type in ['all', 'triton']: + # print("\n" + "šŸ”„" + " Starting Python vs Triton Tests " + "šŸ”„") + # test_results['triton'] = test_triton_forward_equivalence(args.accuracy_threshold) - if args.test_type in ['all', 'flex']: - print("\n" + "🌟" + " Starting Python vs Flex Attention Tests " + "🌟") - test_results['flex'] = test_flex_forward_equivalence(args.accuracy_threshold) + # if args.test_type in ['all', 'flex']: + # print("\n" + "🌟" + " Starting Python vs Flex Attention Tests " + "🌟") + # test_results['flex'] = test_flex_forward_equivalence(args.accuracy_threshold) # Print overall summary diff --git a/benchmarks/benchmark_forward_performance.py b/benchmarks/forward_performance.py similarity index 94% rename from benchmarks/benchmark_forward_performance.py rename to benchmarks/forward_performance.py index 0c1c042..7511946 100644 --- a/benchmarks/benchmark_forward_performance.py +++ b/benchmarks/forward_performance.py @@ -732,57 +732,57 @@ def run_performance_benchmark(test_type='all', num_runs=3, warmup_runs=2): (1, 2, 1, 4096, 4096, 128, 2048, True), (1, 2, 1, 8192, 8192, 128, 2048, True), (1, 2, 1, 16384, 16384, 128, 2048, True), - (1, 2, 1, 32768, 32768, 128, 2048, True), - - # Inference - (1, 2, 1, 2, 256, 128, 2048, True), - (1, 2, 1, 2, 512, 128, 2048, True), - (1, 2, 1, 2, 1024, 128, 2048, True), - (1, 2, 1, 2, 2048, 128, 2048, True), - (1, 2, 1, 2, 4096, 128, 2048, True), - (1, 2, 1, 2, 8192, 128, 2048, True), - (1, 2, 1, 2, 16384, 128, 2048, True), - (1, 2, 1, 2, 32768, 128, 2048, True), + # (1, 2, 1, 32768, 32768, 128, 2048, True), + + # # Inference + # (1, 2, 1, 2, 256, 128, 2048, True), + # (1, 2, 1, 2, 512, 128, 2048, True), + # (1, 2, 1, 2, 1024, 128, 2048, True), + # (1, 2, 1, 2, 2048, 128, 2048, True), + # (1, 2, 1, 2, 4096, 128, 2048, True), + # (1, 2, 1, 2, 8192, 128, 2048, True), + # (1, 2, 1, 2, 16384, 128, 2048, True), + # (1, 2, 1, 2, 32768, 128, 2048, True), (1, 2, 1, 2, 65536, 128, 2048, True), - (1, 2, 1, 2, 131072, 128, 2048, True), - (1, 2, 1, 2, 262144, 128, 2048, True), - (1, 2, 1, 2, 524288, 128, 2048, True), - - # Vary batch size - (1, 2, 1, 4096, 4096, 32, 2048, True), - (2, 2, 1, 4096, 4096, 32, 2048, True), - (4, 2, 1, 4096, 4096, 32, 2048, True), - (8, 2, 1, 4096, 4096, 32, 2048, True), - - # Vary head count - (1, 1, 1, 4096, 4096, 32, 2048, True), - (1, 2, 1, 4096, 4096, 32, 2048, True), - (1, 4, 1, 4096, 4096, 32, 2048, True), - (1, 8, 2, 4096, 4096, 32, 2048, True), - - # Vary head dimension - (1, 2, 1, 4096, 4096, 32, 2048, True), - (1, 2, 1, 4096, 4096, 64, 2048, True), - (1, 2, 1, 4096, 4096, 96, 2048, True), - (1, 2, 1, 4096, 4096, 128, 2048, True), - (1, 2, 1, 4096, 4096, 192, 2048, True), - (1, 2, 1, 4096, 4096, 256, 2048, True), - - # Vary keep_window_size - (1, 2, 1, 32768, 32768, 128, 32, True), - (1, 2, 1, 32768, 32768, 128, 64, True), - (1, 2, 1, 32768, 32768, 128, 128, True), - (1, 2, 1, 32768, 32768, 128, 256, True), - (1, 2, 1, 32768, 32768, 128, 512, True), - (1, 2, 1, 32768, 32768, 128, 1024, True), - (1, 2, 1, 32768, 32768, 128, 2048, True), - (1, 2, 1, 32768, 32768, 128, 4096, True), - (1, 2, 1, 32768, 32768, 128, 8192, True), - (1, 2, 1, 32768, 32768, 128, 16384, True), - (1, 2, 1, 32768, 32768, 128, 32768, True), - - # Test non-causal - (1, 2, 1, 4096, 4096, 128, 2048, False), + # (1, 2, 1, 2, 131072, 128, 2048, True), + # (1, 2, 1, 2, 262144, 128, 2048, True), + # (1, 2, 1, 2, 524288, 128, 2048, True), + + # # Vary batch size + # (1, 2, 1, 4096, 4096, 32, 2048, True), + # (2, 2, 1, 4096, 4096, 32, 2048, True), + # (4, 2, 1, 4096, 4096, 32, 2048, True), + # (8, 2, 1, 4096, 4096, 32, 2048, True), + + # # Vary head count + # (1, 1, 1, 4096, 4096, 32, 2048, True), + # (1, 2, 1, 4096, 4096, 32, 2048, True), + # (1, 4, 1, 4096, 4096, 32, 2048, True), + # (1, 8, 2, 4096, 4096, 32, 2048, True), + + # # Vary head dimension + # (1, 2, 1, 4096, 4096, 32, 2048, True), + # (1, 2, 1, 4096, 4096, 64, 2048, True), + # (1, 2, 1, 4096, 4096, 96, 2048, True), + # (1, 2, 1, 4096, 4096, 128, 2048, True), + # (1, 2, 1, 4096, 4096, 192, 2048, True), + # (1, 2, 1, 4096, 4096, 256, 2048, True), + + # # Vary keep_window_size + # (1, 2, 1, 32768, 32768, 128, 32, True), + # (1, 2, 1, 32768, 32768, 128, 64, True), + # (1, 2, 1, 32768, 32768, 128, 128, True), + # (1, 2, 1, 32768, 32768, 128, 256, True), + # (1, 2, 1, 32768, 32768, 128, 512, True), + # (1, 2, 1, 32768, 32768, 128, 1024, True), + # (1, 2, 1, 32768, 32768, 128, 2048, True), + # (1, 2, 1, 32768, 32768, 128, 4096, True), + # (1, 2, 1, 32768, 32768, 128, 8192, True), + # (1, 2, 1, 32768, 32768, 128, 16384, True), + # (1, 2, 1, 32768, 32768, 128, 32768, True), + + # # Test non-causal + # (1, 2, 1, 4096, 4096, 128, 2048, False), ] print(f"\nšŸ“Š Benchmark Results (averaged over {num_runs} runs):") diff --git a/benchmarks/benchmark_grad.py b/benchmarks/grad_equivalence.py similarity index 100% rename from benchmarks/benchmark_grad.py rename to benchmarks/grad_equivalence.py