# Testing Tensor of shapes

One of the things its bothering me is that `spatial_shapes` is a tensor of shapes that obtained from a list of "specialized" tensors (or at least that's the intention).

To explain this a bit further, let's implement minimal reproducible example (MRE) to isolate the error from all the complicated code. We'll use `dynamo.explain` to debug our code.

In [5]:
import torch
import torch.nn as nn
import typing
import torch._dynamo as dynamo
import logging


In [22]:
def shapes_to_tensor(x: typing.List[int], device: typing.Optional[torch.device] = None) -> torch.Tensor:
    """
    Turn a list of integer scalars or integer Tensor scalars into a vector,
    in a way that's both traceable and scriptable.

    In tracing, `x` should be a list of scalar Tensor, so the output can trace to the inputs.
    In scripting or eager, `x` should be a list of int.
    """
    if torch.jit.is_scripting():
        return torch.as_tensor(x, device=device)
    if torch.jit.is_tracing() or torch._dynamo.is_compiling():
        # assert all(
        #     [isinstance(t, torch.Tensor) for t in x]
        # ), "Shape should be tensor during tracing!"
        # as_tensor should not be used in tracing because it records a constant
        ret = torch.stack(x)
        if ret.device != device:  # avoid recording a hard-coded device if not necessary
            ret = ret.to(device=device)
        return ret
    return torch.as_tensor(x, device=device)

In [None]:
from torch._dynamo.utils import get_fake_value

In [94]:
import numpy as np
class NetMRE(nn.Module):
    def __init__(self):
        super().__init__()
        self.ss: typing.Final[typing.List[typing.Tuple[int, int]]] =[(150, 100), (75, 50), (37, 25), (19, 13)]
    
    def forward(self, x: typing.List[torch.Tensor]):
        # Let's annotate possible culprits with C{N}
        # C2: tensor casting
        # C3: type overriding typing.List[typing.Tuple[int, int]] to torch.Tensor
        # spatial_shapes_tensor = torch.as_tensor(spatial_shapes, dtype=torch.long)
        reference_points = self.get_reference_points(torch.asarray(self.ss, copy=False))
        return reference_points
    
    @staticmethod
    def get_reference_points(
        spatial_shapes: torch.Tensor,
    ): 
        reference_points_list = []
        for H, W in spatial_shapes:
            lin = torch.linspace(0.5, H - 0.5, H, dtype=torch.float32) # --- breaks here
            reference_points_list.append(lin)
        return reference_points_list


In [96]:
import numpy as np
import array
class NetMRE(nn.Module):
    def forward(self, x: typing.List[torch.Tensor]):
        # Let's annotate possible culprits with C{N}
        spatial_shapes: typing.List[typing.Tuple[int, int]] = []
        for i, xi in enumerate(x):
            spatial_shapes.append(xi.shape[2:4])  # C1: tuple casting

        # C2: tensor casting
        # C3: type overriding typing.List[typing.Tuple[int, int]] to torch.Tensor
        spatial_shapes = array.array("i", spatial_shapes)
        spatial_shapes_tensor = torch.frombuffer(spatial_shapes)
        # spatial_shapes_tensor = torch.as_tensor(spatial_shapes, dtype=torch.long)
        reference_points = self.get_reference_points(spatial_shapes_tensor)
        return reference_points
    
    @staticmethod
    def get_reference_points(
        spatial_shapes: torch.Tensor,
    ): 
        reference_points_list = []
        for H, W in spatial_shapes:
            lin = torch.linspace(0.5, H - 0.5, H, dtype=torch.float32) # --- breaks here
            reference_points_list.append(lin)
        return reference_points_list


In [97]:
# sst = [(150, 100), (75, 50), (37, 25), (19, 13)]
example_kwargs = {
    "x": [
        # torch.rand(1, 3, 64, 64),
        # torch.rand(1, 3, 32, 32),
        torch.rand(1, 3, 150, 100),
        torch.rand(1, 3, 75, 50),
        torch.rand(1, 3, 37, 25),
        torch.rand(1, 3, 19, 13),
    ],
}

In [98]:
exported_program: torch.export.ExportedProgram = torch.export.export(
    NetMRE(), (), kwargs=example_kwargs, strict=True,
)
print(exported_program)

V0904 22:04:10.005000 440 torch/_dynamo/convert_frame.py:855] [9/0] torchdynamo start compiling forward /var/folders/rk/fqb5s6dn6sl66ntxp1ssybz80000gn/T/ipykernel_440/1310961134.py:4, stack (elided 4 frames):
V0904 22:04:10.005000 440 torch/_dynamo/convert_frame.py:855] [9/0]   File "/Users/dgcnz/.pyenv/versions/3.10.12/lib/python3.10/runpy.py", line 196, in _run_module_as_main
V0904 22:04:10.005000 440 torch/_dynamo/convert_frame.py:855] [9/0]     return _run_code(code, main_globals, None,
V0904 22:04:10.005000 440 torch/_dynamo/convert_frame.py:855] [9/0]   File "/Users/dgcnz/.pyenv/versions/3.10.12/lib/python3.10/runpy.py", line 86, in _run_code
V0904 22:04:10.005000 440 torch/_dynamo/convert_frame.py:855] [9/0]     exec(code, run_globals)
V0904 22:04:10.005000 440 torch/_dynamo/convert_frame.py:855] [9/0]   File "/Users/dgcnz/development/amsterdam/edge/.venv/lib/python3.10/site-packages/ipykernel_launcher.py", line 18, in <module>
V0904 22:04:10.005000 440 torch/_dynamo/convert_fra

Unsupported: call_function UserDefinedClassVariable(<class 'array.array'>) [ConstantVariable(), ListVariable(length=4)] {}

from user code:
   File "/var/folders/rk/fqb5s6dn6sl66ntxp1ssybz80000gn/T/ipykernel_440/1310961134.py", line 12, in forward
    spatial_shapes = array.array("i", spatial_shapes)


In [None]:
dynamo.config.verbose=False
dynamo.config.capture_scalar_outputs = True
torch._logging.set_logs(dynamo = logging.WARN)

explanation = dynamo.explain(NetMRE())(**example_kwargs)
print(explanation.break_reasons[0])

This error basically tells us that we are trying to compute something (creating the tensor `lin`) whose values and more importantly, shape, is dependent on whatever values the `spatial_shapes` tensor contains.

However, this should not be the case, since `spatial_shapes` are the shapes of the input tensor, which we know to be constant.

## C2: Avoiding extra tensor casting

A simple fix is to not cast `spatial_shapes` into a tensor, but that doesn't seem satisfactory and would require further rewrites in the actual code that use the tensor methods like `spatial_shapes.your_nice_tensor_method`. Anyway, let's do that to see what changes.

In [13]:
class NetC2(nn.Module):
    def forward(self, x: typing.List[torch.Tensor]):
        spatial_shapes: typing.List[typing.Tuple[int, int]] = []
        for i, xi in enumerate(x):
            spatial_shapes.append(tuple(xi.shape[2:]))
        reference_points = self.get_reference_points(spatial_shapes)
        return reference_points

    @staticmethod
    def get_reference_points(
        spatial_shapes: typing.List[typing.Tuple[int, int]],
    ): 
        reference_points_list = []
        for H, W in spatial_shapes:
            lin = torch.linspace(0.5, H - 0.5, H, dtype=torch.float32)
            reference_points_list.append(lin)
        return reference_points_list


In [None]:
dynamo.config.verbose=False
dynamo.config.capture_scalar_outputs = True
torch._logging.set_logs(dynamo = logging.WARN)

explanation = dynamo.explain(NetC2())(**example_kwargs)
print(explanation)

In [None]:
dynamo.config.verbose=True
# dynamo.config.capture_scalar_outputs = True
# dynamo.config.capture_dynamic_output_shape_ops = True
# torch._dynamo.config.fake_tensor_cache_enabled = False
torch._logging.set_logs(dynamo = logging.INFO)

# explanation = dynamo.explain(NetC4v4())(**example_kwargs)
# print(explanation)
exported_program: torch.export.ExportedProgram = torch.export.export(
    NetC2(), (), kwargs=example_kwargs, strict=True,
)
print(exported_program)

## C3

In [None]:
class NetC3(nn.Module):
    def forward(self, x: typing.List[torch.Tensor]):
        spatial_shapes: typing.List[typing.Tuple[int, int]] = []
        for i, xi in enumerate(x):
            spatial_shapes.append(xi.shape[2:])
        spatial_shapes_tensor: typing.Final[torch.Tensor] = torch.as_tensor(spatial_shapes, dtype=torch.long)
        reference_points = self.get_reference_points(spatial_shapes_tensor)
        return reference_points

    @staticmethod
    def get_reference_points(
        spatial_shapes: torch.Tensor,
    ): 
        reference_points_list = []
        for H, W in spatial_shapes:
            lin = torch.linspace(0.5, H - 0.5, H, dtype=torch.float32) # --- breaks here
            reference_points_list.append(lin)
        return reference_points_list


In [None]:
dynamo.config.verbose=False
dynamo.config.capture_scalar_outputs = True
torch._logging.set_logs(dynamo = logging.WARN)

explanation = dynamo.explain(NetC3())(**example_kwargs)
print(explanation.break_reasons[0].reason)

In [None]:
import torch._dynamo.variables as dynamo_vars

class NetC4v2(nn.Module):

    def forward(self, x: typing.List[torch.Tensor]):
        spatial_shapes: typing.List[typing.Tuple[int, int]] = []
        for i, xi in enumerate(x):
            spatial_shapes.append(xi.shape[2:])
        spatial_shapes_tensor = torch.as_tensor(spatial_shapes, dtype=torch.long)

        # Tried:
        # dynamo.assume_constant_result
        spatial_shapes_tensor._dynamo_marked_constant = True
        spatial_shapes_tensor._dynamo_static_input_type = "unguarded"
        reference_points_list = []
        for H, W in spatial_shapes_tensor:
            lin = torch.linspace(0.5, H - 0.5, H, dtype=torch.float32) # --- breaks here
            reference_points_list.append(lin)
        return reference_points_list


In [None]:
dynamo.config.verbose=False
dynamo.config.capture_scalar_outputs = True
torch._logging.set_logs(dynamo = logging.INFO)

explanation = dynamo.explain(NetC4v2())(**example_kwargs)
print(explanation)

Surprisingly, even if we 

In [None]:
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorConverter, FakeTensorMode, unset_fake_temporarily
from torch._subclasses.functional_tensor import FunctionalTensor, disable_functional_mode
from torch._dynamo.variables import TensorVariable
import numpy as np
from torch.fx.experimental.proxy_tensor import maybe_disable_fake_tensor_mode
from torch._dispatch.python import suspend_functionalization
from torch.fx.experimental.proxy_tensor import disable_proxy_modes_tracing
class NetC4v4(nn.Module):

    def __init__(self):
        super().__init__()
        # self.ss: typing.Final[typing.List[typing.Tuple[int, int]]] =[(150, 100), (75, 50), (37, 25), (19, 13)]
        # self.sst: typing.Final[torch.Tensor] = torch.tensor(self.ss)
    def forward(self, x: typing.List[torch.Tensor]):
        # dynamo.decorators.mark_static(self.sst)
        with unset_fake_temporarily():
            with suspend_functionalization(), disable_functional_mode():
                with disable_proxy_modes_tracing():
                    sst = [(150, 100), (75, 50), (37, 25), (19, 13)]
                    ss = torch.tensor(sst) 
        reference_points_list = []
        for i in range(len(ss)):
            # h = sst.constant[i][0]
            # torch._check(h == ss[i][0])
            with unset_fake_temporarily():
                with suspend_functionalization(), disable_functional_mode():
                    with disable_proxy_modes_tracing():
                        h = int(ss[i][0])
            lin = torch.linspace(0.5, h  - 0.5, h)
            # torch._check(lin.shape[0] == sst[i][0])
            reference_points_list.append(lin)
        return reference_points_list

In [None]:
dynamo.config.verbose=True
dynamo.config.capture_scalar_outputs = True
dynamo.config.capture_dynamic_output_shape_ops = True
torch._dynamo.config.fake_tensor_cache_enabled = False
torch._logging.set_logs(dynamo = logging.INFO)

# explanation = dynamo.explain(NetC4v4())(**example_kwargs)
# print(explanation)
exported_program: torch.export.ExportedProgram = torch.export.export(
    NetC4v4(), (), kwargs=example_kwargs, strict=False,
)
print(exported_program)

In [17]:
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorConverter, FakeTensorMode, unset_fake_temporarily
from torch._subclasses.functional_tensor import FunctionalTensor, disable_functional_mode
from torch._dynamo.variables import TensorVariable
import numpy as np
from torch.fx.experimental.proxy_tensor import maybe_disable_fake_tensor_mode
from torch._dispatch.python import suspend_functionalization
from torch.fx.experimental.proxy_tensor import disable_proxy_modes_tracing
class NetC4v5(nn.Module):

    def __init__(self):
        super().__init__()
        # self.ss: typing.Final[typing.List[typing.Tuple[int, int]]] =[(150, 100), (75, 50), (37, 25), (19, 13)]
        # self.sst: typing.Final[torch.Tensor] = torch.tensor(self.ss)
    def forward(self, x: typing.List[torch.Tensor]):
        # dynamo.decorators.mark_static(self.sst)
        with unset_fake_temporarily():
            # with suspend_functionalization(), disable_functional_mode():
                with disable_proxy_modes_tracing():
                    sst = [(150, 100), (75, 50), (37, 25), (19, 13)]
                    ss = torch.tensor(sst) 
        reference_points_list = []
        for i in range(len(ss)):
            # h = sst.constant[i][0]
            # torch._check(h == ss[i][0])
            with unset_fake_temporarily():
                # with suspend_functionalization(), disable_functional_mode():
                    with disable_proxy_modes_tracing():
                        h = int(ss[i][0])
            lin = torch.linspace(0.5, h  - 0.5, h)
            # torch._check(lin.shape[0] == sst[i][0])
            reference_points_list.append(lin)
        return reference_points_list

In [18]:
dynamo.config.verbose=True
dynamo.config.capture_scalar_outputs = True
dynamo.config.capture_dynamic_output_shape_ops = True
torch._dynamo.config.fake_tensor_cache_enabled = False
torch._logging.set_logs(dynamo = logging.INFO)

# explanation = dynamo.explain(NetC4v4())(**example_kwargs)
# print(explanation)
exported_program: torch.export.ExportedProgram = torch.export.export(
    NetC4v5(), (), kwargs=example_kwargs, strict=False,
)
print(exported_program)

I0821 19:57:27.517000 7923618560 torch/fx/experimental/symbolic_shapes.py:3639] produce_guards


ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x_0: "f32[1, 3, 150, 100]", x_1: "f32[1, 3, 75, 50]", x_2: "f32[1, 3, 37, 25]", x_3: "f32[1, 3, 19, 13]"):
            # File: /var/folders/rk/fqb5s6dn6sl66ntxp1ssybz80000gn/T/ipykernel_65943/2942393893.py:29 in forward, code: lin = torch.linspace(0.5, h  - 0.5, h)
            linspace: "f32[150]" = torch.ops.aten.linspace.default(0.5, 149.5, 150, device = device(type='cpu'), pin_memory = False)
            linspace_1: "f32[75]" = torch.ops.aten.linspace.default(0.5, 74.5, 75, device = device(type='cpu'), pin_memory = False)
            linspace_2: "f32[37]" = torch.ops.aten.linspace.default(0.5, 36.5, 37, device = device(type='cpu'), pin_memory = False)
            linspace_3: "f32[19]" = torch.ops.aten.linspace.default(0.5, 18.5, 19, device = device(type='cpu'), pin_memory = False)
            return (linspace, linspace_1, linspace_2, linspace_3)
            
Graph signature: ExportGraphSignature(input

In [25]:
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorConverter, FakeTensorMode, unset_fake_temporarily
from torch._subclasses.functional_tensor import FunctionalTensor, disable_functional_mode
from torch._dynamo.variables import TensorVariable
import numpy as np
from torch.fx.experimental.proxy_tensor import maybe_disable_fake_tensor_mode
from torch._dispatch.python import suspend_functionalization
from torch.fx.experimental.proxy_tensor import disable_proxy_modes_tracing
class NetC4v5(nn.Module):

    def __init__(self):
        super().__init__()
        # self.ss: typing.Final[typing.List[typing.Tuple[int, int]]] =[(150, 100), (75, 50), (37, 25), (19, 13)]
        # self.sst: typing.Final[torch.Tensor] = torch.tensor(self.ss)
    def forward(self, x: typing.List[torch.Tensor]):
        # dynamo.decorators.mark_static(self.sst)

        sst = [(150, 100), (75, 50), (37, 25), (19, 13)]
        with unset_fake_temporarily():
            # with suspend_functionalization(), disable_functional_mode():
                with disable_proxy_modes_tracing():
                    ss = torch.tensor(sst) 
        reference_points_list = []
        for i in range(len(ss)):
            # h = sst.constant[i][0]
            # torch._check(h == ss[i][0])
            with unset_fake_temporarily():
                # with suspend_functionalization(), disable_functional_mode():
                    with disable_proxy_modes_tracing():
                        h = int(ss[i][0])
            lin = torch.linspace(0.5, h  - 0.5, h)
            # torch._check(lin.shape[0] == sst[i][0])
            reference_points_list.append(lin)
        return reference_points_list

In [28]:
dynamo.config.verbose=True
# dynamo.config.capture_scalar_outputs = True
# dynamo.config.capture_dynamic_output_shape_ops = True
# torch._dynamo.config.fake_tensor_cache_enabled = False
torch._logging.set_logs(dynamo = logging.INFO)

# explanation = dynamo.explain(NetC4v4())(**example_kwargs)
# print(explanation)
exported_program: torch.export.ExportedProgram = torch.export.export(
    NetC4v5(), (), kwargs=example_kwargs, strict=False,
)
print(exported_program)

V0821 20:00:35.062000 7923618560 torch/fx/experimental/symbolic_shapes.py:2529] create_env
I0821 20:00:35.085000 7923618560 torch/fx/experimental/symbolic_shapes.py:3639] produce_guards
V0821 20:00:35.086000 7923618560 torch/fx/experimental/symbolic_shapes.py:3821] track_symint L['args'][1]['x'][0].size()[0] 1 None
V0821 20:00:35.086000 7923618560 torch/fx/experimental/symbolic_shapes.py:3821] track_symint L['args'][1]['x'][0].size()[1] 3 None
V0821 20:00:35.087000 7923618560 torch/fx/experimental/symbolic_shapes.py:3821] track_symint L['args'][1]['x'][0].size()[2] 150 None
V0821 20:00:35.087000 7923618560 torch/fx/experimental/symbolic_shapes.py:3821] track_symint L['args'][1]['x'][0].size()[3] 100 None
V0821 20:00:35.087000 7923618560 torch/fx/experimental/symbolic_shapes.py:3821] track_symint L['args'][1]['x'][0].stride()[0] 45000 None
V0821 20:00:35.088000 7923618560 torch/fx/experimental/symbolic_shapes.py:3821] track_symint L['args'][1]['x'][0].stride()[1] 15000 None
V0821 20:00:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x_0: "f32[1, 3, 150, 100]", x_1: "f32[1, 3, 75, 50]", x_2: "f32[1, 3, 37, 25]", x_3: "f32[1, 3, 19, 13]"):
            # File: /var/folders/rk/fqb5s6dn6sl66ntxp1ssybz80000gn/T/ipykernel_65943/1980386621.py:30 in forward, code: lin = torch.linspace(0.5, h  - 0.5, h)
            linspace: "f32[150]" = torch.ops.aten.linspace.default(0.5, 149.5, 150, device = device(type='cpu'), pin_memory = False)
            linspace_1: "f32[75]" = torch.ops.aten.linspace.default(0.5, 74.5, 75, device = device(type='cpu'), pin_memory = False)
            linspace_2: "f32[37]" = torch.ops.aten.linspace.default(0.5, 36.5, 37, device = device(type='cpu'), pin_memory = False)
            linspace_3: "f32[19]" = torch.ops.aten.linspace.default(0.5, 18.5, 19, device = device(type='cpu'), pin_memory = False)
            return (linspace, linspace_1, linspace_2, linspace_3)
            
Graph signature: ExportGraphSignature(input

In [24]:
dynamo.config.verbose=True
dynamo.config.capture_scalar_outputs = True
dynamo.config.capture_dynamic_output_shape_ops = True
torch._dynamo.config.fake_tensor_cache_enabled = False
torch._logging.set_logs(dynamo = logging.INFO)

exported_program: torch.export.ExportedProgram = torch.export.export(
    NetC2(), (), kwargs=example_kwargs, strict=False,
)
print(exported_program)

I0821 19:59:11.458000 7923618560 torch/fx/experimental/symbolic_shapes.py:3639] produce_guards


ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x_0: "f32[1, 3, 150, 100]", x_1: "f32[1, 3, 75, 50]", x_2: "f32[1, 3, 37, 25]", x_3: "f32[1, 3, 19, 13]"):
            # File: /var/folders/rk/fqb5s6dn6sl66ntxp1ssybz80000gn/T/ipykernel_65943/1787364395.py:6 in forward, code: reference_points = self.get_reference_points(spatial_shapes)
            linspace: "f32[150]" = torch.ops.aten.linspace.default(0.5, 149.5, 150, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
            linspace_1: "f32[75]" = torch.ops.aten.linspace.default(0.5, 74.5, 75, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
            linspace_2: "f32[37]" = torch.ops.aten.linspace.default(0.5, 36.5, 37, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
            linspace_3: "f32[19]" = torch.ops.aten.linspace.default(0.5, 18.5, 19, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
         

In [6]:
class NetC2v2(nn.Module):
    def forward(self, x: typing.List[torch.Tensor]):
        spatial_shapes: typing.List[typing.Tuple[int, int]] = []
        for i, xi in enumerate(x):
            spatial_shapes.append(tuple(xi.shape[2:]))
        
        with unset_fake_temporarily():
            with disable_proxy_modes_tracing():
                spatial_shapes_tensor = torch.tensor(spatial_shapes, dtype=torch.long)
        reference_points = self.get_reference_points(spatial_shapes_tensor)
        return reference_points

    @staticmethod
    def get_reference_points(
        spatial_shapes: torch.Tensor,
    ): 
        reference_points_list = []
        for i in range(len(spatial_shapes)):
            with disable_proxy_modes_tracing():
                with unset_fake_temporarily():
                    h = int(spatial_shapes[i][0])
            lin = torch.linspace(0.5, h - 0.5, h, dtype=torch.float32)
            reference_points_list.append(lin)
        return reference_points_list


In [7]:
dynamo.config.verbose=True
dynamo.config.capture_scalar_outputs = True
dynamo.config.capture_dynamic_output_shape_ops = True
torch._dynamo.config.fake_tensor_cache_enabled = False
torch._logging.set_logs(dynamo = logging.INFO)

exported_program: torch.export.ExportedProgram = torch.export.export(
    NetC2v2(), (), kwargs=example_kwargs, strict=False,
)
print(exported_program)

I0821 20:57:12.049000 7923618560 torch/fx/experimental/symbolic_shapes.py:3639] produce_guards


ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x_0: "f32[1, 3, 150, 100]", x_1: "f32[1, 3, 75, 50]", x_2: "f32[1, 3, 37, 25]", x_3: "f32[1, 3, 19, 13]"):
            # File: /var/folders/rk/fqb5s6dn6sl66ntxp1ssybz80000gn/T/ipykernel_84265/4072546233.py:10 in forward, code: reference_points = self.get_reference_points(spatial_shapes_tensor)
            linspace: "f32[150]" = torch.ops.aten.linspace.default(0.5, 149.5, 150, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
            linspace_1: "f32[75]" = torch.ops.aten.linspace.default(0.5, 74.5, 75, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
            linspace_2: "f32[37]" = torch.ops.aten.linspace.default(0.5, 36.5, 37, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
            linspace_3: "f32[19]" = torch.ops.aten.linspace.default(0.5, 18.5, 19, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
 

In [13]:
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorConverter, FakeTensorMode, unset_fake_temporarily
from torch._subclasses.functional_tensor import FunctionalTensor, disable_functional_mode
from torch._dynamo.variables import TensorVariable
import numpy as np
from torch._dispatch.python import suspend_functionalization
from torch.utils._python_dispatch import _unset_infra_mode, _push_mode
from torch.fx.experimental.proxy_tensor import disable_proxy_modes_tracing
import torch.nn as nn
import torch._dynamo as dynamo
import typing
import torch
import logging

class NetC4v6(nn.Module):

    def __init__(self):
        super().__init__()
        # self.ss: typing.Final[typing.List[typing.Tuple[int, int]]] =[(150, 100), (75, 50), (37, 25), (19, 13)]
        # old = torch._C._unset_dispatch_mode(torch._C._TorchDispatchModeKey.FAKE)
        # with disable_proxy_modes_tracing():
        # self.sst: typing.Final[torch.Tensor] = torch.tensor(self.ss)
    def forward(self, x: typing.List[torch.Tensor]):
        ss: typing.List[typing.Tuple[int, int]] =[(150, 100), (75, 50), (37, 25), (19, 13)]
        # old = torch._C._unset_dispatch_mode(torch._C._TorchDispatchModeKey.FAKE)
        # with disable_proxy_modes_tracing():

        old = torch._C._unset_dispatch_mode(torch._C._TorchDispatchModeKey.FAKE)
        # with unset_fake_temporarily():
        # with disable_proxy_modes_tracing():
        # return _disable_infra_mode(torch._C._TorchDispatchModeKey.PROXY)
        sst: torch.Tensor = torch.as_tensor(ss)
        torch._C._set_dispatch_mode(old)
        # dynamo.decorators.mark_static(sst)
        reference_points_list = []
        for i in range(len(sst)):
            old = torch._C._unset_dispatch_mode(torch._C._TorchDispatchModeKey.FAKE)
            mode_unset = _unset_infra_mode(torch._C._TorchDispatchModeKey.PROXY)
            h = int(sst[i][0])
            _push_mode(mode_unset)
            # torch._C._push_on_torch_dispatch_stack(mode_unset)
            torch._C._set_dispatch_mode(old)
            lin = torch.linspace(0.5, h  - 0.5, h)
            # torch._check(lin.shape[0] == sst[i][0])
            reference_points_list.append(lin)
        return reference_points_list

In [14]:
dynamo.config.verbose=True
# dynamo.config.capture_scalar_outputs = True
# dynamo.config.capture_dynamic_output_shape_ops = True
# torch._dynamo.config.fake_tensor_cache_enabled = False
torch._logging.set_logs(dynamo = logging.DEBUG)

exported_program: torch.export.ExportedProgram = torch.export.export(
    NetC4v6(), (), kwargs=example_kwargs, strict=True,
)
print(exported_program)

V0904 20:34:44.916000 440 torch/_dynamo/convert_frame.py:1216] skipping: _wrapped_call_impl (reason: in skipfiles, file: /Users/dgcnz/development/amsterdam/edge/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py)
V0904 20:34:44.916000 440 torch/_dynamo/convert_frame.py:1216] skipping: _call_impl (reason: in skipfiles, file: /Users/dgcnz/development/amsterdam/edge/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py)
V0904 20:34:44.922000 440 torch/_dynamo/convert_frame.py:855] [0/0] torchdynamo start compiling forward /var/folders/rk/fqb5s6dn6sl66ntxp1ssybz80000gn/T/ipykernel_440/300679688.py:22, stack (elided 4 frames):
V0904 20:34:44.922000 440 torch/_dynamo/convert_frame.py:855] [0/0]   File "/Users/dgcnz/.pyenv/versions/3.10.12/lib/python3.10/runpy.py", line 196, in _run_module_as_main
V0904 20:34:44.922000 440 torch/_dynamo/convert_frame.py:855] [0/0]     return _run_code(code, main_globals, None,
V0904 20:34:44.922000 440 torch/_dynamo/convert_frame.py:855] [0

-1 0
-1 0


Unsupported: call_function args: UserDefinedObjectVariable(_TorchDispatchModeKey) 

from user code:
   File "/var/folders/rk/fqb5s6dn6sl66ntxp1ssybz80000gn/T/ipykernel_440/300679688.py", line 27, in forward
    old = torch._C._unset_dispatch_mode(torch._C._TorchDispatchModeKey.FAKE)


Seems to work for strict=False only, that's a shame. Also it requires non-trivial rewrites that I probably should avoid anyway.