In [1]:
import torch
import torch.nn as nn
import copy
from torchao.quantization import Int8DynamicActivationInt4WeightConfig, Int8DynamicActivationInt8WeightConfig, Int8WeightOnlyConfig, quantize_
from torchvision import datasets, models, transforms
from tqdm import tqdm
from torch.utils.data import DataLoader, random_split

In [2]:
# --- MHA を Q/K/V 分解して nn.Linear に置き換える ---
class MHAWithExplicitLinear(nn.Module):
    def __init__(self, mha: nn.MultiheadAttention):
        super().__init__()
        self.embed_dim = mha.embed_dim
        self.num_heads = mha.num_heads
        self.dropout = mha.dropout
        self.batch_first = mha.batch_first

        # in_proj を分割して q/k/v の Linear を作成
        w = mha.in_proj_weight
        b = mha.in_proj_bias
        d = self.embed_dim

        self.q_proj = nn.Linear(d, d, bias=b is not None)
        self.k_proj = nn.Linear(d, d, bias=b is not None)
        self.v_proj = nn.Linear(d, d, bias=b is not None)

        with torch.no_grad():
            self.q_proj.weight.copy_(w[:d, :])
            self.k_proj.weight.copy_(w[d:2*d, :])
            self.v_proj.weight.copy_(w[2*d:, :])
            if b is not None:
                self.q_proj.bias.copy_(b[:d])
                self.k_proj.bias.copy_(b[d:2*d])
                self.v_proj.bias.copy_(b[2*d:])

        # out_proj はそのままコピー
        self.out_proj = copy.deepcopy(mha.out_proj)

    def forward(self, query, key, value, need_weights=False, attn_mask=None):
        q = self.q_proj(query)
        k = self.k_proj(key)
        v = self.v_proj(value)

        attn_output, attn_output_weights = nn.functional.multi_head_attention_forward(
            query=q,
            key=k,
            value=v,
            embed_dim_to_check=self.embed_dim,
            num_heads=self.num_heads,
            in_proj_weight=None,
            in_proj_bias=None,
            bias_k=None,
            bias_v=None,
            add_zero_attn=False,
            dropout_p=self.dropout,
            out_proj_weight=self.out_proj.weight,
            out_proj_bias=self.out_proj.bias,
            training=self.training,
            key_padding_mask=None,
            need_weights=need_weights,
            attn_mask=attn_mask,
            use_separate_proj_weight=True,
            q_proj_weight=self.q_proj.weight,
            k_proj_weight=self.k_proj.weight,
            v_proj_weight=self.v_proj.weight,
            static_k=None,
            static_v=None
        )
        return attn_output, attn_output_weights


def convert_mha_to_linear_proj(model):
    """
    モデル内の nn.MultiheadAttention を MHAWithExplicitLinear に置換
    """
    targets = []
    for name, module in model.named_modules():
        if isinstance(module, nn.MultiheadAttention):
            targets.append((name, module))

    for full_name, old_mod in targets:
        new_mod = MHAWithExplicitLinear(old_mod)
        # 親モジュールをたどって置換
        name_parts = full_name.split(".")
        parent = model
        for p in name_parts[:-1]:
            parent = getattr(parent, p)
        setattr(parent, name_parts[-1], new_mod)
        print(f"Replaced {full_name} -> MHAWithExplicitLinear")

    return model


# --- 量子化関数 ---
def apply_mixed_quantization(model, qcfg_qkv, qcfg_out, qcfg_other):
    """
    Q/K/V, out_proj, その他 Linear をそれぞれ異なる qconfig で量子化
    """
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            if any(x in name for x in ["q_proj", "k_proj", "v_proj"]):
                quantize_(module, qcfg_qkv)
                print(f"{name} -> Q/K/V qconfig")
            elif "out_proj" in name:
                quantize_(module, qcfg_out)
                print(f"{name} -> out_proj qconfig")
            else:
                quantize_(module, qcfg_other)
                print(f"{name} -> other Linear qconfig")

    return model

In [3]:
# モデル読み込み
model = torch.load("model_full.pth", weights_only=False)
model.eval()

model_fp32 = copy.deepcopy(model).to(torch.float32)

# MHA を Q/K/V/Out の Linear に展開
model_fp32 = convert_mha_to_linear_proj(model_fp32)

# 好きな qconfig を指定可能
qcfg_qkv   = Int8WeightOnlyConfig()   # 例: Q/K/V は Int8
qcfg_out   = Int8DynamicActivationInt8WeightConfig()   # 例: out_proj も Int8
qcfg_other = Int8DynamicActivationInt8WeightConfig()   # 例: その他 Linear は Int4

model_fp32 = apply_mixed_quantization(model_fp32, qcfg_qkv, qcfg_out, qcfg_other)
model_fp32.eval()

torch.save(model_fp32, "model_custom_quant.pth")

Replaced encoder.layers.encoder_layer_0.self_attention -> MHAWithExplicitLinear
Replaced encoder.layers.encoder_layer_1.self_attention -> MHAWithExplicitLinear
Replaced encoder.layers.encoder_layer_2.self_attention -> MHAWithExplicitLinear
Replaced encoder.layers.encoder_layer_3.self_attention -> MHAWithExplicitLinear
Replaced encoder.layers.encoder_layer_4.self_attention -> MHAWithExplicitLinear
Replaced encoder.layers.encoder_layer_5.self_attention -> MHAWithExplicitLinear
Replaced encoder.layers.encoder_layer_6.self_attention -> MHAWithExplicitLinear
Replaced encoder.layers.encoder_layer_7.self_attention -> MHAWithExplicitLinear
Replaced encoder.layers.encoder_layer_8.self_attention -> MHAWithExplicitLinear
Replaced encoder.layers.encoder_layer_9.self_attention -> MHAWithExplicitLinear
Replaced encoder.layers.encoder_layer_10.self_attention -> MHAWithExplicitLinear
Replaced encoder.layers.encoder_layer_11.self_attention -> MHAWithExplicitLinear
encoder.layers.encoder_layer_0.self_at

In [4]:
weights = models.ViT_B_16_Weights.DEFAULT
transform = weights.transforms()

train_dataset = datasets.ImageFolder("dataset/train", transform=transform)
val_dataset   = datasets.ImageFolder("dataset/val", transform=transform)
test_dataset  = datasets.ImageFolder("dataset/test", transform=transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader  = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [5]:
device =torch.device("cpu")
model_fp32.to(device)

VisionTransformer(
  (conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  (encoder): Encoder(
    (dropout): Dropout(p=0.0, inplace=False)
    (layers): Sequential(
      (encoder_layer_0): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_attention): MHAWithExplicitLinear(
          (q_proj): Linear(in_features=768, out_features=768, weight=AffineQuantizedTensor(shape=torch.Size([768, 768]), block_size=(1, 768), device=cpu, _layout=PlainLayout(), tensor_impl_dtype=torch.int8, quant_min=None, quant_max=None))
          (k_proj): Linear(in_features=768, out_features=768, weight=AffineQuantizedTensor(shape=torch.Size([768, 768]), block_size=(1, 768), device=cpu, _layout=PlainLayout(), tensor_impl_dtype=torch.int8, quant_min=None, quant_max=None))
          (v_proj): Linear(in_features=768, out_features=768, weight=AffineQuantizedTensor(shape=torch.Size([768, 768]), block_size=(1, 768), device=cpu, _layout=PlainLayout(), tens

In [6]:
# --- テスト評価（あなたの test_loader をそのまま使ってください）---
correct, total = 0, 0
with torch.no_grad():
    for images, labels in tqdm(test_loader, desc="Testing quantized model"):
        images, labels = images.to(device), labels.to(device)
        outputs = model_fp32(images)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
test_acc = correct / total
print(f"INT4 weight-only Test Accuracy: {test_acc:.4f}")

Testing quantized model: 100%|██████████| 14/14 [00:20<00:00,  1.49s/it]

INT4 weight-only Test Accuracy: 0.2437



