In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B", dtype="bfloat16")

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 4/4 [00:00<00:00, 120.07it/s]


In [2]:
length = 0
for name, param in model.named_parameters():
    if "proj" in name or "lm_head" in name:
        print(name, param.shape)
        length += 1
print(length)

# from double_sparse.modelutils import find_layers
# a = find_layers(model)
# for i in a:
#     print(i)
# print(len(a))

model.layers.0.self_attn.q_proj.weight torch.Size([4096, 4096])
model.layers.0.self_attn.k_proj.weight torch.Size([1024, 4096])
model.layers.0.self_attn.v_proj.weight torch.Size([1024, 4096])
model.layers.0.self_attn.o_proj.weight torch.Size([4096, 4096])
model.layers.0.mlp.gate_proj.weight torch.Size([14336, 4096])
model.layers.0.mlp.up_proj.weight torch.Size([14336, 4096])
model.layers.0.mlp.down_proj.weight torch.Size([4096, 14336])
model.layers.1.self_attn.q_proj.weight torch.Size([4096, 4096])
model.layers.1.self_attn.k_proj.weight torch.Size([1024, 4096])
model.layers.1.self_attn.v_proj.weight torch.Size([1024, 4096])
model.layers.1.self_attn.o_proj.weight torch.Size([4096, 4096])
model.layers.1.mlp.gate_proj.weight torch.Size([14336, 4096])
model.layers.1.mlp.up_proj.weight torch.Size([14336, 4096])
model.layers.1.mlp.down_proj.weight torch.Size([4096, 14336])
model.layers.2.self_attn.q_proj.weight torch.Size([4096, 4096])
model.layers.2.self_attn.k_proj.weight torch.Size([1024,

In [3]:
q_proj_weight = model.model.layers[3].self_attn.q_proj.weight
# q_proj_weight = model.model.layers[0].self_attn.k_proj.weight
print(q_proj_weight.shape)   # torch.Size([4096, 4096])
print(q_proj_weight)

torch.Size([4096, 4096])
Parameter containing:
tensor([[ 0.0092,  0.0076,  0.0015,  ...,  0.0166,  0.0034,  0.0076],
        [ 0.0125,  0.0153,  0.0129,  ..., -0.0004,  0.0128,  0.0056],
        [ 0.0192,  0.0037, -0.0167,  ...,  0.0115,  0.0078, -0.0008],
        ...,
        [-0.0164,  0.0254,  0.0026,  ...,  0.0089, -0.0062,  0.0029],
        [-0.0457,  0.0171, -0.0064,  ..., -0.0184, -0.0020, -0.0018],
        [-0.0135, -0.0003, -0.0234,  ..., -0.0082,  0.0088,  0.0267]],
       dtype=torch.bfloat16, requires_grad=True)


In [4]:
from doublesparse import factorizef, mag_prune
from ignite.metrics.regression.mean_absolute_relative_error import MeanAbsoluteRelativeError
import torch
import torch.nn as nn
import numpy as np

torch.cuda.empty_cache()
torch.cuda.synchronize()

In [5]:
q_proj_weight = q_proj_weight.cuda()
print(q_proj_weight.device)
print(q_proj_weight.dtype)

cuda:0
torch.bfloat16


In [6]:
@torch.no_grad()
def test_factorize(mask=None):
    torch.cuda.synchronize()
    SIZE = 4096
    
    matrix = q_proj_weight.to(dtype=torch.float32)
    identity = torch.eye(SIZE, device="cuda")
    prod, A, B = factorizef(matrix, identity, fixmask=mask)

    frobenius = torch.norm(prod - matrix, p='fro')
    print(f"Frobenius norm: {frobenius.item()}")

    prod = prod.cpu()
    A = A.cpu()
    B = B.cpu()

    nz_count_A = torch.count_nonzero(A).item()
    nz_count_B = torch.count_nonzero(B).item()

    print(f"A: {A[:4, :4]}")
    print(f"B: {B[:4, :4]}")
    print(f"AB: {prod}")
    print(f"Input matrix was: {matrix}")

    del A, B, matrix, identity
    torch.cuda.empty_cache()

    print(f"A has {nz_count_A} non-zero entries ({round(nz_count_A/(SIZE**2)*100, 1)}%)")
    print(f"B has {nz_count_B} non-zero entries ({round(nz_count_B/(SIZE**2)*100, 1)}%)")


In [7]:
@torch.no_grad()
def test_mag_prune(mask=None):
    torch.cuda.synchronize()
    SIZE = 4096
    
    matrix = q_proj_weight.to(dtype=torch.float32)
    identity = torch.eye(SIZE, device="cuda")
    prod = mag_prune(matrix)

    frobenius = torch.norm(matrix - prod, p='fro')
    print(f"Frobenius norm: {frobenius.item()}")

    prod = prod.cpu()
    print(f"AB: {prod}")
    print(f"Input matrix was: {matrix}")

    del matrix, identity
    torch.cuda.empty_cache()


In [8]:
test_factorize(mask=None)

Frobenius norm: 11.107711791992188
A: tensor([[-0.0000,  0.0000,  0.0000,  0.0000],
        [-0.0000, -0.0000,  0.0000, -0.0000],
        [ 0.0000,  0.0000, -0.0000, -0.0159],
        [ 0.0444,  0.0000, -0.0000,  0.0000]])
B: tensor([[0.5037, 0.0000, 0.0256, -0.0000],
        [0.0000, 0.4861, 0.0000, -0.0000],
        [-0.0000, 0.0000, 0.4536, 0.0000],
        [-0.0000, 0.0000, -0.0000, 0.4842]])
AB: tensor([[ 0.0110,  0.0053, -0.0034,  ...,  0.0195,  0.0077,  0.0048],
        [ 0.0144,  0.0151,  0.0122,  ...,  0.0012,  0.0163,  0.0088],
        [ 0.0154,  0.0084, -0.0134,  ...,  0.0136,  0.0020,  0.0028],
        ...,
        [-0.0203,  0.0232,  0.0048,  ...,  0.0073, -0.0018,  0.0028],
        [-0.0429,  0.0223, -0.0047,  ..., -0.0177, -0.0004, -0.0025],
        [-0.0090,  0.0005, -0.0237,  ..., -0.0118,  0.0090,  0.0314]])
Input matrix was: tensor([[ 0.0092,  0.0076,  0.0015,  ...,  0.0166,  0.0034,  0.0076],
        [ 0.0125,  0.0153,  0.0129,  ..., -0.0004,  0.0128,  0.0056],
    

In [9]:
test_mag_prune(mask=None)

Frobenius norm: 22.477331161499023
AB: tensor([[ 0.0000,  0.0000,  0.0000,  ...,  0.0166,  0.0000,  0.0000],
        [ 0.0000,  0.0153,  0.0000,  ..., -0.0000,  0.0000,  0.0000],
        [ 0.0192,  0.0000, -0.0167,  ...,  0.0000,  0.0000, -0.0000],
        ...,
        [-0.0164,  0.0254,  0.0000,  ...,  0.0000, -0.0000,  0.0000],
        [-0.0457,  0.0171, -0.0000,  ..., -0.0184, -0.0000, -0.0000],
        [-0.0135, -0.0000, -0.0234,  ..., -0.0000,  0.0000,  0.0267]])
Input matrix was: tensor([[ 0.0092,  0.0076,  0.0015,  ...,  0.0166,  0.0034,  0.0076],
        [ 0.0125,  0.0153,  0.0129,  ..., -0.0004,  0.0128,  0.0056],
        [ 0.0192,  0.0037, -0.0167,  ...,  0.0115,  0.0078, -0.0008],
        ...,
        [-0.0164,  0.0254,  0.0026,  ...,  0.0089, -0.0062,  0.0029],
        [-0.0457,  0.0171, -0.0064,  ..., -0.0184, -0.0020, -0.0018],
        [-0.0135, -0.0003, -0.0234,  ..., -0.0082,  0.0088,  0.0267]],
       device='cuda:0')


In [10]:
dblock_upper_half = torch.cat((torch.ones([32, 32], device="cuda"), torch.zeros([32, 32], device="cuda")), dim=1)
dblock_lower_half = torch.zeros([32, 64], device="cuda")
dblock = torch.cat((dblock_upper_half, dblock_lower_half), dim=0)
for _ in range(6):
    dblock = torch.cat((dblock, dblock), dim=1)
for _ in range(6):
    dblock = torch.cat((dblock, dblock), dim=0)

test_factorize(mask=dblock)

Frobenius norm: 51.09311294555664
A: tensor([[ 0.0000,  0.0081, -0.0000,  0.0000],
        [ 0.0000, -0.0000, -0.0000, -0.0000],
        [ 0.0000,  0.0000,  0.0000, -0.0000],
        [ 0.0185,  0.0135, -0.0000, -0.0098]])
B: tensor([[ 0.7362, -0.0020,  0.0396, -0.0077],
        [-0.0203,  0.7104,  0.0035,  0.0284],
        [-0.0092,  0.0020,  0.7983,  0.0492],
        [ 0.0215,  0.0018,  0.0504,  0.7454]])
AB: tensor([[ 0.0103,  0.0104,  0.0024,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0119,  0.0131,  0.0107,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0200,  0.0011, -0.0178,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [-0.0174,  0.0235,  0.0026,  ...,  0.0000,  0.0000,  0.0000],
        [-0.0439,  0.0166, -0.0084,  ...,  0.0000,  0.0000,  0.0000],
        [-0.0129,  0.0006, -0.0227,  ...,  0.0000,  0.0000,  0.0000]])
Input matrix was: tensor([[ 0.0092,  0.0076,  0.0015,  ...,  0.0166,  0.0034,  0.0076],
        [ 0.0125,  0.0153,  0.0129,  ..., -0.0004,  0.0128,  0.0

In [11]:
# import gc

# gc.collect()
# torch.cuda.empty_cache()
# torch.cuda.reset_peak_memory_stats()