In [57]:
import torch
from utils.utils import HadamardTransform
from utils.hadamard_utils import matmul_hadU_cuda, get_hadK
import math
import numpy as np
import struct

In [60]:
# Function to get the bit pattern of a floating-point number
def float_to_hex(f):
    # Pack the float into 4 bytes (32-bit float) or 8 bytes (64-bit float)
    packed = struct.pack('>f', f)  # Use '>f' for 32-bit float, '>d' for 64-bit double
    # Unpack as an integer to get the bit pattern
    hex_value = ''.join(f'{byte:02x}' for byte in packed)
    return hex_value

In [61]:
q = torch.rand(2,2,11008)

np.log2([11008])
print(2**14)
print(16384-11008)

dtype = q.dtype
q_ref = (HadamardTransform.apply(q.float()) / math.sqrt(q.shape[-1])).to(dtype)

q_pad = torch.concat([q, torch.zeros(q.shape[0], q.shape[1], 5376, dtype=q.dtype)], dim=-1)

q_evl = (HadamardTransform.apply(q_pad.float()) / math.sqrt(q.shape[-1])).to(dtype)

(q_ref-q_evl[...,:11008]).abs().mean()

(q_ref-q_evl[...,:11008]).abs()/(q_ref.abs() + 1e-6)

16384
5376


tensor([[[7.4804, 0.4414, 2.9439,  ..., 0.0269, 4.0045, 4.4479],
         [7.9524, 1.3467, 3.6151,  ..., 2.4852, 0.2197, 0.7258]],

        [[8.6133, 0.6591, 0.2078,  ..., 1.3946, 0.8357, 0.7427],
         [7.3354, 1.7494, 3.9259,  ..., 3.4990, 4.5255, 2.2099]]])

In [63]:
n1, n2, n3 = 2, 2, 4*172
q = torch.zeros(n1*n2*n3, dtype=torch.float16)
for i in range(n1*n2*n3):
  q[i] = i/(n1*n2*n3)
q = q.reshape(n1, n2, n3)


had_K, K = get_hadK(q.shape[-1])
q_had = matmul_hadU_cuda(q, had_K, K)

print(had_K)
print(K)
print(q_had)

data = q_had.numpy()
data.tofile('./hadamard_test/hadamard_transform.bin')

tensor([[ 1., -1., -1.,  ..., -1., -1.,  1.],
        [-1.,  1., -1.,  ..., -1., -1., -1.],
        [-1., -1.,  1.,  ...,  1., -1., -1.],
        ...,
        [ 1.,  1., -1.,  ...,  1., -1., -1.],
        [ 1.,  1.,  1.,  ..., -1.,  1., -1.],
        [-1.,  1.,  1.,  ..., -1., -1.,  1.]])
172
tensor([[[ 5.7770e-02, -2.1434e-04, -4.4322e-04,  ..., -1.5831e-04,
          -3.3617e-04, -2.8014e-06],
         [ 3.6279e-01, -2.0456e-04, -5.9605e-04,  ..., -1.6737e-04,
          -3.1614e-04, -1.8597e-05]],

        [[ 6.6797e-01, -2.6035e-04, -4.4680e-04,  ..., -1.6737e-04,
          -5.5885e-04,  5.5790e-05],
         [ 9.7314e-01, -2.6035e-04, -4.4680e-04,  ..., -1.6737e-04,
          -5.5885e-04,  5.5790e-05]]], dtype=torch.float16)


In [64]:
q_had_ref = q_had.reshape(-1, n3).to(torch.float32)

In [None]:
for i in range(q_had_ref.shape[0]):
    for j in range(q_had_ref.shape[1]):
        hex_value = float_to_hex(q_had_ref[i, j].item())
        print(f"0x{hex_value}", end=" ")

0x3d6ca000
0xb960c000
0xb9e86000
0x36b20000
0x3d4ca000
0xb9610000
0xb9e6c000
0xb6da0000
0x3d7ba000
0xb9670000
0xb9ee2000
0x36240000
0x3d94e000
0xb96c4000
0xb9e74000
0x36da0000
0x3d84e000
0xb9668000
0xb9e02000
0xb6140000
0x3d6a8000
0xb96d0000
0xb9dbe000
0x35100000
0x3d4a6000
0xb9650000
0xb9de2000
0x361c0000
0x3d2aa000
0xb96cc000
0xb9e90000
0xb7090000
0x3d0b4000
0xb9670000
0xb9efe000
0xb61c0000
0x3d39c000
0xb95f4000
0xb9e5a000
0x36d60000
0x3d19a000
0xb9600000
0xb9e70000
0xb6b60000
0x3d47e000
0xb9634000
0xb9e1c000
0xb5d80000
0x3d762000
0xb971c000
0xb9dd6000
0x37010000
0x3d560000
0xb96f8000
0xb9ea2000
0xb61c0000
0x3d824000
0xb96cc000
0xb9f3e000
0x363c0000
0x3d64c000
0xb9690000
0xb9db6000
0x37110000
0x3d452000
0xb9680000
0xb9f18000
0x37250000
0x3d252000
0xb9678000
0xb9e7e000
0x36fe0000
0x3d53c000
0xb969c000
0xb9ea2000
0x36440000
0x3d340000
0xb9774000
0xb9e88000
0x36a20000
0x3d620000
0xb9634000
0xb9e88000
0xb6b20000
0x3d882000
0xb9618000
0xb9da6000
0xb6da0000
0x3d9f4000
0xb9664000
0xb9e3e000