Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[-------------------------------------------- attn --------------------------------------------] | reference | cutlass 1 threads: ------------------------------------------------------------------------------------- (8, 512, 64, 128, torch.float16, (512, 512, 512), False, 0.0) | 12.5 | 7.6 (8, 512, 64, 128, torch.float16, (512, 512, 512), False, 0.5) | 15.4 | 9.1 (8, 512, 64, 128, torch.float16, (512, 512, 512), True, 0.0) | 12.6 | 7.7 (8, 512, 64, 128, torch.float16, (512, 512, 512), True, 0.5) | 15.5 | 9.2 (8, 512, 64, 128, torch.float16, None, False, 0.0) | 10.3 | 6.0 (8, 512, 64, 128, torch.float16, None, False, 0.5) | 12.9 | 7.6 (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), False, 0.0) | 45.1 | 29.2 (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), False, 0.5) | 55.7 | 35.2 (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), True, 0.0) | 45.6 | 29.3 (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), True, 0.5) | 56.1 | 35.1 (8, 1024, 64, 128, torch.float16, None, False, 0.0) | 38.7 | 22.7 (8, 1024, 64, 128, torch.float16, None, False, 0.5) | 46.8 | 29.0 Times are in milliseconds (ms). [------------------------------------------ attn-bwd ------------------------------------------] | reference | cutlass 1 threads: ------------------------------------------------------------------------------------- (8, 512, 64, 128, torch.float16, (512, 512, 512), False, 0.0) | 19.3 | 24.1 (8, 512, 64, 128, torch.float16, (512, 512, 512), False, 0.5) | 19.4 | 24.6 (8, 512, 64, 128, torch.float16, (512, 512, 512), True, 0.0) | 22.3 | 28.7 (8, 512, 64, 128, torch.float16, (512, 512, 512), True, 0.5) | 22.3 | 29.1 (8, 512, 64, 128, torch.float16, None, False, 0.0) | 19.4 | 22.7 (8, 512, 64, 128, torch.float16, None, False, 0.5) | 19.5 | 23.4 (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), False, 0.0) | 63.8 | 91.3 (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), False, 0.5) | 63.9 | 94.1 (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), True, 0.0) | 75.4 | 109.7 (8, 1024, 64, 128, torch.float16, (512, 1024, 1024), True, 0.5) | 75.6 | 111.2 (8, 1024, 64, 128, torch.float16, None, False, 0.0) | 63.9 | 85.8 (8, 1024, 64, 128, torch.float16, None, False, 0.5) | 64.3 | 90.3
- Loading branch information