In [43]:
import torch
import torch.nn as nn
import torch.optim as optim
from PIL import Image #as mg
import torch.nn.utils.prune as prune
from torchvision import models,transforms,datasets
from torch.utils.data import DataLoader as loader
from torchvision.datasets import CIFAR10


In [78]:
device=torch.device('cuda'if torch.cuda.is_available()else'cpu')
print(device, 'mode')
model=models.resnet18(pretrained=True).to(device).eval()
# 예시: 마지막 레이어만 10개로 교체
model.fc = torch.nn.Linear(model.fc.in_features, 10).to(device)
# 그리고 나서 train_loader로 fine-tune

model

cuda mode


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [80]:
layer=model.layer4[1].conv2 #마지박부분\
prune.ln_structured(
    layer,name='weight',
    amount=0.3,# 30% 채널 제거
    n=1, # L1-norm 기준
    dim=0 # 출력 채널 축
)

prune.remove(layer,'weight')
layer

Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)

In [92]:
preprocess=transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])

test_data=CIFAR10(root='./data',train=False,download=True,transform=preprocess)
#test_ds = CIFAR10(root='./data', train=False, download=True)
print(test_data.classes)

test_loader=loader(test_data,batch_size=1,shuffle=False)

img,true_label=next(iter(test_loader))
img=img.to(device)

#img=preprocess(mg.open('dog.jpg')).unsqueeze(0).to(device)


['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']


In [94]:
true_label = 3
predicted  = 6
print(f"True class:  {test_data.classes[true_label]}")
print(f"Predicted:   {test_data.classes[predicted]}")
# True class:  cat
# Predicted:   frog


True class:  cat
Predicted:   frog


In [96]:
correct = 0
for x,y in test_loader:
    x = x.to(device)
    with torch.no_grad():
        pred = model(x).argmax(1).item()
    if pred == y:
        correct += 1
print("Test Accuracy:", correct / len(test_ds))


Test Accuracy: 0.0882


In [98]:
with torch.no_grad():
    logit=model(img)
    pred=logit.argmax(1).item()

print(f"True label: {true_label}, Pruned ResNet18 prediction: {pred}")


True label: 3, Pruned ResNet18 prediction: 6


여기서 출력된 값들의 의미는 이렇습니다:

1. **True label: tensor(\[3])**
   이건 **CIFAR-10** 데이터셋에서 가져온 정답 레이블입니다.
   CIFAR-10의 클래스 인덱스는

   ```
   0: airplane  
   1: automobile  
   2: bird  
   3: cat  
   4: deer  
   5: dog  
   6: frog  
   7: horse  
   8: ship  
   9: truck  
   ```

   이므로, `3`은 **cat** (고양이)를 뜻합니다.

2. **Pruned ResNet18 prediction: 381**
   이 숫자 `381`은 **ImageNet**용으로 사전학습된 ResNet18이 출력한 **1000개 중 하나의 클래스 인덱스**입니다.
   즉, 당신이 CIFAR-10 이미지를 `pretrained=True` ResNet18에 바로 넣었기 때문에,
   모델은 “ImageNet”의 1000개 레이블 중 381번(382번째)에 해당하는 카테고리를 선택한 겁니다.

   만약 이 인덱스를 사람이 읽을 수 있는 이름으로 바꾸려면, ImageNet의 클래스 맵 파일을 로드해서 조회하면 됩니다. 예를 들어:

   ```python
   # 1) 미리 다운로드해 두었다고 가정한 imagenet_class_index.json 파일 로드
   import json
   with open("imagenet_class_index.json") as f:
       idx2label = json.load(f)
   # 2) 381번 인덱스에 해당하는 (ID, human-readable name) 조회
   print(idx2label[str(381)])  
   # 예) ["n02119022", "red_fox"]
   ```

---

### 해결책 및 권장 방법

* **ImageNet 모델을 CIFAR-10에 바로 쓰지 마시고**,

  1. **마지막 fc 레이어**만 `nn.Linear(512, 10)` 같이 10개 출력으로 교체 →
  2. CIFAR-10 데이터로 **파인튜닝(fine-tune)** →
  3. 그 다음에 `model(img)` 를 호출하면 `pred` 도 0\~9 범위로 나옵니다.

```python
# 예시: 마지막 레이어만 10개로 교체
model.fc = torch.nn.Linear(model.fc.in_features, 10).to(device)
# 그리고 나서 train_loader로 fine-tune
```

* 이렇게 하면 `pred` 가 0\~9 사이의 **CIFAR-10 클래스 인덱스**(예: “3 → cat”)로 나오게 됩니다.

---

정리하자면:

> **“381”** 은 **ImageNet 1000개 레이블** 중 하나의 번호이고,
> **“3”** 은 **CIFAR-10 10개 레이블** 중 하나의 번호입니다.

두 레이블 체계가 다르니, CIFAR-10 작업을 하실 땐 모델을 **10개 클래스용**으로 바꾼 뒤 반드시 **재학습(fine-tuning)** 하셔야 올바른 예측 결과를 얻으실 수 있습니다.


In [76]:
from torch.quantization import get_default_qconfig, quantize_jit, QuantStub, DeQuantStub, fuse_modules, prepare_qat, convert
import torch.nn.utils.prune as prune
from torchvision import models
import torch
import os
import matplotlib.pyplot as plt
import numpy as np
from tqdm.notebook import tqdm
from torchvision import models,transforms,datasets
from torch.utils.data import DataLoader as loader
from torchvision.datasets import CIFAR10
import torch.nn as nn
import torch.optim as optim

In [227]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.quantization import get_default_qconfig, QuantStub, DeQuantStub, fuse_modules, prepare_qat, convert

import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from torch.utils.data import DataLoader 

In [229]:
batch_size = 64
lr = 0.0001
num_epoch = 5 # FP32 모델 학습 에폭
qat_epoch = 5 # QAT 학습 에폭
save_path = "./qat_resnet18_model.pth"

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device} mode")

Using device: cuda mode


In [232]:
transform_train = transforms.Compose([
    transforms.RandomResizedCrop(224), 
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 
])

transform_test = transforms.Compose([
    transforms.Resize(224), 
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [235]:
train_data = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
test_data = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

# 학습 속도를 위해 데이터셋 서브셋 사용 (선택 사항)
subset_size_train = 5000
subset_size_test = 1000

train_subset_indices = torch.randperm(len(train_data))[:subset_size_train]
test_subset_indices = torch.randperm(len(test_data))[:subset_size_test]

train_subset = torch.utils.data.Subset(train_data, train_subset_indices)
test_subset = torch.utils.data.Subset(test_data, test_subset_indices)

train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader = DataLoader(test_subset, batch_size=batch_size, shuffle=False, num_workers=2)

print(f"Number of training batches (subset): {len(train_loader)}")
print(f"Number of test batches (subset): {len(test_loader)}")


Number of training batches (subset): 79
Number of test batches (subset): 16


In [236]:
def get_res(num_classes=10, pretrained=True): 
    """
    ResNet18 모델을 로드하고, 마지막 Fully Connected Layer를
    지정된 num_classes에 맞게 수정합니다.
    """
    model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None)
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, num_classes) 
    return model


In [237]:
def prepare_model_for_qat(model): 
    """
    QAT를 위해 모델에 QuantStub/DeQuantStub를 삽입하고,
    ResNet18의 특정 모듈들을 퓨징합니다.
    """
    model.quant = QuantStub()
    model.dequant = DeQuantStub() 

    # ResNet18의 기본 블록 구조에 맞춰 모듈 퓨징
    # (conv1, bn1, relu)와 각 BasicBlock 내부의 (conv, bn, relu)를 퓨징
    for name, module in model.named_children():
        if isinstance(module, nn.Sequential): 
            for basic_block_name, basic_block_module in module.named_children():
                if isinstance(basic_block_module, models.resnet.BasicBlock):
                    # BasicBlock의 첫 번째 Conv-BN-ReLU 시퀀스
                    torch.quantization.fuse_modules(basic_block_module, [['conv1', 'bn1', 'relu']], inplace=True)
                    # BasicBlock의 두 번째 Conv-BN (ReLU는 여기에 붙지 않음)
                    torch.quantization.fuse_modules(basic_block_module, [['conv2', 'bn2']], inplace=True)
    # 모델의 첫 번째 Conv-BN-ReLU 시퀀스
    torch.quantization.fuse_modules(model, [['conv1', 'bn1', 'relu']], inplace=True)
    return model

print("Model definition and QAT preparation functions defined.")

Model definition and QAT preparation functions defined.


In [238]:
def train_mode(model, train_loader, criter, optim, num_epoch, device, model_name="Model"):
    """
    모델을 학습시키는 일반적인 학습 루프입니다.
    """
    model.train()
    print(f"\n--- {model_name} Training ---")
    for epoch in range(num_epoch):
        running_loss = 0.0
        correct_predictions = 0
        total_predictions = 0
        for batch_idx, (data, target) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epoch} ({model_name})")):
            data, target = data.to(device), target.to(device)
            optim.zero_grad()
            out = model(data)
            loss = criter(out, target)
            loss.backward()
            optim.step()
            running_loss += loss.item()
            _, predicted = torch.max(out.data, 1)
            total_predictions += target.size(0)
            correct_predictions += (predicted == target).sum().item()

        avg_loss = running_loss / len(train_loader)
        acc = 100 * correct_predictions / total_predictions
        print(f"Epoch {epoch+1} Complete: Avg Loss: {avg_loss:.4f}, Accuracy: {acc:.2f}%")
    print(f"{model_name} Training finished!")


In [239]:
def eval_mode(model, test_loader, device, model_name='Model'):
    """
    모델의 정확도를 평가하는 함수입니다.
    """
    model.eval()
    correct = 0.0
    total = 0.0
    with torch.no_grad():
        for data, target in tqdm(test_loader, desc=f"Evaluating {model_name}"):
            data, target = data.to(device), target.to(device)
            out = model(data)
            _, pred = torch.max(out.data, 1)
            total += target.size(0)
            correct += (pred == target).sum().item()
        acc = 100 * correct / total
        print(f'Accuracy on test set ({model_name}): {acc:.2f}%')
    return acc

print("Training and evaluation functions defined.")

Training and evaluation functions defined.


In [245]:
print("\n--- FP32 ResNet18 Model Training and Evaluation ---")
fp32_model = get_res(num_classes=10, pretrained=True).to(device) 
criter = nn.CrossEntropyLoss()
optimy = optim.Adam(fp32_model.parameters(), lr=lr)

train_mode(fp32_model, train_loader, criter, optimy, num_epoch, device, model_name="FP32 ResNet18")
fp32_accuracy = eval_mode(fp32_model, test_loader, device, model_name="FP32 ResNet18")

fp32_model_path = save_path.replace(".pth", "_fp32.pth")
torch.save(fp32_model.state_dict(), fp32_model_path)
print(f"FP32 ResNet18 model saved to {fp32_model_path}")


--- FP32 ResNet18 Model Training and Evaluation ---

--- FP32 ResNet18 Training ---


Epoch 1/5 (FP32 ResNet18):   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 1 Complete: Avg Loss: 1.3560, Accuracy: 52.78%


Epoch 2/5 (FP32 ResNet18):   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 2 Complete: Avg Loss: 0.9098, Accuracy: 68.68%


Epoch 3/5 (FP32 ResNet18):   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 3 Complete: Avg Loss: 0.7525, Accuracy: 74.36%


Epoch 4/5 (FP32 ResNet18):   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 4 Complete: Avg Loss: 0.6734, Accuracy: 77.14%


Epoch 5/5 (FP32 ResNet18):   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 5 Complete: Avg Loss: 0.6106, Accuracy: 78.82%
FP32 ResNet18 Training finished!


Evaluating FP32 ResNet18:   0%|          | 0/16 [00:00<?, ?it/s]

Accuracy on test set (FP32 ResNet18): 85.30%
FP32 ResNet18 model saved to ./qat_resnet18_model_fp32.pth


In [246]:
print("\n--- Quantization-Aware Training (QAT) for ResNet18 ---")
qat_model = get_res(num_classes=10, pretrained=True).to(device)

qat_model.load_state_dict(fp32_model.state_dict())
print("QAT model initialized with FP32 model weights.")

print("Preparing ResNet18 for QAT (fusing modules)...")
# 퓨징 전에 eval() 모드로 전환
qat_model.eval() 
qat_model = prepare_model_for_qat(qat_model) 
print("ResNet18 QAT preparation complete.")

print("Setting QConfig (fbgemm)...")
qat_model.qconfig = get_default_qconfig('fbgemm')
print(f"QConfig set to: {qat_model.qconfig}")

print("Calling torch.quantization.prepare_qat...")
# prepare_qat을 호출하기 전에 다시 train() 모드로 전환
qat_model.train() # #### 변경 사항 8: prepare_qat 호출 전에 모델을 train() 모드로 전환
prepare_qat(qat_model, inplace=True)
print("Model prepared for QAT.")

# 퓨징 및 prepare_qat 완료 후, QAT 학습을 위해 이미 train() 모드이므로 이 줄은 중복입니다.
# 하지만 코드가 명확하게 보이도록 그대로 두거나 제거할 수 있습니다.
# qat_model.train() 

criter_qt = nn.CrossEntropyLoss()
optimy_qt = optim.Adam(qat_model.parameters(), lr=lr * 0.1) 
print(f"\nStarting QAT for {qat_epoch} epochs...")
train_mode(qat_model,train_loader, criter_qt, optimy_qt, qat_epoch, device, model_name="QAT ResNet18")

print("\nConverting QAT model to quantized model...")
qat_model.eval()
quantized_model = convert(qat_model, inplace=True)
print("Model converted to fully quantized (INT8) ResNet18.")


--- Quantization-Aware Training (QAT) for ResNet18 ---
QAT model initialized with FP32 model weights.
Preparing ResNet18 for QAT (fusing modules)...
ResNet18 QAT preparation complete.
Setting QConfig (fbgemm)...
QConfig set to: QConfig(activation=functools.partial(<class 'torch.ao.quantization.observer.HistogramObserver'>, reduce_range=True){}, weight=functools.partial(<class 'torch.ao.quantization.observer.PerChannelMinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_channel_symmetric){})
Calling torch.quantization.prepare_qat...
Model prepared for QAT.

Starting QAT for 5 epochs...

--- QAT ResNet18 Training ---




Epoch 1/5 (QAT ResNet18):   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 1 Complete: Avg Loss: 235951.8902, Accuracy: 15.52%


Epoch 2/5 (QAT ResNet18):   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 2 Complete: Avg Loss: 8159.6862, Accuracy: 23.98%


Epoch 3/5 (QAT ResNet18):   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 3 Complete: Avg Loss: 5892.9919, Accuracy: 27.12%


Epoch 4/5 (QAT ResNet18):   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 4 Complete: Avg Loss: 4801.2242, Accuracy: 28.80%


Epoch 5/5 (QAT ResNet18):   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 5 Complete: Avg Loss: 3820.8045, Accuracy: 31.86%
QAT ResNet18 Training finished!

Converting QAT model to quantized model...


RuntimeError: Unsupported qscheme: per_channel_affine

In [None]:
print("\n--- Quantized ResNet18 Model Evaluation ---")
qat_accuracy = eval_mode(quantized_model, test_loader, device, model_name="Quantized ResNet18 (QAT)")
torch.save(quantized_model.state_dict(), save_path)
print(f"Quantized ResNet18 model saved to {save_path}")

fp32_model_size = os.path.getsize(fp32_model_path)
quantized_model_size = os.path.getsize(save_path)
print(f"\n--- Model Size Comparison ---")
print(f"FP32 ResNet18 Model Size: {fp32_model_size / (1024*1024):.2f} MB")
print(f"Quantized ResNet18 Model Size: {quantized_model_size / (1024*1024):.2f} MB")
print(f"Quantized model is approximately {fp32_model_size / quantized_model_size:.2f}x smaller than FP32 model.")

print(f"\n--- Final Summary ---")
print(f"FP32 ResNet18 Model Accuracy: {fp32_accuracy:.2f}%")
print(f"Quantized ResNet18 Model Accuracy (QAT): {qat_accuracy:.2f}%")

In [260]:
# 필요한 라이브러리 임포트
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
# default_per_channel_weight_observer는 더 이상 직접 사용하지 않으므로 제거하거나 주석 처리합니다.
# from torch.ao.quantization import get_default_qconfig, QuantStub, DeQuantStub, fuse_modules, prepare_qat, convert, default_per_channel_weight_observer 
from torch.ao.quantization import get_default_qconfig, QuantStub, DeQuantStub, fuse_modules, prepare_qat, convert
from torch.ao.quantization import observer # observer 모듈을 직접 임포트합니다.

import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from torch.utils.data import DataLoader 
import functools # QConfig 설정에 필수적입니다.

# --- 1. 하이퍼파라미터 및 장치 설정 ---
batch_size = 64
lr = 0.0001
num_epoch = 5 # FP32 모델 학습 에폭
qat_epoch = 5 # QAT 학습 에폭
save_path = "./qat_resnet18_model.pth"

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device} mode")

# --- 2. 데이터 로드 및 전처리 ---
transform_train = transforms.Compose([
    transforms.RandomResizedCrop(224), # ResNet18은 224x224 입력 기대
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet 평균/분산 사용
])

transform_test = transforms.Compose([
    transforms.Resize(224), # 224x224로 리사이즈
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

train_data = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
test_data = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

subset_size_train = 5000
subset_size_test = 1000

train_subset_indices = torch.randperm(len(train_data))[:subset_size_train]
test_subset_indices = torch.randperm(len(test_data))[:subset_size_test]

train_subset = torch.utils.data.Subset(train_data, train_subset_indices)
test_subset = torch.utils.data.Subset(test_data, test_subset_indices)

train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader = DataLoader(test_subset, batch_size=batch_size, shuffle=False, num_workers=2)

print(f"Number of training batches (subset): {len(train_loader)}")
print(f"Number of test batches (subset): {len(test_loader)}")

# --- 3. 모델 정의 및 QAT 수정 함수 ---
def get_res(num_classes=10, pretrained=True): # #### 변경 사항 1: num_classes 기본값 변경 (CIFAR10용)
    """
    ResNet18 모델을 로드하고, 마지막 Fully Connected Layer를
    지정된 num_classes에 맞게 수정합니다.
    """
    model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None)
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, num_classes) # CIFAR10 클래스 수에 맞게 final layer 조정
    return model

def prepare_model_for_qat(model): # #### 변경 사항 2: 함수명 변경 (pretrain_Model -> prepare_model_for_qat)
    """
    QAT를 위해 모델에 QuantStub/DeQuantStub를 삽입하고,
    ResNet18의 특정 모듈들을 퓨징합니다.
    """
    model.quant = QuantStub()
    model.dequant = DeQuantStub() # #### 변경 사항 3: DeQuantStub()으로 인스턴스화 필요

    # ResNet18의 기본 블록 구조에 맞춰 모듈 퓨징
    # (conv1, bn1, relu)와 각 BasicBlock 내부의 (conv, bn, relu)를 퓨징
    for name, module in model.named_children():
        if isinstance(module, nn.Sequential): # Layer blocks (e.g., layer1, layer2, ...)
            for basic_block_name, basic_block_module in module.named_children():
                if isinstance(basic_block_module, models.resnet.BasicBlock):
                    # BasicBlock의 첫 번째 Conv-BN-ReLU 시퀀스
                    torch.quantization.fuse_modules(basic_block_module, [['conv1', 'bn1', 'relu']], inplace=True)
                    # BasicBlock의 두 번째 Conv-BN (ReLU는 여기에 붙지 않음)
                    torch.quantization.fuse_modules(basic_block_module, [['conv2', 'bn2']], inplace=True)
    # 모델의 첫 번째 Conv-BN-ReLU 시퀀스
    torch.quantization.fuse_modules(model, [['conv1', 'bn1', 'relu']], inplace=True)
    return model

print("Model definition and QAT preparation functions defined.")

# --- 4. 학습 및 평가 함수 ---
def train_mode(model, train_loader, criter, optim, num_epoch, device, model_name="Model"):
    """
    모델을 학습시키는 일반적인 학습 루프입니다.
    """
    model.train()
    print(f"\n--- {model_name} Training ---")
    for epoch in range(num_epoch):
        running_loss = 0.0
        correct_predictions = 0
        total_predictions = 0
        for batch_idx, (data, target) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epoch} ({model_name})")):
            data, target = data.to(device), target.to(device)
            optim.zero_grad()
            out = model(data)
            loss = criter(out, target)
            loss.backward()
            optim.step()
            running_loss += loss.item()
            _, predicted = torch.max(out.data, 1)
            total_predictions += target.size(0)
            correct_predictions += (predicted == target).sum().item()

        avg_loss = running_loss / len(train_loader)
        acc = 100 * correct_predictions / total_predictions
        print(f"Epoch {epoch+1} Complete: Avg Loss: {avg_loss:.4f}, Accuracy: {acc:.2f}%")
    print(f"{model_name} Training finished!")


def eval_mode(model, test_loader, device, model_name='Model'):
    """
    모델의 정확도를 평가하는 함수입니다.
    """
    model.eval()
    correct = 0.0
    total = 0.0
    with torch.no_grad():
        for data, target in tqdm(test_loader, desc=f"Evaluating {model_name}"):
            data, target = data.to(device), target.to(device)
            out = model(data)
            _, pred = torch.max(out.data, 1)
            total += target.size(0)
            correct += (pred == target).sum().item()
        acc = 100 * correct / total
        print(f'Accuracy on test set ({model_name}): {acc:.2f}%')
    return acc

print("Training and evaluation functions defined.")

# --- 5. FP32 (원본) ResNet18 학습 및 평가 ---
print("\n--- FP32 ResNet18 Model Training and Evaluation ---")
fp32_model = get_res(num_classes=10, pretrained=True).to(device) # CIFAR10 클래스 10개
criter = nn.CrossEntropyLoss()
optimy = optim.Adam(fp32_model.parameters(), lr=lr)

train_mode(fp32_model, train_loader, criter, optimy, num_epoch, device, model_name="FP32 ResNet18")
fp32_accuracy = eval_mode(fp32_model, test_loader, device, model_name="FP32 ResNet18")

fp32_model_path = save_path.replace(".pth", "_fp32.pth")
torch.save(fp32_model.state_dict(), fp32_model_path)
print(f"FP32 ResNet18 model saved to {fp32_model_path}")

# --- 6. Quantization-Aware Training (QAT) for ResNet18 ---
print("\n--- Quantization-Aware Training (QAT) for ResNet18 ---")
qat_model = get_res(num_classes=10, pretrained=True).to(device)

qat_model.load_state_dict(fp32_model.state_dict())
print("QAT model initialized with FP32 model weights.")

print("Preparing ResNet18 for QAT (fusing modules)...")
qat_model.eval() 
qat_model = prepare_model_for_qat(qat_model) 
print("ResNet18 QAT preparation complete.")

print("Setting QConfig (fbgemm)...")
# #### 최종 변경 사항 11: QConfig 설정 방식 변경 (AttributeError 및 Unsupported qscheme 재시도)
# `get_default_qconfig('fbgemm')`는 기본적으로 activation=HistogramObserver, weight=PerChannelMinMaxObserver를 사용합니다.
# 하지만 qscheme 이슈가 지속되므로, qscheme을 명시적으로 symmetric으로 설정하고,
# 활성화와 가중치 옵저버를 직접 functools.partial로 구성합니다.
# HistogramObserver for activation usually uses reduce_range=True.
# PerChannelMinMaxObserver for weights needs qscheme=torch.per_channel_symmetric for fbgemm.
qat_model.qconfig = torch.ao.quantization.QConfig(
    activation=functools.partial(observer.HistogramObserver, reduce_range=True),
    weight=functools.partial(observer.PerChannelMinMaxObserver, dtype=torch.qint8, qscheme=torch.per_channel_symmetric)
)


print(f"QConfig set to: {qat_model.qconfig}")

print("Calling torch.quantization.prepare_qat...")
qat_model.train() 
prepare_qat(qat_model, inplace=True)
print("Model prepared for QAT.")

criter_qt = nn.CrossEntropyLoss()
optimy_qt = optim.Adam(qat_model.parameters(), lr=lr * 0.1) 
print(f"\nStarting QAT for {qat_epoch} epochs...")
train_mode(qat_model,train_loader, criter_qt, optimy_qt, qat_epoch, device, model_name="QAT ResNet18")

print("\nConverting QAT model to quantized model...")
qat_model.eval()
quantized_model = convert(qat_model, inplace=True)
print("Model converted to fully quantized (INT8) ResNet18.")

# --- 7. 양자화된 모델 평가 및 크기 비교 ---
print("\n--- Quantized ResNet18 Model Evaluation ---")
qat_accuracy = eval_mode(quantized_model, test_loader, device, model_name="Quantized ResNet18 (QAT)")
torch.save(quantized_model.state_dict(), save_path)
print(f"Quantized ResNet18 model saved to {save_path}")

fp32_model_size = os.path.getsize(fp32_model_path)
quantized_model_size = os.path.getsize(save_path)
print(f"\n--- Model Size Comparison ---")
print(f"FP32 ResNet18 Model Size: {fp32_model_size / (1024*1024):.2f} MB")
print(f"Quantized ResNet18 Model Size: {quantized_model_size / (1024*1024):.2f} MB")
print(f"Quantized model is approximately {fp32_model_size / quantized_model_size:.2f}x smaller than FP32 model.")

print(f"\n--- Final Summary ---")
print(f"FP32 ResNet18 Model Accuracy: {fp32_accuracy:.2f}%")
print(f"Quantized ResNet18 Model Accuracy (QAT): {qat_accuracy:.2f}%")

Using device: cuda mode
Number of training batches (subset): 79
Number of test batches (subset): 16
Model definition and QAT preparation functions defined.
Training and evaluation functions defined.

--- FP32 ResNet18 Model Training and Evaluation ---

--- FP32 ResNet18 Training ---


Epoch 1/5 (FP32 ResNet18):   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 1 Complete: Avg Loss: 1.3737, Accuracy: 52.54%


Epoch 2/5 (FP32 ResNet18):   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 2 Complete: Avg Loss: 0.8948, Accuracy: 68.82%


Epoch 3/5 (FP32 ResNet18):   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 3 Complete: Avg Loss: 0.7637, Accuracy: 73.66%


Epoch 4/5 (FP32 ResNet18):   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 4 Complete: Avg Loss: 0.6551, Accuracy: 78.34%


Epoch 5/5 (FP32 ResNet18):   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 5 Complete: Avg Loss: 0.6062, Accuracy: 79.24%
FP32 ResNet18 Training finished!


Evaluating FP32 ResNet18:   0%|          | 0/16 [00:00<?, ?it/s]

Accuracy on test set (FP32 ResNet18): 87.10%
FP32 ResNet18 model saved to ./qat_resnet18_model_fp32.pth

--- Quantization-Aware Training (QAT) for ResNet18 ---
QAT model initialized with FP32 model weights.
Preparing ResNet18 for QAT (fusing modules)...
ResNet18 QAT preparation complete.
Setting QConfig (fbgemm)...
QConfig set to: QConfig(activation=functools.partial(<class 'torch.ao.quantization.observer.HistogramObserver'>, reduce_range=True), weight=functools.partial(<class 'torch.ao.quantization.observer.PerChannelMinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_channel_symmetric))
Calling torch.quantization.prepare_qat...
Model prepared for QAT.

Starting QAT for 5 epochs...

--- QAT ResNet18 Training ---


Epoch 1/5 (QAT ResNet18):   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 1 Complete: Avg Loss: 163802.1837, Accuracy: 16.62%


Epoch 2/5 (QAT ResNet18):   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 2 Complete: Avg Loss: 6452.2827, Accuracy: 25.76%


Epoch 3/5 (QAT ResNet18):   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 3 Complete: Avg Loss: 5137.4137, Accuracy: 30.04%


Epoch 4/5 (QAT ResNet18):   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 4 Complete: Avg Loss: 4370.6192, Accuracy: 31.96%


Epoch 5/5 (QAT ResNet18):   0%|          | 0/79 [00:00<?, ?it/s]

Epoch 5 Complete: Avg Loss: 3580.3217, Accuracy: 34.08%
QAT ResNet18 Training finished!

Converting QAT model to quantized model...


RuntimeError: Unsupported qscheme: per_channel_affine

In [256]:
import torch
import torchvision
print(f"torch version: {torch.__version__}")
print(f"torchvision version: {torchvision.__version__}")

torch version: 2.7.0+cu126
torchvision version: 0.22.0+cu126


In [258]:
#3. Model Definition and Modification for QAT
# Change the QConfig setting again
print("Setting QConfig (fbgemm)...")
qat_model.qconfig = torch.ao.quantization.QConfig(
    # Try MovingAverageMinMaxObserver for activation
    activation=functools.partial(observer.MovingAverageMinMaxObserver, reduce_range=True),
    # Keep PerChannelMinMaxObserver for weight with symmetric qscheme
    weight=functools.partial(observer.PerChannelMinMaxObserver, dtype=torch.qint8, qscheme=torch.per_channel_symmetric)
)

Setting QConfig (fbgemm)...


In [264]:
# 필요한 라이브러리 임포트
import os
import functools

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
import torch.ao.quantization as quant
from torch.ao.quantization import (
    get_default_qconfig,
    QuantStub,
    DeQuantStub,
    prepare_qat,
    convert,
    observer
)
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

# --- 1. 하이퍼파라미터 및 장치 설정 ---
batch_size = 64
lr = 1e-4
num_epoch = 5      # FP32 학습 에폭
qat_epoch = 5      # QAT 학습 에폭
save_path = "./qat_resnet18_model.pth"

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# quantization 백엔드 설정
torch.backends.quantized.engine = 'fbgemm'

# --- 2. 데이터 로드 및 전처리 ---
transform_train = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std= [0.229, 0.224, 0.225]
    )
])

transform_test = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std= [0.229, 0.224, 0.225]
    )
])

train_data = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
test_data  = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

# subset 사용 (연습용)
subset_size_train = 5000
subset_size_test  = 1000
train_subset = torch.utils.data.Subset(train_data, torch.randperm(len(train_data))[:subset_size_train])
test_subset  = torch.utils.data.Subset(test_data,  torch.randperm(len(test_data))[:subset_size_test])

train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True,  num_workers=2)
test_loader  = DataLoader(test_subset,  batch_size=batch_size, shuffle=False, num_workers=2)

print(f"Train batches: {len(train_loader)}, Test batches: {len(test_loader)}")

# --- 3. 모델 정의 및 QAT 준비 ---
def get_resnet18(num_classes=10, pretrained=True):
    """ResNet18 로드 & 마지막 FC 레이어 클래스 수에 맞춰 교체"""
    model = models.resnet18(
        weights=models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None
    )
    in_features = model.fc.in_features
    model.fc = nn.Linear(in_features, num_classes)
    return model

class QuantizedResNet18(nn.Module):
    """QuantStub/DeQuantStub을 래핑하고, fuse까지 포함"""
    def __init__(self, fp32_model: nn.Module):
        super().__init__()
        self.quant   = QuantStub()
        self.dequant = DeQuantStub()
        self.model   = fp32_model
        self._fuse_modules()

    def _fuse_modules(self):
        # 전체 모델 fuse
        quant.fuse_modules(self.model, [['conv1', 'bn1', 'relu']], inplace=True)

        # BasicBlock fuse (conv1+bn1+relu, conv2+bn2) 및 downsample fuse
        for _, layer in self.model.named_children():
            if isinstance(layer, nn.Sequential):
                for block in layer:
                    if isinstance(block, models.resnet.BasicBlock):
                        quant.fuse_modules(block, [['conv1', 'bn1', 'relu']], inplace=True)
                        quant.fuse_modules(block, [['conv2', 'bn2']], inplace=True)
                        if block.downsample is not None:
                            quant.fuse_modules(block.downsample, ['0', '1'], inplace=True)

    def forward(self, x):
        x = self.quant(x)
        x = self.model(x)
        x = self.dequant(x)
        return x

print("Model definitions ready.")

# --- 4. 학습/평가 루프 정의 ---
def train_mode(model, loader, criterion, optimizer, epochs, device, name="Model"):
    model.train()
    print(f"\n--- {name} Training ---")
    for epoch in range(epochs):
        total_loss = 0.0
        correct, total = 0, 0
        pbar = tqdm(loader, desc=f"Epoch {epoch+1}/{epochs}")
        for data, target in pbar:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            outputs = model(data)
            loss = criterion(outputs, target)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            preds = outputs.argmax(dim=1)
            correct += (preds == target).sum().item()
            total   += target.size(0)
        print(f"[{name}] Epoch {epoch+1}: Loss={total_loss/len(loader):.4f}, Acc={100*correct/total:.2f}%")
    print(f"{name} training done.")

def eval_mode(model, loader, device, name="Model"):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for data, target in tqdm(loader, desc=f"Evaluating {name}"):
            data, target = data.to(device), target.to(device)
            outputs = model(data)
            preds = outputs.argmax(dim=1)
            correct += (preds == target).sum().item()
            total   += target.size(0)
    acc = 100 * correct / total
    print(f"[{name}] Test Accuracy: {acc:.2f}%")
    return acc

# --- 5. FP32 모델 학습 및 평가 ---
print("\n=== FP32 ResNet18 ===")
fp32 = get_resnet18().to(device)
criterion = nn.CrossEntropyLoss()
optimizer_fp32 = optim.Adam(fp32.parameters(), lr=lr)

train_mode(fp32, train_loader, criterion, optimizer_fp32, num_epoch, device, name="FP32")
fp32_acc = eval_mode(fp32, test_loader, device, name="FP32")
fp32_path = save_path.replace(".pth", "_fp32.pth")
torch.save(fp32.state_dict(), fp32_path)
print(f"Saved FP32 model → {fp32_path}")

# --- 6. QAT 준비 및 학습 ---
print("\n=== QAT ResNet18 ===")
# 1) FP32 모델 weight 로드
qat_fp32 = get_resnet18().to('cpu')
qat_fp32.load_state_dict(torch.load(fp32_path, map_location='cpu'))

# 2) QuantWrapper 생성 (fuse 포함)
qat_model = QuantizedResNet18(qat_fp32).to('cpu')
print("Fused and wrapped for QAT.")

# 3) QConfig 설정
qat_model.qconfig = torch.ao.quantization.QConfig(
    activation=functools.partial(observer.HistogramObserver, reduce_range=True),
    weight=functools.partial(
        observer.PerChannelMinMaxObserver,
        dtype=torch.qint8,
        qscheme=torch.per_channel_symmetric
    )
)
print("QConfig:", qat_model.qconfig)

# 4) Prepare QAT
prepare_qat(qat_model, inplace=True)
print("Prepared for QAT (fake-quant enabled).")

# 5) QAT 학습 (GPU에서 가능)
qat_model.to(device)
optimizer_qat = optim.Adam(qat_model.parameters(), lr=lr * 0.1)
train_mode(qat_model, train_loader, criterion, optimizer_qat, qat_epoch, device, name="QAT")

# --- 7. 최종 양자화 및 평가 (CPU 전용) ---
print("\n--- Converting to INT8 ---")
qat_model.to('cpu')
quantized = convert(qat_model, inplace=True)
print("Converted to INT8.")

quantized_acc = eval_mode(quantized, test_loader, torch.device('cpu'), name="Quantized")

# 8. 모델 저장 및 크기 비교
torch.save(quantized.state_dict(), save_path)
print(f"Saved quantized model → {save_path}")

fp32_size = os.path.getsize(fp32_path) / (1024**2)
quant_size = os.path.getsize(save_path) / (1024**2)
print(f"FP32 size: {fp32_size:.2f} MB")
print(f"Quant size: {quant_size:.2f} MB")
print(f"Size reduction: {fp32_size/quant_size:.2f}×")
print(f"FP32 Acc: {fp32_acc:.2f}%, Quant Acc: {quantized_acc:.2f}%")


Using device: cuda
Train batches: 79, Test batches: 16
Model definitions ready.

=== FP32 ResNet18 ===

--- FP32 Training ---


Epoch 1/5:   0%|          | 0/79 [00:00<?, ?it/s]

[FP32] Epoch 1: Loss=1.3870, Acc=51.20%


Epoch 2/5:   0%|          | 0/79 [00:00<?, ?it/s]

[FP32] Epoch 2: Loss=0.8966, Acc=69.34%


Epoch 3/5:   0%|          | 0/79 [00:00<?, ?it/s]

[FP32] Epoch 3: Loss=0.7603, Acc=74.18%


Epoch 4/5:   0%|          | 0/79 [00:00<?, ?it/s]

[FP32] Epoch 4: Loss=0.6360, Acc=78.30%


Epoch 5/5:   0%|          | 0/79 [00:00<?, ?it/s]

[FP32] Epoch 5: Loss=0.6119, Acc=79.16%
FP32 training done.


Evaluating FP32:   0%|          | 0/16 [00:00<?, ?it/s]

[FP32] Test Accuracy: 87.70%
Saved FP32 model → ./qat_resnet18_model_fp32.pth

=== QAT ResNet18 ===


AssertionError: Fusion only for eval!

In [266]:
# 필요한 라이브러리 임포트
import os
import functools

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
import torch.ao.quantization as quant
from torch.ao.quantization import (
    QuantStub,
    DeQuantStub,
    prepare_qat,
    convert,
    observer
)
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm

# --- 1. 하이퍼파라미터 및 장치 설정 ---
batch_size = 64
lr = 1e-4
num_epoch = 5      # FP32 학습 에폭
qat_epoch = 5      # QAT 학습 에폭
save_path = "./qat_resnet18_model.pth"

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# quantization 백엔드 설정
torch.backends.quantized.engine = 'fbgemm'

# --- 2. 데이터 로드 및 전처리 ---
transform_train = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std= [0.229, 0.224, 0.225]
    )
])

transform_test = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std= [0.229, 0.224, 0.225]
    )
])

train_data = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
test_data  = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

subset_size_train = 5000
subset_size_test  = 1000
train_subset = torch.utils.data.Subset(train_data, torch.randperm(len(train_data))[:subset_size_train])
test_subset  = torch.utils.data.Subset(test_data,  torch.randperm(len(test_data))[:subset_size_test])

train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True,  num_workers=2)
test_loader  = DataLoader(test_subset,  batch_size=batch_size, shuffle=False, num_workers=2)

print(f"Train batches: {len(train_loader)}, Test batches: {len(test_loader)}")

# --- 3. 모델 정의 및 QAT 준비 함수들 ---
def get_resnet18(num_classes=10, pretrained=True):
    """ResNet18 로드 & 마지막 FC 레이어를 num_classes에 맞춰 교체"""
    # PyTorch 2.x API
    weights = models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None
    model = models.resnet18(weights=weights)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model

class QuantizedResNet18(nn.Module):
    """QuantStub/DeQuantStub 래퍼 + fuse 처리"""
    def __init__(self, fp32_model: nn.Module):
        super().__init__()
        # 1) 기존 FP32 모델
        self.model = fp32_model
        # 2) fake-quant stub
        self.quant   = QuantStub()
        self.dequant = DeQuantStub()

        # 3) fuse를 위해 eval 모드로 전환 → fuse → train 모드 복귀
        self.model.eval()
        self._fuse_modules()
        self.model.train()

    def _fuse_modules(self):
        # 최상위 conv1, bn1, relu
        quant.fuse_modules(self.model, [['conv1', 'bn1', 'relu']], inplace=True)

        # BasicBlock 내부 fuse (conv1+bn1+relu, conv2+bn2) 및 downsample
        for module in self.model.modules():
            if isinstance(module, models.resnet.BasicBlock):
                quant.fuse_modules(module, [['conv1', 'bn1', 'relu']], inplace=True)
                quant.fuse_modules(module, [['conv2', 'bn2']], inplace=True)
                if module.downsample is not None:
                    # downsample: [Conv, BN]
                    quant.fuse_modules(module.downsample, ['0', '1'], inplace=True)

    def forward(self, x):
        x = self.quant(x)
        x = self.model(x)
        x = self.dequant(x)
        return x

# --- 4. 학습/평가 루프 정의 ---
def train_mode(model, loader, criterion, optimizer, epochs, device, name="Model"):
    model.train()
    print(f"\n--- {name} Training ---")
    for epoch in range(epochs):
        total_loss = 0.0
        correct, total = 0, 0
        for data, target in tqdm(loader, desc=f"{name} Epoch {epoch+1}/{epochs}"):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            outputs = model(data)
            loss = criterion(outputs, target)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            preds = outputs.argmax(dim=1)
            correct += (preds == target).sum().item()
            total   += target.size(0)

        print(f"[{name}] Epoch {epoch+1}: "
              f"Loss={total_loss/len(loader):.4f}, "
              f"Acc={100*correct/total:.2f}%")
    print(f"{name} training done.")

def eval_mode(model, loader, device, name="Model"):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for data, target in tqdm(loader, desc=f"Evaluating {name}"):
            data, target = data.to(device), target.to(device)
            outputs = model(data)
            preds = outputs.argmax(dim=1)
            correct += (preds == target).sum().item()
            total   += target.size(0)
    acc = 100 * correct / total
    print(f"[{name}] Test Accuracy: {acc:.2f}%")
    return acc

# --- 5. FP32 ResNet18 학습 및 평가 ---
print("\n=== FP32 ResNet18 ===")
fp32_model = get_resnet18().to(device)
criterion = nn.CrossEntropyLoss()
optimizer_fp32 = optim.Adam(fp32_model.parameters(), lr=lr)

train_mode(fp32_model, train_loader, criterion, optimizer_fp32, num_epoch, device, name="FP32")
fp32_acc = eval_mode(fp32_model, test_loader, device, name="FP32")

fp32_path = save_path.replace(".pth", "_fp32.pth")
torch.save(fp32_model.state_dict(), fp32_path)
print(f"Saved FP32 model → {fp32_path}")

# --- 6. QAT 준비 및 학습 ---
print


Using device: cuda
Train batches: 79, Test batches: 16

=== FP32 ResNet18 ===

--- FP32 Training ---


FP32 Epoch 1/5:   0%|          | 0/79 [00:00<?, ?it/s]

[FP32] Epoch 1: Loss=1.4009, Acc=50.76%


FP32 Epoch 2/5:   0%|          | 0/79 [00:00<?, ?it/s]

[FP32] Epoch 2: Loss=0.8908, Acc=69.28%


FP32 Epoch 3/5:   0%|          | 0/79 [00:00<?, ?it/s]

[FP32] Epoch 3: Loss=0.7633, Acc=73.28%


FP32 Epoch 4/5:   0%|          | 0/79 [00:00<?, ?it/s]

[FP32] Epoch 4: Loss=0.6708, Acc=77.16%


FP32 Epoch 5/5:   0%|          | 0/79 [00:00<?, ?it/s]

[FP32] Epoch 5: Loss=0.6140, Acc=78.90%
FP32 training done.


Evaluating FP32:   0%|          | 0/16 [00:00<?, ?it/s]

[FP32] Test Accuracy: 87.00%
Saved FP32 model → ./qat_resnet18_model_fp32.pth


<function print(*args, sep=' ', end='\n', file=None, flush=False)>

In [None]:
# 전체 수정된 QAT 파이프라인 코드
# 커널 폭파범 QAT 파이프라인 코드
import os
import functools

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models

# quantization 관련
import torch.ao.quantization as quant
from torch.ao.quantization import (
    QuantStub,
    DeQuantStub,
    prepare_qat,
    convert,
    observer
)
# *** 수정된 부분: FloatFunctional import 경로 ***
from torch.nn.quantized import FloatFunctional

from torchvision.models.resnet import BasicBlock
from tqdm.notebook import tqdm

# --- 1. 하이퍼파라미터 & 디바이스 설정 ---
batch_size  = 16
lr          = 1e-4
num_epoch   = 2    # FP32 학습 에폭
qat_epoch   = 2    # QAT 학습 에폭
save_path   = "./qat_resnet18_model.pth"

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# INT8 변환에 사용할 백엔드
torch.backends.quantized.engine = 'fbgemm'

# --- 2. 데이터 준비 ---
transform_train = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
])
transform_test = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
])

train_dataset = datasets.CIFAR10('./data', train=True,  download=True, transform=transform_train)
test_dataset  = datasets.CIFAR10('./data', train=False, download=True, transform=transform_test)

# 연습용으로 subset
train_subset = torch.utils.data.Subset(train_dataset, torch.randperm(len(train_dataset))[:5000])
test_subset  = torch.utils.data.Subset(test_dataset,  torch.randperm(len(test_dataset))[:1000])

train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True,  num_workers=2)
test_loader  = DataLoader(test_subset,  batch_size=batch_size, shuffle=False, num_workers=2)

print(f"Train batches: {len(train_loader)}, Test batches: {len(test_loader)}")


# --- 3. 모델 정의 Helpers ---
def get_resnet18(num_classes=10, pretrained=True):
    """ResNet18 불러와서 마지막 FC 교체"""
    weights = models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None
    model = models.resnet18(weights=weights)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model

class QBasicBlock(nn.Module):
    """BasicBlock에서 skip connection에 FloatFunctional을 쓰도록 변경"""
    def __init__(self, orig: BasicBlock):
        super().__init__()
        self.conv1      = orig.conv1
        self.bn1        = orig.bn1
        self.relu       = orig.relu
        self.conv2      = orig.conv2
        self.bn2        = orig.bn2
        self.downsample = orig.downsample
        self.skip_add   = FloatFunctional()

    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.skip_add.add(out, identity)
        out = self.relu(out)
        return out

class QuantizedResNet18(nn.Module):
    """QuantStub, DeQuantStub 래핑 + fuse + BasicBlock 교체"""
    def __init__(self, fp32_model: nn.Module):
        super().__init__()
        self.model   = fp32_model
        self.quant   = QuantStub()
        self.dequant = DeQuantStub()

        self.model.eval()
        self._fuse_modules()
        self._replace_basic_blocks()
        self.model.train()

    def _fuse_modules(self):
        quant.fuse_modules(self.model, [['conv1','bn1','relu']], inplace=True)
        for module in self.model.modules():
            if isinstance(module, BasicBlock):
                quant.fuse_modules(module, [['conv1','bn1','relu']], inplace=True)
                quant.fuse_modules(module, [['conv2','bn2']],   inplace=True)
                if module.downsample is not None:
                    quant.fuse_modules(module.downsample, ['0','1'], inplace=True)

    def _replace_basic_blocks(self):
        for name, child in list(self.model.named_children()):
            if isinstance(child, nn.Sequential):
                new_seq = []
                for blk in child:
                    if isinstance(blk, BasicBlock):
                        new_seq.append(QBasicBlock(blk))
                    else:
                        new_seq.append(blk)
                setattr(self.model, name, nn.Sequential(*new_seq))

    def forward(self, x):
        x = self.quant(x)
        x = self.model(x)
        x = self.dequant(x)
        return x


# --- 4. 학습/평가 루프 ---
def train_mode(model, loader, criterion, optimizer, epochs, device, name="Model"):
    model.train()
    print(f"\n--- {name} Training ---")
    for epoch in range(epochs):
        total_loss, correct, total = 0.0, 0, 0
        for data, target in tqdm(loader, desc=f"{name} Epoch {epoch+1}/{epochs}"):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            out = model(data)
            loss = criterion(out, target)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            preds = out.argmax(1)
            correct   += (preds == target).sum().item()
            total     += target.size(0)

        print(f"[{name}] Epoch {epoch+1}: "
              f"Loss={total_loss/len(loader):.4f}, "
              f"Acc={100*correct/total:.2f}%")
    print(f"{name} training done.")

def eval_mode(model, loader, device, name="Model"):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for data, target in tqdm(loader, desc=f"Evaluating {name}"):
            data, target = data.to(device), target.to(device)
            out = model(data)
            preds = out.argmax(1)
            correct += (preds == target).sum().item()
            total   += target.size(0)
    acc = 100 * correct / total
    print(f"[{name}] Test Accuracy: {acc:.2f}%")
    return acc


# --- 5. FP32 학습 & 저장 ---
print("\n=== FP32 ResNet18 ===")
fp32_model   = get_resnet18().to(device)
criterion    = nn.CrossEntropyLoss()
optimizer_fp = optim.Adam(fp32_model.parameters(), lr=lr)

train_mode(fp32_model, train_loader, criterion, optimizer_fp, num_epoch, device, name="FP32")
fp32_acc = eval_mode(fp32_model, test_loader, device, name="FP32")

fp32_path = save_path.replace(".pth","_fp32.pth")
torch.save(fp32_model.state_dict(), fp32_path)
print(f"Saved FP32 → {fp32_path}")


# --- 6. QAT 준비 & 학습 ---
print("\n=== QAT ResNet18 ===")
qat_fp32 = get_resnet18(pretrained=False)
qat_fp32.load_state_dict(torch.load(fp32_path, map_location='cpu'))

qat_model = QuantizedResNet18(qat_fp32).to('cpu')
print("Fused+wrapped for QAT.")

qat_model.qconfig = quant.QConfig(
    activation=functools.partial(observer.HistogramObserver, reduce_range=True),
    weight=functools.partial(
        observer.PerChannelMinMaxObserver,
        dtype=torch.qint8,
        qscheme=torch.per_channel_symmetric
    )
)
print("QConfig:", qat_model.qconfig)

prepare_qat(qat_model, inplace=True)
print("Prepared for QAT.")

qat_model.to(device)
optimizer_qat = optim.Adam(qat_model.parameters(), lr=lr * 0.1)
train_mode(qat_model, train_loader, criterion, optimizer_qat, qat_epoch, device, name="QAT")


# --- 7. INT8 변환 & 평가 (CPU) ---
print("\n--- Converting to INT8 ---")
qat_model.to('cpu')
quantized_model = convert(qat_model, inplace=True)
print("Converted to INT8.")

quant_acc = eval_mode(quantized_model, test_loader, torch.device('cpu'), name="Quantized")


# --- 8. 저장 & 크기 비교 ---
torch.save(quantized_model.state_dict(), save_path)
print(f"Saved Quantized → {save_path}")

#fp32_sz  = os.path.getsize(fp32_path)/(1024**2)
quant_sz = os.path.getsize(save_path)/(1024**2)
#print(f"FP32 size: {fp32_sz:.2f} MB")
print(f"Quant size: {quant_sz:.2f} MB")
#print(f"Size reduction: {fp32_sz/quant_sz:.2f}×")
#print(f"Acc FP32: {fp32_acc:.2f}%, Quant: {quant_acc:.2f}%")


Using device: cuda
Train batches: 313, Test batches: 63

=== FP32 ResNet18 ===

--- FP32 Training ---


FP32 Epoch 1/2:   0%|          | 0/313 [00:00<?, ?it/s]

[FP32] Epoch 1: Loss=1.2930, Acc=54.96%


FP32 Epoch 2/2:   0%|          | 0/313 [00:00<?, ?it/s]

[FP32] Epoch 2: Loss=0.9496, Acc=67.42%
FP32 training done.


Evaluating FP32:   0%|          | 0/63 [00:00<?, ?it/s]

[FP32] Test Accuracy: 84.60%
Saved FP32 → ./qat_resnet18_model_fp32.pth

=== QAT ResNet18 ===
Fused+wrapped for QAT.
QConfig: QConfig(activation=functools.partial(<class 'torch.ao.quantization.observer.HistogramObserver'>, reduce_range=True), weight=functools.partial(<class 'torch.ao.quantization.observer.PerChannelMinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_channel_symmetric))
Prepared for QAT.

--- QAT Training ---




QAT Epoch 1/2:   0%|          | 0/313 [00:00<?, ?it/s]

[QAT] Epoch 1: Loss=62055.0435, Acc=24.46%


QAT Epoch 2/2:   0%|          | 0/313 [00:00<?, ?it/s]

[QAT] Epoch 2: Loss=2771.6464, Acc=34.64%
QAT training done.

--- Converting to INT8 ---
Converted to INT8.


Evaluating Quantized:   0%|          | 0/63 [00:00<?, ?it/s]

In [None]:
import os
import functools

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models

import torch.ao.quantization as quant
from torch.ao.quantization import (
    QuantStub, DeQuantStub,
    prepare_qat, convert, observer
)
from torch.nn.quantized import FloatFunctional
from torchvision.models.resnet import BasicBlock
from tqdm.auto import tqdm

# --- 1. 하이퍼파라미터 & 디바이스 ---
# **수정: 배치 사이즈를 더 적극적으로 줄였습니다.**
# 문제가 지속되면 8, 4까지도 시도해 보세요.
batch_size  = 16 # 이전 32에서 16으로 변경. 메모리 문제가 가장 유력한 원인.
lr          = 1e-4
num_epoch   = 2 # FP32 학습 에폭 (공부용으로 짧게 유지)
qat_epoch   = 2 # QAT 학습 에폭 (공부용으로 짧게 유지)
save_path   = "./qat_resnet18_model.pth"

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", device)

torch.backends.cudnn.benchmark = True # GPU 사용 시 성능 향상
torch.backends.quantized.engine = 'fbgemm' # 양자화 백엔드 설정 (x86 CPU 환경용)

# --- 2. 데이터 로드 ---
transform_train = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
])
transform_test = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
])

train_ds = datasets.CIFAR10("./data", train=True, download=True, transform=transform_train)
test_ds  = datasets.CIFAR10("./data", train=False,download=True, transform=transform_test)

# **수정: num_workers를 1로 줄이고 pin_memory=False로 설정.**
# CPU로 모델을 옮겨 양자화 변환 시, 데이터 로딩으로 인한 메모리 부하를 최소화합니다.
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
                          num_workers=1, pin_memory=False)
test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False,
                          num_workers=1, pin_memory=False)

# --- 3. 모델 및 QuantBlock 정의 ---
def get_resnet18(num_classes=10, pretrained=True):
    weights = models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None
    m = models.resnet18(weights=weights)
    m.fc = nn.Linear(m.fc.in_features, num_classes)
    return m

class QBasicBlock(nn.Module):
    def __init__(self, orig: BasicBlock):
        super().__init__()
        # fused block 그대로 재사용
        self.conv1      = orig.conv1
        self.bn1        = orig.bn1
        self.relu       = orig.relu
        self.conv2      = orig.conv2
        self.bn2        = orig.bn2
        self.downsample = orig.downsample
        self.skip_add   = FloatFunctional()
    def forward(self, x):
        identity = x
        out = self.conv1(x); out = self.bn1(out); out = self.relu(out)
        out = self.conv2(out); out = self.bn2(out)
        if self.downsample: identity = self.downsample(x)
        out = self.skip_add.add(out, identity)
        out = self.relu(out)
        return out

class QuantizedResNet18(nn.Module):
    def __init__(self, fp32_model: nn.Module):
        super().__init__()
        self.model    = fp32_model
        self.quant    = QuantStub()
        self.dequant = DeQuantStub()
        # fuse 전에 eval → fuse → replace → train
        self.model.eval()
        self._fuse_modules()
        self._replace_blocks()
        self.model.train()
    def _fuse_modules(self):
        quant.fuse_modules(self.model, [['conv1','bn1','relu']], inplace=True)
        for m in self.model.modules():
            if isinstance(m, BasicBlock):
                quant.fuse_modules(m, [['conv1','bn1','relu']], inplace=True)
                quant.fuse_modules(m, [['conv2','bn2']],   inplace=True)
                if m.downsample:
                    quant.fuse_modules(m.downsample, ['0','1'], inplace=True)
    def _replace_blocks(self):
        for name, child in list(self.model.named_children()):
            if isinstance(child, nn.Sequential):
                new_seq = []
                for blk in child:
                    new_seq.append(QBasicBlock(blk) if isinstance(blk, BasicBlock) else blk)
                setattr(self.model, name, nn.Sequential(*new_seq))
    def forward(self, x):
        x = self.quant(x)
        x = self.model(x)
        x = self.dequant(x)
        return x

# --- 4. 학습/평가 루프 정의 ---
# FP32 training with AMP
scaler = torch.cuda.amp.GradScaler()
def train_fp32(model, loader, criterion, optimizer, epochs, device):
    model.train()
    for ep in range(epochs):
        tot_loss, corr, tot = 0,0,0
        for x,y in tqdm(loader, desc=f"FP32 Ep{ep+1}/{epochs}"):
            x,y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
            optimizer.zero_grad()
            with torch.cuda.amp.autocast():
                out = model(x)
                loss = criterion(out, y)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            tot_loss += loss.item()
            preds = out.argmax(1)
            corr    += (preds==y).sum().item()
            tot     += y.size(0)
        print(f"  FP32 Ep{ep+1}: Loss={tot_loss/len(loader):.4f} Acc={100*corr/tot:.2f}%")

# QAT training without AMP (올바른 접근 방식)
def train_qat(model, loader, criterion, optimizer, epochs, device):
    model.train()
    for ep in range(epochs):
        tot_loss, corr, tot = 0,0,0
        for x,y in tqdm(loader, desc=f"QAT Ep{ep+1}/{epochs}"):
            x,y = x.to(device), y.to(device)
            optimizer.zero_grad()
            out = model(x)
            loss = criterion(out, y)
            loss.backward()
            optimizer.step()
            tot_loss += loss.item()
            preds = out.argmax(1)
            corr    += (preds==y).sum().item()
            tot     += y.size(0)
        print(f"  QAT Ep{ep+1}: Loss={tot_loss/len(loader):.4f} Acc={100*corr/tot:.2f}%")

def eval_model(model, loader, device, name):
    model.eval()
    corr, tot = 0,0
    with torch.no_grad():
        for x,y in tqdm(loader, desc=f"Eval {name}"):
            x,y = x.to(device), y.to(device) # 데이터를 명시된 device로 이동 (CPU 또는 GPU)
            out = model(x)
            corr += (out.argmax(1)==y).sum().item()
            tot  += y.size(0)
    acc = 100*corr/tot
    print(f"  {name} Accuracy: {acc:.2f}%")
    return acc

# --- 5. FP32 학습 & 평가 ---
print("=== FP32 Training ===")
fp32 = get_resnet18().to(device)
crit = nn.CrossEntropyLoss()
opt_fp = optim.Adam(fp32.parameters(), lr=lr)
train_fp32(fp32, train_loader, crit, opt_fp, num_epoch, device)
fp32_acc = eval_model(fp32, test_loader, device, "FP32")
torch.save(fp32.state_dict(), save_path.replace(".pth","_fp32.pth"))

# --- 6. QAT 준비 & 학습 ---
print("=== QAT Setup ===")
qat_fp32 = get_resnet18(pretrained=False)
qat_fp32.load_state_dict(torch.load(save_path.replace(".pth","_fp32.pth"), map_location='cpu'))
qat_model = QuantizedResNet18(qat_fp32).to('cpu') # QAT 모델은 CPU에서 준비

qat_model.qconfig = quant.QConfig(
    activation=functools.partial(observer.HistogramObserver, reduce_range=True),
    weight=functools.partial(observer.PerChannelMinMaxObserver,
                             dtype=torch.qint8,
                             qscheme=torch.per_channel_symmetric)
)
prepare_qat(qat_model, inplace=True)

print("=== QAT Training ===")
qat_model.to(device) # QAT 학습은 GPU에서 진행
opt_qat = optim.Adam(qat_model.parameters(), lr=lr*0.1)
train_qat(qat_model, train_loader, crit, opt_qat, qat_epoch, device)


# --- 7. INT8 변환 & 평가 ---
print("=== Converting to INT8 ===")
qat_model.to('cpu') # 변환 전 CPU로 이동
if torch.cuda.is_available():
    torch.cuda.empty_cache() # **추가: GPU 캐시 비우기**

# **수정: convert 전에 모델을 eval() 모드로 전환하고, inplace=False로 새로운 모델 객체 생성**
quantized = convert(qat_model.eval(), inplace=False)
# 양자화 모델 평가 시 명시적으로 'cpu' device 사용
quant_acc = eval_model(quantized, test_loader, torch.device('cpu'), "Quantized")

# --- 8. 저장 & 크기 비교 ---
torch.save(quantized.state_dict(), save_path)
print("Sizes:",
      f"FP32 {(os.path.getsize(save_path.replace('.pth','_fp32.pth'))/1e6):.2f}MB →",
      f"Quant {(os.path.getsize(save_path)/1e6):.2f}MB")
print(f"Accuracies: FP32 {fp32_acc:.2f}%  Quant {quant_acc:.2f}%")

Device: cuda


  scaler = torch.cuda.amp.GradScaler()


=== FP32 Training ===


FP32 Ep1/2:   0%|          | 0/3125 [00:02<?, ?it/s]

  with torch.cuda.amp.autocast():


  FP32 Ep1: Loss=0.8852 Acc=69.30%


FP32 Ep2/2:   0%|          | 0/3125 [00:02<?, ?it/s]

  FP32 Ep2: Loss=0.6884 Acc=76.23%


Eval FP32:   0%|          | 0/625 [00:02<?, ?it/s]

  FP32 Accuracy: 90.89%
=== QAT Setup ===




=== QAT Training ===


QAT Ep1/2:   0%|          | 0/3125 [00:02<?, ?it/s]

  QAT Ep1: Loss=2941.6958 Acc=39.58%


QAT Ep2/2:   0%|          | 0/3125 [00:02<?, ?it/s]

  QAT Ep2: Loss=44.7330 Acc=34.67%
=== Converting to INT8 ===


Eval Quantized:   0%|          | 0/625 [00:02<?, ?it/s]

In [None]:
import os
import functools

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models

import torch.ao.quantization as quant
from torch.ao.quantization import (
    QuantStub, DeQuantStub,
    prepare_qat, convert, observer
)
from torch.nn.quantized import FloatFunctional
from torchvision.models.resnet import BasicBlock
from tqdm.auto import tqdm

# --- 1. 하이퍼파라미터 & 디바이스 ---
batch_size  = 64
lr          = 1e-4
num_epoch   = 2
qat_epoch   = 2
save_path   = "./qat_resnet18_model.pth"

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", device)

torch.backends.cudnn.benchmark = True
torch.backends.quantized.engine = 'fbgemm'

# --- 2. 데이터 로드 ---
transform_train = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
])
transform_test = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
])

train_ds = datasets.CIFAR10("./data", train=True, download=True, transform=transform_train)
test_ds  = datasets.CIFAR10("./data", train=False,download=True, transform=transform_test)

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
                          num_workers=8, pin_memory=True)
test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False,
                          num_workers=8, pin_memory=True)

# --- 3. 모델 및 QuantBlock 정의 ---
def get_resnet18(num_classes=10, pretrained=True):
    weights = models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None
    m = models.resnet18(weights=weights)
    m.fc = nn.Linear(m.fc.in_features, num_classes)
    return m

class QBasicBlock(nn.Module):
    def __init__(self, orig: BasicBlock):
        super().__init__()
        # fused block 그대로 재사용
        self.conv1      = orig.conv1
        self.bn1        = orig.bn1
        self.relu       = orig.relu
        self.conv2      = orig.conv2
        self.bn2        = orig.bn2
        self.downsample = orig.downsample
        self.skip_add   = FloatFunctional()
    def forward(self, x):
        identity = x
        out = self.conv1(x); out = self.bn1(out); out = self.relu(out)
        out = self.conv2(out); out = self.bn2(out)
        if self.downsample: identity = self.downsample(x)
        out = self.skip_add.add(out, identity)
        out = self.relu(out)
        return out

class QuantizedResNet18(nn.Module):
    def __init__(self, fp32_model: nn.Module):
        super().__init__()
        self.model   = fp32_model
        self.quant   = QuantStub()
        self.dequant = DeQuantStub()
        # fuse 전에 eval → fuse → replace → train
        self.model.eval()
        self._fuse_modules()
        self._replace_blocks()
        self.model.train()
    def _fuse_modules(self):
        quant.fuse_modules(self.model, [['conv1','bn1','relu']], inplace=True)
        for m in self.model.modules():
            if isinstance(m, BasicBlock):
                quant.fuse_modules(m, [['conv1','bn1','relu']], inplace=True)
                quant.fuse_modules(m, [['conv2','bn2']],   inplace=True)
                if m.downsample:
                    quant.fuse_modules(m.downsample, ['0','1'], inplace=True)
    def _replace_blocks(self):
        for name, child in list(self.model.named_children()):
            if isinstance(child, nn.Sequential):
                new_seq = []
                for blk in child:
                    new_seq.append(QBasicBlock(blk) if isinstance(blk, BasicBlock) else blk)
                setattr(self.model, name, nn.Sequential(*new_seq))
    def forward(self, x):
        x = self.quant(x)
        x = self.model(x)
        x = self.dequant(x)
        return x

# --- 4. 학습/평가 루프 정의 ---
# FP32 training with AMP
scaler = torch.cuda.amp.GradScaler()
def train_fp32(model, loader, criterion, optimizer, epochs, device):
    model.train()
    for ep in range(epochs):
        tot_loss, corr, tot = 0,0,0
        for x,y in tqdm(loader, desc=f"FP32 Ep{ep+1}/{epochs}"):
            x,y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
            optimizer.zero_grad()
            with torch.cuda.amp.autocast():
                out = model(x)
                loss = criterion(out, y)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            tot_loss += loss.item()
            preds = out.argmax(1)
            corr    += (preds==y).sum().item()
            tot     += y.size(0)
        print(f"  FP32 Ep{ep+1}: Loss={tot_loss/len(loader):.4f} Acc={100*corr/tot:.2f}%")

# QAT training without AMP
def train_qat(model, loader, criterion, optimizer, epochs, device):
    model.train()
    for ep in range(epochs):
        tot_loss, corr, tot = 0,0,0
        for x,y in tqdm(loader, desc=f"QAT Ep{ep+1}/{epochs}"):
            x,y = x.to(device), y.to(device)
            optimizer.zero_grad()
            out = model(x)               # ← 여기서는 FP32 모드
            loss = criterion(out, y)
            loss.backward()
            optimizer.step()
            tot_loss += loss.item()
            preds = out.argmax(1)
            corr    += (preds==y).sum().item()
            tot     += y.size(0)
        print(f"  QAT Ep{ep+1}: Loss={tot_loss/len(loader):.4f} Acc={100*corr/tot:.2f}%")

def eval_model(model, loader, device, name):
    model.eval()
    corr, tot = 0,0
    with torch.no_grad():
        for x,y in tqdm(loader, desc=f"Eval {name}"):
            x,y = x.to(device), y.to(device)
            out = model(x)
            corr += (out.argmax(1)==y).sum().item()
            tot  += y.size(0)
    acc = 100*corr/tot
    print(f"  {name} Accuracy: {acc:.2f}%")
    return acc

# --- 5. FP32 학습 & 평가 ---
print("=== FP32 Training ===")
fp32 = get_resnet18().to(device)
crit = nn.CrossEntropyLoss()
opt_fp = optim.Adam(fp32.parameters(), lr=lr)
train_fp32(fp32, train_loader, crit, opt_fp, num_epoch, device)
fp32_acc = eval_model(fp32, test_loader, device, "FP32")
torch.save(fp32.state_dict(), save_path.replace(".pth","_fp32.pth"))

# --- 6. QAT 준비 & 학습 ---
print("=== QAT Setup ===")
qat_fp32 = get_resnet18(pretrained=False)
qat_fp32.load_state_dict(torch.load(save_path.replace(".pth","_fp32.pth"), map_location='cpu'))
qat_model = QuantizedResNet18(qat_fp32).to('cpu')

qat_model.qconfig = quant.QConfig(
    activation=functools.partial(observer.HistogramObserver, reduce_range=True),
    weight=functools.partial(observer.PerChannelMinMaxObserver,
                             dtype=torch.qint8,
                             qscheme=torch.per_channel_symmetric)
)
prepare_qat(qat_model, inplace=True)

print("=== QAT Training ===")
qat_model.to(device)
opt_qat = optim.Adam(qat_model.parameters(), lr=lr*0.1)
train_qat(qat_model, train_loader, crit, opt_qat, qat_epoch, device)

# --- 7. INT8 변환 & 평가 ---
print("=== Converting to INT8 ===")
qat_model.to('cpu')
quantized = convert(qat_model, inplace=True)
quant_acc = eval_model(quantized, test_loader, torch.device('cpu'), "Quantized")

# --- 8. 저장 & 크기 비교 ---
torch.save(quantized.state_dict(), save_path)
print("Sizes:",
      f"FP32 {(os.path.getsize(save_path.replace('.pth','_fp32.pth'))/1e6):.2f}MB →",
      f"Quant {(os.path.getsize(save_path)/1e6):.2f}MB")
print(f"Accuracies: FP32 {fp32_acc:.2f}%  Quant {quant_acc:.2f}%")


Device: cuda


  scaler = torch.cuda.amp.GradScaler()


=== FP32 Training ===


FP32 Ep1/2:   0%|          | 0/782 [00:22<?, ?it/s]

  with torch.cuda.amp.autocast():


  FP32 Ep1: Loss=0.7929 Acc=72.35%


FP32 Ep2/2:   0%|          | 0/782 [00:22<?, ?it/s]

  FP32 Ep2: Loss=0.5723 Acc=80.07%


Eval FP32:   0%|          | 0/157 [00:20<?, ?it/s]

  FP32 Accuracy: 91.88%
=== QAT Setup ===
=== QAT Training ===




QAT Ep1/2:   0%|          | 0/782 [00:22<?, ?it/s]

  QAT Ep1: Loss=8463.5545 Acc=34.96%


QAT Ep2/2:   0%|          | 0/782 [00:22<?, ?it/s]

  QAT Ep2: Loss=698.6713 Acc=42.55%
=== Converting to INT8 ===


Eval Quantized:   0%|          | 0/157 [00:20<?, ?it/s]

In [2]:
!watch -n 0.5 nvidia-smi


'watch' is not recognized as an internal or external command,
operable program or batch file.
