# Week 12: ASR Inference

In this seminar we are going to implement chunked-streaming for Conformer model Encoder.

Some useful links (it is not *necessary* to read them right now &mdash; but they are here if you need them):

* [Conformer paper](https://arxiv.org/pdf/2005.08100)
* [Chunked streaming paper](https://arxiv.org/pdf/2312.17279)
* [NeMo repository](https://github.com/NVIDIA/NeMo) &mdash; source of model weights and basic idea for our seminar
* [Google streaming kws](https://github.com/google-research/google-research/tree/master/kws_streaming) &mdash; Implementation of streaming architectures for keyword spotting tasks by Google


## 0. Preparation.

First we are going to install necessary libraries and import them.


In [None]:
# note: almost any version of pytorch will do. Cpu is enough for this seminar.

!pip install librosa==0.10.1
!pip install torch==2.2.1 torchvision==0.17.1 torchaudio==2.2.1 --index-url https://download.pytorch.org/whl/cpu

In [None]:
import json
import librosa
import math
import numpy as np
import os
import pickle
import queue
import requests
import torch
import torch.nn
import torch.nn.functional as F
import wave

from IPython.display import Audio
from typing import Callable, Optional
from urllib.parse import urlencode

In [None]:
# this may speed up cpu-inference
torch.set_num_threads(1)
torch.set_num_interop_threads(1)

## 1. Basic streaming framework

How are we going to implement streaming? Layer-by-layer.

We will add a couple of methods to each layer that needs streaming:

1) `streaming_forward`: Streaming forward will take an additional argument &mdash; `state`, which will be passed from call to call (in sort of an autoregressive manner). `state` can be anything, depending on a layer: tensor, list of tensors or some more complicated structure.

2) `get_initial_state`:  This method will return state, that we will pass to the first invocation of `streaming_forward`.
So, overall streaming layer (and model) usage will look something like this:

```python
state = model.get_initial_state()
for chunk in chunk_iterator:
    output, state = model.streaming_forward(chunk, state)
    process_output(output)
```

#### Simplification:
Right now we will implement streaming for `batch_size = 1`.



## 2. First example

Let's look at an example and implement streaming for 1D causal convolution.

Causal convolution is basically a convolution with $left\_padding = kernel\_size - 1$ and $right\_padding = 0$

<img src="./images/CausalConv.png" style="margin-left:auto; margin-right:auto; height: 200px; width: auto" />

In order to produce an output at frame $i$, we need not only current frame $x_i$, but also $kernel\_size - 1$ previous frames: 
$x_{i - kernel\_size + 1} \ldots, x_{i - 1}$

(if $i < kernel\_size - 1$ then some of these frames are set to zero &mdash; perfectly emulating left padding)

So let's store these previous $kernel\_size - 1$ frames in our state! Initial state then will be just $kernel\_size - 1$ zeros

Go ahead and implement it!

In [None]:
class CausalConv1D(torch.nn.Conv1d):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        stride: int,
        groups: int = 1,
        bias: bool = True,
    ):
        self._in_channels = in_channels
        self._left_padding = kernel_size - 1
        self._right_padding = 0
        self._stride = stride

        super().__init__(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=0,
            groups=groups,
            bias=bias,
        )

    def forward(self, x):
        x = F.pad(x, pad=(self._left_padding, self._right_padding))
        return super().forward(x)

    def get_initial_state(self):
        """
        Returns:
            (torch.Tensor): [1, in_channels, kernel_size - 1] - initial state (all zeros)
        """
        # Note: do not forget to set right device and dtype
        device = self.weight.device
        dtype = self.weight.dtype

        # Your code goes here:
        raise NotImplementedError()

    def streaming_forward(self, x, state):
        """
        Args:
            x (torch.Tensor): [1, in_channels, time] - chunk of data.
                Time should be divisible by stide.
            state (torch.Tensor): [1, in_channels, kernel_size - 1] previous input values.
        Returns:
            tuple[torch.Tensor, torch.Tensor] - output and new state.
        """

        # we demand that each input chunk time dimention is divisible by stride
        # for the sake of clarity and simplicity
        assert x.size(2) % self._stride == 0

        # Your code goes here:
        raise NotImplementedError()

Let's run a simple test:

In [None]:
def test_streaming(
    layer: torch.nn.Module,
    input_shape: tuple,
    time_dimention: int,
    chunk_size: int,
    tolerance: float = 1e-5,
    additional_inputs_fn: Optional[Callable[torch.Tensor, list[torch.Tensor]]] = None,
):
    with torch.no_grad():
        assert input_shape[time_dimention] % chunk_size == 0
        test_input = torch.tensor(np.random.randn(*input_shape).astype(np.float32))
        try:
            # if layer has parameters, we should move input to the same device
            device = next(iter(layer.parameters())).device
            test_input.to(device)
        except StopIteration:
            pass

        if additional_inputs_fn is None:
            regular_output = layer(test_input)
        else:
            # Some layers have additional inputs (e.g. lengths)
            # which could be mocked in test settings via additional_inputs_fn
            additional_inputs = additional_inputs_fn(test_input)
            regular_output = layer(test_input, *additional_inputs)

        if isinstance(regular_output, tuple):
            # some layers output several tensors, for test purposes we only need
            # the first one
            regular_output = regular_output[0]

        
        streaming_output = []
        state = layer.get_initial_state()
        for chunk_start in range(0, input_shape[time_dimention], chunk_size):
            indices = [slice(None) for _ in range(len(input_shape))]
            indices[time_dimention] = slice(chunk_start, chunk_start + chunk_size)
            step_input = test_input[indices]
            step_output, state = layer.streaming_forward(step_input, state)
            streaming_output.append(step_output)
        streaming_output = torch.cat(streaming_output, axis=time_dimention)
        assert streaming_output.shape == regular_output.shape, (streaming_output.shape, regular_output.shape)
        assert torch.abs(streaming_output - regular_output).max() < tolerance
    print('Test OK')

In [None]:
test_streaming(
    layer=CausalConv1D(in_channels=3, out_channels=5, kernel_size=9, stride=1),
    input_shape=(1, 3, 16),
    time_dimention=2,
    chunk_size=8
)

Congratulations! You have successfully implemented a streaming layer!

## 3. Implementing other streaming layers

In "chunked"-conformer architecture, we have several types of layers:

1) Activations, linear layers, convolutions with kernel_size=1 and other "pointwise" layers.
2) CausalConv1D
3) CausalConv2D
4) RelPosMultiHeadAttention

We do not need to do anything with 1) &mdash; these layers work as-is in streaming mode.

We have already implemented streming mode for 2).

Now let's implement streaming mode for 3) and 4) !

### 3.1 CausalConv2D.

This is basically the same as CausalConv1D, but with extra dimention. Note, that this extra-dimention is not time-related, so we almost do not need to care about it for streaming purposes.

In [None]:
class CausalConv2D(torch.nn.Conv2d):
    """
    A causal version of nn.Conv2d where each location in the 2D matrix would have no access to locations on its right or down
    All arguments are the same as nn.Conv2d except padding which should be set as None
    """

    def __init__(
        self,
        in_feats: int,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        stride: int = 1,
        groups: int = 1,
    ) -> None:
        self._in_feats = in_feats
        self._in_channels = in_channels
        self._stride = stride

        # Side note: originally (in NeMo repo) right_padding = bottom_padding = stride - 1
        # but we change right_padding to 0 for better streaming consistency
        # and keep _bottom_padding at stride - 1 to have matching weights shape
        self._left_padding = kernel_size - 1
        self._right_padding = 0
        self._top_padding = kernel_size - 1
        self._bottom_padding = stride - 1

        super(CausalConv2D, self).__init__(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=0,
            groups=groups,
        )

    def forward(self, x):
        """
        Args:
            x (torch.Tensor): [batch, channels, time, in_feats]
        Returns:
            torch.Tensor - output
        """
        x = F.pad(x, pad=(self._top_padding, self._bottom_padding, self._left_padding, self._right_padding))
        x = super().forward(x)
        return x

    def get_initial_state(self):
        """
        Returns:
            (???): ??? - initial state
        """
        # Your code goes here:
        raise NotImplementedError()

    def streaming_forward(self, x, state):
        """
        Args:
            x (torch.Tensor): [1, in_channels, time, in_feats].
                Time should be divisible by stride.
            state (???): ???
        Returns:
            tuple[torch.Tensor, ???] - output and new state.
        """
        assert x.shape[2] % self._stride == 0

        # Your code goes here:
        raise NotImplementedError()


In [None]:
test_streaming(
    layer=CausalConv2D(in_channels=3, in_feats=7, out_channels=5, kernel_size=3, stride=2),
    input_shape=(1, 3, 16, 7),
    time_dimention=2,
    chunk_size=4
)

### 3.2 RelPosMultiHeadAttention


#### 3.2.1 Recap

First, let's recap how a (non-streaming) chunked relative position multi-head self-attention works.

For simplicity of notation, in this recap we will assume batch=1 and num_heads=1.

RelPosMultiHeadAttention performs following steps:

1. Given vectors $x_0, \cdots, x_{t - 1}$ transform them to three other "sets" of vectors: $q, k, v$

<img src="./images/Attention-QKV.png" style="margin-left:auto; margin-right:auto; height: 200px; width: auto" />

2. Using vectors $q$, pre-computed vector $U$ and vectors $k$, compute first attention matrix AC: $AC_{ij} := (q_i + U) \cdot k_j^T$

<img src="./images/Attention-AC.png" style="margin-left:auto; margin-right:auto; height: 300px; width: auto" />

3. (start of "relative position" part): Given relative positional embeddings $PE_{t - 1}, \ldots, PE_0, \ldots PE_{-(t - 1)}$ transform them (linearly) to vectors $p_{t - 1}, \ldots, p_{-(t - 1)}$

<img src="./images/Attention-PE.png" style="margin-left:auto; margin-right:auto; height: 150px; width: auto" />


4. Using vectors $q$, pre-computed vector $V$ (do not confuse this for value vectors $v$) and vectors $p$ compute matrix BD: $BD_{ij} := (q_i + V) \cdot p_{i - j}^T$

    4.1. To get BD we can first compute preliminary matrix BD': $BD'_{ij} := (q_i + V) \cdot p_j^T$
    <img src="./images/Attention-BD'.png" style="margin-left:auto; margin-right:auto; height: 700px; width: auto" />
    <h5 align="center">(Dark-colored squares represent BD matrix we want to extract) </h5>
    4.2. And then we extract BD matrix from BD'. We will talk about specific operation a bit later

<img src="./images/Attention-BD.png" style="margin-left:auto; margin-right:auto; height: 250px; width: auto" />

5. We compute attention scores matrix $A := \frac{AC + BD}{\sqrt{d_k}}$, where $d_k$ is dimention of one attention head.

6. (start of "chunked" part). What does "chunked" attention mean? It means, that we split all input into chunks, and each query vector will attend only to key/value vectors either in the same chunk, or in the several previous chunks:

<img src="./images/Chunked-paper.png" style="margin-left:auto; margin-right:auto; height: 250px; width: auto" />
<h5 align="center">Note: this image is taked from <a href="https://arxiv.org/pdf/2312.17279">paper</a>, so it's notation is a bit different from all other images </h5>

We can achieve this result by using a mask for attention scores (in this example chunk_size = 2 and each element of chunk will also attend to previous 2 chunks):

<img src="./images/Mask.png" style="margin-left:auto; margin-right:auto; height: 250px; width: auto" />
<h5 align="center">In this mask 1 means valid value and 0 means invalid. In the implementation we will use an inverse notation, with <span>True</span> indicating that value should be masked.</h5>

Implementation detail: since mask will be identical for all MHA layers in the model, it will be computed externally (at the top-level of our model) and passed to MHA layer as an argument.

7. What's left is basic attention stuff: mask attention scores, use softmax, combine with values

```
maked_A = A.masked_fill(~Mask, -1e4)
attn = softmax(maked_A).masked_fill(~Mask, 0.0)
result = matmul(attn, v)
```
8. Combine all the heads


Note: in step 3 we have taken relative positional embeddings for distances $-(t - 1), \ldots, t - 1$. Although we only need embeddings for distances $ - (left\_context\_in\_chunks + 1) \cdot chunk\_size + 1, \ldots, chunk\_size - 1$ (other values will be canceled by mask) - extracting BD matrix is a bit easier with the whole range of positional embeddings.

If you want, you may try to optimize this.


### 3.2.2 Streaming RelPosMultiHeadAttention - ideas

So, let's figure out how to stream RelPosMultiHeadAttention layer.

Let's denote chunk_size as $C$ and left context (number of chunks to attend to on the left) as $L$.

Imagine we have several (maybe 1) new chunks of input $x_{t}, \ldots, x_{t + m \cdot C - 1}$.

We can compute queries, keys and values for these input $q_{t}, \ldots, q_{t + m \cdot C - 1}; k_{t}, \ldots, k_{t + m \cdot C - 1}; v_{t}, \ldots v_{t + m \cdot C - 1}$.


In order to compute our attention matrices (AC and BD) we need

* keys for previous $L \cdot C$ inputs
* relative positional embeddings for distances $C - 1, \ldots, -(L + 1) \cdot C + 1$

Positional embeddings could be provided by caller (same as in case for non-streaming inference), but we have to get previous keys from somewhere &mdash; so let's store them in our state.

After computing attention matrices we need to apply mask, use softmax and combine attention probabilities with values.

As in case with keys we need to access values for previous $L \cdot C$ inputs &mdash; we can also store them in our state.

Mask is basically identical to the one in non-streaming version. However, we need to consider early time frames, where there are less then $L \cdot C$ previous inputs. There are two ways to handle this:

* Either have store variable-length tensors of previous keys and values in the state, where state size grows from 0 to $L \cdot C$
* Or have state store tensors of constant size ($L \cdot C$), but modify the mask, so that we do not attend to non-existent inputs. In this case, we should keep track of number of processed inputs so far (so we would know what to mask). We may store them in a state &mdash; but not the state of attention layer, but the state of layer forming the mask.

Variable-length tensors could be tricky to work with: e.g, if one were to convert their streaming model to other format (onnx, tflite, etc.), having variable-length tensors would make the process harder, if not impossible. So let's stick with the latter approach.

#### To sum up
Our state consists of:
* Previous $L \cdot C$ keys
* Previous $L \cdot C$ values

We should also store number of processed inputs in the state of a layer that creates the mask.

### 3.2.3 Streaming RelPosMultiHeadAttention - implementation preparation 

Let's start our implementation with `create_streaming_mask` function &mdash; since mask is the same for all attention layers in our architecture, we will pass it to `streaming_forward` function - the same way mask is passed to `forward` function.

First, let's look at `create_attn_mask` function:

In [None]:
def create_attn_mask(chunk_size: int, left_chunks_num: int, input_size: int, device: torch.device):
    """
    Args:
        chunk_size (int): chunk size
        left_chunks_num (int): number of chunks to attend to to the left
        input_size (int): number of inputs. Should be divisible by chunk size.
        device (torch.device): device to store mask on.
    Returns:
        torch.Tensor: [1, input_size, input_size], bool, True means value should be masked.
    """

    assert input_size % chunk_size == 0

    # chunk_idx is tensor of shape [input_size]
    chunk_idx = torch.arange(0, input_size, dtype=torch.int, device=device)
    chunk_idx = torch.div(chunk_idx, chunk_size, rounding_mode="floor")

    # diff_chunks is tensor of shape [input_size, input_size]: diff_chunks[i, j] = chunk_idx[i] - chunk_idx[j]
    diff_chunks = chunk_idx.unsqueeze(1) - chunk_idx.unsqueeze(0)

    mask = torch.logical_and(
        torch.le(diff_chunks, left_chunks_num), torch.ge(diff_chunks, 0)
    )
    return ~mask.unsqueeze(0)

Let's look at an example

In [None]:
create_attn_mask(chunk_size=2, left_chunks_num=2, input_size=10, device=torch.device('cpu'))

Let's look at some streaming mask examples:

1) `processed_inputs` is low (we are at the beginning of an utterance)

<img src="./images/StreamingMaskStart.png" style="margin-left:auto; margin-right:auto; height: 600px; width: auto" />

2) `processed_inputs` is high (we are in the middle of an utterance)

<img src="./images/StreamingMaskMiddle.png" style="margin-left:auto; margin-right:auto; height: 600px; width: auto" />

Now it's your turn to implement `create_streaming_attn_mask`!

In [None]:
def create_streaming_attn_mask(
    chunk_size: int,
    left_chunks_num: int,
    new_inputs_size: int,
    processed_inputs: int,
    device: torch.device
):
    """
    Args:
        chunk_size (int): chunk size
        left_chunks_num (int): number of chunks to attend to on the left
        new_inputs_size (int): number of new inputs to process. Should be divisible by chunk size.
        processed_inputs (int): number of inputs already processed. Should be divisible by chunk size.
        device (torch.device): device to store mask on.
    Returns:
        torch.Tensor: [1, new_inputs_size, chunk_size * left_chunks_num + new_inputs_size], bool, True means value should be used.
    """
    assert new_inputs_size % chunk_size == 0
    assert processed_inputs % chunk_size == 0

    # Your code goes here...
    raise NotImplementedError()


In [None]:
# let's run some tests.
# converting result from bool to int here to make constants more compact

assert (create_streaming_attn_mask(
    chunk_size=2, 
    left_chunks_num=2,
    new_inputs_size=4,
    processed_inputs=0,
    device=torch.device('cpu')
).int() == torch.tensor(
    [
        [1, 1, 1, 1, 0, 0, 1, 1],
        [1, 1, 1, 1, 0, 0, 1, 1],
        [1, 1, 1, 1, 0, 0, 0, 0],
        [1, 1, 1, 1, 0, 0, 0, 0],
    ]
).unsqueeze(0)).all()

assert (create_streaming_attn_mask(
    chunk_size=1, 
    left_chunks_num=5,
    new_inputs_size=3,
    processed_inputs=3,
    device=torch.device('cpu')
).int() == torch.tensor(
    [
        [1, 1, 0, 0, 0, 0, 1, 1],
        [1, 1, 0, 0, 0, 0, 0, 1],
        [1, 1, 0, 0, 0, 0, 0, 0],
    ]
).unsqueeze(0)).all()

assert (create_streaming_attn_mask(
    chunk_size=2, 
    left_chunks_num=2,
    new_inputs_size=4,
    processed_inputs=10,
    device=torch.device('cpu')
).int() == torch.tensor(
    [
        [0, 0, 0, 0, 0, 0, 1, 1],
        [0, 0, 0, 0, 0, 0, 1, 1],
        [1, 1, 0, 0, 0, 0, 0, 0],
        [1, 1, 0, 0, 0, 0, 0, 0],
    ]
).unsqueeze(0)).all()

assert (create_streaming_attn_mask(
    chunk_size=2,
    left_chunks_num=2,
    new_inputs_size=4,
    processed_inputs=6,
    device=torch.device('cpu')
).int() == torch.tensor(
    [
        [0, 0, 0, 0, 0, 0, 1, 1],
        [0, 0, 0, 0, 0, 0, 1, 1],
        [1, 1, 0, 0, 0, 0, 0, 0],
        [1, 1, 0, 0, 0, 0, 0, 0],
    ]
).unsqueeze(0)).all()

Now let's implement extraction of BD from BD'.

As we've already mentioned, although we only need positional embeddings for distances $C - 1, \ldots, -((L + 1) \cdot C - 1)$, we will require embeddings for distances $(num\_keys - 1, \ldots, -(num\_queries - 1))$ &mdash; it will simplify implementation and add little (if any) overhead.

* $num\_queries$ is just number of inputs (either $input\_length$ in case of non-streaming, or $new\_input\_size$ in case of streaming)
* $num\_keys$ is total number of keys in our attention (either $input\_length$ is case of non-streaming or $L \cdot C + num\_new\_inputs$ in case of streaming.

Let's look how extractioin will work. First, in case of non-streaming:

<img src="./images/Extraction-1.png" style="margin-left:auto; margin-right:auto; height: 250px; width: auto" />

This is our BD' matrix. Dark-colored squares represent BD matrix we want to extract. Note, that if we look at the matrix as a contiguous array, distances between dark-colored values are the same &mdash; $num\_queries - 1$.

So let's perform the folowing transformations:
1) view our matrix as an array
2) drop first $num\_queries - 1$ elements and one last element
3) view array as a matrix ($num\_queries$, $num\_queries + num\_keys - 2$)
4) drop last $num\_queries - 2$ columns

Here is the visualization:

<img src="./images/Extraction-2.png" style="margin-left:auto; margin-right:auto; height: 800px; width: auto" />

Let's try the same with streaming setup:

<img src="./images/Extraction-3.png" style="margin-left:auto; margin-right:auto; height: 200px; width: auto" />

In this picture $C=2, L=2$, our new inputs are (6, 7, 8, 9, 10, 11) (3 chunks in total).

Dark-colored and grey-colored squares represent BD matrix we want to extract: grey-colored are the values of BD matrix that will eventually be masked out by "chunked" masking.

Let's see how our transformation look:

<img src="./images/Extraction-4.png" style="margin-left:auto; margin-right:auto; height: 650px; width: auto" />

Great. Now, let's write some code:

In [None]:
def extract_bd_from_bd_prime(bd_prime: torch.Tensor):
    """Extract BD matrix from BD'
    Args:
        ad_prime (torch.Tensor): [batch, num_heads, num_queries, num_keys + num_queries - 1]
            Reminder what BD' matrix is:
            Given (for batch element b and head h)
                - query vectors q_s, ..., q_{s + num_queries - 1}
                - constant vector u
                - (transformed) positional embeddings p_{(num_keys - 1)}, ..., p_0, ..., p_{-(num_keys - 1)}
            BD'[b, h, i, j] = (q[b, h]_{s + i} + u) * (p_{j - num_keys + 1))^T
            Note:
                this is a bit different from notation is recap, specifically:
                1) BD' last dimention is indexed with non-negative numbers, so we shift index in p by (num_keys - 1)
                2) we allow queries first index to be non-zero - this will help us in streaming case
    Returns:
        torch.Tensor of shape [batch, num_heads, num_queries, num_keys]: BD matrix
            BD[b, h, i, j] = BD'[b, h, i, i - j + num_keys - 1] = (q[b, h]_{s + i} + u) * p_{i - j}^T
    """
    batch_size, num_heads, num_queries, num_pos_embeddings = bd_prime.size()
    if num_queries == 1:
        # in this case AD' and AD matrix are the same
        return ad_prime

    # we need input tensor to be contigent in memory for the next set of tricks
    bd = bd_prime.contiguous()

    # Your code goes here...
    raise NotImplementedError()


In [None]:
# let's run some tests.

assert (extract_bd_from_bd_prime(torch.tensor(
    [
        [9, 8, 7, 6, 5, 4, 3, 2, 1, 0, -1, -2, -3, -4, -5],
        [9, 8, 7, 6, 5, 4, 3, 2, 1, 0, -1, -2, -3, -4, -5],
        [9, 8, 7, 6, 5, 4, 3, 2, 1, 0, -1, -2, -3, -4, -5],
        [9, 8, 7, 6, 5, 4, 3, 2, 1, 0, -1, -2, -3, -4, -5],
        [9, 8, 7, 6, 5, 4, 3, 2, 1, 0, -1, -2, -3, -4, -5],
        [9, 8, 7, 6, 5, 4, 3, 2, 1, 0, -1, -2, -3, -4, -5],
    ]
).view(1, 1, 6, 6 + 10 - 1)) == torch.tensor(
    [
        [ 4,  3,  2,  1,  0, -1, -2, -3, -4, -5],
        [ 5,  4,  3,  2,  1,  0, -1, -2, -3, -4],
        [ 6,  5,  4,  3,  2,  1,  0, -1, -2, -3],
        [ 7,  6,  5,  4,  3,  2,  1,  0, -1, -2],
        [ 8,  7,  6,  5,  4,  3,  2,  1,  0, -1],
        [ 9,  8,  7,  6,  5,  4,  3,  2,  1,  0]
    ]
).view(1, 1, 6, 10)).all()

assert (extract_bd_from_bd_prime(torch.tensor(
    [
        [-7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7],
        [-7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7],
        [-7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7],
        [-7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7],
        [-7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7],
        [-7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7],
        [-7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7],
        [-7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7],
    ]
).view(1, 1, 8, 8 + 8 - 1)) == torch.tensor(
    [
        [ 0,  1,  2,  3,  4,  5,  6,  7],
        [-1,  0,  1,  2,  3,  4,  5,  6],
        [-2, -1,  0,  1,  2,  3,  4,  5],
        [-3, -2, -1,  0,  1,  2,  3,  4],
        [-4, -3, -2, -1,  0,  1,  2,  3],
        [-5, -4, -3, -2, -1,  0,  1,  2],
        [-6, -5, -4, -3, -2, -1,  0,  1],
        [-7, -6, -5, -4, -3, -2, -1,  0],
    ]
).view(1, 1, 8, 8)).all()


assert (extract_bd_from_bd_prime(torch.tensor(
    [
        [5, 4, 3, 2, 1, 0, -1],
        [5, 4, 3, 2, 1, 0, -1],
    ]
).view(1, 1, 2, 2 + 6 - 1)) == torch.tensor(
    [
        [ 4,  3,  2,  1,  0, -1],
        [ 5,  4,  3,  2,  1,  0],
    ]
).view(1, 1, 2, 6)).all()

Before implementing streaming layer let's define `RelPositionalEncoding` module &mdash; it will be useful for testing

In [None]:
class RelPositionalEncoding(torch.nn.Module):
    """Relative positional encoding for TransformerXL's layers
    See : Appendix B in https://arxiv.org/abs/1901.02860
    Args:
        d_model (int): embedding dim
        device (torch.device): device to store embeddings on.
        max_len (int): maximum input length
    """
    def __init__(self, d_model, device, max_len=5000):
        """Construct an PositionalEncoding object."""
        super().__init__()
        self.d_model = d_model
        self.max_len = max_len

        # [2 * max_len - 1, 1]
        positions = torch.arange(max_len - 1, -max_len, -1, dtype=torch.float32, device=device).unsqueeze(1)

        # [2 * max_len - 1, d_model]
        pe = torch.zeros(2 * max_len - 1, self.d_model, device=device)

        # [1, d_model / 2]
        div_term = torch.exp(
            torch.arange(0, self.d_model, 2, dtype=torch.float32, device=device)
            * -(math.log(10000.0) / self.d_model)
        ).unsqueeze(0)

        pe[:, 0::2] = torch.sin(positions * div_term)
        pe[:, 1::2] = torch.cos(positions * div_term)

        # [1, 2 * max_len - 1, d_model]
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, end_idx, start_idx):
        """Return positional embedding from end_idx to start_idx.
        Args:
            end_idx (int): end index
            start_idx (int): start index
        Note: it is required that start_idx <= end_idx

        Returns:
            torch.Tensor of shape [1, end_idx - start_idx + 1, d_model] - embeddings for 
                (end_idx, \ldots, start_idx) distances
        """
        center_pos = self.pe.size(1) // 2 + 1

        end_idx = center_pos - end_idx
        start_idx = center_pos - start_idx + 1

        assert 0 <= end_idx < start_idx <= self.pe.size(1)
        return self.pe[:, end_idx:start_idx]


### 3.2.4 Streaming RelPosMultiHeadAttention - implementation

Now we are ready to implement streaming version of RelPosMultiHeadAttention.

In [None]:
class RelPositionMultiHeadAttention(torch.nn.Module):
    """Multi-Head Attention layer of Transformer-XL with support of relative positional encoding.
    Paper: https://arxiv.org/abs/1901.02860
    Args:
        n_head (int): number of heads
        n_feat (int): size of the features
        chunk_size (int): chunk_size
        left_chunks_num (int): number of chunks to attend to on the left
        pos_bias_nonzero_init (bool, optional): initialize pos_bias vectors with nonzero values -- useful for testing
    
    Note: in forward method there is a mask arguments - it should already account for which elements to attend.
    chunk_size and left_chunks_num parameters should mainly be used in get_init_state() method.
    """

    def __init__(
        self,
        n_head: int,
        n_feat: int,
        chunk_size: int,
        left_chunks_num: int,
        pos_bias_nonzero_init: bool = False
    ):
        super().__init__()
        assert n_feat % n_head == 0
        self.d_k = n_feat // n_head
        self.s_d_k = math.sqrt(self.d_k)
        self.n_head = n_head
        self.chunk_size = chunk_size
        self.left_chunks_num = left_chunks_num

        self.linear_q = torch.nn.Linear(n_feat, n_feat)
        self.linear_k = torch.nn.Linear(n_feat, n_feat)
        self.linear_v = torch.nn.Linear(n_feat, n_feat)
        self.linear_out = torch.nn.Linear(n_feat, n_feat)

        # linear transformation for positional encoding
        self.linear_pos = torch.nn.Linear(n_feat, n_feat, bias=False)

        self.pos_bias_u = torch.nn.Parameter(torch.FloatTensor(self.n_head, self.d_k))
        self.pos_bias_v = torch.nn.Parameter(torch.FloatTensor(self.n_head, self.d_k))

        if pos_bias_nonzero_init:
            torch.nn.init.xavier_uniform_(self.pos_bias_u)
            torch.nn.init.xavier_uniform_(self.pos_bias_v)
        else:
            torch.nn.init.zeros_(self.pos_bias_u)
            torch.nn.init.zeros_(self.pos_bias_v)

    def forward_qkv(self, x):
        """Transforms query, key and value.
        Args:
            x (torch.Tensor): [batch, num_inputs, d_model] - input tensor
        returns:
            q (torch.Tensor): [batch, n_head, num_inputs, d_k]
            k (torch.Tensor): [batch, n_head, num_inputs, d_k]
            v (torch.Tensor): [batch, n_head, num_inputs, d_k]
        """
        n_batch = x.size(0)
        q = self.linear_q(x).view(n_batch, -1, self.n_head, self.d_k)
        k = self.linear_k(x).view(n_batch, -1, self.n_head, self.d_k)
        v = self.linear_v(x).view(n_batch, -1, self.n_head, self.d_k)
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)
        return q, k, v

    def forward_attention(self, value, scores, mask):
        """Compute attention context vector.
        Args:
            value (torch.Tensor): [batch, n_head, num_keys, d_k]
                (num_keys is the same as num_values)
            scores (torch.Tensor): [batch, n_head, num_queries, num_keys]
            mask (torch.Tensor): [batch, num_queries, num_keys]
                bool, True means value should be masked.
        returns:
            (torch.Tensor): [batch, n_head, num_queries, d_k] transformed `value` weighted by the attention scores
        """

        n_batch = value.size(0)

        # [batch, 1, num_queries, num_keys]
        mask = mask.unsqueeze(1)
        scores = scores.masked_fill(mask, -10000.0)

        # [batch, n_head, num_queries, num_keys]
        attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)

        # [batch, n_head, num_queries, d_k]
        x = torch.matmul(attn, value)  # (batch, head, time1, d_k)

        # [batch, n_head, d_model]
        x = x.transpose(1, 2).reshape(n_batch, -1, self.n_head * self.d_k)  
        
        # [batch, n_head, d_model]
        return self.linear_out(x)

    def forward(self, x, pos_emb, mask):
        """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
        Args:
            x (torch.Tensor): [batch, num_inputs, d_model] - input tensor
            pos_emb (torch.Tensor): [1, 2 * num_inputs - 1, size] - relative positional embeddings
                for distances [num_inputs - 1, ..., 0, ..., -(num_inputs - 1)]
            mask (torch.Tensor): [batch, num_inputs, num_inputs] - attention mask.
                True means value should be masked.
        Returns:
            (torch.Tensor): [batch, num_queries, d_model] - output
        """

        q, k, v = self.forward_qkv(x)

        # [1, 2 * num_inputs - 1, n_head, d_k]
        p = self.linear_pos(pos_emb).view(1, -1, self.n_head, self.d_k)

        # [1, n_head, 2 * num_inputs - 1, d_k]
        p = p.transpose(1, 2)

        # [batch, head, num_inputs, d_k]
        q_with_bias_u = (q + self.pos_bias_u.view(1, self.n_head, 1, self.d_k))

        # [batch, head, num_inputs, d_k]
        q_with_bias_v = (q + self.pos_bias_v.view(1, self.n_head, 1, self.d_k))

        # compute attention score
        # first compute matrix a and matrix c
        # as described in https://arxiv.org/abs/1901.02860 Section 3.3
        # [batch, head, num_inputs, num_inputs]
        matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))

        # compute matrix bd
        # [batch, n_head, num_inputs, 2 * num_inputs - 1]
        matrix_bd_prime = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
        
        # [batch, n_head, num_inputs, num_inputs]
        matrix_bd = extract_bd_from_bd_prime(matrix_bd_prime)

        # [batch, n_head, num_inputs, num_inputs]
        scores = (matrix_ac + matrix_bd) / self.s_d_k

        out = self.forward_attention(v, scores, mask)

        return out

    def get_initial_state(self):
        """
        Returns:
            state (dict): {
                'keys': (torch.Tensor) - previous chunk_size * left_chunks_num keys
                'values': (torch.Tensor) - previous chunk_size * left_chunks_num values
            } - initial state
        """
        # Do not forget to set correct device/dtype for your tensors
        device = self.linear_q.weight.device
        dtype = self.linear_q.weight.dtype

        # Note: you may choose to store keys and values in a state as
        # [batch, n_head, chunk_size * left_chunks_num, d_k] tensors or
        # maybe some other transposed way for efficiency reasons - your choice!

        # Your code goes here...
        raise NotImplementedError()

    def streaming_forward(self, x, pos_emb, mask, state):
        """
        Args:
            x (torch.Tensor): [1, num_queries, d_model] - new inputs
            pos_emb (torch.Tensor): [1, num_queries + num_keys - 1, d_pos_emb] - relative positional embeddings for distances
                num_keys - 1, ..., 0, ..., -(num_queries - 1)
            mask (torch.Tensor): [1, num_queries, num_keys] - attention mask.
                True means value should be masked.
            state (dict): {
                'keys': (torch.Tensor) - previous chunk_size * left_chunks_num keys
                'values': (torch.Tensor) - previous chunk_size * left_chunks_num values
            } - current state
        Returns:
            tuple[torch.Tensor, dict]: output and new_state
                output: (torch.Tensor(: [1, num_queries, d_model] - new outputs
                state: dict: {
                    'keys': (torch.Tensor) - previous chunk_size * left_chunks_num keys
                    'values': (torch.Tensor) - previous chunk_size * left_chunks_num values
                } - new state
        """
        # Your code goes here...
        raise NotImplementedError()
        # Feel free to reuse forward_qkv and forward_attention methods.


In [None]:
def test_mha_layer(
    n_head: int,
    n_feat: int,
    chunk_size: int,
    left_chunks_num: int,
    input_size: int,
    chunks_per_step: int,
    tolerance: float = 1e-5,
    layer_constructor: Callable = RelPositionMultiHeadAttention
):
    with torch.no_grad():
        assert input_size % (chunk_size * chunks_per_step) == 0
        device = torch.device('cpu')
        pos_enc = RelPositionalEncoding(d_model=n_feat, device=device)

        layer = layer_constructor(
            n_head=n_head,
            n_feat=n_feat,
            chunk_size=chunk_size,
            left_chunks_num=left_chunks_num,
            pos_bias_nonzero_init=True
        )

        mask = create_attn_mask(chunk_size=chunk_size, left_chunks_num=left_chunks_num, input_size=input_size, device=device)
        
        test_input = torch.tensor(np.random.randn(1, input_size, n_feat).astype(np.float32)).to(device)
        regular_output = layer(
            x=test_input,
            pos_emb=pos_enc(end_idx=(input_size - 1), start_idx=-(input_size - 1)),
            mask=mask
        )

        state = layer.get_initial_state()
        streaming_output = []
        streaming_pe = pos_enc(end_idx=(left_chunks_num + chunks_per_step) * chunk_size - 1, start_idx=-(chunk_size * chunks_per_step - 1))
        for start_idx in range(0, input_size, chunk_size * chunks_per_step):
            chunk_input = test_input[:, start_idx:start_idx + chunk_size * chunks_per_step, :]
            step_mask = create_streaming_attn_mask(
                chunk_size=chunk_size, 
                left_chunks_num=left_chunks_num,
                new_inputs_size=chunk_size * chunks_per_step,
                processed_inputs=start_idx,
                device=device
            )
            step_output, state = layer.streaming_forward(
                x=chunk_input,
                pos_emb=streaming_pe,
                mask=step_mask,
                state=state
            )
            streaming_output.append(step_output)
        streaming_output = torch.cat(streaming_output, axis=1)

        assert torch.abs(regular_output - streaming_output).max() < tolerance
    print('Test ok')


In [None]:
test_mha_layer(n_head=2, n_feat=4, chunk_size=2, left_chunks_num=1, input_size=32, chunks_per_step=1)
test_mha_layer(n_head=2, n_feat=4, chunk_size=2, left_chunks_num=1, input_size=32, chunks_per_step=2)

test_mha_layer(n_head=2, n_feat=4, chunk_size=2, left_chunks_num=5, input_size=32, chunks_per_step=1)
test_mha_layer(n_head=2, n_feat=4, chunk_size=2, left_chunks_num=5, input_size=32, chunks_per_step=2)

test_mha_layer(n_head=4, n_feat=32, chunk_size=3, left_chunks_num=5, input_size=60, chunks_per_step=1)
test_mha_layer(n_head=4, n_feat=32, chunk_size=3, left_chunks_num=5, input_size=60, chunks_per_step=2)

Great! We have implemented all basic streaming layers.

### 4. Conformer Feed Forward layer.

This is a section to rest a bit. We will define ConformerFeedForward layer, which requires no streaming support

In [None]:
class ConformerFeedForward(torch.nn.Module):
    def __init__(
        self,
        d_model: int,
        d_ff: int,
        activation: torch.nn.Module = torch.nn.SiLU()
    ):
        super().__init__()
        self._d_model = d_model
        self._d_ff = d_ff

        self.linear1 = torch.nn.Linear(self._d_model, self._d_ff)
        self.activation = activation
        self.linear2 = torch.nn.Linear(self._d_ff, self._d_model)

    def forward(self, x):
        x = self.linear1(x)
        x = self.activation(x)
        x = self.linear2(x)
        return x

### 5. Composite layers and putting it all together.

Now let's implement streaming for "composite" layers &mdash; layers, which are basically applying several inner layers, some of which are streaming.


Basic framework of making streaming version of "composite" layers is simple: state is just combination of states of inner streaming layers (e.g, list of states or dict of inner layer name -> state).

`streaming_forward` method looks almost identical to `forward`, except for streaming inner layers we call `streaming_forward` instead of `forward` (and "update" their state along the way).

### 5.1 ConformerConvolution
Let's look at an example:

`ConformerConvolution` layer consists of several pointwise operations (activations, 1x1 convolutions, layernorm) as well as one `CausalConv1D`.

Since there is only one streaming inner layer, we can make `ConformerConvolution` state just `CausalConv1D`'s state.

`streaming_forward` looks almost exactly like `forward`, except we call `streaming_forward` for causal convolution




In [None]:
class ConformerConvolution(torch.nn.Module):
    def __init__(
        self,
        d_model: int,
        kernel_size: int,
    ):
        super().__init__()
        assert (kernel_size - 1) % 2 == 0
        self._d_model = d_model
        self._kernel_size = kernel_size
        self.pointwise_activation = lambda x: torch.nn.functional.glu(x, dim=1)

        self.pointwise_conv1 = torch.nn.Conv1d(
            in_channels=self._d_model, out_channels=self._d_model * 2, kernel_size=1, stride=1, padding=0, bias=True
        )
        self.depthwise_conv = CausalConv1D(
            in_channels=self._d_model,
            out_channels=self._d_model,
            kernel_size=self._kernel_size,
            stride=1,
            groups=self._d_model,
            bias=True,
        )

        # yep, batch_norm here is layer norm - for weight compatibility reason.
        self.batch_norm = torch.nn.LayerNorm(self._d_model)
        self.activation = torch.nn.SiLU()
        self.pointwise_conv2 = torch.nn.Conv1d(
            in_channels=self._d_model, out_channels=d_model, kernel_size=1, stride=1, padding=0, bias=True
        )

    def forward(self, x):
        """
        Args:
            x (torch.Tensor): [batch, time, feats] - input tensor
        Returns:
            (torch.Tensor) - layer output
        """

        x = x.transpose(1, 2)
        x = self.pointwise_conv1(x)
        x = self.pointwise_activation(x)

        x = self.depthwise_conv(x)

        x = x.transpose(1, 2)
        x = self.batch_norm(x)
        x = x.transpose(1, 2)

        x = self.activation(x)
        x = self.pointwise_conv2(x)
        x = x.transpose(1, 2)
        return x

    def get_initial_state(self):
        """
        Returns:
            torch.Tensor: initial state
        """
        # Your code goes here...
        raise NotImplementedError()

    def streaming_forward(self, x, state):
        """Args:
            x (torch.Tensor): [1, time, feats] - input tensor
            state (torch.Tensor): state
        Returns:
            tuple[torch.Tensor, torch.Tensor]: output tensor and new state.
        """

        # Your code goes here...
        raise NotImplementedError()

In [None]:
test_streaming(
    layer=ConformerConvolution(d_model=5, kernel_size=9),
    input_shape=(1, 16, 5),
    time_dimention=1,
    chunk_size=4
)

### 5.2 ConvSubsampling

Next "composite" layer is `ConvSubsampling`. It consists of several `CausalConv2D` layers, as well as some 1x1 convolutions and activations.

Making a streaming version of it is, again, pretty straightforward &mdash; state is just a list of all inner `CausalConv2D` states and instead of `forward` calls to `CausalConv2D` we are making `streaming_forward` call.

The only new thing is this: `forward` method takes `lengths` argument &mdash; length of each input element in batch
and returns not only output tensor, but also `lengths` tensor &mdash; length of each output element in a batch.

Since we are currently implementing streaming version only for the case of `batch_size=1`, we do not need to worry about length input argument and return value.


In [None]:
# First, a little helper function

def calc_length(lengths: torch.Tensor, paddings: int, kernel_size: int, stride: int, repeat_num: int):
    """Calculates the output length of a Tensor passed through series of convolution or max pooling layer"""
    add_pad: int = paddings - kernel_size
    one: float = 1.0
    for i in range(repeat_num):
        lengths = torch.div(lengths.to(dtype=torch.float) + add_pad, stride) + one
        lengths = torch.floor(lengths)
    return lengths.to(dtype=torch.int)


In [None]:
class ConvSubsampling(torch.nn.Module):
    def __init__(
        self,
        subsampling_factor: int,
        feat_in: int,
        feat_out: int,
        conv_channels: int,
        activation: torch.nn.Module,
    ):
        super().__init__()
        self._conv_channels = conv_channels
        self._feat_in = feat_in
        self._feat_out = feat_out

        # checking that subsampling_factor is a power of 2
        assert subsampling_factor & (subsampling_factor - 1) == 0

        self._sampling_num = int(math.log(subsampling_factor, 2))
        self._subsampling_factor = subsampling_factor

        in_channels = 1
        layers = []

        self._stride = 2
        self._kernel_size = 3

        self._left_padding = self._kernel_size - 1
        self._right_padding = 0
        self._top_padding = self._kernel_size - 1
        self._bottom_padding = self._stride - 1

        layers.append(
            CausalConv2D(
                in_feats=self._feat_in,
                in_channels=in_channels,
                out_channels=conv_channels,
                kernel_size=self._kernel_size,
                stride=self._stride,
            )
        )

        in_channels = conv_channels
        out_length = int(
            calc_length(
                torch.tensor(self._feat_in, dtype=torch.float),
                paddings=self._top_padding + self._bottom_padding,
                kernel_size=self._kernel_size,
                stride=self._stride,
                repeat_num=1
            )
        )

        layers.append(activation)

        for i in range(self._sampling_num - 1):
            layers.append(
                CausalConv2D(
                    in_feats=out_length,
                    in_channels=in_channels,
                    out_channels=in_channels,
                    kernel_size=self._kernel_size,
                    stride=self._stride,
                    groups=in_channels,
                )
            )

            layers.append(
                torch.nn.Conv2d(
                    in_channels=in_channels,
                    out_channels=conv_channels,
                    kernel_size=1,
                    stride=1,
                    padding=0,
                    groups=1,
                )
            )
            layers.append(activation)
            in_channels = conv_channels
            out_length = int(
                calc_length(
                    torch.tensor(out_length, dtype=torch.float),
                    paddings=self._top_padding + self._bottom_padding,
                    kernel_size=self._kernel_size,
                    stride=self._stride,
                    repeat_num=1
                )
            )

        self.out = torch.nn.Linear(conv_channels * out_length, self._feat_out)
        self.conv = torch.nn.Sequential(*layers)

    def forward(self, x, lengths):
        """
        Args:
            x (torch.Tensor): [batch, max_time, features] - input tensor
            lengths (torch.Tensor): [batch] -  lengths of inputs
        Returns:
            tuple[torch.Tensor, torch.Tensor]: output tensor and output lengths
        """
        lengths = calc_length(
            lengths,
            paddings=self._left_padding + self._right_padding,
            kernel_size=self._kernel_size,
            stride=self._stride,
            repeat_num=self._sampling_num,
        )

        # [batch, 1, time, features]
        x = x.unsqueeze(1)
        x = self.conv(x)

        batch, channels, time, features = x.size()

        # [batch, time channels * features]
        x = x.transpose(1, 2).reshape(batch, time, -1)
        x = self.out(x)
        return x, lengths

    def get_initial_state(self):
        """
        Returns:
            list[torch.Tensor]: initial layer state - list of initial states of
                inner layers
        """
        # Your code goes here...
        raise NotImplementedError()

        # hint: you can distinguish CausalConv2D layers using isinstance(layer, CausalConv2D)


    def streaming_forward(self, x, state):
        """
        Args:
            x (torch.Tensor): [1, time, features] - input.
            state (list[torch.Tensor]): state.
        Returns:
            tuple[torch.Tensor, list[torch.Tensor]]: output and new state.
        """
        # Your code goes here
        raise NotImplementedError()


In [None]:
test_streaming(
    layer=ConvSubsampling(subsampling_factor=8, feat_in=5, feat_out=3, conv_channels=3, activation=torch.nn.ReLU()),
    input_shape=(1, 40, 5),
    time_dimention=1,
    chunk_size=8,
    additional_inputs_fn = lambda x: [torch.tensor([x.size(1)])]
)

### 5.2 ConformerLayer

Next "composite" layer is `ConformerLayer`. It consists of several `layer_norm` layers, several `ConformerFeedForward` layers, a `ConformerConvolution` and a `RelPositionMultiHeadAttention` layer.

Although it may look intimidating, there is actually nothing new here. Just combine states and pass them along!


In [None]:
class ConformerLayer(torch.nn.Module):
    """A single block of the Conformer encoder.

    Args:
        d_model (int): input dimension of RelPositionMultiHeadAttention and ConformerFeedForward
        d_ff (int): hidden dimension of ConformerFeedForward
        n_heads (int): number of heads for multi-head attention
        conv_kernel_size (int): kernel size for depthwise convolution in convolution module
        chunk_size (int): chunk_size
        left_chunks_num (int): number of chunks to attend to on the left
        pos_bias_nonzero_init (bool, optional): initialize pos_bias vectors in RelPositionMultiHeadAttention module
            with nonzero values - useful for testing
    
    Note: in forward method there is a mask argument - it should already account for which elements to attend.
    chunk_size and left_chunks_num parameters are just passed to RelPositionMultiHeadAttention
    """

    def __init__(
        self,
        d_model: int,
        d_ff: int,
        n_heads: int,
        conv_kernel_size: int,
        chunk_size: int,
        left_chunks_num: int,
        pos_bias_nonzero_init: bool = False
    ):
        super().__init__()
        self._d_model = d_model
        self._d_ff = d_ff
        self._n_heads = n_heads
        self._conv_kernel_size = conv_kernel_size

        self._fc_factor = 0.5

        self.norm_feed_forward1 = torch.nn.LayerNorm(self._d_model)
        self.feed_forward1 = ConformerFeedForward(d_model=self._d_model, d_ff=self._d_ff)

        self.norm_conv = torch.nn.LayerNorm(self._d_model)
        self.conv = ConformerConvolution(
            d_model=self._d_model,
            kernel_size=self._conv_kernel_size,
        )

        self.norm_self_att = torch.nn.LayerNorm(self._d_model)

        self.self_attn = RelPositionMultiHeadAttention(
            n_head=self._n_heads,
            n_feat=self._d_model,
            chunk_size=chunk_size,
            left_chunks_num=left_chunks_num,
            pos_bias_nonzero_init=pos_bias_nonzero_init,
        )

        self.norm_feed_forward2 = torch.nn.LayerNorm(self._d_model)
        self.feed_forward2 = ConformerFeedForward(d_model=self._d_model, d_ff=self._d_ff)

        self.norm_out = torch.nn.LayerNorm(self._d_model)

    def forward(self, x, pos_emb, mask):
        """
        Args:
            x (torch.Tensor): [batch, num_inputs, d_model] - input
            pos_emb (torch.Tensor): [batch, 2 * num_inputs - 1, size] - relative positional embeddings
                for distances num_inputs - 1, ..., -(num_inputs - 1)
            mask (torch.Tensor): [batch, num_inputs, num_inputs] - attention mask
                True means value should be masked.
        Returns:
            torch.Tensor: [batch, num_inputs, d_model] - output
        """
        residual = x
        x = self.norm_feed_forward1(x)
        x = self.feed_forward1(x)
        residual = residual + x * self._fc_factor

        x = self.norm_self_att(residual)
        x = self.self_attn(x, pos_emb=pos_emb, mask=mask)

        residual = residual + x

        x = self.norm_conv(residual)
        x = self.conv(x)
        residual = residual + x

        x = self.norm_feed_forward2(residual)
        x = self.feed_forward2(x)
        residual = residual + x * self._fc_factor

        x = self.norm_out(residual)
        return x

    def get_initial_state(self):
        """Returns:
            ???: initial state.
        """
        # Your code goes here
        raise NotImplementedError()

    def streaming_forward(self, x, pos_emb, mask, state):
        """Args:
            x (torch.Tensor): [1, num_queries, d_model] - new inputs
            pos_emb (torch.Tensor): [1, 2 * num_queries + chunk_size * left_chunks_num - 1, size] -
                relative positional embeddings for distances 
                    chunk_size * left_chunks_num + num_queries - 1, ..., num_queries - 1
            mask (torch.Tensor): [1, num_queries, chunk_size * left_chunks_num + num_queries] - attention mask.
                True means value should be masked.
            state (???): state.
        """
        # Your code goes here
        raise NotImplementedError()


Signature of ConformerLayer call is the same as signature of RelPositionMultiHeadAttention call, so we will reuse `test_mha_layer` function for testing.


In [None]:
def make_conformer_layer_constructor(d_ff: int = 4, conv_kernel_size: int = 9):
    """Adapter to pass to test_mha_layer function"""

    def constructor(n_head: int, n_feat: int, chunk_size: int, left_chunks_num: int, pos_bias_nonzero_init: bool):
        return ConformerLayer(
            d_model=n_feat,
            d_ff=d_ff,
            n_heads=n_head,
            conv_kernel_size=conv_kernel_size,
            chunk_size=chunk_size,
            left_chunks_num=left_chunks_num,
            pos_bias_nonzero_init=pos_bias_nonzero_init
        )
    return constructor

test_mha_layer(
    n_head=2,
    n_feat=4,
    chunk_size=2,
    left_chunks_num=1,
    input_size=32,
    chunks_per_step=1,
    layer_constructor=make_conformer_layer_constructor()
)
test_mha_layer(
    n_head=2,
    n_feat=4,
    chunk_size=2,
    left_chunks_num=1,
    input_size=32,
    chunks_per_step=2,
    layer_constructor=make_conformer_layer_constructor()
)

test_mha_layer(
    n_head=2,
    n_feat=4,
    chunk_size=2,
    left_chunks_num=5,
    input_size=32,
    chunks_per_step=1,
    layer_constructor=make_conformer_layer_constructor()
)
test_mha_layer(
    n_head=2,
    n_feat=4,
    chunk_size=2,
    left_chunks_num=5,
    input_size=32,
    chunks_per_step=2,
    layer_constructor=make_conformer_layer_constructor()
)

test_mha_layer(
    n_head=4,
    n_feat=32,
    chunk_size=3,
    left_chunks_num=5,
    input_size=60,
    chunks_per_step=1,
    layer_constructor=make_conformer_layer_constructor()
)
test_mha_layer(
    n_head=4,
    n_feat=32,
    chunk_size=3,
    left_chunks_num=5,
    input_size=60,
    chunks_per_step=2,
    layer_constructor=make_conformer_layer_constructor()
)

### 5.3 ConformerEncoder &mdash; puting it all together

Finally, we are going to implement streaming Conformer Encoder.

There are several differences from other "composite" layers:
* We need to call `RelPositionalEncoding` layer (arguments will be different for streaming and non-streaming version)
* We need to create masks &mdash; they are different for streaming and non-streaming version
* Also in order to create masks we need to keep track of number of processed inputs. So in the state there will be not only inner layer states, but also an integer &mdash; number of processed inputs

Otherwise, implementation is pretty straightforward.

Note: when using number of processed inputs, keep track of what inputs are you counting &mdash; are these inputs to ConformerEncoder, or are these inputs to ConformerLayer? They are different because of subsampling!


<details> 
  <summary>Hint</summary>
   You may look at <span>test_mha_layer</span> function if you struggle with creation of relative positional embeddings or mask
</details>

In [None]:
class ConformerEncoder(torch.nn.Module):
    def __init__(
        self,
        feat_in: int,
        n_layers: int,
        d_model: int,
        ff_expansion_factor: int,
        n_heads: int,
        subsampling_factor: int,
        subsampling_conv_channels: int,
        chunk_size: int,
        left_chunks_num: int,
        conv_kernel_size: int,
        pos_emb_max_len: int = 5000,
    ):
        super().__init__()

        self._feat_in = feat_in
        self._n_layers = n_layers
        self._d_model = d_model
        self._ff_expansion_factor = ff_expansion_factor
        self._n_heads = n_heads
        self._subsampling_factor = subsampling_factor
        self._subsampling_conv_channels = subsampling_conv_channels
        self._x_scale = math.sqrt(self._d_model)

        self._chunk_size = chunk_size
        self._left_chunks_num = left_chunks_num
        self._conv_kernel_size = conv_kernel_size
        self._pos_emb_max_len = pos_emb_max_len

        self.pre_encode = ConvSubsampling(
            subsampling_factor=self._subsampling_factor,
            feat_in=self._feat_in,
            feat_out=self._d_model,
            conv_channels=self._subsampling_conv_channels,
            activation=torch.nn.ReLU(),
        )

        self._feat_out = d_model

        self.pos_enc = RelPositionalEncoding(
            d_model=d_model,
            max_len=pos_emb_max_len,
            device=next(iter(self.pre_encode.parameters())).device
        )

        self.layers = torch.nn.ModuleList()
        for i in range(n_layers):
            layer = ConformerLayer(
                d_model=self._d_model,
                d_ff=self._d_model * self._ff_expansion_factor,
                n_heads=self._n_heads,
                conv_kernel_size=self._conv_kernel_size,
                chunk_size=self._chunk_size,
                left_chunks_num=self._left_chunks_num,
            )
            self.layers.append(layer)

    def forward(self, features, lengths):
        """
        Args:
            features (torch.Tensor): [batch, input_size, features] - input features.
            lengths (torch.Tensor): [batch] - input lengths.
        Returns:
            tuple[torch.Tensor, torch.Tensor] - output features and lengths.
        """
        features, lengths = self.pre_encode(x=features, lengths=lengths)
        lengths = lengths.to(torch.int64)

        features = features * self._x_scale

        # this is different form input_size, because of subsampling!
        layers_input_size = features.size(1)

        pos_emb = self.pos_enc(end_idx=layers_input_size - 1, start_idx=-(layers_input_size - 1))

        # [1, input_size, input_size]
        chunked_mask = create_attn_mask(
            chunk_size=self._chunk_size,
            left_chunks_num=self._left_chunks_num,
            input_size=layers_input_size,
            device=features.device
        )

        # [batch, input_size, 1]
        # padding_mask[i, j, 0] = ~(i < lengths[j])
        padding_mask = ~(
            torch.arange(0, layers_input_size, device=features.device).unsqueeze(0) < lengths.unsqueeze(-1)
        ).unsqueeze(-1)
        mask = torch.logical_or(chunked_mask, padding_mask)

        for layer in self.layers:
            features = layer(
                x=features,
                mask=mask,
                pos_emb=pos_emb
            )

        return features, lengths

    def get_initial_state(self):
        """Returns:
            ???: initial state
        """
        # Your code goes here...
        raise NotImplementedError()

    def streaming_forward(self, features, state):
        """Args:
            features (torch.Tensor): [1, new_input_size, features] - new inputs
            state (???): input state
        Returns:
            tuple[torch.Tensor, ???] - new outputs and new state.
        """

        # Your code goes here...
        raise NotImplementedError()


In [None]:
test_streaming(
    layer=ConformerEncoder(
        feat_in=3,
        n_layers=2,
        d_model=8,
        ff_expansion_factor=2,
        n_heads=2,
        subsampling_factor=4,
        subsampling_conv_channels=3,
        chunk_size=2,
        left_chunks_num=3,
        conv_kernel_size=9,
    ),
    input_shape=(1, 80, 3),
    time_dimention=1,
    chunk_size=2 * 4,
    additional_inputs_fn = lambda x: [torch.tensor([x.size(1)])]
)

test_streaming(
    layer=ConformerEncoder(
        feat_in=3,
        n_layers=2,
        d_model=8,
        ff_expansion_factor=2,
        n_heads=2,
        subsampling_factor=4,
        subsampling_conv_channels=3,
        chunk_size=2,
        left_chunks_num=3,
        conv_kernel_size=9,
    ),
    input_shape=(1, 80, 3),
    time_dimention=1,
    chunk_size=4 * 4,
    additional_inputs_fn = lambda x: [torch.tensor([x.size(1)])]
)

Congratulations! You have successfully implemented streaming ConformerEncoder.

## 6. Testing with real-life data.

Now, let's relax, load some real weights and put our hard work in action!

### 6.1. Helper classes.

But first, we need to define several helper classes &mdash; filterbank feature calculator and Greedy CTC Decoder.

Note: if you don't want to, you don't need to understand implementation.

In [None]:
def int_singal_to_tensor(int_signal: np.ndarray, device: torch.device) -> torch.Tensor:
    return torch.Tensor(
        int_signal.astype(np.float32) / np.float32(2. ** 15)
    ).to(device)


class FilterbankFeatures(torch.nn.Module):
    def __init__(
        self,
        sample_rate=16000,
        n_window_size=400,
        n_window_stride=160,
        preemph=0.97,
        nfilt=80,
        lowfreq=0,
        highfreq=None,
        log=True,
        log_zero_guard_value=2 ** -24,
        pad_value=0,
        nb_max_freq=4000,
        mel_norm="slaney",
    ):
        super().__init__()
        self.log_zero_guard_value = log_zero_guard_value

        self.win_length = n_window_size
        self.hop_length = n_window_stride
        self.n_fft = n_window_size

        window_fn = torch.hann_window
        window_tensor = window_fn(self.win_length, periodic=False)
        self.register_buffer("window", window_tensor)
        self.stft = lambda x: torch.stft(
            x,
            n_fft=self.n_fft,
            hop_length=self.hop_length,
            win_length=self.win_length,
            center=False,
            window=self.window.to(dtype=torch.float),
            return_complex=True,
        )

        self.nfilt = nfilt
        self.preemph = preemph
        highfreq = highfreq or sample_rate / 2

        filterbanks = torch.tensor(
            librosa.filters.mel(
                sr=sample_rate, n_fft=self.n_fft, n_mels=nfilt, fmin=lowfreq, fmax=highfreq, norm=mel_norm
            ),
            dtype=torch.float,
        ).unsqueeze(0)
        self.register_buffer("fb", filterbanks)

        self.pad_value = pad_value

        self.forward = torch.no_grad()(self.forward)

    def forward(self, x):
        """Args:
            x (torch.Tensor): [num_samples] - input float32 waveform with values from -1 to 1
        Returns:
            torch.Tensor: [num_features, nfilt] fbank features.
                num_features = (num_samples - n_window_size) / n_window_stride + 1
        """
        if x.shape[0] < self.win_length:
            raise ValueError('Not enough data')

        x = torch.cat((x[0].unsqueeze(0), x[1:] - self.preemph * x[:-1]), dim=0)
        # disable autocast to get full range of stft values
        with torch.cuda.amp.autocast(enabled=False):
            x = self.stft(x)
        x = torch.view_as_real(x)
        x = x.pow(2).sum(-1)
        # dot with filterbank energies
        x = torch.matmul(self.fb.to(x.dtype), x).squeeze(0)
        x = torch.log(x + self.log_zero_guard_value)
        return x.transpose(0, 1)


class ChunkedStreamingFbank:
    """Streaming adapter for FilterbankFeatures.
    Consumes waveform chunks and output fixed-size feature chunks.

    Args:
        chunk_size_feats (int): fixed output chunk_size
        featurizer (FilterbankFeatures): feature calculator to wrap.
    """
    def __init__(self, chunk_size_feats: int, featurizer: FilterbankFeatures):
        self.featurizer = featurizer
        
        self.win_length_samples = self.featurizer.win_length
        self.hop_length_samples = self.featurizer.hop_length
        
        self.chunk_size_feats = chunk_size_feats
        
        self.buffer_signal = None
        self.buffer_feature_chunks = queue.Queue()
        self.last_feature_chunk_prefix = None

    def reset(self):
        self.buffer = None
        self.buffer_feature_chunks = queue.Queue()
        self.last_feature_chunk_prefix = None

    def _get_valid_samples_and_feats(self, signal_length: int) -> tuple[int, int]:
        if signal_length < self.win_length_samples:
            return 0, 0
        valid_feats = (signal_length - self.win_length_samples) // self.hop_length_samples + 1
        return (valid_feats - 1) * self.hop_length_samples + self.win_length_samples, valid_feats
    
    def add(self, signal_chunk):
        """
        Args:
            signal (torch.Tensor): input signal chunk
        """
        if self.buffer_signal is not None:
            signal_chunk = torch.cat([self.buffer_signal, signal_chunk], axis=0)
            self.buffer_signal = None

        valid_samples, valid_feats = self._get_valid_samples_and_feats(signal_chunk.shape[0])
        self.buffer_signal = signal_chunk[valid_feats * self.hop_length_samples:]

        if valid_samples == 0:
            return None
    
        signal_chunk = signal_chunk[:valid_samples]

        feats = self.featurizer(signal_chunk)
        if self.last_feature_chunk_prefix is not None:
            feats = torch.cat((self.last_feature_chunk_prefix, feats), axis=0)
            self.last_feature_chunk_prefix = None

        idx = 0
        while (idx + 1) * self.chunk_size_feats <= feats.shape[0]:
            self.buffer_feature_chunks.put(feats[idx * self.chunk_size_feats:(idx + 1) * self.chunk_size_feats])
            idx += 1
        if idx * self.chunk_size_feats != feats.shape[0]:
            self.last_feature_chunk_prefix = feats[idx * self.chunk_size_feats:]

    def get_next_feature_chunk(self) -> torch.Tensor | None:
        if self.buffer_feature_chunks.empty():
            return None
        return self.buffer_feature_chunks.get()


class GreedyCtcDecoder(torch.nn.Module):
    def __init__(self, enc_output_size, tokenizer_settings):
        super().__init__()
        self._tokenizer_settings = tokenizer_settings
        self.decoder_layers = torch.nn.Sequential(
            torch.nn.Conv1d(
                in_channels=enc_output_size,
                out_channels=len(tokenizer_settings['token_to_piece']) + 1,
                kernel_size=1,
                stride=1
            )
        )

    def forward(self, enc_output, enc_lengths):
        """Args:
            enc_output (torch.Tensor): [batch, time, features]
            enc_lengths (torch.Tensor): [batch]
        Returns:
            tuple[torch.Tensor, torch.Tensor]: logits [batch, time, num_logits] and logits lengths [batch]
        """
        enc_output = enc_output.transpose(1, 2)
        logits = self.decoder_layers(enc_output).transpose(1, 2)
        return logits, enc_lengths

    def decode(self, logits, logits_lengths):
        """Args:
            logits (torch.Tensor): [batch, time, num_logits]
            logits_lengths (torch.Tensor): [batch]
        Returns:
            list[str]: [batch] of greedy hypos.
        """

        result = []
        logits_lengths = logits_lengths.detach().cpu().numpy()
        for idx in range(len(logits)):
            tokens = list(map(int, logits[idx, :logits_lengths[idx]].max(dim=-1)[1].detach().cpu().numpy()))
            prediction = []
            prev_token = None
            for token in tokens:
                if token != prev_token and token != self._tokenizer_settings['blank_idx']:
                    prediction.append(self._tokenizer_settings['token_to_piece'][str(token)])
                prev_token = token
            result.append(''.join(prediction).replace(self._tokenizer_settings['special_symbol'], ' ').strip())
        return result


### 6.2 Real-life data

Let's download some real model weights, tokenizer settings and audio

In [None]:
def download_file(public_link, filename='archieve.tgz'):
    base_url = 'https://cloud-api.yandex.net/v1/disk/public/resources/download?'
    final_url = base_url + urlencode(dict(public_key=public_link))
    response = requests.get(final_url)
    parse_href = response.json()['href']

    url = parse_href
    download_url = requests.get(url)
    final_link = os.path.join(os.getcwd(), filename)
    print(final_link)
    with open(final_link, 'wb') as ff:
        ff.write(download_url.content)

In [None]:
link_to_archive = "https://disk.yandex.ru/d/Omgg4HryF5AWLQ"
download_file(link_to_archive, filename='archieve.tgz')
!mkdir -p ../data
!mv archieve.tgz ../data/
!tar xzvf ../data/archieve.tgz -C ../data

### 6.3 Testing on real-life data

In [None]:
encoder = ConformerEncoder(
    feat_in=80,
    n_layers=17,
    d_model=512,
    ff_expansion_factor=4,
    n_heads=8,
    subsampling_factor=8,
    subsampling_conv_channels=256,
    chunk_size=2,
    left_chunks_num=70,
    conv_kernel_size=9,
)
# chunk_size * subsampling_factor
encoder_step = 2 * 8

In [None]:
with open('../data/week12_data/encoder_state.pkl', 'rb') as fp:
    encoder.load_state_dict(pickle.load(fp), strict=False)
encoder = encoder.cpu().eval()

In [None]:
with open('../data/week12_data/token.json') as fp:
    tokenizer_settings = json.load(fp)

In [None]:
decoder = GreedyCtcDecoder(512, tokenizer_settings)

In [None]:
with open('../data/week12_data/decoder_state.pkl', 'rb') as fp:
    decoder.load_state_dict(pickle.load(fp))
decoder = decoder.cpu().eval()

In [None]:
with open('../data/week12_data/audio.wav', 'rb') as fp:
    with wave.open(fp, 'r') as wfp:
        pcm_data = wfp.readframes(wfp.getnframes())

signal = np.frombuffer(pcm_data, dtype=np.int16)
signal = int_singal_to_tensor(signal, device=torch.device('cpu'))


In [None]:
Audio('../data/week12_data/audio.wav')

In [None]:
featurizer = FilterbankFeatures()

In [None]:
with torch.no_grad():
    features = featurizer(signal)

    # make features multiple of encoder_step
    features = features[:(features.shape[0] // encoder_step) * encoder_step, :]
    encoded, encoded_len = encoder(features.unsqueeze(0), torch.tensor([features.size(0)]))
    logits, logits_len = decoder(encoded, encoded_len)
    print(decoder.decode(logits, logits_len)[0])

In [None]:
features_chunker = ChunkedStreamingFbank(chunk_size_feats=encoder_step, featurizer=featurizer)

# 200ms (16 = 16000 / 1000 - samples per ms)
signal_chunk_size_samples = 200 * 16
with torch.no_grad():
    whole_logits = None
    state = encoder.get_initial_state()
    for start_idx in range(0, signal.shape[0], signal_chunk_size_samples):
        total_processed_samples = start_idx + signal_chunk_size_samples
        signal_chunk = signal[start_idx:total_processed_samples]
        features_chunker.add(signal_chunk)
        while (features_chunk := features_chunker.get_next_feature_chunk()) is not None:
            encoder_step_output, state = encoder.streaming_forward(features_chunk.unsqueeze(0), state)
            step_logits, _ = decoder(encoder_step_output, torch.tensor([encoder_step_output.shape[1]]))
            if whole_logits is not None:
                whole_logits = torch.cat([whole_logits, step_logits], axis=1)
            else:
                whole_logits = step_logits
            hypo = decoder.decode(whole_logits, torch.tensor([whole_logits.shape[1]]))[0]
            print(f"time: {total_processed_samples / 16000:.2f}s: '{hypo}'")

Let's compare naive approach (rerun model on prefix each time new chunk arrives) and streaming approach:

In [None]:
%%time
features = featurizer(signal)
with torch.no_grad():
    for end_idx in range(encoder_step, features.shape[0] + 1, encoder_step):
        prefix_features = features[:end_idx]
        encoded, encoded_len = encoder(prefix_features.unsqueeze(0), torch.tensor([prefix_features.size(0)]))
        logits, logits_len = decoder(encoded, encoded_len)

In [None]:
%%time
features_chunker = ChunkedStreamingFbank(chunk_size_feats=encoder_step, featurizer=featurizer)

signal_chunk_size_samples = 200 * 16
with torch.no_grad():
    whole_logits = None
    state = encoder.get_initial_state()
    for start_idx in range(0, signal.shape[0], signal_chunk_size_samples):
        total_processed_samples = start_idx + signal_chunk_size_samples
        signal_chunk = signal[start_idx:total_processed_samples]
        features_chunker.add(signal_chunk)
        while (features_chunk := features_chunker.get_next_feature_chunk()) is not None:
            encoder_step_output, state = encoder.streaming_forward(features_chunk.unsqueeze(0), state)
            step_logits, _ = decoder(encoder_step_output, torch.tensor([encoder_step_output.shape[1]]))
            if whole_logits is not None:
                whole_logits = torch.cat([whole_logits, step_logits], axis=1)
            else:
                whole_logits = step_logits
            hypo = decoder.decode(whole_logits, torch.tensor([whole_logits.shape[1]]))[0]

## 7. The task

1) If you've successfully filled in the gaps and now have working streaming conformer encoder &mdash; congratulations, you can collect your **5 points**
2) If you want *bonus* **5 points** &mdash; you can modify the code to be able to apply streaming to work with batch_size > 1. Things to consider:
* You may need to modify signatures in order to accomodate additional "length" argument.
* You may need to define `combine_states` methods, in order to merge states from different examples to one batch.
* You only need to work with input lengths divisible by encoder step.


    