# Quantization

In [1]:
import torch
from torch.autograd.function import InplaceFunction, Function
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.autograd import Variable
import math
import numpy as np
import torch.nn.init as init
import torchvision
from tqdm import tqdm
import json
from pathlib import Path

In [2]:
def quantize_model(model, quantize = False, bits = 8, qmode = "dynamic"):
    if quantize:
        print("Quantize mode on")
        for layer in model.modules():
            try:
                mode = layer.mode()
                if mode == False:
                    layer.change_mod(True, bits, qmode)
            except:
                continue
    else:
        print("Quantize mode off")
        for layer in model.modules():
            try:
                mode = layer.mode()
                if mode == True:
                    layer.change_mod(False, 0)
            except:
                continue
    return model

In [3]:
def qsin_activation_mode(model):
    print("QSIN activation mode on")
    for layer in model.modules():
        try:
            mode = layer.mode()
            layer.qsinmode()
        except:
            continue
    return model

In [4]:
class MyRound(torch.autograd.Function):
    
    @staticmethod
    def forward(ctx, input):
        return torch.round(input)

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output

In [5]:
def Quantize_tensor(input_tensor, max_abs_val = None, num_bits = 8):
    my_round = MyRound.apply
    qmin = -1.0 * (2**num_bits) / 2
    qmax = -qmin - 1
    scale = max_abs_val / ((qmax - qmin) / 2)
    input_tensor = torch.div(input_tensor, scale)
    input_tensor = my_round((input_tensor))
    input_tensor = torch.clamp(input_tensor, qmin, qmax)
    return torch.mul(input_tensor, scale)

In [6]:
class Quant(nn.Module):
    def __init__(self, num_bits=8, mode = "dynamic", static_count = 30):
        super(Quant, self).__init__()
        self.num_bits = num_bits
        self.mode = mode
        self.static_count = static_count
        self.static_cur = 0
        self.stat_values = []
        self.max_abs = 0 
        if mode != "dynamic":
            self.max_abs_tr = nn.Parameter(torch.zeros(0), requires_grad=True) # IMPORTANT
        
    def forward(self, input):
        if self.mode == "dynamic":
            self.max_abs = torch.max(torch.abs(input.detach()))
            return Quantize_tensor(input, self.max_abs, self.num_bits)
        
        elif self.mode == "static":
            if self.static_cur > self.static_count:
                return Quantize_tensor(input, self.max_abs_tr, self.num_bits)
            elif self.static_cur == self.static_count:
                self.max_abs = np.mean(self.stat_values)
                self.max_abs_tr.data = torch.tensor(self.max_abs, dtype=torch.float).to(self.max_abs_tr.device)
                self.static_cur += 1
                return Quantize_tensor(input, self.max_abs_tr, self.num_bits)
            else:
                self.static_cur += 1
                self.stat_values.append(np.max(np.absolute(input.cpu().detach().numpy())))
                return input

In [7]:
def QSin(x, num_bits = 8): 
    pi = torch.tensor(np.pi)
    qmin = -1.0 * (2**num_bits) / 2
    qmax = -qmin - 1
    result = torch.sum(torch.square(torch.sin(torch.mul(pi, x[torch.logical_and(x >= qmin, x <= qmax)]))))
    result = result + torch.sum(torch.mul(torch.square(pi), torch.square((x[x < qmin] - qmin))))
    result = result + torch.sum(torch.mul(torch.square(pi), torch.square((x[x > qmax] - qmax))))
    return result

In [8]:
class Linear(nn.Linear):
    def __init__(self, in_features: int, out_features: int, bias: bool = True, quantization: bool = False, q_bits: int = 8, qsin_activation = False):
        super(Linear, self).__init__(in_features, out_features, bias)

        self.quantize = True if quantization else False
        self.QsinA = True if qsin_activation else False

        if self.quantize:
            self.bits = q_bits
            self.Quantize_weights = Quant(self.bits)
            self.Quantize_input = Quant(8) 
        else:
            self.bits = 'FP'
            
        if self.QsinA:
            self.qsin_loss_A = 0

    def init(self, input):
        self.inputW = input.shape
        
    def change_mod(self, value, bits = 8, mode = "dynamic"):
        self.quantize = value
        self.bits = bits
        self.Quantize_weights = Quant(bits, mode)
        self.Quantize_input = Quant(8, mode)
        
    def qsinmode(self):
        self.QsinA = True
        self.qsin_loss_A = 0
        
    def mode(self):
        return self.quantize  

    def forward(self, input):
            
        if self.quantize:
            qinput = self.Quantize_input(input)
            qweight = self.Quantize_weights(self.weight)
            
            #count qsin loss on activation
            if self.QsinA:
                self.qsin_loss_A = 0
                qmin = -1.0 * (2**8) / 2
                qmax = -qmin - 1
                scale = self.Quantize_input.max_abs_tr / ((qmax - qmin) / 2)
                sq_scale = torch.square(scale)
                self.qsin_loss_A = torch.mul(sq_scale, QSin(torch.div(input, scale), 8))
            
            return nn.functional.linear(qinput, qweight, self.bias)
        else:
            return nn.functional.linear(input, self.weight, self.bias)

In [9]:
class Embedding(nn.Embedding):
    def __init__(self, num_embeddings, embedding_dim, padding_idx=None, max_norm=None,
                 norm_type=2.0, scale_grad_by_freq=False, sparse=False,
                 quantization: bool = False, q_bits: int = 8):
        super(Embedding, self).__init__(num_embeddings, embedding_dim, padding_idx)

        self.quantize = True if quantization else False

        if self.quantize:
            self.bits = q_bits
            self.Quantize_weights = Quant(self.bits)
        else:
            self.bits = 'FP'

    def init(self, input):
        self.inputW = input.shape
        
    def change_mod(self, value, bits = 8, mode = "dynamic"):
        self.quantize = value
        self.bits = bits
        self.Quantize_weights = Quant(bits, mode)
        
    def mode(self):
        return self.quantize  

    def forward(self, input):
            
        if self.quantize:
            qweight = self.Quantize_weights(self.weight)
        
            return nn.functional.embedding(input, qweight, self.padding_idx, self.max_norm,
                 self.norm_type, self.scale_grad_by_freq, self.sparse)
        else:
            return nn.functional.embedding(input, self.weight, self.padding_idx, self.max_norm,
                 self.norm_type, self.scale_grad_by_freq, self.sparse)
        

In [10]:
def Qsin_W(model, bits = 8):
    qmin = -1.0 * (2**bits) / 2
    qmax = -qmin - 1
    loss = 0
    for layer in model.modules():
        try:
            scale = layer.Quantize_weights.max_abs_tr / ((qmax - qmin) / 2)
            sq_scale = torch.square(scale)
            QSin_w = QSin(torch.div(layer.weight, scale), bits)
            loss = loss + torch.mul(sq_scale, QSin_w)
        except:
            continue
    return loss

In [11]:
def Qsin_A(model):
    loss = 0
    for layer in model.modules():
        try:
            loss = loss + layer.qsin_loss_A
        except:
            continue
    return loss

In [12]:
def get_custom_Linear(in_features, out_features, bias, weight):
    linear = Linear(in_features, out_features)
    linear.bias = bias
    linear.weight = weight
    return linear

def get_custom_Embeding(num_embeddings, embedding_dim, padding_idx, weight):
    embedding = Embedding(num_embeddings, embedding_dim, padding_idx)
    embedding.weight = weight
    return embedding

def change_layers(model):
    for name, layer in model.named_children():
        #if name == 'intermediate' or \
        #name == 'output'or name == 'embeddings':
        #   continue
        if isinstance(layer, nn.Linear):
            setattr(model, name, get_custom_Linear(
                                                layer.in_features,
                                                layer.out_features,
                                                layer.bias,
                                                layer.weight
            ))
            
        if isinstance(layer, nn.Embedding):
            setattr(model, name, get_custom_Embeding(
                                                layer.num_embeddings,
                                                layer.embedding_dim,
                                                layer.padding_idx,
                                                layer.weight
            ))
        change_layers(getattr(model, name))

In [13]:
#pooler
#classifier
#attention

In [14]:
def quantize_model(model, quantize = False, bits = 8, qmode = "dynamic"):
    if quantize:
        print("Quantize mode on")
        for name, layer in model.named_modules():
            try:
                mode = layer.mode()
                if mode == False:
                    if 'pooler' in name or 'attention' in name or 'token_type_embeddings' in name:
                        layer.change_mod(True, 4, qmode)
                    elif 'classifier' in name: continue 
                    else: layer.change_mod(True, 8, qmode)
            except:
                continue
    else:
        print("Quantize mode off")
        for layer in model.modules():
            try:
                mode = layer.mode()
                if mode == True:
                    layer.change_mod(False, 0)
            except:
                continue
    return model

In [15]:
#pooler
#classifier
#attention

        #if name == 'intermediate' or \
        #name == 'output'or name == 'embeddings':
        #    continue

# glue metric

In [16]:
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import f1_score, matthews_corrcoef

import datasets


_CITATION = """\
@inproceedings{wang2019glue,
  title={{GLUE}: A Multi-Task Benchmark and Analysis Platform for Natural Language Understanding},
  author={Wang, Alex and Singh, Amanpreet and Michael, Julian and Hill, Felix and Levy, Omer and Bowman, Samuel R.},
  note={In the Proceedings of ICLR.},
  year={2019}
}
"""

_DESCRIPTION = """\
GLUE, the General Language Understanding Evaluation benchmark
(https://gluebenchmark.com/) is a collection of resources for training,
evaluating, and analyzing natural language understanding systems.
"""

_KWARGS_DESCRIPTION = """
Compute GLUE evaluation metric associated to each GLUE dataset.
Args:
    predictions: list of predictions to score.
        Each translation should be tokenized into a list of tokens.
    references: list of lists of references for each translation.
        Each reference should be tokenized into a list of tokens.
Returns: depending on the GLUE subset, one or several of:
    "accuracy": Accuracy
    "f1": F1 score
    "pearson": Pearson Correlation
    "spearmanr": Spearman Correlation
    "matthews_correlation": Matthew Correlation
Examples:
    >>> glue_metric = datasets.load_metric('glue', 'sst2')  # 'sst2' or any of ["mnli", "mnli_mismatched", "mnli_matched", "qnli", "rte", "wnli", "hans"]
    >>> references = [0, 1]
    >>> predictions = [0, 1]
    >>> results = glue_metric.compute(predictions=predictions, references=references)
    >>> print(results)
    {'accuracy': 1.0}
    >>> glue_metric = datasets.load_metric('glue', 'mrpc')  # 'mrpc' or 'qqp'
    >>> references = [0, 1]
    >>> predictions = [0, 1]
    >>> results = glue_metric.compute(predictions=predictions, references=references)
    >>> print(results)
    {'accuracy': 1.0, 'f1': 1.0}
    >>> glue_metric = datasets.load_metric('glue', 'stsb')
    >>> references = [0., 1., 2., 3., 4., 5.]
    >>> predictions = [0., 1., 2., 3., 4., 5.]
    >>> results = glue_metric.compute(predictions=predictions, references=references)
    >>> print({"pearson": round(results["pearson"], 2), "spearmanr": round(results["spearmanr"], 2)})
    {'pearson': 1.0, 'spearmanr': 1.0}
    >>> glue_metric = datasets.load_metric('glue', 'cola')
    >>> references = [0, 1]
    >>> predictions = [0, 1]
    >>> results = glue_metric.compute(predictions=predictions, references=references)
    >>> print(results)
    {'matthews_correlation': 1.0}
"""


def simple_accuracy(preds, labels):
    return (preds == labels).mean()


def acc_and_f1(preds, labels):
    acc = simple_accuracy(preds, labels)
    f1 = f1_score(y_true=labels, y_pred=preds)
    return {
        "accuracy": acc,
        "f1": f1,
    }


def pearson_and_spearman(preds, labels):
    pearson_corr = pearsonr(preds, labels)[0]
    spearman_corr = spearmanr(preds, labels)[0]
    return {
        "pearson": pearson_corr,
        "spearmanr": spearman_corr,
    }


@datasets.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class Glue(datasets.Metric):
    def _info(self):
        if self.config_name not in [
            "sst2",
            "mnli",
            "mnli_mismatched",
            "mnli_matched",
            "cola",
            "stsb",
            "mrpc",
            "qqp",
            "qnli",
            "rte",
            "wnli",
            "hans",
        ]:
            raise KeyError(
                "You should supply a configuration name selected in "
                '["sst2", "mnli", "mnli_mismatched", "mnli_matched", '
                '"cola", "stsb", "mrpc", "qqp", "qnli", "rte", "wnli", "hans"]'
            )
        return datasets.MetricInfo(
            description=_DESCRIPTION,
            citation=_CITATION,
            inputs_description=_KWARGS_DESCRIPTION,
            features=datasets.Features(
                {
                    "predictions": datasets.Value("int64" if self.config_name != "stsb" else "float32"),
                    "references": datasets.Value("int64" if self.config_name != "stsb" else "float32"),
                }
            ),
            codebase_urls=[],
            reference_urls=[],
            format="numpy",
        )

    def _compute(self, predictions, references):
        if self.config_name == "cola":
            return {"matthews_correlation": matthews_corrcoef(references, predictions)}
        elif self.config_name == "stsb":
            return pearson_and_spearman(predictions, references)
        elif self.config_name in ["mrpc", "qqp"]:
            return acc_and_f1(predictions, references)
        elif self.config_name in ["sst2", "mnli", "mnli_mismatched", "mnli_matched", "qnli", "rte", "wnli", "hans"]:
            return {"accuracy": simple_accuracy(predictions, references)}
        else:
            raise KeyError(
                "You should supply a configuration name selected in "
                '["sst2", "mnli", "mnli_mismatched", "mnli_matched", '
                '"cola", "stsb", "mrpc", "qqp", "qnli", "rte", "wnli", "hans"]'
            )

# BERT

In [17]:
from transformers import AutoModelForSequenceClassification
model = AutoModelForSequenceClassification.from_pretrained('./models/model-bert-base/', local_files_only=True)

In [17]:
from datasets import load_dataset, load_metric
from transformers import DistilBertTokenizer
    
batch_size = 16

tokenizer = DistilBertTokenizer.from_pretrained('./models/tokenizer-bert-base/', local_files_only=True)

In [18]:
from datasets import load_from_disk
encoded_dataset = load_from_disk('cur_glue_data')


In [19]:
from transformers import TrainingArguments, Trainer
import numpy as np
task = 'cola'
metric_name = "pearson" if task == "stsb" else "matthews_correlation" if task == "cola" else "accuracy"
metric = metric = Glue(task)


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    if task != "stsb":
        predictions = np.argmax(predictions, axis=1)
    else:
        predictions = predictions[:, 0]
    return metric.compute(predictions=predictions, references=labels)



In [56]:
change_layers(model)

In [22]:
num_parameters = sum(p.numel() for name, p in model.state_dict().items())
print(num_parameters * 32 / 8 * 1e-6)

437.93716


In [23]:
w = 0
for name, p in model.bert.embeddings.state_dict().items():
    w += p.numel()
    print(name, ":", p.numel())
print(w * 32 / 8 * 1e-6)

position_ids : 512
word_embeddings.weight : 23440896
position_embeddings.weight : 393216
token_type_embeddings.weight : 1536
LayerNorm.weight : 768
LayerNorm.bias : 768
95.35078399999999


In [24]:
w = 0
for name, p in model.bert.encoder.layer[0].state_dict().items():
    w += p.numel()
    print(name, ":", p.numel())
print(w * 12 * 32 / 8 * 1e-6)

attention.self.query.weight : 589824
attention.self.query.bias : 768
attention.self.key.weight : 589824
attention.self.key.bias : 768
attention.self.value.weight : 589824
attention.self.value.bias : 768
attention.output.dense.weight : 589824
attention.output.dense.bias : 768
attention.output.LayerNorm.weight : 768
attention.output.LayerNorm.bias : 768
intermediate.dense.weight : 2359296
intermediate.dense.bias : 3072
output.dense.weight : 2359296
output.dense.bias : 768
output.LayerNorm.weight : 768
output.LayerNorm.bias : 768
340.217856


In [25]:
w = 0
for name, p in model.classifier.state_dict().items():
    w += p.numel()
    print(name, ":", p.numel())
print(w * 32 / 8 * 1e-6)

weight : 1536
bias : 2
0.0061519999999999995


In [26]:
w = 0
for name, p in model.bert.pooler.state_dict().items():
    w += p.numel()
    print(name, ":", p.numel())
print(w * 32 / 8 * 1e-6)

dense.weight : 589824
dense.bias : 768
2.362368


In [197]:
position_ids = 512
word_embeddings_weight = 23440896
position_embeddings_weight = 393216
token_type_embeddings_weight = 1536
LayerNorm_weight = 768
LayerNorm_bias = 768

embd = position_ids * 32 + word_embeddings_weight * 8 
embd += position_embeddings_weight * 8 + token_type_embeddings_weight * 4
embd += LayerNorm_weight * 32 + LayerNorm_bias * 32 
print(embd / 8 * 1e-6)

23.843072


In [198]:
attention_self_query_weight = 589824
intermediate_dense_weight = 2359296
output_dense_weight = 2359296
    
att = attention_self_query_weight * 4 * 4 # bits * count
att += intermediate_dense_weight * 8
att += output_dense_weight * 8
att += 768 * 32 * 6 + 3072 * 32
print(att / 8 * 1e-6 * 12)

71.14752


In [199]:
# pooler
dense_weight = 589824
dense_bias = 768
pooler = dense_weight * 4 + dense_bias * 32
print(pooler / 8 * 1e-6)

0.29798399999999997


In [182]:
# all 32 bits
95.35078399999999 + 340.107264 + 2.362368

437.82041599999997

In [186]:
# all 8 bits sheme
23.84384 + 85.30329599999999 + 0.592896

109.74003199999999

In [200]:
# mix 4-8 bits
23.843072 + 71.14752 + 0.29798399999999997

95.28857599999999

In [202]:
437.82041599999997 / 109.74003199999999

3.9896144371454167

In [203]:
437.82041599999997 / 95.28857599999999

4.59467896760258

In [21]:
validation_key = "validation_mismatched" if task == "mnli-mm" else "validation_matched" if task == "mnli" else "validation"

In [25]:
args = TrainingArguments(
    "test-glue",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    num_train_epochs=4,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model=metric_name,
)

trainer = Trainer(
    model,
    args,
    train_dataset=encoded_dataset["train"],
    eval_dataset=encoded_dataset[validation_key],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [24]:
trainer.train()

Epoch,Training Loss,Validation Loss,Matthews Correlation,Runtime,Samples Per Second
1,No log,0.459633,0.485139,0.731,1426.82
2,0.438300,0.42599,0.53116,0.6741,1547.138
3,0.438300,0.467304,0.552074,0.6744,1546.664
4,0.205400,0.512932,0.573205,0.6752,1544.764


TrainOutput(global_step=1072, training_loss=0.3110503678891196, metrics={'train_runtime': 126.4256, 'train_samples_per_second': 8.479, 'total_flos': 1051111929774804.0, 'epoch': 4.0})

In [22]:
torch.cuda.empty_cache()

In [23]:
from transformers import AutoModelForSequenceClassification
model = AutoModelForSequenceClassification.from_pretrained('./test-glue/checkpoint-1072/',
                                                           local_files_only=True)

In [26]:
trainer.evaluate()

{'eval_loss': 0.5129320621490479,
 'eval_matthews_correlation': 0.5732046470010711,
 'eval_runtime': 0.7093,
 'eval_samples_per_second': 1470.519}

In [27]:
change_layers(model)

In [28]:
model = quantize_model(model, quantize=False, bits = 4)
model = quantize_model(model, quantize=True, bits = 4)

Quantize mode off
Quantize mode on


In [29]:
Eval_DQ = Trainer(
    model,
    args,
    eval_dataset=encoded_dataset[validation_key],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [30]:
Eval_DQ.evaluate()

{'eval_loss': 0.4798387289047241,
 'eval_matthews_correlation': 0.41682112176346214,
 'eval_runtime': 1.073,
 'eval_samples_per_second': 972.048}

In [31]:
model = quantize_model(model, quantize=False, bits = 4)
model = quantize_model(model, quantize=True, bits = 4, qmode = "static")
encoded_dataset_static = load_from_disk('cur_glue_data_st')

Quantize mode off
Quantize mode on


In [32]:
train_enc = encoded_dataset_static['train']

In [33]:
from torch.utils.data import DataLoader
train_enc.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'label'])
train_loader = torch.utils.data.DataLoader(train_enc, batch_size=8, shuffle = True)

In [34]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)
model.eval()
i = 0
for batch in tqdm(train_loader):
    i += 1
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    labels = batch['label'].to(device)
    outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
    if i == 32:
        break

  3%|▎         | 31/1069 [00:33<18:30,  1.07s/it]


In [35]:
args = TrainingArguments(
    "test-glue",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    metric_for_best_model=metric_name,
)

Eval_SQ = Trainer(
    model,
    args,
    eval_dataset=encoded_dataset[validation_key],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)
Eval_SQ.evaluate()

{'eval_loss': 0.5073927044868469,
 'eval_matthews_correlation': 0.34197644445477027,
 'eval_runtime': 1.3166,
 'eval_samples_per_second': 792.194}

# QSIN

In [None]:
qsin_activation_mode(model)
print()

In [None]:
torch.cuda.empty_cache()

In [None]:
import torch
from transformers import Trainer

class QSinTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        outputs = model(**inputs)
        Qsin_W_loss = Qsin_W(model, 4)
        Qsin_A_loss = Qsin_A(model)
        L = outputs[0]
        lambda_w = 10 ** (np.round(np.log10(Qsin_W_loss.cuda().tolist()) - np.log10(L.cuda().tolist())))
        lambda_a = 10 ** (np.round(np.log10(Qsin_A_loss.cuda().tolist()) - np.log10(L.cuda().tolist()))+1)
        loss = L + Qsin_W_loss / lambda_w + Qsin_A_loss / lambda_a
        return (loss, outputs) if return_outputs else loss

In [38]:
args = TrainingArguments(
    "qsin_train_tmp",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    #eval_steps=10, 
    num_train_epochs=6,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model=metric_name,
)

trainer_QSin = QSinTrainer(
    model,
    args,
    train_dataset=encoded_dataset["train"],
    eval_dataset=encoded_dataset[validation_key],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [39]:
trainer_QSin.train()

Epoch,Training Loss,Validation Loss,Pearson,Spearmanr,Runtime,Samples Per Second
1,No log,1.338755,0.874005,0.872234,12.6678,118.41
2,1.252400,1.177294,0.881382,0.87906,12.6631,118.454
3,0.760000,1.130876,0.884721,0.881053,12.6464,118.611
4,0.760000,1.188228,0.88659,0.882593,12.6259,118.804
5,0.482500,1.154018,0.886627,0.882877,12.6224,118.836
6,0.286000,1.118148,0.890415,0.886659,12.6817,118.28


Some weights of the model checkpoint at qsin_train_tmp/checkpoint-2160 were not used when initializing BertForSequenceClassification: ['bert.embeddings.word_embeddings.Quantize_weights.max_abs_tr', 'bert.embeddings.position_embeddings.Quantize_weights.max_abs_tr', 'bert.embeddings.token_type_embeddings.Quantize_weights.max_abs_tr', 'bert.encoder.layer.0.attention.self.query.Quantize_weights.max_abs_tr', 'bert.encoder.layer.0.attention.self.query.Quantize_input.max_abs_tr', 'bert.encoder.layer.0.attention.self.key.Quantize_weights.max_abs_tr', 'bert.encoder.layer.0.attention.self.key.Quantize_input.max_abs_tr', 'bert.encoder.layer.0.attention.self.value.Quantize_weights.max_abs_tr', 'bert.encoder.layer.0.attention.self.value.Quantize_input.max_abs_tr', 'bert.encoder.layer.0.attention.output.dense.Quantize_weights.max_abs_tr', 'bert.encoder.layer.0.attention.output.dense.Quantize_input.max_abs_tr', 'bert.encoder.layer.0.intermediate.dense.Quantize_weights.max_abs_tr', 'bert.encoder.layer

TrainOutput(global_step=2160, training_loss=0.6629010659677012, metrics={'train_runtime': 1940.7698, 'train_samples_per_second': 1.113, 'total_flos': 2788098539560632.0, 'epoch': 6.0})

# QAT

In [36]:
args = TrainingArguments(
    "qat_train_tmp",
    evaluation_strategy = "epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    #eval_steps=10,
    num_train_epochs=6,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model=metric_name,
)


trainer_QAT = Trainer(
    model,
    args,
    train_dataset=encoded_dataset["train"],
    eval_dataset=encoded_dataset[validation_key],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [37]:
torch.cuda.empty_cache()

In [38]:
trainer_QAT.train()

Epoch,Training Loss,Validation Loss,Pearson,Spearmanr,Runtime,Samples Per Second
1,No log,0.56463,0.879557,0.874765,2.981,503.189
2,0.516000,0.539674,0.879536,0.876675,2.9845,502.593
3,0.262800,0.486788,0.88708,0.881744,2.9799,503.375
4,0.262800,0.481755,0.886697,0.883014,2.9825,502.936
5,0.150100,0.464847,0.890952,0.88718,2.9848,502.546
6,0.090600,0.47217,0.890047,0.8858,2.9817,503.077


Some weights of the model checkpoint at qat_train_tmp/checkpoint-1800 were not used when initializing BertForSequenceClassification: ['bert.embeddings.word_embeddings.Quantize_weights.max_abs_tr', 'bert.embeddings.position_embeddings.Quantize_weights.max_abs_tr', 'bert.embeddings.token_type_embeddings.Quantize_weights.max_abs_tr', 'bert.encoder.layer.0.attention.self.query.Quantize_weights.max_abs_tr', 'bert.encoder.layer.0.attention.self.query.Quantize_input.max_abs_tr', 'bert.encoder.layer.0.attention.self.key.Quantize_weights.max_abs_tr', 'bert.encoder.layer.0.attention.self.key.Quantize_input.max_abs_tr', 'bert.encoder.layer.0.attention.self.value.Quantize_weights.max_abs_tr', 'bert.encoder.layer.0.attention.self.value.Quantize_input.max_abs_tr', 'bert.encoder.layer.0.attention.output.dense.Quantize_weights.max_abs_tr', 'bert.encoder.layer.0.attention.output.dense.Quantize_input.max_abs_tr', 'bert.encoder.layer.0.intermediate.dense.Quantize_weights.max_abs_tr', 'bert.encoder.layer.

TrainOutput(global_step=2160, training_loss=0.24185576218145866, metrics={'train_runtime': 456.1213, 'train_samples_per_second': 4.736, 'total_flos': 2788098539560632.0, 'epoch': 6.0})