In [1]:
%env CUDA_VISIBLE_DEVICES=6

import math

import torch
from linear import HiggsLinear

env: CUDA_VISIBLE_DEVICES=6


In [2]:
SIZE = 4096
HADAMARD_SIZE = 1024

layer = HiggsLinear(
    SIZE, SIZE,
    2, 256,
    device="cuda",
    dtype=torch.float16,
)

In [3]:
input = torch.rand(1, 1, SIZE, device="cuda", dtype=torch.float16)

layer(input)

tensor([[[ 9.6875, 47.5625,  6.3594,  ..., 20.6406, 35.1250,  8.8281]]],
       device='cuda:0', dtype=torch.float16)

In [4]:
from aqlm.utils import unpack_int_data, _dequantize_weight

In [5]:
@torch.inference_mode()
def unpack_int_data(data: torch.IntTensor, nbits: int) -> torch.IntTensor:
    return data.to(torch.int64) % (2**nbits)

codebooks = torch.load("../grids/EDEN2-256.pt").half()

weight = _dequantize_weight(
    unpack_int_data(layer.codes, 8)[:,:,None],
    codebooks[None,:,None,:],
    # scales=layer.scales,
).reshape(-1, SIZE//HADAMARD_SIZE, HADAMARD_SIZE) * layer.scales[...,None]

weight = weight.reshape(-1, SIZE)

In [6]:
from fast_hadamard_transform import hadamard_transform

hadamard_transform(input.reshape(-1, HADAMARD_SIZE), scale=1/math.sqrt(HADAMARD_SIZE)).reshape(1, -1) @ weight.T

tensor([[ 9.6875, 47.5625,  6.3555,  ..., 20.6406, 35.1250,  8.8203]],
       device='cuda:0', dtype=torch.float16)

In [7]:
%%time

with torch.inference_mode():
    for i in range(1000):
        layer(input)
        torch.cuda.synchronize()

CPU times: user 160 ms, sys: 92 µs, total: 160 ms
Wall time: 158 ms


In [8]:
%%time

with torch.inference_mode():
    for i in range(1000):
        torch.nn.functional.linear(input, weight)
        torch.cuda.synchronize()

CPU times: user 33.1 ms, sys: 506 µs, total: 33.7 ms
Wall time: 31.9 ms
