In [4]:
from dmx.compressor.modeling.hf import DmxModel
import torch


class Submod(torch.nn.Module):
    def __init__(self, indim, hiddim, outdim) -> None:
        super().__init__()
        self.lin1 = torch.nn.Linear(indim, hiddim)
        self.act = torch.nn.ReLU()
        self.lin2 = torch.nn.Linear(hiddim, outdim)

    def forward(self, x, y, relu = True):
        out = self.lin1(x + y)
        if relu:
            out = self.act(out)
        out = self.lin2(out)
        return out


class CustomModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.sm1 = Submod(160, 6400, 6400)
        self.act = torch.nn.GELU()
        self.sm2 = Submod(6400, 12800, 10)

    def forward(self, x, relu = True):
        out = self.sm1(x, x, relu)
        out = self.act(out)
        out = self.sm2(out, out, relu)
        return out

output of the main model and submodule is different from baseline model after quantization

In [None]:
model0 = CustomModel()
model = CustomModel()
model.load_state_dict(model0.state_dict())
model = DmxModel.from_torch(model)
inp = torch.rand((1, 160))
ref_output = model0(inp)
assert torch.allclose(model(inp),model0(inp))
model.to_basic_mode()
assert not torch.allclose(model(inp),model0(inp))
assert not torch.allclose(model0.sm1(inp,inp),model.sm1(inp,inp))

output of the whole model is equivalent to running submodules sequentially

In [None]:

basic_output = model(inp)
assert not torch.allclose(ref_output, basic_output)
basic_output_from_submod = model.sm1(inp, inp)
basic_output_from_submod = model.act(basic_output_from_submod)
basic_output_from_submod = model.sm2(basic_output_from_submod, basic_output_from_submod)
torch.allclose(basic_output, basic_output_from_submod)

DmxModules are shared accross _gm for different model components

In [7]:
assert model.sm1._gm.lin1 is model._gm.sm1.lin1
assert model.act._gm is model._gm.act

Configure the model to baseline, output of the whole model is still equivalent to running submodules sequentially

In [None]:
model.to_baseline_mode()
baseline_output = model(inp)
assert torch.allclose(ref_output, baseline_output)
baseline_output_from_submod = model.sm1(inp, inp)
baseline_output_from_submod = model.act(baseline_output_from_submod)
baseline_output_from_submod = model.sm2(baseline_output_from_submod, baseline_output_from_submod)
torch.allclose(baseline_output, baseline_output_from_submod)

change in control flow triggers submod retransformation

In [None]:
assert torch.allclose(model0.sm1(inp,inp, False),model.sm1(inp,inp, False))

Quantizing submodules changes output of main model

In [10]:
model.sm1.to_basic_mode()
assert not torch.allclose(model0.sm1(inp,inp, False),model.sm1(inp,inp, False))
assert not torch.allclose(model(inp),model0(inp))

In [None]:
model._gm

In [None]:
model

### Whisper

In [None]:
from transformers import AutoProcessor, WhisperForConditionalGeneration
from dmx.compressor.modeling.hf import DmxModel
from datasets import load_dataset
import torch
processor = AutoProcessor.from_pretrained("openai/whisper-tiny")
model0 = WhisperForConditionalGeneration.from_pretrained(
        "openai/whisper-tiny",
        torch_dtype=torch.float16,
        device_map="cuda",
        attn_implementation="eager",
    ).to("cuda")
model = WhisperForConditionalGeneration.from_pretrained(
        "openai/whisper-tiny",
        torch_dtype=torch.float16,
        device_map="cuda",
        attn_implementation="eager",
    ).to("cuda")
model = DmxModel.from_torch(model)

dataset = list(
        load_dataset(
            "librispeech_asr",
            "clean",
            split="validation",
            streaming=True,
        )
    )
audio_sample = dataset[0]["audio"]
input_features = processor(
    audio_sample["array"],
    sampling_rate=audio_sample["sampling_rate"],
    return_tensors="pt",
).input_features
input_features = input_features.to("cuda", dtype=torch.float16)
decoder_input_ids = torch.randint(0, 100, (1, 2)).to("cuda")
ref_output = model0(input_features, decoder_input_ids=decoder_input_ids)


In [None]:
assert torch.allclose(ref_output.logits,model(input_features, decoder_input_ids=decoder_input_ids).logits)

In [15]:
from dmx.compressor.modeling.model import *
bfp14 = "BFP[6|8]{64}(SN)"
rules = (
    DmxConfigRule(
        module_types=(Linear,),
        module_config=dict(
            weight_format=bfp14,
        ),
    ),
)
model.configure(None, *rules)
assert not torch.allclose(model(input_features, decoder_input_ids=decoder_input_ids).logits,ref_output.logits)

In [None]:
assert not torch.allclose(model.model.encoder(input_features).last_hidden_state,model0.model.encoder(input_features).last_hidden_state)

encoder_last_hidden_state of quantized main model

In [None]:
model(input_features, decoder_input_ids=decoder_input_ids).encoder_last_hidden_state

encoder_last_hidden_state of unquantized HF model

In [None]:
ref_output.encoder_last_hidden_state

encoder_last_hidden_state of quantized encoder only

In [None]:
model.model.encoder(input_features)

### CLIP

In [None]:
from PIL import Image
import requests
from transformers import CLIPProcessor, CLIPModel

model0 = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", attn_implementation="eager")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

inputs = processor(
    text=["a photo of a cat", "a photo of a dog"],
    images=image,
    return_tensors="pt",
    padding=True,
)
model0.eval()
outputs_ref = model0(**inputs)
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
model = DmxModel.from_torch(model)
model.eval()

In [None]:
assert torch.allclose(outputs_ref.logits_per_image, model(**inputs).logits_per_image)
model.to_basic_mode()
assert not torch.allclose(outputs_ref.logits_per_image, model(**inputs).logits_per_image)

In [None]:
model.to_baseline_mode()
assert torch.allclose(model.text_model(inputs['input_ids']).pooler_output,model0.text_model(inputs['input_ids']).pooler_output,atol=1e-6)
assert torch.allclose(model.vision_model(inputs['pixel_values']).pooler_output,model0.vision_model(inputs['pixel_values']).pooler_output)

In [None]:
model.to_basic_mode()
assert not torch.allclose(model.text_model(inputs['input_ids']).pooler_output,model0.text_model(inputs['input_ids']).pooler_output,atol=1e-3)
assert not torch.allclose(model.vision_model(inputs['pixel_values']).pooler_output,model0.vision_model(inputs['pixel_values']).pooler_output,atol=1e-3)

## Llama

In [None]:
from transformers import AutoModelForCausalLM
import torch
from dmx.compressor.modeling import DmxModel
model_name = "d-matrix/Llama-3.2-1B"
model = AutoModelForCausalLM.from_pretrained(
    model_name, torch_dtype=torch.float, device_map="cuda",trust_remote_code=True
)
model0 = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float, device_map="cuda", attn_implementation="eager",trust_remote_code=True)
model = DmxModel.from_torch(model)
model.eval()
model0.eval()
input_ids = torch.randint(0, 100, (1, 8)).to("cuda:0")
with torch.no_grad():
    submod_input = model0.model(input_ids).last_hidden_state
    model(input_ids)

In [None]:
causal_mask = torch.full(
    (1, 1, 8, model.config.max_position_embeddings),
    fill_value=-torch.inf,
    device="cuda",
)
causal_mask = torch.triu(causal_mask, diagonal=1)
position_ids = torch.range(0, input_ids.shape[1] - 1).unsqueeze(0).to('cuda')
with torch.no_grad():
    ref_output = model0.model.layers[1](submod_input, position_ids=position_ids, attention_mask=causal_mask)[0]
    output = model.model.layers[1](submod_input, position_ids=position_ids, attention_mask=causal_mask)[0]
assert torch.allclose(ref_output, output,1e-4)

In [26]:
from dmx.compressor.modeling.model import *
bfp14 = "BFP[6|8]{64}(SN)"
rules = (
    DmxConfigRule(
        module_types=(Linear,),
        module_config=dict(
            weight_format=bfp14,
        ),
    ),
)
model.configure(None, *rules)
with torch.no_grad():
    ref_output = model0.model.layers[1](submod_input, position_ids=position_ids, attention_mask=causal_mask)[0]
    output = model.model.layers[1](submod_input, position_ids=position_ids, attention_mask=causal_mask)[0]
assert not torch.allclose(ref_output, output,1e-4)