In [6]:
from tensor import Tensor, TensorSpec, TensorShape
from utils.index import Index
from algorithm import vectorize
from sys.info import simdwidthof
from random import rand
import math
from python import Python
let json = Python.import_module("json")

"""
    B: batch size                       (`B` in Mamba paper [1] Algorithm 2)
    L: sequence length                  (`L` in [1] Algorithm 2)
    D_MODEL: hidden dim
    D_STATE: latent state dim           (`N` in [1] Algorithm 2)
    EXPAND: expansion factor            (`E` in [1] Section 3.4)
    D_INNER: D_STATE * EXPAND           (`D` in [1] Algorithm 2)
    A, B, C, D: state space parameters  (See any state space representation formula)
                                        (B, C are input-dependent (aka selective, a key innovation in Mamba); A, D are not)
    Δ or delta: input-dependent step size
    dt_rank: rank of Δ                  (See [1] Section 3.6 "Parameterization of ∆")

"""
let B = 4
let L = 128
let D_MODEL = 512
let D_STATE = 128
let EXPAND = 4
let D_INNER = D_STATE * EXPAND

let input_ids = rand[DType.uint32](4)
print(input_ids)

Tensor([[2042135635, 1345947507, 2073208346, 3565425363]], dtype=uint32, shape=4)


In [12]:
# TODO: implement
fn get_embedding(input_ids: Tensor[DType.uint32]) -> Tensor[DType.float32]:
    # input_ids:  Tensor of shape (b, 1)
    # return:     Tensor of shape (b, 1, d_model)
    let EMBEDDING_DIM = 4
    return rand[DType.float32](EMBEDDING_DIM)

print(get_embedding(input_ids))

Tensor([[0.51489287614822388, 0.78978455066680908, 0.544272780418396, 0.093629911541938782]], dtype=float32, shape=4)


1


In [29]:
alias floattensor = Tensor[DType.float32]
fn naive_matmul(A: floattensor, B: floattensor) -> floattensor:
    var output = Tensor[DType.float32](A.shape()[0], B.shape()[1])
    for i in range(A.shape()[0]):
        for j in range(B.shape()[1]):
            for k in range(A.shape()[1]):
                output[i][j] += A[i][k] * B[k][j]
    return output

In [33]:
struct LinearLayer:
    var D_IN: Int
    var D_OUT: Int
    var weights: Tensor[DType.float32]
    var bias: Tensor[DType.float32]
    var add_bias: Bool
    
    fn __init__(inout self, in_features: Int, out_features: Int, add_bias: Bool = False):
        self.D_IN = in_features
        self.D_OUT = out_features
        self.weights = rand[DType.float32](self.D_IN, self.D_OUT)
        self.bias = rand[DType.float32](self.D_OUT, 1)
        self.add_bias = add_bias
    
    fn forward(self, input: Tensor[DType.float32]) -> Tensor[DType.float32]:
        let output: Tensor[DType.float32] = naive_matmul(self.weights, input)
        return output

In [None]:
struct Mamba:
    var in_projection: LinearLayer
    var x_projection: LinearLayer
    var dt_projection: LinearLayer
    var out_projection: LinearLayer

    fn __init__(inout self):
        self.in_projection = LinearLayer()
        self.x_projection = LinearLayer()
        self.dt_projection = LinearLayer()
        self.out_projection = LinearLayer()
        
    self.in_proj = nn.Linear(args.d_model, args.d_inner * 2, bias=args.bias)

        self.conv1d = nn.Conv1d(
            in_channels=args.d_inner,
            out_channels=args.d_inner,
            bias=args.conv_bias,
            kernel_size=args.d_conv,
            groups=args.d_inner,
            padding=args.d_conv - 1,
        )

        # x_proj takes in `x` and outputs the input-specific Δ, B, C
        self.x_proj = nn.Linear(args.d_inner, args.dt_rank + args.d_state * 2, bias=False)
        
        # dt_proj projects Δ from dt_rank to d_in
        self.dt_proj = nn.Linear(args.dt_rank, args.d_inner, bias=True)

        A = repeat(torch.arange(1, args.d_state + 1), 'n -> d n', d=args.d_inner)
        self.A_log = nn.Parameter(torch.log(A))
        self.D = nn.Parameter(torch.ones(args.d_inner))
        self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=args.bias)

In [26]:
# TODO: implement this
fn compute_layer(input: Tensor[DType.float32]) -> Tensor[DType.float32]:
    # return a random f32 tensor for now
    return rand[DType.float32](input.shape())

In [6]:
fn RMS_norm(input: Tensor[DType.float32]) -> Tensor[DType.float32]:
    var mean: Float32 = 0.0
    var squares = Tensor[DType.float32](input.shape())
    var output = Tensor[DType.float32](input.shape())
    alias simd_float32_width: Int = simdwidthof[DType.float32]()
    
    # Vectorized squaring of input
    @parameter
    fn square_tensor[simd_float32_width: Int](idx: Int) -> None:
        squares.simd_store[simd_float32_width]
                (idx, math.pow(input.simd_load[simd_float32_width](idx), 2))
    vectorize[simd_float32_width, square_tensor](input.num_elements())

    # Sum up squares
    @parameter
    fn sum_squares[simd_float32_width: Int](idx: Int) -> None:
        mean += squares.simd_load[simd_float32_width](idx).reduce_add()
    vectorize[simd_float32_width, sum_squares](squares.num_elements())

    mean /= squares.num_elements()
    
    let rms: Float32 = math.sqrt(mean)
    @parameter
    fn divide_by_rms[simd_float32_width: Int](idx: Int) -> None:
        output.simd_store[simd_float32_width](idx, input.simd_load[simd_float32_width](idx) / rms)
    vectorize[simd_float32_width, divide_by_rms](output.num_elements())

    return output

In [None]:
fn in_projection(input: Tensor[DType.float32]) -> Tensor[DType.float32]:
    # input: Tensor of shape (b, 1, d_model)
    # return: Tensor of shape (b, 1, d_model)
    let d_model = input.shape()[2]
    let W = rand[DType.float32](TensorShape(d_model, d_model))
    return input.matmul(W)

In [23]:
# Mamba block as shown in Fig. 3 in Section 3.4 of the paper
fn compute_mamba_block(input: Tensor[DType.float32]) -> Tensor[DType.float32]:
    let b: Int = input.shape()[0]
    let l: Int = input.shape()[1]
    let d: Int = input.shape()[2]
    
    # in_projection takes in input and outputs the input-specific Δ, B, C
    let x_and_residual = in_projection(input)
    
    # split x and residual
    # rearrange and do conv1d
    # rearrange back
    # silu activation
    # y = ssx(x)
    # y *= silu(residual)
    # let output = out_projection(y)

    return rand[DType.float32](4)

In [24]:
fn residual_block_forward(input: Tensor[DType.float32]) -> Tensor[DType.float32]:
    alias simd_float32_width: Int = simdwidthof[DType.float32]()
    var output = Tensor[DType.float32](input.shape())

    let normalized_input = RMS_norm(input)
    let mamba_forward = compute_mamba_block(normalized_input)

    @parameter
    fn residual_add[simd_float32_width: Int](idx: Int) -> None:
        output.simd_store[simd_float32_width](idx, math.add(mamba_forward.simd_load[simd_float32_width](idx), 
                                                            input.simd_load[simd_float32_width](idx)))
    vectorize[simd_float32_width, residual_add](output.num_elements())

    return output

In [62]:
# TODO: implement this
fn lm_head(input: Tensor[DType.float32]) -> Tensor[DType.float32]:
    # return a random f32 tensor for now
    return rand[DType.float32](input.shape())

In [50]:
# TODO: implement this
fn mamba_forward_pass(input_ids: Tensor[DType.uint32]) -> Tensor[DType.float32]:
    let num_layers = 2
    var x = get_embedding(input_ids)
    
    for i in range(num_layers):
        x = compute_layer(x)
            
    let normalized = RMS_norm(x)
    let logits = lm_head(normalized)

    return logits