# Quantization

In [None]:
import torch
from torch.autograd.function import InplaceFunction, Function
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.autograd import Variable
import math
import numpy as np
import torch.nn.init as init
import torchvision
from tqdm import tqdm
import json
from pathlib import Path

In [None]:
def quantize_model(model, quantize = False, bits = 8, qmode = "dynamic"):
    if quantize:
        print("Quantize mode on")
        for layer in model.modules():
            try:
                mode = layer.mode()
                if mode == False:
                    layer.change_mod(True, bits, qmode)
            except:
                continue
    else:
        print("Quantize mode off")
        for layer in model.modules():
            try:
                mode = layer.mode()
                if mode == True:
                    layer.change_mod(False, 0)
            except:
                continue
    return model

In [None]:
def qsin_activation_mode(model):
    print("QSIN activation mode on")
    for layer in model.modules():
        try:
            mode = layer.mode()
            layer.qsinmode()
        except:
            continue
    return model

In [None]:
class MyRound(torch.autograd.Function):
    
    @staticmethod
    def forward(ctx, input):
        return torch.round(input)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output

In [None]:
def Quantize_tensor(input_tensor, max_abs_val = None, num_bits = 8):
    my_round = MyRound.apply
    qmin = -1.0 * (2**num_bits) / 2
    qmax = -qmin - 1
    scale = max_abs_val / ((qmax - qmin) / 2)
    input_tensor = torch.div(input_tensor, scale)
    input_tensor = my_round((input_tensor))
    input_tensor = torch.clamp(input_tensor, qmin, qmax)
    return torch.mul(input_tensor, scale)

In [None]:
class Quant(nn.Module):
    def __init__(self, num_bits=8, mode = "dynamic", static_count = 30):
        super(Quant, self).__init__()
        self.num_bits = num_bits
        self.mode = mode
        self.static_count = static_count
        self.static_cur = 0
        self.stat_values = []
        self.max_abs = 0 
        if mode != "dynamic":
            self.max_abs_tr = nn.Parameter(torch.zeros(0), requires_grad=True) # IMPORTANT
        
    def forward(self, input):
        if self.mode == "dynamic":
            self.max_abs = torch.max(torch.abs(input.detach()))
            return Quantize_tensor(input, self.max_abs, self.num_bits)
        
        elif self.mode == "static":
            if self.static_cur > self.static_count:
                return Quantize_tensor(input, self.max_abs_tr, self.num_bits)
            elif self.static_cur == self.static_count:
                self.max_abs = np.mean(self.stat_values)
                self.max_abs_tr.data = torch.tensor(self.max_abs, dtype=torch.float).to(self.max_abs_tr.device)
                self.static_cur += 1
                return Quantize_tensor(input, self.max_abs_tr, self.num_bits)
            else:
                self.static_cur += 1
                self.stat_values.append(np.max(np.absolute(input.cpu().detach().numpy())))
                return input

In [None]:
def QSin(x, num_bits = 8): 
    pi = torch.tensor(np.pi)
    qmin = -1.0 * (2**num_bits) / 2
    qmax = -qmin - 1
    result = torch.sum(torch.square(torch.sin(torch.mul(pi, x[torch.logical_and(x >= qmin, x <= qmax)]))))
    result = result + torch.sum(torch.mul(torch.square(pi), torch.square((x[x < qmin] - qmin))))
    result = result + torch.sum(torch.mul(torch.square(pi), torch.square((x[x > qmax] - qmax))))
    return result

In [None]:
class Linear(nn.Linear):
    def __init__(self, in_features: int, out_features: int, bias: bool = True, quantization: bool = False, q_bits: int = 8, qsin_activation = False):
        super(Linear, self).__init__(in_features, out_features, bias)

        self.quantize = True if quantization else False
        self.QsinA = True if qsin_activation else False

        if self.quantize:
            self.bits = q_bits
            self.Quantize_weights = Quant(self.bits)
            self.Quantize_input = Quant(8) 
        else:
            self.bits = 'FP'
            
        if self.QsinA:
            self.qsin_loss_A = 0

    def init(self, input):
        self.inputW = input.shape
        
    def change_mod(self, value, bits = 8, mode = "dynamic"):
        self.quantize = value
        self.bits = bits
        self.Quantize_weights = Quant(bits, mode)
        self.Quantize_input = Quant(8, mode)
        
    def qsinmode(self):
        self.QsinA = True
        self.qsin_loss_A = 0
        
    def mode(self):
        return self.quantize  

    def forward(self, input):
            
        if self.quantize:
            qinput = self.Quantize_input(input)
            qweight = self.Quantize_weights(self.weight)
            
            #count qsin loss on activation
            if self.QsinA:
                self.qsin_loss_A = 0
                qmin = -1.0 * (2**8) / 2
                qmax = -qmin - 1
                scale = self.Quantize_input.max_abs_tr / ((qmax - qmin) / 2)
                sq_scale = torch.square(scale)
                self.qsin_loss_A = torch.mul(sq_scale, QSin(torch.div(input, scale), 8))
            
            return nn.functional.linear(qinput, qweight, self.bias)
        else:
            return nn.functional.linear(input, self.weight, self.bias)

In [None]:
class Embedding(nn.Embedding):
    def __init__(self, num_embeddings, embedding_dim, padding_idx=None, max_norm=None,
                 norm_type=2.0, scale_grad_by_freq=False, sparse=False,
                 quantization: bool = False, q_bits: int = 8):
        super(Embedding, self).__init__(num_embeddings, embedding_dim, padding_idx)

        self.quantize = True if quantization else False

        if self.quantize:
            self.bits = q_bits
            self.Quantize_weights = Quant(self.bits)
        else:
            self.bits = 'FP'

    def init(self, input):
        self.inputW = input.shape
        
    def change_mod(self, value, bits = 8, mode = "dynamic"):
        self.quantize = value
        self.bits = bits
        self.Quantize_weights = Quant(bits, mode)
        
    def mode(self):
        return self.quantize  

    def forward(self, input):
            
        if self.quantize:
            qweight = self.Quantize_weights(self.weight)
        
            return nn.functional.embedding(input, qweight, self.padding_idx, self.max_norm,
                 self.norm_type, self.scale_grad_by_freq, self.sparse)
        else:
            return nn.functional.embedding(input, self.weight, self.padding_idx, self.max_norm,
                 self.norm_type, self.scale_grad_by_freq, self.sparse)
        

In [None]:
def Qsin_W(model, bits = 8):
    qmin = -1.0 * (2**bits) / 2
    qmax = -qmin - 1
    loss = 0
    for layer in model.modules():
        try:
            scale = layer.Quantize_weights.max_abs_tr / ((qmax - qmin) / 2)
            sq_scale = torch.square(scale)
            QSin_w = QSin(torch.div(layer.weight, scale), bits)
            loss = loss + torch.mul(sq_scale, QSin_w)
        except:
            continue
    return loss

In [None]:
def Qsin_A(model):
    loss = 0
    for layer in model.modules():
        try:
            loss = loss + layer.qsin_loss_A
        except:
            continue
    return loss

In [None]:
def get_custom_Linear(in_features, out_features, bias, weight):
    linear = Linear(in_features, out_features)
    linear.bias = bias
    linear.weight = weight
    return linear

def get_custom_Embeding(num_embeddings, embedding_dim, padding_idx, weight):
    embedding = Embedding(num_embeddings, embedding_dim, padding_idx)
    embedding.weight = weight
    return embedding

def change_layers(model):
    for name, layer in model.named_children():
        #if name == 'intermediate' or \
        #name == 'output'or name == 'embeddings':
        #   continue
        if isinstance(layer, nn.Linear):
            setattr(model, name, get_custom_Linear(
                                                layer.in_features,
                                                layer.out_features,
                                                layer.bias,
                                                layer.weight
            ))
            
        if isinstance(layer, nn.Embedding):
            setattr(model, name, get_custom_Embeding(
                                                layer.num_embeddings,
                                                layer.embedding_dim,
                                                layer.padding_idx,
                                                layer.weight
            ))
        change_layers(getattr(model, name))

In [None]:
#pooler
#classifier
#attention

In [None]:
def quantize_model(model, quantize = False, bits = 8, qmode = "dynamic"):
    if quantize:
        print("Quantize mode on")
        for name, layer in model.named_modules():
            try:
                mode = layer.mode()
                if mode == False:
                    if 'pooler' in name or 'attention' in name or 'token_type_embeddings' in name:
                        layer.change_mod(True, 4, qmode)
                    elif 'classifier' in name: continue 
                    else: layer.change_mod(True, 8, qmode)
            except:
                continue
    else:
        print("Quantize mode off")
        for layer in model.modules():
            try:
                mode = layer.mode()
                if mode == True:
                    layer.change_mod(False, 0)
            except:
                continue
    return model

In [None]:
#pooler
#classifier
#attention

        #if name == 'intermediate' or \
        #name == 'output'or name == 'embeddings':
        #    continue

# SQUAD metric

In [None]:
""" SQuAD metric. """

import datasets

from evaluate import evaluate


_CITATION = ""

_DESCRIPTION = ""

_KWARGS_DESCRIPTION = ""

@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class Squad(datasets.Metric):
    def _info(self):
        return datasets.MetricInfo(
            description=_DESCRIPTION,
            citation=_CITATION,
            inputs_description=_KWARGS_DESCRIPTION,
            features=datasets.Features(
                {
                    "predictions": {"id": datasets.Value("string"), "prediction_text": datasets.Value("string")},
                    "references": {
                        "id": datasets.Value("string"),
                        "answers": datasets.features.Sequence(
                            {
                                "text": datasets.Value("string"),
                                "answer_start": datasets.Value("int32"),
                            }
                        ),
                    },
                }
            ),
            codebase_urls=["https://rajpurkar.github.io/SQuAD-explorer/"],
            reference_urls=["https://rajpurkar.github.io/SQuAD-explorer/"],
        )

    def _compute(self, predictions, references):
        pred_dict = {prediction["id"]: prediction["prediction_text"] for prediction in predictions}
        dataset = [
            {
                "paragraphs": [
                    {
                        "qas": [
                            {
                                "answers": [{"text": answer_text} for answer_text in ref["answers"]["text"]],
                                "id": ref["id"],
                            }
                            for ref in references
                        ]
                    }
                ]
            }
        ]
        score = evaluate(dataset=dataset, predictions=pred_dict)
        return score

In [None]:
from tqdm.auto import tqdm
import collections

def eval_squad(trainer, validation_features, datasets, metric):

    raw_predictions = trainer.predict(validation_features)
    validation_features.set_format(type=validation_features.format["type"], columns=list(validation_features.features.keys()))
    max_answer_length = 30

    examples = datasets["validation"]
    features = validation_features

    example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
    features_per_example = collections.defaultdict(list)
    for i, feature in enumerate(features):
        features_per_example[example_id_to_index[feature["example_id"]]].append(i)

    def postprocess_qa_predictions(examples, features, raw_predictions, n_best_size = 20, max_answer_length = 30):
        all_start_logits, all_end_logits = raw_predictions
        # Build a map example to its corresponding features.
        example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
        features_per_example = collections.defaultdict(list)
        for i, feature in enumerate(features):
            features_per_example[example_id_to_index[feature["example_id"]]].append(i)

        # The dictionaries we have to fill.
        predictions = collections.OrderedDict()

        # Logging.
        print(f"Post-processing {len(examples)} example predictions split into {len(features)} features.")

        # Let's loop over all the examples!
        for example_index, example in enumerate(tqdm(examples)):
            # Those are the indices of the features associated to the current example.
            feature_indices = features_per_example[example_index]

            min_null_score = None # Only used if squad_v2 is True.
            valid_answers = []

            context = example["context"]
            # Looping through all the features associated to the current example.
            for feature_index in feature_indices:
                # We grab the predictions of the model for this feature.
                start_logits = all_start_logits[feature_index]
                end_logits = all_end_logits[feature_index]
                # This is what will allow us to map some the positions in our logits to span of texts in the original
                # context.
                offset_mapping = features[feature_index]["offset_mapping"]

                # Update minimum null prediction.
                cls_index = features[feature_index]["input_ids"].index(tokenizer.cls_token_id)
                feature_null_score = start_logits[cls_index] + end_logits[cls_index]
                if min_null_score is None or min_null_score < feature_null_score:
                    min_null_score = feature_null_score

                # Go through all possibilities for the `n_best_size` greater start and end logits.
                start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()
                end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()
                for start_index in start_indexes:
                    for end_index in end_indexes:
                        # Don't consider out-of-scope answers, either because the indices are out of bounds or correspond
                        # to part of the input_ids that are not in the context.
                        if (
                            start_index >= len(offset_mapping)
                            or end_index >= len(offset_mapping)
                            or offset_mapping[start_index] is None
                            or offset_mapping[end_index] is None
                        ):
                            continue
                        # Don't consider answers with a length that is either < 0 or > max_answer_length.
                        if end_index < start_index or end_index - start_index + 1 > max_answer_length:
                            continue

                        start_char = offset_mapping[start_index][0]
                        end_char = offset_mapping[end_index][1]
                        valid_answers.append(
                            {
                                "score": start_logits[start_index] + end_logits[end_index],
                                "text": context[start_char: end_char]
                            }
                        )

            if len(valid_answers) > 0:
                best_answer = sorted(valid_answers, key=lambda x: x["score"], reverse=True)[0]
            else:
                # In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid
                # failure.
                best_answer = {"text": "", "score": 0.0}
            answer = best_answer["text"] if best_answer["score"] > min_null_score else ""
            predictions[example["id"]] = answer

        return predictions

    final_predictions = postprocess_qa_predictions(datasets["validation"], validation_features, raw_predictions.predictions)
    formatted_predictions = [{"id": k, "prediction_text": v} for k, v in final_predictions.items()]
    references = [{"id": ex["id"], "answers": ex["answers"]} for ex in datasets["validation"]]
    print(metric.compute(predictions=formatted_predictions, references=references))

# BERT

In [None]:
from transformers import AutoModelForQuestionAnswering
model = AutoModelForQuestionAnswering.from_pretrained('./models/squad/model-bert-base/', local_files_only=True)

In [None]:
from transformers import DistilBertTokenizer

tokenizer = DistilBertTokenizer.from_pretrained('./models/tokenizer-bert-base/', local_files_only=True)

batch_size = 16
max_length = 384
doc_stride = 128

In [None]:
from datasets import load_from_disk
encoded_dataset = load_from_disk('cur_squad_data')

In [None]:
change_layers(model)

In [None]:
from transformers import TrainingArguments, Trainer
args = TrainingArguments(
    f"test-squad",
    evaluation_strategy = "epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    num_train_epochs=3,
    weight_decay=0.01,
)

In [None]:
from transformers import default_data_collator

data_collator = default_data_collator

In [None]:
trainer = Trainer(
    model,
    args,
    train_dataset=encoded_dataset["train"],
    eval_dataset=encoded_dataset["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
)

In [None]:
trainer.train()

In [None]:
validation_features = encoded_dataset = load_from_disk('cur_squad_val_data')
metric = Squad()
datasets = load_from_disk("cur_squad_set_data/")

In [None]:
eval_squad(trainer, validation_features, datasets, metric)

In [None]:
torch.cuda.empty_cache()

In [None]:
from transformers import AutoModelForQuestionAnswering
model = AutoModelForQuestionAnswering.from_pretrained('./test-squad/checkpoint-8000/', local_files_only=True)

In [None]:
from transformers import TrainingArguments, Trainer
encoded_dataset = load_from_disk('cur_squad_data')
args = TrainingArguments(
    f"test-squad",
    evaluation_strategy = "epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    num_train_epochs=3,
    weight_decay=0.01,
)
trainer = Trainer(
    model,
    args,
    train_dataset=encoded_dataset["train"],
    eval_dataset=encoded_dataset["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
)

In [None]:
change_layers(model)

In [None]:
eval_squad(trainer, validation_features, datasets, metric)

In [None]:
model = quantize_model(model, quantize=False, bits = 4)
model = quantize_model(model, quantize=True, bits = 4)

In [None]:
encoded_dataset = load_from_disk('cur_squad_data')
trainer_DQ = Trainer(
    model,
    args,
    eval_dataset=encoded_dataset["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
)

In [None]:
eval_squad(trainer_DQ, validation_features, datasets, metric)

In [None]:
model = quantize_model(model, quantize=False, bits = 4)
model = quantize_model(model, quantize=True, bits = 4, qmode = "static")

In [None]:
encoded_dataset = load_from_disk('cur_squad_data')
train_enc = encoded_dataset['train']

In [None]:
from torch.utils.data import DataLoader
train_enc.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'start_positions', 'end_positions'])
train_loader = torch.utils.data.DataLoader(train_enc, batch_size=8, shuffle = True)

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)
model.eval()
i = 0
for batch in tqdm(train_loader):
    i += 1
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    start_positions = batch['start_positions'].to(device)
    end_positions = batch['end_positions'].to(device)
    outputs = model(input_ids, attention_mask=attention_mask, start_positions=start_positions, end_positions=end_positions)
    if i == 32:
        break

In [None]:
encoded_dataset = load_from_disk('cur_squad_data')
trainer_SQ = Trainer(
    model,
    args,
    eval_dataset=encoded_dataset["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
)

In [None]:
eval_squad(trainer_SQ, validation_features, datasets, metric)

# QSIN

In [None]:
qsin_activation_mode(model)
print()

In [None]:
torch.cuda.empty_cache()

In [None]:
import torch
from transformers import Trainer

class QSinTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        outputs = model(**inputs)
        Qsin_W_loss = Qsin_W(model, 4)
        Qsin_A_loss = Qsin_A(model)
        L = outputs[0]
        lambda_w = 10 ** (np.round(np.log10(Qsin_W_loss.cuda().tolist()) - np.log10(L.cuda().tolist())))
        lambda_a = 10 ** (np.round(np.log10(Qsin_A_loss.cuda().tolist()) - np.log10(L.cuda().tolist()))+1)
        loss = L + Qsin_W_loss / lambda_w + Qsin_A_loss / lambda_a
        return (loss, outputs) if return_outputs else loss

In [None]:
validation_features = encoded_dataset = load_from_disk('cur_squad_val_data')
metric = Squad()
datasets = load_from_disk("cur_squad_set_data/")

In [None]:
from transformers import TrainingArguments, Trainer
encoded_dataset = load_from_disk('cur_squad_data')
args = TrainingArguments(
    "qsin_train_tmp",
    evaluation_strategy = "epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=12,
    per_device_eval_batch_size=32,
    #eval_steps=10,
    num_train_epochs=3,
    weight_decay=0.01
)

trainer_QSin = QSinTrainer(
    model,
    args,
    train_dataset=encoded_dataset["train"],
    eval_dataset=encoded_dataset['validation'],
    tokenizer=tokenizer
)

In [None]:
from transformers import TrainerCallback

class MyCallback(TrainerCallback):
    "A callback that prints a message at the beginning of training"

    def on_evaluate(self, args, state, control, model, **kwargs):
        print("Starting evaluate:")
        cur_tr = QSinTrainer(model,args,train_dataset=encoded_dataset["train"],
                             eval_dataset=encoded_dataset['validation'],tokenizer=tokenizer)
        eval_squad(cur_tr, validation_features, datasets, metric)
        



In [None]:
trainer_QSin.add_callback(MyCallback)

In [None]:
trainer_QSin.train()

In [None]:
eval_squad(trainer_QSin, validation_features, datasets, metric)

# QAT

In [None]:
from transformers import TrainingArguments, Trainer
args = TrainingArguments(
    "qat_train_tmp",
    evaluation_strategy = "epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=12,
    per_device_eval_batch_size=32,
    #eval_steps=10,
    num_train_epochs=3,
    weight_decay=0.01
)

trainer_qat = Trainer(
    model,
    args,
    train_dataset=encoded_dataset["train"],
    eval_dataset=encoded_dataset['validation'],
    tokenizer=tokenizer
)

In [None]:
torch.cuda.empty_cache()

In [None]:
from transformers import TrainerCallback

class MyCallback(TrainerCallback):
    "A callback that prints a message at the beginning of training"

    def on_evaluate(self, args, state, control, model, **kwargs):
        print("Starting evaluate:")
        cur_tr = Trainer(model,args,train_dataset=encoded_dataset["train"],
                             eval_dataset=encoded_dataset['validation'],tokenizer=tokenizer)
        eval_squad(cur_tr, validation_features, datasets, metric)
        
trainer_qat.add_callback(MyCallback)


In [None]:
trainer_qat.train()

In [None]:
validation_features = encoded_dataset = load_from_disk('cur_squad_val_data')
metric = Squad()
datasets = load_from_disk("cur_squad_set_data/")
eval_squad(trainer_qat, validation_features, datasets, metric)