In [1]:
from transformers import AutoTokenizer
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
bigname="microsoft/Phi-3.5-mini-instruct"

tokenizer = AutoTokenizer.from_pretrained(bigname,use_fast=False)
assert not tokenizer.legacy

In [3]:
from transformers import AutoModelForCausalLM

bigname="microsoft/Phi-3.5-mini-instruct"
bigmodel = AutoModelForCausalLM.from_pretrained(bigname,  device_map="cpu",torch_dtype=torch.float16,
                                                trust_remote_code=True)


`flash-attention` package not found, consider installing for better performance: No module named 'flash_attn'.
Current `flash-attention` does not support `window_size`. Either upgrade or use `attn_implementation='eager'`.
Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.61it/s]


In [4]:
import numpy as np

matrices = []

for name, param in bigmodel.named_parameters():
    if ".1." in name and "o_proj" in name :
        layer_5_oproj = param.detach().cpu().numpy()
        print(name)
    if ".2." in name and "o_proj" in name :
        layer_6_oproj = param.detach().cpu().numpy()
        print(name)
    if ".2." in name:
        print(name, param.shape)
    if (".1." in name or ".2." in name or ".3." in name or ".4." in name or ".5." in name) and "o_proj" in name:
        matrices.append(param.detach().cpu().numpy())
        

A = np.array(layer_5_oproj, dtype=np.float64, copy=True)
B = np.array(layer_6_oproj, dtype=np.float64, copy=True)
print('total A', np.linalg.norm(A, 'fro'))

model.layers.1.self_attn.o_proj.weight
model.layers.2.self_attn.o_proj.weight
model.layers.2.self_attn.o_proj.weight torch.Size([3072, 3072])
model.layers.2.self_attn.qkv_proj.weight torch.Size([9216, 3072])
model.layers.2.mlp.gate_up_proj.weight torch.Size([16384, 3072])
model.layers.2.mlp.down_proj.weight torch.Size([3072, 8192])
model.layers.2.input_layernorm.weight torch.Size([3072])
model.layers.2.post_attention_layernorm.weight torch.Size([3072])
total A 73.3127109662652


In [7]:
# we have two weight matrices, A and B, each of shape (3072, 3072)


(3072, 3072)

In [5]:
def get_all_layers():
    layer_all_o_proj = []
    for name, param in bigmodel.named_parameters():
        if "o_proj" in name:
            layer_all_o_proj.append(param.detach().cpu().numpy())

    return layer_all_o_proj

In [7]:
import scipy
import scipy.linalg

def orthogonal_procrustes(data1, data2, check_finite=True):
    if check_finite:
        data1 = np.asarray_chkfinite(data1)
        data2 = np.asarray_chkfinite(data2)
    else:
        data1 = np.asanyarray(data1)
        data2 = np.asanyarray(data2)
    if data1.ndim != 2:
        raise ValueError('expected ndim to be 2, but observed %s' % data1.ndim)
    if data1.shape != data2.shape:
        raise ValueError(f'the shapes of A and B differ ({data1.shape} vs {data2.shape})')
    # Be clever with transposes, with the intention to save memory.
    u, w, vt = scipy.linalg.svd(data2.T.dot(data1).T)
    R = u.dot(vt)
    R_ = u[:, :].dot(vt[:, :])
    scale = w.sum()

    # n = u[:, 0]
    # # Ensure that n is a unit vector
    # n = n / np.linalg.norm(n)
    # R = R_ = np.eye(data2.shape[1]) - 2 * np.outer(n ,n)
    # scale = 1
    # R = np.array(R, dtype=np.float64, copy=True)
    # u, w, vt = scipy.linalg.svd(R)
    # R = u[:, :256] @ np.diag(w[:256]) @ vt[:256, :]
    return R, scale, R_

In [5]:
def procrustes(data1, data2):
    mtx1 = np.array(data1, dtype=np.float64, copy=True)
    mtx2 = np.array(data2, dtype=np.float64, copy=True)

    if mtx1.ndim != 2 or mtx2.ndim != 2:
        raise ValueError("Input matrices must be two-dimensional")
    if mtx1.shape != mtx2.shape:
        raise ValueError("Input matrices must be of same shape")
    if mtx1.size == 0:
        raise ValueError("Input matrices must be >0 rows and >0 cols")

    # translate all the data to the origin
    mean_1 = np.mean(mtx1, 0)
    mean_2 = np.mean(mtx2, 0)
    mtx1 -= mean_1
    mtx2 -= mean_2

    norm1 = np.linalg.norm(mtx1)
    norm2 = np.linalg.norm(mtx2)

    if norm1 == 0 or norm2 == 0:
        raise ValueError("Input matrices must contain >1 unique points")

    # change scaling of data (in rows) such that trace(mtx*mtx') = 1
    mtx1 /= norm1
    mtx2 /= norm2
    mtx2_scaled = mtx2.copy()
    # transform mtx2 to minimize disparity
    # R, s = scipy.linalg.orthogonal_procrustes(mtx1, mtx2)
    R, s, R_ = orthogonal_procrustes(mtx1, mtx2)
    
    mtx2 = np.dot(mtx2, R_.T) * s

    # measure the dissimilarity between the two datasets
    disparity = np.sum(np.square(mtx1 - mtx2))

    return mtx1, mtx2, mtx2_scaled, R_, s, mean_1, mean_2, norm1, norm2, disparity

In [8]:

mtx1, mtx2, mtx2_scaled, R_, s, mean_1, mean_2, norm_1, norm_2, disparity = procrustes(layer_5_oproj, layer_6_oproj)
A_scaled = mtx1
B_scaled = mtx2_scaled
mtx2_estimate = np.dot(B_scaled, R_.T) * s
print(R_.shape)
# Delta = mtx1 - mtx2
# plt.plot(mtx1 - mtx2)
# plt.xlabel('Index')
# plt.ylabel('Value')
# plt.title('Plot of delta 5 and 6 after procrustes')
# plt.show()


(3072, 3072)


In [9]:
import scipy
delta = mtx1 - mtx2
orig_delta = A-B
scaled_delta = mtx1 - mtx2_scaled

Ud, Sd, VTd = scipy.linalg.svd(delta)
U, S, VT = scipy.linalg.svd(A)
# Ud2, Sd2, VTd2 = scipy.linalg.svd(orig_delta)
# Ud3, Sd3, VTd3 = scipy.linalg.svd(scaled_delta)
# Us, Ss, VTs = scipy.linalg.svd(A_scaled)



In [10]:
def get_approx_matrix(U, S, VT, k=128):
    U_k = U[:, :k]
    S_k = np.diag(S[:k])
    VT_k = VT[:k, :]

    return U_k @ S_k @ VT_k

def approx_error_calc_original(A, U, S, VT, k=128):

    A_approx = get_approx_matrix(U, S, VT, k)
    reconstruction_error = np.linalg.norm(A - A_approx, 'fro')
    # reconstruction_error = np.sum(np.abs(A - A_approx))
    print("Reconstruction error original:", reconstruction_error)
    return reconstruction_error

def approx_error_calc_original_scaled(A, U, S, VT, k=128):

    A_approx_scaled = get_approx_matrix(U, S, VT, k)
    A_approx = A_approx_scaled * norm_1
    A_approx += mean_1
    reconstruction_error = np.linalg.norm(A - A_approx, 'fro')
    # reconstruction_error = np.sum(np.abs(A - A_approx))
    print("Reconstruction error original scaled:", reconstruction_error)
    return reconstruction_error

def approx_error_calc_original_delta(A, B, U, S, VT, k=128):

    delta_approx = get_approx_matrix(U, S, VT, k)
    A_approx = delta_approx + B
    reconstruction_error = np.linalg.norm(A - A_approx, 'fro')
    # reconstruction_error = np.sum(np.abs(A - A_approx))
    print("Reconstruction error original delta:", reconstruction_error)
    return reconstruction_error

def approx_error_calc_scaled_delta(A, B_scaled, U, S, VT, k=128):

    scaled_delta_approx = get_approx_matrix(U, S, VT, k)
    A_scaled_approx = scaled_delta_approx + B_scaled
    A_approx = A_scaled_approx * norm_1
    A_approx += mean_1
    # print(mean_1)
    reconstruction_error = np.linalg.norm(A - A_approx, 'fro')
    # reconstruction_error = np.sum(np.abs(A - A_approx))
    print("Reconstruction error scaled delta:", reconstruction_error)
    return reconstruction_error

def approx_error_calc_rotated(A, B_rotated, U, S, VT, k=128):

    rotated_delta_approx = get_approx_matrix(U, S, VT, k)
    A_scaled_approx = rotated_delta_approx + B_rotated
    A_approx = A_scaled_approx * norm_1
    A_approx += mean_1
    # print(mean_1)
    reconstruction_error = np.linalg.norm(A - A_approx, 'fro')
    # reconstruction_error = np.sum(np.abs(A - A_approx))
    print("Reconstruction error rotated:", reconstruction_error)
    return reconstruction_error
# print('delta')
# approx_error_calc(delta, Ud, Sd, VTd, k=128)
# print('original')
# approx_error_calc(mtx1, U, S, VT, k=128)
# print('original delta')

print('total A', np.linalg.norm(A, 'fro'))
# print('total A', np.sum(np.abs(A)))
approx_error_calc_original(A, U, S, VT, k=1000)
# approx_error_calc_original_scaled(A, Us, Ss, VTs, k=1024)
# approx_error_calc_original_delta(A, B, Ud2, Sd2, VTd2, k=1024)

# approx_error_calc_scaled_delta(A, B_scaled, Ud3, Sd3, VTd3, k=1024)
approx_error_calc_rotated(A, mtx2_estimate, Ud, Sd, VTd, k=1000)


total A 73.3127109662652
Reconstruction error original: 30.734716702842555
Reconstruction error rotated: 19.235566208154548


np.float64(19.235566208154548)