In [1]:
import numpy as np
import torch 
from provider import *

def bfloat16arr_to_uint8_arr(nparr, endian='little'):
    # Directly convert bfloat16 to uint8 since numpy does not support bfloat16
    if nparr.dtype != torch.bfloat16:
        raise ValueError('bfloat16_to_uint8: nparr must be of type bfloat16')
    if nparr.ndim != 1:
        nparr = nparr.reshape(-1)
    uint8_list = []
    # pdb.set_trace()
    # Since bfloat16 does not offer bitwise operation, we need to convert to int16 first
    nparr = nparr.view(torch.int16) # int32로
    for elem in nparr:
        if endian == 'little':
            uint8_list.append(elem & 0xFF)
            uint8_list.append((elem >> 8) & 0xFF)

        elif endian == 'big':
            uint8_list.append((elem >> 8) & 0xFF)
            uint8_list.append(elem & 0xFF)
        else:
            raise ValueError('bfloat16arr_to_uint8: endian must be either little or big')
    nparr = np.array(uint8_list, dtype=np.uint8)
    return nparr

def fp16arr_to_uint8_arr(nparr, endian='little'):
    # Directly convert bfloat16 to uint8 since numpy does not support bfloat16
    
    if nparr.ndim != 1:
        nparr = nparr.reshape(-1)
    uint8_list = []
    # pdb.set_trace()
    # Since bfloat16 does not offer bitwise operation, we need to convert to int16 first
    nparr = nparr.view(torch.int16) # int32로
    for elem in nparr:
        if endian == 'little':
            uint8_list.append(elem & 0xFF)
            uint8_list.append((elem >> 8) & 0xFF)

        elif endian == 'big':
            uint8_list.append((elem >> 8) & 0xFF)
            uint8_list.append(elem & 0xFF)
        else:
            raise ValueError('bfloat16arr_to_uint8: endian must be either little or big')
    nparr = np.array(uint8_list, dtype=np.uint8)
    return nparr

# Change the function above to convert float32 array to uint8 arr
def fp32arr_to_uint8_arr(nparr, endian='little'):
    if nparr.dtype != torch.float32:
        raise ValueError('fp32arr_to_uint8_arr: nparr must be of type float32')
    if nparr.ndim != 1:
        nparr = nparr.reshape(-1)
    uint8_list = []
    pdb.set_trace()
    # Convert float32 to uint8 using numpy
    # nparr = (nparr * 255).astype(np.uint8)
    nparr = (nparr).view(torch.int32)
    for elem in nparr:
        if endian == 'little':
            uint8_list.append(elem & 0xFF)
            uint8_list.append((elem >> 8) & 0xFF)
            # Add more since this is float32 (has 4 bytes)
            uint8_list.append((elem >> 16) & 0xFF)
            uint8_list.append((elem >> 24) & 0xFF)
        else:
            uint8_list.append((elem >> 24) & 0xFF)
            uint8_list.append((elem >> 16) & 0xFF)
            uint8_list.append((elem >> 8) & 0xFF)
            uint8_list.append(elem & 0xFF)
            
    nparr = np.array(uint8_list, dtype=np.uint8)
    return nparr

def int4_to_uint16_array(nparr):
    if nparr.dtype != torch.int8:
        raise ValueError('packed_int4_to_uint16_array: nparr must be of type int8')
    assert nparr.ndim == 2
    assert nparr.shape[1] == 4, "Only shapes (N, 4) are supported"
    assert nparr.max() < 8, nparr.min() > -9
    uint16_list = []
    for elem in nparr:
        uint16 = 0
        for i in range(4):
            uint16 |= (elem[i].item() & 0xF) << (i * 4)
        uint16_list.append(uint16)
    
    nparr = np.array(uint16_list, dtype=np.uint16)
    return nparr

def int4_to_uint8_array(nparr):
    if nparr.dtype != torch.int8:
        raise ValueError('packed_int4_to_uint16_array: nparr must be of type int8')
    assert nparr.ndim == 2
    assert nparr.shape[1] == 4, "Only shapes (N, 4) are supported"
    assert nparr.max() < 8, nparr.min() > -9
    uint8_list = []
    for elem in nparr:
        uint16 = 0
        for i in range(4):
            uint16 |= (elem[i].item() & 0xF) << (i * 4)
        uint8_list.append(uint16 & 0xFF)
        uint8_list.append((uint16 >> 8) & 0xFF)
    
    nparr = np.array(uint8_list, dtype=np.uint8)
    return nparr

def twos_complement_int4_to_binary(nparr):
    if nparr.dtype != torch.int8:
        raise ValueError('twos_complement_int4_to_binary: nparr must be of type int8')
    binary_list = []
    for elem in nparr:
        # Each elenmement is 4 bits, but represented as 8 bits
        # We need to convert to 4 bits
        elem = elem & 0xF # But this is for unsigned int
        # Convert to binary
        binary = bin(elem)[2:][-4:]
        # Also get sign
        sign = 1 if elem < 0 else 0
        # Add sign bit
        binary = str(sign) + binary
        # Add padding
        binary_list.append(binary.zfill(4))
    return binary_list

In [3]:
data_load = torch.load("test_data.pth")

In [4]:
x = data_load['in_A'].bfloat16()
w = data_load['in_B'].bfloat16()
result_ref = x @ w

result_uint8 = bfloat16arr_to_uint8_arr(result_ref)
result_uint16 = result_uint8.reshape(-1,2).view(np.uint16)
with open('data_bf_int/dout_result.hex', 'w') as f:
    for row in result_uint16:
        f.write(' '.join(format(val, 'd') for val in row) + '\n')

#x_f = x.float().detach()
#x_np = x_f.numpy()
x_uint8 = bfloat16arr_to_uint8_arr(x)
x_uint16 = x_uint8.reshape(-1, 2).view(np.uint16)

with open('data_bf_int/din_a.hex', 'w') as f:
    for row in x_uint16:
        f.write(' '.join(format(val, 'X').zfill(4) for val in row) + '\n')

w_i = torch.tensor(w, dtype=torch.int8)
w_uint8 = int4_to_uint8_array(w_i)
w_uint16 = w_uint8.reshape(-1, 2).view(np.uint16)

with open('data_bf_int/din_b.hex', 'w') as f:
    for row in w_uint16:
        f.write(' '.join(format(val, 'X').zfill(4) for val in row) + '\n')

  w_i = torch.tensor(w, dtype=torch.int8)


In [15]:
x_uint8 = bfloat16arr_to_uint8_arr(x * 10)
x_uint16 = x_uint8.reshape(-1, 2).view(np.uint16)
xmaxidx = torch.argmax(x)
x_uint16[xmaxidx-56]

with open('data_bf_bf/din_a.hex', 'w') as f:
    for row in x_uint16:
        f.write(' '.join(format(val, 'X').zfill(4) for val in row) + '\n')

with open('data_bf_bf/din_b.hex', 'w') as f:
    for row in w_uint16:
        f.write(' '.join(format(val, 'X').zfill(4) for val in row) + '\n')

array([48585], dtype=uint16)

In [18]:
x = torch.randn(100,32).bfloat16()
w = torch.randn(100,32).bfloat16()

x_uint8 = bfloat16arr_to_uint8_arr(x)
x_uint16 = x_uint8.reshape(-1, 2).view(np.uint16)

w_uint8 = bfloat16arr_to_uint8_arr(w)
w_uint16 = w_uint8.reshape(-1, 2).view(np.uint16)

In [17]:
x

tensor([[ 1.3672,  0.0728, -1.5625,  ...,  0.1855, -1.2266, -1.1641],
        [-0.0212,  0.3438,  0.4277,  ...,  1.2422, -0.2539, -1.2109],
        [ 1.3594,  0.7617, -0.9570,  ...,  0.3281,  0.0237, -1.3125],
        ...,
        [-0.8242,  1.0781,  0.3105,  ..., -0.1650,  0.0320,  1.0000],
        [-1.6328,  0.2695, -0.7148,  ...,  1.1562, -0.6914, -0.8867],
        [ 0.8008,  1.1016, -0.8828,  ..., -0.0903,  0.2168,  1.5781]],
       dtype=torch.bfloat16)

In [127]:
x = data_load['in_A'].half()
w = data_load['in_B'].half()
result_ref = x @ w

result_uint8 = fp16arr_to_uint8_arr(result_ref)
result_uint16 = result_uint8.reshape(-1,2).view(np.uint16)
with open('data_fp_int/dout_result.hex', 'w') as f:
    for row in result_uint16:
        f.write(' '.join(format(val, 'd') for val in row) + '\n')

#x_f = x.float().detach()
#x_np = x_f.numpy()
x_uint8 = fp16arr_to_uint8_arr(x)
x_uint16 = x_uint8.reshape(-1, 2).view(np.uint16)

with open('data_fp_int/din_a.hex', 'w') as f:
    for row in x_uint16:
        f.write(' '.join(format(val, 'X').zfill(4) for val in row) + '\n')

w_i = torch.tensor(w, dtype=torch.int8)
w_uint8 = int4_to_uint8_array(w_i)
w_uint16 = w_uint8.reshape(-1, 2).view(np.uint16)

with open('data_fp_int/din_b.hex', 'w') as f:
    for row in w_uint16:
        f.write(' '.join(format(val, 'X').zfill(4) for val in row) + '\n')

  w_i = torch.tensor(w, dtype=torch.int8)


In [17]:
w.shape

torch.Size([768, 4])

In [123]:
w_i = torch.tensor(w, dtype=torch.int8)
w_uint8 = int4_to_uint8_array(w_i)
w_uint16 = w_uint8.reshape(-1, 2).view(np.uint16)

  w_i = torch.tensor(w, dtype=torch.int8)


In [124]:
int4_to_uint16_array(w_i).reshape(-1, 1).shape
w_uint16.shape

(768, 1)

In [None]:
x = data_load['in_A'].half()
w = data_load['in_B'].half()
result_ref = x @ w

In [92]:
x_f = x.float().detach()
x_np = x_f.numpy()
x_uint8 = fp16arr_to_uint8_arr(x_f)
x_uint16 = x_uint8.reshape(-1, 2).view(np.uint16)

In [118]:
with open('data_half/din_a.hex', 'w') as f:
    for row in x_uint16:
        f.write(' '.join(format(val, 'X').zfill(4) for val in row) + '\n')

In [117]:
format(0, 'x').zfill(4)

'0000'

In [100]:
result_uint8 = fp16arr_to_uint8_arr(result_ref)
result_uint16 = result_uint8.reshape(-1,2).view(np.uint16)
with open('data_half/dout_result.hex', 'w') as f:
    for row in result_uint16:
        f.write(' '.join(format(val, 'd') for val in row) + '\n')

In [119]:
with open('data_half/din_b.hex', 'w') as f:
    for row in w_uint16:
        f.write(' '.join(format(val, 'X').zfill(4) for val in row) + '\n')

In [113]:
elem = w_i[6]
uint16 = 0
for i in range(4):
    uint16 |= (elem[i].item() & 0xF) << (i * 4)

In [115]:
hex(uint16)
w_i[6]
#bin(elem[3]&0xF) + bin(elem[2]&0xF)[2:] + bin(elem[1]&0xF)[2:] + bin(elem[0]&0xF)[2:]

tensor([0, 0, 1, 0], dtype=torch.int8)