In [303]:
import typing as tp

import torch
import torch.nn as nn
import torch.nn.functional as F

### Simulation - StreamingTransformerEncoder's feedforward

In [304]:
# B T C audio data shape
x = torch.randn(2, 5, 4)
B, T, C = x.shape

x

tensor([[[ 2.0865, -0.2938, -1.6713, -0.7352],
         [ 0.5066, -1.7875,  2.4585, -0.8830],
         [ 1.3333, -0.8979,  1.5627, -1.3295],
         [ 0.4372, -0.2338,  0.8174,  0.5218],
         [ 0.3918, -1.0793,  0.3148, -0.3560]],

        [[-0.6961, -0.1982,  0.7532,  0.6678],
         [ 1.6565,  0.4308, -0.3309, -0.0367],
         [-1.4224, -0.2807,  1.0541,  0.3142],
         [ 2.3469, -1.4884,  0.3607, -1.8405],
         [-0.1464, -1.2865,  0.7448, -1.5723]]])

In [305]:
x[:, :1]

tensor([[[ 2.0865, -0.2938, -1.6713, -0.7352]],

        [[-0.6961, -0.1982,  0.7532,  0.6678]]])

In [306]:
states = [torch.zeros_like(x[:, :1]) for _ in range(3)]

states

[tensor([[[0., 0., 0., 0.]],
 
         [[0., 0., 0., 0.]]]),
 tensor([[[0., 0., 0., 0.]],
 
         [[0., 0., 0., 0.]]]),
 tensor([[[0., 0., 0., 0.]],
 
         [[0., 0., 0., 0.]]])]

In [307]:
offset = 0
positions = torch.arange(T).view(1, -1, 1) + offset

positions

tensor([[[0],
         [1],
         [2],
         [3],
         [4]]])

### Function - create_sin_embedding

In [308]:
v_dim = 4
half_dim = v_dim // 2
adim = torch.arange(half_dim).view(1, 1, -1)

adim

tensor([[[0, 1]]])

In [309]:
phase_division = 10000 ** (adim / (half_dim-1))

phase_division

tensor([[[1.0000e+00, 1.0000e+04]]])

In [310]:
phase = positions / (10000 ** (adim / (half_dim-1)))

phase

tensor([[[0.0000e+00, 0.0000e+00],
         [1.0000e+00, 1.0000e-04],
         [2.0000e+00, 2.0000e-04],
         [3.0000e+00, 3.0000e-04],
         [4.0000e+00, 4.0000e-04]]])

In [311]:
phase_cos = torch.cos(phase)

phase_cos

tensor([[[ 1.0000,  1.0000],
         [ 0.5403,  1.0000],
         [-0.4161,  1.0000],
         [-0.9900,  1.0000],
         [-0.6536,  1.0000]]])

In [312]:
phase_sin = torch.sin(phase)

phase_sin

tensor([[[ 0.0000e+00,  0.0000e+00],
         [ 8.4147e-01,  1.0000e-04],
         [ 9.0930e-01,  2.0000e-04],
         [ 1.4112e-01,  3.0000e-04],
         [-7.5680e-01,  4.0000e-04]]])

In [313]:
phase_concat = torch.concat([phase_cos, phase_sin], dim=-1)
phase_concat_opposite = torch.concat([phase_cos, phase_sin], dim=1)

phase_concat

tensor([[[ 1.0000e+00,  1.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 5.4030e-01,  1.0000e+00,  8.4147e-01,  1.0000e-04],
         [-4.1615e-01,  1.0000e+00,  9.0930e-01,  2.0000e-04],
         [-9.8999e-01,  1.0000e+00,  1.4112e-01,  3.0000e-04],
         [-6.5364e-01,  1.0000e+00, -7.5680e-01,  4.0000e-04]]])

In [314]:
def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float = 10000):
    assert dim % 2 == 0
    half_dim = dim // 2
    adim = torch.arange(half_dim, device=positions.device).view(1, 1, -1)
    phase = positions / (max_period ** (adim / (half_dim - 1)))
    return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1)

In [315]:
pos_emb = create_sin_embedding(positions, dim=C, max_period=10000)

pos_emb

tensor([[[ 1.0000e+00,  1.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 5.4030e-01,  1.0000e+00,  8.4147e-01,  1.0000e-04],
         [-4.1615e-01,  1.0000e+00,  9.0930e-01,  2.0000e-04],
         [-9.8999e-01,  1.0000e+00,  1.4112e-01,  3.0000e-04],
         [-6.5364e-01,  1.0000e+00, -7.5680e-01,  4.0000e-04]]])

In [316]:
pos_emb.shape

torch.Size([1, 5, 4])

### Simulation - StreamingTransformerEncoderLayer

In [317]:
v_queries_pos = torch.arange(2, 2 + 4, device=x.device).view(-1, 1)

v_queries_pos

tensor([[2],
        [3],
        [4],
        [5]])

In [318]:
v_keys_pos = torch.arange(2 + 4, device=x.device).view(1, -1)

v_keys_pos

tensor([[0, 1, 2, 3, 4, 5]])

### Class - StreamingTransformerEncoderLayer

In [319]:
class StreamingTransformerEncoderLayer(nn.TransformerEncoderLayer):
    """Create time embedding for the given position, target dimension `dim`"""
    def forward(self, x: torch.Tensor, x_past: torch.Tensor, past_context: int):
        if self.norm_first:
            sa_input = self.norm1(x)
            x = x + self._sa_block(sa_input, x_past, past_context)
            x = x + self._ff_block(self.norm2(x))
        else:
            sa_input = x
            x = self.norm1(x + self._sa_block(sa_input, x_past, past_context))
            x = self.norm2(x + self._ff_block(x))
        return x, sa_input
    
    def _sa_block(self, x: torch.Tensor, x_past: torch.Tensor, past_context: int):
        _, T, _ = x.shape
        _, H, _ = x_past.shape
        queries = x
        queries_pos = torch.arange(H, H + T, device=x.device).view(-1, 1)
        keys = torch.cat([x_past, x], dim = 1)
        keys_pos = torch.arange(H + T, device=x.device).view(1, -1)
        values = keys

        delta = queries_pos - keys_pos
        valid_access = (delta >= 0) & (delta <= past_context)

        x = self.self_attn(queries, keys, values, attn_mask=~valid_access, need_weights=False)[0]

        return self.dropout1(x)

### Class - StreamingTransformerEncoder

In [320]:
class StreamingTransformerEncoder(nn.Module):
    def __init__(
            self, 
            dim, 
            hidden_scale: int = 4,
            num_heads: int = 8, 
            num_layers: int = 5, 
            max_period: int = 10000,
            past_context: int = 1000,
            gelu: bool = True,
            norm_in: bool = True, 
            dropout: float = 0., 
            **kwargs
            ):
        super().__init__()
        assert dim % num_heads == 0
        hidden_dim = dim * hidden_scale

        self.max_period = max_period
        self.past_context= past_context
        activation: tp.Any = F.gelu if gelu else F.relu

        self.norm_in: nn.Module
        if norm_in:
            self.norm_in = nn.LayerNorm(dim)
        else:
            self.norm_in = nn.Identity()

        self.layers = nn.ModuleList()
        for idx in range(num_layers):
            self.layers.append(
                StreamingTransformerEncoderLayer(
                    dim, 
                    num_heads, 
                    hidden_dim, 
                    activation=activation, 
                    batch_first=True, 
                    dropout=dropout,
                    **kwargs
                    )
            )


    def forward(
            self, 
            x: torch.Tensor, 
            states: tp.Optional[tp.List[torch.Tensor]] = None, 
            offset: tp.Union[int, torch.Tensor] = 0
            ):
        B, C, T = x.shape
        positions = torch.arange(T, device=x.device).view(1, -1, 1) + offset
        pos_emb = create_sin_embedding(positions, C, self.max_period)

        x = x.permute(0, 2, 1)
        x = self.norm_in(x)
        x += pos_emb

        if states is None:
            states = [torch.zeros_like(x[:, :1]) for _ in range(len(self.layers) + 1)]
        new_states: tp.List[torch.Tensor] = []
        for state, layer in zip(states, self.layers):
            x, new_layer = layer(x, state, self.past_context)
            new_layer = torch.concat([state, new_layer], dim=1)
            new_states.append(new_layer[:, -self.past_context:, :])
        return x, new_states, offset + T

### Test with dummy data

In [321]:
model = StreamingTransformerEncoder(dim=16, num_heads=4, num_layers=3)

simulated_audio_data = torch.randn(2, 16, 10) # the data is in ordered B, C, T form

print(f"simulated_audio_data:\n{simulated_audio_data.shape}")
simulated_audio_data

simulated_audio_data:
torch.Size([2, 16, 10])


tensor([[[ 1.3661,  0.0818, -0.8510,  1.3325, -0.5567,  1.2054,  0.1599,
           1.2376,  1.3392, -0.5552],
         [-0.6186, -1.8180,  0.6838,  0.8326, -1.7848, -0.0643,  0.2513,
          -0.8455,  0.3437, -0.2087],
         [ 0.6285, -0.7177,  0.4643, -1.7050, -0.5364, -2.0309, -0.4000,
           0.0221,  0.8664,  1.0746],
         [-0.3609,  0.6172,  1.0386,  2.1425,  1.9318, -0.6743,  0.4873,
           0.8535,  0.8368,  0.0124],
         [-0.5369,  1.0483, -0.6455, -1.5342, -0.3790,  0.2303,  0.4899,
           0.8170, -0.3221, -0.2212],
         [-1.4927,  0.2404, -1.8879,  0.7581, -0.3032, -1.6557, -0.9824,
           1.7419, -0.1908,  0.4613],
         [-0.4685, -0.2954, -0.3005,  0.7489,  0.0460, -0.6783, -0.5357,
          -1.6337, -0.1599, -0.2351],
         [-0.3424,  1.2997, -0.9775, -1.3538, -0.5587,  0.4933,  1.5763,
          -0.9482, -1.4368, -0.7848],
         [-0.2884,  1.1783, -0.0507,  0.6311, -3.2857, -0.7116, -0.7674,
           0.5432,  0.3099,  0.4940],
 

In [322]:
output, new_states, new_offset = model(simulated_audio_data)
output = output.permute(0, 2, 1)

print(f"output:\n{output.shape}")
output

output:
torch.Size([2, 16, 10])


tensor([[[ 2.7096e+00,  2.2952e-01, -1.1216e+00,  6.7533e-01, -4.3854e-01,
           1.3946e+00,  7.4558e-01,  1.7598e+00,  1.3663e+00, -1.6166e+00],
         [ 1.4923e-01, -1.2565e+00,  8.1344e-01,  4.6468e-01, -1.7833e+00,
          -6.0504e-01, -8.4396e-01, -1.9806e+00, -1.2982e+00, -1.2926e+00],
         [ 1.0823e+00, -2.1431e-01,  1.1000e+00, -2.3361e-01,  6.6281e-02,
          -4.8274e-01, -4.5998e-01,  1.0350e-01,  9.9526e-01,  1.2699e+00],
         [-4.1495e-01,  1.0108e+00,  1.5395e+00,  1.3699e+00,  1.4542e+00,
           3.6927e-01,  7.7336e-01,  3.8977e-01,  3.4878e-01,  7.5649e-01],
         [-6.1157e-01,  1.0671e+00, -6.6021e-01, -1.5160e+00, -5.2035e-02,
           4.9523e-01,  4.7828e-01,  9.6017e-01, -9.3844e-01, -4.0545e-02],
         [-1.4059e+00,  2.3096e-02, -1.7226e+00,  4.5063e-01, -2.4989e-02,
          -8.2499e-01, -1.6819e+00,  1.4045e+00, -6.0121e-01,  8.9652e-01],
         [ 3.8415e-01, -1.2221e-01,  4.1802e-01,  1.0778e+00,  2.9510e-01,
          -1.8279e-

In [323]:
print(f"\n\nnew_states:\n{len(new_states)}")
print(f"\n\nnew_offset:\n{new_offset}")



new_states:
3


new_offset:
10
