In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
import os
import shutil
import random
import matplotlib.pyplot as plt
import copy
from torchao.quantization import Int4DynamicActivationInt4WeightConfig, Int8WeightOnlyConfig , quantize_

AttributeError: module 'torch' has no attribute 'int1'

In [None]:
weights = models.ViT_B_16_Weights.DEFAULT
transform = weights.transforms()

train_dataset = datasets.ImageFolder("dataset/train", transform=transform)
val_dataset   = datasets.ImageFolder("dataset/val", transform=transform)
test_dataset  = datasets.ImageFolder("dataset/test", transform=transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader  = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [None]:
device =torch.device("cpu")

In [None]:
model = torch.load("model_full.pth", weights_only=False)
model.eval()

In [None]:
# --- 1) モジュール差し替え用ヘルパ ---
def replace_nondynamic_linear_with_nn_linear(model):
    """
    model 中のクラス名が 'NonDynamicallyQuantizableLinear' のモジュールを
    通常の nn.Linear に置き換える。
    """
    # collect (full_name, module) first to avoid mutating while iterating
    targets = []
    for name, module in model.named_modules():
        if module.__class__.__name__ == "NonDynamicallyQuantizableLinear":
            targets.append((name, module))

    if not targets:
        print("No NonDynamicallyQuantizableLinear found.")
        return 0

    for full_name, old_mod in targets:
        # create new linear with same shape & bias
        in_f = getattr(old_mod, "in_features", None)
        out_f = getattr(old_mod, "out_features", None)
        has_bias = getattr(old_mod, "bias", None) is not None

        if in_f is None or out_f is None:
            print(f"Skipping {full_name}: cannot find in/out features.")
            continue

        new_mod = nn.Linear(in_f, out_f, bias=has_bias)
        # copy weights and bias (cast to float32 to be safe)
        try:
            # some custom modules store weight as `.weight` Parameter
            new_mod.weight.data.copy_(old_mod.weight.data.to(new_mod.weight.dtype))
            if has_bias:
                new_mod.bias.data.copy_(old_mod.bias.data.to(new_mod.bias.dtype))
        except Exception as e:
            # fallback: try to read .weight.numpy or .weight.clone()
            print(f"Warning copying params for {full_name}: {e}")

        # find parent module and attribute name
        name_parts = full_name.split(".")
        parent = model
        for p in name_parts[:-1]:
            parent = getattr(parent, p)
        attr_name = name_parts[-1]

        # setattr on parent
        setattr(parent, attr_name, new_mod)
        print(f"Replaced {full_name} -> nn.Linear({in_f},{out_f}, bias={has_bias})")

    return len(targets)


In [None]:
model_fp32 = copy.deepcopy(model).to(torch.float32)
model_fp32.to(device)


In [None]:
replace_nondynamic_linear_with_nn_linear(model_fp32)

In [None]:
def apply_mixed_quantization(model):

    int4_qcfg = Int4DynamicActivationInt4WeightConfig()
    int8_qcfg = Int8WeightOnlyConfig()

    for name, module in model.named_modules():
        # MultiheadAttention の out_proj は Int8 に量子化（in-place）
        if isinstance(module, nn.MultiheadAttention):
            outp = module.out_proj
            if isinstance(outp, nn.Linear):
                # quantize_ は in-place 変換で None を返すので代入しないこと！
                quantize_(outp, int8_qcfg)
                print(f"{name}.out_proj -> Int8 (in-place)")
        # その他の Linear は Int4 に量子化（in-place）
        elif isinstance(module, nn.Linear):
            # out_proj を上書きしてしまわないよう名前チェック。ただし out_proj は MultiheadAttention 部分で既に処理済み。
            if name.endswith("out_proj"):
                # MultiheadAttention の out_proj は上で処理済み or out_proj が別オブジェクトの場合があるのでスキップ
                continue
            quantize_(module, int4_qcfg)
            print(f"{name} -> Int4 (in-place)")

     # 簡単な整合チェック: MultiheadAttention の out_proj が None になっていないか確認
    for name, module in model.named_modules():
        if isinstance(module, nn.MultiheadAttention):
            if module.out_proj is None:
                raise RuntimeError(f"Error: {name}.out_proj is None after quantization!")
    return model

In [None]:
model_fp32 = apply_mixed_quantization(model_fp32)
model_fp32.eval()

model_fp32.to(device)

In [None]:
torch.save(model, "model_MHAweight8bit_otherweightact4bit.pth")

In [None]:
# --- テスト評価（あなたの test_loader をそのまま使ってください）---
correct, total = 0, 0
with torch.no_grad():
    for images, labels in tqdm(test_loader, desc="Testing quantized model"):
        images, labels = images.to(device), labels.to(device)
        outputs = model_fp32(images)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
test_acc = correct / total
print(f"INT4 weight-only Test Accuracy: {test_acc:.4f}")