Sample Pruning experiment to demonstrate memory savings from pruning feedforward weights.
Doesn't use actual weights; simply for demonstration. See prune.py for script to prune Llama FF layers

In [None]:
import struct
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune

from model import ModelArgs
from convert import serialize_fp32
from prune import serialize_bitvec

In [None]:
"""
Serializes a pruned Nx1 weight as follows:
- first byte is an unsigned int max containing the number of elements:
- next log(N) bytes is a bitvector containing a 1 if that number in the tensor is nonzero and 0 otherwise
- padding to make the number of bytes a multiple of 8
- every single nonzero value, saved as fp32
"""
def serialize_pruned_tensor(tensor, filename, bitvec=False):
    num_elements = tensor.numel()
    print("num elements", num_elements)
    bitvector = tensor.flatten().bool().view(-1).cpu().numpy().tobytes()
    if bitvec:
        bin_file = open("bin_" + filename, "wb+")
        serialize_bitvec(bin_file, tensor.flatten().bool())

    num_nonzero = tensor.flatten().bool().sum().item()

    with open(filename, "wb+") as f:
        f.write(struct.pack("I", num_elements))
        f.write(bitvector)
        padding = b"\x00" * ((8 - len(bitvector)) % 8)
        f.write(padding)
        print("writing", num_nonzero, "nonzero values")
        serialize_fp32(f, tensor.flatten()[tensor.flatten().bool()])

In [None]:
# set ModelArgs\
config = ModelArgs()

config.dim = 4096
config.n_layers = 32
config.n_heads = 32
config.hidden_dim = 11008

config.vocab_size = 32000
config.max_seq_len = 2048

In [None]:
# same dims as model.layers[i].feed_forward.w1 (or w2 or w3)
ff1 = nn.Linear(config.hidden_dim, config.dim)
w = torch.randn(config.hidden_dim, config.dim)
ff1.weight.data = w
# commented out to avoid accidents, but if u want to test, uncomment
# serialize_pruned_tensor(ff1.weight.data, 'full.bin')

full.bin should be 225.4 MB

In [None]:
# clone layer for later
ff1_clone = nn.Linear(config.hidden_dim, config.dim)
ff1_clone.weight = nn.Parameter(w.clone())

llama with 20% weights pruned only exhibits 73%->69% performance loss  
src: https://arxiv.org/pdf/2305.11627.pdfprune_amount = 0.2

In [None]:
prune.l1_unstructured(ff1, name='weight', amount=0.2)
# serialize_pruned_tensor(ff1.weight.data, 'pruned-02.bin')

pruned-02.bin should be 189.4 MB


In [None]:
# extremely aggressive pruning
prune.l1_unstructured(ff1_clone, name='weight', amount=0.5)
# remove orig weights
prune.remove(ff1_clone, 'weight')
# serialize_pruned_tensor(ff1_clone.weight.data, 'pruned-05.bin')

pruned-05.bin should be 135.3 MB