In [16]:
import torch
from torch.utils._pytree import tree_map

# Tensor Subclass
- Huge In terms of scope and alot to breakdown
- Read this for fuller picture: https://dev-discuss.pytorch.org/t/what-and-why-is-torch-dispatch/557
- Adds extensibility from the python layer instead of having to drop further down into c++
- Examples of subclasses: https://github.com/albanD/subclass_zoo
- Extend the dispatcher with no need to drop down to c++


In [17]:
class BaseTensor(torch.Tensor):
    # See https://github.com/pytorch/pytorch/pull/73727 ; this is necessary
    # to ensure that super().__new__ can cooperate with each other
    @staticmethod
    def __new__(cls, elem, *, requires_grad=None):
        if requires_grad is None:
            return super().__new__(cls, elem)
        else:
            return cls._make_subclass(cls, elem, requires_grad)

    # To ensure constructors can cooperate with one another, must accept and
    # ignore element tensor (TODO: is this right???)
    def __init__(self, elem):
        super().__init__()

    # If __torch_dispatch__ is defined (which it will be for all our examples)
    # the default torch function implementation (which preserves subclasses)
    # typically must be disabled
    __torch_function__ = torch._C._disabled_torch_function_impl

### Scaffolding 
- https://github.com/albanD/subclass_zoo/blob/main/trivial_tensors.py This is the tensor to use for building a new subclass tensor

In [19]:
from types import FunctionType

import torch
# from base_tensor import BaseTensor
from torch import Tensor
from torch.fx import Graph, GraphModule, Tracer
from torch.fx.passes.shape_prop import _extract_tensor_metadata
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.utils._pytree import tree_map


class TrivialTensorViaInheritance(BaseTensor):
    """
    TrivialTensorViaInheritance extends tensor behavior using inheritance ("is
    a").  These implementations are very straightforward and we recommend
    using them if it works for your use case.  To get the base behavior,
    you use standard object-oriented idiom of super().
    Benefits and downsides of this representation:
        + Efficient representation (only one tensor).
        + Do not have to worry about synchronizing metadata between the inner
          and outer tensor.
        = Requires multiple inheritance to do composition.  This *does*
          work, but it is a bit mind-bending, you have to deal with the
          diamond inheritance problem, and traditionally you only get a fixed
          set of composition (rather than dynamic, as in functorch) unless
          you're willing to generate classes on the fly.
        - Doesn't work if you need to run internal PyTorch subsystems
          (e.g., autograd) multiple times.
        - Doesn't work if the internal tensor has a different shape
          than the outer tensor.
        - Doesn't work if you need multiple internal tensors.
    """

    @classmethod
    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
        def wrap(t):
            # When could the returned tensor already be our subclass?
            # The most common situation is when an input tensor
            # is returned as an output tensor, e.g., inplace or out
            # implementations.
            if isinstance(t, torch.Tensor) and not isinstance(t, cls):
                return cls(t)
            else:
                return t

        return tree_map(wrap, super().__torch_dispatch__(func, types, args, kwargs))



### Example of subclassed tensor that is defined internal and used alot for testing: LoggingTensor
- https://github.com/pytorch/pytorch/blob/master/torch/testing/_internal/logging_tensor.py


In [90]:

from torch.testing._internal.logging_tensor import LoggingTensor, capture_logs, capture_logs_with_logging_tensor_mode
from pprint import pprint
x = torch.randn(2,10) 
y = torch.randn(10,2)

with capture_logs_with_logging_tensor_mode() as log:
    (x@y)**2
pprint(log)

['$2 = torch._ops.aten.mm.default($0, $1)',
 '$3 = torch._ops.aten.pow.Tensor_Scalar($2, 2)']


In [91]:
x = torch.randn(2,10, device="mps") 
y = torch.randn(10,2, device="mps")

with capture_logs_with_logging_tensor_mode() as log:
    c=(x@y)**3
pprint(log)

['$2 = torch._ops.aten.mm.default($0, $1)',
 '$3 = torch._ops.aten.pow.Tensor_Scalar($2, 3)']


## How does logging tensor do this?

In [None]:
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
    def unwrap(e):
        return e.elem if isinstance(e, cls) else e

    def wrap(e):
        return cls(e) if isinstance(e, torch.Tensor) else e

    with cls.context():
        rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))
        # The above line essentially calls the function on the unwraped tesnor and then rewraps
        
    # added functionality come here which logs the the module and the name of the func as well as args and kwargs and result
    logging.getLogger("LoggingTensor").info(f"{func.__module__}.{func.__name__}", args, kwargs, rs)
    return rs

### Other cool use cases:

- Fairseq using this for offloading tensors to SSD: [here](https://github.com/facebookresearch/fairscale/blob/6f18e779a794badba1fc19bb161ed4382fd337f7/fairscale/experimental/nn/ssd_offload.py) 

- Flop counter that works with forward and backward: [here](https://pastebin.com/AkvAyJBw)

- An intern was able to implement a memory profiler in a week with __torch_dispatch__: [here](https://fb.workplace.com/notes/116830607720079/ )

- Potentially adding a new device without any c++ code:[here](https://github.com/albanD/subclass_zoo/pull/36/files)


In [82]:
# Trace big models
import torchvision.models as models
eff_weights = models.EfficientNet_B0_Weights.DEFAULT
eff_preprocessor = eff_weights.transforms()

efficientnet_b0 = models.efficientnet_b0(weights=eff_weights)
effic_cpu = efficientnet_b0.eval()

# Input for benchmarking:
x_cpu = eff_preprocessor(torch.randint(0, 256, size=(64, 3, 224, 224)))

with capture_logs() as logs:
    x = LoggingTensor(x_cpu)
    log_input("x", x)

    efficientnet_b0(x)
    
logs = [log[:log.find('Parameter containing')] for log in logs]
print('\n'.join(logs))