In [None]:
import os, torch
%load_ext wurlitzer
from importlib import reload
print(torch.version.cuda)

In [None]:
import kernels.flashattn16_128
reload(kernels.flashattn16_128)

In [None]:
qbar = 23040
xbar = 8192
DIM = 128
big_randomQ = torch.randn(qbar, DIM, device='cuda', dtype=torch.half)
big_randomK = torch.randn(xbar, DIM, device='cuda', dtype=torch.half)
#big_randomV = torch.randn(xbar, DIM, device='cuda', dtype=torch.half)
big_randomV = torch.ones(xbar, DIM, device='cuda', dtype=torch.half)

In [None]:
import time, math

data_size = 2*(2*big_randomQ.numel() + big_randomK.numel() + big_randomV.numel())

class Moments:
    def __init__(self, m1=0.0, m2=0.0, n=0):
        self.m1 = m1
        self.m2 = m2
        self.n = n
    def add(self, v):
        self.m1 = (self.m1*self.n + v)/(self.n+1)
        self.m2 = (self.m2*self.n + v**2)/(self.n+1)
        self.n += 1
    def std(self):
        return self.m1, math.sqrt(self.m2-self.m1**2)
    def __str__(self):
        return str_std(*self.std())

def str_std(m, s, additional=1):
    precision_m = math.floor(math.log10(m))
    precision_s = math.floor(math.log10(s))
    precision_o = precision_m - precision_s
    return f"{m:.2e}±{s:.2e}"

def benchmarkFlash(inputLambda, reps=1):
    # Mean for harmonics
    mn = Moments()
    t0 = time.time()
    for _ in range(reps):
        s0 = time.time()
        input, O, mz = inputLambda()
        torch.cuda.synchronize()
        mn.add(1/(time.time()-s0))
    t1 = time.time()
    ht, sht = mn.std()
    d = input[1].item()
    b = input[2].item()
    r = reps
    # Read Rate
    T0 = (d*ht) / (1024**3)
    Ts = (d*sht) / (1024**3)
    # Deposit Rate
    T1 = (d*b*ht) / (1024**3)
    sT1 = (d*b*sht) / (1024**3)
    print(f"Throughput: {str_std(T0,Ts)} GiB/s in -> {T1:.2e} GiB/s out ({int(b)}bl, {t1-t0:.2e}s)")
    return input, O, mz

iv, O, mz = benchmarkFlash(lambda: kernels.flashattn16_128.flashAttention(big_randomQ, big_randomK, big_randomV), 50)
#iv, O, mz = benchmarkFlash(lambda: kernels.flashattn16_128.flashAttention(big_randomQ, big_randomK, big_randomV), 1)

In [None]:
print(mz)
with open("text.txt","r") as f:
    list_of = [float(r) for r in f.readlines()]

list_of = list_of[:4096]
sum(list_of)
list_of
print(len(list_of))

In [None]:
%pip install numpy

In [None]:
QK = (big_randomQ @ big_randomK.transpose(0,1))/math.sqrt(DIM)
list_of3 = []
import math
from numpy import float16, float32

def process_like(xs):
    max_counter = 0
    for i, y in enumerate(xs):
        if (i==0):
            m = float16(y)
            z = 1
        elif (y < m):
            y = float16(math.exp(float16(float16(y)-m)))
            z += float32(y)
        else:
            max_counter+=1
            mult = float16(math.exp(float16(m-float16(y))))
            m = float16(y)
            z = float32(mult*z + 1)
        list_of3.append(z)
    return m, z, max_counter
process_like([x.item() for x in QK[0,:]])
print(len(list_of3))


In [None]:
# Import libraries
import torch, math
from matplotlib import pyplot as plt
from numpy import float16, float32
# Generate the random variables
qbar = 2048
xbar = 4096*2 # This matches 2*the window size of Mixtral
DIM = 128   # This matches the head-dim of Mixtral
big_randomQ = torch.randn(qbar, DIM, device='cpu', dtype=torch.half)
big_randomK = torch.randn(xbar, DIM, device='cpu', dtype=torch.half)
# QK matrix, [qbar, xbar]
QK = (big_randomQ @ big_randomK.transpose(0,1))/math.sqrt(DIM)

current_l = []

# Process in f32
def process_f32(xs):
    list_l = []
    for i, y in enumerate(xs):
        if (i==0):
            m = y
            l = 1
        elif (y < m):
            y = math.exp(y-m)
            l += y
        else:
            mult = math.exp(m-y)
            m = y
            l = mult*l + 1
        list_l.append(l)
    return list_l

# Accumulate in f32, but calculate in f16
def process_mixed(xs):
    list_l = []
    for i, y in enumerate(xs):
        if (i==0):
            m = float16(y)
            l = 1
        elif (y < m):
            y = float16(math.exp(y-m))
            l += y
        else:
            mult = float16(math.exp(m-y))
            m = float16(y)
            l = mult*l + 1
        list_l.append(l)
    return list_l

# Process in f16
def process_f16(xs):
    list_l = []
    for i, y in enumerate(xs):
        if (i==0):
            m = float16(y)
            l = 1
        elif (y < m):
            y = float16(math.exp(y-m))
            l += float16(y)
        else:
            mult = float16(math.exp(m-y))
            m = float16(y)
            l = float16(mult*l + 1)
        list_l.append(l)
    return list_l

# First Plot - Accumulated Values
for q_index in range(1):
    plt.plot(process_f16  (QK[q_index,:])   , c='r', linestyle="-.", label="FP16") 
    plt.plot(process_mixed(QK[q_index,:]) , c='y', linestyle="solid", label="FP16, FP32 $l$")
    plt.plot(process_f32  (QK[q_index,:])   , c='g', linestyle="-.", label="FP32")
    plt.plot([4096,4096],[0,500], label="Mixtral Window Size (4096)")
plt.title("Accumulated $l$ Value for Various Precisions")
plt.xlabel("Progress Along $\overline{x}$ Axis")
plt.ylabel("Accumulated $l$ Value")
plt.legend()
plt.show()

# Second Plot - Ratios
for q_index in range(12):
    plt.plot([x/y for x,y in zip(process_f16(QK[q_index,:]), process_f32(QK[q_index,:]))])
plt.plot([4096,4096],[0.65,1], label="Mixtral Window Size (4096)")
plt.show()


In [None]:
torch.float16(2.3)

In [None]:
%pip install matplotlib

In [None]:
from matplotlib import pyplot as plt
plt.plot(list_of, c='b')
plt.plot(list_of2, c='r')
plt.plot(list_of3, c='g')
plt.show()

In [None]:
print(list_of2)

In [None]:
print(mz)

In [None]:
QK = (big_randomQ @ big_randomK.transpose(0,1))/math.sqrt(DIM)


In [None]:
print(QK)
print(QK.shape)

In [None]:
QK[0,:]

In [None]:
print(QK.shape)

In [None]:
m = QK.max(axis=1)[0]
print(m.shape)
print(torch.exp(QK-m.unsqueeze(1)).sum(axis=1))

In [None]:
ys=torch.tensor([-0.951660,
-0.529785,
-1.196289,
-1.516602,
-1.274414,
1.180664,
-0.781738,
-0.147705,
-0.207153,
-0.065247,
0.023468,
1.379883,
0.393555,
0.532227,
-0.342529,
0.837402,
0.633301,
0.030380,
-1.551758,
1.186523,
0.636230,
0.350098,
-0.115295,
-1.531250],dtype=torch.float16)
torch.exp(ys-torch.max(ys)).sum()

In [None]:
import flash_attn
def benchmarkBase(inputLambda, reps=1, data_size=data_size):
    Qs = big_randomQ.reshape(1, qbar, 1, DIM)
    Ks = big_randomK.reshape(1, xbar, 1, DIM)
    Vs = big_randomV.reshape(1, xbar, 1, DIM)
    mn = Moments()
    t0 = time.time()
    for _ in range(reps):
        s0 = time.time()
        Onew = inputLambda(Qs, Ks, Vs)
        torch.cuda.synchronize()
        mn.add(1/(time.time()-s0))
    t1 = time.time()
    ht, sht = mn.std()
    # Read/Write Rate
    T0 = (data_size*ht) / (1024**3)
    Ts = (data_size*sht)/ (1024**3)
    print(f"Base Throughput: {str_std(T0,Ts)} GiB/s in ({t1-t0:.2e}s)")
    return Onew

Obase = benchmarkBase(flash_attn.flash_attn_func, 100)
len(Obase.unique())

In [None]:
print(O)
print(Obase)

In [None]:
print(O.shape)
print(Obase.shape)

In [None]:
delta = Obase.reshape(qbar,DIM)-O

In [None]:
import numpy as np
Alpha = np.float32(3.14)
Beta = Alpha.tobytes()[1::2]
float16 = np.frombuffer(np.frombuffer(Beta, dtype='u1'), dtype='f2')
print(Alpha,float16)

In [None]:
torch.max(delta)

In [None]:
%load_ext wurlitzer
from importlib import reload
import os
import kernels.benchmarks
reload(kernels.benchmarks)

In [None]:
%load_ext wurlitzer
from importlib import reload
import kernels.flashattn32_128
reload(kernels.flashattn32_128)

In [None]:
qbar = 4096
xbar = 9060*8
DIM = 128
big_randomQ = torch.rand(qbar, DIM, device='cuda')
big_randomK = torch.rand(xbar, DIM, device='cuda')
big_randomV = torch.rand(xbar, DIM, device='cuda')

In [None]:
def benchmarkFlash(inputVector):
    input, O = inputVector
    t = input[0].item() / 1000
    d = input[1].item()
    b = input[2].item()
    r = input[3].item()
    # Read Rate
    T0 = r*(d/t) / (1024**3)
    # Deposit Rate
    T1 = r*(d*b/t) / (1024**3)
    print(f"Throughput: {T0:.2e} GiB/s in -> {T1:.2e} GiB/s out ({b:.0}bl, {t:.2e}s)")

benchmarkFlash(kernels.flashattn32_128.flashAttention(big_randomQ, big_randomK, big_randomV, 12, 8,  8, 1))

benchmarkFlash(kernels.flashattn32_128.flashAttention(big_randomQ, big_randomK, big_randomV, 12, 1,  8, 1))
benchmarkFlash(kernels.flashattn32_128.flashAttention(big_randomQ, big_randomK, big_randomV, 12, 2,  8, 1))
benchmarkFlash(kernels.flashattn32_128.flashAttention(big_randomQ, big_randomK, big_randomV, 12, 4,  8, 1))
benchmarkFlash(kernels.flashattn32_128.flashAttention(big_randomQ, big_randomK, big_randomV, 12, 8,  8, 1))
benchmarkFlash(kernels.flashattn32_128.flashAttention(big_randomQ, big_randomK, big_randomV, 12, 12, 12,1))
benchmarkFlash(kernels.flashattn32_128.flashAttention(big_randomQ, big_randomK, big_randomV, 8, 8, 8,1))
benchmarkFlash(kernels.flashattn32_128.flashAttention(big_randomQ, big_randomK, big_randomV, 4, 4, 4,1))

#benchmarkFlash(kernels.benchmarks.loadFlash(big_randomQ, big_randomK, big_randomV, 32, 32, 8, 1280))

vect, O = kernels.flashattn32_128.flashAttention(big_randomQ, big_randomK, big_randomV, 12, 1,  8, 1)

In [None]:
%load_ext wurlitzer
from importlib import reload
import kernels.flashattn16_128
reload(kernels.flashattn16_128)

In [None]:
def benchmarkFlash(inputVector):
    input, O = inputVector
    t = input[0].item() / 1000
    d = input[1].item()
    b = input[2].item()
    r = input[3].item()
    # Read Rate
    T0 = r*(d/t) / (1024**3)
    # Deposit Rate
    T1 = r*(d*b/t) / (1024**3)
    print(f"Throughput: {T0:.2e} GiB/s in -> {T1:.2e} GiB/s out ({b:.0}bl, {t:.2e}s)")

benchmarkFlash(kernels.flashattn16_128.flashAttention(big_randomQ, big_randomK, big_randomV, 12, 8,  8, 1))
benchmarkFlash(kernels.flashattn16_128.flashAttention(big_randomQ, big_randomK, big_randomV, 12, 1,  8, 1))
benchmarkFlash(kernels.flashattn16_128.flashAttention(big_randomQ, big_randomK, big_randomV, 12, 2,  8, 1))
benchmarkFlash(kernels.flashattn16_128.flashAttention(big_randomQ, big_randomK, big_randomV, 12, 4,  8, 1))
benchmarkFlash(kernels.flashattn16_128.flashAttention(big_randomQ, big_randomK, big_randomV, 12, 8,  8, 1))
benchmarkFlash(kernels.flashattn16_128.flashAttention(big_randomQ, big_randomK, big_randomV, 12, 12, 12,1))
benchmarkFlash(kernels.flashattn16_128.flashAttention(big_randomQ, big_randomK, big_randomV, 8,  8,  8, 1))
benchmarkFlash(kernels.flashattn16_128.flashAttention(big_randomQ, big_randomK, big_randomV, 4,  4,  4, 1))

#benchmarkFlash(kernels.benchmarks.loadFlash(big_randomQ, big_randomK, big_randomV, 32, 32, 8, 1280))

vect, O = kernels.flashattn32_128.flashAttention(big_randomQ, big_randomK, big_randomV, 12, 1,  8, 1)

In [None]:
O

In [None]:
import flash_attn, time, math

b, h = 1, 1
qb = 4096
xb = 9060*8
d = 128

Qs = big_randomQ.reshape(b, qb, h, d)
Ks = big_randomK.reshape(b, xb, h, d)
Vs = big_randomV.reshape(b, xb, h, d)

# Qs = torch.rand(b, qbar, h, DIM, device='cuda', dtype=torch.float16)
# Ks = torch.rand(b, xbar, h, DIM, device='cuda', dtype=torch.float16)
# Vs = torch.rand(b, xbar, h, DIM, device='cuda', dtype=torch.float16)

dsize = 2*(2*Qs.numel()+Ks.numel()+Vs.numel())
print(Qs.shape)
reps = 100
# the average time
t1 = 0
# the second moment of time
t2 = 0

# the average throughput
T1 = 0
# the second moment of throughput
T2 = 0

for n in range(reps):
    s0 = time.time()
    Os = flash_attn.flash_attn_func(Qs, Ks, Vs)
    torch.cuda.synchronize()
    # Get average time
    s1 = time.time()
    dt = s1 - s0
    t1 = (n*t1 + dt) / (n+1)
    t2 = (n*t2 + dt**2) / (n+1)
    # Get average throughput
    Tn = (dsize / dt)/(1024**3)
    T1 = (n*T1 + Tn) / (n+1)
    T2 = (n*T2 + Tn**2) / (n+1) 
    if n > 10:
        sd = math.sqrt((t2-t1**2)/(1-1/(n+1)))
        if sd < t1/3:
            break

print(Os.shape)
sd = math.sqrt((t2-t1**2)/(1-1/reps))
Tsd = math.sqrt((T2-T1**2)/(1-1/reps))
T = (dsize / t1)/(1024**3)
print(f"Throughput: {T:.2e} GiB/s ({(n+1)*t1:.2e}s, {t1:.2e}~{sd:.2e}s, {dsize:.2e}B)")
print(f"Throughput: {T1:.2e} GiB/s ({(n+1)*t1:.2e}s, {T1:.2e}~{Tsd:.2e}s, {dsize:.2e}B)")


In [None]:
#Os
(O-Os).max()

In [None]:
big_randomA = torch.rand(112,96).cuda()
big_randomB = torch.rand(112,96).cuda()
# The second output are the number of us the function kernel takes.
# kernels.benchmarks.matmul(big_randomA, big_randomB)

# The speed for broadcasting data, hence, using L2 cache.
# Scales with number of blocks (expected)
# Reaches ~1.90 TB/s with all blocks receiving (32 warps, 120)
def displayBenchmark(input : torch.Tensor):
    t = input[0].item()
    d = input[1].item()
    w = input[2].item()
    b = input[3].item()
    r = input[4].item()
    # Read Rate
    T0 = ((d*r) / (t/1000))/(1024**3)
    # Deposit Rate
    T = (b*d*r / (t/1000))/(1024**3)
    print(f"Throughput: {T0:.2e} GiB/s in -> {T:.2e} GiB/s out ({t/1000:.2e}s,{d/1024:.0f}KiB,{w:.0f}w,{b:.0f}bl,{r:.1e}reps)")
    return T0, T
if False:
    displayBenchmark(kernels.benchmarks.loadBroadcasted(big_randomA, 8000000, 4, 1))
    displayBenchmark(kernels.benchmarks.loadBroadcasted(big_randomA, 8000000, 8, 1))
    displayBenchmark(kernels.benchmarks.loadBroadcasted(big_randomA, 8000000, 16, 1))
    displayBenchmark(kernels.benchmarks.loadBroadcasted(big_randomA, 8000000, 32, 1))
    displayBenchmark(kernels.benchmarks.loadBroadcasted(big_randomA, 800000, 16, 30))
    displayBenchmark(kernels.benchmarks.loadBroadcasted(big_randomA, 800000, 32, 30))
    displayBenchmark(kernels.benchmarks.loadBroadcasted(big_randomA, 800000, 32, 60))
    displayBenchmark(kernels.benchmarks.loadBroadcasted(big_randomA, 800000, 32, 120))
    displayBenchmark(kernels.benchmarks.loadBroadcasted(big_randomA, 800000, 32, 240))

#displayBenchmark(kernels.benchmarks.loadBroadcasted(big_randomA, 10000000, 16, 120))
#displayBenchmark(kernels.benchmarks.loadBroadcasted(big_randomA, 10000000, 16, 240))
#displayBenchmark(kernels.benchmarks.loadBroadcasted(big_randomA, 10000000, 32, 120))

In [None]:
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 10000, 32, 24000))
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 10000, 16, 24000))

In [None]:
big_randomA = torch.rand(128,96).cuda()
#displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 10000, 32, 24000))
print("Load GOOD")
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 1000, 32, 32, 24000, 1))
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 1000, 24, 32, 24000, 1))
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 1000, 16, 32, 24000, 1))
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 1000, 12, 32, 24000, 1))
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 1000, 8, 32,  24000, 1))
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 1000, 4, 32,  24000, 1))
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 1000, 2, 32,  24000, 1))
print("Load ANDSAVE")
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 1000, 32, 32, 24000, 4))
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 1000, 24, 32, 24000, 4))
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 1000, 16, 32, 24000, 4))
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 1000, 12, 32, 24000, 4))
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 1000, 8, 32,  24000, 4))
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 1000, 4, 32,  24000, 4))
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 1000, 2, 32,  24000, 4))
print("Load PIPELINE")
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 1000, 32, 32, 24000, 5))
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 1000, 24, 32, 24000, 5))
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 1000, 16, 32, 24000, 5))
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 1000, 12, 32, 24000, 5))
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 1000, 8, 32,  24000, 5))
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 1000, 4, 32,  24000, 5))
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 1000, 2, 32,  24000, 5))
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 1000, 1, 32,  24000, 5))
print("Load BAD")
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 1000, 32, 32, 24000, 2))
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 1000, 24, 32, 24000, 2))
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 1000, 16, 32, 24000, 2))
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 1000, 12, 32, 24000, 2))
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 1000, 8, 32,  24000, 2))
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 1000, 4, 32,  24000, 2))
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 1000, 2, 32,  24000, 2))
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 1000, 1, 32,  24000, 2))
print("Warp Size 16")
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 1000, 16, 16,  24000, 5))
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 1000, 8, 16,  24000, 5))
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 1000, 4, 16,  24000, 5))
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 1000, 2, 16,  24000, 5))
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 1000, 1, 16,  24000, 5))
mode = 5
print("Warp Size 12")
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 1000, 12, 12,  24000, mode))
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 1000, 10, 12,  24000, mode))
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 1000, 8, 12,  24000,  mode))
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 1000, 6, 12,  24000,  mode))
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 1000, 4, 12,  24000,  mode))
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 1000, 2, 12,  24000,  mode))
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 1000, 1, 12,  24000,  mode))
print("Smaller Warp Sizes")
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 1000, 8, 8,  24000, 5))
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 1000, 4, 8,  24000, 5))
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 1000, 4, 4,  24000, 5))
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 1000, 2, 4,  24000, 5))
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 1000, 2, 2,  24000, 5))

print("Load NONE")
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 1000, 32, 32, 24000, 3))
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 1000, 24, 32, 24000, 3))
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 1000, 16, 32, 24000, 3))
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 1000, 12, 32, 24000, 3))
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 1000, 8, 32,  24000, 3))
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 1000, 4, 32,  24000, 3))
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 1000, 2, 32,  24000, 3))



In [None]:
big_randomA = torch.rand(48,128).cuda()
#displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 10000, 32, 24000))
mode = 2
print("Load PIPELINE")
displayBenchmark(kernels.benchmarks.loadBroadcasted(big_randomA, 1000, 32, 32, 24000, mode))
displayBenchmark(kernels.benchmarks.loadBroadcasted(big_randomA, 1000, 24, 32, 24000, mode))
displayBenchmark(kernels.benchmarks.loadBroadcasted(big_randomA, 1000, 16, 32, 24000, mode))
displayBenchmark(kernels.benchmarks.loadBroadcasted(big_randomA, 1000, 12, 32, 24000, mode))
displayBenchmark(kernels.benchmarks.loadBroadcasted(big_randomA, 1000, 8, 32,  24000, mode))
displayBenchmark(kernels.benchmarks.loadBroadcasted(big_randomA, 1000, 4, 32,  24000, mode))
displayBenchmark(kernels.benchmarks.loadBroadcasted(big_randomA, 1000, 2, 32,  24000, mode))
displayBenchmark(kernels.benchmarks.loadBroadcasted(big_randomA, 1000, 1, 32,  24000, mode))
print("Warp Size 16")
displayBenchmark(kernels.benchmarks.loadBroadcasted(big_randomA, 1000, 16, 16, 24000, mode))
displayBenchmark(kernels.benchmarks.loadBroadcasted(big_randomA, 1000, 8, 16,  24000, mode))
displayBenchmark(kernels.benchmarks.loadBroadcasted(big_randomA, 1000, 4, 16,  24000, mode))
displayBenchmark(kernels.benchmarks.loadBroadcasted(big_randomA, 1000, 2, 16,  24000, mode))
displayBenchmark(kernels.benchmarks.loadBroadcasted(big_randomA, 1000, 1, 16,  24000, mode))
print("Warp Size 12")
displayBenchmark(kernels.benchmarks.loadBroadcasted(big_randomA, 100000, 12, 12, 2400,  mode))
displayBenchmark(kernels.benchmarks.loadBroadcasted(big_randomA, 100000, 10, 12, 2400,  mode))
displayBenchmark(kernels.benchmarks.loadBroadcasted(big_randomA, 100000, 8, 12,  2400,  mode))
displayBenchmark(kernels.benchmarks.loadBroadcasted(big_randomA, 100000, 6, 12,  2400,  mode))
displayBenchmark(kernels.benchmarks.loadBroadcasted(big_randomA, 100000, 4, 12,  2400,  mode))
displayBenchmark(kernels.benchmarks.loadBroadcasted(big_randomA, 100000, 2, 12,  2400,  mode))
displayBenchmark(kernels.benchmarks.loadBroadcasted(big_randomA, 100000, 1, 12,  2400,  mode))
print("Smaller Warp Sizes")
displayBenchmark(kernels.benchmarks.loadBroadcasted(big_randomA, 1000, 8, 8,  24000,  mode))
displayBenchmark(kernels.benchmarks.loadBroadcasted(big_randomA, 1000, 4, 8,  24000,  mode))
displayBenchmark(kernels.benchmarks.loadBroadcasted(big_randomA, 1000, 4, 4,  24000,  mode))
displayBenchmark(kernels.benchmarks.loadBroadcasted(big_randomA, 1000, 2, 4,  24000,  mode))
displayBenchmark(kernels.benchmarks.loadBroadcasted(big_randomA, 1000, 2, 2,  24000,  mode))

In [None]:
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 10000, 16, 32, 24000))

In [None]:
import flash_attn, time, math

b, h = 1, 1
qb = 1024
xb = 9060*8
d = 128

Qs = torch.rand(b, qb, h, d, device='cuda', dtype=torch.float16)
Ks = torch.rand(b, xb, h, d, device='cuda', dtype=torch.float16)
Vs = torch.rand(b, xb, h, d, device='cuda', dtype=torch.float16)

dsize = 2*(Qs.numel()+Ks.numel()+Vs.numel())
print(Qs.shape)
reps = 1000
# the average time
t1 = 0
# the second moment of time
t2 = 0

for n in range(reps):
    s0 = time.time()
    Os = flash_attn.flash_attn_func(Qs, Ks, Vs)
    torch.cuda.synchronize()
    s1 = time.time()
    dt = s1 - s0
    t1 = (n*t1 + dt) / (n+1)
    t2 = (n*t2 + dt**2) / (n+1)
    if n > 10:
        sd = math.sqrt((t2-t1**2)/(1-1/(n+1)))
        if sd < t1/3:
            break

print(Os.shape)
sd = math.sqrt((t2-t1**2)/(1-1/reps))
T = (dsize / t1)/(1024**3)
print(f"Throughput: {T:.2e} GiB/s ({(n+1)*t1:.2e}s, {t1:.2e}~{sd:.2e}s, {dsize:.2e}B)")


In [None]:
from itertools import product
# displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 80000, 4, 1))
# displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 80000, 8, 1))
# displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 80000, 16, 1))
# displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 80000, 32, 1))
# displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 80000, 32, 2))
# displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 80000, 32, 4))
# displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 80000, 32, 8))
# Undersaturated - Prediction: Slower
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 80000, 32, 30))
# Saturated - Prediction: Fastest
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 80000, 32, 60))
# Oversaturated - Prediction: Equally Fastest
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 80000, 32, 120))
# Undertimed - Prediction: Equally Fastest
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 8000, 32, 120))
# Overtimed - Prediction: Equally Fastest
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 800000, 32, 120))
# Undersaturated - Prediction: Slower (Warp Dependent, SM Dependent)
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 80000, 16, 60))
# Saturated - Prediction: Slower (Warp Dependent), Equally Fastest (SM dependent)
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 80000, 16, 120))


In [None]:
displayBenchmark(kernels.benchmarks.loadDistributed(big_randomA, 60, 32, 9600))

In [None]:
displayBenchmark(kernels.benchmarks.loadBroadcasted(big_randomA, 800000, 16, 360))


In [None]:
# displayBenchmark(kernels.benchmarks.loadBroadcasted(big_randomA, 20000000, 16, 240))
# displayBenchmark(kernels.benchmarks.loadBroadcasted(big_randomA, 20000000, 8, 480))

In [None]:
big_randomC = torch.rand(32,32).cuda()
displayBenchmark(kernels.benchmarks.loadBroadcasted(big_randomC, 2000000, 4, 960))
displayBenchmark(kernels.benchmarks.loadBroadcasted(big_randomC, 2000000, 2, 1920))

In [None]:
times = []
bs = [240,480,720,960,1200,1440,1680,1920,2160,2400]
reps = 2000000000
for b in bs:
   times.append(displayBenchmark(kernels.benchmarks.loadBroadcasted(big_randomC, reps // b, 4, b)))

import matplotlib.pyplot as plt
plt.plot(bs, times)

In [None]:
hbm, hbm_size, niter = kernels.benchmarks.testCopy(True)
l2, l2_size, niter = kernels.benchmarks.testCopy(False)

# hbm_through = 2*niter*(4*hbm_size)/(hbm/1000)
# l2_through = 2*niter*(4*l2_size)/(l2/1000)
print(hbm, hbm_size, niter)
print(l2, l2_size, niter)

print(f"{hbm_through:0.2e}")
print(f"{l2_through:0.2e}")

In [None]:
displayBenchmark(kernels.benchmarks.loadBroadcasted(big_randomA, 4000000, 32, 120))