In [None]:
import os
import json
from functools import cache
from dataclasses import dataclass
import typing as tp

import torch
from torch import nn

from transformers import AutoConfig
from transformers.models.mixtral import MixtralForCausalLM, MixtralConfig

from safetensors.torch import load_file

from torch import nn
from tqdm.auto import trange

from hqq.core.quantize import BaseQuantizeConfig

from src.expert_cache import ExpertCache
from src.expert_wrapper import MixtralExpertWrapper
from src.custom_layers import (
    HQQLinearTritonSavable,
    MixtralBLockSparseTop2MLP_HQQ,
    SparseMoeWrapper,
)
from src.utils import with_default_dtype

import sys

sys.path.append("mixtral-offloading")
import torch
from torch.nn import functional as F
from hqq.core.quantize import BaseQuantizeConfig
from huggingface_hub import snapshot_download
from IPython.display import clear_output
from tqdm.auto import trange
from transformers import AutoConfig, AutoTokenizer
from transformers.utils import logging as hf_logging

from src.build_model import OffloadConfig, QuantConfig, build_model

In [None]:
model_name = "mistralai/Mixtral-8x7B-Instruct-v0.1"
quantized_model_name = "lavawolfiee/Mixtral-8x7B-Instruct-v0.1-offloading-demo"
state_path = "Mixtral-8x7B-Instruct-v0.1-offloading-demo"

config = AutoConfig.from_pretrained(quantized_model_name)

device = torch.device("cpu")
offload_per_layer = 4
num_experts = config.num_local_experts

In [None]:
attn_config = BaseQuantizeConfig(
    nbits=4,
    group_size=64,
    quant_zero=True,
    quant_scale=True,
)
attn_config["scale_quant_params"]["group_size"] = 256


ffn_config = BaseQuantizeConfig(
    nbits=2,
    group_size=16,
    quant_zero=True,
    quant_scale=True,
)
quant_config = QuantConfig(ffn_config=ffn_config, attn_config=attn_config)

In [None]:
meta1, meta2 = quant_config.get_ffn_metas(
    config.hidden_size, config.intermediate_size
)
mblock = MixtralBLockSparseTop2MLP_HQQ(config, quant_config.ffn_config, meta1, meta2)
mblock

In [None]:
torch.random.manual_seed(0)
hqqlayer = HQQLinearTritonSavable(nn.Linear(128, 128, bias=False), quant_config.ffn_config, use_gpu=False)
inp = torch.randn(1, 128)
out = hqqlayer(inp)
out

In [None]:
hqqlayer(inp)