In [2]:
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import bitsandbytes as bnb
import itertools
from tabulate import tabulate
from bitsandbytes.nn import Linear4bit, Params4bit
from torch import Tensor
from tqdm.auto import tqdm

In [3]:
class BNBLinear(Linear4bit):
    def __init__(
            self,
            input_features,
            output_features,
            bias=False,
            compute_dtype=None,
            compress_statistics=True,
            quant_type='fp4',
            quant_storage=torch.uint8,
            num_bits=8,
            device=None):
        super().__init__(input_features, output_features, bias, device)
        self.weight = Params4bit(self.weight.data,
                                 requires_grad=False,
                                 compress_statistics=compress_statistics,
                                 quant_type=quant_type,
                                 quant_storage=quant_storage,
                                 module=self)
        self.norm: nn.Module = nn.LayerNorm(input_features)
        self.compute_dtype = compute_dtype
        self.compute_type_is_set = False
        self.quant_state = None
        self.quant_storage = quant_storage
        self.eps:float = 1e-5
        self.quantization_range: int = 2 ** (num_bits - 1)
    
    def binarize_weights(self, weights_gamma: float) -> torch.Tensor:
        scaled_weights:torch.Tensor = self.weight / (weights_gamma + self.eps)
        binarized_input_no_grad: torch.Tensor = torch.clamp(torch.round(scaled_weights), min=-1, max=1)
        binarized_input_with_grad: torch.Tensor = (binarized_input_no_grad - self.weight).detach() + self.weight
        return binarized_input_with_grad
    
    def quantize_activations(self, input:torch.Tensor, input_gamma: float) -> torch.Tensor:
        # Equation 4 BitNet paper
        quantized_input = torch.clamp(
                input * self.quantization_range / input_gamma,
                -self.quantization_range + self.eps,
                self.quantization_range - self.eps,
            )
        return quantized_input
    
    def dequantize_activations(self, input: torch.Tensor, input_gamma: float, beta: float) -> torch.Tensor:
        return input * input_gamma * beta / self.quantization_range


    def forward(self, x: torch.Tensor):
        # weights are cast automatically as Int8Params, but the bias has to be cast manually
        if self.bias is not None and self.bias.dtype != x.dtype:
            self.bias.data = self.bias.data.to(x.dtype)

        if getattr(self.weight, 'quant_state', None) is None:
            if getattr(self, 'quant_state', None) is not None:
                # the quant state got lost when the parameter got converted. This happens for example for fsdp
                # since we registered the module, we can recover the state here
                assert self.weight.shape[1] == 1
                if not isinstance(self.weight, Params4bit):
                    self.weight = Params4bit(self.weight, quant_storage=self.quant_storage)
                self.weight.quant_state = self.quant_state
            else:
                print('FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.')
        if not self.compute_type_is_set:
            self.set_compute_type(x)
            self.compute_type_is_set = True

        inp_dtype = x.dtype
        if self.compute_dtype is not None:
            x = x.to(self.compute_dtype)

        input_gamma = x.abs().max().item()
        weight_abs_mean = self.weight.float().abs().mean().item()
        binarized_weights = self.binarize_weights(weight_abs_mean)
        x = self.quantize_activations(x, input_gamma)

        bias = None if self.bias is None else self.bias.to(self.compute_dtype)
        out = bnb.matmul_4bit(x, binarized_weights.t(), bias=bias, quant_state=self.weight.quant_state)
        out = self.dequantize_activations(out, input_gamma, weight_abs_mean)
        out = out.to(inp_dtype)

        return out


class CiscoHalf(nn.Linear):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        bias: bool = False,
        num_bits: int = 8,
    ):
        super().__init__(in_features, out_features, bias)
        self.eps:float = 1e-5
        self.quantization_range: int = 2 ** (num_bits - 1) # Q_b in the paper
        self.norm: nn.Module = nn.LayerNorm(in_features)


    def ste_weights(self, weights_gamma: float) -> torch.Tensor:
        eps: float = 1e-8
        scaled_weights:torch.Tensor = self.weight / (weights_gamma + eps)
        binarized_input_no_grad: torch.Tensor = torch.clamp(torch.round(scaled_weights), min=-1, max=1)
        binarized_input_with_grad: torch.Tensor = (binarized_input_no_grad - self.weight).detach() + self.weight
        return binarized_input_with_grad.to(torch.float16)


    def binarize_weights(self, weights_gamma: float) -> torch.Tensor:
        binarized_weights = self.ste_weights(weights_gamma)
        return binarized_weights


    def quantize_activations(self, _input:torch.Tensor, input_gamma: float) -> torch.Tensor:
        # Equation 4 BitNet paper
        quantized_input = torch.clamp(
                _input * self.quantization_range / input_gamma,
                -self.quantization_range + self.eps,
                self.quantization_range - self.eps,
            )
        return quantized_input.to(torch.float16)


    def dequantize_activations(self, _input: torch.Tensor, input_gamma: float, beta: float) -> torch.Tensor:
        return _input * input_gamma * beta / self.quantization_range


    def forward(self, _input: torch.Tensor) -> torch.Tensor:
        normalized_input: torch.Tensor = self.norm(_input)
        input_gamma: float = normalized_input.abs().max().item()
        weight_abs_mean: float = self.weight.abs().mean().item()

        binarized_weights = self.binarize_weights(weight_abs_mean)
        input_quant = self.quantize_activations(normalized_input, input_gamma)
        output = F.linear(input_quant, binarized_weights, self.bias)
        output = self.dequantize_activations(output, input_gamma, weight_abs_mean)

        return output


class Cisco(nn.Linear):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        bias: bool = False,
        num_bits: int = 8,
    ):
        super().__init__(in_features, out_features, bias)
        self.eps:float = 1e-5
        self.quantization_range: int = 2 ** (num_bits - 1) # Q_b in the paper
        self.norm: nn.Module = nn.LayerNorm(in_features)


    def ste_weights(self, weights_gamma: float) -> torch.Tensor:
        eps: float = 1e-8
        scaled_weights:torch.Tensor = self.weight / (weights_gamma + eps)
        binarized_input_no_grad: torch.Tensor = torch.clamp(torch.round(scaled_weights), min=-1, max=1)
        binarized_input_with_grad: torch.Tensor = (binarized_input_no_grad - self.weight).detach() + self.weight
        return binarized_input_with_grad


    def binarize_weights(self, weights_gamma: float) -> torch.Tensor:
        binarized_weights = self.ste_weights(weights_gamma)
        return binarized_weights


    def quantize_activations(self, _input:torch.Tensor, input_gamma: float) -> torch.Tensor:
        # Equation 4 BitNet paper
        quantized_input = torch.clamp(
                _input * self.quantization_range / input_gamma,
                -self.quantization_range + self.eps,
                self.quantization_range - self.eps,
            )
        return quantized_input


    def dequantize_activations(self, _input: torch.Tensor, input_gamma: float, beta: float) -> torch.Tensor:
        return _input * input_gamma * beta / self.quantization_range


    def forward(self, _input: torch.Tensor) -> torch.Tensor:
        normalized_input: torch.Tensor = self.norm(_input)
        input_gamma: float = normalized_input.abs().max().item()
        weight_abs_mean: float = self.weight.abs().mean().item()

        binarized_weights = self.binarize_weights(weight_abs_mean)
        input_quant = self.quantize_activations(normalized_input, input_gamma)
        output = F.linear(input_quant, binarized_weights, self.bias)
        output = self.dequantize_activations(output, input_gamma, weight_abs_mean)

        return output

In [4]:
def test_forward_speed(layer_class, input_sizes, num_runs=1000):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    results = {}

    for size in tqdm(input_sizes):
        in_features = size[-1]
        out_features = size[-1]
        layer = layer_class(in_features, out_features).to(device)
        input_tensor = torch.randn(size).to(device)

        # Warm-up run
        _ = layer(input_tensor)

        # Timing runs
        times = []
        for _ in range(num_runs):
            start_time = time.perf_counter()
            _ = layer(input_tensor)
            end_time = time.perf_counter()
            times.append(end_time - start_time)

        avg_time = sum(times) / num_runs
        results[size] = avg_time

    return results

# Example usage
input_sizes = [
    (1, 64),
    (1, 640),
    (1, 6400),
    (8, 64),
    (8, 640),
    (8, 6400),
    (128, 64),
    (128, 640),
    (128, 6400),
    (256, 12800), 
]

layer_classes = [nn.Linear, BNBLinear, CiscoHalf, Cisco]  # Add more layer classes as needed
layer_results = {}

for layer_class in layer_classes:
    layer_results[layer_class.__name__] = test_forward_speed(layer_class, input_sizes)

print("Speed comparison matrix:")
headers = ["Input Size"] + [cls.__name__ for cls in layer_classes]
table_data = []

for size in input_sizes:
    row = [str(size)]
    for layer_class in layer_classes:
        row.append(f"{layer_results[layer_class.__name__][size]:.6f}")
    table_data.append(row)

print(tabulate(table_data, headers, tablefmt="grid"))

table_data = []
for layer1 in layer_classes:
    row = [layer1.__name__]
    for layer2 in layer_classes:
        if layer1 == layer2:
            row.append("1.00")
        else:
            time1_sum = sum(layer_results[layer1.__name__].values())
            time2_sum = sum(layer_results[layer2.__name__].values())
            ratio = time1_sum / time2_sum
            row.append(f"{ratio:.2f}")
    table_data.append(row)

print(tabulate(table_data, headers, tablefmt="grid"))

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

Speed comparison matrix:
+--------------+----------+-------------+-------------+----------+
| Input Size   |   Linear |   BNBLinear |   CiscoHalf |    Cisco |
| (1, 64)      |  3.3e-05 |    0.000491 |    0.000496 | 0.000469 |
+--------------+----------+-------------+-------------+----------+
| (1, 640)     |  3.3e-05 |    0.000527 |    0.000588 | 0.00053  |
+--------------+----------+-------------+-------------+----------+
| (1, 6400)    |  3.3e-05 |    0.006594 |    0.009963 | 0.009347 |
+--------------+----------+-------------+-------------+----------+
| (8, 64)      |  3.1e-05 |    0.000619 |    0.000547 | 0.000424 |
+--------------+----------+-------------+-------------+----------+
| (8, 640)     |  3.2e-05 |    0.000674 |    0.000647 | 0.000507 |
+--------------+----------+-------------+-------------+----------+
| (8, 6400)    |  3.8e-05 |    0.007255 |    0.009561 | 0.008994 |
+--------------+----------+-------------+-------------+----------+
| (128, 64)    |  4e-05   |    0.0005