In [None]:
from __future__ import print_function
import h5py
from numpy import *
from matplotlib.pyplot import *
import torch
import torch.utils.data
from torch import nn, optim
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np

import os
from torchinfo import summary

In [None]:
from hawq.utils.quantization_utils.quant_modules import QuantConv2d, QuantLinear, QuantAct

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device 

In [None]:
csr = range(500, 1500)
sr = len(csr)
hn = sr * 2 * 1


class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()

        self.linear1 = nn.Linear(sr * 2, int(hn / 8))
        self.relu1 = nn.ReLU()
        self.bn = nn.BatchNorm1d(int(hn / 8), affine=True)

        self.linear2 = nn.Linear(int(hn / 8), 2)
        self.relu2 = nn.ReLU()

    def forward(self, sig):
        x = self.linear1(sig)
        x = self.relu1(x)
        x = self.bn(x)

        x = self.linear2(x)
        x = self.relu2(x)
        return x

In [None]:
base_model = Classifier()
base_model.load_state_dict(torch.load("checkpoints/checkpoint_tiny_affine.pth"))

In [None]:
class Q_Classifier(nn.Module):
    def __init__(self, model):
        super(Q_Classifier, self).__init__()

        self.quant_input = QuantAct(activation_bit=12)
        self.q_relu1 = QuantAct(activation_bit=12)
        self.q_relu2 = QuantAct(activation_bit=12)

        layer = getattr(model, 'linear1')
        hawq_layer = QuantLinear(weight_bit=6, bias_bit=8)
        hawq_layer.set_param(layer)
        setattr(self, 'linear1', hawq_layer)

        layer = getattr(model, 'linear2')
        hawq_layer = QuantLinear(weight_bit=6, bias_bit=8)
        hawq_layer.set_param(layer)
        setattr(self, 'linear2', hawq_layer)

        self.relu = nn.ReLU()
        self.bn = nn.BatchNorm1d(int(hn / 8), affine=True)

    def forward(self, sig):
        x, p_sf = self.quant_input(sig)

        x, w_sf = self.linear1(x, p_sf)
        x = self.relu(x)
        x, p_sf = self.q_relu1(x, p_sf, w_sf)

        x = self.bn(x)

        x = self.linear2(x)
        x = self.relu2(x)
        x, p_sf = self.q_relu2(x, p_sf, w_sf)
        return x

In [None]:
model = Q_Classifier(base_model)