In [1]:
import numpy as np
import mlx.core as mx
import time
import h5py

In [2]:
rng = np.random.default_rng(1)

In [3]:
weights_path = './vicuna_weight.h5'

weights = []
w_input = []
attn_weights = []
aw_input = []
q_weights = []
k_weights = []

# with h5py.File(weights_path, 'r') as weight_file:
#     for layer_name in weight_file:
#         w = np.squeeze(np.array(weight_file[layer_name])).astype(np.float32)
#         if "model" in layer_name and "embed_tokens" not in layer_name and "layernorm" not in layer_name:
#             weights.append(w)
#             w_input.append(rng.random(w.shape, dtype = np.float32))
#         if "attn" in layer_name:
#             attn_weights.append(w)
#             aw_input.append(rng.random(w.shape[1], dtype = np.float32))
#             if "q_proj" in layer_name:
#                 q_weights.append(w)
#             if "k_proj" in layer_name:
#                 k_weights.append(w)
#         if (len(q_weights) > 5 and len(q_weights) == len(k_weights)):
#             break

with h5py.File(weights_path, 'r') as weight_file:
    for layer_name in weight_file:
        w = np.squeeze(np.array(weight_file[layer_name])).astype(np.float16)
        if "model" in layer_name and "embed_tokens" not in layer_name and "layernorm" not in layer_name:
            weights.append(w)
            w_input.append(rng.random(w.shape, dtype = np.float32))
        if "attn" in layer_name:
            attn_weights.append(w)
            aw_input.append(rng.random(w.shape[1], dtype = np.float32))
            if "q_proj" in layer_name:
                q_weights.append(w)
            if "k_proj" in layer_name:
                k_weights.append(w)

In [4]:
def timer(input1, input2, f, runner):
    runs = 10
    times = []
    for _ in range(runs):
        times.append(runner(input1, input2, f))
    times = np.array(times)
    print(f"{runner.__name__[:-6]}mlx")
    print(f"{np.average(times)}ms +/- {np.std(times)}ms")

In [5]:
def transformer_part4_mx(input1, input2, hidden_dim):
    return (input1[:hidden_dim]) * (input2[:hidden_dim])

# def transformer_part4_runner(inputs1, inputs2, f=None):
#     total_time = 0
#     for i in range(len(inputs1)):
#         input1 = inputs1[i].flatten()
#         input2 = inputs2[i].flatten()
#         hd = len(input1)

#         inp1 = mx.array(input1, mx.float32)
#         inp2 = mx.array(input2, mx.float32)
#         hidden_dim = hd
        
#         start_time = time.perf_counter()
#         mx.eval(transformer_part4_mx(inp1, inp2, hidden_dim))
#         end_time = time.perf_counter()
#         del inp2
#         del inp1
#         total_time += (end_time - start_time) * 1000
#     return total_time


def transformer_part4_runner(inputs1, inputs2, f=None):
    total_time = 0
    for i in range(len(inputs1)):
        input1 = inputs1[i].flatten().astype(np.float32)
        input2 = inputs2[i].flatten().astype(np.float32)
        hd = len(input1)

        inp1 = mx.array(input1, mx.float32)
        inp2 = mx.array(input2, mx.float32)
        hidden_dim = hd
        
        start_time = time.perf_counter()
        mx.eval(transformer_part4_mx(inp1, inp2, hidden_dim))
        end_time = time.perf_counter()
        del inp2
        del inp1
        total_time += (end_time - start_time) * 1000
    return total_time

In [6]:
def matmul_mx(weight, input):
    return mx.matmul(weight, input)

# def matmul_runner(weights, inputs, f=None):
#     total_time = 0
#     for i in range(len(inputs)):
#         weight = weights[i]
#         input = inputs[i]

#         w = mx.array(weight, mx.float32)
#         inp = mx.array(input, mx.float32)
        
#         start_time = time.perf_counter()
#         mx.eval(matmul_mx(w, inp))
#         end_time = time.perf_counter()
#         del inp
#         del w
#         total_time += (end_time - start_time) * 1000
#     return total_time

def matmul_runner(weights, inputs, f=None):
    total_time = 0
    for i in range(len(inputs)):
        weight = weights[i].astype(np.float32)
        input = inputs[i].astype(np.float32)

        w = mx.array(weight, mx.float32)
        inp = mx.array(input, mx.float32)
        
        start_time = time.perf_counter()
        mx.eval(matmul_mx(w, inp))
        end_time = time.perf_counter()
        del inp
        del w
        total_time += (end_time - start_time) * 1000
    return total_time

In [7]:
def transformer_part1_mx(token_position, head, head_size, key_cache_layer, q):
    return (mx.matmul(key_cache_layer[:token_position][:, (head) * (head_size):(head) * (head_size) + head_size], q[(head) * (head_size):(head) * (head_size) + head_size])) / (mx.sqrt(mx.array([(head_size) * (1)])))

# def transformer_part1_runner(k_matrixes, q_matrixes, f=None):
#     total_time = 0
#     for i in range(len(k_matrixes)):
#         k_matrix = k_matrixes[i]
#         q_matrix = q_matrixes[i]
#         token_position = k_matrix.shape[0] - 1

#         num_head = 32
#         head = int(rng.integers(low=0, high=num_head))
#         head_size = k_matrix.shape[0] // num_head
        
#         key_cache_layer = mx.array(k_matrix, mx.float32)
#         q = mx.array(q_matrix.flatten(), mx.float32)
        
#         start_time = time.perf_counter()
#         mx.eval(transformer_part1_mx(token_position, head, head_size, key_cache_layer, q))
#         end_time = time.perf_counter()
#         del key_cache_layer
#         del q
#         total_time += (end_time - start_time) * 1000
#     return total_time

def transformer_part1_runner(k_matrixes, q_matrixes, f=None):
    total_time = 0
    for i in range(len(k_matrixes)):
        k_matrix = k_matrixes[i].astype(np.float32)
        q_matrix = q_matrixes[i].astype(np.float32)
        token_position = k_matrix.shape[0] - 1

        num_head = 32
        head = int(rng.integers(low=0, high=num_head))
        head_size = k_matrix.shape[0] // num_head
        
        key_cache_layer = mx.array(k_matrix, mx.float32)
        q = mx.array(q_matrix.flatten(), mx.float32)
        
        start_time = time.perf_counter()
        mx.eval(transformer_part1_mx(token_position, head, head_size, key_cache_layer, q))
        end_time = time.perf_counter()
        del key_cache_layer
        del q
        total_time += (end_time - start_time) * 1000
    return total_time


In [28]:
def transformer_part2_mx(token_position, head, head_size, key_cache_layer, attention):
    return mx.matmul(mx.transpose(key_cache_layer[:(token_position) + (1)][:, (head) * (head_size):(head) * (head_size) + head_size]), attention[:(token_position) + (1)])

# def transformer_part2_runner(k_matrixes, q_matrixes, f=None):
#     total_time = 0
#     for i in range(len(k_matrixes)):
#         k_matrix = k_matrixes[i]
#         q_matrix = q_matrixes[i]
#         token_position = k_matrix.shape[0] - 1

#         num_head = 32
#         head = int(rng.integers(low=0, high=num_head))
#         head_size = k_matrix.shape[0] // num_head
        
#         key_cache_layer = mx.array(k_matrix, mx.float32)
#         q = mx.array(q_matrix.flatten(), mx.float32)

#         attention = transformer_part1_mx(token_position, head, head_size, key_cache_layer, q)
#         attention = mx.concatenate([attention, mx.array([0])])
        
#         start_time = time.perf_counter()
#         mx.eval(transformer_part2_mx(token_position, head, head_size, key_cache_layer, attention))
#         end_time = time.perf_counter()
#         del key_cache_layer
#         del attention
#         total_time += (end_time - start_time) * 1000
#     return total_time

def transformer_part2_runner(k_matrixes, q_matrixes, f=None):
    total_time = 0
    for i in range(len(k_matrixes)):
        k_matrix = k_matrixes[i].astype(np.float32)
        q_matrix = q_matrixes[i].astype(np.float32)
        token_position = k_matrix.shape[0] - 1

        num_head = 32
        head = int(rng.integers(low=0, high=num_head))
        head_size = k_matrix.shape[0] // num_head
        
        key_cache_layer = mx.array(k_matrix, mx.float32)
        q = mx.array(q_matrix.flatten(), mx.float32)

        attention = transformer_part1_mx(token_position, head, head_size, key_cache_layer, q)
        attention = mx.concatenate([attention, mx.array([0])])
        
        start_time = time.perf_counter()
        mx.eval(transformer_part2_mx(token_position, head, head_size, key_cache_layer, attention))
        end_time = time.perf_counter()
        del key_cache_layer
        del attention
        total_time += (end_time - start_time) * 1000
    return total_time

In [9]:
def rmsnorm_part1_mx(input, weight):
    return mx.sum((input) * (input))

# def rmsnorm_part1_runner(weights, inputs, f=None):
#     total_time = 0
#     for i in range(len(inputs)):
#         input = inputs[i].flatten()
#         weight = weights[i].flatten()

               
#         inp = mx.array(input, mx.float32)
#         w = mx.array(weight, mx.float32)

#         start_time = time.perf_counter()
#         mx.eval(rmsnorm_part1_mx(inp, w))
#         end_time = time.perf_counter()
#         del w
#         del inp
#         total_time += (end_time - start_time) * 1000
#     return total_time

def rmsnorm_part1_runner(weights, inputs, f=None):
    total_time = 0
    for i in range(len(inputs)):
        input = inputs[i].flatten().astype(np.float32)
        weight = weights[i].flatten().astype(np.float32)

               
        inp = mx.array(input, mx.float32)
        w = mx.array(weight, mx.float32)

        start_time = time.perf_counter()
        mx.eval(rmsnorm_part1_mx(inp, w))
        end_time = time.perf_counter()
        del w
        del inp
        total_time += (end_time - start_time) * 1000
    return total_time

In [27]:
def rmsnorm_part2_mx(input, weight, ss):
    return ((1) / (mx.sqrt(mx.array([((ss) / (input.size)) + (1)])))) * ((input) * (weight))

# def rmsnorm_part2_runner(weights, inputs, f=None):
#     total_time = 0
#     for i in range(len(inputs)):
#         input = inputs[i].flatten()
#         weight = weights[i].flatten()
#         ssum = np.sum(input * input)

#         inp = mx.array(input, mx.float32)
#         w = mx.array(weight, mx.float32)
#         ss = mx.array(ssum, mx.float32)

#         start_time = time.perf_counter()
#         mx.eval(rmsnorm_part2_mx(inp, w, ss))
#         end_time = time.perf_counter()
#         del w
#         del inp
#         total_time += (end_time - start_time) * 1000
#     return total_time

def rmsnorm_part2_runner(weights, inputs, f=None):
    total_time = 0
    for i in range(len(inputs)):
        input = inputs[i].flatten().astype(np.float32)
        weight = weights[i].flatten().astype(np.float32)
        ssum = np.sum(input * input)

        inp = mx.array(input, mx.float32)
        w = mx.array(weight, mx.float32)
        ss = ssum
        
        start_time = time.perf_counter()
        mx.eval(rmsnorm_part2_mx(inp, w, ss))
        end_time = time.perf_counter()
        del w
        del inp
        total_time += (end_time - start_time) * 1000
    return total_time

In [11]:
def transformer_part3_mx(input, hidden_dim):
    return (input[:hidden_dim]) * ((1) / ((1) + (mx.exp((0) - (input[:hidden_dim])))))

# def transformer_part3_runner(inputs, _, f=None):
#     total_time = 0
#     for i in range(len(inputs)):
#         input = inputs[i].flatten()
#         hd = len(input)

#         inp = mx.array(input, mx.float32)
#         hidden_dim = hd
#         start_time = time.perf_counter()
#         mx.eval(transformer_part3_mx(inp, hidden_dim))
#         end_time = time.perf_counter()
#         del inp
#         total_time += (end_time - start_time) * 1000
#     return total_time

def transformer_part3_runner(inputs, _, f=None):
    total_time = 0
    for i in range(len(inputs)):
        input = inputs[i].flatten().astype(np.float32)
        hd = len(input)

        inp = mx.array(input, mx.float32)
        hidden_dim = hd
        start_time = time.perf_counter()
        mx.eval(transformer_part3_mx(inp, hidden_dim))
        end_time = time.perf_counter()
        del inp
        total_time += (end_time - start_time) * 1000
    return total_time

In [12]:
def softmax_part1_mx(input, max_pos):
    return mx.max(input[:max_pos])

# def softmax_part1_runner(inputs, _, f=None):
#     total_time = 0
#     for i in range(len(inputs)):
#         input = inputs[i].flatten()
#         mp = len(input)
        
#         inp = mx.array(input, mx.float32)
#         max_pos = mp
#         start_time = time.perf_counter()
#         mx.eval(softmax_part1_mx(inp, max_pos))
#         end_time = time.perf_counter()
#         del inp
#         total_time += (end_time - start_time) * 1000
#     return total_time

def softmax_part1_runner(inputs, _, f=None):
    total_time = 0
    for i in range(len(inputs)):
        input = inputs[i].flatten().astype(np.float32)
        mp = len(input)
        
        inp = mx.array(input, mx.float32)
        max_pos = mp
        start_time = time.perf_counter()
        mx.eval(softmax_part1_mx(inp, max_pos))
        end_time = time.perf_counter()
        del inp
        total_time += (end_time - start_time) * 1000
    return total_time

In [13]:
def softmax_part2_mx(input, max_pos, max_val):
    return mx.exp((input[:max_pos]) - (max_val))

# def softmax_part2_runner(inputs, _, f=None):
#     total_time = 0
#     for i in range(len(inputs)):
#         input = inputs[i].flatten()
#         mp = len(input)
        
#         inp = mx.array(input, mx.float32)
#         max_pos = mp
#         max_val = mx.array(np.max(input[:mp]), mx.float32)

#         start_time = time.perf_counter()
#         mx.eval(softmax_part2_mx(inp, max_pos, max_val))
#         end_time = time.perf_counter()
#         del inp
#         total_time += (end_time - start_time) * 1000
#     return total_time

def softmax_part2_runner(inputs, _, f=None):
    total_time = 0
    for i in range(len(inputs)):
        input = inputs[i].flatten().astype(np.float32)
        mp = len(input)
        
        inp = mx.array(input, mx.float32)
        max_pos = mp
        max_val = mx.array(np.max(input[:mp]), mx.float32)

        start_time = time.perf_counter()
        mx.eval(softmax_part2_mx(inp, max_pos, max_val))
        end_time = time.perf_counter()
        del inp
        total_time += (end_time - start_time) * 1000
    return total_time

In [14]:
def softmax_part3_mx(output, max_pos):
    return mx.sum(output[:max_pos])

# def softmax_part3_runner(inputs, _, f=None):
#     total_time = 0
#     for i in range(len(inputs)):
#         input = inputs[i].flatten()
#         mp = len(input)
#         output = np.exp(input[:mp]-np.max(input[:mp]))
        
#         outp = mx.array(output, mx.float32)
#         max_pos = mp
        
#         start_time = time.perf_counter()
#         mx.eval(softmax_part3_mx(outp, max_pos))
#         end_time = time.perf_counter()
#         del outp
#         total_time += (end_time - start_time) * 1000
#     return total_time

def softmax_part3_runner(inputs, _, f=None):
    total_time = 0
    for i in range(len(inputs)):
        input = inputs[i].flatten().astype(np.float32)
        mp = len(input)
        output = np.exp(input[:mp]-np.max(input[:mp]))
        
        outp = mx.array(output, mx.float32)
        max_pos = mp
        
        start_time = time.perf_counter()
        mx.eval(softmax_part3_mx(outp, max_pos))
        end_time = time.perf_counter()
        del outp
        total_time += (end_time - start_time) * 1000
    return total_time

In [15]:
def softmax_part4_mx(unnormalized_output, max_pos, sum):
    return (unnormalized_output[:max_pos]) / (sum)

# def softmax_part4_runner(inputs, _, f=None):
#     total_time = 0
#     for i in range(len(inputs)):
#         input = inputs[i].flatten()
#         mp = len(input)
#         output = np.exp(input[:mp]-np.max(input[:mp]))
#         s = np.sum(output[:mp])
        
#         outp = mx.array(output, mx.float32)
#         max_pos = mp
#         sum = mx.array(s, mx.float32)
        
#         start_time = time.perf_counter()
#         mx.eval(softmax_part4_mx(outp, max_pos, sum))
#         end_time = time.perf_counter()
#         del outp
#         total_time += (end_time - start_time) * 1000
#     return total_time

def softmax_part4_runner(inputs, _, f=None):
    total_time = 0
    for i in range(len(inputs)):
        input = inputs[i].flatten().astype(np.float32)
        mp = len(input)
        output = np.exp(input[:mp]-np.max(input[:mp]))
        s = np.sum(output[:mp])
        
        outp = mx.array(output, mx.float32)
        max_pos = mp
        sum = mx.array(s, mx.float32)
        
        start_time = time.perf_counter()
        mx.eval(softmax_part4_mx(outp, max_pos, sum))
        end_time = time.perf_counter()
        del outp
        total_time += (end_time - start_time) * 1000
    return total_time

In [16]:
timer(weights, w_input, None, transformer_part4_runner)

elewise_mul_mlx
124.46504530007587ms +/- 9.615869981105703ms


In [17]:
timer(attn_weights, aw_input, None, matmul_runner)

matmul_mlx
26.145453799381357ms +/- 1.4911706287728572ms


In [18]:
timer(k_weights, q_weights, None, transformer_part1_runner)

multiquery_attention_part1_mlx
9.95410899977287ms +/- 16.71460224299404ms


In [29]:
timer(k_weights, q_weights, None, transformer_part2_runner)

multiquery_attention_part2_mlx
9.576474699770188ms +/- 13.601878665346964ms


In [20]:
timer(weights, w_input, None, rmsnorm_part1_runner)

rmsnorm_part1_mlx
151.8911950999609ms +/- 23.116388873278304ms


In [21]:
timer(weights, w_input,None, rmsnorm_part2_runner)

rmsnorm_part2_mlx
187.68991289998667ms +/- 14.889146486184822ms


In [22]:
timer(weights, None, None, transformer_part3_runner)

silu_mlx
324.14364539918097ms +/- 15.11308505256706ms


In [23]:
timer(attn_weights, None, None, softmax_part1_runner)

softmax_part1_mlx
34.38897890046064ms +/- 3.784112471945636ms


In [24]:
timer(attn_weights, None, None, softmax_part2_runner)

softmax_part2_mlx
64.34085259943458ms +/- 1.6587003885597416ms


In [25]:
timer(attn_weights, None, None, softmax_part3_runner)

softmax_part3_mlx
37.81613349892723ms +/- 0.938680863344943ms


In [26]:
timer(attn_weights, None, None, softmax_part4_runner)

softmax_part4_mlx
65.29268749991388ms +/- 9.300468336661075ms
