In [53]:
import torch
from torch.utils.data import DataLoader, Dataset, random_split
import torchvision
import torchvision.transforms as transforms
from torchvision import models
from torch.optim.lr_scheduler import CosineAnnealingLR
import torchvision.models as models
from tqdm import tqdm
import random
import numpy as np
import os
import torchao
from copy import deepcopy
import time
from torchao.quantization import (
    quantize_,
    Int8WeightOnlyConfig,
    Int4WeightOnlyConfig
)
from torchao.quantization.qat import QATConfig
from torchao.dtypes import Int4CPULayout
from torch import nn
import torchao

from torchao.quantization import Int8DynamicActivationInt8WeightConfig

In [54]:
NUM_EPOCHS = 5
LEARNING_RATE = 0.001
BATCH_SIZE = 128
NUM_WORKERS = 0

In [55]:
def train_epoch(model, loader, criterion, optimizer, device):
  model.to(device)
  model.train()
  running_loss, correct, total = 0.0, 0, 0

  for inputs, labels in tqdm(loader):
    inputs, labels = inputs.to(device), labels.to(device)

    optimizer.zero_grad()
    logits = model(inputs)              # forward pass
    loss = criterion(logits, labels)
    loss.backward()
    optimizer.step()

    running_loss += loss.item() * inputs.size(0)
    _, predicted = torch.max(logits.data, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum().item()

  avg_loss = running_loss / total
  accuracy = 100. * correct / total
  return avg_loss, accuracy


def evaluate(model, test_loader, device):
  model.eval()
  model.to(device)

  correct = 0
  total = 0

  with torch.no_grad():
    for images, labels in tqdm(test_loader):
      images = images.to(device)
      labels = labels.to(device)

      outputs = model(images)
      _, predicted = torch.max(outputs.data, 1)

      total += labels.size(0)
      correct += (predicted == labels).sum().item()

  accuracy = 100 * correct / total
  return accuracy

In [56]:
def fix_random_seed(seed=42):
  random.seed(seed)
  np.random.seed(seed)
  os.environ["PYTHONHASHSEED"] = str(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed(seed)
  torch.cuda.manual_seed_all(seed)
  torch.backends.cudnn.deterministic = True
  torch.backends.cudnn.benchmark = False
  print(f"Fixed random seed: {seed}")

fix_random_seed(42)

# For deterministic DataLoader behavior
def seed_worker(worker_id):
  worker_seed = torch.initial_seed() % 2**32
  np.random.seed(worker_seed)
  random.seed(worker_seed)

def get_model_size(model, path="/tmp/temp_model.pt"):
  """
  Save a PyTorch model temporarily and return its size in MB.

  Args:
      model: PyTorch model (nn.Module)
      path: Temporary file path to save the model

  Returns:
      Size of the model in MB (float)
  """
  torch.save(model, path)
  size_mb = os.path.getsize(path) / 1024 / 1024
  return size_mb

Fixed random seed: 42


In [57]:
imagenet_mean = [0.485, 0.456, 0.406]
imagenet_std = [0.229, 0.224, 0.225]

g = torch.Generator()
g.manual_seed(42)

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.9, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(imagenet_mean, imagenet_std),
])

test_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(imagenet_mean, imagenet_std),
])

full_train = torchvision.datasets.CIFAR100(root='./data', train=True, download=True)
test_dataset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True)
train_subset, val_subset = random_split(full_train, [45000, 5000], generator=g)

train_subset.dataset.transform = train_transform
val_subset.dataset.transform = test_transform
test_dataset.transform = test_transform

print("Training set size:", len(train_subset))
print("Validation set size:", len(val_subset))
print("Test set size:", len(test_dataset))

def get_loader(dataset, shuffle):
    return DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        shuffle=shuffle,
        num_workers=NUM_WORKERS,
        pin_memory=True,
        worker_init_fn=seed_worker,
        generator=g
    )

train_loader = get_loader(train_subset, shuffle=True)
val_loader = get_loader(val_subset, shuffle=False)
test_loader = get_loader(test_dataset, shuffle=False)

Training set size: 45000
Validation set size: 5000
Test set size: 10000


In [58]:
model_base = models.vgg11(weights=models.VGG11_Weights.IMAGENET1K_V1)
model_base.classifier[6] = torch.nn.Linear(4096, 100)
model_base.classifier[5] = torch.nn.Dropout(p=0.5) # Dropout
device = "cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu"
model_base = model_base.to(device)

# base_lr = 0.04
# weight_decay = 5e-4
# num_epochs = 6
# optim = torch.optim.SGD(
#     model_base.parameters(),
#     lr=base_lr,
#     momentum=0.9,
#     weight_decay=weight_decay
# )

# scheduler = CosineAnnealingLR(optim, T_max=num_epochs, eta_min=1e-5)

# criterion = torch.nn.CrossEntropyLoss()

In [59]:
model = deepcopy(model_base)
model.load_state_dict(torch.load("vgg11_cifar100_baseline_5e.pt", map_location=device))

<All keys matched successfully>

In [60]:
# Alternative: Manually quantize specific layers
def apply_mixed_precision_simple(model):
    """
    Apply INT8 to specific linear layers only
    """
    # Quantize classifier[0] (first linear layer)
    temp_0 = nn.Sequential(model.classifier[0])
    quantize_(temp_0, Int8WeightOnlyConfig())
    model.classifier[0] = temp_0[0]
    print("Quantized classifier[0] to INT8")
    
    # Quantize classifier[3] (second linear layer)
    temp_3 = nn.Sequential(model.classifier[3])
    quantize_(temp_3, Int8WeightOnlyConfig())
    model.classifier[3] = temp_3[0]
    print("Quantized classifier[3] to INT8")
    
    # Keep classifier[6] as FP32 (don't quantize)
    print("Kept classifier[6] as FP32")
    
    # Keep all conv layers as FP32 (don't quantize features)
    print("Kept all conv layers as FP32")
    
    return model

model_mixed = deepcopy(model)
model_mixed.to(device)
model_mixed = apply_mixed_precision_simple(model_mixed)
# model_mixed = torch.compile(model_mixed, mode='max-autotune')

for name, layer in model_mixed.named_modules():
    if isinstance(layer, torch.nn.Linear):
        # Check if weight has tensor_impl (quantized) or not (FP32)
        if hasattr(layer.weight, 'tensor_impl'):
            print(f"{name}: INT8 (quantized) - {layer.weight.tensor_impl.dtype}")
        else:
            print(f"{name}: FP32 (not quantized) - {layer.weight.dtype}")
    elif isinstance(layer, torch.nn.Conv2d):
        print(f"{name}: FP32 (conv layer) - {layer.weight.dtype}")

correct = 0
total = 0

start_time = time.time()

with torch.no_grad():
    for inputs, labels in tqdm(test_loader):
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model_mixed(inputs)

        _, predicted = torch.max(outputs.data, 1)

        total += labels.size(0)
        correct += (predicted == labels).sum().item()

end_time = time.time()
inf_time = end_time - start_time

test_acc = 100 * correct / total
model_size = get_model_size(model_mixed, "vgg11_mixed.pt")
print(f"\n Test Accuracy = {test_acc:.2f}, Inference Time = {inf_time}, Size = {model_size:.2f} MB")

Quantized classifier[0] to INT8
Quantized classifier[3] to INT8
Kept classifier[6] as FP32
Kept all conv layers as FP32
features.0: FP32 (conv layer) - torch.float32
features.3: FP32 (conv layer) - torch.float32
features.6: FP32 (conv layer) - torch.float32
features.8: FP32 (conv layer) - torch.float32
features.11: FP32 (conv layer) - torch.float32
features.13: FP32 (conv layer) - torch.float32
features.16: FP32 (conv layer) - torch.float32
features.18: FP32 (conv layer) - torch.float32
classifier.0: INT8 (quantized) - torch.int8
classifier.3: INT8 (quantized) - torch.int8
classifier.6: FP32 (not quantized) - torch.float32


100%|██████████| 79/79 [01:03<00:00,  1.24it/s]



 Test Accuracy = 67.77, Inference Time = 63.709638833999634, Size = 150.88 MB


In [None]:
def compute_activation_variance(model, data_loader, device):
    """
    Compute per-layer activation variance using forward hooks.
    """
    activation_variances = {}

    def hook_fn(name):
        def hook(module, input, output):
            # Compute variance of activations (flattened)
            with torch.no_grad():
                var = output.detach().float().var().item()
                if name in activation_variances:
                    activation_variances[name].append(var)
                else:
                    activation_variances[name] = [var]
        return hook

    hooks = []
    for name, module in model.named_modules():
        if isinstance(module, (nn.Conv2d, nn.Linear)):
            hooks.append(module.register_forward_hook(hook_fn(name)))

    # Run a few batches through to collect stats
    model.eval()
    with torch.no_grad():
        for i, (inputs, _) in enumerate(data_loader):
            inputs = inputs.to(device)
            model(inputs)
            if i >= 5:  # use a few batches for estimation
                break

    # Average the variances for each layer
    for name in activation_variances:
        activation_variances[name] = sum(activation_variances[name]) / len(activation_variances[name])

    # Remove hooks
    for h in hooks:
        h.remove()

    return activation_variances

def get_module_and_parent(root: nn.Module, dotted_name: str):
    """
    Return (parent_module, child_name, child_module)
    For dotted_name "a.b.c" returns (module_at_a.b, "c", module_at_a.b.c)
    If dotted_name has no dot, parent is root and child_name is dotted_name.
    """
    parts = dotted_name.split('.')
    if len(parts) == 1:
        parent = root
        child_name = parts[0]
        child = parent._modules.get(child_name, None)
        return parent, child_name, child

    parent = root
    for p in parts[:-1]:
        parent = parent._modules.get(p)
        if parent is None:
            return None, None, None
    child_name = parts[-1]
    child = parent._modules.get(child_name, None)
    return parent, child_name, child

def set_submodule_by_name(root: nn.Module, dotted_name: str, new_module: nn.Module, device=None):
    """
    Replace submodule located at dotted_name with new_module.
    Moves new_module to device if provided.
    """
    parent, child_name, child = get_module_and_parent(root, dotted_name)
    if parent is None:
        raise KeyError(f"Parent for '{dotted_name}' not found.")
    if device is None:
        for p in parent.parameters(recurse=False):
            device = p.device
            break
    if device is not None:
        new_module.to(device)
    parent._modules[child_name] = new_module

class FP16Wrapper(nn.Module):
    """Wrap a module to run internal computation in float16 while exposing float32 interface."""
    def __init__(self, module: nn.Module):
        super().__init__()
        self.module = deepcopy(module).half()

    def forward(self, x):
        if x.dtype != torch.float16:
            x_h = x.to(torch.float16)
        else:
            x_h = x
        out_h = self.module(x_h)
        if out_h.dtype != torch.float32:
            return out_h.to(torch.float32)
        return out_h

def apply_adaptive_mixed_precision_fixed(model: nn.Module, act_variances: dict,
                                         device=None, keep_conv_as_fp32=True,
                                         low_pct=30, high_pct=70):
    """
    Adaptive precision selection with correct submodule replacement.
    - low_pct / high_pct define percentile thresholds.
    - keep_conv_as_fp32: if False, convs will be considered for FP16/INT8 as well.
    """
    vars_list = list(act_variances.values()) if len(act_variances) > 0 else [0.0]
    low_var_threshold = float(np.percentile(vars_list, low_pct))
    high_var_threshold = float(np.percentile(vars_list, high_pct))

    print(f"Adaptive thresholds (percentiles {low_pct}/{high_pct}):"
          f" low={low_var_threshold:.6f}, high={high_var_threshold:.6f}")

    named_modules = list(model.named_modules())

    for name, module in named_modules:
        if name == '':
            continue

        if isinstance(module, nn.Linear) or (not keep_conv_as_fp32 and isinstance(module, nn.Conv2d)):
            var = float(act_variances.get(name, 0.0))
            if var < low_var_threshold:
                q_temp = deepcopy(module)
                q_temp_seq = nn.Sequential(q_temp)
                quantize_(q_temp_seq, Int8WeightOnlyConfig())
                new_mod = q_temp_seq[0]
                print(f"{name}: -> INT8 (var={var:.6f} < {low_var_threshold:.6f})")

            elif var < high_var_threshold:
                new_mod = FP16Wrapper(module)
                print(f"{name}: -> FP16 (var={var:.6f} between {low_var_threshold:.6f} and {high_var_threshold:.6f})")

            else:
                new_mod = deepcopy(module).float()
                print(f"{name}: keep FP32 (var={var:.6f} >= {high_var_threshold:.6f})")

            try:
                set_submodule_by_name(model, name, new_mod, device=device)
            except KeyError as e:
                print("WARNING: could not set submodule:", name, e)

        else:
            if isinstance(module, nn.Conv2d):
                print(f"{name}: kept Conv2d FP32")

    return model

def dump_layer_dtypes(root: nn.Module, filter_prefix=None):
    print("\nLayer weight dtypes (first 200 layers):")
    for i, (name, m) in enumerate(root.named_modules()):
        if i > 200:
            break
        if isinstance(m, (nn.Linear, nn.Conv2d)):
            w = getattr(m, "weight", None)
            if w is not None:
                print(f"{name:40s} : weight dtype={w.dtype}  module_type={type(m).__name__}")

model_adaptive = deepcopy(model)
model_adaptive.to(device)
model_adaptive.eval()

print("Collecting activation variance per layer...")
act_variances = compute_activation_variance(model_adaptive, test_loader, device)

print("Applying adaptive FP32/FP16/INT8 quantization (fixed replacer)...")
model_adaptive = apply_adaptive_mixed_precision_fixed(
    model_adaptive, act_variances, device=device, keep_conv_as_fp32=True)

dump_layer_dtypes(model_adaptive)

correct = 0
total = 0
start_time = time.time()

with torch.no_grad():
    for inputs, labels in tqdm(test_loader):
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model_adaptive(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

end_time = time.time()
inf_time = end_time - start_time

test_acc = 100 * correct / total
model_size = get_model_size(model_adaptive, "vgg11_adaptive_fp16_int8.pt")

print(f"\nAdaptive Mixed Precision (FP32/FP16/INT8): "
      f"Accuracy = {test_acc:.2f}%, Inference Time = {inf_time:.2f}s, Size = {model_size:.2f} MB")


Collecting activation variance per layer...
Applying adaptive FP32/FP16/INT8 quantization (fixed replacer)...
Adaptive thresholds (percentiles 30/70): low=3.450456, high=10.158555
features.0: kept Conv2d FP32
features.3: kept Conv2d FP32
features.6: kept Conv2d FP32
features.8: kept Conv2d FP32
features.11: kept Conv2d FP32
features.13: kept Conv2d FP32
features.16: kept Conv2d FP32
features.18: kept Conv2d FP32
classifier.0: -> FP16 (var=3.450456 between 3.450456 and 10.158555)
classifier.3: -> INT8 (var=0.794579 < 3.450456)
classifier.6: keep FP32 (var=25.867520 >= 10.158555)

Layer weight dtypes (first 200 layers):
features.0                               : weight dtype=torch.float32  module_type=Conv2d
features.3                               : weight dtype=torch.float32  module_type=Conv2d
features.6                               : weight dtype=torch.float32  module_type=Conv2d
features.8                               : weight dtype=torch.float32  module_type=Conv2d
features.11   

100%|██████████| 79/79 [05:23<00:00,  4.09s/it]



Adaptive Mixed Precision (FP32/FP16/INT8): Accuracy = 71.01%, Inference Time = 323.41s, Size = 248.82 MB
