In [17]:
import sys

class ABC:
    a: int
    b: int = 1
    pass

cls = ABC
print (cls.__module__)

__main__


In [106]:
from typing import Tuple, List, Union
from torchtyping import TensorType
import torch
import numpy as np

_IDENTIFER = "__tensordataclass__"

def is_tensordataclass(x):
    return hasattr(x.__class__, _IDENTIFER)   

def tensordataclass(cls):

    FIELD_KEYS = list(cls.__annotations__.keys())
    FIELD_TYPES = list(cls.__annotations__.values())

    def __init__(self, *args, **kwargs) -> None:
        # Quietly pop out and save these arguments from kwargs.
        self._shape: Tuple[int] = kwargs.pop("_shape", None)
        self._packed_info: TensorType["num_chunks", 3] = kwargs.pop("_packed_info", None)

        # check the remaining arguments
        n_params = len(args) + len(kwargs.values())
        n_params_max = len(FIELD_KEYS)
        assert n_params <= n_params_max, (
            f"__init__() takes from 1 to {n_params_max + 1} positional arguments"
            f" but {n_params + 1} were given"
        )

        # collect attributes
        init_params = {}
        for i, value in enumerate(args):
            key = FIELD_KEYS[i]
            init_params.update({key: value})
        for key, value in kwargs.items():
            assert key not in init_params, (
                f"__init__() got multiple values for argument '{key}'"
            )
            assert key in FIELD_KEYS, (
                f"__init__() got an unexpected keyword argument '{key}'"
            )
            init_params.update({key: value})
        for key in FIELD_KEYS:
            if key in init_params:
                continue
            assert hasattr(self, key), (
                f"__init__() missing 1 required positional argument: '{key}'"
            )
            value = getattr(self, key)
            init_params.update({key: value})

        # set attributes
        for i, (key, value) in enumerate(init_params.items()):
            assert value is None or isinstance(value, FIELD_TYPES[i]), (
                f"__init__() got a wrong type {type(value)} v.s. {FIELD_TYPES[i]}"
                f" for argument `{key}`"
            )
            setattr(self, key, value)
        
        self.__post_init__()
        
    def __post_init__(self) -> None:
        if self.is_packed():
            # Do nothing for a packed tensor
            return
        
        batch_shapes = []
        for f in FIELD_KEYS:
            v = self.__getattribute__(f)
            if v is not None:
                if isinstance(v, torch.Tensor):
                    batch_shapes.append(v.shape[:-1])
                elif is_tensordataclass(v):
                    batch_shapes.append(v.shape)
        if len(batch_shapes) == 0:
            raise ValueError("TensorDataclass must have at least one tensor")
        batch_shape = torch.broadcast_shapes(*batch_shapes)

        for f in FIELD_KEYS:
            v = self.__getattribute__(f)
            if v is not None:
                if isinstance(v, torch.Tensor):
                    self.__setattr__(f, v.broadcast_to((*batch_shape, v.shape[-1])))
                elif is_tensordataclass(v):
                    self.__setattr__(f, v.broadcast_to(batch_shape))

        self._shape = batch_shape

    def __getitem__(self, indices) -> cls:
        if self.is_packed():
            raise IndexError("Packed TensorDataClass can not be indexed!")
        if isinstance(indices, torch.Tensor):
            return self._apply_fn_to_fields(lambda x: x[indices])
        if isinstance(indices, (int, slice)):
            indices = (indices,)
        tensor_fn = lambda x: x[indices + (slice(None),)]
        dataclass_fn = lambda x: x[indices]
        return self._apply_fn_to_fields(tensor_fn, dataclass_fn)

    def __setitem__(self, indices, value) -> cls:
        raise RuntimeError("Index assignment is not supported for TensorDataclass")

    def __len__(self) -> int:
        return self.shape[0]

    def __bool__(self) -> bool:
        if len(self) == 0:
            raise ValueError(
                f"The truth value of {self.__class__.__name__} when `len(x) == 0` "
                "is ambiguous. Use `len(x)` or `x is not None`."
            )
        return True

    @property
    def shape(self) -> tuple:
        """Returns the batch shape of the tensor dataclass."""
        if self._shape is None:
            raise RuntimeError("Packed TensorDataClass does not have a shape defined!")
        else:
            return self._shape
    
    @property
    def size(self) -> int:
        """Returns the number of elements in the tensor dataclass batch dimension."""
        if len(self.shape) == 0:
            return 1
        return int(np.prod(self.shape))

    @property
    def ndim(self) -> int:
        """Returns the number of dimensions of the tensor dataclass."""
        return len(self.shape)

    def reshape(self, shape: Tuple[int, ...]) -> cls:
        """Returns a new TensorDataclass with the same data but with a new shape.

        Args:
            shape (Tuple[int]): The new shape of the tensor dataclass.

        Returns:
            TensorDataclass: A new TensorDataclass with the same data but with a new shape.
        """
        if self.is_packed():
            raise RuntimeError("Packed TensorDataclass can not be reshaped!")
        if isinstance(shape, int):
            shape = (shape,)
        tensor_fn = lambda x: x.reshape((*shape, x.shape[-1]))
        dataclass_fn = lambda x: x.reshape(shape)
        return self._apply_fn_to_fields(tensor_fn, dataclass_fn)

    def flatten(self) -> cls:
        """Returns a new TensorDataclass with flattened batch dimensions

        Returns:
            TensorDataclass: A new TensorDataclass with the same data but with a new shape.
        """
        if self.is_packed():
            raise RuntimeError("Packed TensorDataclass can not be flatten!")
        return self.reshape((-1,))

    def broadcast_to(self, shape: Union[torch.Size, Tuple[int]]) -> "TensorDataclass":
        """Returns a new TensorDataclass broadcast to new shape.

        Args:
            shape (Tuple[int]): The new shape of the tensor dataclass.

        Returns:
            TensorDataclass: A new TensorDataclass with the same data but with a new shape.
        """
        if self.is_packed():
            raise RuntimeError("Packed TensorDataclass can not be broadcasted!")
        return self._apply_fn_to_fields(lambda x: x.broadcast_to((*shape, x.shape[-1])))

    def to(self, device) -> cls:
        """Returns a new TensorDataclass with the same data but on the specified device.

        Args:
            device: The device to place the tensor dataclass.

        Returns:
            TensorDataclass: A new TensorDataclass with the same data but on the specified device.
        """
        return self._apply_fn_to_fields(lambda x: x.to(device))

    def _apply_fn_to_fields(
        self,
        fn: callable,
        dataclass_fn: callable = None,
        exclude_fields: List[str] = None,
        **kwargs,
    ) -> cls:
        """Applies a function to all fields of the tensor dataclass.

        Args:
            fn (callable): The function to apply to tensor fields.
            dataclass_fn (callable): The function to apply to TensorDataclass fields. Else use fn.
            exclude_fields (List[str]): The fields to be excluded from calling fn and dataclass_fn.
            **kwargs: Additional arguments to initialize the new TensorDataclass.

        Returns:
            cls: A new class with the same data but with a new shape.
        """
        if exclude_fields is None:
            exclude_fields = []

        field_names = [f for f in FIELD_KEYS if f not in exclude_fields]
        new_fields = {}
        for f in field_names:
            v = self.__getattribute__(f)
            if v is not None:
                if is_tensordataclass(v) and dataclass_fn is not None:
                    new_fields[f] = dataclass_fn(v)
                elif isinstance(v, torch.Tensor) or is_tensordataclass(v):
                    new_fields[f] = fn(v)
        new_fields.update(kwargs)

        return cls(**new_fields)

    def is_packed(self) -> bool:
        """Returns whether the data are packed."""
        return self._packed_info is not None

    setattr(cls, _IDENTIFER, {})
    setattr(cls, "__init__", __init__)
    setattr(cls, "__post_init__", __post_init__)
    setattr(cls, "__getitem__", __getitem__)
    setattr(cls, "__setitem__", __setitem__)
    setattr(cls, "__len__", __len__)
    setattr(cls, "__bool__", __bool__)
    setattr(cls, "shape", shape)
    setattr(cls, "size", size)
    setattr(cls, "ndim", ndim)
    setattr(cls, "reshape", reshape)
    setattr(cls, "flatten", flatten)
    setattr(cls, "broadcast_to", broadcast_to)
    setattr(cls, "to", to)
    setattr(cls, "_apply_fn_to_fields", _apply_fn_to_fields)
    setattr(cls, "is_packed", is_packed)
    return cls

@tensordataclass
class TestNestedClass():
    """Dummy dataclass"""

    x: torch.Tensor


@tensordataclass
class TestTensorDataclass():
    """Dummy dataclass"""

    a: torch.Tensor
    b: torch.Tensor
    c: TestNestedClass = None

a = torch.ones((4, 6, 3))
b = torch.ones((6, 2))
c = TestNestedClass(x=torch.ones(6, 5))
tensor_dataclass = TestTensorDataclass(a=a, b=b, c=c)
print (is_tensordataclass(tensor_dataclass))
print (is_tensordataclass(tensor_dataclass.c))
out = tensor_dataclass.to("cuda:0")
print (out.a.device)
print (out.c.x.device)

        

AttributeError: 'TestNestedClass' object has no attribute 'shape'

In [18]:
print (cls.__annotations__)

{'a': <class 'int'>, 'b': <class 'int'>}


In [23]:
print (getattr(cls, "a", None))

None


In [None]:
class _MISSING_TYPE:
    pass


MISSING = _MISSING_TYPE()

# pyline: disable=redefined-builtin
def _create_fn(name, args, body, *, globals=None, locals=None, return_type=MISSING):
    # Note that we mutate locals when exec() is called.  Caller
    # beware!  The only callers are internal to this module, so no
    # worries about external callers.
    if locals is None:
        locals = {}
    return_annotation = ""
    if return_type is not MISSING:
        locals["_return_type"] = return_type
        return_annotation = "->_return_type"
    args = ",".join(args)
    body = "\n".join(f"  {b}" for b in body)

    # Compute the text of the entire function.
    txt = f" def {name}({args}){return_annotation}:\n{body}"

    local_vars = ", ".join(locals.keys())
    txt = f"def __create_fn__({local_vars}):\n{txt}\n return {name}"

    ns = {}
    exec(txt, globals, ns)  # pyline: disable=exec-used
    return ns["__create_fn__"](**locals)

locals = {[name, type, ] for name, type in cls.__annotations__.items()}

getattr(cls, "a", None)

_create_fn('__init__',
                    [self_name] + [_init_param(f) for f in fields if f.init],
                    body_lines,
                    locals=locals,
                    globals=globals,
                    return_type=None)

In [24]:
__create_fn__

NameError: name '__create_fn__' is not defined

In [None]:


def tensordataclass(cls):

    def __init__(self, *args, **kwargs):
        
        

    run = getattr(cls, 'run')
    def new_run(self):
        print('before')
        run(self)
        print('after')
    setattr(cls, 'run', new_run)
    return cls


class Task(object): pass

@decorate_run
class MyTask(Task):
    def run(self):
        pass
