In [1]:
import torch
import numpy as np
import time
import h5py

In [2]:
rng = np.random.default_rng(1)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

  return torch._C._cuda_getDeviceCount() > 0


In [3]:
weights_path = './vicuna_weight.h5'

weights = []
w_input = []
attn_weights = []
aw_input = []
q_weights = []
k_weights = []

with h5py.File(weights_path, 'r') as weight_file:
    for layer_name in weight_file:
        w = np.squeeze(np.array(weight_file[layer_name])).astype(np.float32)
        if "model" in layer_name and "embed_tokens" not in layer_name and "layernorm" not in layer_name:
            weights.append(w)
            w_input.append(rng.random(w.shape, dtype = np.float32))
        if "attn" in layer_name:
            attn_weights.append(w)
            aw_input.append(rng.random(w.shape[1], dtype = np.float32))
            if "q_proj" in layer_name:
                q_weights.append(w)
            if "k_proj" in layer_name:
                k_weights.append(w)



In [4]:
def timer(input1, input2, f, runner):
    runs = 10
    times = []
    for _ in range(runs):
        times.append(runner(input1, input2, f))
    times = np.array(times)
    print(f"{runner.__name__[:-6]}pytorch")
    print(f"{np.average(times)}ms +/- {np.std(times)}ms")

In [5]:
def transformer_part4_torch(input1, input2, hidden_dim):
    return (input1[:hidden_dim]) * (input2[:hidden_dim])

def transformer_part4_runner(inputs1, inputs2, f=None):
    total_time = 0
    for i in range(len(inputs1)):
        input1 = inputs1[i].flatten()
        input2 = inputs2[i].flatten()
        hd = len(input1)

        inp1 = torch.from_numpy(input1).to(dtype=torch.float32).to(device)
        inp2 = torch.from_numpy(input2).to(dtype=torch.float32).to(device)
        hidden_dim = torch.tensor(hd).to(device)
        
        start_time = time.perf_counter()
        transformer_part4_torch(inp1, inp2, hidden_dim)
        end_time = time.perf_counter()
        del inp2
        del inp1
        total_time += (end_time - start_time) * 1000
    return total_time

In [6]:
def matmul_torch(weight, input):
    return torch.matmul(weight, input)

def matmul_runner(weights, inputs, f=None):
    total_time = 0
    for i in range(len(inputs)):
        weight = weights[i]
        input = inputs[i]

        w = torch.from_numpy(weight).to(dtype=torch.float32).to(device)
        inp = torch.from_numpy(input).to(dtype=torch.float32).to(device)
        
        start_time = time.perf_counter()
        matmul_torch(w, inp)
        end_time = time.perf_counter()
        del inp
        del w
        total_time += (end_time - start_time) * 1000
    return total_time

In [7]:
def transformer_part1_torch(token_position, head, head_size, key_cache_layer, q):
    return (torch.matmul(key_cache_layer[:token_position][:, (head) * (head_size):(head) * (head_size) + head_size], q[(head) * (head_size):(head) * (head_size) + head_size])) / (torch.sqrt(torch.as_tensor((head_size) * (1))))

def transformer_part1_runner(k_matrixes, q_matrixes, f=None):
    total_time = 0
    for i in range(len(k_matrixes)):
        k_matrix = k_matrixes[i]
        q_matrix = q_matrixes[i]
        token_position = k_matrix.shape[0] - 1

        num_head = 32
        head = int(rng.integers(low=0, high=num_head))
        head_size = k_matrix.shape[0] // num_head
        
        key_cache_layer = torch.from_numpy(k_matrix).to(device)
        q = torch.from_numpy(q_matrix.flatten()).to(device)
        
        start_time = time.perf_counter()
        transformer_part1_torch(token_position, head, head_size, key_cache_layer, q)
        end_time = time.perf_counter()
        del key_cache_layer
        del q
        total_time += (end_time - start_time) * 1000
    return total_time

In [8]:
def transformer_part2_torch(token_position, head, head_size, key_cache_layer, attention):
    return torch.matmul(torch.transpose(key_cache_layer[:(token_position) + (1)][:, (head) * (head_size):(head) * (head_size) + head_size], 0, 1), attention[:(token_position) + (1)])

def transformer_part2_runner(k_matrixes, q_matrixes, f=None):
    total_time = 0
    for i in range(len(k_matrixes)):
        k_matrix = k_matrixes[i]
        q_matrix = q_matrixes[i]
        token_position = k_matrix.shape[0] - 1

        num_head = 32
        head = int(rng.integers(low=0, high=num_head))
        head_size = k_matrix.shape[0] // num_head
        
        key_cache_layer = torch.from_numpy(k_matrix).to(device)
        q = torch.from_numpy(q_matrix.flatten()).to(device)

        attention = transformer_part1_torch(token_position, head, head_size, key_cache_layer, q).to(device)
        attention = torch.cat((attention, torch.tensor([0]).to(device))).to(device)
        
        start_time = time.perf_counter()
        transformer_part2_torch(token_position, head, head_size, key_cache_layer, attention)
        end_time = time.perf_counter()
        del key_cache_layer
        del attention
        total_time += (end_time - start_time) * 1000
    return total_time

In [9]:
def rmsnorm_part1_torch(input, weight):
    return torch.sum((input) * (input))

def rmsnorm_part1_runner(weights, inputs, f=None):
    total_time = 0
    for i in range(len(inputs)):
        input = inputs[i].flatten()
        weight = weights[i].flatten()
        
        inp = torch.from_numpy(input).to(dtype=torch.float32).to(device)
        w = torch.from_numpy(weight).to(dtype=torch.float32).to(device)
        start_time = time.perf_counter()
        rmsnorm_part1_torch(inp, w)
        end_time = time.perf_counter()
        del w
        del inp
        total_time += (end_time - start_time) * 1000
    return total_time

In [10]:
def rmsnorm_part2_torch(input, weight, ss):
    return ((1) / (torch.sqrt(torch.as_tensor(((ss) / (input.size(dim=0))) + (1))))) * ((input) * (weight))

def rmsnorm_part2_runner(weights, inputs, f=None):
    total_time = 0
    for i in range(len(inputs)):
        input = inputs[i].flatten()
        weight = weights[i].flatten()
        ssum = np.sum(input * input)

        inp = torch.from_numpy(input).to(dtype=torch.float32).to(device)
        w = torch.from_numpy(weight).to(dtype=torch.float32).to(device)
        ss = torch.tensor(ssum).to(dtype=torch.float32).to(device)

        start_time = time.perf_counter()
        rmsnorm_part2_torch(inp, w, ss)
        end_time = time.perf_counter()
        del w
        del inp
        total_time += (end_time - start_time) * 1000
    return total_time

In [11]:
def transformer_part3_torch(input, hidden_dim):
    return (input[:hidden_dim]) * ((1) / ((1) + (torch.exp(torch.as_tensor((0) - (input[:hidden_dim]))))))

def transformer_part3_runner(inputs, _, f=None):
    total_time = 0
    for i in range(len(inputs)):
        input = inputs[i].flatten()
        hd = len(input)

        inp = torch.from_numpy(input).to(dtype=torch.float32).to(device)
        hidden_dim = torch.tensor(hd).to(device)
        start_time = time.perf_counter()
        transformer_part3_torch(inp, hidden_dim)
        end_time = time.perf_counter()
        del inp
        total_time += (end_time - start_time) * 1000
    return total_time

In [12]:
def softmax_part1_torch(input, max_pos):
    return torch.max(input[:max_pos])

def softmax_part1_runner(inputs, _, f=None):
    total_time = 0
    for i in range(len(inputs)):
        input = inputs[i].flatten()
        mp = len(input)
        
        inp = torch.from_numpy(input).to(device)
        max_pos = torch.tensor(mp).to(device)
        start_time = time.perf_counter()
        softmax_part1_torch(inp, max_pos)
        end_time = time.perf_counter()
        del inp
        total_time += (end_time - start_time) * 1000
    return total_time

In [13]:
def softmax_part2_torch(input, max_pos, max_val):
    return torch.exp(torch.as_tensor((input[:max_pos]) - (max_val)))

def softmax_part2_runner(inputs, _, f=None):
    total_time = 0
    for i in range(len(inputs)):
        input = inputs[i].flatten()
        mp = len(input)
        
        inp = torch.from_numpy(input).to(dtype=torch.float32).to(device)
        max_pos = torch.tensor(mp).to(device)
        max_val = torch.tensor(np.max(input[:mp])).to(dtype=torch.float32).to(device)

        start_time = time.perf_counter()
        softmax_part2_torch(inp, max_pos, max_val)
        end_time = time.perf_counter()
        del inp
        total_time += (end_time - start_time) * 1000
    return total_time

In [14]:
def softmax_part3_torch(output, max_pos):
    return torch.sum(output[:max_pos])

def softmax_part3_runner(inputs, _, f=None):
    total_time = 0
    for i in range(len(inputs)):
        input = inputs[i].flatten()
        mp = len(input)
        output = np.exp(input[:mp]-np.max(input[:mp]))
        
        outp = torch.from_numpy(output).to(dtype=torch.float32).to(device)
        max_pos = torch.tensor(mp).to(device)
        
        start_time = time.perf_counter()
        softmax_part3_torch(outp, max_pos)
        end_time = time.perf_counter()
        del outp
        total_time += (end_time - start_time) * 1000
    return total_time

In [15]:
def softmax_part4_torch(unnormalized_output, max_pos, sum):
    return (unnormalized_output[:max_pos]) / (sum)

def softmax_part4_runner(inputs, _, f=None):
    total_time = 0
    for i in range(len(inputs)):
        input = inputs[i].flatten()
        mp = len(input)
        output = np.exp(input[:mp]-np.max(input[:mp]))
        s = np.sum(output[:mp])
        
        outp = torch.from_numpy(output).to(dtype=torch.float32).to(device)
        max_pos = torch.tensor(mp).to(device)
        sum = torch.tensor(s).to(dtype=torch.float32).to(device)
        
        start_time = time.perf_counter()
        softmax_part4_torch(outp, max_pos, sum)
        end_time = time.perf_counter()
        del outp
        total_time += (end_time - start_time) * 1000
    return total_time

In [16]:
timer(weights, w_input, None, transformer_part4_runner)

In [None]:
timer(attn_weights, aw_input, None, matmul_runner)

matmul_pytorch
24.433968123048544ms +/- 1.381028715149651ms


In [None]:
timer(k_weights, q_weights, None, transformer_part1_runner)

multiquery_attention_part1_pytorch
1.2479783035814762ms +/- 1.318031165149219ms


In [None]:
timer(k_weights, q_weights, None, transformer_part2_runner)

multiquery_attention_part2_pytorch
0.2821787726134062ms +/- 0.18459898719304943ms


In [None]:
timer(weights, w_input, None, rmsnorm_part1_runner)

rmsnorm_part1_pytorch
454.0018748957664ms +/- 0.7874046614157534ms


In [None]:
timer(weights, w_input,None, rmsnorm_part2_runner)

rmsnorm_part2_pytorch
766.2844286300242ms +/- 3.051371663684736ms


In [None]:
timer(weights, None, None, transformer_part3_runner)

silu_pytorch
1947.5853705778718ms +/- 5.557897591486259ms


In [None]:
timer(attn_weights, None, None, softmax_part1_runner)

softmax_part1_pytorch
27.303209900856018ms +/- 0.14051284226447797ms


In [None]:
timer(attn_weights, None, None, softmax_part2_runner)

softmax_part2_pytorch
237.33236687257886ms +/- 0.7364762389287917ms


In [None]:
timer(attn_weights, None, None, softmax_part3_runner)

softmax_part3_pytorch
24.55094694159925ms +/- 0.2237995992287591ms


In [None]:
timer(attn_weights, None, None, softmax_part4_runner)

softmax_part4_pytorch
149.70479435287416ms +/- 0.17329120919903354ms
