Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Write a sharded transformer block in nvFuser API. #2199

Closed
wujingyue opened this issue May 3, 2024 · 21 comments · Fixed by #3232
Closed

Write a sharded transformer block in nvFuser API. #2199

wujingyue opened this issue May 3, 2024 · 21 comments · Fixed by #3232

Comments

@wujingyue
Copy link
Collaborator

This is to unblock @cowanmeg and @samnordmann 's distributed matmul experiments.

I'll start with the tensor parallelism proposed by the original Megatron-LM paper.

  1. Only MHA and MLP are sharded.
  2. Activations are sharded in 2D, batch and hidden. However, the batch dimension sharding is just for data parallelism and the dimension is never resharded.
  3. Weights are sharded in 1D, the hidden dimension.
@wujingyue wujingyue self-assigned this May 3, 2024
@wujingyue
Copy link
Collaborator Author

Note to myself: I'll first try to get a single-device nvFuser python definition from Thunder, and then we can manually shard it using nvFuser's API.

@Priya2698 pointed me to the nv_enable_linear flag (https://github.com/Lightning-AI/lightning-thunder/blob/90a0f4c0d0a90d1e94684a847f3adfe2230985b4/thunder/tests/test_nvfuser.py#L875) that I'll need to turn on to enable prims.linear via nvFuser. I'll probably need to nv_enable_bookend=False as well.

@wujingyue
Copy link
Collaborator Author

Note to myself: I'll start with the following benchmark

$ pytest thunder/benchmarks/targets.py -k test_nanogpt_block_fwd[thunder] -s

which exercises one transformer layer in nanoGPT:
https://github.com/Lightning-AI/lightning-thunder/blob/cab020881765594fd9552d4deb8cc4e0f64410d2/thunder/tests/nanogpt_model.py#L132-L143

@wujingyue
Copy link
Collaborator Author

wujingyue commented May 4, 2024

cc @Priya2698

a.ndim==2 is the first check that failed. Here's how you can reproduce the problem:

With the following patch

diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py
index c955da06..4767ab9c 100644
--- a/thunder/executors/nvfuserex_impl.py
+++ b/thunder/executors/nvfuserex_impl.py
@@ -2201,6 +2201,7 @@ def _linear_check(a: TensorProxy, b: TensorProxy, bias: TensorProxy | None) -> b
         return False

     enable_linear: None | bool = get_compile_option("nv_enable_linear", "Enable nvFuser matmul.")
+    enable_linear = True
     if not enable_linear:
         return False
     # Verify linear inputs and bias (optional) are supported tensors.
@@ -2210,6 +2211,7 @@ def _linear_check(a: TensorProxy, b: TensorProxy, bias: TensorProxy | None) -> b
         return False

     # nvFuser only supports 2D inputs in v0.2.3.
+    import pdb; pdb.set_trace()
     if not a.ndim == 2:
         return False
     return True
$ NVFUSER_DUMP=python_definition pytest thunder/benchmarks/targets.py -k test_nanogpt_block_fwd[thunder] -s
========================================================================================================================================================================================================================================= test session starts =========================================================================================================================================================================================================================================
platform linux -- Python 3.10.12, pytest-8.1.1, pluggy-1.4.0
Test order randomisation NOT enabled. Enable with --random-order or --random-order-bucket=<bucket_type>
benchmark: 4.0.0 (defaults: timer=time.perf_counter disable_gc=False min_rounds=5 min_time=0.000005 max_time=1.0 calibration_precision=10 warmup=False warmup_iterations=100000)
rootdir: /opt/pytorch/lightning-thunder
configfile: pyproject.toml
plugins: timestamper-0.0.10, xdist-3.5.0, random-order-1.1.1, cov-4.1.0, benchmark-4.0.0, hypothesis-6.100.0, timeout-2.2.0, shard-0.1.2
timeout: 900.0s
timeout method: signal
timeout func_only: False
collected 162 items / 161 deselected / 1 selected
Running 1 items in this shard

thunder/benchmarks/targets.py
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> PDB set_trace >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
> /opt/pytorch/lightning-thunder/thunder/executors/nvfuserex_impl.py(2215)_linear_check()
-> if not a.ndim == 2:
(Pdb) p a.ndim
3
(Pdb)

The Python definition printed out is unsurprisingly five fusions, none of which have matmul or linear.

@wujingyue
Copy link
Collaborator Author

wujingyue commented May 8, 2024

Below is a WAR for the above Thunder check but it ran into an nvFuser issue.

diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py
index c955da06..137da102 100644
--- a/thunder/executors/nvfuserex_impl.py
+++ b/thunder/executors/nvfuserex_impl.py
@@ -4,6 +4,7 @@ from numbers import Number
 from typing import Union, List, Any, Optional, Dict, Set, Tuple, Type
 from types import NoneType
 from collections.abc import Callable, Mapping, Hashable, Sequence
+import math
 import os
 import time
 from copy import copy
@@ -796,7 +797,7 @@ instantiated) this heuristic actually leads to worse code.
             enable_bookend: None | bool = get_compile_option("nv_enable_bookend", bookend_help)
             # Set default value.
             if enable_bookend is None:
-                enable_bookend = True
+                enable_bookend = False
             assert isinstance(enable_bookend, bool)

             if enable_bookend:
@@ -2200,7 +2201,7 @@ def _linear_check(a: TensorProxy, b: TensorProxy, bias: TensorProxy | None) -> b
     if nv_version < LooseVersion("0.2.3"):
         return False

-    enable_linear: None | bool = get_compile_option("nv_enable_linear", "Enable nvFuser matmul.")
+    enable_linear = True
     if not enable_linear:
         return False
     # Verify linear inputs and bias (optional) are supported tensors.
@@ -2209,8 +2210,11 @@ def _linear_check(a: TensorProxy, b: TensorProxy, bias: TensorProxy | None) -> b
     if bias is not None and not is_supported_tensor(bias):
         return False

-    # nvFuser only supports 2D inputs in v0.2.3.
-    if not a.ndim == 2:
+    if a.ndim < 2:
+        return False
+    if b.ndim != 2:
+        return False
+    if bias.ndim != 1:
         return False
     return True

@@ -2226,7 +2230,10 @@ def linear(
     nva = getnv(a, fd, lc_to_nv_map)
     nvb = getnv(b, fd, lc_to_nv_map)
     nvbias = None if bias is None else getnv(bias, fd, lc_to_nv_map)
-    return fd.ops.linear(nva, nvb, nvbias)
+
+    nva_2d = fd.ops.reshape(nva, (math.prod(a.shape[:-1]), a.shape[-1]))
+    nvc_2d = fd.ops.linear(nva_2d, nvb, nvbias)
+    return fd.ops.reshape(nvc_2d, a.shape[:-1] + (b.shape[-2],))


 register_supported(PrimIDs.LINEAR, linear, _linear_check)
import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id0(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T1 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
    T2 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T3 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T4 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[True, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
    T5 = fd.ops.cast(T4, dtype=DataType.Float)
    T6, T7 = fd.ops.var_mean(T5, dims=[2], correction=0, keepdim=False)
    S8 = fd.define_scalar(16, dtype=DataType.Int)
    S9 = fd.define_scalar(128, dtype=DataType.Int)
    S10 = fd.define_scalar(1, dtype=DataType.Int)
    V11 = fd.define_vector([S8, S9, S10], dtype=DataType.Int)
    T12 = fd.ops.broadcast_in_dim(T6, shape=V11, broadcast_dims=[0, 1])
    S13 = fd.define_scalar(16, dtype=DataType.Int)
    S14 = fd.define_scalar(128, dtype=DataType.Int)
    S15 = fd.define_scalar(1, dtype=DataType.Int)
    V16 = fd.define_vector([S13, S14, S15], dtype=DataType.Int)
    T17 = fd.ops.broadcast_in_dim(T7, shape=V16, broadcast_dims=[0, 1])
    S18 = fd.define_scalar(1.00000e-05, dtype=DataType.Double)
    T19 = fd.ops.add(T12, S18)
    T20 = fd.ops.rsqrt(T19)
    S21 = fd.define_scalar(16, dtype=DataType.Int)
    S22 = fd.define_scalar(128, dtype=DataType.Int)
    S23 = fd.define_scalar(1600, dtype=DataType.Int)
    V24 = fd.define_vector([S21, S22, S23], dtype=DataType.Int)
    T25 = fd.ops.broadcast_in_dim(T17, shape=V24, broadcast_dims=[0, 1, 2])
    T26 = fd.ops.sub(T5, T25)
    S27 = fd.define_scalar(16, dtype=DataType.Int)
    S28 = fd.define_scalar(128, dtype=DataType.Int)
    S29 = fd.define_scalar(1600, dtype=DataType.Int)
    V30 = fd.define_vector([S27, S28, S29], dtype=DataType.Int)
    T31 = fd.ops.broadcast_in_dim(T20, shape=V30, broadcast_dims=[0, 1, 2])
    T32 = fd.ops.mul(T26, T31)
    S33 = fd.define_scalar(16, dtype=DataType.Int)
    S34 = fd.define_scalar(128, dtype=DataType.Int)
    S35 = fd.define_scalar(1600, dtype=DataType.Int)
    V36 = fd.define_vector([S33, S34, S35], dtype=DataType.Int)
    T37 = fd.ops.broadcast_in_dim(T3, shape=V36, broadcast_dims=[2])
    T38 = fd.ops.cast(T37, dtype=DataType.Float)
    T39 = fd.ops.mul(T32, T38)
    S40 = fd.define_scalar(16, dtype=DataType.Int)
    S41 = fd.define_scalar(128, dtype=DataType.Int)
    S42 = fd.define_scalar(1600, dtype=DataType.Int)
    V43 = fd.define_vector([S40, S41, S42], dtype=DataType.Int)
    T44 = fd.ops.broadcast_in_dim(T2, shape=V43, broadcast_dims=[2])
    T45 = fd.ops.cast(T44, dtype=DataType.Float)
    T46 = fd.ops.add(T39, T45)
    T47 = fd.ops.cast(T46, dtype=DataType.BFloat16)
    S48 = fd.define_scalar(2048, dtype=DataType.Int)
    S49 = fd.define_scalar(1600, dtype=DataType.Int)
    V50 = fd.define_vector([S48, S49], dtype=DataType.Int)
    T51 = fd.ops.reshape(T47, new_shape=V50)
    T52 = fd.ops.linear(T51, T1, T0)
    fd.add_output(T52)

with FusionDefinition() as fd:
    nvfuser_fusion_id0(fd)

inputs = [
    torch.randn((4800,), dtype=torch.bfloat16, device='cuda:0').as_strided((4800,), (1,)),
    torch.randn((7680000,), dtype=torch.bfloat16, device='cuda:0').as_strided((4800, 1600), (1600, 1)),
    torch.randn((1600,), dtype=torch.bfloat16, device='cuda:0').as_strided((1600,), (1,)),
    torch.randn((1600,), dtype=torch.bfloat16, device='cuda:0').as_strided((1600,), (1,)),
    torch.randn((3276800,), dtype=torch.bfloat16, device='cuda:0').as_strided((16, 128, 1600), (204800, 1600, 1)),
]
fd.execute(inputs)
Traceback (most recent call last):
  File "/opt/pytorch/nvfuser/nvfuser/__init__.py", line 146, in execute
    result = self._execute(
RuntimeError: h.has_value() INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/fusion_segmenter.cpp":3671, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Can not find a scheduler to schedule fusion segment
Exception raised from deriveHeuristic at /opt/pytorch/nvfuser/csrc/fusion_segmenter.cpp:3671 (most recent call first):
frame #0: nvfuser::nvfCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xf3 (0x7fbf362d8381 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #1: nvfuser::nvfErrorFail(char const*, char const*, unsigned int, char const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x53 (0x7fbf365d51b3 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #2: <unknown function> + 0x4bde42 (0x7fbf36675e42 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #3: <unknown function> + 0x4c5032 (0x7fbf3667d032 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #4: <unknown function> + 0x4d0c42 (0x7fbf36688c42 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #5: nvfuser::SegmentCandidateFinder::SegmentCandidateFinder(std::unique_ptr<nvfuser::Fusion, std::default_delete<nvfuser::Fusion> >, nvfuser::KernelArgumentHolder const*, nvfuser::SegmentCandidateFinderOptions) + 0x46f (0x7fbf366897ff in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #6: <unknown function> + 0x4a8082 (0x7fbf36660082 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #7: <unknown function> + 0x4d1a0e (0x7fbf36689a0e in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #8: nvfuser::FusionKernelRuntime::FusionKernelRuntime(std::unique_ptr<nvfuser::Fusion, std::default_delete<nvfuser::Fusion> >, nvfuser::KernelArgumentHolder const&, nvfuser::serde::FusionKernelRuntime const*, std::optional<nvfuser::PrimDataType>, long, long, long, bool) + 0x373 (0x7fbf36799ed3 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #9: <unknown function> + 0x5e5b57 (0x7fbf3679db57 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #10: nvfuser::FusionExecutorCache::runFusionWithInputs(c10::ArrayRef<c10::IValue> const&, std::optional<nvfuser::PrimDataType>, std::optional<signed char>) + 0x1e7 (0x7fbf3679e8b7 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #11: nvfuser::python_frontend::FusionDefinition::execute(c10::ArrayRef<c10::IValue> const&, bool, bool, std::optional<signed char>) const + 0x3c8 (0x7fbf36981998 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #12: <unknown function> + 0x18ca25 (0x7fbf36344a25 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #13: <unknown function> + 0x2009c2 (0x7fbf363b89c2 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #14: <unknown function> + 0x288d00 (0x7fbf36440d00 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #15: <unknown function> + 0x15a10e (0x555dc548510e in /usr/bin/python3)
frame #16: _PyObject_MakeTpCall + 0x25b (0x555dc547ba7b in /usr/bin/python3)
frame #17: <unknown function> + 0x168acb (0x555dc5493acb in /usr/bin/python3)
frame #18: _PyEval_EvalFrameDefault + 0x198c (0x555dc546f53c in /usr/bin/python3)
frame #19: <unknown function> + 0x16893e (0x555dc549393e in /usr/bin/python3)
frame #20: _PyEval_EvalFrameDefault + 0x2a27 (0x555dc54705d7 in /usr/bin/python3)
frame #21: _PyObject_FastCallDictTstate + 0xc4 (0x555dc547ac14 in /usr/bin/python3)
frame #22: _PyObject_Call_Prepend + 0xc1 (0x555dc54908d1 in /usr/bin/python3)
frame #23: <unknown function> + 0x280700 (0x555dc55ab700 in /usr/bin/python3)
frame #24: _PyObject_MakeTpCall + 0x25b (0x555dc547ba7b in /usr/bin/python3)
frame #25: _PyEval_EvalFrameDefault + 0x64e6 (0x555dc5474096 in /usr/bin/python3)
frame #26: _PyFunction_Vectorcall + 0x7c (0x555dc54859fc in /usr/bin/python3)
frame #27: _PyEval_EvalFrameDefault + 0x2a27 (0x555dc54705d7 in /usr/bin/python3)
frame #28: _PyFunction_Vectorcall + 0x7c (0x555dc54859fc in /usr/bin/python3)
frame #29: _PyEval_EvalFrameDefault + 0x2a27 (0x555dc54705d7 in /usr/bin/python3)
frame #30: _PyFunction_Vectorcall + 0x7c (0x555dc54859fc in /usr/bin/python3)
frame #31: _PyEval_EvalFrameDefault + 0x2a27 (0x555dc54705d7 in /usr/bin/python3)
frame #32: _PyFunction_Vectorcall + 0x7c (0x555dc54859fc in /usr/bin/python3)
frame #33: _PyEval_EvalFrameDefault + 0x2a27 (0x555dc54705d7 in /usr/bin/python3)
frame #34: _PyFunction_Vectorcall + 0x7c (0x555dc54859fc in /usr/bin/python3)
frame #35: _PyEval_EvalFrameDefault + 0x2a27 (0x555dc54705d7 in /usr/bin/python3)
frame #36: <unknown function> + 0x16893e (0x555dc549393e in /usr/bin/python3)
frame #37: _PyEval_EvalFrameDefault + 0x2a27 (0x555dc54705d7 in /usr/bin/python3)
frame #38: <unknown function> + 0x16893e (0x555dc549393e in /usr/bin/python3)
frame #39: _PyEval_EvalFrameDefault + 0x2a27 (0x555dc54705d7 in /usr/bin/python3)
frame #40: _PyObject_FastCallDictTstate + 0xc4 (0x555dc547ac14 in /usr/bin/python3)
frame #41: _PyObject_Call_Prepend + 0x5c (0x555dc549086c in /usr/bin/python3)
frame #42: <unknown function> + 0x280700 (0x555dc55ab700 in /usr/bin/python3)
frame #43: PyObject_Call + 0xbb (0x555dc549442b in /usr/bin/python3)
frame #44: _PyEval_EvalFrameDefault + 0x2a27 (0x555dc54705d7 in /usr/bin/python3)
frame #45: _PyFunction_Vectorcall + 0x7c (0x555dc54859fc in /usr/bin/python3)
frame #46: _PyEval_EvalFrameDefault + 0x2a27 (0x555dc54705d7 in /usr/bin/python3)
frame #47: _PyFunction_Vectorcall + 0x7c (0x555dc54859fc in /usr/bin/python3)
frame #48: _PyEval_EvalFrameDefault + 0x2a27 (0x555dc54705d7 in /usr/bin/python3)
frame #49: _PyFunction_Vectorcall + 0x7c (0x555dc54859fc in /usr/bin/python3)
frame #50: _PyEval_EvalFrameDefault + 0x6bd (0x555dc546e26d in /usr/bin/python3)
frame #51: <unknown function> + 0x1687f1 (0x555dc54937f1 in /usr/bin/python3)
frame #52: _PyEval_EvalFrameDefault + 0x198c (0x555dc546f53c in /usr/bin/python3)
frame #53: <unknown function> + 0x1687f1 (0x555dc54937f1 in /usr/bin/python3)
frame #54: _PyEval_EvalFrameDefault + 0x198c (0x555dc546f53c in /usr/bin/python3)
frame #55: _PyFunction_Vectorcall + 0x7c (0x555dc54859fc in /usr/bin/python3)
frame #56: PyObject_Call + 0x122 (0x555dc5494492 in /usr/bin/python3)
frame #57: _PyEval_EvalFrameDefault + 0x2a27 (0x555dc54705d7 in /usr/bin/python3)
frame #58: _PyFunction_Vectorcall + 0x7c (0x555dc54859fc in /usr/bin/python3)
frame #59: _PyEval_EvalFrameDefault + 0x2a27 (0x555dc54705d7 in /usr/bin/python3)
frame #60: _PyFunction_Vectorcall + 0x7c (0x555dc54859fc in /usr/bin/python3)
frame #61: _PyEval_EvalFrameDefault + 0x614a (0x555dc5473cfa in /usr/bin/python3)
frame #62: <unknown function> + 0x1687f1 (0x555dc54937f1 in /usr/bin/python3)
frame #63: _PyEval_EvalFrameDefault + 0x614a (0x555dc5473cfa in /usr/bin/python3)

wujingyue added a commit that referenced this issue May 8, 2024
@wujingyue
Copy link
Collaborator Author

FYI, NVFUSER_DUMP=segmenter_logging prints the following

**Segmenter** Considering fusion:
T34_g[ iS97{( 16 * 128 )}, iS138{i0}, rS99{1600} ]
   = mma(T32_g[ iS91{( 16 * 128 )}, bS92{1}, iS93{1600} ],
         T33_g[ bS94{1}, iS136{i0}, iS137{1600} ])

Scheduler _no_op_ ***rejected*** because : output has a concrete dimension
Scheduler _matmul_ ***rejected*** because : MmaOp input has unsupported dependency
Scheduler _reduction_ ***rejected*** because : No reduction op to schedule
Scheduler _transpose_ ***rejected*** because : no support for mma ops.
Scheduler _pointwise_ ***rejected*** because : no support for mma ops.
Scheduler _inner_persistent_ ***rejected*** because : needs a reduction op
Scheduler _outer_persistent_ ***rejected*** because : needs a reduction op
Scheduler _inner_outer_persistent_ ***rejected*** because : needs a reduction op

@wujingyue
Copy link
Collaborator Author

The matmul scheduler failed at

const auto areMmaOpInputDependeciesValid = [](const Val* val) {

Looks like it assumes both operands to be broadcasted. I'm under the impression that we removed that assumption for #1628. What am I missing? @zasdfgbnm

@wujingyue
Copy link
Collaborator Author

FYI, below is the complete fusion after preseg optimizations. The MmaOp is indeed part of the beautiful broadcast+broadcast+mma+add+float2bfloat subgraph, which is good. However, due to other ops in the fusion, this subgraph is not given to the matmul scheduler immediately. Instead, it's decomposed into singletons, and the segmenter has troubles merging them into the expected subgraph.

$ NVFUSER_DUMP=fusion_ir_preseg python repro.py 
Fusion IR after pre-segmenter optimization passes:
Inputs:
  T0_g[ iS0{i0} ], __bfloat
  T1_g[ iS134{i0}, iS135{1600} ], __bfloat
  T2_g[ iS132{1600} ], __bfloat
  T3_g[ iS130{1600} ], __bfloat
  T4_g[ iS107{16}, iS108{128}, iS109{1600} ], __bfloat
Outputs:
  T38_g[ iS105{2048}, iS140{i0} ], __bfloat

%kernel_math {
T5_l[ iS110{16}, iS111{128}, iS112{1600} ]
   = __bfloat2float(T4_g[ iS107{16}, iS108{128}, iS109{1600} ]);
T6_l[ iS116{16}, iS117{128}, rS118{1600} ](Avg),
T7_l[ iS122{16}, iS123{128}, rS124{1600} ](Var),
T8_l[ iS113{16}, iS114{128}, rS115{1600} ](Count)
 = Welford ( T5_l[ iS110{16}, iS111{128}, iS112{1600} ](Avg), 
  allreduce = false )
T12_l[ iS119{16}, iS120{128}, bS30{1} ]
   = broadcast( T6_l[ iS116{16}, iS117{128}, rS118{1600} ] )
T13_l[ iS31{16}, iS32{128}, bS33{1} ]
   = Set( T12_l[ iS119{16}, iS120{128}, bS30{1} ], cache_op=Streaming )
T16_l[ iS40{16}, iS41{128}, bS42{1} ]
   = Set( T13_l[ iS31{16}, iS32{128}, bS33{1} ], cache_op=Streaming )
T17_l[ iS43{16}, iS44{128}, bS45{1 ex 1600} ] = expand( T16_l[ iS40{16}, iS41{128}, bS42{1} ], {16, 128, 1600} )
T18_l[ iS46{16}, iS47{128}, iS121{1600} ]
   = T5_l[ iS110{16}, iS111{128}, iS112{1600} ]
   - T17_l[ iS43{16}, iS44{128}, bS45{1 ex 1600} ];
d17 = (double)(1600);
d19 = double(1) * d17;
d23 = (double)(0);
d25 = d19 - d23;
d27 = (double)(0);
b29 = d25 >= d27;
d31 = (double)(0);
d33 = where(b29, d25, d31);
d39 = reciprocal(d33);
T9_l[ iS125{16}, iS126{128} ]
   = T7_l[ iS122{16}, iS123{128}, rS124{1600} ]
   * d39;
T10_l[ iS127{16}, iS128{128}, bS24{1} ]
   = broadcast( T9_l[ iS125{16}, iS126{128} ] )
T11_l[ iS25{16}, iS26{128}, bS27{1} ]
   = Set( T10_l[ iS127{16}, iS128{128}, bS24{1} ], cache_op=Streaming )
T14_l[ iS34{16}, iS35{128}, bS36{1} ]
   = T11_l[ iS25{16}, iS26{128}, bS27{1} ]
   + double(1.0000000000000001e-05);
T15_l[ iS37{16}, iS38{128}, bS39{1} ]
   = rsqrtf(T14_l[ iS34{16}, iS35{128}, bS36{1} ]);
T19_l[ iS49{16}, iS50{128}, bS51{1} ]
   = Set( T15_l[ iS37{16}, iS38{128}, bS39{1} ], cache_op=Streaming )
T20_l[ iS52{16}, iS53{128}, bS54{1 ex 1600} ] = expand( T19_l[ iS49{16}, iS50{128}, bS51{1} ], {16, 128, 1600} )
T21_l[ iS55{16}, iS56{128}, iS129{1600} ]
   = T18_l[ iS46{16}, iS47{128}, iS121{1600} ]
   * T20_l[ iS52{16}, iS53{128}, bS54{1 ex 1600} ];
T22_l[ bS58{1}, bS59{1}, iS131{1600} ]
   = broadcast( T3_g[ iS130{1600} ] )
T23_l[ bS61{1 ex 16}, bS62{1 ex 128}, iS63{1600} ] = expand( T22_l[ bS58{1}, bS59{1}, iS131{1600} ], {16, 128, 1600} )
T24_l[ bS64{1 ex 16}, bS65{1 ex 128}, iS66{1600} ]
   = __bfloat2float(T23_l[ bS61{1 ex 16}, bS62{1 ex 128}, iS63{1600} ]);
T25_l[ iS67{16}, iS68{128}, iS69{1600} ]
   = T21_l[ iS55{16}, iS56{128}, iS129{1600} ]
   * T24_l[ bS64{1 ex 16}, bS65{1 ex 128}, iS66{1600} ];
T26_l[ bS70{1}, bS71{1}, iS133{1600} ]
   = broadcast( T2_g[ iS132{1600} ] )
T27_l[ bS73{1 ex 16}, bS74{1 ex 128}, iS75{1600} ] = expand( T26_l[ bS70{1}, bS71{1}, iS133{1600} ], {16, 128, 1600} )
T28_l[ bS76{1 ex 16}, bS77{1 ex 128}, iS78{1600} ]
   = __bfloat2float(T27_l[ bS73{1 ex 16}, bS74{1 ex 128}, iS75{1600} ]);
T29_l[ iS79{16}, iS80{128}, iS81{1600} ]
   = T25_l[ iS67{16}, iS68{128}, iS69{1600} ]
   + T28_l[ bS76{1 ex 16}, bS77{1 ex 128}, iS78{1600} ];
T30_l[ iS82{16}, iS83{128}, iS84{1600} ]
   = __float2bfloat(T29_l[ iS79{16}, iS80{128}, iS81{1600} ]);
T31_l[ iS90{( 16 * 128 )}rf, iS87{1600} ] = view( T30_l[ iS82{16}, iS83{128}, iS84{1600} ] )
T32_l[ iS91{( 16 * 128 )}, bS92{1}, iS93{1600} ]
   = broadcast( T31_l[ iS90{( 16 * 128 )}rf, iS87{1600} ] )
T33_l[ bS94{1}, iS136{i0}, iS137{1600} ]
   = broadcast( T1_g[ iS134{i0}, iS135{1600} ] )
T34_l[ iS97{( 16 * 128 )}, iS138{i0}, rS99{1600} ]
   = mma(T32_l[ iS91{( 16 * 128 )}, bS92{1}, iS93{1600} ],
         T33_l[ bS94{1}, iS136{i0}, iS137{1600} ])
T35_l[ iS100{i0} ]
   = __bfloat2float(T0_g[ iS0{i0} ]);
T36_l[ bS101{1}, iS102{i0} ]
   = broadcast( T35_l[ iS100{i0} ] )
T37_l[ iS103{2048}, iS139{i0} ]
   = T34_l[ iS97{( 16 * 128 )}, iS138{i0}, rS99{1600} ]
   + T36_l[ bS101{1}, iS102{i0} ];
T38_g[ iS105{2048}, iS140{i0} ]
   = __float2bfloat(T37_l[ iS103{2048}, iS139{i0} ]);
}

@Priya2698
Copy link
Collaborator

This issue looks related to: #2127.
The failure stemmed from assuming inputs to be created through BroadcastOp.

@wujingyue What do you get after #2221?

While the ATen evaluation for matmul/linear will drop these assumptions once the new IR nodes are merged, at present, we assume the same in pattern matching as well (BroadcastOp -> MmaOp-> CastOp)

wujingyue added a commit that referenced this issue May 9, 2024
For #2199 

Broadcasts before Mma are optional. matmul_expr_eval still has problems
with this, but I'll file a separate issue for that.
@wujingyue
Copy link
Collaborator Author

at present, we assume the same in pattern matching as well (BroadcastOp -> MmaOp-> CastOp)

That's right. I already merged #2221, so you can reproduce this by running the reproducer in #2199 (comment). Anyhow, I'll run my experiments with matmul_expr_eval disabled, so #2221 is sufficient to unblock me at this moment.

That being said, there's a variation of the same problem for the ATen evaluation: the segmenter doesn't guarantee to put cast into the same segment, just as it didn't put broadcast that way. I think the new IR nodes will help that but I'm not sure and I'll leave that to you.

@wujingyue
Copy link
Collaborator Author

With some more hacks (which I'll try to find a way to submit), I'm getting some useful nvFusions to hopefully start with. Now the forward pass runs two nvFusions. The first one has one fd.ops.linear, which I suspect is the input linear layer. The second one has three fd.ops.linear, which I suspect is the output linear layer followed by the two-layer MLP.

I'll confirm this and try to include SDPA as well.

$ NVFUSER_DUMP=python_definition NVFUSER_DISABLE=matmul_expr_eval pytest thunder/benchmarks/targets.py -k test_nanogpt_block_fwd[thunder] -s
============================================================================================================================ test session starts =============================================================================================================================
platform linux -- Python 3.10.12, pytest-8.1.1, pluggy-1.4.0
Test order randomisation NOT enabled. Enable with --random-order or --random-order-bucket=<bucket_type>
benchmark: 4.0.0 (defaults: timer=time.perf_counter disable_gc=False min_rounds=5 min_time=0.000005 max_time=1.0 calibration_precision=10 warmup=False warmup_iterations=100000)
rootdir: /opt/pytorch/lightning-thunder
configfile: pyproject.toml
plugins: timestamper-0.0.10, xdist-3.5.0, random-order-1.1.1, cov-4.1.0, benchmark-4.0.0, hypothesis-6.100.0, timeout-2.2.0, shard-0.1.2
timeout: 900.0s
timeout method: signal
timeout func_only: False
collected 162 items / 161 deselected / 1 selected
Running 1 items in this shard

thunder/benchmarks/targets.py
def nvfuser_fusion_id0(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T1 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
    T2 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T3 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T4 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[True, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
    T5 = fd.ops.cast(T4, dtype=DataType.Float)
    T6, T7 = fd.ops.var_mean(T5, dims=[2], correction=0, keepdim=False)
    S8 = fd.define_scalar(16, dtype=DataType.Int)
    S9 = fd.define_scalar(128, dtype=DataType.Int)
    S10 = fd.define_scalar(1, dtype=DataType.Int)
    V11 = fd.define_vector([S8, S9, S10], dtype=DataType.Int)
    T12 = fd.ops.broadcast_in_dim(T6, shape=V11, broadcast_dims=[0, 1])
    S13 = fd.define_scalar(16, dtype=DataType.Int)
    S14 = fd.define_scalar(128, dtype=DataType.Int)
    S15 = fd.define_scalar(1, dtype=DataType.Int)
    V16 = fd.define_vector([S13, S14, S15], dtype=DataType.Int)
    T17 = fd.ops.broadcast_in_dim(T7, shape=V16, broadcast_dims=[0, 1])
    S18 = fd.define_scalar(1.00000e-05, dtype=DataType.Double)
    T19 = fd.ops.add(T12, S18)
    T20 = fd.ops.rsqrt(T19)
    S21 = fd.define_scalar(16, dtype=DataType.Int)
    S22 = fd.define_scalar(128, dtype=DataType.Int)
    S23 = fd.define_scalar(1600, dtype=DataType.Int)
    V24 = fd.define_vector([S21, S22, S23], dtype=DataType.Int)
    T25 = fd.ops.broadcast_in_dim(T17, shape=V24, broadcast_dims=[0, 1, 2])
    T26 = fd.ops.sub(T5, T25)
    S27 = fd.define_scalar(16, dtype=DataType.Int)
    S28 = fd.define_scalar(128, dtype=DataType.Int)
    S29 = fd.define_scalar(1600, dtype=DataType.Int)
    V30 = fd.define_vector([S27, S28, S29], dtype=DataType.Int)
    T31 = fd.ops.broadcast_in_dim(T20, shape=V30, broadcast_dims=[0, 1, 2])
    T32 = fd.ops.mul(T26, T31)
    S33 = fd.define_scalar(16, dtype=DataType.Int)
    S34 = fd.define_scalar(128, dtype=DataType.Int)
    S35 = fd.define_scalar(1600, dtype=DataType.Int)
    V36 = fd.define_vector([S33, S34, S35], dtype=DataType.Int)
    T37 = fd.ops.broadcast_in_dim(T3, shape=V36, broadcast_dims=[2])
    T38 = fd.ops.cast(T37, dtype=DataType.Float)
    T39 = fd.ops.mul(T32, T38)
    S40 = fd.define_scalar(16, dtype=DataType.Int)
    S41 = fd.define_scalar(128, dtype=DataType.Int)
    S42 = fd.define_scalar(1600, dtype=DataType.Int)
    V43 = fd.define_vector([S40, S41, S42], dtype=DataType.Int)
    T44 = fd.ops.broadcast_in_dim(T2, shape=V43, broadcast_dims=[2])
    T45 = fd.ops.cast(T44, dtype=DataType.Float)
    T46 = fd.ops.add(T39, T45)
    T47 = fd.ops.cast(T46, dtype=DataType.BFloat16)
    S48 = fd.define_scalar(2048, dtype=DataType.Int)
    S49 = fd.define_scalar(1600, dtype=DataType.Int)
    V50 = fd.define_vector([S48, S49], dtype=DataType.Int)
    T51 = fd.ops.reshape(T47, new_shape=V50)
    T52 = fd.ops.linear(T51, T1, T0)
    S53 = fd.define_scalar(16, dtype=DataType.Int)
    S54 = fd.define_scalar(128, dtype=DataType.Int)
    S55 = fd.define_scalar(4800, dtype=DataType.Int)
    V56 = fd.define_vector([S53, S54, S55], dtype=DataType.Int)
    T57 = fd.ops.reshape(T52, new_shape=V56)
    T58 = fd.ops.slice(T57, start_indices=[0, 0, 0], end_indices=[16, 128, 1600], strides=[1, 1, 1])
    T59 = fd.ops.slice(T57, start_indices=[0, 0, 1600], end_indices=[16, 128, 3200], strides=[1, 1, 1])
    T60 = fd.ops.slice(T57, start_indices=[0, 0, 3200], end_indices=[16, 128, 4800], strides=[1, 1, 1])
    S61 = fd.define_scalar(16, dtype=DataType.Int)
    S62 = fd.define_scalar(128, dtype=DataType.Int)
    S63 = fd.define_scalar(25, dtype=DataType.Int)
    S64 = fd.define_scalar(64, dtype=DataType.Int)
    V65 = fd.define_vector([S61, S62, S63, S64], dtype=DataType.Int)
    T66 = fd.ops.reshape(T59, new_shape=V65)
    T67 = fd.ops.permute(T66, dims=[0, 2, 1, 3])
    S68 = fd.define_scalar(16, dtype=DataType.Int)
    S69 = fd.define_scalar(128, dtype=DataType.Int)
    S70 = fd.define_scalar(25, dtype=DataType.Int)
    S71 = fd.define_scalar(64, dtype=DataType.Int)
    V72 = fd.define_vector([S68, S69, S70, S71], dtype=DataType.Int)
    T73 = fd.ops.reshape(T58, new_shape=V72)
    T74 = fd.ops.permute(T73, dims=[0, 2, 1, 3])
    S75 = fd.define_scalar(16, dtype=DataType.Int)
    S76 = fd.define_scalar(128, dtype=DataType.Int)
    S77 = fd.define_scalar(25, dtype=DataType.Int)
    S78 = fd.define_scalar(64, dtype=DataType.Int)
    V79 = fd.define_vector([S75, S76, S77, S78], dtype=DataType.Int)
    T80 = fd.ops.reshape(T60, new_shape=V79)
    T81 = fd.ops.permute(T80, dims=[0, 2, 1, 3])
    fd.add_output(T74)
    fd.add_output(T67)
    fd.add_output(T81)

[W509 16:47:44.547956141 matmul_utils.cpp:386] Warning: Scheduling a matmul without heuristic plugin. Specify plugin location like this: NVFUSER_MATMUL_HEURISTIC_PLUGIN=/path/to/libmatmulheuristic.so (function operator())

def nvfuser_fusion_id1(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T1 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
    T2 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T3 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T4 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T5 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
    T6 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
    T7 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
    T8 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[True, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
    T9 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[True, True, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[3, 2, 1, 0])
    T10 = fd.ops.permute(T9, dims=[0, 2, 1, 3])
    T11 = fd.ops.stride_order(T10, stride_order=[3, 2, 1, 0])
    S12 = fd.define_scalar(16, dtype=DataType.Int)
    S13 = fd.define_scalar(128, dtype=DataType.Int)
    S14 = fd.define_scalar(1600, dtype=DataType.Int)
    V15 = fd.define_vector([S12, S13, S14], dtype=DataType.Int)
    T16 = fd.ops.reshape(T11, new_shape=V15)
    S17 = fd.define_scalar(2048, dtype=DataType.Int)
    S18 = fd.define_scalar(1600, dtype=DataType.Int)
    V19 = fd.define_vector([S17, S18], dtype=DataType.Int)
    T20 = fd.ops.reshape(T16, new_shape=V19)
    T21 = fd.ops.linear(T20, T1, T0)
    S22 = fd.define_scalar(16, dtype=DataType.Int)
    S23 = fd.define_scalar(128, dtype=DataType.Int)
    S24 = fd.define_scalar(1600, dtype=DataType.Int)
    V25 = fd.define_vector([S22, S23, S24], dtype=DataType.Int)
    T26 = fd.ops.reshape(T21, new_shape=V25)
    S27 = fd.define_scalar(0.00000, dtype=DataType.Double)
    S28 = fd.define_scalar(1.00000, dtype=DataType.Double)
    S29 = fd.define_scalar(16, dtype=DataType.Int)
    S30 = fd.define_scalar(128, dtype=DataType.Int)
    S31 = fd.define_scalar(1600, dtype=DataType.Int)
    V32 = fd.define_vector([S29, S30, S31], dtype=DataType.Int)
    T33 = fd.ops.uniform(S27, S28, shape=V32, dtype=DataType.BFloat16)
    S34 = fd.define_scalar(0.900000, dtype=DataType.Double)
    T35 = fd.ops.lt(T33, S34)
    T36 = fd.ops.cast(T26, dtype=DataType.Float)
    T37 = fd.ops.cast(T35, dtype=DataType.Float)
    T38 = fd.ops.mul(T36, T37)
    S39 = fd.define_scalar(1.11111, dtype=DataType.Double)
    T40 = fd.ops.mul(T38, S39)
    T41 = fd.ops.cast(T8, dtype=DataType.Float)
    T42 = fd.ops.add(T41, T40)
    T43, T44 = fd.ops.var_mean(T42, dims=[2], correction=0, keepdim=False)
    S45 = fd.define_scalar(16, dtype=DataType.Int)
    S46 = fd.define_scalar(128, dtype=DataType.Int)
    S47 = fd.define_scalar(1, dtype=DataType.Int)
    V48 = fd.define_vector([S45, S46, S47], dtype=DataType.Int)
    T49 = fd.ops.broadcast_in_dim(T43, shape=V48, broadcast_dims=[0, 1])
    S50 = fd.define_scalar(16, dtype=DataType.Int)
    S51 = fd.define_scalar(128, dtype=DataType.Int)
    S52 = fd.define_scalar(1, dtype=DataType.Int)
    V53 = fd.define_vector([S50, S51, S52], dtype=DataType.Int)
    T54 = fd.ops.broadcast_in_dim(T44, shape=V53, broadcast_dims=[0, 1])
    S55 = fd.define_scalar(1.00000e-05, dtype=DataType.Double)
    T56 = fd.ops.add(T49, S55)
    T57 = fd.ops.rsqrt(T56)
    S58 = fd.define_scalar(16, dtype=DataType.Int)
    S59 = fd.define_scalar(128, dtype=DataType.Int)
    S60 = fd.define_scalar(1600, dtype=DataType.Int)
    V61 = fd.define_vector([S58, S59, S60], dtype=DataType.Int)
    T62 = fd.ops.broadcast_in_dim(T54, shape=V61, broadcast_dims=[0, 1, 2])
    T63 = fd.ops.sub(T42, T62)
    S64 = fd.define_scalar(16, dtype=DataType.Int)
    S65 = fd.define_scalar(128, dtype=DataType.Int)
    S66 = fd.define_scalar(1600, dtype=DataType.Int)
    V67 = fd.define_vector([S64, S65, S66], dtype=DataType.Int)
    T68 = fd.ops.broadcast_in_dim(T57, shape=V67, broadcast_dims=[0, 1, 2])
    T69 = fd.ops.mul(T63, T68)
    S70 = fd.define_scalar(16, dtype=DataType.Int)
    S71 = fd.define_scalar(128, dtype=DataType.Int)
    S72 = fd.define_scalar(1600, dtype=DataType.Int)
    V73 = fd.define_vector([S70, S71, S72], dtype=DataType.Int)
    T74 = fd.ops.broadcast_in_dim(T3, shape=V73, broadcast_dims=[2])
    T75 = fd.ops.cast(T74, dtype=DataType.Float)
    T76 = fd.ops.mul(T69, T75)
    S77 = fd.define_scalar(16, dtype=DataType.Int)
    S78 = fd.define_scalar(128, dtype=DataType.Int)
    S79 = fd.define_scalar(1600, dtype=DataType.Int)
    V80 = fd.define_vector([S77, S78, S79], dtype=DataType.Int)
    T81 = fd.ops.broadcast_in_dim(T2, shape=V80, broadcast_dims=[2])
    T82 = fd.ops.cast(T81, dtype=DataType.Float)
    T83 = fd.ops.add(T76, T82)
    T84 = fd.ops.cast(T83, dtype=DataType.BFloat16)
    S85 = fd.define_scalar(2048, dtype=DataType.Int)
    S86 = fd.define_scalar(1600, dtype=DataType.Int)
    V87 = fd.define_vector([S85, S86], dtype=DataType.Int)
    T88 = fd.ops.reshape(T84, new_shape=V87)
    T89 = fd.ops.linear(T88, T5, T4)
    S90 = fd.define_scalar(16, dtype=DataType.Int)
    S91 = fd.define_scalar(128, dtype=DataType.Int)
    S92 = fd.define_scalar(6400, dtype=DataType.Int)
    V93 = fd.define_vector([S90, S91, S92], dtype=DataType.Int)
    T94 = fd.ops.reshape(T89, new_shape=V93)
    T95 = fd.ops.cast(T94, dtype=DataType.Float)
    T96 = fd.ops.mul(T95, T95)
    T97 = fd.ops.mul(T96, T95)
    S98 = fd.define_scalar(0.500000, dtype=DataType.Double)
    T99 = fd.ops.mul(S98, T95)
    S100 = fd.define_scalar(0.0447150, dtype=DataType.Double)
    T101 = fd.ops.mul(S100, T97)
    T102 = fd.ops.add(T95, T101)
    S103 = fd.define_scalar(0.797885, dtype=DataType.Double)
    T104 = fd.ops.mul(S103, T102)
    T105 = fd.ops.tanh(T104)
    S106 = fd.define_scalar(1.00000, dtype=DataType.Double)
    T107 = fd.ops.add(S106, T105)
    T108 = fd.ops.mul(T99, T107)
    T109 = fd.ops.cast(T108, dtype=DataType.BFloat16)
    S110 = fd.define_scalar(2048, dtype=DataType.Int)
    S111 = fd.define_scalar(6400, dtype=DataType.Int)
    V112 = fd.define_vector([S110, S111], dtype=DataType.Int)
    T113 = fd.ops.reshape(T109, new_shape=V112)
    T114 = fd.ops.linear(T113, T7, T6)
    S115 = fd.define_scalar(16, dtype=DataType.Int)
    S116 = fd.define_scalar(128, dtype=DataType.Int)
    S117 = fd.define_scalar(1600, dtype=DataType.Int)
    V118 = fd.define_vector([S115, S116, S117], dtype=DataType.Int)
    T119 = fd.ops.reshape(T114, new_shape=V118)
    S120 = fd.define_scalar(0.00000, dtype=DataType.Double)
    S121 = fd.define_scalar(1.00000, dtype=DataType.Double)
    S122 = fd.define_scalar(16, dtype=DataType.Int)
    S123 = fd.define_scalar(128, dtype=DataType.Int)
    S124 = fd.define_scalar(1600, dtype=DataType.Int)
    V125 = fd.define_vector([S122, S123, S124], dtype=DataType.Int)
    T126 = fd.ops.uniform(S120, S121, shape=V125, dtype=DataType.BFloat16)
    S127 = fd.define_scalar(0.900000, dtype=DataType.Double)
    T128 = fd.ops.lt(T126, S127)
    T129 = fd.ops.cast(T119, dtype=DataType.Float)
    T130 = fd.ops.cast(T128, dtype=DataType.Float)
    T131 = fd.ops.mul(T129, T130)
    S132 = fd.define_scalar(1.11111, dtype=DataType.Double)
    T133 = fd.ops.mul(T131, S132)
    T134 = fd.ops.add(T42, T133)
    T135 = fd.ops.cast(T134, dtype=DataType.BFloat16)
    fd.add_output(T135)

@Priya2698
Copy link
Collaborator

at present, we assume the same in pattern matching as well (BroadcastOp -> MmaOp-> CastOp)

That's right. I already merged #2221, so you can reproduce this by running the reproducer in #2199 (comment). Anyhow, I'll run my experiments with matmul_expr_eval disabled, so #2221 is sufficient to unblock me at this moment.

That being said, there's a variation of the same problem for the ATen evaluation: the segmenter doesn't guarantee to put cast into the same segment, just as it didn't put broadcast that way. I think the new IR nodes will help that but I'm not sure and I'll leave that to you.

Yes, the new IR nodes will fix this issue since we won't evaluate a decomposed IR. The pattern matching will be redundant and removed once the API is modified to use the new IR nodes.

@wujingyue
Copy link
Collaborator Author

@Priya2698 Wdyt about the drafted WAR for nvFuser-not-support-3D-linear? Submit that WAR in Thunder or wait for your PRs?

@Priya2698
Copy link
Collaborator

@Priya2698 Wdyt about the drafted WAR for nvFuser-not-support-3D-linear? Submit that WAR in Thunder or wait for your PRs?

It looks like the WAR will still run into the segmentation issue due to the reshapes.

If you don't necessarily need that change in thunder to proceed, then adding the new nodes will lift that restriction anyway. I am estimating the new PRs within a couple days earlier next week.

We can go ahead with it if it unblocks you in the interim.

@wujingyue
Copy link
Collaborator Author

Cool -- I closed Lightning-AI/lightning-thunder#391.

wujingyue added a commit to Lightning-AI/lightning-thunder that referenced this issue May 9, 2024
For NVIDIA/Fuser#2199.

To run them,

```
NVFUSER_DISABLE=matmul_expr_eval python before_sdpa.py
NVFUSER_DISABLE=matmul_expr_eval python after_sdpa.py
```

`matmul_expr_eval` is disabled for a known limitation that will be fixed
soon.

I'll try to include SDPA as well. Currently, the two files implement
things before and after SDPA. For your understanding, code around
`fd.ops.uniform` corresponds to dropout. Code around `fd.ops.tanh`
corresponds to an approximated GELU layer. Code around `fd.ops.var_mean`
corresponds to layernorm.
@wujingyue
Copy link
Collaborator Author

@cowanmeg Lightning-AI/lightning-thunder@bf84b04 checked in what's in the forward pass of a single-device transformer block modulo SDPA. See the message of that commit for more details. With that, we should be able to work on this in parallel. I'll try to include SPDA and backprop, and you'll try to build a sharded version. How does that sound?

@cowanmeg
Copy link
Collaborator

cowanmeg commented May 9, 2024

Thanks @wujingyue! This is super helpful, I'll start working on the sharding soon!

@cowanmeg
Copy link
Collaborator

cowanmeg commented May 18, 2024

I annotated the sharding of the MLP layer of the example: https://gist.github.com/cowanmeg/75b4144a3627df74efcfc12dda01a2a3

Some comments:
(1) The two linear layers and GeLU have sharded computation. The dropout, layernorm, and residual add have replicated computed on each device. (BTW I don't think it would be too hard to represent SP).
Sharding propagation is relatively straightforward if we annotate only the Linear layer inputs and outputs. I think the current naive one will suffice for at least now.
(2) Now that LinearOp and MatmulOp are part of the compute definition, we need should reconsider how we insert resharding expressions and DID leaf parallelization. (cc @Priya2698 @jacobhinkle)
(3) Pointwise scheduler needs to be modified to ignore DID axes. This should be straightforward as reordering DID axes in front and ignoring them.

While we discuss our design for (2), I will manually translate these programs and decompose the LinearOp myself. Regardless this is necessary since we need to logically split sharded axes in the compute definition because of our RFactor restriction. For MLP, this isn't too hard and would let us get a small example working.

wujingyue added a commit to Lightning-AI/lightning-thunder that referenced this issue May 29, 2024
For NVIDIA/Fuser#2199.

To run them,

```
NVFUSER_DISABLE=matmul_expr_eval python before_sdpa.py
NVFUSER_DISABLE=matmul_expr_eval python after_sdpa.py
```

`matmul_expr_eval` is disabled for a known limitation that will be fixed
soon.

I'll try to include SDPA as well. Currently, the two files implement
things before and after SDPA. For your understanding, code around
`fd.ops.uniform` corresponds to dropout. Code around `fd.ops.tanh`
corresponds to an approximated GELU layer. Code around `fd.ops.var_mean`
corresponds to layernorm.
@wujingyue
Copy link
Collaborator Author

FYI, Lightning-AI/lightning-thunder@af6bfc1 added the forward pass of the whole transformer block (i.e. with SDPA). Caveat: the speed is probably far from SOL because nvFuser can't fuse matmul+softmax+matmul at this moment. #2278 is going to add an SDPA IR node so we can fallback to the existing flash attention implementation in ATen. When that's done, we'll see in the fusion definition simply the SDPA node instead of the decomposed form.

@wujingyue
Copy link
Collaborator Author

Lightning-AI/lightning-thunder@b06bf4e adds the backprop. It's hard to verify because https://github.com/Lightning-AI/lightning-thunder/blob/bc3925be04e7ec58d9b24fb7ac55fbe007862a65/thunder/core/rematerialization.py#L569 mixes in some ops from the forward pass. However, when I try to print the backprop trace before rematerialization (see below), I do see 12 prims.matmuls, which looks right. (There are 4 linear layers and 2 matmuls in the forward pass, each of which becomes 2 matmuls in backprop).

import thunder
import thunder.core.prims as prims
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def backward_fn(saved_for_backward, cotangents):
  # saved_for_backward: "Collection"
  # cotangents: "Collection"
  C0, C1, = saved_for_backward
  t27, = cotangents
  t0, t4, t8, t11, t12, t13, t15, t20, t_attn_c_attn_weight, t42, t45, t49, t59, \
  t71, t75, t80, t84, t_attn_c_proj_weight, t92, t100, t108, t113, t116, t117, \
  t118, t120, t125, t_mlp_c_fc_weight, t127, t127, t127, t129, t149, t152, t136, \
  t157, t_mlp_c_proj_weight, t168, = C0
  i6, i9, f43, f47, i54, f62, f78, i85, f89, f91, f93, f102, = C1
  i622 = prims.sub(1600, i85)  # i622: "int 1600"
  i807 = prims.sub(1600, i6)  # i807: "int 1600"
  [t524, t527, t588, t591, t596, t602, t661, t664, t773, t776, t781, t787, t827] = nvFusion0(f102, f43, f47, f62, f78, f89, f91, f93, i54, i622, i807, i9, t0, t100, t108, t11, t113, t116, t117, t118, t12, t120, t125, t127, t129, t13, t136, t149, t15, t152, t157, t168, t20, t27, t4, t42, t45, t49, t59, t71, t75, t8, t80, t84, t92, t_attn_c_attn_weight, t_attn_c_proj_weight, t_mlp_c_fc_weight, t_mlp_c_proj_weight)
    # t506 = prims.convert_element_type(t27, dtypes.float32)  # t506: "cuda:0 f32[16, 128, 1600]"
    # t511 = prims.mul(f102, t506)  # t511: "cuda:0 f32[16, 128, 1600]"
    # t514 = prims.mul(t168, t511)  # t514: "cuda:0 f32[16, 128, 1600]"
    # t517 = prims.convert_element_type(t514, dtypes.bfloat16)  # t517: "cuda:0 bf16[16, 128, 1600]"
    # t518 = prims.reshape(t517, (2048, 1600))  # t518: "cuda:0 bf16[2048, 1600]"
    # t519 = prims.matmul(t518, t_mlp_c_proj_weight)  # t519: "cuda:0 bf16[2048, 6400]"
    # t520 = prims.reshape(t519, (16, 128, 6400))  # t520: "cuda:0 bf16[16, 128, 6400]"
    # t522 = prims.transpose(t518, (1, 0))  # t522: "cuda:0 bf16[1600, 2048]"
    # t523 = prims.reshape(t157, (2048, 6400))  # t523: "cuda:0 bf16[2048, 6400]"
    # t524 = prims.matmul(t522, t523)  # t524: "cuda:0 bf16[1600, 6400]"
    # t526 = prims.sum(t514, (0, 1))  # t526: "cuda:0 f32[1600]"
    # t527 = prims.convert_element_type(t526, dtypes.bfloat16)  # t527: "cuda:0 bf16[1600]"
    # t528 = prims.convert_element_type(t520, dtypes.float32)  # t528: "cuda:0 f32[16, 128, 6400]"
    # t529 = prims.mul(t152, t528)  # t529: "cuda:0 f32[16, 128, 6400]"
    # t530 = prims.mul(t136, t528)  # t530: "cuda:0 f32[16, 128, 6400]"
    # t537 = prims.mul(t149, t149)  # t537: "cuda:0 f32[16, 128, 6400]"
    # t538 = prims.sub(1.0, t537)  # t538: "cuda:0 f32[16, 128, 6400]"
    # t539 = prims.mul(t530, t538)  # t539: "cuda:0 f32[16, 128, 6400]"
    # t543 = prims.mul(f93, t539)  # t543: "cuda:0 f32[16, 128, 6400]"
    # t550 = prims.mul(f91, t543)  # t550: "cuda:0 f32[16, 128, 6400]"
    # t554 = prims.mul(f89, t529)  # t554: "cuda:0 f32[16, 128, 6400]"
    # t558 = prims.add(t543, t554)  # t558: "cuda:0 f32[16, 128, 6400]"
    # t561 = prims.mul(t127, t550)  # t561: "cuda:0 f32[16, 128, 6400]"
    # t562 = prims.mul(t129, t550)  # t562: "cuda:0 f32[16, 128, 6400]"
    # t567 = prims.add(t558, t562)  # t567: "cuda:0 f32[16, 128, 6400]"
    # t570 = prims.mul(t127, t561)  # t570: "cuda:0 f32[16, 128, 6400]"
    # t576 = prims.add(t567, t570)  # t576: "cuda:0 f32[16, 128, 6400]"
    # t580 = prims.add(t576, t570)  # t580: "cuda:0 f32[16, 128, 6400]"
    # t581 = prims.convert_element_type(t580, dtypes.bfloat16)  # t581: "cuda:0 bf16[16, 128, 6400]"
    # t582 = prims.reshape(t581, (2048, 6400))  # t582: "cuda:0 bf16[2048, 6400]"
    # t583 = prims.matmul(t582, t_mlp_c_fc_weight)  # t583: "cuda:0 bf16[2048, 1600]"
    # t584 = prims.reshape(t583, (16, 128, 1600))  # t584: "cuda:0 bf16[16, 128, 1600]"
    # t586 = prims.transpose(t582, (1, 0))  # t586: "cuda:0 bf16[6400, 2048]"
    # t587 = prims.reshape(t125, (2048, 1600))  # t587: "cuda:0 bf16[2048, 1600]"
    # t588 = prims.matmul(t586, t587)  # t588: "cuda:0 bf16[6400, 1600]"
    # t590 = prims.sum(t580, (0, 1))  # t590: "cuda:0 f32[6400]"
    # t591 = prims.convert_element_type(t590, dtypes.bfloat16)  # t591: "cuda:0 bf16[6400]"
    # t592 = prims.convert_element_type(t584, dtypes.float32)  # t592: "cuda:0 f32[16, 128, 1600]"
    # t595 = prims.sum(t592, (0, 1))  # t595: "cuda:0 f32[1600]"
    # t596 = prims.convert_element_type(t595, dtypes.bfloat16)  # t596: "cuda:0 bf16[1600]"
    # t597 = prims.mul(t120, t592)  # t597: "cuda:0 f32[16, 128, 1600]"
    # t598 = prims.mul(t118, t592)  # t598: "cuda:0 f32[16, 128, 1600]"
    # t601 = prims.sum(t598, (0, 1))  # t601: "cuda:0 f32[1600]"
    # t602 = prims.convert_element_type(t601, dtypes.bfloat16)  # t602: "cuda:0 bf16[1600]"
    # t603 = prims.mul(t117, t597)  # t603: "cuda:0 f32[16, 128, 1600]"
    # t604 = prims.mul(t116, t597)  # t604: "cuda:0 f32[16, 128, 1600]"
    # t605 = prims.sum(t604, (2,))  # t605: "cuda:0 f32[16, 128]"
    # t606 = prims.broadcast_in_dim(t605, [16, 128, 1], [0, 1])  # t606: "cuda:0 f32[16, 128, 1]"
    # t607 = prims.neg(t603)  # t607: "cuda:0 f32[16, 128, 1600]"
    # t609 = prims.sum(t607, (2,))  # t609: "cuda:0 f32[16, 128]"
    # t610 = prims.broadcast_in_dim(t609, [16, 128, 1], [0, 1])  # t610: "cuda:0 f32[16, 128, 1]"
    # t611 = prims.mul(-0.5, t606)  # t611: "cuda:0 f32[16, 128, 1]"
    # t612 = prims.pow(t113, 3.0)  # t612: "cuda:0 f32[16, 128, 1]"
    # t613 = prims.mul(t611, t612)  # t613: "cuda:0 f32[16, 128, 1]"
    # t615 = prims.sum(t610, (2,))  # t615: "cuda:0 f32[16, 128]"
    # t616 = prims.sum(t613, (2,))  # t616: "cuda:0 f32[16, 128]"
    # t619 = prims.broadcast_in_dim(t615, [16, 128, 1], [0, 1])  # t619: "cuda:0 f32[16, 128, 1]"
    # t620 = prims.broadcast_in_dim(t619, (16, 128, 1600), (0, 1, 2))  # t620: "cuda:0 f32[16, 128, 1600]"
    # t621 = prims.mul(0.000625, t620)  # t621: "cuda:0 f32[16, 128, 1600]"
    # t623 = prims.broadcast_in_dim(t616, [16, 128, 1], [0, 1])  # t623: "cuda:0 f32[16, 128, 1]"
    # t624 = prims.broadcast_in_dim(t623, (16, 128, 1600), (0, 1, 2))  # t624: "cuda:0 f32[16, 128, 1600]"
    # t626 = prims.broadcast_in_dim(t108, [16, 128, 1], [0, 1])  # t626: "cuda:0 f32[16, 128, 1]"
    # t627 = prims.broadcast_in_dim(t626, (16, 128, 1600), (0, 1, 2))  # t627: "cuda:0 f32[16, 128, 1600]"
    # t628 = prims.mul(2.0, t624)  # t628: "cuda:0 f32[16, 128, 1600]"
    # t629 = prims.sub(t100, t627)  # t629: "cuda:0 f32[16, 128, 1600]"
    # t630 = prims.mul(t628, t629)  # t630: "cuda:0 f32[16, 128, 1600]"
    # f631 = prims.convert_element_type(i622, float)  # f631: "float 1600.0"
    # t632 = prims.div(t630, f631)  # t632: "cuda:0 f32[16, 128, 1600]"
    # t633 = prims.add(t621, t632)  # t633: "cuda:0 f32[16, 128, 1600]"
    # t637 = prims.add(t603, t633)  # t637: "cuda:0 f32[16, 128, 1600]"
    # t641 = prims.add(t506, t637)  # t641: "cuda:0 f32[16, 128, 1600]"
    # t648 = prims.mul(f78, t641)  # t648: "cuda:0 f32[16, 128, 1600]"
    # t651 = prims.mul(t92, t648)  # t651: "cuda:0 f32[16, 128, 1600]"
    # t654 = prims.convert_element_type(t651, dtypes.bfloat16)  # t654: "cuda:0 bf16[16, 128, 1600]"
    # t655 = prims.reshape(t654, (2048, 1600))  # t655: "cuda:0 bf16[2048, 1600]"
    # t656 = prims.matmul(t655, t_attn_c_proj_weight)  # t656: "cuda:0 bf16[2048, 1600]"
    # t657 = prims.reshape(t656, (16, 128, 1600))  # t657: "cuda:0 bf16[16, 128, 1600]"
    # t659 = prims.transpose(t655, (1, 0))  # t659: "cuda:0 bf16[1600, 2048]"
    # t660 = prims.reshape(t84, (2048, 1600))  # t660: "cuda:0 bf16[2048, 1600]"
    # t661 = prims.matmul(t659, t660)  # t661: "cuda:0 bf16[1600, 1600]"
    # t663 = prims.sum(t651, (0, 1))  # t663: "cuda:0 f32[1600]"
    # t664 = prims.convert_element_type(t663, dtypes.bfloat16)  # t664: "cuda:0 bf16[1600]"
    # t668 = prims.reshape(t657, (16, 128, 25, 64))  # t668: "cuda:0 bf16[16, 128, 25, 64]"
    # t671 = prims.transpose(t668, (0, 2, 1, 3))  # t671: "cuda:0 bf16[16, 25, 128, 64]"
    # t672 = prims.transpose(t42, (0, 1, 3, 2))  # t672: "cuda:0 bf16[16, 25, 64, 128]"
    # t673 = prims.matmul(t671, t672)  # t673: "cuda:0 bf16[16, 25, 128, 128]"
    # t674 = prims.transpose(t80, (0, 1, 3, 2))  # t674: "cuda:0 bf16[16, 25, 128, 128]"
    # t675 = prims.matmul(t674, t671)  # t675: "cuda:0 bf16[16, 25, 128, 64]"
    # t676 = prims.convert_element_type(t673, dtypes.float32)  # t676: "cuda:0 f32[16, 25, 128, 128]"
    # t678 = prims.mul(f62, t676)  # t678: "cuda:0 f32[16, 25, 128, 128]"
    # t681 = prims.mul(t75, t678)  # t681: "cuda:0 f32[16, 25, 128, 128]"
    # t685 = prims.convert_element_type(t71, dtypes.float32)  # t685: "cuda:0 f32[16, 25, 128, 128]"
    # t687 = prims.mul(t685, t681)  # t687: "cuda:0 f32[16, 25, 128, 128]"
    # i691 = prims.add(i54, 4)  # i691: "int 3"
    # t701 = prims.sum(t687, (i691,))  # t701: "cuda:0 f32[16, 25, 128]"
    # t710 = prims.broadcast_in_dim(t701, [16, 25, 128, 1], [0, 1, 2])  # t710: "cuda:0 f32[16, 25, 128, 1]"
    # t711 = prims.convert_element_type(t710, dtypes.bfloat16)  # t711: "cuda:0 bf16[16, 25, 128, 1]"
    # t712 = prims.broadcast_in_dim(t711, (16, 25, 128, 128), (0, 1, 2, 3))  # t712: "cuda:0 bf16[16, 25, 128, 128]"
    # t714 = prims.convert_element_type(t712, dtypes.float32)  # t714: "cuda:0 f32[16, 25, 128, 128]"
    # t715 = prims.sub(t681, t714)  # t715: "cuda:0 f32[16, 25, 128, 128]"
    # t719 = prims.mul(t685, t715)  # t719: "cuda:0 f32[16, 25, 128, 128]"
    # t720 = prims.convert_element_type(t719, dtypes.bfloat16)  # t720: "cuda:0 bf16[16, 25, 128, 128]"
    # t722 = prims.where(t59, t720, 0.0)  # t722: "cuda:0 bf16[16, 25, 128, 128]"
    # t723 = prims.transpose(t49, (0, 1, 3, 2))  # t723: "cuda:0 bf16[16, 25, 128, 64]"
    # t724 = prims.matmul(t722, t723)  # t724: "cuda:0 bf16[16, 25, 128, 64]"
    # t725 = prims.transpose(t45, (0, 1, 3, 2))  # t725: "cuda:0 bf16[16, 25, 64, 128]"
    # t726 = prims.matmul(t725, t722)  # t726: "cuda:0 bf16[16, 25, 64, 128]"
    # t727 = prims.convert_element_type(t726, dtypes.float32)  # t727: "cuda:0 f32[16, 25, 64, 128]"
    # t729 = prims.mul(f47, t727)  # t729: "cuda:0 f32[16, 25, 64, 128]"
    # t730 = prims.convert_element_type(t729, dtypes.bfloat16)  # t730: "cuda:0 bf16[16, 25, 64, 128]"
    # t733 = prims.transpose(t730, (0, 1, 3, 2))  # t733: "cuda:0 bf16[16, 25, 128, 64]"
    # t734 = prims.convert_element_type(t724, dtypes.float32)  # t734: "cuda:0 f32[16, 25, 128, 64]"
    # t736 = prims.mul(f43, t734)  # t736: "cuda:0 f32[16, 25, 128, 64]"
    # t737 = prims.convert_element_type(t736, dtypes.bfloat16)  # t737: "cuda:0 bf16[16, 25, 128, 64]"
    # t740 = prims.transpose(t675, (0, 2, 1, 3))  # t740: "cuda:0 bf16[16, 128, 25, 64]"
    # t745 = prims.reshape(t740, (16, 128, 1600))  # t745: "cuda:0 bf16[16, 128, 1600]"
    # t748 = prims.transpose(t737, (0, 2, 1, 3))  # t748: "cuda:0 bf16[16, 128, 25, 64]"
    # t753 = prims.reshape(t748, (16, 128, 1600))  # t753: "cuda:0 bf16[16, 128, 1600]"
    # t756 = prims.transpose(t733, (0, 2, 1, 3))  # t756: "cuda:0 bf16[16, 128, 25, 64]"
    # t761 = prims.reshape(t756, (16, 128, 1600))  # t761: "cuda:0 bf16[16, 128, 1600]"
    # t766 = prims.cat((t753, t761, t745), i9)  # t766: "cuda:0 bf16[16, 128, 4800]"
    # t767 = prims.reshape(t766, (2048, 4800))  # t767: "cuda:0 bf16[2048, 4800]"
    # t768 = prims.matmul(t767, t_attn_c_attn_weight)  # t768: "cuda:0 bf16[2048, 1600]"
    # t769 = prims.reshape(t768, (16, 128, 1600))  # t769: "cuda:0 bf16[16, 128, 1600]"
    # t771 = prims.transpose(t767, (1, 0))  # t771: "cuda:0 bf16[4800, 2048]"
    # t772 = prims.reshape(t20, (2048, 1600))  # t772: "cuda:0 bf16[2048, 1600]"
    # t773 = prims.matmul(t771, t772)  # t773: "cuda:0 bf16[4800, 1600]"
    # t774 = prims.convert_element_type(t766, dtypes.float32)  # t774: "cuda:0 f32[16, 128, 4800]"
    # t775 = prims.sum(t774, (0, 1))  # t775: "cuda:0 f32[4800]"
    # t776 = prims.convert_element_type(t775, dtypes.bfloat16)  # t776: "cuda:0 bf16[4800]"
    # t777 = prims.convert_element_type(t769, dtypes.float32)  # t777: "cuda:0 f32[16, 128, 1600]"
    # t780 = prims.sum(t777, (0, 1))  # t780: "cuda:0 f32[1600]"
    # t781 = prims.convert_element_type(t780, dtypes.bfloat16)  # t781: "cuda:0 bf16[1600]"
    # t782 = prims.mul(t15, t777)  # t782: "cuda:0 f32[16, 128, 1600]"
    # t783 = prims.mul(t13, t777)  # t783: "cuda:0 f32[16, 128, 1600]"
    # t786 = prims.sum(t783, (0, 1))  # t786: "cuda:0 f32[1600]"
    # t787 = prims.convert_element_type(t786, dtypes.bfloat16)  # t787: "cuda:0 bf16[1600]"
    # t788 = prims.mul(t12, t782)  # t788: "cuda:0 f32[16, 128, 1600]"
    # t789 = prims.mul(t11, t782)  # t789: "cuda:0 f32[16, 128, 1600]"
    # t790 = prims.sum(t789, (2,))  # t790: "cuda:0 f32[16, 128]"
    # t791 = prims.broadcast_in_dim(t790, [16, 128, 1], [0, 1])  # t791: "cuda:0 f32[16, 128, 1]"
    # t792 = prims.neg(t788)  # t792: "cuda:0 f32[16, 128, 1600]"
    # t794 = prims.sum(t792, (2,))  # t794: "cuda:0 f32[16, 128]"
    # t795 = prims.broadcast_in_dim(t794, [16, 128, 1], [0, 1])  # t795: "cuda:0 f32[16, 128, 1]"
    # t796 = prims.mul(-0.5, t791)  # t796: "cuda:0 f32[16, 128, 1]"
    # t797 = prims.pow(t8, 3.0)  # t797: "cuda:0 f32[16, 128, 1]"
    # t798 = prims.mul(t796, t797)  # t798: "cuda:0 f32[16, 128, 1]"
    # t800 = prims.sum(t795, (2,))  # t800: "cuda:0 f32[16, 128]"
    # t801 = prims.sum(t798, (2,))  # t801: "cuda:0 f32[16, 128]"
    # t804 = prims.broadcast_in_dim(t800, [16, 128, 1], [0, 1])  # t804: "cuda:0 f32[16, 128, 1]"
    # t805 = prims.broadcast_in_dim(t804, (16, 128, 1600), (0, 1, 2))  # t805: "cuda:0 f32[16, 128, 1600]"
    # t806 = prims.mul(0.000625, t805)  # t806: "cuda:0 f32[16, 128, 1600]"
    # t808 = prims.broadcast_in_dim(t801, [16, 128, 1], [0, 1])  # t808: "cuda:0 f32[16, 128, 1]"
    # t809 = prims.broadcast_in_dim(t808, (16, 128, 1600), (0, 1, 2))  # t809: "cuda:0 f32[16, 128, 1600]"
    # t811 = prims.broadcast_in_dim(t4, [16, 128, 1], [0, 1])  # t811: "cuda:0 f32[16, 128, 1]"
    # t812 = prims.broadcast_in_dim(t811, (16, 128, 1600), (0, 1, 2))  # t812: "cuda:0 f32[16, 128, 1600]"
    # t813 = prims.mul(2.0, t809)  # t813: "cuda:0 f32[16, 128, 1600]"
    # t814 = prims.sub(t0, t812)  # t814: "cuda:0 f32[16, 128, 1600]"
    # t815 = prims.mul(t813, t814)  # t815: "cuda:0 f32[16, 128, 1600]"
    # f816 = prims.convert_element_type(i807, float)  # f816: "float 1600.0"
    # t817 = prims.div(t815, f816)  # t817: "cuda:0 f32[16, 128, 1600]"
    # t818 = prims.add(t806, t817)  # t818: "cuda:0 f32[16, 128, 1600]"
    # t822 = prims.add(t788, t818)  # t822: "cuda:0 f32[16, 128, 1600]"
    # t826 = prims.add(t641, t822)  # t826: "cuda:0 f32[16, 128, 1600]"
    # t827 = prims.convert_element_type(t826, dtypes.bfloat16)  # t827: "cuda:0 bf16[16, 128, 1600]"
  return (t827, t776, t773, t664, t661, t781, t787, t596, t602, t591, t588, t527, t524)

@wujingyue
Copy link
Collaborator Author

@cowanmeg Here's how you can get a Thunder trace to help you understand the backprop nvFusion. The Thunder trace tends to be more concise than nvFusion and has shapes annotated. Also, you can dump the intermediate traces to see where the end trace comes from.

  1. Check out the branch wjy/sharded, which disables bookend and the SDPA executor so Thunder gives nvFuser the entire transformer block. It also patches the linear layer to work around Merging IterDomains requires that their iteration types match. #2317.
  2. Add print of whatever traces you like to examine. https://github.com/Lightning-AI/lightning-thunder/blob/27158e62a19e144a3081be9507c81084b702c58e/thunder/executors/torch_autograd.py#L109 is where Thunder tries to generate the forward and backward passes from the forward-only original trace. You can print any TraceCtx to see what it looks like. For example, in my previous comment, I tried to print bw_trace before https://github.com/Lightning-AI/lightning-thunder/blob/27158e62a19e144a3081be9507c81084b702c58e/thunder/executors/torch_autograd.py#L214 to see what the trace looks like before rematerialization.
  3. pytest thunder/benchmarks/targets.py -k test_nanogpt_block_grad[thunder] -s.

@wujingyue
Copy link
Collaborator Author

FYI, Lightning-AI/lightning-thunder@e19f6ea tries to update the test case to use the GPT-3 config, the one used in the two most recent Megatron papers: https://arxiv.org/pdf/2104.04473 and https://arxiv.org/pdf/2205.05198. It hits #2359 at this moment.

wujingyue added a commit to Lightning-AI/lightning-thunder that referenced this issue Jun 7, 2024
For NVIDIA/Fuser#2199.

To run them,

```
NVFUSER_DISABLE=matmul_expr_eval python before_sdpa.py
NVFUSER_DISABLE=matmul_expr_eval python after_sdpa.py
```

`matmul_expr_eval` is disabled for a known limitation that will be fixed
soon.

I'll try to include SDPA as well. Currently, the two files implement
things before and after SDPA. For your understanding, code around
`fd.ops.uniform` corresponds to dropout. Code around `fd.ops.tanh`
corresponds to an approximated GELU layer. Code around `fd.ops.var_mean`
corresponds to layernorm.
cowanmeg added a commit that referenced this issue Jun 26, 2024
Manually sharded tensor parallel multilayer perception layer.

Input is manually translated and sharded mlp layer taken from nanoGPT.
See #2199 for where we get the
initial compute trace.
wujingyue added a commit to Lightning-AI/lightning-thunder that referenced this issue Jul 29, 2024
For NVIDIA/Fuser#2199.

To run them,

```
NVFUSER_DISABLE=matmul_expr_eval python before_sdpa.py
NVFUSER_DISABLE=matmul_expr_eval python after_sdpa.py
```

`matmul_expr_eval` is disabled for a known limitation that will be fixed
soon.

I'll try to include SDPA as well. Currently, the two files implement
things before and after SDPA. For your understanding, code around
`fd.ops.uniform` corresponds to dropout. Code around `fd.ops.tanh`
corresponds to an approximated GELU layer. Code around `fd.ops.var_mean`
corresponds to layernorm.
wujingyue added a commit to Lightning-AI/lightning-thunder that referenced this issue Jul 29, 2024
For NVIDIA/Fuser#2199.

Both use the GPT-3 sizes to be consistent with the Megatron paper:
https://arxiv.org/pdf/2104.04473

With #541 fixed, I saw one fusion for forward, and one for backward.
Yay!

SDPA is still in decomposed form. Pending on
NVIDIA/Fuser#2483 and changes to Thunder's
executors.
wujingyue added a commit that referenced this issue Jul 30, 2024
For #2199.

I've been maintaining the nvFusions and the inputs in a branch. This PR
checks them into nvFuser's main for convenience.
wujingyue added a commit that referenced this issue Jul 30, 2024
For #2199.

I've been maintaining the nvFusions and the inputs in a branch. This PR
checks them into nvFuser's main for convenience.

```shell
$ pytest benchmarks/python/test_transformer.py
================================================================================================ test session starts =================================================================================================
platform linux -- Python 3.10.12, pytest-8.1.1, pluggy-1.5.0
Test order randomisation NOT enabled. Enable with --random-order or --random-order-bucket=<bucket_type>
benchmark: 4.0.0 (defaults: timer=time.perf_counter disable_gc=False min_rounds=5 min_time=0.000005 max_time=1.0 calibration_precision=10 warmup=False warmup_iterations=100000)
rootdir: /opt/pytorch/nvfuser
plugins: xdist-3.6.1, hypothesis-6.104.2, timestamper-0.0.10, cov-5.0.0, timeout-2.3.1, random-order-1.1.1, benchmark-4.0.0, shard-0.1.2
collected 2 items
Running 2 items in this shard

benchmarks/python/test_transformer.py ..                                                                                                                                                                       [100%]

--------------------------------------------------------------------------------------- benchmark: 2 tests --------------------------------------------------------------------------------------
Name (time in ms)                  Min                 Max                Mean            StdDev              Median               IQR            Outliers      OPS            Rounds  Iterations
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_transformer_forward       54.7508 (1.0)       72.9630 (1.0)       67.7275 (1.0)      4.8153 (1.38)      68.7300 (1.0)      2.0469 (1.01)          2;2  14.7650 (1.0)          10           1
test_transformer_backward     174.6965 (3.19)     187.7991 (2.57)     183.9202 (2.72)     3.4975 (1.0)      184.6459 (2.69)     2.0344 (1.0)           2;1   5.4371 (0.37)         10           1
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

Legend:
  Outliers: 1 Standard Deviation from Mean; 1.5 IQR (InterQuartile Range) from 1st Quartile and 3rd Quartile.
  OPS: Operations Per Second, computed as 1 / Mean
================================================================================================= 2 passed in 8.56s ==================================================================================================
```
wujingyue added a commit that referenced this issue Jul 30, 2024
For #2199.

I've been maintaining the nvFusions and the inputs in a branch. This PR
checks them into nvFuser's main for convenience.

```shell
$ pytest benchmarks/python/test_transformer.py
================================================================================================ test session starts =================================================================================================
platform linux -- Python 3.10.12, pytest-8.1.1, pluggy-1.5.0
Test order randomisation NOT enabled. Enable with --random-order or --random-order-bucket=<bucket_type>
benchmark: 4.0.0 (defaults: timer=time.perf_counter disable_gc=False min_rounds=5 min_time=0.000005 max_time=1.0 calibration_precision=10 warmup=False warmup_iterations=100000)
rootdir: /opt/pytorch/nvfuser
plugins: xdist-3.6.1, hypothesis-6.104.2, timestamper-0.0.10, cov-5.0.0, timeout-2.3.1, random-order-1.1.1, benchmark-4.0.0, shard-0.1.2
collected 2 items
Running 2 items in this shard

benchmarks/python/test_transformer.py ..                                                                                                                                                                       [100%]

--------------------------------------------------------------------------------------- benchmark: 2 tests --------------------------------------------------------------------------------------
Name (time in ms)                  Min                 Max                Mean            StdDev              Median               IQR            Outliers      OPS            Rounds  Iterations
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_transformer_forward       54.7508 (1.0)       72.9630 (1.0)       67.7275 (1.0)      4.8153 (1.38)      68.7300 (1.0)      2.0469 (1.01)          2;2  14.7650 (1.0)          10           1
test_transformer_backward     174.6965 (3.19)     187.7991 (2.57)     183.9202 (2.72)     3.4975 (1.0)      184.6459 (2.69)     2.0344 (1.0)           2;1   5.4371 (0.37)         10           1
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

Legend:
  Outliers: 1 Standard Deviation from Mean; 1.5 IQR (InterQuartile Range) from 1st Quartile and 3rd Quartile.
  OPS: Operations Per Second, computed as 1 / Mean
================================================================================================= 2 passed in 8.56s ==================================================================================================
```
wujingyue added a commit that referenced this issue Aug 10, 2024
For #2199

```
$ pytest benchmarks/python/test_transformer.py
```

Before:
```
--------------------------------------------------------------------------------------- benchmark: 2 tests --------------------------------------------------------------------------------------
Name (time in ms)                  Min                 Max                Mean            StdDev              Median               IQR            Outliers      OPS            Rounds  Iterations
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_transformer_forward       53.0883 (1.0)       69.7684 (1.0)       65.8204 (1.0)      6.0816 (1.62)      68.9426 (1.0)      4.2709 (2.16)          2;2  15.1929 (1.0)          10           1
test_transformer_backward     174.3857 (3.28)     187.1334 (2.68)     184.6143 (2.80)     3.7561 (1.0)      185.1308 (2.69)     1.9769 (1.0)           1;1   5.4167 (0.36)         10           1
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
```

After:
```
--------------------------------------------------------------------------------------- benchmark: 2 tests --------------------------------------------------------------------------------------
Name (time in ms)                  Min                 Max                Mean            StdDev              Median               IQR            Outliers      OPS            Rounds  Iterations
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_transformer_forward       53.3807 (1.0)       66.7263 (1.0)       63.6231 (1.0)      3.7131 (1.15)      64.7397 (1.0)      1.0460 (1.0)           1;2  15.7176 (1.0)          10           1
test_transformer_backward     160.4337 (3.01)     171.0229 (2.56)     168.4271 (2.65)     3.2160 (1.0)      169.6143 (2.62)     3.7713 (3.61)          1;1   5.9373 (0.38)         10           1
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
```
wujingyue added a commit that referenced this issue Aug 16, 2024
For #2199

Thanks to Lightning-AI/lightning-thunder#951,
I'm now able to generate microbenchmarks with SDPA nodes!

```
$ pytest benchmarks/python/test_transformer.py
```

Before:
```
--------------------------------------------------------------------------------------- benchmark: 2 tests --------------------------------------------------------------------------------------
Name (time in ms)                  Min                 Max                Mean            StdDev              Median               IQR            Outliers      OPS            Rounds  Iterations
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_transformer_forward       53.0883 (1.0)       69.7684 (1.0)       65.8204 (1.0)      6.0816 (1.62)      68.9426 (1.0)      4.2709 (2.16)          2;2  15.1929 (1.0)          10           1
test_transformer_backward     174.3857 (3.28)     187.1334 (2.68)     184.6143 (2.80)     3.7561 (1.0)      185.1308 (2.69)     1.9769 (1.0)           1;1   5.4167 (0.36)         10           1
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
```

After:
```
--------------------------------------------------------------------------------------- benchmark: 2 tests --------------------------------------------------------------------------------------
Name (time in ms)                  Min                 Max                Mean            StdDev              Median               IQR            Outliers      OPS            Rounds  Iterations
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_transformer_forward       53.3807 (1.0)       66.7263 (1.0)       63.6231 (1.0)      3.7131 (1.15)      64.7397 (1.0)      1.0460 (1.0)           1;2  15.7176 (1.0)          10           1
test_transformer_backward     160.4337 (3.01)     171.0229 (2.56)     168.4271 (2.65)     3.2160 (1.0)      169.6143 (2.62)     3.7713 (3.61)          1;1   5.9373 (0.38)         10           1
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
```
wujingyue added a commit that referenced this issue Sep 25, 2024
This is a fresh dump from the latest
https://github.com/Lightning-AI/lightning-thunder/tree/wjy/sharded.

The main differences are:
1. The code size of the fusion is cut in half because shapes are inlined.
2. Two more outputs are added and therefore cached for backprop.

For #2199.
@wujingyue wujingyue mentioned this issue Oct 5, 2024
wujingyue added a commit that referenced this issue Oct 9, 2024
For #2199. 

This PR only shards the MLP. MHA will come in a separate PR (#3115) to
keep changes small and incremental.
wujingyue added a commit that referenced this issue Oct 17, 2024
All tensors are replicated to all devices to start with. Future PRs will
try to shard them.

For #2199.
wujingyue added a commit that referenced this issue Oct 22, 2024
This PR tries to parallelize inputs according to
https://arxiv.org/pdf/1909.08053. `propagate_shardings` is able to
propagate parallelization to intermediate tensors and outputs.

Fixes #2199.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants