In [85]:
import torch
from dataclasses import dataclass
from typing import List, Union


@dataclass
class MultiModeState:
    """time-dependent hybrid state (t, x, k, mask)"""

    time: torch.Tensor = None
    continuous: torch.Tensor = None
    discrete: torch.Tensor = None
    mask: torch.Tensor = None

    def __post_init__(self):
        self._set_mode_types()
        # self._reshape_time(self)

    def reshape_time_like(self, state: Union[torch.Tensor, MultiModeState]):
        self.time = self.time.reshape(-1, *([1] * (state.ndim - 1)))

    def to(self, device: str):
        return self._apply_fn(lambda tensor: tensor.to(device))
    def detach(self):
        return self._apply_fn(lambda tensor: tensor.detach())
    def cpu(self):
        return self._apply_fn(lambda tensor: tensor.cpu())
    def clone(self):
        return self._apply_fn(lambda tensor: tensor.clone())

    def apply_mask(self):
        self.time = self.time * self.mask
        self.continuous = self.continuous * self.mask
        self.discrete = (self.discrete * self.mask).long()

    @property
    def ndim(self):
        if not len(self.available_modes()):
            return 0
        return len(getattr(self, self.available_modes()[-1]).shape)

    @property
    def shape(self):
        if self.ndim > 0:
            return getattr(self, self.available_modes()[-1]).shape[:-1]
        return None

    def __len__(self):
        if self.ndim > 0:
            return len(getattr(self, self.available_modes()[-1]))
        return 0
        
    @staticmethod
    def cat(states: List["MultiModeState"], dim=0) -> "MultiModeState":
        def cat_attr(attr_name):
            attrs = [getattr(s, attr_name, None) for s in states]
            if all(a is None for a in attrs):
                return None
            attrs = [a for a in attrs if a is not None]
            return torch.cat(attrs, dim=dim)

        return MultiModeState(
            time=cat_attr("time"),
            continuous=cat_attr("continuous"),
            discrete=cat_attr("discrete"),
            mask=cat_attr("mask"),
        )

    def available_modes(self,) -> List[str]:
        """Return a list of non-None modes in the state."""
        available_modes = []
        if self.time is not None:
            available_modes.append("time")
        if self.continuous is not None:
            available_modes.append("continuous")
        if self.discrete is not None:
            available_modes.append("discrete")
        return available_modes

    def _apply_fn(self, func):
        """apply transformation function to all attributes."""
        return MultiModeState(
            time=func(self.time) if isinstance(self.time, torch.Tensor) else None,
            continuous=func(self.continuous)
            if isinstance(self.continuous, torch.Tensor)
            else None,
            discrete=func(self.discrete)
            if isinstance(self.discrete, torch.Tensor)
            else None,
            mask=func(self.mask) if isinstance(self.mask, torch.Tensor) else None,
        )

    def _set_mode_types(self):
        if self.time is not None:
            self.time = self.time.float()
        if self.continuous is not None:
            self.continuous = self.continuous.float()
        if self.discrete is not None:
            self.discrete = self.discrete.long()
        if self.mask is not None:
            self.mask = self.mask.float()


In [86]:

state = MultiModeState()

print(state.available_modes())
print(state.ndim)
print(state.shape)
print(len(state))



[]
0
None
0


In [87]:
state_ = MultiModeState(

time=torch.randn(100, 128, 3),
continuous=torch.randn(100, 128, 3),
discrete=torch.randint(0, 8, (100, 128, 2)),
mask=torch.cat([torch.ones(100, 64, 1), torch.zeros(100, 64, 1)], dim=1), 
)

In [88]:
state = MultiModeState(

time=torch.randn(100),
continuous=torch.randn(100, 128, 3),
discrete=torch.randint(0, 8, (100, 128, 2)),
mask=torch.cat([torch.ones(100, 64, 1), torch.zeros(100, 64, 1)], dim=1), 
)

In [90]:
state.reshape_time_like(state_)
state.time.shape

torch.Size([100, 1, 1])

In [5]:
voc = [8,3]
k = []
torch.arange(0, 7) 

tensor([0, 1, 2, 3, 4, 5, 6])

In [6]:
x = torch.tensor([[
                   [1, 2, 3], 
                   [4, 5, 6]],

                  [[7, 8, 9], 
                  [10, 11, 12]]
                  ])  
                  
                  # Shape (2, 2, 3)

y = torch.tensor([[[2], 
                   [5]], 
                  [[8], 
                   [11]]])  # Shape (2, 2, 1)


In [7]:
x==y

tensor([[[False,  True, False],
         [False,  True, False]],

        [[False,  True, False],
         [False,  True, False]]])

In [2]:
import torch 

k = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7], dtype=torch.float32)

# Add singleton dimensions (unsqueeze) for the two new dimensions
# k = k.unsqueeze(0).unsqueeze(0)

# Expand along the new dimensions
result = k.expand(5, 6, -1)  # Expands to (5, 6, 8)

print(result)

tensor([[[0., 1., 2., 3., 4., 5., 6., 7.],
         [0., 1., 2., 3., 4., 5., 6., 7.],
         [0., 1., 2., 3., 4., 5., 6., 7.],
         [0., 1., 2., 3., 4., 5., 6., 7.],
         [0., 1., 2., 3., 4., 5., 6., 7.],
         [0., 1., 2., 3., 4., 5., 6., 7.]],

        [[0., 1., 2., 3., 4., 5., 6., 7.],
         [0., 1., 2., 3., 4., 5., 6., 7.],
         [0., 1., 2., 3., 4., 5., 6., 7.],
         [0., 1., 2., 3., 4., 5., 6., 7.],
         [0., 1., 2., 3., 4., 5., 6., 7.],
         [0., 1., 2., 3., 4., 5., 6., 7.]],

        [[0., 1., 2., 3., 4., 5., 6., 7.],
         [0., 1., 2., 3., 4., 5., 6., 7.],
         [0., 1., 2., 3., 4., 5., 6., 7.],
         [0., 1., 2., 3., 4., 5., 6., 7.],
         [0., 1., 2., 3., 4., 5., 6., 7.],
         [0., 1., 2., 3., 4., 5., 6., 7.]],

        [[0., 1., 2., 3., 4., 5., 6., 7.],
         [0., 1., 2., 3., 4., 5., 6., 7.],
         [0., 1., 2., 3., 4., 5., 6., 7.],
         [0., 1., 2., 3., 4., 5., 6., 7.],
         [0., 1., 2., 3., 4., 5., 6., 7.],
     

In [96]:
a = (*tuple(state.shape), -1)
a

(100, 128, -1)

In [68]:
print(state.available_modes())
print(state.ndim)
print(state.shape)
print(len(state))

['time', 'continuous', 'discrete']
3
torch.Size([100, 128])
100


In [69]:
state.time.shape

torch.Size([100])

In [59]:
state.mask

tensor([[[1.],
         [1.],
         [1.],
         ...,
         [0.],
         [0.],
         [0.]],

        [[1.],
         [1.],
         [1.],
         ...,
         [0.],
         [0.],
         [0.]],

        [[1.],
         [1.],
         [1.],
         ...,
         [0.],
         [0.],
         [0.]],

        ...,

        [[1.],
         [1.],
         [1.],
         ...,
         [0.],
         [0.],
         [0.]],

        [[1.],
         [1.],
         [1.],
         ...,
         [0.],
         [0.],
         [0.]],

        [[1.],
         [1.],
         [1.],
         ...,
         [0.],
         [0.],
         [0.]]])

In [61]:
state.apply_mask()

In [62]:
state.continuous[0]

tensor([[-1.6418e+00, -1.3978e-02,  4.8699e-01],
        [-1.9695e-01, -2.5602e-01, -2.1607e+00],
        [ 7.9536e-03, -7.3961e-01,  9.6487e-01],
        [-9.6108e-01, -2.2165e+00, -1.3605e+00],
        [ 3.6764e-01,  1.7479e+00, -3.6501e-02],
        [ 3.2490e-01, -1.0720e-02,  7.8410e-01],
        [ 1.4839e-01, -1.6486e-01, -7.4412e-01],
        [-5.1665e-02,  1.6175e-02,  1.4235e+00],
        [ 1.1597e-02,  6.6513e-01,  4.2468e-01],
        [ 1.6826e+00, -6.0148e-01, -1.6948e-01],
        [ 4.0913e-01, -2.7351e-01, -8.5302e-02],
        [-7.4374e-01,  6.9258e-01,  8.2209e-01],
        [ 8.6051e-02,  3.5968e-01,  7.0999e-01],
        [ 4.9639e-01,  5.7044e-01, -7.0965e-02],
        [-1.5045e+00, -2.7658e-03, -1.7672e-02],
        [ 9.5448e-02,  8.6752e-01,  1.1721e+00],
        [ 6.3235e-01, -8.2718e-02, -2.4021e+00],
        [ 8.3112e-01,  1.4420e+00,  2.0850e-01],
        [ 1.9707e+00,  5.8688e-01,  4.2737e-01],
        [-1.0123e+00, -7.9161e-01,  6.2866e-01],
        [ 3.2591e-01

In [19]:
print(state.available_modes())
print(state.ndim)
print(state.shape)
print(len(state))

['time', 'continuous', 'discrete']
3
torch.Size([100, 128])
100


In [None]:
reshape_time_like