In [None]:
import numpy as np

# quantize weights for PTQ
def uniform_symmetric_quant(t, b):
    tensor = t.detach().numpy()
    s = np.max(np.abs(tensor)) / (2**(b-1)-1)
    tensor_int = np.round(tensor/s)
    tensor_float = tensor_int * s
    return torch.tensor(tensor_float)

# transfer weights
def transfer_weights(quant_model, input_dim, width):
    # create new models
    model_latent_weights = QuantNet(input_dim=input_dim,width=width,weight_quantizer=None)
    model_quant_weights = QuantNet(input_dim=input_dim,width=width,weight_quantizer=None)
    
    # latent weight model
    model_latent_weights.fc1.weight.data = quant_model.fc1.weight.detach()
    model_latent_weights.fc1.bias.data = quant_model.fc1.bias.detach()
    model_latent_weights.fc2.weight.data = quant_model.fc2.weight.detach()
    model_latent_weights.fc2.bias.data = quant_model.fc2.bias.detach()
    model_latent_weights.fc3.weight.data = quant_model.fc3.weight.detach()
    model_latent_weights.fc3.bias.data = quant_model.fc3.bias.detach()
    model_latent_weights.fc4.weight.data = quant_model.fc4.weight.detach()
    model_latent_weights.fc4.bias.data = quant_model.fc4.bias.detach()
    model_latent_weights.fc5.weight.data = quant_model.fc5.weight.detach()
    model_latent_weights.fc5.bias.data = quant_model.fc5.bias.detach()
    
    # quantized weight model
    model_quant_weights.fc1.weight.data = quant_model.fc1.quant_weight().tensor.detach()
    model_quant_weights.fc1.bias.data = quant_model.fc1.bias.detach()
    model_quant_weights.fc2.weight.data = quant_model.fc2.quant_weight().tensor.detach()
    model_quant_weights.fc2.bias.data = quant_model.fc2.bias.detach()
    model_quant_weights.fc3.weight.data = quant_model.fc3.quant_weight().tensor.detach()
    model_quant_weights.fc3.bias.data = quant_model.fc3.bias.detach()
    model_quant_weights.fc4.weight.data = quant_model.fc4.quant_weight().tensor.detach()
    model_quant_weights.fc4.bias.data = quant_model.fc4.bias.detach()
    model_quant_weights.fc5.weight.data = quant_model.fc5.quant_weight().tensor.detach()
    model_quant_weights.fc5.bias.data = quant_model.fc5.bias.detach()
    
    return model_latent_weights, model_quant_weights

def PTQ(baseline_model, b):
    baseline_model.fc1.weight.data = uniform_symmetric_quant(baseline_model.fc1.weight, b)
    baseline_model.fc2.weight.data = uniform_symmetric_quant(baseline_model.fc2.weight, b)
    baseline_model.fc3.weight.data = uniform_symmetric_quant(baseline_model.fc3.weight, b)
    baseline_model.fc4.weight.data = uniform_symmetric_quant(baseline_model.fc4.weight, b)
    baseline_model.fc5.weight.data = uniform_symmetric_quant(baseline_model.fc5.weight, b)

In [None]:
import matplotlib.pyplot as plt 
import torch
import copy
import torch.nn as nn
from torch_cka import CKA
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import brevitas.nn as qnn
from brevitas.quant import Int8WeightPerTensorFloat
from brevitas.core.zero_point import ZeroZeroPoint
from brevitas.inject.enum import *
from src.util.util import train, test, build_dataloaders, QuantNet 


criterion = nn.CrossEntropyLoss()

def run_experiment(BITS, ROUND, EPOCHS=10, BASELINE=False):
    class WeigthQuant(Int8WeightPerTensorFloat):
        quant_type = QuantType.INT  # integer quantization
        bit_width_impl_type = BitWidthImplType.CONST  # constant bit width
        scaling_impl_type = ScalingImplType.STATS  # scale based on statistics
        scaling_stats_op = StatsOp.MAX  # scale statistics is the absmax value
        restrict_scaling_type = RestrictValueType.FP  # scale factor is a floating point value
        scaling_per_output_channel = False  # scale is per tensor
        signed = True  # quantization range is signed
        narrow_range = True  # quantization range is [-127,127] rather than [-128, 127]
        zero_point_impl = ZeroZeroPoint
        bit_width = BITS
        if ROUND == "stochastic":
            float_to_int_impl_type = FloatToIntImplType.STOCHASTIC_ROUND
        else:
            float_to_int_impl_type = FloatToIntImplType.ROUND

    if BASELINE:
        quantizer = None
    else:
        quantizer = WeigthQuant

    trainloader, testloader = build_dataloaders()

    model = QuantNet(input_dim=32*32*3, width=64, weight_quantizer=quantizer)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()
    
    train_losses, test_losses, l0_norms, l1_norms, l2_norms = train(model=model, optimizer=optimizer, 
                                                                    criterion=criterion, train_loader=trainloader, 
                                                                    test_loader=testloader, epochs=EPOCHS, baseline=BASELINE)
    # Train a quantized model and create:
    # latent weights model and quant weights model
    quantized_model = QuantNet(input_dim=32*32*3, width=64, weight_quantizer=quantizer)
    optimizer = optim.Adam(quantized_model.parameters(), lr=0.001)
    print("Training quantized")
    train(model=quantized_model, optimizer=optimizer, criterion=criterion, train_loader=trainloader,
          test_loader=testloader, epochs=EPOCHS, baseline=BASELINE)
     
    model_latent_weights, model_quant_weights = transfer_weights(quantized_model, 32*32*3, 64)
    
    # train a baseline and create:
    # baseline and baseline ptq
    baseline_model = QuantNet(input_dim=32*32*3, width=64, weight_quantizer=None)
    optimizer = optim.Adam(baseline_model.parameters(), lr=0.001)
    print("Training baseline")
    train(model=baseline_model, optimizer=optimizer, criterion=criterion, train_loader=trainloader,
          test_loader=testloader, epochs=EPOCHS, baseline=BASELINE)
    
    baseline_model_PTQ = copy.deepcopy(baseline_model)
    
    # apply PTQ
    PTQ(baseline_model_PTQ, BITS)
    
    # train a second baseline for comparison
    baseline_model2 = QuantNet(input_dim=32*32*3, width=64, weight_quantizer=None)
    optimizer = optim.Adam(baseline_model2.parameters(), lr=0.001)
    print("Training baseline 2")
    train(model=baseline_model2, optimizer=optimizer, criterion=criterion, train_loader=trainloader,
          test_loader=testloader, epochs=EPOCHS, baseline=BASELINE)
    
    models = [baseline_model2, model_latent_weights, model_quant_weights]
    names = ["Second Baseline", "QAT latent weights", "QAT Quant weights"]
    accuracies = {}
    for model, name in zip([baseline_model] + models, ["FP32"] + names):
        accuracy = test(model, criterion, testloader)
        accuracies[name] = f'{accuracy:.2f}%'
        
    # Plot heat map
    global_min = 0
    global_max = 1
    
    num_models = len(models)
    fig, axs = plt.subplots(1, num_models, figsize=(5 * num_models, 5))
    
    for idx, (model, name) in enumerate(zip(models, names)):
        cka = CKA(baseline_model, model,
              model1_name="FP32",
              model2_name=name,
              model1_layers=['fc1', 'fc2', 'fc3', 'fc4', 'fc5'],
              model2_layers=['fc1', 'fc2', 'fc3', 'fc4', 'fc5'],
              device='cpu')
        cka.compare(testloader, testloader)
        results = cka.export()
    
        ax = axs[idx]
        im = ax.imshow(results['CKA'], origin='lower', cmap='magma', vmin=global_min, vmax=global_max)
        ax.set_xticks(range(5))
        ax.set_yticks(range(5))
        ax.set_xticklabels(['1', '2', '3', '4', '5'])
        ax.set_yticklabels(['1', '2', '3', '4', '5'])
        ax.set_xlabel(f"Layers {name}", fontsize=12)
        ax.set_ylabel(f"Layers FP32", fontsize=12)
        ax.set_title(f"FP32 vs {name}", fontsize=12) # \n{accuracies['FP32']} / {accuracies[name]}
        ax.grid(False)
        # diagonal with CKA values
        for i in range(5):
            text_color = 'white' if results['CKA'][i, i] < (global_min + global_max) / 2 else 'black'
            ax.text(i, i, f'{results["CKA"][i, i]:.2f}', ha='center', va='center', color=text_color)
            
    if BITS == 2:
        fig.suptitle(f'Ternary', fontsize=20)
    else:
        fig.suptitle(f'{BITS}bit', fontsize=20)
    plt.show()
    
    return model, train_losses, test_losses, l0_norms, l1_norms, l2_norms


In [None]:
run_experiment(2, "ROUND", EPOCHS=50, BASELINE=False)