In [3]:
import copy
import matplotlib.pyplot as plt
import numpy as np
import random
import time
import torch
import torch.nn.functional as F
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.models.gpt2.modeling_gpt2 import GPT2Model

In [4]:
def generate_token_with_past(inputs):
    with torch.no_grad():
        outputs = model(**inputs)

    logits = outputs.logits
    last_logits = logits[0, -1, :]
    next_token_id = last_logits.argmax()
    return next_token_id, outputs.past_key_values

def generate(inputs, max_tokens):
    generated_tokens = []
    next_inputs = inputs
    for _ in range(max_tokens):
        next_token_id, past_key_values = generate_token_with_past(next_inputs)
    
        next_inputs = { 
            "input_ids": next_token_id.reshape((1, 1)),
            "attention_mask": torch.cat(
                [next_inputs["attention_mask"], torch.tensor([[1]])],
                dim=1),
            "past_key_values": past_key_values,
            }
        next_token = tokenizer.decode(next_token_id)
        generated_tokens.append(next_token)
    return "".join(generated_tokens)

In [5]:
model_name = 'gpt2'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

In [6]:
# Define PAD Token = EOS Token = 50256
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id

# Pad on the left so we can append new tokens on the right
tokenizer.padding_side = "left"
tokenizer.truncation_siede = "left"

In [7]:
# fit dtype post quantization to "pretend" to be fp32
def get_float32_dtype(self):
    return torch.float32
print(GPT2Model.dtype)
GPT2Model.dtype = property(get_float32_dtype)
print(GPT2Model.dtype)

<property object at 0x0000020DCC466F20>
<property object at 0x0000020DB1543330>


In [8]:
model.get_memory_footprint()

510342192

# Qauntize and Dequantize Basic

In [9]:
def quantize(t):
    # obtain range of values in the tensor to map between 0 and 255
    min_val, max_val = t.min(), t.max()

    # determine the "zero-point" or value in the tensor to map to 0
    scale = (max_val - min_val)/255
    zero_point = min_val

    # quantize and clamp to ensure we are in [0, 255]
    t_quant = (t - zero_point) / scale
    t_quant = torch.clamp(t_quant, min=0, max=255)

    # keep track of scale and zero_point for reversing quantization
    state = (scale, zero_point)

    # cast to uint8 and return 
    t_quant = t_quant.type(torch.uint8)
    return t_quant, state

In [10]:
# original t
t = model.transformer.h[0].attn.c_attn.weight.data
t, t.shape

(tensor([[-0.4738, -0.2614, -0.0978,  ...,  0.0513, -0.0584,  0.0250],
         [ 0.0874,  0.1473,  0.2387,  ..., -0.0525, -0.0113, -0.0156],
         [ 0.0039,  0.0695,  0.3668,  ...,  0.1143,  0.0363, -0.0318],
         ...,
         [-0.2592, -0.0164,  0.1991,  ...,  0.0095, -0.0516,  0.0319],
         [ 0.1517,  0.2170,  0.1043,  ...,  0.0293, -0.0429, -0.0475],
         [-0.4100, -0.1924, -0.2400,  ..., -0.0046,  0.0070,  0.0198]]),
 torch.Size([768, 2304]))

In [11]:
# quantize t
t_q, state = quantize(t)
t_q, t_q.min(), t_q.max(), state

(tensor([[107, 116, 124,  ..., 130, 125, 129],
         [132, 135, 139,  ..., 126, 128, 127],
         [128, 131, 145,  ..., 133, 130, 127],
         ...,
         [116, 127, 137,  ..., 129, 126, 130],
         [135, 138, 133,  ..., 129, 126, 126],
         [110, 119, 117,  ..., 128, 128, 129]], dtype=torch.uint8),
 tensor(0, dtype=torch.uint8),
 tensor(255, dtype=torch.uint8),
 (tensor(0.0221), tensor(-2.8436)))

In [12]:
def dequantize(t, state):
    scale, zero_point = state
    return t.to(torch.float32) * scale + zero_point

In [13]:
# dequantize t  (퀀타이즈 전과 완전히 동일하지는 않다.)
t_rev = dequantize(t_q, state)
t_rev

tensor([[-0.4774, -0.2783, -0.1014,  ...,  0.0313, -0.0793,  0.0092],
        [ 0.0755,  0.1419,  0.2303,  ..., -0.0572, -0.0129, -0.0351],
        [-0.0129,  0.0534,  0.3630,  ...,  0.0976,  0.0313, -0.0351],
        ...,
        [-0.2783, -0.0351,  0.1861,  ...,  0.0092, -0.0572,  0.0313],
        [ 0.1419,  0.2082,  0.0976,  ...,  0.0092, -0.0572, -0.0572],
        [-0.4110, -0.2120, -0.2562,  ..., -0.0129, -0.0129,  0.0092]])

In [14]:
# show errors (퀀타이즈 전 후 값의 차이)
torch.abs(t - t_rev)

tensor([[0.0035, 0.0170, 0.0036,  ..., 0.0200, 0.0209, 0.0158],
        [0.0119, 0.0055, 0.0084,  ..., 0.0046, 0.0017, 0.0195],
        [0.0168, 0.0161, 0.0038,  ..., 0.0167, 0.0050, 0.0032],
        ...,
        [0.0191, 0.0187, 0.0131,  ..., 0.0004, 0.0056, 0.0006],
        [0.0098, 0.0088, 0.0067,  ..., 0.0202, 0.0143, 0.0097],
        [0.0010, 0.0196, 0.0162,  ..., 0.0084, 0.0199, 0.0107]])

# Apply Quantization to Model

In [46]:
bprompt = "The quick brown fox jumped over the"
inputs = tokenizer(prompt, return_tensors='pt')

def generate_token_with_past(inputs):
    with torch.no_grad():
        outputs = model(**inputs)

    logits = outputs.logits
    last_logits = logits[0, -1, :]
    next_token_id = last_logits.argmax()
    return next_token_id, outputs.past_key_values

def generate(inputs, max_tokens):
    generated_tokens = []
    next_inputs = inputs
    for _ in range(max_tokens):
        next_token_id, past_key_values = generate_token_with_past(next_inputs)
    
        next_inputs = { 
            "input_ids": next_token_id.reshape((1, 1)),
            "attention_mask": torch.cat(
                [next_inputs["attention_mask"], torch.tensor([[1]])],
                dim=1),
            "past_key_values": past_key_values,
            }
        next_token = tokenizer.decode(next_token_id)
        generated_tokens.append(next_token)
    return "".join(generated_tokens)

In [47]:
def full_generate(model, tokenizer, inputs):
    tokend_inputs = tokenizer(inputs, return_tensors='pt')
    tokens = generate(tokend_inputs, max_tokens=10)
    result = ''.join([inputs, tokens])
    return result

In [48]:
inputs = "The qick brown fox jumped over the"
response_expected = full_generate(model, tokenizer, inputs)
response_expected

'The qick brown fox jumped over the fence and ran to the other side of the fence'

In [51]:
def quntize_model(model):
    states = {}
    for name, param in model.named_parameters():
        param.requires_grad = False
        param.data, state = quantize(param.data)
        states[name] = state
    return model, states

In [52]:
q_model, states = quntize_model(model)

In [53]:
q_model.get_memory_footprint()

137022768

In [56]:
# 기여도 측정
def size_in_bytes(t):
    return t.numel() * t.element_size()

In [55]:
sum([
    size_in_bytes(v[0]) + size_in_bytes(v[1])
    for v in states.values()
])

1181

In [57]:
def dequntize_model(model):
    for name, param in model.named_parameters():
        state = states[name]
        param.data = dequantize(param.data, state)
    return model

In [58]:
deq_model = dequntize_model(q_model)

In [59]:
deq_model.get_memory_footprint()

510342192

In [64]:
# 한번 퀀타이즈후 디퀀타이즈해도.. 모델 성능에는 데미지가 있다.
inputs = "The qick brown fox jumped over the"
response_expected = full_generate(deq_model, tokenizer, inputs)
response_expected

'The qick brown fox jumped over the same same same same same same same same same same'