# LSTM Optimization with Groq API

In this notebook, we will:
 - Describe an LSTM and create a NumPy baseline model to compare against
 - Design reusable components for LSTM sub-kernels
 - Design an optimized LSTM in Groq API
 - Visualize and analyze the performance results with GroqView
 - Compare the results with state-of-the-art performances.

## LSTM Introduction

Long Short-Term Memory (LSTM) models are a class of recurrent neural network (RNN). 

![LSTM dataflow graph](img/lstm-dag.png)

[Source: Microsoft Brainwave architecture ISCA paper](https://www.microsoft.com/en-us/research/uploads/prod/2018/06/ISCA18-Brainwave-CameraReady.pdf)

The figure above shows the compute graph for a single LSTM layer. We can observe a few traits:
 - The primary compute kernel is matrix-vector multiplication (MVM)
   - MVMs have poor data reuse, which makes this an SRAM bandwidth-bound problem.
   - We should be able to get strong batching performance on the GroqChip™ processor by packing multiple activation vectors into a matrix.
 - There is high potential for VXM chaining
   - Few tensors must be written to memory.
   - Many tensors are only consumed once.

# LSTM Optimization Plan
The follow diagram shows our plan for mapping the LSTM cell on to Groq API components. 

We will define 4 major components:

1. MVMQuad: 4x concurrent matrix-vector multiplication (MVM) ops, one per LSTM gate.
1. DequantizeBias: Convert from the MXM output format, quantized `int32`, to `float32` then add the bias vector. Support 4x concurrent DequantizeBias operations.
1. Activations: Apply an activation function, either `sigmoid` or `hyperbolic tangent` to each of the 4x streams from DequantizeBias. Support 2x concurrent activation operations.
1. Reduction: Pipeline of vector operations that reduce the 4x streams from the activations into the next-state vectors `cnext` and `hnext`

![LSTM Mapping](img/lstm_mapping.png)

## Boilerplate Imports

In [None]:
import groq.api as g
from groq.runner import tsp
from groq.api import nn
import numpy as np

## Quantization

The literature shows that LSTMs can maintain high accuracy using narrow precision data types. Given that our performance is SRAM bandwidth-bound, we want to maximize our effective bandwidth by minimizing the data type size. On GroqChip, we select the int8 data type for matmuls (and therefore weight and activation data) because it is the narrowest available type. VXM functions will remain in float32.

We select a fixed point 1.7 scheme on top of int8 so that we can express fractional numbers, which is important in LSTMs because the sigmoid and tanh activation functions expect a narrow range of inputs. Ideally, the fixed point scheme is chosen based on calibration against real-world data.

To accommodate our chosen data types we must:
 - Offline, quantize our weights and inputs to fixed point
 - Online, dequantize the int32 (fixed point 18.14) matmul results into float32 by:
   - Cast int32 to float32
   - Multiply by a scaling factor of 2^-14



In [None]:

quantize_scale = 2 ** 7
dequantize_scale = 1 / (2 ** 14)

# Quantize float32 to fixed point 1.7 format
def quantize(tensor, scale):
    tensor = tensor * scale
    tensor = tensor.astype(np.int8)
    return tensor

# Dequantize fixed point 18.14 to float32
def dequantize(tensor, scale):
    tensor = tensor.astype(np.float32)
    tensor = tensor * scale
    return tensor

## Data Preparation
Let's set up some randomized data for the activations, weights, bias, and constants. We can also perform offline quantization here. 

It is important to note that LSTMs are numerically fickle and it is important to carefully choose the range for our random initialization. The reason is that sigm() and tanh() expect a narrow range of input values and it is easy to accidentally saturate them. 

In [None]:
hidden_size = 640
input_size = 640
batch = 3

w_shape = (input_size, hidden_size)
u_shape = (hidden_size, hidden_size)
w_shape_transposed = (w_shape[1], w_shape[0])
u_shape_transposed = (u_shape[1], u_shape[0])
x_shape = (batch, input_size)
v_shape = (batch, hidden_size)

in_concat_shape = (batch, input_size + hidden_size)
mat_concat_shape = (input_size + hidden_size, hidden_size)
mat_concat_transpose_shape = (mat_concat_shape[1], mat_concat_shape[0])

x_data = quantize((-2) * np.random.random_sample(size=x_shape) + 1, quantize_scale)
cprev_data = (-2) * np.random.random_sample(size=v_shape).astype(np.float32) + 1
hprev_data = quantize(
    (-2) * np.random.random_sample(size=v_shape).astype(np.float32) + 1,
    quantize_scale,
)

bi_data = (-64) * np.random.random_sample(size=v_shape).astype(np.float32) + 32
bf_data = (-64) * np.random.random_sample(size=v_shape).astype(np.float32) + 32
bo_data = (-64) * np.random.random_sample(size=v_shape).astype(np.float32) + 32
bc_data = (-64) * np.random.random_sample(size=v_shape).astype(np.float32) + 32

Wi_data = quantize(
    (-0.5) * np.random.random_sample(size=w_shape).astype(np.float32) + 0.25,
    quantize_scale,
)
Wf_data = quantize(
    (-0.5) * np.random.random_sample(size=w_shape).astype(np.float32) + 0.25,
    quantize_scale,
)
Wo_data = quantize(
    (-0.5) * np.random.random_sample(size=w_shape).astype(np.float32) + 0.25,
    quantize_scale,
)
Wc_data = quantize(
    (-0.5) * np.random.random_sample(size=w_shape).astype(np.float32) + 0.25,
    quantize_scale,
)

Ui_data = quantize(
    (-0.5) * np.random.random_sample(size=u_shape).astype(np.float32) + 0.25,
    quantize_scale,
)
Uf_data = quantize(
    (-0.5) * np.random.random_sample(size=u_shape).astype(np.float32) + 0.25,
    quantize_scale,
)
Uo_data = quantize(
    (-0.5) * np.random.random_sample(size=u_shape).astype(np.float32) + 0.25,
    quantize_scale,
)
Uc_data = quantize(
    (-0.5) * np.random.random_sample(size=u_shape).astype(np.float32) + 0.25,
    quantize_scale,
)

# Concatenate data as an optimization
WUi_data = np.concatenate((Wi_data, Ui_data), axis=0)
WUf_data = np.concatenate((Wf_data, Uf_data), axis=0)
WUo_data = np.concatenate((Wo_data, Uo_data), axis=0)
WUc_data = np.concatenate((Wc_data, Uc_data), axis=0)
xhprev_data = np.concatenate((x_data, hprev_data), 1)

# Constants
log2e = np.ones(shape=v_shape, dtype=np.float32) * np.log2(np.e)
ones = np.ones(shape=v_shape, dtype=np.float32)
dequantize_vector = np.full(
    shape=v_shape, fill_value=dequantize_scale, dtype=np.float32
)

## NumPy Baseline

We implement a baseline model in NumPy to ensure that our LSTM is functionally correct. By breaking it up into a few functions, the NumPy baseline will match the structure of our Groq API program.
 - Activation functions for sigm() and tanh(). NumPy has a tanh we can use but we need our own sigmoid.
 - The dequantize function we defined earlier.
 - lstmgate_np implements the pattern of matmul -> dequantize -> bias -> activation, which repeats 4 times in an LSTM layer.

In [None]:
def sigm_np(x):
    return 1 / (1 + np.exp2(np.log2(np.e) * np.negative(x)))


def lstmgate_np(x, hprev, b, W, U, dequantize, act_type):
    result = np.matmul(x.astype(np.int32), W.astype(np.int32))
    result = result + np.matmul(hprev.astype(np.int32), U.astype(np.int32))

    result = dequantize(result, dequantize_scale)

    result = result + b

    if act_type == "sigm":
        result = sigm_np(result)
    if act_type == "tanh":
        result = np.tanh(result)

    return result


def lstm_np(
    x, cprev, hprev, bi, bf, bo, bc, Wi, Wf, Wo, Wc, Ui, Uf, Uo, Uc, dequantize
):
    i_gate = lstmgate_np(x, hprev, bi, Wi, Ui, dequantize, "sigm")
    f_gate = lstmgate_np(x, hprev, bf, Wf, Uf, dequantize, "sigm")
    o_gate = lstmgate_np(x, hprev, bo, Wo, Uo, dequantize, "sigm")
    c_gate = lstmgate_np(x, hprev, bc, Wc, Uc, dequantize, "tanh")

    cnext = f_gate * cprev + i_gate * c_gate

    hnext = np.tanh(cnext) * o_gate

    return (cnext, hnext)

## Groq API Components
We implement a set of composable Groq API components that can be used to build an LSTM that matches our NumPy baseline above. 


### VXM Components Legend
We diagram our LSTM VXM components (ALU allocation and chaining) superimposed on the VXM topology below. The topology is not included in these diagrams for readability. 

VXM diagrams give the ALU number and opcode for each operation.

![VXM topology](img/vxm.png)

### Dequantize -> Bias

We fuse the dequantize and bias functions into a single VXM pass because the LSTM always performs them in sequence. We also write the component so that it can take streaming input, which allows it to be chained after a matmul. 

To allow 4x concurrency, we find 4 paths through the VXM, two starting in the western hemisphere and 2 starting in the east. These paths must be valid in terms of both VXM placement and stream routing, and cannot conflict with each other. We also specify the layout of our scale and bias constants so they enter the VXM from the same direction as input data. Finally, we have our outputs exit the VXM to the opposite hemisphere from the input, which we will need to take into account when laying out the activation functions.

The following diagram shows a single dequantize -> bias VXM kernel:

![Dequantize Bias x1](img/dequantize_bias_x1.png)

This is the mapping for 4x concurrent DequantizeBias:

![Dequantize Bias x4](img/dequantize_bias_x4.png)


In [None]:
class DequantizeBias(g.tensor.Component):
    def __init__(self, gate_name, bias_data, scale_data, mxm_plane=0, **kwargs):
        super().__init__(**kwargs)

        self.gate_name = gate_name
        self.bias_data = bias_data
        self.scale_data = scale_data
        self.mxm_plane = mxm_plane

        # Allocate ALUs, streams, and layouts
        if mxm_plane == 0:
            # Start in NW corner of VXM
            alus = [8, 9, 14]
            self.output_stream = [g.SG4_E[7]]
            self.scale_stream = [g.SG4_E[6]]
            self.bias_stream = [g.SG4_E[0]]
            self.layout = "H1(W), -1, S4"
        elif mxm_plane == 1:
            # Start in SW corner of VXM
            alus = [5, 2, 3]
            self.output_stream = [g.SG4_E[3]]
            self.scale_stream = [g.SG4_E[1]]
            self.bias_stream = [g.SG4_E[2]]
            self.layout = "H1(W), -1, S4"
        elif mxm_plane == 2:
            # Start in NE corner of VXM
            alus = [10, 13, 12]
            self.output_stream = [g.SG4_W[7]]
            self.scale_stream = [g.SG4_W[6]]
            self.bias_stream = [g.SG4_W[1]]
            self.layout = "H1(E), -1, S4"
        else:  # 3
            # Start in SE corner of VXM
            alus = [7, 6, 1]
            self.output_stream = [g.SG4_W[3]]
            self.scale_stream = [g.SG4_W[2]]
            self.bias_stream = [g.SG4_W[0]]
            self.layout = "H1(E), -1, S4"

        self.alus = g.tensor.create_alu_request(alus=alus)

    def build(self, vec):
        # Constants
        self.scale_mt = g.from_data(
            data=self.scale_data,
            name="scale_{}".format(self.gate_name),
            layout=self.layout,
        )
        self.bias_mt = g.from_data(
            data=self.bias_data,
            name="bias_{}".format(self.gate_name),
            layout=self.layout,
        )

        scale_st = self.scale_mt.read(streams=self.scale_stream)
        bias_st = self.bias_mt.read(streams=self.bias_stream)

        result = g.cast(
            vec,
            g.float32,
            False,
            alus=self.alus[0],
            input_streams=self.output_stream,
            output_streams=self.output_stream,
            time=0,
        )
        result = g.mul(
            result, scale_st, alus=self.alus[1], output_streams=self.output_stream
        )
        result = g.add(
            result, bias_st, alus=self.alus[2], output_streams=self.output_stream
        )

        return result

### MvmQuad Component
The `MvmQuad` component performs 4 matrix-vector multiplies (MVMs), which stream into 4 dequantize->bias operations in parallel across the 4 MXM planes and 4 VXM rows.

The MVMs in a single hemisphere (eg, i and c), are interleaved in time to maximize use of the MEM slices that load the weights. Recall that MVM is a memory bandwidth bound problem, so it is vital to keep the MEM slices working as much as possible.

 - We map each gate matmul to an independent MXM plane
 - Weights in the same hemisphere share the same mem slices due to space constraints
 - Make 4 copies of the input so we can stream in parallel, place each next to its corresponding MXM plane
 - We start the matmuls at a slight offset, such that the mem slices are almost always busy reading out weights. Remember, this is a SRAM bandwidth-bound problem.
 - All 4 matmuls stream their results into VXM concurrently, each hitting a different DequantizeBias component in parallel

In [None]:
class MvmQuad(g.tensor.Component):
    def __init__(self, WUi, WUf, WUo, WUc, **kwargs):
        super().__init__(**kwargs)

        self.WUi_data = WUi
        self.WUf_data = WUf
        self.WUo_data = WUo
        self.WUc_data = WUc
        wait_btw_mxm_passes = 40
        self.mm_i = nn.MatMul(name="WUi_mm",wait_btw_passes=wait_btw_mxm_passes, planes=[0], arith_mode_warmup=True)
        self.mm_c = nn.MatMul(name="WUc_mm",wait_btw_passes=wait_btw_mxm_passes, planes=[1], arith_mode_warmup=True)
        self.mm_o = nn.MatMul(name="WUo_mm",wait_btw_passes=wait_btw_mxm_passes, planes=[2], arith_mode_warmup=True)
        self.mm_f = nn.MatMul(name="WUf_mm",wait_btw_passes=wait_btw_mxm_passes, planes=[3], arith_mode_warmup=True)

    def build(self, xhprev_0_mt, xhprev_1_mt, xhprev_2_mt, xhprev_3_mt, **kwargs):
        super().build(**kwargs)
        
        # Weights: Transpose is necessary because Groq and NumPy expect
        # column-major and row-major weights, respectively
        self.WUi_mt = g.from_data(
            data=self.WUi_data.transpose(),
            name="WUi_weights",
            layout="H1(W), -1, S16(12-39)",
        )

        self.WUo_mt = g.from_data(
            data=self.WUo_data.transpose(),
            name="WUo_weights",
            layout="H1(E), -1, S16(12-39)",
        )

        self.WUc_mt = g.from_data(
            data=self.WUc_data.transpose(),
            name="WUc_weights",
            layout="H1(W), -1, S16(12-39)",
        )

        self.WUf_mt = g.from_data(
            data=self.WUf_data.transpose(),
            name="WUf_weights",
            layout="H1(E), -1, S16(12-39)",
        )
        
        g.add_mem_constraints(
            [self.WUf_mt], 
            [self.WUo_mt], 
            g.MemConstraintType.NOT_MUTUALLY_EXCLUSIVE
        )
        g.add_mem_constraints(
            [self.WUi_mt],
            [self.WUc_mt],
            g.MemConstraintType.NOT_MUTUALLY_EXCLUSIVE,
        )
        with g.ResourceScope(name="mvm_x4", time=0):
            i_product_mt = self.mm_i(xhprev_0_mt, self.WUi_mt, time=0).write(
                layout="H1(W), -1, S4", name="i_product_mt"
            )
            o_product_mt = self.mm_o(xhprev_2_mt, self.WUo_mt, time=0).write(
                layout="H1(E), -1, S4", name="o_product_mt"
            )
            c_product_mt = self.mm_c(xhprev_1_mt, self.WUc_mt, time=20).write(
                layout="H1(W), -1, S4", name="c_product_mt"
            )
            f_product_mt = self.mm_f(xhprev_3_mt, self.WUf_mt, time=20).write(
                layout="H1(E), -1, S4", name="f_product_mt"
            )

        return (i_product_mt, f_product_mt, o_product_mt, c_product_mt)

### MvmQuad Unit Test
There's a lot flying around in parallel in `MvmQuad` so we should unit test. We recommend creating a unit test of every significant component of Groq API functionality.

In [None]:
def MvmQuad_unit_test():
    g.reset_program_context()

    xhprev_0_mt = g.from_data(
        data=xhprev_data, name="xhprev_cpy0", layout="H1(W), -1, S1(41)"
    )
    xhprev_1_mt = g.from_data(
        data=xhprev_data, name="xhprev_cpy1", layout="H1(W), -1, S1(40)"
    )
    xhprev_2_mt = g.from_data(
        data=xhprev_data, name="xhprev_cpy2", layout="H1(E), -1, S1(41)"
    )
    xhprev_3_mt = g.from_data(
        data=xhprev_data, name="xhprev_cpy3", layout="H1(E), -1, S1(40)"
    )

    unit = MvmQuad(WUi_data, WUf_data, WUo_data, WUc_data,)

    (i_mt, f_mt, o_mt, c_mt) = unit(
        xhprev_0_mt, xhprev_1_mt, xhprev_2_mt, xhprev_3_mt
    )

    print("Compiling...")

    # Compile and run
    iop_file = g.compile(
        base_name="mvmquad_test", result_tensor=[i_mt, f_mt, o_mt, c_mt]
    )
    gate_program = tsp.create_tsp_runner(iop_file)

    result = gate_program()

    g.write_visualizer_data("quad_mvm")

    # Oracle
    oracle_i = np.matmul(xhprev_data.astype(np.int32), WUi_data.astype(np.int32))
    oracle_o = np.matmul(xhprev_data.astype(np.int32), WUo_data.astype(np.int32))
    oracle_c = np.matmul(xhprev_data.astype(np.int32), WUc_data.astype(np.int32))
    oracle_f = np.matmul(xhprev_data.astype(np.int32), WUf_data.astype(np.int32))

    np.testing.assert_allclose(result["i_product_mt"], oracle_i, atol=0.01)
    np.testing.assert_allclose(result["o_product_mt"], oracle_o, atol=0.01)
    np.testing.assert_allclose(result["c_product_mt"], oracle_c, atol=0.01)
    np.testing.assert_allclose(result["f_product_mt"], oracle_f, atol=0.01)

    print("MvmQuad unit test success")


MvmQuad_unit_test()

### Sigmoid (sigm)
Our sigmoid function is based on Groq API's nn.sigmoid, however we wrote our own to get the flow of data we need for our top-level LSTM mapping. Specifically, we want data to enter the VXM on one side and exit on the opposite side, versus the nn.sigmoid, which has data exit the VXM on the same side it entered.

Our numerics are the same as our NumPy for the sigmoid above with one exception: GroqChip instructions do not include a division (`d = y/x`) or reciprocal instruction (`1/x`), so we implement division as `d = y * rsqrt(x)^2`.

We enable 2x concurrency by finding one mapping that stays in the top 2 rows of the VXM and another that stays in the bottom 2 rows.

A single sigmoid looks like this:

![Sigmoid_x1](img/sigm_x1.png)

The two concurrent sigmoids together look like this:

![Sigmoid_x2](img/sigm_x2.png)

In [None]:
class Sigm(g.tensor.Component):
    def __init__(self, use_upper=True, **kwargs):
        super().__init__(**kwargs)

        # Flow = neg (small) - mul (small) - exp (large) - add (small) - rsqrt (large) - mul (small)
        if use_upper:
            alus = [12, 13, 8, 9, 10, 11]
            self.compute_streams = [
                g.SG4_E[7],
                g.SG4_E[7],
                g.SG4_W[7],
                g.SG4_E[5],
                g.SG4_E[5],
                g.SG4_E[5],
                g.SG4_E[5],
            ]
            self.log2e_stream = g.SG4_E[6]
            self.ones_stream = g.SG4_E[4]
            self.layout = "H1(W), -1, S4"
        else:
            alus = [3, 2, 7, 6, 5, 4]
            self.compute_streams = [
                g.SG4_W[1],
                g.SG4_W[1],
                g.SG4_E[2],
                g.SG4_W[2],
                g.SG4_W[2],
                g.SG4_W[2],
                g.SG4_W[2],
            ]
            self.log2e_stream = g.SG4_W[0]
            self.ones_stream = g.SG4_W[3]
            self.layout = "H1(E), -1, S4"

        self.alus = g.tensor.create_alu_request(alus=alus)

    def build(self, vec):
        # Constants
        log2e_mt = g.from_data(data=log2e, name="log2e", layout=self.layout)
        ones_mt = g.from_data(data=ones, name="ones", layout=self.layout)

        log2e_st = log2e_mt.read(streams=self.log2e_stream)
        ones_st = ones_mt.read(streams=self.ones_stream)

        vec_st = vec.read(streams=self.compute_streams[0])

        result_st = g.neg(
            vec_st, alus=self.alus[0], output_streams=self.compute_streams[1], time=0,
        )
        result_st = g.mul(
            result_st,
            log2e_st,
            alus=self.alus[1],
            output_streams=self.compute_streams[2],
        )
        result_st = g.exp2(
            result_st, alus=self.alus[2], output_streams=self.compute_streams[3]
        )
        result_st = g.add(
            result_st,
            ones_st,
            alus=self.alus[3],
            output_streams=self.compute_streams[4],
        )
        result_st = g.rsqrt(
            result_st, alus=self.alus[4], output_streams=self.compute_streams[5]
        )
        result_st = g.mul(
            result_st,
            result_st,
            alus=self.alus[5],
            output_streams=self.compute_streams[6],
        )

        return result_st

### Hyperbolic Tangent (tanh)
Our tanh function uses 5 VXM ALUs to perform the entire function in a single chain. This function leverages the VXM's special `tanh` instruction, which requires the data to be scaled by a constant `pre-tanh`, cast to `int16`, then the `tanh` instruction, cast back to `float32`, and finally scaled by another constant `post-tanh`.

Our route through the VXM satisfies two use cases:
1. Concurrency with a `sigmoid` occupying the upper two rows of VXM
2. Composability with the other ops for `lstm reduction` 

Note that the code allows Groq API to infer the direction of the input stream, and we let the user set the direction of the output stream. This enables the flexibility for the two use cases above.

Tanh by itself (LSTM doesn't do this, but it's the simplest view of the tanh):

![Tanh](img/tanh.png)

Tanh concurrency with sigmoid:

![Tanh and Sigmoid](img/tanh_and_sigmoid.png)

See the LSTM Reduction section below to see how tanh is incorporated as a subcomponent.
                           

In [None]:
class Tanh(g.tensor.Component):
    def __init__(self, output_stream, **kwargs):
        super().__init__(**kwargs)

        self.output_stream = output_stream

    def build(self, vec):
        if vec.is_memory_tensor():
            start_time = 0
            vec_st = vec.read(streams=g.SG4_W[5])
        else:
            # vec is a stream, and we should inherit time from the stream
            start_time = None
            vec_st = vec

        pre_tanh_mt = g.from_data(
            np.ones((1, v_shape[1]), dtype=np.float32) * 5461.17,
            name="pre_tanh_mt",
            layout="H1(E),-1,S4",
        )
        post_tanh_mt = g.from_data(
            np.ones((1, v_shape[1]), dtype=np.float32) * 3.05185e-05,
            name="post_tanh_mt",
            layout="H1(W),-1,S4",
        )

        pre_tanh_st = pre_tanh_mt.read(streams=g.SG4_W[4])
        post_tanh_st = post_tanh_mt.read(streams=g.SG4_E[3])

        tmp_st = vec_st.mul(
            pre_tanh_st, alus=[6], output_streams=[g.SG4_E[3]], time=start_time
        )
        tmp_st = tmp_st.cast(
            dtype=g.int16, alus=[7], output_streams=[g.SG4_W[3]], fp16_inf=False
        )
        tmp_st = tmp_st.tanh(alus=[2], output_streams=[g.SG4_W[2]])
        tmp_st = tmp_st.cast(
            dtype=g.float32, alus=[5], output_streams=[g.SG4_W[2]], fp16_inf=False
        )
        tmp_st = tmp_st.mul(
            post_tanh_st, alus=[4], output_streams=[self.output_stream]
        )

        return tmp_st

### LSTM Reduction
We pack all of the pointwise ops after the LSTM gates into a single component, which executes all ops in a single VXM chain:

    cnext = f_gate * cprev + i_gate * c_gate
    hnext = tanh_np(cnext) * o_gate

The LSTM Reduction VXM mapping looks like this:

![LSTM Reduction](img/lstm_reduction.png)

In [None]:
class LstmReduction(g.tensor.Component):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.tanh = Tanh(g.SG4_E[3])

    def build(self, fgate_mt, cprev_mt, igate_mt, cgate_mt, ogate_mt, **kwargs):
        super().build(**kwargs)
        
        cgate_st = cgate_mt.read(streams=g.SG4_E[7])
        igate_st = igate_mt.read(streams=g.SG4_E[6])
        fgate_st = fgate_mt.read(streams=g.SG4_W[6])
        cprev_st = cprev_mt.read(streams=g.SG4_W[7])
        ogate_st = ogate_mt.read(streams=g.SG4_W[0])

        ic_st = igate_st.mul(cgate_st, alus=[12], output_streams=[g.SG4_E[6]], time=0)
        fcprev_st = fgate_st.mul(
            cprev_st, alus=[15], output_streams=[g.SG4_W[6]]
        )
        cnext_st = fcprev_st.add(
            ic_st, alus=[9], output_streams=[g.SG4_E[4], g.SG4_E[5]]
        )
        cnext_mt = cnext_st.write(
            streams=g.SG4_E[4], name="cnext_out", layout="H1(E), -1, S4"
        )

        tmp_st = self.tanh(cnext_st)
        hnext_mt = tmp_st.mul(ogate_st, alus=[1], output_streams=[g.SG4_W[1]]).write(
            name="hnext_out", layout="H1(W), -1, S4"
        )

        return (cnext_mt, hnext_mt)

In [None]:
def LstmReduction_unit_test():
    g.reset_program_context()

    fgate_data = np.random.rand(v_shape[0], v_shape[1]).astype(np.float32)
    cprev_data = np.random.rand(v_shape[0], v_shape[1]).astype(np.float32)
    igate_data = np.random.rand(v_shape[0], v_shape[1]).astype(np.float32)
    cgate_data = np.random.rand(v_shape[0], v_shape[1]).astype(np.float32)
    ogate_data = np.random.rand(v_shape[0], v_shape[1]).astype(np.float32)

    fgate_mt = g.from_data(fgate_data, name="fgate", layout="H1(E), -1, S4")
    cprev_mt = g.from_data(cprev_data, name="cprev", layout="H1(E), -1, S4")
    igate_mt = g.from_data(igate_data, name="igate", layout="H1(W), -1, S4")
    cgate_mt = g.from_data(cgate_data, name="cgate", layout="H1(W), -1, S4")
    ogate_mt = g.from_data(ogate_data, name="ogate", layout="H1(E), -1, S4")

    unit = LstmReduction()

    (cnext_mt, hnext_mt) = unit(fgate_mt, cprev_mt, igate_mt, cgate_mt, ogate_mt)

    print("Compiling...")

    # Compile and run
    iop_file = g.compile(
        base_name="lstmreduction_test", result_tensor=[cnext_mt, hnext_mt]
    )
    gate_program = tsp.create_tsp_runner(iop_file)

    result = gate_program()

    g.write_visualizer_data("lstm_reduction")

    # Oracle
    def oracle_np(fgate, cprev, igate, cgate, ogate):
        cnext = (fgate * cprev) + (igate * cgate)
        hnext = np.tanh(cnext) * ogate
        return (cnext, hnext)

    (cnext_np, hnext_np) = oracle_np(
        fgate_data, cprev_data, igate_data, cgate_data, ogate_data
    )

    np.testing.assert_allclose(result["cnext_out"], cnext_np, atol=0.01)
    np.testing.assert_allclose(result["hnext_out"], hnext_np, atol=0.01)

    print("LstmReduction unit test success")


LstmReduction_unit_test()

### Copying the Input
We create a quick component to take the input data (`xhprev_mt`, the concatenation of x and hprev) and copy it 4 times. We do this to produce an independent copy for each MXM plane to use concurrently.


In [None]:
class CopyQuad(g.tensor.Component):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def build(self, input_mt, **kwargs):
        super().build(**kwargs)

        with g.ResourceScope(
            name="input_copying_W", is_buffered=True, time=0
        ) as input_copying_W:
            input_st = input_mt.read(streams=g.SG1_W, time=0)
            copy_0W_mt = input_st.write(name="copy_0W_mt", layout="H1(W), -1, S1(41)")
            copy_1W_mt = input_st.write(name="copy_1W_mt", layout="H1(W), -1, S1(40)")

        with g.ResourceScope(
            name="input_copying_E",
            is_buffered=True,
            predecessors=[input_copying_W],
            time=None,
        ) as input_copying_E:
            input_st = input_mt.read(streams=g.SG1_E, time=0)
            copy_0E_mt = input_st.write(name="copy_0E_mt", layout="H1(E), -1, S1(41)")
            copy_1E_mt = input_st.write(name="copy_1E_mt", layout="H1(E), -1, S1(40)")

        return (copy_0W_mt, copy_1W_mt, copy_0E_mt, copy_1E_mt)


### Activations Component
We group all of the activation functions together to take advantage of the memory allocation features of buffered components. Specifically, we want to make sure that all of the inputs to LstmReduction are in mutually exclusive memory slices. This is accomplished by placing the producers of those tensors (`i_gate`, `f_gate`, `c_gate`, and `o_gate`) into the same component.

We'll do the same thing for the 4x dequantize bias for the sake of consistency in our programming style.

In [None]:
class QuadActivations(g.tensor.Component):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        self.i_activation = Sigm(name="i_activation", use_upper=False)
        self.c_activation = Tanh(name="c_activation", output_stream=g.SG4_W[2])

        self.f_activation = Sigm(name="f_activation", use_upper=True)
        self.o_activation = Sigm(name="o_activation", use_upper=True)

    def build(self, i_partial_mt, f_partial_mt, c_partial_mt, o_partial_mt, **kwargs):
        super().build(**kwargs)

        with g.ResourceScope(
            name="i_o_activations", is_buffered=True, time=0,
        ) as i_o_activations:
            i_st = self.i_activation(i_partial_mt)
            igate_mt = i_st.write(name="igate_mt", layout="H1(W), -1, S4")

            o_st = self.o_activation(o_partial_mt)
            ogate_mt = o_st.write(name="ogate_mt", layout="H1(E), -1, S4")

        with g.ResourceScope(
            name="f_c_activations",
            is_buffered=True,
            time=None,
            predecessors=[i_o_activations],
        ) as f_c_activations:
            c_st = self.c_activation(c_partial_mt)
            cgate_mt = c_st.write(name="cgate_mt", layout="H1(W), -1, S4")

            f_st = self.f_activation(f_partial_mt)
            fgate_mt = f_st.write(name="fgate_mt", layout="H1(E), -1, S4")

        return (igate_mt, fgate_mt, cgate_mt, ogate_mt)


class QuadDequantizeBias(g.tensor.Component):
    def __init__(self, bi, bf, bo, bc, dequantize, **kwargs):
        super().__init__(**kwargs)

        self.dequantize_bias_i = DequantizeBias("i", bi, dequantize, mxm_plane=0)
        self.dequantize_bias_c = DequantizeBias("c", bc, dequantize, mxm_plane=1)
        self.dequantize_bias_o = DequantizeBias("o", bo, dequantize, mxm_plane=2)
        self.dequantize_bias_f = DequantizeBias("f", bf, dequantize, mxm_plane=3)

    def build(self, i_product_mt, f_product_mt, o_product_mt, c_product_mt, **kwargs):
        super().build(**kwargs)

        i_partial_mt = self.dequantize_bias_i(i_product_mt).write(
            name="i_partial", layout="H1(E), -1, S4"
        )
        o_partial_mt = self.dequantize_bias_o(o_product_mt).write(
            name="o_partial", layout="H1(W), -1, S4"
        )
        c_partial_mt = self.dequantize_bias_c(c_product_mt).write(
            name="c_partial", layout="H1(E), -1, S4"
        )
        f_partial_mt = self.dequantize_bias_f(f_product_mt).write(
            name="f_partial", layout="H1(W), -1, S4"
        )

        return (i_partial_mt, f_partial_mt, c_partial_mt, o_partial_mt)

### Top-Level LSTM Component
Now we build the optimized LSTM component out of our existing components to match the LSTM optimization plan from the top of the notebook.

In [None]:
class Lstm(g.tensor.Component):
    def __init__(self, bi, bf, bo, bc, WUi, WUf, WUo, WUc, dequantize, **kwargs):
        super().__init__(**kwargs)

        self.copy_x4 = CopyQuad(name="copy_x4", is_buffered=True)
        self.mvm_x4 = MvmQuad(
            WUi,
            WUf,
            WUo,
            WUc,
            name="mvm_x4",
            is_buffered=True,
            is_resource_scope=True,
        )
        self.dequantize_bias_x4 = QuadDequantizeBias(
            bi,
            bf,
            bo,
            bc,
            dequantize,
            name="dequantize_bias_x4",
            is_buffered=True,
            is_resource_scope=True,
        )
        self.activations_x4 = QuadActivations(name="activations_x4", is_buffered=True)
        self.reduction = LstmReduction(
            name="reduction", is_buffered=True, is_resource_scope=True
        )

    def build(self, xhprev_mt, cprev_mt):

        # Copy the input 4 times
        (xhprev_0_mt, xhprev_1_mt, xhprev_2_mt, xhprev_3_mt) = self.copy_x4(xhprev_mt, time=0)

        # Perform 4x MVMs in parallel
        (i_product_mt, f_product_mt, o_product_mt, c_product_mt) = self.mvm_x4(
            xhprev_0_mt, xhprev_1_mt, xhprev_2_mt, xhprev_3_mt, predecessors=[self.copy_x4]
        )

        # Perform 4x dequantize -> bias operations
        (
            i_partial_mt,
            f_partial_mt,
            c_partial_mt,
            o_partial_mt,
        ) = self.dequantize_bias_x4(
            i_product_mt,
            f_product_mt,
            o_product_mt,
            c_product_mt,
            predecessors=[self.mvm_x4],
        )

        # Run the activation function on each of our 4 tensors
        (igate_mt, fgate_mt, cgate_mt, ogate_mt) = self.activations_x4(
            i_partial_mt,
            f_partial_mt,
            c_partial_mt,
            o_partial_mt,
            predecessors=[self.dequantize_bias_x4],
        )

        # Final reduction operations to get c next and h next
        (cnext_mt, hnext_mt) = self.reduction(
            fgate_mt,
            cprev_mt,
            igate_mt,
            cgate_mt,
            ogate_mt,
            predecessors=[self.activations_x4],
        )

        return (cnext_mt, hnext_mt)

### LSTM Unit Test

In [None]:
def Lstm_unit_test():
    g.reset_program_context()
    
    xhprev_mt = g.from_data(
        data=xhprev_data, name="xhprev_cpy0", layout="H1(W), -1, S1(0)"
    )

    cprev_mt = g.from_data(
        data=cprev_data, name="cprev_mt", layout="H1(E), -1, S4"
    )

    model = Lstm(
        bi_data,
        bf_data,
        bo_data,
        bc_data,
        WUi_data,
        WUf_data,
        WUo_data,
        WUc_data,
        dequantize_vector,
    )

    (cnext_mt, hnext_mt) = model(xhprev_mt, cprev_mt)

    print("Compiling...")

    # Compile and run
    iop_file = g.compile(
        base_name="lstm_test", result_tensor=[cnext_mt, hnext_mt]
    )
    gate_program = tsp.create_tsp_runner(iop_file)

    result = gate_program()

    g.write_visualizer_data("lstm")

    (cnext_np, hnext_np) = lstm_np(
        x_data,
        cprev_data,
        hprev_data,
        bi_data,
        bf_data,
        bo_data,
        bc_data,
        Wi_data,
        Wf_data,
        Wo_data,
        Wc_data,
        Ui_data,
        Uf_data,
        Uo_data,
        Uc_data,
        dequantize,
    )

    np.testing.assert_allclose(result["cnext_out"], cnext_np, atol=0.01)
    np.testing.assert_allclose(result["hnext_out"], hnext_np, atol=0.01)

    print("Optimized LSTM unit test success")


Lstm_unit_test()

In [None]:

def calcLstmPerformance(input_size, hidden_size, batch, cycles):
    alan_clock_rate = 900000000

    # Calculate ops as:
    #  - Consider MXM ops only
    #  - MVM of size [1 x N] * [N x M] has M*N multiply-accumulates (MACs)
    #  - 1 add and 1 multiply op per MAC
    #  - Matrix size is (input_size + hidden_size) * hidden_size
    ops = batch * 2 * 4 * (input_size + hidden_size) * hidden_size
    seconds = cycles / alan_clock_rate
    microseconds = seconds * 1000000
    tops = (ops / seconds) / (1000000000000)

    print("----- Performance -----")

    print("batch                    =", batch)
    print("input size               =", input_size)
    print("hidden size              =", hidden_size)
    print("ops                      =", ops)
    print("latency (cycles)         =", cycles)
    print("latency (microseconds)   = {0:.2f}".format(microseconds))
    print("tops                     = {0:.2f}".format(tops))

### Get the performance
1. Run the GroqView command `groqview lstm/visdata.json`
1. Run `calcLstmPerformance()` against the GroqView cycle count

In [None]:
calcLstmPerformance(512, 512, 1, 670)
calcLstmPerformance(512, 512, 3, 718)
calcLstmPerformance(512, 512, 6, 790)
calcLstmPerformance(512, 512, 8, 838)
calcLstmPerformance(512, 512, 32, 1413)

calcLstmPerformance(1024, 1024, 1, 1958)
calcLstmPerformance(1024, 1024, 3, 2086)
calcLstmPerformance(1024, 1024, 6, 2278)
calcLstmPerformance(1024, 1024, 8, 2406)
calcLstmPerformance(1024, 1024, 32, 3941)

## State of the Art (SOTA) Comparison
[The most recent paper on LSTM acceleration is a collaboration between FPGA royalty from Intel and top universities published in late 2020.](https://www.intel.com/content/dam/www/public/us/en/documents/white-papers/a1153843-beyond-peak-performance-white-paper.pdf)

FPGAs are a strong target for LSTM acceleration thanks to:
 - Many small distributed SRAMs that provide high memory bandwidth to MVMs
 - Flexibility to specify deeply pipelined datapaths for pointwise vector ops

The paper above targets the Intel Stratix 10 NX FPGA, which Intel's Stratix 10 MX (HBM) FPGA modified with special block floating point matmul digital signal processing units. The device has a theoretical peak 120 TOPS at 500 MHz at int8 / block-float-16 precision. 

The authors map a clone of the Microsoft Brainwave RNN architecture to the NX and compare against Nvidia's V100 and T4 GPUs. See below for a comparison of results across different LSTM sizes and batch sizes.

Copy of Intel's results, for convenience: 

![Intel LSTM benchmarks](img/intel-lstm-results.png)

Intel's results for LSTM-512, plus GroqChip 1 results:

![Groq LSTM-512 benchmarks](img/lstm-512.png)

Intel's results for LSTM-1024, plus GroqChip 1 results:

![Groq LSTM-1024 benchmarks](img/lstm-1024.png)