In [46]:
import copy
import numpy as np
import torch
import tensorly as tl
from tensorly.decomposition import CP
import tntorch as tn

from safetensors.torch import load_file as safe_load_file, save_file as safe_save_file

adapter_path = "./dpo_intel/meta-llama/Llama-2-7B-hf_lora_r_8_alpha_16"

# Load state dict
state_dict = safe_load_file(f"{adapter_path}/adapter_model.safetensors")

# Make a copy so we can store CP-decomposed A and B
new_state_dict = copy.deepcopy(state_dict)

# Parameters
d_in = 4096     # Input dimension
d_out = 4096    # Output dimension (usually the embedding size)
num_layers = 32
num_heads = 32
d_out_per_head = d_out // num_heads
components = ['q_proj', 'v_proj']  # for illustration
num_components = len(components)

# We will store the "original" LoRA expansions for convenience
# shape: [d_in, num_layers, num_components, d_out_per_head, num_heads]
dWs = torch.zeros(d_in, num_layers, num_components, d_out_per_head, num_heads)

# Collect the original LoRA A and B, form the 2D slice, and reshape
for layer_idx in range(num_layers):
    for comp_idx, comp_name in enumerate(components):
        A_key = f"base_model.model.model.layers.{layer_idx}.self_attn.{comp_name}.lora_A.weight"
        B_key = f"base_model.model.model.layers.{layer_idx}.self_attn.{comp_name}.lora_B.weight"

        A_orig = state_dict[A_key]  # shape: (rank, d_in)
        B_orig = state_dict[B_key]  # shape: (d_out, rank)

        # dW_2D = (A^T @ B^T) -> shape (d_in, d_out)
        # But the original LORA formulation is often W + B_orig @ A_orig,
        # depending on the framework. We'll follow your code convention:
        dW_2D = (B_orig @ A_orig).T  # => (d_in, d_out)

        # Reshape to (d_in, d_out_per_head, num_heads)
        dW_3D = dW_2D.reshape(d_in, d_out_per_head, num_heads)

        # Place this slice into dWs
        dWs[:, layer_idx, comp_idx, :, :] = dW_3D



In [None]:
# Decomposition parameters
rank = 8

rel_errors = []
r2_scores = []

##############################################################################
# Loop over each layer and each component, perform CP decomposition, 
# and update new_state_dict with the new A and B.
##############################################################################
for layer_idx in range(num_layers):
    for comp_idx, comp_name in enumerate(components):
        # Extract slice for that (layer, component)
        data = dWs[:, layer_idx, comp_idx, :, :]  # shape: (d_in, d_out_per_head, num_heads)
        tensor = tl.tensor(data)                  # convert to NumPy array internally

        # CP decomposition on the 3D slice
        cp = CP(rank=rank, verbose=False)
        factors = cp.fit_transform(tensor)  # returns (weights, [factor1, factor2, factor3])
        weights, (A_factor, B_factor, C_factor) = factors

        # By TensorLy convention:
        #   A_factor.shape == (d_in, rank)
        #   B_factor.shape == (d_out_per_head, rank)
        #   C_factor.shape == (num_heads, rank)

        # The original LoRA A was shape (rank, d_in). So we transpose A_factor:
        A_cp = A_factor.T  # shape: (rank, d_in)

        # We want B to be (d_out, rank) = (d_out_per_head * num_heads, rank).
        # Combine B_factor and C_factor appropriately:
        # Each head's slice is B_factor * diag(C_factor[i]), then we stack them:
        B_slices = []
        for i in range(num_heads):
            # shape is (d_out_per_head, rank) multiplied by diag(...) => (d_out_per_head, rank)
            B_slices.append(B_factor @ np.diag(C_factor[i]))
        B_cp = np.concatenate(B_slices, axis=0)  # shape: (d_out_per_head * num_heads, rank) = (d_out, rank)

        # Compute reconstructed tensor to measure errors
        reconstructed = torch.from_numpy(tl.cp_to_tensor(factors))

        # Convert data to torch if it isn't already
        data_torch = data if isinstance(data, torch.Tensor) else torch.from_numpy(data)
        data_torch = data_torch.to(reconstructed.dtype)

        # Compute metrics
        rel_err = tn.relative_error(data_torch, reconstructed)
        r2 = tn.r_squared(data_torch, reconstructed)
        rel_errors.append(rel_err)
        r2_scores.append(r2)

        # Now store the new A_cp and B_cp into new_state_dict
        # Convert to torch.float32 Tensors
        A_torch = torch.from_numpy(A_cp).float()
        B_torch = torch.from_numpy(B_cp).float()

        # The key names match the original LoRA structure:
        A_key = f"base_model.model.model.layers.{layer_idx}.self_attn.{comp_name}.lora_A.weight"
        B_key = f"base_model.model.model.layers.{layer_idx}.self_attn.{comp_name}.lora_B.weight"

        # Update them in new_state_dict
        new_state_dict[A_key] = A_torch
        new_state_dict[B_key] = B_torch
        break


In [None]:
##############################################################################
# Save new state_dict to disk (safetensors)
##############################################################################
output_safetensors_path = f"{adapter_path}/adapter_model_cp.safetensors"
safe_save_file(new_state_dict, output_safetensors_path)
print(f"New CP-decomposed model saved to: {output_safetensors_path}")

##############################################################################
# Dump the error metrics to a file (mean, median, max, std)
##############################################################################
rel_errors = np.array(rel_errors)
r2_scores = np.array(r2_scores)

rel_err_mean = rel_errors.mean()
rel_err_median = np.median(rel_errors)
rel_err_max = rel_errors.max()
rel_err_std = rel_errors.std()

r2_mean = r2_scores.mean()
r2_median = np.median(r2_scores)
r2_max = r2_scores.max()
r2_std = r2_scores.std()

stats_filename = "cp_decomp_stats.txt"
with open(stats_filename, "w") as f:
    f.write("Relative error stats:\n")
    f.write(f"  mean   = {rel_err_mean:.6f}\n")
    f.write(f"  median = {rel_err_median:.6f}\n")
    f.write(f"  max    = {rel_err_max:.6f}\n")
    f.write(f"  std    = {rel_err_std:.6f}\n\n")

    f.write("R^2 score stats:\n")
    f.write(f"  mean   = {r2_mean:.6f}\n")
    f.write(f"  median = {r2_median:.6f}\n")
    f.write(f"  max    = {r2_max:.6f}\n")
    f.write(f"  std    = {r2_std:.6f}\n")

print(f"Decomposition stats saved to: {stats_filename}")