## This notebook demos how to do quantization calibration on INT formats

In [None]:
from dmx.compressor.modeling.hf import pipeline
pipe = pipeline(
    task="text-generation",
    model="d-matrix/opt",
    revision="opt-125m",
    dmx_config="BASELINE",
    trust_remote_code=True,
    device_map="auto",  # enabling model parallel on multi-GPU nodes
)

The next block configures the model to the right format.

xxx_format takes a single value.

input_formats takes a list or a dictionary. When a list is passed, the formats will be set in the order of the castTos within input_casts.

In [None]:
from dmx.compressor.modeling import DmxConfigRule,nn
format = "XP[8,0](CSN)"
rules = (
    DmxConfigRule(
        module_types=(nn.Linear,),
        module_config=dict(
            input_formats=[format],  # option 1
            # input_formats = {"input_cast": format} # option 2
            weight_format=format,
        ),
    ),
    DmxConfigRule(
        module_types=(nn.ScaledDotProductAttention,),
        module_config=dict(
            input_formats=[format, format, format],  # option 1
            # input_formats={
            #     "query_states_cast": format,
            #     "key_states_cast": format,
            #     "value_states_cast": format,
            #     "attn_mask_cast": None,
            # },  # option 2
            weight_format=format,
        ),
    ),
    DmxConfigRule(
        module_types=(nn.ActActMatMul,),
        module_config=dict(
            input_formats=[format, format],  # option 1
            # input_formats = {"input_cast": format, "multiplier_cast":format} # option 2
        ),
    ),
)
# configure model based on rules
pipe.model.configure(None, *rules)

### Note: if the data format does not require calibration, all steps under this collapsed section can be skipped

A forward pass needs to be done before calibration so that JIT transformation is triggered and dmx modules exists

In [None]:
import torch
x = torch.randint(1, 100, (1, 1024))
with torch.no_grad():
    y = pipe.model(x)


Checking the content of the transformed model

In [None]:
pipe.model._gm

Specifying layers to calibrate

In [None]:
calibration_layers_matmul = {
    n: m for n, m in pipe.model.named_dmx_modules() if isinstance(m, nn.ActActMatMul)
}
calibration_layers_lin = {
    n: m for n, m in pipe.model.named_dmx_modules() if isinstance(m, (nn.Linear,))
}
calibration_layers_attention = {
    n: m
    for n, m in pipe.model.named_dmx_modules()
    if isinstance(m, (nn.ScaledDotProductAttention,))
}


Specifying hyperparameters to use for calibration

In [None]:
from dmx.compressor.numerical.observer import HistogramObserver, MinMaxObserver
matmul_hyperparams = {
    "input_cast": dict(
        observer_cls=HistogramObserver,
        qscheme_to_overload=None,
        group_size=None,
        ch_axis=None,
    ),
    "multiplier_cast": dict(
        observer_cls=MinMaxObserver,
        qscheme_to_overload=torch.per_channel_affine,
        ch_axis=-2,
    ),
}



if hyperparams=None, inside calibrating_activations method it defaults to
```python
{
    "input_cast": dict(
        observer_cls=HistogramObserver,
        qscheme_to_overload=None,
        group_size=None,
        ch_axis=None,
    ),
}
```

In [None]:
lin_hyperparams = None

if values of hyperparams are empty dicts, inside calibrating_activations method it defaults to 
``` python
dict(
    observer_cls=HistogramObserver,
    qscheme_to_overload=None,
    group_size=None,
    ch_axis=None,
)
```

In [None]:
attention_hyperparams = {
    "query_states_cast": {},
    "key_states_cast": {},
    "value_states_cast": {},
}


doing calibration

In [None]:
with torch.no_grad(), pipe.model.calibrating_weights(
    calibration_layers_lin.items()
), pipe.model.calibrating_activations(
    calibration_layers_matmul.items(), matmul_hyperparams
), pipe.model.calibrating_activations(
    calibration_layers_lin.items(), lin_hyperparams
), pipe.model.calibrating_activations(
    calibration_layers_attention.items(), attention_hyperparams
):
    pipe.do_forward_on(dataset = "wikitext",dataset_version="wikitext-2-raw-v1",column_name = "text",dataset_split="train",num_samples=10)


### Evaluation

In [None]:
metric = pipe.evaluate(
    "d-matrix/dmx_perplexity",
    dataset="wikitext",
    dataset_version="wikitext-2-raw-v1",
)
print(metric)
