## This notebook demos the basic usage of how to create and configure a quantized model

Define a custom model

In [None]:
import torch
class CustomNet(torch.nn.Module):
    def __init__(self, in_dim, out_dim) -> None:
        super().__init__()
        self.layer = torch.nn.Linear(in_dim, in_dim)

    def forward(self, x, use_gelu=False, old_x=None):
        x = self.layer(x)
        if use_gelu:
            x = torch.nn.functional.gelu(x)
        else:
            x = torch.nn.functional.relu(x)
        if not old_x is None:
            x += old_x
        return x

Create a DmxModel

In [None]:
from dmx.compressor.modeling import DmxModel
torch.random.manual_seed(0)
model = CustomNet(10, 10)
model = DmxModel.from_torch(model)
x = torch.rand(1, 10)


Configure the model to formats equivalent to basic-mode execution on d-Matrix's hardware

In [None]:
from dmx.compressor import config_rules
model = model.transform(
    model.dmx_config,
    *config_rules.BASIC,
)

In [None]:
model(x)

In [None]:
model._gm

Configure to other formats

In [None]:
from dmx.compressor.modeling import DmxConfigRule,Linear
from dmx.compressor import format

rules = (
    DmxConfigRule(
        module_types=(Linear,),
        module_config=dict(
            input_formats=[format.MXINT8_K64],
            weight_format=format.MXINT4_K64,
        ),
    ),
)

In [None]:
model(x)
model(-x)
model(x, use_gelu=True)
model(-x, use_gelu=True, old_x=x)

visualize latest computation graph

In [None]:
model.visualize_graph(out_file="graph")