In [None]:
%env CUDA_VISIBLE_DEVICES=5

import torch
from torch import nn

from linear import HiggsLinear

def replace_with_higgs_linear(
    model,
    quantization_config=None,
    current_key_name=None,
    has_been_replaced=False,
):
    for name, module in model.named_children():
        if current_key_name is None:
            current_key_name = []
        current_key_name.append(name)

        if isinstance(module, nn.Linear):
            # Check if the current key is not in the `linear_weights_not_to_quantize`
            if ".".join(current_key_name) in quantization_config:
                in_features = module.in_features
                out_features = module.out_features
                higgs_d = quantization_config[".".join(current_key_name)]

                model._modules[name] = HiggsLinear(
                    in_features,
                    out_features,
                    higgs_d,
                    bias=module.bias is not None,
                    dtype=module.weight.dtype,
                )
                has_been_replaced = True

                # Store the module class in case we need to transpose the weight later
                model._modules[name].source_cls = type(module)
                # Force requires grad to False to avoid unexpected errors
                model._modules[name].requires_grad_(False)
        if len(list(module.children())) > 0:
            _, has_been_replaced = replace_with_higgs_linear(
                module,
                quantization_config=quantization_config,
                current_key_name=current_key_name,
                has_been_replaced=has_been_replaced,
            )
        # Remove the last key for recursion
        current_key_name.pop(-1)
    return model, has_been_replaced

In [2]:
from transformers import AutoModelForCausalLM


model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Meta-Llama-3.1-8B",
    torch_dtype=torch.float16,
    # attn_implementation="eager",
    device_map="cuda",
)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [3]:
from typing import Optional

def build_layerwise_edenn_config(
    edenn_d: Optional[int] = None, 
    blockwise_edenn_config: Optional[list[int]] = None,
    layerwise_edenn_config: Optional[dict[str, int]] = None,
) -> list[(int, int)]:
    if layerwise_edenn_config is not None:
        assert edenn_d is None and blockwise_edenn_config is None
        return layerwise_edenn_config
    
    if blockwise_edenn_config is None:
        assert edenn_d is not None
        blockwise_edenn_config = [edenn_d for _ in range(32)]
    
    layer_names = [
        "self_attn.q_proj",
        "self_attn.k_proj",
        "self_attn.v_proj",
        "self_attn.o_proj",
        "mlp.gate_proj",
        "mlp.up_proj",
        "mlp.down_proj",
    ]
    
    return {
        f"model.layers.{i}.{layer_name}": blockwise_edenn_config[i]
        for layer_name in layer_names
        for i in range(len(blockwise_edenn_config))
    }


In [None]:
from accelerate import load_checkpoint_and_dispatch

DIM = 1

if DIM != -1:
    model, _ = replace_with_higgs_linear(
        model,
        build_layerwise_edenn_config(DIM),
    )
    model = load_checkpoint_and_dispatch(
        model,
        "~/models/higgs/Meta-Llama-3.1-8B.pt",
    ).to("cuda")

In [5]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B")

In [None]:
tokenizer.decode(
    model.generate(**tokenizer("Hi!", return_tensors='pt').to("cuda"))[0].cpu()
)

In [7]:
model.generation_config.cache_implementation = "static"

model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)

In [None]:
model.generate(
    **tokenizer("Hi!", return_tensors='pt').to("cuda")
)

In [None]:
with torch.no_grad():
    for _ in range(10):
        model(**tokenizer("Hi!", return_tensors='pt').to("cuda"))
    
    # benchmarking with jupyter macro
    %timeit model(**tokenizer("Hi!", return_tensors='pt').to("cuda")); torch.cuda.synchronize()