In [66]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
import bitmat

In [67]:
device = torch.device("cuda")

In [14]:
y_mat = bitlinear(x)

In [70]:
class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-5):
        super().__init__()
        """
        Paper: https://arxiv.org/abs/1910.07467
        """
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x: Tensor):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight


def activation_quant(x):
    """Per−token quantization to 8 bits. No grouping is needed for quantization.
    Args:
    x: an activation tensor with shape [n, d]
    Returns:
    y: a quantized activation tensor with shape [n, d]
    """
    scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)
    y = (x * scale).round().clamp_(-128, 127) / scale
    return y

def weight_quant(w):
    """Per−tensor quantization to 1.58 bits. No grouping is needed for quantization.
    Args:
    w: a weight tensor with shape [d, k]
    Returns:
    u: a quantized weight with shape [d, k]
    """
    scale = 1.0 / w.abs().mean().clamp_(min=1e-5)
    u = (w * scale).round().clamp_(-1, 1) / scale
    return u


class BitLinear(nn.Linear):
    """
    This is only for training, and kernel optimization is needed for efficiency.
    """
    def __init__(self, in_features: int, out_features: int, bias: bool = True,
                 device=None, dtype=None, config=None):
        super().__init__(in_features, out_features, bias, device, dtype)
        self.norm = RMSNorm(in_features)

    def forward(self, x):
        """
        Args:
        x: an input tensor with shape [n, d]
        Returns:
        y: an output tensor with shape [n, d]
        """
        w = self.weight  # a weight tensor with shape [d, k]
        x_norm = self.norm(x)
        # Atrick for implementing Straight−Through−Estimator (STE) using detach()
        x_quant = x_norm + (activation_quant(x_norm) - x_norm).detach()
        w_quant = w + (weight_quant(w) - w).detach()
        y = F.linear(x_quant, w_quant)
        return y


In [71]:
blin = BitLinear(2000, 200,bias=False).to(device=device)
bitlinear = bitmat.BitLinear(2000, 200,bias=False).to(device=device)
x = torch.randn(8, 2000,device=device)

In [72]:
bitlinear.convert_weights_to_parameters()
blin.load_state_dict(bitlinear.state_dict())

<All keys matched successfully>

In [73]:
y_exp = blin(x)
y_mat = bitlinear(x)

In [74]:
torch.isclose(y_exp, y_mat).all()

tensor(False, device='cuda:0')

In [76]:
(y_exp - y_mat).sum()

tensor(2.0575, device='cuda:0', grad_fn=<SumBackward0>)

In [77]:
x_ref = blin.norm(x.to(blin.norm.weight.dtype)).to(x.dtype)
x_tri = bitlinear.norm(x.to(bitlinear.norm.weight.dtype)).to(x.dtype)

In [78]:
torch.isclose(x_ref, x_tri).all()

tensor(True, device='cuda:0')

In [85]:
from bitmat.utils.bitmat import bitmat as pi_bitmat

In [95]:
out_tri = pi_bitmat(bitlinear.weight,x_tri,None)
out_tri

tensor([[ 0.8170,  1.1203, -0.1639,  ...,  0.2053, -0.8501, -2.2957],
        [ 1.6535, -1.3252,  1.0182,  ...,  0.5661, -0.0032,  0.7885],
        [ 1.9974,  0.1114, -1.1349,  ...,  1.1120, -1.1433, -0.3425],
        ...,
        [-0.5807, -0.2230,  0.7941,  ...,  1.2162, -0.1268, -0.4903],
        [ 0.0493,  0.3854, -0.9208,  ...,  1.3942, -0.2622, -2.5241],
        [-1.3806,  0.6822, -1.7021,  ...,  0.1908, -0.3309,  0.2036]],
       device='cuda:0', grad_fn=<BitMatBackward>)

In [None]:
x_quant = x_ref + (activation_quant(x_ref) - x_ref).detach()
w_quant = w + (weight_quant(w) - w).detach()
y = F.linear(x_quant, w_quant)