When working on compiling or exporting code that has data dependent sizes, it's common to run into GuardOnDataDependentSymNode errors. Unfortunately, if it's the first time you've needed to fix one of these problems, it can be quite difficult to get your bearing in a real world model. The purpose of these puzzlers is to give you some practice fixing GuardOnDataDependentSymNode errors in a controlled setting, so you can get some sense for what works and doesn't and build your mental model.

I highly recommend reading [Dealing with GuardOnDataDependentSymNode errors](https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit#heading=h.44gwi83jepaj) document first.  If you have questions or need help, you can reach out to [PT2 Data Dependent Shapes WG](https://fb.workplace.com/groups/6829516587176185).

In [3]:
# pyre-ignore-all-errors

import torch
import torch._dynamo.config
from typing import List, Union, Any
from torch import Tensor
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
from dataclasses import dataclass
import random

import sys
import traceback
import IPython

def customshowtraceback(self, *args, **kwargs):
    traceback.print_exc()

IPython.core.interactiveshell.InteractiveShell.showtraceback = customshowtraceback

torch._dynamo.config.capture_scalar_outputs = True
torch._dynamo.config.capture_dynamic_output_shape_ops = True

TEMPLATE = False
MSG = "🎉 Correct! 🎉"

@dataclass
class U:
    # Describes an unbacked SymInt whose real value is val, for testing
    val: int

UnpackedTy = Union[int, U]
PackedTy = Union[int, Tensor]

def assert_eq(actual, expected):
    if isinstance(actual, Tensor) or isinstance(expected, Tensor):
        assert isinstance(actual, Tensor), actual
        assert isinstance(expected, Tensor), expected
        cond = torch.allclose(actual, expected)
    elif isinstance(actual, list) or isinstance(expected, list):
        assert isinstance(actual, list), actual
        assert isinstance(expected, list), expected
        for a, e in zip(actual, expected):
            assert_eq(a, e)
        cond = len(actual) == len(expected)
    else:
        assert type(actual) == type(expected), (actual, expected)
        cond = actual == expected
        assert isinstance(cond, bool), (actual, expected)
    if not cond:
        raise RuntimeError(f"{actual} != {expected}")

def pack_size(s: UnpackedTy) -> PackedTy:
    if isinstance(s, U):
        return torch.tensor([s.val, s.val])
    else:
        return s

def pack_sizes(shape: UnpackedTy) -> List[PackedTy]:
    return [pack_size(s) for s in shape]

# NB: the unpack functions will be Dynamo traced, so they're not really inverses
def unpack_size(s: PackedTy) -> int:
    if isinstance(s, Tensor):
        r = s[0].item()
        torch._check_is_size(r)
        return r
    else:
        return s

def unpack_sizes(ss: List[PackedTy]) -> List[int]:
    return [unpack_size(s) for s in ss]

def unpack_tensor(ss: List[PackedTy]) -> Tensor:
    return torch.randn(unpack_sizes(ss))

def force_guard(x: bool):
    if x:
        return torch.tensor(True)
    else:
        return torch.tensor(False)

def run_test(fn):
    torch._dynamo.reset()
    fn()
    print(MSG)

## Basic symbolic reasoning

Here are some basic warmups to test your understanding of [Dealing with GuardOnDataDependentSymNode errors](https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit#heading=h.44gwi83jepaj), and familiarize you with how this notebook is setup.

**[check]** Use torch._check to resolve the compile time error below.

In [95]:
@torch.compile(backend="eager", fullgraph=True)
def cf_check(x):
    u0, u1 = x.tolist()
    if TEMPLATE:
        pass
    else:
        torch._check(u0 * 2 == u1 * 3)
    # Do not modify the code below here (imagine it's in framework code you can't edit)
    # NB: In future exercises, we'll use force_guard as a shorthand for this pattern.
    if u0 * 2 == u1 * 3:
        return torch.tensor(True)
    else:
        return torch.tensor(False)

@run_test
def test_check():
    assert cf_check(torch.tensor([12, 8])).item()

**[checkand]**  It is best not to use logical conjunction inside `torch._check` operators.  Modify the torch._check below so that compilation passes.

In [96]:
@torch.compile(backend="eager", fullgraph=True)
def cf_checkand(x):
    u0, u1 = x.tolist()
    if TEMPLATE:
        torch._check((u0 // 3 == 0) and (u1 // 5 == 0))
    else:
        torch._check(u0 // 3 == 0)
        torch._check(u1 // 5 == 0)
    # Do not modify the code below
    return force_guard((u0 // 3 == 0) and (u1 // 5 == 0))

@run_test
def test_checkand():
    assert cf_checkand(torch.tensor([2, 4]))

**[checksize]** Marking a variable as size-like means we will assume it is `>= 2` inside size oblivious guards, but we will still permit it to have 0/1 value at runtime.

In [97]:
@torch.compile(backend="eager", fullgraph=True)
def cf_checksize(x):
    u0 = x[0].item()
    # For extra credit, mark u0 as size-like WITHOUT explicitly calling torch._check_is_size
    if TEMPLATE:
        pass
    else:
        torch._check_is_size(u0)
    return force_guard(guard_size_oblivious(u0 != 0))

@run_test
def test_checksize():
    assert cf_checksize(torch.tensor([5, 5])).item()
    assert cf_checksize(torch.tensor([0, 0])).item()  # this is true!  You can make some strange things happen if you combine this with torch._check(u0 == 0)

**[implicitsize]** Some APIs do not implicitly specify a size is size-like.  So you may need to explicitly specify it in those situations.  When this is not a bug in PyTorch, this is typically because the API accepts sentinel values like -1 and you have to explicitly indicate that this is not a permitted input.

In [98]:
@torch.compile(backend="eager", fullgraph=True)
def cf_implicitsize(x, y):
    u0 = x[0].item()
    if TEMPLATE:
        pass
    else:
        torch._check_is_size(u0)
    torch._check(u0 < y.size(0))
    return y[u0]

@run_test
def test_implicitsize():
    assert cf_implicitsize(torch.tensor([2, 2]), torch.randn(10)).item()

**[nomemo]** Some APIs in PyTorch accept Tensor directly where int is expected. This results in an implicit item() call. Although sometimes we memoize the returned unbacked SymInts, in general it is better to make sure you call item() once, and use the resulting unbacked SymInt consistently through the rest of the program.  Fix the program below.  Note that the commented `torch._check_is_size` does not work (why?)

In [99]:
@torch.compile(fullgraph=True, backend="eager")
def cf_nomemo(x, y):
    if TEMPLATE:
        # torch._check_is_size(y[0])
        return x.unsqueeze(1).expand(-1, y[0])
    else:
        u0 = y[0].item()
        torch._check_is_size(u0)
        return x.unsqueeze(1).expand(-1, u0)

@run_test
def test_nomemo():
    cf_nomemo(torch.randn(8), torch.tensor([2]))

## What's in a tensor constructor?

The very first guards you are likely to encounter with data-dependent shapes are those associated with tensor constructors.  Modern PyTorch has been patched to work with data-dependent shapes; in this section, we reimplement various important functions that are called during tensor construction, to make them data-dependent shape friendly.  If you don't feel comfortable working with strides, consider reading the [PyTorch Internals](http://blog.ezyang.com/2019/05/pytorch-internals/) blog post.  We have already imported `guard_size_oblivious` so that it is in scope.

**[contigstride]**  When you call a constructor like torch.empty() which produces a contiguous tensor, we must compute the contiguous strides for the tensor. Intuitively, the stride of a contiguous tensor is the product of all sizes to the left of it; e.g., a tensor with size `[2, 3, 4]` has stride `[3*4, 4, 1]`.

This calculation has a subtle special case for zero sized tensors (`t.numel() == 0`). Ordinarily, a stride tells us how much we must advance the physical pointer of a tensor to access the next element.  But in a zero element tensor, there is no next element: this means we could validly put whatever we want in the stride of a tensor.  The convention that Numpy and PyTorch have for this situation is, we compute the strides as normal, but if a size is zero, we treat it as if it had size one for the purpose of stride computation, preventing us from zeroing out any subsequent strides.

Below, we've reproduced `make_contiguous_strides_for` function in PyTorch.  Modify it so that it works with data dependent sizes.

In [100]:
# Change this FUNCTION only
def make_contiguous_strides_for(shape: List[int]) -> List[int]:
    """Returns the strides of a contiguous tensor."""
    multiplier = 1 
    strides = []
    for l in reversed(shape):
        torch._check_is_size(l)
        strides.append(multiplier)
        if TEMPLATE:
            if l >= 1:
                multiplier *= l
        else:
            multiplier *= max(l, 1)

    return list(reversed(strides))

def f_contigstride(pxs):
    xs = unpack_sizes(pxs)
    return torch.empty_strided(xs, make_contiguous_strides_for(xs))

cf_contigstride = torch.compile(fullgraph=True, backend="eager")(f_contigstride)

@run_test
def test_contigstride():
    inp1, ans1 = pack_sizes([U(2), U(3), U(4)]), (12, 4, 1)
    assert_eq(f_contigstride(inp1).stride(), ans1)
    assert_eq(cf_contigstride(inp1).stride(), ans1)

    inp2, ans2 = pack_sizes([U(2), U(0), U(4)]), (4, 4, 1)
    assert_eq(f_contigstride(inp2).stride(), ans2)
    # Extra credit: can you make this test pass?
    # assert_eq(cf_contfor(inp2).stride(), ans2)

**[manualcontig]** Another thing we must do on tensor construction is compute if it is contiguous or not.  Although a call to torch.empty is obviously contiguous, constructors like torch.empty_strided allow construction of tensors with arbitrary strides, so we need an algorithm `is_contiguous` that can take the sizes and strides of a tensor and determine if it is contiguous.

Modify the function below so that it works with data dependent sizes.

In [101]:
# Change this FUNCTION only
def is_contiguous(a: Tensor) -> bool:
    if TEMPLATE:
        if a.numel() < 2:
            return True

        expected_stride = 1
        for x, y in reversed(tuple(zip(a.shape, a.stride()))):
            # Skips checking strides when a dimension has length 1
            if x == 1:
                continue

            if y != expected_stride:
                return False
            expected_stride = expected_stride * x

        return True
    else:
        if guard_size_oblivious(a.numel() < 2):
            return True

        expected_stride = 1
        for x, y in reversed(tuple(zip(a.shape, a.stride()))):
            # Skips checking strides when a dimension has length 1
            if guard_size_oblivious(x == 1):
                continue

            if guard_size_oblivious(y != expected_stride):
                return False
            expected_stride = expected_stride * x

        return True    

def f_manualcontig(x):
    return force_guard(is_contiguous(unpack_tensor(x)))

cf_manualcontig = torch.compile(fullgraph=True, backend="eager")(f_manualcontig)

@run_test
def test_manualcontig():
    inp1 = pack_sizes([U(2), U(3), U(4)])
    assert cf_manualcontig(inp1).item()

**[transposecontig]** In PyTorch, `is_contiguous` is implemented with `guard_size_oblivious`.  The use of guard_size_oblivous means that sometimes the semantics of plain and compiled programs can diverge.  The most obvious consequence of this is that with unbacked SymInts, we will generally conclude that tensors are noncontiguous, even if there are specific values for which they might actually be contiguous.  For example, we always consider a transposed 2D tensor of size `[u0, u1]` non-contiguous.

In the test below, change `inp1` to show a 2D tensor which is contiguous in eager mode even after transposition.


In [102]:
def f_transposecontig(sizes):
    x = torch.zeros(sizes.tolist()).T
    return force_guard(x.is_contiguous())

cf_transposecontig = torch.compile(fullgraph=True, backend="eager")(f_transposecontig)

@run_test
def test_transposecontig():
    # # Change this line ONLY
    if TEMPLATE:
        inp1 = torch.tensor([3, 4])
    else:
        inp1 = torch.tensor([0, 4])

    assert not torch.allclose(f_transposecontig(inp1), cf_transposecontig(inp1)), f"{f_transposecontig(inp1)} == {cf_transposecontig(inp1)}"

A similar version of the problem above also occurs when answering `is_contiguous(memory_format=torch.channels_last)`, since NCHW contiguous (the default) tensors are typically not NHWC contiguous, except under the same conditions as the input above.

## Pointwise operations: broadcasting

Whenever you perform a pointwise operation between two tensors, we must compute the output shape of the tensor. Because PyTorch supports implicit broadcasting, this is not simply "assert the two input tensors have the same size and produce a tensor of the same size."  Instead, whenever you compare two dimensions for equality, you allow for a size mismatch if one of the dimensions is size one.   Unlike the previous examples, you will need to make more changes than just slapping `guard_size_oblivious` on the conditions.

In [103]:
def infer_size(a: List[int], b: List[int]) -> List[int]:
    dimsA = len(a)
    dimsB = len(b)
    ndim = max(dimsA, dimsB)
    expandedSizes = [0] * ndim
    for i in range(ndim - 1, -1, -1):
        offset = ndim - 1 - i
        dimA = dimsA - 1 - offset
        dimB = dimsB - 1 - offset
        sizeA = a[dimA] if dimA >= 0 else 1
        sizeB = b[dimB] if dimB >= 0 else 1
        torch._check_is_size(sizeA)
        torch._check_is_size(sizeB)
        if TEMPLATE:
            if sizeA == sizeB:
                expandedSizes[i] = sizeA
            elif sizeA == 1:
                expandedSizes[i] = sizeB
            elif sizeB == 1:
                expandedSizes[i] = sizeA
            else:
                raise RuntimeError("shape mismatch")
        else:
            if guard_size_oblivious(sizeA == 1):
                expandedSizes[i] = sizeB
            elif guard_size_oblivious(sizeB == 1):
                expandedSizes[i] = sizeA
            else:
                torch._check(sizeA == sizeB, "shape mismatch")
                expandedSizes[i] = sizeA
    return expandedSizes

def f_infersize(sz1, sz2):
    return torch.zeros(infer_size(unpack_sizes(sz1), unpack_sizes(sz2)))

cf_infersize = torch.compile(fullgraph=True, backend="eager")(f_infersize)

@run_test
def test_infersize():
    # Some baseline tests: these should pass even without changes
    assert_eq(cf_infersize(pack_sizes([2, 3, 4]), pack_sizes([2, 3, 4])).shape, torch.Size([2, 3, 4]))
    assert_eq(cf_infersize(pack_sizes([2, 1, 4]), pack_sizes([2, 3, 4])).shape, torch.Size([2, 3, 4]))
    # Now test with unbacked SymInts
    assert_eq(cf_infersize(pack_sizes([2, 1, 4]), pack_sizes([2, U(3), 4])).shape, torch.Size([2, 3, 4]))
    # NB: The two occurrences of U(2) have distinct unbacked SymInts!
    assert_eq(cf_infersize(pack_sizes([U(2), 1, 4]), pack_sizes([U(2), U(3), 4])).shape, torch.Size([2, 3, 4]))

## More advanced symbolic reasoning

**[latespec]** Reasoning about guards is done *online* fashion; whenever we have to test a guard, we must immediately determine if it is true or false; otherwise tracing cannot continue.  This means that whether or not a program traces is order sensitive.

Without removing the guard in question or adding any deferred runtime asserts to u0, modify this program so it compiles.

In [104]:
@torch.compile(dynamic=True, fullgraph=True, backend="eager")
def cf_latespec(x, y):
    u0 = unpack_size(x)
    s0 = y.size(0)
    if TEMPLATE:
        pass
    else:
        torch._check(s0 == 3)
    b = force_guard(u0 * s0 ** 2 == 9 * u0)  # do not change this line
    if TEMPLATE:
        torch._check(s0 == 3)
    else:
        pass
    return b

@run_test
def test_latespec():
    assert cf_latespec(pack_size(U(10)), torch.randn(3)).item()

**[uselessunbacked]** Sometimes, poorly written user code reads out unbacked variables from Tensors for use with sizes, even when they actually statically *know* the size in question.  Rewriting code to preferentially used backed quantities, and only referencing unbacked code to generate runtime asserts, can help you bypass data dependent exceptions.

Without changing the runtime semantics (including exception throwing behavior), modify this program so it compiles.

In [4]:
@torch.compile(dynamic=True, fullgraph=True, backend="eager")
def cf_uselessunbacked(x, y):
    u0 = unpack_size(x)
    if TEMPLATE:
        x2 = torch.arange(u0 // 2)
    else:
        torch._check(u0 // 2 == y.size(0))
        x2 = torch.arange(y.size(0))
    return x2 + y

@run_test
def test_uselessunbacked():
    assert_eq(cf_uselessunbacked(pack_size(U(20)), torch.zeros(10, dtype=torch.int64)), torch.arange(10,))
    try:
        cf_uselessunbacked(pack_size(U(10)), torch.zeros(10, dtype=torch.int64))
    except RuntimeError:
        pass
    else:
        raise AssertionError("Expected runtime error")

**[changevar]** When you call item() on a tensor, you trigger the allocation of a fresh unbacked variable.  It matters whether or not a given size is a variable `u0` or an expression `u0 // 2`, because we only ascribe value ranges to variables; expressions only ever have their value ranges derived from the value ranges of their free symbols.

Without changing the valid range of values for x, modify the program below so that it compiles.

In [106]:
@torch.compile(dynamic=True, fullgraph=True, backend="eager")
def cf_changevar(x):
    if TEMPLATE:
        u0 = x[0].item()
        torch._check_is_size(u0)
        r = torch.arange(u0 // 2)
    else:
        u0 = (x[0] // 2).item()
        torch._check_is_size(u0)
        r = torch.arange(u0)
    return r + r

@run_test
def test_changevar():
    assert_eq(cf_changevar(torch.tensor([20, 20])), torch.arange(10) * 2)
    assert_eq(cf_changevar(torch.tensor([0, 0])), torch.arange(0) * 2)

**[varrange]** Value ranges are a simple but powerful reasoning mechanism, whereby we propagate the lower and upper bounds of variables through expressions.  When you perform a `torch._check(u0 >= c)`, where c is a constant, you refine the value range of that variable.

Using only one call to `torch._check` (do NOT use `torch._check_is_size`), discharge all of the guards in this program.

In [107]:
@torch.compile(dynamic=True, fullgraph=True, backend="eager")
def cf_varrange(x):
    u0 = x[0].item()
    # Put your torch._check call here
    if TEMPLATE:
        pass
    else:
        torch._check(u0 >= 2)
    a = force_guard(u0 != 0)
    b = force_guard(u0 != 1)
    c = force_guard(u0 != -1)
    d = force_guard(u0 * u0 != 0)
    e = force_guard(u0 >= 0)
    f = force_guard(u0 * u0 >= 2)
    return sum([a, b, c, d, e, f])

@run_test
def test_varrange():
    assert cf_varrange(torch.tensor([2, 2])).item() == 6
    assert cf_varrange(torch.tensor([10, 10])).item() == 6
    assert cf_varrange(torch.tensor([100, 100])).item() == 6

## Judo moves

Sometimes, fixing your code doesn't involve just bashing `torch._check` everywhere, you need a more semantic change, a judo move.

**[stacklist]** Unbacked SymInts cannot be used to index into Python data structures.  However, if the data structure can be expressed as a single Tensor, it can be done symbolically

In [108]:
@torch.compile(fullgraph=True, backend="eager")
def cf_stacklist(xs: List[Tensor], y: Tensor):
    if TEMPLATE:
        u0 = y[0].item()
        torch._check_is_size(u0)
        torch._check(u0 < len(xs))
        return xs[u0]
    else:
        stack = torch.stack(xs)
        u0 = y[0].item()
        torch._check_is_size(u0)
        torch._check(u0 < len(xs))
        return stack[u0]

@run_test
def test_stacklist():
    assert_eq(cf_stacklist([torch.zeros(5), torch.ones(5)], torch.tensor([0, 0])), torch.zeros(5))

**[vb]** Some code in torchrec tests if all batches are the same size, and have an optimized codepath for this case.  If it's not statically known if this condition is true, it is better to force the unoptimized path unconditionally.  Modify the code so that it no longer compile errors.  In eager mode, you should still use the optimized codepath.  You may find the function `torch.compiler.is_dynamo_compiling()` helpful.

In [109]:
@torch.compile(fullgraph=True, backend="eager")
def cf_vb(xst: Tensor, y: Tensor):
    xs = xst.tolist()
    if TEMPLATE:
        if all(xs[0] == x for x in xs):
            return (y.view(len(xs), -1) + torch.arange(len(xs)).unsqueeze(1)).view(-1)
        else:
            return torch.concat([t + i for i, t in enumerate(torch.split(y, xs, dim=0))], dim=0)
    else:
        if not torch.compiler.is_dynamo_compiling() and all(xs[0] == x for x in xs):
            return (y.view(len(xs), -1) + torch.arange(len(xs)).unsqueeze(1)).view(-1)
        else:
            return torch.concat([t + i for i, t in enumerate(torch.split(y, xs, dim=0))], dim=0)

@run_test
def test_vb():
    assert_eq(cf_vb(torch.tensor([2, 2, 2]), torch.zeros(6, dtype=torch.int64)), torch.tensor([0, 0, 1, 1, 2, 2]))
    assert_eq(cf_vb(torch.tensor([2, 3, 1]), torch.zeros(6, dtype=torch.int64)), torch.tensor([0, 0, 1, 1, 1, 2]))

## Jagged tensors: tensor splits

**[tensorsplit]**  Your first idea will be to use `torch._check` to solve the errors. It will not work.  Instead, apply the idea from [changevar] to change the split computation from operating on unbacked SymInts representing offsets, to unbacked SymInts representing sizes.  You may find `torch.diff` and `torch.narrow` useful.

In [13]:
# TODO: Actually the original code works if you use torch.narrow
from torch._dynamo.comptime import comptime

@torch.compile(fullgraph=True, backend="eager")
def cf_tensorsplit(x, offsets_t):
    if TEMPLATE:
        offsets = offsets_t.tolist()
        rs = []
        for start, end in zip(offsets, offsets[1:]):
            # NB: The code here shows what you might end up with if you kept bashing
            # more runtime asserts, but you are still going to get stuck
            torch._check_is_size(start)
            torch._check_is_size(end)
            torch._check(end <= x.size(0))
            torch._check(start <= x.size(0))
            # Specifically, these asserts will be necessary to get past the last set
            # of conditionals, but these asserts are not necessarily true at runtime!
            """
            torch._check(end - start != 0)  # NB: this is wrong at runtime
            torch._check(end - start != 1)  # NB: this is wrong at runtime
            torch._check(end > start)  # NB: this is wrong at runtime
            """
            rs.append(torch.narrow(x, 0, start, end - start))
    else:
        sizes_t = torch.diff(offsets_t)
        sizes = sizes_t.tolist()
        rs = []
        offset = 0
        for s in sizes:
            torch._check_is_size(s)
            rs.append(torch.narrow(x, 0, offset, s))
            offset += s
    return rs

@run_test
def test_tensorsplit():
    assert_eq(
        cf_tensorsplit(torch.arange(10), torch.tensor([0, 2, 5, 7, 10])),
        [torch.tensor([0, 1]), torch.tensor([2, 3, 4]), torch.tensor([5, 6]), torch.tensor([7, 8, 9])]
    )
    assert_eq(
        cf_tensorsplit(torch.arange(10), torch.tensor([0, 2, 3, 7, 10])),
        [torch.tensor([0, 1]), torch.tensor([2]), torch.tensor([3, 4, 5, 6]), torch.tensor([7, 8, 9])]
    )

## TODO: Views: divisibility and contiguity

contiguous discontiguous slice problem: tensor[start:end, :] - analogous problem split on dim=0 of 2d https://github.com/pytorch/pytorch/issues/125519

wrapping slice problem: slice_forward decomp, sym_max/sym_min

strides matter https://github.com/pytorch/pytorch/issues/124581

## TODO: Custom operators

Overspecialization - how to find the site

C++ how to interpret the C++ stack

C++ Operator is not using dispatcher

Custom op overallocate unbacked symints https://www.internalfb.com/diff/D54314970?transaction_fbid=1125375575483406


Pending unbacked symints that don't escape custom op https://github.com/pytorch/pytorch/issues/125368

Unsolvable data dependent unbind, wrap a custom op around it https://fb.workplace.com/groups/6829516587176185/posts/1417880972183443/

## Scratch space

Everything below is not an exercise per se, but you might find something useful playing around with the samples here.

**[printlocals]** TODO THIS PUZZLER IS BROKEN, SEE https://github.com/pytorch/pytorch/issues/133650 
The program below is missing one more runtime assert before it compiles.  To understand what the allocated symbols for variables in scope are, you can use `comptime.print()` to print information about variables that are in scope.  TODO: demo propagate_real_tensors  TODO: demo extended guard creation trace env var

In [111]:
from torch._dynamo.comptime import comptime

import torch._logging
import logging
torch._logging.set_logs()

@torch._dynamo.config.patch(do_not_emit_runtime_asserts=True)
@torch.compile(dynamic=True, fullgraph=True, backend="eager")
def cf_printlocals(x):
    u5, u3 = x[2:].tolist()
    u6, *u10 = x.tolist()
    u4 = x[1].item()
    u9, u8, *u11 = x[:-1].tolist()
    torch._check(u3 != 1)
    torch._check(u5 != u6 + 2 * u4)
    torch._check_is_size(u6)
    torch._check_is_size(u4)
    torch._check_is_size(u5)
    torch._check((u6 + 2*u4) % u5 == 0)
    # Put the missing assertion here
    if TEMPLATE:
        pass
    else:
        torch._check(u3 == (u6 + 2 * u4) // u5)
        comptime.print({
            "u5": u5,
            "u3": u3,
            "u6": u6,
            "u10": u10,
            "u4": u4,
            "u9": u9,
            "u8": u8,
            "u11": u11,
        })
    u2 = torch.randn(u5, u3)
    u0 = torch.zeros(u6)
    torch._check_is_size(u4)
    u1 = torch.zeros(u4 * 2)
    stk = torch.cat([u0, u1], dim=0)
    return torch.stack([stk, stk]).view(2, *u2.size())

@run_test
def test_printlocals():
    assert_eq(cf_printlocals(torch.tensor([20, 2, 3, 8])), torch.zeros(3, 8))

In [41]:
import torch.library
from typing import Tuple

@torch.library.custom_op(
    "mylib::index_op", mutates_args=(), device_types=["cpu"]
)
def index_op(indices: Tensor, values: Tensor) -> Tensor:
    t0 = indices.max()
    u0 = t0.item()
    return (values + u0).clone()

def index_op_backward(ctx, grad):
   t0, = ctx.saved_tensors
   u0 = t0.item()
   return None, grad * u0

def index_op_setup_context(ctx, inputs, output):
   indices, values = inputs
   out = output
   t0 = indices.max()
   ctx.save_for_backward(t0)

torch.library.register_autograd(
    "mylib::index_op", index_op_backward, setup_context=index_op_setup_context)

@torch.library.register_fake("mylib::index_op")
def _(indices: Tensor, values: Tensor):
    return torch.empty_like(values)

@torch.compile(backend="inductor")
def cf_index_op(indices, values):
    return index_op(indices, values)

@run_test
def test_index_op():
    x = torch.randn(20, requires_grad=True)
    r = cf_index_op(torch.tensor([0, 1, 2]), x)
    r[1].sum().backward()

In [42]:
@torch.compile(fullgraph=True, backend="eager")
def cf(x):
    u0, u1 = x.tolist()
    torch._check_is_size(u0)
    torch._check_is_size(u1)
    torch._check(u0 + u1 == 20)
    if guard_size_oblivious(torch.sym_max(1, u0 + u1) == 20):
        return torch.tensor(True)
    else:
        return torch.tensor(False)

@run_test
def test_symmax():
    assert cf(torch.tensor([10, 10])).item()