In [1]:
import torch
import numpy as np
import time
import h5py

In [2]:
rng = np.random.default_rng(1)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
cpu = 'cpu'

  return torch._C._cuda_getDeviceCount() > 0


In [3]:
weights_path = './vicuna_weight.h5'

weights = []
w_input = []
attn_weights = []
aw_input = []
q_weights = []
k_weights = []

with h5py.File(weights_path, 'r') as weight_file:
    for layer_name in weight_file:
        w = np.squeeze(np.array(weight_file[layer_name])).astype(np.float32)
        if "model" in layer_name and "embed_tokens" not in layer_name and "layernorm" not in layer_name:
            weights.append(w)
            w_input.append(rng.random(w.shape, dtype = np.float32))
        if "attn" in layer_name:
            attn_weights.append(w)
            aw_input.append(rng.random(w.shape[1], dtype = np.float32))
            if "q_proj" in layer_name:
                q_weights.append(w)
            if "k_proj" in layer_name:
                k_weights.append(w)


: 

In [None]:
def timer(input1, input2, f, runner):
    runs = 10
    times = []
    for _ in range(runs):
        times.append(runner(input1, input2, f))
    times = np.array(times)
    print(f"{runner.__name__[:-6]}pytorch_with_load")
    print(f"{np.average(times)}ms +/- {np.std(times)}ms")

In [None]:
def elewise_mul_torch(input1, input2, hidden_dim):
    return torch.multiply(input1[:hidden_dim], input2[:hidden_dim])

def elewise_mul_runner(inputs1, inputs2, f=None):
    total_time = 0
    for i in range(len(inputs1)):
        input1 = inputs1[i].flatten()
        input2 = inputs2[i].flatten()
        hd = len(input1)

        inp1 = torch.from_numpy(input1).to(dtype=torch.float32).to(cpu)
        inp2 = torch.from_numpy(input2).to(dtype=torch.float32).to(cpu)
        hidden_dim = torch.tensor(hd).to(cpu)
        
        start_time = time.perf_counter()
        inp1 = inp1.to(device)
        inp2 = inp2.to(device)
        hidden_dim = hidden_dim.to(device)
        res = elewise_mul_torch(inp1, inp2, hidden_dim)
        res = res.to(cpu)
        end_time = time.perf_counter()

        del inp2
        del inp1
        total_time += (end_time - start_time) * 1000
    return total_time

In [None]:
def matmul_torch(weight, input):
    return torch.matmul(weight, input)

def matmul_runner(weights, inputs, f=None):
    total_time = 0
    for i in range(len(inputs)):
        weight = weights[i]
        input = inputs[i]

        w = torch.from_numpy(weight).to(dtype=torch.float32).to(cpu)
        inp = torch.from_numpy(input).to(dtype=torch.float32).to(cpu)
        
        start_time = time.perf_counter()
        w = w.to(device)
        inp = inp.to(device)
        res = matmul_torch(w, inp)
        res = res.to(cpu)
        end_time = time.perf_counter()
        
        del inp
        del w
        total_time += (end_time - start_time) * 1000
    return total_time

In [None]:
def multiquery_attention_part1_torch(token_position, head, head_size, key_cache_layer, q):
    return torch.matmul(key_cache_layer[:token_position][:, (head * head_size):(head * head_size) + head_size], q[(head * head_size):(head * head_size) + head_size])/ torch.sqrt(torch.as_tensor(head_size * 1))

def multiquery_attention_part1_runner(k_matrixes, q_matrixes, f=None):
    total_time = 0
    for i in range(len(k_matrixes)):
        k_matrix = k_matrixes[i]
        q_matrix = q_matrixes[i]
        token_position = k_matrix.shape[0] - 1

        num_head = 32
        head = int(rng.integers(low=0, high=num_head))
        head_size = k_matrix.shape[0] // num_head
        
        key_cache_layer = torch.from_numpy(k_matrix).to(cpu)
        q = torch.from_numpy(q_matrix.flatten()).to(cpu)
        
        start_time = time.perf_counter()
        key_cache_layer = key_cache_layer.to(device)
        q = q.to(device)
        res = multiquery_attention_part1_torch(token_position, head, head_size, key_cache_layer, q)
        res = res.to(cpu)
        end_time = time.perf_counter()
        
        del key_cache_layer
        del q
        total_time += (end_time - start_time) * 1000
    return total_time

In [None]:
def multiquery_attention_part2_torch(token_position, head, head_size, key_cache_layer, attention):
    return torch.matmul(torch.transpose(key_cache_layer[:token_position + 1][:, (head * head_size):(head * head_size) + head_size], 0, 1), attention[:token_position + 1])

def multiquery_attention_part2_runner(k_matrixes, q_matrixes, f=None):
    total_time = 0
    for i in range(len(k_matrixes)):
        k_matrix = k_matrixes[i]
        q_matrix = q_matrixes[i]
        token_position = k_matrix.shape[0] - 1

        num_head = 32
        head = int(rng.integers(low=0, high=num_head))
        head_size = k_matrix.shape[0] // num_head
        
        key_cache_layer = torch.from_numpy(k_matrix).to(cpu)
        q = torch.from_numpy(q_matrix.flatten()).to(cpu)

        attention = multiquery_attention_part1_torch(token_position, head, head_size, key_cache_layer, q).to(cpu)
        attention = torch.cat((attention, torch.tensor([0]).to(cpu))).to(cpu)
        
        start_time = time.perf_counter()
        key_cache_layer = key_cache_layer.to(device)
        attention = attention.to(device)
        res = multiquery_attention_part2_torch(token_position, head, head_size, key_cache_layer, attention)
        res = res.to(cpu)
        end_time = time.perf_counter()
        
        del key_cache_layer
        del attention
        total_time += (end_time - start_time) * 1000
    return total_time

In [None]:
def rmsnorm_part1_torch(input, weight):
    return torch.sum(torch.multiply(input, input))

def rmsnorm_part1_runner(weights, inputs, f=None):
    total_time = 0
    for i in range(len(inputs)):
        input = inputs[i].flatten()
        weight = weights[i].flatten()
        
        inp = torch.from_numpy(input).to(dtype=torch.float32).to(cpu)
        w = torch.from_numpy(weight).to(dtype=torch.float32).to(cpu)

        start_time = time.perf_counter()
        inp = inp.to(device)
        w = w.to(device)
        res = rmsnorm_part1_torch(inp, w)
        res = res.to(cpu)
        end_time = time.perf_counter()

        del w
        del inp
        total_time += (end_time - start_time) * 1000
    return total_time

In [None]:
def rmsnorm_part2_torch(input, weight, ss):
    return torch.multiply((1 / torch.sqrt(torch.as_tensor((ss/input.size(dim=0))) + 1)), torch.multiply(input, weight))

def rmsnorm_part2_runner(weights, inputs, f=None):
    total_time = 0
    for i in range(len(inputs)):
        input = inputs[i].flatten()
        weight = weights[i].flatten()
        ssum = np.sum(input * input)

        inp = torch.from_numpy(input).to(dtype=torch.float32).to(cpu)
        w = torch.from_numpy(weight).to(dtype=torch.float32).to(cpu)
        ss = torch.tensor(ssum).to(dtype=torch.float32).to(cpu)

        start_time = time.perf_counter()
        inp = inp.to(device)
        w = w.to(device)
        ss = ss.to(device)
        res = rmsnorm_part2_torch(inp, w, ss)
        res = res.to(cpu)
        end_time = time.perf_counter()

        del w
        del inp
        total_time += (end_time - start_time) * 1000
    return total_time

In [None]:
def silu_torch(input, hidden_dim):
    return torch.multiply(torch.divide(1, (torch.exp(0 - input[:hidden_dim]) + 1)), input[:hidden_dim])

def silu_runner(inputs, _, f=None):
    total_time = 0
    for i in range(len(inputs)):
        input = inputs[i].flatten()
        hd = len(input)

        inp = torch.from_numpy(input).to(dtype=torch.float32).to(cpu)
        hidden_dim = torch.tensor(hd).to(cpu)

        start_time = time.perf_counter()
        inp = inp.to(device)
        hidden_dim = hidden_dim.to(device)
        res = silu_torch(inp, hidden_dim)
        res = res.to(cpu)
        end_time = time.perf_counter()

        del inp
        total_time += (end_time - start_time) * 1000
    return total_time

In [None]:
def softmax_part1_torch(input, max_pos):
    return torch.max(input[:max_pos])

def softmax_part1_runner(inputs, _, f=None):
    total_time = 0
    for i in range(len(inputs)):
        input = inputs[i].flatten()
        mp = len(input)
        
        inp = torch.from_numpy(input).to(cpu)
        max_pos = torch.tensor(mp).to(cpu)

        start_time = time.perf_counter()
        inp = inp.to(device)
        max_pos = max_pos.to(device)
        res = softmax_part1_torch(inp, max_pos)
        res = res.to(cpu)
        end_time = time.perf_counter()
        
        del inp
        total_time += (end_time - start_time) * 1000
    return total_time

In [None]:
def softmax_part2_torch(input, max_pos, max_val):
    return torch.exp(input[:max_pos] - max_val)

def softmax_part2_runner(inputs, _, f=None):
    total_time = 0
    for i in range(len(inputs)):
        input = inputs[i].flatten()
        mp = len(input)
        
        inp = torch.from_numpy(input).to(dtype=torch.float32).to(cpu)
        max_pos = torch.tensor(mp).to(cpu)
        max_val = torch.tensor(np.max(input[:mp])).to(dtype=torch.float32).to(cpu)

        start_time = time.perf_counter()
        inp = inp.to(device)
        max_pos = max_pos.to(device)
        max_val = max_val.to(device)
        res = softmax_part2_torch(inp, max_pos, max_val)
        res = res.to(cpu)
        end_time = time.perf_counter()
        
        del inp
        total_time += (end_time - start_time) * 1000
    return total_time

In [None]:
def softmax_part3_torch(output, max_pos):
    return torch.sum(output[:max_pos])

def softmax_part3_runner(inputs, _, f=None):
    total_time = 0
    for i in range(len(inputs)):
        input = inputs[i].flatten()
        mp = len(input)
        output = np.exp(input[:mp]-np.max(input[:mp]))
        
        outp = torch.from_numpy(output).to(dtype=torch.float32).to(cpu)
        max_pos = torch.tensor(mp).to(cpu)
        
        start_time = time.perf_counter()
        outp = outp.to(device)
        max_pos = max_pos.to(device)
        res = softmax_part3_torch(outp, max_pos)
        res = res.to(cpu)
        end_time = time.perf_counter()
        
        del outp
        total_time += (end_time - start_time) * 1000
    return total_time

In [None]:
def softmax_part4_torch(unnormalized_output, max_pos, sum):
    return unnormalized_output[:max_pos]/sum

def softmax_part4_runner(inputs, _, f=None):
    total_time = 0
    for i in range(len(inputs)):
        input = inputs[i].flatten()
        mp = len(input)
        output = np.exp(input[:mp]-np.max(input[:mp]))
        s = np.sum(output[:mp])
        
        outp = torch.from_numpy(output).to(dtype=torch.float32).to(device)
        max_pos = torch.tensor(mp).to(device)
        sum = torch.tensor(s).to(dtype=torch.float32).to(device)
        
        start_time = time.perf_counter()
        softmax_part4_torch(outp, max_pos, sum)
        end_time = time.perf_counter()
        del outp
        total_time += (end_time - start_time) * 1000
    return total_time

In [None]:
timer(weights, w_input, None, elewise_mul_runner)

elewise_mul_pytorch_with_load
558.834465360269ms +/- 64.03464096943887ms


In [None]:
timer(attn_weights, aw_input, None, matmul_runner)

matmul_pytorch_with_load
26.87900783494115ms +/- 1.230684701351786ms


In [None]:
timer(k_weights, q_weights, None, multiquery_attention_part1_runner)

multiquery_attention_part1_pytorch_with_load
1.1015573516488075ms +/- 0.1503313896433275ms


In [None]:
timer(k_weights, q_weights, None, multiquery_attention_part2_runner)

torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([

In [None]:
timer(weights, w_input, None, rmsnorm_part1_runner)

rmsnorm_part1_pytorch_with_load
330.09126214310527ms +/- 1.800282390717216ms


In [None]:
timer(weights, w_input,None, rmsnorm_part2_runner)

rmsnorm_part2_pytorch_with_load
1010.9781612642109ms +/- 12.108922029933726ms


In [None]:
timer(weights, None, None, silu_runner)

silu_pytorch_with_load
875.178828695789ms +/- 3.14554028634092ms


In [None]:
timer(attn_weights, None, None, softmax_part1_runner)

softmax_part1_pytorch_with_load
55.707092909142375ms +/- 2.5383555690753643ms


In [None]:
timer(attn_weights, None, None, softmax_part2_runner)

softmax_part2_pytorch_with_load
288.1737082730979ms +/- 1.0722158977785117ms


In [None]:
timer(attn_weights, None, None, softmax_part3_runner)

softmax_part3_pytorch_with_load
54.56754267215729ms +/- 0.11126528259816412ms


In [None]:
timer(attn_weights, None, None, softmax_part4_runner)

softmax_part4_pytorch_with_load
1.0392235592007637ms +/- 0.2029464233078125ms
