In [1]:
%env CUDA_VISIBLE_DEVICES=6

import math

import torch
from linear import HiggsLinear, pad_to_block

env: CUDA_VISIBLE_DEVICES=6


In [2]:
SIZE = 4096
HADAMARD_SIZE = 1024
GRID_DIM = 2

layer = HiggsLinear(
    SIZE, SIZE,
    GRID_DIM, 256,
    device="cuda",
    dtype=torch.float16,
)

In [3]:
input = torch.rand(1, 1, SIZE, device="cuda", dtype=torch.float16) * 2 - 1

In [4]:
layer(input)[0, 0]

tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0', dtype=torch.float16)

In [5]:
from aqlm.utils import unpack_int_data, _dequantize_weight

codebooks = torch.load(f"../grids/EDEN{GRID_DIM}-256.pt").half()

post_hadamard_size = ((HADAMARD_SIZE - 1) // GRID_DIM + 1) * GRID_DIM

unscaled_weight = _dequantize_weight(
    unpack_int_data(layer.codes, 8)[:,:,None],
    codebooks[None,:,None,:],
    # scales=layer.scales,
)
unscaled_weight = unscaled_weight.reshape(unscaled_weight.shape[0], -1, post_hadamard_size)[...,:HADAMARD_SIZE]

weight = unscaled_weight * layer.scales[...,None]

weight = weight.reshape(weight.shape[0], -1)

In [6]:
from fast_hadamard_transform import hadamard_transform

(
    hadamard_transform(input.reshape(-1, HADAMARD_SIZE), scale=1/math.sqrt(HADAMARD_SIZE)).reshape(1, 1, SIZE) @ weight.T
)[0, 0]

tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0', dtype=torch.float16)

In [7]:
%%time

with torch.inference_mode():
    for i in range(1000):
        layer(input)
        torch.cuda.synchronize()

CPU times: user 161 ms, sys: 138 Âµs, total: 161 ms
Wall time: 160 ms


In [8]:
%%time

with torch.inference_mode():
    for i in range(1000):
        torch.nn.functional.linear(input, weight)
        torch.cuda.synchronize()

CPU times: user 65 ms, sys: 0 ns, total: 65 ms
Wall time: 64 ms
