In [1]:


import time

import torch
import torch.nn as nn

from quant import *
from sparsegpt import *
from modelutils import *
from datautils import *
from transformers import OPTForCausalLM
from dataclasses import dataclass




In [2]:
@dataclass
class Args(object):
    nsamples: int = 128
    sparsity = 0.3
    prunen: int = 0
    prunem: int = 0
    percdamp = .01
    blocksize: int = 128
    batch_size: int = 64
    num_layers: int = 5
    input_size: int = 784
    output_size: int = 10
    minlayer: int = -1
    maxlayer: int = 1000
    prune_only: str = ""
    invert: bool = True
    dataset: str = "ptb"
    seed: int = 0
    
args = Args()

In [3]:

model = "facebook/opt-125m"
model = OPTForCausalLM.from_pretrained(model, torch_dtype='auto')
model.seqlen = model.config.max_position_embeddings

In [4]:
model.model.decoder.layers

ModuleList(
  (0-11): 12 x OPTDecoderLayer(
    (self_attn): OPTAttention(
      (k_proj): Linear(in_features=768, out_features=768, bias=True)
      (v_proj): Linear(in_features=768, out_features=768, bias=True)
      (q_proj): Linear(in_features=768, out_features=768, bias=True)
      (out_proj): Linear(in_features=768, out_features=768, bias=True)
    )
    (activation_fn): ReLU()
    (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (fc1): Linear(in_features=768, out_features=3072, bias=True)
    (fc2): Linear(in_features=3072, out_features=768, bias=True)
    (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
)

In [None]:
class FCBlock(nn.Module):
    def __init__(self, input_size = 28*28, output_size = 28*28):
        super(FCBlock, self).__init__()
        self.fc1 = nn.Linear(input_size, output_size)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(output_size, output_size)
        self.relu2 = nn.ReLU()
        
    def forward(self, x):
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        x = self.relu2(x)
        return x

class MLP(nn.Module):
    def __init__(self, input_size=28*28, output_size=10, num_blocks = 4):
        super(MLP, self).__init__()
        
        self.input_size = input_size
        self.output_size = output_size

        self.layers = [FCBlock() for i in range(num_blocks)]
        out = nn.Linear(input_size, output_size)
        self.layers.append(out)
        self.layers = nn.ModuleList(self.layers)
        
    def forward(self, x):
        x = x.view(-1, self.input_size)
        for layer in self.layers[:-1]:
            x = layer(x)
        return self.layers[-1](x)


In [5]:

@torch.no_grad()
def opt_sequential(model, dataloader, dev):
    print('Starting ...')

    use_cache = model.config.use_cache
    model.config.use_cache = False
    layers = model.model.decoder.layers
    print("layers: ", layers)
    model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(dev) 
    model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(dev)
    if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out:
        model.model.decoder.project_out = model.model.decoder.project_out.to(dev) 
    if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in:
        model.model.decoder.project_in = model.model.decoder.project_in.to(dev) 
    print("layers[0]: ", layers[0])
    layers[0] = layers[0].to(dev)
    print("model.seqlen: ", model.seqlen)
    dtype = next(iter(model.parameters())).dtype
    inps = torch.zeros(
        (args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev
    )
    print(f"hidden size: {model.config.hidden_size}")
    cache = {'i': 0, 'attention_mask': None}

    class Catcher(nn.Module):
        def __init__(self, module):
            super().__init__()
            self.module = module
        def forward(self, inp, **kwargs):
            inps[cache['i']] = inp
            print(f"cache['i']: {cache['i']}")
            print(f"inps[cache['i']]: {inps[cache['i']]}")
            cache['i'] += 1

            cache['attention_mask'] = kwargs['attention_mask']
            raise ValueError
    layers[0] = Catcher(layers[0])
    print("Catcher(layers[0]): ", layers[0])
    for batch in dataloader:
        try:
            print(batch[0])
            model(batch[0].to(dev))
        except ValueError:
            pass
    layers[0] = layers[0].module
    print("layers[0].module: ", layers[0])
    layers[0] = layers[0].cpu()
    model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.cpu()
    model.model.decoder.embed_positions = model.model.decoder.embed_positions.cpu()
    if hasattr(model.model.decoder, 'project_out') and model.model.decoder.project_out:
        model.model.decoder.project_out = model.model.decoder.project_out.cpu()
    if hasattr(model.model.decoder, 'project_in') and model.model.decoder.project_in:
        model.model.decoder.project_in = model.model.decoder.project_in.cpu()
    torch.cuda.empty_cache()

    outs = torch.zeros_like(inps)
    attention_mask = cache['attention_mask']

    print('Ready.')

    for i in range(len(layers)):
        layer = layers[i].to(dev)
        print(f"layer[{i}]: ", layer)
        subset = find_layers(layer)
        
        gpts = {}
        for name in subset:
            if (not (args.minlayer <= i < args.maxlayer and args.prune_only in name)) == (not args.invert):
              continue
            gpts[name] = SparseGPT(subset[name])
            print(name, subset[name])
            if args.wbits < 16:
                gpts[name].quantizer = Quantizer()
                gpts[name].quantizer.configure(
                    args.wbits, perchannel=True, sym=False, mse=False
                )

        def add_batch(name):
            def tmp(_, inp, out):
                gpts[name].add_batch(inp[0].data, out.data)

            return tmp
        handles = []
        for name in gpts:
            handles.append(subset[name].register_forward_hook(add_batch(name)))
        for j in range(args.nsamples):
            outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0]
        for h in handles:
            h.remove()

        for name in gpts:
            print(i, name)
            print('Pruning ...')
            sparsity = args.sparsity
            gpts[name].fasterprune(
                sparsity, prunen=args.prunen, prunem=args.prunem, percdamp=args.percdamp, blocksize=args.blocksize
            )
            gpts[name].free()

        for j in range(args.nsamples):
            outs[j] = layer(inps[j].unsqueeze(0), attention_mask=attention_mask)[0]

        layers[i] = layer.cpu()
        del layer
        torch.cuda.empty_cache()

        inps, outs = outs, inps

    model.config.use_cache = use_cache

In [6]:
tokenizer = get_tokenizer("facebook/opt-125m")

In [7]:
dataloader, testloader = get_ptb(args.nsamples, args.seed, model.config.max_position_embeddings, "facebook/opt-125m", tokenizer)

Reusing dataset ptb_text_only (C:\Users\igor-\.cache\huggingface\datasets\ptb_text_only\penn_treebank\1.1.0\8d1b97746fb9765d140e569ec5ddd35e20af4d37761f5e1bf357ea0b081f2c1f)
Reusing dataset ptb_text_only (C:\Users\igor-\.cache\huggingface\datasets\ptb_text_only\penn_treebank\1.1.0\8d1b97746fb9765d140e569ec5ddd35e20af4d37761f5e1bf357ea0b081f2c1f)


In [8]:
opt_sequential(model, dataloader, torch.device('cuda:0'))

Starting ...
layers:  ModuleList(
  (0-11): 12 x OPTDecoderLayer(
    (self_attn): OPTAttention(
      (k_proj): Linear(in_features=768, out_features=768, bias=True)
      (v_proj): Linear(in_features=768, out_features=768, bias=True)
      (q_proj): Linear(in_features=768, out_features=768, bias=True)
      (out_proj): Linear(in_features=768, out_features=768, bias=True)
    )
    (activation_fn): ReLU()
    (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (fc1): Linear(in_features=768, out_features=3072, bias=True)
    (fc2): Linear(in_features=3072, out_features=768, bias=True)
    (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
)
layers[0]:  OPTDecoderLayer(
  (self_attn): OPTAttention(
    (k_proj): Linear(in_features=768, out_features=768, bias=True)
    (v_proj): Linear(in_features=768, out_features=768, bias=True)
    (q_proj): Linear(in_features=768, out_features=768, bias=True)
    (out_proj): Linear(in_features=7

In [9]:
opt_sequential(model, dataloader, torch.device('cuda:0'))

Starting ...
layers:  ModuleList(
  (0-11): 12 x OPTDecoderLayer(
    (self_attn): OPTAttention(
      (k_proj): Linear(in_features=768, out_features=768, bias=True)
      (v_proj): Linear(in_features=768, out_features=768, bias=True)
      (q_proj): Linear(in_features=768, out_features=768, bias=True)
      (out_proj): Linear(in_features=768, out_features=768, bias=True)
    )
    (activation_fn): ReLU()
    (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (fc1): Linear(in_features=768, out_features=3072, bias=True)
    (fc2): Linear(in_features=3072, out_features=768, bias=True)
    (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
)
layers[0]:  OPTDecoderLayer(
  (self_attn): OPTAttention(
    (k_proj): Linear(in_features=768, out_features=768, bias=True)
    (v_proj): Linear(in_features=768, out_features=768, bias=True)
    (q_proj): Linear(in_features=768, out_features=768, bias=True)
    (out_proj): Linear(in_features=7