In [21]:
from dataclasses import dataclass

import torch

from moshi.modules.streaming import StreamingModule, StreamingAdd, StreamingContainer

In [6]:
streaming_add = StreamingAdd()

x = torch.tensor([1, 1])
y = torch.tensor([1, 2])
out = streaming_add.forward(x, y)
print(out)

tensor([3])


In [9]:
B = 1
with streaming_add.streaming(B):
    for i in range(4):
        out = streaming_add.forward(x, y)
        print(streaming_add._streaming_state)
        print(out)

# B = 1
# with streaming_add.streaming(B):
#     for i in range(4):
#         out = streaming_add.forward(x, y)
#         print(out)


_StreamingAddState(previous_x=tensor([], dtype=torch.int64), previous_y=tensor([], dtype=torch.int64))
tensor([3])
_StreamingAddState(previous_x=tensor([], dtype=torch.int64), previous_y=tensor([], dtype=torch.int64))
tensor([3])
_StreamingAddState(previous_x=tensor([], dtype=torch.int64), previous_y=tensor([], dtype=torch.int64))
tensor([3])
_StreamingAddState(previous_x=tensor([], dtype=torch.int64), previous_y=tensor([], dtype=torch.int64))
tensor([3])


### Simple custom streaming module

In [16]:
@dataclass
class _State:
    inputs: torch.Tensor | None = None

def reset(self):
    self.inputs = None


class StreamingInputs(StreamingModule[_State]):
    def _init_streaming_state(self, batch_size: int) -> _State:
        return _State()

    def forward(self, x):
        if self._streaming_state is None:
            return x
        else:
            prev_state = self._streaming_state.inputs
            if prev_state is not None:
                x = torch.cat([prev_state, x], dim=-1)
            self._streaming_state.inputs = x
            return x

In [17]:
B = 1
streaming_inputs = StreamingInputs()
with streaming_inputs.streaming(B):
    for i in range(4):
        out = streaming_inputs.forward(x)
        print(streaming_inputs._streaming_state)
        print(out)

_State(inputs=tensor([1]))
tensor([1])
_State(inputs=tensor([1, 1]))
tensor([1, 1])
_State(inputs=tensor([1, 1, 1]))
tensor([1, 1, 1])
_State(inputs=tensor([1, 1, 1, 1]))
tensor([1, 1, 1, 1])


In [19]:
for i in range(4):
    out = streaming_inputs.forward(x)
    print(streaming_inputs._streaming_state)
    print(out)

None
tensor([1])
None
tensor([1])
None
tensor([1])
None
tensor([1])


### Streaming module with children

In [22]:
class ParentStreamingModule(StreamingContainer):
    # We use StreamingContainer s.t. we can stream the submodules recursively, even though
    # this parent class has no streaming state itself
    def __init__(self):
        super().__init__()
        # We define two streaming submodules
        self.streaming_x = StreamingInputs()
        self.streaming_y = StreamingInputs()
    

    def forward(self, x, y):
        x = self.streaming_x(x)
        y = self.streaming_y(y)
        return x + y

In [28]:
B = 1
streaming_parent = ParentStreamingModule()
for i in range(4):
    x = torch.randint(0 , 10, (B, 1))
    y = torch.randint(0 , 10, (B, 1))
    out = streaming_parent.forward(x, y)
    # Parent class has no straming state --> None
    print(streaming_inputs._streaming_state)
    # But we can get the streaming state of the submodules recursively
    print(streaming_parent.get_streaming_state())
    print(x, y, out)

None
{'': None, 'streaming_x': None, 'streaming_y': None}
tensor([[8]]) tensor([[8]]) tensor([[16]])
None
{'': None, 'streaming_x': None, 'streaming_y': None}
tensor([[7]]) tensor([[5]]) tensor([[12]])
None
{'': None, 'streaming_x': None, 'streaming_y': None}
tensor([[4]]) tensor([[5]]) tensor([[9]])
None
{'': None, 'streaming_x': None, 'streaming_y': None}
tensor([[1]]) tensor([[0]]) tensor([[1]])


In [29]:
B = 1
streaming_parent = ParentStreamingModule()
with streaming_parent.streaming(B):
    for i in range(4):
        x = torch.randint(0 , 10, (B, 1))
        y = torch.randint(0 , 10, (B, 1))
        out = streaming_parent.forward(x, y)
        # Parent class has no straming state --> None
        print(streaming_inputs._streaming_state)
        # But we can get the streaming state of the submodules recursively
        print(streaming_parent.get_streaming_state())
        print(x, y, out)

None
{'': _NullState(), 'streaming_x': _State(inputs=tensor([[3]])), 'streaming_y': _State(inputs=tensor([[5]]))}
tensor([[3]]) tensor([[5]]) tensor([[8]])
None
{'': _NullState(), 'streaming_x': _State(inputs=tensor([[3, 8]])), 'streaming_y': _State(inputs=tensor([[5, 8]]))}
tensor([[8]]) tensor([[8]]) tensor([[ 8, 16]])
None
{'': _NullState(), 'streaming_x': _State(inputs=tensor([[3, 8, 5]])), 'streaming_y': _State(inputs=tensor([[5, 8, 9]]))}
tensor([[5]]) tensor([[9]]) tensor([[ 8, 16, 14]])
None
{'': _NullState(), 'streaming_x': _State(inputs=tensor([[3, 8, 5, 6]])), 'streaming_y': _State(inputs=tensor([[5, 8, 9, 6]]))}
tensor([[6]]) tensor([[6]]) tensor([[ 8, 16, 14, 12]])
