In [12]:
import torch.utils.cpp_extension

compiled_lib = torch.utils.cpp_extension.load(
    name='LlamaCppFloatLinear',
    sources=['LlamaCppFloatLinear.mm'],
    extra_cflags=['-std=c++17'],
)

In [38]:
import time
import torch

assert torch.backends.mps.is_available()

def test_speedup():
    mps_device = torch.device("mps")  # Device object representing GPU.
    custom_mps_linear = 0
    default_linear = 0
    weight = torch.randn(2048, 768, device=mps_device, dtype=torch.int8)
    input = torch.randn(128, 768, device=mps_device, dtype=torch.float32) # M = 128, N = 2048, K = 768
    scale = torch.randn(2048, device=mps_device, dtype=torch.float32)
    # Quantized linear: Linear(in_features=768, out_features=2048) * scale

    # Measures time.
    for _ in range(100):
        start = time.time()
        torch.ops.aten._weight_int8pack_mm(input, weight, scale)
        torch.mps.synchronize()
        default_linear += time.time() - start

        start = time.time()
        compiled_lib.llama_cpp_mps_int8_linear(input, weight, scale)
        torch.mps.synchronize()
        custom_mps_linear += time.time() - start

    speedup = default_linear / custom_mps_linear
    print('Default int8 QLinear: {:.3f} us | Custom int8 QLinear {:.3f} us. ({:.3f} times faster)'.format(
        default_linear * 1e6/1e5, custom_mps_linear * 1e6/1e5, speedup))



In [45]:
test_speedup()

Default int8 QLinear: 0.849 us | Custom int8 QLinear 0.661 us. (1.284 times faster)


In [17]:
import torch
m = torch.nn.Linear(768, 2048)
weight = torch.randn(2048, 768)
m.weight = torch.nn.Parameter(weight)
input = torch.randn(128, 768)
output = m(input)
print(output.size())

torch.Size([128, 2048])


In [29]:
mps_device = torch.device("mps")  # Device object representing GPU.
weight = torch.randn(2048, 768, device=mps_device, dtype=torch.int8)
input = torch.randn(128, 768, device=mps_device, dtype=torch.float32) # M = 128, N = 2048, K = 768
scale = torch.randn(2048, device=mps_device, dtype=torch.float32)
res1 = torch.ops.aten._weight_int8pack_mm(input, weight, scale)
print(res1.size())

torch.Size([128, 2048])


In [30]:
res2 = compiled_lib.llama_cpp_mps_int8_linear(input, weight, scale)

In [31]:
print(res2.size())

torch.Size([128, 2048])


In [32]:
print(res1)

tensor([[ -7.2068, -13.4481, -18.3802,  ...,  10.1890,  -2.2464,   1.5332],
        [ -3.5948,   8.0580,  13.1837,  ..., -14.9112,  -9.6422,  -0.7589],
        [  8.6552,  26.3780,  -3.4192,  ..., -36.6986, -10.1032,  14.1250],
        ...,
        [  4.3094, -31.0385,   3.6036,  ...,  -1.5528, -36.5997,  -2.1174],
        [  5.9805,   7.4148, -27.7217,  ..., -18.8047,  -4.7120,   8.9746],
        [ -2.2707, -12.3608,  14.5149,  ...,  24.0632,  25.4688,  -6.0178]],
       device='mps:0')


In [33]:
print(res2)

tensor([[ 14.5421,  10.2357,  40.4167,  ..., -21.8944,  60.1607,  16.0577],
        [  8.4213,   5.3934,  91.2100,  ...,  28.1776, -24.6787,  -7.3789],
        [-14.5463,  -1.6143,  10.3095,  ...,  -3.4682, -63.5541, -27.2303],
        ...,
        [ 11.3661, -27.6764,  14.9066,  ...,  -6.6825,  -2.6358,   5.4985],
        [ 32.8333,  31.9119, -30.6715,  ..., -13.7488, -39.2429, -15.3178],
        [  7.5256,  25.4646,   3.0442,  ...,  -8.3438,  21.2945,  30.2778]],
       device='mps:0')


In [14]:
torch.allclose(res1, res2, atol=1e-4)

True

In [59]:
mps_device = torch.device("mps")  # Device object representing GPU.
# Create a tensor with values from 0 to 127
row = torch.arange(32, device=mps_device, dtype=torch.int8)
# Repeat the row 64 times to create a 64x128 tensor
weight = row.repeat(64, 1)
input = torch.ones(32, 32, device=mps_device, dtype=torch.float32) # M = 128, N = 2048, K = 768
scale = torch.ones(64, device=mps_device, dtype=torch.float32)
res1 = torch.ops.aten._weight_int8pack_mm(input, weight, scale)
print(res1.size())
res2 = compiled_lib.llama_cpp_mps_int8_linear(input, weight, scale)

torch.Size([32, 64])


In [60]:
print(res1)

tensor([[496., 496., 496.,  ..., 496., 496., 496.],
        [496., 496., 496.,  ..., 496., 496., 496.],
        [496., 496., 496.,  ..., 496., 496., 496.],
        ...,
        [496., 496., 496.,  ..., 496., 496., 496.],
        [496., 496., 496.,  ..., 496., 496., 496.],
        [496., 496., 496.,  ..., 496., 496., 496.]], device='mps:0')


In [61]:
print(res2)

tensor([[496., 496., 496.,  ..., 496., 496., 496.],
        [496., 496., 496.,  ..., 496., 496., 496.],
        [496., 496., 496.,  ..., 496., 496., 496.],
        ...,
        [496., 496., 496.,  ..., 496., 496., 496.],
        [496., 496., 496.,  ..., 496., 496., 496.],
        [496., 496., 496.,  ..., 496., 496., 496.]], device='mps:0')


In [1]:
import torch

In [2]:
# Create a tensor with values from 0 to 127
row = torch.arange(128)
# Repeat the row 64 times to create a 64x128 tensor
tensor = row.repeat(64, 1)

In [3]:
print(tensor)

tensor([[  0,   1,   2,  ..., 125, 126, 127],
        [  0,   1,   2,  ..., 125, 126, 127],
        [  0,   1,   2,  ..., 125, 126, 127],
        ...,
        [  0,   1,   2,  ..., 125, 126, 127],
        [  0,   1,   2,  ..., 125, 126, 127],
        [  0,   1,   2,  ..., 125, 126, 127]])


In [67]:
mps_device = torch.device("mps")  # Device object representing GPU.
# Create a tensor with values from 0 to 127
row = torch.arange(32, device=mps_device, dtype=torch.int8)
# Repeat the row 64 times to create a 64x128 tensor
weight = torch.ones(64, 32, device=mps_device, dtype=torch.int8)
input = torch.ones(32, 32, device=mps_device, dtype=torch.float32) # M = 128, N = 2048, K = 768
scale = torch.ones(64, device=mps_device, dtype=torch.float32)
res1 = torch.ops.aten._weight_int8pack_mm(input, weight, scale)
print(res1.size())
res2 = compiled_lib.llama_cpp_mps_int8_linear(input, weight, scale)

torch.Size([32, 64])


In [68]:
print(res1)

tensor([[32., 32., 32.,  ..., 32., 32., 32.],
        [32., 32., 32.,  ..., 32., 32., 32.],
        [32., 32., 32.,  ..., 32., 32., 32.],
        ...,
        [32., 32., 32.,  ..., 32., 32., 32.],
        [32., 32., 32.,  ..., 32., 32., 32.],
        [32., 32., 32.,  ..., 32., 32., 32.]], device='mps:0')


In [69]:
print(res2)

tensor([[32., 32., 32.,  ..., 32., 32., 32.],
        [32., 32., 32.,  ..., 32., 32., 32.],
        [32., 32., 32.,  ..., 32., 32., 32.],
        ...,
        [32., 32., 32.,  ..., 32., 32., 32.],
        [32., 32., 32.,  ..., 32., 32., 32.],
        [32., 32., 32.,  ..., 32., 32., 32.]], device='mps:0')


In [16]:
mps_device = torch.device("mps")  # Device object representing GPU.
# Create a tensor with values from 0 to 127
row = torch.arange(32, device=mps_device, dtype=torch.int8)
# Repeat the row 64 times to create a 64x128 tensor
weight = torch.ones(64, 32, device=mps_device, dtype=torch.float32)
input = torch.ones(32, 32, device=mps_device, dtype=torch.float32) # M = 128, N = 2048, K = 768

res2 = compiled_lib.llama_cpp_mm(input, weight)

In [18]:
print(res2)

tensor([[32., 32., 32.,  ..., 32., 32., 32.],
        [32., 32., 32.,  ..., 32., 32., 32.],
        [32., 32., 32.,  ..., 32., 32., 32.],
        ...,
        [ 1.,  1.,  1.,  ...,  1.,  1.,  1.],
        [ 1.,  1.,  1.,  ...,  1.,  1.,  1.],
        [ 1.,  1.,  1.,  ...,  1.,  1.,  1.]], device='mps:0')


In [19]:
torch.ops.aten.mm(input, weight.transpose(1, 0))

tensor([[32., 32., 32.,  ..., 32., 32., 32.],
        [32., 32., 32.,  ..., 32., 32., 32.],
        [32., 32., 32.,  ..., 32., 32., 32.],
        ...,
        [32., 32., 32.,  ..., 32., 32., 32.],
        [32., 32., 32.,  ..., 32., 32., 32.],
        [32., 32., 32.,  ..., 32., 32., 32.]], device='mps:0')

In [15]:
mps_device = torch.device("mps")  # Device object representing GPU.

weight = torch.ones(3, 1, device=mps_device, dtype=torch.float32)
input = torch.ones(4, 1, device=mps_device, dtype=torch.float32)

res1 = torch.mm(input, weight.transpose(1, 0).contiguous())
res2 = compiled_lib.llama_cpp_mm(input, weight)
print(res1)
print(res2)
torch.allclose(res1, res2, atol=1e-2)

tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]], device='mps:0')
tensor([[nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan]], device='mps:0')


False

In [92]:
import torch
def blockwise_outer_product_matrix_mult(A, B):
    
    # Check if the matrices are compatible for multiplication
    if A.size(1) != B.size(0):
        raise ValueError("Incompatible matrix dimensions for multiplication.")
    
    # Initialize an empty matrix for the result
    C = torch.zeros((A.size(0), B.size(1)))
    
    # Perform the outer product for each pair of vectors and add the result to C
    for b in range(2):
        blockA = [A[b*2:b*2+2,0:2], A[b*2:b*2+2, 2:4]]
        blockB = [B[0:2,b*2:b*2+2], B[2:4, b*2:b*2+2]]
        tempC = torch.zeros(2, 2)
        for i in range(2):
            tempC += torch.outer(blockA[i, :], blockB[:, i])
        C[b*2:b*2+2, b*2:b*2+2] = tempC
    
    return C
# Test the function with two matrices
A = torch.arange(1, 17, dtype=torch.int32).reshape(4, 4)
B = torch.arange(1, 17, dtype=torch.int32).reshape(4, 4)

print(blockwise_outer_product_matrix_mult(A, B))

tensor([[ 11.,  35.,   0.,   0.],
        [ 14.,  46.,   0.,   0.],
        [  0.,   0., 301., 405.],
        [  0.,   0., 324., 436.]])


In [87]:
res = torch.matmul(A, B)
print(res)

tensor([[ 90, 100, 110, 120],
        [202, 228, 254, 280],
        [314, 356, 398, 440],
        [426, 484, 542, 600]], dtype=torch.int32)


In [90]:
for b in range(2):
    blockA = A[b*2:b*2+2,b*2:b*2+2]
    blockB = B[b*2:b*2+2,b*2:b*2+2]
    print(blockA)
    print(blockB)

tensor([[1, 2],
        [5, 6]], dtype=torch.int32)
tensor([[1, 2],
        [5, 6]], dtype=torch.int32)
tensor([[11, 12],
        [15, 16]], dtype=torch.int32)
tensor([[11, 12],
        [15, 16]], dtype=torch.int32)


In [None]:
def outer_product_mm(input, weight):
    assert(input.size(1) == weight.size(1), "input and weight need to have the same K")
    