<a href="https://colab.research.google.com/github/honicky/quantization-experiments/blob/main/Rank_experiements.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install wandb datasets huggingface_hub


# Quantization functions

Borrowed from https://huggingface.co/1bitLLM/bitnet_b1_58-3B/blob/main/utils_quant.py

In [3]:
from torch import nn
import torch
import torch.autograd as autograd

def weight_quant_158b(weight, num_bits=1):
    dtype = weight.dtype
    weight = weight.float()
    s = 1 / weight.abs().mean().clamp(min=1e-5)
    result = (weight * s).round().clamp(-1, 1) / s
    return result.type(dtype)

def weight_quant(x, num_bits=8):
    dtype = x.dtype
    x = x.float()
    Qn = -2 ** (num_bits - 1)
    Qp = 2 ** (num_bits - 1) - 1
    s = Qp / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5)
    result = (x * s).round().clamp(Qn, Qp) / s
    return result.type(dtype)


# Load the pythia model

In [5]:
import torch
from transformers import GPTNeoXForCausalLM

# Load the SafeTensors model file
model_name = "EleutherAI/pythia-160m-deduped"
model = GPTNeoXForCausalLM.from_pretrained(model_name)


config.json:   0%|          | 0.00/569 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/375M [00:00<?, ?B/s]

In [6]:
model

GPTNeoXForCausalLM(
  (gpt_neox): GPTNeoXModel(
    (embed_in): Embedding(50304, 768)
    (emb_dropout): Dropout(p=0.0, inplace=False)
    (layers): ModuleList(
      (0-11): 12 x GPTNeoXLayer(
        (input_layernorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (post_attention_layernorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (post_attention_dropout): Dropout(p=0.0, inplace=False)
        (post_mlp_dropout): Dropout(p=0.0, inplace=False)
        (attention): GPTNeoXSdpaAttention(
          (rotary_emb): GPTNeoXRotaryEmbedding()
          (query_key_value): Linear(in_features=768, out_features=2304, bias=True)
          (dense): Linear(in_features=768, out_features=768, bias=True)
          (attention_dropout): Dropout(p=0.0, inplace=False)
        )
        (mlp): GPTNeoXMLP(
          (dense_h_to_4h): Linear(in_features=768, out_features=3072, bias=True)
          (dense_4h_to_h): Linear(in_features=3072, out_features=768, bias=True)
      

In [7]:
for name, param in model.named_parameters():
  print(name, param.shape)

gpt_neox.embed_in.weight torch.Size([50304, 768])
gpt_neox.layers.0.input_layernorm.weight torch.Size([768])
gpt_neox.layers.0.input_layernorm.bias torch.Size([768])
gpt_neox.layers.0.post_attention_layernorm.weight torch.Size([768])
gpt_neox.layers.0.post_attention_layernorm.bias torch.Size([768])
gpt_neox.layers.0.attention.query_key_value.weight torch.Size([2304, 768])
gpt_neox.layers.0.attention.query_key_value.bias torch.Size([2304])
gpt_neox.layers.0.attention.dense.weight torch.Size([768, 768])
gpt_neox.layers.0.attention.dense.bias torch.Size([768])
gpt_neox.layers.0.mlp.dense_h_to_4h.weight torch.Size([3072, 768])
gpt_neox.layers.0.mlp.dense_h_to_4h.bias torch.Size([3072])
gpt_neox.layers.0.mlp.dense_4h_to_h.weight torch.Size([768, 3072])
gpt_neox.layers.0.mlp.dense_4h_to_h.bias torch.Size([768])
gpt_neox.layers.1.input_layernorm.weight torch.Size([768])
gpt_neox.layers.1.input_layernorm.bias torch.Size([768])
gpt_neox.layers.1.post_attention_layernorm.weight torch.Size([768])

In [8]:
state_dict = model.state_dict()
weight_tensor = state_dict["gpt_neox.layers.10.mlp.dense_h_to_4h.weight"]
bias_tensor = state_dict["gpt_neox.layers.10.mlp.dense_h_to_4h.bias"]

in_features = weight_tensor.T.shape[0]
out_features = weight_tensor.T.shape[1]

print(f"in_features: {in_features}, out_features: {out_features}")



in_features: 768, out_features: 3072


# Plot the top singular values at different quantization levels

We iterate through the a few layers from the `pythia` model and show the top singluar values, comparing the original weight matrix, the quantized weight matrix and the residual matrix.  We can use this to get a sense for how quantization impacts the rank of our weight matrices

In [103]:
import torch
import plotly.express as px
from plotly.subplots import make_subplots

def svd_plot(weight_and_name, quant_bits, colors, sv_count):
  name, weight_tensor = weight_and_name
  print(name)
  # Perform SVD on the original tensor
  U, S, Vh = torch.linalg.svd(weight_tensor)

  if quant_bits == 1.58:
    quant = weight_quant_158b
  else:
    quant = weight_quant

  # Quantize the weight tensor
  weight_tensor_q = quant(weight_tensor, quant_bits)

  # Perform SVD on the quantized tensor
  Uq, Sq, Vhq = torch.linalg.svd(weight_tensor_q)

  # Perform SVD on the difference between original and quantized tensors
  Uqr, Sqr, Vhqr = torch.linalg.svd(weight_tensor - weight_tensor_q)

  # Calculate cumulative sums of singular values
  S_cumsum = S.cumsum(dim=0).numpy() / S.numpy().sum()
  Sq_cumsum = Sq.cumsum(dim=0).numpy() / Sq.numpy().sum()
  Sqr_cumsum = Sqr.cumsum(dim=0).numpy() / Sqr.numpy().sum()

  return S_cumsum[:sv_count], Sq_cumsum[:sv_count], Sqr_cumsum[:sv_count]

def plot_svd_grid(weight_tensors, quant_levels, colors, sv_count=100):
  num_rows = len(weight_tensors)
  num_cols = len(quant_levels)

  # Fixed row height and column width
  fixed_row_height = 265  # Set this to your desired row height
  fixed_col_width = 300  # Set this to your desired column width

  # Define row heights and column widths
  row_heights = [1] * num_rows
  col_widths = [1] * num_cols

  # Create a subplot grid with fixed vertical and horizontal spacing
  fig = make_subplots(
    rows=num_rows,
    cols=num_cols,
    subplot_titles=[f"{quant_bits} bits" for quant_bits in quant_levels],
    row_heights=row_heights,
    column_widths=col_widths,
    vertical_spacing=0.02,  # Small, consistent vertical spacing
    horizontal_spacing=0.05  # Adjust this value to control horizontal spacing
  )

  for i, weight_tensor in enumerate(weight_tensors):
    name, _ = weight_tensor
    y_max = 0
    y_min = 1

    row_data = []

    # Collect data and determine the y-axis range for the entire row in one pass
    for j, quant_level in enumerate(quant_levels):
      S_cumsum, Sq_cumsum, Sqr_cumsum = svd_plot(weight_tensor, quant_level, colors, sv_count)

      row_data.append((S_cumsum, Sq_cumsum, Sqr_cumsum))

      y_max = max(y_max, S_cumsum.max(), Sq_cumsum.max(), Sqr_cumsum.max())
      y_min = min(y_min, S_cumsum.min(), Sq_cumsum.min(), Sqr_cumsum.min())

    # Plotting the data with consistent y-axis across the row
    for j, (S_cumsum, Sq_cumsum, Sqr_cumsum) in enumerate(row_data):
      fig.add_scatter(
        y=S_cumsum,
        mode='lines',
        name="S",
        line=dict(color=colors[0]),
        row=i+1,
        col=j+1
      )
      fig.add_scatter(
        y=Sq_cumsum,
        mode='lines',
        name="Sq",
        line=dict(color=colors[1]),
        row=i+1,
        col=j+1
      )
      fig.add_scatter(
        y=Sqr_cumsum,
        mode='lines',
        name="Sqr",
        line=dict(color=colors[2]),
        row=i+1,
        col=j+1
      )

      fig.update_yaxes(range=[y_min, y_max], row=i+1, col=j+1)

    # Add the weight tensor name to the y-axis of the first column in each row
    fig.update_yaxes(title_text=name, row=i+1, col=1, title_standoff=10)

  # Calculate the overall figure height based on the fixed row height
  fig_height = fixed_row_height * num_rows

  # Update layout for the entire grid
  fig.update_layout(
    width=fixed_col_width * num_cols,
    height=fig_height,
    margin=dict(l=100, r=50, t=50, b=50),  # Increased left margin to make space for y-axis labels
    showlegend=False,
    title_text="SVD Cumulative Sum Grid",
    title_x=0.5
  )

  return fig


In [104]:
weight_tensors = [
    (name[len("gpt_neox.layers."):-len(".weight")], param)
    for name, param in state_dict.items()
    if "weight" in name and "embed" not in name and "norm" not in name
]
quant_bits = [1.58, 2, 3, 4]
colors = ['blue', 'green', 'red']
grid_fig = plot_svd_grid(weight_tensors[:8], quant_bits, colors, sv_count=200)
grid_fig.show()



0.attention.query_key_value
0.attention.query_key_value
0.attention.query_key_value
0.attention.query_key_value
0.attention.dense
0.attention.dense
0.attention.dense
0.attention.dense
0.mlp.dense_h_to_4h
0.mlp.dense_h_to_4h
0.mlp.dense_h_to_4h
0.mlp.dense_h_to_4h
0.mlp.dense_4h_to_h
0.mlp.dense_4h_to_h
0.mlp.dense_4h_to_h
0.mlp.dense_4h_to_h
1.attention.query_key_value
1.attention.query_key_value
1.attention.query_key_value
1.attention.query_key_value
1.attention.dense
1.attention.dense
1.attention.dense
1.attention.dense
1.mlp.dense_h_to_4h
1.mlp.dense_h_to_4h
1.mlp.dense_h_to_4h
1.mlp.dense_h_to_4h
1.mlp.dense_4h_to_h
1.mlp.dense_4h_to_h
1.mlp.dense_4h_to_h
1.mlp.dense_4h_to_h


# Compare to an actually low rank matrix

Create a truly low rank matrix and look at the cumulative sum of the singular values, for comparison with the above plots

In [12]:
import torch
import plotly.express as px

# create a low rank matrix that is 768 x 768 and rank 4

low_rank_matrix = torch.randn(768, 4) @ torch.randn(4, 768)
Ulr, Slr, Vhr = torch.linalg.svd(low_rank_matrix)
fig = px.line(Slr.cumsum(dim=0).numpy()[:200] / Slr.numpy().sum())
fig.update_layout(
  width=420,
  height=400,
)
#hide the
fig.update_layout(showlegend=False)
fig.show()