# Train, Prune, and Quantize

In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import helper
from squeezenet_model import SqueezeNetCIFAR10, SqueezeNetCIFAR10_QAT
from alexnet_model import AlexNetCIFAR10, AlexNetCIFAR10_QAT
from resnet32_model import ResNet, ResNetQAT

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device={device}")

Device=cuda


In [2]:
mname = "squeezenet"
# mname = "alexnet"
# mname = "resnet"

if mname == "squeezenet":
    get_model = SqueezeNetCIFAR10
    get_model_qat = SqueezeNetCIFAR10_QAT
elif mname == "alexnet":
    get_model = AlexNetCIFAR10
    get_model_qat = AlexNetCIFAR10_QAT
elif mname == "resnet":
    get_model = ResNet
    get_model_qat = ResNetQAT

In [None]:
train_loader, test_loader = helper.load_dataset(batch_size=128)

In [None]:
# model_fp32 = get_model()
# # model_fp32.load_model('squeezenet_bn_cifar10_fp32.pth')

# total_params = sum(p.numel() for p in model_fp32.parameters())
# print(f"Total parameters: {total_params}")

# trainable_params = sum(p.numel() for p in model_fp32.parameters() if p.requires_grad)
# print(f"Trainable parameters: {trainable_params}")

Total parameters: 734986
Trainable parameters: 734986


## Training

In [None]:
# train, test = True, True
# epochs = 100
# fp32_metrics = helper.train_model(model=model_fp32,train_loader=train_loader,test_loader=test_loader,train=train,test=test,device=device,epochs=epochs)

In [None]:
# # model_fp32.save_model(f"{mname}_fp32.pth")
# torch.save(model_fp32.state_dict(), f"{mname}_fp32.pth")

In [None]:
# helper.plot_metrics(fp32_metrics)

In [None]:
# helper.plot_weight_histogram(model_fp32)

## Pruning

In [None]:
model = get_model()
model.load_model(f"{mname}_fp32.pth", device=device)
model.train()

In [None]:
import copy
import torch.nn.utils.prune as prune

prune_levels = [(0.1,'10'), (0.3,'30'), (0.5,'50'), (0.7,'70')]
results = []

best = 0.0
best_name = ""

for p, pname in prune_levels:
    m = copy.deepcopy(model)

    parameters_to_prune = []
    for module in m.modules():
        if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
            parameters_to_prune.append((module, "weight"))

    prune.global_unstructured(
        parameters_to_prune,
        pruning_method=prune.L1Unstructured,
        amount=p
    )

    # fine-tune
    metrics = helper.train_model(model=m,train_loader=train_loader,test_loader=test_loader,device=device)

    # remove pruning masks
    for module in m.modules():
        if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
            prune.remove(module, "weight")

    percent = p * 100
    fname = f"{mname}_{pname}.pth"
    # m.save_model(fname)
    torch.save(m.state_dict(), fname)
    
    acc = helper.evaluate(m,test_loader,device)
    if acc > best:
        best = acc
        best_name = fname

    results.append((p, metrics))

In [None]:
for p, metrics in results:
    print(f"Metrics for pruning with p={p}")
    helper.plot_metrics(metrics)

## Quantization Aware Training

In [None]:
import torch
from torch.ao.quantization import get_default_qat_qconfig
from torch.ao.quantization.quantize_fx import prepare_qat_fx, convert_fx

model_qat = get_model_qat()
# model_qat.load_model(best_name, device='cpu')
state_dict = torch.load(best_name)
model_qat.load_state_dict(state_dict)
model_qat.eval()

# example input for FX tracing
example_inputs = (torch.randn(1, 3, 32, 32, device='cpu'),)

# QAT config dictionary, default 8-bit symmetric QAT
qconfig_dict = {"": get_default_qat_qconfig('fbgemm')}

# prepare the model for QAT
model_qat_prepared = prepare_qat_fx(model_qat, qconfig_dict, example_inputs=example_inputs)

model_qat_prepared.to(device)
model_qat_prepared.train()

In [None]:
qat_metrics = helper.train_model(model=model_qat_prepared,train_loader=train_loader,test_loader=test_loader,device=device,epochs=20)

In [None]:
helper.plot_metrics(qat_metrics)

In [None]:
helper.plot_weight_histogram(model_qat_prepared)

In [16]:
model_qat_prepared.eval()
torch.save(model_qat_prepared.state_dict(), f"{mname}_qat_preconvert.pth")
# model_int8 = convert_fx(model_qat_prepared.cpu())
# torch.save(model_int8.state_dict(), "squeezenet_int8_qat.pth")