In [20]:
import torch
from transformers import OPTForCausalLM

In [21]:
model = OPTForCausalLM.from_pretrained('facebook/opt-350m', torch_dtype='auto')
model.eval()
device = torch.device('cuda:0')

In [22]:
layers = model.model.decoder.layers

In [23]:
import torch.nn as nn

def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''):
    if type(module) in layers:
        return {name: module}
    res = {}
    for name1, child in module.named_children():
        res.update(find_layers(
            child, layers=layers, name=name + '.' + name1 if name != '' else name1
        ))
    return res


In [24]:
layer = layers[0].to(device)
subset = find_layers(layer)

In [25]:
def quantize(x, scale, zero, maxq):
    if maxq < 0:
        return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero
    q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
    return scale * (q - zero)

class RTN(nn.Module):

    def __init__(self, shape=1):
        super(RTN, self).__init__()
        self.register_buffer('maxq', torch.tensor(0))
        self.register_buffer('scale', torch.zeros(shape))
        self.register_buffer('zero', torch.zeros(shape))

    def configure(
        self,
        bits, perchannel=False, sym=True, 
        mse=False, norm=2.4, grid=100, maxshrink=.8,
        trits=False
    ):
        self.maxq = torch.tensor(2 ** bits - 1)
        self.perchannel = perchannel
        self.sym = sym
        self.mse = mse
        self.norm = norm
        self.grid = grid
        self.maxshrink = maxshrink 
        if trits:
            self.maxq = torch.tensor(-1) 

    def find_params(self, x, weight=False):
        dev = x.device
        self.maxq = self.maxq.to(dev)

        shape = x.shape
        if self.perchannel:
            if weight:
                x = x.flatten(1)
            else:
                if len(shape) == 4:
                    x = x.permute([1, 0, 2, 3])
                    x = x.flatten(1)
                if len(shape) == 3:
                    x = x.reshape((-1, shape[-1])).t()
                if len(shape) == 2:
                    x = x.t()
        else:
            x = x.flatten().unsqueeze(0)

        tmp = torch.zeros(x.shape[0], device=dev)
        xmin = torch.minimum(x.min(1)[0], tmp)
        xmax = torch.maximum(x.max(1)[0], tmp)

        if self.sym:
            xmax = torch.maximum(torch.abs(xmin), xmax)
            tmp = xmin < 0
            if torch.any(tmp):
                xmin[tmp] = -xmax[tmp]
        tmp = (xmin == 0) & (xmax == 0)
        xmin[tmp] = -1
        xmax[tmp] = +1

        if self.maxq < 0:
          self.scale = xmax
          self.zero = xmin
        else:
          self.scale = (xmax - xmin) / self.maxq
          if self.sym:
              self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2)
          else:
              self.zero = torch.round(-xmin / self.scale)

        if self.mse:
            best = torch.full([x.shape[0]], float('inf'), device=dev)
            for i in range(int(self.maxshrink * self.grid)):
                p = 1 - i / self.grid 
                xmin1 = p * xmin
                xmax1 = p * xmax
                scale1 = (xmax1 - xmin1) / self.maxq
                zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero
                q = quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq)
                q -= x
                q.abs_()
                q.pow_(self.norm)
                err = torch.sum(q, 1)
                tmp = err < best
                if torch.any(tmp):
                    best[tmp] = err[tmp]
                    self.scale[tmp] = scale1[tmp]
                    self.zero[tmp] = zero1[tmp]
        if not self.perchannel:
            if weight:
                tmp = shape[0]
            else:
                tmp = shape[1] if len(shape) != 3 else shape[2]
            self.scale = self.scale.repeat(tmp)
            self.zero = self.zero.repeat(tmp)

        if weight:
            shape = [-1] + [1] * (len(shape) - 1)
            self.scale = self.scale.reshape(shape)
            self.zero = self.zero.reshape(shape)
            return
        if len(shape) == 4:
            self.scale = self.scale.reshape((1, -1, 1, 1))
            self.zero = self.zero.reshape((1, -1, 1, 1))
        if len(shape) == 3:
            self.scale = self.scale.reshape((1, 1, -1))
            self.zero = self.zero.reshape((1, 1, -1)) 
        if len(shape) == 2:
            self.scale = self.scale.unsqueeze(0)
            self.zero = self.zero.unsqueeze(0)

    def quantize(self, x):
        if self.ready():
            return quantize(x, self.scale, self.zero, self.maxq)
        return x

    def enabled(self):
        return self.maxq > 0

    def ready(self):
        return torch.all(self.scale != 0)

In [26]:
quantizer = RTN()
quantizer.configure(4, perchannel=True, sym=False, mse=False)


In [27]:
quantizer.maxq

tensor(15)

In [32]:
subset

{'self_attn.k_proj': Linear(in_features=1024, out_features=1024, bias=True),
 'self_attn.v_proj': Linear(in_features=1024, out_features=1024, bias=True),
 'self_attn.q_proj': Linear(in_features=1024, out_features=1024, bias=True),
 'self_attn.out_proj': Linear(in_features=1024, out_features=1024, bias=True),
 'fc1': Linear(in_features=1024, out_features=4096, bias=True),
 'fc2': Linear(in_features=4096, out_features=1024, bias=True)}

In [33]:
W = subset['fc2'].weight.data

In [34]:
W.shape

torch.Size([1024, 4096])

In [35]:
x = W
x = x.flatten(1)

In [36]:
x.shape

torch.Size([1024, 4096])

In [37]:
tmp = torch.zeros(x.shape[0], device=device)
xmin = torch.minimum(x.min(1)[0], tmp)
xmax = torch.maximum(x.max(1)[0], tmp)

In [38]:
tmp = (xmin == 0) & (xmax == 0)
xmin[tmp] = -1
xmax[tmp] = +1

In [42]:
scale = (xmax - xmin) / quantizer.maxq

In [43]:
zero = torch.round(-xmin / scale)

In [44]:
scale = scale.unsqueeze(0)
zero = zero.unsqueeze(0)

In [None]:
subset['fc2'].weight.data = quantize(x.t(), scale, zero, quantizer.maxq)

In [47]:
x_q = quantize(x.t(), scale, zero, quantizer.maxq)

In [54]:
x_q.dtype

torch.float32

In [53]:
next(iter(layer.parameters())).dtype

torch.float16

tensor(-0.1438, device='cuda:0')

In [1]:
from datasets import load_dataset
from transformers import AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [80]:
testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m', use_fast=False)
testloader = tokenizer("\n\n".join(testdata['text']), return_tensors='pt')

Found cached dataset wikitext (/home/youpengzhao/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)


In [85]:
testloader['input_ids'].shape

torch.Size([1, 287645])

In [87]:
testloader['input_ids'].numel()

287645

In [79]:
len(test_set)

2176

In [57]:
for i in range(len(layers)):
    print('Quantizing layer {}'.format(i))
    layer = layers[i].to(device)
    subset = find_layers(layer)
    for name in subset:
        quantizer = RTN()
        quantizer.configure(
            4, perchannel=True, sym=False, mse=False
        )
        W = subset[name].weight.data
        quantizer.find_params(W, weight=True)
        subset[name].weight.data = quantize(W, 
                                            quantizer.scale, quantizer.zero, quantizer.maxq
                                            ).to(next(iter(layer.parameters())).dtype)
    layers[i] = layer.cpu()
    del layer
    torch.cuda.empty_cache()

Quantizing layer 1
Quantizing layer 2
Quantizing layer 3
Quantizing layer 4
Quantizing layer 5
Quantizing layer 6
Quantizing layer 7
Quantizing layer 8
Quantizing layer 9
Quantizing layer 10
Quantizing layer 11
Quantizing layer 12
Quantizing layer 13
Quantizing layer 14
Quantizing layer 15
Quantizing layer 16
Quantizing layer 17
Quantizing layer 18
Quantizing layer 19
Quantizing layer 20
Quantizing layer 21
Quantizing layer 22
Quantizing layer 23
Quantizing layer 24


In [60]:
model.model.decoder.layers = layers

In [61]:
torch.save(model.state_dict(), 'opt_350m_4_bit.pt') 

In [88]:
num_samples = testloader['input_ids'].numel() // 2048

In [89]:
num_samples

140

In [90]:
model = model.to(device)

In [91]:
losses = []
with torch.no_grad():
    for i in range(num_samples):
        outputs = model(input_ids=testloader['input_ids'][:, (i * 2048):((i + 1) * 2048)].to(device), 
                        labels=testloader['input_ids'][:, (i * 2048):((i + 1) * 2048)].to(device),
                        attention_mask = testloader["attention_mask"][:, (i * 2048):((i + 1) * 2048)].to(device).to(device))
        losses.append(outputs[0])
    loss = torch.mean(torch.stack(losses))
    perplexity = torch.exp(loss)

print(perplexity)


tensor(25.9375, device='cuda:0', dtype=torch.float16)


In [92]:
torch.cuda.empty_cache()

In [93]:
model = OPTForCausalLM.from_pretrained('facebook/opt-350m', torch_dtype='auto')
model.eval()
device = torch.device('cuda:0')



In [95]:
model = model.to(device)

In [96]:
losses = []
with torch.no_grad():
    for i in range(num_samples):
        outputs = model(input_ids=testloader['input_ids'][:, (i * 2048):((i + 1) * 2048)].to(device), 
                        labels=testloader['input_ids'][:, (i * 2048):((i + 1) * 2048)].to(device),
                        attention_mask = testloader["attention_mask"][:, (i * 2048):((i + 1) * 2048)].to(device).to(device))
        losses.append(outputs[0])
    loss = torch.mean(torch.stack(losses))
    perplexity = torch.exp(loss)

print(perplexity)

tensor(22.0156, device='cuda:0', dtype=torch.float16)


In [98]:
from transformers import AutoConfig
config = AutoConfig.from_pretrained('facebook/opt-350m')
model = OPTForCausalLM(config)

In [99]:
ckpt = torch.load('opt_350m_4_bit.pt', map_location='cpu')


In [100]:
model.load_state_dict(ckpt)

<All keys matched successfully>

In [102]:
model = model.to(device)

In [103]:
losses = []
with torch.no_grad():
    for i in range(num_samples):
        outputs = model(input_ids=testloader['input_ids'][:, (i * 2048):((i + 1) * 2048)].to(device), 
                        labels=testloader['input_ids'][:, (i * 2048):((i + 1) * 2048)].to(device),
                        attention_mask = testloader["attention_mask"][:, (i * 2048):((i + 1) * 2048)].to(device).to(device))
        losses.append(outputs[0])
    loss = torch.mean(torch.stack(losses))
    perplexity = torch.exp(loss)

print(perplexity)

tensor(29.4278, device='cuda:0')


In [108]:
print('PPL: {:.2f}'.format(perplexity))

PPL: 29.43


In [1]:
from transformers import BloomForCausalLM
model = BloomForCausalLM.from_pretrained('bigscience/bloom-560M', torch_dtype='auto')

  from .autonotebook import tqdm as notebook_tqdm
Downloading (…)lve/main/config.json: 100%|██████████| 693/693 [00:00<00:00, 101kB/s]
Downloading pytorch_model.bin: 100%|██████████| 1.12G/1.12G [00:21<00:00, 52.1MB/s]


In [8]:
len(model.transformer.h)

24

In [10]:
import torch
from transformers import OPTForCausalLM

model = OPTForCausalLM.from_pretrained('facebook/opt-350m', torch_dtype='auto')
model.eval()
device = torch.device('cuda:0')

In [13]:
layers = model.model.decoder.layers


In [14]:
import torch.nn as nn 

def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''):
    if type(module) in layers:
        return {name: module}
    res = {}
    for name1, child in module.named_children():
        res.update(find_layers(
            child, layers=layers, name=name + '.' + name1 if name != '' else name1
        ))
    return res

layer = layers[0].to(device)
subset = find_layers(layer)

In [15]:
subset

{'self_attn.k_proj': Linear(in_features=1024, out_features=1024, bias=True),
 'self_attn.v_proj': Linear(in_features=1024, out_features=1024, bias=True),
 'self_attn.q_proj': Linear(in_features=1024, out_features=1024, bias=True),
 'self_attn.out_proj': Linear(in_features=1024, out_features=1024, bias=True),
 'fc1': Linear(in_features=1024, out_features=4096, bias=True),
 'fc2': Linear(in_features=4096, out_features=1024, bias=True)}

In [16]:
from quantization.gptq import GPTQ

name = 'fc2'
gptq = {}
gptq[name] = GPTQ(subset[name])

In [19]:
from quantization.rtn import RTN

gptq[name].quantizer = RTN()
gptq[name].quantizer.configure(
                4, perchannel=True, sym=False, mse=False, trits=False
            )

In [59]:
dtype = next(iter(model.parameters())).dtype
inps = torch.zeros(
        (128, 2048, model.config.hidden_size), dtype=dtype, device=device
    )

In [60]:
cache = {'i': 0, 'attention_mask': None}

class Catcher(nn.Module):
    def __init__(self, module):
        super().__init__()
        self.module = module
    def forward(self, inp, **kwargs):
        inps[cache['i']] = inp
        cache['i'] += 1
        cache['attention_mask'] = kwargs['attention_mask']
        raise ValueError
    


In [61]:
layers[0] = Catcher(layers[0])

In [58]:
from data import get_sample_data

trainloader = get_sample_data('wikitext2', 'facebook/opt-350m', 128, 2048)
model = model.to(device)

Found cached dataset wikitext (/home/youpengzhao/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)


In [62]:
len(trainloader)

128

In [63]:
for batch in trainloader:
    try:
        model(batch[0].to(device))
    except ValueError:
        pass

In [69]:
batch

(tensor([[6225, 1003,    7,  ...,  645,    7, 5709]]),
 tensor([[-100, -100, -100,  ..., -100, -100, 5709]]))

In [28]:
layers[0] = layers[0].module

In [29]:
layers[0]

OPTDecoderLayer(
  (self_attn): OPTAttention(
    (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
    (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
    (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
    (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
  )
  (activation_fn): ReLU()
  (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  (fc1): Linear(in_features=1024, out_features=4096, bias=True)
  (fc2): Linear(in_features=4096, out_features=1024, bias=True)
  (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)

In [30]:
cache

{'i': 0, 'attention_mask': None}

In [31]:
outs = torch.zeros_like(inps)

In [35]:
def add_batch(name):
    def tmp(_, inp, out):
        gptq[name].add_batch(inp[0].data, out.data)
    return tmp

In [37]:
handles = []