In [97]:
from transformers import AutoModelForCausalLM
import torch

model_name="microsoft/Phi-3.5-mini-instruct"
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, trust_remote_code=True)

`flash-attention` package not found, consider installing for better performance: No module named 'flash_attn'.
Current `flash-attention` does not support `window_size`. Either upgrade or use `attn_implementation='eager'`.
Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.93it/s]


In [98]:
import numpy as np

for name, param in model.named_parameters():
    if ".2." in name and "o_proj" in name :
        layer_1_oproj = param.detach().cpu().numpy()
        print(name)
    if ".3." in name and "o_proj" in name :
        layer_2_oproj = param.detach().cpu().numpy()
        print(name)
        
A = np.array(layer_1_oproj, dtype=np.float64, copy=True)
B = np.array(layer_2_oproj, dtype=np.float64, copy=True)

model.layers.2.self_attn.o_proj.weight
model.layers.3.self_attn.o_proj.weight
total A 81.69725469101107


In [None]:
def get_all_layers():
    pairs = []

    for name, param in model.named_parameters():
        


In [137]:
model

Phi3ForCausalLM(
  (model): Phi3Model(
    (embed_tokens): Embedding(32064, 3072, padding_idx=32000)
    (embed_dropout): Dropout(p=0.0, inplace=False)
    (layers): ModuleList(
      (0-31): 32 x Phi3DecoderLayer(
        (self_attn): Phi3Attention(
          (o_proj): Linear(in_features=3072, out_features=3072, bias=False)
          (qkv_proj): Linear(in_features=3072, out_features=9216, bias=False)
          (rotary_emb): Phi3LongRoPEScaledRotaryEmbedding()
        )
        (mlp): Phi3MLP(
          (gate_up_proj): Linear(in_features=3072, out_features=16384, bias=False)
          (down_proj): Linear(in_features=8192, out_features=3072, bias=False)
          (activation_fn): SiLU()
        )
        (input_layernorm): Phi3RMSNorm()
        (resid_attn_dropout): Dropout(p=0.0, inplace=False)
        (resid_mlp_dropout): Dropout(p=0.0, inplace=False)
        (post_attention_layernorm): Phi3RMSNorm()
      )
    )
    (norm): Phi3RMSNorm()
  )
  (lm_head): Linear(in_features=3072, out

In [20]:
# we have two weight matrices, A and B, each of shape (3072, 3072)
# orthogonal procrustes problem: find an orthogonal matrix Q which most closely maps A to B. 
# Q is the matrix that minimizes ||QA - B|| subject to Q^TQ = I 
# solution: Q = UV^T where UV^T is the SVD of BA^T


In [99]:
import scipy
import scipy.linalg

def orthogonal_procrustes(A, B):
    # Be clever with transposes, with the intention to save memory.
    u, w, vt = scipy.linalg.svd(B.T.dot(A).T)
    R = u.dot(vt)
    scale = w.sum()
    return R, scale

In [100]:
def procrustes(data1, data2):
    mtx1 = np.array(data1, dtype=np.float64, copy=True)
    mtx2 = np.array(data2, dtype=np.float64, copy=True)

    if mtx1.ndim != 2 or mtx2.ndim != 2:
        raise ValueError("Input matrices must be two-dimensional")
    if mtx1.shape != mtx2.shape:
        raise ValueError("Input matrices must be of same shape")
    if mtx1.size == 0:
        raise ValueError("Input matrices must be >0 rows and >0 cols")

    # translate all the data to the origin
    mean_1 = np.mean(mtx1, 0)
    mean_2 = np.mean(mtx2, 0)
    mtx1 -= mean_1
    mtx2 -= mean_2

    norm1 = np.linalg.norm(mtx1)
    norm2 = np.linalg.norm(mtx2)

    if norm1 == 0 or norm2 == 0:
        raise ValueError("Input matrices must contain >1 unique points")

    # change scaling of data (in rows) such that trace(mtx*mtx') = 1
    mtx1 /= norm1
    mtx2 /= norm2
    mtx2_scaled = mtx2.copy()
    # transform mtx2 to minimize disparity
    # R, s = scipy.linalg.orthogonal_procrustes(mtx1, mtx2)
    R, s = orthogonal_procrustes(mtx1, mtx2)
    
    mtx2 = np.dot(mtx2, R.T) * s

    # measure the dissimilarity between the two datasets
    disparity = np.sum(np.square(mtx1 - mtx2))

    return mtx1, mtx2, mtx2_scaled, R, s, mean_1, mean_2, norm1, norm2, disparity

In [122]:
A_final, B_final, B_scaled, R, scale, mean_1, mean_2, norm_1, norm_2, disparity = procrustes(A, B)

In [123]:
delta = A_final - B_final
orig_delta = A - B

In [124]:
def get_low_rank_matrix(U, S, VT, rank=128):
    U_k = U[:, :rank]
    S_k = np.diag(S[:rank])
    VT_k = VT[:rank, :]

    return U_k @ S_k @ VT_k

In [125]:
Ud, Sd, VTd = scipy.linalg.svd(delta)
U, S, VT = scipy.linalg.svd(A)

In [126]:
low_rank_delta = get_low_rank_matrix(Ud, Sd, VTd, rank=3072)
low_rank_A = get_low_rank_matrix(U, S, VT, rank=3072)

In [127]:
def approx_error_calc_original(matrix, k=128):
    U, S, VT = scipy.linalg.svd(matrix)
    matrix_approx = get_low_rank_matrix(U, S, VT, k)
    reconstruction_error = np.linalg.norm(matrix - matrix_approx, 'fro')
    return reconstruction_error

In [128]:
approx_error_calc_original(A, 3072)

np.float64(3.9086753689514116e-13)

In [129]:
def approx_error_calc_rotated(A, B_rotated, delta, k=128):
    Ud, Sd, VTd = scipy.linalg.svd(delta)
    rotated_delta_approx = get_low_rank_matrix(Ud, Sd, VTd, k)
    A_scaled_approx = rotated_delta_approx + B_rotated
    A_approx = A_scaled_approx * norm_1
    A_approx += mean_1
    reconstruction_error = np.linalg.norm(A - A_approx, 'fro')
    return reconstruction_error

In [130]:
B_rotated = np.dot(B_scaled, R.T) * scale
approx_error_calc_rotated(A, B_rotated, delta, k=3072)

np.float64(2.4065190971944947e-13)

In [135]:
def test_reconstruction_error():
    # with the same rank, the error should be zero
    error_orig = approx_error_calc_original(A, 3072)
    error_rotated = approx_error_calc_rotated(A, B_rotated, delta, 3072)
    np.testing.assert_allclose(error_orig, 0, atol=1e-12)
    np.testing.assert_allclose(error_rotated, 0, atol=1e-12)

In [136]:
test_reconstruction_error()