In [1]:
import os
import torch
import pandas as pd
import numpy as np
from torch import nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer
from torch.nn.utils.rnn import pad_sequence
from transformers import BertForSequenceClassification, AdamW
from tqdm import tqdm

### Dataset 및 DataLoader 생성

In [2]:
# 반드시 do_lower_case=True로 해야 한다.
# bert-base-uncased는 영어 데이터를 소문자로 변환해서 학습한 모델이기 때문이다.
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

In [3]:
class CoLADataset(Dataset):
    def __init__(self, path, tokenizer, is_train=True, is_inference=False):
        '''
        path: CoLA 데이터셋 위치
        tokenizer: CoLA 데이터셋을 토크나이징할 토크나이저, ex) BertTokenizer
        is_train: CoLADataset을 정의하는 목적이 모델 학습용일 경우 True, 그렇지 않으면 False
        is_inference: CoLADataset을 정의하는 목적이 인퍼런스용일 경우 True, 그렇지 않으면 False
        '''
        
        if is_train:
            filename = os.path.join(path, 'raw/in_domain_train.tsv')
        else:
            if is_inference:
                filename = os.path.join(path, 'raw/out_of_domain_dev.tsv')
            else:
                filename = os.path.join(path, 'raw/in_domain_dev.tsv')
        df = pd.read_csv(filename, sep='\t', names=['source', 'label', 'judgement', 'text'])
        self.input_ids = []
        self.token_type_ids = []
        self.attention_mask = []
        for t in df.text:
            inp = tokenizer(t, return_tensors='pt')
            self.input_ids.append(inp['input_ids'])
            self.token_type_ids.append(inp['token_type_ids'])
            self.attention_mask.append(inp['attention_mask'])
        self.label = df.label
        
    def __len__(self):
        return len(self.input_ids)
    
    def __getitem__(self, idx):
        return [self.input_ids[idx], self.token_type_ids[idx], self.attention_mask[idx], self.label[idx]]

In [4]:
eval_dataset = CoLADataset('../../data/cola_classification', tokenizer, is_train=False)

In [5]:
len(eval_dataset)

527

In [6]:
def collate_fn(batch):
    input_ids = [b[0][0] for b in batch]
    token_type_ids = [b[1][0] for b in batch]
    attention_mask = [b[2][0] for b in batch]
    label = torch.tensor([b[3] for b in batch])
    input_ids = pad_sequence(input_ids, batch_first=True)
    token_type_ids = pad_sequence(token_type_ids, batch_first=True)
    attention_mask = pad_sequence(attention_mask, batch_first=True)
    return input_ids, token_type_ids, attention_mask, label

### quantization

In [7]:
# Load BertForSequenceClassification, the pretrained BERT model with a single 
# linear classification layer on top. 
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels = 2)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

In [8]:
# 학습한 모델 로딩
model.load_state_dict(torch.load('cola_model.bin', map_location='cpu'))
#model.load_state_dict(torch.load('cola_model_no_pretrained.bin', map_location='cpu'))
model.eval()

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, element

In [9]:
quantized_model = torch.quantization.quantize_dynamic(
    model, {torch.nn.Linear}, dtype=torch.qint8
)

In [10]:
quantized_model.device

device(type='cpu')

### Inference

학습한 모델을 로딩해서 Inference하는 코드

In [11]:
import time
from sklearn.metrics import confusion_matrix
from sklearn.metrics import matthews_corrcoef

In [12]:
# 테스트를 위한 CoLA 데이터셋 로딩 및 DataLoader 클래스 생성
test_dataset = CoLADataset('../../data/cola_classification', tokenizer, is_train=False, is_inference=True)
test_dataloader = DataLoader(test_dataset, collate_fn=collate_fn, batch_size=32, shuffle=False)

In [13]:
def measure_accuracy_and_latency(model, dataloader):
    tbar = tqdm(dataloader, desc='Inference', leave=True)
    
    label_list = []
    pred_list = []
    start = time.time()
    for i, d in enumerate(tbar):
        input_ids, token_type_ids, attention_mask, labels = d
                
        # do inference
        pred = model(input_ids=input_ids, attention_mask=attention_mask)
        pred = pred[0].argmax(dim=1)
        
        label_list.extend(labels.cpu().data.numpy())
        pred_list.extend(pred.cpu().data.numpy())
    end = time.time()

    labels = np.array(label_list)
    preds = np.array(pred_list)
    
    acc = (labels == preds).mean()
    latency = (end - start) / len(labels)
    
    print(f'acc={acc:.3f} latency={latency:.4f}')

In [14]:
measure_accuracy_and_latency(model, test_dataloader)

Inference: 100%|██████████| 17/17 [00:02<00:00,  6.07it/s]

acc=0.816 latency=0.0054





In [15]:
measure_accuracy_and_latency(quantized_model, test_dataloader)

Inference: 100%|██████████| 17/17 [00:01<00:00, 13.07it/s]

acc=0.814 latency=0.0025





In [16]:
# torch.save(quantized_model.state_dict(), 'cola_model_quantized.bin')

In [17]:
ls -alh *.bin

-rw-rw-r-- 1 jkfirst jkfirst 418M  6월 23 04:37 cola_model.bin
-rw-rw-r-- 1 jkfirst jkfirst 174M  6월 28 19:19 cola_model_quantized.bin


In [18]:
quantized_model

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): DynamicQuantizedLinear(in_features=768, out_features=768, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
              (key): DynamicQuantizedLinear(in_features=768, out_features=768, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
              (value): DynamicQuantizedLinear(in_features=768, out_features=768, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
              (dropout): Dropout(p=0.1, inplace=False)
            )
      

In [19]:
model

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, element

In [20]:
quantized_model.bert.encoder.layer[0].attention.self.query

DynamicQuantizedLinear(in_features=768, out_features=768, dtype=torch.qint8, qscheme=torch.per_tensor_affine)

### layer fusion

In [21]:
x = torch.from_numpy(np.random.random((16, 512, 768))).float()

In [22]:
model.bert.encoder.layer[0].attention

BertAttention(
  (self): BertSelfAttention(
    (query): Linear(in_features=768, out_features=768, bias=True)
    (key): Linear(in_features=768, out_features=768, bias=True)
    (value): Linear(in_features=768, out_features=768, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (output): BertSelfOutput(
    (dense): Linear(in_features=768, out_features=768, bias=True)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
)

In [23]:
transpose_fn = model.bert.encoder.layer[0].attention.self.transpose_for_scores

In [24]:
mixed_query_layer = model.bert.encoder.layer[0].attention.self.query(x)
query_layer = model.bert.encoder.layer[0].attention.self.transpose_for_scores(mixed_query_layer)
mixed_key_layer = model.bert.encoder.layer[0].attention.self.key(x)
key_layer = model.bert.encoder.layer[0].attention.self.transpose_for_scores(mixed_key_layer)
mixed_value_layer = model.bert.encoder.layer[0].attention.self.value(x)
value_layer = model.bert.encoder.layer[0].attention.self.transpose_for_scores(mixed_value_layer)
print(mixed_query_layer.shape, mixed_value_layer.shape, mixed_key_layer.shape)

attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
print(query_layer.shape, key_layer.transpose(-1, -2).shape, value_layer.shape)

context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()

new_context_layer_shape = context_layer.size()[:-2] + (768,)
context_layer = context_layer.view(new_context_layer_shape)

context_layer.shape

torch.Size([16, 512, 768]) torch.Size([16, 512, 768]) torch.Size([16, 512, 768])
torch.Size([16, 12, 512, 64]) torch.Size([16, 12, 64, 512]) torch.Size([16, 12, 512, 64])


torch.Size([16, 512, 768])

In [97]:
def bert_layer_fusion(attention_module, T):
    wq = attention_module.self.query.weight
    bq = attention_module.self.query.bias
    wk = attention_module.self.key.weight
    bk = attention_module.self.key.bias
    wv = attention_module.self.value.weight
    bv = attention_module.self.value.bias
    
    mixed_query_layer = torch.matmul(x, wq.T) + bq
    mixed_key_layer = torch.matmul(x, wk.T) + bk
    mixed_value_layer = torch.matmul(x, wv.T) + bv
    for i in range(12):
        a = mixed_query_layer[:,:,i*64:(i+1)*64]
        b = mixed_key_layer[:,:,i*64:(i+1)*64]
        o = torch.matmul(a, b.transpose(-1, -2))
        
    print(o[0][0])
        
    query_layer = T(torch.matmul(x, wq.T) + bq)
    key_layer = T(torch.matmul(x, wk.T) + bk)
    value_layer = T(torch.matmul(x, wv.T) + bv)
    
    # attention_scores, attention_probs -> (16, 12, 512, 512)
    attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
    attention_probs = nn.functional.softmax(attention_scores, dim=-1)
    print(attention_scores[:,-1,:,:][0][0])
    
    context_layer = torch.matmul(attention_probs, value_layer)
    context_layer = context_layer.permute(0, 2, 1, 3).contiguous()

    new_context_layer_shape = context_layer.size()[:-2] + (768,)
    context_layer = context_layer.view(new_context_layer_shape)
    
    return context_layer

In [98]:
o = bert_layer_fusion(model.bert.encoder.layer[0].attention, transpose_fn)

tensor([-1.9015e+00,  2.5915e+00,  4.4448e+00, -2.3938e+00,  1.2839e+00,
         2.8127e+00,  1.4193e+00,  1.9148e-02,  2.6554e+00, -7.0722e-01,
         1.6721e+00,  7.4494e-02,  5.8034e+00, -1.1887e+00,  8.1230e-01,
         1.2581e+00, -3.7323e-01,  4.3691e+00,  1.2334e+00,  1.4726e+00,
        -2.7465e+00,  6.5842e-01,  1.9186e+00, -3.3944e+00,  9.1420e-01,
         3.0800e+00, -2.6303e+00,  1.6961e+00, -1.7202e+00, -6.3358e-01,
        -2.1531e+00, -1.3558e+00, -1.7010e-01,  3.5539e+00,  3.3828e+00,
         2.7084e+00, -2.8276e-01, -2.4939e-01,  5.2201e-01, -1.6801e+00,
         5.6782e-01,  2.4558e+00, -5.3414e-02,  2.8917e+00, -3.2387e-01,
         2.6608e+00, -1.1062e-01, -7.4396e-01,  2.6550e+00,  3.9978e+00,
        -6.5622e-01,  1.6588e-01,  1.3661e-01, -1.0841e+00,  3.1562e+00,
         4.4794e-01,  2.6865e+00,  1.9943e+00,  4.8848e+00,  1.7926e+00,
        -4.0432e+00,  7.0576e-01,  5.1248e-01,  3.7472e+00,  2.3168e+00,
        -8.2729e-01, -1.0973e+00,  1.7524e+00,  2.4

In [27]:
wq = model.bert.encoder.layer[0].attention.self.query.weight
bq = model.bert.encoder.layer[0].attention.self.query.bias
wq.shape, bq.shape

(torch.Size([768, 768]), torch.Size([768]))

In [28]:
wk = model.bert.encoder.layer[0].attention.self.key.weight
bk = model.bert.encoder.layer[0].attention.self.key.bias
wk.shape, bk.shape

(torch.Size([768, 768]), torch.Size([768]))

In [29]:
wv = model.bert.encoder.layer[0].attention.self.value.weight
bv = model.bert.encoder.layer[0].attention.self.value.bias
wv.shape, bv.shape

(torch.Size([768, 768]), torch.Size([768]))

In [30]:
mixed_query_layer[0]

tensor([[ 1.4820, -0.4495, -0.0934,  ...,  0.3367, -0.3739, -0.6107],
        [ 0.9407, -0.3381, -0.4547,  ...,  0.2770, -0.3883,  0.5439],
        [ 0.9656, -0.5552, -0.3678,  ...,  0.4237,  0.5679, -0.4176],
        ...,
        [ 0.9892,  0.2435, -0.2735,  ...,  0.4637, -0.9741,  0.5602],
        [ 0.9645, -0.8999, -0.3634,  ...,  0.7425, -0.7519,  0.0691],
        [ 1.2305, -0.4847, -0.1822,  ...,  0.4492,  0.2554,  0.2505]],
       grad_fn=<SelectBackward0>)

In [31]:
# mixed_query_layer -> torch.matmul(x, wq.T) + bq
# mixed_key_layer -> torch.matmul(x, wk.T) + bk
# mixed_value_layer -> torch.matmul(x, wv.T) + bv

In [32]:
_mixed_query_layer = torch.matmul(x, wq.T) + bq
_mixed_key_layer = torch.matmul(x, wk.T) + bk
_mixed_value_layer = torch.matmul(x, wv.T) + bv

In [33]:
_query_layer = transpose_fn(_mixed_query_layer)
_key_layer = transpose_fn(_mixed_key_layer)
_value_layer = transpose_fn(_mixed_value_layer)

In [34]:
_attention_scores = torch.matmul(_query_layer, _key_layer.transpose(-1, -2))
_attention_probs = nn.functional.softmax(_attention_scores, dim=-1)

In [35]:
_attention_probs[0][0]

tensor([[1.3736e-03, 4.9345e-03, 5.3059e-04,  ..., 5.7736e-06, 3.6539e-03,
         5.1676e-04],
        [4.5039e-03, 3.3418e-04, 1.6283e-04,  ..., 2.8543e-06, 3.2148e-04,
         1.6461e-04],
        [1.5002e-02, 2.9829e-03, 2.2461e-04,  ..., 5.4842e-06, 6.1736e-04,
         3.3424e-04],
        ...,
        [4.5424e-03, 3.2775e-03, 2.4234e-04,  ..., 1.5969e-05, 2.7897e-03,
         4.0480e-04],
        [5.9863e-03, 1.9494e-03, 9.0838e-04,  ..., 4.4690e-05, 8.4949e-04,
         4.9135e-04],
        [1.1367e-03, 1.3965e-03, 5.6448e-04,  ..., 7.6649e-05, 2.4117e-03,
         1.4031e-04]], grad_fn=<SelectBackward0>)

In [36]:
attention_probs[0][0]

tensor([[1.3736e-03, 4.9345e-03, 5.3059e-04,  ..., 5.7736e-06, 3.6539e-03,
         5.1676e-04],
        [4.5039e-03, 3.3418e-04, 1.6283e-04,  ..., 2.8543e-06, 3.2148e-04,
         1.6461e-04],
        [1.5002e-02, 2.9829e-03, 2.2461e-04,  ..., 5.4842e-06, 6.1736e-04,
         3.3424e-04],
        ...,
        [4.5424e-03, 3.2775e-03, 2.4234e-04,  ..., 1.5969e-05, 2.7897e-03,
         4.0480e-04],
        [5.9863e-03, 1.9494e-03, 9.0838e-04,  ..., 4.4690e-05, 8.4949e-04,
         4.9135e-04],
        [1.1367e-03, 1.3965e-03, 5.6448e-04,  ..., 7.6649e-05, 2.4117e-03,
         1.4031e-04]], grad_fn=<SelectBackward0>)