Skip to content

Commit

Permalink
initial cherry-picked from Lightning-AI#250
Browse files Browse the repository at this point in the history
  • Loading branch information
jjsjann123 committed Jun 4, 2024
1 parent 6850202 commit 5133271
Show file tree
Hide file tree
Showing 11 changed files with 119 additions and 51 deletions.
2 changes: 1 addition & 1 deletion thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def jit(
executors: None | Sequence[Executor | str] = None,
sharp_edges: None | SHARP_EDGES_OPTIONS | str = None,
interpretation: None | INTERPRETATION_OPTIONS | str = None,
cache: None | CACHE_OPTIONS | str = None,
cache: None | CACHE_OPTIONS | str = CACHE_OPTIONS.SYMBOLIC_VALUES,
disable_torch_autograd: bool = False, # TODO Revisit this UX for RC1
early_transforms: list | None = None,
additional_transforms: list | None = None,
Expand Down
3 changes: 2 additions & 1 deletion thunder/core/baseutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ def check_valid_length(length: int):

# maybe we should skip the check for IntegerProxy in general
check_type(length, (int, NumberProxyInterface))
check(length >= 0, lambda: f"Found invalid length {length}!")
if isinstance(length, int):
check(length >= 0, lambda: f"Found invalid length {length}!")


def check_valid_shape(shape: tuple[int, ...] | list[int]):
Expand Down
14 changes: 11 additions & 3 deletions thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
import inspect
import time

from thunder.core.compile_data import compile_data_and_stats, get_cache_option, using_symbolic_values, get_compile_data
from thunder.core.symbol import bsym_header
from thunder.core.compile_data import compile_data_and_stats, get_cache_option, get_compile_data
import thunder.clang as clang
import thunder.core.transforms

Expand Down Expand Up @@ -595,6 +594,9 @@ def proxify(self, value: WrappedValue) -> Any:
co: CACHE_OPTIONS = get_cache_option()
if co is CACHE_OPTIONS.CONSTANT_VALUES:
self.add_constraint((clang.check_tensor_shape_and_metadata, p_orig))
elif co is CACHE_OPTIONS.SYMBOLIC_VALUES:
# TODO: establish guarding logic to allow non-broadcast shape change
self.add_constraint((clang.check_tensor_shape_and_metadata, p_orig))
elif co not in (CACHE_OPTIONS.SAME_INPUT, CACHE_OPTIONS.NO_CACHING):
raise NotImplementedError(f"Unsupported cache option {co}")
return p
Expand All @@ -612,6 +614,10 @@ def proxify(self, value: WrappedValue) -> Any:
self.add_constraint((clang.check_string_value, p, uvalue))
else:
self.add_constraint((clang.check_number_type_and_value, p, uvalue))
elif co is CACHE_OPTIONS.SYMBOLIC_VALUES:
# TODO: establish guarding logic
if p is not uvalue:
value.register_proxy(p)
elif co not in (CACHE_OPTIONS.SAME_INPUT, CACHE_OPTIONS.NO_CACHING):
raise NotImplementedError(f"Unsupported cache option {co}")
return p
Expand Down Expand Up @@ -1546,7 +1552,7 @@ def thunder_general_jit(
)

co: CACHE_OPTIONS = get_cache_option()
if co not in {CACHE_OPTIONS.CONSTANT_VALUES, CACHE_OPTIONS.NO_CACHING}:
if co not in {CACHE_OPTIONS.CONSTANT_VALUES, CACHE_OPTIONS.NO_CACHING, CACHE_OPTIONS.SYMBOLIC_VALUES}:
raise NotImplementedError(f"Only constant constraints is supported")

prologue_trace: TraceCtx = TraceCtx(fn)
Expand Down Expand Up @@ -1578,6 +1584,7 @@ def thunder_general_jit(
record_history=record_history,
)

# NOTE(jiej): numbers are baked in as constant here vvv
with general_jit_ctx(ctx):
with tracectx(computation_trace):
result = jfn(*args, **kwargs)
Expand Down Expand Up @@ -1613,6 +1620,7 @@ def thunder_general_jit(
else:
epilogue_trace = None

# NOTE(jiej): prologue trace is produced here vvv
pro_to_comp_proxies, pro_to_epi_proxies = unpack_inputs(
ctx, prologue_trace, pro_to_comp, pro_to_epi, args, kwargs, has_epilogue=epilogue_trace is not None
)
Expand Down
5 changes: 4 additions & 1 deletion thunder/core/prims.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ def register_method(method_name: str, method: Callable, /) -> None:
from thunder.core.proxies import (
CollectionProxy,
TensorProxy,
IntegerProxy,
NumberProxy,
is_proxyable,
proxy,
Expand All @@ -77,6 +76,7 @@ def register_method(method_name: str, method: Callable, /) -> None:
TupleProxy,
AnyProxy,
IntegerProxy,
unwrap_number_proxy,
)
import thunder.core.codeutils as codeutils
from thunder.core.codeutils import Printable
Expand All @@ -88,6 +88,7 @@ def register_method(method_name: str, method: Callable, /) -> None:
from thunder.core.trace import get_tracectx
from thunder.core.langctxs import langctx, LanguageContext, register_langctx, Languages


#
# Primitives and helpers for defining them
#
Expand Down Expand Up @@ -2610,6 +2611,7 @@ def _exogenous_like_meta(likes: Sequence[TensorProxy], /) -> tuple[TensorProxy]:
# Logically these tensors are constructed intermediate to a trace, so there's no mechanism for a user to
# extract their grad, but we could support compiling forward and backward and accessing grad attributes
# in the future
@unwrap_number_proxy
def _full_meta(shape: Sequence[int], fill_value: Number, *, device: devices.Device, dtype: dtypes.dtype) -> TensorProxy:
# Checks inputs
utils.check_type(fill_value, (Number, NumberProxy))
Expand Down Expand Up @@ -3383,6 +3385,7 @@ def transpose_meta(a: TensorProxy, /, permutation: tuple[int, ...]) -> TensorPro
view = make_prim(PrimIDs.VIEW, "view", meta=reshape_meta, tags=(OpTags.SHAPE_OP,))


@unwrap_number_proxy
def unfold_meta(a: TensorProxy, /, dim: int, size: int, step: int) -> TensorProxy:
dim = utils.canonicalize_dim(a.ndim, dim)
max_size = 1 if a.ndim == 0 else a.shape[dim]
Expand Down
31 changes: 21 additions & 10 deletions thunder/core/proxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Type, Optional, Any, Tuple, List, Union
from collections.abc import Callable
from collections.abc import Sequence
from functools import reduce, partial
from functools import reduce, partial, wraps
import operator
import builtins
import math
Expand Down Expand Up @@ -389,6 +389,10 @@ def __str__(self) -> str:
def __repr__(self) -> str:
return f"<StringProxy '{self.value}'>"

def replace_name(self, name: str, /):
"""Return a copy of this proxy with the given name."""
return StringProxy(self.value, name=name, history=self.history)

def type_string(self) -> str:
return "str"

Expand Down Expand Up @@ -901,7 +905,19 @@ def __ixor__(self, other):
NumberLike = Number | NumberProxy


def pyval(x: Number | str | AnyProxy) -> Number | str | any:
def unwrap_number_proxy(func):
@wraps(func)
def with_pyval(*args, **kwargs):
args = [pyval(arg) if isinstance(arg, NumberProxy) else arg for arg in args]
for k, v in kwargs.items():
if isinstance(v, NumberProxy):
kwargs[k] = pyval(v)
return func(*args, **kwargs)

return with_pyval


def pyval(x: NumberLike | str | AnyProxy) -> Number | str | any:
baseutils.check_type(x, (NumberProxy, Number, str, AnyProxy))

if isinstance(x, AnyProxy):
Expand Down Expand Up @@ -1041,14 +1057,9 @@ def _infer_tensor_properties(
thunder_fsdp_padding_size if thunder_fsdp_padding_size is not None else _thunder_fsdp_padding_size
)

# Extracts actual values for shape
# TODO RC1 Enable this
if using_symbolic_values():
raise NotImplementedError(
f"Trying to construct a tensor proxy while using symbolic values, but this is not yet supported"
)

_shape = tuple(pyval(x) for x in _shape)
if not using_symbolic_values():
# Extracts actual values for shape
_shape = tuple(pyval(x) for x in _shape)

# Computes derived properties
_numel = reduce(operator.mul, _shape, 1)
Expand Down
15 changes: 15 additions & 0 deletions thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1089,9 +1089,24 @@ def _div_prim_grad(a: Number | TensorProxy, b: Number | TensorProxy, /) -> Numbe

register_grad(pids.DIV, _div_prim_grad)


# NOTE not differentiable, but it would trigger a flatten failure without a grad function
# NOTE that's probably a bad error message that we should fix.
@torchctx
def _py_floordiv_prim_grad(a: Number | TensorProxy, b: Number | TensorProxy, /) -> Number | TensorProxy:
fwd = prims.py_floordiv(a, b)

return fwd


register_grad(pids.PY_FLOORDIV, _py_floordiv_prim_grad)

# Comparison operators -- these create no grad associations
register_grad(pids.EQ, prims.eq)
register_grad(pids.GE, prims.ge)
register_grad(pids.GT, prims.gt)
register_grad(pids.NE, prims.ne)
register_grad(pids.LE, prims.le)
register_grad(pids.LT, prims.lt)
register_grad(pids.NE, prims.ne)
register_grad(pids.GT, prims.gt)
Expand Down
6 changes: 3 additions & 3 deletions thunder/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import thunder.core.dtypes as dtypes
from thunder.core.pytree import tree_flatten, tree_unflatten, tree_map
from thunder.core.proxies import Proxy, NumberProxy, TensorProxy, variableify
from thunder.core.proxies import Proxy, NumberProxy, TensorProxy, variableify, pyval
from thunder.core.baseutils import *
from thunder.core.codeutils import *
from thunder.core.trace import TraceCtx
Expand Down Expand Up @@ -566,7 +566,7 @@ def canonicalize_dim(rank: int, idx: int, wrap_scalar: bool = True) -> int:
if rank == 0:
check(
wrap_scalar,
lambda: f"Dimension specified as {idx} but tensor has no dimensions!",
lambda: f"Dimension specified as {pyval(idx)} but tensor has no dimensions!",
exception_type=IndexError,
)
rank = 1
Expand All @@ -581,7 +581,7 @@ def canonicalize_dim(rank: int, idx: int, wrap_scalar: bool = True) -> int:

check(
_idx >= 0 and _idx < rank,
lambda: f"Dimension out of range (expected to be in range of [{-rank}, {rank - 1}], but got {idx})",
lambda: f"Dimension out of range (expected to be in range of [{-rank}, {rank - 1}], but got {pyval(idx)})",
exception_type=IndexError,
)

Expand Down
37 changes: 25 additions & 12 deletions thunder/executors/nvfuserex_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from thunder.core.prims import PrimIDs
from thunder.core.proxies import (
NumberProxy,
IntegerProxy,
StringProxy,
Proxy,
TupleProxy,
TensorProxy,
Expand Down Expand Up @@ -246,6 +248,10 @@ def add_input(x: Any, y: Any) -> Any:
# TODO: validate x is a tuple of int
utils.check_type(y, type)
nv = fd.define_vector(len(x._value))
elif isinstance(x, StringProxy):
utils.check_type(y, type)
# TODO: should we add a string type? I think we should reject it here and instead ask thunder to bake in string.
nv = x
elif isinstance(x, Proxy):
utils.check(False, lambda: f"Unsupported proxy type {type(x)} in fusion", exception_type=AssertionError)
else:
Expand Down Expand Up @@ -366,7 +372,7 @@ def get_tensor_descriptor(t: torch.Tensor) -> tuple[tuple[int, ...], tuple[bool,
# TODO Inline the get_tensor_descriptor call
def to_descriptors(args) -> tuple:
def to_descriptor(arg):
if isinstance(arg, Number):
if isinstance(arg, (Number, str)):
return type(arg)
elif isinstance(arg, tuple):
if len(arg) != 0:
Expand Down Expand Up @@ -408,6 +414,8 @@ def __call__(self, *args):
if nv_version >= LooseVersion("0.0.13") and hasattr(fd, "_selected_device")
else {}
)
# TODO: quick hack to drop str
args = [x for x in args if not isinstance(x, str)]
with add_markers(self.name):
return fd.execute(args, **kwargs)

Expand Down Expand Up @@ -954,10 +962,11 @@ def full(
) -> Any:
nv_fill_value = getnv(fill_value, fd, lc_to_nv_map)
nvdtype = lcdtype_to_nvdtype(dtype)
nv_shape = [getnv(i, fd, lc_to_nv_map) for i in shape]

_select_device(fd, device)

return fd.ops.full(shape, nv_fill_value, nvdtype)
return fd.ops.full(nv_shape, nv_fill_value, nvdtype)


register_supported(PrimIDs.FULL, full, _full_check)
Expand Down Expand Up @@ -1011,11 +1020,11 @@ def uniform(
nv_minval = getnv(minval, fd, lc_to_nv_map)
nv_maxval = getnv(maxval, fd, lc_to_nv_map)

nvshape = list(getnv(x, fd, lc_to_nv_map) for x in shape)
nv_shape = [getnv(i, fd, lc_to_nv_map) for i in shape]

_select_device(fd, device)

return fd.ops.uniform(nv_minval, nv_maxval, nvshape, dtype=nvdtype)
return fd.ops.uniform(nv_minval, nv_maxval, nv_shape, dtype=nvdtype)


register_supported(PrimIDs.UNIFORM, uniform, _uniform_check)
Expand All @@ -1034,8 +1043,8 @@ def _uniform_philox_check(
return (
is_supported_device(device)
and is_supported_dtype(dtype)
and is_supported_tensor_or_number(seed)
and is_supported_tensor_or_number(offset)
and isinstance(seed, (int, IntegerProxy))
and isinstance(offset, (int, IntegerProxy))
)


Expand All @@ -1056,7 +1065,7 @@ def uniform_philox(
nv_minval = getnv(minval, fd, lc_to_nv_map)
nv_maxval = getnv(maxval, fd, lc_to_nv_map)

nvshape = list(getnv(x, fd, lc_to_nv_map) for x in shape)
nv_shape = [getnv(i, fd, lc_to_nv_map) for i in shape]

nv_rng_seed = getnv(seed, fd, lc_to_nv_map)
nv_rng_offset = getnv(offset, fd, lc_to_nv_map)
Expand All @@ -1066,7 +1075,7 @@ def uniform_philox(
return fd.ops.uniform(
nv_minval,
nv_maxval,
nvshape,
nv_shape,
dtype=nvdtype,
rng_seed=nv_rng_seed,
rng_offset=nv_rng_offset,
Expand All @@ -1093,8 +1102,9 @@ def broadcast_in_dim(
a: TensorProxy, shape: list[int], broadcast_dimensions: list[int], *, fd: FusionDefinition, lc_to_nv_map: dict
) -> Any:
nva = getnv(a, fd, lc_to_nv_map)
nv_shape = [getnv(i, fd, lc_to_nv_map) for i in shape]

return fd.ops.broadcast_in_dim(nva, shape, broadcast_dimensions)
return fd.ops.broadcast_in_dim(nva, nv_shape, broadcast_dimensions)


register_supported(PrimIDs.BROADCAST_IN_DIM, broadcast_in_dim, _broadcast_in_dim_check)
Expand Down Expand Up @@ -1191,11 +1201,12 @@ def _reshape_check(a: TensorProxy, shape: list[int]) -> bool:

def reshape(a: TensorProxy, shape: list[int], *, fd: FusionDefinition, lc_to_nv_map: dict) -> Any:
nv_a = getnv(a, fd, lc_to_nv_map)
nv_shape = [getnv(i, fd, lc_to_nv_map) for i in shape]

if nv_version < LooseVersion("0.0.22"):
return fd.ops.reshape(nv_a, a.shape, shape)
return fd.ops.reshape(nv_a, a.shape, nv_shape)
else:
return fd.ops.reshape(nv_a, shape)
return fd.ops.reshape(nv_a, nv_shape)


register_supported(PrimIDs.RESHAPE, reshape, _reshape_check)
Expand Down Expand Up @@ -1879,7 +1890,9 @@ def where(

# TODO Checks that the dtype is supported by nvFuser
def _reduction_check(a: TensorProxy, dims: Sequence[int]) -> bool:
return is_supported_tensor(a, allow_low_precision_floats=False)
return is_supported_tensor(a, allow_low_precision_floats=False) and not any(
isinstance(dim, NumberProxy) for dim in dims
)


# TODO Review if this accepts empty dim sequences
Expand Down
Loading

0 comments on commit 5133271

Please sign in to comment.